From d504d3bf2202d7ebb087a348b123f26f7bbad976 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Mon, 13 Nov 2023 08:20:21 -0800 Subject: [PATCH] Create shared utilities to construct landmark lists PiperOrigin-RevId: 581970043 --- .../tasks/components/containers/BUILD | 2 + .../tasks/components/containers/Landmark.java | 30 ++++++++ .../containers/NormalizedLandmark.java | 31 ++++++++ .../facelandmarker/FaceLandmarkerResult.java | 22 ++---- .../handlandmarker/HandLandmarkerResult.java | 74 ++++++------------- .../poselandmarker/PoseLandmarkerResult.java | 43 +++-------- .../components/containers/LandmarkTest.java | 62 ++++++++++++++++ .../containers/NormalizedLandmarkTest.java | 62 ++++++++++++++++ 8 files changed, 226 insertions(+), 100 deletions(-) create mode 100644 mediapipe/tasks/javatests/com/google/mediapipe/tasks/components/containers/LandmarkTest.java create mode 100644 mediapipe/tasks/javatests/com/google/mediapipe/tasks/components/containers/NormalizedLandmarkTest.java diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD index bcdc0e5e5..1149ea036 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD @@ -96,6 +96,7 @@ android_library( "-Xep:AndroidJdkLibsChecker:OFF", ], deps = [ + "//mediapipe/framework/formats:landmark_java_proto_lite", "//third_party:autovalue", "@maven//:com_google_guava_guava", ], @@ -108,6 +109,7 @@ android_library( "-Xep:AndroidJdkLibsChecker:OFF", ], deps = [ + "//mediapipe/framework/formats:landmark_java_proto_lite", "//third_party:autovalue", "@maven//:com_google_guava_guava", ], diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Landmark.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Landmark.java index e23d9115d..b3bb2d52e 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Landmark.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Landmark.java @@ -14,7 +14,11 @@ package com.google.mediapipe.tasks.components.containers; +import android.annotation.TargetApi; import com.google.auto.value.AutoValue; +import com.google.mediapipe.formats.proto.LandmarkProto; +import java.util.ArrayList; +import java.util.List; import java.util.Objects; import java.util.Optional; @@ -24,18 +28,44 @@ import java.util.Optional; * is to the camera. */ @AutoValue +@TargetApi(31) public abstract class Landmark { private static final float TOLERANCE = 1e-6f; + /** Creates a landmark from x, y, z coordinates. */ public static Landmark create(float x, float y, float z) { return new AutoValue_Landmark(x, y, z, Optional.empty(), Optional.empty()); } + /** + * Creates a normalized landmark from x, y, z coordinates with optional visibility and presence. + */ public static Landmark create( float x, float y, float z, Optional visibility, Optional presence) { return new AutoValue_Landmark(x, y, z, visibility, presence); } + /** Creates a landmark from a landmark proto. */ + public static Landmark createFromProto(LandmarkProto.Landmark landmarkProto) { + return Landmark.create( + landmarkProto.getX(), + landmarkProto.getY(), + landmarkProto.getZ(), + landmarkProto.hasVisibility() + ? Optional.of(landmarkProto.getVisibility()) + : Optional.empty(), + landmarkProto.hasPresence() ? Optional.of(landmarkProto.getPresence()) : Optional.empty()); + } + + /** Creates a list of landmarks from a {@link LandmarkList}. */ + public static List createListFromProto(LandmarkProto.LandmarkList landmarkListProto) { + List landmarkList = new ArrayList<>(); + for (LandmarkProto.Landmark landmarkProto : landmarkListProto.getLandmarkList()) { + landmarkList.add(createFromProto(landmarkProto)); + } + return landmarkList; + } + // The x coordinates of the landmark. public abstract float x(); diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/NormalizedLandmark.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/NormalizedLandmark.java index 50a95d565..d6fa618a3 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/NormalizedLandmark.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/NormalizedLandmark.java @@ -14,7 +14,11 @@ package com.google.mediapipe.tasks.components.containers; +import android.annotation.TargetApi; import com.google.auto.value.AutoValue; +import com.google.mediapipe.formats.proto.LandmarkProto; +import java.util.ArrayList; +import java.util.List; import java.util.Objects; import java.util.Optional; @@ -25,18 +29,45 @@ import java.util.Optional; * uses roughly the same scale as x. */ @AutoValue +@TargetApi(31) public abstract class NormalizedLandmark { private static final float TOLERANCE = 1e-6f; + /** Creates a normalized landmark from x, y, z coordinates. */ public static NormalizedLandmark create(float x, float y, float z) { return new AutoValue_NormalizedLandmark(x, y, z, Optional.empty(), Optional.empty()); } + /** + * Creates a normalized landmark from x, y, z coordinates with optional visibility and presence. + */ public static NormalizedLandmark create( float x, float y, float z, Optional visibility, Optional presence) { return new AutoValue_NormalizedLandmark(x, y, z, visibility, presence); } + /** Creates a normalized landmark from a normalized landmark proto. */ + public static NormalizedLandmark createFromProto(LandmarkProto.NormalizedLandmark landmarkProto) { + return NormalizedLandmark.create( + landmarkProto.getX(), + landmarkProto.getY(), + landmarkProto.getZ(), + landmarkProto.hasVisibility() + ? Optional.of(landmarkProto.getVisibility()) + : Optional.empty(), + landmarkProto.hasPresence() ? Optional.of(landmarkProto.getPresence()) : Optional.empty()); + } + + /** Creates a list of normalized landmarks from a {@link NormalizedLandmarkList}. */ + public static List createListFromProto( + LandmarkProto.NormalizedLandmarkList landmarkListProto) { + List landmarkList = new ArrayList<>(); + for (LandmarkProto.NormalizedLandmark landmarkProto : landmarkListProto.getLandmarkList()) { + landmarkList.add(createFromProto(landmarkProto)); + } + return landmarkList; + } + // The x coordinates of the normalized landmark. public abstract float x(); diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facelandmarker/FaceLandmarkerResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facelandmarker/FaceLandmarkerResult.java index 98fb9376f..78bc7efb9 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facelandmarker/FaceLandmarkerResult.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facelandmarker/FaceLandmarkerResult.java @@ -46,23 +46,11 @@ public abstract class FaceLandmarkerResult implements TaskResult { long timestampMs) { List> multiFaceLandmarks = new ArrayList<>(); for (LandmarkProto.NormalizedLandmarkList faceLandmarksProto : multiFaceLandmarksProto) { - List faceLandmarks = new ArrayList<>(); - multiFaceLandmarks.add(faceLandmarks); - for (LandmarkProto.NormalizedLandmark faceLandmarkProto : - faceLandmarksProto.getLandmarkList()) { - faceLandmarks.add( - NormalizedLandmark.create( - faceLandmarkProto.getX(), - faceLandmarkProto.getY(), - faceLandmarkProto.getZ(), - faceLandmarkProto.hasVisibility() - ? Optional.of(faceLandmarkProto.getVisibility()) - : Optional.empty(), - faceLandmarkProto.hasPresence() - ? Optional.of(faceLandmarkProto.getPresence()) - : Optional.empty())); - } + List faceLandmarks = + NormalizedLandmark.createListFromProto(faceLandmarksProto); + multiFaceLandmarks.add(Collections.unmodifiableList(faceLandmarks)); } + Optional>> multiFaceBlendshapes = Optional.empty(); if (multiFaceBendshapesProto.isPresent()) { List> blendshapes = new ArrayList<>(); @@ -72,6 +60,7 @@ public abstract class FaceLandmarkerResult implements TaskResult { } multiFaceBlendshapes = Optional.of(Collections.unmodifiableList(blendshapes)); } + Optional> multiFaceTransformationMatrixes = Optional.empty(); if (multiFaceTransformationMatrixesProto.isPresent()) { List matrixes = new ArrayList<>(); @@ -90,6 +79,7 @@ public abstract class FaceLandmarkerResult implements TaskResult { } multiFaceTransformationMatrixes = Optional.of(Collections.unmodifiableList(matrixes)); } + return new AutoValue_FaceLandmarkerResult( timestampMs, Collections.unmodifiableList(multiFaceLandmarks), diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarkerResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarkerResult.java index 54ed04848..5a1661a52 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarkerResult.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarkerResult.java @@ -24,7 +24,6 @@ import com.google.mediapipe.tasks.core.TaskResult; import java.util.ArrayList; import java.util.Collections; import java.util.List; -import java.util.Optional; /** Represents the hand landmarks deection results generated by {@link HandLandmarker}. */ @AutoValue @@ -34,63 +33,38 @@ public abstract class HandLandmarkerResult implements TaskResult { * Creates a {@link HandLandmarkerResult} instance from the lists of landmarks and handedness * protobuf messages. * - * @param landmarksProto a List of {@link NormalizedLandmarkList} - * @param worldLandmarksProto a List of {@link LandmarkList} - * @param handednessesProto a List of {@link ClassificationList} + * @param landmarksProtos a List of {@link NormalizedLandmarkList} + * @param worldLandmarksProtos a List of {@link LandmarkList} + * @param handednessesProtos a List of {@link ClassificationList} */ static HandLandmarkerResult create( - List landmarksProto, - List worldLandmarksProto, - List handednessesProto, + List landmarksProtos, + List worldLandmarksProtos, + List handednessesProtos, long timestampMs) { - List> multiHandLandmarks = new ArrayList<>(); - List> multiHandWorldLandmarks = new ArrayList<>(); - List> multiHandHandednesses = new ArrayList<>(); - for (LandmarkProto.NormalizedLandmarkList handLandmarksProto : landmarksProto) { - List handLandmarks = new ArrayList<>(); - multiHandLandmarks.add(handLandmarks); - for (LandmarkProto.NormalizedLandmark handLandmarkProto : - handLandmarksProto.getLandmarkList()) { - handLandmarks.add( - NormalizedLandmark.create( - handLandmarkProto.getX(), - handLandmarkProto.getY(), - handLandmarkProto.getZ(), - handLandmarkProto.hasVisibility() - ? Optional.of(handLandmarkProto.getVisibility()) - : Optional.empty(), - handLandmarkProto.hasPresence() - ? Optional.of(handLandmarkProto.getPresence()) - : Optional.empty())); - } + List> handLandmarks = new ArrayList<>(); + for (LandmarkProto.NormalizedLandmarkList handLandmarksProto : landmarksProtos) { + handLandmarks.add( + Collections.unmodifiableList(NormalizedLandmark.createListFromProto(handLandmarksProto))); } - for (LandmarkProto.LandmarkList handWorldLandmarksProto : worldLandmarksProto) { - List handWorldLandmarks = new ArrayList<>(); - multiHandWorldLandmarks.add(handWorldLandmarks); - for (LandmarkProto.Landmark handWorldLandmarkProto : - handWorldLandmarksProto.getLandmarkList()) { - handWorldLandmarks.add( - com.google.mediapipe.tasks.components.containers.Landmark.create( - handWorldLandmarkProto.getX(), - handWorldLandmarkProto.getY(), - handWorldLandmarkProto.getZ(), - handWorldLandmarkProto.hasVisibility() - ? Optional.of(handWorldLandmarkProto.getVisibility()) - : Optional.empty(), - handWorldLandmarkProto.hasPresence() - ? Optional.of(handWorldLandmarkProto.getPresence()) - : Optional.empty())); - } + + List> handWorldLandmarks = new ArrayList<>(); + for (LandmarkProto.LandmarkList handWorldLandmarksProto : worldLandmarksProtos) { + handWorldLandmarks.add( + Collections.unmodifiableList(Landmark.createListFromProto(handWorldLandmarksProto))); } - for (ClassificationList handednessProto : handednessesProto) { - List handedness = Category.createListFromProto(handednessProto); - multiHandHandednesses.add(Collections.unmodifiableList(handedness)); + + List> handHandednesses = new ArrayList<>(); + for (ClassificationList handednessProto : handednessesProtos) { + handHandednesses.add( + Collections.unmodifiableList(Category.createListFromProto(handednessProto))); } + return new AutoValue_HandLandmarkerResult( timestampMs, - Collections.unmodifiableList(multiHandLandmarks), - Collections.unmodifiableList(multiHandWorldLandmarks), - Collections.unmodifiableList(multiHandHandednesses)); + Collections.unmodifiableList(handLandmarks), + Collections.unmodifiableList(handWorldLandmarks), + Collections.unmodifiableList(handHandednesses)); } @Override diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/poselandmarker/PoseLandmarkerResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/poselandmarker/PoseLandmarkerResult.java index e693c7b88..792d7407d 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/poselandmarker/PoseLandmarkerResult.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/poselandmarker/PoseLandmarkerResult.java @@ -50,43 +50,18 @@ public abstract class PoseLandmarkerResult implements TaskResult { } List> multiPoseLandmarks = new ArrayList<>(); + for (LandmarkProto.NormalizedLandmarkList handLandmarksProto : landmarksProto) { + List poseLandmarks = + NormalizedLandmark.createListFromProto(handLandmarksProto); + multiPoseLandmarks.add(Collections.unmodifiableList(poseLandmarks)); + } + List> multiPoseWorldLandmarks = new ArrayList<>(); - for (LandmarkProto.NormalizedLandmarkList poseLandmarksProto : landmarksProto) { - List poseLandmarks = new ArrayList<>(); - multiPoseLandmarks.add(poseLandmarks); - for (LandmarkProto.NormalizedLandmark poseLandmarkProto : - poseLandmarksProto.getLandmarkList()) { - poseLandmarks.add( - NormalizedLandmark.create( - poseLandmarkProto.getX(), - poseLandmarkProto.getY(), - poseLandmarkProto.getZ(), - poseLandmarkProto.hasVisibility() - ? Optional.of(poseLandmarkProto.getVisibility()) - : Optional.empty(), - poseLandmarkProto.hasPresence() - ? Optional.of(poseLandmarkProto.getPresence()) - : Optional.empty())); - } - } for (LandmarkProto.LandmarkList poseWorldLandmarksProto : worldLandmarksProto) { - List poseWorldLandmarks = new ArrayList<>(); - multiPoseWorldLandmarks.add(poseWorldLandmarks); - for (LandmarkProto.Landmark poseWorldLandmarkProto : - poseWorldLandmarksProto.getLandmarkList()) { - poseWorldLandmarks.add( - Landmark.create( - poseWorldLandmarkProto.getX(), - poseWorldLandmarkProto.getY(), - poseWorldLandmarkProto.getZ(), - poseWorldLandmarkProto.hasVisibility() - ? Optional.of(poseWorldLandmarkProto.getVisibility()) - : Optional.empty(), - poseWorldLandmarkProto.hasPresence() - ? Optional.of(poseWorldLandmarkProto.getPresence()) - : Optional.empty())); - } + List poseWorldLandmarks = Landmark.createListFromProto(poseWorldLandmarksProto); + multiPoseWorldLandmarks.add(Collections.unmodifiableList(poseWorldLandmarks)); } + return new AutoValue_PoseLandmarkerResult( timestampMs, Collections.unmodifiableList(multiPoseLandmarks), diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/components/containers/LandmarkTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/components/containers/LandmarkTest.java new file mode 100644 index 000000000..b5ff0564a --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/components/containers/LandmarkTest.java @@ -0,0 +1,62 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.components.containers; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import androidx.test.ext.junit.runners.AndroidJUnit4; +import com.google.mediapipe.formats.proto.LandmarkProto; +import java.util.List; +import org.junit.Test; +import org.junit.runner.RunWith; + +@RunWith(AndroidJUnit4.class) +public final class LandmarkTest { + + @Test + public void createFromProto_succeedsWithCoordinates() { + LandmarkProto.Landmark input = + LandmarkProto.Landmark.newBuilder().setX(1.0f).setY(2.0f).setZ(3.0f).build(); + Landmark output = Landmark.createFromProto(input); + assertThat(output.x()).isEqualTo(1.0f); + assertThat(output.y()).isEqualTo(2.0f); + assertThat(output.z()).isEqualTo(3.0f); + assertFalse(output.visibility().isPresent()); + assertFalse(output.presence().isPresent()); + } + + @Test + public void createFromProto_succeedsWithVisibility() { + LandmarkProto.Landmark input = + LandmarkProto.Landmark.newBuilder().setVisibility(0.4f).setPresence(0.5f).build(); + Landmark output = Landmark.createFromProto(input); + assertTrue(output.visibility().isPresent()); + assertThat(output.visibility().get()).isEqualTo(0.4f); + assertTrue(output.presence().isPresent()); + assertThat(output.presence().get()).isEqualTo(0.5f); + } + + @Test + public void createListFromProto_succeeds() { + LandmarkProto.Landmark element = + LandmarkProto.Landmark.newBuilder().setX(1.0f).setY(2.0f).setZ(3.0f).build(); + LandmarkProto.LandmarkList input = + LandmarkProto.LandmarkList.newBuilder().addLandmark(element).build(); + List output = Landmark.createListFromProto(input); + assertThat(output).hasSize(1); + } +} diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/components/containers/NormalizedLandmarkTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/components/containers/NormalizedLandmarkTest.java new file mode 100644 index 000000000..64b61d263 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/components/containers/NormalizedLandmarkTest.java @@ -0,0 +1,62 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.components.containers; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import androidx.test.ext.junit.runners.AndroidJUnit4; +import com.google.mediapipe.formats.proto.LandmarkProto; +import java.util.List; +import org.junit.Test; +import org.junit.runner.RunWith; + +@RunWith(AndroidJUnit4.class) +public final class NormalizedLandmarkTest { + + @Test + public void createFromProto_succeedsWithCoordinates() { + LandmarkProto.NormalizedLandmark input = + LandmarkProto.NormalizedLandmark.newBuilder().setX(0.1f).setY(0.2f).setZ(0.3f).build(); + NormalizedLandmark output = NormalizedLandmark.createFromProto(input); + assertThat(output.x()).isEqualTo(0.1f); + assertThat(output.y()).isEqualTo(0.2f); + assertThat(output.z()).isEqualTo(0.3f); + assertFalse(output.visibility().isPresent()); + assertFalse(output.presence().isPresent()); + } + + @Test + public void createFromProto_succeedsWithVisibility() { + LandmarkProto.NormalizedLandmark input = + LandmarkProto.NormalizedLandmark.newBuilder().setVisibility(0.4f).setPresence(0.5f).build(); + NormalizedLandmark output = NormalizedLandmark.createFromProto(input); + assertTrue(output.visibility().isPresent()); + assertThat(output.visibility().get()).isEqualTo(0.4f); + assertTrue(output.presence().isPresent()); + assertThat(output.presence().get()).isEqualTo(0.5f); + } + + @Test + public void createListFromProto_succeeds() { + LandmarkProto.NormalizedLandmark element = + LandmarkProto.NormalizedLandmark.newBuilder().setX(0.1f).setY(0.2f).setZ(0.3f).build(); + LandmarkProto.NormalizedLandmarkList input = + LandmarkProto.NormalizedLandmarkList.newBuilder().addLandmark(element).build(); + List output = NormalizedLandmark.createListFromProto(input); + assertThat(output).hasSize(1); + } +}