Create shared utilities to construct landmark lists

PiperOrigin-RevId: 581970043
This commit is contained in:
Sebastian Schmidt 2023-11-13 08:20:21 -08:00 committed by Copybara-Service
parent 939a9c2a37
commit d504d3bf22
8 changed files with 226 additions and 100 deletions

View File

@ -96,6 +96,7 @@ android_library(
"-Xep:AndroidJdkLibsChecker:OFF", "-Xep:AndroidJdkLibsChecker:OFF",
], ],
deps = [ deps = [
"//mediapipe/framework/formats:landmark_java_proto_lite",
"//third_party:autovalue", "//third_party:autovalue",
"@maven//:com_google_guava_guava", "@maven//:com_google_guava_guava",
], ],
@ -108,6 +109,7 @@ android_library(
"-Xep:AndroidJdkLibsChecker:OFF", "-Xep:AndroidJdkLibsChecker:OFF",
], ],
deps = [ deps = [
"//mediapipe/framework/formats:landmark_java_proto_lite",
"//third_party:autovalue", "//third_party:autovalue",
"@maven//:com_google_guava_guava", "@maven//:com_google_guava_guava",
], ],

View File

@ -14,7 +14,11 @@
package com.google.mediapipe.tasks.components.containers; package com.google.mediapipe.tasks.components.containers;
import android.annotation.TargetApi;
import com.google.auto.value.AutoValue; import com.google.auto.value.AutoValue;
import com.google.mediapipe.formats.proto.LandmarkProto;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects; import java.util.Objects;
import java.util.Optional; import java.util.Optional;
@ -24,18 +28,44 @@ import java.util.Optional;
* is to the camera. * is to the camera.
*/ */
@AutoValue @AutoValue
@TargetApi(31)
public abstract class Landmark { public abstract class Landmark {
private static final float TOLERANCE = 1e-6f; private static final float TOLERANCE = 1e-6f;
/** Creates a landmark from x, y, z coordinates. */
public static Landmark create(float x, float y, float z) { public static Landmark create(float x, float y, float z) {
return new AutoValue_Landmark(x, y, z, Optional.empty(), Optional.empty()); return new AutoValue_Landmark(x, y, z, Optional.empty(), Optional.empty());
} }
/**
* Creates a normalized landmark from x, y, z coordinates with optional visibility and presence.
*/
public static Landmark create( public static Landmark create(
float x, float y, float z, Optional<Float> visibility, Optional<Float> presence) { float x, float y, float z, Optional<Float> visibility, Optional<Float> presence) {
return new AutoValue_Landmark(x, y, z, visibility, presence); return new AutoValue_Landmark(x, y, z, visibility, presence);
} }
/** Creates a landmark from a landmark proto. */
public static Landmark createFromProto(LandmarkProto.Landmark landmarkProto) {
return Landmark.create(
landmarkProto.getX(),
landmarkProto.getY(),
landmarkProto.getZ(),
landmarkProto.hasVisibility()
? Optional.of(landmarkProto.getVisibility())
: Optional.empty(),
landmarkProto.hasPresence() ? Optional.of(landmarkProto.getPresence()) : Optional.empty());
}
/** Creates a list of landmarks from a {@link LandmarkList}. */
public static List<Landmark> createListFromProto(LandmarkProto.LandmarkList landmarkListProto) {
List<Landmark> landmarkList = new ArrayList<>();
for (LandmarkProto.Landmark landmarkProto : landmarkListProto.getLandmarkList()) {
landmarkList.add(createFromProto(landmarkProto));
}
return landmarkList;
}
// The x coordinates of the landmark. // The x coordinates of the landmark.
public abstract float x(); public abstract float x();

View File

@ -14,7 +14,11 @@
package com.google.mediapipe.tasks.components.containers; package com.google.mediapipe.tasks.components.containers;
import android.annotation.TargetApi;
import com.google.auto.value.AutoValue; import com.google.auto.value.AutoValue;
import com.google.mediapipe.formats.proto.LandmarkProto;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects; import java.util.Objects;
import java.util.Optional; import java.util.Optional;
@ -25,18 +29,45 @@ import java.util.Optional;
* uses roughly the same scale as x. * uses roughly the same scale as x.
*/ */
@AutoValue @AutoValue
@TargetApi(31)
public abstract class NormalizedLandmark { public abstract class NormalizedLandmark {
private static final float TOLERANCE = 1e-6f; private static final float TOLERANCE = 1e-6f;
/** Creates a normalized landmark from x, y, z coordinates. */
public static NormalizedLandmark create(float x, float y, float z) { public static NormalizedLandmark create(float x, float y, float z) {
return new AutoValue_NormalizedLandmark(x, y, z, Optional.empty(), Optional.empty()); return new AutoValue_NormalizedLandmark(x, y, z, Optional.empty(), Optional.empty());
} }
/**
* Creates a normalized landmark from x, y, z coordinates with optional visibility and presence.
*/
public static NormalizedLandmark create( public static NormalizedLandmark create(
float x, float y, float z, Optional<Float> visibility, Optional<Float> presence) { float x, float y, float z, Optional<Float> visibility, Optional<Float> presence) {
return new AutoValue_NormalizedLandmark(x, y, z, visibility, presence); return new AutoValue_NormalizedLandmark(x, y, z, visibility, presence);
} }
/** Creates a normalized landmark from a normalized landmark proto. */
public static NormalizedLandmark createFromProto(LandmarkProto.NormalizedLandmark landmarkProto) {
return NormalizedLandmark.create(
landmarkProto.getX(),
landmarkProto.getY(),
landmarkProto.getZ(),
landmarkProto.hasVisibility()
? Optional.of(landmarkProto.getVisibility())
: Optional.empty(),
landmarkProto.hasPresence() ? Optional.of(landmarkProto.getPresence()) : Optional.empty());
}
/** Creates a list of normalized landmarks from a {@link NormalizedLandmarkList}. */
public static List<NormalizedLandmark> createListFromProto(
LandmarkProto.NormalizedLandmarkList landmarkListProto) {
List<NormalizedLandmark> landmarkList = new ArrayList<>();
for (LandmarkProto.NormalizedLandmark landmarkProto : landmarkListProto.getLandmarkList()) {
landmarkList.add(createFromProto(landmarkProto));
}
return landmarkList;
}
// The x coordinates of the normalized landmark. // The x coordinates of the normalized landmark.
public abstract float x(); public abstract float x();

View File

@ -46,23 +46,11 @@ public abstract class FaceLandmarkerResult implements TaskResult {
long timestampMs) { long timestampMs) {
List<List<NormalizedLandmark>> multiFaceLandmarks = new ArrayList<>(); List<List<NormalizedLandmark>> multiFaceLandmarks = new ArrayList<>();
for (LandmarkProto.NormalizedLandmarkList faceLandmarksProto : multiFaceLandmarksProto) { for (LandmarkProto.NormalizedLandmarkList faceLandmarksProto : multiFaceLandmarksProto) {
List<NormalizedLandmark> faceLandmarks = new ArrayList<>(); List<NormalizedLandmark> faceLandmarks =
multiFaceLandmarks.add(faceLandmarks); NormalizedLandmark.createListFromProto(faceLandmarksProto);
for (LandmarkProto.NormalizedLandmark faceLandmarkProto : multiFaceLandmarks.add(Collections.unmodifiableList(faceLandmarks));
faceLandmarksProto.getLandmarkList()) {
faceLandmarks.add(
NormalizedLandmark.create(
faceLandmarkProto.getX(),
faceLandmarkProto.getY(),
faceLandmarkProto.getZ(),
faceLandmarkProto.hasVisibility()
? Optional.of(faceLandmarkProto.getVisibility())
: Optional.empty(),
faceLandmarkProto.hasPresence()
? Optional.of(faceLandmarkProto.getPresence())
: Optional.empty()));
}
} }
Optional<List<List<Category>>> multiFaceBlendshapes = Optional.empty(); Optional<List<List<Category>>> multiFaceBlendshapes = Optional.empty();
if (multiFaceBendshapesProto.isPresent()) { if (multiFaceBendshapesProto.isPresent()) {
List<List<Category>> blendshapes = new ArrayList<>(); List<List<Category>> blendshapes = new ArrayList<>();
@ -72,6 +60,7 @@ public abstract class FaceLandmarkerResult implements TaskResult {
} }
multiFaceBlendshapes = Optional.of(Collections.unmodifiableList(blendshapes)); multiFaceBlendshapes = Optional.of(Collections.unmodifiableList(blendshapes));
} }
Optional<List<float[]>> multiFaceTransformationMatrixes = Optional.empty(); Optional<List<float[]>> multiFaceTransformationMatrixes = Optional.empty();
if (multiFaceTransformationMatrixesProto.isPresent()) { if (multiFaceTransformationMatrixesProto.isPresent()) {
List<float[]> matrixes = new ArrayList<>(); List<float[]> matrixes = new ArrayList<>();
@ -90,6 +79,7 @@ public abstract class FaceLandmarkerResult implements TaskResult {
} }
multiFaceTransformationMatrixes = Optional.of(Collections.unmodifiableList(matrixes)); multiFaceTransformationMatrixes = Optional.of(Collections.unmodifiableList(matrixes));
} }
return new AutoValue_FaceLandmarkerResult( return new AutoValue_FaceLandmarkerResult(
timestampMs, timestampMs,
Collections.unmodifiableList(multiFaceLandmarks), Collections.unmodifiableList(multiFaceLandmarks),

View File

@ -24,7 +24,6 @@ import com.google.mediapipe.tasks.core.TaskResult;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Optional;
/** Represents the hand landmarks deection results generated by {@link HandLandmarker}. */ /** Represents the hand landmarks deection results generated by {@link HandLandmarker}. */
@AutoValue @AutoValue
@ -34,63 +33,38 @@ public abstract class HandLandmarkerResult implements TaskResult {
* Creates a {@link HandLandmarkerResult} instance from the lists of landmarks and handedness * Creates a {@link HandLandmarkerResult} instance from the lists of landmarks and handedness
* protobuf messages. * protobuf messages.
* *
* @param landmarksProto a List of {@link NormalizedLandmarkList} * @param landmarksProtos a List of {@link NormalizedLandmarkList}
* @param worldLandmarksProto a List of {@link LandmarkList} * @param worldLandmarksProtos a List of {@link LandmarkList}
* @param handednessesProto a List of {@link ClassificationList} * @param handednessesProtos a List of {@link ClassificationList}
*/ */
static HandLandmarkerResult create( static HandLandmarkerResult create(
List<LandmarkProto.NormalizedLandmarkList> landmarksProto, List<LandmarkProto.NormalizedLandmarkList> landmarksProtos,
List<LandmarkProto.LandmarkList> worldLandmarksProto, List<LandmarkProto.LandmarkList> worldLandmarksProtos,
List<ClassificationList> handednessesProto, List<ClassificationList> handednessesProtos,
long timestampMs) { long timestampMs) {
List<List<NormalizedLandmark>> multiHandLandmarks = new ArrayList<>(); List<List<NormalizedLandmark>> handLandmarks = new ArrayList<>();
List<List<Landmark>> multiHandWorldLandmarks = new ArrayList<>(); for (LandmarkProto.NormalizedLandmarkList handLandmarksProto : landmarksProtos) {
List<List<Category>> multiHandHandednesses = new ArrayList<>();
for (LandmarkProto.NormalizedLandmarkList handLandmarksProto : landmarksProto) {
List<NormalizedLandmark> handLandmarks = new ArrayList<>();
multiHandLandmarks.add(handLandmarks);
for (LandmarkProto.NormalizedLandmark handLandmarkProto :
handLandmarksProto.getLandmarkList()) {
handLandmarks.add( handLandmarks.add(
NormalizedLandmark.create( Collections.unmodifiableList(NormalizedLandmark.createListFromProto(handLandmarksProto)));
handLandmarkProto.getX(),
handLandmarkProto.getY(),
handLandmarkProto.getZ(),
handLandmarkProto.hasVisibility()
? Optional.of(handLandmarkProto.getVisibility())
: Optional.empty(),
handLandmarkProto.hasPresence()
? Optional.of(handLandmarkProto.getPresence())
: Optional.empty()));
} }
}
for (LandmarkProto.LandmarkList handWorldLandmarksProto : worldLandmarksProto) { List<List<Landmark>> handWorldLandmarks = new ArrayList<>();
List<Landmark> handWorldLandmarks = new ArrayList<>(); for (LandmarkProto.LandmarkList handWorldLandmarksProto : worldLandmarksProtos) {
multiHandWorldLandmarks.add(handWorldLandmarks);
for (LandmarkProto.Landmark handWorldLandmarkProto :
handWorldLandmarksProto.getLandmarkList()) {
handWorldLandmarks.add( handWorldLandmarks.add(
com.google.mediapipe.tasks.components.containers.Landmark.create( Collections.unmodifiableList(Landmark.createListFromProto(handWorldLandmarksProto)));
handWorldLandmarkProto.getX(),
handWorldLandmarkProto.getY(),
handWorldLandmarkProto.getZ(),
handWorldLandmarkProto.hasVisibility()
? Optional.of(handWorldLandmarkProto.getVisibility())
: Optional.empty(),
handWorldLandmarkProto.hasPresence()
? Optional.of(handWorldLandmarkProto.getPresence())
: Optional.empty()));
} }
List<List<Category>> handHandednesses = new ArrayList<>();
for (ClassificationList handednessProto : handednessesProtos) {
handHandednesses.add(
Collections.unmodifiableList(Category.createListFromProto(handednessProto)));
} }
for (ClassificationList handednessProto : handednessesProto) {
List<Category> handedness = Category.createListFromProto(handednessProto);
multiHandHandednesses.add(Collections.unmodifiableList(handedness));
}
return new AutoValue_HandLandmarkerResult( return new AutoValue_HandLandmarkerResult(
timestampMs, timestampMs,
Collections.unmodifiableList(multiHandLandmarks), Collections.unmodifiableList(handLandmarks),
Collections.unmodifiableList(multiHandWorldLandmarks), Collections.unmodifiableList(handWorldLandmarks),
Collections.unmodifiableList(multiHandHandednesses)); Collections.unmodifiableList(handHandednesses));
} }
@Override @Override

View File

@ -50,43 +50,18 @@ public abstract class PoseLandmarkerResult implements TaskResult {
} }
List<List<NormalizedLandmark>> multiPoseLandmarks = new ArrayList<>(); List<List<NormalizedLandmark>> multiPoseLandmarks = new ArrayList<>();
for (LandmarkProto.NormalizedLandmarkList handLandmarksProto : landmarksProto) {
List<NormalizedLandmark> poseLandmarks =
NormalizedLandmark.createListFromProto(handLandmarksProto);
multiPoseLandmarks.add(Collections.unmodifiableList(poseLandmarks));
}
List<List<Landmark>> multiPoseWorldLandmarks = new ArrayList<>(); List<List<Landmark>> multiPoseWorldLandmarks = new ArrayList<>();
for (LandmarkProto.NormalizedLandmarkList poseLandmarksProto : landmarksProto) {
List<NormalizedLandmark> poseLandmarks = new ArrayList<>();
multiPoseLandmarks.add(poseLandmarks);
for (LandmarkProto.NormalizedLandmark poseLandmarkProto :
poseLandmarksProto.getLandmarkList()) {
poseLandmarks.add(
NormalizedLandmark.create(
poseLandmarkProto.getX(),
poseLandmarkProto.getY(),
poseLandmarkProto.getZ(),
poseLandmarkProto.hasVisibility()
? Optional.of(poseLandmarkProto.getVisibility())
: Optional.empty(),
poseLandmarkProto.hasPresence()
? Optional.of(poseLandmarkProto.getPresence())
: Optional.empty()));
}
}
for (LandmarkProto.LandmarkList poseWorldLandmarksProto : worldLandmarksProto) { for (LandmarkProto.LandmarkList poseWorldLandmarksProto : worldLandmarksProto) {
List<Landmark> poseWorldLandmarks = new ArrayList<>(); List<Landmark> poseWorldLandmarks = Landmark.createListFromProto(poseWorldLandmarksProto);
multiPoseWorldLandmarks.add(poseWorldLandmarks); multiPoseWorldLandmarks.add(Collections.unmodifiableList(poseWorldLandmarks));
for (LandmarkProto.Landmark poseWorldLandmarkProto :
poseWorldLandmarksProto.getLandmarkList()) {
poseWorldLandmarks.add(
Landmark.create(
poseWorldLandmarkProto.getX(),
poseWorldLandmarkProto.getY(),
poseWorldLandmarkProto.getZ(),
poseWorldLandmarkProto.hasVisibility()
? Optional.of(poseWorldLandmarkProto.getVisibility())
: Optional.empty(),
poseWorldLandmarkProto.hasPresence()
? Optional.of(poseWorldLandmarkProto.getPresence())
: Optional.empty()));
}
} }
return new AutoValue_PoseLandmarkerResult( return new AutoValue_PoseLandmarkerResult(
timestampMs, timestampMs,
Collections.unmodifiableList(multiPoseLandmarks), Collections.unmodifiableList(multiPoseLandmarks),

View File

@ -0,0 +1,62 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.components.containers;
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import androidx.test.ext.junit.runners.AndroidJUnit4;
import com.google.mediapipe.formats.proto.LandmarkProto;
import java.util.List;
import org.junit.Test;
import org.junit.runner.RunWith;
@RunWith(AndroidJUnit4.class)
public final class LandmarkTest {
@Test
public void createFromProto_succeedsWithCoordinates() {
LandmarkProto.Landmark input =
LandmarkProto.Landmark.newBuilder().setX(1.0f).setY(2.0f).setZ(3.0f).build();
Landmark output = Landmark.createFromProto(input);
assertThat(output.x()).isEqualTo(1.0f);
assertThat(output.y()).isEqualTo(2.0f);
assertThat(output.z()).isEqualTo(3.0f);
assertFalse(output.visibility().isPresent());
assertFalse(output.presence().isPresent());
}
@Test
public void createFromProto_succeedsWithVisibility() {
LandmarkProto.Landmark input =
LandmarkProto.Landmark.newBuilder().setVisibility(0.4f).setPresence(0.5f).build();
Landmark output = Landmark.createFromProto(input);
assertTrue(output.visibility().isPresent());
assertThat(output.visibility().get()).isEqualTo(0.4f);
assertTrue(output.presence().isPresent());
assertThat(output.presence().get()).isEqualTo(0.5f);
}
@Test
public void createListFromProto_succeeds() {
LandmarkProto.Landmark element =
LandmarkProto.Landmark.newBuilder().setX(1.0f).setY(2.0f).setZ(3.0f).build();
LandmarkProto.LandmarkList input =
LandmarkProto.LandmarkList.newBuilder().addLandmark(element).build();
List<Landmark> output = Landmark.createListFromProto(input);
assertThat(output).hasSize(1);
}
}

View File

@ -0,0 +1,62 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.components.containers;
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import androidx.test.ext.junit.runners.AndroidJUnit4;
import com.google.mediapipe.formats.proto.LandmarkProto;
import java.util.List;
import org.junit.Test;
import org.junit.runner.RunWith;
@RunWith(AndroidJUnit4.class)
public final class NormalizedLandmarkTest {
@Test
public void createFromProto_succeedsWithCoordinates() {
LandmarkProto.NormalizedLandmark input =
LandmarkProto.NormalizedLandmark.newBuilder().setX(0.1f).setY(0.2f).setZ(0.3f).build();
NormalizedLandmark output = NormalizedLandmark.createFromProto(input);
assertThat(output.x()).isEqualTo(0.1f);
assertThat(output.y()).isEqualTo(0.2f);
assertThat(output.z()).isEqualTo(0.3f);
assertFalse(output.visibility().isPresent());
assertFalse(output.presence().isPresent());
}
@Test
public void createFromProto_succeedsWithVisibility() {
LandmarkProto.NormalizedLandmark input =
LandmarkProto.NormalizedLandmark.newBuilder().setVisibility(0.4f).setPresence(0.5f).build();
NormalizedLandmark output = NormalizedLandmark.createFromProto(input);
assertTrue(output.visibility().isPresent());
assertThat(output.visibility().get()).isEqualTo(0.4f);
assertTrue(output.presence().isPresent());
assertThat(output.presence().get()).isEqualTo(0.5f);
}
@Test
public void createListFromProto_succeeds() {
LandmarkProto.NormalizedLandmark element =
LandmarkProto.NormalizedLandmark.newBuilder().setX(0.1f).setY(0.2f).setZ(0.3f).build();
LandmarkProto.NormalizedLandmarkList input =
LandmarkProto.NormalizedLandmarkList.newBuilder().addLandmark(element).build();
List<NormalizedLandmark> output = NormalizedLandmark.createListFromProto(input);
assertThat(output).hasSize(1);
}
}