Create shared utilities to construct category lists
PiperOrigin-RevId: 581009898
This commit is contained in:
parent
6532ce5c59
commit
edca85c5d3
|
@ -16,6 +16,10 @@ package com.google.mediapipe.tasks.components.containers;
|
||||||
|
|
||||||
import com.google.auto.value.AutoValue;
|
import com.google.auto.value.AutoValue;
|
||||||
import com.google.mediapipe.formats.proto.ClassificationProto;
|
import com.google.mediapipe.formats.proto.ClassificationProto;
|
||||||
|
import com.google.mediapipe.formats.proto.ClassificationProto.Classification;
|
||||||
|
import com.google.mediapipe.formats.proto.ClassificationProto.ClassificationList;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -49,6 +53,22 @@ public abstract class Category {
|
||||||
return create(proto.getScore(), proto.getIndex(), proto.getLabel(), proto.getDisplayName());
|
return create(proto.getScore(), proto.getIndex(), proto.getLabel(), proto.getDisplayName());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a list of {@link Category} objects from a {@link
|
||||||
|
* ClassificationProto.ClassificationList}.
|
||||||
|
*
|
||||||
|
* @param classificationListProto the {@link ClassificationProto.ClassificationList} protobuf
|
||||||
|
* message to convert.
|
||||||
|
* @return A list of {@link Category} objects.
|
||||||
|
*/
|
||||||
|
public static List<Category> createListFromProto(ClassificationList classificationListProto) {
|
||||||
|
List<Category> categoryList = new ArrayList<>();
|
||||||
|
for (Classification classification : classificationListProto.getClassificationList()) {
|
||||||
|
categoryList.add(createFromProto(classification));
|
||||||
|
}
|
||||||
|
return categoryList;
|
||||||
|
}
|
||||||
|
|
||||||
/** The probability score of this label category. */
|
/** The probability score of this label category. */
|
||||||
public abstract float score();
|
public abstract float score();
|
||||||
|
|
||||||
|
|
|
@ -15,9 +15,7 @@
|
||||||
package com.google.mediapipe.tasks.components.containers;
|
package com.google.mediapipe.tasks.components.containers;
|
||||||
|
|
||||||
import com.google.auto.value.AutoValue;
|
import com.google.auto.value.AutoValue;
|
||||||
import com.google.mediapipe.formats.proto.ClassificationProto;
|
|
||||||
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
|
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
|
@ -49,11 +47,7 @@ public abstract class Classifications {
|
||||||
* @param proto the {@link ClassificationsProto.Classifications} protobuf message to convert.
|
* @param proto the {@link ClassificationsProto.Classifications} protobuf message to convert.
|
||||||
*/
|
*/
|
||||||
public static Classifications createFromProto(ClassificationsProto.Classifications proto) {
|
public static Classifications createFromProto(ClassificationsProto.Classifications proto) {
|
||||||
List<Category> categories = new ArrayList<>();
|
List<Category> categories = Category.createListFromProto(proto.getClassificationList());
|
||||||
for (ClassificationProto.Classification classificationProto :
|
|
||||||
proto.getClassificationList().getClassificationList()) {
|
|
||||||
categories.add(Category.createFromProto(classificationProto));
|
|
||||||
}
|
|
||||||
Optional<String> headName =
|
Optional<String> headName =
|
||||||
proto.hasHeadName() ? Optional.of(proto.getHeadName()) : Optional.empty();
|
proto.hasHeadName() ? Optional.of(proto.getHeadName()) : Optional.empty();
|
||||||
return create(categories, proto.getHeadIndex(), headName);
|
return create(categories, proto.getHeadIndex(), headName);
|
||||||
|
|
|
@ -208,11 +208,9 @@ android_library(
|
||||||
deps = [
|
deps = [
|
||||||
":core",
|
":core",
|
||||||
"//mediapipe/framework:calculator_options_java_proto_lite",
|
"//mediapipe/framework:calculator_options_java_proto_lite",
|
||||||
"//mediapipe/framework/formats:classification_java_proto_lite",
|
|
||||||
"//mediapipe/framework/formats:landmark_java_proto_lite",
|
"//mediapipe/framework/formats:landmark_java_proto_lite",
|
||||||
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
||||||
"//mediapipe/java/com/google/mediapipe/framework/image",
|
"//mediapipe/java/com/google/mediapipe/framework/image",
|
||||||
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite",
|
|
||||||
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
|
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
|
||||||
"//mediapipe/tasks/cc/vision/pose_detector/proto:pose_detector_graph_options_java_proto_lite",
|
"//mediapipe/tasks/cc/vision/pose_detector/proto:pose_detector_graph_options_java_proto_lite",
|
||||||
"//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarker_graph_options_java_proto_lite",
|
"//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarker_graph_options_java_proto_lite",
|
||||||
|
@ -222,7 +220,6 @@ android_library(
|
||||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:normalized_landmark",
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:normalized_landmark",
|
||||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
|
||||||
"//third_party:autovalue",
|
"//third_party:autovalue",
|
||||||
"@maven//:androidx_annotation_annotation",
|
|
||||||
"@maven//:com_google_guava_guava",
|
"@maven//:com_google_guava_guava",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -246,7 +243,6 @@ android_library(
|
||||||
"//mediapipe/framework/formats:landmark_java_proto_lite",
|
"//mediapipe/framework/formats:landmark_java_proto_lite",
|
||||||
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
||||||
"//mediapipe/java/com/google/mediapipe/framework/image",
|
"//mediapipe/java/com/google/mediapipe/framework/image",
|
||||||
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite",
|
|
||||||
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
|
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
|
||||||
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite",
|
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite",
|
||||||
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_java_proto_lite",
|
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_java_proto_lite",
|
||||||
|
|
|
@ -16,7 +16,6 @@ package com.google.mediapipe.tasks.vision.facelandmarker;
|
||||||
|
|
||||||
import com.google.auto.value.AutoValue;
|
import com.google.auto.value.AutoValue;
|
||||||
import com.google.mediapipe.formats.proto.LandmarkProto;
|
import com.google.mediapipe.formats.proto.LandmarkProto;
|
||||||
import com.google.mediapipe.formats.proto.ClassificationProto.Classification;
|
|
||||||
import com.google.mediapipe.formats.proto.ClassificationProto.ClassificationList;
|
import com.google.mediapipe.formats.proto.ClassificationProto.ClassificationList;
|
||||||
import com.google.mediapipe.tasks.components.containers.Category;
|
import com.google.mediapipe.tasks.components.containers.Category;
|
||||||
import com.google.mediapipe.tasks.components.containers.NormalizedLandmark;
|
import com.google.mediapipe.tasks.components.containers.NormalizedLandmark;
|
||||||
|
@ -68,16 +67,8 @@ public abstract class FaceLandmarkerResult implements TaskResult {
|
||||||
if (multiFaceBendshapesProto.isPresent()) {
|
if (multiFaceBendshapesProto.isPresent()) {
|
||||||
List<List<Category>> blendshapes = new ArrayList<>();
|
List<List<Category>> blendshapes = new ArrayList<>();
|
||||||
for (ClassificationList faceBendshapeProto : multiFaceBendshapesProto.get()) {
|
for (ClassificationList faceBendshapeProto : multiFaceBendshapesProto.get()) {
|
||||||
List<Category> blendshape = new ArrayList<>();
|
List<Category> blendshape = Category.createListFromProto(faceBendshapeProto);
|
||||||
blendshapes.add(blendshape);
|
blendshapes.add(Collections.unmodifiableList(blendshape));
|
||||||
for (Classification classification : faceBendshapeProto.getClassificationList()) {
|
|
||||||
blendshape.add(
|
|
||||||
Category.create(
|
|
||||||
classification.getScore(),
|
|
||||||
classification.getIndex(),
|
|
||||||
classification.getLabel(),
|
|
||||||
classification.getDisplayName()));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
multiFaceBlendshapes = Optional.of(Collections.unmodifiableList(blendshapes));
|
multiFaceBlendshapes = Optional.of(Collections.unmodifiableList(blendshapes));
|
||||||
}
|
}
|
||||||
|
|
|
@ -75,16 +75,8 @@ public abstract class GestureRecognizerResult implements TaskResult {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (ClassificationList handednessProto : handednessesProto) {
|
for (ClassificationList handednessProto : handednessesProto) {
|
||||||
List<Category> handedness = new ArrayList<>();
|
List<Category> handedness = Category.createListFromProto(handednessProto);
|
||||||
multiHandHandednesses.add(handedness);
|
multiHandHandednesses.add(Collections.unmodifiableList(handedness));
|
||||||
for (Classification classification : handednessProto.getClassificationList()) {
|
|
||||||
handedness.add(
|
|
||||||
Category.create(
|
|
||||||
classification.getScore(),
|
|
||||||
classification.getIndex(),
|
|
||||||
classification.getLabel(),
|
|
||||||
classification.getDisplayName()));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
for (ClassificationList gestureProto : gesturesProto) {
|
for (ClassificationList gestureProto : gesturesProto) {
|
||||||
List<Category> gestures = new ArrayList<>();
|
List<Category> gestures = new ArrayList<>();
|
||||||
|
|
|
@ -16,7 +16,6 @@ package com.google.mediapipe.tasks.vision.handlandmarker;
|
||||||
|
|
||||||
import com.google.auto.value.AutoValue;
|
import com.google.auto.value.AutoValue;
|
||||||
import com.google.mediapipe.formats.proto.LandmarkProto;
|
import com.google.mediapipe.formats.proto.LandmarkProto;
|
||||||
import com.google.mediapipe.formats.proto.ClassificationProto.Classification;
|
|
||||||
import com.google.mediapipe.formats.proto.ClassificationProto.ClassificationList;
|
import com.google.mediapipe.formats.proto.ClassificationProto.ClassificationList;
|
||||||
import com.google.mediapipe.tasks.components.containers.Category;
|
import com.google.mediapipe.tasks.components.containers.Category;
|
||||||
import com.google.mediapipe.tasks.components.containers.Landmark;
|
import com.google.mediapipe.tasks.components.containers.Landmark;
|
||||||
|
@ -84,16 +83,8 @@ public abstract class HandLandmarkerResult implements TaskResult {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (ClassificationList handednessProto : handednessesProto) {
|
for (ClassificationList handednessProto : handednessesProto) {
|
||||||
List<Category> handedness = new ArrayList<>();
|
List<Category> handedness = Category.createListFromProto(handednessProto);
|
||||||
multiHandHandednesses.add(handedness);
|
multiHandHandednesses.add(Collections.unmodifiableList(handedness));
|
||||||
for (Classification classification : handednessProto.getClassificationList()) {
|
|
||||||
handedness.add(
|
|
||||||
Category.create(
|
|
||||||
classification.getScore(),
|
|
||||||
classification.getIndex(),
|
|
||||||
classification.getLabel(),
|
|
||||||
classification.getDisplayName()));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return new AutoValue_HandLandmarkerResult(
|
return new AutoValue_HandLandmarkerResult(
|
||||||
timestampMs,
|
timestampMs,
|
||||||
|
|
|
@ -0,0 +1,24 @@
|
||||||
|
<?xml version="1.0" encoding="utf-8"?>
|
||||||
|
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||||
|
package="com.google.mediapipe.tasks.components.containerstest"
|
||||||
|
android:versionCode="1"
|
||||||
|
android:versionName="1.0" >
|
||||||
|
|
||||||
|
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE"/>
|
||||||
|
<uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE"/>
|
||||||
|
|
||||||
|
<uses-sdk android:minSdkVersion="24"
|
||||||
|
android:targetSdkVersion="31" />
|
||||||
|
|
||||||
|
<application
|
||||||
|
android:label="utilstest"
|
||||||
|
android:name="android.support.multidex.MultiDexApplication"
|
||||||
|
android:taskAffinity="">
|
||||||
|
<uses-library android:name="android.test.runner" />
|
||||||
|
</application>
|
||||||
|
|
||||||
|
<instrumentation
|
||||||
|
android:name="com.google.android.apps.common.testing.testrunner.GoogleInstrumentationTestRunner"
|
||||||
|
android:targetPackage="com.google.mediapipe.tasks.components.containerstest" />
|
||||||
|
|
||||||
|
</manifest>
|
|
@ -0,0 +1,19 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
# TODO: Enable these tests in OSS
|
|
@ -0,0 +1,52 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package com.google.mediapipe.tasks.components.containers;
|
||||||
|
|
||||||
|
import static com.google.common.truth.Truth.assertThat;
|
||||||
|
|
||||||
|
import androidx.test.ext.junit.runners.AndroidJUnit4;
|
||||||
|
import com.google.mediapipe.formats.proto.ClassificationProto.Classification;
|
||||||
|
import com.google.mediapipe.formats.proto.ClassificationProto.ClassificationList;
|
||||||
|
import java.util.List;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.junit.runner.RunWith;
|
||||||
|
|
||||||
|
@RunWith(AndroidJUnit4.class)
|
||||||
|
public final class CategoryTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void create_succeedsWithClassificationProto() {
|
||||||
|
Classification input =
|
||||||
|
Classification.newBuilder()
|
||||||
|
.setScore(0.1f)
|
||||||
|
.setIndex(1)
|
||||||
|
.setLabel("label")
|
||||||
|
.setDisplayName("displayName")
|
||||||
|
.build();
|
||||||
|
Category output = Category.createFromProto(input);
|
||||||
|
assertThat(output.score()).isEqualTo(0.1f);
|
||||||
|
assertThat(output.index()).isEqualTo(1);
|
||||||
|
assertThat(output.categoryName()).isEqualTo("label");
|
||||||
|
assertThat(output.displayName()).isEqualTo("displayName");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void create_succeedsWithClassificationListProto() {
|
||||||
|
Classification element = Classification.newBuilder().setScore(0.1f).build();
|
||||||
|
ClassificationList input = ClassificationList.newBuilder().addClassification(element).build();
|
||||||
|
List<Category> output = Category.createListFromProto(input);
|
||||||
|
assertThat(output).containsExactly(Category.create(0.1f, 0, "", ""));
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user