From edca85c5d3fa4ef848340fcc9f88e9d95db05688 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Thu, 9 Nov 2023 13:34:48 -0800 Subject: [PATCH] Create shared utilities to construct category lists PiperOrigin-RevId: 581009898 --- .../tasks/components/containers/Category.java | 20 +++++++ .../containers/Classifications.java | 8 +-- .../com/google/mediapipe/tasks/vision/BUILD | 4 -- .../facelandmarker/FaceLandmarkerResult.java | 13 +---- .../GestureRecognizerResult.java | 12 +---- .../handlandmarker/HandLandmarkerResult.java | 13 +---- .../components/containers/AndroidManifest.xml | 24 +++++++++ .../tasks/components/containers/BUILD | 19 +++++++ .../components/containers/CategoryTest.java | 52 +++++++++++++++++++ 9 files changed, 122 insertions(+), 43 deletions(-) create mode 100644 mediapipe/tasks/javatests/com/google/mediapipe/tasks/components/containers/AndroidManifest.xml create mode 100644 mediapipe/tasks/javatests/com/google/mediapipe/tasks/components/containers/BUILD create mode 100644 mediapipe/tasks/javatests/com/google/mediapipe/tasks/components/containers/CategoryTest.java diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Category.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Category.java index 65996c2af..916ad1bed 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Category.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Category.java @@ -16,6 +16,10 @@ package com.google.mediapipe.tasks.components.containers; import com.google.auto.value.AutoValue; import com.google.mediapipe.formats.proto.ClassificationProto; +import com.google.mediapipe.formats.proto.ClassificationProto.Classification; +import com.google.mediapipe.formats.proto.ClassificationProto.ClassificationList; +import java.util.ArrayList; +import java.util.List; import java.util.Objects; /** @@ -49,6 +53,22 @@ public abstract class Category { return create(proto.getScore(), proto.getIndex(), proto.getLabel(), proto.getDisplayName()); } + /** + * Creates a list of {@link Category} objects from a {@link + * ClassificationProto.ClassificationList}. + * + * @param classificationListProto the {@link ClassificationProto.ClassificationList} protobuf + * message to convert. + * @return A list of {@link Category} objects. + */ + public static List createListFromProto(ClassificationList classificationListProto) { + List categoryList = new ArrayList<>(); + for (Classification classification : classificationListProto.getClassificationList()) { + categoryList.add(createFromProto(classification)); + } + return categoryList; + } + /** The probability score of this label category. */ public abstract float score(); diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Classifications.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Classifications.java index 9e53590d7..7c2a1fc21 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Classifications.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Classifications.java @@ -15,9 +15,7 @@ package com.google.mediapipe.tasks.components.containers; import com.google.auto.value.AutoValue; -import com.google.mediapipe.formats.proto.ClassificationProto; import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto; -import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Optional; @@ -49,11 +47,7 @@ public abstract class Classifications { * @param proto the {@link ClassificationsProto.Classifications} protobuf message to convert. */ public static Classifications createFromProto(ClassificationsProto.Classifications proto) { - List categories = new ArrayList<>(); - for (ClassificationProto.Classification classificationProto : - proto.getClassificationList().getClassificationList()) { - categories.add(Category.createFromProto(classificationProto)); - } + List categories = Category.createListFromProto(proto.getClassificationList()); Optional headName = proto.hasHeadName() ? Optional.of(proto.getHeadName()) : Optional.empty(); return create(categories, proto.getHeadIndex(), headName); diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD index 60a9806e9..181b45dc8 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD @@ -208,11 +208,9 @@ android_library( deps = [ ":core", "//mediapipe/framework:calculator_options_java_proto_lite", - "//mediapipe/framework/formats:classification_java_proto_lite", "//mediapipe/framework/formats:landmark_java_proto_lite", "//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/java/com/google/mediapipe/framework/image", - "//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/cc/vision/pose_detector/proto:pose_detector_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarker_graph_options_java_proto_lite", @@ -222,7 +220,6 @@ android_library( "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:normalized_landmark", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//third_party:autovalue", - "@maven//:androidx_annotation_annotation", "@maven//:com_google_guava_guava", ], ) @@ -246,7 +243,6 @@ android_library( "//mediapipe/framework/formats:landmark_java_proto_lite", "//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/java/com/google/mediapipe/framework/image", - "//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_java_proto_lite", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facelandmarker/FaceLandmarkerResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facelandmarker/FaceLandmarkerResult.java index 0429ecacb..98fb9376f 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facelandmarker/FaceLandmarkerResult.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facelandmarker/FaceLandmarkerResult.java @@ -16,7 +16,6 @@ package com.google.mediapipe.tasks.vision.facelandmarker; import com.google.auto.value.AutoValue; import com.google.mediapipe.formats.proto.LandmarkProto; -import com.google.mediapipe.formats.proto.ClassificationProto.Classification; import com.google.mediapipe.formats.proto.ClassificationProto.ClassificationList; import com.google.mediapipe.tasks.components.containers.Category; import com.google.mediapipe.tasks.components.containers.NormalizedLandmark; @@ -68,16 +67,8 @@ public abstract class FaceLandmarkerResult implements TaskResult { if (multiFaceBendshapesProto.isPresent()) { List> blendshapes = new ArrayList<>(); for (ClassificationList faceBendshapeProto : multiFaceBendshapesProto.get()) { - List blendshape = new ArrayList<>(); - blendshapes.add(blendshape); - for (Classification classification : faceBendshapeProto.getClassificationList()) { - blendshape.add( - Category.create( - classification.getScore(), - classification.getIndex(), - classification.getLabel(), - classification.getDisplayName())); - } + List blendshape = Category.createListFromProto(faceBendshapeProto); + blendshapes.add(Collections.unmodifiableList(blendshape)); } multiFaceBlendshapes = Optional.of(Collections.unmodifiableList(blendshapes)); } diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerResult.java index c8d43e2ca..09ceac215 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerResult.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerResult.java @@ -75,16 +75,8 @@ public abstract class GestureRecognizerResult implements TaskResult { } } for (ClassificationList handednessProto : handednessesProto) { - List handedness = new ArrayList<>(); - multiHandHandednesses.add(handedness); - for (Classification classification : handednessProto.getClassificationList()) { - handedness.add( - Category.create( - classification.getScore(), - classification.getIndex(), - classification.getLabel(), - classification.getDisplayName())); - } + List handedness = Category.createListFromProto(handednessProto); + multiHandHandednesses.add(Collections.unmodifiableList(handedness)); } for (ClassificationList gestureProto : gesturesProto) { List gestures = new ArrayList<>(); diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarkerResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarkerResult.java index 14d2fa926..54ed04848 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarkerResult.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarkerResult.java @@ -16,7 +16,6 @@ package com.google.mediapipe.tasks.vision.handlandmarker; import com.google.auto.value.AutoValue; import com.google.mediapipe.formats.proto.LandmarkProto; -import com.google.mediapipe.formats.proto.ClassificationProto.Classification; import com.google.mediapipe.formats.proto.ClassificationProto.ClassificationList; import com.google.mediapipe.tasks.components.containers.Category; import com.google.mediapipe.tasks.components.containers.Landmark; @@ -84,16 +83,8 @@ public abstract class HandLandmarkerResult implements TaskResult { } } for (ClassificationList handednessProto : handednessesProto) { - List handedness = new ArrayList<>(); - multiHandHandednesses.add(handedness); - for (Classification classification : handednessProto.getClassificationList()) { - handedness.add( - Category.create( - classification.getScore(), - classification.getIndex(), - classification.getLabel(), - classification.getDisplayName())); - } + List handedness = Category.createListFromProto(handednessProto); + multiHandHandednesses.add(Collections.unmodifiableList(handedness)); } return new AutoValue_HandLandmarkerResult( timestampMs, diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/components/containers/AndroidManifest.xml b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/components/containers/AndroidManifest.xml new file mode 100644 index 000000000..4a6416933 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/components/containers/AndroidManifest.xml @@ -0,0 +1,24 @@ + + + + + + + + + + + + + + + diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/components/containers/BUILD b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/components/containers/BUILD new file mode 100644 index 000000000..7363a23e0 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/components/containers/BUILD @@ -0,0 +1,19 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +# TODO: Enable these tests in OSS diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/components/containers/CategoryTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/components/containers/CategoryTest.java new file mode 100644 index 000000000..ed501ac57 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/components/containers/CategoryTest.java @@ -0,0 +1,52 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.components.containers; + +import static com.google.common.truth.Truth.assertThat; + +import androidx.test.ext.junit.runners.AndroidJUnit4; +import com.google.mediapipe.formats.proto.ClassificationProto.Classification; +import com.google.mediapipe.formats.proto.ClassificationProto.ClassificationList; +import java.util.List; +import org.junit.Test; +import org.junit.runner.RunWith; + +@RunWith(AndroidJUnit4.class) +public final class CategoryTest { + + @Test + public void create_succeedsWithClassificationProto() { + Classification input = + Classification.newBuilder() + .setScore(0.1f) + .setIndex(1) + .setLabel("label") + .setDisplayName("displayName") + .build(); + Category output = Category.createFromProto(input); + assertThat(output.score()).isEqualTo(0.1f); + assertThat(output.index()).isEqualTo(1); + assertThat(output.categoryName()).isEqualTo("label"); + assertThat(output.displayName()).isEqualTo("displayName"); + } + + @Test + public void create_succeedsWithClassificationListProto() { + Classification element = Classification.newBuilder().setScore(0.1f).build(); + ClassificationList input = ClassificationList.newBuilder().addClassification(element).build(); + List output = Category.createListFromProto(input); + assertThat(output).containsExactly(Category.create(0.1f, 0, "", "")); + } +}