Rename ObjectDetctionResult to ObjectDetectorResult

PiperOrigin-RevId: 534858600
This commit is contained in:
Sebastian Schmidt 2023-05-24 08:53:19 -07:00 committed by Copybara-Service
parent 2017fcc9ab
commit d9f316e12a
5 changed files with 104 additions and 47 deletions

View File

@ -71,6 +71,7 @@ android_library(
srcs = [
"objectdetector/ObjectDetectionResult.java",
"objectdetector/ObjectDetector.java",
"objectdetector/ObjectDetectorResult.java",
],
javacopts = [
"-Xep:AndroidJdkLibsChecker:OFF",

View File

@ -14,15 +14,16 @@
package com.google.mediapipe.tasks.vision.objectdetector;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.tasks.core.TaskResult;
import com.google.mediapipe.formats.proto.DetectionProto.Detection;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
/** Represents the detection results generated by {@link ObjectDetector}. */
@AutoValue
/**
* Represents the detection results generated by {@link ObjectDetector}.
*
* @deprecated Use {@link ObjectDetectorResult} instead.
*/
@Deprecated
public abstract class ObjectDetectionResult implements TaskResult {
@Override
@ -36,15 +37,10 @@ public abstract class ObjectDetectionResult implements TaskResult {
*
* @param detectionList a list of {@link DetectionOuterClass.Detection} protobuf messages.
* @param timestampMs a timestamp for this result.
* @deprecated Use {@link ObjectDetectorResult#create} instead.
*/
@Deprecated
public static ObjectDetectionResult create(List<Detection> detectionList, long timestampMs) {
List<com.google.mediapipe.tasks.components.containers.Detection> detections = new ArrayList<>();
for (Detection detectionProto : detectionList) {
detections.add(
com.google.mediapipe.tasks.components.containers.Detection.createFromProto(
detectionProto));
}
return new AutoValue_ObjectDetectionResult(
timestampMs, Collections.unmodifiableList(detections));
return ObjectDetectorResult.create(detectionList, timestampMs);
}
}

View File

@ -99,11 +99,16 @@ public final class ObjectDetector extends BaseVisionTaskApi {
private static final String TAG = ObjectDetector.class.getSimpleName();
private static final String IMAGE_IN_STREAM_NAME = "image_in";
private static final String NORM_RECT_IN_STREAM_NAME = "norm_rect_in";
@SuppressWarnings("ConstantCaseForConstants")
private static final List<String> INPUT_STREAMS =
Collections.unmodifiableList(
Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME));
@SuppressWarnings("ConstantCaseForConstants")
private static final List<String> OUTPUT_STREAMS =
Collections.unmodifiableList(Arrays.asList("DETECTIONS:detections_out", "IMAGE:image_out"));
private static final int DETECTIONS_OUT_STREAM_INDEX = 0;
private static final int IMAGE_OUT_STREAM_INDEX = 1;
private static final String TASK_GRAPH_NAME = "mediapipe.tasks.vision.ObjectDetectorGraph";
@ -166,19 +171,19 @@ public final class ObjectDetector extends BaseVisionTaskApi {
public static ObjectDetector createFromOptions(
Context context, ObjectDetectorOptions detectorOptions) {
// TODO: Consolidate OutputHandler and TaskRunner.
OutputHandler<ObjectDetectionResult, MPImage> handler = new OutputHandler<>();
OutputHandler<ObjectDetectorResult, MPImage> handler = new OutputHandler<>();
handler.setOutputPacketConverter(
new OutputHandler.OutputPacketConverter<ObjectDetectionResult, MPImage>() {
new OutputHandler.OutputPacketConverter<ObjectDetectorResult, MPImage>() {
@Override
public ObjectDetectionResult convertToTaskResult(List<Packet> packets) {
public ObjectDetectorResult convertToTaskResult(List<Packet> packets) {
// If there is no object detected in the image, just returns empty lists.
if (packets.get(DETECTIONS_OUT_STREAM_INDEX).isEmpty()) {
return ObjectDetectionResult.create(
return ObjectDetectorResult.create(
new ArrayList<>(),
BaseVisionTaskApi.generateResultTimestampMs(
detectorOptions.runningMode(), packets.get(DETECTIONS_OUT_STREAM_INDEX)));
}
return ObjectDetectionResult.create(
return ObjectDetectorResult.create(
PacketGetter.getProtoVector(
packets.get(DETECTIONS_OUT_STREAM_INDEX), Detection.parser()),
BaseVisionTaskApi.generateResultTimestampMs(
@ -235,7 +240,7 @@ public final class ObjectDetector extends BaseVisionTaskApi {
* @param image a MediaPipe {@link MPImage} object for processing.
* @throws MediaPipeException if there is an internal error.
*/
public ObjectDetectionResult detect(MPImage image) {
public ObjectDetectorResult detect(MPImage image) {
return detect(image, ImageProcessingOptions.builder().build());
}
@ -258,10 +263,9 @@ public final class ObjectDetector extends BaseVisionTaskApi {
* region-of-interest.
* @throws MediaPipeException if there is an internal error.
*/
public ObjectDetectionResult detect(
MPImage image, ImageProcessingOptions imageProcessingOptions) {
public ObjectDetectorResult detect(MPImage image, ImageProcessingOptions imageProcessingOptions) {
validateImageProcessingOptions(imageProcessingOptions);
return (ObjectDetectionResult) processImageData(image, imageProcessingOptions);
return (ObjectDetectorResult) processImageData(image, imageProcessingOptions);
}
/**
@ -282,7 +286,7 @@ public final class ObjectDetector extends BaseVisionTaskApi {
* @param timestampMs the input timestamp (in milliseconds).
* @throws MediaPipeException if there is an internal error.
*/
public ObjectDetectionResult detectForVideo(MPImage image, long timestampMs) {
public ObjectDetectorResult detectForVideo(MPImage image, long timestampMs) {
return detectForVideo(image, ImageProcessingOptions.builder().build(), timestampMs);
}
@ -309,10 +313,10 @@ public final class ObjectDetector extends BaseVisionTaskApi {
* region-of-interest.
* @throws MediaPipeException if there is an internal error.
*/
public ObjectDetectionResult detectForVideo(
public ObjectDetectorResult detectForVideo(
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
validateImageProcessingOptions(imageProcessingOptions);
return (ObjectDetectionResult) processVideoData(image, imageProcessingOptions, timestampMs);
return (ObjectDetectorResult) processVideoData(image, imageProcessingOptions, timestampMs);
}
/**
@ -435,7 +439,7 @@ public final class ObjectDetector extends BaseVisionTaskApi {
* object detector is in the live stream mode.
*/
public abstract Builder setResultListener(
ResultListener<ObjectDetectionResult, MPImage> value);
ResultListener<ObjectDetectorResult, MPImage> value);
/** Sets an optional {@link ErrorListener}}. */
public abstract Builder setErrorListener(ErrorListener value);
@ -476,11 +480,13 @@ public final class ObjectDetector extends BaseVisionTaskApi {
abstract Optional<Float> scoreThreshold();
@SuppressWarnings("AutoValueImmutableFields")
abstract List<String> categoryAllowlist();
@SuppressWarnings("AutoValueImmutableFields")
abstract List<String> categoryDenylist();
abstract Optional<ResultListener<ObjectDetectionResult, MPImage>> resultListener();
abstract Optional<ResultListener<ObjectDetectorResult, MPImage>> resultListener();
abstract Optional<ErrorListener> errorListener();

View File

@ -0,0 +1,44 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.vision.objectdetector;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.formats.proto.DetectionProto.Detection;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
/** Represents the detection results generated by {@link ObjectDetector}. */
@AutoValue
@SuppressWarnings("deprecation")
public abstract class ObjectDetectorResult extends ObjectDetectionResult {
/**
* Creates an {@link ObjectDetectorResult} instance from a list of {@link Detection} protobuf
* messages.
*
* @param detectionList a list of {@link DetectionOuterClass.Detection} protobuf messages.
* @param timestampMs a timestamp for this result.
*/
public static ObjectDetectorResult create(List<Detection> detectionList, long timestampMs) {
List<com.google.mediapipe.tasks.components.containers.Detection> detections = new ArrayList<>();
for (Detection detectionProto : detectionList) {
detections.add(
com.google.mediapipe.tasks.components.containers.Detection.createFromProto(
detectionProto));
}
return new AutoValue_ObjectDetectorResult(
timestampMs, Collections.unmodifiableList(detections));
}
}

View File

@ -69,7 +69,7 @@ public class ObjectDetectorTest {
.build();
ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE);
}
@ -77,7 +77,7 @@ public class ObjectDetectorTest {
public void detect_successWithNoOptions() throws Exception {
ObjectDetector objectDetector =
ObjectDetector.createFromFile(ApplicationProvider.getApplicationContext(), MODEL_FILE);
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
// Check if the object with the highest score is cat.
assertIsCat(results.detections().get(0).categories().get(0), CAT_SCORE);
}
@ -91,7 +91,7 @@ public class ObjectDetectorTest {
.build();
ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
// results should have 8 detected objects because maxResults was set to 8.
assertThat(results.detections()).hasSize(8);
}
@ -105,7 +105,7 @@ public class ObjectDetectorTest {
.build();
ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
// The score threshold should block all other other objects, except cat.
assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE);
}
@ -119,7 +119,7 @@ public class ObjectDetectorTest {
.build();
ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
// The score threshold should block objects.
assertThat(results.detections()).isEmpty();
}
@ -133,7 +133,7 @@ public class ObjectDetectorTest {
.build();
ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
// Because of the allowlist, results should only contain cat, and there are 6 detected
// bounding boxes of cats in CAT_AND_DOG_IMAGE.
assertThat(results.detections()).hasSize(5);
@ -148,7 +148,7 @@ public class ObjectDetectorTest {
.build();
ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
// Because of the denylist, the highest result is not cat anymore.
assertThat(results.detections().get(0).categories().get(0).categoryName())
.isNotEqualTo("cat");
@ -160,7 +160,7 @@ public class ObjectDetectorTest {
ObjectDetector.createFromFile(
ApplicationProvider.getApplicationContext(),
TestUtils.loadFile(ApplicationProvider.getApplicationContext(), MODEL_FILE));
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
// Check if the object with the highest score is cat.
assertIsCat(results.detections().get(0).categories().get(0), CAT_SCORE);
}
@ -172,7 +172,7 @@ public class ObjectDetectorTest {
ApplicationProvider.getApplicationContext(),
TestUtils.loadToDirectByteBuffer(
ApplicationProvider.getApplicationContext(), MODEL_FILE));
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
// Check if the object with the highest score is cat.
assertIsCat(results.detections().get(0).categories().get(0), CAT_SCORE);
}
@ -191,7 +191,7 @@ public class ObjectDetectorTest {
.build();
ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE);
}
@ -256,7 +256,7 @@ public class ObjectDetectorTest {
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ImageProcessingOptions imageProcessingOptions =
ImageProcessingOptions.builder().setRotationDegrees(-90).build();
ObjectDetectionResult results =
ObjectDetectorResult results =
objectDetector.detect(
getImageFromAsset(CAT_AND_DOG_ROTATED_IMAGE), imageProcessingOptions);
@ -302,7 +302,7 @@ public class ObjectDetectorTest {
ObjectDetectorOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
.setRunningMode(mode)
.setResultListener((objectDetectionResult, inputImage) -> {})
.setResultListener((ObjectDetectorResult, inputImage) -> {})
.build());
assertThat(exception)
.hasMessageThat()
@ -381,7 +381,7 @@ public class ObjectDetectorTest {
ObjectDetectorOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
.setRunningMode(RunningMode.LIVE_STREAM)
.setResultListener((objectDetectionResult, inputImage) -> {})
.setResultListener((ObjectDetectorResult, inputImage) -> {})
.build();
ObjectDetector objectDetector =
@ -411,7 +411,7 @@ public class ObjectDetectorTest {
.build();
ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE);
}
@ -426,7 +426,7 @@ public class ObjectDetectorTest {
ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
for (int i = 0; i < 3; i++) {
ObjectDetectionResult results =
ObjectDetectorResult results =
objectDetector.detectForVideo(
getImageFromAsset(CAT_AND_DOG_IMAGE), /* timestampsMs= */ i);
assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE);
@ -441,8 +441,8 @@ public class ObjectDetectorTest {
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
.setRunningMode(RunningMode.LIVE_STREAM)
.setResultListener(
(objectDetectionResult, inputImage) -> {
assertContainsOnlyCat(objectDetectionResult, CAT_BOUNDING_BOX, CAT_SCORE);
(ObjectDetectorResult, inputImage) -> {
assertContainsOnlyCat(ObjectDetectorResult, CAT_BOUNDING_BOX, CAT_SCORE);
assertImageSizeIsExpected(inputImage);
})
.setMaxResults(1)
@ -468,8 +468,8 @@ public class ObjectDetectorTest {
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
.setRunningMode(RunningMode.LIVE_STREAM)
.setResultListener(
(objectDetectionResult, inputImage) -> {
assertContainsOnlyCat(objectDetectionResult, CAT_BOUNDING_BOX, CAT_SCORE);
(ObjectDetectorResult, inputImage) -> {
assertContainsOnlyCat(ObjectDetectorResult, CAT_BOUNDING_BOX, CAT_SCORE);
assertImageSizeIsExpected(inputImage);
})
.setMaxResults(1)
@ -483,6 +483,16 @@ public class ObjectDetectorTest {
}
}
@Test
@SuppressWarnings("deprecation")
public void detect_canUseDeprecatedApi() throws Exception {
ObjectDetector objectDetector =
ObjectDetector.createFromFile(ApplicationProvider.getApplicationContext(), MODEL_FILE);
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
// Check if the object with the highest score is cat.
assertIsCat(results.detections().get(0).categories().get(0), CAT_SCORE);
}
private static MPImage getImageFromAsset(String filePath) throws Exception {
AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets();
InputStream istr = assetManager.open(filePath);
@ -491,7 +501,7 @@ public class ObjectDetectorTest {
// Checks if results has one and only detection result, which is a cat.
private static void assertContainsOnlyCat(
ObjectDetectionResult result, RectF expectedBoundingBox, float expectedScore) {
ObjectDetectorResult result, RectF expectedBoundingBox, float expectedScore) {
assertThat(result.detections()).hasSize(1);
Detection catResult = result.detections().get(0);
assertApproximatelyEqualBoundingBoxes(catResult.boundingBox(), expectedBoundingBox);