Internal change

PiperOrigin-RevId: 527430483
This commit is contained in:
MediaPipe Team 2023-04-26 18:15:46 -07:00 committed by Copybara-Service
parent b05fd21709
commit 2122b5d7be
2 changed files with 20 additions and 22 deletions

View File

@ -80,7 +80,6 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
private static final List<String> INPUT_STREAMS =
Collections.unmodifiableList(
Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME));
private static final int IMAGE_OUT_STREAM_INDEX = 0;
private static final String TASK_GRAPH_NAME =
"mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph";
private static final String TENSORS_TO_SEGMENTATION_CALCULATOR_NAME =
@ -102,27 +101,21 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
"At least one of `outputConfidenceMasks` and `outputCategoryMask` must be set.");
}
List<String> outputStreams = new ArrayList<>();
outputStreams.add("IMAGE:image_out");
// Add an output stream to the output stream list, and get the added output stream index.
BiFunction<List<String>, String, Integer> getStreamIndex =
(List<String> streams, String streamName) -> {
streams.add(streamName);
return streams.size() - 1;
};
int confidenceMasksOutStreamIndex =
final int confidenceMasksOutStreamIndex =
segmenterOptions.outputConfidenceMasks()
? getStreamIndex.apply(outputStreams, "CONFIDENCE_MASKS:confidence_masks")
: -1;
int confidenceMaskOutStreamIndex =
segmenterOptions.outputConfidenceMasks()
? getStreamIndex.apply(outputStreams, "CONFIDENCE_MASK:0:confidence_mask")
: -1;
int categoryMaskOutStreamIndex =
final int categoryMaskOutStreamIndex =
segmenterOptions.outputCategoryMask()
? getStreamIndex.apply(outputStreams, "CATEGORY_MASK:category_mask")
: -1;
final int imageOutStreamIndex = getStreamIndex.apply(outputStreams, "IMAGE:image_out");
// TODO: Consolidate OutputHandler and TaskRunner.
OutputHandler<ImageSegmenterResult, MPImage> handler = new OutputHandler<>();
@ -131,17 +124,21 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
@Override
public ImageSegmenterResult convertToTaskResult(List<Packet> packets)
throws MediaPipeException {
if (packets.get(IMAGE_OUT_STREAM_INDEX).isEmpty()) {
if (packets.get(imageOutStreamIndex).isEmpty()) {
return ImageSegmenterResult.create(
Optional.empty(),
Optional.empty(),
packets.get(IMAGE_OUT_STREAM_INDEX).getTimestamp());
packets.get(imageOutStreamIndex).getTimestamp());
}
boolean copyImage = !segmenterOptions.resultListener().isPresent();
Optional<List<MPImage>> confidenceMasks = Optional.empty();
if (segmenterOptions.outputConfidenceMasks()) {
int width = PacketGetter.getImageWidth(packets.get(confidenceMaskOutStreamIndex));
int height = PacketGetter.getImageHeight(packets.get(confidenceMaskOutStreamIndex));
int width =
PacketGetter.getImageWidthFromImageList(
packets.get(confidenceMasksOutStreamIndex));
int height =
PacketGetter.getImageHeightFromImageList(
packets.get(confidenceMasksOutStreamIndex));
confidenceMasks = Optional.of(new ArrayList<MPImage>());
int confidenceMasksListSize =
PacketGetter.getImageListSize(packets.get(confidenceMasksOutStreamIndex));
@ -189,13 +186,13 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
confidenceMasks,
categoryMask,
BaseVisionTaskApi.generateResultTimestampMs(
segmenterOptions.runningMode(), packets.get(IMAGE_OUT_STREAM_INDEX)));
segmenterOptions.runningMode(), packets.get(imageOutStreamIndex)));
}
@Override
public MPImage convertToTaskInput(List<Packet> packets) {
return new BitmapImageBuilder(
AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX)))
AndroidPacketGetter.getBitmapFromRgb(packets.get(imageOutStreamIndex)))
.build();
}
});

View File

@ -94,7 +94,6 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
"IMAGE:" + IMAGE_IN_STREAM_NAME,
"ROI:" + ROI_IN_STREAM_NAME,
"NORM_RECT:" + NORM_RECT_IN_STREAM_NAME));
private static final int IMAGE_OUT_STREAM_INDEX = 0;
private static final String TASK_GRAPH_NAME =
"mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph";
private static final String TENSORS_TO_SEGMENTATION_CALCULATOR_NAME =
@ -120,7 +119,6 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
"At least one of `outputConfidenceMasks` and `outputCategoryMask` must be set.");
}
List<String> outputStreams = new ArrayList<>();
outputStreams.add("IMAGE:image_out");
if (segmenterOptions.outputConfidenceMasks()) {
outputStreams.add("CONFIDENCE_MASKS:confidence_masks");
}
@ -129,6 +127,9 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
outputStreams.add("CATEGORY_MASK:category_mask");
}
final int categoryMaskOutStreamIndex = outputStreams.size() - 1;
outputStreams.add("IMAGE:image_out");
// TODO: add test for stream indices.
final int imageOutStreamIndex = outputStreams.size() - 1;
// TODO: Consolidate OutputHandler and TaskRunner.
OutputHandler<ImageSegmenterResult, MPImage> handler = new OutputHandler<>();
@ -137,11 +138,11 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
@Override
public ImageSegmenterResult convertToTaskResult(List<Packet> packets)
throws MediaPipeException {
if (packets.get(IMAGE_OUT_STREAM_INDEX).isEmpty()) {
if (packets.get(imageOutStreamIndex).isEmpty()) {
return ImageSegmenterResult.create(
Optional.empty(),
Optional.empty(),
packets.get(IMAGE_OUT_STREAM_INDEX).getTimestamp());
packets.get(imageOutStreamIndex).getTimestamp());
}
// If resultListener is not provided, the resulted MPImage is deep copied from
// mediapipe graph. If provided, the result MPImage is wrapping the mediapipe packet
@ -202,13 +203,13 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
confidenceMasks,
categoryMask,
BaseVisionTaskApi.generateResultTimestampMs(
RunningMode.IMAGE, packets.get(IMAGE_OUT_STREAM_INDEX)));
RunningMode.IMAGE, packets.get(imageOutStreamIndex)));
}
@Override
public MPImage convertToTaskInput(List<Packet> packets) {
return new BitmapImageBuilder(
AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX)))
AndroidPacketGetter.getBitmapFromRgb(packets.get(imageOutStreamIndex)))
.build();
}
});