Internal change
PiperOrigin-RevId: 527430483
This commit is contained in:
		
							parent
							
								
									b05fd21709
								
							
						
					
					
						commit
						2122b5d7be
					
				| 
						 | 
				
			
			@ -80,7 +80,6 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
 | 
			
		|||
  private static final List<String> INPUT_STREAMS =
 | 
			
		||||
      Collections.unmodifiableList(
 | 
			
		||||
          Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME));
 | 
			
		||||
  private static final int IMAGE_OUT_STREAM_INDEX = 0;
 | 
			
		||||
  private static final String TASK_GRAPH_NAME =
 | 
			
		||||
      "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph";
 | 
			
		||||
  private static final String TENSORS_TO_SEGMENTATION_CALCULATOR_NAME =
 | 
			
		||||
| 
						 | 
				
			
			@ -102,27 +101,21 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
 | 
			
		|||
          "At least one of `outputConfidenceMasks` and `outputCategoryMask` must be set.");
 | 
			
		||||
    }
 | 
			
		||||
    List<String> outputStreams = new ArrayList<>();
 | 
			
		||||
    outputStreams.add("IMAGE:image_out");
 | 
			
		||||
 | 
			
		||||
    // Add an output stream to the output stream list, and get the added output stream index.
 | 
			
		||||
    BiFunction<List<String>, String, Integer> getStreamIndex =
 | 
			
		||||
        (List<String> streams, String streamName) -> {
 | 
			
		||||
          streams.add(streamName);
 | 
			
		||||
          return streams.size() - 1;
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
    int confidenceMasksOutStreamIndex =
 | 
			
		||||
    final int confidenceMasksOutStreamIndex =
 | 
			
		||||
        segmenterOptions.outputConfidenceMasks()
 | 
			
		||||
            ? getStreamIndex.apply(outputStreams, "CONFIDENCE_MASKS:confidence_masks")
 | 
			
		||||
            : -1;
 | 
			
		||||
    int confidenceMaskOutStreamIndex =
 | 
			
		||||
        segmenterOptions.outputConfidenceMasks()
 | 
			
		||||
            ? getStreamIndex.apply(outputStreams, "CONFIDENCE_MASK:0:confidence_mask")
 | 
			
		||||
            : -1;
 | 
			
		||||
    int categoryMaskOutStreamIndex =
 | 
			
		||||
    final int categoryMaskOutStreamIndex =
 | 
			
		||||
        segmenterOptions.outputCategoryMask()
 | 
			
		||||
            ? getStreamIndex.apply(outputStreams, "CATEGORY_MASK:category_mask")
 | 
			
		||||
            : -1;
 | 
			
		||||
    final int imageOutStreamIndex = getStreamIndex.apply(outputStreams, "IMAGE:image_out");
 | 
			
		||||
 | 
			
		||||
    // TODO: Consolidate OutputHandler and TaskRunner.
 | 
			
		||||
    OutputHandler<ImageSegmenterResult, MPImage> handler = new OutputHandler<>();
 | 
			
		||||
| 
						 | 
				
			
			@ -131,17 +124,21 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
 | 
			
		|||
          @Override
 | 
			
		||||
          public ImageSegmenterResult convertToTaskResult(List<Packet> packets)
 | 
			
		||||
              throws MediaPipeException {
 | 
			
		||||
            if (packets.get(IMAGE_OUT_STREAM_INDEX).isEmpty()) {
 | 
			
		||||
            if (packets.get(imageOutStreamIndex).isEmpty()) {
 | 
			
		||||
              return ImageSegmenterResult.create(
 | 
			
		||||
                  Optional.empty(),
 | 
			
		||||
                  Optional.empty(),
 | 
			
		||||
                  packets.get(IMAGE_OUT_STREAM_INDEX).getTimestamp());
 | 
			
		||||
                  packets.get(imageOutStreamIndex).getTimestamp());
 | 
			
		||||
            }
 | 
			
		||||
            boolean copyImage = !segmenterOptions.resultListener().isPresent();
 | 
			
		||||
            Optional<List<MPImage>> confidenceMasks = Optional.empty();
 | 
			
		||||
            if (segmenterOptions.outputConfidenceMasks()) {
 | 
			
		||||
              int width = PacketGetter.getImageWidth(packets.get(confidenceMaskOutStreamIndex));
 | 
			
		||||
              int height = PacketGetter.getImageHeight(packets.get(confidenceMaskOutStreamIndex));
 | 
			
		||||
              int width =
 | 
			
		||||
                  PacketGetter.getImageWidthFromImageList(
 | 
			
		||||
                      packets.get(confidenceMasksOutStreamIndex));
 | 
			
		||||
              int height =
 | 
			
		||||
                  PacketGetter.getImageHeightFromImageList(
 | 
			
		||||
                      packets.get(confidenceMasksOutStreamIndex));
 | 
			
		||||
              confidenceMasks = Optional.of(new ArrayList<MPImage>());
 | 
			
		||||
              int confidenceMasksListSize =
 | 
			
		||||
                  PacketGetter.getImageListSize(packets.get(confidenceMasksOutStreamIndex));
 | 
			
		||||
| 
						 | 
				
			
			@ -189,13 +186,13 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
 | 
			
		|||
                confidenceMasks,
 | 
			
		||||
                categoryMask,
 | 
			
		||||
                BaseVisionTaskApi.generateResultTimestampMs(
 | 
			
		||||
                    segmenterOptions.runningMode(), packets.get(IMAGE_OUT_STREAM_INDEX)));
 | 
			
		||||
                    segmenterOptions.runningMode(), packets.get(imageOutStreamIndex)));
 | 
			
		||||
          }
 | 
			
		||||
 | 
			
		||||
          @Override
 | 
			
		||||
          public MPImage convertToTaskInput(List<Packet> packets) {
 | 
			
		||||
            return new BitmapImageBuilder(
 | 
			
		||||
                    AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX)))
 | 
			
		||||
                    AndroidPacketGetter.getBitmapFromRgb(packets.get(imageOutStreamIndex)))
 | 
			
		||||
                .build();
 | 
			
		||||
          }
 | 
			
		||||
        });
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -94,7 +94,6 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
 | 
			
		|||
              "IMAGE:" + IMAGE_IN_STREAM_NAME,
 | 
			
		||||
              "ROI:" + ROI_IN_STREAM_NAME,
 | 
			
		||||
              "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME));
 | 
			
		||||
  private static final int IMAGE_OUT_STREAM_INDEX = 0;
 | 
			
		||||
  private static final String TASK_GRAPH_NAME =
 | 
			
		||||
      "mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph";
 | 
			
		||||
  private static final String TENSORS_TO_SEGMENTATION_CALCULATOR_NAME =
 | 
			
		||||
| 
						 | 
				
			
			@ -120,7 +119,6 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
 | 
			
		|||
          "At least one of `outputConfidenceMasks` and `outputCategoryMask` must be set.");
 | 
			
		||||
    }
 | 
			
		||||
    List<String> outputStreams = new ArrayList<>();
 | 
			
		||||
    outputStreams.add("IMAGE:image_out");
 | 
			
		||||
    if (segmenterOptions.outputConfidenceMasks()) {
 | 
			
		||||
      outputStreams.add("CONFIDENCE_MASKS:confidence_masks");
 | 
			
		||||
    }
 | 
			
		||||
| 
						 | 
				
			
			@ -129,6 +127,9 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
 | 
			
		|||
      outputStreams.add("CATEGORY_MASK:category_mask");
 | 
			
		||||
    }
 | 
			
		||||
    final int categoryMaskOutStreamIndex = outputStreams.size() - 1;
 | 
			
		||||
    outputStreams.add("IMAGE:image_out");
 | 
			
		||||
    // TODO: add test for stream indices.
 | 
			
		||||
    final int imageOutStreamIndex = outputStreams.size() - 1;
 | 
			
		||||
 | 
			
		||||
    // TODO: Consolidate OutputHandler and TaskRunner.
 | 
			
		||||
    OutputHandler<ImageSegmenterResult, MPImage> handler = new OutputHandler<>();
 | 
			
		||||
| 
						 | 
				
			
			@ -137,11 +138,11 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
 | 
			
		|||
          @Override
 | 
			
		||||
          public ImageSegmenterResult convertToTaskResult(List<Packet> packets)
 | 
			
		||||
              throws MediaPipeException {
 | 
			
		||||
            if (packets.get(IMAGE_OUT_STREAM_INDEX).isEmpty()) {
 | 
			
		||||
            if (packets.get(imageOutStreamIndex).isEmpty()) {
 | 
			
		||||
              return ImageSegmenterResult.create(
 | 
			
		||||
                  Optional.empty(),
 | 
			
		||||
                  Optional.empty(),
 | 
			
		||||
                  packets.get(IMAGE_OUT_STREAM_INDEX).getTimestamp());
 | 
			
		||||
                  packets.get(imageOutStreamIndex).getTimestamp());
 | 
			
		||||
            }
 | 
			
		||||
            // If resultListener is not provided, the resulted MPImage is deep copied from
 | 
			
		||||
            // mediapipe graph. If provided, the result MPImage is wrapping the mediapipe packet
 | 
			
		||||
| 
						 | 
				
			
			@ -202,13 +203,13 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
 | 
			
		|||
                confidenceMasks,
 | 
			
		||||
                categoryMask,
 | 
			
		||||
                BaseVisionTaskApi.generateResultTimestampMs(
 | 
			
		||||
                    RunningMode.IMAGE, packets.get(IMAGE_OUT_STREAM_INDEX)));
 | 
			
		||||
                    RunningMode.IMAGE, packets.get(imageOutStreamIndex)));
 | 
			
		||||
          }
 | 
			
		||||
 | 
			
		||||
          @Override
 | 
			
		||||
          public MPImage convertToTaskInput(List<Packet> packets) {
 | 
			
		||||
            return new BitmapImageBuilder(
 | 
			
		||||
                    AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX)))
 | 
			
		||||
                    AndroidPacketGetter.getBitmapFromRgb(packets.get(imageOutStreamIndex)))
 | 
			
		||||
                .build();
 | 
			
		||||
          }
 | 
			
		||||
        });
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user