Internal change

PiperOrigin-RevId: 480034669
This commit is contained in:
MediaPipe Team 2022-10-10 02:52:27 -07:00 committed by Copybara-Service
parent 1ab332835a
commit 62d2ae601e
5 changed files with 179 additions and 29 deletions

View File

@ -19,8 +19,12 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
android_library(
name = "core",
srcs = glob(["*.java"]),
javacopts = [
"-Xep:AndroidJdkLibsChecker:OFF",
],
deps = [
":libmediapipe_tasks_vision_jni_lib",
"//mediapipe/framework/formats:rect_java_proto_lite",
"//mediapipe/java/com/google/mediapipe/framework:android_framework_no_mff",
"//mediapipe/java/com/google/mediapipe/framework/image",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
@ -36,6 +40,7 @@ cc_binary(
deps = [
"//mediapipe/calculators/core:flow_limiter_calculator",
"//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni",
"//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph",
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni",
],

View File

@ -14,101 +14,247 @@
package com.google.mediapipe.tasks.vision.core;
import android.graphics.RectF;
import com.google.mediapipe.formats.proto.RectProto.NormalizedRect;
import com.google.mediapipe.framework.MediaPipeException;
import com.google.mediapipe.framework.Packet;
import com.google.mediapipe.framework.ProtoUtil;
import com.google.mediapipe.framework.image.Image;
import com.google.mediapipe.tasks.core.TaskResult;
import com.google.mediapipe.tasks.core.TaskRunner;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
/** The base class of MediaPipe vision tasks. */
public class BaseVisionTaskApi implements AutoCloseable {
private static final long MICROSECONDS_PER_MILLISECOND = 1000;
private final TaskRunner runner;
private final RunningMode runningMode;
private final String imageStreamName;
private final Optional<String> normRectStreamName;
static {
System.loadLibrary("mediapipe_tasks_vision_jni");
ProtoUtil.registerTypeName(NormalizedRect.class, "mediapipe.NormalizedRect");
}
/**
* Constructor to initialize an {@link BaseVisionTaskApi} from a {@link TaskRunner} and a vision
* task {@link RunningMode}.
* Constructor to initialize a {@link BaseVisionTaskApi} only taking images as input.
*
* @param runner a {@link TaskRunner}.
* @param runningMode a mediapipe vision task {@link RunningMode}.
* @param imageStreamName the name of the input image stream.
*/
public BaseVisionTaskApi(TaskRunner runner, RunningMode runningMode) {
public BaseVisionTaskApi(TaskRunner runner, RunningMode runningMode, String imageStreamName) {
this.runner = runner;
this.runningMode = runningMode;
this.imageStreamName = imageStreamName;
this.normRectStreamName = Optional.empty();
}
/**
* Constructor to initialize a {@link BaseVisionTaskApi} taking images and normalized rects as
* input.
*
* @param runner a {@link TaskRunner}.
* @param runningMode a mediapipe vision task {@link RunningMode}.
* @param imageStreamName the name of the input image stream.
* @param normRectStreamName the name of the input normalized rect image stream.
*/
public BaseVisionTaskApi(
TaskRunner runner,
RunningMode runningMode,
String imageStreamName,
String normRectStreamName) {
this.runner = runner;
this.runningMode = runningMode;
this.imageStreamName = imageStreamName;
this.normRectStreamName = Optional.of(normRectStreamName);
}
/**
* A synchronous method to process single image inputs. The call blocks the current thread until a
* failure status or a successful result is returned.
*
* @param imageStreamName the image input stream name.
* @param image a MediaPipe {@link Image} object for processing.
* @throws MediaPipeException if the task is not in the image mode.
* @throws MediaPipeException if the task is not in the image mode or requires a normalized rect
* input.
*/
protected TaskResult processImageData(String imageStreamName, Image image) {
protected TaskResult processImageData(Image image) {
if (runningMode != RunningMode.IMAGE) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task is not initialized with the image mode. Current running mode:"
+ runningMode.name());
}
if (normRectStreamName.isPresent()) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task expects a normalized rect as input.");
}
Map<String, Packet> inputPackets = new HashMap<>();
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
return runner.process(inputPackets);
}
/**
* A synchronous method to process single image inputs. The call blocks the current thread until a
* failure status or a successful result is returned.
*
* @param image a MediaPipe {@link Image} object for processing.
* @param roi a {@link RectF} defining the region-of-interest to process in the image. Coordinates
* are expected to be specified as normalized values in [0,1].
* @throws MediaPipeException if the task is not in the image mode or doesn't require a normalized
* rect.
*/
protected TaskResult processImageData(Image image, RectF roi) {
if (runningMode != RunningMode.IMAGE) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task is not initialized with the image mode. Current running mode:"
+ runningMode.name());
}
if (!normRectStreamName.isPresent()) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task doesn't expect a normalized rect as input.");
}
Map<String, Packet> inputPackets = new HashMap<>();
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
inputPackets.put(
normRectStreamName.get(),
runner.getPacketCreator().createProto(convertToNormalizedRect(roi)));
return runner.process(inputPackets);
}
/**
* A synchronous method to process continuous video frames. The call blocks the current thread
* until a failure status or a successful result is returned.
*
* @param imageStreamName the image input stream name.
* @param image a MediaPipe {@link Image} object for processing.
* @param timestampMs the corresponding timestamp of the input image in milliseconds.
* @throws MediaPipeException if the task is not in the video mode.
* @throws MediaPipeException if the task is not in the video mode or requires a normalized rect
* input.
*/
protected TaskResult processVideoData(String imageStreamName, Image image, long timestampMs) {
protected TaskResult processVideoData(Image image, long timestampMs) {
if (runningMode != RunningMode.VIDEO) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task is not initialized with the video mode. Current running mode:"
+ runningMode.name());
}
if (normRectStreamName.isPresent()) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task expects a normalized rect as input.");
}
Map<String, Packet> inputPackets = new HashMap<>();
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
return runner.process(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
}
/**
* A synchronous method to process continuous video frames. The call blocks the current thread
* until a failure status or a successful result is returned.
*
* @param image a MediaPipe {@link Image} object for processing.
* @param roi a {@link RectF} defining the region-of-interest to process in the image. Coordinates
* are expected to be specified as normalized values in [0,1].
* @param timestampMs the corresponding timestamp of the input image in milliseconds.
* @throws MediaPipeException if the task is not in the video mode or doesn't require a normalized
* rect.
*/
protected TaskResult processVideoData(Image image, RectF roi, long timestampMs) {
if (runningMode != RunningMode.VIDEO) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task is not initialized with the video mode. Current running mode:"
+ runningMode.name());
}
if (!normRectStreamName.isPresent()) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task doesn't expect a normalized rect as input.");
}
Map<String, Packet> inputPackets = new HashMap<>();
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
inputPackets.put(
normRectStreamName.get(),
runner.getPacketCreator().createProto(convertToNormalizedRect(roi)));
return runner.process(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
}
/**
* An asynchronous method to send live stream data to the {@link TaskRunner}. The results will be
* available in the user-defined result listener.
*
* @param imageStreamName the image input stream name.
* @param image a MediaPipe {@link Image} object for processing.
* @param timestampMs the corresponding timestamp of the input image in milliseconds.
* @throws MediaPipeException if the task is not in the video mode.
* @throws MediaPipeException if the task is not in the video mode or requires a normalized rect
* input.
*/
protected void sendLiveStreamData(String imageStreamName, Image image, long timestampMs) {
protected void sendLiveStreamData(Image image, long timestampMs) {
if (runningMode != RunningMode.LIVE_STREAM) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task is not initialized with the live stream mode. Current running mode:"
+ runningMode.name());
}
if (normRectStreamName.isPresent()) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task expects a normalized rect as input.");
}
Map<String, Packet> inputPackets = new HashMap<>();
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
runner.send(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
}
/**
* An asynchronous method to send live stream data to the {@link TaskRunner}. The results will be
* available in the user-defined result listener.
*
* @param image a MediaPipe {@link Image} object for processing.
* @param roi a {@link RectF} defining the region-of-interest to process in the image. Coordinates
* are expected to be specified as normalized values in [0,1].
* @param timestampMs the corresponding timestamp of the input image in milliseconds.
* @throws MediaPipeException if the task is not in the video mode or doesn't require a normalized
* rect.
*/
protected void sendLiveStreamData(Image image, RectF roi, long timestampMs) {
if (runningMode != RunningMode.LIVE_STREAM) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task is not initialized with the live stream mode. Current running mode:"
+ runningMode.name());
}
if (!normRectStreamName.isPresent()) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task doesn't expect a normalized rect as input.");
}
Map<String, Packet> inputPackets = new HashMap<>();
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
inputPackets.put(
normRectStreamName.get(),
runner.getPacketCreator().createProto(convertToNormalizedRect(roi)));
runner.send(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
}
/** Closes and cleans up the MediaPipe vision task. */
@Override
public void close() {
runner.close();
}
/** Converts a {@link RectF} object into a {@link NormalizedRect} protobuf message. */
private static NormalizedRect convertToNormalizedRect(RectF rect) {
return NormalizedRect.newBuilder()
.setXCenter(rect.centerX())
.setYCenter(rect.centerY())
.setWidth(rect.width())
.setHeight(rect.height())
.build();
}
}

View File

@ -38,7 +38,8 @@ public abstract class ObjectDetectionResult implements TaskResult {
* Creates an {@link ObjectDetectionResult} instance from a list of {@link Detection} protobuf
* messages.
*
* @param detectionList a list of {@link Detection} protobuf messages.
* @param detectionList a list of {@link DetectionOuterClass.Detection} protobuf messages.
* @param timestampMs a timestamp for this result.
*/
static ObjectDetectionResult create(List<Detection> detectionList, long timestampMs) {
List<com.google.mediapipe.tasks.components.containers.Detection> detections = new ArrayList<>();

View File

@ -155,7 +155,7 @@ public final class ObjectDetector extends BaseVisionTaskApi {
* Creates an {@link ObjectDetector} instance from an {@link ObjectDetectorOptions}.
*
* @param context an Android {@link Context}.
* @param detectorOptions a {@link ObjectDetectorOptions} instance.
* @param detectorOptions an {@link ObjectDetectorOptions} instance.
* @throws MediaPipeException if there is an error during {@link ObjectDetector} creation.
*/
public static ObjectDetector createFromOptions(
@ -192,7 +192,6 @@ public final class ObjectDetector extends BaseVisionTaskApi {
.setEnableFlowLimiting(detectorOptions.runningMode() == RunningMode.LIVE_STREAM)
.build(),
handler);
detectorOptions.errorListener().ifPresent(runner::setErrorListener);
return new ObjectDetector(runner, detectorOptions.runningMode());
}
@ -204,7 +203,7 @@ public final class ObjectDetector extends BaseVisionTaskApi {
* @param runningMode a mediapipe vision task {@link RunningMode}.
*/
private ObjectDetector(TaskRunner taskRunner, RunningMode runningMode) {
super(taskRunner, runningMode);
super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME);
}
/**
@ -221,7 +220,7 @@ public final class ObjectDetector extends BaseVisionTaskApi {
* @throws MediaPipeException if there is an internal error.
*/
public ObjectDetectionResult detect(Image inputImage) {
return (ObjectDetectionResult) processImageData(IMAGE_IN_STREAM_NAME, inputImage);
return (ObjectDetectionResult) processImageData(inputImage);
}
/**
@ -242,8 +241,7 @@ public final class ObjectDetector extends BaseVisionTaskApi {
* @throws MediaPipeException if there is an internal error.
*/
public ObjectDetectionResult detectForVideo(Image inputImage, long inputTimestampMs) {
return (ObjectDetectionResult)
processVideoData(IMAGE_IN_STREAM_NAME, inputImage, inputTimestampMs);
return (ObjectDetectionResult) processVideoData(inputImage, inputTimestampMs);
}
/**
@ -265,7 +263,7 @@ public final class ObjectDetector extends BaseVisionTaskApi {
* @throws MediaPipeException if there is an internal error.
*/
public void detectAsync(Image inputImage, long inputTimestampMs) {
sendLiveStreamData(IMAGE_IN_STREAM_NAME, inputImage, inputTimestampMs);
sendLiveStreamData(inputImage, inputTimestampMs);
}
/** Options for setting up an {@link ObjectDetector}. */
@ -275,12 +273,12 @@ public final class ObjectDetector extends BaseVisionTaskApi {
/** Builder for {@link ObjectDetectorOptions}. */
@AutoValue.Builder
public abstract static class Builder {
/** Sets the base options for the object detector task. */
/** Sets the {@link BaseOptions} for the object detector task. */
public abstract Builder setBaseOptions(BaseOptions value);
/**
* Sets the running mode for the object detector task. Default to the image mode. Object
* detector has three modes:
* Sets the {@link RunningMode} for the object detector task. Default to the image mode.
* Object detector has three modes:
*
* <ul>
* <li>IMAGE: The mode for detecting objects on single image inputs.
@ -293,8 +291,8 @@ public final class ObjectDetector extends BaseVisionTaskApi {
public abstract Builder setRunningMode(RunningMode value);
/**
* Sets the locale to use for display names specified through the TFLite Model Metadata, if
* any. Defaults to English.
* Sets the optional locale to use for display names specified through the TFLite Model
* Metadata, if any.
*/
public abstract Builder setDisplayNamesLocale(String value);
@ -331,12 +329,12 @@ public final class ObjectDetector extends BaseVisionTaskApi {
public abstract Builder setCategoryDenylist(List<String> value);
/**
* Sets the result listener to receive the detection results asynchronously when the object
* detector is in the live stream mode.
* Sets the {@link ResultListener} to receive the detection results asynchronously when the
* object detector is in the live stream mode.
*/
public abstract Builder setResultListener(ResultListener<ObjectDetectionResult, Image> value);
/** Sets an optional error listener. */
/** Sets an optional {@link ErrorListener}}. */
public abstract Builder setErrorListener(ErrorListener value);
abstract ObjectDetectorOptions autoBuild();

View File

@ -11,7 +11,7 @@
android:targetSdkVersion="30" />
<application
android:label="facedetectiontest"
android:label="objectdetectortest"
android:name="android.support.multidex.MultiDexApplication"
android:taskAffinity="">
<uses-library android:name="android.test.runner" />