From d9f316e12a1e59b6d9daa45c64431d29d8700310 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Wed, 24 May 2023 08:53:19 -0700 Subject: [PATCH] Rename ObjectDetctionResult to ObjectDetectorResult PiperOrigin-RevId: 534858600 --- .../com/google/mediapipe/tasks/vision/BUILD | 1 + .../objectdetector/ObjectDetectionResult.java | 22 ++++---- .../vision/objectdetector/ObjectDetector.java | 34 +++++++------ .../objectdetector/ObjectDetectorResult.java | 44 ++++++++++++++++ .../objectdetector/ObjectDetectorTest.java | 50 +++++++++++-------- 5 files changed, 104 insertions(+), 47 deletions(-) create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetectorResult.java diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD index 399156da3..cbb1797e2 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD @@ -71,6 +71,7 @@ android_library( srcs = [ "objectdetector/ObjectDetectionResult.java", "objectdetector/ObjectDetector.java", + "objectdetector/ObjectDetectorResult.java", ], javacopts = [ "-Xep:AndroidJdkLibsChecker:OFF", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetectionResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetectionResult.java index 120cddd46..49ab0ae2b 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetectionResult.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetectionResult.java @@ -14,15 +14,16 @@ package com.google.mediapipe.tasks.vision.objectdetector; -import com.google.auto.value.AutoValue; import com.google.mediapipe.tasks.core.TaskResult; import com.google.mediapipe.formats.proto.DetectionProto.Detection; -import java.util.ArrayList; -import java.util.Collections; import java.util.List; -/** Represents the detection results generated by {@link ObjectDetector}. */ -@AutoValue +/** + * Represents the detection results generated by {@link ObjectDetector}. + * + * @deprecated Use {@link ObjectDetectorResult} instead. + */ +@Deprecated public abstract class ObjectDetectionResult implements TaskResult { @Override @@ -36,15 +37,10 @@ public abstract class ObjectDetectionResult implements TaskResult { * * @param detectionList a list of {@link DetectionOuterClass.Detection} protobuf messages. * @param timestampMs a timestamp for this result. + * @deprecated Use {@link ObjectDetectorResult#create} instead. */ + @Deprecated public static ObjectDetectionResult create(List detectionList, long timestampMs) { - List detections = new ArrayList<>(); - for (Detection detectionProto : detectionList) { - detections.add( - com.google.mediapipe.tasks.components.containers.Detection.createFromProto( - detectionProto)); - } - return new AutoValue_ObjectDetectionResult( - timestampMs, Collections.unmodifiableList(detections)); + return ObjectDetectorResult.create(detectionList, timestampMs); } } diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java index d9a36cce7..0c70a119d 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java @@ -99,11 +99,16 @@ public final class ObjectDetector extends BaseVisionTaskApi { private static final String TAG = ObjectDetector.class.getSimpleName(); private static final String IMAGE_IN_STREAM_NAME = "image_in"; private static final String NORM_RECT_IN_STREAM_NAME = "norm_rect_in"; + + @SuppressWarnings("ConstantCaseForConstants") private static final List INPUT_STREAMS = Collections.unmodifiableList( Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME)); + + @SuppressWarnings("ConstantCaseForConstants") private static final List OUTPUT_STREAMS = Collections.unmodifiableList(Arrays.asList("DETECTIONS:detections_out", "IMAGE:image_out")); + private static final int DETECTIONS_OUT_STREAM_INDEX = 0; private static final int IMAGE_OUT_STREAM_INDEX = 1; private static final String TASK_GRAPH_NAME = "mediapipe.tasks.vision.ObjectDetectorGraph"; @@ -166,19 +171,19 @@ public final class ObjectDetector extends BaseVisionTaskApi { public static ObjectDetector createFromOptions( Context context, ObjectDetectorOptions detectorOptions) { // TODO: Consolidate OutputHandler and TaskRunner. - OutputHandler handler = new OutputHandler<>(); + OutputHandler handler = new OutputHandler<>(); handler.setOutputPacketConverter( - new OutputHandler.OutputPacketConverter() { + new OutputHandler.OutputPacketConverter() { @Override - public ObjectDetectionResult convertToTaskResult(List packets) { + public ObjectDetectorResult convertToTaskResult(List packets) { // If there is no object detected in the image, just returns empty lists. if (packets.get(DETECTIONS_OUT_STREAM_INDEX).isEmpty()) { - return ObjectDetectionResult.create( + return ObjectDetectorResult.create( new ArrayList<>(), BaseVisionTaskApi.generateResultTimestampMs( detectorOptions.runningMode(), packets.get(DETECTIONS_OUT_STREAM_INDEX))); } - return ObjectDetectionResult.create( + return ObjectDetectorResult.create( PacketGetter.getProtoVector( packets.get(DETECTIONS_OUT_STREAM_INDEX), Detection.parser()), BaseVisionTaskApi.generateResultTimestampMs( @@ -235,7 +240,7 @@ public final class ObjectDetector extends BaseVisionTaskApi { * @param image a MediaPipe {@link MPImage} object for processing. * @throws MediaPipeException if there is an internal error. */ - public ObjectDetectionResult detect(MPImage image) { + public ObjectDetectorResult detect(MPImage image) { return detect(image, ImageProcessingOptions.builder().build()); } @@ -258,10 +263,9 @@ public final class ObjectDetector extends BaseVisionTaskApi { * region-of-interest. * @throws MediaPipeException if there is an internal error. */ - public ObjectDetectionResult detect( - MPImage image, ImageProcessingOptions imageProcessingOptions) { + public ObjectDetectorResult detect(MPImage image, ImageProcessingOptions imageProcessingOptions) { validateImageProcessingOptions(imageProcessingOptions); - return (ObjectDetectionResult) processImageData(image, imageProcessingOptions); + return (ObjectDetectorResult) processImageData(image, imageProcessingOptions); } /** @@ -282,7 +286,7 @@ public final class ObjectDetector extends BaseVisionTaskApi { * @param timestampMs the input timestamp (in milliseconds). * @throws MediaPipeException if there is an internal error. */ - public ObjectDetectionResult detectForVideo(MPImage image, long timestampMs) { + public ObjectDetectorResult detectForVideo(MPImage image, long timestampMs) { return detectForVideo(image, ImageProcessingOptions.builder().build(), timestampMs); } @@ -309,10 +313,10 @@ public final class ObjectDetector extends BaseVisionTaskApi { * region-of-interest. * @throws MediaPipeException if there is an internal error. */ - public ObjectDetectionResult detectForVideo( + public ObjectDetectorResult detectForVideo( MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { validateImageProcessingOptions(imageProcessingOptions); - return (ObjectDetectionResult) processVideoData(image, imageProcessingOptions, timestampMs); + return (ObjectDetectorResult) processVideoData(image, imageProcessingOptions, timestampMs); } /** @@ -435,7 +439,7 @@ public final class ObjectDetector extends BaseVisionTaskApi { * object detector is in the live stream mode. */ public abstract Builder setResultListener( - ResultListener value); + ResultListener value); /** Sets an optional {@link ErrorListener}}. */ public abstract Builder setErrorListener(ErrorListener value); @@ -476,11 +480,13 @@ public final class ObjectDetector extends BaseVisionTaskApi { abstract Optional scoreThreshold(); + @SuppressWarnings("AutoValueImmutableFields") abstract List categoryAllowlist(); + @SuppressWarnings("AutoValueImmutableFields") abstract List categoryDenylist(); - abstract Optional> resultListener(); + abstract Optional> resultListener(); abstract Optional errorListener(); diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetectorResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetectorResult.java new file mode 100644 index 000000000..a17dc90c0 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetectorResult.java @@ -0,0 +1,44 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.objectdetector; + +import com.google.auto.value.AutoValue; +import com.google.mediapipe.formats.proto.DetectionProto.Detection; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +/** Represents the detection results generated by {@link ObjectDetector}. */ +@AutoValue +@SuppressWarnings("deprecation") +public abstract class ObjectDetectorResult extends ObjectDetectionResult { + /** + * Creates an {@link ObjectDetectorResult} instance from a list of {@link Detection} protobuf + * messages. + * + * @param detectionList a list of {@link DetectionOuterClass.Detection} protobuf messages. + * @param timestampMs a timestamp for this result. + */ + public static ObjectDetectorResult create(List detectionList, long timestampMs) { + List detections = new ArrayList<>(); + for (Detection detectionProto : detectionList) { + detections.add( + com.google.mediapipe.tasks.components.containers.Detection.createFromProto( + detectionProto)); + } + return new AutoValue_ObjectDetectorResult( + timestampMs, Collections.unmodifiableList(detections)); + } +} diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetectorTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetectorTest.java index 20ddfcef6..fb83723c5 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetectorTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetectorTest.java @@ -69,7 +69,7 @@ public class ObjectDetectorTest { .build(); ObjectDetector objectDetector = ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); - ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); + ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE); } @@ -77,7 +77,7 @@ public class ObjectDetectorTest { public void detect_successWithNoOptions() throws Exception { ObjectDetector objectDetector = ObjectDetector.createFromFile(ApplicationProvider.getApplicationContext(), MODEL_FILE); - ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); + ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); // Check if the object with the highest score is cat. assertIsCat(results.detections().get(0).categories().get(0), CAT_SCORE); } @@ -91,7 +91,7 @@ public class ObjectDetectorTest { .build(); ObjectDetector objectDetector = ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); - ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); + ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); // results should have 8 detected objects because maxResults was set to 8. assertThat(results.detections()).hasSize(8); } @@ -105,7 +105,7 @@ public class ObjectDetectorTest { .build(); ObjectDetector objectDetector = ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); - ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); + ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); // The score threshold should block all other other objects, except cat. assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE); } @@ -119,7 +119,7 @@ public class ObjectDetectorTest { .build(); ObjectDetector objectDetector = ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); - ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); + ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); // The score threshold should block objects. assertThat(results.detections()).isEmpty(); } @@ -133,7 +133,7 @@ public class ObjectDetectorTest { .build(); ObjectDetector objectDetector = ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); - ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); + ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); // Because of the allowlist, results should only contain cat, and there are 6 detected // bounding boxes of cats in CAT_AND_DOG_IMAGE. assertThat(results.detections()).hasSize(5); @@ -148,7 +148,7 @@ public class ObjectDetectorTest { .build(); ObjectDetector objectDetector = ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); - ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); + ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); // Because of the denylist, the highest result is not cat anymore. assertThat(results.detections().get(0).categories().get(0).categoryName()) .isNotEqualTo("cat"); @@ -160,7 +160,7 @@ public class ObjectDetectorTest { ObjectDetector.createFromFile( ApplicationProvider.getApplicationContext(), TestUtils.loadFile(ApplicationProvider.getApplicationContext(), MODEL_FILE)); - ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); + ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); // Check if the object with the highest score is cat. assertIsCat(results.detections().get(0).categories().get(0), CAT_SCORE); } @@ -172,7 +172,7 @@ public class ObjectDetectorTest { ApplicationProvider.getApplicationContext(), TestUtils.loadToDirectByteBuffer( ApplicationProvider.getApplicationContext(), MODEL_FILE)); - ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); + ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); // Check if the object with the highest score is cat. assertIsCat(results.detections().get(0).categories().get(0), CAT_SCORE); } @@ -191,7 +191,7 @@ public class ObjectDetectorTest { .build(); ObjectDetector objectDetector = ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); - ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); + ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE); } @@ -256,7 +256,7 @@ public class ObjectDetectorTest { ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); ImageProcessingOptions imageProcessingOptions = ImageProcessingOptions.builder().setRotationDegrees(-90).build(); - ObjectDetectionResult results = + ObjectDetectorResult results = objectDetector.detect( getImageFromAsset(CAT_AND_DOG_ROTATED_IMAGE), imageProcessingOptions); @@ -302,7 +302,7 @@ public class ObjectDetectorTest { ObjectDetectorOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) .setRunningMode(mode) - .setResultListener((objectDetectionResult, inputImage) -> {}) + .setResultListener((ObjectDetectorResult, inputImage) -> {}) .build()); assertThat(exception) .hasMessageThat() @@ -381,7 +381,7 @@ public class ObjectDetectorTest { ObjectDetectorOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) .setRunningMode(RunningMode.LIVE_STREAM) - .setResultListener((objectDetectionResult, inputImage) -> {}) + .setResultListener((ObjectDetectorResult, inputImage) -> {}) .build(); ObjectDetector objectDetector = @@ -411,7 +411,7 @@ public class ObjectDetectorTest { .build(); ObjectDetector objectDetector = ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); - ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); + ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE); } @@ -426,7 +426,7 @@ public class ObjectDetectorTest { ObjectDetector objectDetector = ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); for (int i = 0; i < 3; i++) { - ObjectDetectionResult results = + ObjectDetectorResult results = objectDetector.detectForVideo( getImageFromAsset(CAT_AND_DOG_IMAGE), /* timestampsMs= */ i); assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE); @@ -441,8 +441,8 @@ public class ObjectDetectorTest { .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) .setRunningMode(RunningMode.LIVE_STREAM) .setResultListener( - (objectDetectionResult, inputImage) -> { - assertContainsOnlyCat(objectDetectionResult, CAT_BOUNDING_BOX, CAT_SCORE); + (ObjectDetectorResult, inputImage) -> { + assertContainsOnlyCat(ObjectDetectorResult, CAT_BOUNDING_BOX, CAT_SCORE); assertImageSizeIsExpected(inputImage); }) .setMaxResults(1) @@ -468,8 +468,8 @@ public class ObjectDetectorTest { .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) .setRunningMode(RunningMode.LIVE_STREAM) .setResultListener( - (objectDetectionResult, inputImage) -> { - assertContainsOnlyCat(objectDetectionResult, CAT_BOUNDING_BOX, CAT_SCORE); + (ObjectDetectorResult, inputImage) -> { + assertContainsOnlyCat(ObjectDetectorResult, CAT_BOUNDING_BOX, CAT_SCORE); assertImageSizeIsExpected(inputImage); }) .setMaxResults(1) @@ -483,6 +483,16 @@ public class ObjectDetectorTest { } } + @Test + @SuppressWarnings("deprecation") + public void detect_canUseDeprecatedApi() throws Exception { + ObjectDetector objectDetector = + ObjectDetector.createFromFile(ApplicationProvider.getApplicationContext(), MODEL_FILE); + ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); + // Check if the object with the highest score is cat. + assertIsCat(results.detections().get(0).categories().get(0), CAT_SCORE); + } + private static MPImage getImageFromAsset(String filePath) throws Exception { AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets(); InputStream istr = assetManager.open(filePath); @@ -491,7 +501,7 @@ public class ObjectDetectorTest { // Checks if results has one and only detection result, which is a cat. private static void assertContainsOnlyCat( - ObjectDetectionResult result, RectF expectedBoundingBox, float expectedScore) { + ObjectDetectorResult result, RectF expectedBoundingBox, float expectedScore) { assertThat(result.detections()).hasSize(1); Detection catResult = result.detections().get(0); assertApproximatelyEqualBoundingBoxes(catResult.boundingBox(), expectedBoundingBox);