Add getLabels to ImageSegmeter Java API
PiperOrigin-RevId: 516683339
This commit is contained in:
parent
51d9640d88
commit
141cf843ae
|
@ -22,6 +22,9 @@ import "mediapipe/framework/calculator.proto";
|
|||
import "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.proto";
|
||||
import "mediapipe/util/label_map.proto";
|
||||
|
||||
option java_package = "com.google.mediapipe.tasks";
|
||||
option java_outer_classname = "TensorsToSegmentationCalculatorOptionsProto";
|
||||
|
||||
message TensorsToSegmentationCalculatorOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
optional TensorsToSegmentationCalculatorOptions ext = 458105876;
|
||||
|
|
|
@ -16,6 +16,7 @@ package com.google.mediapipe.tasks.core;
|
|||
|
||||
import android.content.Context;
|
||||
import android.util.Log;
|
||||
import com.google.mediapipe.proto.CalculatorProto.CalculatorGraphConfig;
|
||||
import com.google.mediapipe.framework.AndroidAssetUtil;
|
||||
import com.google.mediapipe.framework.AndroidPacketCreator;
|
||||
import com.google.mediapipe.framework.Graph;
|
||||
|
@ -201,6 +202,10 @@ public class TaskRunner implements AutoCloseable {
|
|||
}
|
||||
}
|
||||
|
||||
public CalculatorGraphConfig getCalculatorGraphConfig() {
|
||||
return graph.getCalculatorGraphConfig();
|
||||
}
|
||||
|
||||
private synchronized void addPackets(Map<String, Packet> inputs, long inputTimestamp) {
|
||||
if (!graphStarted.get()) {
|
||||
reportError(
|
||||
|
|
|
@ -41,6 +41,7 @@ _VISION_TASKS_JAVA_PROTO_LITE_TARGETS = [
|
|||
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite",
|
||||
|
|
|
@ -197,6 +197,7 @@ android_library(
|
|||
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
||||
"//mediapipe/java/com/google/mediapipe/framework/image",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_java_proto_lite",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
|
||||
|
|
|
@ -17,6 +17,7 @@ package com.google.mediapipe.tasks.vision.imagesegmenter;
|
|||
import android.content.Context;
|
||||
import com.google.auto.value.AutoValue;
|
||||
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
|
||||
import com.google.mediapipe.proto.CalculatorProto.CalculatorGraphConfig;
|
||||
import com.google.mediapipe.framework.AndroidPacketGetter;
|
||||
import com.google.mediapipe.framework.MediaPipeException;
|
||||
import com.google.mediapipe.framework.Packet;
|
||||
|
@ -24,6 +25,7 @@ import com.google.mediapipe.framework.PacketGetter;
|
|||
import com.google.mediapipe.framework.image.BitmapImageBuilder;
|
||||
import com.google.mediapipe.framework.image.ByteBufferImageBuilder;
|
||||
import com.google.mediapipe.framework.image.MPImage;
|
||||
import com.google.mediapipe.tasks.TensorsToSegmentationCalculatorOptionsProto;
|
||||
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||
import com.google.mediapipe.tasks.core.ErrorListener;
|
||||
import com.google.mediapipe.tasks.core.OutputHandler;
|
||||
|
@ -88,8 +90,10 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
|
|||
private static final int SEGMENTATION_OUT_STREAM_INDEX = 2;
|
||||
private static final String TASK_GRAPH_NAME =
|
||||
"mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph";
|
||||
|
||||
private static final String TENSORS_TO_SEGMENTATION_CALCULATOR_NAME =
|
||||
"mediapipe.tasks.TensorsToSegmentationCalculator";
|
||||
private boolean hasResultListener = false;
|
||||
private List<String> labels = new ArrayList<>();
|
||||
|
||||
/**
|
||||
* Creates an {@link ImageSegmenter} instance from an {@link ImageSegmenterOptions}.
|
||||
|
@ -190,6 +194,41 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
|
|||
TaskRunner taskRunner, RunningMode runningMode, boolean hasResultListener) {
|
||||
super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME);
|
||||
this.hasResultListener = hasResultListener;
|
||||
populateLabels();
|
||||
}
|
||||
/**
|
||||
* Populate the labelmap in TensorsToSegmentationCalculator to labels field.
|
||||
*
|
||||
* @throws MediaPipeException if there is an error during finding TensorsToSegmentationCalculator.
|
||||
*/
|
||||
private void populateLabels() {
|
||||
CalculatorGraphConfig graphConfig = this.runner.getCalculatorGraphConfig();
|
||||
|
||||
boolean foundTensorsToSegmentation = false;
|
||||
for (CalculatorGraphConfig.Node node : graphConfig.getNodeList()) {
|
||||
if (node.getName().contains(TENSORS_TO_SEGMENTATION_CALCULATOR_NAME)) {
|
||||
if (foundTensorsToSegmentation) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.INTERNAL.ordinal(),
|
||||
"The graph has more than one mediapipe.tasks.TensorsToSegmentationCalculator.");
|
||||
}
|
||||
foundTensorsToSegmentation = true;
|
||||
TensorsToSegmentationCalculatorOptionsProto.TensorsToSegmentationCalculatorOptions options =
|
||||
node.getOptions()
|
||||
.getExtension(
|
||||
TensorsToSegmentationCalculatorOptionsProto
|
||||
.TensorsToSegmentationCalculatorOptions.ext);
|
||||
for (int i = 0; i < options.getLabelItemsMap().size(); i++) {
|
||||
Long labelKey = Long.valueOf(i);
|
||||
if (!options.getLabelItemsMap().containsKey(labelKey)) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.INTERNAL.ordinal(),
|
||||
"The lablemap have no expected key: " + labelKey);
|
||||
}
|
||||
labels.add(options.getLabelItemsMap().get(labelKey).getName());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -473,6 +512,17 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
|
|||
sendLiveStreamData(image, imageProcessingOptions, timestampMs);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the category label list of the ImageSegmenter can recognize. For CATEGORY_MASK type, the
|
||||
* index in the category mask corresponds to the category in the label list. For CONFIDENCE_MASK
|
||||
* type, the output mask list at index corresponds to the category in the label list.
|
||||
*
|
||||
* <p>If there is no labelmap provided in the model file, empty label list is returned.
|
||||
*/
|
||||
List<String> getLabels() {
|
||||
return labels;
|
||||
}
|
||||
|
||||
/** Options for setting up an {@link ImageSegmenter}. */
|
||||
@AutoValue
|
||||
public abstract static class ImageSegmenterOptions extends TaskOptions {
|
||||
|
|
|
@ -7,6 +7,7 @@ VERS_1.0 {
|
|||
Java_com_google_mediapipe_framework_Graph_nativeAddPacketToInputStream;
|
||||
Java_com_google_mediapipe_framework_Graph_nativeCloseAllPacketSources;
|
||||
Java_com_google_mediapipe_framework_Graph_nativeCreateGraph;
|
||||
Java_com_google_mediapipe_framework_Graph_nativeGetCalculatorGraphConfig;
|
||||
Java_com_google_mediapipe_framework_Graph_nativeLoadBinaryGraph*;
|
||||
Java_com_google_mediapipe_framework_Graph_nativeMovePacketToInputStream;
|
||||
Java_com_google_mediapipe_framework_Graph_nativeReleaseGraph;
|
||||
|
|
|
@ -34,6 +34,7 @@ import com.google.mediapipe.tasks.vision.imagesegmenter.ImageSegmenter.ImageSegm
|
|||
import java.io.InputStream;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.FloatBuffer;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
|
@ -135,6 +136,45 @@ public class ImageSegmenterTest {
|
|||
// MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
|
||||
// verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
|
||||
// }
|
||||
|
||||
@Test
|
||||
public void getLabels_success() throws Exception {
|
||||
final List<String> expectedLabels =
|
||||
Arrays.asList(
|
||||
"background",
|
||||
"aeroplane",
|
||||
"bicycle",
|
||||
"bird",
|
||||
"boat",
|
||||
"bottle",
|
||||
"bus",
|
||||
"car",
|
||||
"cat",
|
||||
"chair",
|
||||
"cow",
|
||||
"dining table",
|
||||
"dog",
|
||||
"horse",
|
||||
"motorbike",
|
||||
"person",
|
||||
"potted plant",
|
||||
"sheep",
|
||||
"sofa",
|
||||
"train",
|
||||
"tv");
|
||||
ImageSegmenterOptions options =
|
||||
ImageSegmenterOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
||||
.setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
|
||||
.build();
|
||||
ImageSegmenter imageSegmenter =
|
||||
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
List<String> actualLabels = imageSegmenter.getLabels();
|
||||
assertThat(actualLabels.size()).isEqualTo(expectedLabels.size());
|
||||
for (int i = 0; i < actualLabels.size(); i++) {
|
||||
assertThat(actualLabels.get(i)).isEqualTo(expectedLabels.get(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@RunWith(AndroidJUnit4.class)
|
||||
|
|
Loading…
Reference in New Issue
Block a user