Support scribble input for Interactive Segmenter Java API

PiperOrigin-RevId: 529177660
This commit is contained in:
MediaPipe Team 2023-05-03 13:19:51 -07:00 committed by Copybara-Service
parent e84e90e5b2
commit 7c955246aa
2 changed files with 78 additions and 2 deletions

View File

@ -502,6 +502,7 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
/** The Region-Of-Interest (ROI) to interact with. */
public static class RegionOfInterest {
private NormalizedKeypoint keypoint;
private List<NormalizedKeypoint> scribble;
private RegionOfInterest() {}
@ -514,6 +515,16 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
roi.keypoint = keypoint;
return roi;
}
/**
* Creates a {@link RegionOfInterest} instance representing scribbles over the object that the
* user wants to segment.
*/
public static RegionOfInterest create(List<NormalizedKeypoint> scribble) {
RegionOfInterest roi = new RegionOfInterest();
roi.scribble = scribble;
return roi;
}
}
/**
@ -535,6 +546,18 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
.setX(roi.keypoint.x())
.setY(roi.keypoint.y())))
.build();
} else if (roi.scribble != null) {
RenderAnnotation.Scribble.Builder scribbleBuilder = RenderAnnotation.Scribble.newBuilder();
for (NormalizedKeypoint p : roi.scribble) {
scribbleBuilder.addPoint(RenderAnnotation.Point.newBuilder().setX(p.x()).setY(p.y()));
}
return builder
.addRenderAnnotations(
RenderAnnotation.newBuilder()
.setColor(Color.newBuilder().setR(255))
.setScribble(scribbleBuilder))
.build();
}
throw new IllegalArgumentException(

View File

@ -27,6 +27,7 @@ import com.google.mediapipe.tasks.core.BaseOptions;
import com.google.mediapipe.tasks.vision.imagesegmenter.ImageSegmenterResult;
import com.google.mediapipe.tasks.vision.interactivesegmenter.InteractiveSegmenter.InteractiveSegmenterOptions;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.List;
import org.junit.Test;
import org.junit.runner.RunWith;
@ -36,7 +37,8 @@ import org.junit.runners.Suite.SuiteClasses;
/** Test for {@link InteractiveSegmenter}. */
@RunWith(Suite.class)
@SuiteClasses({
InteractiveSegmenterTest.General.class,
InteractiveSegmenterTest.KeypointRoi.class,
InteractiveSegmenterTest.ScribbleRoi.class,
})
public class InteractiveSegmenterTest {
private static final String DEEPLAB_MODEL_FILE = "ptm_512_hdt_ptm_woid.tflite";
@ -44,7 +46,7 @@ public class InteractiveSegmenterTest {
private static final int MAGNIFICATION_FACTOR = 10;
@RunWith(AndroidJUnit4.class)
public static final class General extends InteractiveSegmenterTest {
public static final class KeypointRoi extends InteractiveSegmenterTest {
@Test
public void segment_successWithCategoryMask() throws Exception {
final String inputImageName = CATS_AND_DOGS_IMAGE;
@ -86,6 +88,57 @@ public class InteractiveSegmenterTest {
}
}
@RunWith(AndroidJUnit4.class)
public static final class ScribbleRoi extends InteractiveSegmenterTest {
@Test
public void segment_successWithCategoryMask() throws Exception {
final String inputImageName = CATS_AND_DOGS_IMAGE;
ArrayList<NormalizedKeypoint> scribble = new ArrayList<>();
scribble.add(NormalizedKeypoint.create(0.25f, 0.9f));
scribble.add(NormalizedKeypoint.create(0.25f, 0.91f));
scribble.add(NormalizedKeypoint.create(0.25f, 0.92f));
final InteractiveSegmenter.RegionOfInterest roi =
InteractiveSegmenter.RegionOfInterest.create(scribble);
InteractiveSegmenterOptions options =
InteractiveSegmenterOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
.setOutputConfidenceMasks(false)
.setOutputCategoryMask(true)
.build();
InteractiveSegmenter imageSegmenter =
InteractiveSegmenter.createFromOptions(
ApplicationProvider.getApplicationContext(), options);
MPImage image = getImageFromAsset(inputImageName);
ImageSegmenterResult actualResult = imageSegmenter.segment(image, roi);
assertThat(actualResult.categoryMask().isPresent()).isTrue();
}
@Test
public void segment_successWithConfidenceMask() throws Exception {
final String inputImageName = CATS_AND_DOGS_IMAGE;
ArrayList<NormalizedKeypoint> scribble = new ArrayList<>();
scribble.add(NormalizedKeypoint.create(0.25f, 0.9f));
scribble.add(NormalizedKeypoint.create(0.25f, 0.91f));
scribble.add(NormalizedKeypoint.create(0.25f, 0.92f));
final InteractiveSegmenter.RegionOfInterest roi =
InteractiveSegmenter.RegionOfInterest.create(scribble);
InteractiveSegmenterOptions options =
InteractiveSegmenterOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
.setOutputConfidenceMasks(true)
.setOutputCategoryMask(false)
.build();
InteractiveSegmenter imageSegmenter =
InteractiveSegmenter.createFromOptions(
ApplicationProvider.getApplicationContext(), options);
ImageSegmenterResult actualResult =
imageSegmenter.segment(getImageFromAsset(inputImageName), roi);
assertThat(actualResult.confidenceMasks().isPresent()).isTrue();
List<MPImage> confidenceMasks = actualResult.confidenceMasks().get();
assertThat(confidenceMasks.size()).isEqualTo(2);
}
}
private static MPImage getImageFromAsset(String filePath) throws Exception {
AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets();
InputStream istr = assetManager.open(filePath);