Update Java interactive segmenter to output both confidence masks and category mask optionally.

PiperOrigin-RevId: 524442070
This commit is contained in:
MediaPipe Team 2023-04-14 19:43:04 -07:00 committed by Copybara-Service
parent e14a88052a
commit 411ffaeb43
2 changed files with 90 additions and 76 deletions

View File

@ -94,15 +94,7 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
"IMAGE:" + IMAGE_IN_STREAM_NAME, "IMAGE:" + IMAGE_IN_STREAM_NAME,
"ROI:" + ROI_IN_STREAM_NAME, "ROI:" + ROI_IN_STREAM_NAME,
"NORM_RECT:" + NORM_RECT_IN_STREAM_NAME)); "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME));
private static final List<String> OUTPUT_STREAMS = private static final int IMAGE_OUT_STREAM_INDEX = 0;
Collections.unmodifiableList(
Arrays.asList(
"GROUPED_SEGMENTATION:segmented_mask_out",
"IMAGE:image_out",
"SEGMENTATION:0:segmentation"));
private static final int GROUPED_SEGMENTATION_OUT_STREAM_INDEX = 0;
private static final int IMAGE_OUT_STREAM_INDEX = 1;
private static final int SEGMENTATION_OUT_STREAM_INDEX = 2;
private static final String TASK_GRAPH_NAME = private static final String TASK_GRAPH_NAME =
"mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph"; "mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph";
private static final String TENSORS_TO_SEGMENTATION_CALCULATOR_NAME = private static final String TENSORS_TO_SEGMENTATION_CALCULATOR_NAME =
@ -123,6 +115,21 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
*/ */
public static InteractiveSegmenter createFromOptions( public static InteractiveSegmenter createFromOptions(
Context context, InteractiveSegmenterOptions segmenterOptions) { Context context, InteractiveSegmenterOptions segmenterOptions) {
if (!segmenterOptions.outputConfidenceMasks() && !segmenterOptions.outputCategoryMask()) {
throw new IllegalArgumentException(
"At least one of `outputConfidenceMasks` and `outputCategoryMask` must be set.");
}
List<String> outputStreams = new ArrayList<>();
outputStreams.add("IMAGE:image_out");
if (segmenterOptions.outputConfidenceMasks()) {
outputStreams.add("CONFIDENCE_MASKS:confidence_masks");
}
final int confidenceMasksOutStreamIndex = outputStreams.size() - 1;
if (segmenterOptions.outputCategoryMask()) {
outputStreams.add("CATEGORY_MASK:category_mask");
}
final int categoryMaskOutStreamIndex = outputStreams.size() - 1;
// TODO: Consolidate OutputHandler and TaskRunner. // TODO: Consolidate OutputHandler and TaskRunner.
OutputHandler<ImageSegmenterResult, MPImage> handler = new OutputHandler<>(); OutputHandler<ImageSegmenterResult, MPImage> handler = new OutputHandler<>();
handler.setOutputPacketConverter( handler.setOutputPacketConverter(
@ -130,52 +137,72 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
@Override @Override
public ImageSegmenterResult convertToTaskResult(List<Packet> packets) public ImageSegmenterResult convertToTaskResult(List<Packet> packets)
throws MediaPipeException { throws MediaPipeException {
if (packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).isEmpty()) { if (packets.get(IMAGE_OUT_STREAM_INDEX).isEmpty()) {
return ImageSegmenterResult.create( return ImageSegmenterResult.create(
Optional.empty(), Optional.empty(),
Optional.empty(), Optional.empty(),
packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).getTimestamp()); packets.get(IMAGE_OUT_STREAM_INDEX).getTimestamp());
} }
List<MPImage> segmentedMasks = new ArrayList<>(); // If resultListener is not provided, the resulted MPImage is deep copied from
int width = PacketGetter.getImageWidth(packets.get(SEGMENTATION_OUT_STREAM_INDEX)); // mediapipe graph. If provided, the result MPImage is wrapping the mediapipe packet
int height = PacketGetter.getImageHeight(packets.get(SEGMENTATION_OUT_STREAM_INDEX)); // memory.
int imageFormat = boolean copyImage = !segmenterOptions.resultListener().isPresent();
segmenterOptions.outputType() Optional<List<MPImage>> confidenceMasks = Optional.empty();
== InteractiveSegmenterOptions.OutputType.CONFIDENCE_MASK if (segmenterOptions.outputConfidenceMasks()) {
? MPImage.IMAGE_FORMAT_VEC32F1 confidenceMasks = Optional.of(new ArrayList<>());
: MPImage.IMAGE_FORMAT_ALPHA; int width =
int imageListSize = PacketGetter.getImageWidthFromImageList(
PacketGetter.getImageListSize(packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX)); packets.get(confidenceMasksOutStreamIndex));
ByteBuffer[] buffersArray = new ByteBuffer[imageListSize]; int height =
// If resultListener is not provided, the resulted MPImage is deep copied from mediapipe PacketGetter.getImageHeightFromImageList(
// graph. If provided, the result MPImage is wrapping the mediapipe packet memory. packets.get(confidenceMasksOutStreamIndex));
if (!segmenterOptions.resultListener().isPresent()) { int imageListSize =
for (int i = 0; i < imageListSize; i++) { PacketGetter.getImageListSize(packets.get(confidenceMasksOutStreamIndex));
buffersArray[i] = ByteBuffer[] buffersArray = new ByteBuffer[imageListSize];
ByteBuffer.allocateDirect( // confidence masks are float type image.
width * height * (imageFormat == MPImage.IMAGE_FORMAT_VEC32F1 ? 4 : 1)); final int numBytes = 4;
if (copyImage) {
for (int i = 0; i < imageListSize; i++) {
buffersArray[i] = ByteBuffer.allocateDirect(width * height * numBytes);
}
}
if (!PacketGetter.getImageList(
packets.get(confidenceMasksOutStreamIndex), buffersArray, copyImage)) {
throw new MediaPipeException(
MediaPipeException.StatusCode.INTERNAL.ordinal(),
"There is an error getting confidence masks.");
}
for (ByteBuffer buffer : buffersArray) {
ByteBufferImageBuilder builder =
new ByteBufferImageBuilder(buffer, width, height, MPImage.IMAGE_FORMAT_VEC32F1);
confidenceMasks.get().add(builder.build());
} }
} }
if (!PacketGetter.getImageList( Optional<MPImage> categoryMask = Optional.empty();
packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX), if (segmenterOptions.outputCategoryMask()) {
buffersArray, int width = PacketGetter.getImageWidth(packets.get(categoryMaskOutStreamIndex));
!segmenterOptions.resultListener().isPresent())) { int height = PacketGetter.getImageHeight(packets.get(categoryMaskOutStreamIndex));
throw new MediaPipeException( ByteBuffer buffer;
MediaPipeException.StatusCode.INTERNAL.ordinal(), if (copyImage) {
"There is an error getting segmented masks. It usually results from incorrect" buffer = ByteBuffer.allocateDirect(width * height);
+ " options of unsupported OutputType of given model."); if (!PacketGetter.getImageData(packets.get(categoryMaskOutStreamIndex), buffer)) {
} throw new MediaPipeException(
for (ByteBuffer buffer : buffersArray) { MediaPipeException.StatusCode.INTERNAL.ordinal(),
"There is an error getting category mask.");
}
} else {
buffer = PacketGetter.getImageDataDirectly(packets.get(categoryMaskOutStreamIndex));
}
ByteBufferImageBuilder builder = ByteBufferImageBuilder builder =
new ByteBufferImageBuilder(buffer, width, height, imageFormat); new ByteBufferImageBuilder(buffer, width, height, MPImage.IMAGE_FORMAT_ALPHA);
segmentedMasks.add(builder.build()); categoryMask = Optional.of(builder.build());
} }
return ImageSegmenterResult.create( return ImageSegmenterResult.create(
Optional.of(segmentedMasks), confidenceMasks,
Optional.empty(), categoryMask,
BaseVisionTaskApi.generateResultTimestampMs( BaseVisionTaskApi.generateResultTimestampMs(
RunningMode.IMAGE, packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX))); RunningMode.IMAGE, packets.get(IMAGE_OUT_STREAM_INDEX)));
} }
@Override @Override
@ -195,7 +222,7 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
.setTaskRunningModeName(RunningMode.IMAGE.name()) .setTaskRunningModeName(RunningMode.IMAGE.name())
.setTaskGraphName(TASK_GRAPH_NAME) .setTaskGraphName(TASK_GRAPH_NAME)
.setInputStreams(INPUT_STREAMS) .setInputStreams(INPUT_STREAMS)
.setOutputStreams(OUTPUT_STREAMS) .setOutputStreams(outputStreams)
.setTaskOptions(segmenterOptions) .setTaskOptions(segmenterOptions)
.setEnableFlowLimiting(false) .setEnableFlowLimiting(false)
.build(), .build(),
@ -394,8 +421,11 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
/** Sets the base options for the image segmenter task. */ /** Sets the base options for the image segmenter task. */
public abstract Builder setBaseOptions(BaseOptions value); public abstract Builder setBaseOptions(BaseOptions value);
/** The output type from image segmenter. */ /** Sets whether to output confidence masks. Default to true. */
public abstract Builder setOutputType(OutputType value); public abstract Builder setOutputConfidenceMasks(boolean value);
/** Sets whether to output category mask. Default to false. */
public abstract Builder setOutputCategoryMask(boolean value);
/** /**
* Sets an optional {@link ResultListener} to receive the segmentation results when the graph * Sets an optional {@link ResultListener} to receive the segmentation results when the graph
@ -417,25 +447,18 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
abstract BaseOptions baseOptions(); abstract BaseOptions baseOptions();
abstract OutputType outputType(); abstract boolean outputConfidenceMasks();
abstract boolean outputCategoryMask();
abstract Optional<ResultListener<ImageSegmenterResult, MPImage>> resultListener(); abstract Optional<ResultListener<ImageSegmenterResult, MPImage>> resultListener();
abstract Optional<ErrorListener> errorListener(); abstract Optional<ErrorListener> errorListener();
/** The output type of segmentation results. */
public enum OutputType {
// Gives a single output mask where each pixel represents the class which
// the pixel in the original image was predicted to belong to.
CATEGORY_MASK,
// Gives a list of output masks where, for each mask, each pixel represents
// the prediction confidence, usually in the [0, 1] range.
CONFIDENCE_MASK
}
public static Builder builder() { public static Builder builder() {
return new AutoValue_InteractiveSegmenter_InteractiveSegmenterOptions.Builder() return new AutoValue_InteractiveSegmenter_InteractiveSegmenterOptions.Builder()
.setOutputType(OutputType.CATEGORY_MASK); .setOutputConfidenceMasks(true)
.setOutputCategoryMask(false);
} }
/** /**
@ -454,14 +477,6 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
SegmenterOptionsProto.SegmenterOptions.Builder segmenterOptionsBuilder = SegmenterOptionsProto.SegmenterOptions.Builder segmenterOptionsBuilder =
SegmenterOptionsProto.SegmenterOptions.newBuilder(); SegmenterOptionsProto.SegmenterOptions.newBuilder();
if (outputType() == OutputType.CONFIDENCE_MASK) {
segmenterOptionsBuilder.setOutputType(
SegmenterOptionsProto.SegmenterOptions.OutputType.CONFIDENCE_MASK);
} else if (outputType() == OutputType.CATEGORY_MASK) {
segmenterOptionsBuilder.setOutputType(
SegmenterOptionsProto.SegmenterOptions.OutputType.CATEGORY_MASK);
}
taskOptionsBuilder.setSegmenterOptions(segmenterOptionsBuilder); taskOptionsBuilder.setSegmenterOptions(segmenterOptionsBuilder);
return CalculatorOptions.newBuilder() return CalculatorOptions.newBuilder()
.setExtension( .setExtension(

View File

@ -53,18 +53,15 @@ public class InteractiveSegmenterTest {
InteractiveSegmenterOptions options = InteractiveSegmenterOptions options =
InteractiveSegmenterOptions.builder() InteractiveSegmenterOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
.setOutputType(InteractiveSegmenterOptions.OutputType.CATEGORY_MASK) .setOutputConfidenceMasks(false)
.setOutputCategoryMask(true)
.build(); .build();
InteractiveSegmenter imageSegmenter = InteractiveSegmenter imageSegmenter =
InteractiveSegmenter.createFromOptions( InteractiveSegmenter.createFromOptions(
ApplicationProvider.getApplicationContext(), options); ApplicationProvider.getApplicationContext(), options);
MPImage image = getImageFromAsset(inputImageName); MPImage image = getImageFromAsset(inputImageName);
ImageSegmenterResult actualResult = imageSegmenter.segment(image, roi); ImageSegmenterResult actualResult = imageSegmenter.segment(image, roi);
// TODO update to correct category mask output. assertThat(actualResult.categoryMask().isPresent()).isTrue();
// After InteractiveSegmenter updated according to (b/276519300), update this to use
// categoryMask field instead of confidenceMasks.
List<MPImage> segmentations = actualResult.confidenceMasks().get();
assertThat(segmentations.size()).isEqualTo(1);
} }
@Test @Test
@ -75,15 +72,17 @@ public class InteractiveSegmenterTest {
InteractiveSegmenterOptions options = InteractiveSegmenterOptions options =
InteractiveSegmenterOptions.builder() InteractiveSegmenterOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
.setOutputType(InteractiveSegmenterOptions.OutputType.CONFIDENCE_MASK) .setOutputConfidenceMasks(true)
.setOutputCategoryMask(false)
.build(); .build();
InteractiveSegmenter imageSegmenter = InteractiveSegmenter imageSegmenter =
InteractiveSegmenter.createFromOptions( InteractiveSegmenter.createFromOptions(
ApplicationProvider.getApplicationContext(), options); ApplicationProvider.getApplicationContext(), options);
ImageSegmenterResult actualResult = ImageSegmenterResult actualResult =
imageSegmenter.segment(getImageFromAsset(inputImageName), roi); imageSegmenter.segment(getImageFromAsset(inputImageName), roi);
List<MPImage> segmentations = actualResult.confidenceMasks().get(); assertThat(actualResult.confidenceMasks().isPresent()).isTrue();
assertThat(segmentations.size()).isEqualTo(2); List<MPImage> confidenceMasks = actualResult.confidenceMasks().get();
assertThat(confidenceMasks.size()).isEqualTo(2);
} }
} }