Rename ObjectDetctionResult to ObjectDetectorResult

PiperOrigin-RevId: 534858600
This commit is contained in:
Sebastian Schmidt 2023-05-24 08:53:19 -07:00 committed by Copybara-Service
parent 2017fcc9ab
commit d9f316e12a
5 changed files with 104 additions and 47 deletions

View File

@ -71,6 +71,7 @@ android_library(
srcs = [ srcs = [
"objectdetector/ObjectDetectionResult.java", "objectdetector/ObjectDetectionResult.java",
"objectdetector/ObjectDetector.java", "objectdetector/ObjectDetector.java",
"objectdetector/ObjectDetectorResult.java",
], ],
javacopts = [ javacopts = [
"-Xep:AndroidJdkLibsChecker:OFF", "-Xep:AndroidJdkLibsChecker:OFF",

View File

@ -14,15 +14,16 @@
package com.google.mediapipe.tasks.vision.objectdetector; package com.google.mediapipe.tasks.vision.objectdetector;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.tasks.core.TaskResult; import com.google.mediapipe.tasks.core.TaskResult;
import com.google.mediapipe.formats.proto.DetectionProto.Detection; import com.google.mediapipe.formats.proto.DetectionProto.Detection;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List; import java.util.List;
/** Represents the detection results generated by {@link ObjectDetector}. */ /**
@AutoValue * Represents the detection results generated by {@link ObjectDetector}.
*
* @deprecated Use {@link ObjectDetectorResult} instead.
*/
@Deprecated
public abstract class ObjectDetectionResult implements TaskResult { public abstract class ObjectDetectionResult implements TaskResult {
@Override @Override
@ -36,15 +37,10 @@ public abstract class ObjectDetectionResult implements TaskResult {
* *
* @param detectionList a list of {@link DetectionOuterClass.Detection} protobuf messages. * @param detectionList a list of {@link DetectionOuterClass.Detection} protobuf messages.
* @param timestampMs a timestamp for this result. * @param timestampMs a timestamp for this result.
* @deprecated Use {@link ObjectDetectorResult#create} instead.
*/ */
@Deprecated
public static ObjectDetectionResult create(List<Detection> detectionList, long timestampMs) { public static ObjectDetectionResult create(List<Detection> detectionList, long timestampMs) {
List<com.google.mediapipe.tasks.components.containers.Detection> detections = new ArrayList<>(); return ObjectDetectorResult.create(detectionList, timestampMs);
for (Detection detectionProto : detectionList) {
detections.add(
com.google.mediapipe.tasks.components.containers.Detection.createFromProto(
detectionProto));
}
return new AutoValue_ObjectDetectionResult(
timestampMs, Collections.unmodifiableList(detections));
} }
} }

View File

@ -99,11 +99,16 @@ public final class ObjectDetector extends BaseVisionTaskApi {
private static final String TAG = ObjectDetector.class.getSimpleName(); private static final String TAG = ObjectDetector.class.getSimpleName();
private static final String IMAGE_IN_STREAM_NAME = "image_in"; private static final String IMAGE_IN_STREAM_NAME = "image_in";
private static final String NORM_RECT_IN_STREAM_NAME = "norm_rect_in"; private static final String NORM_RECT_IN_STREAM_NAME = "norm_rect_in";
@SuppressWarnings("ConstantCaseForConstants")
private static final List<String> INPUT_STREAMS = private static final List<String> INPUT_STREAMS =
Collections.unmodifiableList( Collections.unmodifiableList(
Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME)); Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME));
@SuppressWarnings("ConstantCaseForConstants")
private static final List<String> OUTPUT_STREAMS = private static final List<String> OUTPUT_STREAMS =
Collections.unmodifiableList(Arrays.asList("DETECTIONS:detections_out", "IMAGE:image_out")); Collections.unmodifiableList(Arrays.asList("DETECTIONS:detections_out", "IMAGE:image_out"));
private static final int DETECTIONS_OUT_STREAM_INDEX = 0; private static final int DETECTIONS_OUT_STREAM_INDEX = 0;
private static final int IMAGE_OUT_STREAM_INDEX = 1; private static final int IMAGE_OUT_STREAM_INDEX = 1;
private static final String TASK_GRAPH_NAME = "mediapipe.tasks.vision.ObjectDetectorGraph"; private static final String TASK_GRAPH_NAME = "mediapipe.tasks.vision.ObjectDetectorGraph";
@ -166,19 +171,19 @@ public final class ObjectDetector extends BaseVisionTaskApi {
public static ObjectDetector createFromOptions( public static ObjectDetector createFromOptions(
Context context, ObjectDetectorOptions detectorOptions) { Context context, ObjectDetectorOptions detectorOptions) {
// TODO: Consolidate OutputHandler and TaskRunner. // TODO: Consolidate OutputHandler and TaskRunner.
OutputHandler<ObjectDetectionResult, MPImage> handler = new OutputHandler<>(); OutputHandler<ObjectDetectorResult, MPImage> handler = new OutputHandler<>();
handler.setOutputPacketConverter( handler.setOutputPacketConverter(
new OutputHandler.OutputPacketConverter<ObjectDetectionResult, MPImage>() { new OutputHandler.OutputPacketConverter<ObjectDetectorResult, MPImage>() {
@Override @Override
public ObjectDetectionResult convertToTaskResult(List<Packet> packets) { public ObjectDetectorResult convertToTaskResult(List<Packet> packets) {
// If there is no object detected in the image, just returns empty lists. // If there is no object detected in the image, just returns empty lists.
if (packets.get(DETECTIONS_OUT_STREAM_INDEX).isEmpty()) { if (packets.get(DETECTIONS_OUT_STREAM_INDEX).isEmpty()) {
return ObjectDetectionResult.create( return ObjectDetectorResult.create(
new ArrayList<>(), new ArrayList<>(),
BaseVisionTaskApi.generateResultTimestampMs( BaseVisionTaskApi.generateResultTimestampMs(
detectorOptions.runningMode(), packets.get(DETECTIONS_OUT_STREAM_INDEX))); detectorOptions.runningMode(), packets.get(DETECTIONS_OUT_STREAM_INDEX)));
} }
return ObjectDetectionResult.create( return ObjectDetectorResult.create(
PacketGetter.getProtoVector( PacketGetter.getProtoVector(
packets.get(DETECTIONS_OUT_STREAM_INDEX), Detection.parser()), packets.get(DETECTIONS_OUT_STREAM_INDEX), Detection.parser()),
BaseVisionTaskApi.generateResultTimestampMs( BaseVisionTaskApi.generateResultTimestampMs(
@ -235,7 +240,7 @@ public final class ObjectDetector extends BaseVisionTaskApi {
* @param image a MediaPipe {@link MPImage} object for processing. * @param image a MediaPipe {@link MPImage} object for processing.
* @throws MediaPipeException if there is an internal error. * @throws MediaPipeException if there is an internal error.
*/ */
public ObjectDetectionResult detect(MPImage image) { public ObjectDetectorResult detect(MPImage image) {
return detect(image, ImageProcessingOptions.builder().build()); return detect(image, ImageProcessingOptions.builder().build());
} }
@ -258,10 +263,9 @@ public final class ObjectDetector extends BaseVisionTaskApi {
* region-of-interest. * region-of-interest.
* @throws MediaPipeException if there is an internal error. * @throws MediaPipeException if there is an internal error.
*/ */
public ObjectDetectionResult detect( public ObjectDetectorResult detect(MPImage image, ImageProcessingOptions imageProcessingOptions) {
MPImage image, ImageProcessingOptions imageProcessingOptions) {
validateImageProcessingOptions(imageProcessingOptions); validateImageProcessingOptions(imageProcessingOptions);
return (ObjectDetectionResult) processImageData(image, imageProcessingOptions); return (ObjectDetectorResult) processImageData(image, imageProcessingOptions);
} }
/** /**
@ -282,7 +286,7 @@ public final class ObjectDetector extends BaseVisionTaskApi {
* @param timestampMs the input timestamp (in milliseconds). * @param timestampMs the input timestamp (in milliseconds).
* @throws MediaPipeException if there is an internal error. * @throws MediaPipeException if there is an internal error.
*/ */
public ObjectDetectionResult detectForVideo(MPImage image, long timestampMs) { public ObjectDetectorResult detectForVideo(MPImage image, long timestampMs) {
return detectForVideo(image, ImageProcessingOptions.builder().build(), timestampMs); return detectForVideo(image, ImageProcessingOptions.builder().build(), timestampMs);
} }
@ -309,10 +313,10 @@ public final class ObjectDetector extends BaseVisionTaskApi {
* region-of-interest. * region-of-interest.
* @throws MediaPipeException if there is an internal error. * @throws MediaPipeException if there is an internal error.
*/ */
public ObjectDetectionResult detectForVideo( public ObjectDetectorResult detectForVideo(
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
validateImageProcessingOptions(imageProcessingOptions); validateImageProcessingOptions(imageProcessingOptions);
return (ObjectDetectionResult) processVideoData(image, imageProcessingOptions, timestampMs); return (ObjectDetectorResult) processVideoData(image, imageProcessingOptions, timestampMs);
} }
/** /**
@ -435,7 +439,7 @@ public final class ObjectDetector extends BaseVisionTaskApi {
* object detector is in the live stream mode. * object detector is in the live stream mode.
*/ */
public abstract Builder setResultListener( public abstract Builder setResultListener(
ResultListener<ObjectDetectionResult, MPImage> value); ResultListener<ObjectDetectorResult, MPImage> value);
/** Sets an optional {@link ErrorListener}}. */ /** Sets an optional {@link ErrorListener}}. */
public abstract Builder setErrorListener(ErrorListener value); public abstract Builder setErrorListener(ErrorListener value);
@ -476,11 +480,13 @@ public final class ObjectDetector extends BaseVisionTaskApi {
abstract Optional<Float> scoreThreshold(); abstract Optional<Float> scoreThreshold();
@SuppressWarnings("AutoValueImmutableFields")
abstract List<String> categoryAllowlist(); abstract List<String> categoryAllowlist();
@SuppressWarnings("AutoValueImmutableFields")
abstract List<String> categoryDenylist(); abstract List<String> categoryDenylist();
abstract Optional<ResultListener<ObjectDetectionResult, MPImage>> resultListener(); abstract Optional<ResultListener<ObjectDetectorResult, MPImage>> resultListener();
abstract Optional<ErrorListener> errorListener(); abstract Optional<ErrorListener> errorListener();

View File

@ -0,0 +1,44 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.vision.objectdetector;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.formats.proto.DetectionProto.Detection;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
/** Represents the detection results generated by {@link ObjectDetector}. */
@AutoValue
@SuppressWarnings("deprecation")
public abstract class ObjectDetectorResult extends ObjectDetectionResult {
/**
* Creates an {@link ObjectDetectorResult} instance from a list of {@link Detection} protobuf
* messages.
*
* @param detectionList a list of {@link DetectionOuterClass.Detection} protobuf messages.
* @param timestampMs a timestamp for this result.
*/
public static ObjectDetectorResult create(List<Detection> detectionList, long timestampMs) {
List<com.google.mediapipe.tasks.components.containers.Detection> detections = new ArrayList<>();
for (Detection detectionProto : detectionList) {
detections.add(
com.google.mediapipe.tasks.components.containers.Detection.createFromProto(
detectionProto));
}
return new AutoValue_ObjectDetectorResult(
timestampMs, Collections.unmodifiableList(detections));
}
}

View File

@ -69,7 +69,7 @@ public class ObjectDetectorTest {
.build(); .build();
ObjectDetector objectDetector = ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE); assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE);
} }
@ -77,7 +77,7 @@ public class ObjectDetectorTest {
public void detect_successWithNoOptions() throws Exception { public void detect_successWithNoOptions() throws Exception {
ObjectDetector objectDetector = ObjectDetector objectDetector =
ObjectDetector.createFromFile(ApplicationProvider.getApplicationContext(), MODEL_FILE); ObjectDetector.createFromFile(ApplicationProvider.getApplicationContext(), MODEL_FILE);
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
// Check if the object with the highest score is cat. // Check if the object with the highest score is cat.
assertIsCat(results.detections().get(0).categories().get(0), CAT_SCORE); assertIsCat(results.detections().get(0).categories().get(0), CAT_SCORE);
} }
@ -91,7 +91,7 @@ public class ObjectDetectorTest {
.build(); .build();
ObjectDetector objectDetector = ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
// results should have 8 detected objects because maxResults was set to 8. // results should have 8 detected objects because maxResults was set to 8.
assertThat(results.detections()).hasSize(8); assertThat(results.detections()).hasSize(8);
} }
@ -105,7 +105,7 @@ public class ObjectDetectorTest {
.build(); .build();
ObjectDetector objectDetector = ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
// The score threshold should block all other other objects, except cat. // The score threshold should block all other other objects, except cat.
assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE); assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE);
} }
@ -119,7 +119,7 @@ public class ObjectDetectorTest {
.build(); .build();
ObjectDetector objectDetector = ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
// The score threshold should block objects. // The score threshold should block objects.
assertThat(results.detections()).isEmpty(); assertThat(results.detections()).isEmpty();
} }
@ -133,7 +133,7 @@ public class ObjectDetectorTest {
.build(); .build();
ObjectDetector objectDetector = ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
// Because of the allowlist, results should only contain cat, and there are 6 detected // Because of the allowlist, results should only contain cat, and there are 6 detected
// bounding boxes of cats in CAT_AND_DOG_IMAGE. // bounding boxes of cats in CAT_AND_DOG_IMAGE.
assertThat(results.detections()).hasSize(5); assertThat(results.detections()).hasSize(5);
@ -148,7 +148,7 @@ public class ObjectDetectorTest {
.build(); .build();
ObjectDetector objectDetector = ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
// Because of the denylist, the highest result is not cat anymore. // Because of the denylist, the highest result is not cat anymore.
assertThat(results.detections().get(0).categories().get(0).categoryName()) assertThat(results.detections().get(0).categories().get(0).categoryName())
.isNotEqualTo("cat"); .isNotEqualTo("cat");
@ -160,7 +160,7 @@ public class ObjectDetectorTest {
ObjectDetector.createFromFile( ObjectDetector.createFromFile(
ApplicationProvider.getApplicationContext(), ApplicationProvider.getApplicationContext(),
TestUtils.loadFile(ApplicationProvider.getApplicationContext(), MODEL_FILE)); TestUtils.loadFile(ApplicationProvider.getApplicationContext(), MODEL_FILE));
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
// Check if the object with the highest score is cat. // Check if the object with the highest score is cat.
assertIsCat(results.detections().get(0).categories().get(0), CAT_SCORE); assertIsCat(results.detections().get(0).categories().get(0), CAT_SCORE);
} }
@ -172,7 +172,7 @@ public class ObjectDetectorTest {
ApplicationProvider.getApplicationContext(), ApplicationProvider.getApplicationContext(),
TestUtils.loadToDirectByteBuffer( TestUtils.loadToDirectByteBuffer(
ApplicationProvider.getApplicationContext(), MODEL_FILE)); ApplicationProvider.getApplicationContext(), MODEL_FILE));
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
// Check if the object with the highest score is cat. // Check if the object with the highest score is cat.
assertIsCat(results.detections().get(0).categories().get(0), CAT_SCORE); assertIsCat(results.detections().get(0).categories().get(0), CAT_SCORE);
} }
@ -191,7 +191,7 @@ public class ObjectDetectorTest {
.build(); .build();
ObjectDetector objectDetector = ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE); assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE);
} }
@ -256,7 +256,7 @@ public class ObjectDetectorTest {
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ImageProcessingOptions imageProcessingOptions = ImageProcessingOptions imageProcessingOptions =
ImageProcessingOptions.builder().setRotationDegrees(-90).build(); ImageProcessingOptions.builder().setRotationDegrees(-90).build();
ObjectDetectionResult results = ObjectDetectorResult results =
objectDetector.detect( objectDetector.detect(
getImageFromAsset(CAT_AND_DOG_ROTATED_IMAGE), imageProcessingOptions); getImageFromAsset(CAT_AND_DOG_ROTATED_IMAGE), imageProcessingOptions);
@ -302,7 +302,7 @@ public class ObjectDetectorTest {
ObjectDetectorOptions.builder() ObjectDetectorOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
.setRunningMode(mode) .setRunningMode(mode)
.setResultListener((objectDetectionResult, inputImage) -> {}) .setResultListener((ObjectDetectorResult, inputImage) -> {})
.build()); .build());
assertThat(exception) assertThat(exception)
.hasMessageThat() .hasMessageThat()
@ -381,7 +381,7 @@ public class ObjectDetectorTest {
ObjectDetectorOptions.builder() ObjectDetectorOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
.setRunningMode(RunningMode.LIVE_STREAM) .setRunningMode(RunningMode.LIVE_STREAM)
.setResultListener((objectDetectionResult, inputImage) -> {}) .setResultListener((ObjectDetectorResult, inputImage) -> {})
.build(); .build();
ObjectDetector objectDetector = ObjectDetector objectDetector =
@ -411,7 +411,7 @@ public class ObjectDetectorTest {
.build(); .build();
ObjectDetector objectDetector = ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE); assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE);
} }
@ -426,7 +426,7 @@ public class ObjectDetectorTest {
ObjectDetector objectDetector = ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
for (int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
ObjectDetectionResult results = ObjectDetectorResult results =
objectDetector.detectForVideo( objectDetector.detectForVideo(
getImageFromAsset(CAT_AND_DOG_IMAGE), /* timestampsMs= */ i); getImageFromAsset(CAT_AND_DOG_IMAGE), /* timestampsMs= */ i);
assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE); assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE);
@ -441,8 +441,8 @@ public class ObjectDetectorTest {
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
.setRunningMode(RunningMode.LIVE_STREAM) .setRunningMode(RunningMode.LIVE_STREAM)
.setResultListener( .setResultListener(
(objectDetectionResult, inputImage) -> { (ObjectDetectorResult, inputImage) -> {
assertContainsOnlyCat(objectDetectionResult, CAT_BOUNDING_BOX, CAT_SCORE); assertContainsOnlyCat(ObjectDetectorResult, CAT_BOUNDING_BOX, CAT_SCORE);
assertImageSizeIsExpected(inputImage); assertImageSizeIsExpected(inputImage);
}) })
.setMaxResults(1) .setMaxResults(1)
@ -468,8 +468,8 @@ public class ObjectDetectorTest {
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
.setRunningMode(RunningMode.LIVE_STREAM) .setRunningMode(RunningMode.LIVE_STREAM)
.setResultListener( .setResultListener(
(objectDetectionResult, inputImage) -> { (ObjectDetectorResult, inputImage) -> {
assertContainsOnlyCat(objectDetectionResult, CAT_BOUNDING_BOX, CAT_SCORE); assertContainsOnlyCat(ObjectDetectorResult, CAT_BOUNDING_BOX, CAT_SCORE);
assertImageSizeIsExpected(inputImage); assertImageSizeIsExpected(inputImage);
}) })
.setMaxResults(1) .setMaxResults(1)
@ -483,6 +483,16 @@ public class ObjectDetectorTest {
} }
} }
@Test
@SuppressWarnings("deprecation")
public void detect_canUseDeprecatedApi() throws Exception {
ObjectDetector objectDetector =
ObjectDetector.createFromFile(ApplicationProvider.getApplicationContext(), MODEL_FILE);
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
// Check if the object with the highest score is cat.
assertIsCat(results.detections().get(0).categories().get(0), CAT_SCORE);
}
private static MPImage getImageFromAsset(String filePath) throws Exception { private static MPImage getImageFromAsset(String filePath) throws Exception {
AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets(); AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets();
InputStream istr = assetManager.open(filePath); InputStream istr = assetManager.open(filePath);
@ -491,7 +501,7 @@ public class ObjectDetectorTest {
// Checks if results has one and only detection result, which is a cat. // Checks if results has one and only detection result, which is a cat.
private static void assertContainsOnlyCat( private static void assertContainsOnlyCat(
ObjectDetectionResult result, RectF expectedBoundingBox, float expectedScore) { ObjectDetectorResult result, RectF expectedBoundingBox, float expectedScore) {
assertThat(result.detections()).hasSize(1); assertThat(result.detections()).hasSize(1);
Detection catResult = result.detections().get(0); Detection catResult = result.detections().get(0);
assertApproximatelyEqualBoundingBoxes(catResult.boundingBox(), expectedBoundingBox); assertApproximatelyEqualBoundingBoxes(catResult.boundingBox(), expectedBoundingBox);