Add FaceDetectorResult
PiperOrigin-RevId: 515104977
This commit is contained in:
		
							parent
							
								
									09f63cbbe0
								
							
						
					
					
						commit
						b8917ad31f
					
				| 
						 | 
				
			
			@ -44,6 +44,8 @@ android_library(
 | 
			
		|||
    deps = [
 | 
			
		||||
        ":category",
 | 
			
		||||
        ":normalizedkeypoint",
 | 
			
		||||
        "//mediapipe/framework/formats:detection_java_proto_lite",
 | 
			
		||||
        "//mediapipe/framework/formats:location_data_java_proto_lite",
 | 
			
		||||
        "//third_party:autovalue",
 | 
			
		||||
        "@maven//:com_google_guava_guava",
 | 
			
		||||
    ],
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -16,6 +16,9 @@ package com.google.mediapipe.tasks.components.containers;
 | 
			
		|||
 | 
			
		||||
import android.graphics.RectF;
 | 
			
		||||
import com.google.auto.value.AutoValue;
 | 
			
		||||
import com.google.mediapipe.formats.proto.LocationDataProto.LocationData.BoundingBox;
 | 
			
		||||
import com.google.mediapipe.formats.proto.LocationDataProto.LocationData.RelativeKeypoint;
 | 
			
		||||
import java.util.ArrayList;
 | 
			
		||||
import java.util.Collections;
 | 
			
		||||
import java.util.List;
 | 
			
		||||
import java.util.Optional;
 | 
			
		||||
| 
						 | 
				
			
			@ -27,6 +30,8 @@ import java.util.Optional;
 | 
			
		|||
@AutoValue
 | 
			
		||||
public abstract class Detection {
 | 
			
		||||
 | 
			
		||||
  private static final int DEFAULT_CATEGORY_INDEX = -1;
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * Creates a {@link Detection} instance from a list of {@link Category} and a bounding box.
 | 
			
		||||
   *
 | 
			
		||||
| 
						 | 
				
			
			@ -58,6 +63,59 @@ public abstract class Detection {
 | 
			
		|||
        Collections.unmodifiableList(categories), boundingBox, keypoints);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * Creates a {@link Detection} instance from a {@link
 | 
			
		||||
   * com.google.mediapipe.formats.proto.DetectionProto.Detection} protobuf message.
 | 
			
		||||
   *
 | 
			
		||||
   * @param detectionProto a {@link com.google.mediapipe.formats.proto.DetectionProto.Detection}
 | 
			
		||||
   *     protobuf message.
 | 
			
		||||
   */
 | 
			
		||||
  public static Detection createFromProto(
 | 
			
		||||
      com.google.mediapipe.formats.proto.DetectionProto.Detection detectionProto) {
 | 
			
		||||
    List<Category> categories = new ArrayList<>();
 | 
			
		||||
    for (int idx = 0; idx < detectionProto.getScoreCount(); ++idx) {
 | 
			
		||||
      categories.add(
 | 
			
		||||
          Category.create(
 | 
			
		||||
              detectionProto.getScore(idx),
 | 
			
		||||
              detectionProto.getLabelIdCount() > idx
 | 
			
		||||
                  ? detectionProto.getLabelId(idx)
 | 
			
		||||
                  : DEFAULT_CATEGORY_INDEX,
 | 
			
		||||
              detectionProto.getLabelCount() > idx ? detectionProto.getLabel(idx) : "",
 | 
			
		||||
              detectionProto.getDisplayNameCount() > idx
 | 
			
		||||
                  ? detectionProto.getDisplayName(idx)
 | 
			
		||||
                  : ""));
 | 
			
		||||
    }
 | 
			
		||||
    RectF boundingBox = new RectF();
 | 
			
		||||
    if (detectionProto.getLocationData().hasBoundingBox()) {
 | 
			
		||||
      BoundingBox boundingBoxProto = detectionProto.getLocationData().getBoundingBox();
 | 
			
		||||
      boundingBox.set(
 | 
			
		||||
          /* left= */ boundingBoxProto.getXmin(),
 | 
			
		||||
          /* top= */ boundingBoxProto.getYmin(),
 | 
			
		||||
          /* right= */ boundingBoxProto.getXmin() + boundingBoxProto.getWidth(),
 | 
			
		||||
          /* bottom= */ boundingBoxProto.getYmin() + boundingBoxProto.getHeight());
 | 
			
		||||
    }
 | 
			
		||||
    Optional<List<NormalizedKeypoint>> keypoints = Optional.empty();
 | 
			
		||||
    if (!detectionProto.getLocationData().getRelativeKeypointsList().isEmpty()) {
 | 
			
		||||
      keypoints = Optional.of(new ArrayList<>());
 | 
			
		||||
      for (RelativeKeypoint relativeKeypoint :
 | 
			
		||||
          detectionProto.getLocationData().getRelativeKeypointsList()) {
 | 
			
		||||
        keypoints
 | 
			
		||||
            .get()
 | 
			
		||||
            .add(
 | 
			
		||||
                NormalizedKeypoint.create(
 | 
			
		||||
                    relativeKeypoint.getX(),
 | 
			
		||||
                    relativeKeypoint.getY(),
 | 
			
		||||
                    relativeKeypoint.hasKeypointLabel()
 | 
			
		||||
                        ? Optional.of(relativeKeypoint.getKeypointLabel())
 | 
			
		||||
                        : Optional.empty(),
 | 
			
		||||
                    relativeKeypoint.hasScore()
 | 
			
		||||
                        ? Optional.of(relativeKeypoint.getScore())
 | 
			
		||||
                        : Optional.empty()));
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
    return create(categories, boundingBox, keypoints);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /** A list of {@link Category} objects. */
 | 
			
		||||
  public abstract List<Category> categories();
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -75,12 +75,10 @@ android_library(
 | 
			
		|||
        ":core",
 | 
			
		||||
        "//mediapipe/framework:calculator_options_java_proto_lite",
 | 
			
		||||
        "//mediapipe/framework/formats:detection_java_proto_lite",
 | 
			
		||||
        "//mediapipe/framework/formats:location_data_java_proto_lite",
 | 
			
		||||
        "//mediapipe/java/com/google/mediapipe/framework:android_framework",
 | 
			
		||||
        "//mediapipe/java/com/google/mediapipe/framework/image",
 | 
			
		||||
        "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
 | 
			
		||||
        "//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_java_proto_lite",
 | 
			
		||||
        "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category",
 | 
			
		||||
        "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:detection",
 | 
			
		||||
        "//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
 | 
			
		||||
        "//third_party:autovalue",
 | 
			
		||||
| 
						 | 
				
			
			@ -234,6 +232,27 @@ android_library(
 | 
			
		|||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
android_library(
 | 
			
		||||
    name = "facedetector",
 | 
			
		||||
    srcs = [
 | 
			
		||||
        "facedetector/FaceDetectorResult.java",
 | 
			
		||||
    ],
 | 
			
		||||
    javacopts = [
 | 
			
		||||
        "-Xep:AndroidJdkLibsChecker:OFF",
 | 
			
		||||
    ],
 | 
			
		||||
    manifest = "facedetector/AndroidManifest.xml",
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":core",
 | 
			
		||||
        "//mediapipe/framework:calculator_options_java_proto_lite",
 | 
			
		||||
        "//mediapipe/framework/formats:detection_java_proto_lite",
 | 
			
		||||
        "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
 | 
			
		||||
        "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:detection",
 | 
			
		||||
        "//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
 | 
			
		||||
        "//third_party:autovalue",
 | 
			
		||||
        "@maven//:com_google_guava_guava",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
load("//mediapipe/tasks/java/com/google/mediapipe/tasks:mediapipe_tasks_aar.bzl", "mediapipe_tasks_vision_aar")
 | 
			
		||||
 | 
			
		||||
mediapipe_tasks_vision_aar(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,8 @@
 | 
			
		|||
<?xml version="1.0" encoding="utf-8"?>
 | 
			
		||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
 | 
			
		||||
    package="com.google.mediapipe.tasks.vision.facedetector">
 | 
			
		||||
 | 
			
		||||
    <uses-sdk android:minSdkVersion="24"
 | 
			
		||||
        android:targetSdkVersion="30" />
 | 
			
		||||
 | 
			
		||||
</manifest>
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,49 @@
 | 
			
		|||
// Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | 
			
		||||
//
 | 
			
		||||
// Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
// you may not use this file except in compliance with the License.
 | 
			
		||||
// You may obtain a copy of the License at
 | 
			
		||||
//
 | 
			
		||||
//      http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
//
 | 
			
		||||
// Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
// distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
// See the License for the specific language governing permissions and
 | 
			
		||||
// limitations under the License.
 | 
			
		||||
 | 
			
		||||
package com.google.mediapipe.tasks.vision.facedetector;
 | 
			
		||||
 | 
			
		||||
import com.google.auto.value.AutoValue;
 | 
			
		||||
import com.google.mediapipe.tasks.core.TaskResult;
 | 
			
		||||
import com.google.mediapipe.formats.proto.DetectionProto.Detection;
 | 
			
		||||
import java.util.ArrayList;
 | 
			
		||||
import java.util.Collections;
 | 
			
		||||
import java.util.List;
 | 
			
		||||
 | 
			
		||||
/** Represents the detection results generated by {@link FaceDetector}. */
 | 
			
		||||
@AutoValue
 | 
			
		||||
public abstract class FaceDetectorResult implements TaskResult {
 | 
			
		||||
 | 
			
		||||
  @Override
 | 
			
		||||
  public abstract long timestampMs();
 | 
			
		||||
 | 
			
		||||
  public abstract List<com.google.mediapipe.tasks.components.containers.Detection> detections();
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * Creates an {@link FaceDetectorResult} instance from a list of {@link Detection} protobuf
 | 
			
		||||
   * messages.
 | 
			
		||||
   *
 | 
			
		||||
   * @param detectionList a list of {@link DetectionOuterClass.Detection} protobuf messages.
 | 
			
		||||
   * @param timestampMs a timestamp for this result.
 | 
			
		||||
   */
 | 
			
		||||
  public static FaceDetectorResult create(List<Detection> detectionList, long timestampMs) {
 | 
			
		||||
    List<com.google.mediapipe.tasks.components.containers.Detection> detections = new ArrayList<>();
 | 
			
		||||
    for (Detection detectionProto : detectionList) {
 | 
			
		||||
      detections.add(
 | 
			
		||||
          com.google.mediapipe.tasks.components.containers.Detection.createFromProto(
 | 
			
		||||
              detectionProto));
 | 
			
		||||
    }
 | 
			
		||||
    return new AutoValue_FaceDetectorResult(timestampMs, Collections.unmodifiableList(detections));
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -14,12 +14,9 @@
 | 
			
		|||
 | 
			
		||||
package com.google.mediapipe.tasks.vision.objectdetector;
 | 
			
		||||
 | 
			
		||||
import android.graphics.RectF;
 | 
			
		||||
import com.google.auto.value.AutoValue;
 | 
			
		||||
import com.google.mediapipe.tasks.components.containers.Category;
 | 
			
		||||
import com.google.mediapipe.tasks.core.TaskResult;
 | 
			
		||||
import com.google.mediapipe.formats.proto.DetectionProto.Detection;
 | 
			
		||||
import com.google.mediapipe.formats.proto.LocationDataProto.LocationData.BoundingBox;
 | 
			
		||||
import java.util.ArrayList;
 | 
			
		||||
import java.util.Collections;
 | 
			
		||||
import java.util.List;
 | 
			
		||||
| 
						 | 
				
			
			@ -27,7 +24,6 @@ import java.util.List;
 | 
			
		|||
/** Represents the detection results generated by {@link ObjectDetector}. */
 | 
			
		||||
@AutoValue
 | 
			
		||||
public abstract class ObjectDetectionResult implements TaskResult {
 | 
			
		||||
  private static final int DEFAULT_CATEGORY_INDEX = -1;
 | 
			
		||||
 | 
			
		||||
  @Override
 | 
			
		||||
  public abstract long timestampMs();
 | 
			
		||||
| 
						 | 
				
			
			@ -41,34 +37,12 @@ public abstract class ObjectDetectionResult implements TaskResult {
 | 
			
		|||
   * @param detectionList a list of {@link DetectionOuterClass.Detection} protobuf messages.
 | 
			
		||||
   * @param timestampMs a timestamp for this result.
 | 
			
		||||
   */
 | 
			
		||||
  static ObjectDetectionResult create(List<Detection> detectionList, long timestampMs) {
 | 
			
		||||
  public static ObjectDetectionResult create(List<Detection> detectionList, long timestampMs) {
 | 
			
		||||
    List<com.google.mediapipe.tasks.components.containers.Detection> detections = new ArrayList<>();
 | 
			
		||||
    for (Detection detectionProto : detectionList) {
 | 
			
		||||
      List<Category> categories = new ArrayList<>();
 | 
			
		||||
      for (int idx = 0; idx < detectionProto.getScoreCount(); ++idx) {
 | 
			
		||||
        categories.add(
 | 
			
		||||
            Category.create(
 | 
			
		||||
                detectionProto.getScore(idx),
 | 
			
		||||
                detectionProto.getLabelIdCount() > idx
 | 
			
		||||
                    ? detectionProto.getLabelId(idx)
 | 
			
		||||
                    : DEFAULT_CATEGORY_INDEX,
 | 
			
		||||
                detectionProto.getLabelCount() > idx ? detectionProto.getLabel(idx) : "",
 | 
			
		||||
                detectionProto.getDisplayNameCount() > idx
 | 
			
		||||
                    ? detectionProto.getDisplayName(idx)
 | 
			
		||||
                    : ""));
 | 
			
		||||
      }
 | 
			
		||||
      RectF boundingBox = new RectF();
 | 
			
		||||
      if (detectionProto.getLocationData().hasBoundingBox()) {
 | 
			
		||||
        BoundingBox boundingBoxProto = detectionProto.getLocationData().getBoundingBox();
 | 
			
		||||
        boundingBox.set(
 | 
			
		||||
            /*left=*/ boundingBoxProto.getXmin(),
 | 
			
		||||
            /*top=*/ boundingBoxProto.getYmin(),
 | 
			
		||||
            /*right=*/ boundingBoxProto.getXmin() + boundingBoxProto.getWidth(),
 | 
			
		||||
            /*bottom=*/ boundingBoxProto.getYmin() + boundingBoxProto.getHeight());
 | 
			
		||||
      }
 | 
			
		||||
      detections.add(
 | 
			
		||||
          com.google.mediapipe.tasks.components.containers.Detection.create(
 | 
			
		||||
              categories, boundingBox));
 | 
			
		||||
          com.google.mediapipe.tasks.components.containers.Detection.createFromProto(
 | 
			
		||||
              detectionProto));
 | 
			
		||||
    }
 | 
			
		||||
    return new AutoValue_ObjectDetectionResult(
 | 
			
		||||
        timestampMs, Collections.unmodifiableList(detections));
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user