internal change

PiperOrigin-RevId: 523773255
This commit is contained in:
MediaPipe Team 2023-04-12 12:25:47 -07:00 committed by Copybara-Service
parent ca0da8d26f
commit 9a10375de6
5 changed files with 147 additions and 124 deletions

View File

@ -45,6 +45,7 @@ import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.function.BiFunction;
/** /**
* Performs image segmentation on images. * Performs image segmentation on images.
@ -79,15 +80,7 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
private static final List<String> INPUT_STREAMS = private static final List<String> INPUT_STREAMS =
Collections.unmodifiableList( Collections.unmodifiableList(
Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME)); Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME));
private static final List<String> OUTPUT_STREAMS = private static final int IMAGE_OUT_STREAM_INDEX = 0;
Collections.unmodifiableList(
Arrays.asList(
"GROUPED_SEGMENTATION:segmented_mask_out",
"IMAGE:image_out",
"SEGMENTATION:0:segmentation"));
private static final int GROUPED_SEGMENTATION_OUT_STREAM_INDEX = 0;
private static final int IMAGE_OUT_STREAM_INDEX = 1;
private static final int SEGMENTATION_OUT_STREAM_INDEX = 2;
private static final String TASK_GRAPH_NAME = private static final String TASK_GRAPH_NAME =
"mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph"; "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph";
private static final String TENSORS_TO_SEGMENTATION_CALCULATOR_NAME = private static final String TENSORS_TO_SEGMENTATION_CALCULATOR_NAME =
@ -104,6 +97,33 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
*/ */
public static ImageSegmenter createFromOptions( public static ImageSegmenter createFromOptions(
Context context, ImageSegmenterOptions segmenterOptions) { Context context, ImageSegmenterOptions segmenterOptions) {
if (!segmenterOptions.outputConfidenceMasks() && !segmenterOptions.outputCategoryMask()) {
throw new IllegalArgumentException(
"At least one of `outputConfidenceMasks` and `outputCategoryMask` must be set.");
}
List<String> outputStreams = new ArrayList<>();
outputStreams.add("IMAGE:image_out");
// Add an output stream to the output stream list, and get the added output stream index.
BiFunction<List<String>, String, Integer> getStreamIndex =
(List<String> streams, String streamName) -> {
streams.add(streamName);
return streams.size() - 1;
};
int confidenceMasksOutStreamIndex =
segmenterOptions.outputConfidenceMasks()
? getStreamIndex.apply(outputStreams, "CONFIDENCE_MASKS:confidence_masks")
: -1;
int confidenceMaskOutStreamIndex =
segmenterOptions.outputConfidenceMasks()
? getStreamIndex.apply(outputStreams, "CONFIDENCE_MASK:0:confidence_mask")
: -1;
int categoryMaskOutStreamIndex =
segmenterOptions.outputCategoryMask()
? getStreamIndex.apply(outputStreams, "CATEGORY_MASK:category_mask")
: -1;
// TODO: Consolidate OutputHandler and TaskRunner. // TODO: Consolidate OutputHandler and TaskRunner.
OutputHandler<ImageSegmenterResult, MPImage> handler = new OutputHandler<>(); OutputHandler<ImageSegmenterResult, MPImage> handler = new OutputHandler<>();
handler.setOutputPacketConverter( handler.setOutputPacketConverter(
@ -111,50 +131,65 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
@Override @Override
public ImageSegmenterResult convertToTaskResult(List<Packet> packets) public ImageSegmenterResult convertToTaskResult(List<Packet> packets)
throws MediaPipeException { throws MediaPipeException {
if (packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).isEmpty()) { if (packets.get(IMAGE_OUT_STREAM_INDEX).isEmpty()) {
return ImageSegmenterResult.create( return ImageSegmenterResult.create(
new ArrayList<>(), Optional.empty(),
packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).getTimestamp()); Optional.empty(),
packets.get(IMAGE_OUT_STREAM_INDEX).getTimestamp());
} }
List<MPImage> segmentedMasks = new ArrayList<>(); boolean copyImage = !segmenterOptions.resultListener().isPresent();
int width = PacketGetter.getImageWidth(packets.get(SEGMENTATION_OUT_STREAM_INDEX)); Optional<List<MPImage>> confidenceMasks = Optional.empty();
int height = PacketGetter.getImageHeight(packets.get(SEGMENTATION_OUT_STREAM_INDEX)); if (segmenterOptions.outputConfidenceMasks()) {
int imageFormat = int width = PacketGetter.getImageWidth(packets.get(confidenceMaskOutStreamIndex));
segmenterOptions.outputType() == ImageSegmenterOptions.OutputType.CONFIDENCE_MASK int height = PacketGetter.getImageHeight(packets.get(confidenceMaskOutStreamIndex));
? MPImage.IMAGE_FORMAT_VEC32F1 confidenceMasks = Optional.of(new ArrayList<MPImage>());
: MPImage.IMAGE_FORMAT_ALPHA; int confidenceMasksListSize =
int imageListSize = PacketGetter.getImageListSize(packets.get(confidenceMasksOutStreamIndex));
PacketGetter.getImageListSize(packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX)); ByteBuffer[] buffersArray = new ByteBuffer[confidenceMasksListSize];
ByteBuffer[] buffersArray = new ByteBuffer[imageListSize]; // If resultListener is not provided, the resulted MPImage is deep copied from
// If resultListener is not provided, the resulted MPImage is deep copied from mediapipe // mediapipe graph. If provided, the result MPImage is wrapping the mediapipe packet
// graph. If provided, the result MPImage is wrapping the mediapipe packet memory. // memory.
if (!segmenterOptions.resultListener().isPresent()) { if (copyImage) {
for (int i = 0; i < imageListSize; i++) { for (int i = 0; i < confidenceMasksListSize; i++) {
buffersArray[i] = buffersArray[i] = ByteBuffer.allocateDirect(width * height * 4);
ByteBuffer.allocateDirect(
width * height * (imageFormat == MPImage.IMAGE_FORMAT_VEC32F1 ? 4 : 1));
} }
} }
if (!PacketGetter.getImageList( if (!PacketGetter.getImageList(
packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX), packets.get(confidenceMasksOutStreamIndex), buffersArray, copyImage)) {
buffersArray,
!segmenterOptions.resultListener().isPresent())) {
throw new MediaPipeException( throw new MediaPipeException(
MediaPipeException.StatusCode.INTERNAL.ordinal(), MediaPipeException.StatusCode.INTERNAL.ordinal(),
"There is an error getting segmented masks. It usually results from incorrect" "There is an error getting confidence masks.");
+ " options of unsupported OutputType of given model.");
} }
for (ByteBuffer buffer : buffersArray) { for (ByteBuffer buffer : buffersArray) {
ByteBufferImageBuilder builder = ByteBufferImageBuilder builder =
new ByteBufferImageBuilder(buffer, width, height, imageFormat); new ByteBufferImageBuilder(buffer, width, height, MPImage.IMAGE_FORMAT_VEC32F1);
segmentedMasks.add(builder.build()); confidenceMasks.get().add(builder.build());
}
}
Optional<MPImage> categoryMask = Optional.empty();
if (segmenterOptions.outputCategoryMask()) {
int width = PacketGetter.getImageWidth(packets.get(categoryMaskOutStreamIndex));
int height = PacketGetter.getImageHeight(packets.get(categoryMaskOutStreamIndex));
ByteBuffer buffer;
if (copyImage) {
buffer = ByteBuffer.allocateDirect(width * height);
if (!PacketGetter.getImageData(packets.get(categoryMaskOutStreamIndex), buffer)) {
throw new MediaPipeException(
MediaPipeException.StatusCode.INTERNAL.ordinal(),
"There is an error getting category mask.");
}
} else {
buffer = PacketGetter.getImageDataDirectly(packets.get(categoryMaskOutStreamIndex));
}
ByteBufferImageBuilder builder =
new ByteBufferImageBuilder(buffer, width, height, MPImage.IMAGE_FORMAT_ALPHA);
categoryMask = Optional.of(builder.build());
} }
return ImageSegmenterResult.create( return ImageSegmenterResult.create(
segmentedMasks, confidenceMasks,
categoryMask,
BaseVisionTaskApi.generateResultTimestampMs( BaseVisionTaskApi.generateResultTimestampMs(
segmenterOptions.runningMode(), segmenterOptions.runningMode(), packets.get(IMAGE_OUT_STREAM_INDEX)));
packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX)));
} }
@Override @Override
@ -174,7 +209,7 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
.setTaskRunningModeName(segmenterOptions.runningMode().name()) .setTaskRunningModeName(segmenterOptions.runningMode().name())
.setTaskGraphName(TASK_GRAPH_NAME) .setTaskGraphName(TASK_GRAPH_NAME)
.setInputStreams(INPUT_STREAMS) .setInputStreams(INPUT_STREAMS)
.setOutputStreams(OUTPUT_STREAMS) .setOutputStreams(outputStreams)
.setTaskOptions(segmenterOptions) .setTaskOptions(segmenterOptions)
.setEnableFlowLimiting(segmenterOptions.runningMode() == RunningMode.LIVE_STREAM) .setEnableFlowLimiting(segmenterOptions.runningMode() == RunningMode.LIVE_STREAM)
.build(), .build(),
@ -553,8 +588,11 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
*/ */
public abstract Builder setDisplayNamesLocale(String value); public abstract Builder setDisplayNamesLocale(String value);
/** The output type from image segmenter. */ /** Whether to output confidence masks. */
public abstract Builder setOutputType(OutputType value); public abstract Builder setOutputConfidenceMasks(boolean value);
/** Whether to output category mask. */
public abstract Builder setOutputCategoryMask(boolean value);
/** /**
* Sets an optional {@link ResultListener} to receive the segmentation results when the graph * Sets an optional {@link ResultListener} to receive the segmentation results when the graph
@ -594,27 +632,20 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
abstract String displayNamesLocale(); abstract String displayNamesLocale();
abstract OutputType outputType(); abstract boolean outputConfidenceMasks();
abstract boolean outputCategoryMask();
abstract Optional<ResultListener<ImageSegmenterResult, MPImage>> resultListener(); abstract Optional<ResultListener<ImageSegmenterResult, MPImage>> resultListener();
abstract Optional<ErrorListener> errorListener(); abstract Optional<ErrorListener> errorListener();
/** The output type of segmentation results. */
public enum OutputType {
// Gives a single output mask where each pixel represents the class which
// the pixel in the original image was predicted to belong to.
CATEGORY_MASK,
// Gives a list of output masks where, for each mask, each pixel represents
// the prediction confidence, usually in the [0, 1] range.
CONFIDENCE_MASK
}
public static Builder builder() { public static Builder builder() {
return new AutoValue_ImageSegmenter_ImageSegmenterOptions.Builder() return new AutoValue_ImageSegmenter_ImageSegmenterOptions.Builder()
.setRunningMode(RunningMode.IMAGE) .setRunningMode(RunningMode.IMAGE)
.setDisplayNamesLocale("en") .setDisplayNamesLocale("en")
.setOutputType(OutputType.CATEGORY_MASK); .setOutputConfidenceMasks(true)
.setOutputCategoryMask(false);
} }
/** /**
@ -633,14 +664,6 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
SegmenterOptionsProto.SegmenterOptions.Builder segmenterOptionsBuilder = SegmenterOptionsProto.SegmenterOptions.Builder segmenterOptionsBuilder =
SegmenterOptionsProto.SegmenterOptions.newBuilder(); SegmenterOptionsProto.SegmenterOptions.newBuilder();
if (outputType() == OutputType.CONFIDENCE_MASK) {
segmenterOptionsBuilder.setOutputType(
SegmenterOptionsProto.SegmenterOptions.OutputType.CONFIDENCE_MASK);
} else if (outputType() == OutputType.CATEGORY_MASK) {
segmenterOptionsBuilder.setOutputType(
SegmenterOptionsProto.SegmenterOptions.OutputType.CATEGORY_MASK);
}
taskOptionsBuilder.setSegmenterOptions(segmenterOptionsBuilder); taskOptionsBuilder.setSegmenterOptions(segmenterOptionsBuilder);
return CalculatorOptions.newBuilder() return CalculatorOptions.newBuilder()
.setExtension( .setExtension(

View File

@ -19,6 +19,7 @@ import com.google.mediapipe.framework.image.MPImage;
import com.google.mediapipe.tasks.core.TaskResult; import com.google.mediapipe.tasks.core.TaskResult;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Optional;
/** Represents the segmentation results generated by {@link ImageSegmenter}. */ /** Represents the segmentation results generated by {@link ImageSegmenter}. */
@AutoValue @AutoValue
@ -27,18 +28,24 @@ public abstract class ImageSegmenterResult implements TaskResult {
/** /**
* Creates an {@link ImageSegmenterResult} instance from a list of segmentation MPImage. * Creates an {@link ImageSegmenterResult} instance from a list of segmentation MPImage.
* *
* @param segmentations a {@link List} of MPImage representing the segmented masks. If OutputType * @param confidenceMasks an {@link Optional} of {@link List} of MPImage in IMAGE_FORMAT_VEC32F1
* is CATEGORY_MASK, the masks will be in IMAGE_FORMAT_ALPHA format. If OutputType is * format representing the confidence masks, where, for each mask, each pixel represents the
* CONFIDENCE_MASK, the masks will be in IMAGE_FORMAT_VEC32F1 format. * prediction confidence, usually in the [0, 1] range.
* @param categoryMask an {@link Optional} MPImage in IMAGE_FORMAT_ALPHA format representing a
* category mask, where each pixel represents the class which the pixel in the original image
* was predicted to belong to.
* @param timestampMs a timestamp for this result. * @param timestampMs a timestamp for this result.
*/ */
// TODO: consolidate output formats across platforms. // TODO: consolidate output formats across platforms.
public static ImageSegmenterResult create(List<MPImage> segmentations, long timestampMs) { public static ImageSegmenterResult create(
Optional<List<MPImage>> confidenceMasks, Optional<MPImage> categoryMask, long timestampMs) {
return new AutoValue_ImageSegmenterResult( return new AutoValue_ImageSegmenterResult(
Collections.unmodifiableList(segmentations), timestampMs); confidenceMasks.map(Collections::unmodifiableList), categoryMask, timestampMs);
} }
public abstract List<MPImage> segmentations(); public abstract Optional<List<MPImage>> confidenceMasks();
public abstract Optional<MPImage> categoryMask();
@Override @Override
public abstract long timestampMs(); public abstract long timestampMs();

View File

@ -132,7 +132,8 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
throws MediaPipeException { throws MediaPipeException {
if (packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).isEmpty()) { if (packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).isEmpty()) {
return ImageSegmenterResult.create( return ImageSegmenterResult.create(
new ArrayList<>(), Optional.empty(),
Optional.empty(),
packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).getTimestamp()); packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).getTimestamp());
} }
List<MPImage> segmentedMasks = new ArrayList<>(); List<MPImage> segmentedMasks = new ArrayList<>();
@ -171,7 +172,8 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
} }
return ImageSegmenterResult.create( return ImageSegmenterResult.create(
segmentedMasks, Optional.of(segmentedMasks),
Optional.empty(),
BaseVisionTaskApi.generateResultTimestampMs( BaseVisionTaskApi.generateResultTimestampMs(
RunningMode.IMAGE, packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX))); RunningMode.IMAGE, packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX)));
} }

View File

@ -61,14 +61,14 @@ public class ImageSegmenterTest {
ImageSegmenterOptions options = ImageSegmenterOptions options =
ImageSegmenterOptions.builder() ImageSegmenterOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
.setOutputType(ImageSegmenterOptions.OutputType.CATEGORY_MASK) .setOutputConfidenceMasks(false)
.setOutputCategoryMask(true)
.build(); .build();
ImageSegmenter imageSegmenter = ImageSegmenter imageSegmenter =
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options); ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName)); ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName));
List<MPImage> segmentations = actualResult.segmentations(); assertThat(actualResult.categoryMask().isPresent()).isTrue();
assertThat(segmentations.size()).isEqualTo(1); MPImage actualMaskBuffer = actualResult.categoryMask().get();
MPImage actualMaskBuffer = actualResult.segmentations().get(0);
MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName); MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
verifyCategoryMask( verifyCategoryMask(
actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY, MAGNIFICATION_FACTOR); actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY, MAGNIFICATION_FACTOR);
@ -81,15 +81,14 @@ public class ImageSegmenterTest {
ImageSegmenterOptions options = ImageSegmenterOptions options =
ImageSegmenterOptions.builder() ImageSegmenterOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
.setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
.build(); .build();
ImageSegmenter imageSegmenter = ImageSegmenter imageSegmenter =
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options); ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName)); ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName));
List<MPImage> segmentations = actualResult.segmentations(); List<MPImage> segmentations = actualResult.confidenceMasks().get();
assertThat(segmentations.size()).isEqualTo(21); assertThat(segmentations.size()).isEqualTo(21);
// Cat category index 8. // Cat category index 8.
MPImage actualMaskBuffer = actualResult.segmentations().get(8); MPImage actualMaskBuffer = segmentations.get(8);
MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName); MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY); verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
} }
@ -102,40 +101,36 @@ public class ImageSegmenterTest {
ImageSegmenterOptions.builder() ImageSegmenterOptions.builder()
.setBaseOptions( .setBaseOptions(
BaseOptions.builder().setModelAssetPath(SELFIE_128x128_MODEL_FILE).build()) BaseOptions.builder().setModelAssetPath(SELFIE_128x128_MODEL_FILE).build())
.setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
.build(); .build();
ImageSegmenter imageSegmenter = ImageSegmenter imageSegmenter =
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options); ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName)); ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName));
List<MPImage> segmentations = actualResult.segmentations(); List<MPImage> segmentations = actualResult.confidenceMasks().get();
assertThat(segmentations.size()).isEqualTo(2); assertThat(segmentations.size()).isEqualTo(2);
// Selfie category index 1. // Selfie category index 1.
MPImage actualMaskBuffer = actualResult.segmentations().get(1); MPImage actualMaskBuffer = segmentations.get(1);
MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName); MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY); verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
} }
// TODO: enable this unit test once activation option is supported in metadata. @Test
// @Test public void segment_successWith144x256Segmentation() throws Exception {
// public void segment_successWith144x256Segmentation() throws Exception { final String inputImageName = "mozart_square.jpg";
// final String inputImageName = "mozart_square.jpg"; final String goldenImageName = "selfie_segm_144_256_3_expected_mask.jpg";
// final String goldenImageName = "selfie_segm_144_256_3_expected_mask.jpg"; ImageSegmenterOptions options =
// ImageSegmenterOptions options = ImageSegmenterOptions.builder()
// ImageSegmenterOptions.builder() .setBaseOptions(
// .setBaseOptions( BaseOptions.builder().setModelAssetPath(SELFIE_144x256_MODEL_FILE).build())
// BaseOptions.builder().setModelAssetPath(SELFIE_144x256_MODEL_FILE).build()) .build();
// .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) ImageSegmenter imageSegmenter =
// .build(); ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
// ImageSegmenter imageSegmenter = ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName));
// ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options); List<MPImage> segmentations = actualResult.confidenceMasks().get();
// ImageSegmenterResult actualResult = assertThat(segmentations.size()).isEqualTo(1);
// imageSegmenter.segment(getImageFromAsset(inputImageName)); MPImage actualMaskBuffer = segmentations.get(0);
// List<MPImage> segmentations = actualResult.segmentations(); MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
// assertThat(segmentations.size()).isEqualTo(1); verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
// MPImage actualMaskBuffer = actualResult.segmentations().get(0); }
// MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
// verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
// }
@Test @Test
public void getLabels_success() throws Exception { public void getLabels_success() throws Exception {
@ -165,7 +160,6 @@ public class ImageSegmenterTest {
ImageSegmenterOptions options = ImageSegmenterOptions options =
ImageSegmenterOptions.builder() ImageSegmenterOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
.setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
.build(); .build();
ImageSegmenter imageSegmenter = ImageSegmenter imageSegmenter =
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options); ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
@ -287,16 +281,15 @@ public class ImageSegmenterTest {
ImageSegmenterOptions options = ImageSegmenterOptions options =
ImageSegmenterOptions.builder() ImageSegmenterOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
.setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
.setRunningMode(RunningMode.IMAGE) .setRunningMode(RunningMode.IMAGE)
.build(); .build();
ImageSegmenter imageSegmenter = ImageSegmenter imageSegmenter =
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options); ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName)); ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName));
List<MPImage> segmentations = actualResult.segmentations(); List<MPImage> segmentations = actualResult.confidenceMasks().get();
assertThat(segmentations.size()).isEqualTo(21); assertThat(segmentations.size()).isEqualTo(21);
// Cat category index 8. // Cat category index 8.
MPImage actualMaskBuffer = actualResult.segmentations().get(8); MPImage actualMaskBuffer = segmentations.get(8);
MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName); MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY); verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
} }
@ -309,12 +302,11 @@ public class ImageSegmenterTest {
ImageSegmenterOptions options = ImageSegmenterOptions options =
ImageSegmenterOptions.builder() ImageSegmenterOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
.setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
.setRunningMode(RunningMode.IMAGE) .setRunningMode(RunningMode.IMAGE)
.setResultListener( .setResultListener(
(segmenterResult, inputImage) -> { (segmenterResult, inputImage) -> {
verifyConfidenceMask( verifyConfidenceMask(
segmenterResult.segmentations().get(8), segmenterResult.confidenceMasks().get().get(8),
expectedResult, expectedResult,
GOLDEN_MASK_SIMILARITY); GOLDEN_MASK_SIMILARITY);
}) })
@ -331,7 +323,6 @@ public class ImageSegmenterTest {
ImageSegmenterOptions options = ImageSegmenterOptions options =
ImageSegmenterOptions.builder() ImageSegmenterOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
.setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
.setRunningMode(RunningMode.VIDEO) .setRunningMode(RunningMode.VIDEO)
.build(); .build();
ImageSegmenter imageSegmenter = ImageSegmenter imageSegmenter =
@ -341,10 +332,10 @@ public class ImageSegmenterTest {
ImageSegmenterResult actualResult = ImageSegmenterResult actualResult =
imageSegmenter.segmentForVideo( imageSegmenter.segmentForVideo(
getImageFromAsset(inputImageName), /* timestampsMs= */ i); getImageFromAsset(inputImageName), /* timestampsMs= */ i);
List<MPImage> segmentations = actualResult.segmentations(); List<MPImage> segmentations = actualResult.confidenceMasks().get();
assertThat(segmentations.size()).isEqualTo(21); assertThat(segmentations.size()).isEqualTo(21);
// Cat category index 8. // Cat category index 8.
MPImage actualMaskBuffer = actualResult.segmentations().get(8); MPImage actualMaskBuffer = segmentations.get(8);
verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY); verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
} }
} }
@ -357,12 +348,11 @@ public class ImageSegmenterTest {
ImageSegmenterOptions options = ImageSegmenterOptions options =
ImageSegmenterOptions.builder() ImageSegmenterOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
.setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
.setRunningMode(RunningMode.VIDEO) .setRunningMode(RunningMode.VIDEO)
.setResultListener( .setResultListener(
(segmenterResult, inputImage) -> { (segmenterResult, inputImage) -> {
verifyConfidenceMask( verifyConfidenceMask(
segmenterResult.segmentations().get(8), segmenterResult.confidenceMasks().get().get(8),
expectedResult, expectedResult,
GOLDEN_MASK_SIMILARITY); GOLDEN_MASK_SIMILARITY);
}) })
@ -384,12 +374,11 @@ public class ImageSegmenterTest {
ImageSegmenterOptions options = ImageSegmenterOptions options =
ImageSegmenterOptions.builder() ImageSegmenterOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
.setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
.setRunningMode(RunningMode.LIVE_STREAM) .setRunningMode(RunningMode.LIVE_STREAM)
.setResultListener( .setResultListener(
(segmenterResult, inputImage) -> { (segmenterResult, inputImage) -> {
verifyConfidenceMask( verifyConfidenceMask(
segmenterResult.segmentations().get(8), segmenterResult.confidenceMasks().get().get(8),
expectedResult, expectedResult,
GOLDEN_MASK_SIMILARITY); GOLDEN_MASK_SIMILARITY);
}) })
@ -411,12 +400,11 @@ public class ImageSegmenterTest {
ImageSegmenterOptions options = ImageSegmenterOptions options =
ImageSegmenterOptions.builder() ImageSegmenterOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
.setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
.setRunningMode(RunningMode.LIVE_STREAM) .setRunningMode(RunningMode.LIVE_STREAM)
.setResultListener( .setResultListener(
(segmenterResult, inputImage) -> { (segmenterResult, inputImage) -> {
verifyConfidenceMask( verifyConfidenceMask(
segmenterResult.segmentations().get(8), segmenterResult.confidenceMasks().get().get(8),
expectedResult, expectedResult,
GOLDEN_MASK_SIMILARITY); GOLDEN_MASK_SIMILARITY);
}) })

View File

@ -60,7 +60,10 @@ public class InteractiveSegmenterTest {
ApplicationProvider.getApplicationContext(), options); ApplicationProvider.getApplicationContext(), options);
MPImage image = getImageFromAsset(inputImageName); MPImage image = getImageFromAsset(inputImageName);
ImageSegmenterResult actualResult = imageSegmenter.segment(image, roi); ImageSegmenterResult actualResult = imageSegmenter.segment(image, roi);
List<MPImage> segmentations = actualResult.segmentations(); // TODO update to correct category mask output.
// After InteractiveSegmenter updated according to (b/276519300), update this to use
// categoryMask field instead of confidenceMasks.
List<MPImage> segmentations = actualResult.confidenceMasks().get();
assertThat(segmentations.size()).isEqualTo(1); assertThat(segmentations.size()).isEqualTo(1);
} }
@ -79,7 +82,7 @@ public class InteractiveSegmenterTest {
ApplicationProvider.getApplicationContext(), options); ApplicationProvider.getApplicationContext(), options);
ImageSegmenterResult actualResult = ImageSegmenterResult actualResult =
imageSegmenter.segment(getImageFromAsset(inputImageName), roi); imageSegmenter.segment(getImageFromAsset(inputImageName), roi);
List<MPImage> segmentations = actualResult.segmentations(); List<MPImage> segmentations = actualResult.confidenceMasks().get();
assertThat(segmentations.size()).isEqualTo(2); assertThat(segmentations.size()).isEqualTo(2);
} }
} }