Support scribble input for Interactive Segmenter Java API
PiperOrigin-RevId: 529177660
This commit is contained in:
		
							parent
							
								
									e84e90e5b2
								
							
						
					
					
						commit
						7c955246aa
					
				| 
						 | 
				
			
			@ -502,6 +502,7 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
 | 
			
		|||
  /** The Region-Of-Interest (ROI) to interact with. */
 | 
			
		||||
  public static class RegionOfInterest {
 | 
			
		||||
    private NormalizedKeypoint keypoint;
 | 
			
		||||
    private List<NormalizedKeypoint> scribble;
 | 
			
		||||
 | 
			
		||||
    private RegionOfInterest() {}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -514,6 +515,16 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
 | 
			
		|||
      roi.keypoint = keypoint;
 | 
			
		||||
      return roi;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Creates a {@link RegionOfInterest} instance representing scribbles over the object that the
 | 
			
		||||
     * user wants to segment.
 | 
			
		||||
     */
 | 
			
		||||
    public static RegionOfInterest create(List<NormalizedKeypoint> scribble) {
 | 
			
		||||
      RegionOfInterest roi = new RegionOfInterest();
 | 
			
		||||
      roi.scribble = scribble;
 | 
			
		||||
      return roi;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
| 
						 | 
				
			
			@ -535,6 +546,18 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
 | 
			
		|||
                          .setX(roi.keypoint.x())
 | 
			
		||||
                          .setY(roi.keypoint.y())))
 | 
			
		||||
          .build();
 | 
			
		||||
    } else if (roi.scribble != null) {
 | 
			
		||||
      RenderAnnotation.Scribble.Builder scribbleBuilder = RenderAnnotation.Scribble.newBuilder();
 | 
			
		||||
      for (NormalizedKeypoint p : roi.scribble) {
 | 
			
		||||
        scribbleBuilder.addPoint(RenderAnnotation.Point.newBuilder().setX(p.x()).setY(p.y()));
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      return builder
 | 
			
		||||
          .addRenderAnnotations(
 | 
			
		||||
              RenderAnnotation.newBuilder()
 | 
			
		||||
                  .setColor(Color.newBuilder().setR(255))
 | 
			
		||||
                  .setScribble(scribbleBuilder))
 | 
			
		||||
          .build();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    throw new IllegalArgumentException(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -27,6 +27,7 @@ import com.google.mediapipe.tasks.core.BaseOptions;
 | 
			
		|||
import com.google.mediapipe.tasks.vision.imagesegmenter.ImageSegmenterResult;
 | 
			
		||||
import com.google.mediapipe.tasks.vision.interactivesegmenter.InteractiveSegmenter.InteractiveSegmenterOptions;
 | 
			
		||||
import java.io.InputStream;
 | 
			
		||||
import java.util.ArrayList;
 | 
			
		||||
import java.util.List;
 | 
			
		||||
import org.junit.Test;
 | 
			
		||||
import org.junit.runner.RunWith;
 | 
			
		||||
| 
						 | 
				
			
			@ -36,7 +37,8 @@ import org.junit.runners.Suite.SuiteClasses;
 | 
			
		|||
/** Test for {@link InteractiveSegmenter}. */
 | 
			
		||||
@RunWith(Suite.class)
 | 
			
		||||
@SuiteClasses({
 | 
			
		||||
  InteractiveSegmenterTest.General.class,
 | 
			
		||||
  InteractiveSegmenterTest.KeypointRoi.class,
 | 
			
		||||
  InteractiveSegmenterTest.ScribbleRoi.class,
 | 
			
		||||
})
 | 
			
		||||
public class InteractiveSegmenterTest {
 | 
			
		||||
  private static final String DEEPLAB_MODEL_FILE = "ptm_512_hdt_ptm_woid.tflite";
 | 
			
		||||
| 
						 | 
				
			
			@ -44,7 +46,7 @@ public class InteractiveSegmenterTest {
 | 
			
		|||
  private static final int MAGNIFICATION_FACTOR = 10;
 | 
			
		||||
 | 
			
		||||
  @RunWith(AndroidJUnit4.class)
 | 
			
		||||
  public static final class General extends InteractiveSegmenterTest {
 | 
			
		||||
  public static final class KeypointRoi extends InteractiveSegmenterTest {
 | 
			
		||||
    @Test
 | 
			
		||||
    public void segment_successWithCategoryMask() throws Exception {
 | 
			
		||||
      final String inputImageName = CATS_AND_DOGS_IMAGE;
 | 
			
		||||
| 
						 | 
				
			
			@ -86,6 +88,57 @@ public class InteractiveSegmenterTest {
 | 
			
		|||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  @RunWith(AndroidJUnit4.class)
 | 
			
		||||
  public static final class ScribbleRoi extends InteractiveSegmenterTest {
 | 
			
		||||
    @Test
 | 
			
		||||
    public void segment_successWithCategoryMask() throws Exception {
 | 
			
		||||
      final String inputImageName = CATS_AND_DOGS_IMAGE;
 | 
			
		||||
      ArrayList<NormalizedKeypoint> scribble = new ArrayList<>();
 | 
			
		||||
      scribble.add(NormalizedKeypoint.create(0.25f, 0.9f));
 | 
			
		||||
      scribble.add(NormalizedKeypoint.create(0.25f, 0.91f));
 | 
			
		||||
      scribble.add(NormalizedKeypoint.create(0.25f, 0.92f));
 | 
			
		||||
      final InteractiveSegmenter.RegionOfInterest roi =
 | 
			
		||||
          InteractiveSegmenter.RegionOfInterest.create(scribble);
 | 
			
		||||
      InteractiveSegmenterOptions options =
 | 
			
		||||
          InteractiveSegmenterOptions.builder()
 | 
			
		||||
              .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
 | 
			
		||||
              .setOutputConfidenceMasks(false)
 | 
			
		||||
              .setOutputCategoryMask(true)
 | 
			
		||||
              .build();
 | 
			
		||||
      InteractiveSegmenter imageSegmenter =
 | 
			
		||||
          InteractiveSegmenter.createFromOptions(
 | 
			
		||||
              ApplicationProvider.getApplicationContext(), options);
 | 
			
		||||
      MPImage image = getImageFromAsset(inputImageName);
 | 
			
		||||
      ImageSegmenterResult actualResult = imageSegmenter.segment(image, roi);
 | 
			
		||||
      assertThat(actualResult.categoryMask().isPresent()).isTrue();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void segment_successWithConfidenceMask() throws Exception {
 | 
			
		||||
      final String inputImageName = CATS_AND_DOGS_IMAGE;
 | 
			
		||||
      ArrayList<NormalizedKeypoint> scribble = new ArrayList<>();
 | 
			
		||||
      scribble.add(NormalizedKeypoint.create(0.25f, 0.9f));
 | 
			
		||||
      scribble.add(NormalizedKeypoint.create(0.25f, 0.91f));
 | 
			
		||||
      scribble.add(NormalizedKeypoint.create(0.25f, 0.92f));
 | 
			
		||||
      final InteractiveSegmenter.RegionOfInterest roi =
 | 
			
		||||
          InteractiveSegmenter.RegionOfInterest.create(scribble);
 | 
			
		||||
      InteractiveSegmenterOptions options =
 | 
			
		||||
          InteractiveSegmenterOptions.builder()
 | 
			
		||||
              .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
 | 
			
		||||
              .setOutputConfidenceMasks(true)
 | 
			
		||||
              .setOutputCategoryMask(false)
 | 
			
		||||
              .build();
 | 
			
		||||
      InteractiveSegmenter imageSegmenter =
 | 
			
		||||
          InteractiveSegmenter.createFromOptions(
 | 
			
		||||
              ApplicationProvider.getApplicationContext(), options);
 | 
			
		||||
      ImageSegmenterResult actualResult =
 | 
			
		||||
          imageSegmenter.segment(getImageFromAsset(inputImageName), roi);
 | 
			
		||||
      assertThat(actualResult.confidenceMasks().isPresent()).isTrue();
 | 
			
		||||
      List<MPImage> confidenceMasks = actualResult.confidenceMasks().get();
 | 
			
		||||
      assertThat(confidenceMasks.size()).isEqualTo(2);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  private static MPImage getImageFromAsset(String filePath) throws Exception {
 | 
			
		||||
    AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets();
 | 
			
		||||
    InputStream istr = assetManager.open(filePath);
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user