diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java index b809ab963..e8e0e4051 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java @@ -79,15 +79,10 @@ public final class ImageSegmenter extends BaseVisionTaskApi { private static final List INPUT_STREAMS = Collections.unmodifiableList( Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME)); - private static final List OUTPUT_STREAMS = - Collections.unmodifiableList( - Arrays.asList( - "GROUPED_SEGMENTATION:segmented_mask_out", - "IMAGE:image_out", - "SEGMENTATION:0:segmentation")); - private static final int GROUPED_SEGMENTATION_OUT_STREAM_INDEX = 0; + private static final int CONFIDENCE_MASKS_OUT_STREAM_INDEX = 0; private static final int IMAGE_OUT_STREAM_INDEX = 1; - private static final int SEGMENTATION_OUT_STREAM_INDEX = 2; + private static final int CONFIDENCE_MASK_OUT_STREAM_INDEX = 2; + private static final int CATEGORY_MASK_OUT_STREAM_INDEX = 3; private static final String TASK_GRAPH_NAME = "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph"; private static final String TENSORS_TO_SEGMENTATION_CALCULATOR_NAME = @@ -104,6 +99,13 @@ public final class ImageSegmenter extends BaseVisionTaskApi { */ public static ImageSegmenter createFromOptions( Context context, ImageSegmenterOptions segmenterOptions) { + List outputStreams = new ArrayList<>(); + outputStreams.add("CONFIDENCE_MASKS:confidence_masks"); + outputStreams.add("IMAGE:image_out"); + outputStreams.add("CONFIDENCE_MASK:0:confidence_mask"); + if (segmenterOptions.outputCategoryMask()) { + outputStreams.add("CATEGORY_MASK:category_mask"); + } // TODO: Consolidate OutputHandler and TaskRunner. OutputHandler handler = new OutputHandler<>(); handler.setOutputPacketConverter( @@ -111,50 +113,62 @@ public final class ImageSegmenter extends BaseVisionTaskApi { @Override public ImageSegmenterResult convertToTaskResult(List packets) throws MediaPipeException { - if (packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).isEmpty()) { + if (packets.get(CONFIDENCE_MASKS_OUT_STREAM_INDEX).isEmpty()) { return ImageSegmenterResult.create( new ArrayList<>(), - packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).getTimestamp()); + Optional.empty(), + packets.get(CONFIDENCE_MASKS_OUT_STREAM_INDEX).getTimestamp()); } - List segmentedMasks = new ArrayList<>(); - int width = PacketGetter.getImageWidth(packets.get(SEGMENTATION_OUT_STREAM_INDEX)); - int height = PacketGetter.getImageHeight(packets.get(SEGMENTATION_OUT_STREAM_INDEX)); - int imageFormat = - segmenterOptions.outputType() == ImageSegmenterOptions.OutputType.CONFIDENCE_MASK - ? MPImage.IMAGE_FORMAT_VEC32F1 - : MPImage.IMAGE_FORMAT_ALPHA; - int imageListSize = - PacketGetter.getImageListSize(packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX)); - ByteBuffer[] buffersArray = new ByteBuffer[imageListSize]; + List confidenceMasks = new ArrayList<>(); + int width = PacketGetter.getImageWidth(packets.get(CONFIDENCE_MASK_OUT_STREAM_INDEX)); + int height = PacketGetter.getImageHeight(packets.get(CONFIDENCE_MASK_OUT_STREAM_INDEX)); + int confidenceMasksListSize = + PacketGetter.getImageListSize(packets.get(CONFIDENCE_MASKS_OUT_STREAM_INDEX)); + ByteBuffer[] buffersArray = new ByteBuffer[confidenceMasksListSize]; // If resultListener is not provided, the resulted MPImage is deep copied from mediapipe // graph. If provided, the result MPImage is wrapping the mediapipe packet memory. - if (!segmenterOptions.resultListener().isPresent()) { - for (int i = 0; i < imageListSize; i++) { - buffersArray[i] = - ByteBuffer.allocateDirect( - width * height * (imageFormat == MPImage.IMAGE_FORMAT_VEC32F1 ? 4 : 1)); + boolean copyImage = !segmenterOptions.resultListener().isPresent(); + if (copyImage) { + for (int i = 0; i < confidenceMasksListSize; i++) { + buffersArray[i] = ByteBuffer.allocateDirect(width * height * 4); } } if (!PacketGetter.getImageList( - packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX), - buffersArray, - !segmenterOptions.resultListener().isPresent())) { + packets.get(CONFIDENCE_MASKS_OUT_STREAM_INDEX), buffersArray, copyImage)) { throw new MediaPipeException( MediaPipeException.StatusCode.INTERNAL.ordinal(), - "There is an error getting segmented masks. It usually results from incorrect" - + " options of unsupported OutputType of given model."); + "There is an error getting segmented masks."); } for (ByteBuffer buffer : buffersArray) { ByteBufferImageBuilder builder = - new ByteBufferImageBuilder(buffer, width, height, imageFormat); - segmentedMasks.add(builder.build()); + new ByteBufferImageBuilder(buffer, width, height, MPImage.IMAGE_FORMAT_VEC32F1); + confidenceMasks.add(builder.build()); + } + Optional categoryMask = Optional.empty(); + if (segmenterOptions.outputCategoryMask()) { + ByteBuffer buffer; + if (copyImage) { + buffer = ByteBuffer.allocateDirect(width * height); + if (!PacketGetter.getImageData( + packets.get(CATEGORY_MASK_OUT_STREAM_INDEX), buffer)) { + throw new MediaPipeException( + MediaPipeException.StatusCode.INTERNAL.ordinal(), + "There is an error getting category mask."); + } + } else { + buffer = + PacketGetter.getImageDataDirectly(packets.get(CATEGORY_MASK_OUT_STREAM_INDEX)); + } + ByteBufferImageBuilder builder = + new ByteBufferImageBuilder(buffer, width, height, MPImage.IMAGE_FORMAT_ALPHA); + categoryMask = Optional.of(builder.build()); } - return ImageSegmenterResult.create( - segmentedMasks, + confidenceMasks, + categoryMask, BaseVisionTaskApi.generateResultTimestampMs( segmenterOptions.runningMode(), - packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX))); + packets.get(CONFIDENCE_MASKS_OUT_STREAM_INDEX))); } @Override @@ -174,7 +188,7 @@ public final class ImageSegmenter extends BaseVisionTaskApi { .setTaskRunningModeName(segmenterOptions.runningMode().name()) .setTaskGraphName(TASK_GRAPH_NAME) .setInputStreams(INPUT_STREAMS) - .setOutputStreams(OUTPUT_STREAMS) + .setOutputStreams(outputStreams) .setTaskOptions(segmenterOptions) .setEnableFlowLimiting(segmenterOptions.runningMode() == RunningMode.LIVE_STREAM) .build(), @@ -553,8 +567,8 @@ public final class ImageSegmenter extends BaseVisionTaskApi { */ public abstract Builder setDisplayNamesLocale(String value); - /** The output type from image segmenter. */ - public abstract Builder setOutputType(OutputType value); + /** Whether to output category mask. */ + public abstract Builder setOutputCategoryMask(boolean value); /** * Sets an optional {@link ResultListener} to receive the segmentation results when the graph @@ -594,27 +608,17 @@ public final class ImageSegmenter extends BaseVisionTaskApi { abstract String displayNamesLocale(); - abstract OutputType outputType(); + abstract boolean outputCategoryMask(); abstract Optional> resultListener(); abstract Optional errorListener(); - /** The output type of segmentation results. */ - public enum OutputType { - // Gives a single output mask where each pixel represents the class which - // the pixel in the original image was predicted to belong to. - CATEGORY_MASK, - // Gives a list of output masks where, for each mask, each pixel represents - // the prediction confidence, usually in the [0, 1] range. - CONFIDENCE_MASK - } - public static Builder builder() { return new AutoValue_ImageSegmenter_ImageSegmenterOptions.Builder() .setRunningMode(RunningMode.IMAGE) .setDisplayNamesLocale("en") - .setOutputType(OutputType.CATEGORY_MASK); + .setOutputCategoryMask(false); } /** @@ -633,14 +637,6 @@ public final class ImageSegmenter extends BaseVisionTaskApi { SegmenterOptionsProto.SegmenterOptions.Builder segmenterOptionsBuilder = SegmenterOptionsProto.SegmenterOptions.newBuilder(); - if (outputType() == OutputType.CONFIDENCE_MASK) { - segmenterOptionsBuilder.setOutputType( - SegmenterOptionsProto.SegmenterOptions.OutputType.CONFIDENCE_MASK); - } else if (outputType() == OutputType.CATEGORY_MASK) { - segmenterOptionsBuilder.setOutputType( - SegmenterOptionsProto.SegmenterOptions.OutputType.CATEGORY_MASK); - } - taskOptionsBuilder.setSegmenterOptions(segmenterOptionsBuilder); return CalculatorOptions.newBuilder() .setExtension( diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterResult.java index 69ab79c13..400894a66 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterResult.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterResult.java @@ -19,6 +19,7 @@ import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.tasks.core.TaskResult; import java.util.Collections; import java.util.List; +import java.util.Optional; /** Represents the segmentation results generated by {@link ImageSegmenter}. */ @AutoValue @@ -27,18 +28,24 @@ public abstract class ImageSegmenterResult implements TaskResult { /** * Creates an {@link ImageSegmenterResult} instance from a list of segmentation MPImage. * - * @param segmentations a {@link List} of MPImage representing the segmented masks. If OutputType - * is CATEGORY_MASK, the masks will be in IMAGE_FORMAT_ALPHA format. If OutputType is - * CONFIDENCE_MASK, the masks will be in IMAGE_FORMAT_VEC32F1 format. + * @param confidenceMasks a {@link List} of MPImage in IMAGE_FORMAT_VEC32F1 format representing + * the confidence masks, where, for each mask, each pixel represents the prediction + * confidence, usually in the [0, 1] range. + * @param categoryMask an {@link Optional} MPImage in IMAGE_FORMAT_ALPHA format representing a + * category mask, where each pixel represents the class which the pixel in the original image + * was predicted to belong to. * @param timestampMs a timestamp for this result. */ // TODO: consolidate output formats across platforms. - public static ImageSegmenterResult create(List segmentations, long timestampMs) { + public static ImageSegmenterResult create( + List confidenceMasks, Optional categoryMask, long timestampMs) { return new AutoValue_ImageSegmenterResult( - Collections.unmodifiableList(segmentations), timestampMs); + Collections.unmodifiableList(confidenceMasks), categoryMask, timestampMs); } - public abstract List segmentations(); + public abstract List confidenceMasks(); + + public abstract Optional categoryMask(); @Override public abstract long timestampMs(); diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java index 657716b6b..2348aaadd 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java @@ -133,6 +133,7 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi { if (packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).isEmpty()) { return ImageSegmenterResult.create( new ArrayList<>(), + Optional.empty(), packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).getTimestamp()); } List segmentedMasks = new ArrayList<>(); @@ -172,6 +173,7 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi { return ImageSegmenterResult.create( segmentedMasks, + Optional.empty(), BaseVisionTaskApi.generateResultTimestampMs( RunningMode.IMAGE, packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX))); } diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterTest.java index 3b35c21bc..7acf1377e 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterTest.java @@ -61,14 +61,13 @@ public class ImageSegmenterTest { ImageSegmenterOptions options = ImageSegmenterOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) - .setOutputType(ImageSegmenterOptions.OutputType.CATEGORY_MASK) + .setOutputCategoryMask(true) .build(); ImageSegmenter imageSegmenter = ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options); ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName)); - List segmentations = actualResult.segmentations(); - assertThat(segmentations.size()).isEqualTo(1); - MPImage actualMaskBuffer = actualResult.segmentations().get(0); + assertThat(actualResult.categoryMask().isPresent()).isTrue(); + MPImage actualMaskBuffer = actualResult.categoryMask().get(); MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName); verifyCategoryMask( actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY, MAGNIFICATION_FACTOR); @@ -81,15 +80,14 @@ public class ImageSegmenterTest { ImageSegmenterOptions options = ImageSegmenterOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) - .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) .build(); ImageSegmenter imageSegmenter = ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options); ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName)); - List segmentations = actualResult.segmentations(); + List segmentations = actualResult.confidenceMasks(); assertThat(segmentations.size()).isEqualTo(21); // Cat category index 8. - MPImage actualMaskBuffer = actualResult.segmentations().get(8); + MPImage actualMaskBuffer = segmentations.get(8); MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName); verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY); } @@ -102,40 +100,36 @@ public class ImageSegmenterTest { ImageSegmenterOptions.builder() .setBaseOptions( BaseOptions.builder().setModelAssetPath(SELFIE_128x128_MODEL_FILE).build()) - .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) .build(); ImageSegmenter imageSegmenter = ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options); ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName)); - List segmentations = actualResult.segmentations(); + List segmentations = actualResult.confidenceMasks(); assertThat(segmentations.size()).isEqualTo(2); // Selfie category index 1. - MPImage actualMaskBuffer = actualResult.segmentations().get(1); + MPImage actualMaskBuffer = segmentations.get(1); MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName); verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY); } - // TODO: enable this unit test once activation option is supported in metadata. - // @Test - // public void segment_successWith144x256Segmentation() throws Exception { - // final String inputImageName = "mozart_square.jpg"; - // final String goldenImageName = "selfie_segm_144_256_3_expected_mask.jpg"; - // ImageSegmenterOptions options = - // ImageSegmenterOptions.builder() - // .setBaseOptions( - // BaseOptions.builder().setModelAssetPath(SELFIE_144x256_MODEL_FILE).build()) - // .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) - // .build(); - // ImageSegmenter imageSegmenter = - // ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options); - // ImageSegmenterResult actualResult = - // imageSegmenter.segment(getImageFromAsset(inputImageName)); - // List segmentations = actualResult.segmentations(); - // assertThat(segmentations.size()).isEqualTo(1); - // MPImage actualMaskBuffer = actualResult.segmentations().get(0); - // MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName); - // verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY); - // } + @Test + public void segment_successWith144x256Segmentation() throws Exception { + final String inputImageName = "mozart_square.jpg"; + final String goldenImageName = "selfie_segm_144_256_3_expected_mask.jpg"; + ImageSegmenterOptions options = + ImageSegmenterOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(SELFIE_144x256_MODEL_FILE).build()) + .build(); + ImageSegmenter imageSegmenter = + ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options); + ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName)); + List segmentations = actualResult.confidenceMasks(); + assertThat(segmentations.size()).isEqualTo(1); + MPImage actualMaskBuffer = segmentations.get(0); + MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName); + verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY); + } @Test public void getLabels_success() throws Exception { @@ -165,7 +159,6 @@ public class ImageSegmenterTest { ImageSegmenterOptions options = ImageSegmenterOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) - .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) .build(); ImageSegmenter imageSegmenter = ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -287,16 +280,15 @@ public class ImageSegmenterTest { ImageSegmenterOptions options = ImageSegmenterOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) - .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) .setRunningMode(RunningMode.IMAGE) .build(); ImageSegmenter imageSegmenter = ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options); ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName)); - List segmentations = actualResult.segmentations(); + List segmentations = actualResult.confidenceMasks(); assertThat(segmentations.size()).isEqualTo(21); // Cat category index 8. - MPImage actualMaskBuffer = actualResult.segmentations().get(8); + MPImage actualMaskBuffer = segmentations.get(8); MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName); verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY); } @@ -309,12 +301,11 @@ public class ImageSegmenterTest { ImageSegmenterOptions options = ImageSegmenterOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) - .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) .setRunningMode(RunningMode.IMAGE) .setResultListener( (segmenterResult, inputImage) -> { verifyConfidenceMask( - segmenterResult.segmentations().get(8), + segmenterResult.confidenceMasks().get(8), expectedResult, GOLDEN_MASK_SIMILARITY); }) @@ -331,7 +322,6 @@ public class ImageSegmenterTest { ImageSegmenterOptions options = ImageSegmenterOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) - .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) .setRunningMode(RunningMode.VIDEO) .build(); ImageSegmenter imageSegmenter = @@ -341,10 +331,10 @@ public class ImageSegmenterTest { ImageSegmenterResult actualResult = imageSegmenter.segmentForVideo( getImageFromAsset(inputImageName), /* timestampsMs= */ i); - List segmentations = actualResult.segmentations(); + List segmentations = actualResult.confidenceMasks(); assertThat(segmentations.size()).isEqualTo(21); // Cat category index 8. - MPImage actualMaskBuffer = actualResult.segmentations().get(8); + MPImage actualMaskBuffer = segmentations.get(8); verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY); } } @@ -357,12 +347,11 @@ public class ImageSegmenterTest { ImageSegmenterOptions options = ImageSegmenterOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) - .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) .setRunningMode(RunningMode.VIDEO) .setResultListener( (segmenterResult, inputImage) -> { verifyConfidenceMask( - segmenterResult.segmentations().get(8), + segmenterResult.confidenceMasks().get(8), expectedResult, GOLDEN_MASK_SIMILARITY); }) @@ -384,12 +373,11 @@ public class ImageSegmenterTest { ImageSegmenterOptions options = ImageSegmenterOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) - .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) .setRunningMode(RunningMode.LIVE_STREAM) .setResultListener( (segmenterResult, inputImage) -> { verifyConfidenceMask( - segmenterResult.segmentations().get(8), + segmenterResult.confidenceMasks().get(8), expectedResult, GOLDEN_MASK_SIMILARITY); }) @@ -411,12 +399,11 @@ public class ImageSegmenterTest { ImageSegmenterOptions options = ImageSegmenterOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) - .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) .setRunningMode(RunningMode.LIVE_STREAM) .setResultListener( (segmenterResult, inputImage) -> { verifyConfidenceMask( - segmenterResult.segmentations().get(8), + segmenterResult.confidenceMasks().get(8), expectedResult, GOLDEN_MASK_SIMILARITY); }) diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenterTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenterTest.java index 0d9581437..9351bc721 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenterTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenterTest.java @@ -60,7 +60,10 @@ public class InteractiveSegmenterTest { ApplicationProvider.getApplicationContext(), options); MPImage image = getImageFromAsset(inputImageName); ImageSegmenterResult actualResult = imageSegmenter.segment(image, roi); - List segmentations = actualResult.segmentations(); + // TODO update to correct category mask output. + // After InteractiveSegmenter updated according to (b/276519300), update this to use + // categoryMask field instead of confidenceMasks. + List segmentations = actualResult.confidenceMasks(); assertThat(segmentations.size()).isEqualTo(1); } @@ -79,7 +82,7 @@ public class InteractiveSegmenterTest { ApplicationProvider.getApplicationContext(), options); ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName), roi); - List segmentations = actualResult.segmentations(); + List segmentations = actualResult.confidenceMasks(); assertThat(segmentations.size()).isEqualTo(2); } }