From 94cd1348096bae81f9cf6bdeb5ed5b5de96b66b3 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 24 Oct 2022 14:59:09 -0700 Subject: [PATCH] Add support for image rotation in Java vision tasks. PiperOrigin-RevId: 483493729 --- .../android/objectdetector/src/main/BUILD | 1 + .../examples/objectdetector/MainActivity.java | 38 ++-- .../com/google/mediapipe/tasks/vision/BUILD | 1 + .../tasks/vision/core/BaseVisionTaskApi.java | 181 ++++-------------- .../vision/core/ImageProcessingOptions.java | 92 +++++++++ .../gesturerecognizer/GestureRecognizer.java | 128 +++++++++++-- .../imageclassifier/ImageClassifier.java | 121 ++++++------ .../vision/objectdetector/ObjectDetector.java | 127 ++++++++++-- .../tasks/vision/core/AndroidManifest.xml | 24 +++ .../google/mediapipe/tasks/vision/core/BUILD | 19 ++ .../core/ImageProcessingOptionsTest.java | 70 +++++++ .../GestureRecognizerTest.java | 79 +++++++- .../imageclassifier/ImageClassifierTest.java | 79 +++++++- .../objectdetector/ObjectDetectorTest.java | 85 ++++++-- 14 files changed, 762 insertions(+), 283 deletions(-) create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/ImageProcessingOptions.java create mode 100644 mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/core/AndroidManifest.xml create mode 100644 mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/core/BUILD create mode 100644 mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/core/ImageProcessingOptionsTest.java diff --git a/mediapipe/tasks/examples/android/objectdetector/src/main/BUILD b/mediapipe/tasks/examples/android/objectdetector/src/main/BUILD index acbdbd6eb..89c1edcb3 100644 --- a/mediapipe/tasks/examples/android/objectdetector/src/main/BUILD +++ b/mediapipe/tasks/examples/android/objectdetector/src/main/BUILD @@ -31,6 +31,7 @@ android_binary( multidex = "native", resource_files = ["//mediapipe/tasks/examples/android:resource_files"], deps = [ + "//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/java/com/google/mediapipe/framework/image", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:detection", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", diff --git a/mediapipe/tasks/examples/android/objectdetector/src/main/java/com/google/mediapipe/tasks/examples/objectdetector/MainActivity.java b/mediapipe/tasks/examples/android/objectdetector/src/main/java/com/google/mediapipe/tasks/examples/objectdetector/MainActivity.java index 11c8c1837..18c010a00 100644 --- a/mediapipe/tasks/examples/android/objectdetector/src/main/java/com/google/mediapipe/tasks/examples/objectdetector/MainActivity.java +++ b/mediapipe/tasks/examples/android/objectdetector/src/main/java/com/google/mediapipe/tasks/examples/objectdetector/MainActivity.java @@ -16,7 +16,6 @@ package com.google.mediapipe.tasks.examples.objectdetector; import android.content.Intent; import android.graphics.Bitmap; -import android.graphics.Matrix; import android.media.MediaMetadataRetriever; import android.os.Bundle; import android.provider.MediaStore; @@ -29,9 +28,11 @@ import androidx.activity.result.ActivityResultLauncher; import androidx.activity.result.contract.ActivityResultContracts; import androidx.exifinterface.media.ExifInterface; // ContentResolver dependency +import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.image.BitmapImageBuilder; import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.tasks.core.BaseOptions; +import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; import com.google.mediapipe.tasks.vision.core.RunningMode; import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetectionResult; import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetector; @@ -82,6 +83,7 @@ public class MainActivity extends AppCompatActivity { if (resultIntent != null) { if (result.getResultCode() == RESULT_OK) { Bitmap bitmap = null; + int rotation = 0; try { bitmap = downscaleBitmap( @@ -93,13 +95,16 @@ public class MainActivity extends AppCompatActivity { try { InputStream imageData = this.getContentResolver().openInputStream(resultIntent.getData()); - bitmap = rotateBitmap(bitmap, imageData); - } catch (IOException e) { + rotation = getImageRotation(imageData); + } catch (IOException | MediaPipeException e) { Log.e(TAG, "Bitmap rotation error:" + e); } if (bitmap != null) { MPImage image = new BitmapImageBuilder(bitmap).build(); - ObjectDetectionResult detectionResult = objectDetector.detect(image); + ObjectDetectionResult detectionResult = + objectDetector.detect( + image, + ImageProcessingOptions.builder().setRotationDegrees(rotation).build()); imageView.setData(image, detectionResult); runOnUiThread(() -> imageView.update()); } @@ -210,28 +215,25 @@ public class MainActivity extends AppCompatActivity { return Bitmap.createScaledBitmap(originalBitmap, width, height, false); } - private Bitmap rotateBitmap(Bitmap inputBitmap, InputStream imageData) throws IOException { + private int getImageRotation(InputStream imageData) throws IOException, MediaPipeException { int orientation = new ExifInterface(imageData) .getAttributeInt(ExifInterface.TAG_ORIENTATION, ExifInterface.ORIENTATION_NORMAL); - if (orientation == ExifInterface.ORIENTATION_NORMAL) { - return inputBitmap; - } - Matrix matrix = new Matrix(); switch (orientation) { + case ExifInterface.ORIENTATION_NORMAL: + return 0; case ExifInterface.ORIENTATION_ROTATE_90: - matrix.postRotate(90); - break; + return 90; case ExifInterface.ORIENTATION_ROTATE_180: - matrix.postRotate(180); - break; + return 180; case ExifInterface.ORIENTATION_ROTATE_270: - matrix.postRotate(270); - break; + return 270; default: - matrix.postRotate(0); + // TODO: use getRotationDegrees() and isFlipped() instead of switch once flip + // is supported. + throw new MediaPipeException( + MediaPipeException.StatusCode.UNIMPLEMENTED.ordinal(), + "Flipped images are not supported yet."); } - return Bitmap.createBitmap( - inputBitmap, 0, 0, inputBitmap.getWidth(), inputBitmap.getHeight(), matrix, true); } } diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD index 5ea465d47..ed65fbcac 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD @@ -28,6 +28,7 @@ android_library( "//mediapipe/java/com/google/mediapipe/framework:android_framework_no_mff", "//mediapipe/java/com/google/mediapipe/framework/image", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", + "//third_party:autovalue", "@maven//:com_google_guava_guava", ], ) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BaseVisionTaskApi.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BaseVisionTaskApi.java index 49dab408c..0774b69a2 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BaseVisionTaskApi.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BaseVisionTaskApi.java @@ -24,7 +24,6 @@ import com.google.mediapipe.tasks.core.TaskResult; import com.google.mediapipe.tasks.core.TaskRunner; import java.util.HashMap; import java.util.Map; -import java.util.Optional; /** The base class of MediaPipe vision tasks. */ public class BaseVisionTaskApi implements AutoCloseable { @@ -32,7 +31,7 @@ public class BaseVisionTaskApi implements AutoCloseable { private final TaskRunner runner; private final RunningMode runningMode; private final String imageStreamName; - private final Optional normRectStreamName; + private final String normRectStreamName; static { System.loadLibrary("mediapipe_tasks_vision_jni"); @@ -40,27 +39,13 @@ public class BaseVisionTaskApi implements AutoCloseable { } /** - * Constructor to initialize a {@link BaseVisionTaskApi} only taking images as input. + * Constructor to initialize a {@link BaseVisionTaskApi}. * * @param runner a {@link TaskRunner}. * @param runningMode a mediapipe vision task {@link RunningMode}. * @param imageStreamName the name of the input image stream. - */ - public BaseVisionTaskApi(TaskRunner runner, RunningMode runningMode, String imageStreamName) { - this.runner = runner; - this.runningMode = runningMode; - this.imageStreamName = imageStreamName; - this.normRectStreamName = Optional.empty(); - } - - /** - * Constructor to initialize a {@link BaseVisionTaskApi} taking images and normalized rects as - * input. - * - * @param runner a {@link TaskRunner}. - * @param runningMode a mediapipe vision task {@link RunningMode}. - * @param imageStreamName the name of the input image stream. - * @param normRectStreamName the name of the input normalized rect image stream. + * @param normRectStreamName the name of the input normalized rect image stream used to provide + * (mandatory) rotation and (optional) region-of-interest. */ public BaseVisionTaskApi( TaskRunner runner, @@ -70,7 +55,7 @@ public class BaseVisionTaskApi implements AutoCloseable { this.runner = runner; this.runningMode = runningMode; this.imageStreamName = imageStreamName; - this.normRectStreamName = Optional.of(normRectStreamName); + this.normRectStreamName = normRectStreamName; } /** @@ -78,53 +63,23 @@ public class BaseVisionTaskApi implements AutoCloseable { * failure status or a successful result is returned. * * @param image a MediaPipe {@link MPImage} object for processing. - * @throws MediaPipeException if the task is not in the image mode or requires a normalized rect - * input. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. + * @throws MediaPipeException if the task is not in the image mode. */ - protected TaskResult processImageData(MPImage image) { + protected TaskResult processImageData( + MPImage image, ImageProcessingOptions imageProcessingOptions) { if (runningMode != RunningMode.IMAGE) { throw new MediaPipeException( MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), "Task is not initialized with the image mode. Current running mode:" + runningMode.name()); } - if (normRectStreamName.isPresent()) { - throw new MediaPipeException( - MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), - "Task expects a normalized rect as input."); - } - Map inputPackets = new HashMap<>(); - inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image)); - return runner.process(inputPackets); - } - - /** - * A synchronous method to process single image inputs. The call blocks the current thread until a - * failure status or a successful result is returned. - * - * @param image a MediaPipe {@link MPImage} object for processing. - * @param roi a {@link RectF} defining the region-of-interest to process in the image. Coordinates - * are expected to be specified as normalized values in [0,1]. - * @throws MediaPipeException if the task is not in the image mode or doesn't require a normalized - * rect. - */ - protected TaskResult processImageData(MPImage image, RectF roi) { - if (runningMode != RunningMode.IMAGE) { - throw new MediaPipeException( - MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), - "Task is not initialized with the image mode. Current running mode:" - + runningMode.name()); - } - if (!normRectStreamName.isPresent()) { - throw new MediaPipeException( - MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), - "Task doesn't expect a normalized rect as input."); - } Map inputPackets = new HashMap<>(); inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image)); inputPackets.put( - normRectStreamName.get(), - runner.getPacketCreator().createProto(convertToNormalizedRect(roi))); + normRectStreamName, + runner.getPacketCreator().createProto(convertToNormalizedRect(imageProcessingOptions))); return runner.process(inputPackets); } @@ -133,55 +88,24 @@ public class BaseVisionTaskApi implements AutoCloseable { * until a failure status or a successful result is returned. * * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. * @param timestampMs the corresponding timestamp of the input image in milliseconds. - * @throws MediaPipeException if the task is not in the video mode or requires a normalized rect - * input. + * @throws MediaPipeException if the task is not in the video mode. */ - protected TaskResult processVideoData(MPImage image, long timestampMs) { + protected TaskResult processVideoData( + MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { if (runningMode != RunningMode.VIDEO) { throw new MediaPipeException( MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), "Task is not initialized with the video mode. Current running mode:" + runningMode.name()); } - if (normRectStreamName.isPresent()) { - throw new MediaPipeException( - MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), - "Task expects a normalized rect as input."); - } - Map inputPackets = new HashMap<>(); - inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image)); - return runner.process(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND); - } - - /** - * A synchronous method to process continuous video frames. The call blocks the current thread - * until a failure status or a successful result is returned. - * - * @param image a MediaPipe {@link MPImage} object for processing. - * @param roi a {@link RectF} defining the region-of-interest to process in the image. Coordinates - * are expected to be specified as normalized values in [0,1]. - * @param timestampMs the corresponding timestamp of the input image in milliseconds. - * @throws MediaPipeException if the task is not in the video mode or doesn't require a normalized - * rect. - */ - protected TaskResult processVideoData(MPImage image, RectF roi, long timestampMs) { - if (runningMode != RunningMode.VIDEO) { - throw new MediaPipeException( - MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), - "Task is not initialized with the video mode. Current running mode:" - + runningMode.name()); - } - if (!normRectStreamName.isPresent()) { - throw new MediaPipeException( - MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), - "Task doesn't expect a normalized rect as input."); - } Map inputPackets = new HashMap<>(); inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image)); inputPackets.put( - normRectStreamName.get(), - runner.getPacketCreator().createProto(convertToNormalizedRect(roi))); + normRectStreamName, + runner.getPacketCreator().createProto(convertToNormalizedRect(imageProcessingOptions))); return runner.process(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND); } @@ -190,55 +114,24 @@ public class BaseVisionTaskApi implements AutoCloseable { * available in the user-defined result listener. * * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. * @param timestampMs the corresponding timestamp of the input image in milliseconds. - * @throws MediaPipeException if the task is not in the video mode or requires a normalized rect - * input. + * @throws MediaPipeException if the task is not in the stream mode. */ - protected void sendLiveStreamData(MPImage image, long timestampMs) { + protected void sendLiveStreamData( + MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { if (runningMode != RunningMode.LIVE_STREAM) { throw new MediaPipeException( MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), "Task is not initialized with the live stream mode. Current running mode:" + runningMode.name()); } - if (normRectStreamName.isPresent()) { - throw new MediaPipeException( - MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), - "Task expects a normalized rect as input."); - } - Map inputPackets = new HashMap<>(); - inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image)); - runner.send(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND); - } - - /** - * An asynchronous method to send live stream data to the {@link TaskRunner}. The results will be - * available in the user-defined result listener. - * - * @param image a MediaPipe {@link MPImage} object for processing. - * @param roi a {@link RectF} defining the region-of-interest to process in the image. Coordinates - * are expected to be specified as normalized values in [0,1]. - * @param timestampMs the corresponding timestamp of the input image in milliseconds. - * @throws MediaPipeException if the task is not in the video mode or doesn't require a normalized - * rect. - */ - protected void sendLiveStreamData(MPImage image, RectF roi, long timestampMs) { - if (runningMode != RunningMode.LIVE_STREAM) { - throw new MediaPipeException( - MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), - "Task is not initialized with the live stream mode. Current running mode:" - + runningMode.name()); - } - if (!normRectStreamName.isPresent()) { - throw new MediaPipeException( - MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), - "Task doesn't expect a normalized rect as input."); - } Map inputPackets = new HashMap<>(); inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image)); inputPackets.put( - normRectStreamName.get(), - runner.getPacketCreator().createProto(convertToNormalizedRect(roi))); + normRectStreamName, + runner.getPacketCreator().createProto(convertToNormalizedRect(imageProcessingOptions))); runner.send(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND); } @@ -248,13 +141,23 @@ public class BaseVisionTaskApi implements AutoCloseable { runner.close(); } - /** Converts a {@link RectF} object into a {@link NormalizedRect} protobuf message. */ - private static NormalizedRect convertToNormalizedRect(RectF rect) { + /** + * Converts an {@link ImageProcessingOptions} instance into a {@link NormalizedRect} protobuf + * message. + */ + private static NormalizedRect convertToNormalizedRect( + ImageProcessingOptions imageProcessingOptions) { + RectF regionOfInterest = + imageProcessingOptions.regionOfInterest().isPresent() + ? imageProcessingOptions.regionOfInterest().get() + : new RectF(0, 0, 1, 1); return NormalizedRect.newBuilder() - .setXCenter(rect.centerX()) - .setYCenter(rect.centerY()) - .setWidth(rect.width()) - .setHeight(rect.height()) + .setXCenter(regionOfInterest.centerX()) + .setYCenter(regionOfInterest.centerY()) + .setWidth(regionOfInterest.width()) + .setHeight(regionOfInterest.height()) + // Convert to radians anti-clockwise. + .setRotation(-(float) Math.PI * imageProcessingOptions.rotationDegrees() / 180.0f) .build(); } } diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/ImageProcessingOptions.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/ImageProcessingOptions.java new file mode 100644 index 000000000..a34a9787d --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/ImageProcessingOptions.java @@ -0,0 +1,92 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.core; + +import android.graphics.RectF; +import com.google.auto.value.AutoValue; +import java.util.Optional; + +// TODO: add support for image flipping. +/** Options for image processing. */ +@AutoValue +public abstract class ImageProcessingOptions { + + /** + * Builder for {@link ImageProcessingOptions}. + * + *

If both region-of-interest and rotation are specified, the crop around the + * region-of-interest is extracted first, then the specified rotation is applied to the crop. + */ + @AutoValue.Builder + public abstract static class Builder { + /** + * Sets the optional region-of-interest to crop from the image. If not specified, the full image + * is used. + * + *

Coordinates must be in [0,1], {@code left} must be < {@code right} and {@code top} must be + * < {@code bottom}, otherwise an IllegalArgumentException will be thrown when {@link #build()} + * is called. + */ + public abstract Builder setRegionOfInterest(RectF value); + + /** + * Sets the rotation to apply to the image (or cropped region-of-interest), in degrees + * clockwise. Defaults to 0. + * + *

The rotation must be a multiple (positive or negative) of 90°, otherwise an + * IllegalArgumentException will be thrown when {@link #build()} is called. + */ + public abstract Builder setRotationDegrees(int value); + + abstract ImageProcessingOptions autoBuild(); + + /** + * Validates and builds the {@link ImageProcessingOptions} instance. + * + * @throws IllegalArgumentException if some of the provided values do not meet their + * requirements. + */ + public final ImageProcessingOptions build() { + ImageProcessingOptions options = autoBuild(); + if (options.regionOfInterest().isPresent()) { + RectF roi = options.regionOfInterest().get(); + if (roi.left >= roi.right || roi.top >= roi.bottom) { + throw new IllegalArgumentException( + String.format( + "Expected left < right and top < bottom, found: %s.", roi.toShortString())); + } + if (roi.left < 0 || roi.right > 1 || roi.top < 0 || roi.bottom > 1) { + throw new IllegalArgumentException( + String.format("Expected RectF values in [0,1], found: %s.", roi.toShortString())); + } + } + if (options.rotationDegrees() % 90 != 0) { + throw new IllegalArgumentException( + String.format( + "Expected rotation to be a multiple of 90°, found: %d.", + options.rotationDegrees())); + } + return options; + } + } + + public abstract Optional regionOfInterest(); + + public abstract int rotationDegrees(); + + public static Builder builder() { + return new AutoValue_ImageProcessingOptions.Builder().setRotationDegrees(0); + } +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java index 55cf275e9..8e5a30eab 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java @@ -15,7 +15,6 @@ package com.google.mediapipe.tasks.vision.gesturerecognizer; import android.content.Context; -import android.graphics.RectF; import android.os.ParcelFileDescriptor; import com.google.auto.value.AutoValue; import com.google.mediapipe.formats.proto.LandmarkProto.LandmarkList; @@ -37,6 +36,7 @@ import com.google.mediapipe.tasks.core.TaskOptions; import com.google.mediapipe.tasks.core.TaskRunner; import com.google.mediapipe.tasks.core.proto.BaseOptionsProto; import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi; +import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; import com.google.mediapipe.tasks.vision.core.RunningMode; import com.google.mediapipe.tasks.vision.gesturerecognizer.proto.GestureClassifierGraphOptionsProto; import com.google.mediapipe.tasks.vision.gesturerecognizer.proto.GestureRecognizerGraphOptionsProto; @@ -212,6 +212,25 @@ public final class GestureRecognizer extends BaseVisionTaskApi { super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME); } + /** + * Performs gesture recognition on the provided single image with default image processing + * options, i.e. without any rotation applied. Only use this method when the {@link + * GestureRecognizer} is created with {@link RunningMode.IMAGE}. TODO update java doc + * for input image format. + * + *

{@link GestureRecognizer} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @throws MediaPipeException if there is an internal error. + */ + public GestureRecognitionResult recognize(MPImage image) { + return recognize(image, ImageProcessingOptions.builder().build()); + } + /** * Performs gesture recognition on the provided single image. Only use this method when the {@link * GestureRecognizer} is created with {@link RunningMode.IMAGE}. TODO update java doc @@ -223,12 +242,41 @@ public final class GestureRecognizer extends BaseVisionTaskApi { *
  • {@link Bitmap.Config.ARGB_8888} * * - * @param inputImage a MediaPipe {@link MPImage} object for processing. + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. Note that region-of-interest is not supported + * by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in + * this method throwing an IllegalArgumentException. + * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a + * region-of-interest. * @throws MediaPipeException if there is an internal error. */ - public GestureRecognitionResult recognize(MPImage inputImage) { - // TODO: add proper support for rotations. - return (GestureRecognitionResult) processImageData(inputImage, buildFullImageRectF()); + public GestureRecognitionResult recognize( + MPImage image, ImageProcessingOptions imageProcessingOptions) { + validateImageProcessingOptions(imageProcessingOptions); + return (GestureRecognitionResult) processImageData(image, imageProcessingOptions); + } + + /** + * Performs gesture recognition on the provided video frame with default image processing options, + * i.e. without any rotation applied. Only use this method when the {@link GestureRecognizer} is + * created with {@link RunningMode.VIDEO}. + * + *

    It's required to provide the video frame's timestamp (in milliseconds). The input timestamps + * must be monotonically increasing. + * + *

    {@link GestureRecognizer} supports the following color space types: + * + *

      + *
    • {@link Bitmap.Config.ARGB_8888} + *
    + * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param timestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public GestureRecognitionResult recognizeForVideo(MPImage image, long timestampMs) { + return recognizeForVideo(image, ImageProcessingOptions.builder().build(), timestampMs); } /** @@ -244,14 +292,43 @@ public final class GestureRecognizer extends BaseVisionTaskApi { *
  • {@link Bitmap.Config.ARGB_8888} * * - * @param inputImage a MediaPipe {@link MPImage} object for processing. - * @param inputTimestampMs the input timestamp (in milliseconds). + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. Note that region-of-interest is not supported + * by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in + * this method throwing an IllegalArgumentException. + * @param timestampMs the input timestamp (in milliseconds). + * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a + * region-of-interest. * @throws MediaPipeException if there is an internal error. */ - public GestureRecognitionResult recognizeForVideo(MPImage inputImage, long inputTimestampMs) { - // TODO: add proper support for rotations. - return (GestureRecognitionResult) - processVideoData(inputImage, buildFullImageRectF(), inputTimestampMs); + public GestureRecognitionResult recognizeForVideo( + MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { + validateImageProcessingOptions(imageProcessingOptions); + return (GestureRecognitionResult) processVideoData(image, imageProcessingOptions, timestampMs); + } + + /** + * Sends live image data to perform gesture recognition with default image processing options, + * i.e. without any rotation applied, and the results will be available via the {@link + * ResultListener} provided in the {@link GestureRecognizerOptions}. Only use this method when the + * {@link GestureRecognition} is created with {@link RunningMode.LIVE_STREAM}. + * + *

    It's required to provide a timestamp (in milliseconds) to indicate when the input image is + * sent to the gesture recognizer. The input timestamps must be monotonically increasing. + * + *

    {@link GestureRecognizer} supports the following color space types: + * + *

      + *
    • {@link Bitmap.Config.ARGB_8888} + *
    + * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param timestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public void recognizeAsync(MPImage image, long timestampMs) { + recognizeAsync(image, ImageProcessingOptions.builder().build(), timestampMs); } /** @@ -268,13 +345,20 @@ public final class GestureRecognizer extends BaseVisionTaskApi { *
  • {@link Bitmap.Config.ARGB_8888} * * - * @param inputImage a MediaPipe {@link MPImage} object for processing. - * @param inputTimestampMs the input timestamp (in milliseconds). + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. Note that region-of-interest is not supported + * by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in + * this method throwing an IllegalArgumentException. + * @param timestampMs the input timestamp (in milliseconds). + * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a + * region-of-interest. * @throws MediaPipeException if there is an internal error. */ - public void recognizeAsync(MPImage inputImage, long inputTimestampMs) { - // TODO: add proper support for rotations. - sendLiveStreamData(inputImage, buildFullImageRectF(), inputTimestampMs); + public void recognizeAsync( + MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { + validateImageProcessingOptions(imageProcessingOptions); + sendLiveStreamData(image, imageProcessingOptions, timestampMs); } /** Options for setting up an {@link GestureRecognizer}. */ @@ -445,8 +529,14 @@ public final class GestureRecognizer extends BaseVisionTaskApi { } } - /** Creates a RectF covering the full image. */ - private static RectF buildFullImageRectF() { - return new RectF(0, 0, 1, 1); + /** + * Validates that the provided {@link ImageProcessingOptions} doesn't contain a + * region-of-interest. + */ + private static void validateImageProcessingOptions( + ImageProcessingOptions imageProcessingOptions) { + if (imageProcessingOptions.regionOfInterest().isPresent()) { + throw new IllegalArgumentException("GestureRecognizer doesn't support region-of-interest."); + } } } diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java index 75e2de13a..3863b6fe0 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java @@ -15,7 +15,6 @@ package com.google.mediapipe.tasks.vision.imageclassifier; import android.content.Context; -import android.graphics.RectF; import android.os.ParcelFileDescriptor; import com.google.auto.value.AutoValue; import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions; @@ -37,6 +36,7 @@ import com.google.mediapipe.tasks.core.TaskOptions; import com.google.mediapipe.tasks.core.TaskRunner; import com.google.mediapipe.tasks.core.proto.BaseOptionsProto; import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi; +import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; import com.google.mediapipe.tasks.vision.core.RunningMode; import com.google.mediapipe.tasks.vision.imageclassifier.proto.ImageClassifierGraphOptionsProto; import java.io.File; @@ -215,6 +215,24 @@ public final class ImageClassifier extends BaseVisionTaskApi { super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME); } + /** + * Performs classification on the provided single image with default image processing options, + * i.e. using the whole image as region-of-interest and without any rotation applied. Only use + * this method when the {@link ImageClassifier} is created with {@link RunningMode.IMAGE}. + * + *

    {@link ImageClassifier} supports the following color space types: + * + *

      + *
    • {@link Bitmap.Config.ARGB_8888} + *
    + * + * @param image a MediaPipe {@link MPImage} object for processing. + * @throws MediaPipeException if there is an internal error. + */ + public ImageClassificationResult classify(MPImage image) { + return classify(image, ImageProcessingOptions.builder().build()); + } + /** * Performs classification on the provided single image. Only use this method when the {@link * ImageClassifier} is created with {@link RunningMode.IMAGE}. @@ -225,16 +243,23 @@ public final class ImageClassifier extends BaseVisionTaskApi { *
  • {@link Bitmap.Config.ARGB_8888} * * - * @param inputImage a MediaPipe {@link MPImage} object for processing. + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. * @throws MediaPipeException if there is an internal error. */ - public ImageClassificationResult classify(MPImage inputImage) { - return (ImageClassificationResult) processImageData(inputImage, buildFullImageRectF()); + public ImageClassificationResult classify( + MPImage image, ImageProcessingOptions imageProcessingOptions) { + return (ImageClassificationResult) processImageData(image, imageProcessingOptions); } /** - * Performs classification on the provided single image and region-of-interest. Only use this - * method when the {@link ImageClassifier} is created with {@link RunningMode.IMAGE}. + * Performs classification on the provided video frame with default image processing options, i.e. + * using the whole image as region-of-interest and without any rotation applied. Only use this + * method when the {@link ImageClassifier} is created with {@link RunningMode.VIDEO}. + * + *

    It's required to provide the video frame's timestamp (in milliseconds). The input timestamps + * must be monotonically increasing. * *

    {@link ImageClassifier} supports the following color space types: * @@ -242,13 +267,12 @@ public final class ImageClassifier extends BaseVisionTaskApi { *

  • {@link Bitmap.Config.ARGB_8888} * * - * @param inputImage a MediaPipe {@link MPImage} object for processing. - * @param roi a {@link RectF} specifying the region of interest on which to perform - * classification. Coordinates are expected to be specified as normalized values in [0,1]. + * @param image a MediaPipe {@link MPImage} object for processing. + * @param timestampMs the input timestamp (in milliseconds). * @throws MediaPipeException if there is an internal error. */ - public ImageClassificationResult classify(MPImage inputImage, RectF roi) { - return (ImageClassificationResult) processImageData(inputImage, roi); + public ImageClassificationResult classifyForVideo(MPImage image, long timestampMs) { + return classifyForVideo(image, ImageProcessingOptions.builder().build(), timestampMs); } /** @@ -264,21 +288,26 @@ public final class ImageClassifier extends BaseVisionTaskApi { *
  • {@link Bitmap.Config.ARGB_8888} * * - * @param inputImage a MediaPipe {@link MPImage} object for processing. - * @param inputTimestampMs the input timestamp (in milliseconds). + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. + * @param timestampMs the input timestamp (in milliseconds). * @throws MediaPipeException if there is an internal error. */ - public ImageClassificationResult classifyForVideo(MPImage inputImage, long inputTimestampMs) { - return (ImageClassificationResult) - processVideoData(inputImage, buildFullImageRectF(), inputTimestampMs); + public ImageClassificationResult classifyForVideo( + MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { + return (ImageClassificationResult) processVideoData(image, imageProcessingOptions, timestampMs); } /** - * Performs classification on the provided video frame with additional region-of-interest. Only - * use this method when the {@link ImageClassifier} is created with {@link RunningMode.VIDEO}. + * Sends live image data to perform classification with default image processing options, i.e. + * using the whole image as region-of-interest and without any rotation applied, and the results + * will be available via the {@link ResultListener} provided in the {@link + * ImageClassifierOptions}. Only use this method when the {@link ImageClassifier} is created with + * {@link RunningMode.LIVE_STREAM}. * - *

    It's required to provide the video frame's timestamp (in milliseconds). The input timestamps - * must be monotonically increasing. + *

    It's required to provide a timestamp (in milliseconds) to indicate when the input image is + * sent to the object detector. The input timestamps must be monotonically increasing. * *

    {@link ImageClassifier} supports the following color space types: * @@ -286,15 +315,12 @@ public final class ImageClassifier extends BaseVisionTaskApi { *

  • {@link Bitmap.Config.ARGB_8888} * * - * @param inputImage a MediaPipe {@link MPImage} object for processing. - * @param roi a {@link RectF} specifying the region of interest on which to perform - * classification. Coordinates are expected to be specified as normalized values in [0,1]. - * @param inputTimestampMs the input timestamp (in milliseconds). + * @param image a MediaPipe {@link MPImage} object for processing. + * @param timestampMs the input timestamp (in milliseconds). * @throws MediaPipeException if there is an internal error. */ - public ImageClassificationResult classifyForVideo( - MPImage inputImage, RectF roi, long inputTimestampMs) { - return (ImageClassificationResult) processVideoData(inputImage, roi, inputTimestampMs); + public void classifyAsync(MPImage image, long timestampMs) { + classifyAsync(image, ImageProcessingOptions.builder().build(), timestampMs); } /** @@ -311,37 +337,15 @@ public final class ImageClassifier extends BaseVisionTaskApi { *
  • {@link Bitmap.Config.ARGB_8888} * * - * @param inputImage a MediaPipe {@link MPImage} object for processing. - * @param inputTimestampMs the input timestamp (in milliseconds). + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. + * @param timestampMs the input timestamp (in milliseconds). * @throws MediaPipeException if there is an internal error. */ - public void classifyAsync(MPImage inputImage, long inputTimestampMs) { - sendLiveStreamData(inputImage, buildFullImageRectF(), inputTimestampMs); - } - - /** - * Sends live image data and additional region-of-interest to perform classification, and the - * results will be available via the {@link ResultListener} provided in the {@link - * ImageClassifierOptions}. Only use this method when the {@link ImageClassifier} is created with - * {@link RunningMode.LIVE_STREAM}. - * - *

    It's required to provide a timestamp (in milliseconds) to indicate when the input image is - * sent to the object detector. The input timestamps must be monotonically increasing. - * - *

    {@link ImageClassifier} supports the following color space types: - * - *

      - *
    • {@link Bitmap.Config.ARGB_8888} - *
    - * - * @param inputImage a MediaPipe {@link MPImage} object for processing. - * @param roi a {@link RectF} specifying the region of interest on which to perform - * classification. Coordinates are expected to be specified as normalized values in [0,1]. - * @param inputTimestampMs the input timestamp (in milliseconds). - * @throws MediaPipeException if there is an internal error. - */ - public void classifyAsync(MPImage inputImage, RectF roi, long inputTimestampMs) { - sendLiveStreamData(inputImage, roi, inputTimestampMs); + public void classifyAsync( + MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { + sendLiveStreamData(image, imageProcessingOptions, timestampMs); } /** Options for setting up and {@link ImageClassifier}. */ @@ -447,9 +451,4 @@ public final class ImageClassifier extends BaseVisionTaskApi { .build(); } } - - /** Creates a RectF covering the full image. */ - private static RectF buildFullImageRectF() { - return new RectF(0, 0, 1, 1); - } } diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java index 0f2e7b540..3f944eaee 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java @@ -32,6 +32,7 @@ import com.google.mediapipe.tasks.core.TaskOptions; import com.google.mediapipe.tasks.core.TaskRunner; import com.google.mediapipe.tasks.core.proto.BaseOptionsProto; import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi; +import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; import com.google.mediapipe.tasks.vision.core.RunningMode; import com.google.mediapipe.tasks.vision.objectdetector.proto.ObjectDetectorOptionsProto; import com.google.mediapipe.formats.proto.DetectionProto.Detection; @@ -96,8 +97,10 @@ import java.util.Optional; public final class ObjectDetector extends BaseVisionTaskApi { private static final String TAG = ObjectDetector.class.getSimpleName(); private static final String IMAGE_IN_STREAM_NAME = "image_in"; + private static final String NORM_RECT_IN_STREAM_NAME = "norm_rect_in"; private static final List INPUT_STREAMS = - Collections.unmodifiableList(Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME)); + Collections.unmodifiableList( + Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME)); private static final List OUTPUT_STREAMS = Collections.unmodifiableList(Arrays.asList("DETECTIONS:detections_out", "IMAGE:image_out")); private static final int DETECTIONS_OUT_STREAM_INDEX = 0; @@ -204,7 +207,25 @@ public final class ObjectDetector extends BaseVisionTaskApi { * @param runningMode a mediapipe vision task {@link RunningMode}. */ private ObjectDetector(TaskRunner taskRunner, RunningMode runningMode) { - super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME); + super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME); + } + + /** + * Performs object detection on the provided single image with default image processing options, + * i.e. without any rotation applied. Only use this method when the {@link ObjectDetector} is + * created with {@link RunningMode.IMAGE}. + * + *

    {@link ObjectDetector} supports the following color space types: + * + *

      + *
    • {@link Bitmap.Config.ARGB_8888} + *
    + * + * @param image a MediaPipe {@link MPImage} object for processing. + * @throws MediaPipeException if there is an internal error. + */ + public ObjectDetectionResult detect(MPImage image) { + return detect(image, ImageProcessingOptions.builder().build()); } /** @@ -217,11 +238,41 @@ public final class ObjectDetector extends BaseVisionTaskApi { *
  • {@link Bitmap.Config.ARGB_8888} * * - * @param inputImage a MediaPipe {@link MPImage} object for processing. + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. Note that region-of-interest is not supported + * by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in + * this method throwing an IllegalArgumentException. + * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a + * region-of-interest. * @throws MediaPipeException if there is an internal error. */ - public ObjectDetectionResult detect(MPImage inputImage) { - return (ObjectDetectionResult) processImageData(inputImage); + public ObjectDetectionResult detect( + MPImage image, ImageProcessingOptions imageProcessingOptions) { + validateImageProcessingOptions(imageProcessingOptions); + return (ObjectDetectionResult) processImageData(image, imageProcessingOptions); + } + + /** + * Performs object detection on the provided video frame with default image processing options, + * i.e. without any rotation applied. Only use this method when the {@link ObjectDetector} is + * created with {@link RunningMode.VIDEO}. + * + *

    It's required to provide the video frame's timestamp (in milliseconds). The input timestamps + * must be monotonically increasing. + * + *

    {@link ObjectDetector} supports the following color space types: + * + *

      + *
    • {@link Bitmap.Config.ARGB_8888} + *
    + * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param timestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public ObjectDetectionResult detectForVideo(MPImage image, long timestampMs) { + return detectForVideo(image, ImageProcessingOptions.builder().build(), timestampMs); } /** @@ -237,12 +288,43 @@ public final class ObjectDetector extends BaseVisionTaskApi { *
  • {@link Bitmap.Config.ARGB_8888} * * - * @param inputImage a MediaPipe {@link MPImage} object for processing. - * @param inputTimestampMs the input timestamp (in milliseconds). + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. Note that region-of-interest is not supported + * by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in + * this method throwing an IllegalArgumentException. + * @param timestampMs the input timestamp (in milliseconds). + * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a + * region-of-interest. * @throws MediaPipeException if there is an internal error. */ - public ObjectDetectionResult detectForVideo(MPImage inputImage, long inputTimestampMs) { - return (ObjectDetectionResult) processVideoData(inputImage, inputTimestampMs); + public ObjectDetectionResult detectForVideo( + MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { + validateImageProcessingOptions(imageProcessingOptions); + return (ObjectDetectionResult) processVideoData(image, imageProcessingOptions, timestampMs); + } + + /** + * Sends live image data to perform object detection with default image processing options, i.e. + * without any rotation applied, and the results will be available via the {@link ResultListener} + * provided in the {@link ObjectDetectorOptions}. Only use this method when the {@link + * ObjectDetector} is created with {@link RunningMode.LIVE_STREAM}. + * + *

    It's required to provide a timestamp (in milliseconds) to indicate when the input image is + * sent to the object detector. The input timestamps must be monotonically increasing. + * + *

    {@link ObjectDetector} supports the following color space types: + * + *

      + *
    • {@link Bitmap.Config.ARGB_8888} + *
    + * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param timestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public void detectAsync(MPImage image, long timestampMs) { + detectAsync(image, ImageProcessingOptions.builder().build(), timestampMs); } /** @@ -259,12 +341,20 @@ public final class ObjectDetector extends BaseVisionTaskApi { *
  • {@link Bitmap.Config.ARGB_8888} * * - * @param inputImage a MediaPipe {@link MPImage} object for processing. - * @param inputTimestampMs the input timestamp (in milliseconds). + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. Note that region-of-interest is not supported + * by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in + * this method throwing an IllegalArgumentException. + * @param timestampMs the input timestamp (in milliseconds). + * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a + * region-of-interest. * @throws MediaPipeException if there is an internal error. */ - public void detectAsync(MPImage inputImage, long inputTimestampMs) { - sendLiveStreamData(inputImage, inputTimestampMs); + public void detectAsync( + MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { + validateImageProcessingOptions(imageProcessingOptions); + sendLiveStreamData(image, imageProcessingOptions, timestampMs); } /** Options for setting up an {@link ObjectDetector}. */ @@ -415,4 +505,15 @@ public final class ObjectDetector extends BaseVisionTaskApi { .build(); } } + + /** + * Validates that the provided {@link ImageProcessingOptions} doesn't contain a + * region-of-interest. + */ + private static void validateImageProcessingOptions( + ImageProcessingOptions imageProcessingOptions) { + if (imageProcessingOptions.regionOfInterest().isPresent()) { + throw new IllegalArgumentException("ObjectDetector doesn't support region-of-interest."); + } + } } diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/core/AndroidManifest.xml b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/core/AndroidManifest.xml new file mode 100644 index 000000000..aa2df6baf --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/core/AndroidManifest.xml @@ -0,0 +1,24 @@ + + + + + + + + + + + + + + + diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/core/BUILD b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/core/BUILD new file mode 100644 index 000000000..a7f804c64 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/core/BUILD @@ -0,0 +1,19 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +# TODO: Enable this in OSS diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/core/ImageProcessingOptionsTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/core/ImageProcessingOptionsTest.java new file mode 100644 index 000000000..078b62af1 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/core/ImageProcessingOptionsTest.java @@ -0,0 +1,70 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.core; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import android.graphics.RectF; +import androidx.test.ext.junit.runners.AndroidJUnit4; +import org.junit.Test; +import org.junit.runner.RunWith; + +/** Test for {@link ImageProcessingOptions}/ */ +@RunWith(AndroidJUnit4.class) +public final class ImageProcessingOptionsTest { + + @Test + public void succeedsWithValidInputs() throws Exception { + ImageProcessingOptions options = + ImageProcessingOptions.builder() + .setRegionOfInterest(new RectF(0.0f, 0.1f, 1.0f, 0.9f)) + .setRotationDegrees(270) + .build(); + } + + @Test + public void failsWithLeftHigherThanRight() { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + ImageProcessingOptions.builder() + .setRegionOfInterest(new RectF(0.9f, 0.0f, 0.1f, 1.0f)) + .build()); + assertThat(exception).hasMessageThat().contains("Expected left < right and top < bottom"); + } + + @Test + public void failsWithBottomHigherThanTop() { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + ImageProcessingOptions.builder() + .setRegionOfInterest(new RectF(0.0f, 0.9f, 1.0f, 0.1f)) + .build()); + assertThat(exception).hasMessageThat().contains("Expected left < right and top < bottom"); + } + + @Test + public void failsWithInvalidRotation() { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> ImageProcessingOptions.builder().setRotationDegrees(1).build()); + assertThat(exception).hasMessageThat().contains("Expected rotation to be a multiple of 90°"); + } +} diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerTest.java index 31e59a259..eca5d35c2 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerTest.java @@ -19,6 +19,7 @@ import static org.junit.Assert.assertThrows; import android.content.res.AssetManager; import android.graphics.BitmapFactory; +import android.graphics.RectF; import androidx.test.core.app.ApplicationProvider; import androidx.test.ext.junit.runners.AndroidJUnit4; import com.google.common.truth.Correspondence; @@ -30,6 +31,7 @@ import com.google.mediapipe.tasks.components.containers.Category; import com.google.mediapipe.tasks.components.containers.Landmark; import com.google.mediapipe.tasks.components.containers.proto.LandmarksDetectionResultProto.LandmarksDetectionResult; import com.google.mediapipe.tasks.core.BaseOptions; +import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; import com.google.mediapipe.tasks.vision.core.RunningMode; import com.google.mediapipe.tasks.vision.gesturerecognizer.GestureRecognizer.GestureRecognizerOptions; import java.io.InputStream; @@ -46,11 +48,14 @@ public class GestureRecognizerTest { private static final String GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE = "gesture_recognizer.task"; private static final String TWO_HANDS_IMAGE = "right_hands.jpg"; private static final String THUMB_UP_IMAGE = "thumb_up.jpg"; + private static final String POINTING_UP_ROTATED_IMAGE = "pointing_up_rotated.jpg"; private static final String NO_HANDS_IMAGE = "cats_and_dogs.jpg"; private static final String THUMB_UP_LANDMARKS = "thumb_up_landmarks.pb"; private static final String TAG = "Gesture Recognizer Test"; private static final String THUMB_UP_LABEL = "Thumb_Up"; private static final int THUMB_UP_INDEX = 5; + private static final String POINTING_UP_LABEL = "Pointing_Up"; + private static final int POINTING_UP_INDEX = 3; private static final float LANDMARKS_ERROR_TOLERANCE = 0.03f; private static final int IMAGE_WIDTH = 382; private static final int IMAGE_HEIGHT = 406; @@ -135,6 +140,53 @@ public class GestureRecognizerTest { gestureRecognizer.recognize(getImageFromAsset(TWO_HANDS_IMAGE)); assertThat(actualResult.handednesses()).hasSize(2); } + + @Test + public void recognize_successWithRotation() throws Exception { + GestureRecognizerOptions options = + GestureRecognizerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE) + .build()) + .setNumHands(1) + .build(); + GestureRecognizer gestureRecognizer = + GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options); + ImageProcessingOptions imageProcessingOptions = + ImageProcessingOptions.builder().setRotationDegrees(-90).build(); + GestureRecognitionResult actualResult = + gestureRecognizer.recognize( + getImageFromAsset(POINTING_UP_ROTATED_IMAGE), imageProcessingOptions); + assertThat(actualResult.gestures()).hasSize(1); + assertThat(actualResult.gestures().get(0).get(0).index()).isEqualTo(POINTING_UP_INDEX); + assertThat(actualResult.gestures().get(0).get(0).categoryName()).isEqualTo(POINTING_UP_LABEL); + } + + @Test + public void recognize_failsWithRegionOfInterest() throws Exception { + GestureRecognizerOptions options = + GestureRecognizerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE) + .build()) + .setNumHands(1) + .build(); + GestureRecognizer gestureRecognizer = + GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options); + ImageProcessingOptions imageProcessingOptions = + ImageProcessingOptions.builder().setRegionOfInterest(new RectF(0, 0, 1, 1)).build(); + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + gestureRecognizer.recognize( + getImageFromAsset(THUMB_UP_IMAGE), imageProcessingOptions)); + assertThat(exception) + .hasMessageThat() + .contains("GestureRecognizer doesn't support region-of-interest"); + } } @RunWith(AndroidJUnit4.class) @@ -195,12 +247,16 @@ public class GestureRecognizerTest { MediaPipeException exception = assertThrows( MediaPipeException.class, - () -> gestureRecognizer.recognizeForVideo(getImageFromAsset(THUMB_UP_IMAGE), 0)); + () -> + gestureRecognizer.recognizeForVideo( + getImageFromAsset(THUMB_UP_IMAGE), /*timestampsMs=*/ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); exception = assertThrows( MediaPipeException.class, - () -> gestureRecognizer.recognizeAsync(getImageFromAsset(THUMB_UP_IMAGE), 0)); + () -> + gestureRecognizer.recognizeAsync( + getImageFromAsset(THUMB_UP_IMAGE), /*timestampsMs=*/ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); } @@ -225,7 +281,9 @@ public class GestureRecognizerTest { exception = assertThrows( MediaPipeException.class, - () -> gestureRecognizer.recognizeAsync(getImageFromAsset(THUMB_UP_IMAGE), 0)); + () -> + gestureRecognizer.recognizeAsync( + getImageFromAsset(THUMB_UP_IMAGE), /*timestampsMs=*/ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); } @@ -251,7 +309,9 @@ public class GestureRecognizerTest { exception = assertThrows( MediaPipeException.class, - () -> gestureRecognizer.recognizeForVideo(getImageFromAsset(THUMB_UP_IMAGE), 0)); + () -> + gestureRecognizer.recognizeForVideo( + getImageFromAsset(THUMB_UP_IMAGE), /*timestampsMs=*/ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); } @@ -291,7 +351,8 @@ public class GestureRecognizerTest { getExpectedGestureRecognitionResult(THUMB_UP_LANDMARKS, THUMB_UP_LABEL, THUMB_UP_INDEX); for (int i = 0; i < 3; i++) { GestureRecognitionResult actualResult = - gestureRecognizer.recognizeForVideo(getImageFromAsset(THUMB_UP_IMAGE), i); + gestureRecognizer.recognizeForVideo( + getImageFromAsset(THUMB_UP_IMAGE), /*timestampsMs=*/ i); assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult); } } @@ -317,9 +378,11 @@ public class GestureRecognizerTest { .build(); try (GestureRecognizer gestureRecognizer = GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { - gestureRecognizer.recognizeAsync(image, 1); + gestureRecognizer.recognizeAsync(image, /*timestampsMs=*/ 1); MediaPipeException exception = - assertThrows(MediaPipeException.class, () -> gestureRecognizer.recognizeAsync(image, 0)); + assertThrows( + MediaPipeException.class, + () -> gestureRecognizer.recognizeAsync(image, /*timestampsMs=*/ 0)); assertThat(exception) .hasMessageThat() .contains("having a smaller timestamp than the processed timestamp"); @@ -348,7 +411,7 @@ public class GestureRecognizerTest { try (GestureRecognizer gestureRecognizer = GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { for (int i = 0; i < 3; i++) { - gestureRecognizer.recognizeAsync(image, i); + gestureRecognizer.recognizeAsync(image, /*timestampsMs=*/ i); } } } diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java index 966e4ff4a..99ebd9777 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java @@ -29,6 +29,7 @@ import com.google.mediapipe.tasks.components.containers.Category; import com.google.mediapipe.tasks.components.processors.ClassifierOptions; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.TestUtils; +import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; import com.google.mediapipe.tasks.vision.core.RunningMode; import com.google.mediapipe.tasks.vision.imageclassifier.ImageClassifier.ImageClassifierOptions; import java.io.InputStream; @@ -47,7 +48,9 @@ public class ImageClassifierTest { private static final String FLOAT_MODEL_FILE = "mobilenet_v2_1.0_224.tflite"; private static final String QUANTIZED_MODEL_FILE = "mobilenet_v1_0.25_224_quant.tflite"; private static final String BURGER_IMAGE = "burger.jpg"; + private static final String BURGER_ROTATED_IMAGE = "burger_rotated.jpg"; private static final String MULTI_OBJECTS_IMAGE = "multi_objects.jpg"; + private static final String MULTI_OBJECTS_ROTATED_IMAGE = "multi_objects_rotated.jpg"; @RunWith(AndroidJUnit4.class) public static final class General extends ImageClassifierTest { @@ -209,13 +212,60 @@ public class ImageClassifierTest { ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); // RectF around the soccer ball. RectF roi = new RectF(0.450f, 0.308f, 0.614f, 0.734f); + ImageProcessingOptions imageProcessingOptions = + ImageProcessingOptions.builder().setRegionOfInterest(roi).build(); ImageClassificationResult results = - imageClassifier.classify(getImageFromAsset(MULTI_OBJECTS_IMAGE), roi); + imageClassifier.classify(getImageFromAsset(MULTI_OBJECTS_IMAGE), imageProcessingOptions); assertHasOneHeadAndOneTimestamp(results, 0); assertCategoriesAre( results, Arrays.asList(Category.create(0.9969325f, 806, "soccer ball", ""))); } + + @Test + public void classify_succeedsWithRotation() throws Exception { + ImageClassifierOptions options = + ImageClassifierOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) + .setClassifierOptions(ClassifierOptions.builder().setMaxResults(3).build()) + .build(); + ImageClassifier imageClassifier = + ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); + ImageProcessingOptions imageProcessingOptions = + ImageProcessingOptions.builder().setRotationDegrees(-90).build(); + ImageClassificationResult results = + imageClassifier.classify(getImageFromAsset(BURGER_ROTATED_IMAGE), imageProcessingOptions); + + assertHasOneHeadAndOneTimestamp(results, 0); + assertCategoriesAre( + results, + Arrays.asList( + Category.create(0.6390683f, 934, "cheeseburger", ""), + Category.create(0.0495407f, 963, "meat loaf", ""), + Category.create(0.0469720f, 925, "guacamole", ""))); + } + + @Test + public void classify_succeedsWithRegionOfInterestAndRotation() throws Exception { + ImageClassifierOptions options = + ImageClassifierOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) + .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) + .build(); + ImageClassifier imageClassifier = + ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); + // RectF around the chair. + RectF roi = new RectF(0.0f, 0.1763f, 0.5642f, 0.3049f); + ImageProcessingOptions imageProcessingOptions = + ImageProcessingOptions.builder().setRegionOfInterest(roi).setRotationDegrees(-90).build(); + ImageClassificationResult results = + imageClassifier.classify( + getImageFromAsset(MULTI_OBJECTS_ROTATED_IMAGE), imageProcessingOptions); + + assertHasOneHeadAndOneTimestamp(results, 0); + assertCategoriesAre( + results, Arrays.asList(Category.create(0.686824f, 560, "folding chair", ""))); + } } @RunWith(AndroidJUnit4.class) @@ -269,12 +319,16 @@ public class ImageClassifierTest { MediaPipeException exception = assertThrows( MediaPipeException.class, - () -> imageClassifier.classifyForVideo(getImageFromAsset(BURGER_IMAGE), 0)); + () -> + imageClassifier.classifyForVideo( + getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); exception = assertThrows( MediaPipeException.class, - () -> imageClassifier.classifyAsync(getImageFromAsset(BURGER_IMAGE), 0)); + () -> + imageClassifier.classifyAsync( + getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); } @@ -296,7 +350,9 @@ public class ImageClassifierTest { exception = assertThrows( MediaPipeException.class, - () -> imageClassifier.classifyAsync(getImageFromAsset(BURGER_IMAGE), 0)); + () -> + imageClassifier.classifyAsync( + getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); } @@ -320,7 +376,9 @@ public class ImageClassifierTest { exception = assertThrows( MediaPipeException.class, - () -> imageClassifier.classifyForVideo(getImageFromAsset(BURGER_IMAGE), 0)); + () -> + imageClassifier.classifyForVideo( + getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); } @@ -352,7 +410,8 @@ public class ImageClassifierTest { ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); for (int i = 0; i < 3; i++) { - ImageClassificationResult results = imageClassifier.classifyForVideo(image, i); + ImageClassificationResult results = + imageClassifier.classifyForVideo(image, /*timestampMs=*/ i); assertHasOneHeadAndOneTimestamp(results, i); assertCategoriesAre( results, Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", ""))); @@ -377,9 +436,11 @@ public class ImageClassifierTest { .build(); try (ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { - imageClassifier.classifyAsync(getImageFromAsset(BURGER_IMAGE), 1); + imageClassifier.classifyAsync(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 1); MediaPipeException exception = - assertThrows(MediaPipeException.class, () -> imageClassifier.classifyAsync(image, 0)); + assertThrows( + MediaPipeException.class, + () -> imageClassifier.classifyAsync(image, /*timestampMs=*/ 0)); assertThat(exception) .hasMessageThat() .contains("having a smaller timestamp than the processed timestamp"); @@ -405,7 +466,7 @@ public class ImageClassifierTest { try (ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { for (int i = 0; i < 3; ++i) { - imageClassifier.classifyAsync(image, i); + imageClassifier.classifyAsync(image, /*timestampMs=*/ i); } } } diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetectorTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetectorTest.java index 91ffa9273..2878c380d 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetectorTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetectorTest.java @@ -29,6 +29,7 @@ import com.google.mediapipe.tasks.components.containers.Category; import com.google.mediapipe.tasks.components.containers.Detection; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.TestUtils; +import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; import com.google.mediapipe.tasks.vision.core.RunningMode; import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetector.ObjectDetectorOptions; import java.io.InputStream; @@ -45,10 +46,11 @@ import org.junit.runners.Suite.SuiteClasses; public class ObjectDetectorTest { private static final String MODEL_FILE = "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite"; private static final String CAT_AND_DOG_IMAGE = "cats_and_dogs.jpg"; + private static final String CAT_AND_DOG_ROTATED_IMAGE = "cats_and_dogs_rotated.jpg"; private static final int IMAGE_WIDTH = 1200; private static final int IMAGE_HEIGHT = 600; private static final float CAT_SCORE = 0.69f; - private static final RectF catBoundingBox = new RectF(611, 164, 986, 596); + private static final RectF CAT_BOUNDING_BOX = new RectF(611, 164, 986, 596); // TODO: Figure out why android_x86 and android_arm tests have slightly different // scores (0.6875 vs 0.69921875). private static final float SCORE_DIFF_TOLERANCE = 0.01f; @@ -67,7 +69,7 @@ public class ObjectDetectorTest { ObjectDetector objectDetector = ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); - assertContainsOnlyCat(results, catBoundingBox, CAT_SCORE); + assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE); } @Test @@ -104,7 +106,7 @@ public class ObjectDetectorTest { ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); // The score threshold should block all other other objects, except cat. - assertContainsOnlyCat(results, catBoundingBox, CAT_SCORE); + assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE); } @Test @@ -175,7 +177,7 @@ public class ObjectDetectorTest { ObjectDetector objectDetector = ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); - assertContainsOnlyCat(results, catBoundingBox, CAT_SCORE); + assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE); } @Test @@ -228,6 +230,46 @@ public class ObjectDetectorTest { .contains("`category_allowlist` and `category_denylist` are mutually exclusive options."); } + @Test + public void detect_succeedsWithRotation() throws Exception { + ObjectDetectorOptions options = + ObjectDetectorOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) + .setMaxResults(1) + .setCategoryAllowlist(Arrays.asList("cat")) + .build(); + ObjectDetector objectDetector = + ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); + ImageProcessingOptions imageProcessingOptions = + ImageProcessingOptions.builder().setRotationDegrees(-90).build(); + ObjectDetectionResult results = + objectDetector.detect( + getImageFromAsset(CAT_AND_DOG_ROTATED_IMAGE), imageProcessingOptions); + + assertContainsOnlyCat(results, new RectF(22.0f, 611.0f, 452.0f, 890.0f), 0.7109375f); + } + + @Test + public void detect_failsWithRegionOfInterest() throws Exception { + ObjectDetectorOptions options = + ObjectDetectorOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) + .build(); + ObjectDetector objectDetector = + ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); + ImageProcessingOptions imageProcessingOptions = + ImageProcessingOptions.builder().setRegionOfInterest(new RectF(0, 0, 1, 1)).build(); + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + objectDetector.detect( + getImageFromAsset(CAT_AND_DOG_IMAGE), imageProcessingOptions)); + assertThat(exception) + .hasMessageThat() + .contains("ObjectDetector doesn't support region-of-interest"); + } + // TODO: Implement detect_succeedsWithFloatImages, detect_succeedsWithOrientation, // detect_succeedsWithNumThreads, detect_successWithNumThreadsFromBaseOptions, // detect_failsWithInvalidNegativeNumThreads, detect_failsWithInvalidNumThreadsAsZero. @@ -282,12 +324,16 @@ public class ObjectDetectorTest { MediaPipeException exception = assertThrows( MediaPipeException.class, - () -> objectDetector.detectForVideo(getImageFromAsset(CAT_AND_DOG_IMAGE), 0)); + () -> + objectDetector.detectForVideo( + getImageFromAsset(CAT_AND_DOG_IMAGE), /*timestampsMs=*/ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); exception = assertThrows( MediaPipeException.class, - () -> objectDetector.detectAsync(getImageFromAsset(CAT_AND_DOG_IMAGE), 0)); + () -> + objectDetector.detectAsync( + getImageFromAsset(CAT_AND_DOG_IMAGE), /*timestampsMs=*/ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); } @@ -309,7 +355,9 @@ public class ObjectDetectorTest { exception = assertThrows( MediaPipeException.class, - () -> objectDetector.detectAsync(getImageFromAsset(CAT_AND_DOG_IMAGE), 0)); + () -> + objectDetector.detectAsync( + getImageFromAsset(CAT_AND_DOG_IMAGE), /*timestampsMs=*/ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); } @@ -333,7 +381,9 @@ public class ObjectDetectorTest { exception = assertThrows( MediaPipeException.class, - () -> objectDetector.detectForVideo(getImageFromAsset(CAT_AND_DOG_IMAGE), 0)); + () -> + objectDetector.detectForVideo( + getImageFromAsset(CAT_AND_DOG_IMAGE), /*timestampsMs=*/ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); } @@ -348,7 +398,7 @@ public class ObjectDetectorTest { ObjectDetector objectDetector = ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); - assertContainsOnlyCat(results, catBoundingBox, CAT_SCORE); + assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE); } @Test @@ -363,8 +413,9 @@ public class ObjectDetectorTest { ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); for (int i = 0; i < 3; i++) { ObjectDetectionResult results = - objectDetector.detectForVideo(getImageFromAsset(CAT_AND_DOG_IMAGE), i); - assertContainsOnlyCat(results, catBoundingBox, CAT_SCORE); + objectDetector.detectForVideo( + getImageFromAsset(CAT_AND_DOG_IMAGE), /*timestampsMs=*/ i); + assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE); } } @@ -377,16 +428,18 @@ public class ObjectDetectorTest { .setRunningMode(RunningMode.LIVE_STREAM) .setResultListener( (objectDetectionResult, inputImage) -> { - assertContainsOnlyCat(objectDetectionResult, catBoundingBox, CAT_SCORE); + assertContainsOnlyCat(objectDetectionResult, CAT_BOUNDING_BOX, CAT_SCORE); assertImageSizeIsExpected(inputImage); }) .setMaxResults(1) .build(); try (ObjectDetector objectDetector = ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { - objectDetector.detectAsync(image, 1); + objectDetector.detectAsync(image, /*timestampsMs=*/ 1); MediaPipeException exception = - assertThrows(MediaPipeException.class, () -> objectDetector.detectAsync(image, 0)); + assertThrows( + MediaPipeException.class, + () -> objectDetector.detectAsync(image, /*timestampsMs=*/ 0)); assertThat(exception) .hasMessageThat() .contains("having a smaller timestamp than the processed timestamp"); @@ -402,7 +455,7 @@ public class ObjectDetectorTest { .setRunningMode(RunningMode.LIVE_STREAM) .setResultListener( (objectDetectionResult, inputImage) -> { - assertContainsOnlyCat(objectDetectionResult, catBoundingBox, CAT_SCORE); + assertContainsOnlyCat(objectDetectionResult, CAT_BOUNDING_BOX, CAT_SCORE); assertImageSizeIsExpected(inputImage); }) .setMaxResults(1) @@ -410,7 +463,7 @@ public class ObjectDetectorTest { try (ObjectDetector objectDetector = ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { for (int i = 0; i < 3; i++) { - objectDetector.detectAsync(image, i); + objectDetector.detectAsync(image, /*timestampsMs=*/ i); } } }