Update java image segmenter to always output confidence masks and optionally output category mask.
PiperOrigin-RevId: 521804641
This commit is contained in:
parent
7c2930102d
commit
33cad24a5a
|
@ -79,15 +79,10 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
|
||||||
private static final List<String> INPUT_STREAMS =
|
private static final List<String> INPUT_STREAMS =
|
||||||
Collections.unmodifiableList(
|
Collections.unmodifiableList(
|
||||||
Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME));
|
Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME));
|
||||||
private static final List<String> OUTPUT_STREAMS =
|
private static final int CONFIDENCE_MASKS_OUT_STREAM_INDEX = 0;
|
||||||
Collections.unmodifiableList(
|
|
||||||
Arrays.asList(
|
|
||||||
"GROUPED_SEGMENTATION:segmented_mask_out",
|
|
||||||
"IMAGE:image_out",
|
|
||||||
"SEGMENTATION:0:segmentation"));
|
|
||||||
private static final int GROUPED_SEGMENTATION_OUT_STREAM_INDEX = 0;
|
|
||||||
private static final int IMAGE_OUT_STREAM_INDEX = 1;
|
private static final int IMAGE_OUT_STREAM_INDEX = 1;
|
||||||
private static final int SEGMENTATION_OUT_STREAM_INDEX = 2;
|
private static final int CONFIDENCE_MASK_OUT_STREAM_INDEX = 2;
|
||||||
|
private static final int CATEGORY_MASK_OUT_STREAM_INDEX = 3;
|
||||||
private static final String TASK_GRAPH_NAME =
|
private static final String TASK_GRAPH_NAME =
|
||||||
"mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph";
|
"mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph";
|
||||||
private static final String TENSORS_TO_SEGMENTATION_CALCULATOR_NAME =
|
private static final String TENSORS_TO_SEGMENTATION_CALCULATOR_NAME =
|
||||||
|
@ -104,6 +99,13 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
|
||||||
*/
|
*/
|
||||||
public static ImageSegmenter createFromOptions(
|
public static ImageSegmenter createFromOptions(
|
||||||
Context context, ImageSegmenterOptions segmenterOptions) {
|
Context context, ImageSegmenterOptions segmenterOptions) {
|
||||||
|
List<String> outputStreams = new ArrayList<>();
|
||||||
|
outputStreams.add("CONFIDENCE_MASKS:confidence_masks");
|
||||||
|
outputStreams.add("IMAGE:image_out");
|
||||||
|
outputStreams.add("CONFIDENCE_MASK:0:confidence_mask");
|
||||||
|
if (segmenterOptions.outputCategoryMask()) {
|
||||||
|
outputStreams.add("CATEGORY_MASK:category_mask");
|
||||||
|
}
|
||||||
// TODO: Consolidate OutputHandler and TaskRunner.
|
// TODO: Consolidate OutputHandler and TaskRunner.
|
||||||
OutputHandler<ImageSegmenterResult, MPImage> handler = new OutputHandler<>();
|
OutputHandler<ImageSegmenterResult, MPImage> handler = new OutputHandler<>();
|
||||||
handler.setOutputPacketConverter(
|
handler.setOutputPacketConverter(
|
||||||
|
@ -111,50 +113,62 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
|
||||||
@Override
|
@Override
|
||||||
public ImageSegmenterResult convertToTaskResult(List<Packet> packets)
|
public ImageSegmenterResult convertToTaskResult(List<Packet> packets)
|
||||||
throws MediaPipeException {
|
throws MediaPipeException {
|
||||||
if (packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).isEmpty()) {
|
if (packets.get(CONFIDENCE_MASKS_OUT_STREAM_INDEX).isEmpty()) {
|
||||||
return ImageSegmenterResult.create(
|
return ImageSegmenterResult.create(
|
||||||
new ArrayList<>(),
|
new ArrayList<>(),
|
||||||
packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).getTimestamp());
|
Optional.empty(),
|
||||||
|
packets.get(CONFIDENCE_MASKS_OUT_STREAM_INDEX).getTimestamp());
|
||||||
}
|
}
|
||||||
List<MPImage> segmentedMasks = new ArrayList<>();
|
List<MPImage> confidenceMasks = new ArrayList<>();
|
||||||
int width = PacketGetter.getImageWidth(packets.get(SEGMENTATION_OUT_STREAM_INDEX));
|
int width = PacketGetter.getImageWidth(packets.get(CONFIDENCE_MASK_OUT_STREAM_INDEX));
|
||||||
int height = PacketGetter.getImageHeight(packets.get(SEGMENTATION_OUT_STREAM_INDEX));
|
int height = PacketGetter.getImageHeight(packets.get(CONFIDENCE_MASK_OUT_STREAM_INDEX));
|
||||||
int imageFormat =
|
int confidenceMasksListSize =
|
||||||
segmenterOptions.outputType() == ImageSegmenterOptions.OutputType.CONFIDENCE_MASK
|
PacketGetter.getImageListSize(packets.get(CONFIDENCE_MASKS_OUT_STREAM_INDEX));
|
||||||
? MPImage.IMAGE_FORMAT_VEC32F1
|
ByteBuffer[] buffersArray = new ByteBuffer[confidenceMasksListSize];
|
||||||
: MPImage.IMAGE_FORMAT_ALPHA;
|
|
||||||
int imageListSize =
|
|
||||||
PacketGetter.getImageListSize(packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX));
|
|
||||||
ByteBuffer[] buffersArray = new ByteBuffer[imageListSize];
|
|
||||||
// If resultListener is not provided, the resulted MPImage is deep copied from mediapipe
|
// If resultListener is not provided, the resulted MPImage is deep copied from mediapipe
|
||||||
// graph. If provided, the result MPImage is wrapping the mediapipe packet memory.
|
// graph. If provided, the result MPImage is wrapping the mediapipe packet memory.
|
||||||
if (!segmenterOptions.resultListener().isPresent()) {
|
boolean copyImage = !segmenterOptions.resultListener().isPresent();
|
||||||
for (int i = 0; i < imageListSize; i++) {
|
if (copyImage) {
|
||||||
buffersArray[i] =
|
for (int i = 0; i < confidenceMasksListSize; i++) {
|
||||||
ByteBuffer.allocateDirect(
|
buffersArray[i] = ByteBuffer.allocateDirect(width * height * 4);
|
||||||
width * height * (imageFormat == MPImage.IMAGE_FORMAT_VEC32F1 ? 4 : 1));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!PacketGetter.getImageList(
|
if (!PacketGetter.getImageList(
|
||||||
packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX),
|
packets.get(CONFIDENCE_MASKS_OUT_STREAM_INDEX), buffersArray, copyImage)) {
|
||||||
buffersArray,
|
|
||||||
!segmenterOptions.resultListener().isPresent())) {
|
|
||||||
throw new MediaPipeException(
|
throw new MediaPipeException(
|
||||||
MediaPipeException.StatusCode.INTERNAL.ordinal(),
|
MediaPipeException.StatusCode.INTERNAL.ordinal(),
|
||||||
"There is an error getting segmented masks. It usually results from incorrect"
|
"There is an error getting segmented masks.");
|
||||||
+ " options of unsupported OutputType of given model.");
|
|
||||||
}
|
}
|
||||||
for (ByteBuffer buffer : buffersArray) {
|
for (ByteBuffer buffer : buffersArray) {
|
||||||
ByteBufferImageBuilder builder =
|
ByteBufferImageBuilder builder =
|
||||||
new ByteBufferImageBuilder(buffer, width, height, imageFormat);
|
new ByteBufferImageBuilder(buffer, width, height, MPImage.IMAGE_FORMAT_VEC32F1);
|
||||||
segmentedMasks.add(builder.build());
|
confidenceMasks.add(builder.build());
|
||||||
|
}
|
||||||
|
Optional<MPImage> categoryMask = Optional.empty();
|
||||||
|
if (segmenterOptions.outputCategoryMask()) {
|
||||||
|
ByteBuffer buffer;
|
||||||
|
if (copyImage) {
|
||||||
|
buffer = ByteBuffer.allocateDirect(width * height);
|
||||||
|
if (!PacketGetter.getImageData(
|
||||||
|
packets.get(CATEGORY_MASK_OUT_STREAM_INDEX), buffer)) {
|
||||||
|
throw new MediaPipeException(
|
||||||
|
MediaPipeException.StatusCode.INTERNAL.ordinal(),
|
||||||
|
"There is an error getting category mask.");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
buffer =
|
||||||
|
PacketGetter.getImageDataDirectly(packets.get(CATEGORY_MASK_OUT_STREAM_INDEX));
|
||||||
|
}
|
||||||
|
ByteBufferImageBuilder builder =
|
||||||
|
new ByteBufferImageBuilder(buffer, width, height, MPImage.IMAGE_FORMAT_ALPHA);
|
||||||
|
categoryMask = Optional.of(builder.build());
|
||||||
}
|
}
|
||||||
|
|
||||||
return ImageSegmenterResult.create(
|
return ImageSegmenterResult.create(
|
||||||
segmentedMasks,
|
confidenceMasks,
|
||||||
|
categoryMask,
|
||||||
BaseVisionTaskApi.generateResultTimestampMs(
|
BaseVisionTaskApi.generateResultTimestampMs(
|
||||||
segmenterOptions.runningMode(),
|
segmenterOptions.runningMode(),
|
||||||
packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX)));
|
packets.get(CONFIDENCE_MASKS_OUT_STREAM_INDEX)));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -174,7 +188,7 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
|
||||||
.setTaskRunningModeName(segmenterOptions.runningMode().name())
|
.setTaskRunningModeName(segmenterOptions.runningMode().name())
|
||||||
.setTaskGraphName(TASK_GRAPH_NAME)
|
.setTaskGraphName(TASK_GRAPH_NAME)
|
||||||
.setInputStreams(INPUT_STREAMS)
|
.setInputStreams(INPUT_STREAMS)
|
||||||
.setOutputStreams(OUTPUT_STREAMS)
|
.setOutputStreams(outputStreams)
|
||||||
.setTaskOptions(segmenterOptions)
|
.setTaskOptions(segmenterOptions)
|
||||||
.setEnableFlowLimiting(segmenterOptions.runningMode() == RunningMode.LIVE_STREAM)
|
.setEnableFlowLimiting(segmenterOptions.runningMode() == RunningMode.LIVE_STREAM)
|
||||||
.build(),
|
.build(),
|
||||||
|
@ -553,8 +567,8 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
|
||||||
*/
|
*/
|
||||||
public abstract Builder setDisplayNamesLocale(String value);
|
public abstract Builder setDisplayNamesLocale(String value);
|
||||||
|
|
||||||
/** The output type from image segmenter. */
|
/** Whether to output category mask. */
|
||||||
public abstract Builder setOutputType(OutputType value);
|
public abstract Builder setOutputCategoryMask(boolean value);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sets an optional {@link ResultListener} to receive the segmentation results when the graph
|
* Sets an optional {@link ResultListener} to receive the segmentation results when the graph
|
||||||
|
@ -594,27 +608,17 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
|
||||||
|
|
||||||
abstract String displayNamesLocale();
|
abstract String displayNamesLocale();
|
||||||
|
|
||||||
abstract OutputType outputType();
|
abstract boolean outputCategoryMask();
|
||||||
|
|
||||||
abstract Optional<ResultListener<ImageSegmenterResult, MPImage>> resultListener();
|
abstract Optional<ResultListener<ImageSegmenterResult, MPImage>> resultListener();
|
||||||
|
|
||||||
abstract Optional<ErrorListener> errorListener();
|
abstract Optional<ErrorListener> errorListener();
|
||||||
|
|
||||||
/** The output type of segmentation results. */
|
|
||||||
public enum OutputType {
|
|
||||||
// Gives a single output mask where each pixel represents the class which
|
|
||||||
// the pixel in the original image was predicted to belong to.
|
|
||||||
CATEGORY_MASK,
|
|
||||||
// Gives a list of output masks where, for each mask, each pixel represents
|
|
||||||
// the prediction confidence, usually in the [0, 1] range.
|
|
||||||
CONFIDENCE_MASK
|
|
||||||
}
|
|
||||||
|
|
||||||
public static Builder builder() {
|
public static Builder builder() {
|
||||||
return new AutoValue_ImageSegmenter_ImageSegmenterOptions.Builder()
|
return new AutoValue_ImageSegmenter_ImageSegmenterOptions.Builder()
|
||||||
.setRunningMode(RunningMode.IMAGE)
|
.setRunningMode(RunningMode.IMAGE)
|
||||||
.setDisplayNamesLocale("en")
|
.setDisplayNamesLocale("en")
|
||||||
.setOutputType(OutputType.CATEGORY_MASK);
|
.setOutputCategoryMask(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -633,14 +637,6 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
|
||||||
|
|
||||||
SegmenterOptionsProto.SegmenterOptions.Builder segmenterOptionsBuilder =
|
SegmenterOptionsProto.SegmenterOptions.Builder segmenterOptionsBuilder =
|
||||||
SegmenterOptionsProto.SegmenterOptions.newBuilder();
|
SegmenterOptionsProto.SegmenterOptions.newBuilder();
|
||||||
if (outputType() == OutputType.CONFIDENCE_MASK) {
|
|
||||||
segmenterOptionsBuilder.setOutputType(
|
|
||||||
SegmenterOptionsProto.SegmenterOptions.OutputType.CONFIDENCE_MASK);
|
|
||||||
} else if (outputType() == OutputType.CATEGORY_MASK) {
|
|
||||||
segmenterOptionsBuilder.setOutputType(
|
|
||||||
SegmenterOptionsProto.SegmenterOptions.OutputType.CATEGORY_MASK);
|
|
||||||
}
|
|
||||||
|
|
||||||
taskOptionsBuilder.setSegmenterOptions(segmenterOptionsBuilder);
|
taskOptionsBuilder.setSegmenterOptions(segmenterOptionsBuilder);
|
||||||
return CalculatorOptions.newBuilder()
|
return CalculatorOptions.newBuilder()
|
||||||
.setExtension(
|
.setExtension(
|
||||||
|
|
|
@ -19,6 +19,7 @@ import com.google.mediapipe.framework.image.MPImage;
|
||||||
import com.google.mediapipe.tasks.core.TaskResult;
|
import com.google.mediapipe.tasks.core.TaskResult;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Optional;
|
||||||
|
|
||||||
/** Represents the segmentation results generated by {@link ImageSegmenter}. */
|
/** Represents the segmentation results generated by {@link ImageSegmenter}. */
|
||||||
@AutoValue
|
@AutoValue
|
||||||
|
@ -27,18 +28,24 @@ public abstract class ImageSegmenterResult implements TaskResult {
|
||||||
/**
|
/**
|
||||||
* Creates an {@link ImageSegmenterResult} instance from a list of segmentation MPImage.
|
* Creates an {@link ImageSegmenterResult} instance from a list of segmentation MPImage.
|
||||||
*
|
*
|
||||||
* @param segmentations a {@link List} of MPImage representing the segmented masks. If OutputType
|
* @param confidenceMasks a {@link List} of MPImage in IMAGE_FORMAT_VEC32F1 format representing
|
||||||
* is CATEGORY_MASK, the masks will be in IMAGE_FORMAT_ALPHA format. If OutputType is
|
* the confidence masks, where, for each mask, each pixel represents the prediction
|
||||||
* CONFIDENCE_MASK, the masks will be in IMAGE_FORMAT_VEC32F1 format.
|
* confidence, usually in the [0, 1] range.
|
||||||
|
* @param categoryMask an {@link Optional} MPImage in IMAGE_FORMAT_ALPHA format representing a
|
||||||
|
* category mask, where each pixel represents the class which the pixel in the original image
|
||||||
|
* was predicted to belong to.
|
||||||
* @param timestampMs a timestamp for this result.
|
* @param timestampMs a timestamp for this result.
|
||||||
*/
|
*/
|
||||||
// TODO: consolidate output formats across platforms.
|
// TODO: consolidate output formats across platforms.
|
||||||
public static ImageSegmenterResult create(List<MPImage> segmentations, long timestampMs) {
|
public static ImageSegmenterResult create(
|
||||||
|
List<MPImage> confidenceMasks, Optional<MPImage> categoryMask, long timestampMs) {
|
||||||
return new AutoValue_ImageSegmenterResult(
|
return new AutoValue_ImageSegmenterResult(
|
||||||
Collections.unmodifiableList(segmentations), timestampMs);
|
Collections.unmodifiableList(confidenceMasks), categoryMask, timestampMs);
|
||||||
}
|
}
|
||||||
|
|
||||||
public abstract List<MPImage> segmentations();
|
public abstract List<MPImage> confidenceMasks();
|
||||||
|
|
||||||
|
public abstract Optional<MPImage> categoryMask();
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public abstract long timestampMs();
|
public abstract long timestampMs();
|
||||||
|
|
|
@ -133,6 +133,7 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
|
||||||
if (packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).isEmpty()) {
|
if (packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).isEmpty()) {
|
||||||
return ImageSegmenterResult.create(
|
return ImageSegmenterResult.create(
|
||||||
new ArrayList<>(),
|
new ArrayList<>(),
|
||||||
|
Optional.empty(),
|
||||||
packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).getTimestamp());
|
packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).getTimestamp());
|
||||||
}
|
}
|
||||||
List<MPImage> segmentedMasks = new ArrayList<>();
|
List<MPImage> segmentedMasks = new ArrayList<>();
|
||||||
|
@ -172,6 +173,7 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
|
||||||
|
|
||||||
return ImageSegmenterResult.create(
|
return ImageSegmenterResult.create(
|
||||||
segmentedMasks,
|
segmentedMasks,
|
||||||
|
Optional.empty(),
|
||||||
BaseVisionTaskApi.generateResultTimestampMs(
|
BaseVisionTaskApi.generateResultTimestampMs(
|
||||||
RunningMode.IMAGE, packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX)));
|
RunningMode.IMAGE, packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX)));
|
||||||
}
|
}
|
||||||
|
|
|
@ -61,14 +61,13 @@ public class ImageSegmenterTest {
|
||||||
ImageSegmenterOptions options =
|
ImageSegmenterOptions options =
|
||||||
ImageSegmenterOptions.builder()
|
ImageSegmenterOptions.builder()
|
||||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
||||||
.setOutputType(ImageSegmenterOptions.OutputType.CATEGORY_MASK)
|
.setOutputCategoryMask(true)
|
||||||
.build();
|
.build();
|
||||||
ImageSegmenter imageSegmenter =
|
ImageSegmenter imageSegmenter =
|
||||||
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName));
|
ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName));
|
||||||
List<MPImage> segmentations = actualResult.segmentations();
|
assertThat(actualResult.categoryMask().isPresent()).isTrue();
|
||||||
assertThat(segmentations.size()).isEqualTo(1);
|
MPImage actualMaskBuffer = actualResult.categoryMask().get();
|
||||||
MPImage actualMaskBuffer = actualResult.segmentations().get(0);
|
|
||||||
MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
|
MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
|
||||||
verifyCategoryMask(
|
verifyCategoryMask(
|
||||||
actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY, MAGNIFICATION_FACTOR);
|
actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY, MAGNIFICATION_FACTOR);
|
||||||
|
@ -81,15 +80,14 @@ public class ImageSegmenterTest {
|
||||||
ImageSegmenterOptions options =
|
ImageSegmenterOptions options =
|
||||||
ImageSegmenterOptions.builder()
|
ImageSegmenterOptions.builder()
|
||||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
||||||
.setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
|
|
||||||
.build();
|
.build();
|
||||||
ImageSegmenter imageSegmenter =
|
ImageSegmenter imageSegmenter =
|
||||||
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName));
|
ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName));
|
||||||
List<MPImage> segmentations = actualResult.segmentations();
|
List<MPImage> segmentations = actualResult.confidenceMasks();
|
||||||
assertThat(segmentations.size()).isEqualTo(21);
|
assertThat(segmentations.size()).isEqualTo(21);
|
||||||
// Cat category index 8.
|
// Cat category index 8.
|
||||||
MPImage actualMaskBuffer = actualResult.segmentations().get(8);
|
MPImage actualMaskBuffer = segmentations.get(8);
|
||||||
MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
|
MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
|
||||||
verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
|
verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
|
||||||
}
|
}
|
||||||
|
@ -102,40 +100,36 @@ public class ImageSegmenterTest {
|
||||||
ImageSegmenterOptions.builder()
|
ImageSegmenterOptions.builder()
|
||||||
.setBaseOptions(
|
.setBaseOptions(
|
||||||
BaseOptions.builder().setModelAssetPath(SELFIE_128x128_MODEL_FILE).build())
|
BaseOptions.builder().setModelAssetPath(SELFIE_128x128_MODEL_FILE).build())
|
||||||
.setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
|
|
||||||
.build();
|
.build();
|
||||||
ImageSegmenter imageSegmenter =
|
ImageSegmenter imageSegmenter =
|
||||||
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName));
|
ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName));
|
||||||
List<MPImage> segmentations = actualResult.segmentations();
|
List<MPImage> segmentations = actualResult.confidenceMasks();
|
||||||
assertThat(segmentations.size()).isEqualTo(2);
|
assertThat(segmentations.size()).isEqualTo(2);
|
||||||
// Selfie category index 1.
|
// Selfie category index 1.
|
||||||
MPImage actualMaskBuffer = actualResult.segmentations().get(1);
|
MPImage actualMaskBuffer = segmentations.get(1);
|
||||||
MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
|
MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
|
||||||
verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
|
verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: enable this unit test once activation option is supported in metadata.
|
@Test
|
||||||
// @Test
|
public void segment_successWith144x256Segmentation() throws Exception {
|
||||||
// public void segment_successWith144x256Segmentation() throws Exception {
|
final String inputImageName = "mozart_square.jpg";
|
||||||
// final String inputImageName = "mozart_square.jpg";
|
final String goldenImageName = "selfie_segm_144_256_3_expected_mask.jpg";
|
||||||
// final String goldenImageName = "selfie_segm_144_256_3_expected_mask.jpg";
|
ImageSegmenterOptions options =
|
||||||
// ImageSegmenterOptions options =
|
ImageSegmenterOptions.builder()
|
||||||
// ImageSegmenterOptions.builder()
|
.setBaseOptions(
|
||||||
// .setBaseOptions(
|
BaseOptions.builder().setModelAssetPath(SELFIE_144x256_MODEL_FILE).build())
|
||||||
// BaseOptions.builder().setModelAssetPath(SELFIE_144x256_MODEL_FILE).build())
|
.build();
|
||||||
// .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
|
ImageSegmenter imageSegmenter =
|
||||||
// .build();
|
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
// ImageSegmenter imageSegmenter =
|
ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName));
|
||||||
// ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
List<MPImage> segmentations = actualResult.confidenceMasks();
|
||||||
// ImageSegmenterResult actualResult =
|
assertThat(segmentations.size()).isEqualTo(1);
|
||||||
// imageSegmenter.segment(getImageFromAsset(inputImageName));
|
MPImage actualMaskBuffer = segmentations.get(0);
|
||||||
// List<MPImage> segmentations = actualResult.segmentations();
|
MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
|
||||||
// assertThat(segmentations.size()).isEqualTo(1);
|
verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
|
||||||
// MPImage actualMaskBuffer = actualResult.segmentations().get(0);
|
}
|
||||||
// MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
|
|
||||||
// verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
|
|
||||||
// }
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void getLabels_success() throws Exception {
|
public void getLabels_success() throws Exception {
|
||||||
|
@ -165,7 +159,6 @@ public class ImageSegmenterTest {
|
||||||
ImageSegmenterOptions options =
|
ImageSegmenterOptions options =
|
||||||
ImageSegmenterOptions.builder()
|
ImageSegmenterOptions.builder()
|
||||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
||||||
.setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
|
|
||||||
.build();
|
.build();
|
||||||
ImageSegmenter imageSegmenter =
|
ImageSegmenter imageSegmenter =
|
||||||
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
|
@ -287,16 +280,15 @@ public class ImageSegmenterTest {
|
||||||
ImageSegmenterOptions options =
|
ImageSegmenterOptions options =
|
||||||
ImageSegmenterOptions.builder()
|
ImageSegmenterOptions.builder()
|
||||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
||||||
.setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
|
|
||||||
.setRunningMode(RunningMode.IMAGE)
|
.setRunningMode(RunningMode.IMAGE)
|
||||||
.build();
|
.build();
|
||||||
ImageSegmenter imageSegmenter =
|
ImageSegmenter imageSegmenter =
|
||||||
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName));
|
ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName));
|
||||||
List<MPImage> segmentations = actualResult.segmentations();
|
List<MPImage> segmentations = actualResult.confidenceMasks();
|
||||||
assertThat(segmentations.size()).isEqualTo(21);
|
assertThat(segmentations.size()).isEqualTo(21);
|
||||||
// Cat category index 8.
|
// Cat category index 8.
|
||||||
MPImage actualMaskBuffer = actualResult.segmentations().get(8);
|
MPImage actualMaskBuffer = segmentations.get(8);
|
||||||
MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
|
MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
|
||||||
verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
|
verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
|
||||||
}
|
}
|
||||||
|
@ -309,12 +301,11 @@ public class ImageSegmenterTest {
|
||||||
ImageSegmenterOptions options =
|
ImageSegmenterOptions options =
|
||||||
ImageSegmenterOptions.builder()
|
ImageSegmenterOptions.builder()
|
||||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
||||||
.setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
|
|
||||||
.setRunningMode(RunningMode.IMAGE)
|
.setRunningMode(RunningMode.IMAGE)
|
||||||
.setResultListener(
|
.setResultListener(
|
||||||
(segmenterResult, inputImage) -> {
|
(segmenterResult, inputImage) -> {
|
||||||
verifyConfidenceMask(
|
verifyConfidenceMask(
|
||||||
segmenterResult.segmentations().get(8),
|
segmenterResult.confidenceMasks().get(8),
|
||||||
expectedResult,
|
expectedResult,
|
||||||
GOLDEN_MASK_SIMILARITY);
|
GOLDEN_MASK_SIMILARITY);
|
||||||
})
|
})
|
||||||
|
@ -331,7 +322,6 @@ public class ImageSegmenterTest {
|
||||||
ImageSegmenterOptions options =
|
ImageSegmenterOptions options =
|
||||||
ImageSegmenterOptions.builder()
|
ImageSegmenterOptions.builder()
|
||||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
||||||
.setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
|
|
||||||
.setRunningMode(RunningMode.VIDEO)
|
.setRunningMode(RunningMode.VIDEO)
|
||||||
.build();
|
.build();
|
||||||
ImageSegmenter imageSegmenter =
|
ImageSegmenter imageSegmenter =
|
||||||
|
@ -341,10 +331,10 @@ public class ImageSegmenterTest {
|
||||||
ImageSegmenterResult actualResult =
|
ImageSegmenterResult actualResult =
|
||||||
imageSegmenter.segmentForVideo(
|
imageSegmenter.segmentForVideo(
|
||||||
getImageFromAsset(inputImageName), /* timestampsMs= */ i);
|
getImageFromAsset(inputImageName), /* timestampsMs= */ i);
|
||||||
List<MPImage> segmentations = actualResult.segmentations();
|
List<MPImage> segmentations = actualResult.confidenceMasks();
|
||||||
assertThat(segmentations.size()).isEqualTo(21);
|
assertThat(segmentations.size()).isEqualTo(21);
|
||||||
// Cat category index 8.
|
// Cat category index 8.
|
||||||
MPImage actualMaskBuffer = actualResult.segmentations().get(8);
|
MPImage actualMaskBuffer = segmentations.get(8);
|
||||||
verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
|
verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -357,12 +347,11 @@ public class ImageSegmenterTest {
|
||||||
ImageSegmenterOptions options =
|
ImageSegmenterOptions options =
|
||||||
ImageSegmenterOptions.builder()
|
ImageSegmenterOptions.builder()
|
||||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
||||||
.setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
|
|
||||||
.setRunningMode(RunningMode.VIDEO)
|
.setRunningMode(RunningMode.VIDEO)
|
||||||
.setResultListener(
|
.setResultListener(
|
||||||
(segmenterResult, inputImage) -> {
|
(segmenterResult, inputImage) -> {
|
||||||
verifyConfidenceMask(
|
verifyConfidenceMask(
|
||||||
segmenterResult.segmentations().get(8),
|
segmenterResult.confidenceMasks().get(8),
|
||||||
expectedResult,
|
expectedResult,
|
||||||
GOLDEN_MASK_SIMILARITY);
|
GOLDEN_MASK_SIMILARITY);
|
||||||
})
|
})
|
||||||
|
@ -384,12 +373,11 @@ public class ImageSegmenterTest {
|
||||||
ImageSegmenterOptions options =
|
ImageSegmenterOptions options =
|
||||||
ImageSegmenterOptions.builder()
|
ImageSegmenterOptions.builder()
|
||||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
||||||
.setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
|
|
||||||
.setRunningMode(RunningMode.LIVE_STREAM)
|
.setRunningMode(RunningMode.LIVE_STREAM)
|
||||||
.setResultListener(
|
.setResultListener(
|
||||||
(segmenterResult, inputImage) -> {
|
(segmenterResult, inputImage) -> {
|
||||||
verifyConfidenceMask(
|
verifyConfidenceMask(
|
||||||
segmenterResult.segmentations().get(8),
|
segmenterResult.confidenceMasks().get(8),
|
||||||
expectedResult,
|
expectedResult,
|
||||||
GOLDEN_MASK_SIMILARITY);
|
GOLDEN_MASK_SIMILARITY);
|
||||||
})
|
})
|
||||||
|
@ -411,12 +399,11 @@ public class ImageSegmenterTest {
|
||||||
ImageSegmenterOptions options =
|
ImageSegmenterOptions options =
|
||||||
ImageSegmenterOptions.builder()
|
ImageSegmenterOptions.builder()
|
||||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
||||||
.setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
|
|
||||||
.setRunningMode(RunningMode.LIVE_STREAM)
|
.setRunningMode(RunningMode.LIVE_STREAM)
|
||||||
.setResultListener(
|
.setResultListener(
|
||||||
(segmenterResult, inputImage) -> {
|
(segmenterResult, inputImage) -> {
|
||||||
verifyConfidenceMask(
|
verifyConfidenceMask(
|
||||||
segmenterResult.segmentations().get(8),
|
segmenterResult.confidenceMasks().get(8),
|
||||||
expectedResult,
|
expectedResult,
|
||||||
GOLDEN_MASK_SIMILARITY);
|
GOLDEN_MASK_SIMILARITY);
|
||||||
})
|
})
|
||||||
|
|
|
@ -60,7 +60,10 @@ public class InteractiveSegmenterTest {
|
||||||
ApplicationProvider.getApplicationContext(), options);
|
ApplicationProvider.getApplicationContext(), options);
|
||||||
MPImage image = getImageFromAsset(inputImageName);
|
MPImage image = getImageFromAsset(inputImageName);
|
||||||
ImageSegmenterResult actualResult = imageSegmenter.segment(image, roi);
|
ImageSegmenterResult actualResult = imageSegmenter.segment(image, roi);
|
||||||
List<MPImage> segmentations = actualResult.segmentations();
|
// TODO update to correct category mask output.
|
||||||
|
// After InteractiveSegmenter updated according to (b/276519300), update this to use
|
||||||
|
// categoryMask field instead of confidenceMasks.
|
||||||
|
List<MPImage> segmentations = actualResult.confidenceMasks();
|
||||||
assertThat(segmentations.size()).isEqualTo(1);
|
assertThat(segmentations.size()).isEqualTo(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -79,7 +82,7 @@ public class InteractiveSegmenterTest {
|
||||||
ApplicationProvider.getApplicationContext(), options);
|
ApplicationProvider.getApplicationContext(), options);
|
||||||
ImageSegmenterResult actualResult =
|
ImageSegmenterResult actualResult =
|
||||||
imageSegmenter.segment(getImageFromAsset(inputImageName), roi);
|
imageSegmenter.segment(getImageFromAsset(inputImageName), roi);
|
||||||
List<MPImage> segmentations = actualResult.segmentations();
|
List<MPImage> segmentations = actualResult.confidenceMasks();
|
||||||
assertThat(segmentations.size()).isEqualTo(2);
|
assertThat(segmentations.size()).isEqualTo(2);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user