Create shared utilities to construct category lists
PiperOrigin-RevId: 581009898
This commit is contained in:
parent
6532ce5c59
commit
edca85c5d3
|
@ -16,6 +16,10 @@ package com.google.mediapipe.tasks.components.containers;
|
|||
|
||||
import com.google.auto.value.AutoValue;
|
||||
import com.google.mediapipe.formats.proto.ClassificationProto;
|
||||
import com.google.mediapipe.formats.proto.ClassificationProto.Classification;
|
||||
import com.google.mediapipe.formats.proto.ClassificationProto.ClassificationList;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
|
@ -49,6 +53,22 @@ public abstract class Category {
|
|||
return create(proto.getScore(), proto.getIndex(), proto.getLabel(), proto.getDisplayName());
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a list of {@link Category} objects from a {@link
|
||||
* ClassificationProto.ClassificationList}.
|
||||
*
|
||||
* @param classificationListProto the {@link ClassificationProto.ClassificationList} protobuf
|
||||
* message to convert.
|
||||
* @return A list of {@link Category} objects.
|
||||
*/
|
||||
public static List<Category> createListFromProto(ClassificationList classificationListProto) {
|
||||
List<Category> categoryList = new ArrayList<>();
|
||||
for (Classification classification : classificationListProto.getClassificationList()) {
|
||||
categoryList.add(createFromProto(classification));
|
||||
}
|
||||
return categoryList;
|
||||
}
|
||||
|
||||
/** The probability score of this label category. */
|
||||
public abstract float score();
|
||||
|
||||
|
|
|
@ -15,9 +15,7 @@
|
|||
package com.google.mediapipe.tasks.components.containers;
|
||||
|
||||
import com.google.auto.value.AutoValue;
|
||||
import com.google.mediapipe.formats.proto.ClassificationProto;
|
||||
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
|
@ -49,11 +47,7 @@ public abstract class Classifications {
|
|||
* @param proto the {@link ClassificationsProto.Classifications} protobuf message to convert.
|
||||
*/
|
||||
public static Classifications createFromProto(ClassificationsProto.Classifications proto) {
|
||||
List<Category> categories = new ArrayList<>();
|
||||
for (ClassificationProto.Classification classificationProto :
|
||||
proto.getClassificationList().getClassificationList()) {
|
||||
categories.add(Category.createFromProto(classificationProto));
|
||||
}
|
||||
List<Category> categories = Category.createListFromProto(proto.getClassificationList());
|
||||
Optional<String> headName =
|
||||
proto.hasHeadName() ? Optional.of(proto.getHeadName()) : Optional.empty();
|
||||
return create(categories, proto.getHeadIndex(), headName);
|
||||
|
|
|
@ -208,11 +208,9 @@ android_library(
|
|||
deps = [
|
||||
":core",
|
||||
"//mediapipe/framework:calculator_options_java_proto_lite",
|
||||
"//mediapipe/framework/formats:classification_java_proto_lite",
|
||||
"//mediapipe/framework/formats:landmark_java_proto_lite",
|
||||
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
||||
"//mediapipe/java/com/google/mediapipe/framework/image",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/pose_detector/proto:pose_detector_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarker_graph_options_java_proto_lite",
|
||||
|
@ -222,7 +220,6 @@ android_library(
|
|||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:normalized_landmark",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
|
||||
"//third_party:autovalue",
|
||||
"@maven//:androidx_annotation_annotation",
|
||||
"@maven//:com_google_guava_guava",
|
||||
],
|
||||
)
|
||||
|
@ -246,7 +243,6 @@ android_library(
|
|||
"//mediapipe/framework/formats:landmark_java_proto_lite",
|
||||
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
||||
"//mediapipe/java/com/google/mediapipe/framework/image",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_java_proto_lite",
|
||||
|
|
|
@ -16,7 +16,6 @@ package com.google.mediapipe.tasks.vision.facelandmarker;
|
|||
|
||||
import com.google.auto.value.AutoValue;
|
||||
import com.google.mediapipe.formats.proto.LandmarkProto;
|
||||
import com.google.mediapipe.formats.proto.ClassificationProto.Classification;
|
||||
import com.google.mediapipe.formats.proto.ClassificationProto.ClassificationList;
|
||||
import com.google.mediapipe.tasks.components.containers.Category;
|
||||
import com.google.mediapipe.tasks.components.containers.NormalizedLandmark;
|
||||
|
@ -68,16 +67,8 @@ public abstract class FaceLandmarkerResult implements TaskResult {
|
|||
if (multiFaceBendshapesProto.isPresent()) {
|
||||
List<List<Category>> blendshapes = new ArrayList<>();
|
||||
for (ClassificationList faceBendshapeProto : multiFaceBendshapesProto.get()) {
|
||||
List<Category> blendshape = new ArrayList<>();
|
||||
blendshapes.add(blendshape);
|
||||
for (Classification classification : faceBendshapeProto.getClassificationList()) {
|
||||
blendshape.add(
|
||||
Category.create(
|
||||
classification.getScore(),
|
||||
classification.getIndex(),
|
||||
classification.getLabel(),
|
||||
classification.getDisplayName()));
|
||||
}
|
||||
List<Category> blendshape = Category.createListFromProto(faceBendshapeProto);
|
||||
blendshapes.add(Collections.unmodifiableList(blendshape));
|
||||
}
|
||||
multiFaceBlendshapes = Optional.of(Collections.unmodifiableList(blendshapes));
|
||||
}
|
||||
|
|
|
@ -75,16 +75,8 @@ public abstract class GestureRecognizerResult implements TaskResult {
|
|||
}
|
||||
}
|
||||
for (ClassificationList handednessProto : handednessesProto) {
|
||||
List<Category> handedness = new ArrayList<>();
|
||||
multiHandHandednesses.add(handedness);
|
||||
for (Classification classification : handednessProto.getClassificationList()) {
|
||||
handedness.add(
|
||||
Category.create(
|
||||
classification.getScore(),
|
||||
classification.getIndex(),
|
||||
classification.getLabel(),
|
||||
classification.getDisplayName()));
|
||||
}
|
||||
List<Category> handedness = Category.createListFromProto(handednessProto);
|
||||
multiHandHandednesses.add(Collections.unmodifiableList(handedness));
|
||||
}
|
||||
for (ClassificationList gestureProto : gesturesProto) {
|
||||
List<Category> gestures = new ArrayList<>();
|
||||
|
|
|
@ -16,7 +16,6 @@ package com.google.mediapipe.tasks.vision.handlandmarker;
|
|||
|
||||
import com.google.auto.value.AutoValue;
|
||||
import com.google.mediapipe.formats.proto.LandmarkProto;
|
||||
import com.google.mediapipe.formats.proto.ClassificationProto.Classification;
|
||||
import com.google.mediapipe.formats.proto.ClassificationProto.ClassificationList;
|
||||
import com.google.mediapipe.tasks.components.containers.Category;
|
||||
import com.google.mediapipe.tasks.components.containers.Landmark;
|
||||
|
@ -84,16 +83,8 @@ public abstract class HandLandmarkerResult implements TaskResult {
|
|||
}
|
||||
}
|
||||
for (ClassificationList handednessProto : handednessesProto) {
|
||||
List<Category> handedness = new ArrayList<>();
|
||||
multiHandHandednesses.add(handedness);
|
||||
for (Classification classification : handednessProto.getClassificationList()) {
|
||||
handedness.add(
|
||||
Category.create(
|
||||
classification.getScore(),
|
||||
classification.getIndex(),
|
||||
classification.getLabel(),
|
||||
classification.getDisplayName()));
|
||||
}
|
||||
List<Category> handedness = Category.createListFromProto(handednessProto);
|
||||
multiHandHandednesses.add(Collections.unmodifiableList(handedness));
|
||||
}
|
||||
return new AutoValue_HandLandmarkerResult(
|
||||
timestampMs,
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
package="com.google.mediapipe.tasks.components.containerstest"
|
||||
android:versionCode="1"
|
||||
android:versionName="1.0" >
|
||||
|
||||
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE"/>
|
||||
<uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE"/>
|
||||
|
||||
<uses-sdk android:minSdkVersion="24"
|
||||
android:targetSdkVersion="31" />
|
||||
|
||||
<application
|
||||
android:label="utilstest"
|
||||
android:name="android.support.multidex.MultiDexApplication"
|
||||
android:taskAffinity="">
|
||||
<uses-library android:name="android.test.runner" />
|
||||
</application>
|
||||
|
||||
<instrumentation
|
||||
android:name="com.google.android.apps.common.testing.testrunner.GoogleInstrumentationTestRunner"
|
||||
android:targetPackage="com.google.mediapipe.tasks.components.containerstest" />
|
||||
|
||||
</manifest>
|
|
@ -0,0 +1,19 @@
|
|||
# Copyright 2023 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
# TODO: Enable these tests in OSS
|
|
@ -0,0 +1,52 @@
|
|||
// Copyright 2022 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.components.containers;
|
||||
|
||||
import static com.google.common.truth.Truth.assertThat;
|
||||
|
||||
import androidx.test.ext.junit.runners.AndroidJUnit4;
|
||||
import com.google.mediapipe.formats.proto.ClassificationProto.Classification;
|
||||
import com.google.mediapipe.formats.proto.ClassificationProto.ClassificationList;
|
||||
import java.util.List;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
|
||||
@RunWith(AndroidJUnit4.class)
|
||||
public final class CategoryTest {
|
||||
|
||||
@Test
|
||||
public void create_succeedsWithClassificationProto() {
|
||||
Classification input =
|
||||
Classification.newBuilder()
|
||||
.setScore(0.1f)
|
||||
.setIndex(1)
|
||||
.setLabel("label")
|
||||
.setDisplayName("displayName")
|
||||
.build();
|
||||
Category output = Category.createFromProto(input);
|
||||
assertThat(output.score()).isEqualTo(0.1f);
|
||||
assertThat(output.index()).isEqualTo(1);
|
||||
assertThat(output.categoryName()).isEqualTo("label");
|
||||
assertThat(output.displayName()).isEqualTo("displayName");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void create_succeedsWithClassificationListProto() {
|
||||
Classification element = Classification.newBuilder().setScore(0.1f).build();
|
||||
ClassificationList input = ClassificationList.newBuilder().addClassification(element).build();
|
||||
List<Category> output = Category.createListFromProto(input);
|
||||
assertThat(output).containsExactly(Category.create(0.1f, 0, "", ""));
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user