Add FaceDetectorResult
PiperOrigin-RevId: 515104977
This commit is contained in:
		
							parent
							
								
									09f63cbbe0
								
							
						
					
					
						commit
						b8917ad31f
					
				|  | @ -44,6 +44,8 @@ android_library( | ||||||
|     deps = [ |     deps = [ | ||||||
|         ":category", |         ":category", | ||||||
|         ":normalizedkeypoint", |         ":normalizedkeypoint", | ||||||
|  |         "//mediapipe/framework/formats:detection_java_proto_lite", | ||||||
|  |         "//mediapipe/framework/formats:location_data_java_proto_lite", | ||||||
|         "//third_party:autovalue", |         "//third_party:autovalue", | ||||||
|         "@maven//:com_google_guava_guava", |         "@maven//:com_google_guava_guava", | ||||||
|     ], |     ], | ||||||
|  |  | ||||||
|  | @ -16,6 +16,9 @@ package com.google.mediapipe.tasks.components.containers; | ||||||
| 
 | 
 | ||||||
| import android.graphics.RectF; | import android.graphics.RectF; | ||||||
| import com.google.auto.value.AutoValue; | import com.google.auto.value.AutoValue; | ||||||
|  | import com.google.mediapipe.formats.proto.LocationDataProto.LocationData.BoundingBox; | ||||||
|  | import com.google.mediapipe.formats.proto.LocationDataProto.LocationData.RelativeKeypoint; | ||||||
|  | import java.util.ArrayList; | ||||||
| import java.util.Collections; | import java.util.Collections; | ||||||
| import java.util.List; | import java.util.List; | ||||||
| import java.util.Optional; | import java.util.Optional; | ||||||
|  | @ -27,6 +30,8 @@ import java.util.Optional; | ||||||
| @AutoValue | @AutoValue | ||||||
| public abstract class Detection { | public abstract class Detection { | ||||||
| 
 | 
 | ||||||
|  |   private static final int DEFAULT_CATEGORY_INDEX = -1; | ||||||
|  | 
 | ||||||
|   /** |   /** | ||||||
|    * Creates a {@link Detection} instance from a list of {@link Category} and a bounding box. |    * Creates a {@link Detection} instance from a list of {@link Category} and a bounding box. | ||||||
|    * |    * | ||||||
|  | @ -58,6 +63,59 @@ public abstract class Detection { | ||||||
|         Collections.unmodifiableList(categories), boundingBox, keypoints); |         Collections.unmodifiableList(categories), boundingBox, keypoints); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  |   /** | ||||||
|  |    * Creates a {@link Detection} instance from a {@link | ||||||
|  |    * com.google.mediapipe.formats.proto.DetectionProto.Detection} protobuf message. | ||||||
|  |    * | ||||||
|  |    * @param detectionProto a {@link com.google.mediapipe.formats.proto.DetectionProto.Detection} | ||||||
|  |    *     protobuf message. | ||||||
|  |    */ | ||||||
|  |   public static Detection createFromProto( | ||||||
|  |       com.google.mediapipe.formats.proto.DetectionProto.Detection detectionProto) { | ||||||
|  |     List<Category> categories = new ArrayList<>(); | ||||||
|  |     for (int idx = 0; idx < detectionProto.getScoreCount(); ++idx) { | ||||||
|  |       categories.add( | ||||||
|  |           Category.create( | ||||||
|  |               detectionProto.getScore(idx), | ||||||
|  |               detectionProto.getLabelIdCount() > idx | ||||||
|  |                   ? detectionProto.getLabelId(idx) | ||||||
|  |                   : DEFAULT_CATEGORY_INDEX, | ||||||
|  |               detectionProto.getLabelCount() > idx ? detectionProto.getLabel(idx) : "", | ||||||
|  |               detectionProto.getDisplayNameCount() > idx | ||||||
|  |                   ? detectionProto.getDisplayName(idx) | ||||||
|  |                   : "")); | ||||||
|  |     } | ||||||
|  |     RectF boundingBox = new RectF(); | ||||||
|  |     if (detectionProto.getLocationData().hasBoundingBox()) { | ||||||
|  |       BoundingBox boundingBoxProto = detectionProto.getLocationData().getBoundingBox(); | ||||||
|  |       boundingBox.set( | ||||||
|  |           /* left= */ boundingBoxProto.getXmin(), | ||||||
|  |           /* top= */ boundingBoxProto.getYmin(), | ||||||
|  |           /* right= */ boundingBoxProto.getXmin() + boundingBoxProto.getWidth(), | ||||||
|  |           /* bottom= */ boundingBoxProto.getYmin() + boundingBoxProto.getHeight()); | ||||||
|  |     } | ||||||
|  |     Optional<List<NormalizedKeypoint>> keypoints = Optional.empty(); | ||||||
|  |     if (!detectionProto.getLocationData().getRelativeKeypointsList().isEmpty()) { | ||||||
|  |       keypoints = Optional.of(new ArrayList<>()); | ||||||
|  |       for (RelativeKeypoint relativeKeypoint : | ||||||
|  |           detectionProto.getLocationData().getRelativeKeypointsList()) { | ||||||
|  |         keypoints | ||||||
|  |             .get() | ||||||
|  |             .add( | ||||||
|  |                 NormalizedKeypoint.create( | ||||||
|  |                     relativeKeypoint.getX(), | ||||||
|  |                     relativeKeypoint.getY(), | ||||||
|  |                     relativeKeypoint.hasKeypointLabel() | ||||||
|  |                         ? Optional.of(relativeKeypoint.getKeypointLabel()) | ||||||
|  |                         : Optional.empty(), | ||||||
|  |                     relativeKeypoint.hasScore() | ||||||
|  |                         ? Optional.of(relativeKeypoint.getScore()) | ||||||
|  |                         : Optional.empty())); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |     return create(categories, boundingBox, keypoints); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|   /** A list of {@link Category} objects. */ |   /** A list of {@link Category} objects. */ | ||||||
|   public abstract List<Category> categories(); |   public abstract List<Category> categories(); | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -75,12 +75,10 @@ android_library( | ||||||
|         ":core", |         ":core", | ||||||
|         "//mediapipe/framework:calculator_options_java_proto_lite", |         "//mediapipe/framework:calculator_options_java_proto_lite", | ||||||
|         "//mediapipe/framework/formats:detection_java_proto_lite", |         "//mediapipe/framework/formats:detection_java_proto_lite", | ||||||
|         "//mediapipe/framework/formats:location_data_java_proto_lite", |  | ||||||
|         "//mediapipe/java/com/google/mediapipe/framework:android_framework", |         "//mediapipe/java/com/google/mediapipe/framework:android_framework", | ||||||
|         "//mediapipe/java/com/google/mediapipe/framework/image", |         "//mediapipe/java/com/google/mediapipe/framework/image", | ||||||
|         "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", |         "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", | ||||||
|         "//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_java_proto_lite", |         "//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_java_proto_lite", | ||||||
|         "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category", |  | ||||||
|         "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:detection", |         "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:detection", | ||||||
|         "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", |         "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", | ||||||
|         "//third_party:autovalue", |         "//third_party:autovalue", | ||||||
|  | @ -234,6 +232,27 @@ android_library( | ||||||
|     ], |     ], | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | android_library( | ||||||
|  |     name = "facedetector", | ||||||
|  |     srcs = [ | ||||||
|  |         "facedetector/FaceDetectorResult.java", | ||||||
|  |     ], | ||||||
|  |     javacopts = [ | ||||||
|  |         "-Xep:AndroidJdkLibsChecker:OFF", | ||||||
|  |     ], | ||||||
|  |     manifest = "facedetector/AndroidManifest.xml", | ||||||
|  |     deps = [ | ||||||
|  |         ":core", | ||||||
|  |         "//mediapipe/framework:calculator_options_java_proto_lite", | ||||||
|  |         "//mediapipe/framework/formats:detection_java_proto_lite", | ||||||
|  |         "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", | ||||||
|  |         "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:detection", | ||||||
|  |         "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", | ||||||
|  |         "//third_party:autovalue", | ||||||
|  |         "@maven//:com_google_guava_guava", | ||||||
|  |     ], | ||||||
|  | ) | ||||||
|  | 
 | ||||||
| load("//mediapipe/tasks/java/com/google/mediapipe/tasks:mediapipe_tasks_aar.bzl", "mediapipe_tasks_vision_aar") | load("//mediapipe/tasks/java/com/google/mediapipe/tasks:mediapipe_tasks_aar.bzl", "mediapipe_tasks_vision_aar") | ||||||
| 
 | 
 | ||||||
| mediapipe_tasks_vision_aar( | mediapipe_tasks_vision_aar( | ||||||
|  |  | ||||||
|  | @ -0,0 +1,8 @@ | ||||||
|  | <?xml version="1.0" encoding="utf-8"?> | ||||||
|  | <manifest xmlns:android="http://schemas.android.com/apk/res/android" | ||||||
|  |     package="com.google.mediapipe.tasks.vision.facedetector"> | ||||||
|  | 
 | ||||||
|  |     <uses-sdk android:minSdkVersion="24" | ||||||
|  |         android:targetSdkVersion="30" /> | ||||||
|  | 
 | ||||||
|  | </manifest> | ||||||
|  | @ -0,0 +1,49 @@ | ||||||
|  | // Copyright 2023 The MediaPipe Authors. All Rights Reserved. | ||||||
|  | // | ||||||
|  | // Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | // you may not use this file except in compliance with the License. | ||||||
|  | // You may obtain a copy of the License at | ||||||
|  | // | ||||||
|  | //      http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  | // | ||||||
|  | // Unless required by applicable law or agreed to in writing, software | ||||||
|  | // distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | // See the License for the specific language governing permissions and | ||||||
|  | // limitations under the License. | ||||||
|  | 
 | ||||||
|  | package com.google.mediapipe.tasks.vision.facedetector; | ||||||
|  | 
 | ||||||
|  | import com.google.auto.value.AutoValue; | ||||||
|  | import com.google.mediapipe.tasks.core.TaskResult; | ||||||
|  | import com.google.mediapipe.formats.proto.DetectionProto.Detection; | ||||||
|  | import java.util.ArrayList; | ||||||
|  | import java.util.Collections; | ||||||
|  | import java.util.List; | ||||||
|  | 
 | ||||||
|  | /** Represents the detection results generated by {@link FaceDetector}. */ | ||||||
|  | @AutoValue | ||||||
|  | public abstract class FaceDetectorResult implements TaskResult { | ||||||
|  | 
 | ||||||
|  |   @Override | ||||||
|  |   public abstract long timestampMs(); | ||||||
|  | 
 | ||||||
|  |   public abstract List<com.google.mediapipe.tasks.components.containers.Detection> detections(); | ||||||
|  | 
 | ||||||
|  |   /** | ||||||
|  |    * Creates an {@link FaceDetectorResult} instance from a list of {@link Detection} protobuf | ||||||
|  |    * messages. | ||||||
|  |    * | ||||||
|  |    * @param detectionList a list of {@link DetectionOuterClass.Detection} protobuf messages. | ||||||
|  |    * @param timestampMs a timestamp for this result. | ||||||
|  |    */ | ||||||
|  |   public static FaceDetectorResult create(List<Detection> detectionList, long timestampMs) { | ||||||
|  |     List<com.google.mediapipe.tasks.components.containers.Detection> detections = new ArrayList<>(); | ||||||
|  |     for (Detection detectionProto : detectionList) { | ||||||
|  |       detections.add( | ||||||
|  |           com.google.mediapipe.tasks.components.containers.Detection.createFromProto( | ||||||
|  |               detectionProto)); | ||||||
|  |     } | ||||||
|  |     return new AutoValue_FaceDetectorResult(timestampMs, Collections.unmodifiableList(detections)); | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | @ -14,12 +14,9 @@ | ||||||
| 
 | 
 | ||||||
| package com.google.mediapipe.tasks.vision.objectdetector; | package com.google.mediapipe.tasks.vision.objectdetector; | ||||||
| 
 | 
 | ||||||
| import android.graphics.RectF; |  | ||||||
| import com.google.auto.value.AutoValue; | import com.google.auto.value.AutoValue; | ||||||
| import com.google.mediapipe.tasks.components.containers.Category; |  | ||||||
| import com.google.mediapipe.tasks.core.TaskResult; | import com.google.mediapipe.tasks.core.TaskResult; | ||||||
| import com.google.mediapipe.formats.proto.DetectionProto.Detection; | import com.google.mediapipe.formats.proto.DetectionProto.Detection; | ||||||
| import com.google.mediapipe.formats.proto.LocationDataProto.LocationData.BoundingBox; |  | ||||||
| import java.util.ArrayList; | import java.util.ArrayList; | ||||||
| import java.util.Collections; | import java.util.Collections; | ||||||
| import java.util.List; | import java.util.List; | ||||||
|  | @ -27,7 +24,6 @@ import java.util.List; | ||||||
| /** Represents the detection results generated by {@link ObjectDetector}. */ | /** Represents the detection results generated by {@link ObjectDetector}. */ | ||||||
| @AutoValue | @AutoValue | ||||||
| public abstract class ObjectDetectionResult implements TaskResult { | public abstract class ObjectDetectionResult implements TaskResult { | ||||||
|   private static final int DEFAULT_CATEGORY_INDEX = -1; |  | ||||||
| 
 | 
 | ||||||
|   @Override |   @Override | ||||||
|   public abstract long timestampMs(); |   public abstract long timestampMs(); | ||||||
|  | @ -41,34 +37,12 @@ public abstract class ObjectDetectionResult implements TaskResult { | ||||||
|    * @param detectionList a list of {@link DetectionOuterClass.Detection} protobuf messages. |    * @param detectionList a list of {@link DetectionOuterClass.Detection} protobuf messages. | ||||||
|    * @param timestampMs a timestamp for this result. |    * @param timestampMs a timestamp for this result. | ||||||
|    */ |    */ | ||||||
|   static ObjectDetectionResult create(List<Detection> detectionList, long timestampMs) { |   public static ObjectDetectionResult create(List<Detection> detectionList, long timestampMs) { | ||||||
|     List<com.google.mediapipe.tasks.components.containers.Detection> detections = new ArrayList<>(); |     List<com.google.mediapipe.tasks.components.containers.Detection> detections = new ArrayList<>(); | ||||||
|     for (Detection detectionProto : detectionList) { |     for (Detection detectionProto : detectionList) { | ||||||
|       List<Category> categories = new ArrayList<>(); |  | ||||||
|       for (int idx = 0; idx < detectionProto.getScoreCount(); ++idx) { |  | ||||||
|         categories.add( |  | ||||||
|             Category.create( |  | ||||||
|                 detectionProto.getScore(idx), |  | ||||||
|                 detectionProto.getLabelIdCount() > idx |  | ||||||
|                     ? detectionProto.getLabelId(idx) |  | ||||||
|                     : DEFAULT_CATEGORY_INDEX, |  | ||||||
|                 detectionProto.getLabelCount() > idx ? detectionProto.getLabel(idx) : "", |  | ||||||
|                 detectionProto.getDisplayNameCount() > idx |  | ||||||
|                     ? detectionProto.getDisplayName(idx) |  | ||||||
|                     : "")); |  | ||||||
|       } |  | ||||||
|       RectF boundingBox = new RectF(); |  | ||||||
|       if (detectionProto.getLocationData().hasBoundingBox()) { |  | ||||||
|         BoundingBox boundingBoxProto = detectionProto.getLocationData().getBoundingBox(); |  | ||||||
|         boundingBox.set( |  | ||||||
|             /*left=*/ boundingBoxProto.getXmin(), |  | ||||||
|             /*top=*/ boundingBoxProto.getYmin(), |  | ||||||
|             /*right=*/ boundingBoxProto.getXmin() + boundingBoxProto.getWidth(), |  | ||||||
|             /*bottom=*/ boundingBoxProto.getYmin() + boundingBoxProto.getHeight()); |  | ||||||
|       } |  | ||||||
|       detections.add( |       detections.add( | ||||||
|           com.google.mediapipe.tasks.components.containers.Detection.create( |           com.google.mediapipe.tasks.components.containers.Detection.createFromProto( | ||||||
|               categories, boundingBox)); |               detectionProto)); | ||||||
|     } |     } | ||||||
|     return new AutoValue_ObjectDetectionResult( |     return new AutoValue_ObjectDetectionResult( | ||||||
|         timestampMs, Collections.unmodifiableList(detections)); |         timestampMs, Collections.unmodifiableList(detections)); | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user