diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java index 52a5f2a67..e9ff1f2b5 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java @@ -502,6 +502,7 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi { /** The Region-Of-Interest (ROI) to interact with. */ public static class RegionOfInterest { private NormalizedKeypoint keypoint; + private List scribble; private RegionOfInterest() {} @@ -514,6 +515,16 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi { roi.keypoint = keypoint; return roi; } + + /** + * Creates a {@link RegionOfInterest} instance representing scribbles over the object that the + * user wants to segment. + */ + public static RegionOfInterest create(List scribble) { + RegionOfInterest roi = new RegionOfInterest(); + roi.scribble = scribble; + return roi; + } } /** @@ -535,6 +546,18 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi { .setX(roi.keypoint.x()) .setY(roi.keypoint.y()))) .build(); + } else if (roi.scribble != null) { + RenderAnnotation.Scribble.Builder scribbleBuilder = RenderAnnotation.Scribble.newBuilder(); + for (NormalizedKeypoint p : roi.scribble) { + scribbleBuilder.addPoint(RenderAnnotation.Point.newBuilder().setX(p.x()).setY(p.y())); + } + + return builder + .addRenderAnnotations( + RenderAnnotation.newBuilder() + .setColor(Color.newBuilder().setR(255)) + .setScribble(scribbleBuilder)) + .build(); } throw new IllegalArgumentException( diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenterTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenterTest.java index 506036ba2..a534970f7 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenterTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenterTest.java @@ -27,6 +27,7 @@ import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.vision.imagesegmenter.ImageSegmenterResult; import com.google.mediapipe.tasks.vision.interactivesegmenter.InteractiveSegmenter.InteractiveSegmenterOptions; import java.io.InputStream; +import java.util.ArrayList; import java.util.List; import org.junit.Test; import org.junit.runner.RunWith; @@ -36,7 +37,8 @@ import org.junit.runners.Suite.SuiteClasses; /** Test for {@link InteractiveSegmenter}. */ @RunWith(Suite.class) @SuiteClasses({ - InteractiveSegmenterTest.General.class, + InteractiveSegmenterTest.KeypointRoi.class, + InteractiveSegmenterTest.ScribbleRoi.class, }) public class InteractiveSegmenterTest { private static final String DEEPLAB_MODEL_FILE = "ptm_512_hdt_ptm_woid.tflite"; @@ -44,7 +46,7 @@ public class InteractiveSegmenterTest { private static final int MAGNIFICATION_FACTOR = 10; @RunWith(AndroidJUnit4.class) - public static final class General extends InteractiveSegmenterTest { + public static final class KeypointRoi extends InteractiveSegmenterTest { @Test public void segment_successWithCategoryMask() throws Exception { final String inputImageName = CATS_AND_DOGS_IMAGE; @@ -86,6 +88,57 @@ public class InteractiveSegmenterTest { } } + @RunWith(AndroidJUnit4.class) + public static final class ScribbleRoi extends InteractiveSegmenterTest { + @Test + public void segment_successWithCategoryMask() throws Exception { + final String inputImageName = CATS_AND_DOGS_IMAGE; + ArrayList scribble = new ArrayList<>(); + scribble.add(NormalizedKeypoint.create(0.25f, 0.9f)); + scribble.add(NormalizedKeypoint.create(0.25f, 0.91f)); + scribble.add(NormalizedKeypoint.create(0.25f, 0.92f)); + final InteractiveSegmenter.RegionOfInterest roi = + InteractiveSegmenter.RegionOfInterest.create(scribble); + InteractiveSegmenterOptions options = + InteractiveSegmenterOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) + .setOutputConfidenceMasks(false) + .setOutputCategoryMask(true) + .build(); + InteractiveSegmenter imageSegmenter = + InteractiveSegmenter.createFromOptions( + ApplicationProvider.getApplicationContext(), options); + MPImage image = getImageFromAsset(inputImageName); + ImageSegmenterResult actualResult = imageSegmenter.segment(image, roi); + assertThat(actualResult.categoryMask().isPresent()).isTrue(); + } + + @Test + public void segment_successWithConfidenceMask() throws Exception { + final String inputImageName = CATS_AND_DOGS_IMAGE; + ArrayList scribble = new ArrayList<>(); + scribble.add(NormalizedKeypoint.create(0.25f, 0.9f)); + scribble.add(NormalizedKeypoint.create(0.25f, 0.91f)); + scribble.add(NormalizedKeypoint.create(0.25f, 0.92f)); + final InteractiveSegmenter.RegionOfInterest roi = + InteractiveSegmenter.RegionOfInterest.create(scribble); + InteractiveSegmenterOptions options = + InteractiveSegmenterOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) + .setOutputConfidenceMasks(true) + .setOutputCategoryMask(false) + .build(); + InteractiveSegmenter imageSegmenter = + InteractiveSegmenter.createFromOptions( + ApplicationProvider.getApplicationContext(), options); + ImageSegmenterResult actualResult = + imageSegmenter.segment(getImageFromAsset(inputImageName), roi); + assertThat(actualResult.confidenceMasks().isPresent()).isTrue(); + List confidenceMasks = actualResult.confidenceMasks().get(); + assertThat(confidenceMasks.size()).isEqualTo(2); + } + } + private static MPImage getImageFromAsset(String filePath) throws Exception { AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets(); InputStream istr = assetManager.open(filePath);