Add getLabels to ImageSegmeter Java API

PiperOrigin-RevId: 516683339
This commit is contained in:
MediaPipe Team 2023-03-14 18:01:49 -07:00 committed by Copybara-Service
parent 51d9640d88
commit 141cf843ae
7 changed files with 102 additions and 1 deletions

View File

@ -22,6 +22,9 @@ import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.proto"; import "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.proto";
import "mediapipe/util/label_map.proto"; import "mediapipe/util/label_map.proto";
option java_package = "com.google.mediapipe.tasks";
option java_outer_classname = "TensorsToSegmentationCalculatorOptionsProto";
message TensorsToSegmentationCalculatorOptions { message TensorsToSegmentationCalculatorOptions {
extend mediapipe.CalculatorOptions { extend mediapipe.CalculatorOptions {
optional TensorsToSegmentationCalculatorOptions ext = 458105876; optional TensorsToSegmentationCalculatorOptions ext = 458105876;

View File

@ -16,6 +16,7 @@ package com.google.mediapipe.tasks.core;
import android.content.Context; import android.content.Context;
import android.util.Log; import android.util.Log;
import com.google.mediapipe.proto.CalculatorProto.CalculatorGraphConfig;
import com.google.mediapipe.framework.AndroidAssetUtil; import com.google.mediapipe.framework.AndroidAssetUtil;
import com.google.mediapipe.framework.AndroidPacketCreator; import com.google.mediapipe.framework.AndroidPacketCreator;
import com.google.mediapipe.framework.Graph; import com.google.mediapipe.framework.Graph;
@ -201,6 +202,10 @@ public class TaskRunner implements AutoCloseable {
} }
} }
public CalculatorGraphConfig getCalculatorGraphConfig() {
return graph.getCalculatorGraphConfig();
}
private synchronized void addPackets(Map<String, Packet> inputs, long inputTimestamp) { private synchronized void addPackets(Map<String, Packet> inputs, long inputTimestamp) {
if (!graphStarted.get()) { if (!graphStarted.get()) {
reportError( reportError(

View File

@ -41,6 +41,7 @@ _VISION_TASKS_JAVA_PROTO_LITE_TARGETS = [
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite",

View File

@ -197,6 +197,7 @@ android_library(
"//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/java/com/google/mediapipe/framework:android_framework",
"//mediapipe/java/com/google/mediapipe/framework/image", "//mediapipe/java/com/google/mediapipe/framework/image",
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_java_proto_lite",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core",

View File

@ -17,6 +17,7 @@ package com.google.mediapipe.tasks.vision.imagesegmenter;
import android.content.Context; import android.content.Context;
import com.google.auto.value.AutoValue; import com.google.auto.value.AutoValue;
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions; import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
import com.google.mediapipe.proto.CalculatorProto.CalculatorGraphConfig;
import com.google.mediapipe.framework.AndroidPacketGetter; import com.google.mediapipe.framework.AndroidPacketGetter;
import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.MediaPipeException;
import com.google.mediapipe.framework.Packet; import com.google.mediapipe.framework.Packet;
@ -24,6 +25,7 @@ import com.google.mediapipe.framework.PacketGetter;
import com.google.mediapipe.framework.image.BitmapImageBuilder; import com.google.mediapipe.framework.image.BitmapImageBuilder;
import com.google.mediapipe.framework.image.ByteBufferImageBuilder; import com.google.mediapipe.framework.image.ByteBufferImageBuilder;
import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.framework.image.MPImage;
import com.google.mediapipe.tasks.TensorsToSegmentationCalculatorOptionsProto;
import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.BaseOptions;
import com.google.mediapipe.tasks.core.ErrorListener; import com.google.mediapipe.tasks.core.ErrorListener;
import com.google.mediapipe.tasks.core.OutputHandler; import com.google.mediapipe.tasks.core.OutputHandler;
@ -88,8 +90,10 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
private static final int SEGMENTATION_OUT_STREAM_INDEX = 2; private static final int SEGMENTATION_OUT_STREAM_INDEX = 2;
private static final String TASK_GRAPH_NAME = private static final String TASK_GRAPH_NAME =
"mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph"; "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph";
private static final String TENSORS_TO_SEGMENTATION_CALCULATOR_NAME =
"mediapipe.tasks.TensorsToSegmentationCalculator";
private boolean hasResultListener = false; private boolean hasResultListener = false;
private List<String> labels = new ArrayList<>();
/** /**
* Creates an {@link ImageSegmenter} instance from an {@link ImageSegmenterOptions}. * Creates an {@link ImageSegmenter} instance from an {@link ImageSegmenterOptions}.
@ -190,6 +194,41 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
TaskRunner taskRunner, RunningMode runningMode, boolean hasResultListener) { TaskRunner taskRunner, RunningMode runningMode, boolean hasResultListener) {
super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME); super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME);
this.hasResultListener = hasResultListener; this.hasResultListener = hasResultListener;
populateLabels();
}
/**
* Populate the labelmap in TensorsToSegmentationCalculator to labels field.
*
* @throws MediaPipeException if there is an error during finding TensorsToSegmentationCalculator.
*/
private void populateLabels() {
CalculatorGraphConfig graphConfig = this.runner.getCalculatorGraphConfig();
boolean foundTensorsToSegmentation = false;
for (CalculatorGraphConfig.Node node : graphConfig.getNodeList()) {
if (node.getName().contains(TENSORS_TO_SEGMENTATION_CALCULATOR_NAME)) {
if (foundTensorsToSegmentation) {
throw new MediaPipeException(
MediaPipeException.StatusCode.INTERNAL.ordinal(),
"The graph has more than one mediapipe.tasks.TensorsToSegmentationCalculator.");
}
foundTensorsToSegmentation = true;
TensorsToSegmentationCalculatorOptionsProto.TensorsToSegmentationCalculatorOptions options =
node.getOptions()
.getExtension(
TensorsToSegmentationCalculatorOptionsProto
.TensorsToSegmentationCalculatorOptions.ext);
for (int i = 0; i < options.getLabelItemsMap().size(); i++) {
Long labelKey = Long.valueOf(i);
if (!options.getLabelItemsMap().containsKey(labelKey)) {
throw new MediaPipeException(
MediaPipeException.StatusCode.INTERNAL.ordinal(),
"The lablemap have no expected key: " + labelKey);
}
labels.add(options.getLabelItemsMap().get(labelKey).getName());
}
}
}
} }
/** /**
@ -473,6 +512,17 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
sendLiveStreamData(image, imageProcessingOptions, timestampMs); sendLiveStreamData(image, imageProcessingOptions, timestampMs);
} }
/**
* Get the category label list of the ImageSegmenter can recognize. For CATEGORY_MASK type, the
* index in the category mask corresponds to the category in the label list. For CONFIDENCE_MASK
* type, the output mask list at index corresponds to the category in the label list.
*
* <p>If there is no labelmap provided in the model file, empty label list is returned.
*/
List<String> getLabels() {
return labels;
}
/** Options for setting up an {@link ImageSegmenter}. */ /** Options for setting up an {@link ImageSegmenter}. */
@AutoValue @AutoValue
public abstract static class ImageSegmenterOptions extends TaskOptions { public abstract static class ImageSegmenterOptions extends TaskOptions {

View File

@ -7,6 +7,7 @@ VERS_1.0 {
Java_com_google_mediapipe_framework_Graph_nativeAddPacketToInputStream; Java_com_google_mediapipe_framework_Graph_nativeAddPacketToInputStream;
Java_com_google_mediapipe_framework_Graph_nativeCloseAllPacketSources; Java_com_google_mediapipe_framework_Graph_nativeCloseAllPacketSources;
Java_com_google_mediapipe_framework_Graph_nativeCreateGraph; Java_com_google_mediapipe_framework_Graph_nativeCreateGraph;
Java_com_google_mediapipe_framework_Graph_nativeGetCalculatorGraphConfig;
Java_com_google_mediapipe_framework_Graph_nativeLoadBinaryGraph*; Java_com_google_mediapipe_framework_Graph_nativeLoadBinaryGraph*;
Java_com_google_mediapipe_framework_Graph_nativeMovePacketToInputStream; Java_com_google_mediapipe_framework_Graph_nativeMovePacketToInputStream;
Java_com_google_mediapipe_framework_Graph_nativeReleaseGraph; Java_com_google_mediapipe_framework_Graph_nativeReleaseGraph;

View File

@ -34,6 +34,7 @@ import com.google.mediapipe.tasks.vision.imagesegmenter.ImageSegmenter.ImageSegm
import java.io.InputStream; import java.io.InputStream;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.nio.FloatBuffer; import java.nio.FloatBuffer;
import java.util.Arrays;
import java.util.List; import java.util.List;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
@ -135,6 +136,45 @@ public class ImageSegmenterTest {
// MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName); // MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
// verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY); // verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
// } // }
@Test
public void getLabels_success() throws Exception {
final List<String> expectedLabels =
Arrays.asList(
"background",
"aeroplane",
"bicycle",
"bird",
"boat",
"bottle",
"bus",
"car",
"cat",
"chair",
"cow",
"dining table",
"dog",
"horse",
"motorbike",
"person",
"potted plant",
"sheep",
"sofa",
"train",
"tv");
ImageSegmenterOptions options =
ImageSegmenterOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
.setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
.build();
ImageSegmenter imageSegmenter =
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
List<String> actualLabels = imageSegmenter.getLabels();
assertThat(actualLabels.size()).isEqualTo(expectedLabels.size());
for (int i = 0; i < actualLabels.size(); i++) {
assertThat(actualLabels.get(i)).isEqualTo(expectedLabels.get(i));
}
}
} }
@RunWith(AndroidJUnit4.class) @RunWith(AndroidJUnit4.class)