Add getLabels to ImageSegmeter Java API
PiperOrigin-RevId: 516683339
This commit is contained in:
		
							parent
							
								
									51d9640d88
								
							
						
					
					
						commit
						141cf843ae
					
				|  | @ -22,6 +22,9 @@ import "mediapipe/framework/calculator.proto"; | ||||||
| import "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.proto"; | import "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.proto"; | ||||||
| import "mediapipe/util/label_map.proto"; | import "mediapipe/util/label_map.proto"; | ||||||
| 
 | 
 | ||||||
|  | option java_package = "com.google.mediapipe.tasks"; | ||||||
|  | option java_outer_classname = "TensorsToSegmentationCalculatorOptionsProto"; | ||||||
|  | 
 | ||||||
| message TensorsToSegmentationCalculatorOptions { | message TensorsToSegmentationCalculatorOptions { | ||||||
|   extend mediapipe.CalculatorOptions { |   extend mediapipe.CalculatorOptions { | ||||||
|     optional TensorsToSegmentationCalculatorOptions ext = 458105876; |     optional TensorsToSegmentationCalculatorOptions ext = 458105876; | ||||||
|  |  | ||||||
|  | @ -16,6 +16,7 @@ package com.google.mediapipe.tasks.core; | ||||||
| 
 | 
 | ||||||
| import android.content.Context; | import android.content.Context; | ||||||
| import android.util.Log; | import android.util.Log; | ||||||
|  | import com.google.mediapipe.proto.CalculatorProto.CalculatorGraphConfig; | ||||||
| import com.google.mediapipe.framework.AndroidAssetUtil; | import com.google.mediapipe.framework.AndroidAssetUtil; | ||||||
| import com.google.mediapipe.framework.AndroidPacketCreator; | import com.google.mediapipe.framework.AndroidPacketCreator; | ||||||
| import com.google.mediapipe.framework.Graph; | import com.google.mediapipe.framework.Graph; | ||||||
|  | @ -201,6 +202,10 @@ public class TaskRunner implements AutoCloseable { | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  |   public CalculatorGraphConfig getCalculatorGraphConfig() { | ||||||
|  |     return graph.getCalculatorGraphConfig(); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|   private synchronized void addPackets(Map<String, Packet> inputs, long inputTimestamp) { |   private synchronized void addPackets(Map<String, Packet> inputs, long inputTimestamp) { | ||||||
|     if (!graphStarted.get()) { |     if (!graphStarted.get()) { | ||||||
|       reportError( |       reportError( | ||||||
|  |  | ||||||
|  | @ -41,6 +41,7 @@ _VISION_TASKS_JAVA_PROTO_LITE_TARGETS = [ | ||||||
|     "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_java_proto_lite", |     "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_java_proto_lite", | ||||||
|     "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite", |     "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite", | ||||||
|     "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_java_proto_lite", |     "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_java_proto_lite", | ||||||
|  |     "//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator_java_proto_lite", | ||||||
|     "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite", |     "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite", | ||||||
|     "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_java_proto_lite", |     "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_java_proto_lite", | ||||||
|     "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite", |     "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite", | ||||||
|  |  | ||||||
|  | @ -197,6 +197,7 @@ android_library( | ||||||
|         "//mediapipe/java/com/google/mediapipe/framework:android_framework", |         "//mediapipe/java/com/google/mediapipe/framework:android_framework", | ||||||
|         "//mediapipe/java/com/google/mediapipe/framework/image", |         "//mediapipe/java/com/google/mediapipe/framework/image", | ||||||
|         "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", |         "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", | ||||||
|  |         "//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator_java_proto_lite", | ||||||
|         "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite", |         "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite", | ||||||
|         "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_java_proto_lite", |         "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_java_proto_lite", | ||||||
|         "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", |         "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", | ||||||
|  |  | ||||||
|  | @ -17,6 +17,7 @@ package com.google.mediapipe.tasks.vision.imagesegmenter; | ||||||
| import android.content.Context; | import android.content.Context; | ||||||
| import com.google.auto.value.AutoValue; | import com.google.auto.value.AutoValue; | ||||||
| import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions; | import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions; | ||||||
|  | import com.google.mediapipe.proto.CalculatorProto.CalculatorGraphConfig; | ||||||
| import com.google.mediapipe.framework.AndroidPacketGetter; | import com.google.mediapipe.framework.AndroidPacketGetter; | ||||||
| import com.google.mediapipe.framework.MediaPipeException; | import com.google.mediapipe.framework.MediaPipeException; | ||||||
| import com.google.mediapipe.framework.Packet; | import com.google.mediapipe.framework.Packet; | ||||||
|  | @ -24,6 +25,7 @@ import com.google.mediapipe.framework.PacketGetter; | ||||||
| import com.google.mediapipe.framework.image.BitmapImageBuilder; | import com.google.mediapipe.framework.image.BitmapImageBuilder; | ||||||
| import com.google.mediapipe.framework.image.ByteBufferImageBuilder; | import com.google.mediapipe.framework.image.ByteBufferImageBuilder; | ||||||
| import com.google.mediapipe.framework.image.MPImage; | import com.google.mediapipe.framework.image.MPImage; | ||||||
|  | import com.google.mediapipe.tasks.TensorsToSegmentationCalculatorOptionsProto; | ||||||
| import com.google.mediapipe.tasks.core.BaseOptions; | import com.google.mediapipe.tasks.core.BaseOptions; | ||||||
| import com.google.mediapipe.tasks.core.ErrorListener; | import com.google.mediapipe.tasks.core.ErrorListener; | ||||||
| import com.google.mediapipe.tasks.core.OutputHandler; | import com.google.mediapipe.tasks.core.OutputHandler; | ||||||
|  | @ -88,8 +90,10 @@ public final class ImageSegmenter extends BaseVisionTaskApi { | ||||||
|   private static final int SEGMENTATION_OUT_STREAM_INDEX = 2; |   private static final int SEGMENTATION_OUT_STREAM_INDEX = 2; | ||||||
|   private static final String TASK_GRAPH_NAME = |   private static final String TASK_GRAPH_NAME = | ||||||
|       "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph"; |       "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph"; | ||||||
| 
 |   private static final String TENSORS_TO_SEGMENTATION_CALCULATOR_NAME = | ||||||
|  |       "mediapipe.tasks.TensorsToSegmentationCalculator"; | ||||||
|   private boolean hasResultListener = false; |   private boolean hasResultListener = false; | ||||||
|  |   private List<String> labels = new ArrayList<>(); | ||||||
| 
 | 
 | ||||||
|   /** |   /** | ||||||
|    * Creates an {@link ImageSegmenter} instance from an {@link ImageSegmenterOptions}. |    * Creates an {@link ImageSegmenter} instance from an {@link ImageSegmenterOptions}. | ||||||
|  | @ -190,6 +194,41 @@ public final class ImageSegmenter extends BaseVisionTaskApi { | ||||||
|       TaskRunner taskRunner, RunningMode runningMode, boolean hasResultListener) { |       TaskRunner taskRunner, RunningMode runningMode, boolean hasResultListener) { | ||||||
|     super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME); |     super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME); | ||||||
|     this.hasResultListener = hasResultListener; |     this.hasResultListener = hasResultListener; | ||||||
|  |     populateLabels(); | ||||||
|  |   } | ||||||
|  |   /** | ||||||
|  |    * Populate the labelmap in TensorsToSegmentationCalculator to labels field. | ||||||
|  |    * | ||||||
|  |    * @throws MediaPipeException if there is an error during finding TensorsToSegmentationCalculator. | ||||||
|  |    */ | ||||||
|  |   private void populateLabels() { | ||||||
|  |     CalculatorGraphConfig graphConfig = this.runner.getCalculatorGraphConfig(); | ||||||
|  | 
 | ||||||
|  |     boolean foundTensorsToSegmentation = false; | ||||||
|  |     for (CalculatorGraphConfig.Node node : graphConfig.getNodeList()) { | ||||||
|  |       if (node.getName().contains(TENSORS_TO_SEGMENTATION_CALCULATOR_NAME)) { | ||||||
|  |         if (foundTensorsToSegmentation) { | ||||||
|  |           throw new MediaPipeException( | ||||||
|  |               MediaPipeException.StatusCode.INTERNAL.ordinal(), | ||||||
|  |               "The graph has more than one mediapipe.tasks.TensorsToSegmentationCalculator."); | ||||||
|  |         } | ||||||
|  |         foundTensorsToSegmentation = true; | ||||||
|  |         TensorsToSegmentationCalculatorOptionsProto.TensorsToSegmentationCalculatorOptions options = | ||||||
|  |             node.getOptions() | ||||||
|  |                 .getExtension( | ||||||
|  |                     TensorsToSegmentationCalculatorOptionsProto | ||||||
|  |                         .TensorsToSegmentationCalculatorOptions.ext); | ||||||
|  |         for (int i = 0; i < options.getLabelItemsMap().size(); i++) { | ||||||
|  |           Long labelKey = Long.valueOf(i); | ||||||
|  |           if (!options.getLabelItemsMap().containsKey(labelKey)) { | ||||||
|  |             throw new MediaPipeException( | ||||||
|  |                 MediaPipeException.StatusCode.INTERNAL.ordinal(), | ||||||
|  |                 "The lablemap have no expected key: " + labelKey); | ||||||
|  |           } | ||||||
|  |           labels.add(options.getLabelItemsMap().get(labelKey).getName()); | ||||||
|  |         } | ||||||
|  |       } | ||||||
|  |     } | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   /** |   /** | ||||||
|  | @ -473,6 +512,17 @@ public final class ImageSegmenter extends BaseVisionTaskApi { | ||||||
|     sendLiveStreamData(image, imageProcessingOptions, timestampMs); |     sendLiveStreamData(image, imageProcessingOptions, timestampMs); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  |   /** | ||||||
|  |    * Get the category label list of the ImageSegmenter can recognize. For CATEGORY_MASK type, the | ||||||
|  |    * index in the category mask corresponds to the category in the label list. For CONFIDENCE_MASK | ||||||
|  |    * type, the output mask list at index corresponds to the category in the label list. | ||||||
|  |    * | ||||||
|  |    * <p>If there is no labelmap provided in the model file, empty label list is returned. | ||||||
|  |    */ | ||||||
|  |   List<String> getLabels() { | ||||||
|  |     return labels; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|   /** Options for setting up an {@link ImageSegmenter}. */ |   /** Options for setting up an {@link ImageSegmenter}. */ | ||||||
|   @AutoValue |   @AutoValue | ||||||
|   public abstract static class ImageSegmenterOptions extends TaskOptions { |   public abstract static class ImageSegmenterOptions extends TaskOptions { | ||||||
|  |  | ||||||
|  | @ -7,6 +7,7 @@ VERS_1.0 { | ||||||
|     Java_com_google_mediapipe_framework_Graph_nativeAddPacketToInputStream; |     Java_com_google_mediapipe_framework_Graph_nativeAddPacketToInputStream; | ||||||
|     Java_com_google_mediapipe_framework_Graph_nativeCloseAllPacketSources; |     Java_com_google_mediapipe_framework_Graph_nativeCloseAllPacketSources; | ||||||
|     Java_com_google_mediapipe_framework_Graph_nativeCreateGraph; |     Java_com_google_mediapipe_framework_Graph_nativeCreateGraph; | ||||||
|  |     Java_com_google_mediapipe_framework_Graph_nativeGetCalculatorGraphConfig; | ||||||
|     Java_com_google_mediapipe_framework_Graph_nativeLoadBinaryGraph*; |     Java_com_google_mediapipe_framework_Graph_nativeLoadBinaryGraph*; | ||||||
|     Java_com_google_mediapipe_framework_Graph_nativeMovePacketToInputStream; |     Java_com_google_mediapipe_framework_Graph_nativeMovePacketToInputStream; | ||||||
|     Java_com_google_mediapipe_framework_Graph_nativeReleaseGraph; |     Java_com_google_mediapipe_framework_Graph_nativeReleaseGraph; | ||||||
|  |  | ||||||
|  | @ -34,6 +34,7 @@ import com.google.mediapipe.tasks.vision.imagesegmenter.ImageSegmenter.ImageSegm | ||||||
| import java.io.InputStream; | import java.io.InputStream; | ||||||
| import java.nio.ByteBuffer; | import java.nio.ByteBuffer; | ||||||
| import java.nio.FloatBuffer; | import java.nio.FloatBuffer; | ||||||
|  | import java.util.Arrays; | ||||||
| import java.util.List; | import java.util.List; | ||||||
| import org.junit.Test; | import org.junit.Test; | ||||||
| import org.junit.runner.RunWith; | import org.junit.runner.RunWith; | ||||||
|  | @ -135,6 +136,45 @@ public class ImageSegmenterTest { | ||||||
|     //   MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName); |     //   MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName); | ||||||
|     //   verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY); |     //   verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY); | ||||||
|     // } |     // } | ||||||
|  | 
 | ||||||
|  |     @Test | ||||||
|  |     public void getLabels_success() throws Exception { | ||||||
|  |       final List<String> expectedLabels = | ||||||
|  |           Arrays.asList( | ||||||
|  |               "background", | ||||||
|  |               "aeroplane", | ||||||
|  |               "bicycle", | ||||||
|  |               "bird", | ||||||
|  |               "boat", | ||||||
|  |               "bottle", | ||||||
|  |               "bus", | ||||||
|  |               "car", | ||||||
|  |               "cat", | ||||||
|  |               "chair", | ||||||
|  |               "cow", | ||||||
|  |               "dining table", | ||||||
|  |               "dog", | ||||||
|  |               "horse", | ||||||
|  |               "motorbike", | ||||||
|  |               "person", | ||||||
|  |               "potted plant", | ||||||
|  |               "sheep", | ||||||
|  |               "sofa", | ||||||
|  |               "train", | ||||||
|  |               "tv"); | ||||||
|  |       ImageSegmenterOptions options = | ||||||
|  |           ImageSegmenterOptions.builder() | ||||||
|  |               .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) | ||||||
|  |               .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) | ||||||
|  |               .build(); | ||||||
|  |       ImageSegmenter imageSegmenter = | ||||||
|  |           ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options); | ||||||
|  |       List<String> actualLabels = imageSegmenter.getLabels(); | ||||||
|  |       assertThat(actualLabels.size()).isEqualTo(expectedLabels.size()); | ||||||
|  |       for (int i = 0; i < actualLabels.size(); i++) { | ||||||
|  |         assertThat(actualLabels.get(i)).isEqualTo(expectedLabels.get(i)); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   @RunWith(AndroidJUnit4.class) |   @RunWith(AndroidJUnit4.class) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user