Add getLabels to ImageSegmeter Java API
PiperOrigin-RevId: 516683339
This commit is contained in:
parent
51d9640d88
commit
141cf843ae
|
@ -22,6 +22,9 @@ import "mediapipe/framework/calculator.proto";
|
||||||
import "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.proto";
|
import "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.proto";
|
||||||
import "mediapipe/util/label_map.proto";
|
import "mediapipe/util/label_map.proto";
|
||||||
|
|
||||||
|
option java_package = "com.google.mediapipe.tasks";
|
||||||
|
option java_outer_classname = "TensorsToSegmentationCalculatorOptionsProto";
|
||||||
|
|
||||||
message TensorsToSegmentationCalculatorOptions {
|
message TensorsToSegmentationCalculatorOptions {
|
||||||
extend mediapipe.CalculatorOptions {
|
extend mediapipe.CalculatorOptions {
|
||||||
optional TensorsToSegmentationCalculatorOptions ext = 458105876;
|
optional TensorsToSegmentationCalculatorOptions ext = 458105876;
|
||||||
|
|
|
@ -16,6 +16,7 @@ package com.google.mediapipe.tasks.core;
|
||||||
|
|
||||||
import android.content.Context;
|
import android.content.Context;
|
||||||
import android.util.Log;
|
import android.util.Log;
|
||||||
|
import com.google.mediapipe.proto.CalculatorProto.CalculatorGraphConfig;
|
||||||
import com.google.mediapipe.framework.AndroidAssetUtil;
|
import com.google.mediapipe.framework.AndroidAssetUtil;
|
||||||
import com.google.mediapipe.framework.AndroidPacketCreator;
|
import com.google.mediapipe.framework.AndroidPacketCreator;
|
||||||
import com.google.mediapipe.framework.Graph;
|
import com.google.mediapipe.framework.Graph;
|
||||||
|
@ -201,6 +202,10 @@ public class TaskRunner implements AutoCloseable {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public CalculatorGraphConfig getCalculatorGraphConfig() {
|
||||||
|
return graph.getCalculatorGraphConfig();
|
||||||
|
}
|
||||||
|
|
||||||
private synchronized void addPackets(Map<String, Packet> inputs, long inputTimestamp) {
|
private synchronized void addPackets(Map<String, Packet> inputs, long inputTimestamp) {
|
||||||
if (!graphStarted.get()) {
|
if (!graphStarted.get()) {
|
||||||
reportError(
|
reportError(
|
||||||
|
|
|
@ -41,6 +41,7 @@ _VISION_TASKS_JAVA_PROTO_LITE_TARGETS = [
|
||||||
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_java_proto_lite",
|
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_java_proto_lite",
|
||||||
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite",
|
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite",
|
||||||
"//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_java_proto_lite",
|
"//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_java_proto_lite",
|
||||||
|
"//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator_java_proto_lite",
|
||||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite",
|
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite",
|
||||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_java_proto_lite",
|
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_java_proto_lite",
|
||||||
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite",
|
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite",
|
||||||
|
|
|
@ -197,6 +197,7 @@ android_library(
|
||||||
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
||||||
"//mediapipe/java/com/google/mediapipe/framework/image",
|
"//mediapipe/java/com/google/mediapipe/framework/image",
|
||||||
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
|
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
|
||||||
|
"//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator_java_proto_lite",
|
||||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite",
|
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite",
|
||||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_java_proto_lite",
|
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_java_proto_lite",
|
||||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
|
||||||
|
|
|
@ -17,6 +17,7 @@ package com.google.mediapipe.tasks.vision.imagesegmenter;
|
||||||
import android.content.Context;
|
import android.content.Context;
|
||||||
import com.google.auto.value.AutoValue;
|
import com.google.auto.value.AutoValue;
|
||||||
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
|
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
|
||||||
|
import com.google.mediapipe.proto.CalculatorProto.CalculatorGraphConfig;
|
||||||
import com.google.mediapipe.framework.AndroidPacketGetter;
|
import com.google.mediapipe.framework.AndroidPacketGetter;
|
||||||
import com.google.mediapipe.framework.MediaPipeException;
|
import com.google.mediapipe.framework.MediaPipeException;
|
||||||
import com.google.mediapipe.framework.Packet;
|
import com.google.mediapipe.framework.Packet;
|
||||||
|
@ -24,6 +25,7 @@ import com.google.mediapipe.framework.PacketGetter;
|
||||||
import com.google.mediapipe.framework.image.BitmapImageBuilder;
|
import com.google.mediapipe.framework.image.BitmapImageBuilder;
|
||||||
import com.google.mediapipe.framework.image.ByteBufferImageBuilder;
|
import com.google.mediapipe.framework.image.ByteBufferImageBuilder;
|
||||||
import com.google.mediapipe.framework.image.MPImage;
|
import com.google.mediapipe.framework.image.MPImage;
|
||||||
|
import com.google.mediapipe.tasks.TensorsToSegmentationCalculatorOptionsProto;
|
||||||
import com.google.mediapipe.tasks.core.BaseOptions;
|
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||||
import com.google.mediapipe.tasks.core.ErrorListener;
|
import com.google.mediapipe.tasks.core.ErrorListener;
|
||||||
import com.google.mediapipe.tasks.core.OutputHandler;
|
import com.google.mediapipe.tasks.core.OutputHandler;
|
||||||
|
@ -88,8 +90,10 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
|
||||||
private static final int SEGMENTATION_OUT_STREAM_INDEX = 2;
|
private static final int SEGMENTATION_OUT_STREAM_INDEX = 2;
|
||||||
private static final String TASK_GRAPH_NAME =
|
private static final String TASK_GRAPH_NAME =
|
||||||
"mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph";
|
"mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph";
|
||||||
|
private static final String TENSORS_TO_SEGMENTATION_CALCULATOR_NAME =
|
||||||
|
"mediapipe.tasks.TensorsToSegmentationCalculator";
|
||||||
private boolean hasResultListener = false;
|
private boolean hasResultListener = false;
|
||||||
|
private List<String> labels = new ArrayList<>();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates an {@link ImageSegmenter} instance from an {@link ImageSegmenterOptions}.
|
* Creates an {@link ImageSegmenter} instance from an {@link ImageSegmenterOptions}.
|
||||||
|
@ -190,6 +194,41 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
|
||||||
TaskRunner taskRunner, RunningMode runningMode, boolean hasResultListener) {
|
TaskRunner taskRunner, RunningMode runningMode, boolean hasResultListener) {
|
||||||
super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME);
|
super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME);
|
||||||
this.hasResultListener = hasResultListener;
|
this.hasResultListener = hasResultListener;
|
||||||
|
populateLabels();
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
* Populate the labelmap in TensorsToSegmentationCalculator to labels field.
|
||||||
|
*
|
||||||
|
* @throws MediaPipeException if there is an error during finding TensorsToSegmentationCalculator.
|
||||||
|
*/
|
||||||
|
private void populateLabels() {
|
||||||
|
CalculatorGraphConfig graphConfig = this.runner.getCalculatorGraphConfig();
|
||||||
|
|
||||||
|
boolean foundTensorsToSegmentation = false;
|
||||||
|
for (CalculatorGraphConfig.Node node : graphConfig.getNodeList()) {
|
||||||
|
if (node.getName().contains(TENSORS_TO_SEGMENTATION_CALCULATOR_NAME)) {
|
||||||
|
if (foundTensorsToSegmentation) {
|
||||||
|
throw new MediaPipeException(
|
||||||
|
MediaPipeException.StatusCode.INTERNAL.ordinal(),
|
||||||
|
"The graph has more than one mediapipe.tasks.TensorsToSegmentationCalculator.");
|
||||||
|
}
|
||||||
|
foundTensorsToSegmentation = true;
|
||||||
|
TensorsToSegmentationCalculatorOptionsProto.TensorsToSegmentationCalculatorOptions options =
|
||||||
|
node.getOptions()
|
||||||
|
.getExtension(
|
||||||
|
TensorsToSegmentationCalculatorOptionsProto
|
||||||
|
.TensorsToSegmentationCalculatorOptions.ext);
|
||||||
|
for (int i = 0; i < options.getLabelItemsMap().size(); i++) {
|
||||||
|
Long labelKey = Long.valueOf(i);
|
||||||
|
if (!options.getLabelItemsMap().containsKey(labelKey)) {
|
||||||
|
throw new MediaPipeException(
|
||||||
|
MediaPipeException.StatusCode.INTERNAL.ordinal(),
|
||||||
|
"The lablemap have no expected key: " + labelKey);
|
||||||
|
}
|
||||||
|
labels.add(options.getLabelItemsMap().get(labelKey).getName());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -473,6 +512,17 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
|
||||||
sendLiveStreamData(image, imageProcessingOptions, timestampMs);
|
sendLiveStreamData(image, imageProcessingOptions, timestampMs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the category label list of the ImageSegmenter can recognize. For CATEGORY_MASK type, the
|
||||||
|
* index in the category mask corresponds to the category in the label list. For CONFIDENCE_MASK
|
||||||
|
* type, the output mask list at index corresponds to the category in the label list.
|
||||||
|
*
|
||||||
|
* <p>If there is no labelmap provided in the model file, empty label list is returned.
|
||||||
|
*/
|
||||||
|
List<String> getLabels() {
|
||||||
|
return labels;
|
||||||
|
}
|
||||||
|
|
||||||
/** Options for setting up an {@link ImageSegmenter}. */
|
/** Options for setting up an {@link ImageSegmenter}. */
|
||||||
@AutoValue
|
@AutoValue
|
||||||
public abstract static class ImageSegmenterOptions extends TaskOptions {
|
public abstract static class ImageSegmenterOptions extends TaskOptions {
|
||||||
|
|
|
@ -7,6 +7,7 @@ VERS_1.0 {
|
||||||
Java_com_google_mediapipe_framework_Graph_nativeAddPacketToInputStream;
|
Java_com_google_mediapipe_framework_Graph_nativeAddPacketToInputStream;
|
||||||
Java_com_google_mediapipe_framework_Graph_nativeCloseAllPacketSources;
|
Java_com_google_mediapipe_framework_Graph_nativeCloseAllPacketSources;
|
||||||
Java_com_google_mediapipe_framework_Graph_nativeCreateGraph;
|
Java_com_google_mediapipe_framework_Graph_nativeCreateGraph;
|
||||||
|
Java_com_google_mediapipe_framework_Graph_nativeGetCalculatorGraphConfig;
|
||||||
Java_com_google_mediapipe_framework_Graph_nativeLoadBinaryGraph*;
|
Java_com_google_mediapipe_framework_Graph_nativeLoadBinaryGraph*;
|
||||||
Java_com_google_mediapipe_framework_Graph_nativeMovePacketToInputStream;
|
Java_com_google_mediapipe_framework_Graph_nativeMovePacketToInputStream;
|
||||||
Java_com_google_mediapipe_framework_Graph_nativeReleaseGraph;
|
Java_com_google_mediapipe_framework_Graph_nativeReleaseGraph;
|
||||||
|
|
|
@ -34,6 +34,7 @@ import com.google.mediapipe.tasks.vision.imagesegmenter.ImageSegmenter.ImageSegm
|
||||||
import java.io.InputStream;
|
import java.io.InputStream;
|
||||||
import java.nio.ByteBuffer;
|
import java.nio.ByteBuffer;
|
||||||
import java.nio.FloatBuffer;
|
import java.nio.FloatBuffer;
|
||||||
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.junit.runner.RunWith;
|
import org.junit.runner.RunWith;
|
||||||
|
@ -135,6 +136,45 @@ public class ImageSegmenterTest {
|
||||||
// MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
|
// MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
|
||||||
// verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
|
// verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
|
||||||
// }
|
// }
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void getLabels_success() throws Exception {
|
||||||
|
final List<String> expectedLabels =
|
||||||
|
Arrays.asList(
|
||||||
|
"background",
|
||||||
|
"aeroplane",
|
||||||
|
"bicycle",
|
||||||
|
"bird",
|
||||||
|
"boat",
|
||||||
|
"bottle",
|
||||||
|
"bus",
|
||||||
|
"car",
|
||||||
|
"cat",
|
||||||
|
"chair",
|
||||||
|
"cow",
|
||||||
|
"dining table",
|
||||||
|
"dog",
|
||||||
|
"horse",
|
||||||
|
"motorbike",
|
||||||
|
"person",
|
||||||
|
"potted plant",
|
||||||
|
"sheep",
|
||||||
|
"sofa",
|
||||||
|
"train",
|
||||||
|
"tv");
|
||||||
|
ImageSegmenterOptions options =
|
||||||
|
ImageSegmenterOptions.builder()
|
||||||
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
||||||
|
.setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
|
||||||
|
.build();
|
||||||
|
ImageSegmenter imageSegmenter =
|
||||||
|
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
|
List<String> actualLabels = imageSegmenter.getLabels();
|
||||||
|
assertThat(actualLabels.size()).isEqualTo(expectedLabels.size());
|
||||||
|
for (int i = 0; i < actualLabels.size(); i++) {
|
||||||
|
assertThat(actualLabels.get(i)).isEqualTo(expectedLabels.get(i));
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@RunWith(AndroidJUnit4.class)
|
@RunWith(AndroidJUnit4.class)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user