From 62bafd39bb5ff53ffecbfff2de85b00d875e86d7 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Tue, 28 Nov 2023 14:57:41 -0800 Subject: [PATCH] HolisticLandmarker Java API PiperOrigin-RevId: 586113048 --- .../mediapipe/tasks/mediapipe_tasks_aar.bzl | 7 +- .../com/google/mediapipe/tasks/vision/BUILD | 12 + .../holisticlandmarker/AndroidManifest.xml | 8 + .../HolisticLandmarker.java | 668 ++++++++++++++++++ .../holisticlandmarker/AndroidManifest.xml | 24 + .../tasks/vision/holisticlandmarker/BUILD | 19 + .../HolisticLandmarkerTest.java | 512 ++++++++++++++ mediapipe/tasks/testdata/vision/BUILD | 1 + 8 files changed, 1248 insertions(+), 3 deletions(-) create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/vision/holisticlandmarker/AndroidManifest.xml create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/vision/holisticlandmarker/HolisticLandmarker.java create mode 100644 mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/holisticlandmarker/AndroidManifest.xml create mode 100644 mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/holisticlandmarker/BUILD create mode 100644 mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/holisticlandmarker/HolisticLandmarkerTest.java diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl index 916323372..e63695e31 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl @@ -47,13 +47,14 @@ _VISION_TASKS_JAVA_PROTO_LITE_TARGETS = [ "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_embedder_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/holistic_landmarker/proto:holistic_landmarker_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_java_proto_lite", - "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite", - "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_java_proto_lite", - "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_java_proto_lite", "//mediapipe/tasks/cc/vision/pose_detector/proto:pose_detector_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarker_graph_options_java_proto_lite", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD index fc56bfa27..2d5ef7a9c 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD @@ -67,6 +67,7 @@ cc_binary( "//mediapipe/tasks/cc/vision/face_landmarker:face_landmarker_graph", "//mediapipe/tasks/cc/vision/face_stylizer:face_stylizer_graph", "//mediapipe/tasks/cc/vision/gesture_recognizer:gesture_recognizer_graph", + "//mediapipe/tasks/cc/vision/holistic_landmarker:holistic_landmarker_graph", "//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph", "//mediapipe/tasks/cc/vision/image_embedder:image_embedder_graph", "//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph", @@ -429,6 +430,7 @@ filegroup( android_library( name = "holisticlandmarker", srcs = [ + "holisticlandmarker/HolisticLandmarker.java", "holisticlandmarker/HolisticLandmarkerResult.java", ], javacopts = [ @@ -439,10 +441,20 @@ android_library( ":core", "//mediapipe/framework/formats:classification_java_proto_lite", "//mediapipe/framework/formats:landmark_java_proto_lite", + "//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/java/com/google/mediapipe/framework/image", + "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/holistic_landmarker/proto:holistic_landmarker_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/pose_detector/proto:pose_detector_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarks_detector_graph_options_java_proto_lite", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:landmark", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:normalized_landmark", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", + "//third_party:any_java_proto", "//third_party:autovalue", "@maven//:com_google_guava_guava", ], diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/holisticlandmarker/AndroidManifest.xml b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/holisticlandmarker/AndroidManifest.xml new file mode 100644 index 000000000..a90c388f4 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/holisticlandmarker/AndroidManifest.xml @@ -0,0 +1,8 @@ + + + + + + diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/holisticlandmarker/HolisticLandmarker.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/holisticlandmarker/HolisticLandmarker.java new file mode 100644 index 000000000..e80da4fca --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/holisticlandmarker/HolisticLandmarker.java @@ -0,0 +1,668 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.holisticlandmarker; + +import android.content.Context; +import android.os.ParcelFileDescriptor; +import com.google.auto.value.AutoValue; +import com.google.mediapipe.formats.proto.LandmarkProto.LandmarkList; +import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmarkList; +import com.google.mediapipe.formats.proto.ClassificationProto.ClassificationList; +import com.google.mediapipe.framework.AndroidPacketGetter; +import com.google.mediapipe.framework.MediaPipeException; +import com.google.mediapipe.framework.Packet; +import com.google.mediapipe.framework.PacketGetter; +import com.google.mediapipe.framework.image.BitmapImageBuilder; +import com.google.mediapipe.framework.image.ByteBufferImageBuilder; +import com.google.mediapipe.framework.image.MPImage; +import com.google.mediapipe.tasks.core.BaseOptions; +import com.google.mediapipe.tasks.core.ErrorListener; +import com.google.mediapipe.tasks.core.OutputHandler; +import com.google.mediapipe.tasks.core.OutputHandler.ResultListener; +import com.google.mediapipe.tasks.core.TaskInfo; +import com.google.mediapipe.tasks.core.TaskOptions; +import com.google.mediapipe.tasks.core.TaskRunner; +import com.google.mediapipe.tasks.core.proto.BaseOptionsProto; +import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi; +import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; +import com.google.mediapipe.tasks.vision.core.RunningMode; +import com.google.mediapipe.tasks.vision.facedetector.proto.FaceDetectorGraphOptionsProto.FaceDetectorGraphOptions; +import com.google.mediapipe.tasks.vision.facelandmarker.proto.FaceLandmarksDetectorGraphOptionsProto.FaceLandmarksDetectorGraphOptions; +import com.google.mediapipe.tasks.vision.handlandmarker.proto.HandLandmarksDetectorGraphOptionsProto.HandLandmarksDetectorGraphOptions; +import com.google.mediapipe.tasks.vision.holisticlandmarker.proto.HolisticLandmarkerGraphOptionsProto.HolisticLandmarkerGraphOptions; +import com.google.mediapipe.tasks.vision.posedetector.proto.PoseDetectorGraphOptionsProto.PoseDetectorGraphOptions; +import com.google.mediapipe.tasks.vision.poselandmarker.proto.PoseLandmarksDetectorGraphOptionsProto.PoseLandmarksDetectorGraphOptions; +import com.google.protobuf.Any; +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +/** + * Performs holistic landmarks detection on images. + * + *

This API expects a pre-trained holistic landmarks model asset bundle. + * + *

+ */ +public final class HolisticLandmarker extends BaseVisionTaskApi { + private static final String TAG = HolisticLandmarker.class.getSimpleName(); + + private static final String IMAGE_IN_STREAM_NAME = "image_in"; + private static final String POSE_LANDMARKS_STREAM = "pose_landmarks"; + private static final String POSE_WORLD_LANDMARKS_STREAM = "pose_world_landmarks"; + private static final String POSE_SEGMENTATION_MASK_STREAM = "pose_segmentation_mask"; + private static final String FACE_LANDMARKS_STREAM = "face_landmarks"; + private static final String FACE_BLENDSHAPES_STREAM = "extra_blendshapes"; + private static final String LEFT_HAND_LANDMARKS_STREAM = "left_hand_landmarks"; + private static final String LEFT_HAND_WORLD_LANDMARKS_STREAM = "left_hand_world_landmarks"; + private static final String RIGHT_HAND_LANDMARKS_STREAM = "right_hand_landmarks"; + private static final String RIGHT_HAND_WORLD_LANDMARKS_STREAM = "right_hand_world_landmarks"; + private static final String IMAGE_OUT_STREAM_NAME = "image_out"; + + private static final int FACE_LANDMARKS_OUT_STREAM_INDEX = 0; + private static final int POSE_LANDMARKS_OUT_STREAM_INDEX = 1; + private static final int POSE_WORLD_LANDMARKS_OUT_STREAM_INDEX = 2; + private static final int LEFT_HAND_LANDMARKS_OUT_STREAM_INDEX = 3; + private static final int LEFT_HAND_WORLD_LANDMARKS_OUT_STREAM_INDEX = 4; + private static final int RIGHT_HAND_LANDMARKS_OUT_STREAM_INDEX = 5; + private static final int RIGHT_HAND_WORLD_LANDMARKS_OUT_STREAM_INDEX = 6; + private static final int IMAGE_OUT_STREAM_INDEX = 7; + + private static final float DEFAULT_PRESENCE_THRESHOLD = 0.5f; + private static final float DEFAULT_SUPPRESION_THRESHOLD = 0.3f; + private static final boolean DEFAULT_OUTPUT_FACE_BLENDSHAPES = false; + private static final boolean DEFAULT_OUTPUT_SEGMENTATION_MASKS = false; + + private static final String TASK_GRAPH_NAME = + "mediapipe.tasks.vision.holistic_landmarker.HolisticLandmarkerGraph"; + + @SuppressWarnings("ConstantCaseForConstants") + private static final List INPUT_STREAMS = + Collections.unmodifiableList(Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME)); + + static { + System.loadLibrary("mediapipe_tasks_vision_jni"); + } + + /** + * Creates a {@link HolisticLandmarker} instance from a model asset bundle path and the default + * {@link HolisticLandmarkerOptions}. + * + * @param context an Android {@link Context}. + * @param modelAssetPath path to the holistic landmarks model with metadata in the assets. + * @throws MediaPipeException if there is an error during {@link HolisticLandmarker} creation. + */ + public static HolisticLandmarker createFromFile(Context context, String modelAssetPath) { + BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(modelAssetPath).build(); + return createFromOptions( + context, HolisticLandmarkerOptions.builder().setBaseOptions(baseOptions).build()); + } + + /** + * Creates a {@link HolisticLandmarker} instance from a model asset bundle file and the default + * {@link HolisticLandmarkerOptions}. + * + * @param context an Android {@link Context}. + * @param modelAssetFile the holistic landmarks model {@link File} instance. + * @throws IOException if an I/O error occurs when opening the tflite model file. + * @throws MediaPipeException if there is an error during {@link HolisticLandmarker} creation. + */ + public static HolisticLandmarker createFromFile(Context context, File modelAssetFile) + throws IOException { + try (ParcelFileDescriptor descriptor = + ParcelFileDescriptor.open(modelAssetFile, ParcelFileDescriptor.MODE_READ_ONLY)) { + BaseOptions baseOptions = + BaseOptions.builder().setModelAssetFileDescriptor(descriptor.getFd()).build(); + return createFromOptions( + context, HolisticLandmarkerOptions.builder().setBaseOptions(baseOptions).build()); + } + } + + /** + * Creates a {@link HolisticLandmarker} instance from a model asset bundle buffer and the default + * {@link HolisticLandmarkerOptions}. + * + * @param context an Android {@link Context}. + * @param modelAssetBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the + * detection model. + * @throws MediaPipeException if there is an error during {@link HolisticLandmarker} creation. + */ + public static HolisticLandmarker createFromBuffer( + Context context, final ByteBuffer modelAssetBuffer) { + BaseOptions baseOptions = BaseOptions.builder().setModelAssetBuffer(modelAssetBuffer).build(); + return createFromOptions( + context, HolisticLandmarkerOptions.builder().setBaseOptions(baseOptions).build()); + } + + /** + * Creates a {@link HolisticLandmarker} instance from a {@link HolisticLandmarkerOptions}. + * + * @param context an Android {@link Context}. + * @param landmarkerOptions a {@link HolisticLandmarkerOptions} instance. + * @throws MediaPipeException if there is an error during {@link HolisticLandmarker} creation. + */ + public static HolisticLandmarker createFromOptions( + Context context, HolisticLandmarkerOptions landmarkerOptions) { + List outputStreams = new ArrayList<>(); + outputStreams.add("FACE_LANDMARKS:" + FACE_LANDMARKS_STREAM); + outputStreams.add("POSE_LANDMARKS:" + POSE_LANDMARKS_STREAM); + outputStreams.add("POSE_WORLD_LANDMARKS:" + POSE_WORLD_LANDMARKS_STREAM); + outputStreams.add("LEFT_HAND_LANDMARKS:" + LEFT_HAND_LANDMARKS_STREAM); + outputStreams.add("LEFT_HAND_WORLD_LANDMARKS:" + LEFT_HAND_WORLD_LANDMARKS_STREAM); + outputStreams.add("RIGHT_HAND_LANDMARKS:" + RIGHT_HAND_LANDMARKS_STREAM); + outputStreams.add("RIGHT_HAND_WORLD_LANDMARKS:" + RIGHT_HAND_WORLD_LANDMARKS_STREAM); + outputStreams.add("IMAGE:" + IMAGE_OUT_STREAM_NAME); + + int[] faceBlendshapesOutStreamIndex = new int[] {-1}; + if (landmarkerOptions.outputFaceBlendshapes()) { + outputStreams.add("FACE_BLENDSHAPES:" + FACE_BLENDSHAPES_STREAM); + faceBlendshapesOutStreamIndex[0] = outputStreams.size() - 1; + } + + int[] poseSegmentationMasksOutStreamIndex = new int[] {-1}; + if (landmarkerOptions.outputPoseSegmentationMasks()) { + outputStreams.add("POSE_SEGMENTATION_MASK:" + POSE_SEGMENTATION_MASK_STREAM); + poseSegmentationMasksOutStreamIndex[0] = outputStreams.size() - 1; + } + + OutputHandler handler = new OutputHandler<>(); + handler.setOutputPacketConverter( + new OutputHandler.OutputPacketConverter() { + @Override + public HolisticLandmarkerResult convertToTaskResult(List packets) { + // If there are no detected landmarks, just returns empty lists. + if (packets.get(FACE_LANDMARKS_OUT_STREAM_INDEX).isEmpty()) { + return HolisticLandmarkerResult.createEmpty( + BaseVisionTaskApi.generateResultTimestampMs( + landmarkerOptions.runningMode(), + packets.get(FACE_LANDMARKS_OUT_STREAM_INDEX))); + } + + NormalizedLandmarkList faceLandmarkProtos = + PacketGetter.getProto( + packets.get(FACE_LANDMARKS_OUT_STREAM_INDEX), NormalizedLandmarkList.parser()); + Optional faceBlendshapeProtos = + landmarkerOptions.outputFaceBlendshapes() + ? Optional.of( + PacketGetter.getProto( + packets.get(faceBlendshapesOutStreamIndex[0]), + ClassificationList.parser())) + : Optional.empty(); + NormalizedLandmarkList poseLandmarkProtos = + PacketGetter.getProto( + packets.get(POSE_LANDMARKS_OUT_STREAM_INDEX), NormalizedLandmarkList.parser()); + LandmarkList poseWorldLandmarkProtos = + PacketGetter.getProto( + packets.get(POSE_WORLD_LANDMARKS_OUT_STREAM_INDEX), LandmarkList.parser()); + Optional segmentationMask = + landmarkerOptions.outputPoseSegmentationMasks() + ? Optional.of( + getSegmentationMask(packets, poseSegmentationMasksOutStreamIndex[0])) + : Optional.empty(); + NormalizedLandmarkList leftHandLandmarkProtos = + PacketGetter.getProto( + packets.get(LEFT_HAND_LANDMARKS_OUT_STREAM_INDEX), + NormalizedLandmarkList.parser()); + LandmarkList leftHandWorldLandmarkProtos = + PacketGetter.getProto( + packets.get(LEFT_HAND_WORLD_LANDMARKS_OUT_STREAM_INDEX), LandmarkList.parser()); + NormalizedLandmarkList rightHandLandmarkProtos = + PacketGetter.getProto( + packets.get(RIGHT_HAND_LANDMARKS_OUT_STREAM_INDEX), + NormalizedLandmarkList.parser()); + LandmarkList rightHandWorldLandmarkProtos = + PacketGetter.getProto( + packets.get(RIGHT_HAND_WORLD_LANDMARKS_OUT_STREAM_INDEX), + LandmarkList.parser()); + + return HolisticLandmarkerResult.create( + faceLandmarkProtos, + faceBlendshapeProtos, + poseLandmarkProtos, + poseWorldLandmarkProtos, + segmentationMask, + leftHandLandmarkProtos, + leftHandWorldLandmarkProtos, + rightHandLandmarkProtos, + rightHandWorldLandmarkProtos, + BaseVisionTaskApi.generateResultTimestampMs( + landmarkerOptions.runningMode(), packets.get(FACE_LANDMARKS_OUT_STREAM_INDEX))); + } + + @Override + public MPImage convertToTaskInput(List packets) { + return new BitmapImageBuilder( + AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX))) + .build(); + } + }); + landmarkerOptions.resultListener().ifPresent(handler::setResultListener); + landmarkerOptions.errorListener().ifPresent(handler::setErrorListener); + TaskRunner runner = + TaskRunner.create( + context, + TaskInfo.builder() + .setTaskName(HolisticLandmarker.class.getSimpleName()) + .setTaskRunningModeName(landmarkerOptions.runningMode().name()) + .setTaskGraphName(TASK_GRAPH_NAME) + .setInputStreams(INPUT_STREAMS) + .setOutputStreams(outputStreams) + .setTaskOptions(landmarkerOptions) + .setEnableFlowLimiting(landmarkerOptions.runningMode() == RunningMode.LIVE_STREAM) + .build(), + handler); + return new HolisticLandmarker(runner, landmarkerOptions.runningMode()); + } + + /** + * Constructor to initialize an {@link HolisticLandmarker} from a {@link TaskRunner} and a {@link + * RunningMode}. + * + * @param taskRunner a {@link TaskRunner}. + * @param runningMode a mediapipe vision task {@link RunningMode}. + */ + private HolisticLandmarker(TaskRunner taskRunner, RunningMode runningMode) { + super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, /* normRectStreamName= */ ""); + } + + /** + * Performs holistic landmarks detection on the provided single image with default image + * processing options, i.e. without any rotation applied. Only use this method when the {@link + * HolisticLandmarker} is created with {@link RunningMode.IMAGE}. + * + *

{@link HolisticLandmarker} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @throws MediaPipeException if there is an internal error. + */ + public HolisticLandmarkerResult detect(MPImage image) { + return detect(image, ImageProcessingOptions.builder().build()); + } + + /** + * Performs holistic landmarks detection on the provided single image. Only use this method when + * the {@link HolisticLandmarker} is created with {@link RunningMode.IMAGE}. + * + *

{@link HolisticLandmarker} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. Note that region-of-interest is not supported + * by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in + * this method throwing an IllegalArgumentException. + * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a + * region-of-interest. + * @throws MediaPipeException if there is an internal error. + */ + public HolisticLandmarkerResult detect( + MPImage image, ImageProcessingOptions imageProcessingOptions) { + validateImageProcessingOptions(imageProcessingOptions); + return (HolisticLandmarkerResult) processImageData(image, imageProcessingOptions); + } + + /** + * Performs holistic landmarks detection on the provided video frame with default image processing + * options, i.e. without any rotation applied. Only use this method when the {@link + * HolisticLandmarker} is created with {@link RunningMode.VIDEO}. + * + *

It's required to provide the video frame"s timestamp (in milliseconds). The input timestamps + * must be monotonically increasing. + * + *

{@link HolisticLandmarker} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param timestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public HolisticLandmarkerResult detectForVideo(MPImage image, long timestampMs) { + return detectForVideo(image, ImageProcessingOptions.builder().build(), timestampMs); + } + + /** + * Performs holistic landmarks detection on the provided video frame. Only use this method when + * the {@link HolisticLandmarker} is created with {@link RunningMode.VIDEO}. + * + *

It's required to provide the video frame"s timestamp (in milliseconds). The input timestamps + * must be monotonically increasing. + * + *

{@link HolisticLandmarker} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. Note that region-of-interest is not supported + * by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in + * this method throwing an IllegalArgumentException. + * @param timestampMs the input timestamp (in milliseconds). + * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a + * region-of-interest. + * @throws MediaPipeException if there is an internal error. + */ + public HolisticLandmarkerResult detectForVideo( + MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { + validateImageProcessingOptions(imageProcessingOptions); + return (HolisticLandmarkerResult) processVideoData(image, imageProcessingOptions, timestampMs); + } + + /** + * Sends live image data to perform holistic landmarks detection with default image processing + * options, i.e. without any rotation applied, and the results will be available via the {@link + * ResultListener} provided in the {@link HolisticLandmarkerOptions}. Only use this method when + * the {@link HolisticLandmarker } is created with {@link RunningMode.LIVE_STREAM}. + * + *

It's required to provide a timestamp (in milliseconds) to indicate when the input image is + * sent to the holistic landmarker. The input timestamps must be monotonically increasing. + * + *

{@link HolisticLandmarker} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param timestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public void detectAsync(MPImage image, long timestampMs) { + detectAsync(image, ImageProcessingOptions.builder().build(), timestampMs); + } + + /** + * Sends live image data to perform holistic landmarks detection, and the results will be + * available via the {@link ResultListener} provided in the {@link HolisticLandmarkerOptions}. + * Only use this method when the {@link HolisticLandmarker} is created with {@link + * RunningMode.LIVE_STREAM}. + * + *

It's required to provide a timestamp (in milliseconds) to indicate when the input image is + * sent to the holistic landmarker. The input timestamps must be monotonically increasing. + * + *

{@link HolisticLandmarker} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. Note that region-of-interest is not supported + * by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in + * this method throwing an IllegalArgumentException. + * @param timestampMs the input timestamp (in milliseconds). + * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a + * region-of-interest. + * @throws MediaPipeException if there is an internal error. + */ + public void detectAsync( + MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { + validateImageProcessingOptions(imageProcessingOptions); + sendLiveStreamData(image, imageProcessingOptions, timestampMs); + } + + /** Options for setting up an {@link HolisticLandmarker}. */ + @AutoValue + public abstract static class HolisticLandmarkerOptions extends TaskOptions { + + /** Builder for {@link HolisticLandmarkerOptions}. */ + @AutoValue.Builder + public abstract static class Builder { + /** Sets the base options for the holistic landmarker task. */ + public abstract Builder setBaseOptions(BaseOptions value); + + /** + * Sets the running mode for the holistic landmarker task. Defaults to the image mode. + * Holistic landmarker has three modes: + * + *
    + *
  • IMAGE: The mode for detecting holistic landmarks on single image inputs. + *
  • VIDEO: The mode for detecting holistic landmarks on the decoded frames of a video. + *
  • LIVE_STREAM: The mode for for detecting holistic landmarks on a live stream of input + * data, such as from camera. In this mode, {@code setResultListener} must be called to + * set up a listener to receive the detection results asynchronously. + *
+ */ + public abstract Builder setRunningMode(RunningMode value); + + /** + * Sets minimum confidence score for the face detection to be considered successful. Defaults + * to 0.5. + */ + public abstract Builder setMinFaceDetectionConfidence(Float value); + + /** + * The minimum threshold for the face suppression score in the face detection. Defaults to + * 0.3. + */ + public abstract Builder setMinFaceSuppressionThreshold(Float value); + + /** + * Sets minimum confidence score for the face landmark detection to be considered successful. + * Defaults to 0.5. + */ + public abstract Builder setMinFaceLandmarksConfidence(Float value); + + /** + * The minimum confidence score for the pose detection to be considered successful. Defaults + * to 0.5. + */ + public abstract Builder setMinPoseDetectionConfidence(Float value); + + /** + * The minimum threshold for the pose suppression score in the pose detection. Defaults to + * 0.3. + */ + public abstract Builder setMinPoseSuppressionThreshold(Float value); + + /** + * The minimum confidence score for the pose landmarks detection to be considered successful. + * Defaults to 0.5. + */ + public abstract Builder setMinPoseLandmarksConfidence(Float value); + + /** + * The minimum confidence score for the hand landmark detection to be considered successful. + * Defaults to 0.5. + */ + public abstract Builder setMinHandLandmarksConfidence(Float value); + + /** Whether to output segmentation masks. Defaults to false. */ + public abstract Builder setOutputPoseSegmentationMasks(Boolean value); + + /** Whether to output face blendshapes. Defaults to false. */ + public abstract Builder setOutputFaceBlendshapes(Boolean value); + + /** + * Sets the result listener to receive the detection results asynchronously when the holistic + * landmarker is in the live stream mode. + */ + public abstract Builder setResultListener( + ResultListener value); + + /** Sets an optional error listener. */ + public abstract Builder setErrorListener(ErrorListener value); + + abstract HolisticLandmarkerOptions autoBuild(); + + /** + * Validates and builds the {@link HolisticLandmarkerOptions} instance. + * + * @throws IllegalArgumentException if the result listener and the running mode are not + * properly configured. The result listener should only be set when the holistic + * landmarker is in the live stream mode. + */ + public final HolisticLandmarkerOptions build() { + HolisticLandmarkerOptions options = autoBuild(); + if (options.runningMode() == RunningMode.LIVE_STREAM) { + if (!options.resultListener().isPresent()) { + throw new IllegalArgumentException( + "The holistic landmarker is in the live stream mode, a user-defined result listener" + + " must be provided in HolisticLandmarkerOptions."); + } + } else if (options.resultListener().isPresent()) { + throw new IllegalArgumentException( + "The holistic landmarker is in the image or the video mode, a user-defined result" + + " listener shouldn't be provided in HolisticLandmarkerOptions."); + } + return options; + } + } + + abstract BaseOptions baseOptions(); + + abstract RunningMode runningMode(); + + abstract Optional minFaceDetectionConfidence(); + + abstract Optional minFaceSuppressionThreshold(); + + abstract Optional minFaceLandmarksConfidence(); + + abstract Optional minPoseDetectionConfidence(); + + abstract Optional minPoseSuppressionThreshold(); + + abstract Optional minPoseLandmarksConfidence(); + + abstract Optional minHandLandmarksConfidence(); + + abstract Boolean outputFaceBlendshapes(); + + abstract Boolean outputPoseSegmentationMasks(); + + abstract Optional> resultListener(); + + abstract Optional errorListener(); + + public static Builder builder() { + return new AutoValue_HolisticLandmarker_HolisticLandmarkerOptions.Builder() + .setRunningMode(RunningMode.IMAGE) + .setMinFaceDetectionConfidence(DEFAULT_PRESENCE_THRESHOLD) + .setMinFaceSuppressionThreshold(DEFAULT_SUPPRESION_THRESHOLD) + .setMinFaceLandmarksConfidence(DEFAULT_PRESENCE_THRESHOLD) + .setMinPoseDetectionConfidence(DEFAULT_PRESENCE_THRESHOLD) + .setMinPoseSuppressionThreshold(DEFAULT_SUPPRESION_THRESHOLD) + .setMinPoseLandmarksConfidence(DEFAULT_PRESENCE_THRESHOLD) + .setMinHandLandmarksConfidence(DEFAULT_PRESENCE_THRESHOLD) + .setOutputFaceBlendshapes(DEFAULT_OUTPUT_FACE_BLENDSHAPES) + .setOutputPoseSegmentationMasks(DEFAULT_OUTPUT_SEGMENTATION_MASKS); + } + + /** Converts a {@link HolisticLandmarkerOptions} to a {@link Any} protobuf message. */ + @Override + public Any convertToAnyProto() { + HolisticLandmarkerGraphOptions.Builder holisticLandmarkerGraphOptions = + HolisticLandmarkerGraphOptions.newBuilder() + .setBaseOptions( + BaseOptionsProto.BaseOptions.newBuilder() + .setUseStreamMode(runningMode() != RunningMode.IMAGE) + .mergeFrom(convertBaseOptionsToProto(baseOptions())) + .build()); + + HandLandmarksDetectorGraphOptions.Builder handLandmarksDetectorGraphOptions = + HandLandmarksDetectorGraphOptions.newBuilder(); + FaceDetectorGraphOptions.Builder faceDetectorGraphOptions = + FaceDetectorGraphOptions.newBuilder(); + FaceLandmarksDetectorGraphOptions.Builder faceLandmarksDetectorGraphOptions = + FaceLandmarksDetectorGraphOptions.newBuilder(); + PoseDetectorGraphOptions.Builder poseDetectorGraphOptions = + PoseDetectorGraphOptions.newBuilder(); + PoseLandmarksDetectorGraphOptions.Builder poseLandmarkerGraphOptions = + PoseLandmarksDetectorGraphOptions.newBuilder(); + + // Configure hand detector options. + minHandLandmarksConfidence() + .ifPresent(handLandmarksDetectorGraphOptions::setMinDetectionConfidence); + + // Configure pose detector options. + minPoseDetectionConfidence().ifPresent(poseDetectorGraphOptions::setMinDetectionConfidence); + minPoseSuppressionThreshold().ifPresent(poseDetectorGraphOptions::setMinSuppressionThreshold); + minPoseLandmarksConfidence().ifPresent(poseLandmarkerGraphOptions::setMinDetectionConfidence); + + // Configure face detector options. + minFaceDetectionConfidence().ifPresent(faceDetectorGraphOptions::setMinDetectionConfidence); + minFaceSuppressionThreshold().ifPresent(faceDetectorGraphOptions::setMinSuppressionThreshold); + minFaceLandmarksConfidence() + .ifPresent(faceLandmarksDetectorGraphOptions::setMinDetectionConfidence); + + holisticLandmarkerGraphOptions + .setHandLandmarksDetectorGraphOptions(handLandmarksDetectorGraphOptions.build()) + .setFaceDetectorGraphOptions(faceDetectorGraphOptions.build()) + .setFaceLandmarksDetectorGraphOptions(faceLandmarksDetectorGraphOptions.build()) + .setPoseDetectorGraphOptions(poseDetectorGraphOptions.build()) + .setPoseLandmarksDetectorGraphOptions(poseLandmarkerGraphOptions.build()); + + return Any.newBuilder() + .setTypeUrl( + "type.googleapis.com/mediapipe.tasks.vision.holistic_landmarker.proto.HolisticLandmarkerGraphOptions") + .setValue(holisticLandmarkerGraphOptions.build().toByteString()) + .build(); + } + } + + /** + * Validates that the provided {@link ImageProcessingOptions} doesn"t contain a + * region-of-interest. + */ + private static void validateImageProcessingOptions( + ImageProcessingOptions imageProcessingOptions) { + if (imageProcessingOptions.regionOfInterest().isPresent()) { + throw new IllegalArgumentException("HolisticLandmarker doesn't support region-of-interest."); + } + } + + private static MPImage getSegmentationMask(List packets, int packetIndex) { + int width = PacketGetter.getImageWidth(packets.get(packetIndex)); + int height = PacketGetter.getImageHeight(packets.get(packetIndex)); + ByteBuffer buffer = ByteBuffer.allocateDirect(width * height * 4); + + if (!PacketGetter.getImageData(packets.get(packetIndex), buffer)) { + throw new MediaPipeException( + MediaPipeException.StatusCode.INTERNAL.ordinal(), + "There was an error getting the sefmentation mask."); + } + + ByteBufferImageBuilder builder = + new ByteBufferImageBuilder(buffer, width, height, MPImage.IMAGE_FORMAT_VEC32F1); + return builder.build(); + } +} diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/holisticlandmarker/AndroidManifest.xml b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/holisticlandmarker/AndroidManifest.xml new file mode 100644 index 000000000..22b19b702 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/holisticlandmarker/AndroidManifest.xml @@ -0,0 +1,24 @@ + + + + + + + + + + + + + + + diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/holisticlandmarker/BUILD b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/holisticlandmarker/BUILD new file mode 100644 index 000000000..287602c85 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/holisticlandmarker/BUILD @@ -0,0 +1,19 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +# TODO: Enable this in OSS diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/holisticlandmarker/HolisticLandmarkerTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/holisticlandmarker/HolisticLandmarkerTest.java new file mode 100644 index 000000000..f8c87c798 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/holisticlandmarker/HolisticLandmarkerTest.java @@ -0,0 +1,512 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.holisticlandmarker; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import android.content.res.AssetManager; +import android.graphics.BitmapFactory; +import android.graphics.RectF; +import androidx.test.core.app.ApplicationProvider; +import androidx.test.ext.junit.runners.AndroidJUnit4; +import com.google.common.truth.Correspondence; +import com.google.mediapipe.formats.proto.LandmarkProto.LandmarkList; +import com.google.mediapipe.formats.proto.ClassificationProto.ClassificationList; +import com.google.mediapipe.framework.MediaPipeException; +import com.google.mediapipe.framework.image.BitmapImageBuilder; +import com.google.mediapipe.framework.image.ByteBufferImageBuilder; +import com.google.mediapipe.framework.image.MPImage; +import com.google.mediapipe.tasks.components.containers.Category; +import com.google.mediapipe.tasks.components.containers.NormalizedLandmark; +import com.google.mediapipe.tasks.core.BaseOptions; +import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; +import com.google.mediapipe.tasks.vision.core.RunningMode; +import com.google.mediapipe.tasks.vision.holisticlandmarker.HolisticLandmarker.HolisticLandmarkerOptions; +import com.google.mediapipe.tasks.vision.holisticlandmarker.HolisticResultProto.HolisticResult; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.util.Optional; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Suite; +import org.junit.runners.Suite.SuiteClasses; + +/** Test for {@link HolisticLandmarker}. */ +@RunWith(Suite.class) +@SuiteClasses({HolisticLandmarkerTest.General.class, HolisticLandmarkerTest.RunningModeTest.class}) +public class HolisticLandmarkerTest { + private static final String HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE = "holistic_landmarker.task"; + private static final String POSE_IMAGE = "male_full_height_hands.jpg"; + private static final String CAT_IMAGE = "cat.jpg"; + private static final String HOLISTIC_RESULT = "male_full_height_hands_result_cpu.pb"; + private static final String TAG = "Holistic Landmarker Test"; + private static final float FACE_LANDMARKS_ERROR_TOLERANCE = 0.03f; + private static final float FACE_BLENDSHAPES_ERROR_TOLERANCE = 0.13f; + private static final MPImage PLACEHOLDER_MASK = + new ByteBufferImageBuilder( + ByteBuffer.allocate(0), /* widht= */ 0, /* height= */ 0, MPImage.IMAGE_FORMAT_VEC32F1) + .build(); + private static final int IMAGE_WIDTH = 638; + private static final int IMAGE_HEIGHT = 1000; + + private static final Correspondence VALIDATE_LANDMARRKS = + Correspondence.from( + (Correspondence.BinaryPredicate) + (actual, expected) -> { + return Correspondence.tolerance(FACE_LANDMARKS_ERROR_TOLERANCE) + .compare(actual.x(), expected.x()) + && Correspondence.tolerance(FACE_LANDMARKS_ERROR_TOLERANCE) + .compare(actual.y(), expected.y()); + }, + "landmarks approximately equal to"); + + private static final Correspondence VALIDATE_BLENDSHAPES = + Correspondence.from( + (Correspondence.BinaryPredicate) + (actual, expected) -> + Correspondence.tolerance(FACE_BLENDSHAPES_ERROR_TOLERANCE) + .compare(actual.score(), expected.score()) + && actual.index() == expected.index() + && actual.categoryName().equals(expected.categoryName()), + "face blendshapes approximately equal to"); + + @RunWith(AndroidJUnit4.class) + public static final class General extends HolisticLandmarkerTest { + + @Test + public void detect_successWithValidModels() throws Exception { + HolisticLandmarkerOptions options = + HolisticLandmarkerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE) + .build()) + .build(); + HolisticLandmarker holisticLandmarker = + HolisticLandmarker.createFromOptions( + ApplicationProvider.getApplicationContext(), options); + HolisticLandmarkerResult actualResult = + holisticLandmarker.detect(getImageFromAsset(POSE_IMAGE)); + HolisticLandmarkerResult expectedResult = + getExpectedHolisticLandmarkerResult( + HOLISTIC_RESULT, /* hasFaceBlendshapes= */ false, /* hasSegmentationMask= */ false); + assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult); + } + + @Test + public void detect_successWithBlendshapes() throws Exception { + HolisticLandmarkerOptions options = + HolisticLandmarkerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE) + .build()) + .setOutputFaceBlendshapes(true) + .build(); + HolisticLandmarker holisticLandmarker = + HolisticLandmarker.createFromOptions( + ApplicationProvider.getApplicationContext(), options); + HolisticLandmarkerResult actualResult = + holisticLandmarker.detect(getImageFromAsset(POSE_IMAGE)); + HolisticLandmarkerResult expectedResult = + getExpectedHolisticLandmarkerResult( + HOLISTIC_RESULT, /* hasFaceBlendshapes= */ true, /* hasSegmentationMask= */ false); + assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult); + } + + @Test + public void detect_successWithSegmentationMasks() throws Exception { + HolisticLandmarkerOptions options = + HolisticLandmarkerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE) + .build()) + .setOutputPoseSegmentationMasks(true) + .build(); + HolisticLandmarker holisticLandmarker = + HolisticLandmarker.createFromOptions( + ApplicationProvider.getApplicationContext(), options); + HolisticLandmarkerResult actualResult = + holisticLandmarker.detect(getImageFromAsset(POSE_IMAGE)); + HolisticLandmarkerResult expectedResult = + getExpectedHolisticLandmarkerResult( + HOLISTIC_RESULT, /* hasFaceBlendshapes= */ false, /* hasSegmentationMask= */ true); + assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult); + } + + @Test + public void detect_successWithEmptyResult() throws Exception { + HolisticLandmarkerOptions options = + HolisticLandmarkerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE) + .build()) + .build(); + HolisticLandmarker holisticLandmarker = + HolisticLandmarker.createFromOptions( + ApplicationProvider.getApplicationContext(), options); + HolisticLandmarkerResult actualResult = + holisticLandmarker.detect(getImageFromAsset(CAT_IMAGE)); + assertThat(actualResult.faceLandmarks()).isEmpty(); + } + + @Test + public void detect_failsWithRegionOfInterest() throws Exception { + HolisticLandmarkerOptions options = + HolisticLandmarkerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE) + .build()) + .build(); + HolisticLandmarker holisticLandmarker = + HolisticLandmarker.createFromOptions( + ApplicationProvider.getApplicationContext(), options); + ImageProcessingOptions imageProcessingOptions = + ImageProcessingOptions.builder().setRegionOfInterest(new RectF(0, 0, 1, 1)).build(); + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + holisticLandmarker.detect(getImageFromAsset(POSE_IMAGE), imageProcessingOptions)); + assertThat(exception) + .hasMessageThat() + .contains("HolisticLandmarker doesn't support region-of-interest"); + } + } + + @RunWith(AndroidJUnit4.class) + public static final class RunningModeTest extends HolisticLandmarkerTest { + private void assertCreationFailsWithResultListenerInNonLiveStreamMode(RunningMode runningMode) + throws Exception { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + HolisticLandmarkerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE) + .build()) + .setRunningMode(runningMode) + .setResultListener((HolisticLandmarkerResult, inputImage) -> {}) + .build()); + assertThat(exception) + .hasMessageThat() + .contains("a user-defined result listener shouldn't be provided"); + } + + @Test + public void create_failsWithIllegalResultListenerInVideoMode() throws Exception { + assertCreationFailsWithResultListenerInNonLiveStreamMode(RunningMode.VIDEO); + } + + @Test + public void create_failsWithIllegalResultListenerInImageMode() throws Exception { + assertCreationFailsWithResultListenerInNonLiveStreamMode(RunningMode.IMAGE); + } + + @Test + public void create_failsWithMissingResultListenerInLiveSteamMode() throws Exception { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + HolisticLandmarkerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE) + .build()) + .setRunningMode(RunningMode.LIVE_STREAM) + .build()); + assertThat(exception) + .hasMessageThat() + .contains("a user-defined result listener must be provided"); + } + + @Test + public void detect_failsWithCallingWrongApiInImageMode() throws Exception { + HolisticLandmarkerOptions options = + HolisticLandmarkerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE) + .build()) + .setRunningMode(RunningMode.IMAGE) + .build(); + + HolisticLandmarker holisticLandmarker = + HolisticLandmarker.createFromOptions( + ApplicationProvider.getApplicationContext(), options); + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> + holisticLandmarker.detectForVideo( + getImageFromAsset(POSE_IMAGE), /* timestampsMs= */ 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); + exception = + assertThrows( + MediaPipeException.class, + () -> + holisticLandmarker.detectAsync( + getImageFromAsset(POSE_IMAGE), /* timestampsMs= */ 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); + } + + @Test + public void detect_failsWithCallingWrongApiInVideoMode() throws Exception { + HolisticLandmarkerOptions options = + HolisticLandmarkerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE) + .build()) + .setRunningMode(RunningMode.VIDEO) + .build(); + + HolisticLandmarker holisticLandmarker = + HolisticLandmarker.createFromOptions( + ApplicationProvider.getApplicationContext(), options); + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> holisticLandmarker.detect(getImageFromAsset(POSE_IMAGE))); + assertThat(exception).hasMessageThat().contains("not initialized with the image mode"); + exception = + assertThrows( + MediaPipeException.class, + () -> + holisticLandmarker.detectAsync( + getImageFromAsset(POSE_IMAGE), /* timestampsMs= */ 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); + } + + @Test + public void detect_failsWithCallingWrongApiInLiveSteamMode() throws Exception { + HolisticLandmarkerOptions options = + HolisticLandmarkerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE) + .build()) + .setRunningMode(RunningMode.LIVE_STREAM) + .setResultListener((HolisticLandmarkerResult, inputImage) -> {}) + .build(); + + HolisticLandmarker holisticLandmarker = + HolisticLandmarker.createFromOptions( + ApplicationProvider.getApplicationContext(), options); + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> holisticLandmarker.detect(getImageFromAsset(POSE_IMAGE))); + assertThat(exception).hasMessageThat().contains("not initialized with the image mode"); + exception = + assertThrows( + MediaPipeException.class, + () -> + holisticLandmarker.detectForVideo( + getImageFromAsset(POSE_IMAGE), /* timestampsMs= */ 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); + } + + @Test + public void detect_successWithImageMode() throws Exception { + HolisticLandmarkerOptions options = + HolisticLandmarkerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE) + .build()) + .setRunningMode(RunningMode.IMAGE) + .build(); + + HolisticLandmarker holisticLandmarker = + HolisticLandmarker.createFromOptions( + ApplicationProvider.getApplicationContext(), options); + HolisticLandmarkerResult actualResult = + holisticLandmarker.detect(getImageFromAsset(POSE_IMAGE)); + HolisticLandmarkerResult expectedResult = + getExpectedHolisticLandmarkerResult( + HOLISTIC_RESULT, /* hasFaceBlendshapes= */ false, /* hasSegmentationMask= */ false); + assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult); + } + + @Test + public void detect_successWithVideoMode() throws Exception { + HolisticLandmarkerOptions options = + HolisticLandmarkerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE) + .build()) + .setRunningMode(RunningMode.VIDEO) + .build(); + HolisticLandmarker holisticLandmarker = + HolisticLandmarker.createFromOptions( + ApplicationProvider.getApplicationContext(), options); + HolisticLandmarkerResult expectedResult = + getExpectedHolisticLandmarkerResult( + HOLISTIC_RESULT, /* hasFaceBlendshapes= */ false, /* hasSegmentationMask= */ false); + for (int i = 0; i < 3; i++) { + HolisticLandmarkerResult actualResult = + holisticLandmarker.detectForVideo(getImageFromAsset(POSE_IMAGE), /* timestampsMs= */ i); + assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult); + } + } + + @Test + public void detect_failsWithOutOfOrderInputTimestamps() throws Exception { + MPImage image = getImageFromAsset(POSE_IMAGE); + HolisticLandmarkerOptions options = + HolisticLandmarkerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE) + .build()) + .setRunningMode(RunningMode.LIVE_STREAM) + .setResultListener((actualResult, inputImage) -> {}) + .build(); + try (HolisticLandmarker holisticLandmarker = + HolisticLandmarker.createFromOptions( + ApplicationProvider.getApplicationContext(), options)) { + holisticLandmarker.detectAsync(image, /* timestampsMs= */ 1); + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> holisticLandmarker.detectAsync(image, /* timestampsMs= */ 0)); + assertThat(exception) + .hasMessageThat() + .contains("having a smaller timestamp than the processed timestamp"); + } + } + + @Test + public void detect_successWithLiveSteamMode() throws Exception { + MPImage image = getImageFromAsset(POSE_IMAGE); + HolisticLandmarkerResult expectedResult = + getExpectedHolisticLandmarkerResult( + HOLISTIC_RESULT, /* hasFaceBlendshapes= */ false, /* hasSegmentationMask= */ false); + HolisticLandmarkerOptions options = + HolisticLandmarkerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE) + .build()) + .setRunningMode(RunningMode.LIVE_STREAM) + .setResultListener( + (actualResult, inputImage) -> { + assertActualResultApproximatelyEqualsToExpectedResult( + actualResult, expectedResult); + assertImageSizeIsExpected(inputImage); + }) + .build(); + try (HolisticLandmarker holisticLandmarker = + HolisticLandmarker.createFromOptions( + ApplicationProvider.getApplicationContext(), options)) { + for (int i = 0; i < 3; i++) { + holisticLandmarker.detectAsync(image, /* timestampsMs= */ i); + } + } + } + } + + private static MPImage getImageFromAsset(String filePath) throws Exception { + AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets(); + InputStream istr = assetManager.open(filePath); + return new BitmapImageBuilder(BitmapFactory.decodeStream(istr)).build(); + } + + private static HolisticLandmarkerResult getExpectedHolisticLandmarkerResult( + String resultPath, boolean hasFaceBlendshapes, boolean hasSegmentationMask) throws Exception { + AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets(); + + HolisticResult holisticResult = HolisticResult.parseFrom(assetManager.open(resultPath)); + + Optional blendshapes = + hasFaceBlendshapes + ? Optional.of(holisticResult.getFaceBlendshapes()) + : Optional.empty(); + Optional segmentationMask = + hasSegmentationMask ? Optional.of(PLACEHOLDER_MASK) : Optional.empty(); + + return HolisticLandmarkerResult.create( + holisticResult.getFaceLandmarks(), + blendshapes, + holisticResult.getPoseLandmarks(), + LandmarkList.getDefaultInstance(), + segmentationMask, + holisticResult.getLeftHandLandmarks(), + LandmarkList.getDefaultInstance(), + holisticResult.getRightHandLandmarks(), + LandmarkList.getDefaultInstance(), + /* timestampMs= */ 0); + } + + private static void assertActualResultApproximatelyEqualsToExpectedResult( + HolisticLandmarkerResult actualResult, HolisticLandmarkerResult expectedResult) { + // Expects to have the same number of holistics detected. + assertThat(actualResult.faceLandmarks()).hasSize(expectedResult.faceLandmarks().size()); + assertThat(actualResult.faceBlendshapes().isPresent()) + .isEqualTo(expectedResult.faceBlendshapes().isPresent()); + assertThat(actualResult.poseLandmarks()).hasSize(expectedResult.poseLandmarks().size()); + assertThat(actualResult.segmentationMask().isPresent()) + .isEqualTo(expectedResult.segmentationMask().isPresent()); + assertThat(actualResult.leftHandLandmarks()).hasSize(expectedResult.leftHandLandmarks().size()); + assertThat(actualResult.rightHandLandmarks()) + .hasSize(expectedResult.rightHandLandmarks().size()); + + // Actual face landmarks match expected face landmarks. + assertThat(actualResult.faceLandmarks()) + .comparingElementsUsing(VALIDATE_LANDMARRKS) + .containsExactlyElementsIn(expectedResult.faceLandmarks()); + + // Actual face blendshapes match expected face blendshapes. + if (actualResult.faceBlendshapes().isPresent()) { + assertThat(actualResult.faceBlendshapes().get()) + .comparingElementsUsing(VALIDATE_BLENDSHAPES) + .containsExactlyElementsIn(expectedResult.faceBlendshapes().get()); + } + + // Actual pose landmarks match expected pose landmarks. + assertThat(actualResult.poseLandmarks()) + .comparingElementsUsing(VALIDATE_LANDMARRKS) + .containsExactlyElementsIn(expectedResult.poseLandmarks()); + + if (actualResult.segmentationMask().isPresent()) { + assertImageSizeIsExpected(actualResult.segmentationMask().get()); + } + + // Actual left hand landmarks match expected left hand landmarks. + assertThat(actualResult.leftHandLandmarks()) + .comparingElementsUsing(VALIDATE_LANDMARRKS) + .containsExactlyElementsIn(expectedResult.leftHandLandmarks()); + + // Actual right hand landmarks match expected right hand landmarks. + assertThat(actualResult.rightHandLandmarks()) + .comparingElementsUsing(VALIDATE_LANDMARRKS) + .containsExactlyElementsIn(expectedResult.rightHandLandmarks()); + } + + private static void assertImageSizeIsExpected(MPImage inputImage) { + assertThat(inputImage).isNotNull(); + assertThat(inputImage.getWidth()).isEqualTo(IMAGE_WIDTH); + assertThat(inputImage.getHeight()).isEqualTo(IMAGE_HEIGHT); + } +} diff --git a/mediapipe/tasks/testdata/vision/BUILD b/mediapipe/tasks/testdata/vision/BUILD index 3f83118b0..422241081 100644 --- a/mediapipe/tasks/testdata/vision/BUILD +++ b/mediapipe/tasks/testdata/vision/BUILD @@ -224,6 +224,7 @@ filegroup( "hand_detector_result_one_hand.pbtxt", "hand_detector_result_one_hand_rotated.pbtxt", "hand_detector_result_two_hands.pbtxt", + "male_full_height_hands_result_cpu.pbtxt", "pointing_up_landmarks.pbtxt", "pointing_up_rotated_landmarks.pbtxt", "portrait_expected_detection.pbtxt",