Internal change

PiperOrigin-RevId: 480034669
This commit is contained in:
MediaPipe Team 2022-10-10 02:52:27 -07:00 committed by Copybara-Service
parent 1ab332835a
commit 62d2ae601e
5 changed files with 179 additions and 29 deletions

View File

@ -19,8 +19,12 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
android_library( android_library(
name = "core", name = "core",
srcs = glob(["*.java"]), srcs = glob(["*.java"]),
javacopts = [
"-Xep:AndroidJdkLibsChecker:OFF",
],
deps = [ deps = [
":libmediapipe_tasks_vision_jni_lib", ":libmediapipe_tasks_vision_jni_lib",
"//mediapipe/framework/formats:rect_java_proto_lite",
"//mediapipe/java/com/google/mediapipe/framework:android_framework_no_mff", "//mediapipe/java/com/google/mediapipe/framework:android_framework_no_mff",
"//mediapipe/java/com/google/mediapipe/framework/image", "//mediapipe/java/com/google/mediapipe/framework/image",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
@ -36,6 +40,7 @@ cc_binary(
deps = [ deps = [
"//mediapipe/calculators/core:flow_limiter_calculator", "//mediapipe/calculators/core:flow_limiter_calculator",
"//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni", "//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni",
"//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph",
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph", "//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni",
], ],

View File

@ -14,101 +14,247 @@
package com.google.mediapipe.tasks.vision.core; package com.google.mediapipe.tasks.vision.core;
import android.graphics.RectF;
import com.google.mediapipe.formats.proto.RectProto.NormalizedRect;
import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.MediaPipeException;
import com.google.mediapipe.framework.Packet; import com.google.mediapipe.framework.Packet;
import com.google.mediapipe.framework.ProtoUtil;
import com.google.mediapipe.framework.image.Image; import com.google.mediapipe.framework.image.Image;
import com.google.mediapipe.tasks.core.TaskResult; import com.google.mediapipe.tasks.core.TaskResult;
import com.google.mediapipe.tasks.core.TaskRunner; import com.google.mediapipe.tasks.core.TaskRunner;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.Optional;
/** The base class of MediaPipe vision tasks. */ /** The base class of MediaPipe vision tasks. */
public class BaseVisionTaskApi implements AutoCloseable { public class BaseVisionTaskApi implements AutoCloseable {
private static final long MICROSECONDS_PER_MILLISECOND = 1000; private static final long MICROSECONDS_PER_MILLISECOND = 1000;
private final TaskRunner runner; private final TaskRunner runner;
private final RunningMode runningMode; private final RunningMode runningMode;
private final String imageStreamName;
private final Optional<String> normRectStreamName;
static { static {
System.loadLibrary("mediapipe_tasks_vision_jni"); System.loadLibrary("mediapipe_tasks_vision_jni");
ProtoUtil.registerTypeName(NormalizedRect.class, "mediapipe.NormalizedRect");
} }
/** /**
* Constructor to initialize an {@link BaseVisionTaskApi} from a {@link TaskRunner} and a vision * Constructor to initialize a {@link BaseVisionTaskApi} only taking images as input.
* task {@link RunningMode}.
* *
* @param runner a {@link TaskRunner}. * @param runner a {@link TaskRunner}.
* @param runningMode a mediapipe vision task {@link RunningMode}. * @param runningMode a mediapipe vision task {@link RunningMode}.
* @param imageStreamName the name of the input image stream.
*/ */
public BaseVisionTaskApi(TaskRunner runner, RunningMode runningMode) { public BaseVisionTaskApi(TaskRunner runner, RunningMode runningMode, String imageStreamName) {
this.runner = runner; this.runner = runner;
this.runningMode = runningMode; this.runningMode = runningMode;
this.imageStreamName = imageStreamName;
this.normRectStreamName = Optional.empty();
}
/**
* Constructor to initialize a {@link BaseVisionTaskApi} taking images and normalized rects as
* input.
*
* @param runner a {@link TaskRunner}.
* @param runningMode a mediapipe vision task {@link RunningMode}.
* @param imageStreamName the name of the input image stream.
* @param normRectStreamName the name of the input normalized rect image stream.
*/
public BaseVisionTaskApi(
TaskRunner runner,
RunningMode runningMode,
String imageStreamName,
String normRectStreamName) {
this.runner = runner;
this.runningMode = runningMode;
this.imageStreamName = imageStreamName;
this.normRectStreamName = Optional.of(normRectStreamName);
} }
/** /**
* A synchronous method to process single image inputs. The call blocks the current thread until a * A synchronous method to process single image inputs. The call blocks the current thread until a
* failure status or a successful result is returned. * failure status or a successful result is returned.
* *
* @param imageStreamName the image input stream name.
* @param image a MediaPipe {@link Image} object for processing. * @param image a MediaPipe {@link Image} object for processing.
* @throws MediaPipeException if the task is not in the image mode. * @throws MediaPipeException if the task is not in the image mode or requires a normalized rect
* input.
*/ */
protected TaskResult processImageData(String imageStreamName, Image image) { protected TaskResult processImageData(Image image) {
if (runningMode != RunningMode.IMAGE) { if (runningMode != RunningMode.IMAGE) {
throw new MediaPipeException( throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task is not initialized with the image mode. Current running mode:" "Task is not initialized with the image mode. Current running mode:"
+ runningMode.name()); + runningMode.name());
} }
if (normRectStreamName.isPresent()) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task expects a normalized rect as input.");
}
Map<String, Packet> inputPackets = new HashMap<>(); Map<String, Packet> inputPackets = new HashMap<>();
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image)); inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
return runner.process(inputPackets); return runner.process(inputPackets);
} }
/**
* A synchronous method to process single image inputs. The call blocks the current thread until a
* failure status or a successful result is returned.
*
* @param image a MediaPipe {@link Image} object for processing.
* @param roi a {@link RectF} defining the region-of-interest to process in the image. Coordinates
* are expected to be specified as normalized values in [0,1].
* @throws MediaPipeException if the task is not in the image mode or doesn't require a normalized
* rect.
*/
protected TaskResult processImageData(Image image, RectF roi) {
if (runningMode != RunningMode.IMAGE) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task is not initialized with the image mode. Current running mode:"
+ runningMode.name());
}
if (!normRectStreamName.isPresent()) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task doesn't expect a normalized rect as input.");
}
Map<String, Packet> inputPackets = new HashMap<>();
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
inputPackets.put(
normRectStreamName.get(),
runner.getPacketCreator().createProto(convertToNormalizedRect(roi)));
return runner.process(inputPackets);
}
/** /**
* A synchronous method to process continuous video frames. The call blocks the current thread * A synchronous method to process continuous video frames. The call blocks the current thread
* until a failure status or a successful result is returned. * until a failure status or a successful result is returned.
* *
* @param imageStreamName the image input stream name.
* @param image a MediaPipe {@link Image} object for processing. * @param image a MediaPipe {@link Image} object for processing.
* @param timestampMs the corresponding timestamp of the input image in milliseconds. * @param timestampMs the corresponding timestamp of the input image in milliseconds.
* @throws MediaPipeException if the task is not in the video mode. * @throws MediaPipeException if the task is not in the video mode or requires a normalized rect
* input.
*/ */
protected TaskResult processVideoData(String imageStreamName, Image image, long timestampMs) { protected TaskResult processVideoData(Image image, long timestampMs) {
if (runningMode != RunningMode.VIDEO) { if (runningMode != RunningMode.VIDEO) {
throw new MediaPipeException( throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task is not initialized with the video mode. Current running mode:" "Task is not initialized with the video mode. Current running mode:"
+ runningMode.name()); + runningMode.name());
} }
if (normRectStreamName.isPresent()) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task expects a normalized rect as input.");
}
Map<String, Packet> inputPackets = new HashMap<>(); Map<String, Packet> inputPackets = new HashMap<>();
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image)); inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
return runner.process(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND); return runner.process(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
} }
/**
* A synchronous method to process continuous video frames. The call blocks the current thread
* until a failure status or a successful result is returned.
*
* @param image a MediaPipe {@link Image} object for processing.
* @param roi a {@link RectF} defining the region-of-interest to process in the image. Coordinates
* are expected to be specified as normalized values in [0,1].
* @param timestampMs the corresponding timestamp of the input image in milliseconds.
* @throws MediaPipeException if the task is not in the video mode or doesn't require a normalized
* rect.
*/
protected TaskResult processVideoData(Image image, RectF roi, long timestampMs) {
if (runningMode != RunningMode.VIDEO) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task is not initialized with the video mode. Current running mode:"
+ runningMode.name());
}
if (!normRectStreamName.isPresent()) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task doesn't expect a normalized rect as input.");
}
Map<String, Packet> inputPackets = new HashMap<>();
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
inputPackets.put(
normRectStreamName.get(),
runner.getPacketCreator().createProto(convertToNormalizedRect(roi)));
return runner.process(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
}
/** /**
* An asynchronous method to send live stream data to the {@link TaskRunner}. The results will be * An asynchronous method to send live stream data to the {@link TaskRunner}. The results will be
* available in the user-defined result listener. * available in the user-defined result listener.
* *
* @param imageStreamName the image input stream name.
* @param image a MediaPipe {@link Image} object for processing. * @param image a MediaPipe {@link Image} object for processing.
* @param timestampMs the corresponding timestamp of the input image in milliseconds. * @param timestampMs the corresponding timestamp of the input image in milliseconds.
* @throws MediaPipeException if the task is not in the video mode. * @throws MediaPipeException if the task is not in the video mode or requires a normalized rect
* input.
*/ */
protected void sendLiveStreamData(String imageStreamName, Image image, long timestampMs) { protected void sendLiveStreamData(Image image, long timestampMs) {
if (runningMode != RunningMode.LIVE_STREAM) { if (runningMode != RunningMode.LIVE_STREAM) {
throw new MediaPipeException( throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task is not initialized with the live stream mode. Current running mode:" "Task is not initialized with the live stream mode. Current running mode:"
+ runningMode.name()); + runningMode.name());
} }
if (normRectStreamName.isPresent()) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task expects a normalized rect as input.");
}
Map<String, Packet> inputPackets = new HashMap<>(); Map<String, Packet> inputPackets = new HashMap<>();
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image)); inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
runner.send(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND); runner.send(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
} }
/**
* An asynchronous method to send live stream data to the {@link TaskRunner}. The results will be
* available in the user-defined result listener.
*
* @param image a MediaPipe {@link Image} object for processing.
* @param roi a {@link RectF} defining the region-of-interest to process in the image. Coordinates
* are expected to be specified as normalized values in [0,1].
* @param timestampMs the corresponding timestamp of the input image in milliseconds.
* @throws MediaPipeException if the task is not in the video mode or doesn't require a normalized
* rect.
*/
protected void sendLiveStreamData(Image image, RectF roi, long timestampMs) {
if (runningMode != RunningMode.LIVE_STREAM) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task is not initialized with the live stream mode. Current running mode:"
+ runningMode.name());
}
if (!normRectStreamName.isPresent()) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task doesn't expect a normalized rect as input.");
}
Map<String, Packet> inputPackets = new HashMap<>();
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
inputPackets.put(
normRectStreamName.get(),
runner.getPacketCreator().createProto(convertToNormalizedRect(roi)));
runner.send(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
}
/** Closes and cleans up the MediaPipe vision task. */ /** Closes and cleans up the MediaPipe vision task. */
@Override @Override
public void close() { public void close() {
runner.close(); runner.close();
} }
/** Converts a {@link RectF} object into a {@link NormalizedRect} protobuf message. */
private static NormalizedRect convertToNormalizedRect(RectF rect) {
return NormalizedRect.newBuilder()
.setXCenter(rect.centerX())
.setYCenter(rect.centerY())
.setWidth(rect.width())
.setHeight(rect.height())
.build();
}
} }

View File

@ -38,7 +38,8 @@ public abstract class ObjectDetectionResult implements TaskResult {
* Creates an {@link ObjectDetectionResult} instance from a list of {@link Detection} protobuf * Creates an {@link ObjectDetectionResult} instance from a list of {@link Detection} protobuf
* messages. * messages.
* *
* @param detectionList a list of {@link Detection} protobuf messages. * @param detectionList a list of {@link DetectionOuterClass.Detection} protobuf messages.
* @param timestampMs a timestamp for this result.
*/ */
static ObjectDetectionResult create(List<Detection> detectionList, long timestampMs) { static ObjectDetectionResult create(List<Detection> detectionList, long timestampMs) {
List<com.google.mediapipe.tasks.components.containers.Detection> detections = new ArrayList<>(); List<com.google.mediapipe.tasks.components.containers.Detection> detections = new ArrayList<>();

View File

@ -155,7 +155,7 @@ public final class ObjectDetector extends BaseVisionTaskApi {
* Creates an {@link ObjectDetector} instance from an {@link ObjectDetectorOptions}. * Creates an {@link ObjectDetector} instance from an {@link ObjectDetectorOptions}.
* *
* @param context an Android {@link Context}. * @param context an Android {@link Context}.
* @param detectorOptions a {@link ObjectDetectorOptions} instance. * @param detectorOptions an {@link ObjectDetectorOptions} instance.
* @throws MediaPipeException if there is an error during {@link ObjectDetector} creation. * @throws MediaPipeException if there is an error during {@link ObjectDetector} creation.
*/ */
public static ObjectDetector createFromOptions( public static ObjectDetector createFromOptions(
@ -192,7 +192,6 @@ public final class ObjectDetector extends BaseVisionTaskApi {
.setEnableFlowLimiting(detectorOptions.runningMode() == RunningMode.LIVE_STREAM) .setEnableFlowLimiting(detectorOptions.runningMode() == RunningMode.LIVE_STREAM)
.build(), .build(),
handler); handler);
detectorOptions.errorListener().ifPresent(runner::setErrorListener);
return new ObjectDetector(runner, detectorOptions.runningMode()); return new ObjectDetector(runner, detectorOptions.runningMode());
} }
@ -204,7 +203,7 @@ public final class ObjectDetector extends BaseVisionTaskApi {
* @param runningMode a mediapipe vision task {@link RunningMode}. * @param runningMode a mediapipe vision task {@link RunningMode}.
*/ */
private ObjectDetector(TaskRunner taskRunner, RunningMode runningMode) { private ObjectDetector(TaskRunner taskRunner, RunningMode runningMode) {
super(taskRunner, runningMode); super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME);
} }
/** /**
@ -221,7 +220,7 @@ public final class ObjectDetector extends BaseVisionTaskApi {
* @throws MediaPipeException if there is an internal error. * @throws MediaPipeException if there is an internal error.
*/ */
public ObjectDetectionResult detect(Image inputImage) { public ObjectDetectionResult detect(Image inputImage) {
return (ObjectDetectionResult) processImageData(IMAGE_IN_STREAM_NAME, inputImage); return (ObjectDetectionResult) processImageData(inputImage);
} }
/** /**
@ -242,8 +241,7 @@ public final class ObjectDetector extends BaseVisionTaskApi {
* @throws MediaPipeException if there is an internal error. * @throws MediaPipeException if there is an internal error.
*/ */
public ObjectDetectionResult detectForVideo(Image inputImage, long inputTimestampMs) { public ObjectDetectionResult detectForVideo(Image inputImage, long inputTimestampMs) {
return (ObjectDetectionResult) return (ObjectDetectionResult) processVideoData(inputImage, inputTimestampMs);
processVideoData(IMAGE_IN_STREAM_NAME, inputImage, inputTimestampMs);
} }
/** /**
@ -265,7 +263,7 @@ public final class ObjectDetector extends BaseVisionTaskApi {
* @throws MediaPipeException if there is an internal error. * @throws MediaPipeException if there is an internal error.
*/ */
public void detectAsync(Image inputImage, long inputTimestampMs) { public void detectAsync(Image inputImage, long inputTimestampMs) {
sendLiveStreamData(IMAGE_IN_STREAM_NAME, inputImage, inputTimestampMs); sendLiveStreamData(inputImage, inputTimestampMs);
} }
/** Options for setting up an {@link ObjectDetector}. */ /** Options for setting up an {@link ObjectDetector}. */
@ -275,12 +273,12 @@ public final class ObjectDetector extends BaseVisionTaskApi {
/** Builder for {@link ObjectDetectorOptions}. */ /** Builder for {@link ObjectDetectorOptions}. */
@AutoValue.Builder @AutoValue.Builder
public abstract static class Builder { public abstract static class Builder {
/** Sets the base options for the object detector task. */ /** Sets the {@link BaseOptions} for the object detector task. */
public abstract Builder setBaseOptions(BaseOptions value); public abstract Builder setBaseOptions(BaseOptions value);
/** /**
* Sets the running mode for the object detector task. Default to the image mode. Object * Sets the {@link RunningMode} for the object detector task. Default to the image mode.
* detector has three modes: * Object detector has three modes:
* *
* <ul> * <ul>
* <li>IMAGE: The mode for detecting objects on single image inputs. * <li>IMAGE: The mode for detecting objects on single image inputs.
@ -293,8 +291,8 @@ public final class ObjectDetector extends BaseVisionTaskApi {
public abstract Builder setRunningMode(RunningMode value); public abstract Builder setRunningMode(RunningMode value);
/** /**
* Sets the locale to use for display names specified through the TFLite Model Metadata, if * Sets the optional locale to use for display names specified through the TFLite Model
* any. Defaults to English. * Metadata, if any.
*/ */
public abstract Builder setDisplayNamesLocale(String value); public abstract Builder setDisplayNamesLocale(String value);
@ -331,12 +329,12 @@ public final class ObjectDetector extends BaseVisionTaskApi {
public abstract Builder setCategoryDenylist(List<String> value); public abstract Builder setCategoryDenylist(List<String> value);
/** /**
* Sets the result listener to receive the detection results asynchronously when the object * Sets the {@link ResultListener} to receive the detection results asynchronously when the
* detector is in the live stream mode. * object detector is in the live stream mode.
*/ */
public abstract Builder setResultListener(ResultListener<ObjectDetectionResult, Image> value); public abstract Builder setResultListener(ResultListener<ObjectDetectionResult, Image> value);
/** Sets an optional error listener. */ /** Sets an optional {@link ErrorListener}}. */
public abstract Builder setErrorListener(ErrorListener value); public abstract Builder setErrorListener(ErrorListener value);
abstract ObjectDetectorOptions autoBuild(); abstract ObjectDetectorOptions autoBuild();

View File

@ -11,7 +11,7 @@
android:targetSdkVersion="30" /> android:targetSdkVersion="30" />
<application <application
android:label="facedetectiontest" android:label="objectdetectortest"
android:name="android.support.multidex.MultiDexApplication" android:name="android.support.multidex.MultiDexApplication"
android:taskAffinity=""> android:taskAffinity="">
<uses-library android:name="android.test.runner" /> <uses-library android:name="android.test.runner" />