Create shared utilities to construct category lists

PiperOrigin-RevId: 581009898
This commit is contained in:
Sebastian Schmidt 2023-11-09 13:34:48 -08:00 committed by Copybara-Service
parent 6532ce5c59
commit edca85c5d3
9 changed files with 122 additions and 43 deletions

View File

@ -16,6 +16,10 @@ package com.google.mediapipe.tasks.components.containers;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.formats.proto.ClassificationProto;
import com.google.mediapipe.formats.proto.ClassificationProto.Classification;
import com.google.mediapipe.formats.proto.ClassificationProto.ClassificationList;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
/**
@ -49,6 +53,22 @@ public abstract class Category {
return create(proto.getScore(), proto.getIndex(), proto.getLabel(), proto.getDisplayName());
}
/**
* Creates a list of {@link Category} objects from a {@link
* ClassificationProto.ClassificationList}.
*
* @param classificationListProto the {@link ClassificationProto.ClassificationList} protobuf
* message to convert.
* @return A list of {@link Category} objects.
*/
public static List<Category> createListFromProto(ClassificationList classificationListProto) {
List<Category> categoryList = new ArrayList<>();
for (Classification classification : classificationListProto.getClassificationList()) {
categoryList.add(createFromProto(classification));
}
return categoryList;
}
/** The probability score of this label category. */
public abstract float score();

View File

@ -15,9 +15,7 @@
package com.google.mediapipe.tasks.components.containers;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.formats.proto.ClassificationProto;
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
@ -49,11 +47,7 @@ public abstract class Classifications {
* @param proto the {@link ClassificationsProto.Classifications} protobuf message to convert.
*/
public static Classifications createFromProto(ClassificationsProto.Classifications proto) {
List<Category> categories = new ArrayList<>();
for (ClassificationProto.Classification classificationProto :
proto.getClassificationList().getClassificationList()) {
categories.add(Category.createFromProto(classificationProto));
}
List<Category> categories = Category.createListFromProto(proto.getClassificationList());
Optional<String> headName =
proto.hasHeadName() ? Optional.of(proto.getHeadName()) : Optional.empty();
return create(categories, proto.getHeadIndex(), headName);

View File

@ -208,11 +208,9 @@ android_library(
deps = [
":core",
"//mediapipe/framework:calculator_options_java_proto_lite",
"//mediapipe/framework/formats:classification_java_proto_lite",
"//mediapipe/framework/formats:landmark_java_proto_lite",
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
"//mediapipe/java/com/google/mediapipe/framework/image",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite",
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/pose_detector/proto:pose_detector_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarker_graph_options_java_proto_lite",
@ -222,7 +220,6 @@ android_library(
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:normalized_landmark",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
"//third_party:autovalue",
"@maven//:androidx_annotation_annotation",
"@maven//:com_google_guava_guava",
],
)
@ -246,7 +243,6 @@ android_library(
"//mediapipe/framework/formats:landmark_java_proto_lite",
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
"//mediapipe/java/com/google/mediapipe/framework/image",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite",
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_java_proto_lite",

View File

@ -16,7 +16,6 @@ package com.google.mediapipe.tasks.vision.facelandmarker;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.formats.proto.LandmarkProto;
import com.google.mediapipe.formats.proto.ClassificationProto.Classification;
import com.google.mediapipe.formats.proto.ClassificationProto.ClassificationList;
import com.google.mediapipe.tasks.components.containers.Category;
import com.google.mediapipe.tasks.components.containers.NormalizedLandmark;
@ -68,16 +67,8 @@ public abstract class FaceLandmarkerResult implements TaskResult {
if (multiFaceBendshapesProto.isPresent()) {
List<List<Category>> blendshapes = new ArrayList<>();
for (ClassificationList faceBendshapeProto : multiFaceBendshapesProto.get()) {
List<Category> blendshape = new ArrayList<>();
blendshapes.add(blendshape);
for (Classification classification : faceBendshapeProto.getClassificationList()) {
blendshape.add(
Category.create(
classification.getScore(),
classification.getIndex(),
classification.getLabel(),
classification.getDisplayName()));
}
List<Category> blendshape = Category.createListFromProto(faceBendshapeProto);
blendshapes.add(Collections.unmodifiableList(blendshape));
}
multiFaceBlendshapes = Optional.of(Collections.unmodifiableList(blendshapes));
}

View File

@ -75,16 +75,8 @@ public abstract class GestureRecognizerResult implements TaskResult {
}
}
for (ClassificationList handednessProto : handednessesProto) {
List<Category> handedness = new ArrayList<>();
multiHandHandednesses.add(handedness);
for (Classification classification : handednessProto.getClassificationList()) {
handedness.add(
Category.create(
classification.getScore(),
classification.getIndex(),
classification.getLabel(),
classification.getDisplayName()));
}
List<Category> handedness = Category.createListFromProto(handednessProto);
multiHandHandednesses.add(Collections.unmodifiableList(handedness));
}
for (ClassificationList gestureProto : gesturesProto) {
List<Category> gestures = new ArrayList<>();

View File

@ -16,7 +16,6 @@ package com.google.mediapipe.tasks.vision.handlandmarker;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.formats.proto.LandmarkProto;
import com.google.mediapipe.formats.proto.ClassificationProto.Classification;
import com.google.mediapipe.formats.proto.ClassificationProto.ClassificationList;
import com.google.mediapipe.tasks.components.containers.Category;
import com.google.mediapipe.tasks.components.containers.Landmark;
@ -84,16 +83,8 @@ public abstract class HandLandmarkerResult implements TaskResult {
}
}
for (ClassificationList handednessProto : handednessesProto) {
List<Category> handedness = new ArrayList<>();
multiHandHandednesses.add(handedness);
for (Classification classification : handednessProto.getClassificationList()) {
handedness.add(
Category.create(
classification.getScore(),
classification.getIndex(),
classification.getLabel(),
classification.getDisplayName()));
}
List<Category> handedness = Category.createListFromProto(handednessProto);
multiHandHandednesses.add(Collections.unmodifiableList(handedness));
}
return new AutoValue_HandLandmarkerResult(
timestampMs,

View File

@ -0,0 +1,24 @@
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="com.google.mediapipe.tasks.components.containerstest"
android:versionCode="1"
android:versionName="1.0" >
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE"/>
<uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE"/>
<uses-sdk android:minSdkVersion="24"
android:targetSdkVersion="31" />
<application
android:label="utilstest"
android:name="android.support.multidex.MultiDexApplication"
android:taskAffinity="">
<uses-library android:name="android.test.runner" />
</application>
<instrumentation
android:name="com.google.android.apps.common.testing.testrunner.GoogleInstrumentationTestRunner"
android:targetPackage="com.google.mediapipe.tasks.components.containerstest" />
</manifest>

View File

@ -0,0 +1,19 @@
# Copyright 2023 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
# TODO: Enable these tests in OSS

View File

@ -0,0 +1,52 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.components.containers;
import static com.google.common.truth.Truth.assertThat;
import androidx.test.ext.junit.runners.AndroidJUnit4;
import com.google.mediapipe.formats.proto.ClassificationProto.Classification;
import com.google.mediapipe.formats.proto.ClassificationProto.ClassificationList;
import java.util.List;
import org.junit.Test;
import org.junit.runner.RunWith;
@RunWith(AndroidJUnit4.class)
public final class CategoryTest {
@Test
public void create_succeedsWithClassificationProto() {
Classification input =
Classification.newBuilder()
.setScore(0.1f)
.setIndex(1)
.setLabel("label")
.setDisplayName("displayName")
.build();
Category output = Category.createFromProto(input);
assertThat(output.score()).isEqualTo(0.1f);
assertThat(output.index()).isEqualTo(1);
assertThat(output.categoryName()).isEqualTo("label");
assertThat(output.displayName()).isEqualTo("displayName");
}
@Test
public void create_succeedsWithClassificationListProto() {
Classification element = Classification.newBuilder().setScore(0.1f).build();
ClassificationList input = ClassificationList.newBuilder().addClassification(element).build();
List<Category> output = Category.createListFromProto(input);
assertThat(output).containsExactly(Category.create(0.1f, 0, "", ""));
}
}