Update ImageSegmenter API for image/video mode to have both callback API and returned result API.
PiperOrigin-RevId: 512697585
This commit is contained in:
parent
aa61abe386
commit
a60d67eb10
|
@ -47,10 +47,14 @@ import java.util.Optional;
|
|||
/**
|
||||
* Performs image segmentation on images.
|
||||
*
|
||||
* <p>Note that, unlike other vision tasks, the output of ImageSegmenter is provided through a
|
||||
* user-defined callback function even for the synchronous API. This makes it possible for
|
||||
* ImageSegmenter to return the output masks without any copy. {@link ResultListener} must be set in
|
||||
* the {@link ImageSegmenterOptions} for all {@link RunningMode}.
|
||||
* <p>Note that, in addition to the standard segmentation API, {@link segment} and {@link
|
||||
* segmentForVideo}, that take an input image and return the outputs, but involves deep copy of the
|
||||
* returns, ImageSegmenter also supports the callback API, {@link segmentWithResultListener} and
|
||||
* {@link segmentForVideoWithResultListener}, which allow you to access the outputs through zero
|
||||
* copy.
|
||||
*
|
||||
* <p>The callback API is available for all {@link RunningMode} in ImageSegmenter. Set {@link
|
||||
* ResultListener} in {@link ImageSegmenterOptions} properly to use the callback API.
|
||||
*
|
||||
* <p>The API expects a TFLite model with,<a
|
||||
* href="https://www.tensorflow.org/lite/convert/metadata">TFLite Model Metadata.</a>.
|
||||
|
@ -85,6 +89,8 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
|
|||
private static final String TASK_GRAPH_NAME =
|
||||
"mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph";
|
||||
|
||||
private boolean hasResultListener = false;
|
||||
|
||||
/**
|
||||
* Creates an {@link ImageSegmenter} instance from an {@link ImageSegmenterOptions}.
|
||||
*
|
||||
|
@ -116,8 +122,19 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
|
|||
int imageListSize =
|
||||
PacketGetter.getImageListSize(packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX));
|
||||
ByteBuffer[] buffersArray = new ByteBuffer[imageListSize];
|
||||
// If resultListener is not provided, the resulted MPImage is deep copied from mediapipe
|
||||
// graph. If provided, the result MPImage is wrapping the mediapipe packet memory.
|
||||
if (!segmenterOptions.resultListener().isPresent()) {
|
||||
for (int i = 0; i < imageListSize; i++) {
|
||||
buffersArray[i] =
|
||||
ByteBuffer.allocateDirect(
|
||||
width * height * (imageFormat == MPImage.IMAGE_FORMAT_VEC32F1 ? 4 : 1));
|
||||
}
|
||||
}
|
||||
if (!PacketGetter.getImageList(
|
||||
packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX), buffersArray, false)) {
|
||||
packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX),
|
||||
buffersArray,
|
||||
!segmenterOptions.resultListener().isPresent())) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.INTERNAL.ordinal(),
|
||||
"There is an error getting segmented masks. It usually results from incorrect"
|
||||
|
@ -143,7 +160,7 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
|
|||
.build();
|
||||
}
|
||||
});
|
||||
handler.setResultListener(segmenterOptions.resultListener());
|
||||
segmenterOptions.resultListener().ifPresent(handler::setResultListener);
|
||||
segmenterOptions.errorListener().ifPresent(handler::setErrorListener);
|
||||
TaskRunner runner =
|
||||
TaskRunner.create(
|
||||
|
@ -158,7 +175,8 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
|
|||
.setEnableFlowLimiting(segmenterOptions.runningMode() == RunningMode.LIVE_STREAM)
|
||||
.build(),
|
||||
handler);
|
||||
return new ImageSegmenter(runner, segmenterOptions.runningMode());
|
||||
return new ImageSegmenter(
|
||||
runner, segmenterOptions.runningMode(), segmenterOptions.resultListener().isPresent());
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -168,16 +186,17 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
|
|||
* @param taskRunner a {@link TaskRunner}.
|
||||
* @param runningMode a mediapipe vision task {@link RunningMode}.
|
||||
*/
|
||||
private ImageSegmenter(TaskRunner taskRunner, RunningMode runningMode) {
|
||||
private ImageSegmenter(
|
||||
TaskRunner taskRunner, RunningMode runningMode, boolean hasResultListener) {
|
||||
super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME);
|
||||
this.hasResultListener = hasResultListener;
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs image segmentation on the provided single image with default image processing options,
|
||||
* i.e. without any rotation applied, and the results will be available via the {@link
|
||||
* ResultListener} provided in the {@link ImageSegmenterOptions}. Only use this method when the
|
||||
* {@link ImageSegmenter} is created with {@link RunningMode.IMAGE}. TODO update java
|
||||
* doc for input image format.
|
||||
* i.e. without any rotation applied. Only use this method when the {@link ImageSegmenter} is
|
||||
* created with {@link RunningMode.IMAGE}. TODO update java doc for input image
|
||||
* format.
|
||||
*
|
||||
* <p>{@link ImageSegmenter} supports the following color space types:
|
||||
*
|
||||
|
@ -186,19 +205,19 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
|
|||
* </ul>
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
* @throws MediaPipeException if there is an internal error. Or if {@link ImageSegmenter} is
|
||||
* created with a {@link ResultListener}.
|
||||
*/
|
||||
public void segment(MPImage image) {
|
||||
segment(image, ImageProcessingOptions.builder().build());
|
||||
public ImageSegmenterResult segment(MPImage image) {
|
||||
return segment(image, ImageProcessingOptions.builder().build());
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs image segmentation on the provided single image, and the results will be available via
|
||||
* the {@link ResultListener} provided in the {@link ImageSegmenterOptions}. Only use this method
|
||||
* when the {@link ImageSegmenter} is created with {@link RunningMode.IMAGE}. TODO
|
||||
* update java doc for input image format.
|
||||
* Performs image segmentation on the provided single image. Only use this method when the {@link
|
||||
* ImageSegmenter} is created with {@link RunningMode.IMAGE}. TODO update java doc
|
||||
* for input image format.
|
||||
*
|
||||
* <p>{@link HandLandmarker} supports the following color space types:
|
||||
* <p>{@link ImageSegmenter} supports the following color space types:
|
||||
*
|
||||
* <ul>
|
||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
|
@ -211,19 +230,84 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
|
|||
* this method throwing an IllegalArgumentException.
|
||||
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
|
||||
* region-of-interest.
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
* @throws MediaPipeException if there is an internal error. Or if {@link ImageSegmenter} is
|
||||
* created with a {@link ResultListener}.
|
||||
*/
|
||||
public void segment(MPImage image, ImageProcessingOptions imageProcessingOptions) {
|
||||
public ImageSegmenterResult segment(
|
||||
MPImage image, ImageProcessingOptions imageProcessingOptions) {
|
||||
if (hasResultListener) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||
"ResultListener is provided in the ImageSegmenterOptions, but this method will return an"
|
||||
+ " ImageSegmentationResult.");
|
||||
}
|
||||
validateImageProcessingOptions(imageProcessingOptions);
|
||||
ImageSegmenterResult unused =
|
||||
(ImageSegmenterResult) processImageData(image, imageProcessingOptions);
|
||||
return (ImageSegmenterResult) processImageData(image, imageProcessingOptions);
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs image segmentation on the provided single image with default image processing options,
|
||||
* i.e. without any rotation applied, and provides zero-copied results via {@link ResultListener}
|
||||
* in {@link ImageSegmenterOptions}. Only use this method when the {@link ImageSegmenter} is
|
||||
* created with {@link RunningMode.IMAGE}.
|
||||
*
|
||||
* <p>TODO update java doc for input image format.
|
||||
*
|
||||
* <p>{@link ImageSegmenter} supports the following color space types:
|
||||
*
|
||||
* <ul>
|
||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
|
||||
* region-of-interest.
|
||||
* @throws MediaPipeException if there is an internal error. Or if {@link ImageSegmenter} is not
|
||||
* created wtih {@link ResultListener} set in {@link ImageSegmenterOptions}.
|
||||
*/
|
||||
public void segmentWithResultListener(MPImage image) {
|
||||
segmentWithResultListener(image, ImageProcessingOptions.builder().build());
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs image segmentation on the provided single image, and provides zero-copied results via
|
||||
* {@link ResultListener} in {@link ImageSegmenterOptions}. Only use this method when the {@link
|
||||
* ImageSegmenter} is created with {@link RunningMode.IMAGE}.
|
||||
*
|
||||
* <p>TODO update java doc for input image format.
|
||||
*
|
||||
* <p>{@link ImageSegmenter} supports the following color space types:
|
||||
*
|
||||
* <ul>
|
||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
|
||||
* input image before running inference. Note that region-of-interest is <b>not</b> supported
|
||||
* by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in
|
||||
* this method throwing an IllegalArgumentException.
|
||||
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
|
||||
* region-of-interest.
|
||||
* @throws MediaPipeException if there is an internal error. Or if {@link ImageSegmenter} is not
|
||||
* created wtih {@link ResultListener} set in {@link ImageSegmenterOptions}.
|
||||
*/
|
||||
public void segmentWithResultListener(
|
||||
MPImage image, ImageProcessingOptions imageProcessingOptions) {
|
||||
if (!hasResultListener) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||
"ResultListener is not set in the ImageSegmenterOptions, but this method expects a"
|
||||
+ " ResultListener to process ImageSegmentationResult.");
|
||||
}
|
||||
validateImageProcessingOptions(imageProcessingOptions);
|
||||
var unused = processImageData(image, imageProcessingOptions);
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs image segmentation on the provided video frame with default image processing options,
|
||||
* i.e. without any rotation applied, and the results will be available via the {@link
|
||||
* ResultListener} provided in the {@link ImageSegmenterOptions}. Only use this method when the
|
||||
* {@link HandLandmarker} is created with {@link RunningMode.VIDEO}.
|
||||
* i.e. without any rotation applied. Only use this method when the {@link ImageSegmenter} is
|
||||
* created with {@link RunningMode.VIDEO}.
|
||||
*
|
||||
* <p>It's required to provide the video frame's timestamp (in milliseconds). The input timestamps
|
||||
* must be monotonically increasing.
|
||||
|
@ -236,21 +320,21 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
|
|||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param timestampMs the input timestamp (in milliseconds).
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
* @throws MediaPipeException if there is an internal error. Or if {@link ImageSegmenter} is
|
||||
* created with a {@link ResultListener}.
|
||||
*/
|
||||
public void segmentForVideo(MPImage image, long timestampMs) {
|
||||
segmentForVideo(image, ImageProcessingOptions.builder().build(), timestampMs);
|
||||
public ImageSegmenterResult segmentForVideo(MPImage image, long timestampMs) {
|
||||
return segmentForVideo(image, ImageProcessingOptions.builder().build(), timestampMs);
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs image segmentation on the provided video frame, and the results will be available via
|
||||
* the {@link ResultListener} provided in the {@link ImageSegmenterOptions}. Only use this method
|
||||
* when the {@link ImageSegmenter} is created with {@link RunningMode.VIDEO}.
|
||||
* Performs image segmentation on the provided video frame. Only use this method when the {@link
|
||||
* ImageSegmenter} is created with {@link RunningMode.VIDEO}.
|
||||
*
|
||||
* <p>It's required to provide the video frame's timestamp (in milliseconds). The input timestamps
|
||||
* must be monotonically increasing.
|
||||
*
|
||||
* <p>{@link HandLandmarker} supports the following color space types:
|
||||
* <p>{@link ImageSegmenter} supports the following color space types:
|
||||
*
|
||||
* <ul>
|
||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
|
@ -264,20 +348,81 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
|
|||
* @param timestampMs the input timestamp (in milliseconds).
|
||||
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
|
||||
* region-of-interest.
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
* @throws MediaPipeException if there is an internal error. Or if {@link ImageSegmenter} is
|
||||
* created with a {@link ResultListener}.
|
||||
*/
|
||||
public void segmentForVideo(
|
||||
public ImageSegmenterResult segmentForVideo(
|
||||
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
|
||||
if (hasResultListener) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||
"ResultListener is provided in the ImageSegmenterOptions, but this method will return an"
|
||||
+ " ImageSegmentationResult.");
|
||||
}
|
||||
validateImageProcessingOptions(imageProcessingOptions);
|
||||
ImageSegmenterResult unused =
|
||||
(ImageSegmenterResult) processVideoData(image, imageProcessingOptions, timestampMs);
|
||||
return (ImageSegmenterResult) processVideoData(image, imageProcessingOptions, timestampMs);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sends live image data to perform hand landmarks detection with default image processing
|
||||
* options, i.e. without any rotation applied, and the results will be available via the {@link
|
||||
* ResultListener} provided in the {@link ImageSegmenterOptions}. Only use this method when the
|
||||
* {@link ImageSegmenter } is created with {@link RunningMode.LIVE_STREAM}.
|
||||
* Performs image segmentation on the provided video frame with default image processing options,
|
||||
* i.e. without any rotation applied, and provides zero-copied results via {@link ResultListener}
|
||||
* in {@link ImageSegmenterOptions}. Only use this method when the {@link ImageSegmenter} is
|
||||
* created with {@link RunningMode.VIDEO}.
|
||||
*
|
||||
* <p>It's required to provide the video frame's timestamp (in milliseconds). The input timestamps
|
||||
* must be monotonically increasing.
|
||||
*
|
||||
* <p>{@link ImageSegmenter} supports the following color space types:
|
||||
*
|
||||
* <ul>
|
||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param timestampMs the input timestamp (in milliseconds).
|
||||
* @throws MediaPipeException if there is an internal error. Or if {@link ImageSegmenter} is not
|
||||
* created wtih {@link ResultListener} set in {@link ImageSegmenterOptions}.
|
||||
*/
|
||||
public void segmentForVideoWithResultListener(MPImage image, long timestampMs) {
|
||||
segmentForVideoWithResultListener(image, ImageProcessingOptions.builder().build(), timestampMs);
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs image segmentation on the provided video frame, and provides zero-copied results via
|
||||
* {@link ResultListener} in {@link ImageSegmenterOptions}. Only use this method when the {@link
|
||||
* ImageSegmenter} is created with {@link RunningMode.VIDEO}.
|
||||
*
|
||||
* <p>It's required to provide the video frame's timestamp (in milliseconds). The input timestamps
|
||||
* must be monotonically increasing.
|
||||
*
|
||||
* <p>{@link ImageSegmenter} supports the following color space types:
|
||||
*
|
||||
* <ul>
|
||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param timestampMs the input timestamp (in milliseconds).
|
||||
* @throws MediaPipeException if there is an internal error. Or if {@link ImageSegmenter} is not
|
||||
* created wtih {@link ResultListener} set in {@link ImageSegmenterOptions}.
|
||||
*/
|
||||
public void segmentForVideoWithResultListener(
|
||||
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
|
||||
if (!hasResultListener) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||
"ResultListener is not set in the ImageSegmenterOptions, but this method expects a"
|
||||
+ " ResultListener to process ImageSegmentationResult.");
|
||||
}
|
||||
validateImageProcessingOptions(imageProcessingOptions);
|
||||
var unused = processVideoData(image, imageProcessingOptions, timestampMs);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sends live image data to perform image segmentation with default image processing options, i.e.
|
||||
* without any rotation applied, and the results will be available via the {@link ResultListener}
|
||||
* provided in the {@link ImageSegmenterOptions}. Only use this method when the {@link
|
||||
* ImageSegmenter } is created with {@link RunningMode.LIVE_STREAM}.
|
||||
*
|
||||
* <p>It's required to provide a timestamp (in milliseconds) to indicate when the input image is
|
||||
* sent to the image segmenter. The input timestamps must be monotonically increasing.
|
||||
|
@ -360,8 +505,8 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
|
|||
public abstract Builder setOutputType(OutputType value);
|
||||
|
||||
/**
|
||||
* Sets the {@link ResultListener} to receive the segmentation results when the graph pipeline
|
||||
* is done processing an image.
|
||||
* Sets an optional {@link ResultListener} to receive the segmentation results when the graph
|
||||
* pipeline is done processing an image.
|
||||
*/
|
||||
public abstract Builder setResultListener(
|
||||
ResultListener<ImageSegmenterResult, MPImage> value);
|
||||
|
@ -375,11 +520,18 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
|
|||
* Validates and builds the {@link ImageSegmenterOptions} instance.
|
||||
*
|
||||
* @throws IllegalArgumentException if the result listener and the running mode are not
|
||||
* properly configured. The result listener should only be set when the image segmenter is
|
||||
* in the live stream mode.
|
||||
* properly configured. The result listener must be set when the image segmenter is in the
|
||||
* live stream mode.
|
||||
*/
|
||||
public final ImageSegmenterOptions build() {
|
||||
ImageSegmenterOptions options = autoBuild();
|
||||
if (options.runningMode() == RunningMode.LIVE_STREAM) {
|
||||
if (!options.resultListener().isPresent()) {
|
||||
throw new IllegalArgumentException(
|
||||
"The image segmenter is in the live stream mode, a user-defined result listener"
|
||||
+ " must be provided in ImageSegmenterOptions.");
|
||||
}
|
||||
}
|
||||
return options;
|
||||
}
|
||||
}
|
||||
|
@ -392,7 +544,7 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
|
|||
|
||||
abstract OutputType outputType();
|
||||
|
||||
abstract ResultListener<ImageSegmenterResult, MPImage> resultListener();
|
||||
abstract Optional<ResultListener<ImageSegmenterResult, MPImage>> resultListener();
|
||||
|
||||
abstract Optional<ErrorListener> errorListener();
|
||||
|
||||
|
@ -410,8 +562,7 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
|
|||
return new AutoValue_ImageSegmenter_ImageSegmenterOptions.Builder()
|
||||
.setRunningMode(RunningMode.IMAGE)
|
||||
.setDisplayNamesLocale("en")
|
||||
.setOutputType(OutputType.CATEGORY_MASK)
|
||||
.setResultListener((result, image) -> {});
|
||||
.setOutputType(OutputType.CATEGORY_MASK);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -437,6 +588,7 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
|
|||
segmenterOptionsBuilder.setOutputType(
|
||||
SegmenterOptionsProto.SegmenterOptions.OutputType.CATEGORY_MASK);
|
||||
}
|
||||
|
||||
// TODO: remove this once activation is handled in metadata and grpah level.
|
||||
segmenterOptionsBuilder.setActivation(
|
||||
SegmenterOptionsProto.SegmenterOptions.Activation.SOFTMAX);
|
||||
|
|
|
@ -53,112 +53,108 @@ public class ImageSegmenterTest {
|
|||
|
||||
@RunWith(AndroidJUnit4.class)
|
||||
public static final class General extends ImageSegmenterTest {
|
||||
|
||||
@Test
|
||||
public void segment_successWithCategoryMask() throws Exception {
|
||||
final String inputImageName = "segmentation_input_rotation0.jpg";
|
||||
final String goldenImageName = "segmentation_golden_rotation0.png";
|
||||
MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
|
||||
ImageSegmenterOptions options =
|
||||
ImageSegmenterOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
||||
.setOutputType(ImageSegmenterOptions.OutputType.CATEGORY_MASK)
|
||||
.setResultListener(
|
||||
(actualResult, inputImage) -> {
|
||||
List<MPImage> segmentations = actualResult.segmentations();
|
||||
assertThat(segmentations.size()).isEqualTo(1);
|
||||
MPImage actualMaskBuffer = actualResult.segmentations().get(0);
|
||||
verifyCategoryMask(
|
||||
actualMaskBuffer,
|
||||
expectedMaskBuffer,
|
||||
GOLDEN_MASK_SIMILARITY,
|
||||
MAGNIFICATION_FACTOR);
|
||||
})
|
||||
.build();
|
||||
ImageSegmenter imageSegmenter =
|
||||
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
imageSegmenter.segment(getImageFromAsset(inputImageName));
|
||||
ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName));
|
||||
List<MPImage> segmentations = actualResult.segmentations();
|
||||
assertThat(segmentations.size()).isEqualTo(1);
|
||||
MPImage actualMaskBuffer = actualResult.segmentations().get(0);
|
||||
MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
|
||||
verifyCategoryMask(
|
||||
actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY, MAGNIFICATION_FACTOR);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void segment_successWithConfidenceMask() throws Exception {
|
||||
final String inputImageName = "cat.jpg";
|
||||
final String goldenImageName = "cat_mask.jpg";
|
||||
MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
|
||||
ImageSegmenterOptions options =
|
||||
ImageSegmenterOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
||||
.setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
|
||||
.setResultListener(
|
||||
(actualResult, inputImage) -> {
|
||||
List<MPImage> segmentations = actualResult.segmentations();
|
||||
assertThat(segmentations.size()).isEqualTo(21);
|
||||
// Cat category index 8.
|
||||
MPImage actualMaskBuffer = actualResult.segmentations().get(8);
|
||||
verifyConfidenceMask(
|
||||
actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
|
||||
})
|
||||
.build();
|
||||
ImageSegmenter imageSegmenter =
|
||||
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
imageSegmenter.segment(getImageFromAsset(inputImageName));
|
||||
ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName));
|
||||
List<MPImage> segmentations = actualResult.segmentations();
|
||||
assertThat(segmentations.size()).isEqualTo(21);
|
||||
// Cat category index 8.
|
||||
MPImage actualMaskBuffer = actualResult.segmentations().get(8);
|
||||
MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
|
||||
verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void segment_successWith128x128Segmentation() throws Exception {
|
||||
final String inputImageName = "mozart_square.jpg";
|
||||
final String goldenImageName = "selfie_segm_128_128_3_expected_mask.jpg";
|
||||
MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
|
||||
ImageSegmenterOptions options =
|
||||
ImageSegmenterOptions.builder()
|
||||
.setBaseOptions(
|
||||
BaseOptions.builder().setModelAssetPath(SELFIE_128x128_MODEL_FILE).build())
|
||||
.setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
|
||||
.setResultListener(
|
||||
(actualResult, inputImage) -> {
|
||||
List<MPImage> segmentations = actualResult.segmentations();
|
||||
assertThat(segmentations.size()).isEqualTo(2);
|
||||
// Selfie category index 1.
|
||||
MPImage actualMaskBuffer = actualResult.segmentations().get(1);
|
||||
verifyConfidenceMask(
|
||||
actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
|
||||
})
|
||||
.build();
|
||||
ImageSegmenter imageSegmenter =
|
||||
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
imageSegmenter.segment(getImageFromAsset(inputImageName));
|
||||
ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName));
|
||||
List<MPImage> segmentations = actualResult.segmentations();
|
||||
assertThat(segmentations.size()).isEqualTo(2);
|
||||
// Selfie category index 1.
|
||||
MPImage actualMaskBuffer = actualResult.segmentations().get(1);
|
||||
MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
|
||||
verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
|
||||
}
|
||||
|
||||
// TODO: enable this unit test once activation option is supported in metadata.
|
||||
// @Test
|
||||
// public void segment_successWith144x256Segmentation() throws Exception {
|
||||
// final String inputImageName = "mozart_square.jpg";
|
||||
// final String goldenImageName = "selfie_segm_144_256_3_expected_mask.jpg";
|
||||
// MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
|
||||
// ImageSegmenterOptions options =
|
||||
// ImageSegmenterOptions.builder()
|
||||
// .setBaseOptions(
|
||||
// BaseOptions.builder().setModelAssetPath(SELFIE_144x256_MODEL_FILE).build())
|
||||
// .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
|
||||
// .setActivation(ImageSegmenterOptions.Activation.NONE)
|
||||
// .setResultListener(
|
||||
// (actualResult, inputImage) -> {
|
||||
// List<MPImage> segmentations = actualResult.segmentations();
|
||||
// assertThat(segmentations.size()).isEqualTo(1);
|
||||
// MPImage actualMaskBuffer = actualResult.segmentations().get(0);
|
||||
// verifyConfidenceMask(
|
||||
// actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
|
||||
// })
|
||||
// .build();
|
||||
// ImageSegmenter imageSegmenter =
|
||||
// ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(),
|
||||
// options);
|
||||
// imageSegmenter.segment(getImageFromAsset(inputImageName));
|
||||
// }
|
||||
// @Test
|
||||
// public void segment_successWith144x256Segmentation() throws Exception {
|
||||
// final String inputImageName = "mozart_square.jpg";
|
||||
// final String goldenImageName = "selfie_segm_144_256_3_expected_mask.jpg";
|
||||
// ImageSegmenterOptions options =
|
||||
// ImageSegmenterOptions.builder()
|
||||
// .setBaseOptions(
|
||||
// BaseOptions.builder().setModelAssetPath(SELFIE_144x256_MODEL_FILE).build())
|
||||
// .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
|
||||
// .build();
|
||||
// ImageSegmenter imageSegmenter =
|
||||
// ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
// ImageSegmenterResult actualResult =
|
||||
// imageSegmenter.segment(getImageFromAsset(inputImageName));
|
||||
// List<MPImage> segmentations = actualResult.segmentations();
|
||||
// assertThat(segmentations.size()).isEqualTo(1);
|
||||
// MPImage actualMaskBuffer = actualResult.segmentations().get(0);
|
||||
// MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
|
||||
// verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
|
||||
// }
|
||||
}
|
||||
|
||||
@RunWith(AndroidJUnit4.class)
|
||||
public static final class RunningModeTest extends ImageSegmenterTest {
|
||||
@Test
|
||||
public void create_failsWithMissingResultListenerInLiveSteamMode() throws Exception {
|
||||
IllegalArgumentException exception =
|
||||
assertThrows(
|
||||
IllegalArgumentException.class,
|
||||
() ->
|
||||
ImageSegmenterOptions.builder()
|
||||
.setBaseOptions(
|
||||
BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
||||
.setRunningMode(RunningMode.LIVE_STREAM)
|
||||
.build());
|
||||
assertThat(exception)
|
||||
.hasMessageThat()
|
||||
.contains("a user-defined result listener must be provided");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void segment_failsWithCallingWrongApiInImageMode() throws Exception {
|
||||
ImageSegmenterOptions options =
|
||||
|
@ -166,7 +162,6 @@ public class ImageSegmenterTest {
|
|||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
||||
.setRunningMode(RunningMode.IMAGE)
|
||||
.build();
|
||||
|
||||
ImageSegmenter imageSegmenter =
|
||||
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
MediaPipeException exception =
|
||||
|
@ -182,6 +177,13 @@ public class ImageSegmenterTest {
|
|||
() ->
|
||||
imageSegmenter.segmentAsync(getImageFromAsset(CAT_IMAGE), /* timestampsMs= */ 0));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
|
||||
exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() -> imageSegmenter.segmentWithResultListener(getImageFromAsset(CAT_IMAGE)));
|
||||
assertThat(exception)
|
||||
.hasMessageThat()
|
||||
.contains("ResultListener is not set in the ImageSegmenterOptions");
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -191,7 +193,6 @@ public class ImageSegmenterTest {
|
|||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
||||
.setRunningMode(RunningMode.VIDEO)
|
||||
.build();
|
||||
|
||||
ImageSegmenter imageSegmenter =
|
||||
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
MediaPipeException exception =
|
||||
|
@ -204,6 +205,15 @@ public class ImageSegmenterTest {
|
|||
() ->
|
||||
imageSegmenter.segmentAsync(getImageFromAsset(CAT_IMAGE), /* timestampsMs= */ 0));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
|
||||
exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() ->
|
||||
imageSegmenter.segmentForVideoWithResultListener(
|
||||
getImageFromAsset(CAT_IMAGE), /* timestampsMs= */ 0));
|
||||
assertThat(exception)
|
||||
.hasMessageThat()
|
||||
.contains("ResultListener is not set in the ImageSegmenterOptions");
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -214,18 +224,18 @@ public class ImageSegmenterTest {
|
|||
.setRunningMode(RunningMode.LIVE_STREAM)
|
||||
.setResultListener((result, inputImage) -> {})
|
||||
.build();
|
||||
|
||||
ImageSegmenter imageSegmenter =
|
||||
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
MediaPipeException exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class, () -> imageSegmenter.segment(getImageFromAsset(CAT_IMAGE)));
|
||||
MediaPipeException.class,
|
||||
() -> imageSegmenter.segmentWithResultListener(getImageFromAsset(CAT_IMAGE)));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the image mode");
|
||||
exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() ->
|
||||
imageSegmenter.segmentForVideo(
|
||||
imageSegmenter.segmentForVideoWithResultListener(
|
||||
getImageFromAsset(CAT_IMAGE), /* timestampsMs= */ 0));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
|
||||
}
|
||||
|
@ -234,51 +244,94 @@ public class ImageSegmenterTest {
|
|||
public void segment_successWithImageMode() throws Exception {
|
||||
final String inputImageName = "cat.jpg";
|
||||
final String goldenImageName = "cat_mask.jpg";
|
||||
ImageSegmenterOptions options =
|
||||
ImageSegmenterOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
||||
.setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
|
||||
.setRunningMode(RunningMode.IMAGE)
|
||||
.build();
|
||||
ImageSegmenter imageSegmenter =
|
||||
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName));
|
||||
List<MPImage> segmentations = actualResult.segmentations();
|
||||
assertThat(segmentations.size()).isEqualTo(21);
|
||||
// Cat category index 8.
|
||||
MPImage actualMaskBuffer = actualResult.segmentations().get(8);
|
||||
MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
|
||||
verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void segment_successWithImageModeWithResultListener() throws Exception {
|
||||
final String inputImageName = "cat.jpg";
|
||||
final String goldenImageName = "cat_mask.jpg";
|
||||
MPImage expectedResult = getImageFromAsset(goldenImageName);
|
||||
ImageSegmenterOptions options =
|
||||
ImageSegmenterOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
||||
.setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
|
||||
.setRunningMode(RunningMode.IMAGE)
|
||||
.setResultListener(
|
||||
(actualResult, inputImage) -> {
|
||||
List<MPImage> segmentations = actualResult.segmentations();
|
||||
assertThat(segmentations.size()).isEqualTo(21);
|
||||
// Cat category index 8.
|
||||
MPImage actualMaskBuffer = actualResult.segmentations().get(8);
|
||||
(segmenterResult, inputImage) -> {
|
||||
verifyConfidenceMask(
|
||||
actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
|
||||
segmenterResult.segmentations().get(8),
|
||||
expectedResult,
|
||||
GOLDEN_MASK_SIMILARITY);
|
||||
})
|
||||
.build();
|
||||
ImageSegmenter imageSegmenter =
|
||||
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
imageSegmenter.segment(getImageFromAsset(inputImageName));
|
||||
imageSegmenter.segmentWithResultListener(getImageFromAsset(inputImageName));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void segment_successWithVideoMode() throws Exception {
|
||||
final String inputImageName = "cat.jpg";
|
||||
final String goldenImageName = "cat_mask.jpg";
|
||||
ImageSegmenterOptions options =
|
||||
ImageSegmenterOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
||||
.setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
|
||||
.setRunningMode(RunningMode.VIDEO)
|
||||
.build();
|
||||
ImageSegmenter imageSegmenter =
|
||||
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
|
||||
for (int i = 0; i < 3; i++) {
|
||||
ImageSegmenterResult actualResult =
|
||||
imageSegmenter.segmentForVideo(
|
||||
getImageFromAsset(inputImageName), /* timestampsMs= */ i);
|
||||
List<MPImage> segmentations = actualResult.segmentations();
|
||||
assertThat(segmentations.size()).isEqualTo(21);
|
||||
// Cat category index 8.
|
||||
MPImage actualMaskBuffer = actualResult.segmentations().get(8);
|
||||
verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void segment_successWithVideoModeWithResultListener() throws Exception {
|
||||
final String inputImageName = "cat.jpg";
|
||||
final String goldenImageName = "cat_mask.jpg";
|
||||
MPImage expectedResult = getImageFromAsset(goldenImageName);
|
||||
ImageSegmenterOptions options =
|
||||
ImageSegmenterOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
||||
.setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
|
||||
.setRunningMode(RunningMode.VIDEO)
|
||||
.setResultListener(
|
||||
(actualResult, inputImage) -> {
|
||||
List<MPImage> segmentations = actualResult.segmentations();
|
||||
assertThat(segmentations.size()).isEqualTo(21);
|
||||
// Cat category index 8.
|
||||
MPImage actualMaskBuffer = actualResult.segmentations().get(8);
|
||||
(segmenterResult, inputImage) -> {
|
||||
verifyConfidenceMask(
|
||||
actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
|
||||
segmenterResult.segmentations().get(8),
|
||||
expectedResult,
|
||||
GOLDEN_MASK_SIMILARITY);
|
||||
})
|
||||
.build();
|
||||
ImageSegmenter imageSegmenter =
|
||||
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
for (int i = 0; i < 3; i++) {
|
||||
imageSegmenter.segmentForVideo(getImageFromAsset(inputImageName), /* timestampsMs= */ i);
|
||||
imageSegmenter.segmentForVideoWithResultListener(
|
||||
getImageFromAsset(inputImageName), /* timestampsMs= */ i);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user