Internal change
PiperOrigin-RevId: 527430483
This commit is contained in:
parent
b05fd21709
commit
2122b5d7be
|
@ -80,7 +80,6 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
|
|||
private static final List<String> INPUT_STREAMS =
|
||||
Collections.unmodifiableList(
|
||||
Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME));
|
||||
private static final int IMAGE_OUT_STREAM_INDEX = 0;
|
||||
private static final String TASK_GRAPH_NAME =
|
||||
"mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph";
|
||||
private static final String TENSORS_TO_SEGMENTATION_CALCULATOR_NAME =
|
||||
|
@ -102,27 +101,21 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
|
|||
"At least one of `outputConfidenceMasks` and `outputCategoryMask` must be set.");
|
||||
}
|
||||
List<String> outputStreams = new ArrayList<>();
|
||||
outputStreams.add("IMAGE:image_out");
|
||||
|
||||
// Add an output stream to the output stream list, and get the added output stream index.
|
||||
BiFunction<List<String>, String, Integer> getStreamIndex =
|
||||
(List<String> streams, String streamName) -> {
|
||||
streams.add(streamName);
|
||||
return streams.size() - 1;
|
||||
};
|
||||
|
||||
int confidenceMasksOutStreamIndex =
|
||||
final int confidenceMasksOutStreamIndex =
|
||||
segmenterOptions.outputConfidenceMasks()
|
||||
? getStreamIndex.apply(outputStreams, "CONFIDENCE_MASKS:confidence_masks")
|
||||
: -1;
|
||||
int confidenceMaskOutStreamIndex =
|
||||
segmenterOptions.outputConfidenceMasks()
|
||||
? getStreamIndex.apply(outputStreams, "CONFIDENCE_MASK:0:confidence_mask")
|
||||
: -1;
|
||||
int categoryMaskOutStreamIndex =
|
||||
final int categoryMaskOutStreamIndex =
|
||||
segmenterOptions.outputCategoryMask()
|
||||
? getStreamIndex.apply(outputStreams, "CATEGORY_MASK:category_mask")
|
||||
: -1;
|
||||
final int imageOutStreamIndex = getStreamIndex.apply(outputStreams, "IMAGE:image_out");
|
||||
|
||||
// TODO: Consolidate OutputHandler and TaskRunner.
|
||||
OutputHandler<ImageSegmenterResult, MPImage> handler = new OutputHandler<>();
|
||||
|
@ -131,17 +124,21 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
|
|||
@Override
|
||||
public ImageSegmenterResult convertToTaskResult(List<Packet> packets)
|
||||
throws MediaPipeException {
|
||||
if (packets.get(IMAGE_OUT_STREAM_INDEX).isEmpty()) {
|
||||
if (packets.get(imageOutStreamIndex).isEmpty()) {
|
||||
return ImageSegmenterResult.create(
|
||||
Optional.empty(),
|
||||
Optional.empty(),
|
||||
packets.get(IMAGE_OUT_STREAM_INDEX).getTimestamp());
|
||||
packets.get(imageOutStreamIndex).getTimestamp());
|
||||
}
|
||||
boolean copyImage = !segmenterOptions.resultListener().isPresent();
|
||||
Optional<List<MPImage>> confidenceMasks = Optional.empty();
|
||||
if (segmenterOptions.outputConfidenceMasks()) {
|
||||
int width = PacketGetter.getImageWidth(packets.get(confidenceMaskOutStreamIndex));
|
||||
int height = PacketGetter.getImageHeight(packets.get(confidenceMaskOutStreamIndex));
|
||||
int width =
|
||||
PacketGetter.getImageWidthFromImageList(
|
||||
packets.get(confidenceMasksOutStreamIndex));
|
||||
int height =
|
||||
PacketGetter.getImageHeightFromImageList(
|
||||
packets.get(confidenceMasksOutStreamIndex));
|
||||
confidenceMasks = Optional.of(new ArrayList<MPImage>());
|
||||
int confidenceMasksListSize =
|
||||
PacketGetter.getImageListSize(packets.get(confidenceMasksOutStreamIndex));
|
||||
|
@ -189,13 +186,13 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
|
|||
confidenceMasks,
|
||||
categoryMask,
|
||||
BaseVisionTaskApi.generateResultTimestampMs(
|
||||
segmenterOptions.runningMode(), packets.get(IMAGE_OUT_STREAM_INDEX)));
|
||||
segmenterOptions.runningMode(), packets.get(imageOutStreamIndex)));
|
||||
}
|
||||
|
||||
@Override
|
||||
public MPImage convertToTaskInput(List<Packet> packets) {
|
||||
return new BitmapImageBuilder(
|
||||
AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX)))
|
||||
AndroidPacketGetter.getBitmapFromRgb(packets.get(imageOutStreamIndex)))
|
||||
.build();
|
||||
}
|
||||
});
|
||||
|
|
|
@ -94,7 +94,6 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
|
|||
"IMAGE:" + IMAGE_IN_STREAM_NAME,
|
||||
"ROI:" + ROI_IN_STREAM_NAME,
|
||||
"NORM_RECT:" + NORM_RECT_IN_STREAM_NAME));
|
||||
private static final int IMAGE_OUT_STREAM_INDEX = 0;
|
||||
private static final String TASK_GRAPH_NAME =
|
||||
"mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph";
|
||||
private static final String TENSORS_TO_SEGMENTATION_CALCULATOR_NAME =
|
||||
|
@ -120,7 +119,6 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
|
|||
"At least one of `outputConfidenceMasks` and `outputCategoryMask` must be set.");
|
||||
}
|
||||
List<String> outputStreams = new ArrayList<>();
|
||||
outputStreams.add("IMAGE:image_out");
|
||||
if (segmenterOptions.outputConfidenceMasks()) {
|
||||
outputStreams.add("CONFIDENCE_MASKS:confidence_masks");
|
||||
}
|
||||
|
@ -129,6 +127,9 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
|
|||
outputStreams.add("CATEGORY_MASK:category_mask");
|
||||
}
|
||||
final int categoryMaskOutStreamIndex = outputStreams.size() - 1;
|
||||
outputStreams.add("IMAGE:image_out");
|
||||
// TODO: add test for stream indices.
|
||||
final int imageOutStreamIndex = outputStreams.size() - 1;
|
||||
|
||||
// TODO: Consolidate OutputHandler and TaskRunner.
|
||||
OutputHandler<ImageSegmenterResult, MPImage> handler = new OutputHandler<>();
|
||||
|
@ -137,11 +138,11 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
|
|||
@Override
|
||||
public ImageSegmenterResult convertToTaskResult(List<Packet> packets)
|
||||
throws MediaPipeException {
|
||||
if (packets.get(IMAGE_OUT_STREAM_INDEX).isEmpty()) {
|
||||
if (packets.get(imageOutStreamIndex).isEmpty()) {
|
||||
return ImageSegmenterResult.create(
|
||||
Optional.empty(),
|
||||
Optional.empty(),
|
||||
packets.get(IMAGE_OUT_STREAM_INDEX).getTimestamp());
|
||||
packets.get(imageOutStreamIndex).getTimestamp());
|
||||
}
|
||||
// If resultListener is not provided, the resulted MPImage is deep copied from
|
||||
// mediapipe graph. If provided, the result MPImage is wrapping the mediapipe packet
|
||||
|
@ -202,13 +203,13 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
|
|||
confidenceMasks,
|
||||
categoryMask,
|
||||
BaseVisionTaskApi.generateResultTimestampMs(
|
||||
RunningMode.IMAGE, packets.get(IMAGE_OUT_STREAM_INDEX)));
|
||||
RunningMode.IMAGE, packets.get(imageOutStreamIndex)));
|
||||
}
|
||||
|
||||
@Override
|
||||
public MPImage convertToTaskInput(List<Packet> packets) {
|
||||
return new BitmapImageBuilder(
|
||||
AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX)))
|
||||
AndroidPacketGetter.getBitmapFromRgb(packets.get(imageOutStreamIndex)))
|
||||
.build();
|
||||
}
|
||||
});
|
||||
|
|
Loading…
Reference in New Issue
Block a user