diff --git a/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.proto b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.proto index b0fdfdd32..f267bf09b 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.proto +++ b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.proto @@ -22,6 +22,9 @@ import "mediapipe/framework/calculator.proto"; import "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.proto"; import "mediapipe/util/label_map.proto"; +option java_package = "com.google.mediapipe.tasks"; +option java_outer_classname = "TensorsToSegmentationCalculatorOptionsProto"; + message TensorsToSegmentationCalculatorOptions { extend mediapipe.CalculatorOptions { optional TensorsToSegmentationCalculatorOptions ext = 458105876; diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskRunner.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskRunner.java index 1a128c538..11d385890 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskRunner.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskRunner.java @@ -16,6 +16,7 @@ package com.google.mediapipe.tasks.core; import android.content.Context; import android.util.Log; +import com.google.mediapipe.proto.CalculatorProto.CalculatorGraphConfig; import com.google.mediapipe.framework.AndroidAssetUtil; import com.google.mediapipe.framework.AndroidPacketCreator; import com.google.mediapipe.framework.Graph; @@ -201,6 +202,10 @@ public class TaskRunner implements AutoCloseable { } } + public CalculatorGraphConfig getCalculatorGraphConfig() { + return graph.getCalculatorGraphConfig(); + } + private synchronized void addPackets(Map inputs, long inputTimestamp) { if (!graphStarted.get()) { reportError( diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl index 5c5a154d8..d8a237e8d 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl @@ -41,6 +41,7 @@ _VISION_TASKS_JAVA_PROTO_LITE_TARGETS = [ "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator_java_proto_lite", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_java_proto_lite", "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD index bd57ffadb..a5b036924 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD @@ -197,6 +197,7 @@ android_library( "//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/java/com/google/mediapipe/framework/image", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator_java_proto_lite", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_java_proto_lite", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java index 76b33fb97..299423003 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java @@ -17,6 +17,7 @@ package com.google.mediapipe.tasks.vision.imagesegmenter; import android.content.Context; import com.google.auto.value.AutoValue; import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions; +import com.google.mediapipe.proto.CalculatorProto.CalculatorGraphConfig; import com.google.mediapipe.framework.AndroidPacketGetter; import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.Packet; @@ -24,6 +25,7 @@ import com.google.mediapipe.framework.PacketGetter; import com.google.mediapipe.framework.image.BitmapImageBuilder; import com.google.mediapipe.framework.image.ByteBufferImageBuilder; import com.google.mediapipe.framework.image.MPImage; +import com.google.mediapipe.tasks.TensorsToSegmentationCalculatorOptionsProto; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.ErrorListener; import com.google.mediapipe.tasks.core.OutputHandler; @@ -88,8 +90,10 @@ public final class ImageSegmenter extends BaseVisionTaskApi { private static final int SEGMENTATION_OUT_STREAM_INDEX = 2; private static final String TASK_GRAPH_NAME = "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph"; - + private static final String TENSORS_TO_SEGMENTATION_CALCULATOR_NAME = + "mediapipe.tasks.TensorsToSegmentationCalculator"; private boolean hasResultListener = false; + private List labels = new ArrayList<>(); /** * Creates an {@link ImageSegmenter} instance from an {@link ImageSegmenterOptions}. @@ -190,6 +194,41 @@ public final class ImageSegmenter extends BaseVisionTaskApi { TaskRunner taskRunner, RunningMode runningMode, boolean hasResultListener) { super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME); this.hasResultListener = hasResultListener; + populateLabels(); + } + /** + * Populate the labelmap in TensorsToSegmentationCalculator to labels field. + * + * @throws MediaPipeException if there is an error during finding TensorsToSegmentationCalculator. + */ + private void populateLabels() { + CalculatorGraphConfig graphConfig = this.runner.getCalculatorGraphConfig(); + + boolean foundTensorsToSegmentation = false; + for (CalculatorGraphConfig.Node node : graphConfig.getNodeList()) { + if (node.getName().contains(TENSORS_TO_SEGMENTATION_CALCULATOR_NAME)) { + if (foundTensorsToSegmentation) { + throw new MediaPipeException( + MediaPipeException.StatusCode.INTERNAL.ordinal(), + "The graph has more than one mediapipe.tasks.TensorsToSegmentationCalculator."); + } + foundTensorsToSegmentation = true; + TensorsToSegmentationCalculatorOptionsProto.TensorsToSegmentationCalculatorOptions options = + node.getOptions() + .getExtension( + TensorsToSegmentationCalculatorOptionsProto + .TensorsToSegmentationCalculatorOptions.ext); + for (int i = 0; i < options.getLabelItemsMap().size(); i++) { + Long labelKey = Long.valueOf(i); + if (!options.getLabelItemsMap().containsKey(labelKey)) { + throw new MediaPipeException( + MediaPipeException.StatusCode.INTERNAL.ordinal(), + "The lablemap have no expected key: " + labelKey); + } + labels.add(options.getLabelItemsMap().get(labelKey).getName()); + } + } + } } /** @@ -473,6 +512,17 @@ public final class ImageSegmenter extends BaseVisionTaskApi { sendLiveStreamData(image, imageProcessingOptions, timestampMs); } + /** + * Get the category label list of the ImageSegmenter can recognize. For CATEGORY_MASK type, the + * index in the category mask corresponds to the category in the label list. For CONFIDENCE_MASK + * type, the output mask list at index corresponds to the category in the label list. + * + *

If there is no labelmap provided in the model file, empty label list is returned. + */ + List getLabels() { + return labels; + } + /** Options for setting up an {@link ImageSegmenter}. */ @AutoValue public abstract static class ImageSegmenterOptions extends TaskOptions { diff --git a/mediapipe/tasks/java/version_script.lds b/mediapipe/tasks/java/version_script.lds index 08577b101..13f36f21e 100644 --- a/mediapipe/tasks/java/version_script.lds +++ b/mediapipe/tasks/java/version_script.lds @@ -7,6 +7,7 @@ VERS_1.0 { Java_com_google_mediapipe_framework_Graph_nativeAddPacketToInputStream; Java_com_google_mediapipe_framework_Graph_nativeCloseAllPacketSources; Java_com_google_mediapipe_framework_Graph_nativeCreateGraph; + Java_com_google_mediapipe_framework_Graph_nativeGetCalculatorGraphConfig; Java_com_google_mediapipe_framework_Graph_nativeLoadBinaryGraph*; Java_com_google_mediapipe_framework_Graph_nativeMovePacketToInputStream; Java_com_google_mediapipe_framework_Graph_nativeReleaseGraph; diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterTest.java index 16f591c40..3b35c21bc 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterTest.java @@ -34,6 +34,7 @@ import com.google.mediapipe.tasks.vision.imagesegmenter.ImageSegmenter.ImageSegm import java.io.InputStream; import java.nio.ByteBuffer; import java.nio.FloatBuffer; +import java.util.Arrays; import java.util.List; import org.junit.Test; import org.junit.runner.RunWith; @@ -135,6 +136,45 @@ public class ImageSegmenterTest { // MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName); // verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY); // } + + @Test + public void getLabels_success() throws Exception { + final List expectedLabels = + Arrays.asList( + "background", + "aeroplane", + "bicycle", + "bird", + "boat", + "bottle", + "bus", + "car", + "cat", + "chair", + "cow", + "dining table", + "dog", + "horse", + "motorbike", + "person", + "potted plant", + "sheep", + "sofa", + "train", + "tv"); + ImageSegmenterOptions options = + ImageSegmenterOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) + .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) + .build(); + ImageSegmenter imageSegmenter = + ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options); + List actualLabels = imageSegmenter.getLabels(); + assertThat(actualLabels.size()).isEqualTo(expectedLabels.size()); + for (int i = 0; i < actualLabels.size(); i++) { + assertThat(actualLabels.get(i)).isEqualTo(expectedLabels.get(i)); + } + } } @RunWith(AndroidJUnit4.class)