Support scribble input for Interactive Segmenter Java API
PiperOrigin-RevId: 529177660
This commit is contained in:
parent
e84e90e5b2
commit
7c955246aa
|
@ -502,6 +502,7 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
|
|||
/** The Region-Of-Interest (ROI) to interact with. */
|
||||
public static class RegionOfInterest {
|
||||
private NormalizedKeypoint keypoint;
|
||||
private List<NormalizedKeypoint> scribble;
|
||||
|
||||
private RegionOfInterest() {}
|
||||
|
||||
|
@ -514,6 +515,16 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
|
|||
roi.keypoint = keypoint;
|
||||
return roi;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a {@link RegionOfInterest} instance representing scribbles over the object that the
|
||||
* user wants to segment.
|
||||
*/
|
||||
public static RegionOfInterest create(List<NormalizedKeypoint> scribble) {
|
||||
RegionOfInterest roi = new RegionOfInterest();
|
||||
roi.scribble = scribble;
|
||||
return roi;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -535,6 +546,18 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
|
|||
.setX(roi.keypoint.x())
|
||||
.setY(roi.keypoint.y())))
|
||||
.build();
|
||||
} else if (roi.scribble != null) {
|
||||
RenderAnnotation.Scribble.Builder scribbleBuilder = RenderAnnotation.Scribble.newBuilder();
|
||||
for (NormalizedKeypoint p : roi.scribble) {
|
||||
scribbleBuilder.addPoint(RenderAnnotation.Point.newBuilder().setX(p.x()).setY(p.y()));
|
||||
}
|
||||
|
||||
return builder
|
||||
.addRenderAnnotations(
|
||||
RenderAnnotation.newBuilder()
|
||||
.setColor(Color.newBuilder().setR(255))
|
||||
.setScribble(scribbleBuilder))
|
||||
.build();
|
||||
}
|
||||
|
||||
throw new IllegalArgumentException(
|
||||
|
|
|
@ -27,6 +27,7 @@ import com.google.mediapipe.tasks.core.BaseOptions;
|
|||
import com.google.mediapipe.tasks.vision.imagesegmenter.ImageSegmenterResult;
|
||||
import com.google.mediapipe.tasks.vision.interactivesegmenter.InteractiveSegmenter.InteractiveSegmenterOptions;
|
||||
import java.io.InputStream;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
|
@ -36,7 +37,8 @@ import org.junit.runners.Suite.SuiteClasses;
|
|||
/** Test for {@link InteractiveSegmenter}. */
|
||||
@RunWith(Suite.class)
|
||||
@SuiteClasses({
|
||||
InteractiveSegmenterTest.General.class,
|
||||
InteractiveSegmenterTest.KeypointRoi.class,
|
||||
InteractiveSegmenterTest.ScribbleRoi.class,
|
||||
})
|
||||
public class InteractiveSegmenterTest {
|
||||
private static final String DEEPLAB_MODEL_FILE = "ptm_512_hdt_ptm_woid.tflite";
|
||||
|
@ -44,7 +46,7 @@ public class InteractiveSegmenterTest {
|
|||
private static final int MAGNIFICATION_FACTOR = 10;
|
||||
|
||||
@RunWith(AndroidJUnit4.class)
|
||||
public static final class General extends InteractiveSegmenterTest {
|
||||
public static final class KeypointRoi extends InteractiveSegmenterTest {
|
||||
@Test
|
||||
public void segment_successWithCategoryMask() throws Exception {
|
||||
final String inputImageName = CATS_AND_DOGS_IMAGE;
|
||||
|
@ -86,6 +88,57 @@ public class InteractiveSegmenterTest {
|
|||
}
|
||||
}
|
||||
|
||||
@RunWith(AndroidJUnit4.class)
|
||||
public static final class ScribbleRoi extends InteractiveSegmenterTest {
|
||||
@Test
|
||||
public void segment_successWithCategoryMask() throws Exception {
|
||||
final String inputImageName = CATS_AND_DOGS_IMAGE;
|
||||
ArrayList<NormalizedKeypoint> scribble = new ArrayList<>();
|
||||
scribble.add(NormalizedKeypoint.create(0.25f, 0.9f));
|
||||
scribble.add(NormalizedKeypoint.create(0.25f, 0.91f));
|
||||
scribble.add(NormalizedKeypoint.create(0.25f, 0.92f));
|
||||
final InteractiveSegmenter.RegionOfInterest roi =
|
||||
InteractiveSegmenter.RegionOfInterest.create(scribble);
|
||||
InteractiveSegmenterOptions options =
|
||||
InteractiveSegmenterOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
||||
.setOutputConfidenceMasks(false)
|
||||
.setOutputCategoryMask(true)
|
||||
.build();
|
||||
InteractiveSegmenter imageSegmenter =
|
||||
InteractiveSegmenter.createFromOptions(
|
||||
ApplicationProvider.getApplicationContext(), options);
|
||||
MPImage image = getImageFromAsset(inputImageName);
|
||||
ImageSegmenterResult actualResult = imageSegmenter.segment(image, roi);
|
||||
assertThat(actualResult.categoryMask().isPresent()).isTrue();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void segment_successWithConfidenceMask() throws Exception {
|
||||
final String inputImageName = CATS_AND_DOGS_IMAGE;
|
||||
ArrayList<NormalizedKeypoint> scribble = new ArrayList<>();
|
||||
scribble.add(NormalizedKeypoint.create(0.25f, 0.9f));
|
||||
scribble.add(NormalizedKeypoint.create(0.25f, 0.91f));
|
||||
scribble.add(NormalizedKeypoint.create(0.25f, 0.92f));
|
||||
final InteractiveSegmenter.RegionOfInterest roi =
|
||||
InteractiveSegmenter.RegionOfInterest.create(scribble);
|
||||
InteractiveSegmenterOptions options =
|
||||
InteractiveSegmenterOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
||||
.setOutputConfidenceMasks(true)
|
||||
.setOutputCategoryMask(false)
|
||||
.build();
|
||||
InteractiveSegmenter imageSegmenter =
|
||||
InteractiveSegmenter.createFromOptions(
|
||||
ApplicationProvider.getApplicationContext(), options);
|
||||
ImageSegmenterResult actualResult =
|
||||
imageSegmenter.segment(getImageFromAsset(inputImageName), roi);
|
||||
assertThat(actualResult.confidenceMasks().isPresent()).isTrue();
|
||||
List<MPImage> confidenceMasks = actualResult.confidenceMasks().get();
|
||||
assertThat(confidenceMasks.size()).isEqualTo(2);
|
||||
}
|
||||
}
|
||||
|
||||
private static MPImage getImageFromAsset(String filePath) throws Exception {
|
||||
AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets();
|
||||
InputStream istr = assetManager.open(filePath);
|
||||
|
|
Loading…
Reference in New Issue
Block a user