Update Java interactive segmenter to output both confidence masks and category mask optionally.
PiperOrigin-RevId: 524442070
This commit is contained in:
parent
e14a88052a
commit
411ffaeb43
|
@ -94,15 +94,7 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
|
|||
"IMAGE:" + IMAGE_IN_STREAM_NAME,
|
||||
"ROI:" + ROI_IN_STREAM_NAME,
|
||||
"NORM_RECT:" + NORM_RECT_IN_STREAM_NAME));
|
||||
private static final List<String> OUTPUT_STREAMS =
|
||||
Collections.unmodifiableList(
|
||||
Arrays.asList(
|
||||
"GROUPED_SEGMENTATION:segmented_mask_out",
|
||||
"IMAGE:image_out",
|
||||
"SEGMENTATION:0:segmentation"));
|
||||
private static final int GROUPED_SEGMENTATION_OUT_STREAM_INDEX = 0;
|
||||
private static final int IMAGE_OUT_STREAM_INDEX = 1;
|
||||
private static final int SEGMENTATION_OUT_STREAM_INDEX = 2;
|
||||
private static final int IMAGE_OUT_STREAM_INDEX = 0;
|
||||
private static final String TASK_GRAPH_NAME =
|
||||
"mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph";
|
||||
private static final String TENSORS_TO_SEGMENTATION_CALCULATOR_NAME =
|
||||
|
@ -123,6 +115,21 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
|
|||
*/
|
||||
public static InteractiveSegmenter createFromOptions(
|
||||
Context context, InteractiveSegmenterOptions segmenterOptions) {
|
||||
if (!segmenterOptions.outputConfidenceMasks() && !segmenterOptions.outputCategoryMask()) {
|
||||
throw new IllegalArgumentException(
|
||||
"At least one of `outputConfidenceMasks` and `outputCategoryMask` must be set.");
|
||||
}
|
||||
List<String> outputStreams = new ArrayList<>();
|
||||
outputStreams.add("IMAGE:image_out");
|
||||
if (segmenterOptions.outputConfidenceMasks()) {
|
||||
outputStreams.add("CONFIDENCE_MASKS:confidence_masks");
|
||||
}
|
||||
final int confidenceMasksOutStreamIndex = outputStreams.size() - 1;
|
||||
if (segmenterOptions.outputCategoryMask()) {
|
||||
outputStreams.add("CATEGORY_MASK:category_mask");
|
||||
}
|
||||
final int categoryMaskOutStreamIndex = outputStreams.size() - 1;
|
||||
|
||||
// TODO: Consolidate OutputHandler and TaskRunner.
|
||||
OutputHandler<ImageSegmenterResult, MPImage> handler = new OutputHandler<>();
|
||||
handler.setOutputPacketConverter(
|
||||
|
@ -130,52 +137,72 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
|
|||
@Override
|
||||
public ImageSegmenterResult convertToTaskResult(List<Packet> packets)
|
||||
throws MediaPipeException {
|
||||
if (packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).isEmpty()) {
|
||||
if (packets.get(IMAGE_OUT_STREAM_INDEX).isEmpty()) {
|
||||
return ImageSegmenterResult.create(
|
||||
Optional.empty(),
|
||||
Optional.empty(),
|
||||
packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).getTimestamp());
|
||||
packets.get(IMAGE_OUT_STREAM_INDEX).getTimestamp());
|
||||
}
|
||||
List<MPImage> segmentedMasks = new ArrayList<>();
|
||||
int width = PacketGetter.getImageWidth(packets.get(SEGMENTATION_OUT_STREAM_INDEX));
|
||||
int height = PacketGetter.getImageHeight(packets.get(SEGMENTATION_OUT_STREAM_INDEX));
|
||||
int imageFormat =
|
||||
segmenterOptions.outputType()
|
||||
== InteractiveSegmenterOptions.OutputType.CONFIDENCE_MASK
|
||||
? MPImage.IMAGE_FORMAT_VEC32F1
|
||||
: MPImage.IMAGE_FORMAT_ALPHA;
|
||||
// If resultListener is not provided, the resulted MPImage is deep copied from
|
||||
// mediapipe graph. If provided, the result MPImage is wrapping the mediapipe packet
|
||||
// memory.
|
||||
boolean copyImage = !segmenterOptions.resultListener().isPresent();
|
||||
Optional<List<MPImage>> confidenceMasks = Optional.empty();
|
||||
if (segmenterOptions.outputConfidenceMasks()) {
|
||||
confidenceMasks = Optional.of(new ArrayList<>());
|
||||
int width =
|
||||
PacketGetter.getImageWidthFromImageList(
|
||||
packets.get(confidenceMasksOutStreamIndex));
|
||||
int height =
|
||||
PacketGetter.getImageHeightFromImageList(
|
||||
packets.get(confidenceMasksOutStreamIndex));
|
||||
int imageListSize =
|
||||
PacketGetter.getImageListSize(packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX));
|
||||
PacketGetter.getImageListSize(packets.get(confidenceMasksOutStreamIndex));
|
||||
ByteBuffer[] buffersArray = new ByteBuffer[imageListSize];
|
||||
// If resultListener is not provided, the resulted MPImage is deep copied from mediapipe
|
||||
// graph. If provided, the result MPImage is wrapping the mediapipe packet memory.
|
||||
if (!segmenterOptions.resultListener().isPresent()) {
|
||||
// confidence masks are float type image.
|
||||
final int numBytes = 4;
|
||||
if (copyImage) {
|
||||
for (int i = 0; i < imageListSize; i++) {
|
||||
buffersArray[i] =
|
||||
ByteBuffer.allocateDirect(
|
||||
width * height * (imageFormat == MPImage.IMAGE_FORMAT_VEC32F1 ? 4 : 1));
|
||||
buffersArray[i] = ByteBuffer.allocateDirect(width * height * numBytes);
|
||||
}
|
||||
}
|
||||
if (!PacketGetter.getImageList(
|
||||
packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX),
|
||||
buffersArray,
|
||||
!segmenterOptions.resultListener().isPresent())) {
|
||||
packets.get(confidenceMasksOutStreamIndex), buffersArray, copyImage)) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.INTERNAL.ordinal(),
|
||||
"There is an error getting segmented masks. It usually results from incorrect"
|
||||
+ " options of unsupported OutputType of given model.");
|
||||
"There is an error getting confidence masks.");
|
||||
}
|
||||
for (ByteBuffer buffer : buffersArray) {
|
||||
ByteBufferImageBuilder builder =
|
||||
new ByteBufferImageBuilder(buffer, width, height, imageFormat);
|
||||
segmentedMasks.add(builder.build());
|
||||
new ByteBufferImageBuilder(buffer, width, height, MPImage.IMAGE_FORMAT_VEC32F1);
|
||||
confidenceMasks.get().add(builder.build());
|
||||
}
|
||||
}
|
||||
Optional<MPImage> categoryMask = Optional.empty();
|
||||
if (segmenterOptions.outputCategoryMask()) {
|
||||
int width = PacketGetter.getImageWidth(packets.get(categoryMaskOutStreamIndex));
|
||||
int height = PacketGetter.getImageHeight(packets.get(categoryMaskOutStreamIndex));
|
||||
ByteBuffer buffer;
|
||||
if (copyImage) {
|
||||
buffer = ByteBuffer.allocateDirect(width * height);
|
||||
if (!PacketGetter.getImageData(packets.get(categoryMaskOutStreamIndex), buffer)) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.INTERNAL.ordinal(),
|
||||
"There is an error getting category mask.");
|
||||
}
|
||||
} else {
|
||||
buffer = PacketGetter.getImageDataDirectly(packets.get(categoryMaskOutStreamIndex));
|
||||
}
|
||||
ByteBufferImageBuilder builder =
|
||||
new ByteBufferImageBuilder(buffer, width, height, MPImage.IMAGE_FORMAT_ALPHA);
|
||||
categoryMask = Optional.of(builder.build());
|
||||
}
|
||||
|
||||
return ImageSegmenterResult.create(
|
||||
Optional.of(segmentedMasks),
|
||||
Optional.empty(),
|
||||
confidenceMasks,
|
||||
categoryMask,
|
||||
BaseVisionTaskApi.generateResultTimestampMs(
|
||||
RunningMode.IMAGE, packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX)));
|
||||
RunningMode.IMAGE, packets.get(IMAGE_OUT_STREAM_INDEX)));
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -195,7 +222,7 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
|
|||
.setTaskRunningModeName(RunningMode.IMAGE.name())
|
||||
.setTaskGraphName(TASK_GRAPH_NAME)
|
||||
.setInputStreams(INPUT_STREAMS)
|
||||
.setOutputStreams(OUTPUT_STREAMS)
|
||||
.setOutputStreams(outputStreams)
|
||||
.setTaskOptions(segmenterOptions)
|
||||
.setEnableFlowLimiting(false)
|
||||
.build(),
|
||||
|
@ -394,8 +421,11 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
|
|||
/** Sets the base options for the image segmenter task. */
|
||||
public abstract Builder setBaseOptions(BaseOptions value);
|
||||
|
||||
/** The output type from image segmenter. */
|
||||
public abstract Builder setOutputType(OutputType value);
|
||||
/** Sets whether to output confidence masks. Default to true. */
|
||||
public abstract Builder setOutputConfidenceMasks(boolean value);
|
||||
|
||||
/** Sets whether to output category mask. Default to false. */
|
||||
public abstract Builder setOutputCategoryMask(boolean value);
|
||||
|
||||
/**
|
||||
* Sets an optional {@link ResultListener} to receive the segmentation results when the graph
|
||||
|
@ -417,25 +447,18 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
|
|||
|
||||
abstract BaseOptions baseOptions();
|
||||
|
||||
abstract OutputType outputType();
|
||||
abstract boolean outputConfidenceMasks();
|
||||
|
||||
abstract boolean outputCategoryMask();
|
||||
|
||||
abstract Optional<ResultListener<ImageSegmenterResult, MPImage>> resultListener();
|
||||
|
||||
abstract Optional<ErrorListener> errorListener();
|
||||
|
||||
/** The output type of segmentation results. */
|
||||
public enum OutputType {
|
||||
// Gives a single output mask where each pixel represents the class which
|
||||
// the pixel in the original image was predicted to belong to.
|
||||
CATEGORY_MASK,
|
||||
// Gives a list of output masks where, for each mask, each pixel represents
|
||||
// the prediction confidence, usually in the [0, 1] range.
|
||||
CONFIDENCE_MASK
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new AutoValue_InteractiveSegmenter_InteractiveSegmenterOptions.Builder()
|
||||
.setOutputType(OutputType.CATEGORY_MASK);
|
||||
.setOutputConfidenceMasks(true)
|
||||
.setOutputCategoryMask(false);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -454,14 +477,6 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
|
|||
|
||||
SegmenterOptionsProto.SegmenterOptions.Builder segmenterOptionsBuilder =
|
||||
SegmenterOptionsProto.SegmenterOptions.newBuilder();
|
||||
if (outputType() == OutputType.CONFIDENCE_MASK) {
|
||||
segmenterOptionsBuilder.setOutputType(
|
||||
SegmenterOptionsProto.SegmenterOptions.OutputType.CONFIDENCE_MASK);
|
||||
} else if (outputType() == OutputType.CATEGORY_MASK) {
|
||||
segmenterOptionsBuilder.setOutputType(
|
||||
SegmenterOptionsProto.SegmenterOptions.OutputType.CATEGORY_MASK);
|
||||
}
|
||||
|
||||
taskOptionsBuilder.setSegmenterOptions(segmenterOptionsBuilder);
|
||||
return CalculatorOptions.newBuilder()
|
||||
.setExtension(
|
||||
|
|
|
@ -53,18 +53,15 @@ public class InteractiveSegmenterTest {
|
|||
InteractiveSegmenterOptions options =
|
||||
InteractiveSegmenterOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
||||
.setOutputType(InteractiveSegmenterOptions.OutputType.CATEGORY_MASK)
|
||||
.setOutputConfidenceMasks(false)
|
||||
.setOutputCategoryMask(true)
|
||||
.build();
|
||||
InteractiveSegmenter imageSegmenter =
|
||||
InteractiveSegmenter.createFromOptions(
|
||||
ApplicationProvider.getApplicationContext(), options);
|
||||
MPImage image = getImageFromAsset(inputImageName);
|
||||
ImageSegmenterResult actualResult = imageSegmenter.segment(image, roi);
|
||||
// TODO update to correct category mask output.
|
||||
// After InteractiveSegmenter updated according to (b/276519300), update this to use
|
||||
// categoryMask field instead of confidenceMasks.
|
||||
List<MPImage> segmentations = actualResult.confidenceMasks().get();
|
||||
assertThat(segmentations.size()).isEqualTo(1);
|
||||
assertThat(actualResult.categoryMask().isPresent()).isTrue();
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -75,15 +72,17 @@ public class InteractiveSegmenterTest {
|
|||
InteractiveSegmenterOptions options =
|
||||
InteractiveSegmenterOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
||||
.setOutputType(InteractiveSegmenterOptions.OutputType.CONFIDENCE_MASK)
|
||||
.setOutputConfidenceMasks(true)
|
||||
.setOutputCategoryMask(false)
|
||||
.build();
|
||||
InteractiveSegmenter imageSegmenter =
|
||||
InteractiveSegmenter.createFromOptions(
|
||||
ApplicationProvider.getApplicationContext(), options);
|
||||
ImageSegmenterResult actualResult =
|
||||
imageSegmenter.segment(getImageFromAsset(inputImageName), roi);
|
||||
List<MPImage> segmentations = actualResult.confidenceMasks().get();
|
||||
assertThat(segmentations.size()).isEqualTo(2);
|
||||
assertThat(actualResult.confidenceMasks().isPresent()).isTrue();
|
||||
List<MPImage> confidenceMasks = actualResult.confidenceMasks().get();
|
||||
assertThat(confidenceMasks.size()).isEqualTo(2);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user