Java API add visibility and presence for landmarks.

PiperOrigin-RevId: 549709256
This commit is contained in:
MediaPipe Team 2023-07-20 12:38:17 -07:00 committed by Copybara-Service
parent 236a36e39a
commit 9af637b125
7 changed files with 139 additions and 9 deletions

View File

@ -92,6 +92,9 @@ android_library(
android_library( android_library(
name = "landmark", name = "landmark",
srcs = ["Landmark.java"], srcs = ["Landmark.java"],
javacopts = [
"-Xep:AndroidJdkLibsChecker:OFF",
],
deps = [ deps = [
"//third_party:autovalue", "//third_party:autovalue",
"@maven//:com_google_guava_guava", "@maven//:com_google_guava_guava",
@ -101,6 +104,9 @@ android_library(
android_library( android_library(
name = "normalized_landmark", name = "normalized_landmark",
srcs = ["NormalizedLandmark.java"], srcs = ["NormalizedLandmark.java"],
javacopts = [
"-Xep:AndroidJdkLibsChecker:OFF",
],
deps = [ deps = [
"//third_party:autovalue", "//third_party:autovalue",
"@maven//:com_google_guava_guava", "@maven//:com_google_guava_guava",

View File

@ -16,6 +16,7 @@ package com.google.mediapipe.tasks.components.containers;
import com.google.auto.value.AutoValue; import com.google.auto.value.AutoValue;
import java.util.Objects; import java.util.Objects;
import java.util.Optional;
/** /**
* Landmark represents a point in 3D space with x, y, z coordinates. The landmark coordinates are in * Landmark represents a point in 3D space with x, y, z coordinates. The landmark coordinates are in
@ -27,7 +28,12 @@ public abstract class Landmark {
private static final float TOLERANCE = 1e-6f; private static final float TOLERANCE = 1e-6f;
public static Landmark create(float x, float y, float z) { public static Landmark create(float x, float y, float z) {
return new AutoValue_Landmark(x, y, z); return new AutoValue_Landmark(x, y, z, Optional.empty(), Optional.empty());
}
public static Landmark create(
float x, float y, float z, Optional<Float> visibility, Optional<Float> presence) {
return new AutoValue_Landmark(x, y, z, visibility, presence);
} }
// The x coordinates of the landmark. // The x coordinates of the landmark.
@ -39,6 +45,12 @@ public abstract class Landmark {
// The z coordinates of the landmark. // The z coordinates of the landmark.
public abstract float z(); public abstract float z();
// Visibility of the normalized landmark.
public abstract Optional<Float> visibility();
// Presence of the normalized landmark.
public abstract Optional<Float> presence();
@Override @Override
public final boolean equals(Object o) { public final boolean equals(Object o) {
if (!(o instanceof Landmark)) { if (!(o instanceof Landmark)) {
@ -57,6 +69,16 @@ public abstract class Landmark {
@Override @Override
public final String toString() { public final String toString() {
return "<Landmark (x=" + x() + " y=" + y() + " z=" + z() + ")>"; return "<Landmark (x="
+ x()
+ " y="
+ y()
+ " z="
+ z()
+ " visibility= "
+ visibility()
+ " presence="
+ presence()
+ ")>";
} }
} }

View File

@ -16,6 +16,7 @@ package com.google.mediapipe.tasks.components.containers;
import com.google.auto.value.AutoValue; import com.google.auto.value.AutoValue;
import java.util.Objects; import java.util.Objects;
import java.util.Optional;
/** /**
* Normalized Landmark represents a point in 3D space with x, y, z coordinates. x and y are * Normalized Landmark represents a point in 3D space with x, y, z coordinates. x and y are
@ -28,7 +29,12 @@ public abstract class NormalizedLandmark {
private static final float TOLERANCE = 1e-6f; private static final float TOLERANCE = 1e-6f;
public static NormalizedLandmark create(float x, float y, float z) { public static NormalizedLandmark create(float x, float y, float z) {
return new AutoValue_NormalizedLandmark(x, y, z); return new AutoValue_NormalizedLandmark(x, y, z, Optional.empty(), Optional.empty());
}
public static NormalizedLandmark create(
float x, float y, float z, Optional<Float> visibility, Optional<Float> presence) {
return new AutoValue_NormalizedLandmark(x, y, z, visibility, presence);
} }
// The x coordinates of the normalized landmark. // The x coordinates of the normalized landmark.
@ -40,6 +46,12 @@ public abstract class NormalizedLandmark {
// The z coordinates of the normalized landmark. // The z coordinates of the normalized landmark.
public abstract float z(); public abstract float z();
// Visibility of the normalized landmark.
public abstract Optional<Float> visibility();
// Presence of the normalized landmark.
public abstract Optional<Float> presence();
@Override @Override
public final boolean equals(Object o) { public final boolean equals(Object o) {
if (!(o instanceof NormalizedLandmark)) { if (!(o instanceof NormalizedLandmark)) {
@ -58,6 +70,16 @@ public abstract class NormalizedLandmark {
@Override @Override
public final String toString() { public final String toString() {
return "<Normalized Landmark (x=" + x() + " y=" + y() + " z=" + z() + ")>"; return "<Normalized Landmark (x="
+ x()
+ " y="
+ y()
+ " z="
+ z()
+ " visibility= "
+ visibility()
+ " presence="
+ presence()
+ ")>";
} }
} }

View File

@ -53,7 +53,15 @@ public abstract class FaceLandmarkerResult implements TaskResult {
faceLandmarksProto.getLandmarkList()) { faceLandmarksProto.getLandmarkList()) {
faceLandmarks.add( faceLandmarks.add(
NormalizedLandmark.create( NormalizedLandmark.create(
faceLandmarkProto.getX(), faceLandmarkProto.getY(), faceLandmarkProto.getZ())); faceLandmarkProto.getX(),
faceLandmarkProto.getY(),
faceLandmarkProto.getZ(),
faceLandmarkProto.hasVisibility()
? Optional.of(faceLandmarkProto.getVisibility())
: Optional.empty(),
faceLandmarkProto.hasPresence()
? Optional.of(faceLandmarkProto.getPresence())
: Optional.empty()));
} }
} }
Optional<List<List<Category>>> multiFaceBlendshapes = Optional.empty(); Optional<List<List<Category>>> multiFaceBlendshapes = Optional.empty();

View File

@ -25,6 +25,7 @@ import com.google.mediapipe.tasks.core.TaskResult;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Optional;
/** Represents the hand landmarks deection results generated by {@link HandLandmarker}. */ /** Represents the hand landmarks deection results generated by {@link HandLandmarker}. */
@AutoValue @AutoValue
@ -53,7 +54,15 @@ public abstract class HandLandmarkerResult implements TaskResult {
handLandmarksProto.getLandmarkList()) { handLandmarksProto.getLandmarkList()) {
handLandmarks.add( handLandmarks.add(
NormalizedLandmark.create( NormalizedLandmark.create(
handLandmarkProto.getX(), handLandmarkProto.getY(), handLandmarkProto.getZ())); handLandmarkProto.getX(),
handLandmarkProto.getY(),
handLandmarkProto.getZ(),
handLandmarkProto.hasVisibility()
? Optional.of(handLandmarkProto.getVisibility())
: Optional.empty(),
handLandmarkProto.hasPresence()
? Optional.of(handLandmarkProto.getPresence())
: Optional.empty()));
} }
} }
for (LandmarkProto.LandmarkList handWorldLandmarksProto : worldLandmarksProto) { for (LandmarkProto.LandmarkList handWorldLandmarksProto : worldLandmarksProto) {
@ -65,7 +74,13 @@ public abstract class HandLandmarkerResult implements TaskResult {
com.google.mediapipe.tasks.components.containers.Landmark.create( com.google.mediapipe.tasks.components.containers.Landmark.create(
handWorldLandmarkProto.getX(), handWorldLandmarkProto.getX(),
handWorldLandmarkProto.getY(), handWorldLandmarkProto.getY(),
handWorldLandmarkProto.getZ())); handWorldLandmarkProto.getZ(),
handWorldLandmarkProto.hasVisibility()
? Optional.of(handWorldLandmarkProto.getVisibility())
: Optional.empty(),
handWorldLandmarkProto.hasPresence()
? Optional.of(handWorldLandmarkProto.getPresence())
: Optional.empty()));
} }
} }
for (ClassificationList handednessProto : handednessesProto) { for (ClassificationList handednessProto : handednessesProto) {

View File

@ -58,7 +58,15 @@ public abstract class PoseLandmarkerResult implements TaskResult {
poseLandmarksProto.getLandmarkList()) { poseLandmarksProto.getLandmarkList()) {
poseLandmarks.add( poseLandmarks.add(
NormalizedLandmark.create( NormalizedLandmark.create(
poseLandmarkProto.getX(), poseLandmarkProto.getY(), poseLandmarkProto.getZ())); poseLandmarkProto.getX(),
poseLandmarkProto.getY(),
poseLandmarkProto.getZ(),
poseLandmarkProto.hasVisibility()
? Optional.of(poseLandmarkProto.getVisibility())
: Optional.empty(),
poseLandmarkProto.hasPresence()
? Optional.of(poseLandmarkProto.getPresence())
: Optional.empty()));
} }
} }
for (LandmarkProto.LandmarkList poseWorldLandmarksProto : worldLandmarksProto) { for (LandmarkProto.LandmarkList poseWorldLandmarksProto : worldLandmarksProto) {
@ -70,7 +78,13 @@ public abstract class PoseLandmarkerResult implements TaskResult {
Landmark.create( Landmark.create(
poseWorldLandmarkProto.getX(), poseWorldLandmarkProto.getX(),
poseWorldLandmarkProto.getY(), poseWorldLandmarkProto.getY(),
poseWorldLandmarkProto.getZ())); poseWorldLandmarkProto.getZ(),
poseWorldLandmarkProto.hasVisibility()
? Optional.of(poseWorldLandmarkProto.getVisibility())
: Optional.empty(),
poseWorldLandmarkProto.hasPresence()
? Optional.of(poseWorldLandmarkProto.getPresence())
: Optional.empty()));
} }
} }
return new AutoValue_PoseLandmarkerResult( return new AutoValue_PoseLandmarkerResult(

View File

@ -15,6 +15,7 @@
package com.google.mediapipe.tasks.vision.poselandmarker; package com.google.mediapipe.tasks.vision.poselandmarker;
import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertThat;
import static com.google.common.truth.Truth.assertWithMessage;
import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertThrows;
import android.content.res.AssetManager; import android.content.res.AssetManager;
@ -26,6 +27,7 @@ import com.google.common.truth.Correspondence;
import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.MediaPipeException;
import com.google.mediapipe.framework.image.BitmapImageBuilder; import com.google.mediapipe.framework.image.BitmapImageBuilder;
import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.framework.image.MPImage;
import com.google.mediapipe.tasks.components.containers.Landmark;
import com.google.mediapipe.tasks.components.containers.NormalizedLandmark; import com.google.mediapipe.tasks.components.containers.NormalizedLandmark;
import com.google.mediapipe.tasks.components.containers.proto.LandmarksDetectionResultProto.LandmarksDetectionResult; import com.google.mediapipe.tasks.components.containers.proto.LandmarksDetectionResultProto.LandmarksDetectionResult;
import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.BaseOptions;
@ -34,6 +36,7 @@ import com.google.mediapipe.tasks.vision.core.RunningMode;
import com.google.mediapipe.tasks.vision.poselandmarker.PoseLandmarker.PoseLandmarkerOptions; import com.google.mediapipe.tasks.vision.poselandmarker.PoseLandmarker.PoseLandmarkerOptions;
import java.io.InputStream; import java.io.InputStream;
import java.util.Arrays; import java.util.Arrays;
import java.util.List;
import java.util.Optional; import java.util.Optional;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
@ -50,6 +53,8 @@ public class PoseLandmarkerTest {
private static final String NO_POSES_IMAGE = "burger.jpg"; private static final String NO_POSES_IMAGE = "burger.jpg";
private static final String TAG = "Pose Landmarker Test"; private static final String TAG = "Pose Landmarker Test";
private static final float LANDMARKS_ERROR_TOLERANCE = 0.03f; private static final float LANDMARKS_ERROR_TOLERANCE = 0.03f;
private static final float VISIBILITY_TOLERANCE = 0.9f;
private static final float PRESENCE_TOLERANCE = 0.9f;
private static final int IMAGE_WIDTH = 1000; private static final int IMAGE_WIDTH = 1000;
private static final int IMAGE_HEIGHT = 667; private static final int IMAGE_HEIGHT = 667;
@ -70,6 +75,8 @@ public class PoseLandmarkerTest {
PoseLandmarkerResult actualResult = poseLandmarker.detect(getImageFromAsset(POSE_IMAGE)); PoseLandmarkerResult actualResult = poseLandmarker.detect(getImageFromAsset(POSE_IMAGE));
PoseLandmarkerResult expectedResult = getExpectedPoseLandmarkerResult(POSE_LANDMARKS); PoseLandmarkerResult expectedResult = getExpectedPoseLandmarkerResult(POSE_LANDMARKS);
assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult); assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult);
assertAllLandmarksAreVisibleAndPresent(
actualResult, VISIBILITY_TOLERANCE, PRESENCE_TOLERANCE);
} }
@Test @Test
@ -361,4 +368,40 @@ public class PoseLandmarkerTest {
assertThat(inputImage.getWidth()).isEqualTo(IMAGE_WIDTH); assertThat(inputImage.getWidth()).isEqualTo(IMAGE_WIDTH);
assertThat(inputImage.getHeight()).isEqualTo(IMAGE_HEIGHT); assertThat(inputImage.getHeight()).isEqualTo(IMAGE_HEIGHT);
} }
private static void assertAllLandmarksAreVisibleAndPresent(
PoseLandmarkerResult result, float visbilityThreshold, float presenceThreshold) {
for (int i = 0; i < result.landmarks().size(); i++) {
List<NormalizedLandmark> landmarks = result.landmarks().get(i);
for (int j = 0; j < landmarks.size(); j++) {
NormalizedLandmark landmark = landmarks.get(j);
String landmarkMessage = "Landmark List " + i + " landmark " + j + ": " + landmark;
landmark
.visibility()
.ifPresent(
val ->
assertWithMessage(landmarkMessage).that(val).isAtLeast((visbilityThreshold)));
landmark
.presence()
.ifPresent(
val -> assertWithMessage(landmarkMessage).that(val).isAtLeast((presenceThreshold)));
}
}
for (int i = 0; i < result.worldLandmarks().size(); i++) {
List<Landmark> landmarks = result.worldLandmarks().get(i);
for (int j = 0; j < landmarks.size(); j++) {
Landmark landmark = landmarks.get(j);
String landmarkMessage = "World Landmark List " + i + " landmark " + j + ": " + landmark;
landmark
.visibility()
.ifPresent(
val ->
assertWithMessage(landmarkMessage).that(val).isAtLeast((visbilityThreshold)));
landmark
.presence()
.ifPresent(
val -> assertWithMessage(landmarkMessage).that(val).isAtLeast((presenceThreshold)));
}
}
}
} }