Add Pose Landmarker Java API
PiperOrigin-RevId: 524359521
This commit is contained in:
parent
3f1fc6f520
commit
dbeb5a8126
|
@ -54,6 +54,9 @@ _VISION_TASKS_JAVA_PROTO_LITE_TARGETS = [
|
|||
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/pose_detector/proto:pose_detector_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarker_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarks_detector_graph_options_java_proto_lite",
|
||||
]
|
||||
|
||||
_TEXT_TASKS_JAVA_PROTO_LITE_TARGETS = [
|
||||
|
|
|
@ -54,6 +54,7 @@ cc_binary(
|
|||
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
|
||||
"//mediapipe/tasks/cc/vision/interactive_segmenter:interactive_segmenter_graph",
|
||||
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
|
||||
"//mediapipe/tasks/cc/vision/pose_landmarker:pose_landmarker_graph",
|
||||
"//mediapipe/tasks/java:version_script.lds",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni",
|
||||
],
|
||||
|
@ -174,6 +175,37 @@ android_library(
|
|||
],
|
||||
)
|
||||
|
||||
android_library(
|
||||
name = "poselandmarker",
|
||||
srcs = [
|
||||
"poselandmarker/PoseLandmarker.java",
|
||||
"poselandmarker/PoseLandmarkerResult.java",
|
||||
],
|
||||
javacopts = [
|
||||
"-Xep:AndroidJdkLibsChecker:OFF",
|
||||
],
|
||||
manifest = "poselandmarker/AndroidManifest.xml",
|
||||
deps = [
|
||||
":core",
|
||||
"//mediapipe/framework:calculator_options_java_proto_lite",
|
||||
"//mediapipe/framework/formats:classification_java_proto_lite",
|
||||
"//mediapipe/framework/formats:landmark_java_proto_lite",
|
||||
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
||||
"//mediapipe/java/com/google/mediapipe/framework/image",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/pose_detector/proto:pose_detector_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarker_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarks_detector_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:landmark",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:normalized_landmark",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
|
||||
"//third_party:autovalue",
|
||||
"@maven//:androidx_annotation_annotation",
|
||||
"@maven//:com_google_guava_guava",
|
||||
],
|
||||
)
|
||||
|
||||
android_library(
|
||||
name = "handlandmarker",
|
||||
srcs = [
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
package="com.google.mediapipe.tasks.vision.poselandmarker">
|
||||
|
||||
<uses-sdk android:minSdkVersion="24"
|
||||
android:targetSdkVersion="30" />
|
||||
|
||||
</manifest>
|
|
@ -0,0 +1,557 @@
|
|||
// Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.vision.poselandmarker;
|
||||
|
||||
import android.content.Context;
|
||||
import android.os.ParcelFileDescriptor;
|
||||
import com.google.auto.value.AutoValue;
|
||||
import com.google.mediapipe.formats.proto.LandmarkProto.LandmarkList;
|
||||
import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmarkList;
|
||||
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
|
||||
import com.google.mediapipe.framework.AndroidPacketGetter;
|
||||
import com.google.mediapipe.framework.MediaPipeException;
|
||||
import com.google.mediapipe.framework.Packet;
|
||||
import com.google.mediapipe.framework.PacketGetter;
|
||||
import com.google.mediapipe.framework.image.BitmapImageBuilder;
|
||||
import com.google.mediapipe.framework.image.ByteBufferImageBuilder;
|
||||
import com.google.mediapipe.framework.image.MPImage;
|
||||
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||
import com.google.mediapipe.tasks.core.ErrorListener;
|
||||
import com.google.mediapipe.tasks.core.OutputHandler;
|
||||
import com.google.mediapipe.tasks.core.OutputHandler.ResultListener;
|
||||
import com.google.mediapipe.tasks.core.TaskInfo;
|
||||
import com.google.mediapipe.tasks.core.TaskOptions;
|
||||
import com.google.mediapipe.tasks.core.TaskRunner;
|
||||
import com.google.mediapipe.tasks.core.proto.BaseOptionsProto;
|
||||
import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi;
|
||||
import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions;
|
||||
import com.google.mediapipe.tasks.vision.core.RunningMode;
|
||||
import com.google.mediapipe.tasks.vision.posedetector.proto.PoseDetectorGraphOptionsProto;
|
||||
import com.google.mediapipe.tasks.vision.poselandmarker.proto.PoseLandmarkerGraphOptionsProto;
|
||||
import com.google.mediapipe.tasks.vision.poselandmarker.proto.PoseLandmarksDetectorGraphOptionsProto;
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
|
||||
/**
|
||||
* Performs pose landmarks detection on images.
|
||||
*
|
||||
* <p>This API expects a pre-trained pose landmarks model asset bundle. See <TODO link
|
||||
* to the DevSite documentation page>.
|
||||
*
|
||||
* <ul>
|
||||
* <li>Input image {@link MPImage}
|
||||
* <ul>
|
||||
* <li>The image that pose landmarks detection runs on.
|
||||
* </ul>
|
||||
* <li>Output PoseLandmarkerResult {@link PoseLandmarkerResult}
|
||||
* <ul>
|
||||
* <li>A PoseLandmarkerResult containing pose landmarks.
|
||||
* </ul>
|
||||
* </ul>
|
||||
*/
|
||||
public final class PoseLandmarker extends BaseVisionTaskApi {
|
||||
private static final String TAG = PoseLandmarker.class.getSimpleName();
|
||||
private static final String IMAGE_IN_STREAM_NAME = "image_in";
|
||||
private static final String NORM_RECT_IN_STREAM_NAME = "norm_rect_in";
|
||||
|
||||
@SuppressWarnings("ConstantCaseForConstants")
|
||||
private static final List<String> INPUT_STREAMS =
|
||||
Collections.unmodifiableList(
|
||||
Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME));
|
||||
|
||||
private static final int LANDMARKS_OUT_STREAM_INDEX = 0;
|
||||
private static final int WORLD_LANDMARKS_OUT_STREAM_INDEX = 1;
|
||||
private static final int AUXILIARY_LANDMARKS_OUT_STREAM_INDEX = 2;
|
||||
private static final int IMAGE_OUT_STREAM_INDEX = 3;
|
||||
private static int segmentationMasksOutStreamIndex = -1;
|
||||
private static final String TASK_GRAPH_NAME =
|
||||
"mediapipe.tasks.vision.pose_landmarker.PoseLandmarkerGraph";
|
||||
|
||||
/**
|
||||
* Creates a {@link PoseLandmarker} instance from a model file and the default {@link
|
||||
* PoseLandmarkerOptions}.
|
||||
*
|
||||
* @param context an Android {@link Context}.
|
||||
* @param modelPath path to the pose landmarks model with metadata in the assets.
|
||||
* @throws MediaPipeException if there is an error during {@link PoseLandmarker} creation.
|
||||
*/
|
||||
public static PoseLandmarker createFromFile(Context context, String modelPath) {
|
||||
BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(modelPath).build();
|
||||
return createFromOptions(
|
||||
context, PoseLandmarkerOptions.builder().setBaseOptions(baseOptions).build());
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a {@link PoseLandmarker} instance from a model file and the default {@link
|
||||
* PoseLandmarkerOptions}.
|
||||
*
|
||||
* @param context an Android {@link Context}.
|
||||
* @param modelFile the pose landmarks model {@link File} instance.
|
||||
* @throws IOException if an I/O error occurs when opening the tflite model file.
|
||||
* @throws MediaPipeException if there is an error during {@link PoseLandmarker} creation.
|
||||
*/
|
||||
public static PoseLandmarker createFromFile(Context context, File modelFile) throws IOException {
|
||||
try (ParcelFileDescriptor descriptor =
|
||||
ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
|
||||
BaseOptions baseOptions =
|
||||
BaseOptions.builder().setModelAssetFileDescriptor(descriptor.getFd()).build();
|
||||
return createFromOptions(
|
||||
context, PoseLandmarkerOptions.builder().setBaseOptions(baseOptions).build());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a {@link PoseLandmarker} instance from a model buffer and the default {@link
|
||||
* PoseLandmarkerOptions}.
|
||||
*
|
||||
* @param context an Android {@link Context}.
|
||||
* @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the detection
|
||||
* model.
|
||||
* @throws MediaPipeException if there is an error during {@link PoseLandmarker} creation.
|
||||
*/
|
||||
public static PoseLandmarker createFromBuffer(Context context, final ByteBuffer modelBuffer) {
|
||||
BaseOptions baseOptions = BaseOptions.builder().setModelAssetBuffer(modelBuffer).build();
|
||||
return createFromOptions(
|
||||
context, PoseLandmarkerOptions.builder().setBaseOptions(baseOptions).build());
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a {@link PoseLandmarker} instance from a {@link PoseLandmarkerOptions}.
|
||||
*
|
||||
* @param context an Android {@link Context}.
|
||||
* @param landmarkerOptions a {@link PoseLandmarkerOptions} instance.
|
||||
* @throws MediaPipeException if there is an error during {@link PoseLandmarker} creation.
|
||||
*/
|
||||
public static PoseLandmarker createFromOptions(
|
||||
Context context, PoseLandmarkerOptions landmarkerOptions) {
|
||||
List<String> outputStreams = new ArrayList<>();
|
||||
outputStreams.add("NORM_LANDMARKS:pose_landmarks");
|
||||
outputStreams.add("WORLD_LANDMARKS:world_landmarks");
|
||||
outputStreams.add("AUXILIARY_LANDMARKS:auxiliary_landmarks");
|
||||
outputStreams.add("IMAGE:image_out");
|
||||
if (landmarkerOptions.outputSegmentationMasks()) {
|
||||
outputStreams.add("SEGMENTATION_MASK:segmentation_masks");
|
||||
segmentationMasksOutStreamIndex = outputStreams.size() - 1;
|
||||
}
|
||||
|
||||
// TODO: Consolidate OutputHandler and TaskRunner.
|
||||
OutputHandler<PoseLandmarkerResult, MPImage> handler = new OutputHandler<>();
|
||||
handler.setOutputPacketConverter(
|
||||
new OutputHandler.OutputPacketConverter<PoseLandmarkerResult, MPImage>() {
|
||||
@Override
|
||||
public PoseLandmarkerResult convertToTaskResult(List<Packet> packets) {
|
||||
// If there is no poses detected in the image, just returns empty lists.
|
||||
if (packets.get(LANDMARKS_OUT_STREAM_INDEX).isEmpty()) {
|
||||
return PoseLandmarkerResult.create(
|
||||
new ArrayList<>(),
|
||||
new ArrayList<>(),
|
||||
new ArrayList<>(),
|
||||
Optional.empty(),
|
||||
BaseVisionTaskApi.generateResultTimestampMs(
|
||||
landmarkerOptions.runningMode(), packets.get(LANDMARKS_OUT_STREAM_INDEX)));
|
||||
}
|
||||
/** Get segmentation masks */
|
||||
Optional<List<MPImage>> segmentedMasks = Optional.empty();
|
||||
if (landmarkerOptions.outputSegmentationMasks()) {
|
||||
segmentedMasks = getSegmentationMasks(packets);
|
||||
}
|
||||
|
||||
return PoseLandmarkerResult.create(
|
||||
PacketGetter.getProtoVector(
|
||||
packets.get(LANDMARKS_OUT_STREAM_INDEX), NormalizedLandmarkList.parser()),
|
||||
PacketGetter.getProtoVector(
|
||||
packets.get(WORLD_LANDMARKS_OUT_STREAM_INDEX), LandmarkList.parser()),
|
||||
PacketGetter.getProtoVector(
|
||||
packets.get(AUXILIARY_LANDMARKS_OUT_STREAM_INDEX),
|
||||
NormalizedLandmarkList.parser()),
|
||||
segmentedMasks,
|
||||
BaseVisionTaskApi.generateResultTimestampMs(
|
||||
landmarkerOptions.runningMode(), packets.get(LANDMARKS_OUT_STREAM_INDEX)));
|
||||
}
|
||||
|
||||
@Override
|
||||
public MPImage convertToTaskInput(List<Packet> packets) {
|
||||
return new BitmapImageBuilder(
|
||||
AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX)))
|
||||
.build();
|
||||
}
|
||||
});
|
||||
landmarkerOptions.resultListener().ifPresent(handler::setResultListener);
|
||||
landmarkerOptions.errorListener().ifPresent(handler::setErrorListener);
|
||||
TaskRunner runner =
|
||||
TaskRunner.create(
|
||||
context,
|
||||
TaskInfo.<PoseLandmarkerOptions>builder()
|
||||
.setTaskName(PoseLandmarker.class.getSimpleName())
|
||||
.setTaskRunningModeName(landmarkerOptions.runningMode().name())
|
||||
.setTaskGraphName(TASK_GRAPH_NAME)
|
||||
.setInputStreams(INPUT_STREAMS)
|
||||
.setOutputStreams(outputStreams)
|
||||
.setTaskOptions(landmarkerOptions)
|
||||
.setEnableFlowLimiting(landmarkerOptions.runningMode() == RunningMode.LIVE_STREAM)
|
||||
.build(),
|
||||
handler);
|
||||
return new PoseLandmarker(runner, landmarkerOptions.runningMode());
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructor to initialize a {@link PoseLandmarker} from a {@link TaskRunner} and a {@link
|
||||
* RunningMode}.
|
||||
*
|
||||
* @param taskRunner a {@link TaskRunner}.
|
||||
* @param runningMode a mediapipe vision task {@link RunningMode}.
|
||||
*/
|
||||
private PoseLandmarker(TaskRunner taskRunner, RunningMode runningMode) {
|
||||
super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME);
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs pose landmarks detection on the provided single image with default image processing
|
||||
* options, i.e. without any rotation applied. Only use this method when the {@link
|
||||
* PoseLandmarker} is created with {@link RunningMode.IMAGE}. TODO update java doc
|
||||
* for input image format.
|
||||
*
|
||||
* <p>{@link PoseLandmarker} supports the following color space types:
|
||||
*
|
||||
* <ul>
|
||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public PoseLandmarkerResult detect(MPImage image) {
|
||||
return detect(image, ImageProcessingOptions.builder().build());
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs pose landmarks detection on the provided single image. Only use this method when the
|
||||
* {@link PoseLandmarker} is created with {@link RunningMode.IMAGE}. TODO update java
|
||||
* doc for input image format.
|
||||
*
|
||||
* <p>{@link PoseLandmarker} supports the following color space types:
|
||||
*
|
||||
* <ul>
|
||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
|
||||
* input image before running inference. Note that region-of-interest is <b>not</b> supported
|
||||
* by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in
|
||||
* this method throwing an IllegalArgumentException.
|
||||
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
|
||||
* region-of-interest.
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public PoseLandmarkerResult detect(MPImage image, ImageProcessingOptions imageProcessingOptions) {
|
||||
validateImageProcessingOptions(imageProcessingOptions);
|
||||
return (PoseLandmarkerResult) processImageData(image, imageProcessingOptions);
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs pose landmarks detection on the provided video frame with default image processing
|
||||
* options, i.e. without any rotation applied. Only use this method when the {@link
|
||||
* PoseLandmarker} is created with {@link RunningMode.VIDEO}.
|
||||
*
|
||||
* <p>It's required to provide the video frame's timestamp (in milliseconds). The input timestamps
|
||||
* must be monotonically increasing.
|
||||
*
|
||||
* <p>{@link PoseLandmarker} supports the following color space types:
|
||||
*
|
||||
* <ul>
|
||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param timestampMs the input timestamp (in milliseconds).
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public PoseLandmarkerResult detectForVideo(MPImage image, long timestampMs) {
|
||||
return detectForVideo(image, ImageProcessingOptions.builder().build(), timestampMs);
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs pose landmarks detection on the provided video frame. Only use this method when the
|
||||
* {@link PoseLandmarker} is created with {@link RunningMode.VIDEO}.
|
||||
*
|
||||
* <p>It's required to provide the video frame's timestamp (in milliseconds). The input timestamps
|
||||
* must be monotonically increasing.
|
||||
*
|
||||
* <p>{@link PoseLandmarker} supports the following color space types:
|
||||
*
|
||||
* <ul>
|
||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
|
||||
* input image before running inference. Note that region-of-interest is <b>not</b> supported
|
||||
* by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in
|
||||
* this method throwing an IllegalArgumentException.
|
||||
* @param timestampMs the input timestamp (in milliseconds).
|
||||
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
|
||||
* region-of-interest.
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public PoseLandmarkerResult detectForVideo(
|
||||
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
|
||||
validateImageProcessingOptions(imageProcessingOptions);
|
||||
return (PoseLandmarkerResult) processVideoData(image, imageProcessingOptions, timestampMs);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sends live image data to perform pose landmarks detection with default image processing
|
||||
* options, i.e. without any rotation applied, and the results will be available via the {@link
|
||||
* ResultListener} provided in the {@link PoseLandmarkerOptions}. Only use this method when the
|
||||
* {@link PoseLandmarker } is created with {@link RunningMode.LIVE_STREAM}.
|
||||
*
|
||||
* <p>It's required to provide a timestamp (in milliseconds) to indicate when the input image is
|
||||
* sent to the pose landmarker. The input timestamps must be monotonically increasing.
|
||||
*
|
||||
* <p>{@link PoseLandmarker} supports the following color space types:
|
||||
*
|
||||
* <ul>
|
||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param timestampMs the input timestamp (in milliseconds).
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public void detectAsync(MPImage image, long timestampMs) {
|
||||
detectAsync(image, ImageProcessingOptions.builder().build(), timestampMs);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sends live image data to perform pose landmarks detection, and the results will be available
|
||||
* via the {@link ResultListener} provided in the {@link PoseLandmarkerOptions}. Only use this
|
||||
* method when the {@link PoseLandmarker} is created with {@link RunningMode.LIVE_STREAM}.
|
||||
*
|
||||
* <p>It's required to provide a timestamp (in milliseconds) to indicate when the input image is
|
||||
* sent to the pose landmarker. The input timestamps must be monotonically increasing.
|
||||
*
|
||||
* <p>{@link PoseLandmarker} supports the following color space types:
|
||||
*
|
||||
* <ul>
|
||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
|
||||
* input image before running inference. Note that region-of-interest is <b>not</b> supported
|
||||
* by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in
|
||||
* this method throwing an IllegalArgumentException.
|
||||
* @param timestampMs the input timestamp (in milliseconds).
|
||||
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
|
||||
* region-of-interest.
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public void detectAsync(
|
||||
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
|
||||
validateImageProcessingOptions(imageProcessingOptions);
|
||||
sendLiveStreamData(image, imageProcessingOptions, timestampMs);
|
||||
}
|
||||
|
||||
/** Options for setting up an {@link PoseLandmarker}. */
|
||||
@AutoValue
|
||||
public abstract static class PoseLandmarkerOptions extends TaskOptions {
|
||||
|
||||
/** Builder for {@link PoseLandmarkerOptions}. */
|
||||
@AutoValue.Builder
|
||||
public abstract static class Builder {
|
||||
/** Sets the base options for the pose landmarker task. */
|
||||
public abstract Builder setBaseOptions(BaseOptions value);
|
||||
|
||||
/**
|
||||
* Sets the running mode for the pose landmarker task. Default to the image mode. Pose
|
||||
* landmarker has three modes:
|
||||
*
|
||||
* <ul>
|
||||
* <li>IMAGE: The mode for detecting pose landmarks on single image inputs.
|
||||
* <li>VIDEO: The mode for detecting pose landmarks on the decoded frames of a video.
|
||||
* <li>LIVE_STREAM: The mode for for detecting pose landmarks on a live stream of input
|
||||
* data, such as from camera. In this mode, {@code setResultListener} must be called to
|
||||
* set up a listener to receive the detection results asynchronously.
|
||||
* </ul>
|
||||
*/
|
||||
public abstract Builder setRunningMode(RunningMode value);
|
||||
|
||||
/** Sets the maximum number of poses can be detected by the PoseLandmarker. */
|
||||
public abstract Builder setNumPoses(Integer value);
|
||||
|
||||
/** Sets minimum confidence score for the pose detection to be considered successful */
|
||||
public abstract Builder setMinPoseDetectionConfidence(Float value);
|
||||
|
||||
/** Sets minimum confidence score of pose presence score in the pose landmark detection. */
|
||||
public abstract Builder setMinPosePresenceConfidence(Float value);
|
||||
|
||||
/** Sets the minimum confidence score for the pose tracking to be considered successful. */
|
||||
public abstract Builder setMinTrackingConfidence(Float value);
|
||||
|
||||
public abstract Builder setOutputSegmentationMasks(Boolean value);
|
||||
|
||||
/**
|
||||
* Sets the result listener to receive the detection results asynchronously when the pose
|
||||
* landmarker is in the live stream mode.
|
||||
*/
|
||||
public abstract Builder setResultListener(
|
||||
ResultListener<PoseLandmarkerResult, MPImage> value);
|
||||
|
||||
/** Sets an optional error listener. */
|
||||
public abstract Builder setErrorListener(ErrorListener value);
|
||||
|
||||
abstract PoseLandmarkerOptions autoBuild();
|
||||
|
||||
/**
|
||||
* Validates and builds the {@link PoseLandmarkerOptions} instance.
|
||||
*
|
||||
* @throws IllegalArgumentException if the result listener and the running mode are not
|
||||
* properly configured. The result listener should only be set when the pose landmarker is
|
||||
* in the live stream mode.
|
||||
*/
|
||||
public final PoseLandmarkerOptions build() {
|
||||
PoseLandmarkerOptions options = autoBuild();
|
||||
if (options.runningMode() == RunningMode.LIVE_STREAM) {
|
||||
if (!options.resultListener().isPresent()) {
|
||||
throw new IllegalArgumentException(
|
||||
"The pose landmarker is in the live stream mode, a user-defined result listener"
|
||||
+ " must be provided in PoseLandmarkerOptions.");
|
||||
}
|
||||
} else if (options.resultListener().isPresent()) {
|
||||
throw new IllegalArgumentException(
|
||||
"The pose landmarker is in the image or the video mode, a user-defined result"
|
||||
+ " listener shouldn't be provided in PoseLandmarkerOptions.");
|
||||
}
|
||||
return options;
|
||||
}
|
||||
}
|
||||
|
||||
abstract BaseOptions baseOptions();
|
||||
|
||||
abstract RunningMode runningMode();
|
||||
|
||||
abstract Optional<Integer> numPoses();
|
||||
|
||||
abstract Optional<Float> minPoseDetectionConfidence();
|
||||
|
||||
abstract Optional<Float> minPosePresenceConfidence();
|
||||
|
||||
abstract Optional<Float> minTrackingConfidence();
|
||||
|
||||
abstract Boolean outputSegmentationMasks();
|
||||
|
||||
abstract Optional<ResultListener<PoseLandmarkerResult, MPImage>> resultListener();
|
||||
|
||||
abstract Optional<ErrorListener> errorListener();
|
||||
|
||||
public static Builder builder() {
|
||||
return new AutoValue_PoseLandmarker_PoseLandmarkerOptions.Builder()
|
||||
.setRunningMode(RunningMode.IMAGE)
|
||||
.setNumPoses(1)
|
||||
.setMinPoseDetectionConfidence(0.5f)
|
||||
.setMinPosePresenceConfidence(0.5f)
|
||||
.setMinTrackingConfidence(0.5f)
|
||||
.setOutputSegmentationMasks(false);
|
||||
}
|
||||
|
||||
/** Converts a {@link PoseLandmarkerOptions} to a {@link CalculatorOptions} protobuf message. */
|
||||
@Override
|
||||
public CalculatorOptions convertToCalculatorOptionsProto() {
|
||||
PoseLandmarkerGraphOptionsProto.PoseLandmarkerGraphOptions.Builder taskOptionsBuilder =
|
||||
PoseLandmarkerGraphOptionsProto.PoseLandmarkerGraphOptions.newBuilder()
|
||||
.setBaseOptions(
|
||||
BaseOptionsProto.BaseOptions.newBuilder()
|
||||
.setUseStreamMode(runningMode() != RunningMode.IMAGE)
|
||||
.mergeFrom(convertBaseOptionsToProto(baseOptions()))
|
||||
.build());
|
||||
|
||||
// Setup PoseDetectorGraphOptions.
|
||||
PoseDetectorGraphOptionsProto.PoseDetectorGraphOptions.Builder
|
||||
poseDetectorGraphOptionsBuilder =
|
||||
PoseDetectorGraphOptionsProto.PoseDetectorGraphOptions.newBuilder();
|
||||
numPoses().ifPresent(poseDetectorGraphOptionsBuilder::setNumPoses);
|
||||
minPoseDetectionConfidence()
|
||||
.ifPresent(poseDetectorGraphOptionsBuilder::setMinDetectionConfidence);
|
||||
|
||||
// Setup PoseLandmarkerGraphOptions.
|
||||
PoseLandmarksDetectorGraphOptionsProto.PoseLandmarksDetectorGraphOptions.Builder
|
||||
poseLandmarksDetectorGraphOptionsBuilder =
|
||||
PoseLandmarksDetectorGraphOptionsProto.PoseLandmarksDetectorGraphOptions.newBuilder();
|
||||
minPosePresenceConfidence()
|
||||
.ifPresent(poseLandmarksDetectorGraphOptionsBuilder::setMinDetectionConfidence);
|
||||
minTrackingConfidence().ifPresent(taskOptionsBuilder::setMinTrackingConfidence);
|
||||
|
||||
taskOptionsBuilder
|
||||
.setPoseDetectorGraphOptions(poseDetectorGraphOptionsBuilder.build())
|
||||
.setPoseLandmarksDetectorGraphOptions(poseLandmarksDetectorGraphOptionsBuilder.build());
|
||||
|
||||
return CalculatorOptions.newBuilder()
|
||||
.setExtension(
|
||||
PoseLandmarkerGraphOptionsProto.PoseLandmarkerGraphOptions.ext,
|
||||
taskOptionsBuilder.build())
|
||||
.build();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Validates that the provided {@link ImageProcessingOptions} doesn't contain a
|
||||
* region-of-interest.
|
||||
*/
|
||||
private static void validateImageProcessingOptions(
|
||||
ImageProcessingOptions imageProcessingOptions) {
|
||||
if (imageProcessingOptions.regionOfInterest().isPresent()) {
|
||||
throw new IllegalArgumentException("PoseLandmarker doesn't support region-of-interest.");
|
||||
}
|
||||
}
|
||||
|
||||
private static Optional<List<MPImage>> getSegmentationMasks(List<Packet> packets) {
|
||||
Optional<List<MPImage>> segmentedMasks = Optional.of(new ArrayList<>());
|
||||
int width =
|
||||
PacketGetter.getImageWidthFromImageList(packets.get(segmentationMasksOutStreamIndex));
|
||||
int height =
|
||||
PacketGetter.getImageHeightFromImageList(packets.get(segmentationMasksOutStreamIndex));
|
||||
int imageListSize = PacketGetter.getImageListSize(packets.get(segmentationMasksOutStreamIndex));
|
||||
ByteBuffer[] buffersArray = new ByteBuffer[imageListSize];
|
||||
|
||||
// Segmentation mask is a float type image.
|
||||
int numBytes = 4;
|
||||
for (int i = 0; i < imageListSize; i++) {
|
||||
buffersArray[i] = ByteBuffer.allocateDirect(width * height * numBytes);
|
||||
}
|
||||
|
||||
if (!PacketGetter.getImageList(
|
||||
packets.get(segmentationMasksOutStreamIndex),
|
||||
buffersArray,
|
||||
/** deepCopy= */
|
||||
true)) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.INTERNAL.ordinal(),
|
||||
"There is an error getting segmented masks.");
|
||||
}
|
||||
for (ByteBuffer buffer : buffersArray) {
|
||||
ByteBufferImageBuilder builder =
|
||||
new ByteBufferImageBuilder(buffer, width, height, MPImage.IMAGE_FORMAT_VEC32F1);
|
||||
segmentedMasks.get().add(builder.build());
|
||||
}
|
||||
return segmentedMasks;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,113 @@
|
|||
// Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.vision.poselandmarker;
|
||||
|
||||
import com.google.auto.value.AutoValue;
|
||||
import com.google.mediapipe.formats.proto.LandmarkProto;
|
||||
import com.google.mediapipe.framework.image.MPImage;
|
||||
import com.google.mediapipe.tasks.components.containers.Landmark;
|
||||
import com.google.mediapipe.tasks.components.containers.NormalizedLandmark;
|
||||
import com.google.mediapipe.tasks.core.TaskResult;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
|
||||
/** Represents the pose landmarks deection results generated by {@link PoseLandmarker}. */
|
||||
@AutoValue
|
||||
public abstract class PoseLandmarkerResult implements TaskResult {
|
||||
|
||||
/**
|
||||
* Creates a {@link PoseLandmarkerResult} instance from the lists of landmarks and
|
||||
* segmentationMask protobuf messages.
|
||||
*
|
||||
* @param landmarksProto a List of {@link NormalizedLandmarkList}
|
||||
* @param worldLandmarksProto a List of {@link LandmarkList}
|
||||
* @param segmentationMasksData a List of {@link MPImage}
|
||||
*/
|
||||
static PoseLandmarkerResult create(
|
||||
List<LandmarkProto.NormalizedLandmarkList> landmarksProto,
|
||||
List<LandmarkProto.LandmarkList> worldLandmarksProto,
|
||||
List<LandmarkProto.NormalizedLandmarkList> auxiliaryLandmarksProto,
|
||||
Optional<List<MPImage>> segmentationMasksData,
|
||||
long timestampMs) {
|
||||
|
||||
Optional<List<MPImage>> multiPoseSegmentationMasks = Optional.empty();
|
||||
if (segmentationMasksData.isPresent()) {
|
||||
multiPoseSegmentationMasks =
|
||||
Optional.of(Collections.unmodifiableList(segmentationMasksData.get()));
|
||||
}
|
||||
|
||||
List<List<NormalizedLandmark>> multiPoseLandmarks = new ArrayList<>();
|
||||
List<List<Landmark>> multiPoseWorldLandmarks = new ArrayList<>();
|
||||
List<List<NormalizedLandmark>> multiPoseAuxiliaryLandmarks = new ArrayList<>();
|
||||
for (LandmarkProto.NormalizedLandmarkList poseLandmarksProto : landmarksProto) {
|
||||
List<NormalizedLandmark> poseLandmarks = new ArrayList<>();
|
||||
multiPoseLandmarks.add(poseLandmarks);
|
||||
for (LandmarkProto.NormalizedLandmark poseLandmarkProto :
|
||||
poseLandmarksProto.getLandmarkList()) {
|
||||
poseLandmarks.add(
|
||||
NormalizedLandmark.create(
|
||||
poseLandmarkProto.getX(), poseLandmarkProto.getY(), poseLandmarkProto.getZ()));
|
||||
}
|
||||
}
|
||||
for (LandmarkProto.LandmarkList poseWorldLandmarksProto : worldLandmarksProto) {
|
||||
List<Landmark> poseWorldLandmarks = new ArrayList<>();
|
||||
multiPoseWorldLandmarks.add(poseWorldLandmarks);
|
||||
for (LandmarkProto.Landmark poseWorldLandmarkProto :
|
||||
poseWorldLandmarksProto.getLandmarkList()) {
|
||||
poseWorldLandmarks.add(
|
||||
Landmark.create(
|
||||
poseWorldLandmarkProto.getX(),
|
||||
poseWorldLandmarkProto.getY(),
|
||||
poseWorldLandmarkProto.getZ()));
|
||||
}
|
||||
}
|
||||
for (LandmarkProto.NormalizedLandmarkList poseAuxiliaryLandmarksProto :
|
||||
auxiliaryLandmarksProto) {
|
||||
List<NormalizedLandmark> poseAuxiliaryLandmarks = new ArrayList<>();
|
||||
multiPoseAuxiliaryLandmarks.add(poseAuxiliaryLandmarks);
|
||||
for (LandmarkProto.NormalizedLandmark poseAuxiliaryLandmarkProto :
|
||||
poseAuxiliaryLandmarksProto.getLandmarkList()) {
|
||||
poseAuxiliaryLandmarks.add(
|
||||
NormalizedLandmark.create(
|
||||
poseAuxiliaryLandmarkProto.getX(),
|
||||
poseAuxiliaryLandmarkProto.getY(),
|
||||
poseAuxiliaryLandmarkProto.getZ()));
|
||||
}
|
||||
}
|
||||
return new AutoValue_PoseLandmarkerResult(
|
||||
timestampMs,
|
||||
Collections.unmodifiableList(multiPoseLandmarks),
|
||||
Collections.unmodifiableList(multiPoseWorldLandmarks),
|
||||
Collections.unmodifiableList(multiPoseAuxiliaryLandmarks),
|
||||
multiPoseSegmentationMasks);
|
||||
}
|
||||
|
||||
@Override
|
||||
public abstract long timestampMs();
|
||||
|
||||
/** Pose landmarks of detected poses. */
|
||||
public abstract List<List<NormalizedLandmark>> landmarks();
|
||||
|
||||
/** Pose landmarks in world coordniates of detected poses. */
|
||||
public abstract List<List<Landmark>> worldLandmarks();
|
||||
|
||||
/** Pose auxiliary landmarks. */
|
||||
public abstract List<List<NormalizedLandmark>> auxiliaryLandmarks();
|
||||
|
||||
/** Pose segmentation masks. */
|
||||
public abstract Optional<List<MPImage>> segmentationMasks();
|
||||
}
|
|
@ -0,0 +1,24 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
package="com.google.mediapipe.tasks.vision.poselandmarkertest"
|
||||
android:versionCode="1"
|
||||
android:versionName="1.0" >
|
||||
|
||||
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE"/>
|
||||
<uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE"/>
|
||||
|
||||
<uses-sdk android:minSdkVersion="24"
|
||||
android:targetSdkVersion="30" />
|
||||
|
||||
<application
|
||||
android:label="poselandmarkertest"
|
||||
android:name="android.support.multidex.MultiDexApplication"
|
||||
android:taskAffinity="">
|
||||
<uses-library android:name="android.test.runner" />
|
||||
</application>
|
||||
|
||||
<instrumentation
|
||||
android:name="com.google.android.apps.common.testing.testrunner.GoogleInstrumentationTestRunner"
|
||||
android:targetPackage="com.google.mediapipe.tasks.vision.poselandmarkertest" />
|
||||
|
||||
</manifest>
|
|
@ -0,0 +1,19 @@
|
|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
# TODO: Enable this in OSS
|
|
@ -0,0 +1,365 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.vision.poselandmarker;
|
||||
|
||||
import static com.google.common.truth.Truth.assertThat;
|
||||
import static org.junit.Assert.assertThrows;
|
||||
|
||||
import android.content.res.AssetManager;
|
||||
import android.graphics.BitmapFactory;
|
||||
import android.graphics.RectF;
|
||||
import androidx.test.core.app.ApplicationProvider;
|
||||
import androidx.test.ext.junit.runners.AndroidJUnit4;
|
||||
import com.google.common.truth.Correspondence;
|
||||
import com.google.mediapipe.framework.MediaPipeException;
|
||||
import com.google.mediapipe.framework.image.BitmapImageBuilder;
|
||||
import com.google.mediapipe.framework.image.MPImage;
|
||||
import com.google.mediapipe.tasks.components.containers.NormalizedLandmark;
|
||||
import com.google.mediapipe.tasks.components.containers.proto.LandmarksDetectionResultProto.LandmarksDetectionResult;
|
||||
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||
import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions;
|
||||
import com.google.mediapipe.tasks.vision.core.RunningMode;
|
||||
import com.google.mediapipe.tasks.vision.poselandmarker.PoseLandmarker.PoseLandmarkerOptions;
|
||||
import java.io.InputStream;
|
||||
import java.util.Arrays;
|
||||
import java.util.Optional;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Suite;
|
||||
import org.junit.runners.Suite.SuiteClasses;
|
||||
|
||||
/** Test for {@link PoseLandmarker}. */
|
||||
@RunWith(Suite.class)
|
||||
@SuiteClasses({PoseLandmarkerTest.General.class, PoseLandmarkerTest.RunningModeTest.class})
|
||||
public class PoseLandmarkerTest {
|
||||
private static final String POSE_LANDMARKER_BUNDLE_ASSET_FILE = "pose_landmarker.task";
|
||||
private static final String POSE_IMAGE = "pose.jpg";
|
||||
private static final String POSE_LANDMARKS = "pose_landmarks.pb";
|
||||
private static final String NO_POSES_IMAGE = "burger.jpg";
|
||||
private static final String TAG = "Pose Landmarker Test";
|
||||
private static final float LANDMARKS_ERROR_TOLERANCE = 0.03f;
|
||||
private static final int IMAGE_WIDTH = 1000;
|
||||
private static final int IMAGE_HEIGHT = 667;
|
||||
|
||||
@RunWith(AndroidJUnit4.class)
|
||||
public static final class General extends PoseLandmarkerTest {
|
||||
|
||||
@Test
|
||||
public void detect_successWithValidModels() throws Exception {
|
||||
PoseLandmarkerOptions options =
|
||||
PoseLandmarkerOptions.builder()
|
||||
.setBaseOptions(
|
||||
BaseOptions.builder()
|
||||
.setModelAssetPath(POSE_LANDMARKER_BUNDLE_ASSET_FILE)
|
||||
.build())
|
||||
.build();
|
||||
PoseLandmarker poseLandmarker =
|
||||
PoseLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
PoseLandmarkerResult actualResult = poseLandmarker.detect(getImageFromAsset(POSE_IMAGE));
|
||||
PoseLandmarkerResult expectedResult = getExpectedPoseLandmarkerResult(POSE_LANDMARKS);
|
||||
assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void detect_successWithEmptyResult() throws Exception {
|
||||
PoseLandmarkerOptions options =
|
||||
PoseLandmarkerOptions.builder()
|
||||
.setBaseOptions(
|
||||
BaseOptions.builder()
|
||||
.setModelAssetPath(POSE_LANDMARKER_BUNDLE_ASSET_FILE)
|
||||
.build())
|
||||
.build();
|
||||
PoseLandmarker poseLandmarker =
|
||||
PoseLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
PoseLandmarkerResult actualResult = poseLandmarker.detect(getImageFromAsset(NO_POSES_IMAGE));
|
||||
assertThat(actualResult.landmarks()).isEmpty();
|
||||
assertThat(actualResult.worldLandmarks()).isEmpty();
|
||||
// TODO: Add additional tests for MP Tasks Pose Graphs
|
||||
// Add tests for segmentation masks.
|
||||
}
|
||||
|
||||
@Test
|
||||
public void recognize_failsWithRegionOfInterest() throws Exception {
|
||||
PoseLandmarkerOptions options =
|
||||
PoseLandmarkerOptions.builder()
|
||||
.setBaseOptions(
|
||||
BaseOptions.builder()
|
||||
.setModelAssetPath(POSE_LANDMARKER_BUNDLE_ASSET_FILE)
|
||||
.build())
|
||||
.setNumPoses(1)
|
||||
.build();
|
||||
PoseLandmarker poseLandmarker =
|
||||
PoseLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
ImageProcessingOptions imageProcessingOptions =
|
||||
ImageProcessingOptions.builder().setRegionOfInterest(new RectF(0, 0, 1, 1)).build();
|
||||
IllegalArgumentException exception =
|
||||
assertThrows(
|
||||
IllegalArgumentException.class,
|
||||
() -> poseLandmarker.detect(getImageFromAsset(POSE_IMAGE), imageProcessingOptions));
|
||||
assertThat(exception)
|
||||
.hasMessageThat()
|
||||
.contains("PoseLandmarker doesn't support region-of-interest");
|
||||
}
|
||||
}
|
||||
|
||||
@RunWith(AndroidJUnit4.class)
|
||||
public static final class RunningModeTest extends PoseLandmarkerTest {
|
||||
@Test
|
||||
public void create_failsWithIllegalResultListenerInNonLiveStreamMode() throws Exception {
|
||||
for (RunningMode mode : new RunningMode[] {RunningMode.IMAGE, RunningMode.VIDEO}) {
|
||||
IllegalArgumentException exception =
|
||||
assertThrows(
|
||||
IllegalArgumentException.class,
|
||||
() ->
|
||||
PoseLandmarkerOptions.builder()
|
||||
.setBaseOptions(
|
||||
BaseOptions.builder()
|
||||
.setModelAssetPath(POSE_LANDMARKER_BUNDLE_ASSET_FILE)
|
||||
.build())
|
||||
.setRunningMode(mode)
|
||||
.setResultListener((PoseLandmarkerResults, inputImage) -> {})
|
||||
.build());
|
||||
assertThat(exception)
|
||||
.hasMessageThat()
|
||||
.contains("a user-defined result listener shouldn't be provided");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void create_failsWithMissingResultListenerInLiveSteamMode() throws Exception {
|
||||
IllegalArgumentException exception =
|
||||
assertThrows(
|
||||
IllegalArgumentException.class,
|
||||
() ->
|
||||
PoseLandmarkerOptions.builder()
|
||||
.setBaseOptions(
|
||||
BaseOptions.builder()
|
||||
.setModelAssetPath(POSE_LANDMARKER_BUNDLE_ASSET_FILE)
|
||||
.build())
|
||||
.setRunningMode(RunningMode.LIVE_STREAM)
|
||||
.build());
|
||||
assertThat(exception)
|
||||
.hasMessageThat()
|
||||
.contains("a user-defined result listener must be provided");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void recognize_failsWithCallingWrongApiInImageMode() throws Exception {
|
||||
PoseLandmarkerOptions options =
|
||||
PoseLandmarkerOptions.builder()
|
||||
.setBaseOptions(
|
||||
BaseOptions.builder().setModelAssetPath(POSE_LANDMARKER_BUNDLE_ASSET_FILE).build())
|
||||
.setRunningMode(RunningMode.IMAGE)
|
||||
.build();
|
||||
|
||||
PoseLandmarker poseLandmarker =
|
||||
PoseLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
MediaPipeException exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() ->
|
||||
poseLandmarker.detectForVideo(
|
||||
getImageFromAsset(POSE_IMAGE), /* timestampsMs= */ 0));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
|
||||
exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() -> poseLandmarker.detectAsync(getImageFromAsset(POSE_IMAGE), /* timestampsMs= */ 0));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void recognize_failsWithCallingWrongApiInVideoMode() throws Exception {
|
||||
PoseLandmarkerOptions options =
|
||||
PoseLandmarkerOptions.builder()
|
||||
.setBaseOptions(
|
||||
BaseOptions.builder().setModelAssetPath(POSE_LANDMARKER_BUNDLE_ASSET_FILE).build())
|
||||
.setRunningMode(RunningMode.VIDEO)
|
||||
.build();
|
||||
|
||||
PoseLandmarker poseLandmarker =
|
||||
PoseLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
MediaPipeException exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class, () -> poseLandmarker.detect(getImageFromAsset(POSE_IMAGE)));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the image mode");
|
||||
exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() -> poseLandmarker.detectAsync(getImageFromAsset(POSE_IMAGE), /* timestampsMs= */ 0));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void recognize_failsWithCallingWrongApiInLiveSteamMode() throws Exception {
|
||||
PoseLandmarkerOptions options =
|
||||
PoseLandmarkerOptions.builder()
|
||||
.setBaseOptions(
|
||||
BaseOptions.builder().setModelAssetPath(POSE_LANDMARKER_BUNDLE_ASSET_FILE).build())
|
||||
.setRunningMode(RunningMode.LIVE_STREAM)
|
||||
.setResultListener((PoseLandmarkerResults, inputImage) -> {})
|
||||
.build();
|
||||
|
||||
PoseLandmarker poseLandmarker =
|
||||
PoseLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
MediaPipeException exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class, () -> poseLandmarker.detect(getImageFromAsset(POSE_IMAGE)));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the image mode");
|
||||
exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() ->
|
||||
poseLandmarker.detectForVideo(
|
||||
getImageFromAsset(POSE_IMAGE), /* timestampsMs= */ 0));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void recognize_successWithImageMode() throws Exception {
|
||||
PoseLandmarkerOptions options =
|
||||
PoseLandmarkerOptions.builder()
|
||||
.setBaseOptions(
|
||||
BaseOptions.builder().setModelAssetPath(POSE_LANDMARKER_BUNDLE_ASSET_FILE).build())
|
||||
.setRunningMode(RunningMode.IMAGE)
|
||||
.build();
|
||||
|
||||
PoseLandmarker poseLandmarker =
|
||||
PoseLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
PoseLandmarkerResult actualResult = poseLandmarker.detect(getImageFromAsset(POSE_IMAGE));
|
||||
PoseLandmarkerResult expectedResult = getExpectedPoseLandmarkerResult(POSE_LANDMARKS);
|
||||
assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void recognize_successWithVideoMode() throws Exception {
|
||||
PoseLandmarkerOptions options =
|
||||
PoseLandmarkerOptions.builder()
|
||||
.setBaseOptions(
|
||||
BaseOptions.builder().setModelAssetPath(POSE_LANDMARKER_BUNDLE_ASSET_FILE).build())
|
||||
.setRunningMode(RunningMode.VIDEO)
|
||||
.build();
|
||||
PoseLandmarker poseLandmarker =
|
||||
PoseLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
PoseLandmarkerResult expectedResult = getExpectedPoseLandmarkerResult(POSE_LANDMARKS);
|
||||
for (int i = 0; i < 3; i++) {
|
||||
PoseLandmarkerResult actualResult =
|
||||
poseLandmarker.detectForVideo(getImageFromAsset(POSE_IMAGE), /* timestampsMs= */ i);
|
||||
assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void recognize_failsWithOutOfOrderInputTimestamps() throws Exception {
|
||||
MPImage image = getImageFromAsset(POSE_IMAGE);
|
||||
PoseLandmarkerResult expectedResult = getExpectedPoseLandmarkerResult(POSE_LANDMARKS);
|
||||
PoseLandmarkerOptions options =
|
||||
PoseLandmarkerOptions.builder()
|
||||
.setBaseOptions(
|
||||
BaseOptions.builder().setModelAssetPath(POSE_LANDMARKER_BUNDLE_ASSET_FILE).build())
|
||||
.setRunningMode(RunningMode.LIVE_STREAM)
|
||||
.setResultListener(
|
||||
(actualResult, inputImage) -> {
|
||||
assertActualResultApproximatelyEqualsToExpectedResult(
|
||||
actualResult, expectedResult);
|
||||
assertImageSizeIsExpected(inputImage);
|
||||
})
|
||||
.build();
|
||||
try (PoseLandmarker poseLandmarker =
|
||||
PoseLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
|
||||
poseLandmarker.detectAsync(image, /* timestampsMs= */ 1);
|
||||
MediaPipeException exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() -> poseLandmarker.detectAsync(image, /* timestampsMs= */ 0));
|
||||
assertThat(exception)
|
||||
.hasMessageThat()
|
||||
.contains("having a smaller timestamp than the processed timestamp");
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void recognize_successWithLiveSteamMode() throws Exception {
|
||||
MPImage image = getImageFromAsset(POSE_IMAGE);
|
||||
PoseLandmarkerResult expectedResult = getExpectedPoseLandmarkerResult(POSE_LANDMARKS);
|
||||
PoseLandmarkerOptions options =
|
||||
PoseLandmarkerOptions.builder()
|
||||
.setBaseOptions(
|
||||
BaseOptions.builder().setModelAssetPath(POSE_LANDMARKER_BUNDLE_ASSET_FILE).build())
|
||||
.setRunningMode(RunningMode.LIVE_STREAM)
|
||||
.setResultListener(
|
||||
(actualResult, inputImage) -> {
|
||||
assertActualResultApproximatelyEqualsToExpectedResult(
|
||||
actualResult, expectedResult);
|
||||
assertImageSizeIsExpected(inputImage);
|
||||
})
|
||||
.build();
|
||||
try (PoseLandmarker poseLandmarker =
|
||||
PoseLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
|
||||
for (int i = 0; i < 3; i++) {
|
||||
poseLandmarker.detectAsync(image, /* timestampsMs= */ i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static MPImage getImageFromAsset(String filePath) throws Exception {
|
||||
AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets();
|
||||
InputStream istr = assetManager.open(filePath);
|
||||
return new BitmapImageBuilder(BitmapFactory.decodeStream(istr)).build();
|
||||
}
|
||||
|
||||
private static PoseLandmarkerResult getExpectedPoseLandmarkerResult(String filePath)
|
||||
throws Exception {
|
||||
AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets();
|
||||
InputStream istr = assetManager.open(filePath);
|
||||
LandmarksDetectionResult landmarksDetectionResultProto =
|
||||
LandmarksDetectionResult.parser().parseFrom(istr);
|
||||
return PoseLandmarkerResult.create(
|
||||
Arrays.asList(landmarksDetectionResultProto.getLandmarks()),
|
||||
Arrays.asList(landmarksDetectionResultProto.getWorldLandmarks()),
|
||||
Arrays.asList(),
|
||||
Optional.empty(),
|
||||
/* timestampMs= */ 0);
|
||||
}
|
||||
|
||||
private static void assertActualResultApproximatelyEqualsToExpectedResult(
|
||||
PoseLandmarkerResult actualResult, PoseLandmarkerResult expectedResult) {
|
||||
// TODO: Add additional tests for MP Tasks Pose Graphs
|
||||
// Add additional tests for auxiliary, world landmarks and segmentation masks.
|
||||
// Expects to have the same number of poses detected.
|
||||
assertThat(actualResult.landmarks()).hasSize(expectedResult.landmarks().size());
|
||||
|
||||
// Actual landmarks match expected landmarks.
|
||||
assertThat(actualResult.landmarks().get(0))
|
||||
.comparingElementsUsing(
|
||||
Correspondence.from(
|
||||
(Correspondence.BinaryPredicate<NormalizedLandmark, NormalizedLandmark>)
|
||||
(actual, expected) -> {
|
||||
return Correspondence.tolerance(LANDMARKS_ERROR_TOLERANCE)
|
||||
.compare(actual.x(), expected.x())
|
||||
&& Correspondence.tolerance(LANDMARKS_ERROR_TOLERANCE)
|
||||
.compare(actual.y(), expected.y());
|
||||
},
|
||||
"landmarks approximately equal to"))
|
||||
.containsExactlyElementsIn(expectedResult.landmarks().get(0));
|
||||
}
|
||||
|
||||
private static void assertImageSizeIsExpected(MPImage inputImage) {
|
||||
assertThat(inputImage).isNotNull();
|
||||
assertThat(inputImage.getWidth()).isEqualTo(IMAGE_WIDTH);
|
||||
assertThat(inputImage.getHeight()).isEqualTo(IMAGE_HEIGHT);
|
||||
}
|
||||
}
|
6
third_party/external_files.bzl
vendored
6
third_party/external_files.bzl
vendored
|
@ -307,7 +307,7 @@ def external_files():
|
|||
http_file(
|
||||
name = "com_google_mediapipe_expected_pose_landmarks_prototxt",
|
||||
sha256 = "eed8dfa169b0abee60cde01496599b0bc75d91a82594a1bdf59be2f76f45d7f5",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/expected_pose_landmarks.prototxt?generation=1681244232522990"],
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/expected_pose_landmarks.prototxt?generation=16812442325229901681244235071100"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
|
@ -996,8 +996,8 @@ def external_files():
|
|||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_pose_landmarks_pbtxt",
|
||||
sha256 = "305a71fbff83e270a5dbd81fb7cf65203f56e0b1caba8ea42edc16c6e8a2ba18",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/pose_landmarks.pbtxt?generation=1681244254964356"],
|
||||
sha256 = "69c79cdf3964d7819776eab1172e47e70684139d72a6d7edcbdd62dbb2ca5527",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/pose_landmarks.pbtxt?generation=1681425322701589"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
|
|
Loading…
Reference in New Issue
Block a user