From 9554836145e528ac6a8e3abfc32271606d39c2b0 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 4 Apr 2023 13:37:10 -0700 Subject: [PATCH] Update java image segmenter to always output confidence masks and optionally output category mask. PiperOrigin-RevId: 521852718 --- .../vision/imagesegmenter/ImageSegmenter.java | 112 +++++++++--------- .../imagesegmenter/ImageSegmenterResult.java | 19 +-- .../InteractiveSegmenter.java | 2 - .../imagesegmenter/ImageSegmenterTest.java | 79 ++++++------ .../InteractiveSegmenterTest.java | 7 +- 5 files changed, 112 insertions(+), 107 deletions(-) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java index e8e0e4051..b809ab963 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java @@ -79,10 +79,15 @@ public final class ImageSegmenter extends BaseVisionTaskApi { private static final List INPUT_STREAMS = Collections.unmodifiableList( Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME)); - private static final int CONFIDENCE_MASKS_OUT_STREAM_INDEX = 0; + private static final List OUTPUT_STREAMS = + Collections.unmodifiableList( + Arrays.asList( + "GROUPED_SEGMENTATION:segmented_mask_out", + "IMAGE:image_out", + "SEGMENTATION:0:segmentation")); + private static final int GROUPED_SEGMENTATION_OUT_STREAM_INDEX = 0; private static final int IMAGE_OUT_STREAM_INDEX = 1; - private static final int CONFIDENCE_MASK_OUT_STREAM_INDEX = 2; - private static final int CATEGORY_MASK_OUT_STREAM_INDEX = 3; + private static final int SEGMENTATION_OUT_STREAM_INDEX = 2; private static final String TASK_GRAPH_NAME = "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph"; private static final String TENSORS_TO_SEGMENTATION_CALCULATOR_NAME = @@ -99,13 +104,6 @@ public final class ImageSegmenter extends BaseVisionTaskApi { */ public static ImageSegmenter createFromOptions( Context context, ImageSegmenterOptions segmenterOptions) { - List outputStreams = new ArrayList<>(); - outputStreams.add("CONFIDENCE_MASKS:confidence_masks"); - outputStreams.add("IMAGE:image_out"); - outputStreams.add("CONFIDENCE_MASK:0:confidence_mask"); - if (segmenterOptions.outputCategoryMask()) { - outputStreams.add("CATEGORY_MASK:category_mask"); - } // TODO: Consolidate OutputHandler and TaskRunner. OutputHandler handler = new OutputHandler<>(); handler.setOutputPacketConverter( @@ -113,62 +111,50 @@ public final class ImageSegmenter extends BaseVisionTaskApi { @Override public ImageSegmenterResult convertToTaskResult(List packets) throws MediaPipeException { - if (packets.get(CONFIDENCE_MASKS_OUT_STREAM_INDEX).isEmpty()) { + if (packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).isEmpty()) { return ImageSegmenterResult.create( new ArrayList<>(), - Optional.empty(), - packets.get(CONFIDENCE_MASKS_OUT_STREAM_INDEX).getTimestamp()); + packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).getTimestamp()); } - List confidenceMasks = new ArrayList<>(); - int width = PacketGetter.getImageWidth(packets.get(CONFIDENCE_MASK_OUT_STREAM_INDEX)); - int height = PacketGetter.getImageHeight(packets.get(CONFIDENCE_MASK_OUT_STREAM_INDEX)); - int confidenceMasksListSize = - PacketGetter.getImageListSize(packets.get(CONFIDENCE_MASKS_OUT_STREAM_INDEX)); - ByteBuffer[] buffersArray = new ByteBuffer[confidenceMasksListSize]; + List segmentedMasks = new ArrayList<>(); + int width = PacketGetter.getImageWidth(packets.get(SEGMENTATION_OUT_STREAM_INDEX)); + int height = PacketGetter.getImageHeight(packets.get(SEGMENTATION_OUT_STREAM_INDEX)); + int imageFormat = + segmenterOptions.outputType() == ImageSegmenterOptions.OutputType.CONFIDENCE_MASK + ? MPImage.IMAGE_FORMAT_VEC32F1 + : MPImage.IMAGE_FORMAT_ALPHA; + int imageListSize = + PacketGetter.getImageListSize(packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX)); + ByteBuffer[] buffersArray = new ByteBuffer[imageListSize]; // If resultListener is not provided, the resulted MPImage is deep copied from mediapipe // graph. If provided, the result MPImage is wrapping the mediapipe packet memory. - boolean copyImage = !segmenterOptions.resultListener().isPresent(); - if (copyImage) { - for (int i = 0; i < confidenceMasksListSize; i++) { - buffersArray[i] = ByteBuffer.allocateDirect(width * height * 4); + if (!segmenterOptions.resultListener().isPresent()) { + for (int i = 0; i < imageListSize; i++) { + buffersArray[i] = + ByteBuffer.allocateDirect( + width * height * (imageFormat == MPImage.IMAGE_FORMAT_VEC32F1 ? 4 : 1)); } } if (!PacketGetter.getImageList( - packets.get(CONFIDENCE_MASKS_OUT_STREAM_INDEX), buffersArray, copyImage)) { + packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX), + buffersArray, + !segmenterOptions.resultListener().isPresent())) { throw new MediaPipeException( MediaPipeException.StatusCode.INTERNAL.ordinal(), - "There is an error getting segmented masks."); + "There is an error getting segmented masks. It usually results from incorrect" + + " options of unsupported OutputType of given model."); } for (ByteBuffer buffer : buffersArray) { ByteBufferImageBuilder builder = - new ByteBufferImageBuilder(buffer, width, height, MPImage.IMAGE_FORMAT_VEC32F1); - confidenceMasks.add(builder.build()); - } - Optional categoryMask = Optional.empty(); - if (segmenterOptions.outputCategoryMask()) { - ByteBuffer buffer; - if (copyImage) { - buffer = ByteBuffer.allocateDirect(width * height); - if (!PacketGetter.getImageData( - packets.get(CATEGORY_MASK_OUT_STREAM_INDEX), buffer)) { - throw new MediaPipeException( - MediaPipeException.StatusCode.INTERNAL.ordinal(), - "There is an error getting category mask."); - } - } else { - buffer = - PacketGetter.getImageDataDirectly(packets.get(CATEGORY_MASK_OUT_STREAM_INDEX)); - } - ByteBufferImageBuilder builder = - new ByteBufferImageBuilder(buffer, width, height, MPImage.IMAGE_FORMAT_ALPHA); - categoryMask = Optional.of(builder.build()); + new ByteBufferImageBuilder(buffer, width, height, imageFormat); + segmentedMasks.add(builder.build()); } + return ImageSegmenterResult.create( - confidenceMasks, - categoryMask, + segmentedMasks, BaseVisionTaskApi.generateResultTimestampMs( segmenterOptions.runningMode(), - packets.get(CONFIDENCE_MASKS_OUT_STREAM_INDEX))); + packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX))); } @Override @@ -188,7 +174,7 @@ public final class ImageSegmenter extends BaseVisionTaskApi { .setTaskRunningModeName(segmenterOptions.runningMode().name()) .setTaskGraphName(TASK_GRAPH_NAME) .setInputStreams(INPUT_STREAMS) - .setOutputStreams(outputStreams) + .setOutputStreams(OUTPUT_STREAMS) .setTaskOptions(segmenterOptions) .setEnableFlowLimiting(segmenterOptions.runningMode() == RunningMode.LIVE_STREAM) .build(), @@ -567,8 +553,8 @@ public final class ImageSegmenter extends BaseVisionTaskApi { */ public abstract Builder setDisplayNamesLocale(String value); - /** Whether to output category mask. */ - public abstract Builder setOutputCategoryMask(boolean value); + /** The output type from image segmenter. */ + public abstract Builder setOutputType(OutputType value); /** * Sets an optional {@link ResultListener} to receive the segmentation results when the graph @@ -608,17 +594,27 @@ public final class ImageSegmenter extends BaseVisionTaskApi { abstract String displayNamesLocale(); - abstract boolean outputCategoryMask(); + abstract OutputType outputType(); abstract Optional> resultListener(); abstract Optional errorListener(); + /** The output type of segmentation results. */ + public enum OutputType { + // Gives a single output mask where each pixel represents the class which + // the pixel in the original image was predicted to belong to. + CATEGORY_MASK, + // Gives a list of output masks where, for each mask, each pixel represents + // the prediction confidence, usually in the [0, 1] range. + CONFIDENCE_MASK + } + public static Builder builder() { return new AutoValue_ImageSegmenter_ImageSegmenterOptions.Builder() .setRunningMode(RunningMode.IMAGE) .setDisplayNamesLocale("en") - .setOutputCategoryMask(false); + .setOutputType(OutputType.CATEGORY_MASK); } /** @@ -637,6 +633,14 @@ public final class ImageSegmenter extends BaseVisionTaskApi { SegmenterOptionsProto.SegmenterOptions.Builder segmenterOptionsBuilder = SegmenterOptionsProto.SegmenterOptions.newBuilder(); + if (outputType() == OutputType.CONFIDENCE_MASK) { + segmenterOptionsBuilder.setOutputType( + SegmenterOptionsProto.SegmenterOptions.OutputType.CONFIDENCE_MASK); + } else if (outputType() == OutputType.CATEGORY_MASK) { + segmenterOptionsBuilder.setOutputType( + SegmenterOptionsProto.SegmenterOptions.OutputType.CATEGORY_MASK); + } + taskOptionsBuilder.setSegmenterOptions(segmenterOptionsBuilder); return CalculatorOptions.newBuilder() .setExtension( diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterResult.java index 400894a66..69ab79c13 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterResult.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterResult.java @@ -19,7 +19,6 @@ import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.tasks.core.TaskResult; import java.util.Collections; import java.util.List; -import java.util.Optional; /** Represents the segmentation results generated by {@link ImageSegmenter}. */ @AutoValue @@ -28,24 +27,18 @@ public abstract class ImageSegmenterResult implements TaskResult { /** * Creates an {@link ImageSegmenterResult} instance from a list of segmentation MPImage. * - * @param confidenceMasks a {@link List} of MPImage in IMAGE_FORMAT_VEC32F1 format representing - * the confidence masks, where, for each mask, each pixel represents the prediction - * confidence, usually in the [0, 1] range. - * @param categoryMask an {@link Optional} MPImage in IMAGE_FORMAT_ALPHA format representing a - * category mask, where each pixel represents the class which the pixel in the original image - * was predicted to belong to. + * @param segmentations a {@link List} of MPImage representing the segmented masks. If OutputType + * is CATEGORY_MASK, the masks will be in IMAGE_FORMAT_ALPHA format. If OutputType is + * CONFIDENCE_MASK, the masks will be in IMAGE_FORMAT_VEC32F1 format. * @param timestampMs a timestamp for this result. */ // TODO: consolidate output formats across platforms. - public static ImageSegmenterResult create( - List confidenceMasks, Optional categoryMask, long timestampMs) { + public static ImageSegmenterResult create(List segmentations, long timestampMs) { return new AutoValue_ImageSegmenterResult( - Collections.unmodifiableList(confidenceMasks), categoryMask, timestampMs); + Collections.unmodifiableList(segmentations), timestampMs); } - public abstract List confidenceMasks(); - - public abstract Optional categoryMask(); + public abstract List segmentations(); @Override public abstract long timestampMs(); diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java index 2348aaadd..657716b6b 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java @@ -133,7 +133,6 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi { if (packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).isEmpty()) { return ImageSegmenterResult.create( new ArrayList<>(), - Optional.empty(), packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).getTimestamp()); } List segmentedMasks = new ArrayList<>(); @@ -173,7 +172,6 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi { return ImageSegmenterResult.create( segmentedMasks, - Optional.empty(), BaseVisionTaskApi.generateResultTimestampMs( RunningMode.IMAGE, packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX))); } diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterTest.java index 7acf1377e..3b35c21bc 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterTest.java @@ -61,13 +61,14 @@ public class ImageSegmenterTest { ImageSegmenterOptions options = ImageSegmenterOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) - .setOutputCategoryMask(true) + .setOutputType(ImageSegmenterOptions.OutputType.CATEGORY_MASK) .build(); ImageSegmenter imageSegmenter = ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options); ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName)); - assertThat(actualResult.categoryMask().isPresent()).isTrue(); - MPImage actualMaskBuffer = actualResult.categoryMask().get(); + List segmentations = actualResult.segmentations(); + assertThat(segmentations.size()).isEqualTo(1); + MPImage actualMaskBuffer = actualResult.segmentations().get(0); MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName); verifyCategoryMask( actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY, MAGNIFICATION_FACTOR); @@ -80,14 +81,15 @@ public class ImageSegmenterTest { ImageSegmenterOptions options = ImageSegmenterOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) + .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) .build(); ImageSegmenter imageSegmenter = ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options); ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName)); - List segmentations = actualResult.confidenceMasks(); + List segmentations = actualResult.segmentations(); assertThat(segmentations.size()).isEqualTo(21); // Cat category index 8. - MPImage actualMaskBuffer = segmentations.get(8); + MPImage actualMaskBuffer = actualResult.segmentations().get(8); MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName); verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY); } @@ -100,36 +102,40 @@ public class ImageSegmenterTest { ImageSegmenterOptions.builder() .setBaseOptions( BaseOptions.builder().setModelAssetPath(SELFIE_128x128_MODEL_FILE).build()) + .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) .build(); ImageSegmenter imageSegmenter = ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options); ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName)); - List segmentations = actualResult.confidenceMasks(); + List segmentations = actualResult.segmentations(); assertThat(segmentations.size()).isEqualTo(2); // Selfie category index 1. - MPImage actualMaskBuffer = segmentations.get(1); + MPImage actualMaskBuffer = actualResult.segmentations().get(1); MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName); verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY); } - @Test - public void segment_successWith144x256Segmentation() throws Exception { - final String inputImageName = "mozart_square.jpg"; - final String goldenImageName = "selfie_segm_144_256_3_expected_mask.jpg"; - ImageSegmenterOptions options = - ImageSegmenterOptions.builder() - .setBaseOptions( - BaseOptions.builder().setModelAssetPath(SELFIE_144x256_MODEL_FILE).build()) - .build(); - ImageSegmenter imageSegmenter = - ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options); - ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName)); - List segmentations = actualResult.confidenceMasks(); - assertThat(segmentations.size()).isEqualTo(1); - MPImage actualMaskBuffer = segmentations.get(0); - MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName); - verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY); - } + // TODO: enable this unit test once activation option is supported in metadata. + // @Test + // public void segment_successWith144x256Segmentation() throws Exception { + // final String inputImageName = "mozart_square.jpg"; + // final String goldenImageName = "selfie_segm_144_256_3_expected_mask.jpg"; + // ImageSegmenterOptions options = + // ImageSegmenterOptions.builder() + // .setBaseOptions( + // BaseOptions.builder().setModelAssetPath(SELFIE_144x256_MODEL_FILE).build()) + // .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) + // .build(); + // ImageSegmenter imageSegmenter = + // ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options); + // ImageSegmenterResult actualResult = + // imageSegmenter.segment(getImageFromAsset(inputImageName)); + // List segmentations = actualResult.segmentations(); + // assertThat(segmentations.size()).isEqualTo(1); + // MPImage actualMaskBuffer = actualResult.segmentations().get(0); + // MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName); + // verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY); + // } @Test public void getLabels_success() throws Exception { @@ -159,6 +165,7 @@ public class ImageSegmenterTest { ImageSegmenterOptions options = ImageSegmenterOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) + .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) .build(); ImageSegmenter imageSegmenter = ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -280,15 +287,16 @@ public class ImageSegmenterTest { ImageSegmenterOptions options = ImageSegmenterOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) + .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) .setRunningMode(RunningMode.IMAGE) .build(); ImageSegmenter imageSegmenter = ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options); ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName)); - List segmentations = actualResult.confidenceMasks(); + List segmentations = actualResult.segmentations(); assertThat(segmentations.size()).isEqualTo(21); // Cat category index 8. - MPImage actualMaskBuffer = segmentations.get(8); + MPImage actualMaskBuffer = actualResult.segmentations().get(8); MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName); verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY); } @@ -301,11 +309,12 @@ public class ImageSegmenterTest { ImageSegmenterOptions options = ImageSegmenterOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) + .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) .setRunningMode(RunningMode.IMAGE) .setResultListener( (segmenterResult, inputImage) -> { verifyConfidenceMask( - segmenterResult.confidenceMasks().get(8), + segmenterResult.segmentations().get(8), expectedResult, GOLDEN_MASK_SIMILARITY); }) @@ -322,6 +331,7 @@ public class ImageSegmenterTest { ImageSegmenterOptions options = ImageSegmenterOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) + .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) .setRunningMode(RunningMode.VIDEO) .build(); ImageSegmenter imageSegmenter = @@ -331,10 +341,10 @@ public class ImageSegmenterTest { ImageSegmenterResult actualResult = imageSegmenter.segmentForVideo( getImageFromAsset(inputImageName), /* timestampsMs= */ i); - List segmentations = actualResult.confidenceMasks(); + List segmentations = actualResult.segmentations(); assertThat(segmentations.size()).isEqualTo(21); // Cat category index 8. - MPImage actualMaskBuffer = segmentations.get(8); + MPImage actualMaskBuffer = actualResult.segmentations().get(8); verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY); } } @@ -347,11 +357,12 @@ public class ImageSegmenterTest { ImageSegmenterOptions options = ImageSegmenterOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) + .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) .setRunningMode(RunningMode.VIDEO) .setResultListener( (segmenterResult, inputImage) -> { verifyConfidenceMask( - segmenterResult.confidenceMasks().get(8), + segmenterResult.segmentations().get(8), expectedResult, GOLDEN_MASK_SIMILARITY); }) @@ -373,11 +384,12 @@ public class ImageSegmenterTest { ImageSegmenterOptions options = ImageSegmenterOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) + .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) .setRunningMode(RunningMode.LIVE_STREAM) .setResultListener( (segmenterResult, inputImage) -> { verifyConfidenceMask( - segmenterResult.confidenceMasks().get(8), + segmenterResult.segmentations().get(8), expectedResult, GOLDEN_MASK_SIMILARITY); }) @@ -399,11 +411,12 @@ public class ImageSegmenterTest { ImageSegmenterOptions options = ImageSegmenterOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) + .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) .setRunningMode(RunningMode.LIVE_STREAM) .setResultListener( (segmenterResult, inputImage) -> { verifyConfidenceMask( - segmenterResult.confidenceMasks().get(8), + segmenterResult.segmentations().get(8), expectedResult, GOLDEN_MASK_SIMILARITY); }) diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenterTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenterTest.java index 9351bc721..0d9581437 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenterTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenterTest.java @@ -60,10 +60,7 @@ public class InteractiveSegmenterTest { ApplicationProvider.getApplicationContext(), options); MPImage image = getImageFromAsset(inputImageName); ImageSegmenterResult actualResult = imageSegmenter.segment(image, roi); - // TODO update to correct category mask output. - // After InteractiveSegmenter updated according to (b/276519300), update this to use - // categoryMask field instead of confidenceMasks. - List segmentations = actualResult.confidenceMasks(); + List segmentations = actualResult.segmentations(); assertThat(segmentations.size()).isEqualTo(1); } @@ -82,7 +79,7 @@ public class InteractiveSegmenterTest { ApplicationProvider.getApplicationContext(), options); ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName), roi); - List segmentations = actualResult.confidenceMasks(); + List segmentations = actualResult.segmentations(); assertThat(segmentations.size()).isEqualTo(2); } }