From 62d2ae601e5663d0b1d1fb4c87267aa4a9673237 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 10 Oct 2022 02:52:27 -0700 Subject: [PATCH] Internal change PiperOrigin-RevId: 480034669 --- .../google/mediapipe/tasks/vision/core/BUILD | 5 + .../tasks/vision/core/BaseVisionTaskApi.java | 170 ++++++++++++++++-- .../objectdetector/ObjectDetectionResult.java | 3 +- .../vision/objectdetector/ObjectDetector.java | 28 ++- .../vision/objectdetector/AndroidManifest.xml | 2 +- 5 files changed, 179 insertions(+), 29 deletions(-) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BUILD index 94f77ea68..8df9173b2 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BUILD @@ -19,8 +19,12 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) android_library( name = "core", srcs = glob(["*.java"]), + javacopts = [ + "-Xep:AndroidJdkLibsChecker:OFF", + ], deps = [ ":libmediapipe_tasks_vision_jni_lib", + "//mediapipe/framework/formats:rect_java_proto_lite", "//mediapipe/java/com/google/mediapipe/framework:android_framework_no_mff", "//mediapipe/java/com/google/mediapipe/framework/image", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", @@ -36,6 +40,7 @@ cc_binary( deps = [ "//mediapipe/calculators/core:flow_limiter_calculator", "//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni", + "//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph", "//mediapipe/tasks/cc/vision/object_detector:object_detector_graph", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni", ], diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BaseVisionTaskApi.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BaseVisionTaskApi.java index 92f64e898..7ab8e75a1 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BaseVisionTaskApi.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BaseVisionTaskApi.java @@ -14,101 +14,247 @@ package com.google.mediapipe.tasks.vision.core; +import android.graphics.RectF; +import com.google.mediapipe.formats.proto.RectProto.NormalizedRect; import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.Packet; +import com.google.mediapipe.framework.ProtoUtil; import com.google.mediapipe.framework.image.Image; import com.google.mediapipe.tasks.core.TaskResult; import com.google.mediapipe.tasks.core.TaskRunner; import java.util.HashMap; import java.util.Map; +import java.util.Optional; /** The base class of MediaPipe vision tasks. */ public class BaseVisionTaskApi implements AutoCloseable { private static final long MICROSECONDS_PER_MILLISECOND = 1000; private final TaskRunner runner; private final RunningMode runningMode; + private final String imageStreamName; + private final Optional normRectStreamName; static { System.loadLibrary("mediapipe_tasks_vision_jni"); + ProtoUtil.registerTypeName(NormalizedRect.class, "mediapipe.NormalizedRect"); } /** - * Constructor to initialize an {@link BaseVisionTaskApi} from a {@link TaskRunner} and a vision - * task {@link RunningMode}. + * Constructor to initialize a {@link BaseVisionTaskApi} only taking images as input. * * @param runner a {@link TaskRunner}. * @param runningMode a mediapipe vision task {@link RunningMode}. + * @param imageStreamName the name of the input image stream. */ - public BaseVisionTaskApi(TaskRunner runner, RunningMode runningMode) { + public BaseVisionTaskApi(TaskRunner runner, RunningMode runningMode, String imageStreamName) { this.runner = runner; this.runningMode = runningMode; + this.imageStreamName = imageStreamName; + this.normRectStreamName = Optional.empty(); + } + + /** + * Constructor to initialize a {@link BaseVisionTaskApi} taking images and normalized rects as + * input. + * + * @param runner a {@link TaskRunner}. + * @param runningMode a mediapipe vision task {@link RunningMode}. + * @param imageStreamName the name of the input image stream. + * @param normRectStreamName the name of the input normalized rect image stream. + */ + public BaseVisionTaskApi( + TaskRunner runner, + RunningMode runningMode, + String imageStreamName, + String normRectStreamName) { + this.runner = runner; + this.runningMode = runningMode; + this.imageStreamName = imageStreamName; + this.normRectStreamName = Optional.of(normRectStreamName); } /** * A synchronous method to process single image inputs. The call blocks the current thread until a * failure status or a successful result is returned. * - * @param imageStreamName the image input stream name. * @param image a MediaPipe {@link Image} object for processing. - * @throws MediaPipeException if the task is not in the image mode. + * @throws MediaPipeException if the task is not in the image mode or requires a normalized rect + * input. */ - protected TaskResult processImageData(String imageStreamName, Image image) { + protected TaskResult processImageData(Image image) { if (runningMode != RunningMode.IMAGE) { throw new MediaPipeException( MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), "Task is not initialized with the image mode. Current running mode:" + runningMode.name()); } + if (normRectStreamName.isPresent()) { + throw new MediaPipeException( + MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), + "Task expects a normalized rect as input."); + } Map inputPackets = new HashMap<>(); inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image)); return runner.process(inputPackets); } + /** + * A synchronous method to process single image inputs. The call blocks the current thread until a + * failure status or a successful result is returned. + * + * @param image a MediaPipe {@link Image} object for processing. + * @param roi a {@link RectF} defining the region-of-interest to process in the image. Coordinates + * are expected to be specified as normalized values in [0,1]. + * @throws MediaPipeException if the task is not in the image mode or doesn't require a normalized + * rect. + */ + protected TaskResult processImageData(Image image, RectF roi) { + if (runningMode != RunningMode.IMAGE) { + throw new MediaPipeException( + MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), + "Task is not initialized with the image mode. Current running mode:" + + runningMode.name()); + } + if (!normRectStreamName.isPresent()) { + throw new MediaPipeException( + MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), + "Task doesn't expect a normalized rect as input."); + } + Map inputPackets = new HashMap<>(); + inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image)); + inputPackets.put( + normRectStreamName.get(), + runner.getPacketCreator().createProto(convertToNormalizedRect(roi))); + return runner.process(inputPackets); + } + /** * A synchronous method to process continuous video frames. The call blocks the current thread * until a failure status or a successful result is returned. * - * @param imageStreamName the image input stream name. * @param image a MediaPipe {@link Image} object for processing. * @param timestampMs the corresponding timestamp of the input image in milliseconds. - * @throws MediaPipeException if the task is not in the video mode. + * @throws MediaPipeException if the task is not in the video mode or requires a normalized rect + * input. */ - protected TaskResult processVideoData(String imageStreamName, Image image, long timestampMs) { + protected TaskResult processVideoData(Image image, long timestampMs) { if (runningMode != RunningMode.VIDEO) { throw new MediaPipeException( MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), "Task is not initialized with the video mode. Current running mode:" + runningMode.name()); } + if (normRectStreamName.isPresent()) { + throw new MediaPipeException( + MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), + "Task expects a normalized rect as input."); + } Map inputPackets = new HashMap<>(); inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image)); return runner.process(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND); } + /** + * A synchronous method to process continuous video frames. The call blocks the current thread + * until a failure status or a successful result is returned. + * + * @param image a MediaPipe {@link Image} object for processing. + * @param roi a {@link RectF} defining the region-of-interest to process in the image. Coordinates + * are expected to be specified as normalized values in [0,1]. + * @param timestampMs the corresponding timestamp of the input image in milliseconds. + * @throws MediaPipeException if the task is not in the video mode or doesn't require a normalized + * rect. + */ + protected TaskResult processVideoData(Image image, RectF roi, long timestampMs) { + if (runningMode != RunningMode.VIDEO) { + throw new MediaPipeException( + MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), + "Task is not initialized with the video mode. Current running mode:" + + runningMode.name()); + } + if (!normRectStreamName.isPresent()) { + throw new MediaPipeException( + MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), + "Task doesn't expect a normalized rect as input."); + } + Map inputPackets = new HashMap<>(); + inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image)); + inputPackets.put( + normRectStreamName.get(), + runner.getPacketCreator().createProto(convertToNormalizedRect(roi))); + return runner.process(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND); + } + /** * An asynchronous method to send live stream data to the {@link TaskRunner}. The results will be * available in the user-defined result listener. * - * @param imageStreamName the image input stream name. * @param image a MediaPipe {@link Image} object for processing. * @param timestampMs the corresponding timestamp of the input image in milliseconds. - * @throws MediaPipeException if the task is not in the video mode. + * @throws MediaPipeException if the task is not in the video mode or requires a normalized rect + * input. */ - protected void sendLiveStreamData(String imageStreamName, Image image, long timestampMs) { + protected void sendLiveStreamData(Image image, long timestampMs) { if (runningMode != RunningMode.LIVE_STREAM) { throw new MediaPipeException( MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), "Task is not initialized with the live stream mode. Current running mode:" + runningMode.name()); } + if (normRectStreamName.isPresent()) { + throw new MediaPipeException( + MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), + "Task expects a normalized rect as input."); + } Map inputPackets = new HashMap<>(); inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image)); runner.send(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND); } + /** + * An asynchronous method to send live stream data to the {@link TaskRunner}. The results will be + * available in the user-defined result listener. + * + * @param image a MediaPipe {@link Image} object for processing. + * @param roi a {@link RectF} defining the region-of-interest to process in the image. Coordinates + * are expected to be specified as normalized values in [0,1]. + * @param timestampMs the corresponding timestamp of the input image in milliseconds. + * @throws MediaPipeException if the task is not in the video mode or doesn't require a normalized + * rect. + */ + protected void sendLiveStreamData(Image image, RectF roi, long timestampMs) { + if (runningMode != RunningMode.LIVE_STREAM) { + throw new MediaPipeException( + MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), + "Task is not initialized with the live stream mode. Current running mode:" + + runningMode.name()); + } + if (!normRectStreamName.isPresent()) { + throw new MediaPipeException( + MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), + "Task doesn't expect a normalized rect as input."); + } + Map inputPackets = new HashMap<>(); + inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image)); + inputPackets.put( + normRectStreamName.get(), + runner.getPacketCreator().createProto(convertToNormalizedRect(roi))); + runner.send(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND); + } + /** Closes and cleans up the MediaPipe vision task. */ @Override public void close() { runner.close(); } + + /** Converts a {@link RectF} object into a {@link NormalizedRect} protobuf message. */ + private static NormalizedRect convertToNormalizedRect(RectF rect) { + return NormalizedRect.newBuilder() + .setXCenter(rect.centerX()) + .setYCenter(rect.centerY()) + .setWidth(rect.width()) + .setHeight(rect.height()) + .build(); + } } diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetectionResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetectionResult.java index 9a0c7e8f6..108c021ea 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetectionResult.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetectionResult.java @@ -38,7 +38,8 @@ public abstract class ObjectDetectionResult implements TaskResult { * Creates an {@link ObjectDetectionResult} instance from a list of {@link Detection} protobuf * messages. * - * @param detectionList a list of {@link Detection} protobuf messages. + * @param detectionList a list of {@link DetectionOuterClass.Detection} protobuf messages. + * @param timestampMs a timestamp for this result. */ static ObjectDetectionResult create(List detectionList, long timestampMs) { List detections = new ArrayList<>(); diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java index 463ab4c43..b64992d3e 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java @@ -155,7 +155,7 @@ public final class ObjectDetector extends BaseVisionTaskApi { * Creates an {@link ObjectDetector} instance from an {@link ObjectDetectorOptions}. * * @param context an Android {@link Context}. - * @param detectorOptions a {@link ObjectDetectorOptions} instance. + * @param detectorOptions an {@link ObjectDetectorOptions} instance. * @throws MediaPipeException if there is an error during {@link ObjectDetector} creation. */ public static ObjectDetector createFromOptions( @@ -192,7 +192,6 @@ public final class ObjectDetector extends BaseVisionTaskApi { .setEnableFlowLimiting(detectorOptions.runningMode() == RunningMode.LIVE_STREAM) .build(), handler); - detectorOptions.errorListener().ifPresent(runner::setErrorListener); return new ObjectDetector(runner, detectorOptions.runningMode()); } @@ -204,7 +203,7 @@ public final class ObjectDetector extends BaseVisionTaskApi { * @param runningMode a mediapipe vision task {@link RunningMode}. */ private ObjectDetector(TaskRunner taskRunner, RunningMode runningMode) { - super(taskRunner, runningMode); + super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME); } /** @@ -221,7 +220,7 @@ public final class ObjectDetector extends BaseVisionTaskApi { * @throws MediaPipeException if there is an internal error. */ public ObjectDetectionResult detect(Image inputImage) { - return (ObjectDetectionResult) processImageData(IMAGE_IN_STREAM_NAME, inputImage); + return (ObjectDetectionResult) processImageData(inputImage); } /** @@ -242,8 +241,7 @@ public final class ObjectDetector extends BaseVisionTaskApi { * @throws MediaPipeException if there is an internal error. */ public ObjectDetectionResult detectForVideo(Image inputImage, long inputTimestampMs) { - return (ObjectDetectionResult) - processVideoData(IMAGE_IN_STREAM_NAME, inputImage, inputTimestampMs); + return (ObjectDetectionResult) processVideoData(inputImage, inputTimestampMs); } /** @@ -265,7 +263,7 @@ public final class ObjectDetector extends BaseVisionTaskApi { * @throws MediaPipeException if there is an internal error. */ public void detectAsync(Image inputImage, long inputTimestampMs) { - sendLiveStreamData(IMAGE_IN_STREAM_NAME, inputImage, inputTimestampMs); + sendLiveStreamData(inputImage, inputTimestampMs); } /** Options for setting up an {@link ObjectDetector}. */ @@ -275,12 +273,12 @@ public final class ObjectDetector extends BaseVisionTaskApi { /** Builder for {@link ObjectDetectorOptions}. */ @AutoValue.Builder public abstract static class Builder { - /** Sets the base options for the object detector task. */ + /** Sets the {@link BaseOptions} for the object detector task. */ public abstract Builder setBaseOptions(BaseOptions value); /** - * Sets the running mode for the object detector task. Default to the image mode. Object - * detector has three modes: + * Sets the {@link RunningMode} for the object detector task. Default to the image mode. + * Object detector has three modes: * *
    *
  • IMAGE: The mode for detecting objects on single image inputs. @@ -293,8 +291,8 @@ public final class ObjectDetector extends BaseVisionTaskApi { public abstract Builder setRunningMode(RunningMode value); /** - * Sets the locale to use for display names specified through the TFLite Model Metadata, if - * any. Defaults to English. + * Sets the optional locale to use for display names specified through the TFLite Model + * Metadata, if any. */ public abstract Builder setDisplayNamesLocale(String value); @@ -331,12 +329,12 @@ public final class ObjectDetector extends BaseVisionTaskApi { public abstract Builder setCategoryDenylist(List value); /** - * Sets the result listener to receive the detection results asynchronously when the object - * detector is in the live stream mode. + * Sets the {@link ResultListener} to receive the detection results asynchronously when the + * object detector is in the live stream mode. */ public abstract Builder setResultListener(ResultListener value); - /** Sets an optional error listener. */ + /** Sets an optional {@link ErrorListener}}. */ public abstract Builder setErrorListener(ErrorListener value); abstract ObjectDetectorOptions autoBuild(); diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/objectdetector/AndroidManifest.xml b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/objectdetector/AndroidManifest.xml index 3e5e81920..19bd638e9 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/objectdetector/AndroidManifest.xml +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/objectdetector/AndroidManifest.xml @@ -11,7 +11,7 @@ android:targetSdkVersion="30" />