Internal change
PiperOrigin-RevId: 527430483
This commit is contained in:
parent
b05fd21709
commit
2122b5d7be
|
@ -80,7 +80,6 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
|
||||||
private static final List<String> INPUT_STREAMS =
|
private static final List<String> INPUT_STREAMS =
|
||||||
Collections.unmodifiableList(
|
Collections.unmodifiableList(
|
||||||
Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME));
|
Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME));
|
||||||
private static final int IMAGE_OUT_STREAM_INDEX = 0;
|
|
||||||
private static final String TASK_GRAPH_NAME =
|
private static final String TASK_GRAPH_NAME =
|
||||||
"mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph";
|
"mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph";
|
||||||
private static final String TENSORS_TO_SEGMENTATION_CALCULATOR_NAME =
|
private static final String TENSORS_TO_SEGMENTATION_CALCULATOR_NAME =
|
||||||
|
@ -102,27 +101,21 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
|
||||||
"At least one of `outputConfidenceMasks` and `outputCategoryMask` must be set.");
|
"At least one of `outputConfidenceMasks` and `outputCategoryMask` must be set.");
|
||||||
}
|
}
|
||||||
List<String> outputStreams = new ArrayList<>();
|
List<String> outputStreams = new ArrayList<>();
|
||||||
outputStreams.add("IMAGE:image_out");
|
|
||||||
|
|
||||||
// Add an output stream to the output stream list, and get the added output stream index.
|
// Add an output stream to the output stream list, and get the added output stream index.
|
||||||
BiFunction<List<String>, String, Integer> getStreamIndex =
|
BiFunction<List<String>, String, Integer> getStreamIndex =
|
||||||
(List<String> streams, String streamName) -> {
|
(List<String> streams, String streamName) -> {
|
||||||
streams.add(streamName);
|
streams.add(streamName);
|
||||||
return streams.size() - 1;
|
return streams.size() - 1;
|
||||||
};
|
};
|
||||||
|
final int confidenceMasksOutStreamIndex =
|
||||||
int confidenceMasksOutStreamIndex =
|
|
||||||
segmenterOptions.outputConfidenceMasks()
|
segmenterOptions.outputConfidenceMasks()
|
||||||
? getStreamIndex.apply(outputStreams, "CONFIDENCE_MASKS:confidence_masks")
|
? getStreamIndex.apply(outputStreams, "CONFIDENCE_MASKS:confidence_masks")
|
||||||
: -1;
|
: -1;
|
||||||
int confidenceMaskOutStreamIndex =
|
final int categoryMaskOutStreamIndex =
|
||||||
segmenterOptions.outputConfidenceMasks()
|
|
||||||
? getStreamIndex.apply(outputStreams, "CONFIDENCE_MASK:0:confidence_mask")
|
|
||||||
: -1;
|
|
||||||
int categoryMaskOutStreamIndex =
|
|
||||||
segmenterOptions.outputCategoryMask()
|
segmenterOptions.outputCategoryMask()
|
||||||
? getStreamIndex.apply(outputStreams, "CATEGORY_MASK:category_mask")
|
? getStreamIndex.apply(outputStreams, "CATEGORY_MASK:category_mask")
|
||||||
: -1;
|
: -1;
|
||||||
|
final int imageOutStreamIndex = getStreamIndex.apply(outputStreams, "IMAGE:image_out");
|
||||||
|
|
||||||
// TODO: Consolidate OutputHandler and TaskRunner.
|
// TODO: Consolidate OutputHandler and TaskRunner.
|
||||||
OutputHandler<ImageSegmenterResult, MPImage> handler = new OutputHandler<>();
|
OutputHandler<ImageSegmenterResult, MPImage> handler = new OutputHandler<>();
|
||||||
|
@ -131,17 +124,21 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
|
||||||
@Override
|
@Override
|
||||||
public ImageSegmenterResult convertToTaskResult(List<Packet> packets)
|
public ImageSegmenterResult convertToTaskResult(List<Packet> packets)
|
||||||
throws MediaPipeException {
|
throws MediaPipeException {
|
||||||
if (packets.get(IMAGE_OUT_STREAM_INDEX).isEmpty()) {
|
if (packets.get(imageOutStreamIndex).isEmpty()) {
|
||||||
return ImageSegmenterResult.create(
|
return ImageSegmenterResult.create(
|
||||||
Optional.empty(),
|
Optional.empty(),
|
||||||
Optional.empty(),
|
Optional.empty(),
|
||||||
packets.get(IMAGE_OUT_STREAM_INDEX).getTimestamp());
|
packets.get(imageOutStreamIndex).getTimestamp());
|
||||||
}
|
}
|
||||||
boolean copyImage = !segmenterOptions.resultListener().isPresent();
|
boolean copyImage = !segmenterOptions.resultListener().isPresent();
|
||||||
Optional<List<MPImage>> confidenceMasks = Optional.empty();
|
Optional<List<MPImage>> confidenceMasks = Optional.empty();
|
||||||
if (segmenterOptions.outputConfidenceMasks()) {
|
if (segmenterOptions.outputConfidenceMasks()) {
|
||||||
int width = PacketGetter.getImageWidth(packets.get(confidenceMaskOutStreamIndex));
|
int width =
|
||||||
int height = PacketGetter.getImageHeight(packets.get(confidenceMaskOutStreamIndex));
|
PacketGetter.getImageWidthFromImageList(
|
||||||
|
packets.get(confidenceMasksOutStreamIndex));
|
||||||
|
int height =
|
||||||
|
PacketGetter.getImageHeightFromImageList(
|
||||||
|
packets.get(confidenceMasksOutStreamIndex));
|
||||||
confidenceMasks = Optional.of(new ArrayList<MPImage>());
|
confidenceMasks = Optional.of(new ArrayList<MPImage>());
|
||||||
int confidenceMasksListSize =
|
int confidenceMasksListSize =
|
||||||
PacketGetter.getImageListSize(packets.get(confidenceMasksOutStreamIndex));
|
PacketGetter.getImageListSize(packets.get(confidenceMasksOutStreamIndex));
|
||||||
|
@ -189,13 +186,13 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
|
||||||
confidenceMasks,
|
confidenceMasks,
|
||||||
categoryMask,
|
categoryMask,
|
||||||
BaseVisionTaskApi.generateResultTimestampMs(
|
BaseVisionTaskApi.generateResultTimestampMs(
|
||||||
segmenterOptions.runningMode(), packets.get(IMAGE_OUT_STREAM_INDEX)));
|
segmenterOptions.runningMode(), packets.get(imageOutStreamIndex)));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public MPImage convertToTaskInput(List<Packet> packets) {
|
public MPImage convertToTaskInput(List<Packet> packets) {
|
||||||
return new BitmapImageBuilder(
|
return new BitmapImageBuilder(
|
||||||
AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX)))
|
AndroidPacketGetter.getBitmapFromRgb(packets.get(imageOutStreamIndex)))
|
||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
|
@ -94,7 +94,6 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
|
||||||
"IMAGE:" + IMAGE_IN_STREAM_NAME,
|
"IMAGE:" + IMAGE_IN_STREAM_NAME,
|
||||||
"ROI:" + ROI_IN_STREAM_NAME,
|
"ROI:" + ROI_IN_STREAM_NAME,
|
||||||
"NORM_RECT:" + NORM_RECT_IN_STREAM_NAME));
|
"NORM_RECT:" + NORM_RECT_IN_STREAM_NAME));
|
||||||
private static final int IMAGE_OUT_STREAM_INDEX = 0;
|
|
||||||
private static final String TASK_GRAPH_NAME =
|
private static final String TASK_GRAPH_NAME =
|
||||||
"mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph";
|
"mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph";
|
||||||
private static final String TENSORS_TO_SEGMENTATION_CALCULATOR_NAME =
|
private static final String TENSORS_TO_SEGMENTATION_CALCULATOR_NAME =
|
||||||
|
@ -120,7 +119,6 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
|
||||||
"At least one of `outputConfidenceMasks` and `outputCategoryMask` must be set.");
|
"At least one of `outputConfidenceMasks` and `outputCategoryMask` must be set.");
|
||||||
}
|
}
|
||||||
List<String> outputStreams = new ArrayList<>();
|
List<String> outputStreams = new ArrayList<>();
|
||||||
outputStreams.add("IMAGE:image_out");
|
|
||||||
if (segmenterOptions.outputConfidenceMasks()) {
|
if (segmenterOptions.outputConfidenceMasks()) {
|
||||||
outputStreams.add("CONFIDENCE_MASKS:confidence_masks");
|
outputStreams.add("CONFIDENCE_MASKS:confidence_masks");
|
||||||
}
|
}
|
||||||
|
@ -129,6 +127,9 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
|
||||||
outputStreams.add("CATEGORY_MASK:category_mask");
|
outputStreams.add("CATEGORY_MASK:category_mask");
|
||||||
}
|
}
|
||||||
final int categoryMaskOutStreamIndex = outputStreams.size() - 1;
|
final int categoryMaskOutStreamIndex = outputStreams.size() - 1;
|
||||||
|
outputStreams.add("IMAGE:image_out");
|
||||||
|
// TODO: add test for stream indices.
|
||||||
|
final int imageOutStreamIndex = outputStreams.size() - 1;
|
||||||
|
|
||||||
// TODO: Consolidate OutputHandler and TaskRunner.
|
// TODO: Consolidate OutputHandler and TaskRunner.
|
||||||
OutputHandler<ImageSegmenterResult, MPImage> handler = new OutputHandler<>();
|
OutputHandler<ImageSegmenterResult, MPImage> handler = new OutputHandler<>();
|
||||||
|
@ -137,11 +138,11 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
|
||||||
@Override
|
@Override
|
||||||
public ImageSegmenterResult convertToTaskResult(List<Packet> packets)
|
public ImageSegmenterResult convertToTaskResult(List<Packet> packets)
|
||||||
throws MediaPipeException {
|
throws MediaPipeException {
|
||||||
if (packets.get(IMAGE_OUT_STREAM_INDEX).isEmpty()) {
|
if (packets.get(imageOutStreamIndex).isEmpty()) {
|
||||||
return ImageSegmenterResult.create(
|
return ImageSegmenterResult.create(
|
||||||
Optional.empty(),
|
Optional.empty(),
|
||||||
Optional.empty(),
|
Optional.empty(),
|
||||||
packets.get(IMAGE_OUT_STREAM_INDEX).getTimestamp());
|
packets.get(imageOutStreamIndex).getTimestamp());
|
||||||
}
|
}
|
||||||
// If resultListener is not provided, the resulted MPImage is deep copied from
|
// If resultListener is not provided, the resulted MPImage is deep copied from
|
||||||
// mediapipe graph. If provided, the result MPImage is wrapping the mediapipe packet
|
// mediapipe graph. If provided, the result MPImage is wrapping the mediapipe packet
|
||||||
|
@ -202,13 +203,13 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
|
||||||
confidenceMasks,
|
confidenceMasks,
|
||||||
categoryMask,
|
categoryMask,
|
||||||
BaseVisionTaskApi.generateResultTimestampMs(
|
BaseVisionTaskApi.generateResultTimestampMs(
|
||||||
RunningMode.IMAGE, packets.get(IMAGE_OUT_STREAM_INDEX)));
|
RunningMode.IMAGE, packets.get(imageOutStreamIndex)));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public MPImage convertToTaskInput(List<Packet> packets) {
|
public MPImage convertToTaskInput(List<Packet> packets) {
|
||||||
return new BitmapImageBuilder(
|
return new BitmapImageBuilder(
|
||||||
AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX)))
|
AndroidPacketGetter.getBitmapFromRgb(packets.get(imageOutStreamIndex)))
|
||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
Loading…
Reference in New Issue
Block a user