diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD
index 32518725a..bd57ffadb 100644
--- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD
+++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD
@@ -45,6 +45,7 @@ cc_binary(
deps = [
"//mediapipe/calculators/core:flow_limiter_calculator",
"//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni",
+ "//mediapipe/tasks/cc/vision/face_detector:face_detector_graph",
"//mediapipe/tasks/cc/vision/gesture_recognizer:gesture_recognizer_graph",
"//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph",
"//mediapipe/tasks/cc/vision/image_embedder:image_embedder_graph",
@@ -235,6 +236,7 @@ android_library(
android_library(
name = "facedetector",
srcs = [
+ "facedetector/FaceDetector.java",
"facedetector/FaceDetectorResult.java",
],
javacopts = [
@@ -245,7 +247,10 @@ android_library(
":core",
"//mediapipe/framework:calculator_options_java_proto_lite",
"//mediapipe/framework/formats:detection_java_proto_lite",
+ "//mediapipe/java/com/google/mediapipe/framework:android_framework",
+ "//mediapipe/java/com/google/mediapipe/framework/image",
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
+ "//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_java_proto_lite",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:detection",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
"//third_party:autovalue",
diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facedetector/FaceDetector.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facedetector/FaceDetector.java
new file mode 100644
index 000000000..c23432c1b
--- /dev/null
+++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facedetector/FaceDetector.java
@@ -0,0 +1,463 @@
+// Copyright 2023 The MediaPipe Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package com.google.mediapipe.tasks.vision.facedetector;
+
+import android.content.Context;
+import android.os.ParcelFileDescriptor;
+import com.google.auto.value.AutoValue;
+import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
+import com.google.mediapipe.framework.AndroidPacketGetter;
+import com.google.mediapipe.framework.Packet;
+import com.google.mediapipe.framework.PacketGetter;
+import com.google.mediapipe.framework.image.BitmapImageBuilder;
+import com.google.mediapipe.framework.image.MPImage;
+import com.google.mediapipe.tasks.core.BaseOptions;
+import com.google.mediapipe.tasks.core.ErrorListener;
+import com.google.mediapipe.tasks.core.OutputHandler;
+import com.google.mediapipe.tasks.core.OutputHandler.ResultListener;
+import com.google.mediapipe.tasks.core.TaskInfo;
+import com.google.mediapipe.tasks.core.TaskOptions;
+import com.google.mediapipe.tasks.core.TaskRunner;
+import com.google.mediapipe.tasks.core.proto.BaseOptionsProto;
+import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi;
+import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions;
+import com.google.mediapipe.tasks.vision.core.RunningMode;
+import com.google.mediapipe.tasks.vision.facedetector.proto.FaceDetectorGraphOptionsProto;
+import com.google.mediapipe.formats.proto.DetectionProto.Detection;
+import java.io.File;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.Optional;
+
+/**
+ * Performs face detection on images.
+ *
+ *
The API expects a TFLite model with TFLite Model Metadata..
+ *
+ *
+ * - Input image {@link MPImage}
+ *
+ * - The image that the face detector runs on.
+ *
+ * - Output FaceDetectorResult {@link FaceDetectorResult}
+ *
+ * - A FaceDetectorResult containing detected faces.
+ *
+ *
+ */
+public final class FaceDetector extends BaseVisionTaskApi {
+ private static final String TAG = FaceDetector.class.getSimpleName();
+ private static final String IMAGE_IN_STREAM_NAME = "image_in";
+ private static final String NORM_RECT_IN_STREAM_NAME = "norm_rect_in";
+
+ @SuppressWarnings("ConstantCaseForConstants")
+ private static final List INPUT_STREAMS =
+ Collections.unmodifiableList(
+ Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME));
+
+ @SuppressWarnings("ConstantCaseForConstants")
+ private static final List OUTPUT_STREAMS =
+ Collections.unmodifiableList(Arrays.asList("DETECTIONS:detections_out", "IMAGE:image_out"));
+
+ private static final int DETECTIONS_OUT_STREAM_INDEX = 0;
+ private static final int IMAGE_OUT_STREAM_INDEX = 1;
+ private static final String TASK_GRAPH_NAME =
+ "mediapipe.tasks.vision.face_detector.FaceDetectorGraph";
+
+ /**
+ * Creates a {@link FaceDetector} instance from a model file and the default {@link
+ * FaceDetectorOptions}.
+ *
+ * @param context an Android {@link Context}.
+ * @param modelPath path to the detection model with metadata in the assets.
+ * @throws MediaPipeException if there is an error during {@link FaceDetector} creation.
+ */
+ public static FaceDetector createFromFile(Context context, String modelPath) {
+ BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(modelPath).build();
+ return createFromOptions(
+ context, FaceDetectorOptions.builder().setBaseOptions(baseOptions).build());
+ }
+
+ /**
+ * Creates a {@link FaceDetector} instance from a model file and the default {@link
+ * FaceDetectorOptions}.
+ *
+ * @param context an Android {@link Context}.
+ * @param modelFile the detection model {@link File} instance.
+ * @throws IOException if an I/O error occurs when opening the tflite model file.
+ * @throws MediaPipeException if there is an error during {@link FaceDetector} creation.
+ */
+ public static FaceDetector createFromFile(Context context, File modelFile) throws IOException {
+ try (ParcelFileDescriptor descriptor =
+ ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
+ BaseOptions baseOptions =
+ BaseOptions.builder().setModelAssetFileDescriptor(descriptor.getFd()).build();
+ return createFromOptions(
+ context, FaceDetectorOptions.builder().setBaseOptions(baseOptions).build());
+ }
+ }
+
+ /**
+ * Creates a {@link FaceDetector} instance from a model buffer and the default {@link
+ * FaceDetectorOptions}.
+ *
+ * @param context an Android {@link Context}.
+ * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the detection
+ * model.
+ * @throws MediaPipeException if there is an error during {@link FaceDetector} creation.
+ */
+ public static FaceDetector createFromBuffer(Context context, final ByteBuffer modelBuffer) {
+ BaseOptions baseOptions = BaseOptions.builder().setModelAssetBuffer(modelBuffer).build();
+ return createFromOptions(
+ context, FaceDetectorOptions.builder().setBaseOptions(baseOptions).build());
+ }
+
+ /**
+ * Creates a {@link FaceDetector} instance from a {@link FaceDetectorOptions}.
+ *
+ * @param context an Android {@link Context}.
+ * @param detectorOptions a {@link FaceDetectorOptions} instance.
+ * @throws MediaPipeException if there is an error during {@link FaceDetector} creation.
+ */
+ public static FaceDetector createFromOptions(
+ Context context, FaceDetectorOptions detectorOptions) {
+ // TODO: Consolidate OutputHandler and TaskRunner.
+ OutputHandler handler = new OutputHandler<>();
+ handler.setOutputPacketConverter(
+ new OutputHandler.OutputPacketConverter() {
+ @Override
+ public FaceDetectorResult convertToTaskResult(List packets) {
+ // If there is no faces detected in the image, just returns empty lists.
+ if (packets.get(DETECTIONS_OUT_STREAM_INDEX).isEmpty()) {
+ return FaceDetectorResult.create(
+ new ArrayList<>(),
+ BaseVisionTaskApi.generateResultTimestampMs(
+ detectorOptions.runningMode(), packets.get(DETECTIONS_OUT_STREAM_INDEX)));
+ }
+ return FaceDetectorResult.create(
+ PacketGetter.getProtoVector(
+ packets.get(DETECTIONS_OUT_STREAM_INDEX), Detection.parser()),
+ BaseVisionTaskApi.generateResultTimestampMs(
+ detectorOptions.runningMode(), packets.get(DETECTIONS_OUT_STREAM_INDEX)));
+ }
+
+ @Override
+ public MPImage convertToTaskInput(List packets) {
+ return new BitmapImageBuilder(
+ AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX)))
+ .build();
+ }
+ });
+ detectorOptions.resultListener().ifPresent(handler::setResultListener);
+ detectorOptions.errorListener().ifPresent(handler::setErrorListener);
+ TaskRunner runner =
+ TaskRunner.create(
+ context,
+ TaskInfo.builder()
+ .setTaskName(FaceDetector.class.getSimpleName())
+ .setTaskRunningModeName(detectorOptions.runningMode().name())
+ .setTaskGraphName(TASK_GRAPH_NAME)
+ .setInputStreams(INPUT_STREAMS)
+ .setOutputStreams(OUTPUT_STREAMS)
+ .setTaskOptions(detectorOptions)
+ .setEnableFlowLimiting(detectorOptions.runningMode() == RunningMode.LIVE_STREAM)
+ .build(),
+ handler);
+ return new FaceDetector(runner, detectorOptions.runningMode());
+ }
+
+ /**
+ * Constructor to initialize a {@link FaceDetector} from a {@link TaskRunner} and a {@link
+ * RunningMode}.
+ *
+ * @param taskRunner a {@link TaskRunner}.
+ * @param runningMode a mediapipe vision task {@link RunningMode}.
+ */
+ private FaceDetector(TaskRunner taskRunner, RunningMode runningMode) {
+ super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME);
+ }
+
+ /**
+ * Performs face detection on the provided single image with default image processing options,
+ * i.e. without any rotation applied. Only use this method when the {@link FaceDetector} is
+ * created with {@link RunningMode.IMAGE}.
+ *
+ * {@link FaceDetector} supports the following color space types:
+ *
+ *
+ * - {@link Bitmap.Config.ARGB_8888}
+ *
+ *
+ * @param image a MediaPipe {@link MPImage} object for processing.
+ * @throws MediaPipeException if there is an internal error.
+ */
+ public FaceDetectorResult detect(MPImage image) {
+ return detect(image, ImageProcessingOptions.builder().build());
+ }
+
+ /**
+ * Performs face detection on the provided single image. Only use this method when the {@link
+ * FaceDetector} is created with {@link RunningMode.IMAGE}.
+ *
+ * {@link FaceDetector} supports the following color space types:
+ *
+ *
+ * - {@link Bitmap.Config.ARGB_8888}
+ *
+ *
+ * @param image a MediaPipe {@link MPImage} object for processing.
+ * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
+ * input image before running inference. Note that region-of-interest is not supported
+ * by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in
+ * this method throwing an IllegalArgumentException.
+ * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
+ * region-of-interest.
+ * @throws MediaPipeException if there is an internal error.
+ */
+ public FaceDetectorResult detect(MPImage image, ImageProcessingOptions imageProcessingOptions) {
+ validateImageProcessingOptions(imageProcessingOptions);
+ return (FaceDetectorResult) processImageData(image, imageProcessingOptions);
+ }
+
+ /**
+ * Performs face detection on the provided video frame with default image processing options, i.e.
+ * without any rotation applied. Only use this method when the {@link FaceDetector} is created
+ * with {@link RunningMode.VIDEO}.
+ *
+ * It's required to provide the video frame's timestamp (in milliseconds). The input timestamps
+ * must be monotonically increasing.
+ *
+ *
{@link FaceDetector} supports the following color space types:
+ *
+ *
+ * - {@link Bitmap.Config.ARGB_8888}
+ *
+ *
+ * @param image a MediaPipe {@link MPImage} object for processing.
+ * @param timestampMs the input timestamp (in milliseconds).
+ * @throws MediaPipeException if there is an internal error.
+ */
+ public FaceDetectorResult detectForVideo(MPImage image, long timestampMs) {
+ return detectForVideo(image, ImageProcessingOptions.builder().build(), timestampMs);
+ }
+
+ /**
+ * Performs face detection on the provided video frame. Only use this method when the {@link
+ * FaceDetector} is created with {@link RunningMode.VIDEO}.
+ *
+ * It's required to provide the video frame's timestamp (in milliseconds). The input timestamps
+ * must be monotonically increasing.
+ *
+ *
{@link FaceDetector} supports the following color space types:
+ *
+ *
+ * - {@link Bitmap.Config.ARGB_8888}
+ *
+ *
+ * @param image a MediaPipe {@link MPImage} object for processing.
+ * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
+ * input image before running inference. Note that region-of-interest is not supported
+ * by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in
+ * this method throwing an IllegalArgumentException.
+ * @param timestampMs the input timestamp (in milliseconds).
+ * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
+ * region-of-interest.
+ * @throws MediaPipeException if there is an internal error.
+ */
+ public FaceDetectorResult detectForVideo(
+ MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
+ validateImageProcessingOptions(imageProcessingOptions);
+ return (FaceDetectorResult) processVideoData(image, imageProcessingOptions, timestampMs);
+ }
+
+ /**
+ * Sends live image data to perform face detection with default image processing options, i.e.
+ * without any rotation applied, and the results will be available via the {@link ResultListener}
+ * provided in the {@link FaceDetectorOptions}. Only use this method when the {@link FaceDetector}
+ * is created with {@link RunningMode.LIVE_STREAM}.
+ *
+ * It's required to provide a timestamp (in milliseconds) to indicate when the input image is
+ * sent to the face detector. The input timestamps must be monotonically increasing.
+ *
+ *
{@link FaceDetector} supports the following color space types:
+ *
+ *
+ * - {@link Bitmap.Config.ARGB_8888}
+ *
+ *
+ * @param image a MediaPipe {@link MPImage} object for processing.
+ * @param timestampMs the input timestamp (in milliseconds).
+ * @throws MediaPipeException if there is an internal error.
+ */
+ public void detectAsync(MPImage image, long timestampMs) {
+ detectAsync(image, ImageProcessingOptions.builder().build(), timestampMs);
+ }
+
+ /**
+ * Sends live image data to perform face detection, and the results will be available via the
+ * {@link ResultListener} provided in the {@link FaceDetectorOptions}. Only use this method when
+ * the {@link FaceDetector} is created with {@link RunningMode.LIVE_STREAM}.
+ *
+ * It's required to provide a timestamp (in milliseconds) to indicate when the input image is
+ * sent to the face detector. The input timestamps must be monotonically increasing.
+ *
+ *
{@link FaceDetector} supports the following color space types:
+ *
+ *
+ * - {@link Bitmap.Config.ARGB_8888}
+ *
+ *
+ * @param image a MediaPipe {@link MPImage} object for processing.
+ * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
+ * input image before running inference. Note that region-of-interest is not supported
+ * by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in
+ * this method throwing an IllegalArgumentException.
+ * @param timestampMs the input timestamp (in milliseconds).
+ * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
+ * region-of-interest.
+ * @throws MediaPipeException if there is an internal error.
+ */
+ public void detectAsync(
+ MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
+ validateImageProcessingOptions(imageProcessingOptions);
+ sendLiveStreamData(image, imageProcessingOptions, timestampMs);
+ }
+
+ /** Options for setting up a {@link FaceDetector}. */
+ @AutoValue
+ public abstract static class FaceDetectorOptions extends TaskOptions {
+
+ /** Builder for {@link FaceDetectorOptions}. */
+ @AutoValue.Builder
+ public abstract static class Builder {
+ /** Sets the {@link BaseOptions} for the face detector task. */
+ public abstract Builder setBaseOptions(BaseOptions value);
+
+ /**
+ * Sets the {@link RunningMode} for the face detector task. Default to the image mode. face
+ * detector has three modes:
+ *
+ *
+ * - IMAGE: The mode for detecting faces on single image inputs.
+ *
- VIDEO: The mode for detecting faces on the decoded frames of a video.
+ *
- LIVE_STREAM: The mode for for detecting faces on a live stream of input data, such as
+ * from camera. In this mode, {@code setResultListener} must be called to set up a
+ * listener to receive the detection results asynchronously.
+ *
+ */
+ public abstract Builder setRunningMode(RunningMode value);
+
+ /**
+ * Sets the minimum confidence score for the face detection to be considered successful. The
+ * default minDetectionConfidence is 0.5.
+ */
+ public abstract Builder setMinDetectionConfidence(Float value);
+
+ /**
+ * Sets the minimum non-maximum-suppression threshold for face detection to be considered
+ * overlapped. The default minSuppressionThreshold is 0.3.
+ */
+ public abstract Builder setMinSuppressionThreshold(Float value);
+
+ /**
+ * Sets the {@link ResultListener} to receive the detection results asynchronously when the
+ * face detector is in the live stream mode.
+ */
+ public abstract Builder setResultListener(ResultListener value);
+
+ /** Sets an optional {@link ErrorListener}}. */
+ public abstract Builder setErrorListener(ErrorListener value);
+
+ abstract FaceDetectorOptions autoBuild();
+
+ /**
+ * Validates and builds the {@link FaceDetectorOptions} instance.
+ *
+ * @throws IllegalArgumentException if the result listener and the running mode are not
+ * properly configured. The result listener should only be set when the face detector is
+ * in the live stream mode.
+ */
+ public final FaceDetectorOptions build() {
+ FaceDetectorOptions options = autoBuild();
+ if (options.runningMode() == RunningMode.LIVE_STREAM) {
+ if (!options.resultListener().isPresent()) {
+ throw new IllegalArgumentException(
+ "The face detector is in the live stream mode, a user-defined result listener"
+ + " must be provided in FaceDetectorOptions.");
+ }
+ } else if (options.resultListener().isPresent()) {
+ throw new IllegalArgumentException(
+ "The face detector is in the image or the video mode, a user-defined result"
+ + " listener shouldn't be provided in FaceDetectorOptions.");
+ }
+ return options;
+ }
+ }
+
+ abstract BaseOptions baseOptions();
+
+ abstract RunningMode runningMode();
+
+ abstract float minDetectionConfidence();
+
+ abstract float minSuppressionThreshold();
+
+ abstract Optional> resultListener();
+
+ abstract Optional errorListener();
+
+ public static Builder builder() {
+ return new AutoValue_FaceDetector_FaceDetectorOptions.Builder()
+ .setRunningMode(RunningMode.IMAGE)
+ .setMinDetectionConfidence(0.5f)
+ .setMinSuppressionThreshold(0.3f);
+ }
+
+ /** Converts a {@link FaceDetectorOptions} to a {@link CalculatorOptions} protobuf message. */
+ @Override
+ public CalculatorOptions convertToCalculatorOptionsProto() {
+ BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder =
+ BaseOptionsProto.BaseOptions.newBuilder();
+ baseOptionsBuilder.setUseStreamMode(runningMode() != RunningMode.IMAGE);
+ baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions()));
+ FaceDetectorGraphOptionsProto.FaceDetectorGraphOptions.Builder taskOptionsBuilder =
+ FaceDetectorGraphOptionsProto.FaceDetectorGraphOptions.newBuilder()
+ .setBaseOptions(baseOptionsBuilder);
+ taskOptionsBuilder.setMinDetectionConfidence(minDetectionConfidence());
+ taskOptionsBuilder.setMinSuppressionThreshold(minSuppressionThreshold());
+ return CalculatorOptions.newBuilder()
+ .setExtension(
+ FaceDetectorGraphOptionsProto.FaceDetectorGraphOptions.ext,
+ taskOptionsBuilder.build())
+ .build();
+ }
+ }
+
+ /**
+ * Validates that the provided {@link ImageProcessingOptions} doesn't contain a
+ * region-of-interest.
+ */
+ private static void validateImageProcessingOptions(
+ ImageProcessingOptions imageProcessingOptions) {
+ if (imageProcessingOptions.regionOfInterest().isPresent()) {
+ throw new IllegalArgumentException("FaceDetector doesn't support region-of-interest.");
+ }
+ }
+}
diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/facedetector/AndroidManifest.xml b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/facedetector/AndroidManifest.xml
new file mode 100644
index 000000000..01cbc3a6f
--- /dev/null
+++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/facedetector/AndroidManifest.xml
@@ -0,0 +1,24 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/facedetector/BUILD b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/facedetector/BUILD
new file mode 100644
index 000000000..c14486766
--- /dev/null
+++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/facedetector/BUILD
@@ -0,0 +1,19 @@
+# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+package(default_visibility = ["//mediapipe/tasks:internal"])
+
+licenses(["notice"])
+
+# TODO: Enable this in OSS
diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/facedetector/FaceDetectorTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/facedetector/FaceDetectorTest.java
new file mode 100644
index 000000000..d995accd5
--- /dev/null
+++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/facedetector/FaceDetectorTest.java
@@ -0,0 +1,455 @@
+// Copyright 2023 The MediaPipe Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package com.google.mediapipe.tasks.vision.facedetector;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.junit.Assert.assertThrows;
+
+import android.content.res.AssetManager;
+import android.graphics.BitmapFactory;
+import android.graphics.RectF;
+import androidx.test.core.app.ApplicationProvider;
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import com.google.mediapipe.framework.MediaPipeException;
+import com.google.mediapipe.framework.image.BitmapImageBuilder;
+import com.google.mediapipe.framework.image.MPImage;
+import com.google.mediapipe.tasks.components.containers.NormalizedKeypoint;
+import com.google.mediapipe.tasks.core.BaseOptions;
+import com.google.mediapipe.tasks.core.TestUtils;
+import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions;
+import com.google.mediapipe.tasks.vision.core.RunningMode;
+import com.google.mediapipe.tasks.vision.facedetector.FaceDetector.FaceDetectorOptions;
+import java.io.InputStream;
+import java.nio.ByteBuffer;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Suite;
+import org.junit.runners.Suite.SuiteClasses;
+
+/** Test for {@link FaceDetector}. */
+@RunWith(Suite.class)
+@SuiteClasses({FaceDetectorTest.General.class, FaceDetectorTest.RunningModeTest.class})
+public class FaceDetectorTest {
+ private static final String MODEL_FILE = "face_detection_short_range.tflite";
+ private static final String CAT_IMAGE = "cat.jpg";
+ private static final String PORTRAIT_IMAGE = "portrait.jpg";
+ private static final String PORTRAIT_ROTATED_IMAGE = "portrait_rotated.jpg";
+ private static final float KEYPOINTS_DIFF_TOLERANCE = 0.01f;
+ private static final float PIXEL_DIFF_TOLERANCE = 5.0f;
+ private static final RectF PORTRAIT_FACE_BOUNDING_BOX = new RectF(283, 115, 514, 349);
+ private static final List PORTRAIT_FACE_KEYPOINTS =
+ Collections.unmodifiableList(
+ Arrays.asList(
+ NormalizedKeypoint.create(0.44416f, 0.17643f),
+ NormalizedKeypoint.create(0.55514f, 0.17731f),
+ NormalizedKeypoint.create(0.50467f, 0.22657f),
+ NormalizedKeypoint.create(0.50227f, 0.27199f),
+ NormalizedKeypoint.create(0.36063f, 0.20143f),
+ NormalizedKeypoint.create(0.60841f, 0.20409f)));
+ private static final RectF PORTRAIT_ROTATED_FACE_BOUNDING_BOX = new RectF(674, 283, 910, 519);
+ private static final List PORTRAIT_ROTATED_FACE_KEYPOINTS =
+ Collections.unmodifiableList(
+ Arrays.asList(
+ NormalizedKeypoint.create(0.82075f, 0.44679f),
+ NormalizedKeypoint.create(0.81965f, 0.56261f),
+ NormalizedKeypoint.create(0.76194f, 0.51719f),
+ NormalizedKeypoint.create(0.71993f, 0.51360f),
+ NormalizedKeypoint.create(0.80700f, 0.36298f),
+ NormalizedKeypoint.create(0.80882f, 0.61204f)));
+
+ @RunWith(AndroidJUnit4.class)
+ public static final class General extends FaceDetectorTest {
+
+ @Test
+ public void detect_successWithValidModels() throws Exception {
+ FaceDetectorOptions options =
+ FaceDetectorOptions.builder()
+ .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
+ .build();
+ FaceDetector faceDetector =
+ FaceDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
+ FaceDetectorResult results = faceDetector.detect(getImageFromAsset(PORTRAIT_IMAGE));
+ assertContainsSinglePortraitFace(
+ results, PORTRAIT_FACE_BOUNDING_BOX, PORTRAIT_FACE_KEYPOINTS);
+ }
+
+ @Test
+ public void detect_succeedsWithMinDetectionConfidence() throws Exception {
+ FaceDetectorOptions options =
+ FaceDetectorOptions.builder()
+ .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
+ .setMinDetectionConfidence(1.0f)
+ .build();
+ FaceDetector faceDetector =
+ FaceDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
+ FaceDetectorResult results = faceDetector.detect(getImageFromAsset(PORTRAIT_IMAGE));
+ // Set minDetectionConfidence to 1.0, so the detected face should be all filtered out.
+ assertThat(results.detections().isEmpty()).isTrue();
+ }
+
+ @Test
+ public void detect_succeedsWithEmptyFace() throws Exception {
+ FaceDetectorOptions options =
+ FaceDetectorOptions.builder()
+ .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
+ .setMinDetectionConfidence(1.0f)
+ .build();
+ FaceDetector faceDetector =
+ FaceDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
+ FaceDetectorResult results = faceDetector.detect(getImageFromAsset(CAT_IMAGE));
+ assertThat(results.detections().isEmpty()).isTrue();
+ }
+
+ @Test
+ public void detect_succeedsWithModelFileObject() throws Exception {
+ FaceDetector faceDetector =
+ FaceDetector.createFromFile(
+ ApplicationProvider.getApplicationContext(),
+ TestUtils.loadFile(ApplicationProvider.getApplicationContext(), MODEL_FILE));
+ FaceDetectorResult results = faceDetector.detect(getImageFromAsset(PORTRAIT_IMAGE));
+ assertContainsSinglePortraitFace(
+ results, PORTRAIT_FACE_BOUNDING_BOX, PORTRAIT_FACE_KEYPOINTS);
+ }
+
+ @Test
+ public void detect_succeedsWithModelBuffer() throws Exception {
+ FaceDetector faceDetector =
+ FaceDetector.createFromBuffer(
+ ApplicationProvider.getApplicationContext(),
+ TestUtils.loadToDirectByteBuffer(
+ ApplicationProvider.getApplicationContext(), MODEL_FILE));
+ FaceDetectorResult results = faceDetector.detect(getImageFromAsset(PORTRAIT_IMAGE));
+ assertContainsSinglePortraitFace(
+ results, PORTRAIT_FACE_BOUNDING_BOX, PORTRAIT_FACE_KEYPOINTS);
+ }
+
+ @Test
+ public void detect_succeedsWithModelBufferAndOptions() throws Exception {
+ FaceDetectorOptions options =
+ FaceDetectorOptions.builder()
+ .setBaseOptions(
+ BaseOptions.builder()
+ .setModelAssetBuffer(
+ TestUtils.loadToDirectByteBuffer(
+ ApplicationProvider.getApplicationContext(), MODEL_FILE))
+ .build())
+ .build();
+ FaceDetector faceDetector =
+ FaceDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
+ FaceDetectorResult results = faceDetector.detect(getImageFromAsset(PORTRAIT_IMAGE));
+ assertContainsSinglePortraitFace(
+ results, PORTRAIT_FACE_BOUNDING_BOX, PORTRAIT_FACE_KEYPOINTS);
+ }
+
+ @Test
+ public void create_failsWithMissingModel() throws Exception {
+ String nonexistentFile = "/path/to/non/existent/file";
+ MediaPipeException exception =
+ assertThrows(
+ MediaPipeException.class,
+ () ->
+ FaceDetector.createFromFile(
+ ApplicationProvider.getApplicationContext(), nonexistentFile));
+ assertThat(exception).hasMessageThat().contains(nonexistentFile);
+ }
+
+ @Test
+ public void create_failsWithInvalidModelBuffer() throws Exception {
+ // Create a non-direct model ByteBuffer.
+ ByteBuffer modelBuffer =
+ TestUtils.loadToNonDirectByteBuffer(
+ ApplicationProvider.getApplicationContext(), MODEL_FILE);
+
+ IllegalArgumentException exception =
+ assertThrows(
+ IllegalArgumentException.class,
+ () ->
+ FaceDetector.createFromBuffer(
+ ApplicationProvider.getApplicationContext(), modelBuffer));
+
+ assertThat(exception)
+ .hasMessageThat()
+ .contains("The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
+ }
+
+ @Test
+ public void detect_succeedsWithRotation() throws Exception {
+ FaceDetectorOptions options =
+ FaceDetectorOptions.builder()
+ .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
+ .build();
+ FaceDetector faceDetector =
+ FaceDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
+ ImageProcessingOptions imageProcessingOptions =
+ ImageProcessingOptions.builder().setRotationDegrees(-90).build();
+ FaceDetectorResult results =
+ faceDetector.detect(getImageFromAsset(PORTRAIT_ROTATED_IMAGE), imageProcessingOptions);
+ assertContainsSinglePortraitFace(
+ results, PORTRAIT_ROTATED_FACE_BOUNDING_BOX, PORTRAIT_ROTATED_FACE_KEYPOINTS);
+ }
+
+ @Test
+ public void detect_failsWithRegionOfInterest() throws Exception {
+ FaceDetectorOptions options =
+ FaceDetectorOptions.builder()
+ .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
+ .build();
+ FaceDetector faceDetector =
+ FaceDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
+ ImageProcessingOptions imageProcessingOptions =
+ ImageProcessingOptions.builder().setRegionOfInterest(new RectF(0, 0, 1, 1)).build();
+ IllegalArgumentException exception =
+ assertThrows(
+ IllegalArgumentException.class,
+ () -> faceDetector.detect(getImageFromAsset(PORTRAIT_IMAGE), imageProcessingOptions));
+ assertThat(exception)
+ .hasMessageThat()
+ .contains("FaceDetector doesn't support region-of-interest");
+ }
+ }
+
+ @RunWith(AndroidJUnit4.class)
+ public static final class RunningModeTest extends FaceDetectorTest {
+
+ @Test
+ public void create_failsWithIllegalResultListenerInNonLiveStreamMode() throws Exception {
+ for (RunningMode mode : new RunningMode[] {RunningMode.IMAGE, RunningMode.VIDEO}) {
+ IllegalArgumentException exception =
+ assertThrows(
+ IllegalArgumentException.class,
+ () ->
+ FaceDetectorOptions.builder()
+ .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
+ .setRunningMode(mode)
+ .setResultListener((faceDetectorResult, inputImage) -> {})
+ .build());
+ assertThat(exception)
+ .hasMessageThat()
+ .contains("a user-defined result listener shouldn't be provided");
+ }
+ }
+
+ @Test
+ public void create_failsWithMissingResultListenerInLiveSteamMode() throws Exception {
+ IllegalArgumentException exception =
+ assertThrows(
+ IllegalArgumentException.class,
+ () ->
+ FaceDetectorOptions.builder()
+ .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
+ .setRunningMode(RunningMode.LIVE_STREAM)
+ .build());
+ assertThat(exception)
+ .hasMessageThat()
+ .contains("a user-defined result listener must be provided");
+ }
+
+ @Test
+ public void detect_failsWithCallingWrongApiInImageMode() throws Exception {
+ FaceDetectorOptions options =
+ FaceDetectorOptions.builder()
+ .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
+ .setRunningMode(RunningMode.IMAGE)
+ .build();
+
+ FaceDetector faceDetector =
+ FaceDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
+ MediaPipeException exception =
+ assertThrows(
+ MediaPipeException.class,
+ () ->
+ faceDetector.detectForVideo(
+ getImageFromAsset(PORTRAIT_IMAGE), /* timestampsMs= */ 0));
+ assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
+ exception =
+ assertThrows(
+ MediaPipeException.class,
+ () ->
+ faceDetector.detectAsync(
+ getImageFromAsset(PORTRAIT_IMAGE), /* timestampsMs= */ 0));
+ assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
+ }
+
+ @Test
+ public void detect_failsWithCallingWrongApiInVideoMode() throws Exception {
+ FaceDetectorOptions options =
+ FaceDetectorOptions.builder()
+ .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
+ .setRunningMode(RunningMode.VIDEO)
+ .build();
+
+ FaceDetector faceDetector =
+ FaceDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
+ MediaPipeException exception =
+ assertThrows(
+ MediaPipeException.class,
+ () -> faceDetector.detect(getImageFromAsset(PORTRAIT_IMAGE)));
+ assertThat(exception).hasMessageThat().contains("not initialized with the image mode");
+ exception =
+ assertThrows(
+ MediaPipeException.class,
+ () ->
+ faceDetector.detectAsync(
+ getImageFromAsset(PORTRAIT_IMAGE), /* timestampsMs= */ 0));
+ assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
+ }
+
+ @Test
+ public void detect_failsWithCallingWrongApiInLiveSteamMode() throws Exception {
+ FaceDetectorOptions options =
+ FaceDetectorOptions.builder()
+ .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
+ .setRunningMode(RunningMode.LIVE_STREAM)
+ .setResultListener((faceDetectorResult, inputImage) -> {})
+ .build();
+
+ FaceDetector faceDetector =
+ FaceDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
+
+ MediaPipeException exception =
+ assertThrows(
+ MediaPipeException.class,
+ () -> faceDetector.detect(getImageFromAsset(PORTRAIT_IMAGE)));
+ assertThat(exception).hasMessageThat().contains("not initialized with the image mode");
+ exception =
+ assertThrows(
+ MediaPipeException.class,
+ () ->
+ faceDetector.detectForVideo(
+ getImageFromAsset(PORTRAIT_IMAGE), /* timestampsMs= */ 0));
+ assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
+ }
+
+ @Test
+ public void detect_successWithImageMode() throws Exception {
+ FaceDetectorOptions options =
+ FaceDetectorOptions.builder()
+ .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
+ .setRunningMode(RunningMode.IMAGE)
+ .build();
+ FaceDetector faceDetector =
+ FaceDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
+ FaceDetectorResult results = faceDetector.detect(getImageFromAsset(PORTRAIT_IMAGE));
+ assertContainsSinglePortraitFace(
+ results, PORTRAIT_FACE_BOUNDING_BOX, PORTRAIT_FACE_KEYPOINTS);
+ }
+
+ @Test
+ public void detect_successWithVideoMode() throws Exception {
+ FaceDetectorOptions options =
+ FaceDetectorOptions.builder()
+ .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
+ .setRunningMode(RunningMode.VIDEO)
+ .build();
+ FaceDetector faceDetector =
+ FaceDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
+ for (int i = 0; i < 3; i++) {
+ FaceDetectorResult results =
+ faceDetector.detectForVideo(getImageFromAsset(PORTRAIT_IMAGE), /* timestampsMs= */ i);
+ assertContainsSinglePortraitFace(
+ results, PORTRAIT_FACE_BOUNDING_BOX, PORTRAIT_FACE_KEYPOINTS);
+ }
+ }
+
+ @Test
+ public void detect_failsWithOutOfOrderInputTimestamps() throws Exception {
+ MPImage image = getImageFromAsset(PORTRAIT_IMAGE);
+ FaceDetectorOptions options =
+ FaceDetectorOptions.builder()
+ .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
+ .setRunningMode(RunningMode.LIVE_STREAM)
+ .setResultListener(
+ (faceDetectorResult, inputImage) -> {
+ assertContainsSinglePortraitFace(
+ faceDetectorResult, PORTRAIT_FACE_BOUNDING_BOX, PORTRAIT_FACE_KEYPOINTS);
+ })
+ .build();
+ try (FaceDetector faceDetector =
+ FaceDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
+ faceDetector.detectAsync(image, /* timestampsMs= */ 1);
+ MediaPipeException exception =
+ assertThrows(
+ MediaPipeException.class,
+ () -> faceDetector.detectAsync(image, /* timestampsMs= */ 0));
+ assertThat(exception)
+ .hasMessageThat()
+ .contains("having a smaller timestamp than the processed timestamp");
+ }
+ }
+
+ @Test
+ public void detect_successWithLiveSteamMode() throws Exception {
+ MPImage image = getImageFromAsset(PORTRAIT_IMAGE);
+ FaceDetectorOptions options =
+ FaceDetectorOptions.builder()
+ .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
+ .setRunningMode(RunningMode.LIVE_STREAM)
+ .setResultListener(
+ (faceDetectorResult, inputImage) -> {
+ assertContainsSinglePortraitFace(
+ faceDetectorResult, PORTRAIT_FACE_BOUNDING_BOX, PORTRAIT_FACE_KEYPOINTS);
+ })
+ .build();
+ try (FaceDetector faceDetector =
+ FaceDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
+ for (int i = 0; i < 3; i++) {
+ faceDetector.detectAsync(image, /* timestampsMs= */ i);
+ }
+ }
+ }
+ }
+
+ private static MPImage getImageFromAsset(String filePath) throws Exception {
+ AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets();
+ InputStream istr = assetManager.open(filePath);
+ return new BitmapImageBuilder(BitmapFactory.decodeStream(istr)).build();
+ }
+
+ private static void assertContainsSinglePortraitFace(
+ FaceDetectorResult results,
+ RectF expectedboundingBox,
+ List expectedKeypoints) {
+ assertThat(results.detections()).hasSize(1);
+ assertApproximatelyEqualBoundingBoxes(
+ results.detections().get(0).boundingBox(), expectedboundingBox);
+ assertThat(results.detections().get(0).keypoints().isPresent()).isTrue();
+ assertApproximatelyEqualKeypoints(
+ results.detections().get(0).keypoints().get(), expectedKeypoints);
+ }
+
+ private static void assertApproximatelyEqualBoundingBoxes(
+ RectF boundingBox1, RectF boundingBox2) {
+ assertThat(boundingBox1.left).isWithin(PIXEL_DIFF_TOLERANCE).of(boundingBox2.left);
+ assertThat(boundingBox1.top).isWithin(PIXEL_DIFF_TOLERANCE).of(boundingBox2.top);
+ assertThat(boundingBox1.right).isWithin(PIXEL_DIFF_TOLERANCE).of(boundingBox2.right);
+ assertThat(boundingBox1.bottom).isWithin(PIXEL_DIFF_TOLERANCE).of(boundingBox2.bottom);
+ }
+
+ private static void assertApproximatelyEqualKeypoints(
+ List keypoints1, List keypoints2) {
+ assertThat(keypoints1.size()).isEqualTo(keypoints2.size());
+ for (int i = 0; i < keypoints1.size(); i++) {
+ assertThat(keypoints1.get(i).x())
+ .isWithin(KEYPOINTS_DIFF_TOLERANCE)
+ .of(keypoints2.get(i).x());
+ assertThat(keypoints1.get(i).y())
+ .isWithin(KEYPOINTS_DIFF_TOLERANCE)
+ .of(keypoints2.get(i).y());
+ }
+ }
+}