Rename ObjectDetctionResult to ObjectDetectorResult
PiperOrigin-RevId: 534858600
This commit is contained in:
parent
2017fcc9ab
commit
d9f316e12a
|
@ -71,6 +71,7 @@ android_library(
|
|||
srcs = [
|
||||
"objectdetector/ObjectDetectionResult.java",
|
||||
"objectdetector/ObjectDetector.java",
|
||||
"objectdetector/ObjectDetectorResult.java",
|
||||
],
|
||||
javacopts = [
|
||||
"-Xep:AndroidJdkLibsChecker:OFF",
|
||||
|
|
|
@ -14,15 +14,16 @@
|
|||
|
||||
package com.google.mediapipe.tasks.vision.objectdetector;
|
||||
|
||||
import com.google.auto.value.AutoValue;
|
||||
import com.google.mediapipe.tasks.core.TaskResult;
|
||||
import com.google.mediapipe.formats.proto.DetectionProto.Detection;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
/** Represents the detection results generated by {@link ObjectDetector}. */
|
||||
@AutoValue
|
||||
/**
|
||||
* Represents the detection results generated by {@link ObjectDetector}.
|
||||
*
|
||||
* @deprecated Use {@link ObjectDetectorResult} instead.
|
||||
*/
|
||||
@Deprecated
|
||||
public abstract class ObjectDetectionResult implements TaskResult {
|
||||
|
||||
@Override
|
||||
|
@ -36,15 +37,10 @@ public abstract class ObjectDetectionResult implements TaskResult {
|
|||
*
|
||||
* @param detectionList a list of {@link DetectionOuterClass.Detection} protobuf messages.
|
||||
* @param timestampMs a timestamp for this result.
|
||||
* @deprecated Use {@link ObjectDetectorResult#create} instead.
|
||||
*/
|
||||
@Deprecated
|
||||
public static ObjectDetectionResult create(List<Detection> detectionList, long timestampMs) {
|
||||
List<com.google.mediapipe.tasks.components.containers.Detection> detections = new ArrayList<>();
|
||||
for (Detection detectionProto : detectionList) {
|
||||
detections.add(
|
||||
com.google.mediapipe.tasks.components.containers.Detection.createFromProto(
|
||||
detectionProto));
|
||||
}
|
||||
return new AutoValue_ObjectDetectionResult(
|
||||
timestampMs, Collections.unmodifiableList(detections));
|
||||
return ObjectDetectorResult.create(detectionList, timestampMs);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -99,11 +99,16 @@ public final class ObjectDetector extends BaseVisionTaskApi {
|
|||
private static final String TAG = ObjectDetector.class.getSimpleName();
|
||||
private static final String IMAGE_IN_STREAM_NAME = "image_in";
|
||||
private static final String NORM_RECT_IN_STREAM_NAME = "norm_rect_in";
|
||||
|
||||
@SuppressWarnings("ConstantCaseForConstants")
|
||||
private static final List<String> INPUT_STREAMS =
|
||||
Collections.unmodifiableList(
|
||||
Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME));
|
||||
|
||||
@SuppressWarnings("ConstantCaseForConstants")
|
||||
private static final List<String> OUTPUT_STREAMS =
|
||||
Collections.unmodifiableList(Arrays.asList("DETECTIONS:detections_out", "IMAGE:image_out"));
|
||||
|
||||
private static final int DETECTIONS_OUT_STREAM_INDEX = 0;
|
||||
private static final int IMAGE_OUT_STREAM_INDEX = 1;
|
||||
private static final String TASK_GRAPH_NAME = "mediapipe.tasks.vision.ObjectDetectorGraph";
|
||||
|
@ -166,19 +171,19 @@ public final class ObjectDetector extends BaseVisionTaskApi {
|
|||
public static ObjectDetector createFromOptions(
|
||||
Context context, ObjectDetectorOptions detectorOptions) {
|
||||
// TODO: Consolidate OutputHandler and TaskRunner.
|
||||
OutputHandler<ObjectDetectionResult, MPImage> handler = new OutputHandler<>();
|
||||
OutputHandler<ObjectDetectorResult, MPImage> handler = new OutputHandler<>();
|
||||
handler.setOutputPacketConverter(
|
||||
new OutputHandler.OutputPacketConverter<ObjectDetectionResult, MPImage>() {
|
||||
new OutputHandler.OutputPacketConverter<ObjectDetectorResult, MPImage>() {
|
||||
@Override
|
||||
public ObjectDetectionResult convertToTaskResult(List<Packet> packets) {
|
||||
public ObjectDetectorResult convertToTaskResult(List<Packet> packets) {
|
||||
// If there is no object detected in the image, just returns empty lists.
|
||||
if (packets.get(DETECTIONS_OUT_STREAM_INDEX).isEmpty()) {
|
||||
return ObjectDetectionResult.create(
|
||||
return ObjectDetectorResult.create(
|
||||
new ArrayList<>(),
|
||||
BaseVisionTaskApi.generateResultTimestampMs(
|
||||
detectorOptions.runningMode(), packets.get(DETECTIONS_OUT_STREAM_INDEX)));
|
||||
}
|
||||
return ObjectDetectionResult.create(
|
||||
return ObjectDetectorResult.create(
|
||||
PacketGetter.getProtoVector(
|
||||
packets.get(DETECTIONS_OUT_STREAM_INDEX), Detection.parser()),
|
||||
BaseVisionTaskApi.generateResultTimestampMs(
|
||||
|
@ -235,7 +240,7 @@ public final class ObjectDetector extends BaseVisionTaskApi {
|
|||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public ObjectDetectionResult detect(MPImage image) {
|
||||
public ObjectDetectorResult detect(MPImage image) {
|
||||
return detect(image, ImageProcessingOptions.builder().build());
|
||||
}
|
||||
|
||||
|
@ -258,10 +263,9 @@ public final class ObjectDetector extends BaseVisionTaskApi {
|
|||
* region-of-interest.
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public ObjectDetectionResult detect(
|
||||
MPImage image, ImageProcessingOptions imageProcessingOptions) {
|
||||
public ObjectDetectorResult detect(MPImage image, ImageProcessingOptions imageProcessingOptions) {
|
||||
validateImageProcessingOptions(imageProcessingOptions);
|
||||
return (ObjectDetectionResult) processImageData(image, imageProcessingOptions);
|
||||
return (ObjectDetectorResult) processImageData(image, imageProcessingOptions);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -282,7 +286,7 @@ public final class ObjectDetector extends BaseVisionTaskApi {
|
|||
* @param timestampMs the input timestamp (in milliseconds).
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public ObjectDetectionResult detectForVideo(MPImage image, long timestampMs) {
|
||||
public ObjectDetectorResult detectForVideo(MPImage image, long timestampMs) {
|
||||
return detectForVideo(image, ImageProcessingOptions.builder().build(), timestampMs);
|
||||
}
|
||||
|
||||
|
@ -309,10 +313,10 @@ public final class ObjectDetector extends BaseVisionTaskApi {
|
|||
* region-of-interest.
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public ObjectDetectionResult detectForVideo(
|
||||
public ObjectDetectorResult detectForVideo(
|
||||
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
|
||||
validateImageProcessingOptions(imageProcessingOptions);
|
||||
return (ObjectDetectionResult) processVideoData(image, imageProcessingOptions, timestampMs);
|
||||
return (ObjectDetectorResult) processVideoData(image, imageProcessingOptions, timestampMs);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -435,7 +439,7 @@ public final class ObjectDetector extends BaseVisionTaskApi {
|
|||
* object detector is in the live stream mode.
|
||||
*/
|
||||
public abstract Builder setResultListener(
|
||||
ResultListener<ObjectDetectionResult, MPImage> value);
|
||||
ResultListener<ObjectDetectorResult, MPImage> value);
|
||||
|
||||
/** Sets an optional {@link ErrorListener}}. */
|
||||
public abstract Builder setErrorListener(ErrorListener value);
|
||||
|
@ -476,11 +480,13 @@ public final class ObjectDetector extends BaseVisionTaskApi {
|
|||
|
||||
abstract Optional<Float> scoreThreshold();
|
||||
|
||||
@SuppressWarnings("AutoValueImmutableFields")
|
||||
abstract List<String> categoryAllowlist();
|
||||
|
||||
@SuppressWarnings("AutoValueImmutableFields")
|
||||
abstract List<String> categoryDenylist();
|
||||
|
||||
abstract Optional<ResultListener<ObjectDetectionResult, MPImage>> resultListener();
|
||||
abstract Optional<ResultListener<ObjectDetectorResult, MPImage>> resultListener();
|
||||
|
||||
abstract Optional<ErrorListener> errorListener();
|
||||
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
// Copyright 2022 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.vision.objectdetector;
|
||||
|
||||
import com.google.auto.value.AutoValue;
|
||||
import com.google.mediapipe.formats.proto.DetectionProto.Detection;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
/** Represents the detection results generated by {@link ObjectDetector}. */
|
||||
@AutoValue
|
||||
@SuppressWarnings("deprecation")
|
||||
public abstract class ObjectDetectorResult extends ObjectDetectionResult {
|
||||
/**
|
||||
* Creates an {@link ObjectDetectorResult} instance from a list of {@link Detection} protobuf
|
||||
* messages.
|
||||
*
|
||||
* @param detectionList a list of {@link DetectionOuterClass.Detection} protobuf messages.
|
||||
* @param timestampMs a timestamp for this result.
|
||||
*/
|
||||
public static ObjectDetectorResult create(List<Detection> detectionList, long timestampMs) {
|
||||
List<com.google.mediapipe.tasks.components.containers.Detection> detections = new ArrayList<>();
|
||||
for (Detection detectionProto : detectionList) {
|
||||
detections.add(
|
||||
com.google.mediapipe.tasks.components.containers.Detection.createFromProto(
|
||||
detectionProto));
|
||||
}
|
||||
return new AutoValue_ObjectDetectorResult(
|
||||
timestampMs, Collections.unmodifiableList(detections));
|
||||
}
|
||||
}
|
|
@ -69,7 +69,7 @@ public class ObjectDetectorTest {
|
|||
.build();
|
||||
ObjectDetector objectDetector =
|
||||
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||
ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||
assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE);
|
||||
}
|
||||
|
||||
|
@ -77,7 +77,7 @@ public class ObjectDetectorTest {
|
|||
public void detect_successWithNoOptions() throws Exception {
|
||||
ObjectDetector objectDetector =
|
||||
ObjectDetector.createFromFile(ApplicationProvider.getApplicationContext(), MODEL_FILE);
|
||||
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||
ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||
// Check if the object with the highest score is cat.
|
||||
assertIsCat(results.detections().get(0).categories().get(0), CAT_SCORE);
|
||||
}
|
||||
|
@ -91,7 +91,7 @@ public class ObjectDetectorTest {
|
|||
.build();
|
||||
ObjectDetector objectDetector =
|
||||
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||
ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||
// results should have 8 detected objects because maxResults was set to 8.
|
||||
assertThat(results.detections()).hasSize(8);
|
||||
}
|
||||
|
@ -105,7 +105,7 @@ public class ObjectDetectorTest {
|
|||
.build();
|
||||
ObjectDetector objectDetector =
|
||||
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||
ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||
// The score threshold should block all other other objects, except cat.
|
||||
assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE);
|
||||
}
|
||||
|
@ -119,7 +119,7 @@ public class ObjectDetectorTest {
|
|||
.build();
|
||||
ObjectDetector objectDetector =
|
||||
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||
ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||
// The score threshold should block objects.
|
||||
assertThat(results.detections()).isEmpty();
|
||||
}
|
||||
|
@ -133,7 +133,7 @@ public class ObjectDetectorTest {
|
|||
.build();
|
||||
ObjectDetector objectDetector =
|
||||
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||
ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||
// Because of the allowlist, results should only contain cat, and there are 6 detected
|
||||
// bounding boxes of cats in CAT_AND_DOG_IMAGE.
|
||||
assertThat(results.detections()).hasSize(5);
|
||||
|
@ -148,7 +148,7 @@ public class ObjectDetectorTest {
|
|||
.build();
|
||||
ObjectDetector objectDetector =
|
||||
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||
ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||
// Because of the denylist, the highest result is not cat anymore.
|
||||
assertThat(results.detections().get(0).categories().get(0).categoryName())
|
||||
.isNotEqualTo("cat");
|
||||
|
@ -160,7 +160,7 @@ public class ObjectDetectorTest {
|
|||
ObjectDetector.createFromFile(
|
||||
ApplicationProvider.getApplicationContext(),
|
||||
TestUtils.loadFile(ApplicationProvider.getApplicationContext(), MODEL_FILE));
|
||||
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||
ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||
// Check if the object with the highest score is cat.
|
||||
assertIsCat(results.detections().get(0).categories().get(0), CAT_SCORE);
|
||||
}
|
||||
|
@ -172,7 +172,7 @@ public class ObjectDetectorTest {
|
|||
ApplicationProvider.getApplicationContext(),
|
||||
TestUtils.loadToDirectByteBuffer(
|
||||
ApplicationProvider.getApplicationContext(), MODEL_FILE));
|
||||
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||
ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||
// Check if the object with the highest score is cat.
|
||||
assertIsCat(results.detections().get(0).categories().get(0), CAT_SCORE);
|
||||
}
|
||||
|
@ -191,7 +191,7 @@ public class ObjectDetectorTest {
|
|||
.build();
|
||||
ObjectDetector objectDetector =
|
||||
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||
ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||
assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE);
|
||||
}
|
||||
|
||||
|
@ -256,7 +256,7 @@ public class ObjectDetectorTest {
|
|||
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
ImageProcessingOptions imageProcessingOptions =
|
||||
ImageProcessingOptions.builder().setRotationDegrees(-90).build();
|
||||
ObjectDetectionResult results =
|
||||
ObjectDetectorResult results =
|
||||
objectDetector.detect(
|
||||
getImageFromAsset(CAT_AND_DOG_ROTATED_IMAGE), imageProcessingOptions);
|
||||
|
||||
|
@ -302,7 +302,7 @@ public class ObjectDetectorTest {
|
|||
ObjectDetectorOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
|
||||
.setRunningMode(mode)
|
||||
.setResultListener((objectDetectionResult, inputImage) -> {})
|
||||
.setResultListener((ObjectDetectorResult, inputImage) -> {})
|
||||
.build());
|
||||
assertThat(exception)
|
||||
.hasMessageThat()
|
||||
|
@ -381,7 +381,7 @@ public class ObjectDetectorTest {
|
|||
ObjectDetectorOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
|
||||
.setRunningMode(RunningMode.LIVE_STREAM)
|
||||
.setResultListener((objectDetectionResult, inputImage) -> {})
|
||||
.setResultListener((ObjectDetectorResult, inputImage) -> {})
|
||||
.build();
|
||||
|
||||
ObjectDetector objectDetector =
|
||||
|
@ -411,7 +411,7 @@ public class ObjectDetectorTest {
|
|||
.build();
|
||||
ObjectDetector objectDetector =
|
||||
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||
ObjectDetectorResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||
assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE);
|
||||
}
|
||||
|
||||
|
@ -426,7 +426,7 @@ public class ObjectDetectorTest {
|
|||
ObjectDetector objectDetector =
|
||||
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
for (int i = 0; i < 3; i++) {
|
||||
ObjectDetectionResult results =
|
||||
ObjectDetectorResult results =
|
||||
objectDetector.detectForVideo(
|
||||
getImageFromAsset(CAT_AND_DOG_IMAGE), /* timestampsMs= */ i);
|
||||
assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE);
|
||||
|
@ -441,8 +441,8 @@ public class ObjectDetectorTest {
|
|||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
|
||||
.setRunningMode(RunningMode.LIVE_STREAM)
|
||||
.setResultListener(
|
||||
(objectDetectionResult, inputImage) -> {
|
||||
assertContainsOnlyCat(objectDetectionResult, CAT_BOUNDING_BOX, CAT_SCORE);
|
||||
(ObjectDetectorResult, inputImage) -> {
|
||||
assertContainsOnlyCat(ObjectDetectorResult, CAT_BOUNDING_BOX, CAT_SCORE);
|
||||
assertImageSizeIsExpected(inputImage);
|
||||
})
|
||||
.setMaxResults(1)
|
||||
|
@ -468,8 +468,8 @@ public class ObjectDetectorTest {
|
|||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
|
||||
.setRunningMode(RunningMode.LIVE_STREAM)
|
||||
.setResultListener(
|
||||
(objectDetectionResult, inputImage) -> {
|
||||
assertContainsOnlyCat(objectDetectionResult, CAT_BOUNDING_BOX, CAT_SCORE);
|
||||
(ObjectDetectorResult, inputImage) -> {
|
||||
assertContainsOnlyCat(ObjectDetectorResult, CAT_BOUNDING_BOX, CAT_SCORE);
|
||||
assertImageSizeIsExpected(inputImage);
|
||||
})
|
||||
.setMaxResults(1)
|
||||
|
@ -483,6 +483,16 @@ public class ObjectDetectorTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
@SuppressWarnings("deprecation")
|
||||
public void detect_canUseDeprecatedApi() throws Exception {
|
||||
ObjectDetector objectDetector =
|
||||
ObjectDetector.createFromFile(ApplicationProvider.getApplicationContext(), MODEL_FILE);
|
||||
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||
// Check if the object with the highest score is cat.
|
||||
assertIsCat(results.detections().get(0).categories().get(0), CAT_SCORE);
|
||||
}
|
||||
|
||||
private static MPImage getImageFromAsset(String filePath) throws Exception {
|
||||
AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets();
|
||||
InputStream istr = assetManager.open(filePath);
|
||||
|
@ -491,7 +501,7 @@ public class ObjectDetectorTest {
|
|||
|
||||
// Checks if results has one and only detection result, which is a cat.
|
||||
private static void assertContainsOnlyCat(
|
||||
ObjectDetectionResult result, RectF expectedBoundingBox, float expectedScore) {
|
||||
ObjectDetectorResult result, RectF expectedBoundingBox, float expectedScore) {
|
||||
assertThat(result.detections()).hasSize(1);
|
||||
Detection catResult = result.detections().get(0);
|
||||
assertApproximatelyEqualBoundingBoxes(catResult.boundingBox(), expectedBoundingBox);
|
||||
|
|
Loading…
Reference in New Issue
Block a user