Add Pose Landmarker Java API
PiperOrigin-RevId: 524359521
This commit is contained in:
		
							parent
							
								
									3f1fc6f520
								
							
						
					
					
						commit
						dbeb5a8126
					
				|  | @ -54,6 +54,9 @@ _VISION_TASKS_JAVA_PROTO_LITE_TARGETS = [ | |||
|     "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_java_proto_lite", | ||||
|     "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite", | ||||
|     "//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_java_proto_lite", | ||||
|     "//mediapipe/tasks/cc/vision/pose_detector/proto:pose_detector_graph_options_java_proto_lite", | ||||
|     "//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarker_graph_options_java_proto_lite", | ||||
|     "//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarks_detector_graph_options_java_proto_lite", | ||||
| ] | ||||
| 
 | ||||
| _TEXT_TASKS_JAVA_PROTO_LITE_TARGETS = [ | ||||
|  |  | |||
|  | @ -54,6 +54,7 @@ cc_binary( | |||
|         "//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph", | ||||
|         "//mediapipe/tasks/cc/vision/interactive_segmenter:interactive_segmenter_graph", | ||||
|         "//mediapipe/tasks/cc/vision/object_detector:object_detector_graph", | ||||
|         "//mediapipe/tasks/cc/vision/pose_landmarker:pose_landmarker_graph", | ||||
|         "//mediapipe/tasks/java:version_script.lds", | ||||
|         "//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni", | ||||
|     ], | ||||
|  | @ -174,6 +175,37 @@ android_library( | |||
|     ], | ||||
| ) | ||||
| 
 | ||||
| android_library( | ||||
|     name = "poselandmarker", | ||||
|     srcs = [ | ||||
|         "poselandmarker/PoseLandmarker.java", | ||||
|         "poselandmarker/PoseLandmarkerResult.java", | ||||
|     ], | ||||
|     javacopts = [ | ||||
|         "-Xep:AndroidJdkLibsChecker:OFF", | ||||
|     ], | ||||
|     manifest = "poselandmarker/AndroidManifest.xml", | ||||
|     deps = [ | ||||
|         ":core", | ||||
|         "//mediapipe/framework:calculator_options_java_proto_lite", | ||||
|         "//mediapipe/framework/formats:classification_java_proto_lite", | ||||
|         "//mediapipe/framework/formats:landmark_java_proto_lite", | ||||
|         "//mediapipe/java/com/google/mediapipe/framework:android_framework", | ||||
|         "//mediapipe/java/com/google/mediapipe/framework/image", | ||||
|         "//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite", | ||||
|         "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", | ||||
|         "//mediapipe/tasks/cc/vision/pose_detector/proto:pose_detector_graph_options_java_proto_lite", | ||||
|         "//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarker_graph_options_java_proto_lite", | ||||
|         "//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarks_detector_graph_options_java_proto_lite", | ||||
|         "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:landmark", | ||||
|         "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:normalized_landmark", | ||||
|         "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", | ||||
|         "//third_party:autovalue", | ||||
|         "@maven//:androidx_annotation_annotation", | ||||
|         "@maven//:com_google_guava_guava", | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
| android_library( | ||||
|     name = "handlandmarker", | ||||
|     srcs = [ | ||||
|  |  | |||
|  | @ -0,0 +1,8 @@ | |||
| <?xml version="1.0" encoding="utf-8"?> | ||||
| <manifest xmlns:android="http://schemas.android.com/apk/res/android" | ||||
|     package="com.google.mediapipe.tasks.vision.poselandmarker"> | ||||
| 
 | ||||
|     <uses-sdk android:minSdkVersion="24" | ||||
|         android:targetSdkVersion="30" /> | ||||
| 
 | ||||
| </manifest> | ||||
|  | @ -0,0 +1,557 @@ | |||
| // Copyright 2023 The MediaPipe Authors. All Rights Reserved. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| // you may not use this file except in compliance with the License. | ||||
| // You may obtain a copy of the License at | ||||
| // | ||||
| //      http://www.apache.org/licenses/LICENSE-2.0 | ||||
| // | ||||
| // Unless required by applicable law or agreed to in writing, software | ||||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| // See the License for the specific language governing permissions and | ||||
| // limitations under the License. | ||||
| 
 | ||||
| package com.google.mediapipe.tasks.vision.poselandmarker; | ||||
| 
 | ||||
| import android.content.Context; | ||||
| import android.os.ParcelFileDescriptor; | ||||
| import com.google.auto.value.AutoValue; | ||||
| import com.google.mediapipe.formats.proto.LandmarkProto.LandmarkList; | ||||
| import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmarkList; | ||||
| import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions; | ||||
| import com.google.mediapipe.framework.AndroidPacketGetter; | ||||
| import com.google.mediapipe.framework.MediaPipeException; | ||||
| import com.google.mediapipe.framework.Packet; | ||||
| import com.google.mediapipe.framework.PacketGetter; | ||||
| import com.google.mediapipe.framework.image.BitmapImageBuilder; | ||||
| import com.google.mediapipe.framework.image.ByteBufferImageBuilder; | ||||
| import com.google.mediapipe.framework.image.MPImage; | ||||
| import com.google.mediapipe.tasks.core.BaseOptions; | ||||
| import com.google.mediapipe.tasks.core.ErrorListener; | ||||
| import com.google.mediapipe.tasks.core.OutputHandler; | ||||
| import com.google.mediapipe.tasks.core.OutputHandler.ResultListener; | ||||
| import com.google.mediapipe.tasks.core.TaskInfo; | ||||
| import com.google.mediapipe.tasks.core.TaskOptions; | ||||
| import com.google.mediapipe.tasks.core.TaskRunner; | ||||
| import com.google.mediapipe.tasks.core.proto.BaseOptionsProto; | ||||
| import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi; | ||||
| import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; | ||||
| import com.google.mediapipe.tasks.vision.core.RunningMode; | ||||
| import com.google.mediapipe.tasks.vision.posedetector.proto.PoseDetectorGraphOptionsProto; | ||||
| import com.google.mediapipe.tasks.vision.poselandmarker.proto.PoseLandmarkerGraphOptionsProto; | ||||
| import com.google.mediapipe.tasks.vision.poselandmarker.proto.PoseLandmarksDetectorGraphOptionsProto; | ||||
| import java.io.File; | ||||
| import java.io.IOException; | ||||
| import java.nio.ByteBuffer; | ||||
| import java.util.ArrayList; | ||||
| import java.util.Arrays; | ||||
| import java.util.Collections; | ||||
| import java.util.List; | ||||
| import java.util.Optional; | ||||
| 
 | ||||
| /** | ||||
|  * Performs pose landmarks detection on images. | ||||
|  * | ||||
|  * <p>This API expects a pre-trained pose landmarks model asset bundle. See <TODO link | ||||
|  * to the DevSite documentation page>. | ||||
|  * | ||||
|  * <ul> | ||||
|  *   <li>Input image {@link MPImage} | ||||
|  *       <ul> | ||||
|  *         <li>The image that pose landmarks detection runs on. | ||||
|  *       </ul> | ||||
|  *   <li>Output PoseLandmarkerResult {@link PoseLandmarkerResult} | ||||
|  *       <ul> | ||||
|  *         <li>A PoseLandmarkerResult containing pose landmarks. | ||||
|  *       </ul> | ||||
|  * </ul> | ||||
|  */ | ||||
| public final class PoseLandmarker extends BaseVisionTaskApi { | ||||
|   private static final String TAG = PoseLandmarker.class.getSimpleName(); | ||||
|   private static final String IMAGE_IN_STREAM_NAME = "image_in"; | ||||
|   private static final String NORM_RECT_IN_STREAM_NAME = "norm_rect_in"; | ||||
| 
 | ||||
|   @SuppressWarnings("ConstantCaseForConstants") | ||||
|   private static final List<String> INPUT_STREAMS = | ||||
|       Collections.unmodifiableList( | ||||
|           Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME)); | ||||
| 
 | ||||
|   private static final int LANDMARKS_OUT_STREAM_INDEX = 0; | ||||
|   private static final int WORLD_LANDMARKS_OUT_STREAM_INDEX = 1; | ||||
|   private static final int AUXILIARY_LANDMARKS_OUT_STREAM_INDEX = 2; | ||||
|   private static final int IMAGE_OUT_STREAM_INDEX = 3; | ||||
|   private static int segmentationMasksOutStreamIndex = -1; | ||||
|   private static final String TASK_GRAPH_NAME = | ||||
|       "mediapipe.tasks.vision.pose_landmarker.PoseLandmarkerGraph"; | ||||
| 
 | ||||
|   /** | ||||
|    * Creates a {@link PoseLandmarker} instance from a model file and the default {@link | ||||
|    * PoseLandmarkerOptions}. | ||||
|    * | ||||
|    * @param context an Android {@link Context}. | ||||
|    * @param modelPath path to the pose landmarks model with metadata in the assets. | ||||
|    * @throws MediaPipeException if there is an error during {@link PoseLandmarker} creation. | ||||
|    */ | ||||
|   public static PoseLandmarker createFromFile(Context context, String modelPath) { | ||||
|     BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(modelPath).build(); | ||||
|     return createFromOptions( | ||||
|         context, PoseLandmarkerOptions.builder().setBaseOptions(baseOptions).build()); | ||||
|   } | ||||
| 
 | ||||
|   /** | ||||
|    * Creates a {@link PoseLandmarker} instance from a model file and the default {@link | ||||
|    * PoseLandmarkerOptions}. | ||||
|    * | ||||
|    * @param context an Android {@link Context}. | ||||
|    * @param modelFile the pose landmarks model {@link File} instance. | ||||
|    * @throws IOException if an I/O error occurs when opening the tflite model file. | ||||
|    * @throws MediaPipeException if there is an error during {@link PoseLandmarker} creation. | ||||
|    */ | ||||
|   public static PoseLandmarker createFromFile(Context context, File modelFile) throws IOException { | ||||
|     try (ParcelFileDescriptor descriptor = | ||||
|         ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) { | ||||
|       BaseOptions baseOptions = | ||||
|           BaseOptions.builder().setModelAssetFileDescriptor(descriptor.getFd()).build(); | ||||
|       return createFromOptions( | ||||
|           context, PoseLandmarkerOptions.builder().setBaseOptions(baseOptions).build()); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   /** | ||||
|    * Creates a {@link PoseLandmarker} instance from a model buffer and the default {@link | ||||
|    * PoseLandmarkerOptions}. | ||||
|    * | ||||
|    * @param context an Android {@link Context}. | ||||
|    * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the detection | ||||
|    *     model. | ||||
|    * @throws MediaPipeException if there is an error during {@link PoseLandmarker} creation. | ||||
|    */ | ||||
|   public static PoseLandmarker createFromBuffer(Context context, final ByteBuffer modelBuffer) { | ||||
|     BaseOptions baseOptions = BaseOptions.builder().setModelAssetBuffer(modelBuffer).build(); | ||||
|     return createFromOptions( | ||||
|         context, PoseLandmarkerOptions.builder().setBaseOptions(baseOptions).build()); | ||||
|   } | ||||
| 
 | ||||
|   /** | ||||
|    * Creates a {@link PoseLandmarker} instance from a {@link PoseLandmarkerOptions}. | ||||
|    * | ||||
|    * @param context an Android {@link Context}. | ||||
|    * @param landmarkerOptions a {@link PoseLandmarkerOptions} instance. | ||||
|    * @throws MediaPipeException if there is an error during {@link PoseLandmarker} creation. | ||||
|    */ | ||||
|   public static PoseLandmarker createFromOptions( | ||||
|       Context context, PoseLandmarkerOptions landmarkerOptions) { | ||||
|     List<String> outputStreams = new ArrayList<>(); | ||||
|     outputStreams.add("NORM_LANDMARKS:pose_landmarks"); | ||||
|     outputStreams.add("WORLD_LANDMARKS:world_landmarks"); | ||||
|     outputStreams.add("AUXILIARY_LANDMARKS:auxiliary_landmarks"); | ||||
|     outputStreams.add("IMAGE:image_out"); | ||||
|     if (landmarkerOptions.outputSegmentationMasks()) { | ||||
|       outputStreams.add("SEGMENTATION_MASK:segmentation_masks"); | ||||
|       segmentationMasksOutStreamIndex = outputStreams.size() - 1; | ||||
|     } | ||||
| 
 | ||||
|     // TODO: Consolidate OutputHandler and TaskRunner. | ||||
|     OutputHandler<PoseLandmarkerResult, MPImage> handler = new OutputHandler<>(); | ||||
|     handler.setOutputPacketConverter( | ||||
|         new OutputHandler.OutputPacketConverter<PoseLandmarkerResult, MPImage>() { | ||||
|           @Override | ||||
|           public PoseLandmarkerResult convertToTaskResult(List<Packet> packets) { | ||||
|             // If there is no poses detected in the image, just returns empty lists. | ||||
|             if (packets.get(LANDMARKS_OUT_STREAM_INDEX).isEmpty()) { | ||||
|               return PoseLandmarkerResult.create( | ||||
|                   new ArrayList<>(), | ||||
|                   new ArrayList<>(), | ||||
|                   new ArrayList<>(), | ||||
|                   Optional.empty(), | ||||
|                   BaseVisionTaskApi.generateResultTimestampMs( | ||||
|                       landmarkerOptions.runningMode(), packets.get(LANDMARKS_OUT_STREAM_INDEX))); | ||||
|             } | ||||
|             /** Get segmentation masks */ | ||||
|             Optional<List<MPImage>> segmentedMasks = Optional.empty(); | ||||
|             if (landmarkerOptions.outputSegmentationMasks()) { | ||||
|               segmentedMasks = getSegmentationMasks(packets); | ||||
|             } | ||||
| 
 | ||||
|             return PoseLandmarkerResult.create( | ||||
|                 PacketGetter.getProtoVector( | ||||
|                     packets.get(LANDMARKS_OUT_STREAM_INDEX), NormalizedLandmarkList.parser()), | ||||
|                 PacketGetter.getProtoVector( | ||||
|                     packets.get(WORLD_LANDMARKS_OUT_STREAM_INDEX), LandmarkList.parser()), | ||||
|                 PacketGetter.getProtoVector( | ||||
|                     packets.get(AUXILIARY_LANDMARKS_OUT_STREAM_INDEX), | ||||
|                     NormalizedLandmarkList.parser()), | ||||
|                 segmentedMasks, | ||||
|                 BaseVisionTaskApi.generateResultTimestampMs( | ||||
|                     landmarkerOptions.runningMode(), packets.get(LANDMARKS_OUT_STREAM_INDEX))); | ||||
|           } | ||||
| 
 | ||||
|           @Override | ||||
|           public MPImage convertToTaskInput(List<Packet> packets) { | ||||
|             return new BitmapImageBuilder( | ||||
|                     AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX))) | ||||
|                 .build(); | ||||
|           } | ||||
|         }); | ||||
|     landmarkerOptions.resultListener().ifPresent(handler::setResultListener); | ||||
|     landmarkerOptions.errorListener().ifPresent(handler::setErrorListener); | ||||
|     TaskRunner runner = | ||||
|         TaskRunner.create( | ||||
|             context, | ||||
|             TaskInfo.<PoseLandmarkerOptions>builder() | ||||
|                 .setTaskName(PoseLandmarker.class.getSimpleName()) | ||||
|                 .setTaskRunningModeName(landmarkerOptions.runningMode().name()) | ||||
|                 .setTaskGraphName(TASK_GRAPH_NAME) | ||||
|                 .setInputStreams(INPUT_STREAMS) | ||||
|                 .setOutputStreams(outputStreams) | ||||
|                 .setTaskOptions(landmarkerOptions) | ||||
|                 .setEnableFlowLimiting(landmarkerOptions.runningMode() == RunningMode.LIVE_STREAM) | ||||
|                 .build(), | ||||
|             handler); | ||||
|     return new PoseLandmarker(runner, landmarkerOptions.runningMode()); | ||||
|   } | ||||
| 
 | ||||
|   /** | ||||
|    * Constructor to initialize a {@link PoseLandmarker} from a {@link TaskRunner} and a {@link | ||||
|    * RunningMode}. | ||||
|    * | ||||
|    * @param taskRunner a {@link TaskRunner}. | ||||
|    * @param runningMode a mediapipe vision task {@link RunningMode}. | ||||
|    */ | ||||
|   private PoseLandmarker(TaskRunner taskRunner, RunningMode runningMode) { | ||||
|     super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME); | ||||
|   } | ||||
| 
 | ||||
|   /** | ||||
|    * Performs pose landmarks detection on the provided single image with default image processing | ||||
|    * options, i.e. without any rotation applied. Only use this method when the {@link | ||||
|    * PoseLandmarker} is created with {@link RunningMode.IMAGE}. TODO update java doc | ||||
|    * for input image format. | ||||
|    * | ||||
|    * <p>{@link PoseLandmarker} supports the following color space types: | ||||
|    * | ||||
|    * <ul> | ||||
|    *   <li>{@link Bitmap.Config.ARGB_8888} | ||||
|    * </ul> | ||||
|    * | ||||
|    * @param image a MediaPipe {@link MPImage} object for processing. | ||||
|    * @throws MediaPipeException if there is an internal error. | ||||
|    */ | ||||
|   public PoseLandmarkerResult detect(MPImage image) { | ||||
|     return detect(image, ImageProcessingOptions.builder().build()); | ||||
|   } | ||||
| 
 | ||||
|   /** | ||||
|    * Performs pose landmarks detection on the provided single image. Only use this method when the | ||||
|    * {@link PoseLandmarker} is created with {@link RunningMode.IMAGE}. TODO update java | ||||
|    * doc for input image format. | ||||
|    * | ||||
|    * <p>{@link PoseLandmarker} supports the following color space types: | ||||
|    * | ||||
|    * <ul> | ||||
|    *   <li>{@link Bitmap.Config.ARGB_8888} | ||||
|    * </ul> | ||||
|    * | ||||
|    * @param image a MediaPipe {@link MPImage} object for processing. | ||||
|    * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the | ||||
|    *     input image before running inference. Note that region-of-interest is <b>not</b> supported | ||||
|    *     by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in | ||||
|    *     this method throwing an IllegalArgumentException. | ||||
|    * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a | ||||
|    *     region-of-interest. | ||||
|    * @throws MediaPipeException if there is an internal error. | ||||
|    */ | ||||
|   public PoseLandmarkerResult detect(MPImage image, ImageProcessingOptions imageProcessingOptions) { | ||||
|     validateImageProcessingOptions(imageProcessingOptions); | ||||
|     return (PoseLandmarkerResult) processImageData(image, imageProcessingOptions); | ||||
|   } | ||||
| 
 | ||||
|   /** | ||||
|    * Performs pose landmarks detection on the provided video frame with default image processing | ||||
|    * options, i.e. without any rotation applied. Only use this method when the {@link | ||||
|    * PoseLandmarker} is created with {@link RunningMode.VIDEO}. | ||||
|    * | ||||
|    * <p>It's required to provide the video frame's timestamp (in milliseconds). The input timestamps | ||||
|    * must be monotonically increasing. | ||||
|    * | ||||
|    * <p>{@link PoseLandmarker} supports the following color space types: | ||||
|    * | ||||
|    * <ul> | ||||
|    *   <li>{@link Bitmap.Config.ARGB_8888} | ||||
|    * </ul> | ||||
|    * | ||||
|    * @param image a MediaPipe {@link MPImage} object for processing. | ||||
|    * @param timestampMs the input timestamp (in milliseconds). | ||||
|    * @throws MediaPipeException if there is an internal error. | ||||
|    */ | ||||
|   public PoseLandmarkerResult detectForVideo(MPImage image, long timestampMs) { | ||||
|     return detectForVideo(image, ImageProcessingOptions.builder().build(), timestampMs); | ||||
|   } | ||||
| 
 | ||||
|   /** | ||||
|    * Performs pose landmarks detection on the provided video frame. Only use this method when the | ||||
|    * {@link PoseLandmarker} is created with {@link RunningMode.VIDEO}. | ||||
|    * | ||||
|    * <p>It's required to provide the video frame's timestamp (in milliseconds). The input timestamps | ||||
|    * must be monotonically increasing. | ||||
|    * | ||||
|    * <p>{@link PoseLandmarker} supports the following color space types: | ||||
|    * | ||||
|    * <ul> | ||||
|    *   <li>{@link Bitmap.Config.ARGB_8888} | ||||
|    * </ul> | ||||
|    * | ||||
|    * @param image a MediaPipe {@link MPImage} object for processing. | ||||
|    * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the | ||||
|    *     input image before running inference. Note that region-of-interest is <b>not</b> supported | ||||
|    *     by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in | ||||
|    *     this method throwing an IllegalArgumentException. | ||||
|    * @param timestampMs the input timestamp (in milliseconds). | ||||
|    * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a | ||||
|    *     region-of-interest. | ||||
|    * @throws MediaPipeException if there is an internal error. | ||||
|    */ | ||||
|   public PoseLandmarkerResult detectForVideo( | ||||
|       MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { | ||||
|     validateImageProcessingOptions(imageProcessingOptions); | ||||
|     return (PoseLandmarkerResult) processVideoData(image, imageProcessingOptions, timestampMs); | ||||
|   } | ||||
| 
 | ||||
|   /** | ||||
|    * Sends live image data to perform pose landmarks detection with default image processing | ||||
|    * options, i.e. without any rotation applied, and the results will be available via the {@link | ||||
|    * ResultListener} provided in the {@link PoseLandmarkerOptions}. Only use this method when the | ||||
|    * {@link PoseLandmarker } is created with {@link RunningMode.LIVE_STREAM}. | ||||
|    * | ||||
|    * <p>It's required to provide a timestamp (in milliseconds) to indicate when the input image is | ||||
|    * sent to the pose landmarker. The input timestamps must be monotonically increasing. | ||||
|    * | ||||
|    * <p>{@link PoseLandmarker} supports the following color space types: | ||||
|    * | ||||
|    * <ul> | ||||
|    *   <li>{@link Bitmap.Config.ARGB_8888} | ||||
|    * </ul> | ||||
|    * | ||||
|    * @param image a MediaPipe {@link MPImage} object for processing. | ||||
|    * @param timestampMs the input timestamp (in milliseconds). | ||||
|    * @throws MediaPipeException if there is an internal error. | ||||
|    */ | ||||
|   public void detectAsync(MPImage image, long timestampMs) { | ||||
|     detectAsync(image, ImageProcessingOptions.builder().build(), timestampMs); | ||||
|   } | ||||
| 
 | ||||
|   /** | ||||
|    * Sends live image data to perform pose landmarks detection, and the results will be available | ||||
|    * via the {@link ResultListener} provided in the {@link PoseLandmarkerOptions}. Only use this | ||||
|    * method when the {@link PoseLandmarker} is created with {@link RunningMode.LIVE_STREAM}. | ||||
|    * | ||||
|    * <p>It's required to provide a timestamp (in milliseconds) to indicate when the input image is | ||||
|    * sent to the pose landmarker. The input timestamps must be monotonically increasing. | ||||
|    * | ||||
|    * <p>{@link PoseLandmarker} supports the following color space types: | ||||
|    * | ||||
|    * <ul> | ||||
|    *   <li>{@link Bitmap.Config.ARGB_8888} | ||||
|    * </ul> | ||||
|    * | ||||
|    * @param image a MediaPipe {@link MPImage} object for processing. | ||||
|    * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the | ||||
|    *     input image before running inference. Note that region-of-interest is <b>not</b> supported | ||||
|    *     by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in | ||||
|    *     this method throwing an IllegalArgumentException. | ||||
|    * @param timestampMs the input timestamp (in milliseconds). | ||||
|    * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a | ||||
|    *     region-of-interest. | ||||
|    * @throws MediaPipeException if there is an internal error. | ||||
|    */ | ||||
|   public void detectAsync( | ||||
|       MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { | ||||
|     validateImageProcessingOptions(imageProcessingOptions); | ||||
|     sendLiveStreamData(image, imageProcessingOptions, timestampMs); | ||||
|   } | ||||
| 
 | ||||
|   /** Options for setting up an {@link PoseLandmarker}. */ | ||||
|   @AutoValue | ||||
|   public abstract static class PoseLandmarkerOptions extends TaskOptions { | ||||
| 
 | ||||
|     /** Builder for {@link PoseLandmarkerOptions}. */ | ||||
|     @AutoValue.Builder | ||||
|     public abstract static class Builder { | ||||
|       /** Sets the base options for the pose landmarker task. */ | ||||
|       public abstract Builder setBaseOptions(BaseOptions value); | ||||
| 
 | ||||
|       /** | ||||
|        * Sets the running mode for the pose landmarker task. Default to the image mode. Pose | ||||
|        * landmarker has three modes: | ||||
|        * | ||||
|        * <ul> | ||||
|        *   <li>IMAGE: The mode for detecting pose landmarks on single image inputs. | ||||
|        *   <li>VIDEO: The mode for detecting pose landmarks on the decoded frames of a video. | ||||
|        *   <li>LIVE_STREAM: The mode for for detecting pose landmarks on a live stream of input | ||||
|        *       data, such as from camera. In this mode, {@code setResultListener} must be called to | ||||
|        *       set up a listener to receive the detection results asynchronously. | ||||
|        * </ul> | ||||
|        */ | ||||
|       public abstract Builder setRunningMode(RunningMode value); | ||||
| 
 | ||||
|       /** Sets the maximum number of poses can be detected by the PoseLandmarker. */ | ||||
|       public abstract Builder setNumPoses(Integer value); | ||||
| 
 | ||||
|       /** Sets minimum confidence score for the pose detection to be considered successful */ | ||||
|       public abstract Builder setMinPoseDetectionConfidence(Float value); | ||||
| 
 | ||||
|       /** Sets minimum confidence score of pose presence score in the pose landmark detection. */ | ||||
|       public abstract Builder setMinPosePresenceConfidence(Float value); | ||||
| 
 | ||||
|       /** Sets the minimum confidence score for the pose tracking to be considered successful. */ | ||||
|       public abstract Builder setMinTrackingConfidence(Float value); | ||||
| 
 | ||||
|       public abstract Builder setOutputSegmentationMasks(Boolean value); | ||||
| 
 | ||||
|       /** | ||||
|        * Sets the result listener to receive the detection results asynchronously when the pose | ||||
|        * landmarker is in the live stream mode. | ||||
|        */ | ||||
|       public abstract Builder setResultListener( | ||||
|           ResultListener<PoseLandmarkerResult, MPImage> value); | ||||
| 
 | ||||
|       /** Sets an optional error listener. */ | ||||
|       public abstract Builder setErrorListener(ErrorListener value); | ||||
| 
 | ||||
|       abstract PoseLandmarkerOptions autoBuild(); | ||||
| 
 | ||||
|       /** | ||||
|        * Validates and builds the {@link PoseLandmarkerOptions} instance. | ||||
|        * | ||||
|        * @throws IllegalArgumentException if the result listener and the running mode are not | ||||
|        *     properly configured. The result listener should only be set when the pose landmarker is | ||||
|        *     in the live stream mode. | ||||
|        */ | ||||
|       public final PoseLandmarkerOptions build() { | ||||
|         PoseLandmarkerOptions options = autoBuild(); | ||||
|         if (options.runningMode() == RunningMode.LIVE_STREAM) { | ||||
|           if (!options.resultListener().isPresent()) { | ||||
|             throw new IllegalArgumentException( | ||||
|                 "The pose landmarker is in the live stream mode, a user-defined result listener" | ||||
|                     + " must be provided in PoseLandmarkerOptions."); | ||||
|           } | ||||
|         } else if (options.resultListener().isPresent()) { | ||||
|           throw new IllegalArgumentException( | ||||
|               "The pose landmarker is in the image or the video mode, a user-defined result" | ||||
|                   + " listener shouldn't be provided in PoseLandmarkerOptions."); | ||||
|         } | ||||
|         return options; | ||||
|       } | ||||
|     } | ||||
| 
 | ||||
|     abstract BaseOptions baseOptions(); | ||||
| 
 | ||||
|     abstract RunningMode runningMode(); | ||||
| 
 | ||||
|     abstract Optional<Integer> numPoses(); | ||||
| 
 | ||||
|     abstract Optional<Float> minPoseDetectionConfidence(); | ||||
| 
 | ||||
|     abstract Optional<Float> minPosePresenceConfidence(); | ||||
| 
 | ||||
|     abstract Optional<Float> minTrackingConfidence(); | ||||
| 
 | ||||
|     abstract Boolean outputSegmentationMasks(); | ||||
| 
 | ||||
|     abstract Optional<ResultListener<PoseLandmarkerResult, MPImage>> resultListener(); | ||||
| 
 | ||||
|     abstract Optional<ErrorListener> errorListener(); | ||||
| 
 | ||||
|     public static Builder builder() { | ||||
|       return new AutoValue_PoseLandmarker_PoseLandmarkerOptions.Builder() | ||||
|           .setRunningMode(RunningMode.IMAGE) | ||||
|           .setNumPoses(1) | ||||
|           .setMinPoseDetectionConfidence(0.5f) | ||||
|           .setMinPosePresenceConfidence(0.5f) | ||||
|           .setMinTrackingConfidence(0.5f) | ||||
|           .setOutputSegmentationMasks(false); | ||||
|     } | ||||
| 
 | ||||
|     /** Converts a {@link PoseLandmarkerOptions} to a {@link CalculatorOptions} protobuf message. */ | ||||
|     @Override | ||||
|     public CalculatorOptions convertToCalculatorOptionsProto() { | ||||
|       PoseLandmarkerGraphOptionsProto.PoseLandmarkerGraphOptions.Builder taskOptionsBuilder = | ||||
|           PoseLandmarkerGraphOptionsProto.PoseLandmarkerGraphOptions.newBuilder() | ||||
|               .setBaseOptions( | ||||
|                   BaseOptionsProto.BaseOptions.newBuilder() | ||||
|                       .setUseStreamMode(runningMode() != RunningMode.IMAGE) | ||||
|                       .mergeFrom(convertBaseOptionsToProto(baseOptions())) | ||||
|                       .build()); | ||||
| 
 | ||||
|       // Setup PoseDetectorGraphOptions. | ||||
|       PoseDetectorGraphOptionsProto.PoseDetectorGraphOptions.Builder | ||||
|           poseDetectorGraphOptionsBuilder = | ||||
|               PoseDetectorGraphOptionsProto.PoseDetectorGraphOptions.newBuilder(); | ||||
|       numPoses().ifPresent(poseDetectorGraphOptionsBuilder::setNumPoses); | ||||
|       minPoseDetectionConfidence() | ||||
|           .ifPresent(poseDetectorGraphOptionsBuilder::setMinDetectionConfidence); | ||||
| 
 | ||||
|       // Setup PoseLandmarkerGraphOptions. | ||||
|       PoseLandmarksDetectorGraphOptionsProto.PoseLandmarksDetectorGraphOptions.Builder | ||||
|           poseLandmarksDetectorGraphOptionsBuilder = | ||||
|               PoseLandmarksDetectorGraphOptionsProto.PoseLandmarksDetectorGraphOptions.newBuilder(); | ||||
|       minPosePresenceConfidence() | ||||
|           .ifPresent(poseLandmarksDetectorGraphOptionsBuilder::setMinDetectionConfidence); | ||||
|       minTrackingConfidence().ifPresent(taskOptionsBuilder::setMinTrackingConfidence); | ||||
| 
 | ||||
|       taskOptionsBuilder | ||||
|           .setPoseDetectorGraphOptions(poseDetectorGraphOptionsBuilder.build()) | ||||
|           .setPoseLandmarksDetectorGraphOptions(poseLandmarksDetectorGraphOptionsBuilder.build()); | ||||
| 
 | ||||
|       return CalculatorOptions.newBuilder() | ||||
|           .setExtension( | ||||
|               PoseLandmarkerGraphOptionsProto.PoseLandmarkerGraphOptions.ext, | ||||
|               taskOptionsBuilder.build()) | ||||
|           .build(); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   /** | ||||
|    * Validates that the provided {@link ImageProcessingOptions} doesn't contain a | ||||
|    * region-of-interest. | ||||
|    */ | ||||
|   private static void validateImageProcessingOptions( | ||||
|       ImageProcessingOptions imageProcessingOptions) { | ||||
|     if (imageProcessingOptions.regionOfInterest().isPresent()) { | ||||
|       throw new IllegalArgumentException("PoseLandmarker doesn't support region-of-interest."); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   private static Optional<List<MPImage>> getSegmentationMasks(List<Packet> packets) { | ||||
|     Optional<List<MPImage>> segmentedMasks = Optional.of(new ArrayList<>()); | ||||
|     int width = | ||||
|         PacketGetter.getImageWidthFromImageList(packets.get(segmentationMasksOutStreamIndex)); | ||||
|     int height = | ||||
|         PacketGetter.getImageHeightFromImageList(packets.get(segmentationMasksOutStreamIndex)); | ||||
|     int imageListSize = PacketGetter.getImageListSize(packets.get(segmentationMasksOutStreamIndex)); | ||||
|     ByteBuffer[] buffersArray = new ByteBuffer[imageListSize]; | ||||
| 
 | ||||
|     // Segmentation mask is a float type image. | ||||
|     int numBytes = 4; | ||||
|     for (int i = 0; i < imageListSize; i++) { | ||||
|       buffersArray[i] = ByteBuffer.allocateDirect(width * height * numBytes); | ||||
|     } | ||||
| 
 | ||||
|     if (!PacketGetter.getImageList( | ||||
|         packets.get(segmentationMasksOutStreamIndex), | ||||
|         buffersArray, | ||||
|         /** deepCopy= */ | ||||
|         true)) { | ||||
|       throw new MediaPipeException( | ||||
|           MediaPipeException.StatusCode.INTERNAL.ordinal(), | ||||
|           "There is an error getting segmented masks."); | ||||
|     } | ||||
|     for (ByteBuffer buffer : buffersArray) { | ||||
|       ByteBufferImageBuilder builder = | ||||
|           new ByteBufferImageBuilder(buffer, width, height, MPImage.IMAGE_FORMAT_VEC32F1); | ||||
|       segmentedMasks.get().add(builder.build()); | ||||
|     } | ||||
|     return segmentedMasks; | ||||
|   } | ||||
| } | ||||
|  | @ -0,0 +1,113 @@ | |||
| // Copyright 2023 The MediaPipe Authors. All Rights Reserved. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| // you may not use this file except in compliance with the License. | ||||
| // You may obtain a copy of the License at | ||||
| // | ||||
| //      http://www.apache.org/licenses/LICENSE-2.0 | ||||
| // | ||||
| // Unless required by applicable law or agreed to in writing, software | ||||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| // See the License for the specific language governing permissions and | ||||
| // limitations under the License. | ||||
| 
 | ||||
| package com.google.mediapipe.tasks.vision.poselandmarker; | ||||
| 
 | ||||
| import com.google.auto.value.AutoValue; | ||||
| import com.google.mediapipe.formats.proto.LandmarkProto; | ||||
| import com.google.mediapipe.framework.image.MPImage; | ||||
| import com.google.mediapipe.tasks.components.containers.Landmark; | ||||
| import com.google.mediapipe.tasks.components.containers.NormalizedLandmark; | ||||
| import com.google.mediapipe.tasks.core.TaskResult; | ||||
| import java.util.ArrayList; | ||||
| import java.util.Collections; | ||||
| import java.util.List; | ||||
| import java.util.Optional; | ||||
| 
 | ||||
| /** Represents the pose landmarks deection results generated by {@link PoseLandmarker}. */ | ||||
| @AutoValue | ||||
| public abstract class PoseLandmarkerResult implements TaskResult { | ||||
| 
 | ||||
|   /** | ||||
|    * Creates a {@link PoseLandmarkerResult} instance from the lists of landmarks and | ||||
|    * segmentationMask protobuf messages. | ||||
|    * | ||||
|    * @param landmarksProto a List of {@link NormalizedLandmarkList} | ||||
|    * @param worldLandmarksProto a List of {@link LandmarkList} | ||||
|    * @param segmentationMasksData a List of {@link MPImage} | ||||
|    */ | ||||
|   static PoseLandmarkerResult create( | ||||
|       List<LandmarkProto.NormalizedLandmarkList> landmarksProto, | ||||
|       List<LandmarkProto.LandmarkList> worldLandmarksProto, | ||||
|       List<LandmarkProto.NormalizedLandmarkList> auxiliaryLandmarksProto, | ||||
|       Optional<List<MPImage>> segmentationMasksData, | ||||
|       long timestampMs) { | ||||
| 
 | ||||
|     Optional<List<MPImage>> multiPoseSegmentationMasks = Optional.empty(); | ||||
|     if (segmentationMasksData.isPresent()) { | ||||
|       multiPoseSegmentationMasks = | ||||
|           Optional.of(Collections.unmodifiableList(segmentationMasksData.get())); | ||||
|     } | ||||
| 
 | ||||
|     List<List<NormalizedLandmark>> multiPoseLandmarks = new ArrayList<>(); | ||||
|     List<List<Landmark>> multiPoseWorldLandmarks = new ArrayList<>(); | ||||
|     List<List<NormalizedLandmark>> multiPoseAuxiliaryLandmarks = new ArrayList<>(); | ||||
|     for (LandmarkProto.NormalizedLandmarkList poseLandmarksProto : landmarksProto) { | ||||
|       List<NormalizedLandmark> poseLandmarks = new ArrayList<>(); | ||||
|       multiPoseLandmarks.add(poseLandmarks); | ||||
|       for (LandmarkProto.NormalizedLandmark poseLandmarkProto : | ||||
|           poseLandmarksProto.getLandmarkList()) { | ||||
|         poseLandmarks.add( | ||||
|             NormalizedLandmark.create( | ||||
|                 poseLandmarkProto.getX(), poseLandmarkProto.getY(), poseLandmarkProto.getZ())); | ||||
|       } | ||||
|     } | ||||
|     for (LandmarkProto.LandmarkList poseWorldLandmarksProto : worldLandmarksProto) { | ||||
|       List<Landmark> poseWorldLandmarks = new ArrayList<>(); | ||||
|       multiPoseWorldLandmarks.add(poseWorldLandmarks); | ||||
|       for (LandmarkProto.Landmark poseWorldLandmarkProto : | ||||
|           poseWorldLandmarksProto.getLandmarkList()) { | ||||
|         poseWorldLandmarks.add( | ||||
|             Landmark.create( | ||||
|                 poseWorldLandmarkProto.getX(), | ||||
|                 poseWorldLandmarkProto.getY(), | ||||
|                 poseWorldLandmarkProto.getZ())); | ||||
|       } | ||||
|     } | ||||
|     for (LandmarkProto.NormalizedLandmarkList poseAuxiliaryLandmarksProto : | ||||
|         auxiliaryLandmarksProto) { | ||||
|       List<NormalizedLandmark> poseAuxiliaryLandmarks = new ArrayList<>(); | ||||
|       multiPoseAuxiliaryLandmarks.add(poseAuxiliaryLandmarks); | ||||
|       for (LandmarkProto.NormalizedLandmark poseAuxiliaryLandmarkProto : | ||||
|           poseAuxiliaryLandmarksProto.getLandmarkList()) { | ||||
|         poseAuxiliaryLandmarks.add( | ||||
|             NormalizedLandmark.create( | ||||
|                 poseAuxiliaryLandmarkProto.getX(), | ||||
|                 poseAuxiliaryLandmarkProto.getY(), | ||||
|                 poseAuxiliaryLandmarkProto.getZ())); | ||||
|       } | ||||
|     } | ||||
|     return new AutoValue_PoseLandmarkerResult( | ||||
|         timestampMs, | ||||
|         Collections.unmodifiableList(multiPoseLandmarks), | ||||
|         Collections.unmodifiableList(multiPoseWorldLandmarks), | ||||
|         Collections.unmodifiableList(multiPoseAuxiliaryLandmarks), | ||||
|         multiPoseSegmentationMasks); | ||||
|   } | ||||
| 
 | ||||
|   @Override | ||||
|   public abstract long timestampMs(); | ||||
| 
 | ||||
|   /** Pose landmarks of detected poses. */ | ||||
|   public abstract List<List<NormalizedLandmark>> landmarks(); | ||||
| 
 | ||||
|   /** Pose landmarks in world coordniates of detected poses. */ | ||||
|   public abstract List<List<Landmark>> worldLandmarks(); | ||||
| 
 | ||||
|   /** Pose auxiliary landmarks. */ | ||||
|   public abstract List<List<NormalizedLandmark>> auxiliaryLandmarks(); | ||||
| 
 | ||||
|   /** Pose segmentation masks. */ | ||||
|   public abstract Optional<List<MPImage>> segmentationMasks(); | ||||
| } | ||||
|  | @ -0,0 +1,24 @@ | |||
| <?xml version="1.0" encoding="utf-8"?> | ||||
| <manifest xmlns:android="http://schemas.android.com/apk/res/android" | ||||
|     package="com.google.mediapipe.tasks.vision.poselandmarkertest" | ||||
|     android:versionCode="1" | ||||
|     android:versionName="1.0" > | ||||
| 
 | ||||
|     <uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE"/> | ||||
|     <uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE"/> | ||||
| 
 | ||||
|     <uses-sdk android:minSdkVersion="24" | ||||
|         android:targetSdkVersion="30" /> | ||||
| 
 | ||||
|     <application | ||||
|         android:label="poselandmarkertest" | ||||
|         android:name="android.support.multidex.MultiDexApplication" | ||||
|         android:taskAffinity=""> | ||||
|         <uses-library android:name="android.test.runner" /> | ||||
|     </application> | ||||
| 
 | ||||
|     <instrumentation | ||||
|         android:name="com.google.android.apps.common.testing.testrunner.GoogleInstrumentationTestRunner" | ||||
|         android:targetPackage="com.google.mediapipe.tasks.vision.poselandmarkertest" /> | ||||
| 
 | ||||
| </manifest> | ||||
|  | @ -0,0 +1,19 @@ | |||
| # Copyright 2023 The MediaPipe Authors. All Rights Reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #      http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| package(default_visibility = ["//mediapipe/tasks:internal"]) | ||||
| 
 | ||||
| licenses(["notice"]) | ||||
| 
 | ||||
| # TODO: Enable this in OSS | ||||
|  | @ -0,0 +1,365 @@ | |||
| // Copyright 2022 The MediaPipe Authors. All Rights Reserved. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| // you may not use this file except in compliance with the License. | ||||
| // You may obtain a copy of the License at | ||||
| // | ||||
| //      http://www.apache.org/licenses/LICENSE-2.0 | ||||
| // | ||||
| // Unless required by applicable law or agreed to in writing, software | ||||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| // See the License for the specific language governing permissions and | ||||
| // limitations under the License. | ||||
| 
 | ||||
| package com.google.mediapipe.tasks.vision.poselandmarker; | ||||
| 
 | ||||
| import static com.google.common.truth.Truth.assertThat; | ||||
| import static org.junit.Assert.assertThrows; | ||||
| 
 | ||||
| import android.content.res.AssetManager; | ||||
| import android.graphics.BitmapFactory; | ||||
| import android.graphics.RectF; | ||||
| import androidx.test.core.app.ApplicationProvider; | ||||
| import androidx.test.ext.junit.runners.AndroidJUnit4; | ||||
| import com.google.common.truth.Correspondence; | ||||
| import com.google.mediapipe.framework.MediaPipeException; | ||||
| import com.google.mediapipe.framework.image.BitmapImageBuilder; | ||||
| import com.google.mediapipe.framework.image.MPImage; | ||||
| import com.google.mediapipe.tasks.components.containers.NormalizedLandmark; | ||||
| import com.google.mediapipe.tasks.components.containers.proto.LandmarksDetectionResultProto.LandmarksDetectionResult; | ||||
| import com.google.mediapipe.tasks.core.BaseOptions; | ||||
| import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; | ||||
| import com.google.mediapipe.tasks.vision.core.RunningMode; | ||||
| import com.google.mediapipe.tasks.vision.poselandmarker.PoseLandmarker.PoseLandmarkerOptions; | ||||
| import java.io.InputStream; | ||||
| import java.util.Arrays; | ||||
| import java.util.Optional; | ||||
| import org.junit.Test; | ||||
| import org.junit.runner.RunWith; | ||||
| import org.junit.runners.Suite; | ||||
| import org.junit.runners.Suite.SuiteClasses; | ||||
| 
 | ||||
| /** Test for {@link PoseLandmarker}. */ | ||||
| @RunWith(Suite.class) | ||||
| @SuiteClasses({PoseLandmarkerTest.General.class, PoseLandmarkerTest.RunningModeTest.class}) | ||||
| public class PoseLandmarkerTest { | ||||
|   private static final String POSE_LANDMARKER_BUNDLE_ASSET_FILE = "pose_landmarker.task"; | ||||
|   private static final String POSE_IMAGE = "pose.jpg"; | ||||
|   private static final String POSE_LANDMARKS = "pose_landmarks.pb"; | ||||
|   private static final String NO_POSES_IMAGE = "burger.jpg"; | ||||
|   private static final String TAG = "Pose Landmarker Test"; | ||||
|   private static final float LANDMARKS_ERROR_TOLERANCE = 0.03f; | ||||
|   private static final int IMAGE_WIDTH = 1000; | ||||
|   private static final int IMAGE_HEIGHT = 667; | ||||
| 
 | ||||
|   @RunWith(AndroidJUnit4.class) | ||||
|   public static final class General extends PoseLandmarkerTest { | ||||
| 
 | ||||
|     @Test | ||||
|     public void detect_successWithValidModels() throws Exception { | ||||
|       PoseLandmarkerOptions options = | ||||
|           PoseLandmarkerOptions.builder() | ||||
|               .setBaseOptions( | ||||
|                   BaseOptions.builder() | ||||
|                       .setModelAssetPath(POSE_LANDMARKER_BUNDLE_ASSET_FILE) | ||||
|                       .build()) | ||||
|               .build(); | ||||
|       PoseLandmarker poseLandmarker = | ||||
|           PoseLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options); | ||||
|       PoseLandmarkerResult actualResult = poseLandmarker.detect(getImageFromAsset(POSE_IMAGE)); | ||||
|       PoseLandmarkerResult expectedResult = getExpectedPoseLandmarkerResult(POSE_LANDMARKS); | ||||
|       assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void detect_successWithEmptyResult() throws Exception { | ||||
|       PoseLandmarkerOptions options = | ||||
|           PoseLandmarkerOptions.builder() | ||||
|               .setBaseOptions( | ||||
|                   BaseOptions.builder() | ||||
|                       .setModelAssetPath(POSE_LANDMARKER_BUNDLE_ASSET_FILE) | ||||
|                       .build()) | ||||
|               .build(); | ||||
|       PoseLandmarker poseLandmarker = | ||||
|           PoseLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options); | ||||
|       PoseLandmarkerResult actualResult = poseLandmarker.detect(getImageFromAsset(NO_POSES_IMAGE)); | ||||
|       assertThat(actualResult.landmarks()).isEmpty(); | ||||
|       assertThat(actualResult.worldLandmarks()).isEmpty(); | ||||
|       // TODO: Add additional tests for MP Tasks Pose Graphs | ||||
|       // Add tests for segmentation masks. | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void recognize_failsWithRegionOfInterest() throws Exception { | ||||
|       PoseLandmarkerOptions options = | ||||
|           PoseLandmarkerOptions.builder() | ||||
|               .setBaseOptions( | ||||
|                   BaseOptions.builder() | ||||
|                       .setModelAssetPath(POSE_LANDMARKER_BUNDLE_ASSET_FILE) | ||||
|                       .build()) | ||||
|               .setNumPoses(1) | ||||
|               .build(); | ||||
|       PoseLandmarker poseLandmarker = | ||||
|           PoseLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options); | ||||
|       ImageProcessingOptions imageProcessingOptions = | ||||
|           ImageProcessingOptions.builder().setRegionOfInterest(new RectF(0, 0, 1, 1)).build(); | ||||
|       IllegalArgumentException exception = | ||||
|           assertThrows( | ||||
|               IllegalArgumentException.class, | ||||
|               () -> poseLandmarker.detect(getImageFromAsset(POSE_IMAGE), imageProcessingOptions)); | ||||
|       assertThat(exception) | ||||
|           .hasMessageThat() | ||||
|           .contains("PoseLandmarker doesn't support region-of-interest"); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   @RunWith(AndroidJUnit4.class) | ||||
|   public static final class RunningModeTest extends PoseLandmarkerTest { | ||||
|     @Test | ||||
|     public void create_failsWithIllegalResultListenerInNonLiveStreamMode() throws Exception { | ||||
|       for (RunningMode mode : new RunningMode[] {RunningMode.IMAGE, RunningMode.VIDEO}) { | ||||
|         IllegalArgumentException exception = | ||||
|             assertThrows( | ||||
|                 IllegalArgumentException.class, | ||||
|                 () -> | ||||
|                     PoseLandmarkerOptions.builder() | ||||
|                         .setBaseOptions( | ||||
|                             BaseOptions.builder() | ||||
|                                 .setModelAssetPath(POSE_LANDMARKER_BUNDLE_ASSET_FILE) | ||||
|                                 .build()) | ||||
|                         .setRunningMode(mode) | ||||
|                         .setResultListener((PoseLandmarkerResults, inputImage) -> {}) | ||||
|                         .build()); | ||||
|         assertThat(exception) | ||||
|             .hasMessageThat() | ||||
|             .contains("a user-defined result listener shouldn't be provided"); | ||||
|       } | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   @Test | ||||
|   public void create_failsWithMissingResultListenerInLiveSteamMode() throws Exception { | ||||
|     IllegalArgumentException exception = | ||||
|         assertThrows( | ||||
|             IllegalArgumentException.class, | ||||
|             () -> | ||||
|                 PoseLandmarkerOptions.builder() | ||||
|                     .setBaseOptions( | ||||
|                         BaseOptions.builder() | ||||
|                             .setModelAssetPath(POSE_LANDMARKER_BUNDLE_ASSET_FILE) | ||||
|                             .build()) | ||||
|                     .setRunningMode(RunningMode.LIVE_STREAM) | ||||
|                     .build()); | ||||
|     assertThat(exception) | ||||
|         .hasMessageThat() | ||||
|         .contains("a user-defined result listener must be provided"); | ||||
|   } | ||||
| 
 | ||||
|   @Test | ||||
|   public void recognize_failsWithCallingWrongApiInImageMode() throws Exception { | ||||
|     PoseLandmarkerOptions options = | ||||
|         PoseLandmarkerOptions.builder() | ||||
|             .setBaseOptions( | ||||
|                 BaseOptions.builder().setModelAssetPath(POSE_LANDMARKER_BUNDLE_ASSET_FILE).build()) | ||||
|             .setRunningMode(RunningMode.IMAGE) | ||||
|             .build(); | ||||
| 
 | ||||
|     PoseLandmarker poseLandmarker = | ||||
|         PoseLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options); | ||||
|     MediaPipeException exception = | ||||
|         assertThrows( | ||||
|             MediaPipeException.class, | ||||
|             () -> | ||||
|                 poseLandmarker.detectForVideo( | ||||
|                     getImageFromAsset(POSE_IMAGE), /* timestampsMs= */ 0)); | ||||
|     assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); | ||||
|     exception = | ||||
|         assertThrows( | ||||
|             MediaPipeException.class, | ||||
|             () -> poseLandmarker.detectAsync(getImageFromAsset(POSE_IMAGE), /* timestampsMs= */ 0)); | ||||
|     assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); | ||||
|   } | ||||
| 
 | ||||
|   @Test | ||||
|   public void recognize_failsWithCallingWrongApiInVideoMode() throws Exception { | ||||
|     PoseLandmarkerOptions options = | ||||
|         PoseLandmarkerOptions.builder() | ||||
|             .setBaseOptions( | ||||
|                 BaseOptions.builder().setModelAssetPath(POSE_LANDMARKER_BUNDLE_ASSET_FILE).build()) | ||||
|             .setRunningMode(RunningMode.VIDEO) | ||||
|             .build(); | ||||
| 
 | ||||
|     PoseLandmarker poseLandmarker = | ||||
|         PoseLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options); | ||||
|     MediaPipeException exception = | ||||
|         assertThrows( | ||||
|             MediaPipeException.class, () -> poseLandmarker.detect(getImageFromAsset(POSE_IMAGE))); | ||||
|     assertThat(exception).hasMessageThat().contains("not initialized with the image mode"); | ||||
|     exception = | ||||
|         assertThrows( | ||||
|             MediaPipeException.class, | ||||
|             () -> poseLandmarker.detectAsync(getImageFromAsset(POSE_IMAGE), /* timestampsMs= */ 0)); | ||||
|     assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); | ||||
|   } | ||||
| 
 | ||||
|   @Test | ||||
|   public void recognize_failsWithCallingWrongApiInLiveSteamMode() throws Exception { | ||||
|     PoseLandmarkerOptions options = | ||||
|         PoseLandmarkerOptions.builder() | ||||
|             .setBaseOptions( | ||||
|                 BaseOptions.builder().setModelAssetPath(POSE_LANDMARKER_BUNDLE_ASSET_FILE).build()) | ||||
|             .setRunningMode(RunningMode.LIVE_STREAM) | ||||
|             .setResultListener((PoseLandmarkerResults, inputImage) -> {}) | ||||
|             .build(); | ||||
| 
 | ||||
|     PoseLandmarker poseLandmarker = | ||||
|         PoseLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options); | ||||
|     MediaPipeException exception = | ||||
|         assertThrows( | ||||
|             MediaPipeException.class, () -> poseLandmarker.detect(getImageFromAsset(POSE_IMAGE))); | ||||
|     assertThat(exception).hasMessageThat().contains("not initialized with the image mode"); | ||||
|     exception = | ||||
|         assertThrows( | ||||
|             MediaPipeException.class, | ||||
|             () -> | ||||
|                 poseLandmarker.detectForVideo( | ||||
|                     getImageFromAsset(POSE_IMAGE), /* timestampsMs= */ 0)); | ||||
|     assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); | ||||
|   } | ||||
| 
 | ||||
|   @Test | ||||
|   public void recognize_successWithImageMode() throws Exception { | ||||
|     PoseLandmarkerOptions options = | ||||
|         PoseLandmarkerOptions.builder() | ||||
|             .setBaseOptions( | ||||
|                 BaseOptions.builder().setModelAssetPath(POSE_LANDMARKER_BUNDLE_ASSET_FILE).build()) | ||||
|             .setRunningMode(RunningMode.IMAGE) | ||||
|             .build(); | ||||
| 
 | ||||
|     PoseLandmarker poseLandmarker = | ||||
|         PoseLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options); | ||||
|     PoseLandmarkerResult actualResult = poseLandmarker.detect(getImageFromAsset(POSE_IMAGE)); | ||||
|     PoseLandmarkerResult expectedResult = getExpectedPoseLandmarkerResult(POSE_LANDMARKS); | ||||
|     assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult); | ||||
|   } | ||||
| 
 | ||||
|   @Test | ||||
|   public void recognize_successWithVideoMode() throws Exception { | ||||
|     PoseLandmarkerOptions options = | ||||
|         PoseLandmarkerOptions.builder() | ||||
|             .setBaseOptions( | ||||
|                 BaseOptions.builder().setModelAssetPath(POSE_LANDMARKER_BUNDLE_ASSET_FILE).build()) | ||||
|             .setRunningMode(RunningMode.VIDEO) | ||||
|             .build(); | ||||
|     PoseLandmarker poseLandmarker = | ||||
|         PoseLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options); | ||||
|     PoseLandmarkerResult expectedResult = getExpectedPoseLandmarkerResult(POSE_LANDMARKS); | ||||
|     for (int i = 0; i < 3; i++) { | ||||
|       PoseLandmarkerResult actualResult = | ||||
|           poseLandmarker.detectForVideo(getImageFromAsset(POSE_IMAGE), /* timestampsMs= */ i); | ||||
|       assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   @Test | ||||
|   public void recognize_failsWithOutOfOrderInputTimestamps() throws Exception { | ||||
|     MPImage image = getImageFromAsset(POSE_IMAGE); | ||||
|     PoseLandmarkerResult expectedResult = getExpectedPoseLandmarkerResult(POSE_LANDMARKS); | ||||
|     PoseLandmarkerOptions options = | ||||
|         PoseLandmarkerOptions.builder() | ||||
|             .setBaseOptions( | ||||
|                 BaseOptions.builder().setModelAssetPath(POSE_LANDMARKER_BUNDLE_ASSET_FILE).build()) | ||||
|             .setRunningMode(RunningMode.LIVE_STREAM) | ||||
|             .setResultListener( | ||||
|                 (actualResult, inputImage) -> { | ||||
|                   assertActualResultApproximatelyEqualsToExpectedResult( | ||||
|                       actualResult, expectedResult); | ||||
|                   assertImageSizeIsExpected(inputImage); | ||||
|                 }) | ||||
|             .build(); | ||||
|     try (PoseLandmarker poseLandmarker = | ||||
|         PoseLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { | ||||
|       poseLandmarker.detectAsync(image, /* timestampsMs= */ 1); | ||||
|       MediaPipeException exception = | ||||
|           assertThrows( | ||||
|               MediaPipeException.class, | ||||
|               () -> poseLandmarker.detectAsync(image, /* timestampsMs= */ 0)); | ||||
|       assertThat(exception) | ||||
|           .hasMessageThat() | ||||
|           .contains("having a smaller timestamp than the processed timestamp"); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   @Test | ||||
|   public void recognize_successWithLiveSteamMode() throws Exception { | ||||
|     MPImage image = getImageFromAsset(POSE_IMAGE); | ||||
|     PoseLandmarkerResult expectedResult = getExpectedPoseLandmarkerResult(POSE_LANDMARKS); | ||||
|     PoseLandmarkerOptions options = | ||||
|         PoseLandmarkerOptions.builder() | ||||
|             .setBaseOptions( | ||||
|                 BaseOptions.builder().setModelAssetPath(POSE_LANDMARKER_BUNDLE_ASSET_FILE).build()) | ||||
|             .setRunningMode(RunningMode.LIVE_STREAM) | ||||
|             .setResultListener( | ||||
|                 (actualResult, inputImage) -> { | ||||
|                   assertActualResultApproximatelyEqualsToExpectedResult( | ||||
|                       actualResult, expectedResult); | ||||
|                   assertImageSizeIsExpected(inputImage); | ||||
|                 }) | ||||
|             .build(); | ||||
|     try (PoseLandmarker poseLandmarker = | ||||
|         PoseLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { | ||||
|       for (int i = 0; i < 3; i++) { | ||||
|         poseLandmarker.detectAsync(image, /* timestampsMs= */ i); | ||||
|       } | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   private static MPImage getImageFromAsset(String filePath) throws Exception { | ||||
|     AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets(); | ||||
|     InputStream istr = assetManager.open(filePath); | ||||
|     return new BitmapImageBuilder(BitmapFactory.decodeStream(istr)).build(); | ||||
|   } | ||||
| 
 | ||||
|   private static PoseLandmarkerResult getExpectedPoseLandmarkerResult(String filePath) | ||||
|       throws Exception { | ||||
|     AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets(); | ||||
|     InputStream istr = assetManager.open(filePath); | ||||
|     LandmarksDetectionResult landmarksDetectionResultProto = | ||||
|         LandmarksDetectionResult.parser().parseFrom(istr); | ||||
|     return PoseLandmarkerResult.create( | ||||
|         Arrays.asList(landmarksDetectionResultProto.getLandmarks()), | ||||
|         Arrays.asList(landmarksDetectionResultProto.getWorldLandmarks()), | ||||
|         Arrays.asList(), | ||||
|         Optional.empty(), | ||||
|         /* timestampMs= */ 0); | ||||
|   } | ||||
| 
 | ||||
|   private static void assertActualResultApproximatelyEqualsToExpectedResult( | ||||
|       PoseLandmarkerResult actualResult, PoseLandmarkerResult expectedResult) { | ||||
|     // TODO: Add additional tests for MP Tasks Pose Graphs | ||||
|     // Add additional tests for auxiliary, world landmarks and segmentation masks. | ||||
|     // Expects to have the same number of poses detected. | ||||
|     assertThat(actualResult.landmarks()).hasSize(expectedResult.landmarks().size()); | ||||
| 
 | ||||
|     // Actual landmarks match expected landmarks. | ||||
|     assertThat(actualResult.landmarks().get(0)) | ||||
|         .comparingElementsUsing( | ||||
|             Correspondence.from( | ||||
|                 (Correspondence.BinaryPredicate<NormalizedLandmark, NormalizedLandmark>) | ||||
|                     (actual, expected) -> { | ||||
|                       return Correspondence.tolerance(LANDMARKS_ERROR_TOLERANCE) | ||||
|                               .compare(actual.x(), expected.x()) | ||||
|                           && Correspondence.tolerance(LANDMARKS_ERROR_TOLERANCE) | ||||
|                               .compare(actual.y(), expected.y()); | ||||
|                     }, | ||||
|                 "landmarks approximately equal to")) | ||||
|         .containsExactlyElementsIn(expectedResult.landmarks().get(0)); | ||||
|   } | ||||
| 
 | ||||
|   private static void assertImageSizeIsExpected(MPImage inputImage) { | ||||
|     assertThat(inputImage).isNotNull(); | ||||
|     assertThat(inputImage.getWidth()).isEqualTo(IMAGE_WIDTH); | ||||
|     assertThat(inputImage.getHeight()).isEqualTo(IMAGE_HEIGHT); | ||||
|   } | ||||
| } | ||||
							
								
								
									
										6
									
								
								third_party/external_files.bzl
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										6
									
								
								third_party/external_files.bzl
									
									
									
									
										vendored
									
									
								
							|  | @ -307,7 +307,7 @@ def external_files(): | |||
|     http_file( | ||||
|         name = "com_google_mediapipe_expected_pose_landmarks_prototxt", | ||||
|         sha256 = "eed8dfa169b0abee60cde01496599b0bc75d91a82594a1bdf59be2f76f45d7f5", | ||||
|         urls = ["https://storage.googleapis.com/mediapipe-assets/expected_pose_landmarks.prototxt?generation=1681244232522990"], | ||||
|         urls = ["https://storage.googleapis.com/mediapipe-assets/expected_pose_landmarks.prototxt?generation=16812442325229901681244235071100"], | ||||
|     ) | ||||
| 
 | ||||
|     http_file( | ||||
|  | @ -996,8 +996,8 @@ def external_files(): | |||
| 
 | ||||
|     http_file( | ||||
|         name = "com_google_mediapipe_pose_landmarks_pbtxt", | ||||
|         sha256 = "305a71fbff83e270a5dbd81fb7cf65203f56e0b1caba8ea42edc16c6e8a2ba18", | ||||
|         urls = ["https://storage.googleapis.com/mediapipe-assets/pose_landmarks.pbtxt?generation=1681244254964356"], | ||||
|         sha256 = "69c79cdf3964d7819776eab1172e47e70684139d72a6d7edcbdd62dbb2ca5527", | ||||
|         urls = ["https://storage.googleapis.com/mediapipe-assets/pose_landmarks.pbtxt?generation=1681425322701589"], | ||||
|     ) | ||||
| 
 | ||||
|     http_file( | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user