Internal change
PiperOrigin-RevId: 527430483
This commit is contained in:
		
							parent
							
								
									b05fd21709
								
							
						
					
					
						commit
						2122b5d7be
					
				| 
						 | 
					@ -80,7 +80,6 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
 | 
				
			||||||
  private static final List<String> INPUT_STREAMS =
 | 
					  private static final List<String> INPUT_STREAMS =
 | 
				
			||||||
      Collections.unmodifiableList(
 | 
					      Collections.unmodifiableList(
 | 
				
			||||||
          Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME));
 | 
					          Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME));
 | 
				
			||||||
  private static final int IMAGE_OUT_STREAM_INDEX = 0;
 | 
					 | 
				
			||||||
  private static final String TASK_GRAPH_NAME =
 | 
					  private static final String TASK_GRAPH_NAME =
 | 
				
			||||||
      "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph";
 | 
					      "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph";
 | 
				
			||||||
  private static final String TENSORS_TO_SEGMENTATION_CALCULATOR_NAME =
 | 
					  private static final String TENSORS_TO_SEGMENTATION_CALCULATOR_NAME =
 | 
				
			||||||
| 
						 | 
					@ -102,27 +101,21 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
 | 
				
			||||||
          "At least one of `outputConfidenceMasks` and `outputCategoryMask` must be set.");
 | 
					          "At least one of `outputConfidenceMasks` and `outputCategoryMask` must be set.");
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    List<String> outputStreams = new ArrayList<>();
 | 
					    List<String> outputStreams = new ArrayList<>();
 | 
				
			||||||
    outputStreams.add("IMAGE:image_out");
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    // Add an output stream to the output stream list, and get the added output stream index.
 | 
					    // Add an output stream to the output stream list, and get the added output stream index.
 | 
				
			||||||
    BiFunction<List<String>, String, Integer> getStreamIndex =
 | 
					    BiFunction<List<String>, String, Integer> getStreamIndex =
 | 
				
			||||||
        (List<String> streams, String streamName) -> {
 | 
					        (List<String> streams, String streamName) -> {
 | 
				
			||||||
          streams.add(streamName);
 | 
					          streams.add(streamName);
 | 
				
			||||||
          return streams.size() - 1;
 | 
					          return streams.size() - 1;
 | 
				
			||||||
        };
 | 
					        };
 | 
				
			||||||
 | 
					    final int confidenceMasksOutStreamIndex =
 | 
				
			||||||
    int confidenceMasksOutStreamIndex =
 | 
					 | 
				
			||||||
        segmenterOptions.outputConfidenceMasks()
 | 
					        segmenterOptions.outputConfidenceMasks()
 | 
				
			||||||
            ? getStreamIndex.apply(outputStreams, "CONFIDENCE_MASKS:confidence_masks")
 | 
					            ? getStreamIndex.apply(outputStreams, "CONFIDENCE_MASKS:confidence_masks")
 | 
				
			||||||
            : -1;
 | 
					            : -1;
 | 
				
			||||||
    int confidenceMaskOutStreamIndex =
 | 
					    final int categoryMaskOutStreamIndex =
 | 
				
			||||||
        segmenterOptions.outputConfidenceMasks()
 | 
					 | 
				
			||||||
            ? getStreamIndex.apply(outputStreams, "CONFIDENCE_MASK:0:confidence_mask")
 | 
					 | 
				
			||||||
            : -1;
 | 
					 | 
				
			||||||
    int categoryMaskOutStreamIndex =
 | 
					 | 
				
			||||||
        segmenterOptions.outputCategoryMask()
 | 
					        segmenterOptions.outputCategoryMask()
 | 
				
			||||||
            ? getStreamIndex.apply(outputStreams, "CATEGORY_MASK:category_mask")
 | 
					            ? getStreamIndex.apply(outputStreams, "CATEGORY_MASK:category_mask")
 | 
				
			||||||
            : -1;
 | 
					            : -1;
 | 
				
			||||||
 | 
					    final int imageOutStreamIndex = getStreamIndex.apply(outputStreams, "IMAGE:image_out");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // TODO: Consolidate OutputHandler and TaskRunner.
 | 
					    // TODO: Consolidate OutputHandler and TaskRunner.
 | 
				
			||||||
    OutputHandler<ImageSegmenterResult, MPImage> handler = new OutputHandler<>();
 | 
					    OutputHandler<ImageSegmenterResult, MPImage> handler = new OutputHandler<>();
 | 
				
			||||||
| 
						 | 
					@ -131,17 +124,21 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
 | 
				
			||||||
          @Override
 | 
					          @Override
 | 
				
			||||||
          public ImageSegmenterResult convertToTaskResult(List<Packet> packets)
 | 
					          public ImageSegmenterResult convertToTaskResult(List<Packet> packets)
 | 
				
			||||||
              throws MediaPipeException {
 | 
					              throws MediaPipeException {
 | 
				
			||||||
            if (packets.get(IMAGE_OUT_STREAM_INDEX).isEmpty()) {
 | 
					            if (packets.get(imageOutStreamIndex).isEmpty()) {
 | 
				
			||||||
              return ImageSegmenterResult.create(
 | 
					              return ImageSegmenterResult.create(
 | 
				
			||||||
                  Optional.empty(),
 | 
					                  Optional.empty(),
 | 
				
			||||||
                  Optional.empty(),
 | 
					                  Optional.empty(),
 | 
				
			||||||
                  packets.get(IMAGE_OUT_STREAM_INDEX).getTimestamp());
 | 
					                  packets.get(imageOutStreamIndex).getTimestamp());
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
            boolean copyImage = !segmenterOptions.resultListener().isPresent();
 | 
					            boolean copyImage = !segmenterOptions.resultListener().isPresent();
 | 
				
			||||||
            Optional<List<MPImage>> confidenceMasks = Optional.empty();
 | 
					            Optional<List<MPImage>> confidenceMasks = Optional.empty();
 | 
				
			||||||
            if (segmenterOptions.outputConfidenceMasks()) {
 | 
					            if (segmenterOptions.outputConfidenceMasks()) {
 | 
				
			||||||
              int width = PacketGetter.getImageWidth(packets.get(confidenceMaskOutStreamIndex));
 | 
					              int width =
 | 
				
			||||||
              int height = PacketGetter.getImageHeight(packets.get(confidenceMaskOutStreamIndex));
 | 
					                  PacketGetter.getImageWidthFromImageList(
 | 
				
			||||||
 | 
					                      packets.get(confidenceMasksOutStreamIndex));
 | 
				
			||||||
 | 
					              int height =
 | 
				
			||||||
 | 
					                  PacketGetter.getImageHeightFromImageList(
 | 
				
			||||||
 | 
					                      packets.get(confidenceMasksOutStreamIndex));
 | 
				
			||||||
              confidenceMasks = Optional.of(new ArrayList<MPImage>());
 | 
					              confidenceMasks = Optional.of(new ArrayList<MPImage>());
 | 
				
			||||||
              int confidenceMasksListSize =
 | 
					              int confidenceMasksListSize =
 | 
				
			||||||
                  PacketGetter.getImageListSize(packets.get(confidenceMasksOutStreamIndex));
 | 
					                  PacketGetter.getImageListSize(packets.get(confidenceMasksOutStreamIndex));
 | 
				
			||||||
| 
						 | 
					@ -189,13 +186,13 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
 | 
				
			||||||
                confidenceMasks,
 | 
					                confidenceMasks,
 | 
				
			||||||
                categoryMask,
 | 
					                categoryMask,
 | 
				
			||||||
                BaseVisionTaskApi.generateResultTimestampMs(
 | 
					                BaseVisionTaskApi.generateResultTimestampMs(
 | 
				
			||||||
                    segmenterOptions.runningMode(), packets.get(IMAGE_OUT_STREAM_INDEX)));
 | 
					                    segmenterOptions.runningMode(), packets.get(imageOutStreamIndex)));
 | 
				
			||||||
          }
 | 
					          }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
          @Override
 | 
					          @Override
 | 
				
			||||||
          public MPImage convertToTaskInput(List<Packet> packets) {
 | 
					          public MPImage convertToTaskInput(List<Packet> packets) {
 | 
				
			||||||
            return new BitmapImageBuilder(
 | 
					            return new BitmapImageBuilder(
 | 
				
			||||||
                    AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX)))
 | 
					                    AndroidPacketGetter.getBitmapFromRgb(packets.get(imageOutStreamIndex)))
 | 
				
			||||||
                .build();
 | 
					                .build();
 | 
				
			||||||
          }
 | 
					          }
 | 
				
			||||||
        });
 | 
					        });
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -94,7 +94,6 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
 | 
				
			||||||
              "IMAGE:" + IMAGE_IN_STREAM_NAME,
 | 
					              "IMAGE:" + IMAGE_IN_STREAM_NAME,
 | 
				
			||||||
              "ROI:" + ROI_IN_STREAM_NAME,
 | 
					              "ROI:" + ROI_IN_STREAM_NAME,
 | 
				
			||||||
              "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME));
 | 
					              "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME));
 | 
				
			||||||
  private static final int IMAGE_OUT_STREAM_INDEX = 0;
 | 
					 | 
				
			||||||
  private static final String TASK_GRAPH_NAME =
 | 
					  private static final String TASK_GRAPH_NAME =
 | 
				
			||||||
      "mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph";
 | 
					      "mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph";
 | 
				
			||||||
  private static final String TENSORS_TO_SEGMENTATION_CALCULATOR_NAME =
 | 
					  private static final String TENSORS_TO_SEGMENTATION_CALCULATOR_NAME =
 | 
				
			||||||
| 
						 | 
					@ -120,7 +119,6 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
 | 
				
			||||||
          "At least one of `outputConfidenceMasks` and `outputCategoryMask` must be set.");
 | 
					          "At least one of `outputConfidenceMasks` and `outputCategoryMask` must be set.");
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    List<String> outputStreams = new ArrayList<>();
 | 
					    List<String> outputStreams = new ArrayList<>();
 | 
				
			||||||
    outputStreams.add("IMAGE:image_out");
 | 
					 | 
				
			||||||
    if (segmenterOptions.outputConfidenceMasks()) {
 | 
					    if (segmenterOptions.outputConfidenceMasks()) {
 | 
				
			||||||
      outputStreams.add("CONFIDENCE_MASKS:confidence_masks");
 | 
					      outputStreams.add("CONFIDENCE_MASKS:confidence_masks");
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
| 
						 | 
					@ -129,6 +127,9 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
 | 
				
			||||||
      outputStreams.add("CATEGORY_MASK:category_mask");
 | 
					      outputStreams.add("CATEGORY_MASK:category_mask");
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    final int categoryMaskOutStreamIndex = outputStreams.size() - 1;
 | 
					    final int categoryMaskOutStreamIndex = outputStreams.size() - 1;
 | 
				
			||||||
 | 
					    outputStreams.add("IMAGE:image_out");
 | 
				
			||||||
 | 
					    // TODO: add test for stream indices.
 | 
				
			||||||
 | 
					    final int imageOutStreamIndex = outputStreams.size() - 1;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // TODO: Consolidate OutputHandler and TaskRunner.
 | 
					    // TODO: Consolidate OutputHandler and TaskRunner.
 | 
				
			||||||
    OutputHandler<ImageSegmenterResult, MPImage> handler = new OutputHandler<>();
 | 
					    OutputHandler<ImageSegmenterResult, MPImage> handler = new OutputHandler<>();
 | 
				
			||||||
| 
						 | 
					@ -137,11 +138,11 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
 | 
				
			||||||
          @Override
 | 
					          @Override
 | 
				
			||||||
          public ImageSegmenterResult convertToTaskResult(List<Packet> packets)
 | 
					          public ImageSegmenterResult convertToTaskResult(List<Packet> packets)
 | 
				
			||||||
              throws MediaPipeException {
 | 
					              throws MediaPipeException {
 | 
				
			||||||
            if (packets.get(IMAGE_OUT_STREAM_INDEX).isEmpty()) {
 | 
					            if (packets.get(imageOutStreamIndex).isEmpty()) {
 | 
				
			||||||
              return ImageSegmenterResult.create(
 | 
					              return ImageSegmenterResult.create(
 | 
				
			||||||
                  Optional.empty(),
 | 
					                  Optional.empty(),
 | 
				
			||||||
                  Optional.empty(),
 | 
					                  Optional.empty(),
 | 
				
			||||||
                  packets.get(IMAGE_OUT_STREAM_INDEX).getTimestamp());
 | 
					                  packets.get(imageOutStreamIndex).getTimestamp());
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
            // If resultListener is not provided, the resulted MPImage is deep copied from
 | 
					            // If resultListener is not provided, the resulted MPImage is deep copied from
 | 
				
			||||||
            // mediapipe graph. If provided, the result MPImage is wrapping the mediapipe packet
 | 
					            // mediapipe graph. If provided, the result MPImage is wrapping the mediapipe packet
 | 
				
			||||||
| 
						 | 
					@ -202,13 +203,13 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
 | 
				
			||||||
                confidenceMasks,
 | 
					                confidenceMasks,
 | 
				
			||||||
                categoryMask,
 | 
					                categoryMask,
 | 
				
			||||||
                BaseVisionTaskApi.generateResultTimestampMs(
 | 
					                BaseVisionTaskApi.generateResultTimestampMs(
 | 
				
			||||||
                    RunningMode.IMAGE, packets.get(IMAGE_OUT_STREAM_INDEX)));
 | 
					                    RunningMode.IMAGE, packets.get(imageOutStreamIndex)));
 | 
				
			||||||
          }
 | 
					          }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
          @Override
 | 
					          @Override
 | 
				
			||||||
          public MPImage convertToTaskInput(List<Packet> packets) {
 | 
					          public MPImage convertToTaskInput(List<Packet> packets) {
 | 
				
			||||||
            return new BitmapImageBuilder(
 | 
					            return new BitmapImageBuilder(
 | 
				
			||||||
                    AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX)))
 | 
					                    AndroidPacketGetter.getBitmapFromRgb(packets.get(imageOutStreamIndex)))
 | 
				
			||||||
                .build();
 | 
					                .build();
 | 
				
			||||||
          }
 | 
					          }
 | 
				
			||||||
        });
 | 
					        });
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user