Support scribble input for Interactive Segmenter Java API
PiperOrigin-RevId: 529177660
This commit is contained in:
parent
e84e90e5b2
commit
7c955246aa
|
@ -502,6 +502,7 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
|
||||||
/** The Region-Of-Interest (ROI) to interact with. */
|
/** The Region-Of-Interest (ROI) to interact with. */
|
||||||
public static class RegionOfInterest {
|
public static class RegionOfInterest {
|
||||||
private NormalizedKeypoint keypoint;
|
private NormalizedKeypoint keypoint;
|
||||||
|
private List<NormalizedKeypoint> scribble;
|
||||||
|
|
||||||
private RegionOfInterest() {}
|
private RegionOfInterest() {}
|
||||||
|
|
||||||
|
@ -514,6 +515,16 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
|
||||||
roi.keypoint = keypoint;
|
roi.keypoint = keypoint;
|
||||||
return roi;
|
return roi;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a {@link RegionOfInterest} instance representing scribbles over the object that the
|
||||||
|
* user wants to segment.
|
||||||
|
*/
|
||||||
|
public static RegionOfInterest create(List<NormalizedKeypoint> scribble) {
|
||||||
|
RegionOfInterest roi = new RegionOfInterest();
|
||||||
|
roi.scribble = scribble;
|
||||||
|
return roi;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -535,6 +546,18 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
|
||||||
.setX(roi.keypoint.x())
|
.setX(roi.keypoint.x())
|
||||||
.setY(roi.keypoint.y())))
|
.setY(roi.keypoint.y())))
|
||||||
.build();
|
.build();
|
||||||
|
} else if (roi.scribble != null) {
|
||||||
|
RenderAnnotation.Scribble.Builder scribbleBuilder = RenderAnnotation.Scribble.newBuilder();
|
||||||
|
for (NormalizedKeypoint p : roi.scribble) {
|
||||||
|
scribbleBuilder.addPoint(RenderAnnotation.Point.newBuilder().setX(p.x()).setY(p.y()));
|
||||||
|
}
|
||||||
|
|
||||||
|
return builder
|
||||||
|
.addRenderAnnotations(
|
||||||
|
RenderAnnotation.newBuilder()
|
||||||
|
.setColor(Color.newBuilder().setR(255))
|
||||||
|
.setScribble(scribbleBuilder))
|
||||||
|
.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
throw new IllegalArgumentException(
|
throw new IllegalArgumentException(
|
||||||
|
|
|
@ -27,6 +27,7 @@ import com.google.mediapipe.tasks.core.BaseOptions;
|
||||||
import com.google.mediapipe.tasks.vision.imagesegmenter.ImageSegmenterResult;
|
import com.google.mediapipe.tasks.vision.imagesegmenter.ImageSegmenterResult;
|
||||||
import com.google.mediapipe.tasks.vision.interactivesegmenter.InteractiveSegmenter.InteractiveSegmenterOptions;
|
import com.google.mediapipe.tasks.vision.interactivesegmenter.InteractiveSegmenter.InteractiveSegmenterOptions;
|
||||||
import java.io.InputStream;
|
import java.io.InputStream;
|
||||||
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.junit.runner.RunWith;
|
import org.junit.runner.RunWith;
|
||||||
|
@ -36,7 +37,8 @@ import org.junit.runners.Suite.SuiteClasses;
|
||||||
/** Test for {@link InteractiveSegmenter}. */
|
/** Test for {@link InteractiveSegmenter}. */
|
||||||
@RunWith(Suite.class)
|
@RunWith(Suite.class)
|
||||||
@SuiteClasses({
|
@SuiteClasses({
|
||||||
InteractiveSegmenterTest.General.class,
|
InteractiveSegmenterTest.KeypointRoi.class,
|
||||||
|
InteractiveSegmenterTest.ScribbleRoi.class,
|
||||||
})
|
})
|
||||||
public class InteractiveSegmenterTest {
|
public class InteractiveSegmenterTest {
|
||||||
private static final String DEEPLAB_MODEL_FILE = "ptm_512_hdt_ptm_woid.tflite";
|
private static final String DEEPLAB_MODEL_FILE = "ptm_512_hdt_ptm_woid.tflite";
|
||||||
|
@ -44,7 +46,7 @@ public class InteractiveSegmenterTest {
|
||||||
private static final int MAGNIFICATION_FACTOR = 10;
|
private static final int MAGNIFICATION_FACTOR = 10;
|
||||||
|
|
||||||
@RunWith(AndroidJUnit4.class)
|
@RunWith(AndroidJUnit4.class)
|
||||||
public static final class General extends InteractiveSegmenterTest {
|
public static final class KeypointRoi extends InteractiveSegmenterTest {
|
||||||
@Test
|
@Test
|
||||||
public void segment_successWithCategoryMask() throws Exception {
|
public void segment_successWithCategoryMask() throws Exception {
|
||||||
final String inputImageName = CATS_AND_DOGS_IMAGE;
|
final String inputImageName = CATS_AND_DOGS_IMAGE;
|
||||||
|
@ -86,6 +88,57 @@ public class InteractiveSegmenterTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@RunWith(AndroidJUnit4.class)
|
||||||
|
public static final class ScribbleRoi extends InteractiveSegmenterTest {
|
||||||
|
@Test
|
||||||
|
public void segment_successWithCategoryMask() throws Exception {
|
||||||
|
final String inputImageName = CATS_AND_DOGS_IMAGE;
|
||||||
|
ArrayList<NormalizedKeypoint> scribble = new ArrayList<>();
|
||||||
|
scribble.add(NormalizedKeypoint.create(0.25f, 0.9f));
|
||||||
|
scribble.add(NormalizedKeypoint.create(0.25f, 0.91f));
|
||||||
|
scribble.add(NormalizedKeypoint.create(0.25f, 0.92f));
|
||||||
|
final InteractiveSegmenter.RegionOfInterest roi =
|
||||||
|
InteractiveSegmenter.RegionOfInterest.create(scribble);
|
||||||
|
InteractiveSegmenterOptions options =
|
||||||
|
InteractiveSegmenterOptions.builder()
|
||||||
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
||||||
|
.setOutputConfidenceMasks(false)
|
||||||
|
.setOutputCategoryMask(true)
|
||||||
|
.build();
|
||||||
|
InteractiveSegmenter imageSegmenter =
|
||||||
|
InteractiveSegmenter.createFromOptions(
|
||||||
|
ApplicationProvider.getApplicationContext(), options);
|
||||||
|
MPImage image = getImageFromAsset(inputImageName);
|
||||||
|
ImageSegmenterResult actualResult = imageSegmenter.segment(image, roi);
|
||||||
|
assertThat(actualResult.categoryMask().isPresent()).isTrue();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void segment_successWithConfidenceMask() throws Exception {
|
||||||
|
final String inputImageName = CATS_AND_DOGS_IMAGE;
|
||||||
|
ArrayList<NormalizedKeypoint> scribble = new ArrayList<>();
|
||||||
|
scribble.add(NormalizedKeypoint.create(0.25f, 0.9f));
|
||||||
|
scribble.add(NormalizedKeypoint.create(0.25f, 0.91f));
|
||||||
|
scribble.add(NormalizedKeypoint.create(0.25f, 0.92f));
|
||||||
|
final InteractiveSegmenter.RegionOfInterest roi =
|
||||||
|
InteractiveSegmenter.RegionOfInterest.create(scribble);
|
||||||
|
InteractiveSegmenterOptions options =
|
||||||
|
InteractiveSegmenterOptions.builder()
|
||||||
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
||||||
|
.setOutputConfidenceMasks(true)
|
||||||
|
.setOutputCategoryMask(false)
|
||||||
|
.build();
|
||||||
|
InteractiveSegmenter imageSegmenter =
|
||||||
|
InteractiveSegmenter.createFromOptions(
|
||||||
|
ApplicationProvider.getApplicationContext(), options);
|
||||||
|
ImageSegmenterResult actualResult =
|
||||||
|
imageSegmenter.segment(getImageFromAsset(inputImageName), roi);
|
||||||
|
assertThat(actualResult.confidenceMasks().isPresent()).isTrue();
|
||||||
|
List<MPImage> confidenceMasks = actualResult.confidenceMasks().get();
|
||||||
|
assertThat(confidenceMasks.size()).isEqualTo(2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private static MPImage getImageFromAsset(String filePath) throws Exception {
|
private static MPImage getImageFromAsset(String filePath) throws Exception {
|
||||||
AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets();
|
AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets();
|
||||||
InputStream istr = assetManager.open(filePath);
|
InputStream istr = assetManager.open(filePath);
|
||||||
|
|
Loading…
Reference in New Issue
Block a user