diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD index ddff069af..be5f8a240 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD @@ -46,6 +46,7 @@ cc_binary( "//mediapipe/calculators/core:flow_limiter_calculator", "//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni", "//mediapipe/tasks/cc/vision/face_detector:face_detector_graph", + "//mediapipe/tasks/cc/vision/face_landmarker:face_landmarker_graph", "//mediapipe/tasks/cc/vision/gesture_recognizer:gesture_recognizer_graph", "//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph", "//mediapipe/tasks/cc/vision/image_embedder:image_embedder_graph", @@ -289,6 +290,37 @@ android_library( ], ) +android_library( + name = "facelandmarker", + srcs = [ + "facelandmarker/FaceLandmarker.java", + "facelandmarker/FaceLandmarkerResult.java", + ], + javacopts = [ + "-Xep:AndroidJdkLibsChecker:OFF", + ], + manifest = "facedetector/AndroidManifest.xml", + deps = [ + ":core", + "//mediapipe/framework:calculator_options_java_proto_lite", + "//mediapipe/framework/formats:classification_java_proto_lite", + "//mediapipe/framework/formats:landmark_java_proto_lite", + "//mediapipe/framework/formats:matrix_data_java_proto_lite", + "//mediapipe/java/com/google/mediapipe/framework:android_framework", + "//mediapipe/java/com/google/mediapipe/framework/image", + "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/face_geometry/proto:face_geometry_java_proto_lite", + "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarker_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_java_proto_lite", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:normalized_landmark", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", + "//third_party:autovalue", + "@maven//:com_google_guava_guava", + ], +) + load("//mediapipe/tasks/java/com/google/mediapipe/tasks:mediapipe_tasks_aar.bzl", "mediapipe_tasks_vision_aar") mediapipe_tasks_vision_aar( diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facelandmarker/AndroidManifest.xml b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facelandmarker/AndroidManifest.xml new file mode 100644 index 000000000..26ea44cef --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facelandmarker/AndroidManifest.xml @@ -0,0 +1,8 @@ + + + + + + diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facelandmarker/FaceLandmarker.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facelandmarker/FaceLandmarker.java new file mode 100644 index 000000000..599113aa2 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facelandmarker/FaceLandmarker.java @@ -0,0 +1,550 @@ +// Copyright 2023 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.facelandmarker; + +import android.content.Context; +import android.os.ParcelFileDescriptor; +import com.google.auto.value.AutoValue; +import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmarkList; +import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions; +import com.google.mediapipe.formats.proto.ClassificationProto.ClassificationList; +import com.google.mediapipe.framework.AndroidPacketGetter; +import com.google.mediapipe.framework.Packet; +import com.google.mediapipe.framework.PacketGetter; +import com.google.mediapipe.framework.image.BitmapImageBuilder; +import com.google.mediapipe.framework.image.MPImage; +import com.google.mediapipe.tasks.core.BaseOptions; +import com.google.mediapipe.tasks.core.ErrorListener; +import com.google.mediapipe.tasks.core.OutputHandler; +import com.google.mediapipe.tasks.core.OutputHandler.ResultListener; +import com.google.mediapipe.tasks.core.TaskInfo; +import com.google.mediapipe.tasks.core.TaskOptions; +import com.google.mediapipe.tasks.core.TaskRunner; +import com.google.mediapipe.tasks.core.proto.BaseOptionsProto; +import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi; +import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; +import com.google.mediapipe.tasks.vision.core.RunningMode; +import com.google.mediapipe.tasks.vision.facedetector.proto.FaceDetectorGraphOptionsProto; +import com.google.mediapipe.tasks.vision.facegeometry.proto.FaceGeometryProto.FaceGeometry; +import com.google.mediapipe.tasks.vision.facelandmarker.proto.FaceLandmarkerGraphOptionsProto; +import com.google.mediapipe.tasks.vision.facelandmarker.proto.FaceLandmarksDetectorGraphOptionsProto; +import com.google.mediapipe.formats.proto.MatrixDataProto.MatrixData; +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +/** + * Performs face landmarks detection on images. + * + *

This API expects a pre-trained face landmarks model asset bundle. See . + * + *

+ */ +public final class FaceLandmarker extends BaseVisionTaskApi { + private static final String TAG = FaceLandmarker.class.getSimpleName(); + private static final String IMAGE_IN_STREAM_NAME = "image_in"; + private static final String NORM_RECT_IN_STREAM_NAME = "norm_rect_in"; + + @SuppressWarnings("ConstantCaseForConstants") + private static final List INPUT_STREAMS = + Collections.unmodifiableList( + Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME)); + + private static final int LANDMARKS_OUT_STREAM_INDEX = 0; + private static final int IMAGE_OUT_STREAM_INDEX = 1; + private static int blendshapesOutStreamIndex = -1; + private static int faceGeometryOutStreamIndex = -1; + private static final String TASK_GRAPH_NAME = + "mediapipe.tasks.vision.face_landmarker.FaceLandmarkerGraph"; + + /** + * Creates a {@link FaceLandmarker} instance from a model asset bundle path and the default {@link + * FaceLandmarkerOptions}. + * + * @param context an Android {@link Context}. + * @param modelAssetPath path to the face landmarks model with metadata in the assets. + * @throws MediaPipeException if there is an error during {@link FaceLandmarker} creation. + */ + public static FaceLandmarker createFromFile(Context context, String modelAssetPath) { + BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(modelAssetPath).build(); + return createFromOptions( + context, FaceLandmarkerOptions.builder().setBaseOptions(baseOptions).build()); + } + + /** + * Creates a {@link FaceLandmarker} instance from a model asset bundle file and the default {@link + * FaceLandmarkerOptions}. + * + * @param context an Android {@link Context}. + * @param modelAssetFile the face landmarks model {@link File} instance. + * @throws IOException if an I/O error occurs when opening the tflite model file. + * @throws MediaPipeException if there is an error during {@link FaceLandmarker} creation. + */ + public static FaceLandmarker createFromFile(Context context, File modelAssetFile) + throws IOException { + try (ParcelFileDescriptor descriptor = + ParcelFileDescriptor.open(modelAssetFile, ParcelFileDescriptor.MODE_READ_ONLY)) { + BaseOptions baseOptions = + BaseOptions.builder().setModelAssetFileDescriptor(descriptor.getFd()).build(); + return createFromOptions( + context, FaceLandmarkerOptions.builder().setBaseOptions(baseOptions).build()); + } + } + + /** + * Creates a {@link FaceLandmarker} instance from a model asset bundle buffer and the default + * {@link FaceLandmarkerOptions}. + * + * @param context an Android {@link Context}. + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the detection + * model. + * @throws MediaPipeException if there is an error during {@link FaceLandmarker} creation. + */ + public static FaceLandmarker createFromBuffer( + Context context, final ByteBuffer modelAssetBuffer) { + BaseOptions baseOptions = BaseOptions.builder().setModelAssetBuffer(modelAssetBuffer).build(); + return createFromOptions( + context, FaceLandmarkerOptions.builder().setBaseOptions(baseOptions).build()); + } + + /** + * Creates a {@link FaceLandmarker} instance from a {@link FaceLandmarkerOptions}. + * + * @param context an Android {@link Context}. + * @param landmarkerOptions a {@link FaceLandmarkerOptions} instance. + * @throws MediaPipeException if there is an error during {@link FaceLandmarker} creation. + */ + public static FaceLandmarker createFromOptions( + Context context, FaceLandmarkerOptions landmarkerOptions) { + List outputStreams = new ArrayList<>(); + outputStreams.add("NORM_LANDMARKS:face_landmarks"); + outputStreams.add("IMAGE:image_out"); + if (landmarkerOptions.outputFaceBlendshapes()) { + outputStreams.add("BLENDSHAPES:face_blendshapes"); + blendshapesOutStreamIndex = outputStreams.size() - 1; + } + if (landmarkerOptions.outputFacialTransformationMatrixes()) { + outputStreams.add("FACE_GEOMETRY:face_geometry"); + faceGeometryOutStreamIndex = outputStreams.size() - 1; + } + // TODO: Consolidate OutputHandler and TaskRunner. + OutputHandler handler = new OutputHandler<>(); + handler.setOutputPacketConverter( + new OutputHandler.OutputPacketConverter() { + @Override + public FaceLandmarkerResult convertToTaskResult(List packets) { + // If there is no faces detected in the image, just returns empty lists. + if (packets.get(LANDMARKS_OUT_STREAM_INDEX).isEmpty()) { + return FaceLandmarkerResult.create( + new ArrayList<>(), + Optional.empty(), + Optional.empty(), + BaseVisionTaskApi.generateResultTimestampMs( + landmarkerOptions.runningMode(), packets.get(LANDMARKS_OUT_STREAM_INDEX))); + } + + Optional> blendshapes = Optional.empty(); + if (landmarkerOptions.outputFaceBlendshapes()) { + blendshapes = + Optional.of( + PacketGetter.getProtoVector( + packets.get(blendshapesOutStreamIndex), ClassificationList.parser())); + } + + Optional> facialTransformationMatrixes = Optional.empty(); + if (landmarkerOptions.outputFacialTransformationMatrixes()) { + List faceGeometryList = + PacketGetter.getProtoVector( + packets.get(faceGeometryOutStreamIndex), FaceGeometry.parser()); + facialTransformationMatrixes = Optional.of(new ArrayList<>()); + for (FaceGeometry faceGeometry : faceGeometryList) { + facialTransformationMatrixes.get().add(faceGeometry.getPoseTransformMatrix()); + } + } + + return FaceLandmarkerResult.create( + PacketGetter.getProtoVector( + packets.get(LANDMARKS_OUT_STREAM_INDEX), NormalizedLandmarkList.parser()), + blendshapes, + facialTransformationMatrixes, + BaseVisionTaskApi.generateResultTimestampMs( + landmarkerOptions.runningMode(), packets.get(LANDMARKS_OUT_STREAM_INDEX))); + } + + @Override + public MPImage convertToTaskInput(List packets) { + return new BitmapImageBuilder( + AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX))) + .build(); + } + }); + landmarkerOptions.resultListener().ifPresent(handler::setResultListener); + landmarkerOptions.errorListener().ifPresent(handler::setErrorListener); + TaskRunner runner = + TaskRunner.create( + context, + TaskInfo.builder() + .setTaskName(FaceLandmarker.class.getSimpleName()) + .setTaskRunningModeName(landmarkerOptions.runningMode().name()) + .setTaskGraphName(TASK_GRAPH_NAME) + .setInputStreams(INPUT_STREAMS) + .setOutputStreams(outputStreams) + .setTaskOptions(landmarkerOptions) + .setEnableFlowLimiting(landmarkerOptions.runningMode() == RunningMode.LIVE_STREAM) + .build(), + handler); + return new FaceLandmarker(runner, landmarkerOptions.runningMode()); + } + + /** + * Constructor to initialize an {@link FaceLandmarker} from a {@link TaskRunner} and a {@link + * RunningMode}. + * + * @param taskRunner a {@link TaskRunner}. + * @param runningMode a mediapipe vision task {@link RunningMode}. + */ + private FaceLandmarker(TaskRunner taskRunner, RunningMode runningMode) { + super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME); + } + + /** + * Performs face landmarks detection on the provided single image with default image processing + * options, i.e. without any rotation applied. Only use this method when the {@link + * FaceLandmarker} is created with {@link RunningMode.IMAGE}. TODO update java doc + * for input image format. + * + *

{@link FaceLandmarker} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @throws MediaPipeException if there is an internal error. + */ + public FaceLandmarkerResult detect(MPImage image) { + return detect(image, ImageProcessingOptions.builder().build()); + } + + /** + * Performs face landmarks detection on the provided single image. Only use this method when the + * {@link FaceLandmarker} is created with {@link RunningMode.IMAGE}. TODO update java + * doc for input image format. + * + *

{@link FaceLandmarker} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. Note that region-of-interest is not supported + * by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in + * this method throwing an IllegalArgumentException. + * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a + * region-of-interest. + * @throws MediaPipeException if there is an internal error. + */ + public FaceLandmarkerResult detect(MPImage image, ImageProcessingOptions imageProcessingOptions) { + validateImageProcessingOptions(imageProcessingOptions); + return (FaceLandmarkerResult) processImageData(image, imageProcessingOptions); + } + + /** + * Performs face landmarks detection on the provided video frame with default image processing + * options, i.e. without any rotation applied. Only use this method when the {@link + * FaceLandmarker} is created with {@link RunningMode.VIDEO}. + * + *

It's required to provide the video frame's timestamp (in milliseconds). The input timestamps + * must be monotonically increasing. + * + *

{@link FaceLandmarker} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param timestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public FaceLandmarkerResult detectForVideo(MPImage image, long timestampMs) { + return detectForVideo(image, ImageProcessingOptions.builder().build(), timestampMs); + } + + /** + * Performs face landmarks detection on the provided video frame. Only use this method when the + * {@link FaceLandmarker} is created with {@link RunningMode.VIDEO}. + * + *

It's required to provide the video frame's timestamp (in milliseconds). The input timestamps + * must be monotonically increasing. + * + *

{@link FaceLandmarker} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. Note that region-of-interest is not supported + * by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in + * this method throwing an IllegalArgumentException. + * @param timestampMs the input timestamp (in milliseconds). + * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a + * region-of-interest. + * @throws MediaPipeException if there is an internal error. + */ + public FaceLandmarkerResult detectForVideo( + MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { + validateImageProcessingOptions(imageProcessingOptions); + return (FaceLandmarkerResult) processVideoData(image, imageProcessingOptions, timestampMs); + } + + /** + * Sends live image data to perform face landmarks detection with default image processing + * options, i.e. without any rotation applied, and the results will be available via the {@link + * ResultListener} provided in the {@link FaceLandmarkerOptions}. Only use this method when the + * {@link FaceLandmarker } is created with {@link RunningMode.LIVE_STREAM}. + * + *

It's required to provide a timestamp (in milliseconds) to indicate when the input image is + * sent to the face landmarker. The input timestamps must be monotonically increasing. + * + *

{@link FaceLandmarker} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param timestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public void detectAsync(MPImage image, long timestampMs) { + detectAsync(image, ImageProcessingOptions.builder().build(), timestampMs); + } + + /** + * Sends live image data to perform face landmarks detection, and the results will be available + * via the {@link ResultListener} provided in the {@link FaceLandmarkerOptions}. Only use this + * method when the {@link FaceLandmarker} is created with {@link RunningMode.LIVE_STREAM}. + * + *

It's required to provide a timestamp (in milliseconds) to indicate when the input image is + * sent to the face landmarker. The input timestamps must be monotonically increasing. + * + *

{@link FaceLandmarker} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. Note that region-of-interest is not supported + * by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in + * this method throwing an IllegalArgumentException. + * @param timestampMs the input timestamp (in milliseconds). + * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a + * region-of-interest. + * @throws MediaPipeException if there is an internal error. + */ + public void detectAsync( + MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { + validateImageProcessingOptions(imageProcessingOptions); + sendLiveStreamData(image, imageProcessingOptions, timestampMs); + } + + /** Options for setting up an {@link FaceLandmarker}. */ + @AutoValue + public abstract static class FaceLandmarkerOptions extends TaskOptions { + + /** Builder for {@link FaceLandmarkerOptions}. */ + @AutoValue.Builder + public abstract static class Builder { + /** Sets the base options for the face landmarker task. */ + public abstract Builder setBaseOptions(BaseOptions value); + + /** + * Sets the running mode for the face landmarker task. Default to the image mode. Hand + * landmarker has three modes: + * + *
    + *
  • IMAGE: The mode for detecting face landmarks on single image inputs. + *
  • VIDEO: The mode for detecting face landmarks on the decoded frames of a video. + *
  • LIVE_STREAM: The mode for for detecting face landmarks on a live stream of input + * data, such as from camera. In this mode, {@code setResultListener} must be called to + * set up a listener to receive the detection results asynchronously. + *
+ */ + public abstract Builder setRunningMode(RunningMode value); + + /** Sets the maximum number of faces can be detected by the FaceLandmarker. */ + public abstract Builder setNumFaces(Integer value); + + /** Sets minimum confidence score for the face detection to be considered successful */ + public abstract Builder setMinFaceDetectionConfidence(Float value); + + /** Sets minimum confidence score of face presence score in the face landmark detection. */ + public abstract Builder setMinFacePresenceConfidence(Float value); + + /** Sets the minimum confidence score for the face tracking to be considered successful. */ + public abstract Builder setMinTrackingConfidence(Float value); + + /** + * Whether FaceLandmarker outputs face blendshapes classification. Face blendshapes are used + * for rendering the 3D face model. + */ + public abstract Builder setOutputFaceBlendshapes(Boolean value); + + /** + * Whether FaceLandmarker outptus facial transformation_matrix. Facial transformation matrix + * is used to transform the face landmarks in canonical face to the detected face, so that + * users can apply face effects on the detected landmarks. + */ + public abstract Builder setOutputFacialTransformationMatrixes(Boolean value); + + /** + * Sets the result listener to receive the detection results asynchronously when the face + * landmarker is in the live stream mode. + */ + public abstract Builder setResultListener( + ResultListener value); + + /** Sets an optional error listener. */ + public abstract Builder setErrorListener(ErrorListener value); + + abstract FaceLandmarkerOptions autoBuild(); + + /** + * Validates and builds the {@link FaceLandmarkerOptions} instance. + * + * @throws IllegalArgumentException if the result listener and the running mode are not + * properly configured. The result listener should only be set when the face landmarker is + * in the live stream mode. + */ + public final FaceLandmarkerOptions build() { + FaceLandmarkerOptions options = autoBuild(); + if (options.runningMode() == RunningMode.LIVE_STREAM) { + if (!options.resultListener().isPresent()) { + throw new IllegalArgumentException( + "The face landmarker is in the live stream mode, a user-defined result listener" + + " must be provided in FaceLandmarkerOptions."); + } + } else if (options.resultListener().isPresent()) { + throw new IllegalArgumentException( + "The face landmarker is in the image or the video mode, a user-defined result" + + " listener shouldn't be provided in FaceLandmarkerOptions."); + } + return options; + } + } + + abstract BaseOptions baseOptions(); + + abstract RunningMode runningMode(); + + abstract Optional numFaces(); + + abstract Optional minFaceDetectionConfidence(); + + abstract Optional minFacePresenceConfidence(); + + abstract Optional minTrackingConfidence(); + + abstract Boolean outputFaceBlendshapes(); + + abstract Boolean outputFacialTransformationMatrixes(); + + abstract Optional> resultListener(); + + abstract Optional errorListener(); + + public static Builder builder() { + return new AutoValue_FaceLandmarker_FaceLandmarkerOptions.Builder() + .setRunningMode(RunningMode.IMAGE) + .setNumFaces(1) + .setMinFaceDetectionConfidence(0.5f) + .setMinFacePresenceConfidence(0.5f) + .setMinTrackingConfidence(0.5f) + .setOutputFaceBlendshapes(false) + .setOutputFacialTransformationMatrixes(false); + } + + /** Converts a {@link FaceLandmarkerOptions} to a {@link CalculatorOptions} protobuf message. */ + @Override + public CalculatorOptions convertToCalculatorOptionsProto() { + FaceLandmarkerGraphOptionsProto.FaceLandmarkerGraphOptions.Builder taskOptionsBuilder = + FaceLandmarkerGraphOptionsProto.FaceLandmarkerGraphOptions.newBuilder() + .setBaseOptions( + BaseOptionsProto.BaseOptions.newBuilder() + .setUseStreamMode(runningMode() != RunningMode.IMAGE) + .mergeFrom(convertBaseOptionsToProto(baseOptions())) + .build()); + + // Setup FaceDetectorGraphOptions. + FaceDetectorGraphOptionsProto.FaceDetectorGraphOptions.Builder + faceDetectorGraphOptionsBuilder = + FaceDetectorGraphOptionsProto.FaceDetectorGraphOptions.newBuilder(); + numFaces().ifPresent(faceDetectorGraphOptionsBuilder::setNumFaces); + minFaceDetectionConfidence() + .ifPresent(faceDetectorGraphOptionsBuilder::setMinDetectionConfidence); + + // Setup FaceLandmarkerGraphOptions. + FaceLandmarksDetectorGraphOptionsProto.FaceLandmarksDetectorGraphOptions.Builder + faceLandmarksDetectorGraphOptionsBuilder = + FaceLandmarksDetectorGraphOptionsProto.FaceLandmarksDetectorGraphOptions.newBuilder(); + minFacePresenceConfidence() + .ifPresent(faceLandmarksDetectorGraphOptionsBuilder::setMinDetectionConfidence); + minTrackingConfidence().ifPresent(taskOptionsBuilder::setMinTrackingConfidence); + + taskOptionsBuilder + .setFaceDetectorGraphOptions(faceDetectorGraphOptionsBuilder.build()) + .setFaceLandmarksDetectorGraphOptions(faceLandmarksDetectorGraphOptionsBuilder.build()); + + return CalculatorOptions.newBuilder() + .setExtension( + FaceLandmarkerGraphOptionsProto.FaceLandmarkerGraphOptions.ext, + taskOptionsBuilder.build()) + .build(); + } + } + + /** + * Validates that the provided {@link ImageProcessingOptions} doesn't contain a + * region-of-interest. + */ + private static void validateImageProcessingOptions( + ImageProcessingOptions imageProcessingOptions) { + if (imageProcessingOptions.regionOfInterest().isPresent()) { + throw new IllegalArgumentException("FaceLandmarker doesn't support region-of-interest."); + } + } +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facelandmarker/FaceLandmarkerResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facelandmarker/FaceLandmarkerResult.java new file mode 100644 index 000000000..6493a6b5f --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facelandmarker/FaceLandmarkerResult.java @@ -0,0 +1,115 @@ +// Copyright 2023 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.facelandmarker; + +import com.google.auto.value.AutoValue; +import com.google.mediapipe.formats.proto.LandmarkProto; +import com.google.mediapipe.formats.proto.ClassificationProto.Classification; +import com.google.mediapipe.formats.proto.ClassificationProto.ClassificationList; +import com.google.mediapipe.tasks.components.containers.Category; +import com.google.mediapipe.tasks.components.containers.NormalizedLandmark; +import com.google.mediapipe.tasks.core.TaskResult; +import com.google.mediapipe.formats.proto.MatrixDataProto.MatrixData; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +/** Represents the face landmarks detection results generated by {@link FaceLandmarker}. */ +@AutoValue +public abstract class FaceLandmarkerResult implements TaskResult { + + /** + * Creates a {@link FaceLandmarkerResult} instance from the list of landmarks, list of face + * blendshapes classification, and list of facial transformation matrixes protobuf message. + * + * @param multiFaceLandmarksProto a List of {@link NormalizedLandmarkList} + * @param multiFaceBendshapesProto an Optional List of {@link ClassificationList} + * @param multiFaceTransformationMatrixesProto an Optional List of {@link MatrixData} + * @throws IllegalArgumentException if there is error creating {@link FaceLandmarkerResult} + */ + static FaceLandmarkerResult create( + List multiFaceLandmarksProto, + Optional> multiFaceBendshapesProto, + Optional> multiFaceTransformationMatrixesProto, + long timestampMs) { + List> multiFaceLandmarks = new ArrayList<>(); + for (LandmarkProto.NormalizedLandmarkList faceLandmarksProto : multiFaceLandmarksProto) { + List faceLandmarks = new ArrayList<>(); + multiFaceLandmarks.add(faceLandmarks); + for (LandmarkProto.NormalizedLandmark faceLandmarkProto : + faceLandmarksProto.getLandmarkList()) { + faceLandmarks.add( + NormalizedLandmark.create( + faceLandmarkProto.getX(), faceLandmarkProto.getY(), faceLandmarkProto.getZ())); + } + } + Optional>> multiFaceBlendshapes = Optional.empty(); + if (!multiFaceBendshapesProto.isEmpty()) { + List> blendshapes = new ArrayList<>(); + for (ClassificationList faceBendshapeProto : multiFaceBendshapesProto.get()) { + List blendshape = new ArrayList<>(); + blendshapes.add(blendshape); + for (Classification classification : faceBendshapeProto.getClassificationList()) { + blendshape.add( + Category.create( + classification.getScore(), + classification.getIndex(), + classification.getLabel(), + classification.getDisplayName())); + } + } + multiFaceBlendshapes = Optional.of(Collections.unmodifiableList(blendshapes)); + } + Optional> multiFaceTransformationMatrixes = Optional.empty(); + if (!multiFaceTransformationMatrixesProto.isEmpty()) { + List matrixes = new ArrayList<>(); + for (MatrixData matrixProto : multiFaceTransformationMatrixesProto.get()) { + if (matrixProto.getPackedDataCount() != 16) { + throw new IllegalArgumentException( + "MatrixData must contain 4x4 matrix as a size 16 float array, but get size " + + matrixProto.getPackedDataCount() + + " float array."); + } + float[] matrixData = new float[matrixProto.getPackedDataCount()]; + for (int i = 0; i < matrixData.length; i++) { + matrixData[i] = matrixProto.getPackedData(i); + } + matrixes.add(matrixData); + } + multiFaceTransformationMatrixes = Optional.of(Collections.unmodifiableList(matrixes)); + } + return new AutoValue_FaceLandmarkerResult( + timestampMs, + Collections.unmodifiableList(multiFaceLandmarks), + multiFaceBlendshapes, + multiFaceTransformationMatrixes); + } + + @Override + public abstract long timestampMs(); + + /** Face landmarks of detected faces. */ + public abstract List> faceLandmarks(); + + /** Optional face blendshapes classifications. */ + public abstract Optional>> faceBlendshapes(); + + /** + * Optional facial transformation matrix list from cannonical face to the detected face landmarks. + * The 4x4 facial transformation matrix is represetned as a flat column-major float array. + */ + public abstract Optional> facialTransformationMatrixes(); +} diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/facelandmarker/AndroidManifest.xml b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/facelandmarker/AndroidManifest.xml new file mode 100644 index 000000000..7bf30e28e --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/facelandmarker/AndroidManifest.xml @@ -0,0 +1,24 @@ + + + + + + + + + + + + + + + diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/facelandmarker/BUILD b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/facelandmarker/BUILD new file mode 100644 index 000000000..c14486766 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/facelandmarker/BUILD @@ -0,0 +1,19 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +# TODO: Enable this in OSS diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/facelandmarker/FaceLandmarkerTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/facelandmarker/FaceLandmarkerTest.java new file mode 100644 index 000000000..4f8de86d8 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/facelandmarker/FaceLandmarkerTest.java @@ -0,0 +1,514 @@ +// Copyright 2023 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.facelandmarker; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import android.content.res.AssetManager; +import android.graphics.BitmapFactory; +import android.graphics.RectF; +import androidx.test.core.app.ApplicationProvider; +import androidx.test.ext.junit.runners.AndroidJUnit4; +import com.google.common.truth.Correspondence; +import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmarkList; +import com.google.mediapipe.formats.proto.ClassificationProto.ClassificationList; +import com.google.mediapipe.framework.MediaPipeException; +import com.google.mediapipe.framework.image.BitmapImageBuilder; +import com.google.mediapipe.framework.image.MPImage; +import com.google.mediapipe.tasks.components.containers.Category; +import com.google.mediapipe.tasks.components.containers.NormalizedLandmark; +import com.google.mediapipe.tasks.core.BaseOptions; +import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; +import com.google.mediapipe.tasks.vision.core.RunningMode; +import com.google.mediapipe.tasks.vision.facegeometry.proto.FaceGeometryProto.FaceGeometry; +import com.google.mediapipe.tasks.vision.facelandmarker.FaceLandmarker.FaceLandmarkerOptions; +import com.google.mediapipe.formats.proto.MatrixDataProto.MatrixData; +import java.io.InputStream; +import java.util.Arrays; +import java.util.List; +import java.util.Optional; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Suite; +import org.junit.runners.Suite.SuiteClasses; + +/** Test for {@link FaceLandmarker}. */ +@RunWith(Suite.class) +@SuiteClasses({FaceLandmarkerTest.General.class, FaceLandmarkerTest.RunningModeTest.class}) +public class FaceLandmarkerTest { + private static final String FACE_LANDMARKER_BUNDLE_ASSET_FILE = + "face_landmarker_with_blendshapes.task"; + private static final String PORTRAIT_IMAGE = "portrait.jpg"; + private static final String CAT_IMAGE = "cat.jpg"; + private static final String PORTRAIT_FACE_LANDMARKS = + "portrait_expected_face_landmarks_with_attention.pb"; + private static final String PORTRAIT_FACE_BLENDSHAPES = + "portrait_expected_blendshapes_with_attention.pb"; + private static final String PORTRAIT_FACE_GEOMETRY = + "portrait_expected_face_geometry_with_attention.pb"; + private static final String TAG = "Face Landmarker Test"; + private static final float FACE_LANDMARKS_ERROR_TOLERANCE = 0.01f; + private static final float FACE_BLENDSHAPES_ERROR_TOLERANCE = 0.1f; + private static final float FACIAL_TRANSFORMATION_MATRIX_ERROR_TOLERANCE = 0.01f; + private static final int IMAGE_WIDTH = 820; + private static final int IMAGE_HEIGHT = 1024; + + @RunWith(AndroidJUnit4.class) + public static final class General extends FaceLandmarkerTest { + + @Test + public void detect_successWithValidModels() throws Exception { + FaceLandmarkerOptions options = + FaceLandmarkerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(FACE_LANDMARKER_BUNDLE_ASSET_FILE) + .build()) + .build(); + FaceLandmarker faceLandmarker = + FaceLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options); + FaceLandmarkerResult actualResult = faceLandmarker.detect(getImageFromAsset(PORTRAIT_IMAGE)); + FaceLandmarkerResult expectedResult = + getExpectedFaceLandmarkerResult( + PORTRAIT_FACE_LANDMARKS, Optional.empty(), Optional.empty()); + assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult); + } + + @Test + public void detect_successWithBlendshapes() throws Exception { + FaceLandmarkerOptions options = + FaceLandmarkerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(FACE_LANDMARKER_BUNDLE_ASSET_FILE) + .build()) + .setOutputFaceBlendshapes(true) + .build(); + FaceLandmarker faceLandmarker = + FaceLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options); + FaceLandmarkerResult actualResult = faceLandmarker.detect(getImageFromAsset(PORTRAIT_IMAGE)); + FaceLandmarkerResult expectedResult = + getExpectedFaceLandmarkerResult( + PORTRAIT_FACE_LANDMARKS, Optional.of(PORTRAIT_FACE_BLENDSHAPES), Optional.empty()); + assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult); + } + + @Test + public void detect_successWithFacialTransformationMatrix() throws Exception { + FaceLandmarkerOptions options = + FaceLandmarkerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(FACE_LANDMARKER_BUNDLE_ASSET_FILE) + .build()) + .setOutputFacialTransformationMatrixes(true) + .build(); + FaceLandmarker faceLandmarker = + FaceLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options); + FaceLandmarkerResult actualResult = faceLandmarker.detect(getImageFromAsset(PORTRAIT_IMAGE)); + FaceLandmarkerResult expectedResult = + getExpectedFaceLandmarkerResult( + PORTRAIT_FACE_LANDMARKS, Optional.empty(), Optional.of(PORTRAIT_FACE_GEOMETRY)); + assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult); + } + + @Test + public void detect_successWithBlendshapesWithFacialTransformationMatrix() throws Exception { + FaceLandmarkerOptions options = + FaceLandmarkerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(FACE_LANDMARKER_BUNDLE_ASSET_FILE) + .build()) + .setOutputFaceBlendshapes(true) + .setOutputFacialTransformationMatrixes(true) + .build(); + FaceLandmarker faceLandmarker = + FaceLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options); + FaceLandmarkerResult actualResult = faceLandmarker.detect(getImageFromAsset(PORTRAIT_IMAGE)); + FaceLandmarkerResult expectedResult = + getExpectedFaceLandmarkerResult( + PORTRAIT_FACE_LANDMARKS, + Optional.of(PORTRAIT_FACE_BLENDSHAPES), + Optional.of(PORTRAIT_FACE_GEOMETRY)); + assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult); + } + + @Test + public void detect_successWithEmptyResult() throws Exception { + FaceLandmarkerOptions options = + FaceLandmarkerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(FACE_LANDMARKER_BUNDLE_ASSET_FILE) + .build()) + .build(); + FaceLandmarker faceLandmarker = + FaceLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options); + FaceLandmarkerResult actualResult = faceLandmarker.detect(getImageFromAsset(CAT_IMAGE)); + assertThat(actualResult.faceLandmarks()).isEmpty(); + } + + @Test + public void detect_failsWithRegionOfInterest() throws Exception { + FaceLandmarkerOptions options = + FaceLandmarkerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(FACE_LANDMARKER_BUNDLE_ASSET_FILE) + .build()) + .setNumFaces(1) + .build(); + FaceLandmarker faceLandmarker = + FaceLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options); + ImageProcessingOptions imageProcessingOptions = + ImageProcessingOptions.builder().setRegionOfInterest(new RectF(0, 0, 1, 1)).build(); + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + faceLandmarker.detect(getImageFromAsset(PORTRAIT_IMAGE), imageProcessingOptions)); + assertThat(exception) + .hasMessageThat() + .contains("FaceLandmarker doesn't support region-of-interest"); + } + } + + @RunWith(AndroidJUnit4.class) + public static final class RunningModeTest extends FaceLandmarkerTest { + @Test + public void create_failsWithIllegalResultListenerInNonLiveStreamMode() throws Exception { + for (RunningMode mode : new RunningMode[] {RunningMode.IMAGE, RunningMode.VIDEO}) { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + FaceLandmarkerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(FACE_LANDMARKER_BUNDLE_ASSET_FILE) + .build()) + .setRunningMode(mode) + .setResultListener((FaceLandmarkerResult, inputImage) -> {}) + .build()); + assertThat(exception) + .hasMessageThat() + .contains("a user-defined result listener shouldn't be provided"); + } + } + + @Test + public void create_failsWithMissingResultListenerInLiveSteamMode() throws Exception { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + FaceLandmarkerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(FACE_LANDMARKER_BUNDLE_ASSET_FILE) + .build()) + .setRunningMode(RunningMode.LIVE_STREAM) + .build()); + assertThat(exception) + .hasMessageThat() + .contains("a user-defined result listener must be provided"); + } + + @Test + public void detect_failsWithCallingWrongApiInImageMode() throws Exception { + FaceLandmarkerOptions options = + FaceLandmarkerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(FACE_LANDMARKER_BUNDLE_ASSET_FILE) + .build()) + .setRunningMode(RunningMode.IMAGE) + .build(); + + FaceLandmarker faceLandmarker = + FaceLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options); + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> + faceLandmarker.detectForVideo( + getImageFromAsset(PORTRAIT_IMAGE), /* timestampsMs= */ 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); + exception = + assertThrows( + MediaPipeException.class, + () -> + faceLandmarker.detectAsync( + getImageFromAsset(PORTRAIT_IMAGE), /* timestampsMs= */ 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); + } + + @Test + public void detect_failsWithCallingWrongApiInVideoMode() throws Exception { + FaceLandmarkerOptions options = + FaceLandmarkerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(FACE_LANDMARKER_BUNDLE_ASSET_FILE) + .build()) + .setRunningMode(RunningMode.VIDEO) + .build(); + + FaceLandmarker faceLandmarker = + FaceLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options); + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> faceLandmarker.detect(getImageFromAsset(PORTRAIT_IMAGE))); + assertThat(exception).hasMessageThat().contains("not initialized with the image mode"); + exception = + assertThrows( + MediaPipeException.class, + () -> + faceLandmarker.detectAsync( + getImageFromAsset(PORTRAIT_IMAGE), /* timestampsMs= */ 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); + } + + @Test + public void detect_failsWithCallingWrongApiInLiveSteamMode() throws Exception { + FaceLandmarkerOptions options = + FaceLandmarkerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(FACE_LANDMARKER_BUNDLE_ASSET_FILE) + .build()) + .setRunningMode(RunningMode.LIVE_STREAM) + .setResultListener((FaceLandmarkerResult, inputImage) -> {}) + .build(); + + FaceLandmarker faceLandmarker = + FaceLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options); + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> faceLandmarker.detect(getImageFromAsset(PORTRAIT_IMAGE))); + assertThat(exception).hasMessageThat().contains("not initialized with the image mode"); + exception = + assertThrows( + MediaPipeException.class, + () -> + faceLandmarker.detectForVideo( + getImageFromAsset(PORTRAIT_IMAGE), /* timestampsMs= */ 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); + } + + @Test + public void detect_successWithImageMode() throws Exception { + FaceLandmarkerOptions options = + FaceLandmarkerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(FACE_LANDMARKER_BUNDLE_ASSET_FILE) + .build()) + .setOutputFaceBlendshapes(true) + .setOutputFacialTransformationMatrixes(true) + .setRunningMode(RunningMode.IMAGE) + .build(); + + FaceLandmarker faceLandmarker = + FaceLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options); + FaceLandmarkerResult actualResult = faceLandmarker.detect(getImageFromAsset(PORTRAIT_IMAGE)); + FaceLandmarkerResult expectedResult = + getExpectedFaceLandmarkerResult( + PORTRAIT_FACE_LANDMARKS, + Optional.of(PORTRAIT_FACE_BLENDSHAPES), + Optional.of(PORTRAIT_FACE_GEOMETRY)); + assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult); + } + + @Test + public void detect_successWithVideoMode() throws Exception { + FaceLandmarkerOptions options = + FaceLandmarkerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(FACE_LANDMARKER_BUNDLE_ASSET_FILE) + .build()) + .setOutputFaceBlendshapes(true) + .setRunningMode(RunningMode.VIDEO) + .build(); + FaceLandmarker faceLandmarker = + FaceLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options); + FaceLandmarkerResult expectedResult = + getExpectedFaceLandmarkerResult( + PORTRAIT_FACE_LANDMARKS, Optional.of(PORTRAIT_FACE_BLENDSHAPES), Optional.empty()); + for (int i = 0; i < 3; i++) { + FaceLandmarkerResult actualResult = + faceLandmarker.detectForVideo(getImageFromAsset(PORTRAIT_IMAGE), /* timestampsMs= */ i); + assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult); + } + } + + @Test + public void detect_failsWithOutOfOrderInputTimestamps() throws Exception { + MPImage image = getImageFromAsset(PORTRAIT_IMAGE); + FaceLandmarkerResult expectedResult = + getExpectedFaceLandmarkerResult( + PORTRAIT_FACE_LANDMARKS, Optional.of(PORTRAIT_FACE_BLENDSHAPES), Optional.empty()); + FaceLandmarkerOptions options = + FaceLandmarkerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(FACE_LANDMARKER_BUNDLE_ASSET_FILE) + .build()) + .setOutputFaceBlendshapes(true) + .setRunningMode(RunningMode.LIVE_STREAM) + .setResultListener( + (actualResult, inputImage) -> { + assertActualResultApproximatelyEqualsToExpectedResult( + actualResult, expectedResult); + assertImageSizeIsExpected(inputImage); + }) + .build(); + try (FaceLandmarker faceLandmarker = + FaceLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { + faceLandmarker.detectAsync(image, /* timestampsMs= */ 1); + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> faceLandmarker.detectAsync(image, /* timestampsMs= */ 0)); + assertThat(exception) + .hasMessageThat() + .contains("having a smaller timestamp than the processed timestamp"); + } + } + + @Test + public void detect_successWithLiveSteamMode() throws Exception { + MPImage image = getImageFromAsset(PORTRAIT_IMAGE); + FaceLandmarkerResult expectedResult = + getExpectedFaceLandmarkerResult( + PORTRAIT_FACE_LANDMARKS, Optional.of(PORTRAIT_FACE_BLENDSHAPES), Optional.empty()); + FaceLandmarkerOptions options = + FaceLandmarkerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(FACE_LANDMARKER_BUNDLE_ASSET_FILE) + .build()) + .setOutputFaceBlendshapes(true) + .setRunningMode(RunningMode.LIVE_STREAM) + .setResultListener( + (actualResult, inputImage) -> { + assertActualResultApproximatelyEqualsToExpectedResult( + actualResult, expectedResult); + assertImageSizeIsExpected(inputImage); + }) + .build(); + try (FaceLandmarker faceLandmarker = + FaceLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { + for (int i = 0; i < 3; i++) { + faceLandmarker.detectAsync(image, /* timestampsMs= */ i); + } + } + } + } + + private static MPImage getImageFromAsset(String filePath) throws Exception { + AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets(); + InputStream istr = assetManager.open(filePath); + return new BitmapImageBuilder(BitmapFactory.decodeStream(istr)).build(); + } + + private static FaceLandmarkerResult getExpectedFaceLandmarkerResult( + String faceLandmarksFilePath, + Optional faceBlendshapesFilePath, + Optional faceGeometryFilePath) + throws Exception { + AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets(); + + List faceLandmarks = + Arrays.asList( + NormalizedLandmarkList.parser().parseFrom(assetManager.open(faceLandmarksFilePath))); + Optional> faceBlendshapes = Optional.empty(); + if (faceBlendshapesFilePath.isPresent()) { + faceBlendshapes = + Optional.of( + Arrays.asList( + ClassificationList.parser() + .parseFrom(assetManager.open(faceBlendshapesFilePath.get())))); + } + Optional> facialTransformationMatrixes = Optional.empty(); + if (faceGeometryFilePath.isPresent()) { + FaceGeometry faceGeometry = + FaceGeometry.parser().parseFrom(assetManager.open(faceGeometryFilePath.get())); + facialTransformationMatrixes = + Optional.of(Arrays.asList(faceGeometry.getPoseTransformMatrix())); + } + + return FaceLandmarkerResult.create( + faceLandmarks, faceBlendshapes, facialTransformationMatrixes, /* timestampMs= */ 0); + } + + private static void assertActualResultApproximatelyEqualsToExpectedResult( + FaceLandmarkerResult actualResult, FaceLandmarkerResult expectedResult) { + // Expects to have the same number of faces detected. + assertThat(actualResult.faceLandmarks()).hasSize(expectedResult.faceLandmarks().size()); + assertThat(actualResult.faceBlendshapes().isPresent()) + .isEqualTo(expectedResult.faceBlendshapes().isPresent()); + assertThat(actualResult.facialTransformationMatrixes().isPresent()) + .isEqualTo(expectedResult.facialTransformationMatrixes().isPresent()); + + // Actual face landmarks match expected face landmarks. + assertThat(actualResult.faceLandmarks().get(0)) + .comparingElementsUsing( + Correspondence.from( + (Correspondence.BinaryPredicate) + (actual, expected) -> { + return Correspondence.tolerance(FACE_LANDMARKS_ERROR_TOLERANCE) + .compare(actual.x(), expected.x()) + && Correspondence.tolerance(FACE_LANDMARKS_ERROR_TOLERANCE) + .compare(actual.y(), expected.y()); + }, + "face landmarks approximately equal to")) + .containsExactlyElementsIn(expectedResult.faceLandmarks().get(0)); + + // Actual face blendshapes match expected face blendshapes. + if (actualResult.faceBlendshapes().isPresent()) { + assertThat(actualResult.faceBlendshapes().get().get(0)) + .comparingElementsUsing( + Correspondence.from( + (Correspondence.BinaryPredicate) + (actual, expected) -> { + return Correspondence.tolerance(FACE_BLENDSHAPES_ERROR_TOLERANCE) + .compare(actual.score(), expected.score()) + && actual.index() == expected.index() + && actual.categoryName().equals(expected.categoryName()); + }, + "face blendshapes approximately equal to")) + .containsExactlyElementsIn(expectedResult.faceBlendshapes().get().get(0)); + } + + // Actual transformation matrix match expected transformation matrix; + if (actualResult.facialTransformationMatrixes().isPresent()) { + assertThat(actualResult.facialTransformationMatrixes().get().get(0)) + .usingTolerance(FACIAL_TRANSFORMATION_MATRIX_ERROR_TOLERANCE) + .containsExactly(expectedResult.facialTransformationMatrixes().get().get(0)); + } + } + + private static void assertImageSizeIsExpected(MPImage inputImage) { + assertThat(inputImage).isNotNull(); + assertThat(inputImage.getWidth()).isEqualTo(IMAGE_WIDTH); + assertThat(inputImage.getHeight()).isEqualTo(IMAGE_HEIGHT); + } +}