Add FaceDetectorResult

PiperOrigin-RevId: 515104977
This commit is contained in:
MediaPipe Team 2023-03-08 12:08:42 -08:00 committed by Copybara-Service
parent 09f63cbbe0
commit b8917ad31f
6 changed files with 141 additions and 31 deletions

View File

@ -44,6 +44,8 @@ android_library(
deps = [ deps = [
":category", ":category",
":normalizedkeypoint", ":normalizedkeypoint",
"//mediapipe/framework/formats:detection_java_proto_lite",
"//mediapipe/framework/formats:location_data_java_proto_lite",
"//third_party:autovalue", "//third_party:autovalue",
"@maven//:com_google_guava_guava", "@maven//:com_google_guava_guava",
], ],

View File

@ -16,6 +16,9 @@ package com.google.mediapipe.tasks.components.containers;
import android.graphics.RectF; import android.graphics.RectF;
import com.google.auto.value.AutoValue; import com.google.auto.value.AutoValue;
import com.google.mediapipe.formats.proto.LocationDataProto.LocationData.BoundingBox;
import com.google.mediapipe.formats.proto.LocationDataProto.LocationData.RelativeKeypoint;
import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
@ -27,6 +30,8 @@ import java.util.Optional;
@AutoValue @AutoValue
public abstract class Detection { public abstract class Detection {
private static final int DEFAULT_CATEGORY_INDEX = -1;
/** /**
* Creates a {@link Detection} instance from a list of {@link Category} and a bounding box. * Creates a {@link Detection} instance from a list of {@link Category} and a bounding box.
* *
@ -58,6 +63,59 @@ public abstract class Detection {
Collections.unmodifiableList(categories), boundingBox, keypoints); Collections.unmodifiableList(categories), boundingBox, keypoints);
} }
/**
* Creates a {@link Detection} instance from a {@link
* com.google.mediapipe.formats.proto.DetectionProto.Detection} protobuf message.
*
* @param detectionProto a {@link com.google.mediapipe.formats.proto.DetectionProto.Detection}
* protobuf message.
*/
public static Detection createFromProto(
com.google.mediapipe.formats.proto.DetectionProto.Detection detectionProto) {
List<Category> categories = new ArrayList<>();
for (int idx = 0; idx < detectionProto.getScoreCount(); ++idx) {
categories.add(
Category.create(
detectionProto.getScore(idx),
detectionProto.getLabelIdCount() > idx
? detectionProto.getLabelId(idx)
: DEFAULT_CATEGORY_INDEX,
detectionProto.getLabelCount() > idx ? detectionProto.getLabel(idx) : "",
detectionProto.getDisplayNameCount() > idx
? detectionProto.getDisplayName(idx)
: ""));
}
RectF boundingBox = new RectF();
if (detectionProto.getLocationData().hasBoundingBox()) {
BoundingBox boundingBoxProto = detectionProto.getLocationData().getBoundingBox();
boundingBox.set(
/* left= */ boundingBoxProto.getXmin(),
/* top= */ boundingBoxProto.getYmin(),
/* right= */ boundingBoxProto.getXmin() + boundingBoxProto.getWidth(),
/* bottom= */ boundingBoxProto.getYmin() + boundingBoxProto.getHeight());
}
Optional<List<NormalizedKeypoint>> keypoints = Optional.empty();
if (!detectionProto.getLocationData().getRelativeKeypointsList().isEmpty()) {
keypoints = Optional.of(new ArrayList<>());
for (RelativeKeypoint relativeKeypoint :
detectionProto.getLocationData().getRelativeKeypointsList()) {
keypoints
.get()
.add(
NormalizedKeypoint.create(
relativeKeypoint.getX(),
relativeKeypoint.getY(),
relativeKeypoint.hasKeypointLabel()
? Optional.of(relativeKeypoint.getKeypointLabel())
: Optional.empty(),
relativeKeypoint.hasScore()
? Optional.of(relativeKeypoint.getScore())
: Optional.empty()));
}
}
return create(categories, boundingBox, keypoints);
}
/** A list of {@link Category} objects. */ /** A list of {@link Category} objects. */
public abstract List<Category> categories(); public abstract List<Category> categories();

View File

@ -75,12 +75,10 @@ android_library(
":core", ":core",
"//mediapipe/framework:calculator_options_java_proto_lite", "//mediapipe/framework:calculator_options_java_proto_lite",
"//mediapipe/framework/formats:detection_java_proto_lite", "//mediapipe/framework/formats:detection_java_proto_lite",
"//mediapipe/framework/formats:location_data_java_proto_lite",
"//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/java/com/google/mediapipe/framework:android_framework",
"//mediapipe/java/com/google/mediapipe/framework/image", "//mediapipe/java/com/google/mediapipe/framework/image",
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_java_proto_lite", "//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_java_proto_lite",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:detection", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:detection",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
"//third_party:autovalue", "//third_party:autovalue",
@ -234,6 +232,27 @@ android_library(
], ],
) )
android_library(
name = "facedetector",
srcs = [
"facedetector/FaceDetectorResult.java",
],
javacopts = [
"-Xep:AndroidJdkLibsChecker:OFF",
],
manifest = "facedetector/AndroidManifest.xml",
deps = [
":core",
"//mediapipe/framework:calculator_options_java_proto_lite",
"//mediapipe/framework/formats:detection_java_proto_lite",
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:detection",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
"//third_party:autovalue",
"@maven//:com_google_guava_guava",
],
)
load("//mediapipe/tasks/java/com/google/mediapipe/tasks:mediapipe_tasks_aar.bzl", "mediapipe_tasks_vision_aar") load("//mediapipe/tasks/java/com/google/mediapipe/tasks:mediapipe_tasks_aar.bzl", "mediapipe_tasks_vision_aar")
mediapipe_tasks_vision_aar( mediapipe_tasks_vision_aar(

View File

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="com.google.mediapipe.tasks.vision.facedetector">
<uses-sdk android:minSdkVersion="24"
android:targetSdkVersion="30" />
</manifest>

View File

@ -0,0 +1,49 @@
// Copyright 2023 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.vision.facedetector;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.tasks.core.TaskResult;
import com.google.mediapipe.formats.proto.DetectionProto.Detection;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
/** Represents the detection results generated by {@link FaceDetector}. */
@AutoValue
public abstract class FaceDetectorResult implements TaskResult {
@Override
public abstract long timestampMs();
public abstract List<com.google.mediapipe.tasks.components.containers.Detection> detections();
/**
* Creates an {@link FaceDetectorResult} instance from a list of {@link Detection} protobuf
* messages.
*
* @param detectionList a list of {@link DetectionOuterClass.Detection} protobuf messages.
* @param timestampMs a timestamp for this result.
*/
public static FaceDetectorResult create(List<Detection> detectionList, long timestampMs) {
List<com.google.mediapipe.tasks.components.containers.Detection> detections = new ArrayList<>();
for (Detection detectionProto : detectionList) {
detections.add(
com.google.mediapipe.tasks.components.containers.Detection.createFromProto(
detectionProto));
}
return new AutoValue_FaceDetectorResult(timestampMs, Collections.unmodifiableList(detections));
}
}

View File

@ -14,12 +14,9 @@
package com.google.mediapipe.tasks.vision.objectdetector; package com.google.mediapipe.tasks.vision.objectdetector;
import android.graphics.RectF;
import com.google.auto.value.AutoValue; import com.google.auto.value.AutoValue;
import com.google.mediapipe.tasks.components.containers.Category;
import com.google.mediapipe.tasks.core.TaskResult; import com.google.mediapipe.tasks.core.TaskResult;
import com.google.mediapipe.formats.proto.DetectionProto.Detection; import com.google.mediapipe.formats.proto.DetectionProto.Detection;
import com.google.mediapipe.formats.proto.LocationDataProto.LocationData.BoundingBox;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
@ -27,7 +24,6 @@ import java.util.List;
/** Represents the detection results generated by {@link ObjectDetector}. */ /** Represents the detection results generated by {@link ObjectDetector}. */
@AutoValue @AutoValue
public abstract class ObjectDetectionResult implements TaskResult { public abstract class ObjectDetectionResult implements TaskResult {
private static final int DEFAULT_CATEGORY_INDEX = -1;
@Override @Override
public abstract long timestampMs(); public abstract long timestampMs();
@ -41,34 +37,12 @@ public abstract class ObjectDetectionResult implements TaskResult {
* @param detectionList a list of {@link DetectionOuterClass.Detection} protobuf messages. * @param detectionList a list of {@link DetectionOuterClass.Detection} protobuf messages.
* @param timestampMs a timestamp for this result. * @param timestampMs a timestamp for this result.
*/ */
static ObjectDetectionResult create(List<Detection> detectionList, long timestampMs) { public static ObjectDetectionResult create(List<Detection> detectionList, long timestampMs) {
List<com.google.mediapipe.tasks.components.containers.Detection> detections = new ArrayList<>(); List<com.google.mediapipe.tasks.components.containers.Detection> detections = new ArrayList<>();
for (Detection detectionProto : detectionList) { for (Detection detectionProto : detectionList) {
List<Category> categories = new ArrayList<>();
for (int idx = 0; idx < detectionProto.getScoreCount(); ++idx) {
categories.add(
Category.create(
detectionProto.getScore(idx),
detectionProto.getLabelIdCount() > idx
? detectionProto.getLabelId(idx)
: DEFAULT_CATEGORY_INDEX,
detectionProto.getLabelCount() > idx ? detectionProto.getLabel(idx) : "",
detectionProto.getDisplayNameCount() > idx
? detectionProto.getDisplayName(idx)
: ""));
}
RectF boundingBox = new RectF();
if (detectionProto.getLocationData().hasBoundingBox()) {
BoundingBox boundingBoxProto = detectionProto.getLocationData().getBoundingBox();
boundingBox.set(
/*left=*/ boundingBoxProto.getXmin(),
/*top=*/ boundingBoxProto.getYmin(),
/*right=*/ boundingBoxProto.getXmin() + boundingBoxProto.getWidth(),
/*bottom=*/ boundingBoxProto.getYmin() + boundingBoxProto.getHeight());
}
detections.add( detections.add(
com.google.mediapipe.tasks.components.containers.Detection.create( com.google.mediapipe.tasks.components.containers.Detection.createFromProto(
categories, boundingBox)); detectionProto));
} }
return new AutoValue_ObjectDetectionResult( return new AutoValue_ObjectDetectionResult(
timestampMs, Collections.unmodifiableList(detections)); timestampMs, Collections.unmodifiableList(detections));