diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD index 32518725a..bd57ffadb 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD @@ -45,6 +45,7 @@ cc_binary( deps = [ "//mediapipe/calculators/core:flow_limiter_calculator", "//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni", + "//mediapipe/tasks/cc/vision/face_detector:face_detector_graph", "//mediapipe/tasks/cc/vision/gesture_recognizer:gesture_recognizer_graph", "//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph", "//mediapipe/tasks/cc/vision/image_embedder:image_embedder_graph", @@ -235,6 +236,7 @@ android_library( android_library( name = "facedetector", srcs = [ + "facedetector/FaceDetector.java", "facedetector/FaceDetectorResult.java", ], javacopts = [ @@ -245,7 +247,10 @@ android_library( ":core", "//mediapipe/framework:calculator_options_java_proto_lite", "//mediapipe/framework/formats:detection_java_proto_lite", + "//mediapipe/java/com/google/mediapipe/framework:android_framework", + "//mediapipe/java/com/google/mediapipe/framework/image", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_java_proto_lite", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:detection", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//third_party:autovalue", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facedetector/FaceDetector.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facedetector/FaceDetector.java new file mode 100644 index 000000000..c23432c1b --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facedetector/FaceDetector.java @@ -0,0 +1,463 @@ +// Copyright 2023 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.facedetector; + +import android.content.Context; +import android.os.ParcelFileDescriptor; +import com.google.auto.value.AutoValue; +import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions; +import com.google.mediapipe.framework.AndroidPacketGetter; +import com.google.mediapipe.framework.Packet; +import com.google.mediapipe.framework.PacketGetter; +import com.google.mediapipe.framework.image.BitmapImageBuilder; +import com.google.mediapipe.framework.image.MPImage; +import com.google.mediapipe.tasks.core.BaseOptions; +import com.google.mediapipe.tasks.core.ErrorListener; +import com.google.mediapipe.tasks.core.OutputHandler; +import com.google.mediapipe.tasks.core.OutputHandler.ResultListener; +import com.google.mediapipe.tasks.core.TaskInfo; +import com.google.mediapipe.tasks.core.TaskOptions; +import com.google.mediapipe.tasks.core.TaskRunner; +import com.google.mediapipe.tasks.core.proto.BaseOptionsProto; +import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi; +import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; +import com.google.mediapipe.tasks.vision.core.RunningMode; +import com.google.mediapipe.tasks.vision.facedetector.proto.FaceDetectorGraphOptionsProto; +import com.google.mediapipe.formats.proto.DetectionProto.Detection; +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +/** + * Performs face detection on images. + * + *

The API expects a TFLite model with TFLite Model Metadata.. + * + *

+ */ +public final class FaceDetector extends BaseVisionTaskApi { + private static final String TAG = FaceDetector.class.getSimpleName(); + private static final String IMAGE_IN_STREAM_NAME = "image_in"; + private static final String NORM_RECT_IN_STREAM_NAME = "norm_rect_in"; + + @SuppressWarnings("ConstantCaseForConstants") + private static final List INPUT_STREAMS = + Collections.unmodifiableList( + Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME)); + + @SuppressWarnings("ConstantCaseForConstants") + private static final List OUTPUT_STREAMS = + Collections.unmodifiableList(Arrays.asList("DETECTIONS:detections_out", "IMAGE:image_out")); + + private static final int DETECTIONS_OUT_STREAM_INDEX = 0; + private static final int IMAGE_OUT_STREAM_INDEX = 1; + private static final String TASK_GRAPH_NAME = + "mediapipe.tasks.vision.face_detector.FaceDetectorGraph"; + + /** + * Creates a {@link FaceDetector} instance from a model file and the default {@link + * FaceDetectorOptions}. + * + * @param context an Android {@link Context}. + * @param modelPath path to the detection model with metadata in the assets. + * @throws MediaPipeException if there is an error during {@link FaceDetector} creation. + */ + public static FaceDetector createFromFile(Context context, String modelPath) { + BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(modelPath).build(); + return createFromOptions( + context, FaceDetectorOptions.builder().setBaseOptions(baseOptions).build()); + } + + /** + * Creates a {@link FaceDetector} instance from a model file and the default {@link + * FaceDetectorOptions}. + * + * @param context an Android {@link Context}. + * @param modelFile the detection model {@link File} instance. + * @throws IOException if an I/O error occurs when opening the tflite model file. + * @throws MediaPipeException if there is an error during {@link FaceDetector} creation. + */ + public static FaceDetector createFromFile(Context context, File modelFile) throws IOException { + try (ParcelFileDescriptor descriptor = + ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) { + BaseOptions baseOptions = + BaseOptions.builder().setModelAssetFileDescriptor(descriptor.getFd()).build(); + return createFromOptions( + context, FaceDetectorOptions.builder().setBaseOptions(baseOptions).build()); + } + } + + /** + * Creates a {@link FaceDetector} instance from a model buffer and the default {@link + * FaceDetectorOptions}. + * + * @param context an Android {@link Context}. + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the detection + * model. + * @throws MediaPipeException if there is an error during {@link FaceDetector} creation. + */ + public static FaceDetector createFromBuffer(Context context, final ByteBuffer modelBuffer) { + BaseOptions baseOptions = BaseOptions.builder().setModelAssetBuffer(modelBuffer).build(); + return createFromOptions( + context, FaceDetectorOptions.builder().setBaseOptions(baseOptions).build()); + } + + /** + * Creates a {@link FaceDetector} instance from a {@link FaceDetectorOptions}. + * + * @param context an Android {@link Context}. + * @param detectorOptions a {@link FaceDetectorOptions} instance. + * @throws MediaPipeException if there is an error during {@link FaceDetector} creation. + */ + public static FaceDetector createFromOptions( + Context context, FaceDetectorOptions detectorOptions) { + // TODO: Consolidate OutputHandler and TaskRunner. + OutputHandler handler = new OutputHandler<>(); + handler.setOutputPacketConverter( + new OutputHandler.OutputPacketConverter() { + @Override + public FaceDetectorResult convertToTaskResult(List packets) { + // If there is no faces detected in the image, just returns empty lists. + if (packets.get(DETECTIONS_OUT_STREAM_INDEX).isEmpty()) { + return FaceDetectorResult.create( + new ArrayList<>(), + BaseVisionTaskApi.generateResultTimestampMs( + detectorOptions.runningMode(), packets.get(DETECTIONS_OUT_STREAM_INDEX))); + } + return FaceDetectorResult.create( + PacketGetter.getProtoVector( + packets.get(DETECTIONS_OUT_STREAM_INDEX), Detection.parser()), + BaseVisionTaskApi.generateResultTimestampMs( + detectorOptions.runningMode(), packets.get(DETECTIONS_OUT_STREAM_INDEX))); + } + + @Override + public MPImage convertToTaskInput(List packets) { + return new BitmapImageBuilder( + AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX))) + .build(); + } + }); + detectorOptions.resultListener().ifPresent(handler::setResultListener); + detectorOptions.errorListener().ifPresent(handler::setErrorListener); + TaskRunner runner = + TaskRunner.create( + context, + TaskInfo.builder() + .setTaskName(FaceDetector.class.getSimpleName()) + .setTaskRunningModeName(detectorOptions.runningMode().name()) + .setTaskGraphName(TASK_GRAPH_NAME) + .setInputStreams(INPUT_STREAMS) + .setOutputStreams(OUTPUT_STREAMS) + .setTaskOptions(detectorOptions) + .setEnableFlowLimiting(detectorOptions.runningMode() == RunningMode.LIVE_STREAM) + .build(), + handler); + return new FaceDetector(runner, detectorOptions.runningMode()); + } + + /** + * Constructor to initialize a {@link FaceDetector} from a {@link TaskRunner} and a {@link + * RunningMode}. + * + * @param taskRunner a {@link TaskRunner}. + * @param runningMode a mediapipe vision task {@link RunningMode}. + */ + private FaceDetector(TaskRunner taskRunner, RunningMode runningMode) { + super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME); + } + + /** + * Performs face detection on the provided single image with default image processing options, + * i.e. without any rotation applied. Only use this method when the {@link FaceDetector} is + * created with {@link RunningMode.IMAGE}. + * + *

{@link FaceDetector} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @throws MediaPipeException if there is an internal error. + */ + public FaceDetectorResult detect(MPImage image) { + return detect(image, ImageProcessingOptions.builder().build()); + } + + /** + * Performs face detection on the provided single image. Only use this method when the {@link + * FaceDetector} is created with {@link RunningMode.IMAGE}. + * + *

{@link FaceDetector} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. Note that region-of-interest is not supported + * by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in + * this method throwing an IllegalArgumentException. + * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a + * region-of-interest. + * @throws MediaPipeException if there is an internal error. + */ + public FaceDetectorResult detect(MPImage image, ImageProcessingOptions imageProcessingOptions) { + validateImageProcessingOptions(imageProcessingOptions); + return (FaceDetectorResult) processImageData(image, imageProcessingOptions); + } + + /** + * Performs face detection on the provided video frame with default image processing options, i.e. + * without any rotation applied. Only use this method when the {@link FaceDetector} is created + * with {@link RunningMode.VIDEO}. + * + *

It's required to provide the video frame's timestamp (in milliseconds). The input timestamps + * must be monotonically increasing. + * + *

{@link FaceDetector} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param timestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public FaceDetectorResult detectForVideo(MPImage image, long timestampMs) { + return detectForVideo(image, ImageProcessingOptions.builder().build(), timestampMs); + } + + /** + * Performs face detection on the provided video frame. Only use this method when the {@link + * FaceDetector} is created with {@link RunningMode.VIDEO}. + * + *

It's required to provide the video frame's timestamp (in milliseconds). The input timestamps + * must be monotonically increasing. + * + *

{@link FaceDetector} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. Note that region-of-interest is not supported + * by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in + * this method throwing an IllegalArgumentException. + * @param timestampMs the input timestamp (in milliseconds). + * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a + * region-of-interest. + * @throws MediaPipeException if there is an internal error. + */ + public FaceDetectorResult detectForVideo( + MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { + validateImageProcessingOptions(imageProcessingOptions); + return (FaceDetectorResult) processVideoData(image, imageProcessingOptions, timestampMs); + } + + /** + * Sends live image data to perform face detection with default image processing options, i.e. + * without any rotation applied, and the results will be available via the {@link ResultListener} + * provided in the {@link FaceDetectorOptions}. Only use this method when the {@link FaceDetector} + * is created with {@link RunningMode.LIVE_STREAM}. + * + *

It's required to provide a timestamp (in milliseconds) to indicate when the input image is + * sent to the face detector. The input timestamps must be monotonically increasing. + * + *

{@link FaceDetector} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param timestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public void detectAsync(MPImage image, long timestampMs) { + detectAsync(image, ImageProcessingOptions.builder().build(), timestampMs); + } + + /** + * Sends live image data to perform face detection, and the results will be available via the + * {@link ResultListener} provided in the {@link FaceDetectorOptions}. Only use this method when + * the {@link FaceDetector} is created with {@link RunningMode.LIVE_STREAM}. + * + *

It's required to provide a timestamp (in milliseconds) to indicate when the input image is + * sent to the face detector. The input timestamps must be monotonically increasing. + * + *

{@link FaceDetector} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. Note that region-of-interest is not supported + * by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in + * this method throwing an IllegalArgumentException. + * @param timestampMs the input timestamp (in milliseconds). + * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a + * region-of-interest. + * @throws MediaPipeException if there is an internal error. + */ + public void detectAsync( + MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { + validateImageProcessingOptions(imageProcessingOptions); + sendLiveStreamData(image, imageProcessingOptions, timestampMs); + } + + /** Options for setting up a {@link FaceDetector}. */ + @AutoValue + public abstract static class FaceDetectorOptions extends TaskOptions { + + /** Builder for {@link FaceDetectorOptions}. */ + @AutoValue.Builder + public abstract static class Builder { + /** Sets the {@link BaseOptions} for the face detector task. */ + public abstract Builder setBaseOptions(BaseOptions value); + + /** + * Sets the {@link RunningMode} for the face detector task. Default to the image mode. face + * detector has three modes: + * + *
    + *
  • IMAGE: The mode for detecting faces on single image inputs. + *
  • VIDEO: The mode for detecting faces on the decoded frames of a video. + *
  • LIVE_STREAM: The mode for for detecting faces on a live stream of input data, such as + * from camera. In this mode, {@code setResultListener} must be called to set up a + * listener to receive the detection results asynchronously. + *
+ */ + public abstract Builder setRunningMode(RunningMode value); + + /** + * Sets the minimum confidence score for the face detection to be considered successful. The + * default minDetectionConfidence is 0.5. + */ + public abstract Builder setMinDetectionConfidence(Float value); + + /** + * Sets the minimum non-maximum-suppression threshold for face detection to be considered + * overlapped. The default minSuppressionThreshold is 0.3. + */ + public abstract Builder setMinSuppressionThreshold(Float value); + + /** + * Sets the {@link ResultListener} to receive the detection results asynchronously when the + * face detector is in the live stream mode. + */ + public abstract Builder setResultListener(ResultListener value); + + /** Sets an optional {@link ErrorListener}}. */ + public abstract Builder setErrorListener(ErrorListener value); + + abstract FaceDetectorOptions autoBuild(); + + /** + * Validates and builds the {@link FaceDetectorOptions} instance. + * + * @throws IllegalArgumentException if the result listener and the running mode are not + * properly configured. The result listener should only be set when the face detector is + * in the live stream mode. + */ + public final FaceDetectorOptions build() { + FaceDetectorOptions options = autoBuild(); + if (options.runningMode() == RunningMode.LIVE_STREAM) { + if (!options.resultListener().isPresent()) { + throw new IllegalArgumentException( + "The face detector is in the live stream mode, a user-defined result listener" + + " must be provided in FaceDetectorOptions."); + } + } else if (options.resultListener().isPresent()) { + throw new IllegalArgumentException( + "The face detector is in the image or the video mode, a user-defined result" + + " listener shouldn't be provided in FaceDetectorOptions."); + } + return options; + } + } + + abstract BaseOptions baseOptions(); + + abstract RunningMode runningMode(); + + abstract float minDetectionConfidence(); + + abstract float minSuppressionThreshold(); + + abstract Optional> resultListener(); + + abstract Optional errorListener(); + + public static Builder builder() { + return new AutoValue_FaceDetector_FaceDetectorOptions.Builder() + .setRunningMode(RunningMode.IMAGE) + .setMinDetectionConfidence(0.5f) + .setMinSuppressionThreshold(0.3f); + } + + /** Converts a {@link FaceDetectorOptions} to a {@link CalculatorOptions} protobuf message. */ + @Override + public CalculatorOptions convertToCalculatorOptionsProto() { + BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder = + BaseOptionsProto.BaseOptions.newBuilder(); + baseOptionsBuilder.setUseStreamMode(runningMode() != RunningMode.IMAGE); + baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions())); + FaceDetectorGraphOptionsProto.FaceDetectorGraphOptions.Builder taskOptionsBuilder = + FaceDetectorGraphOptionsProto.FaceDetectorGraphOptions.newBuilder() + .setBaseOptions(baseOptionsBuilder); + taskOptionsBuilder.setMinDetectionConfidence(minDetectionConfidence()); + taskOptionsBuilder.setMinSuppressionThreshold(minSuppressionThreshold()); + return CalculatorOptions.newBuilder() + .setExtension( + FaceDetectorGraphOptionsProto.FaceDetectorGraphOptions.ext, + taskOptionsBuilder.build()) + .build(); + } + } + + /** + * Validates that the provided {@link ImageProcessingOptions} doesn't contain a + * region-of-interest. + */ + private static void validateImageProcessingOptions( + ImageProcessingOptions imageProcessingOptions) { + if (imageProcessingOptions.regionOfInterest().isPresent()) { + throw new IllegalArgumentException("FaceDetector doesn't support region-of-interest."); + } + } +} diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/facedetector/AndroidManifest.xml b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/facedetector/AndroidManifest.xml new file mode 100644 index 000000000..01cbc3a6f --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/facedetector/AndroidManifest.xml @@ -0,0 +1,24 @@ + + + + + + + + + + + + + + + diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/facedetector/BUILD b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/facedetector/BUILD new file mode 100644 index 000000000..c14486766 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/facedetector/BUILD @@ -0,0 +1,19 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +# TODO: Enable this in OSS diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/facedetector/FaceDetectorTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/facedetector/FaceDetectorTest.java new file mode 100644 index 000000000..d995accd5 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/facedetector/FaceDetectorTest.java @@ -0,0 +1,455 @@ +// Copyright 2023 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.facedetector; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import android.content.res.AssetManager; +import android.graphics.BitmapFactory; +import android.graphics.RectF; +import androidx.test.core.app.ApplicationProvider; +import androidx.test.ext.junit.runners.AndroidJUnit4; +import com.google.mediapipe.framework.MediaPipeException; +import com.google.mediapipe.framework.image.BitmapImageBuilder; +import com.google.mediapipe.framework.image.MPImage; +import com.google.mediapipe.tasks.components.containers.NormalizedKeypoint; +import com.google.mediapipe.tasks.core.BaseOptions; +import com.google.mediapipe.tasks.core.TestUtils; +import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; +import com.google.mediapipe.tasks.vision.core.RunningMode; +import com.google.mediapipe.tasks.vision.facedetector.FaceDetector.FaceDetectorOptions; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Suite; +import org.junit.runners.Suite.SuiteClasses; + +/** Test for {@link FaceDetector}. */ +@RunWith(Suite.class) +@SuiteClasses({FaceDetectorTest.General.class, FaceDetectorTest.RunningModeTest.class}) +public class FaceDetectorTest { + private static final String MODEL_FILE = "face_detection_short_range.tflite"; + private static final String CAT_IMAGE = "cat.jpg"; + private static final String PORTRAIT_IMAGE = "portrait.jpg"; + private static final String PORTRAIT_ROTATED_IMAGE = "portrait_rotated.jpg"; + private static final float KEYPOINTS_DIFF_TOLERANCE = 0.01f; + private static final float PIXEL_DIFF_TOLERANCE = 5.0f; + private static final RectF PORTRAIT_FACE_BOUNDING_BOX = new RectF(283, 115, 514, 349); + private static final List PORTRAIT_FACE_KEYPOINTS = + Collections.unmodifiableList( + Arrays.asList( + NormalizedKeypoint.create(0.44416f, 0.17643f), + NormalizedKeypoint.create(0.55514f, 0.17731f), + NormalizedKeypoint.create(0.50467f, 0.22657f), + NormalizedKeypoint.create(0.50227f, 0.27199f), + NormalizedKeypoint.create(0.36063f, 0.20143f), + NormalizedKeypoint.create(0.60841f, 0.20409f))); + private static final RectF PORTRAIT_ROTATED_FACE_BOUNDING_BOX = new RectF(674, 283, 910, 519); + private static final List PORTRAIT_ROTATED_FACE_KEYPOINTS = + Collections.unmodifiableList( + Arrays.asList( + NormalizedKeypoint.create(0.82075f, 0.44679f), + NormalizedKeypoint.create(0.81965f, 0.56261f), + NormalizedKeypoint.create(0.76194f, 0.51719f), + NormalizedKeypoint.create(0.71993f, 0.51360f), + NormalizedKeypoint.create(0.80700f, 0.36298f), + NormalizedKeypoint.create(0.80882f, 0.61204f))); + + @RunWith(AndroidJUnit4.class) + public static final class General extends FaceDetectorTest { + + @Test + public void detect_successWithValidModels() throws Exception { + FaceDetectorOptions options = + FaceDetectorOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) + .build(); + FaceDetector faceDetector = + FaceDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); + FaceDetectorResult results = faceDetector.detect(getImageFromAsset(PORTRAIT_IMAGE)); + assertContainsSinglePortraitFace( + results, PORTRAIT_FACE_BOUNDING_BOX, PORTRAIT_FACE_KEYPOINTS); + } + + @Test + public void detect_succeedsWithMinDetectionConfidence() throws Exception { + FaceDetectorOptions options = + FaceDetectorOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) + .setMinDetectionConfidence(1.0f) + .build(); + FaceDetector faceDetector = + FaceDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); + FaceDetectorResult results = faceDetector.detect(getImageFromAsset(PORTRAIT_IMAGE)); + // Set minDetectionConfidence to 1.0, so the detected face should be all filtered out. + assertThat(results.detections().isEmpty()).isTrue(); + } + + @Test + public void detect_succeedsWithEmptyFace() throws Exception { + FaceDetectorOptions options = + FaceDetectorOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) + .setMinDetectionConfidence(1.0f) + .build(); + FaceDetector faceDetector = + FaceDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); + FaceDetectorResult results = faceDetector.detect(getImageFromAsset(CAT_IMAGE)); + assertThat(results.detections().isEmpty()).isTrue(); + } + + @Test + public void detect_succeedsWithModelFileObject() throws Exception { + FaceDetector faceDetector = + FaceDetector.createFromFile( + ApplicationProvider.getApplicationContext(), + TestUtils.loadFile(ApplicationProvider.getApplicationContext(), MODEL_FILE)); + FaceDetectorResult results = faceDetector.detect(getImageFromAsset(PORTRAIT_IMAGE)); + assertContainsSinglePortraitFace( + results, PORTRAIT_FACE_BOUNDING_BOX, PORTRAIT_FACE_KEYPOINTS); + } + + @Test + public void detect_succeedsWithModelBuffer() throws Exception { + FaceDetector faceDetector = + FaceDetector.createFromBuffer( + ApplicationProvider.getApplicationContext(), + TestUtils.loadToDirectByteBuffer( + ApplicationProvider.getApplicationContext(), MODEL_FILE)); + FaceDetectorResult results = faceDetector.detect(getImageFromAsset(PORTRAIT_IMAGE)); + assertContainsSinglePortraitFace( + results, PORTRAIT_FACE_BOUNDING_BOX, PORTRAIT_FACE_KEYPOINTS); + } + + @Test + public void detect_succeedsWithModelBufferAndOptions() throws Exception { + FaceDetectorOptions options = + FaceDetectorOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetBuffer( + TestUtils.loadToDirectByteBuffer( + ApplicationProvider.getApplicationContext(), MODEL_FILE)) + .build()) + .build(); + FaceDetector faceDetector = + FaceDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); + FaceDetectorResult results = faceDetector.detect(getImageFromAsset(PORTRAIT_IMAGE)); + assertContainsSinglePortraitFace( + results, PORTRAIT_FACE_BOUNDING_BOX, PORTRAIT_FACE_KEYPOINTS); + } + + @Test + public void create_failsWithMissingModel() throws Exception { + String nonexistentFile = "/path/to/non/existent/file"; + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> + FaceDetector.createFromFile( + ApplicationProvider.getApplicationContext(), nonexistentFile)); + assertThat(exception).hasMessageThat().contains(nonexistentFile); + } + + @Test + public void create_failsWithInvalidModelBuffer() throws Exception { + // Create a non-direct model ByteBuffer. + ByteBuffer modelBuffer = + TestUtils.loadToNonDirectByteBuffer( + ApplicationProvider.getApplicationContext(), MODEL_FILE); + + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + FaceDetector.createFromBuffer( + ApplicationProvider.getApplicationContext(), modelBuffer)); + + assertThat(exception) + .hasMessageThat() + .contains("The model buffer should be either a direct ByteBuffer or a MappedByteBuffer."); + } + + @Test + public void detect_succeedsWithRotation() throws Exception { + FaceDetectorOptions options = + FaceDetectorOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) + .build(); + FaceDetector faceDetector = + FaceDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); + ImageProcessingOptions imageProcessingOptions = + ImageProcessingOptions.builder().setRotationDegrees(-90).build(); + FaceDetectorResult results = + faceDetector.detect(getImageFromAsset(PORTRAIT_ROTATED_IMAGE), imageProcessingOptions); + assertContainsSinglePortraitFace( + results, PORTRAIT_ROTATED_FACE_BOUNDING_BOX, PORTRAIT_ROTATED_FACE_KEYPOINTS); + } + + @Test + public void detect_failsWithRegionOfInterest() throws Exception { + FaceDetectorOptions options = + FaceDetectorOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) + .build(); + FaceDetector faceDetector = + FaceDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); + ImageProcessingOptions imageProcessingOptions = + ImageProcessingOptions.builder().setRegionOfInterest(new RectF(0, 0, 1, 1)).build(); + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> faceDetector.detect(getImageFromAsset(PORTRAIT_IMAGE), imageProcessingOptions)); + assertThat(exception) + .hasMessageThat() + .contains("FaceDetector doesn't support region-of-interest"); + } + } + + @RunWith(AndroidJUnit4.class) + public static final class RunningModeTest extends FaceDetectorTest { + + @Test + public void create_failsWithIllegalResultListenerInNonLiveStreamMode() throws Exception { + for (RunningMode mode : new RunningMode[] {RunningMode.IMAGE, RunningMode.VIDEO}) { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + FaceDetectorOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) + .setRunningMode(mode) + .setResultListener((faceDetectorResult, inputImage) -> {}) + .build()); + assertThat(exception) + .hasMessageThat() + .contains("a user-defined result listener shouldn't be provided"); + } + } + + @Test + public void create_failsWithMissingResultListenerInLiveSteamMode() throws Exception { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + FaceDetectorOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) + .setRunningMode(RunningMode.LIVE_STREAM) + .build()); + assertThat(exception) + .hasMessageThat() + .contains("a user-defined result listener must be provided"); + } + + @Test + public void detect_failsWithCallingWrongApiInImageMode() throws Exception { + FaceDetectorOptions options = + FaceDetectorOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) + .setRunningMode(RunningMode.IMAGE) + .build(); + + FaceDetector faceDetector = + FaceDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> + faceDetector.detectForVideo( + getImageFromAsset(PORTRAIT_IMAGE), /* timestampsMs= */ 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); + exception = + assertThrows( + MediaPipeException.class, + () -> + faceDetector.detectAsync( + getImageFromAsset(PORTRAIT_IMAGE), /* timestampsMs= */ 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); + } + + @Test + public void detect_failsWithCallingWrongApiInVideoMode() throws Exception { + FaceDetectorOptions options = + FaceDetectorOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) + .setRunningMode(RunningMode.VIDEO) + .build(); + + FaceDetector faceDetector = + FaceDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> faceDetector.detect(getImageFromAsset(PORTRAIT_IMAGE))); + assertThat(exception).hasMessageThat().contains("not initialized with the image mode"); + exception = + assertThrows( + MediaPipeException.class, + () -> + faceDetector.detectAsync( + getImageFromAsset(PORTRAIT_IMAGE), /* timestampsMs= */ 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); + } + + @Test + public void detect_failsWithCallingWrongApiInLiveSteamMode() throws Exception { + FaceDetectorOptions options = + FaceDetectorOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) + .setRunningMode(RunningMode.LIVE_STREAM) + .setResultListener((faceDetectorResult, inputImage) -> {}) + .build(); + + FaceDetector faceDetector = + FaceDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); + + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> faceDetector.detect(getImageFromAsset(PORTRAIT_IMAGE))); + assertThat(exception).hasMessageThat().contains("not initialized with the image mode"); + exception = + assertThrows( + MediaPipeException.class, + () -> + faceDetector.detectForVideo( + getImageFromAsset(PORTRAIT_IMAGE), /* timestampsMs= */ 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); + } + + @Test + public void detect_successWithImageMode() throws Exception { + FaceDetectorOptions options = + FaceDetectorOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) + .setRunningMode(RunningMode.IMAGE) + .build(); + FaceDetector faceDetector = + FaceDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); + FaceDetectorResult results = faceDetector.detect(getImageFromAsset(PORTRAIT_IMAGE)); + assertContainsSinglePortraitFace( + results, PORTRAIT_FACE_BOUNDING_BOX, PORTRAIT_FACE_KEYPOINTS); + } + + @Test + public void detect_successWithVideoMode() throws Exception { + FaceDetectorOptions options = + FaceDetectorOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) + .setRunningMode(RunningMode.VIDEO) + .build(); + FaceDetector faceDetector = + FaceDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); + for (int i = 0; i < 3; i++) { + FaceDetectorResult results = + faceDetector.detectForVideo(getImageFromAsset(PORTRAIT_IMAGE), /* timestampsMs= */ i); + assertContainsSinglePortraitFace( + results, PORTRAIT_FACE_BOUNDING_BOX, PORTRAIT_FACE_KEYPOINTS); + } + } + + @Test + public void detect_failsWithOutOfOrderInputTimestamps() throws Exception { + MPImage image = getImageFromAsset(PORTRAIT_IMAGE); + FaceDetectorOptions options = + FaceDetectorOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) + .setRunningMode(RunningMode.LIVE_STREAM) + .setResultListener( + (faceDetectorResult, inputImage) -> { + assertContainsSinglePortraitFace( + faceDetectorResult, PORTRAIT_FACE_BOUNDING_BOX, PORTRAIT_FACE_KEYPOINTS); + }) + .build(); + try (FaceDetector faceDetector = + FaceDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { + faceDetector.detectAsync(image, /* timestampsMs= */ 1); + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> faceDetector.detectAsync(image, /* timestampsMs= */ 0)); + assertThat(exception) + .hasMessageThat() + .contains("having a smaller timestamp than the processed timestamp"); + } + } + + @Test + public void detect_successWithLiveSteamMode() throws Exception { + MPImage image = getImageFromAsset(PORTRAIT_IMAGE); + FaceDetectorOptions options = + FaceDetectorOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) + .setRunningMode(RunningMode.LIVE_STREAM) + .setResultListener( + (faceDetectorResult, inputImage) -> { + assertContainsSinglePortraitFace( + faceDetectorResult, PORTRAIT_FACE_BOUNDING_BOX, PORTRAIT_FACE_KEYPOINTS); + }) + .build(); + try (FaceDetector faceDetector = + FaceDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { + for (int i = 0; i < 3; i++) { + faceDetector.detectAsync(image, /* timestampsMs= */ i); + } + } + } + } + + private static MPImage getImageFromAsset(String filePath) throws Exception { + AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets(); + InputStream istr = assetManager.open(filePath); + return new BitmapImageBuilder(BitmapFactory.decodeStream(istr)).build(); + } + + private static void assertContainsSinglePortraitFace( + FaceDetectorResult results, + RectF expectedboundingBox, + List expectedKeypoints) { + assertThat(results.detections()).hasSize(1); + assertApproximatelyEqualBoundingBoxes( + results.detections().get(0).boundingBox(), expectedboundingBox); + assertThat(results.detections().get(0).keypoints().isPresent()).isTrue(); + assertApproximatelyEqualKeypoints( + results.detections().get(0).keypoints().get(), expectedKeypoints); + } + + private static void assertApproximatelyEqualBoundingBoxes( + RectF boundingBox1, RectF boundingBox2) { + assertThat(boundingBox1.left).isWithin(PIXEL_DIFF_TOLERANCE).of(boundingBox2.left); + assertThat(boundingBox1.top).isWithin(PIXEL_DIFF_TOLERANCE).of(boundingBox2.top); + assertThat(boundingBox1.right).isWithin(PIXEL_DIFF_TOLERANCE).of(boundingBox2.right); + assertThat(boundingBox1.bottom).isWithin(PIXEL_DIFF_TOLERANCE).of(boundingBox2.bottom); + } + + private static void assertApproximatelyEqualKeypoints( + List keypoints1, List keypoints2) { + assertThat(keypoints1.size()).isEqualTo(keypoints2.size()); + for (int i = 0; i < keypoints1.size(); i++) { + assertThat(keypoints1.get(i).x()) + .isWithin(KEYPOINTS_DIFF_TOLERANCE) + .of(keypoints2.get(i).x()); + assertThat(keypoints1.get(i).y()) + .isWithin(KEYPOINTS_DIFF_TOLERANCE) + .of(keypoints2.get(i).y()); + } + } +}