Internal change

PiperOrigin-RevId: 527430483
This commit is contained in:
MediaPipe Team 2023-04-26 18:15:46 -07:00 committed by Copybara-Service
parent b05fd21709
commit 2122b5d7be
2 changed files with 20 additions and 22 deletions

View File

@ -80,7 +80,6 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
private static final List<String> INPUT_STREAMS = private static final List<String> INPUT_STREAMS =
Collections.unmodifiableList( Collections.unmodifiableList(
Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME)); Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME));
private static final int IMAGE_OUT_STREAM_INDEX = 0;
private static final String TASK_GRAPH_NAME = private static final String TASK_GRAPH_NAME =
"mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph"; "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph";
private static final String TENSORS_TO_SEGMENTATION_CALCULATOR_NAME = private static final String TENSORS_TO_SEGMENTATION_CALCULATOR_NAME =
@ -102,27 +101,21 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
"At least one of `outputConfidenceMasks` and `outputCategoryMask` must be set."); "At least one of `outputConfidenceMasks` and `outputCategoryMask` must be set.");
} }
List<String> outputStreams = new ArrayList<>(); List<String> outputStreams = new ArrayList<>();
outputStreams.add("IMAGE:image_out");
// Add an output stream to the output stream list, and get the added output stream index. // Add an output stream to the output stream list, and get the added output stream index.
BiFunction<List<String>, String, Integer> getStreamIndex = BiFunction<List<String>, String, Integer> getStreamIndex =
(List<String> streams, String streamName) -> { (List<String> streams, String streamName) -> {
streams.add(streamName); streams.add(streamName);
return streams.size() - 1; return streams.size() - 1;
}; };
final int confidenceMasksOutStreamIndex =
int confidenceMasksOutStreamIndex =
segmenterOptions.outputConfidenceMasks() segmenterOptions.outputConfidenceMasks()
? getStreamIndex.apply(outputStreams, "CONFIDENCE_MASKS:confidence_masks") ? getStreamIndex.apply(outputStreams, "CONFIDENCE_MASKS:confidence_masks")
: -1; : -1;
int confidenceMaskOutStreamIndex = final int categoryMaskOutStreamIndex =
segmenterOptions.outputConfidenceMasks()
? getStreamIndex.apply(outputStreams, "CONFIDENCE_MASK:0:confidence_mask")
: -1;
int categoryMaskOutStreamIndex =
segmenterOptions.outputCategoryMask() segmenterOptions.outputCategoryMask()
? getStreamIndex.apply(outputStreams, "CATEGORY_MASK:category_mask") ? getStreamIndex.apply(outputStreams, "CATEGORY_MASK:category_mask")
: -1; : -1;
final int imageOutStreamIndex = getStreamIndex.apply(outputStreams, "IMAGE:image_out");
// TODO: Consolidate OutputHandler and TaskRunner. // TODO: Consolidate OutputHandler and TaskRunner.
OutputHandler<ImageSegmenterResult, MPImage> handler = new OutputHandler<>(); OutputHandler<ImageSegmenterResult, MPImage> handler = new OutputHandler<>();
@ -131,17 +124,21 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
@Override @Override
public ImageSegmenterResult convertToTaskResult(List<Packet> packets) public ImageSegmenterResult convertToTaskResult(List<Packet> packets)
throws MediaPipeException { throws MediaPipeException {
if (packets.get(IMAGE_OUT_STREAM_INDEX).isEmpty()) { if (packets.get(imageOutStreamIndex).isEmpty()) {
return ImageSegmenterResult.create( return ImageSegmenterResult.create(
Optional.empty(), Optional.empty(),
Optional.empty(), Optional.empty(),
packets.get(IMAGE_OUT_STREAM_INDEX).getTimestamp()); packets.get(imageOutStreamIndex).getTimestamp());
} }
boolean copyImage = !segmenterOptions.resultListener().isPresent(); boolean copyImage = !segmenterOptions.resultListener().isPresent();
Optional<List<MPImage>> confidenceMasks = Optional.empty(); Optional<List<MPImage>> confidenceMasks = Optional.empty();
if (segmenterOptions.outputConfidenceMasks()) { if (segmenterOptions.outputConfidenceMasks()) {
int width = PacketGetter.getImageWidth(packets.get(confidenceMaskOutStreamIndex)); int width =
int height = PacketGetter.getImageHeight(packets.get(confidenceMaskOutStreamIndex)); PacketGetter.getImageWidthFromImageList(
packets.get(confidenceMasksOutStreamIndex));
int height =
PacketGetter.getImageHeightFromImageList(
packets.get(confidenceMasksOutStreamIndex));
confidenceMasks = Optional.of(new ArrayList<MPImage>()); confidenceMasks = Optional.of(new ArrayList<MPImage>());
int confidenceMasksListSize = int confidenceMasksListSize =
PacketGetter.getImageListSize(packets.get(confidenceMasksOutStreamIndex)); PacketGetter.getImageListSize(packets.get(confidenceMasksOutStreamIndex));
@ -189,13 +186,13 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
confidenceMasks, confidenceMasks,
categoryMask, categoryMask,
BaseVisionTaskApi.generateResultTimestampMs( BaseVisionTaskApi.generateResultTimestampMs(
segmenterOptions.runningMode(), packets.get(IMAGE_OUT_STREAM_INDEX))); segmenterOptions.runningMode(), packets.get(imageOutStreamIndex)));
} }
@Override @Override
public MPImage convertToTaskInput(List<Packet> packets) { public MPImage convertToTaskInput(List<Packet> packets) {
return new BitmapImageBuilder( return new BitmapImageBuilder(
AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX))) AndroidPacketGetter.getBitmapFromRgb(packets.get(imageOutStreamIndex)))
.build(); .build();
} }
}); });

View File

@ -94,7 +94,6 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
"IMAGE:" + IMAGE_IN_STREAM_NAME, "IMAGE:" + IMAGE_IN_STREAM_NAME,
"ROI:" + ROI_IN_STREAM_NAME, "ROI:" + ROI_IN_STREAM_NAME,
"NORM_RECT:" + NORM_RECT_IN_STREAM_NAME)); "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME));
private static final int IMAGE_OUT_STREAM_INDEX = 0;
private static final String TASK_GRAPH_NAME = private static final String TASK_GRAPH_NAME =
"mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph"; "mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph";
private static final String TENSORS_TO_SEGMENTATION_CALCULATOR_NAME = private static final String TENSORS_TO_SEGMENTATION_CALCULATOR_NAME =
@ -120,7 +119,6 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
"At least one of `outputConfidenceMasks` and `outputCategoryMask` must be set."); "At least one of `outputConfidenceMasks` and `outputCategoryMask` must be set.");
} }
List<String> outputStreams = new ArrayList<>(); List<String> outputStreams = new ArrayList<>();
outputStreams.add("IMAGE:image_out");
if (segmenterOptions.outputConfidenceMasks()) { if (segmenterOptions.outputConfidenceMasks()) {
outputStreams.add("CONFIDENCE_MASKS:confidence_masks"); outputStreams.add("CONFIDENCE_MASKS:confidence_masks");
} }
@ -129,6 +127,9 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
outputStreams.add("CATEGORY_MASK:category_mask"); outputStreams.add("CATEGORY_MASK:category_mask");
} }
final int categoryMaskOutStreamIndex = outputStreams.size() - 1; final int categoryMaskOutStreamIndex = outputStreams.size() - 1;
outputStreams.add("IMAGE:image_out");
// TODO: add test for stream indices.
final int imageOutStreamIndex = outputStreams.size() - 1;
// TODO: Consolidate OutputHandler and TaskRunner. // TODO: Consolidate OutputHandler and TaskRunner.
OutputHandler<ImageSegmenterResult, MPImage> handler = new OutputHandler<>(); OutputHandler<ImageSegmenterResult, MPImage> handler = new OutputHandler<>();
@ -137,11 +138,11 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
@Override @Override
public ImageSegmenterResult convertToTaskResult(List<Packet> packets) public ImageSegmenterResult convertToTaskResult(List<Packet> packets)
throws MediaPipeException { throws MediaPipeException {
if (packets.get(IMAGE_OUT_STREAM_INDEX).isEmpty()) { if (packets.get(imageOutStreamIndex).isEmpty()) {
return ImageSegmenterResult.create( return ImageSegmenterResult.create(
Optional.empty(), Optional.empty(),
Optional.empty(), Optional.empty(),
packets.get(IMAGE_OUT_STREAM_INDEX).getTimestamp()); packets.get(imageOutStreamIndex).getTimestamp());
} }
// If resultListener is not provided, the resulted MPImage is deep copied from // If resultListener is not provided, the resulted MPImage is deep copied from
// mediapipe graph. If provided, the result MPImage is wrapping the mediapipe packet // mediapipe graph. If provided, the result MPImage is wrapping the mediapipe packet
@ -202,13 +203,13 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
confidenceMasks, confidenceMasks,
categoryMask, categoryMask,
BaseVisionTaskApi.generateResultTimestampMs( BaseVisionTaskApi.generateResultTimestampMs(
RunningMode.IMAGE, packets.get(IMAGE_OUT_STREAM_INDEX))); RunningMode.IMAGE, packets.get(imageOutStreamIndex)));
} }
@Override @Override
public MPImage convertToTaskInput(List<Packet> packets) { public MPImage convertToTaskInput(List<Packet> packets) {
return new BitmapImageBuilder( return new BitmapImageBuilder(
AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX))) AndroidPacketGetter.getBitmapFromRgb(packets.get(imageOutStreamIndex)))
.build(); .build();
} }
}); });