diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.h b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.h index 17c9cc921..53b824e25 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.h +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.h @@ -87,9 +87,8 @@ struct GestureRecognizerOptions { // Performs hand gesture recognition on the given image. // // TODO add the link to DevSite. -// This API expects expects a pre-trained hand gesture model asset bundle, or a -// custom one created using Model Maker. See . +// This API expects a pre-trained hand gesture model asset bundle, or a custom +// one created using Model Maker. See . // // Inputs: // Image diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BUILD index 8df9173b2..453ae9a90 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BUILD @@ -40,6 +40,7 @@ cc_binary( deps = [ "//mediapipe/calculators/core:flow_limiter_calculator", "//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni", + "//mediapipe/tasks/cc/vision/gesture_recognizer:gesture_recognizer_graph", "//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph", "//mediapipe/tasks/cc/vision/object_detector:object_detector_graph", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/BUILD index eb3eca52b..7782a747e 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/BUILD @@ -20,6 +20,7 @@ android_library( name = "gesturerecognizer", srcs = [ "GestureRecognitionResult.java", + "GestureRecognizer.java", ], javacopts = [ "-Xep:AndroidJdkLibsChecker:OFF", @@ -29,11 +30,19 @@ android_library( "//mediapipe/framework:calculator_options_java_proto_lite", "//mediapipe/framework/formats:classification_java_proto_lite", "//mediapipe/framework/formats:landmark_java_proto_lite", + "//mediapipe/java/com/google/mediapipe/framework:android_framework", + "//mediapipe/java/com/google/mediapipe/framework/image", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:landmark", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core", "//third_party:autovalue", "@maven//:com_google_guava_guava", ], diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java new file mode 100644 index 000000000..e429cc6dc --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java @@ -0,0 +1,466 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.gesturerecognizer; + +import android.content.Context; +import android.os.ParcelFileDescriptor; +import com.google.auto.value.AutoValue; +import com.google.mediapipe.formats.proto.LandmarkProto.LandmarkList; +import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmarkList; +import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions; +import com.google.mediapipe.formats.proto.ClassificationProto.ClassificationList; +import com.google.mediapipe.framework.AndroidPacketGetter; +import com.google.mediapipe.framework.Packet; +import com.google.mediapipe.framework.PacketGetter; +import com.google.mediapipe.framework.image.BitmapImageBuilder; +import com.google.mediapipe.framework.image.Image; +import com.google.mediapipe.tasks.components.processors.proto.ClassifierOptionsProto; +import com.google.mediapipe.tasks.core.BaseOptions; +import com.google.mediapipe.tasks.core.ErrorListener; +import com.google.mediapipe.tasks.core.OutputHandler; +import com.google.mediapipe.tasks.core.OutputHandler.ResultListener; +import com.google.mediapipe.tasks.core.TaskInfo; +import com.google.mediapipe.tasks.core.TaskOptions; +import com.google.mediapipe.tasks.core.TaskRunner; +import com.google.mediapipe.tasks.core.proto.BaseOptionsProto; +import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi; +import com.google.mediapipe.tasks.vision.core.RunningMode; +import com.google.mediapipe.tasks.vision.handdetector.HandDetectorGraphOptionsProto; +import com.google.mediapipe.tasks.vision.handlandmarker.HandLandmarkerGraphOptionsProto; +import com.google.mediapipe.tasks.vision.handlandmarker.HandLandmarksDetectorGraphOptionsProto; +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +/** + * Performs gesture recognition on images. + * + *

This API expects a pre-trained hand gesture model asset bundle, or a custom one created using + * Model Maker. See . + * + *

+ */ +public final class GestureRecognizer extends BaseVisionTaskApi { + private static final String TAG = GestureRecognizer.class.getSimpleName(); + private static final String IMAGE_IN_STREAM_NAME = "image_in"; + private static final List INPUT_STREAMS = + Collections.unmodifiableList(Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME)); + private static final List OUTPUT_STREAMS = + Collections.unmodifiableList( + Arrays.asList( + "LANDMARKS:hand_landmarks", + "WORLD_LANDMARKS:world_hand_landmarks", + "HANDEDNESS:handedness", + "HAND_GESTURES:hand_gestures", + "IMAGE:image_out")); + private static final int LANDMARKS_OUT_STREAM_INDEX = 0; + private static final int WORLD_LANDMARKS_OUT_STREAM_INDEX = 1; + private static final int HANDEDNESS_OUT_STREAM_INDEX = 2; + private static final int HAND_GESTURES_OUT_STREAM_INDEX = 3; + private static final int IMAGE_OUT_STREAM_INDEX = 4; + private static final String TASK_GRAPH_NAME = + "mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph"; + + /** + * Creates a {@link GestureRecognizer} instance from a model file and the default {@link + * GestureRecognizerOptions}. + * + * @param context an Android {@link Context}. + * @param modelPath path to the gesture recognition model with metadata in the assets. + * @throws MediaPipeException if there is an error during {@link GestureRecognizer} creation. + */ + public static GestureRecognizer createFromFile(Context context, String modelPath) { + BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(modelPath).build(); + return createFromOptions( + context, GestureRecognizerOptions.builder().setBaseOptions(baseOptions).build()); + } + + /** + * Creates a {@link GestureRecognizer} instance from a model file and the default {@link + * GestureRecognizerOptions}. + * + * @param context an Android {@link Context}. + * @param modelFile the gesture recognition model {@link File} instance. + * @throws IOException if an I/O error occurs when opening the tflite model file. + * @throws MediaPipeException if there is an error during {@link GestureRecognizer} creation. + */ + public static GestureRecognizer createFromFile(Context context, File modelFile) + throws IOException { + try (ParcelFileDescriptor descriptor = + ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) { + BaseOptions baseOptions = + BaseOptions.builder().setModelAssetFileDescriptor(descriptor.getFd()).build(); + return createFromOptions( + context, GestureRecognizerOptions.builder().setBaseOptions(baseOptions).build()); + } + } + + /** + * Creates a {@link GestureRecognizer} instance from a model buffer and the default {@link + * GestureRecognizerOptions}. + * + * @param context an Android {@link Context}. + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the detection + * model. + * @throws MediaPipeException if there is an error during {@link GestureRecognizer} creation. + */ + public static GestureRecognizer createFromBuffer(Context context, final ByteBuffer modelBuffer) { + BaseOptions baseOptions = BaseOptions.builder().setModelAssetBuffer(modelBuffer).build(); + return createFromOptions( + context, GestureRecognizerOptions.builder().setBaseOptions(baseOptions).build()); + } + + /** + * Creates a {@link GestureRecognizer} instance from a {@link GestureRecognizerOptions}. + * + * @param context an Android {@link Context}. + * @param recognizerOptions a {@link GestureRecognizerOptions} instance. + * @throws MediaPipeException if there is an error during {@link GestureRecognizer} creation. + */ + public static GestureRecognizer createFromOptions( + Context context, GestureRecognizerOptions recognizerOptions) { + // TODO: Consolidate OutputHandler and TaskRunner. + OutputHandler handler = new OutputHandler<>(); + handler.setOutputPacketConverter( + new OutputHandler.OutputPacketConverter() { + @Override + public GestureRecognitionResult convertToTaskResult(List packets) { + // If there is no hands detected in the image, just returns empty lists. + if (packets.get(HAND_GESTURES_OUT_STREAM_INDEX).isEmpty()) { + return GestureRecognitionResult.create( + new ArrayList<>(), + new ArrayList<>(), + new ArrayList<>(), + new ArrayList<>(), + packets.get(HAND_GESTURES_OUT_STREAM_INDEX).getTimestamp()); + } + return GestureRecognitionResult.create( + PacketGetter.getProtoVector( + packets.get(LANDMARKS_OUT_STREAM_INDEX), NormalizedLandmarkList.parser()), + PacketGetter.getProtoVector( + packets.get(WORLD_LANDMARKS_OUT_STREAM_INDEX), LandmarkList.parser()), + PacketGetter.getProtoVector( + packets.get(HANDEDNESS_OUT_STREAM_INDEX), ClassificationList.parser()), + PacketGetter.getProtoVector( + packets.get(HAND_GESTURES_OUT_STREAM_INDEX), ClassificationList.parser()), + packets.get(HAND_GESTURES_OUT_STREAM_INDEX).getTimestamp()); + } + + @Override + public Image convertToTaskInput(List packets) { + return new BitmapImageBuilder( + AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX))) + .build(); + } + }); + recognizerOptions.resultListener().ifPresent(handler::setResultListener); + recognizerOptions.errorListener().ifPresent(handler::setErrorListener); + TaskRunner runner = + TaskRunner.create( + context, + TaskInfo.builder() + .setTaskGraphName(TASK_GRAPH_NAME) + .setInputStreams(INPUT_STREAMS) + .setOutputStreams(OUTPUT_STREAMS) + .setTaskOptions(recognizerOptions) + .setEnableFlowLimiting(recognizerOptions.runningMode() == RunningMode.LIVE_STREAM) + .build(), + handler); + return new GestureRecognizer(runner, recognizerOptions.runningMode()); + } + + /** + * Constructor to initialize an {@link GestureRecognizer} from a {@link TaskRunner} and a {@link + * RunningMode}. + * + * @param taskRunner a {@link TaskRunner}. + * @param runningMode a mediapipe vision task {@link RunningMode}. + */ + private GestureRecognizer(TaskRunner taskRunner, RunningMode runningMode) { + super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME); + } + + /** + * Performs gesture recognition on the provided single image. Only use this method when the {@link + * GestureRecognizer} is created with {@link RunningMode.IMAGE}. TODO update java doc + * for input image format. + * + *

{@link GestureRecognizer} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param inputImage a MediaPipe {@link Image} object for processing. + * @throws MediaPipeException if there is an internal error. + */ + public GestureRecognitionResult recognize(Image inputImage) { + return (GestureRecognitionResult) processImageData(inputImage); + } + + /** + * Performs gesture recognition on the provided video frame. Only use this method when the {@link + * GestureRecognizer} is created with {@link RunningMode.VIDEO}. + * + *

It's required to provide the video frame's timestamp (in milliseconds). The input timestamps + * must be monotonically increasing. + * + *

{@link GestureRecognizer} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param inputImage a MediaPipe {@link Image} object for processing. + * @param inputTimestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public GestureRecognitionResult recognizeForVideo(Image inputImage, long inputTimestampMs) { + return (GestureRecognitionResult) processVideoData(inputImage, inputTimestampMs); + } + + /** + * Sends live image data to perform gesture recognition, and the results will be available via the + * {@link ResultListener} provided in the {@link GestureRecognizerOptions}. Only use this method + * when the {@link GestureRecognition} is created with {@link RunningMode.LIVE_STREAM}. + * + *

It's required to provide a timestamp (in milliseconds) to indicate when the input image is + * sent to the gesture recognizer. The input timestamps must be monotonically increasing. + * + *

{@link GestureRecognizer} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param inputImage a MediaPipe {@link Image} object for processing. + * @param inputTimestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public void recognizeAsync(Image inputImage, long inputTimestampMs) { + sendLiveStreamData(inputImage, inputTimestampMs); + } + + /** Options for setting up an {@link GestureRecognizer}. */ + @AutoValue + public abstract static class GestureRecognizerOptions extends TaskOptions { + + /** Builder for {@link GestureRecognizerOptions}. */ + @AutoValue.Builder + public abstract static class Builder { + /** Sets the base options for the gesture recognizer task. */ + public abstract Builder setBaseOptions(BaseOptions value); + + /** + * Sets the running mode for the gesture recognizer task. Default to the image mode. Gesture + * recognizer has three modes: + * + *
    + *
  • IMAGE: The mode for recognizing gestures on single image inputs. + *
  • VIDEO: The mode for recognizing gestures on the decoded frames of a video. + *
  • LIVE_STREAM: The mode for for recognizing gestures on a live stream of input data, + * such as from camera. In this mode, {@code setResultListener} must be called to set up + * a listener to receive the recognition results asynchronously. + *
+ */ + public abstract Builder setRunningMode(RunningMode value); + + // TODO: remove these. Temporary solutions before bundle asset is ready. + public abstract Builder setBaseOptionsHandDetector(BaseOptions value); + + public abstract Builder setBaseOptionsHandLandmarker(BaseOptions value); + + public abstract Builder setBaseOptionsGestureRecognizer(BaseOptions value); + + /** Sets the maximum number of hands can be detected by the GestureRecognizer. */ + public abstract Builder setNumHands(Integer value); + + /** Sets minimum confidence score for the hand detection to be considered successfully */ + public abstract Builder setMinHandDetectionConfidence(Float value); + + /** Sets minimum confidence score of hand presence score in the hand landmark detection. */ + public abstract Builder setMinHandPresenceConfidence(Float value); + + /** Sets the minimum confidence score for the hand tracking to be considered successfully. */ + public abstract Builder setMinTrackingConfidence(Float value); + + /** + * Sets the minimum confidence score for the gestures to be considered successfully. If < 0, + * the gesture confidence threshold=0.5 for the model is used. + * + *

TODO Note this option is subject to change, after scoring merging + * calculator is implemented. + */ + public abstract Builder setMinGestureConfidence(Float value); + + /** + * Sets the result listener to receive the detection results asynchronously when the gesture + * recognizer is in the live stream mode. + */ + public abstract Builder setResultListener( + ResultListener value); + + /** Sets an optional error listener. */ + public abstract Builder setErrorListener(ErrorListener value); + + abstract GestureRecognizerOptions autoBuild(); + + /** + * Validates and builds the {@link GestureRecognizerOptions} instance. + * + * @throws IllegalArgumentException if the result listener and the running mode are not + * properly configured. The result listener should only be set when the object detector is + * in the live stream mode. + */ + public final GestureRecognizerOptions build() { + GestureRecognizerOptions options = autoBuild(); + if (options.runningMode() == RunningMode.LIVE_STREAM) { + if (!options.resultListener().isPresent()) { + throw new IllegalArgumentException( + "The gesture recognizer is in the live stream mode, a user-defined result listener" + + " must be provided in GestureRecognizerOptions."); + } + } else if (options.resultListener().isPresent()) { + throw new IllegalArgumentException( + "The gesture recognizer is in the image or the video mode, a user-defined result" + + " listener shouldn't be provided in GestureRecognizerOptions."); + } + return options; + } + } + + abstract BaseOptions baseOptions(); + + // TODO: remove these. Temporary solutions before bundle asset is ready. + abstract BaseOptions baseOptionsHandDetector(); + + abstract BaseOptions baseOptionsHandLandmarker(); + + abstract BaseOptions baseOptionsGestureRecognizer(); + + abstract RunningMode runningMode(); + + abstract Optional numHands(); + + abstract Optional minHandDetectionConfidence(); + + abstract Optional minHandPresenceConfidence(); + + abstract Optional minTrackingConfidence(); + + // TODO update gesture confidence options after score merging calculator is ready. + abstract Optional minGestureConfidence(); + + abstract Optional> resultListener(); + + abstract Optional errorListener(); + + public static Builder builder() { + return new AutoValue_GestureRecognizer_GestureRecognizerOptions.Builder() + .setRunningMode(RunningMode.IMAGE) + .setNumHands(1) + .setMinHandDetectionConfidence(0.5f) + .setMinHandPresenceConfidence(0.5f) + .setMinTrackingConfidence(0.5f) + .setMinGestureConfidence(-1f); + } + + /** + * Converts a {@link GestureRecognizerOptions} to a {@link CalculatorOptions} protobuf message. + */ + @Override + public CalculatorOptions convertToCalculatorOptionsProto() { + BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder = + BaseOptionsProto.BaseOptions.newBuilder() + .setUseStreamMode(runningMode() != RunningMode.IMAGE) + .mergeFrom(convertBaseOptionsToProto(baseOptions())); + GestureRecognizerGraphOptionsProto.GestureRecognizerGraphOptions.Builder taskOptionsBuilder = + GestureRecognizerGraphOptionsProto.GestureRecognizerGraphOptions.newBuilder() + .setBaseOptions(baseOptionsBuilder); + + // Setup HandDetectorGraphOptions. + HandDetectorGraphOptionsProto.HandDetectorGraphOptions.Builder + handDetectorGraphOptionsBuilder = + HandDetectorGraphOptionsProto.HandDetectorGraphOptions.newBuilder() + .setBaseOptions( + BaseOptionsProto.BaseOptions.newBuilder() + .setUseStreamMode(runningMode() != RunningMode.IMAGE) + .mergeFrom(convertBaseOptionsToProto(baseOptionsHandDetector()))); + numHands().ifPresent(handDetectorGraphOptionsBuilder::setNumHands); + minHandDetectionConfidence() + .ifPresent(handDetectorGraphOptionsBuilder::setMinDetectionConfidence); + + // Setup HandLandmarkerGraphOptions. + HandLandmarksDetectorGraphOptionsProto.HandLandmarksDetectorGraphOptions.Builder + handLandmarksDetectorGraphOptionsBuilder = + HandLandmarksDetectorGraphOptionsProto.HandLandmarksDetectorGraphOptions.newBuilder() + .setBaseOptions( + BaseOptionsProto.BaseOptions.newBuilder() + .setUseStreamMode(runningMode() != RunningMode.IMAGE) + .mergeFrom(convertBaseOptionsToProto(baseOptionsHandLandmarker()))); + minHandPresenceConfidence() + .ifPresent(handLandmarksDetectorGraphOptionsBuilder::setMinDetectionConfidence); + HandLandmarkerGraphOptionsProto.HandLandmarkerGraphOptions.Builder + handLandmarkerGraphOptionsBuilder = + HandLandmarkerGraphOptionsProto.HandLandmarkerGraphOptions.newBuilder() + .setBaseOptions( + BaseOptionsProto.BaseOptions.newBuilder() + .setUseStreamMode(runningMode() != RunningMode.IMAGE) + .mergeFrom(convertBaseOptionsToProto(baseOptionsHandLandmarker()))); + minTrackingConfidence() + .ifPresent(handLandmarkerGraphOptionsBuilder::setMinTrackingConfidence); + handLandmarkerGraphOptionsBuilder + .setHandDetectorGraphOptions(handDetectorGraphOptionsBuilder.build()) + .setHandLandmarksDetectorGraphOptions(handLandmarksDetectorGraphOptionsBuilder.build()); + + // Setup HandGestureRecognizerGraphOptions. + HandGestureRecognizerGraphOptionsProto.HandGestureRecognizerGraphOptions.Builder + handGestureRecognizerGraphOptionsBuilder = + HandGestureRecognizerGraphOptionsProto.HandGestureRecognizerGraphOptions.newBuilder() + .setBaseOptions( + BaseOptionsProto.BaseOptions.newBuilder() + .setUseStreamMode(runningMode() != RunningMode.IMAGE) + .mergeFrom(convertBaseOptionsToProto(baseOptionsGestureRecognizer()))); + ClassifierOptionsProto.ClassifierOptions.Builder classifierOptionsBuilder = + ClassifierOptionsProto.ClassifierOptions.newBuilder(); + minGestureConfidence().ifPresent(classifierOptionsBuilder::setScoreThreshold); + handGestureRecognizerGraphOptionsBuilder.setClassifierOptions( + classifierOptionsBuilder.build()); + + taskOptionsBuilder + .setHandLandmarkerGraphOptions(handLandmarkerGraphOptionsBuilder.build()) + .setHandGestureRecognizerGraphOptions(handGestureRecognizerGraphOptionsBuilder.build()); + return CalculatorOptions.newBuilder() + .setExtension( + GestureRecognizerGraphOptionsProto.GestureRecognizerGraphOptions.ext, + taskOptionsBuilder.build()) + .build(); + } + } +} diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/AndroidManifest.xml b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/AndroidManifest.xml new file mode 100644 index 000000000..dd3ceb848 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/AndroidManifest.xml @@ -0,0 +1,24 @@ + + + + + + + + + + + + + + + diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/BUILD b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/BUILD new file mode 100644 index 000000000..a7f804c64 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/BUILD @@ -0,0 +1,19 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +# TODO: Enable this in OSS diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerTest.java new file mode 100644 index 000000000..efec02b2a --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerTest.java @@ -0,0 +1,495 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.gesturerecognizer; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import android.content.res.AssetManager; +import android.graphics.BitmapFactory; +import androidx.test.core.app.ApplicationProvider; +import androidx.test.ext.junit.runners.AndroidJUnit4; +import com.google.common.truth.Correspondence; +import com.google.mediapipe.formats.proto.ClassificationProto; +import com.google.mediapipe.framework.MediaPipeException; +import com.google.mediapipe.framework.image.BitmapImageBuilder; +import com.google.mediapipe.framework.image.Image; +import com.google.mediapipe.tasks.components.containers.Category; +import com.google.mediapipe.tasks.components.containers.Landmark; +import com.google.mediapipe.tasks.components.containers.proto.LandmarksDetectionResultProto.LandmarksDetectionResult; +import com.google.mediapipe.tasks.core.BaseOptions; +import com.google.mediapipe.tasks.vision.core.RunningMode; +import com.google.mediapipe.tasks.vision.gesturerecognizer.GestureRecognizer.GestureRecognizerOptions; +import java.io.InputStream; +import java.util.Arrays; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Suite; +import org.junit.runners.Suite.SuiteClasses; + +/** Test for {@link GestureRecognizer}. */ +@RunWith(Suite.class) +@SuiteClasses({GestureRecognizerTest.General.class, GestureRecognizerTest.RunningModeTest.class}) +public class GestureRecognizerTest { + private static final String HAND_DETECTOR_MODEL_FILE = "palm_detection_full.tflite"; + private static final String HAND_LANDMARKER_MODEL_FILE = "hand_landmark_full.tflite"; + private static final String GESTURE_RECOGNIZER_MODEL_FILE = + "cg_classifier_screen3d_landmark_features_nn_2022_08_04_base_simple_model.tflite"; + private static final String TWO_HANDS_IMAGE = "right_hands.jpg"; + private static final String THUMB_UP_IMAGE = "thumb_up.jpg"; + private static final String NO_HANDS_IMAGE = "cats_and_dogs.jpg"; + private static final String THUMB_UP_LANDMARKS = "thumb_up_landmarks.pb"; + private static final String TAG = "Gesture Recognizer Test"; + private static final String THUMB_UP_LABEL = "Thumb_Up"; + private static final int THUMB_UP_INDEX = 5; + private static final float LANDMARKS_ERROR_TOLERANCE = 0.03f; + private static final int IMAGE_WIDTH = 382; + private static final int IMAGE_HEIGHT = 406; + + @RunWith(AndroidJUnit4.class) + public static final class General extends GestureRecognizerTest { + + @Test + public void recognize_successWithValidModels() throws Exception { + GestureRecognizerOptions options = + GestureRecognizerOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + .setBaseOptionsHandDetector( + BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build()) + .setBaseOptionsHandLandmarker( + BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build()) + .setBaseOptionsGestureRecognizer( + BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + .build(); + GestureRecognizer gestureRecognizer = + GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options); + GestureRecognitionResult actualResult = + gestureRecognizer.recognize(getImageFromAsset(THUMB_UP_IMAGE)); + GestureRecognitionResult expectedResult = + getExpectedGestureRecognitionResult(THUMB_UP_LANDMARKS, THUMB_UP_LABEL, THUMB_UP_INDEX); + assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult); + } + + @Test + public void recognize_successWithEmptyResult() throws Exception { + GestureRecognizerOptions options = + GestureRecognizerOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + .setBaseOptionsHandDetector( + BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build()) + .setBaseOptionsHandLandmarker( + BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build()) + .setBaseOptionsGestureRecognizer( + BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + .build(); + GestureRecognizer gestureRecognizer = + GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options); + GestureRecognitionResult actualResult = + gestureRecognizer.recognize(getImageFromAsset(NO_HANDS_IMAGE)); + assertThat(actualResult.landmarks()).isEmpty(); + assertThat(actualResult.worldLandmarks()).isEmpty(); + assertThat(actualResult.handednesses()).isEmpty(); + assertThat(actualResult.gestures()).isEmpty(); + } + + @Test + public void recognize_successWithMinGestureConfidence() throws Exception { + GestureRecognizerOptions options = + GestureRecognizerOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + .setBaseOptionsHandDetector( + BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build()) + .setBaseOptionsHandLandmarker( + BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build()) + .setBaseOptionsGestureRecognizer( + BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + // TODO update the confidence to be in range [0,1] after embedding model + // and scoring calculator is integrated. + .setMinGestureConfidence(3.0f) + .build(); + GestureRecognizer gestureRecognizer = + GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options); + GestureRecognitionResult actualResult = + gestureRecognizer.recognize(getImageFromAsset(THUMB_UP_IMAGE)); + GestureRecognitionResult expectedResult = + getExpectedGestureRecognitionResult(THUMB_UP_LANDMARKS, THUMB_UP_LABEL, THUMB_UP_INDEX); + // Only contains one top scoring gesture. + assertThat(actualResult.gestures().get(0)).hasSize(1); + assertActualGestureEqualExpectedGesture( + actualResult.gestures().get(0).get(0), expectedResult.gestures().get(0).get(0)); + } + + @Test + public void recognize_successWithNumHands() throws Exception { + GestureRecognizerOptions options = + GestureRecognizerOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + .setBaseOptionsHandDetector( + BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build()) + .setBaseOptionsHandLandmarker( + BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build()) + .setBaseOptionsGestureRecognizer( + BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + .setNumHands(2) + .build(); + GestureRecognizer gestureRecognizer = + GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options); + GestureRecognitionResult actualResult = + gestureRecognizer.recognize(getImageFromAsset(TWO_HANDS_IMAGE)); + assertThat(actualResult.handednesses()).hasSize(2); + } + } + + @RunWith(AndroidJUnit4.class) + public static final class RunningModeTest extends GestureRecognizerTest { + @Test + public void create_failsWithIllegalResultListenerInNonLiveStreamMode() throws Exception { + for (RunningMode mode : new RunningMode[] {RunningMode.IMAGE, RunningMode.VIDEO}) { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + GestureRecognizerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE) + .build()) + .setBaseOptionsHandDetector( + BaseOptions.builder() + .setModelAssetPath(HAND_DETECTOR_MODEL_FILE) + .build()) + .setBaseOptionsHandLandmarker( + BaseOptions.builder() + .setModelAssetPath(HAND_LANDMARKER_MODEL_FILE) + .build()) + .setBaseOptionsGestureRecognizer( + BaseOptions.builder() + .setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE) + .build()) + .setRunningMode(mode) + .setResultListener((gestureRecognitionResult, inputImage) -> {}) + .build()); + assertThat(exception) + .hasMessageThat() + .contains("a user-defined result listener shouldn't be provided"); + } + } + } + + @Test + public void create_failsWithMissingResultListenerInLiveSteamMode() throws Exception { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + GestureRecognizerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE) + .build()) + .setBaseOptionsHandDetector( + BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build()) + .setBaseOptionsHandLandmarker( + BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build()) + .setBaseOptionsGestureRecognizer( + BaseOptions.builder() + .setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE) + .build()) + .setRunningMode(RunningMode.LIVE_STREAM) + .build()); + assertThat(exception) + .hasMessageThat() + .contains("a user-defined result listener must be provided"); + } + + @Test + public void recognize_failsWithCallingWrongApiInImageMode() throws Exception { + GestureRecognizerOptions options = + GestureRecognizerOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + .setBaseOptionsHandDetector( + BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build()) + .setBaseOptionsHandLandmarker( + BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build()) + .setBaseOptionsGestureRecognizer( + BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + .setRunningMode(RunningMode.IMAGE) + .build(); + + GestureRecognizer gestureRecognizer = + GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options); + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> gestureRecognizer.recognizeForVideo(getImageFromAsset(THUMB_UP_IMAGE), 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); + exception = + assertThrows( + MediaPipeException.class, + () -> gestureRecognizer.recognizeAsync(getImageFromAsset(THUMB_UP_IMAGE), 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); + } + + @Test + public void recognize_failsWithCallingWrongApiInVideoMode() throws Exception { + GestureRecognizerOptions options = + GestureRecognizerOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + .setBaseOptionsHandDetector( + BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build()) + .setBaseOptionsHandLandmarker( + BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build()) + .setBaseOptionsGestureRecognizer( + BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + .setRunningMode(RunningMode.VIDEO) + .build(); + + GestureRecognizer gestureRecognizer = + GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options); + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> gestureRecognizer.recognize(getImageFromAsset(THUMB_UP_IMAGE))); + assertThat(exception).hasMessageThat().contains("not initialized with the image mode"); + exception = + assertThrows( + MediaPipeException.class, + () -> gestureRecognizer.recognizeAsync(getImageFromAsset(THUMB_UP_IMAGE), 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); + } + + @Test + public void recognize_failsWithCallingWrongApiInLiveSteamMode() throws Exception { + GestureRecognizerOptions options = + GestureRecognizerOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + .setBaseOptionsHandDetector( + BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build()) + .setBaseOptionsHandLandmarker( + BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build()) + .setBaseOptionsGestureRecognizer( + BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + .setRunningMode(RunningMode.LIVE_STREAM) + .setResultListener((gestureRecognitionResult, inputImage) -> {}) + .build(); + + GestureRecognizer gestureRecognizer = + GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options); + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> gestureRecognizer.recognize(getImageFromAsset(THUMB_UP_IMAGE))); + assertThat(exception).hasMessageThat().contains("not initialized with the image mode"); + exception = + assertThrows( + MediaPipeException.class, + () -> gestureRecognizer.recognizeForVideo(getImageFromAsset(THUMB_UP_IMAGE), 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); + } + + @Test + public void recognize_successWithImageMode() throws Exception { + GestureRecognizerOptions options = + GestureRecognizerOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + .setBaseOptionsHandDetector( + BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build()) + .setBaseOptionsHandLandmarker( + BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build()) + .setBaseOptionsGestureRecognizer( + BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + .setRunningMode(RunningMode.IMAGE) + .build(); + + GestureRecognizer gestureRecognizer = + GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options); + GestureRecognitionResult actualResult = + gestureRecognizer.recognize(getImageFromAsset(THUMB_UP_IMAGE)); + GestureRecognitionResult expectedResult = + getExpectedGestureRecognitionResult(THUMB_UP_LANDMARKS, THUMB_UP_LABEL, THUMB_UP_INDEX); + assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult); + } + + @Test + public void recognize_successWithVideoMode() throws Exception { + GestureRecognizerOptions options = + GestureRecognizerOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + .setBaseOptionsHandDetector( + BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build()) + .setBaseOptionsHandLandmarker( + BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build()) + .setBaseOptionsGestureRecognizer( + BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + .setRunningMode(RunningMode.VIDEO) + .build(); + GestureRecognizer gestureRecognizer = + GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options); + GestureRecognitionResult expectedResult = + getExpectedGestureRecognitionResult(THUMB_UP_LANDMARKS, THUMB_UP_LABEL, THUMB_UP_INDEX); + for (int i = 0; i < 3; i++) { + GestureRecognitionResult actualResult = + gestureRecognizer.recognizeForVideo(getImageFromAsset(THUMB_UP_IMAGE), i); + assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult); + } + } + + @Test + public void recognize_failsWithOutOfOrderInputTimestamps() throws Exception { + Image image = getImageFromAsset(THUMB_UP_IMAGE); + GestureRecognitionResult expectedResult = + getExpectedGestureRecognitionResult(THUMB_UP_LANDMARKS, THUMB_UP_LABEL, THUMB_UP_INDEX); + GestureRecognizerOptions options = + GestureRecognizerOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + .setBaseOptionsHandDetector( + BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build()) + .setBaseOptionsHandLandmarker( + BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build()) + .setBaseOptionsGestureRecognizer( + BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + .setRunningMode(RunningMode.LIVE_STREAM) + .setResultListener( + (actualResult, inputImage) -> { + assertActualResultApproximatelyEqualsToExpectedResult( + actualResult, expectedResult); + assertImageSizeIsExpected(inputImage); + }) + .build(); + try (GestureRecognizer gestureRecognizer = + GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { + gestureRecognizer.recognizeAsync(image, 1); + MediaPipeException exception = + assertThrows(MediaPipeException.class, () -> gestureRecognizer.recognizeAsync(image, 0)); + assertThat(exception) + .hasMessageThat() + .contains("having a smaller timestamp than the processed timestamp"); + } + } + + @Test + public void recognize_successWithLiveSteamMode() throws Exception { + Image image = getImageFromAsset(THUMB_UP_IMAGE); + GestureRecognitionResult expectedResult = + getExpectedGestureRecognitionResult(THUMB_UP_LANDMARKS, THUMB_UP_LABEL, THUMB_UP_INDEX); + GestureRecognizerOptions options = + GestureRecognizerOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + .setBaseOptionsHandDetector( + BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build()) + .setBaseOptionsHandLandmarker( + BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build()) + .setBaseOptionsGestureRecognizer( + BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + .setRunningMode(RunningMode.LIVE_STREAM) + .setResultListener( + (actualResult, inputImage) -> { + assertActualResultApproximatelyEqualsToExpectedResult( + actualResult, expectedResult); + assertImageSizeIsExpected(inputImage); + }) + .build(); + try (GestureRecognizer gestureRecognizer = + GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { + for (int i = 0; i < 3; i++) { + gestureRecognizer.recognizeAsync(image, i); + } + } + } + + private static Image getImageFromAsset(String filePath) throws Exception { + AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets(); + InputStream istr = assetManager.open(filePath); + return new BitmapImageBuilder(BitmapFactory.decodeStream(istr)).build(); + } + + private static GestureRecognitionResult getExpectedGestureRecognitionResult( + String filePath, String gestureLabel, int gestureIndex) throws Exception { + AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets(); + InputStream istr = assetManager.open(filePath); + LandmarksDetectionResult landmarksDetectionResultProto = + LandmarksDetectionResult.parser().parseFrom(istr); + ClassificationProto.ClassificationList gesturesProto = + ClassificationProto.ClassificationList.newBuilder() + .addClassification( + ClassificationProto.Classification.newBuilder() + .setLabel(gestureLabel) + .setIndex(gestureIndex)) + .build(); + return GestureRecognitionResult.create( + Arrays.asList(landmarksDetectionResultProto.getLandmarks()), + Arrays.asList(landmarksDetectionResultProto.getWorldLandmarks()), + Arrays.asList(landmarksDetectionResultProto.getClassifications()), + Arrays.asList(gesturesProto), + /*timestampMs=*/ 0); + } + + private static void assertActualResultApproximatelyEqualsToExpectedResult( + GestureRecognitionResult actualResult, GestureRecognitionResult expectedResult) { + // Expects to have the same number of hands detected. + assertThat(actualResult.landmarks()).hasSize(expectedResult.landmarks().size()); + assertThat(actualResult.worldLandmarks()).hasSize(expectedResult.worldLandmarks().size()); + assertThat(actualResult.handednesses()).hasSize(expectedResult.handednesses().size()); + assertThat(actualResult.gestures()).hasSize(expectedResult.gestures().size()); + + // Actual landmarks match expected landmarks. + assertThat(actualResult.landmarks().get(0)) + .comparingElementsUsing( + Correspondence.from( + (Correspondence.BinaryPredicate) + (actual, expected) -> { + return Correspondence.tolerance(LANDMARKS_ERROR_TOLERANCE) + .compare(actual.x(), expected.x()) + && Correspondence.tolerance(LANDMARKS_ERROR_TOLERANCE) + .compare(actual.y(), expected.y()); + }, + "landmarks approximately equal to")) + .containsExactlyElementsIn(expectedResult.landmarks().get(0)); + + // Actual handedness matches expected handedness. + Category actualTopHandedness = actualResult.handednesses().get(0).get(0); + Category expectedTopHandedness = expectedResult.handednesses().get(0).get(0); + assertThat(actualTopHandedness.index()).isEqualTo(expectedTopHandedness.index()); + assertThat(actualTopHandedness.categoryName()).isEqualTo(expectedTopHandedness.categoryName()); + + // Actual gesture with top score matches expected gesture. + Category actualTopGesture = actualResult.gestures().get(0).get(0); + Category expectedTopGesture = expectedResult.gestures().get(0).get(0); + assertActualGestureEqualExpectedGesture(actualTopGesture, expectedTopGesture); + } + + private static void assertActualGestureEqualExpectedGesture( + Category actualGesture, Category expectedGesture) { + assertThat(actualGesture.index()).isEqualTo(actualGesture.index()); + assertThat(expectedGesture.categoryName()).isEqualTo(expectedGesture.categoryName()); + } + + private static void assertImageSizeIsExpected(Image inputImage) { + assertThat(inputImage).isNotNull(); + assertThat(inputImage.getWidth()).isEqualTo(IMAGE_WIDTH); + assertThat(inputImage.getHeight()).isEqualTo(IMAGE_HEIGHT); + } +}