Rename ObjectDetctionResult to ObjectDetectorResult
PiperOrigin-RevId: 534858600
This commit is contained in:
parent
2017fcc9ab
commit
d9f316e12a
|
@ -71,6 +71,7 @@ android_library(
|
||||||
srcs = [
|
srcs = [
|
||||||
"objectdetector/ObjectDetectionResult.java",
|
"objectdetector/ObjectDetectionResult.java",
|
||||||
"objectdetector/ObjectDetector.java",
|
"objectdetector/ObjectDetector.java",
|
||||||
|
"objectdetector/ObjectDetectorResult.java",
|
||||||
],
|
],
|
||||||
javacopts = [
|
javacopts = [
|
||||||
"-Xep:AndroidJdkLibsChecker:OFF",
|
"-Xep:AndroidJdkLibsChecker:OFF",
|
||||||
|
|
|
@ -14,15 +14,16 @@
|
||||||
|
|
||||||
package com.google.mediapipe.tasks.vision.objectdetector;
|
package com.google.mediapipe.tasks.vision.objectdetector;
|
||||||
|
|
||||||
import com.google.auto.value.AutoValue;
|
|
||||||
import com.google.mediapipe.tasks.core.TaskResult;
|
import com.google.mediapipe.tasks.core.TaskResult;
|
||||||
import com.google.mediapipe.formats.proto.DetectionProto.Detection;
|
import com.google.mediapipe.formats.proto.DetectionProto.Detection;
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
/** Represents the detection results generated by {@link ObjectDetector}. */
|
/**
|
||||||
@AutoValue
|
* Represents the detection results generated by {@link ObjectDetector}.
|
||||||
|
*
|
||||||
|
* @deprecated Use {@link ObjectDetectorResult} instead.
|
||||||
|
*/
|
||||||
|
@Deprecated
|
||||||
public abstract class ObjectDetectionResult implements TaskResult {
|
public abstract class ObjectDetectionResult implements TaskResult {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -36,15 +37,10 @@ public abstract class ObjectDetectionResult implements TaskResult {
|
||||||
*
|
*
|
||||||
* @param detectionList a list of {@link DetectionOuterClass.Detection} protobuf messages.
|
* @param detectionList a list of {@link DetectionOuterClass.Detection} protobuf messages.
|
||||||
* @param timestampMs a timestamp for this result.
|
* @param timestampMs a timestamp for this result.
|
||||||
|
* @deprecated Use {@link ObjectDetectorResult#create} instead.
|
||||||
*/
|
*/
|
||||||
|
@Deprecated
|
||||||
public static ObjectDetectionResult create(List<Detection> detectionList, long timestampMs) {
|
public static ObjectDetectionResult create(List<Detection> detectionList, long timestampMs) {
|
||||||
List<com.google.mediapipe.tasks.components.containers.Detection> detections = new ArrayList<>();
|
return ObjectDetectorResult.create(detectionList, timestampMs);
|
||||||
for (Detection detectionProto : detectionList) {
|
|
||||||
detections.add(
|
|
||||||
com.google.mediapipe.tasks.components.containers.Detection.createFromProto(
|
|
||||||
detectionProto));
|
|
||||||
}
|
|
||||||
return new AutoValue_ObjectDetectionResult(
|
|
||||||
timestampMs, Collections.unmodifiableList(detections));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -99,11 +99,16 @@ public final class ObjectDetector extends BaseVisionTaskApi {
|
||||||
private static final String TAG = ObjectDetector.class.getSimpleName();
|
private static final String TAG = ObjectDetector.class.getSimpleName();
|
||||||
private static final String IMAGE_IN_STREAM_NAME = "image_in";
|
private static final String IMAGE_IN_STREAM_NAME = "image_in";
|
||||||
private static final String NORM_RECT_IN_STREAM_NAME = "norm_rect_in";
|
private static final String NORM_RECT_IN_STREAM_NAME = "norm_rect_in";
|
||||||
|
|
||||||
|
@SuppressWarnings("ConstantCaseForConstants")
|
||||||
private static final List<String> INPUT_STREAMS =
|
private static final List<String> INPUT_STREAMS =
|
||||||
Collections.unmodifiableList(
|
Collections.unmodifiableList(
|
||||||
Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME));
|
Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME));
|
||||||
|
|
||||||
|
@SuppressWarnings("ConstantCaseForConstants")
|
||||||
private static final List<String> OUTPUT_STREAMS =
|
private static final List<String> OUTPUT_STREAMS =
|
||||||
Collections.unmodifiableList(Arrays.asList("DETECTIONS:detections_out", "IMAGE:image_out"));
|
Collections.unmodifiableList(Arrays.asList("DETECTIONS:detections_out", "IMAGE:image_out"));
|
||||||
|
|
||||||
private static final int DETECTIONS_OUT_STREAM_INDEX = 0;
|
private static final int DETECTIONS_OUT_STREAM_INDEX = 0;
|
||||||
private static final int IMAGE_OUT_STREAM_INDEX = 1;
|
private static final int IMAGE_OUT_STREAM_INDEX = 1;
|
||||||
private static final String TASK_GRAPH_NAME = "mediapipe.tasks.vision.ObjectDetectorGraph";
|
private static final String TASK_GRAPH_NAME = "mediapipe.tasks.vision.ObjectDetectorGraph";
|
||||||
|
@ -166,19 +171,19 @@ public final class ObjectDetector extends BaseVisionTaskApi {
|
||||||
public static ObjectDetector createFromOptions(
|
public static ObjectDetector createFromOptions(
|
||||||
Context context, ObjectDetectorOptions detectorOptions) {
|
Context context, ObjectDetectorOptions detectorOptions) {
|
||||||
// TODO: Consolidate OutputHandler and TaskRunner.
|
// TODO: Consolidate OutputHandler and TaskRunner.
|
||||||
OutputHandler<ObjectDetectionResult, MPImage> handler = new OutputHandler<>();
|
OutputHandler<ObjectDetectorResult, MPImage> handler = new OutputHandler<>();
|
||||||
handler.setOutputPacketConverter(
|
handler.setOutputPacketConverter(
|
||||||
new OutputHandler.OutputPacketConverter<ObjectDetectionResult, MPImage>() {
|
new OutputHandler.OutputPacketConverter<ObjectDetectorResult, MPImage>() {
|
||||||
@Override
|
@Override
|
||||||
public ObjectDetectionResult convertToTaskResult(List<Packet> packets) {
|
public ObjectDetectorResult convertToTaskResult(List<Packet> packets) {
|
||||||
// If there is no object detected in the image, just returns empty lists.
|
// If there is no object detected in the image, just returns empty lists.
|
||||||
if (packets.get(DETECTIONS_OUT_STREAM_INDEX).isEmpty()) {
|
if (packets.get(DETECTIONS_OUT_STREAM_INDEX).isEmpty()) {
|
||||||
return ObjectDetectionResult.create(
|
return ObjectDetectorResult.create(
|
||||||
new ArrayList<>(),
|
new ArrayList<>(),
|
||||||
BaseVisionTaskApi.generateResultTimestampMs(
|
BaseVisionTaskApi.generateResultTimestampMs(
|
||||||
detectorOptions.runningMode(), packets.get(DETECTIONS_OUT_STREAM_INDEX)));
|
detectorOptions.runningMode(), packets.get(DETECTIONS_OUT_STREAM_INDEX)));
|
||||||
}
|
}
|
||||||
return ObjectDetectionResult.create(
|
return ObjectDetectorResult.create(
|
||||||
PacketGetter.getProtoVector(
|
PacketGetter.getProtoVector(
|
||||||
packets.get(DETECTIONS_OUT_STREAM_INDEX), Detection.parser()),
|
packets.get(DETECTIONS_OUT_STREAM_INDEX), Detection.parser()),
|
||||||
BaseVisionTaskApi.generateResultTimestampMs(
|
BaseVisionTaskApi.generateResultTimestampMs(
|
||||||
|
@ -235,7 +240,7 @@ public final class ObjectDetector extends BaseVisionTaskApi {
|
||||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||||
* @throws MediaPipeException if there is an internal error.
|
* @throws MediaPipeException if there is an internal error.
|
||||||
*/
|
*/
|
||||||
public ObjectDetectionResult detect(MPImage image) {
|
public ObjectDetectorResult detect(MPImage image) {
|
||||||
return detect(image, ImageProcessingOptions.builder().build());
|
return detect(image, ImageProcessingOptions.builder().build());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -258,10 +263,9 @@ public final class ObjectDetector extends BaseVisionTaskApi {
|
||||||
* region-of-interest.
|
* region-of-interest.
|
||||||
* @throws MediaPipeException if there is an internal error.
|
* @throws MediaPipeException if there is an internal error.
|
||||||
*/
|
*/
|
||||||
public ObjectDetectionResult detect(
|
public ObjectDetectorResult detect(MPImage image, ImageProcessingOptions imageProcessingOptions) {
|
||||||
MPImage image, ImageProcessingOptions imageProcessingOptions) {
|
|
||||||
validateImageProcessingOptions(imageProcessingOptions);
|
validateImageProcessingOptions(imageProcessingOptions);
|
||||||
return (ObjectDetectionResult) processImageData(image, imageProcessingOptions);
|
return (ObjectDetectorResult) processImageData(image, imageProcessingOptions);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -282,7 +286,7 @@ public final class ObjectDetector extends BaseVisionTaskApi {
|
||||||
* @param timestampMs the input timestamp (in milliseconds).
|
* @param timestampMs the input timestamp (in milliseconds).
|
||||||
* @throws MediaPipeException if there is an internal error.
|
* @throws MediaPipeException if there is an internal error.
|
||||||
*/
|
*/
|
||||||
public ObjectDetectionResult detectForVideo(MPImage image, long timestampMs) {
|
public ObjectDetectorResult detectForVideo(MPImage image, long timestampMs) {
|
||||||
return detectForVideo(image, ImageProcessingOptions.builder().build(), timestampMs);
|
return detectForVideo(image, ImageProcessingOptions.builder().build(), timestampMs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -309,10 +313,10 @@ public final class ObjectDetector extends BaseVisionTaskApi {
|
||||||
* region-of-interest.
|
* region-of-interest.
|
||||||
* @throws MediaPipeException if there is an internal error.
|
* @throws MediaPipeException if there is an internal error.
|
||||||
*/
|
*/
|
||||||
public ObjectDetectionResult detectForVideo(
|
public ObjectDetectorResult detectForVideo(
|
||||||
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
|
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
|
||||||
validateImageProcessingOptions(imageProcessingOptions);
|
validateImageProcessingOptions(imageProcessingOptions);
|
||||||
return (ObjectDetectionResult) processVideoData(image, imageProcessingOptions, timestampMs);
|
return (ObjectDetectorResult) processVideoData(image, imageProcessingOptions, timestampMs);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -435,7 +439,7 @@ public final class ObjectDetector extends BaseVisionTaskApi {
|
||||||
* object detector is in the live stream mode.
|
* object detector is in the live stream mode.
|
||||||
*/
|
*/
|
||||||
public abstract Builder setResultListener(
|
public abstract Builder setResultListener(
|
||||||
ResultListener<ObjectDetectionResult, MPImage> value);
|
ResultListener<ObjectDetectorResult, MPImage> value);
|
||||||
|
|
||||||
/** Sets an optional {@link ErrorListener}}. */
|
/** Sets an optional {@link ErrorListener}}. */
|
||||||
public abstract Builder setErrorListener(ErrorListener value);
|
public abstract Builder setErrorListener(ErrorListener value);
|
||||||
|
@ -476,11 +480,13 @@ public final class ObjectDetector extends BaseVisionTaskApi {
|
||||||
|
|
||||||
abstract Optional<Float> scoreThreshold();
|
abstract Optional<Float> scoreThreshold();
|
||||||
|
|
||||||
|
@SuppressWarnings("AutoValueImmutableFields")
|
||||||
abstract List<String> categoryAllowlist();
|
abstract List<String> categoryAllowlist();
|
||||||
|
|
||||||
|
@SuppressWarnings("AutoValueImmutableFields")
|
||||||
abstract List<String> categoryDenylist();
|
abstract List<String> categoryDenylist();
|
||||||
|
|
||||||
abstract Optional<ResultListener<ObjectDetectionResult, MPImage>> resultListener();
|
abstract Optional<ResultListener<ObjectDetectorResult, MPImage>> resultListener();
|
||||||
|
|
||||||
abstract Optional<ErrorListener> errorListener();
|
abstract Optional<ErrorListener> errorListener();
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,44 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package com.google.mediapipe.tasks.vision.objectdetector;
|
||||||
|
|
||||||
|
import com.google.auto.value.AutoValue;
|
||||||
|
import com.google.mediapipe.formats.proto.DetectionProto.Detection;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/** Represents the detection results generated by {@link ObjectDetector}. */
|
||||||
|
@AutoValue
|
||||||
|
@SuppressWarnings("deprecation")
|
||||||
|
public abstract class ObjectDetectorResult extends ObjectDetectionResult {
|
||||||
|
/**
|
||||||
|
* Creates an {@link ObjectDetectorResult} instance from a list of {@link Detection} protobuf
|
||||||
|
* messages.
|
||||||
|
*
|
||||||
|
* @param detectionList a list of {@link DetectionOuterClass.Detection} protobuf messages.
|
||||||
|
* @param timestampMs a timestamp for this result.
|
||||||
|
*/
|
||||||
|
public static ObjectDetectorResult create(List<Detection> detectionList, long timestampMs) {
|
||||||
|
List<com.google.mediapipe.tasks.components.containers.Detection> detections = new ArrayList<>();
|
||||||
|
for (Detection detectionProto : detectionList) {
|
||||||
|
detections.add(
|
||||||
|
com.google.mediapipe.tasks.components.containers.Detection.createFromProto(
|
||||||
|
detectionProto));
|
||||||
|
}
|
||||||
|
return new AutoValue_ObjectDetectorResult(
|
||||||
|
timestampMs, Collections.unmodifiableList(detections));
|
||||||
|
}
|
||||||
|
}
|
|
@ -69,7 +69,7 @@ public class ObjectDetectorTest {
|
||||||
.build();
|
.build();
|
||||||
ObjectDetector objectDetector =
|
ObjectDetector objectDetector =
|
||||||
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||||
assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE);
|
assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -77,7 +77,7 @@ public class ObjectDetectorTest {
|
||||||
public void detect_successWithNoOptions() throws Exception {
|
public void detect_successWithNoOptions() throws Exception {
|
||||||
ObjectDetector objectDetector =
|
ObjectDetector objectDetector =
|
||||||
ObjectDetector.createFromFile(ApplicationProvider.getApplicationContext(), MODEL_FILE);
|
ObjectDetector.createFromFile(ApplicationProvider.getApplicationContext(), MODEL_FILE);
|
||||||
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||||
// Check if the object with the highest score is cat.
|
// Check if the object with the highest score is cat.
|
||||||
assertIsCat(results.detections().get(0).categories().get(0), CAT_SCORE);
|
assertIsCat(results.detections().get(0).categories().get(0), CAT_SCORE);
|
||||||
}
|
}
|
||||||
|
@ -91,7 +91,7 @@ public class ObjectDetectorTest {
|
||||||
.build();
|
.build();
|
||||||
ObjectDetector objectDetector =
|
ObjectDetector objectDetector =
|
||||||
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||||
// results should have 8 detected objects because maxResults was set to 8.
|
// results should have 8 detected objects because maxResults was set to 8.
|
||||||
assertThat(results.detections()).hasSize(8);
|
assertThat(results.detections()).hasSize(8);
|
||||||
}
|
}
|
||||||
|
@ -105,7 +105,7 @@ public class ObjectDetectorTest {
|
||||||
.build();
|
.build();
|
||||||
ObjectDetector objectDetector =
|
ObjectDetector objectDetector =
|
||||||
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||||
// The score threshold should block all other other objects, except cat.
|
// The score threshold should block all other other objects, except cat.
|
||||||
assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE);
|
assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE);
|
||||||
}
|
}
|
||||||
|
@ -119,7 +119,7 @@ public class ObjectDetectorTest {
|
||||||
.build();
|
.build();
|
||||||
ObjectDetector objectDetector =
|
ObjectDetector objectDetector =
|
||||||
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||||
// The score threshold should block objects.
|
// The score threshold should block objects.
|
||||||
assertThat(results.detections()).isEmpty();
|
assertThat(results.detections()).isEmpty();
|
||||||
}
|
}
|
||||||
|
@ -133,7 +133,7 @@ public class ObjectDetectorTest {
|
||||||
.build();
|
.build();
|
||||||
ObjectDetector objectDetector =
|
ObjectDetector objectDetector =
|
||||||
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||||
// Because of the allowlist, results should only contain cat, and there are 6 detected
|
// Because of the allowlist, results should only contain cat, and there are 6 detected
|
||||||
// bounding boxes of cats in CAT_AND_DOG_IMAGE.
|
// bounding boxes of cats in CAT_AND_DOG_IMAGE.
|
||||||
assertThat(results.detections()).hasSize(5);
|
assertThat(results.detections()).hasSize(5);
|
||||||
|
@ -148,7 +148,7 @@ public class ObjectDetectorTest {
|
||||||
.build();
|
.build();
|
||||||
ObjectDetector objectDetector =
|
ObjectDetector objectDetector =
|
||||||
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||||
// Because of the denylist, the highest result is not cat anymore.
|
// Because of the denylist, the highest result is not cat anymore.
|
||||||
assertThat(results.detections().get(0).categories().get(0).categoryName())
|
assertThat(results.detections().get(0).categories().get(0).categoryName())
|
||||||
.isNotEqualTo("cat");
|
.isNotEqualTo("cat");
|
||||||
|
@ -160,7 +160,7 @@ public class ObjectDetectorTest {
|
||||||
ObjectDetector.createFromFile(
|
ObjectDetector.createFromFile(
|
||||||
ApplicationProvider.getApplicationContext(),
|
ApplicationProvider.getApplicationContext(),
|
||||||
TestUtils.loadFile(ApplicationProvider.getApplicationContext(), MODEL_FILE));
|
TestUtils.loadFile(ApplicationProvider.getApplicationContext(), MODEL_FILE));
|
||||||
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||||
// Check if the object with the highest score is cat.
|
// Check if the object with the highest score is cat.
|
||||||
assertIsCat(results.detections().get(0).categories().get(0), CAT_SCORE);
|
assertIsCat(results.detections().get(0).categories().get(0), CAT_SCORE);
|
||||||
}
|
}
|
||||||
|
@ -172,7 +172,7 @@ public class ObjectDetectorTest {
|
||||||
ApplicationProvider.getApplicationContext(),
|
ApplicationProvider.getApplicationContext(),
|
||||||
TestUtils.loadToDirectByteBuffer(
|
TestUtils.loadToDirectByteBuffer(
|
||||||
ApplicationProvider.getApplicationContext(), MODEL_FILE));
|
ApplicationProvider.getApplicationContext(), MODEL_FILE));
|
||||||
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||||
// Check if the object with the highest score is cat.
|
// Check if the object with the highest score is cat.
|
||||||
assertIsCat(results.detections().get(0).categories().get(0), CAT_SCORE);
|
assertIsCat(results.detections().get(0).categories().get(0), CAT_SCORE);
|
||||||
}
|
}
|
||||||
|
@ -191,7 +191,7 @@ public class ObjectDetectorTest {
|
||||||
.build();
|
.build();
|
||||||
ObjectDetector objectDetector =
|
ObjectDetector objectDetector =
|
||||||
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||||
assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE);
|
assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -256,7 +256,7 @@ public class ObjectDetectorTest {
|
||||||
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
ImageProcessingOptions imageProcessingOptions =
|
ImageProcessingOptions imageProcessingOptions =
|
||||||
ImageProcessingOptions.builder().setRotationDegrees(-90).build();
|
ImageProcessingOptions.builder().setRotationDegrees(-90).build();
|
||||||
ObjectDetectionResult results =
|
ObjectDetectorResult results =
|
||||||
objectDetector.detect(
|
objectDetector.detect(
|
||||||
getImageFromAsset(CAT_AND_DOG_ROTATED_IMAGE), imageProcessingOptions);
|
getImageFromAsset(CAT_AND_DOG_ROTATED_IMAGE), imageProcessingOptions);
|
||||||
|
|
||||||
|
@ -302,7 +302,7 @@ public class ObjectDetectorTest {
|
||||||
ObjectDetectorOptions.builder()
|
ObjectDetectorOptions.builder()
|
||||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
|
||||||
.setRunningMode(mode)
|
.setRunningMode(mode)
|
||||||
.setResultListener((objectDetectionResult, inputImage) -> {})
|
.setResultListener((ObjectDetectorResult, inputImage) -> {})
|
||||||
.build());
|
.build());
|
||||||
assertThat(exception)
|
assertThat(exception)
|
||||||
.hasMessageThat()
|
.hasMessageThat()
|
||||||
|
@ -381,7 +381,7 @@ public class ObjectDetectorTest {
|
||||||
ObjectDetectorOptions.builder()
|
ObjectDetectorOptions.builder()
|
||||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
|
||||||
.setRunningMode(RunningMode.LIVE_STREAM)
|
.setRunningMode(RunningMode.LIVE_STREAM)
|
||||||
.setResultListener((objectDetectionResult, inputImage) -> {})
|
.setResultListener((ObjectDetectorResult, inputImage) -> {})
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
ObjectDetector objectDetector =
|
ObjectDetector objectDetector =
|
||||||
|
@ -411,7 +411,7 @@ public class ObjectDetectorTest {
|
||||||
.build();
|
.build();
|
||||||
ObjectDetector objectDetector =
|
ObjectDetector objectDetector =
|
||||||
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||||
assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE);
|
assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -426,7 +426,7 @@ public class ObjectDetectorTest {
|
||||||
ObjectDetector objectDetector =
|
ObjectDetector objectDetector =
|
||||||
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
for (int i = 0; i < 3; i++) {
|
for (int i = 0; i < 3; i++) {
|
||||||
ObjectDetectionResult results =
|
ObjectDetectorResult results =
|
||||||
objectDetector.detectForVideo(
|
objectDetector.detectForVideo(
|
||||||
getImageFromAsset(CAT_AND_DOG_IMAGE), /* timestampsMs= */ i);
|
getImageFromAsset(CAT_AND_DOG_IMAGE), /* timestampsMs= */ i);
|
||||||
assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE);
|
assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE);
|
||||||
|
@ -441,8 +441,8 @@ public class ObjectDetectorTest {
|
||||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
|
||||||
.setRunningMode(RunningMode.LIVE_STREAM)
|
.setRunningMode(RunningMode.LIVE_STREAM)
|
||||||
.setResultListener(
|
.setResultListener(
|
||||||
(objectDetectionResult, inputImage) -> {
|
(ObjectDetectorResult, inputImage) -> {
|
||||||
assertContainsOnlyCat(objectDetectionResult, CAT_BOUNDING_BOX, CAT_SCORE);
|
assertContainsOnlyCat(ObjectDetectorResult, CAT_BOUNDING_BOX, CAT_SCORE);
|
||||||
assertImageSizeIsExpected(inputImage);
|
assertImageSizeIsExpected(inputImage);
|
||||||
})
|
})
|
||||||
.setMaxResults(1)
|
.setMaxResults(1)
|
||||||
|
@ -468,8 +468,8 @@ public class ObjectDetectorTest {
|
||||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
|
||||||
.setRunningMode(RunningMode.LIVE_STREAM)
|
.setRunningMode(RunningMode.LIVE_STREAM)
|
||||||
.setResultListener(
|
.setResultListener(
|
||||||
(objectDetectionResult, inputImage) -> {
|
(ObjectDetectorResult, inputImage) -> {
|
||||||
assertContainsOnlyCat(objectDetectionResult, CAT_BOUNDING_BOX, CAT_SCORE);
|
assertContainsOnlyCat(ObjectDetectorResult, CAT_BOUNDING_BOX, CAT_SCORE);
|
||||||
assertImageSizeIsExpected(inputImage);
|
assertImageSizeIsExpected(inputImage);
|
||||||
})
|
})
|
||||||
.setMaxResults(1)
|
.setMaxResults(1)
|
||||||
|
@ -483,6 +483,16 @@ public class ObjectDetectorTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
@SuppressWarnings("deprecation")
|
||||||
|
public void detect_canUseDeprecatedApi() throws Exception {
|
||||||
|
ObjectDetector objectDetector =
|
||||||
|
ObjectDetector.createFromFile(ApplicationProvider.getApplicationContext(), MODEL_FILE);
|
||||||
|
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||||
|
// Check if the object with the highest score is cat.
|
||||||
|
assertIsCat(results.detections().get(0).categories().get(0), CAT_SCORE);
|
||||||
|
}
|
||||||
|
|
||||||
private static MPImage getImageFromAsset(String filePath) throws Exception {
|
private static MPImage getImageFromAsset(String filePath) throws Exception {
|
||||||
AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets();
|
AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets();
|
||||||
InputStream istr = assetManager.open(filePath);
|
InputStream istr = assetManager.open(filePath);
|
||||||
|
@ -491,7 +501,7 @@ public class ObjectDetectorTest {
|
||||||
|
|
||||||
// Checks if results has one and only detection result, which is a cat.
|
// Checks if results has one and only detection result, which is a cat.
|
||||||
private static void assertContainsOnlyCat(
|
private static void assertContainsOnlyCat(
|
||||||
ObjectDetectionResult result, RectF expectedBoundingBox, float expectedScore) {
|
ObjectDetectorResult result, RectF expectedBoundingBox, float expectedScore) {
|
||||||
assertThat(result.detections()).hasSize(1);
|
assertThat(result.detections()).hasSize(1);
|
||||||
Detection catResult = result.detections().get(0);
|
Detection catResult = result.detections().get(0);
|
||||||
assertApproximatelyEqualBoundingBoxes(catResult.boundingBox(), expectedBoundingBox);
|
assertApproximatelyEqualBoundingBoxes(catResult.boundingBox(), expectedBoundingBox);
|
||||||
|
|
Loading…
Reference in New Issue
Block a user