Add interactive segmenter java API

PiperOrigin-RevId: 518303391
This commit is contained in:
MediaPipe Team 2023-03-21 09:55:59 -07:00 committed by Copybara-Service
parent 6e0542c16a
commit 2be66e8eb0
7 changed files with 739 additions and 0 deletions

View File

@ -358,11 +358,21 @@ def mediapipe_java_proto_srcs(name = ""):
src_out = "com/google/mediapipe/formats/proto/RectProto.java", src_out = "com/google/mediapipe/formats/proto/RectProto.java",
)) ))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/util:color_java_proto_lite",
src_out = "com/google/mediapipe/util/proto/Color.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor( proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/util:label_map_java_proto_lite", target = "//mediapipe/util:label_map_java_proto_lite",
src_out = "com/google/mediapipe/util/proto/LabelMapProto.java", src_out = "com/google/mediapipe/util/proto/LabelMapProto.java",
)) ))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/util:render_data_java_proto_lite",
src_out = "com/google/mediapipe/util/proto/RenderData.java",
))
return proto_src_list return proto_src_list
def mediapipe_logging_java_proto_srcs(name = ""): def mediapipe_logging_java_proto_srcs(name = ""):

View File

@ -50,6 +50,7 @@ cc_binary(
"//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph", "//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph",
"//mediapipe/tasks/cc/vision/image_embedder:image_embedder_graph", "//mediapipe/tasks/cc/vision/image_embedder:image_embedder_graph",
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph", "//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
"//mediapipe/tasks/cc/vision/interactive_segmenter:interactive_segmenter_graph",
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph", "//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
"//mediapipe/tasks/java:version_script.lds", "//mediapipe/tasks/java:version_script.lds",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni",
@ -206,6 +207,35 @@ android_library(
], ],
) )
android_library(
name = "interactivesegmenter",
srcs = [
"imagesegmenter/ImageSegmenterResult.java",
"interactivesegmenter/InteractiveSegmenter.java",
],
javacopts = [
"-Xep:AndroidJdkLibsChecker:OFF",
],
manifest = "interactivesegmenter/AndroidManifest.xml",
deps = [
":core",
"//mediapipe/framework:calculator_options_java_proto_lite",
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
"//mediapipe/java/com/google/mediapipe/framework/image",
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_java_proto_lite",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:normalizedkeypoint",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
"//mediapipe/util:color_java_proto_lite",
"//mediapipe/util:render_data_java_proto_lite",
"//third_party:autovalue",
"@maven//:androidx_annotation_annotation",
"@maven//:com_google_guava_guava",
],
)
android_library( android_library(
name = "imageembedder", name = "imageembedder",
srcs = [ srcs = [

View File

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="com.google.mediapipe.tasks.vision.interactivesegmenter">
<uses-sdk android:minSdkVersion="24"
android:targetSdkVersion="30" />
</manifest>

View File

@ -0,0 +1,556 @@
// Copyright 2023 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.vision.interactivesegmenter;
import android.content.Context;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
import com.google.mediapipe.proto.CalculatorProto.CalculatorGraphConfig;
import com.google.mediapipe.framework.AndroidPacketGetter;
import com.google.mediapipe.framework.MediaPipeException;
import com.google.mediapipe.framework.Packet;
import com.google.mediapipe.framework.PacketGetter;
import com.google.mediapipe.framework.ProtoUtil;
import com.google.mediapipe.framework.image.BitmapImageBuilder;
import com.google.mediapipe.framework.image.ByteBufferImageBuilder;
import com.google.mediapipe.framework.image.MPImage;
import com.google.mediapipe.tasks.TensorsToSegmentationCalculatorOptionsProto;
import com.google.mediapipe.tasks.components.containers.NormalizedKeypoint;
import com.google.mediapipe.tasks.core.BaseOptions;
import com.google.mediapipe.tasks.core.ErrorListener;
import com.google.mediapipe.tasks.core.OutputHandler;
import com.google.mediapipe.tasks.core.OutputHandler.ResultListener;
import com.google.mediapipe.tasks.core.TaskInfo;
import com.google.mediapipe.tasks.core.TaskOptions;
import com.google.mediapipe.tasks.core.TaskRunner;
import com.google.mediapipe.tasks.core.proto.BaseOptionsProto;
import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi;
import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions;
import com.google.mediapipe.tasks.vision.core.RunningMode;
import com.google.mediapipe.tasks.vision.imagesegmenter.ImageSegmenterResult;
import com.google.mediapipe.tasks.vision.imagesegmenter.proto.ImageSegmenterGraphOptionsProto;
import com.google.mediapipe.tasks.vision.imagesegmenter.proto.SegmenterOptionsProto;
import com.google.mediapipe.util.proto.ColorProto.Color;
import com.google.mediapipe.util.proto.RenderDataProto.RenderAnnotation;
import com.google.mediapipe.util.proto.RenderDataProto.RenderData;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
/**
* Performs interactive segmentation on images.
*
* <p>Note that, in addition to the standard segmentation API {@link segment} that takes an input
* image and returns the outputs, but involves deep copy of the returns, InteractiveSegmenter also
* supports the callback API, {@link segmentWithResultListener}, which allows you to access the
* outputs through zero copy. Set {@link ResultListener} in {@link InteractiveSegmenterOptions}
* properly to use the callback API.
*
* <p>The API expects a TFLite model with,<a
* href="https://www.tensorflow.org/lite/convert/metadata">TFLite Model Metadata.</a>. The model
* expects input with 4 channels, where the first 3 channels represent RGB image, and the last
* channel represents the user's region of interest.
*
* <ul>
* <li>Input image {@link MPImage}
* <ul>
* <li>The image that image segmenter runs on.
* </ul>
* <li>Input roi {@link RegionOfInterest}
* <ul>
* <li>Region of interest based on user interaction.
* </ul>
* <li>Output ImageSegmenterResult {@link ImageSegmenterResult}
* <ul>
* <li>An ImageSegmenterResult containing segmented masks.
* </ul>
* </ul>
*/
public final class InteractiveSegmenter extends BaseVisionTaskApi {
private static final String TAG = InteractiveSegmenter.class.getSimpleName();
private static final String IMAGE_IN_STREAM_NAME = "image_in";
private static final String ROI_IN_STREAM_NAME = "roi_in";
private static final String NORM_RECT_IN_STREAM_NAME = "norm_rect_in";
private static final List<String> INPUT_STREAMS =
Collections.unmodifiableList(
Arrays.asList(
"IMAGE:" + IMAGE_IN_STREAM_NAME,
"ROI:" + ROI_IN_STREAM_NAME,
"NORM_RECT:" + NORM_RECT_IN_STREAM_NAME));
private static final List<String> OUTPUT_STREAMS =
Collections.unmodifiableList(
Arrays.asList(
"GROUPED_SEGMENTATION:segmented_mask_out",
"IMAGE:image_out",
"SEGMENTATION:0:segmentation"));
private static final int GROUPED_SEGMENTATION_OUT_STREAM_INDEX = 0;
private static final int IMAGE_OUT_STREAM_INDEX = 1;
private static final int SEGMENTATION_OUT_STREAM_INDEX = 2;
private static final String TASK_GRAPH_NAME =
"mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph";
private static final String TENSORS_TO_SEGMENTATION_CALCULATOR_NAME =
"mediapipe.tasks.TensorsToSegmentationCalculator";
private boolean hasResultListener = false;
private List<String> labels = new ArrayList<>();
static {
ProtoUtil.registerTypeName(RenderData.class, "mediapipe.RenderData");
}
/**
* Creates an {@link InteractiveSegmenter} instance from an {@link InteractiveSegmenterOptions}.
*
* @param context an Android {@link Context}.
* @param segmenterOptions an {@link InteractiveSegmenterOptions} instance.
* @throws MediaPipeException if there is an error during {@link InteractiveSegmenter} creation.
*/
public static InteractiveSegmenter createFromOptions(
Context context, InteractiveSegmenterOptions segmenterOptions) {
// TODO: Consolidate OutputHandler and TaskRunner.
OutputHandler<ImageSegmenterResult, MPImage> handler = new OutputHandler<>();
handler.setOutputPacketConverter(
new OutputHandler.OutputPacketConverter<ImageSegmenterResult, MPImage>() {
@Override
public ImageSegmenterResult convertToTaskResult(List<Packet> packets)
throws MediaPipeException {
if (packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).isEmpty()) {
return ImageSegmenterResult.create(
new ArrayList<>(),
packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).getTimestamp());
}
List<MPImage> segmentedMasks = new ArrayList<>();
int width = PacketGetter.getImageWidth(packets.get(SEGMENTATION_OUT_STREAM_INDEX));
int height = PacketGetter.getImageHeight(packets.get(SEGMENTATION_OUT_STREAM_INDEX));
int imageFormat =
segmenterOptions.outputType()
== InteractiveSegmenterOptions.OutputType.CONFIDENCE_MASK
? MPImage.IMAGE_FORMAT_VEC32F1
: MPImage.IMAGE_FORMAT_ALPHA;
int imageListSize =
PacketGetter.getImageListSize(packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX));
ByteBuffer[] buffersArray = new ByteBuffer[imageListSize];
// If resultListener is not provided, the resulted MPImage is deep copied from mediapipe
// graph. If provided, the result MPImage is wrapping the mediapipe packet memory.
if (!segmenterOptions.resultListener().isPresent()) {
for (int i = 0; i < imageListSize; i++) {
buffersArray[i] =
ByteBuffer.allocateDirect(
width * height * (imageFormat == MPImage.IMAGE_FORMAT_VEC32F1 ? 4 : 1));
}
}
if (!PacketGetter.getImageList(
packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX),
buffersArray,
!segmenterOptions.resultListener().isPresent())) {
throw new MediaPipeException(
MediaPipeException.StatusCode.INTERNAL.ordinal(),
"There is an error getting segmented masks. It usually results from incorrect"
+ " options of unsupported OutputType of given model.");
}
for (ByteBuffer buffer : buffersArray) {
ByteBufferImageBuilder builder =
new ByteBufferImageBuilder(buffer, width, height, imageFormat);
segmentedMasks.add(builder.build());
}
return ImageSegmenterResult.create(
segmentedMasks,
BaseVisionTaskApi.generateResultTimestampMs(
RunningMode.IMAGE, packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX)));
}
@Override
public MPImage convertToTaskInput(List<Packet> packets) {
return new BitmapImageBuilder(
AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX)))
.build();
}
});
segmenterOptions.resultListener().ifPresent(handler::setResultListener);
segmenterOptions.errorListener().ifPresent(handler::setErrorListener);
TaskRunner runner =
TaskRunner.create(
context,
TaskInfo.<InteractiveSegmenterOptions>builder()
.setTaskName(InteractiveSegmenter.class.getSimpleName())
.setTaskRunningModeName(RunningMode.IMAGE.name())
.setTaskGraphName(TASK_GRAPH_NAME)
.setInputStreams(INPUT_STREAMS)
.setOutputStreams(OUTPUT_STREAMS)
.setTaskOptions(segmenterOptions)
.setEnableFlowLimiting(false)
.build(),
handler);
return new InteractiveSegmenter(runner, segmenterOptions.resultListener().isPresent());
}
/**
* Constructor to initialize an {@link InteractiveSegmenter} from a {@link TaskRunner}.
*
* @param taskRunner a {@link TaskRunner}.
*/
private InteractiveSegmenter(TaskRunner taskRunner, boolean hasResultListener) {
super(taskRunner, RunningMode.IMAGE, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME);
this.hasResultListener = hasResultListener;
populateLabels();
}
/**
* Populate the labelmap in TensorsToSegmentationCalculator to labels field.
*
* @throws MediaPipeException if there is an error during finding TensorsToSegmentationCalculator.
*/
private void populateLabels() {
CalculatorGraphConfig graphConfig = this.runner.getCalculatorGraphConfig();
boolean foundTensorsToSegmentation = false;
for (CalculatorGraphConfig.Node node : graphConfig.getNodeList()) {
if (node.getName().contains(TENSORS_TO_SEGMENTATION_CALCULATOR_NAME)) {
if (foundTensorsToSegmentation) {
throw new MediaPipeException(
MediaPipeException.StatusCode.INTERNAL.ordinal(),
"The graph has more than one mediapipe.tasks.TensorsToSegmentationCalculator.");
}
foundTensorsToSegmentation = true;
TensorsToSegmentationCalculatorOptionsProto.TensorsToSegmentationCalculatorOptions options =
node.getOptions()
.getExtension(
TensorsToSegmentationCalculatorOptionsProto
.TensorsToSegmentationCalculatorOptions.ext);
for (int i = 0; i < options.getLabelItemsMap().size(); i++) {
Long labelKey = Long.valueOf(i);
if (!options.getLabelItemsMap().containsKey(labelKey)) {
throw new MediaPipeException(
MediaPipeException.StatusCode.INTERNAL.ordinal(),
"The lablemap have no expected key: " + labelKey);
}
labels.add(options.getLabelItemsMap().get(labelKey).getName());
}
}
}
}
/**
* Performs segmentation on the provided single image with default image processing options, given
* user's region-of-interest, i.e. without any rotation applied. TODO update java doc
* for input image format.
*
* <p>Users can represent user interaction through {@link RegionOfInterest}, which gives a hint to
* perform segmentation focusing on the given region of interest.
*
* <p>{@link InteractiveSegmenter} supports the following color space types:
*
* <ul>
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param image a MediaPipe {@link MPImage} object for processing.
* @param roi a {@link RegionOfInterest} object to represent user interaction.
* @throws MediaPipeException if there is an internal error. Or if {@link InteractiveSegmenter} is
* created with a {@link ResultListener}.
*/
public ImageSegmenterResult segment(MPImage image, RegionOfInterest roi) {
return segment(image, roi, ImageProcessingOptions.builder().build());
}
/**
* Performs segmentation on the provided single image, given user's region-of-interest.
* TODO update java doc for input image format.
*
* <p>Users can represent user interaction through {@link RegionOfInterest}, which gives a hint to
* perform segmentation focusing on the given region of interest.
*
* <p>{@link InteractiveSegmenter} supports the following color space types:
*
* <ul>
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param image a MediaPipe {@link MPImage} object for processing.
* @param roi a {@link RegionOfInterest} object to represent user interaction.
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
* input image before running inference. Note that region-of-interest is <b>not</b> supported
* by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in
* this method throwing an IllegalArgumentException.
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
* region-of-interest.
* @throws MediaPipeException if there is an internal error. Or if {@link InteractiveSegmenter} is
* created with a {@link ResultListener}.
*/
public ImageSegmenterResult segment(
MPImage image, RegionOfInterest roi, ImageProcessingOptions imageProcessingOptions) {
if (hasResultListener) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"ResultListener is provided in the InteractiveSegmenterOptions, but this method will"
+ " return an ImageSegmentationResult.");
}
validateImageProcessingOptions(imageProcessingOptions);
return processImageWithRoi(image, roi, imageProcessingOptions);
}
/**
* Performs segmentation on the provided single image with default image processing options, given
* user's region-of-interest, i.e. without any rotation applied, and provides zero-copied results
* via {@link ResultListener} in {@link InteractiveSegmenterOptions}.
*
* <p>TODO update java doc for input image format.
*
* <p>Users can represent user interaction through {@link RegionOfInterest}, which gives a hint to
* perform segmentation focusing on the given region of interest.
*
* <p>{@link InteractiveSegmenter} supports the following color space types:
*
* <ul>
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param image a MediaPipe {@link MPImage} object for processing.
* @param roi a {@link RegionOfInterest} object to represent user interaction.
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
* region-of-interest.
* @throws MediaPipeException if there is an internal error. Or if {@link InteractiveSegmenter} is
* not created wtih {@link ResultListener} set in {@link InteractiveSegmenterOptions}.
*/
public void segmentWithResultListener(MPImage image, RegionOfInterest roi) {
segmentWithResultListener(image, roi, ImageProcessingOptions.builder().build());
}
/**
* Performs segmentation on the provided single image given user's region-of-interest, and
* provides zero-copied results via {@link ResultListener} in {@link InteractiveSegmenterOptions}.
*
* <p>TODO update java doc for input image format.
*
* <p>Users can represent user interaction through {@link RegionOfInterest}, which gives a hint to
* perform segmentation focusing on the given region of interest.
*
* <p>{@link InteractiveSegmenter} supports the following color space types:
*
* <ul>
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param image a MediaPipe {@link MPImage} object for processing.
* @param roi a {@link RegionOfInterest} object to represent user interaction.
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
* input image before running inference. Note that region-of-interest is <b>not</b> supported
* by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in
* this method throwing an IllegalArgumentException.
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
* region-of-interest.
* @throws MediaPipeException if there is an internal error. Or if {@link InteractiveSegmenter} is
* not created wtih {@link ResultListener} set in {@link InteractiveSegmenterOptions}.
*/
public void segmentWithResultListener(
MPImage image, RegionOfInterest roi, ImageProcessingOptions imageProcessingOptions) {
if (!hasResultListener) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"ResultListener is not set in the InteractiveSegmenterOptions, but this method expects a"
+ " ResultListener to process ImageSegmentationResult.");
}
validateImageProcessingOptions(imageProcessingOptions);
ImageSegmenterResult unused = processImageWithRoi(image, roi, imageProcessingOptions);
}
/**
* Get the category label list of the ImageSegmenter can recognize. For CATEGORY_MASK type, the
* index in the category mask corresponds to the category in the label list. For CONFIDENCE_MASK
* type, the output mask list at index corresponds to the category in the label list.
*
* <p>If there is no labelmap provided in the model file, empty label list is returned.
*/
List<String> getLabels() {
return labels;
}
/** Options for setting up an {@link InteractiveSegmenter}. */
@AutoValue
public abstract static class InteractiveSegmenterOptions extends TaskOptions {
/** Builder for {@link InteractiveSegmenterOptions}. */
@AutoValue.Builder
public abstract static class Builder {
/** Sets the base options for the image segmenter task. */
public abstract Builder setBaseOptions(BaseOptions value);
/** The output type from image segmenter. */
public abstract Builder setOutputType(OutputType value);
/**
* Sets an optional {@link ResultListener} to receive the segmentation results when the graph
* pipeline is done processing an image.
*/
public abstract Builder setResultListener(
ResultListener<ImageSegmenterResult, MPImage> value);
/** Sets an optional {@link ErrorListener}}. */
public abstract Builder setErrorListener(ErrorListener value);
abstract InteractiveSegmenterOptions autoBuild();
/** Builds the {@link InteractiveSegmenterOptions} instance. */
public final InteractiveSegmenterOptions build() {
return autoBuild();
}
}
abstract BaseOptions baseOptions();
abstract OutputType outputType();
abstract Optional<ResultListener<ImageSegmenterResult, MPImage>> resultListener();
abstract Optional<ErrorListener> errorListener();
/** The output type of segmentation results. */
public enum OutputType {
// Gives a single output mask where each pixel represents the class which
// the pixel in the original image was predicted to belong to.
CATEGORY_MASK,
// Gives a list of output masks where, for each mask, each pixel represents
// the prediction confidence, usually in the [0, 1] range.
CONFIDENCE_MASK
}
public static Builder builder() {
return new AutoValue_InteractiveSegmenter_InteractiveSegmenterOptions.Builder()
.setOutputType(OutputType.CATEGORY_MASK);
}
/**
* Converts an {@link InteractiveSegmenterOptions} to a {@link CalculatorOptions} protobuf
* message.
*/
@Override
public CalculatorOptions convertToCalculatorOptionsProto() {
ImageSegmenterGraphOptionsProto.ImageSegmenterGraphOptions.Builder taskOptionsBuilder =
ImageSegmenterGraphOptionsProto.ImageSegmenterGraphOptions.newBuilder()
.setBaseOptions(
BaseOptionsProto.BaseOptions.newBuilder()
.setUseStreamMode(false)
.mergeFrom(convertBaseOptionsToProto(baseOptions()))
.build());
SegmenterOptionsProto.SegmenterOptions.Builder segmenterOptionsBuilder =
SegmenterOptionsProto.SegmenterOptions.newBuilder();
if (outputType() == OutputType.CONFIDENCE_MASK) {
segmenterOptionsBuilder.setOutputType(
SegmenterOptionsProto.SegmenterOptions.OutputType.CONFIDENCE_MASK);
} else if (outputType() == OutputType.CATEGORY_MASK) {
segmenterOptionsBuilder.setOutputType(
SegmenterOptionsProto.SegmenterOptions.OutputType.CATEGORY_MASK);
}
taskOptionsBuilder.setSegmenterOptions(segmenterOptionsBuilder);
return CalculatorOptions.newBuilder()
.setExtension(
ImageSegmenterGraphOptionsProto.ImageSegmenterGraphOptions.ext,
taskOptionsBuilder.build())
.build();
}
}
/**
* Validates that the provided {@link ImageProcessingOptions} doesn't contain a
* region-of-interest.
*/
private static void validateImageProcessingOptions(
ImageProcessingOptions imageProcessingOptions) {
if (imageProcessingOptions.regionOfInterest().isPresent()) {
throw new IllegalArgumentException(
"InteractiveSegmenter doesn't support region-of-interest.");
}
}
/** The Region-Of-Interest (ROI) to interact with. */
public static class RegionOfInterest {
private NormalizedKeypoint keypoint;
private RegionOfInterest() {}
/**
* Creates a {@link RegionOfInterest} instance representing a single point pointing to the
* object that the user wants to segment.
*/
public static RegionOfInterest create(NormalizedKeypoint keypoint) {
RegionOfInterest roi = new RegionOfInterest();
roi.keypoint = keypoint;
return roi;
}
}
/**
* Converts a {@link RegionOfInterest} instance into a {@link RenderData} protobuf message
*
* @param roi a {@link RegionOfInterest} object to represent user interaction.
* @throws IllegalArgumentException if {@link RegionOfInterest} does not represent a valid user
* interaction.
*/
private static RenderData convertToRenderData(RegionOfInterest roi) {
RenderData.Builder builder = RenderData.newBuilder();
if (roi.keypoint != null) {
return builder
.addRenderAnnotations(
RenderAnnotation.newBuilder()
.setColor(Color.newBuilder().setR(255))
.setPoint(
RenderAnnotation.Point.newBuilder()
.setX(roi.keypoint.x())
.setY(roi.keypoint.y())))
.build();
}
throw new IllegalArgumentException(
"RegionOfInterest does not include a valid user interaction");
}
/**
* A synchronous method to process single image inputs. The call blocks the current thread until a
* failure status or a successful result is returned.
*
* <p>This is almost the same as {@link BaseVisionTaskApi.processImageData} except accepting an
* additional {@link RegionOfInterest}.
*
* @param image a MediaPipe {@link MPImage} object for processing.
* @param roi a {@link RegionOfInterest} object to represent user interaction.
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
* input image before running inference.
* @throws MediaPipeException if the task is not in the image mode.
*/
private ImageSegmenterResult processImageWithRoi(
MPImage image, RegionOfInterest roi, ImageProcessingOptions imageProcessingOptions) {
if (runningMode != RunningMode.IMAGE) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task is not initialized with the image mode. Current running mode:"
+ runningMode.name());
}
Map<String, Packet> inputPackets = new HashMap<>();
inputPackets.put(IMAGE_IN_STREAM_NAME, runner.getPacketCreator().createImage(image));
RenderData renderData = convertToRenderData(roi);
inputPackets.put(ROI_IN_STREAM_NAME, runner.getPacketCreator().createProto(renderData));
inputPackets.put(
NORM_RECT_IN_STREAM_NAME,
runner.getPacketCreator().createProto(convertToNormalizedRect(imageProcessingOptions)));
return (ImageSegmenterResult) runner.process(inputPackets);
}
}

View File

@ -0,0 +1,24 @@
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="com.google.mediapipe.tasks.vision.interactivesegmentertest"
android:versionCode="1"
android:versionName="1.0" >
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE"/>
<uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE"/>
<uses-sdk android:minSdkVersion="24"
android:targetSdkVersion="30" />
<application
android:label="interactivesegmentertest"
android:name="android.support.multidex.MultiDexApplication"
android:taskAffinity="">
<uses-library android:name="android.test.runner" />
</application>
<instrumentation
android:name="com.google.android.apps.common.testing.testrunner.GoogleInstrumentationTestRunner"
android:targetPackage="com.google.mediapipe.tasks.vision.interactivesegmentertest" />
</manifest>

View File

@ -0,0 +1,19 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
# TODO: Enable this in OSS

View File

@ -0,0 +1,92 @@
// Copyright 2023 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.vision.interactivesegmenter;
import static com.google.common.truth.Truth.assertThat;
import android.content.res.AssetManager;
import android.graphics.BitmapFactory;
import androidx.test.core.app.ApplicationProvider;
import androidx.test.ext.junit.runners.AndroidJUnit4;
import com.google.mediapipe.framework.image.BitmapImageBuilder;
import com.google.mediapipe.framework.image.MPImage;
import com.google.mediapipe.tasks.components.containers.NormalizedKeypoint;
import com.google.mediapipe.tasks.core.BaseOptions;
import com.google.mediapipe.tasks.vision.imagesegmenter.ImageSegmenterResult;
import com.google.mediapipe.tasks.vision.interactivesegmenter.InteractiveSegmenter.InteractiveSegmenterOptions;
import java.io.InputStream;
import java.util.List;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Suite;
import org.junit.runners.Suite.SuiteClasses;
/** Test for {@link InteractiveSegmenter}. */
@RunWith(Suite.class)
@SuiteClasses({
InteractiveSegmenterTest.General.class,
})
public class InteractiveSegmenterTest {
private static final String DEEPLAB_MODEL_FILE = "ptm_512_hdt_ptm_woid.tflite";
private static final String CATS_AND_DOGS_IMAGE = "cats_and_dogs.jpg";
private static final int MAGNIFICATION_FACTOR = 10;
@RunWith(AndroidJUnit4.class)
public static final class General extends InteractiveSegmenterTest {
@Test
public void segment_successWithCategoryMask() throws Exception {
final String inputImageName = CATS_AND_DOGS_IMAGE;
final InteractiveSegmenter.RegionOfInterest roi =
InteractiveSegmenter.RegionOfInterest.create(NormalizedKeypoint.create(0.25f, 0.9f));
InteractiveSegmenterOptions options =
InteractiveSegmenterOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
.setOutputType(InteractiveSegmenterOptions.OutputType.CATEGORY_MASK)
.build();
InteractiveSegmenter imageSegmenter =
InteractiveSegmenter.createFromOptions(
ApplicationProvider.getApplicationContext(), options);
MPImage image = getImageFromAsset(inputImageName);
ImageSegmenterResult actualResult = imageSegmenter.segment(image, roi);
List<MPImage> segmentations = actualResult.segmentations();
assertThat(segmentations.size()).isEqualTo(1);
}
@Test
public void segment_successWithConfidenceMask() throws Exception {
final String inputImageName = CATS_AND_DOGS_IMAGE;
final InteractiveSegmenter.RegionOfInterest roi =
InteractiveSegmenter.RegionOfInterest.create(NormalizedKeypoint.create(0.25f, 0.9f));
InteractiveSegmenterOptions options =
InteractiveSegmenterOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
.setOutputType(InteractiveSegmenterOptions.OutputType.CONFIDENCE_MASK)
.build();
InteractiveSegmenter imageSegmenter =
InteractiveSegmenter.createFromOptions(
ApplicationProvider.getApplicationContext(), options);
ImageSegmenterResult actualResult =
imageSegmenter.segment(getImageFromAsset(inputImageName), roi);
List<MPImage> segmentations = actualResult.segmentations();
assertThat(segmentations.size()).isEqualTo(2);
}
}
private static MPImage getImageFromAsset(String filePath) throws Exception {
AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets();
InputStream istr = assetManager.open(filePath);
return new BitmapImageBuilder(BitmapFactory.decodeStream(istr)).build();
}
}