internal change
PiperOrigin-RevId: 523773255
This commit is contained in:
		
							parent
							
								
									ca0da8d26f
								
							
						
					
					
						commit
						9a10375de6
					
				| 
						 | 
					@ -45,6 +45,7 @@ import java.util.Arrays;
 | 
				
			||||||
import java.util.Collections;
 | 
					import java.util.Collections;
 | 
				
			||||||
import java.util.List;
 | 
					import java.util.List;
 | 
				
			||||||
import java.util.Optional;
 | 
					import java.util.Optional;
 | 
				
			||||||
 | 
					import java.util.function.BiFunction;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/**
 | 
					/**
 | 
				
			||||||
 * Performs image segmentation on images.
 | 
					 * Performs image segmentation on images.
 | 
				
			||||||
| 
						 | 
					@ -79,15 +80,7 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
 | 
				
			||||||
  private static final List<String> INPUT_STREAMS =
 | 
					  private static final List<String> INPUT_STREAMS =
 | 
				
			||||||
      Collections.unmodifiableList(
 | 
					      Collections.unmodifiableList(
 | 
				
			||||||
          Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME));
 | 
					          Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME));
 | 
				
			||||||
  private static final List<String> OUTPUT_STREAMS =
 | 
					  private static final int IMAGE_OUT_STREAM_INDEX = 0;
 | 
				
			||||||
      Collections.unmodifiableList(
 | 
					 | 
				
			||||||
          Arrays.asList(
 | 
					 | 
				
			||||||
              "GROUPED_SEGMENTATION:segmented_mask_out",
 | 
					 | 
				
			||||||
              "IMAGE:image_out",
 | 
					 | 
				
			||||||
              "SEGMENTATION:0:segmentation"));
 | 
					 | 
				
			||||||
  private static final int GROUPED_SEGMENTATION_OUT_STREAM_INDEX = 0;
 | 
					 | 
				
			||||||
  private static final int IMAGE_OUT_STREAM_INDEX = 1;
 | 
					 | 
				
			||||||
  private static final int SEGMENTATION_OUT_STREAM_INDEX = 2;
 | 
					 | 
				
			||||||
  private static final String TASK_GRAPH_NAME =
 | 
					  private static final String TASK_GRAPH_NAME =
 | 
				
			||||||
      "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph";
 | 
					      "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph";
 | 
				
			||||||
  private static final String TENSORS_TO_SEGMENTATION_CALCULATOR_NAME =
 | 
					  private static final String TENSORS_TO_SEGMENTATION_CALCULATOR_NAME =
 | 
				
			||||||
| 
						 | 
					@ -104,6 +97,33 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
 | 
				
			||||||
   */
 | 
					   */
 | 
				
			||||||
  public static ImageSegmenter createFromOptions(
 | 
					  public static ImageSegmenter createFromOptions(
 | 
				
			||||||
      Context context, ImageSegmenterOptions segmenterOptions) {
 | 
					      Context context, ImageSegmenterOptions segmenterOptions) {
 | 
				
			||||||
 | 
					    if (!segmenterOptions.outputConfidenceMasks() && !segmenterOptions.outputCategoryMask()) {
 | 
				
			||||||
 | 
					      throw new IllegalArgumentException(
 | 
				
			||||||
 | 
					          "At least one of `outputConfidenceMasks` and `outputCategoryMask` must be set.");
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    List<String> outputStreams = new ArrayList<>();
 | 
				
			||||||
 | 
					    outputStreams.add("IMAGE:image_out");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Add an output stream to the output stream list, and get the added output stream index.
 | 
				
			||||||
 | 
					    BiFunction<List<String>, String, Integer> getStreamIndex =
 | 
				
			||||||
 | 
					        (List<String> streams, String streamName) -> {
 | 
				
			||||||
 | 
					          streams.add(streamName);
 | 
				
			||||||
 | 
					          return streams.size() - 1;
 | 
				
			||||||
 | 
					        };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    int confidenceMasksOutStreamIndex =
 | 
				
			||||||
 | 
					        segmenterOptions.outputConfidenceMasks()
 | 
				
			||||||
 | 
					            ? getStreamIndex.apply(outputStreams, "CONFIDENCE_MASKS:confidence_masks")
 | 
				
			||||||
 | 
					            : -1;
 | 
				
			||||||
 | 
					    int confidenceMaskOutStreamIndex =
 | 
				
			||||||
 | 
					        segmenterOptions.outputConfidenceMasks()
 | 
				
			||||||
 | 
					            ? getStreamIndex.apply(outputStreams, "CONFIDENCE_MASK:0:confidence_mask")
 | 
				
			||||||
 | 
					            : -1;
 | 
				
			||||||
 | 
					    int categoryMaskOutStreamIndex =
 | 
				
			||||||
 | 
					        segmenterOptions.outputCategoryMask()
 | 
				
			||||||
 | 
					            ? getStreamIndex.apply(outputStreams, "CATEGORY_MASK:category_mask")
 | 
				
			||||||
 | 
					            : -1;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // TODO: Consolidate OutputHandler and TaskRunner.
 | 
					    // TODO: Consolidate OutputHandler and TaskRunner.
 | 
				
			||||||
    OutputHandler<ImageSegmenterResult, MPImage> handler = new OutputHandler<>();
 | 
					    OutputHandler<ImageSegmenterResult, MPImage> handler = new OutputHandler<>();
 | 
				
			||||||
    handler.setOutputPacketConverter(
 | 
					    handler.setOutputPacketConverter(
 | 
				
			||||||
| 
						 | 
					@ -111,50 +131,65 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
 | 
				
			||||||
          @Override
 | 
					          @Override
 | 
				
			||||||
          public ImageSegmenterResult convertToTaskResult(List<Packet> packets)
 | 
					          public ImageSegmenterResult convertToTaskResult(List<Packet> packets)
 | 
				
			||||||
              throws MediaPipeException {
 | 
					              throws MediaPipeException {
 | 
				
			||||||
            if (packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).isEmpty()) {
 | 
					            if (packets.get(IMAGE_OUT_STREAM_INDEX).isEmpty()) {
 | 
				
			||||||
              return ImageSegmenterResult.create(
 | 
					              return ImageSegmenterResult.create(
 | 
				
			||||||
                  new ArrayList<>(),
 | 
					                  Optional.empty(),
 | 
				
			||||||
                  packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).getTimestamp());
 | 
					                  Optional.empty(),
 | 
				
			||||||
 | 
					                  packets.get(IMAGE_OUT_STREAM_INDEX).getTimestamp());
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
            List<MPImage> segmentedMasks = new ArrayList<>();
 | 
					            boolean copyImage = !segmenterOptions.resultListener().isPresent();
 | 
				
			||||||
            int width = PacketGetter.getImageWidth(packets.get(SEGMENTATION_OUT_STREAM_INDEX));
 | 
					            Optional<List<MPImage>> confidenceMasks = Optional.empty();
 | 
				
			||||||
            int height = PacketGetter.getImageHeight(packets.get(SEGMENTATION_OUT_STREAM_INDEX));
 | 
					            if (segmenterOptions.outputConfidenceMasks()) {
 | 
				
			||||||
            int imageFormat =
 | 
					              int width = PacketGetter.getImageWidth(packets.get(confidenceMaskOutStreamIndex));
 | 
				
			||||||
                segmenterOptions.outputType() == ImageSegmenterOptions.OutputType.CONFIDENCE_MASK
 | 
					              int height = PacketGetter.getImageHeight(packets.get(confidenceMaskOutStreamIndex));
 | 
				
			||||||
                    ? MPImage.IMAGE_FORMAT_VEC32F1
 | 
					              confidenceMasks = Optional.of(new ArrayList<MPImage>());
 | 
				
			||||||
                    : MPImage.IMAGE_FORMAT_ALPHA;
 | 
					              int confidenceMasksListSize =
 | 
				
			||||||
            int imageListSize =
 | 
					                  PacketGetter.getImageListSize(packets.get(confidenceMasksOutStreamIndex));
 | 
				
			||||||
                PacketGetter.getImageListSize(packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX));
 | 
					              ByteBuffer[] buffersArray = new ByteBuffer[confidenceMasksListSize];
 | 
				
			||||||
            ByteBuffer[] buffersArray = new ByteBuffer[imageListSize];
 | 
					              // If resultListener is not provided, the resulted MPImage is deep copied from
 | 
				
			||||||
            // If resultListener is not provided, the resulted MPImage is deep copied from mediapipe
 | 
					              // mediapipe graph. If provided, the result MPImage is wrapping the mediapipe packet
 | 
				
			||||||
            // graph. If provided, the result MPImage is wrapping the mediapipe packet memory.
 | 
					              // memory.
 | 
				
			||||||
            if (!segmenterOptions.resultListener().isPresent()) {
 | 
					              if (copyImage) {
 | 
				
			||||||
              for (int i = 0; i < imageListSize; i++) {
 | 
					                for (int i = 0; i < confidenceMasksListSize; i++) {
 | 
				
			||||||
                buffersArray[i] =
 | 
					                  buffersArray[i] = ByteBuffer.allocateDirect(width * height * 4);
 | 
				
			||||||
                    ByteBuffer.allocateDirect(
 | 
					                }
 | 
				
			||||||
                        width * height * (imageFormat == MPImage.IMAGE_FORMAT_VEC32F1 ? 4 : 1));
 | 
					              }
 | 
				
			||||||
 | 
					              if (!PacketGetter.getImageList(
 | 
				
			||||||
 | 
					                  packets.get(confidenceMasksOutStreamIndex), buffersArray, copyImage)) {
 | 
				
			||||||
 | 
					                throw new MediaPipeException(
 | 
				
			||||||
 | 
					                    MediaPipeException.StatusCode.INTERNAL.ordinal(),
 | 
				
			||||||
 | 
					                    "There is an error getting confidence masks.");
 | 
				
			||||||
 | 
					              }
 | 
				
			||||||
 | 
					              for (ByteBuffer buffer : buffersArray) {
 | 
				
			||||||
 | 
					                ByteBufferImageBuilder builder =
 | 
				
			||||||
 | 
					                    new ByteBufferImageBuilder(buffer, width, height, MPImage.IMAGE_FORMAT_VEC32F1);
 | 
				
			||||||
 | 
					                confidenceMasks.get().add(builder.build());
 | 
				
			||||||
              }
 | 
					              }
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
            if (!PacketGetter.getImageList(
 | 
					            Optional<MPImage> categoryMask = Optional.empty();
 | 
				
			||||||
                packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX),
 | 
					            if (segmenterOptions.outputCategoryMask()) {
 | 
				
			||||||
                buffersArray,
 | 
					              int width = PacketGetter.getImageWidth(packets.get(categoryMaskOutStreamIndex));
 | 
				
			||||||
                !segmenterOptions.resultListener().isPresent())) {
 | 
					              int height = PacketGetter.getImageHeight(packets.get(categoryMaskOutStreamIndex));
 | 
				
			||||||
              throw new MediaPipeException(
 | 
					              ByteBuffer buffer;
 | 
				
			||||||
                  MediaPipeException.StatusCode.INTERNAL.ordinal(),
 | 
					              if (copyImage) {
 | 
				
			||||||
                  "There is an error getting segmented masks. It usually results from incorrect"
 | 
					                buffer = ByteBuffer.allocateDirect(width * height);
 | 
				
			||||||
                      + " options of unsupported OutputType of given model.");
 | 
					                if (!PacketGetter.getImageData(packets.get(categoryMaskOutStreamIndex), buffer)) {
 | 
				
			||||||
            }
 | 
					                  throw new MediaPipeException(
 | 
				
			||||||
            for (ByteBuffer buffer : buffersArray) {
 | 
					                      MediaPipeException.StatusCode.INTERNAL.ordinal(),
 | 
				
			||||||
 | 
					                      "There is an error getting category mask.");
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					              } else {
 | 
				
			||||||
 | 
					                buffer = PacketGetter.getImageDataDirectly(packets.get(categoryMaskOutStreamIndex));
 | 
				
			||||||
 | 
					              }
 | 
				
			||||||
              ByteBufferImageBuilder builder =
 | 
					              ByteBufferImageBuilder builder =
 | 
				
			||||||
                  new ByteBufferImageBuilder(buffer, width, height, imageFormat);
 | 
					                  new ByteBufferImageBuilder(buffer, width, height, MPImage.IMAGE_FORMAT_ALPHA);
 | 
				
			||||||
              segmentedMasks.add(builder.build());
 | 
					              categoryMask = Optional.of(builder.build());
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					 | 
				
			||||||
            return ImageSegmenterResult.create(
 | 
					            return ImageSegmenterResult.create(
 | 
				
			||||||
                segmentedMasks,
 | 
					                confidenceMasks,
 | 
				
			||||||
 | 
					                categoryMask,
 | 
				
			||||||
                BaseVisionTaskApi.generateResultTimestampMs(
 | 
					                BaseVisionTaskApi.generateResultTimestampMs(
 | 
				
			||||||
                    segmenterOptions.runningMode(),
 | 
					                    segmenterOptions.runningMode(), packets.get(IMAGE_OUT_STREAM_INDEX)));
 | 
				
			||||||
                    packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX)));
 | 
					 | 
				
			||||||
          }
 | 
					          }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
          @Override
 | 
					          @Override
 | 
				
			||||||
| 
						 | 
					@ -174,7 +209,7 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
 | 
				
			||||||
                .setTaskRunningModeName(segmenterOptions.runningMode().name())
 | 
					                .setTaskRunningModeName(segmenterOptions.runningMode().name())
 | 
				
			||||||
                .setTaskGraphName(TASK_GRAPH_NAME)
 | 
					                .setTaskGraphName(TASK_GRAPH_NAME)
 | 
				
			||||||
                .setInputStreams(INPUT_STREAMS)
 | 
					                .setInputStreams(INPUT_STREAMS)
 | 
				
			||||||
                .setOutputStreams(OUTPUT_STREAMS)
 | 
					                .setOutputStreams(outputStreams)
 | 
				
			||||||
                .setTaskOptions(segmenterOptions)
 | 
					                .setTaskOptions(segmenterOptions)
 | 
				
			||||||
                .setEnableFlowLimiting(segmenterOptions.runningMode() == RunningMode.LIVE_STREAM)
 | 
					                .setEnableFlowLimiting(segmenterOptions.runningMode() == RunningMode.LIVE_STREAM)
 | 
				
			||||||
                .build(),
 | 
					                .build(),
 | 
				
			||||||
| 
						 | 
					@ -553,8 +588,11 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
 | 
				
			||||||
       */
 | 
					       */
 | 
				
			||||||
      public abstract Builder setDisplayNamesLocale(String value);
 | 
					      public abstract Builder setDisplayNamesLocale(String value);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      /** The output type from image segmenter. */
 | 
					      /** Whether to output confidence masks. */
 | 
				
			||||||
      public abstract Builder setOutputType(OutputType value);
 | 
					      public abstract Builder setOutputConfidenceMasks(boolean value);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      /** Whether to output category mask. */
 | 
				
			||||||
 | 
					      public abstract Builder setOutputCategoryMask(boolean value);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      /**
 | 
					      /**
 | 
				
			||||||
       * Sets an optional {@link ResultListener} to receive the segmentation results when the graph
 | 
					       * Sets an optional {@link ResultListener} to receive the segmentation results when the graph
 | 
				
			||||||
| 
						 | 
					@ -594,27 +632,20 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    abstract String displayNamesLocale();
 | 
					    abstract String displayNamesLocale();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    abstract OutputType outputType();
 | 
					    abstract boolean outputConfidenceMasks();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    abstract boolean outputCategoryMask();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    abstract Optional<ResultListener<ImageSegmenterResult, MPImage>> resultListener();
 | 
					    abstract Optional<ResultListener<ImageSegmenterResult, MPImage>> resultListener();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    abstract Optional<ErrorListener> errorListener();
 | 
					    abstract Optional<ErrorListener> errorListener();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /** The output type of segmentation results. */
 | 
					 | 
				
			||||||
    public enum OutputType {
 | 
					 | 
				
			||||||
      // Gives a single output mask where each pixel represents the class which
 | 
					 | 
				
			||||||
      // the pixel in the original image was predicted to belong to.
 | 
					 | 
				
			||||||
      CATEGORY_MASK,
 | 
					 | 
				
			||||||
      // Gives a list of output masks where, for each mask, each pixel represents
 | 
					 | 
				
			||||||
      // the prediction confidence, usually in the [0, 1] range.
 | 
					 | 
				
			||||||
      CONFIDENCE_MASK
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    public static Builder builder() {
 | 
					    public static Builder builder() {
 | 
				
			||||||
      return new AutoValue_ImageSegmenter_ImageSegmenterOptions.Builder()
 | 
					      return new AutoValue_ImageSegmenter_ImageSegmenterOptions.Builder()
 | 
				
			||||||
          .setRunningMode(RunningMode.IMAGE)
 | 
					          .setRunningMode(RunningMode.IMAGE)
 | 
				
			||||||
          .setDisplayNamesLocale("en")
 | 
					          .setDisplayNamesLocale("en")
 | 
				
			||||||
          .setOutputType(OutputType.CATEGORY_MASK);
 | 
					          .setOutputConfidenceMasks(true)
 | 
				
			||||||
 | 
					          .setOutputCategoryMask(false);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /**
 | 
					    /**
 | 
				
			||||||
| 
						 | 
					@ -633,14 +664,6 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      SegmenterOptionsProto.SegmenterOptions.Builder segmenterOptionsBuilder =
 | 
					      SegmenterOptionsProto.SegmenterOptions.Builder segmenterOptionsBuilder =
 | 
				
			||||||
          SegmenterOptionsProto.SegmenterOptions.newBuilder();
 | 
					          SegmenterOptionsProto.SegmenterOptions.newBuilder();
 | 
				
			||||||
      if (outputType() == OutputType.CONFIDENCE_MASK) {
 | 
					 | 
				
			||||||
        segmenterOptionsBuilder.setOutputType(
 | 
					 | 
				
			||||||
            SegmenterOptionsProto.SegmenterOptions.OutputType.CONFIDENCE_MASK);
 | 
					 | 
				
			||||||
      } else if (outputType() == OutputType.CATEGORY_MASK) {
 | 
					 | 
				
			||||||
        segmenterOptionsBuilder.setOutputType(
 | 
					 | 
				
			||||||
            SegmenterOptionsProto.SegmenterOptions.OutputType.CATEGORY_MASK);
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
      taskOptionsBuilder.setSegmenterOptions(segmenterOptionsBuilder);
 | 
					      taskOptionsBuilder.setSegmenterOptions(segmenterOptionsBuilder);
 | 
				
			||||||
      return CalculatorOptions.newBuilder()
 | 
					      return CalculatorOptions.newBuilder()
 | 
				
			||||||
          .setExtension(
 | 
					          .setExtension(
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -19,6 +19,7 @@ import com.google.mediapipe.framework.image.MPImage;
 | 
				
			||||||
import com.google.mediapipe.tasks.core.TaskResult;
 | 
					import com.google.mediapipe.tasks.core.TaskResult;
 | 
				
			||||||
import java.util.Collections;
 | 
					import java.util.Collections;
 | 
				
			||||||
import java.util.List;
 | 
					import java.util.List;
 | 
				
			||||||
 | 
					import java.util.Optional;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/** Represents the segmentation results generated by {@link ImageSegmenter}. */
 | 
					/** Represents the segmentation results generated by {@link ImageSegmenter}. */
 | 
				
			||||||
@AutoValue
 | 
					@AutoValue
 | 
				
			||||||
| 
						 | 
					@ -27,18 +28,24 @@ public abstract class ImageSegmenterResult implements TaskResult {
 | 
				
			||||||
  /**
 | 
					  /**
 | 
				
			||||||
   * Creates an {@link ImageSegmenterResult} instance from a list of segmentation MPImage.
 | 
					   * Creates an {@link ImageSegmenterResult} instance from a list of segmentation MPImage.
 | 
				
			||||||
   *
 | 
					   *
 | 
				
			||||||
   * @param segmentations a {@link List} of MPImage representing the segmented masks. If OutputType
 | 
					   * @param confidenceMasks an {@link Optional} of {@link List} of MPImage in IMAGE_FORMAT_VEC32F1
 | 
				
			||||||
   *     is CATEGORY_MASK, the masks will be in IMAGE_FORMAT_ALPHA format. If OutputType is
 | 
					   *     format representing the confidence masks, where, for each mask, each pixel represents the
 | 
				
			||||||
   *     CONFIDENCE_MASK, the masks will be in IMAGE_FORMAT_VEC32F1 format.
 | 
					   *     prediction confidence, usually in the [0, 1] range.
 | 
				
			||||||
 | 
					   * @param categoryMask an {@link Optional} MPImage in IMAGE_FORMAT_ALPHA format representing a
 | 
				
			||||||
 | 
					   *     category mask, where each pixel represents the class which the pixel in the original image
 | 
				
			||||||
 | 
					   *     was predicted to belong to.
 | 
				
			||||||
   * @param timestampMs a timestamp for this result.
 | 
					   * @param timestampMs a timestamp for this result.
 | 
				
			||||||
   */
 | 
					   */
 | 
				
			||||||
  // TODO: consolidate output formats across platforms.
 | 
					  // TODO: consolidate output formats across platforms.
 | 
				
			||||||
  public static ImageSegmenterResult create(List<MPImage> segmentations, long timestampMs) {
 | 
					  public static ImageSegmenterResult create(
 | 
				
			||||||
 | 
					      Optional<List<MPImage>> confidenceMasks, Optional<MPImage> categoryMask, long timestampMs) {
 | 
				
			||||||
    return new AutoValue_ImageSegmenterResult(
 | 
					    return new AutoValue_ImageSegmenterResult(
 | 
				
			||||||
        Collections.unmodifiableList(segmentations), timestampMs);
 | 
					        confidenceMasks.map(Collections::unmodifiableList), categoryMask, timestampMs);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  public abstract List<MPImage> segmentations();
 | 
					  public abstract Optional<List<MPImage>> confidenceMasks();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  public abstract Optional<MPImage> categoryMask();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  @Override
 | 
					  @Override
 | 
				
			||||||
  public abstract long timestampMs();
 | 
					  public abstract long timestampMs();
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -132,7 +132,8 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
 | 
				
			||||||
              throws MediaPipeException {
 | 
					              throws MediaPipeException {
 | 
				
			||||||
            if (packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).isEmpty()) {
 | 
					            if (packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).isEmpty()) {
 | 
				
			||||||
              return ImageSegmenterResult.create(
 | 
					              return ImageSegmenterResult.create(
 | 
				
			||||||
                  new ArrayList<>(),
 | 
					                  Optional.empty(),
 | 
				
			||||||
 | 
					                  Optional.empty(),
 | 
				
			||||||
                  packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).getTimestamp());
 | 
					                  packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).getTimestamp());
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
            List<MPImage> segmentedMasks = new ArrayList<>();
 | 
					            List<MPImage> segmentedMasks = new ArrayList<>();
 | 
				
			||||||
| 
						 | 
					@ -171,7 +172,8 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            return ImageSegmenterResult.create(
 | 
					            return ImageSegmenterResult.create(
 | 
				
			||||||
                segmentedMasks,
 | 
					                Optional.of(segmentedMasks),
 | 
				
			||||||
 | 
					                Optional.empty(),
 | 
				
			||||||
                BaseVisionTaskApi.generateResultTimestampMs(
 | 
					                BaseVisionTaskApi.generateResultTimestampMs(
 | 
				
			||||||
                    RunningMode.IMAGE, packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX)));
 | 
					                    RunningMode.IMAGE, packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX)));
 | 
				
			||||||
          }
 | 
					          }
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -61,14 +61,14 @@ public class ImageSegmenterTest {
 | 
				
			||||||
      ImageSegmenterOptions options =
 | 
					      ImageSegmenterOptions options =
 | 
				
			||||||
          ImageSegmenterOptions.builder()
 | 
					          ImageSegmenterOptions.builder()
 | 
				
			||||||
              .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
 | 
					              .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
 | 
				
			||||||
              .setOutputType(ImageSegmenterOptions.OutputType.CATEGORY_MASK)
 | 
					              .setOutputConfidenceMasks(false)
 | 
				
			||||||
 | 
					              .setOutputCategoryMask(true)
 | 
				
			||||||
              .build();
 | 
					              .build();
 | 
				
			||||||
      ImageSegmenter imageSegmenter =
 | 
					      ImageSegmenter imageSegmenter =
 | 
				
			||||||
          ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
 | 
					          ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
 | 
				
			||||||
      ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName));
 | 
					      ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName));
 | 
				
			||||||
      List<MPImage> segmentations = actualResult.segmentations();
 | 
					      assertThat(actualResult.categoryMask().isPresent()).isTrue();
 | 
				
			||||||
      assertThat(segmentations.size()).isEqualTo(1);
 | 
					      MPImage actualMaskBuffer = actualResult.categoryMask().get();
 | 
				
			||||||
      MPImage actualMaskBuffer = actualResult.segmentations().get(0);
 | 
					 | 
				
			||||||
      MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
 | 
					      MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
 | 
				
			||||||
      verifyCategoryMask(
 | 
					      verifyCategoryMask(
 | 
				
			||||||
          actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY, MAGNIFICATION_FACTOR);
 | 
					          actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY, MAGNIFICATION_FACTOR);
 | 
				
			||||||
| 
						 | 
					@ -81,15 +81,14 @@ public class ImageSegmenterTest {
 | 
				
			||||||
      ImageSegmenterOptions options =
 | 
					      ImageSegmenterOptions options =
 | 
				
			||||||
          ImageSegmenterOptions.builder()
 | 
					          ImageSegmenterOptions.builder()
 | 
				
			||||||
              .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
 | 
					              .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
 | 
				
			||||||
              .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
 | 
					 | 
				
			||||||
              .build();
 | 
					              .build();
 | 
				
			||||||
      ImageSegmenter imageSegmenter =
 | 
					      ImageSegmenter imageSegmenter =
 | 
				
			||||||
          ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
 | 
					          ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
 | 
				
			||||||
      ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName));
 | 
					      ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName));
 | 
				
			||||||
      List<MPImage> segmentations = actualResult.segmentations();
 | 
					      List<MPImage> segmentations = actualResult.confidenceMasks().get();
 | 
				
			||||||
      assertThat(segmentations.size()).isEqualTo(21);
 | 
					      assertThat(segmentations.size()).isEqualTo(21);
 | 
				
			||||||
      // Cat category index 8.
 | 
					      // Cat category index 8.
 | 
				
			||||||
      MPImage actualMaskBuffer = actualResult.segmentations().get(8);
 | 
					      MPImage actualMaskBuffer = segmentations.get(8);
 | 
				
			||||||
      MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
 | 
					      MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
 | 
				
			||||||
      verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
 | 
					      verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
| 
						 | 
					@ -102,40 +101,36 @@ public class ImageSegmenterTest {
 | 
				
			||||||
          ImageSegmenterOptions.builder()
 | 
					          ImageSegmenterOptions.builder()
 | 
				
			||||||
              .setBaseOptions(
 | 
					              .setBaseOptions(
 | 
				
			||||||
                  BaseOptions.builder().setModelAssetPath(SELFIE_128x128_MODEL_FILE).build())
 | 
					                  BaseOptions.builder().setModelAssetPath(SELFIE_128x128_MODEL_FILE).build())
 | 
				
			||||||
              .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
 | 
					 | 
				
			||||||
              .build();
 | 
					              .build();
 | 
				
			||||||
      ImageSegmenter imageSegmenter =
 | 
					      ImageSegmenter imageSegmenter =
 | 
				
			||||||
          ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
 | 
					          ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
 | 
				
			||||||
      ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName));
 | 
					      ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName));
 | 
				
			||||||
      List<MPImage> segmentations = actualResult.segmentations();
 | 
					      List<MPImage> segmentations = actualResult.confidenceMasks().get();
 | 
				
			||||||
      assertThat(segmentations.size()).isEqualTo(2);
 | 
					      assertThat(segmentations.size()).isEqualTo(2);
 | 
				
			||||||
      // Selfie category index 1.
 | 
					      // Selfie category index 1.
 | 
				
			||||||
      MPImage actualMaskBuffer = actualResult.segmentations().get(1);
 | 
					      MPImage actualMaskBuffer = segmentations.get(1);
 | 
				
			||||||
      MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
 | 
					      MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
 | 
				
			||||||
      verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
 | 
					      verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // TODO: enable this unit test once activation option is supported in metadata.
 | 
					    @Test
 | 
				
			||||||
    // @Test
 | 
					    public void segment_successWith144x256Segmentation() throws Exception {
 | 
				
			||||||
    // public void segment_successWith144x256Segmentation() throws Exception {
 | 
					      final String inputImageName = "mozart_square.jpg";
 | 
				
			||||||
    //   final String inputImageName = "mozart_square.jpg";
 | 
					      final String goldenImageName = "selfie_segm_144_256_3_expected_mask.jpg";
 | 
				
			||||||
    //   final String goldenImageName = "selfie_segm_144_256_3_expected_mask.jpg";
 | 
					      ImageSegmenterOptions options =
 | 
				
			||||||
    //   ImageSegmenterOptions options =
 | 
					          ImageSegmenterOptions.builder()
 | 
				
			||||||
    //       ImageSegmenterOptions.builder()
 | 
					              .setBaseOptions(
 | 
				
			||||||
    //           .setBaseOptions(
 | 
					                  BaseOptions.builder().setModelAssetPath(SELFIE_144x256_MODEL_FILE).build())
 | 
				
			||||||
    //               BaseOptions.builder().setModelAssetPath(SELFIE_144x256_MODEL_FILE).build())
 | 
					              .build();
 | 
				
			||||||
    //           .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
 | 
					      ImageSegmenter imageSegmenter =
 | 
				
			||||||
    //           .build();
 | 
					          ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
 | 
				
			||||||
    //   ImageSegmenter imageSegmenter =
 | 
					      ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName));
 | 
				
			||||||
    //       ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
 | 
					      List<MPImage> segmentations = actualResult.confidenceMasks().get();
 | 
				
			||||||
    //   ImageSegmenterResult actualResult =
 | 
					      assertThat(segmentations.size()).isEqualTo(1);
 | 
				
			||||||
    // imageSegmenter.segment(getImageFromAsset(inputImageName));
 | 
					      MPImage actualMaskBuffer = segmentations.get(0);
 | 
				
			||||||
    //   List<MPImage> segmentations = actualResult.segmentations();
 | 
					      MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
 | 
				
			||||||
    //   assertThat(segmentations.size()).isEqualTo(1);
 | 
					      verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
 | 
				
			||||||
    //   MPImage actualMaskBuffer = actualResult.segmentations().get(0);
 | 
					    }
 | 
				
			||||||
    //   MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
 | 
					 | 
				
			||||||
    //   verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
 | 
					 | 
				
			||||||
    // }
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Test
 | 
					    @Test
 | 
				
			||||||
    public void getLabels_success() throws Exception {
 | 
					    public void getLabels_success() throws Exception {
 | 
				
			||||||
| 
						 | 
					@ -165,7 +160,6 @@ public class ImageSegmenterTest {
 | 
				
			||||||
      ImageSegmenterOptions options =
 | 
					      ImageSegmenterOptions options =
 | 
				
			||||||
          ImageSegmenterOptions.builder()
 | 
					          ImageSegmenterOptions.builder()
 | 
				
			||||||
              .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
 | 
					              .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
 | 
				
			||||||
              .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
 | 
					 | 
				
			||||||
              .build();
 | 
					              .build();
 | 
				
			||||||
      ImageSegmenter imageSegmenter =
 | 
					      ImageSegmenter imageSegmenter =
 | 
				
			||||||
          ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
 | 
					          ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
 | 
				
			||||||
| 
						 | 
					@ -287,16 +281,15 @@ public class ImageSegmenterTest {
 | 
				
			||||||
      ImageSegmenterOptions options =
 | 
					      ImageSegmenterOptions options =
 | 
				
			||||||
          ImageSegmenterOptions.builder()
 | 
					          ImageSegmenterOptions.builder()
 | 
				
			||||||
              .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
 | 
					              .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
 | 
				
			||||||
              .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
 | 
					 | 
				
			||||||
              .setRunningMode(RunningMode.IMAGE)
 | 
					              .setRunningMode(RunningMode.IMAGE)
 | 
				
			||||||
              .build();
 | 
					              .build();
 | 
				
			||||||
      ImageSegmenter imageSegmenter =
 | 
					      ImageSegmenter imageSegmenter =
 | 
				
			||||||
          ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
 | 
					          ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
 | 
				
			||||||
      ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName));
 | 
					      ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName));
 | 
				
			||||||
      List<MPImage> segmentations = actualResult.segmentations();
 | 
					      List<MPImage> segmentations = actualResult.confidenceMasks().get();
 | 
				
			||||||
      assertThat(segmentations.size()).isEqualTo(21);
 | 
					      assertThat(segmentations.size()).isEqualTo(21);
 | 
				
			||||||
      // Cat category index 8.
 | 
					      // Cat category index 8.
 | 
				
			||||||
      MPImage actualMaskBuffer = actualResult.segmentations().get(8);
 | 
					      MPImage actualMaskBuffer = segmentations.get(8);
 | 
				
			||||||
      MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
 | 
					      MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
 | 
				
			||||||
      verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
 | 
					      verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
| 
						 | 
					@ -309,12 +302,11 @@ public class ImageSegmenterTest {
 | 
				
			||||||
      ImageSegmenterOptions options =
 | 
					      ImageSegmenterOptions options =
 | 
				
			||||||
          ImageSegmenterOptions.builder()
 | 
					          ImageSegmenterOptions.builder()
 | 
				
			||||||
              .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
 | 
					              .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
 | 
				
			||||||
              .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
 | 
					 | 
				
			||||||
              .setRunningMode(RunningMode.IMAGE)
 | 
					              .setRunningMode(RunningMode.IMAGE)
 | 
				
			||||||
              .setResultListener(
 | 
					              .setResultListener(
 | 
				
			||||||
                  (segmenterResult, inputImage) -> {
 | 
					                  (segmenterResult, inputImage) -> {
 | 
				
			||||||
                    verifyConfidenceMask(
 | 
					                    verifyConfidenceMask(
 | 
				
			||||||
                        segmenterResult.segmentations().get(8),
 | 
					                        segmenterResult.confidenceMasks().get().get(8),
 | 
				
			||||||
                        expectedResult,
 | 
					                        expectedResult,
 | 
				
			||||||
                        GOLDEN_MASK_SIMILARITY);
 | 
					                        GOLDEN_MASK_SIMILARITY);
 | 
				
			||||||
                  })
 | 
					                  })
 | 
				
			||||||
| 
						 | 
					@ -331,7 +323,6 @@ public class ImageSegmenterTest {
 | 
				
			||||||
      ImageSegmenterOptions options =
 | 
					      ImageSegmenterOptions options =
 | 
				
			||||||
          ImageSegmenterOptions.builder()
 | 
					          ImageSegmenterOptions.builder()
 | 
				
			||||||
              .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
 | 
					              .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
 | 
				
			||||||
              .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
 | 
					 | 
				
			||||||
              .setRunningMode(RunningMode.VIDEO)
 | 
					              .setRunningMode(RunningMode.VIDEO)
 | 
				
			||||||
              .build();
 | 
					              .build();
 | 
				
			||||||
      ImageSegmenter imageSegmenter =
 | 
					      ImageSegmenter imageSegmenter =
 | 
				
			||||||
| 
						 | 
					@ -341,10 +332,10 @@ public class ImageSegmenterTest {
 | 
				
			||||||
        ImageSegmenterResult actualResult =
 | 
					        ImageSegmenterResult actualResult =
 | 
				
			||||||
            imageSegmenter.segmentForVideo(
 | 
					            imageSegmenter.segmentForVideo(
 | 
				
			||||||
                getImageFromAsset(inputImageName), /* timestampsMs= */ i);
 | 
					                getImageFromAsset(inputImageName), /* timestampsMs= */ i);
 | 
				
			||||||
        List<MPImage> segmentations = actualResult.segmentations();
 | 
					        List<MPImage> segmentations = actualResult.confidenceMasks().get();
 | 
				
			||||||
        assertThat(segmentations.size()).isEqualTo(21);
 | 
					        assertThat(segmentations.size()).isEqualTo(21);
 | 
				
			||||||
        // Cat category index 8.
 | 
					        // Cat category index 8.
 | 
				
			||||||
        MPImage actualMaskBuffer = actualResult.segmentations().get(8);
 | 
					        MPImage actualMaskBuffer = segmentations.get(8);
 | 
				
			||||||
        verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
 | 
					        verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
| 
						 | 
					@ -357,12 +348,11 @@ public class ImageSegmenterTest {
 | 
				
			||||||
      ImageSegmenterOptions options =
 | 
					      ImageSegmenterOptions options =
 | 
				
			||||||
          ImageSegmenterOptions.builder()
 | 
					          ImageSegmenterOptions.builder()
 | 
				
			||||||
              .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
 | 
					              .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
 | 
				
			||||||
              .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
 | 
					 | 
				
			||||||
              .setRunningMode(RunningMode.VIDEO)
 | 
					              .setRunningMode(RunningMode.VIDEO)
 | 
				
			||||||
              .setResultListener(
 | 
					              .setResultListener(
 | 
				
			||||||
                  (segmenterResult, inputImage) -> {
 | 
					                  (segmenterResult, inputImage) -> {
 | 
				
			||||||
                    verifyConfidenceMask(
 | 
					                    verifyConfidenceMask(
 | 
				
			||||||
                        segmenterResult.segmentations().get(8),
 | 
					                        segmenterResult.confidenceMasks().get().get(8),
 | 
				
			||||||
                        expectedResult,
 | 
					                        expectedResult,
 | 
				
			||||||
                        GOLDEN_MASK_SIMILARITY);
 | 
					                        GOLDEN_MASK_SIMILARITY);
 | 
				
			||||||
                  })
 | 
					                  })
 | 
				
			||||||
| 
						 | 
					@ -384,12 +374,11 @@ public class ImageSegmenterTest {
 | 
				
			||||||
      ImageSegmenterOptions options =
 | 
					      ImageSegmenterOptions options =
 | 
				
			||||||
          ImageSegmenterOptions.builder()
 | 
					          ImageSegmenterOptions.builder()
 | 
				
			||||||
              .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
 | 
					              .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
 | 
				
			||||||
              .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
 | 
					 | 
				
			||||||
              .setRunningMode(RunningMode.LIVE_STREAM)
 | 
					              .setRunningMode(RunningMode.LIVE_STREAM)
 | 
				
			||||||
              .setResultListener(
 | 
					              .setResultListener(
 | 
				
			||||||
                  (segmenterResult, inputImage) -> {
 | 
					                  (segmenterResult, inputImage) -> {
 | 
				
			||||||
                    verifyConfidenceMask(
 | 
					                    verifyConfidenceMask(
 | 
				
			||||||
                        segmenterResult.segmentations().get(8),
 | 
					                        segmenterResult.confidenceMasks().get().get(8),
 | 
				
			||||||
                        expectedResult,
 | 
					                        expectedResult,
 | 
				
			||||||
                        GOLDEN_MASK_SIMILARITY);
 | 
					                        GOLDEN_MASK_SIMILARITY);
 | 
				
			||||||
                  })
 | 
					                  })
 | 
				
			||||||
| 
						 | 
					@ -411,12 +400,11 @@ public class ImageSegmenterTest {
 | 
				
			||||||
      ImageSegmenterOptions options =
 | 
					      ImageSegmenterOptions options =
 | 
				
			||||||
          ImageSegmenterOptions.builder()
 | 
					          ImageSegmenterOptions.builder()
 | 
				
			||||||
              .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
 | 
					              .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
 | 
				
			||||||
              .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
 | 
					 | 
				
			||||||
              .setRunningMode(RunningMode.LIVE_STREAM)
 | 
					              .setRunningMode(RunningMode.LIVE_STREAM)
 | 
				
			||||||
              .setResultListener(
 | 
					              .setResultListener(
 | 
				
			||||||
                  (segmenterResult, inputImage) -> {
 | 
					                  (segmenterResult, inputImage) -> {
 | 
				
			||||||
                    verifyConfidenceMask(
 | 
					                    verifyConfidenceMask(
 | 
				
			||||||
                        segmenterResult.segmentations().get(8),
 | 
					                        segmenterResult.confidenceMasks().get().get(8),
 | 
				
			||||||
                        expectedResult,
 | 
					                        expectedResult,
 | 
				
			||||||
                        GOLDEN_MASK_SIMILARITY);
 | 
					                        GOLDEN_MASK_SIMILARITY);
 | 
				
			||||||
                  })
 | 
					                  })
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -60,7 +60,10 @@ public class InteractiveSegmenterTest {
 | 
				
			||||||
              ApplicationProvider.getApplicationContext(), options);
 | 
					              ApplicationProvider.getApplicationContext(), options);
 | 
				
			||||||
      MPImage image = getImageFromAsset(inputImageName);
 | 
					      MPImage image = getImageFromAsset(inputImageName);
 | 
				
			||||||
      ImageSegmenterResult actualResult = imageSegmenter.segment(image, roi);
 | 
					      ImageSegmenterResult actualResult = imageSegmenter.segment(image, roi);
 | 
				
			||||||
      List<MPImage> segmentations = actualResult.segmentations();
 | 
					      // TODO update to correct category mask output.
 | 
				
			||||||
 | 
					      // After InteractiveSegmenter updated according to (b/276519300), update this to use
 | 
				
			||||||
 | 
					      // categoryMask field instead of confidenceMasks.
 | 
				
			||||||
 | 
					      List<MPImage> segmentations = actualResult.confidenceMasks().get();
 | 
				
			||||||
      assertThat(segmentations.size()).isEqualTo(1);
 | 
					      assertThat(segmentations.size()).isEqualTo(1);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -79,7 +82,7 @@ public class InteractiveSegmenterTest {
 | 
				
			||||||
              ApplicationProvider.getApplicationContext(), options);
 | 
					              ApplicationProvider.getApplicationContext(), options);
 | 
				
			||||||
      ImageSegmenterResult actualResult =
 | 
					      ImageSegmenterResult actualResult =
 | 
				
			||||||
          imageSegmenter.segment(getImageFromAsset(inputImageName), roi);
 | 
					          imageSegmenter.segment(getImageFromAsset(inputImageName), roi);
 | 
				
			||||||
      List<MPImage> segmentations = actualResult.segmentations();
 | 
					      List<MPImage> segmentations = actualResult.confidenceMasks().get();
 | 
				
			||||||
      assertThat(segmentations.size()).isEqualTo(2);
 | 
					      assertThat(segmentations.size()).isEqualTo(2);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user