diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD index 07106985d..bcdc0e5e5 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD @@ -92,6 +92,9 @@ android_library( android_library( name = "landmark", srcs = ["Landmark.java"], + javacopts = [ + "-Xep:AndroidJdkLibsChecker:OFF", + ], deps = [ "//third_party:autovalue", "@maven//:com_google_guava_guava", @@ -101,6 +104,9 @@ android_library( android_library( name = "normalized_landmark", srcs = ["NormalizedLandmark.java"], + javacopts = [ + "-Xep:AndroidJdkLibsChecker:OFF", + ], deps = [ "//third_party:autovalue", "@maven//:com_google_guava_guava", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Landmark.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Landmark.java index c3e9f2715..e23d9115d 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Landmark.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Landmark.java @@ -16,6 +16,7 @@ package com.google.mediapipe.tasks.components.containers; import com.google.auto.value.AutoValue; import java.util.Objects; +import java.util.Optional; /** * Landmark represents a point in 3D space with x, y, z coordinates. The landmark coordinates are in @@ -27,7 +28,12 @@ public abstract class Landmark { private static final float TOLERANCE = 1e-6f; public static Landmark create(float x, float y, float z) { - return new AutoValue_Landmark(x, y, z); + return new AutoValue_Landmark(x, y, z, Optional.empty(), Optional.empty()); + } + + public static Landmark create( + float x, float y, float z, Optional visibility, Optional presence) { + return new AutoValue_Landmark(x, y, z, visibility, presence); } // The x coordinates of the landmark. @@ -39,6 +45,12 @@ public abstract class Landmark { // The z coordinates of the landmark. public abstract float z(); + // Visibility of the normalized landmark. + public abstract Optional visibility(); + + // Presence of the normalized landmark. + public abstract Optional presence(); + @Override public final boolean equals(Object o) { if (!(o instanceof Landmark)) { @@ -57,6 +69,16 @@ public abstract class Landmark { @Override public final String toString() { - return ""; + return ""; } } diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/NormalizedLandmark.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/NormalizedLandmark.java index f96e434ca..50a95d565 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/NormalizedLandmark.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/NormalizedLandmark.java @@ -16,6 +16,7 @@ package com.google.mediapipe.tasks.components.containers; import com.google.auto.value.AutoValue; import java.util.Objects; +import java.util.Optional; /** * Normalized Landmark represents a point in 3D space with x, y, z coordinates. x and y are @@ -28,7 +29,12 @@ public abstract class NormalizedLandmark { private static final float TOLERANCE = 1e-6f; public static NormalizedLandmark create(float x, float y, float z) { - return new AutoValue_NormalizedLandmark(x, y, z); + return new AutoValue_NormalizedLandmark(x, y, z, Optional.empty(), Optional.empty()); + } + + public static NormalizedLandmark create( + float x, float y, float z, Optional visibility, Optional presence) { + return new AutoValue_NormalizedLandmark(x, y, z, visibility, presence); } // The x coordinates of the normalized landmark. @@ -40,6 +46,12 @@ public abstract class NormalizedLandmark { // The z coordinates of the normalized landmark. public abstract float z(); + // Visibility of the normalized landmark. + public abstract Optional visibility(); + + // Presence of the normalized landmark. + public abstract Optional presence(); + @Override public final boolean equals(Object o) { if (!(o instanceof NormalizedLandmark)) { @@ -58,6 +70,16 @@ public abstract class NormalizedLandmark { @Override public final String toString() { - return ""; + return ""; } } diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facelandmarker/FaceLandmarkerResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facelandmarker/FaceLandmarkerResult.java index c91477e10..0429ecacb 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facelandmarker/FaceLandmarkerResult.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facelandmarker/FaceLandmarkerResult.java @@ -53,7 +53,15 @@ public abstract class FaceLandmarkerResult implements TaskResult { faceLandmarksProto.getLandmarkList()) { faceLandmarks.add( NormalizedLandmark.create( - faceLandmarkProto.getX(), faceLandmarkProto.getY(), faceLandmarkProto.getZ())); + faceLandmarkProto.getX(), + faceLandmarkProto.getY(), + faceLandmarkProto.getZ(), + faceLandmarkProto.hasVisibility() + ? Optional.of(faceLandmarkProto.getVisibility()) + : Optional.empty(), + faceLandmarkProto.hasPresence() + ? Optional.of(faceLandmarkProto.getPresence()) + : Optional.empty())); } } Optional>> multiFaceBlendshapes = Optional.empty(); diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarkerResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarkerResult.java index 467e871b2..b8b236d42 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarkerResult.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarkerResult.java @@ -25,6 +25,7 @@ import com.google.mediapipe.tasks.core.TaskResult; import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Optional; /** Represents the hand landmarks deection results generated by {@link HandLandmarker}. */ @AutoValue @@ -53,7 +54,15 @@ public abstract class HandLandmarkerResult implements TaskResult { handLandmarksProto.getLandmarkList()) { handLandmarks.add( NormalizedLandmark.create( - handLandmarkProto.getX(), handLandmarkProto.getY(), handLandmarkProto.getZ())); + handLandmarkProto.getX(), + handLandmarkProto.getY(), + handLandmarkProto.getZ(), + handLandmarkProto.hasVisibility() + ? Optional.of(handLandmarkProto.getVisibility()) + : Optional.empty(), + handLandmarkProto.hasPresence() + ? Optional.of(handLandmarkProto.getPresence()) + : Optional.empty())); } } for (LandmarkProto.LandmarkList handWorldLandmarksProto : worldLandmarksProto) { @@ -65,7 +74,13 @@ public abstract class HandLandmarkerResult implements TaskResult { com.google.mediapipe.tasks.components.containers.Landmark.create( handWorldLandmarkProto.getX(), handWorldLandmarkProto.getY(), - handWorldLandmarkProto.getZ())); + handWorldLandmarkProto.getZ(), + handWorldLandmarkProto.hasVisibility() + ? Optional.of(handWorldLandmarkProto.getVisibility()) + : Optional.empty(), + handWorldLandmarkProto.hasPresence() + ? Optional.of(handWorldLandmarkProto.getPresence()) + : Optional.empty())); } } for (ClassificationList handednessProto : handednessesProto) { diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/poselandmarker/PoseLandmarkerResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/poselandmarker/PoseLandmarkerResult.java index 389e78266..0dde56700 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/poselandmarker/PoseLandmarkerResult.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/poselandmarker/PoseLandmarkerResult.java @@ -58,7 +58,15 @@ public abstract class PoseLandmarkerResult implements TaskResult { poseLandmarksProto.getLandmarkList()) { poseLandmarks.add( NormalizedLandmark.create( - poseLandmarkProto.getX(), poseLandmarkProto.getY(), poseLandmarkProto.getZ())); + poseLandmarkProto.getX(), + poseLandmarkProto.getY(), + poseLandmarkProto.getZ(), + poseLandmarkProto.hasVisibility() + ? Optional.of(poseLandmarkProto.getVisibility()) + : Optional.empty(), + poseLandmarkProto.hasPresence() + ? Optional.of(poseLandmarkProto.getPresence()) + : Optional.empty())); } } for (LandmarkProto.LandmarkList poseWorldLandmarksProto : worldLandmarksProto) { @@ -70,7 +78,13 @@ public abstract class PoseLandmarkerResult implements TaskResult { Landmark.create( poseWorldLandmarkProto.getX(), poseWorldLandmarkProto.getY(), - poseWorldLandmarkProto.getZ())); + poseWorldLandmarkProto.getZ(), + poseWorldLandmarkProto.hasVisibility() + ? Optional.of(poseWorldLandmarkProto.getVisibility()) + : Optional.empty(), + poseWorldLandmarkProto.hasPresence() + ? Optional.of(poseWorldLandmarkProto.getPresence()) + : Optional.empty())); } } return new AutoValue_PoseLandmarkerResult( diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/poselandmarker/PoseLandmarkerTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/poselandmarker/PoseLandmarkerTest.java index 7adef9e27..508709ab0 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/poselandmarker/PoseLandmarkerTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/poselandmarker/PoseLandmarkerTest.java @@ -15,6 +15,7 @@ package com.google.mediapipe.tasks.vision.poselandmarker; import static com.google.common.truth.Truth.assertThat; +import static com.google.common.truth.Truth.assertWithMessage; import static org.junit.Assert.assertThrows; import android.content.res.AssetManager; @@ -26,6 +27,7 @@ import com.google.common.truth.Correspondence; import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.image.BitmapImageBuilder; import com.google.mediapipe.framework.image.MPImage; +import com.google.mediapipe.tasks.components.containers.Landmark; import com.google.mediapipe.tasks.components.containers.NormalizedLandmark; import com.google.mediapipe.tasks.components.containers.proto.LandmarksDetectionResultProto.LandmarksDetectionResult; import com.google.mediapipe.tasks.core.BaseOptions; @@ -34,6 +36,7 @@ import com.google.mediapipe.tasks.vision.core.RunningMode; import com.google.mediapipe.tasks.vision.poselandmarker.PoseLandmarker.PoseLandmarkerOptions; import java.io.InputStream; import java.util.Arrays; +import java.util.List; import java.util.Optional; import org.junit.Test; import org.junit.runner.RunWith; @@ -50,6 +53,8 @@ public class PoseLandmarkerTest { private static final String NO_POSES_IMAGE = "burger.jpg"; private static final String TAG = "Pose Landmarker Test"; private static final float LANDMARKS_ERROR_TOLERANCE = 0.03f; + private static final float VISIBILITY_TOLERANCE = 0.9f; + private static final float PRESENCE_TOLERANCE = 0.9f; private static final int IMAGE_WIDTH = 1000; private static final int IMAGE_HEIGHT = 667; @@ -70,6 +75,8 @@ public class PoseLandmarkerTest { PoseLandmarkerResult actualResult = poseLandmarker.detect(getImageFromAsset(POSE_IMAGE)); PoseLandmarkerResult expectedResult = getExpectedPoseLandmarkerResult(POSE_LANDMARKS); assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult); + assertAllLandmarksAreVisibleAndPresent( + actualResult, VISIBILITY_TOLERANCE, PRESENCE_TOLERANCE); } @Test @@ -361,4 +368,40 @@ public class PoseLandmarkerTest { assertThat(inputImage.getWidth()).isEqualTo(IMAGE_WIDTH); assertThat(inputImage.getHeight()).isEqualTo(IMAGE_HEIGHT); } + + private static void assertAllLandmarksAreVisibleAndPresent( + PoseLandmarkerResult result, float visbilityThreshold, float presenceThreshold) { + for (int i = 0; i < result.landmarks().size(); i++) { + List landmarks = result.landmarks().get(i); + for (int j = 0; j < landmarks.size(); j++) { + NormalizedLandmark landmark = landmarks.get(j); + String landmarkMessage = "Landmark List " + i + " landmark " + j + ": " + landmark; + landmark + .visibility() + .ifPresent( + val -> + assertWithMessage(landmarkMessage).that(val).isAtLeast((visbilityThreshold))); + landmark + .presence() + .ifPresent( + val -> assertWithMessage(landmarkMessage).that(val).isAtLeast((presenceThreshold))); + } + } + for (int i = 0; i < result.worldLandmarks().size(); i++) { + List landmarks = result.worldLandmarks().get(i); + for (int j = 0; j < landmarks.size(); j++) { + Landmark landmark = landmarks.get(j); + String landmarkMessage = "World Landmark List " + i + " landmark " + j + ": " + landmark; + landmark + .visibility() + .ifPresent( + val -> + assertWithMessage(landmarkMessage).that(val).isAtLeast((visbilityThreshold))); + landmark + .presence() + .ifPresent( + val -> assertWithMessage(landmarkMessage).that(val).isAtLeast((presenceThreshold))); + } + } + } }