Object detector handle empty packet when no object is detected.
PiperOrigin-RevId: 529919638
This commit is contained in:
parent
cc8847def5
commit
800a7b4a27
|
@ -129,9 +129,17 @@ absl::StatusOr<std::unique_ptr<ObjectDetector>> ObjectDetector::Create(
|
||||||
if (status_or_packets.value()[kImageOutStreamName].IsEmpty()) {
|
if (status_or_packets.value()[kImageOutStreamName].IsEmpty()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
Packet image_packet = status_or_packets.value()[kImageOutStreamName];
|
||||||
Packet detections_packet =
|
Packet detections_packet =
|
||||||
status_or_packets.value()[kDetectionsOutStreamName];
|
status_or_packets.value()[kDetectionsOutStreamName];
|
||||||
Packet image_packet = status_or_packets.value()[kImageOutStreamName];
|
if (detections_packet.IsEmpty()) {
|
||||||
|
Packet empty_packet =
|
||||||
|
status_or_packets.value()[kDetectionsOutStreamName];
|
||||||
|
result_callback(
|
||||||
|
{ConvertToDetectionResult({})}, image_packet.Get<Image>(),
|
||||||
|
empty_packet.Timestamp().Value() / kMicroSecondsPerMilliSecond);
|
||||||
|
return;
|
||||||
|
}
|
||||||
result_callback(ConvertToDetectionResult(
|
result_callback(ConvertToDetectionResult(
|
||||||
detections_packet.Get<std::vector<Detection>>()),
|
detections_packet.Get<std::vector<Detection>>()),
|
||||||
image_packet.Get<Image>(),
|
image_packet.Get<Image>(),
|
||||||
|
@ -165,6 +173,9 @@ absl::StatusOr<ObjectDetectorResult> ObjectDetector::Detect(
|
||||||
ProcessImageData(
|
ProcessImageData(
|
||||||
{{kImageInStreamName, MakePacket<Image>(std::move(image))},
|
{{kImageInStreamName, MakePacket<Image>(std::move(image))},
|
||||||
{kNormRectName, MakePacket<NormalizedRect>(std::move(norm_rect))}}));
|
{kNormRectName, MakePacket<NormalizedRect>(std::move(norm_rect))}}));
|
||||||
|
if (output_packets[kDetectionsOutStreamName].IsEmpty()) {
|
||||||
|
return {ConvertToDetectionResult({})};
|
||||||
|
}
|
||||||
return ConvertToDetectionResult(
|
return ConvertToDetectionResult(
|
||||||
output_packets[kDetectionsOutStreamName].Get<std::vector<Detection>>());
|
output_packets[kDetectionsOutStreamName].Get<std::vector<Detection>>());
|
||||||
}
|
}
|
||||||
|
@ -190,6 +201,9 @@ absl::StatusOr<ObjectDetectorResult> ObjectDetector::DetectForVideo(
|
||||||
{kNormRectName,
|
{kNormRectName,
|
||||||
MakePacket<NormalizedRect>(std::move(norm_rect))
|
MakePacket<NormalizedRect>(std::move(norm_rect))
|
||||||
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}));
|
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}));
|
||||||
|
if (output_packets[kDetectionsOutStreamName].IsEmpty()) {
|
||||||
|
return {ConvertToDetectionResult({})};
|
||||||
|
}
|
||||||
return ConvertToDetectionResult(
|
return ConvertToDetectionResult(
|
||||||
output_packets[kDetectionsOutStreamName].Get<std::vector<Detection>>());
|
output_packets[kDetectionsOutStreamName].Get<std::vector<Detection>>());
|
||||||
}
|
}
|
||||||
|
|
|
@ -499,6 +499,22 @@ TEST_F(ImageModeTest, SucceedsEfficientDetNoNmsModel) {
|
||||||
})pb")}));
|
})pb")}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(ImageModeTest, SucceedsNoObjectDetected) {
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(Image image,
|
||||||
|
DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
|
||||||
|
"cats_and_dogs.jpg")));
|
||||||
|
auto options = std::make_unique<ObjectDetectorOptions>();
|
||||||
|
options->max_results = 4;
|
||||||
|
options->score_threshold = 1.0f;
|
||||||
|
options->base_options.model_asset_path =
|
||||||
|
JoinPath("./", kTestDataDirectory, kEfficientDetWithoutNms);
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector,
|
||||||
|
ObjectDetector::Create(std::move(options)));
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image));
|
||||||
|
MP_ASSERT_OK(object_detector->Close());
|
||||||
|
EXPECT_THAT(results.detections, testing::IsEmpty());
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(ImageModeTest, SucceedsWithoutImageResizing) {
|
TEST_F(ImageModeTest, SucceedsWithoutImageResizing) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(Image image, DecodeImageFromFile(JoinPath(
|
MP_ASSERT_OK_AND_ASSIGN(Image image, DecodeImageFromFile(JoinPath(
|
||||||
"./", kTestDataDirectory,
|
"./", kTestDataDirectory,
|
||||||
|
|
|
@ -39,6 +39,7 @@ import com.google.mediapipe.formats.proto.DetectionProto.Detection;
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.nio.ByteBuffer;
|
import java.nio.ByteBuffer;
|
||||||
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
@ -170,6 +171,13 @@ public final class ObjectDetector extends BaseVisionTaskApi {
|
||||||
new OutputHandler.OutputPacketConverter<ObjectDetectionResult, MPImage>() {
|
new OutputHandler.OutputPacketConverter<ObjectDetectionResult, MPImage>() {
|
||||||
@Override
|
@Override
|
||||||
public ObjectDetectionResult convertToTaskResult(List<Packet> packets) {
|
public ObjectDetectionResult convertToTaskResult(List<Packet> packets) {
|
||||||
|
// If there is no object detected in the image, just returns empty lists.
|
||||||
|
if (packets.get(DETECTIONS_OUT_STREAM_INDEX).isEmpty()) {
|
||||||
|
return ObjectDetectionResult.create(
|
||||||
|
new ArrayList<>(),
|
||||||
|
BaseVisionTaskApi.generateResultTimestampMs(
|
||||||
|
detectorOptions.runningMode(), packets.get(DETECTIONS_OUT_STREAM_INDEX)));
|
||||||
|
}
|
||||||
return ObjectDetectionResult.create(
|
return ObjectDetectionResult.create(
|
||||||
PacketGetter.getProtoVector(
|
PacketGetter.getProtoVector(
|
||||||
packets.get(DETECTIONS_OUT_STREAM_INDEX), Detection.parser()),
|
packets.get(DETECTIONS_OUT_STREAM_INDEX), Detection.parser()),
|
||||||
|
|
|
@ -45,6 +45,7 @@ import org.junit.runners.Suite.SuiteClasses;
|
||||||
@SuiteClasses({ObjectDetectorTest.General.class, ObjectDetectorTest.RunningModeTest.class})
|
@SuiteClasses({ObjectDetectorTest.General.class, ObjectDetectorTest.RunningModeTest.class})
|
||||||
public class ObjectDetectorTest {
|
public class ObjectDetectorTest {
|
||||||
private static final String MODEL_FILE = "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite";
|
private static final String MODEL_FILE = "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite";
|
||||||
|
private static final String NO_NMS_MODEL_FILE = "efficientdet_lite0_fp16_no_nms.tflite";
|
||||||
private static final String CAT_AND_DOG_IMAGE = "cats_and_dogs.jpg";
|
private static final String CAT_AND_DOG_IMAGE = "cats_and_dogs.jpg";
|
||||||
private static final String CAT_AND_DOG_ROTATED_IMAGE = "cats_and_dogs_rotated.jpg";
|
private static final String CAT_AND_DOG_ROTATED_IMAGE = "cats_and_dogs_rotated.jpg";
|
||||||
private static final int IMAGE_WIDTH = 1200;
|
private static final int IMAGE_WIDTH = 1200;
|
||||||
|
@ -109,6 +110,20 @@ public class ObjectDetectorTest {
|
||||||
assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE);
|
assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void detect_succeedsWithNoObjectDetected() throws Exception {
|
||||||
|
ObjectDetectorOptions options =
|
||||||
|
ObjectDetectorOptions.builder()
|
||||||
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(NO_NMS_MODEL_FILE).build())
|
||||||
|
.setScoreThreshold(1.0f)
|
||||||
|
.build();
|
||||||
|
ObjectDetector objectDetector =
|
||||||
|
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
|
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||||
|
// The score threshold should block objects.
|
||||||
|
assertThat(results.detections()).isEmpty();
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void detect_succeedsWithAllowListOption() throws Exception {
|
public void detect_succeedsWithAllowListOption() throws Exception {
|
||||||
ObjectDetectorOptions options =
|
ObjectDetectorOptions options =
|
||||||
|
|
|
@ -44,6 +44,7 @@ _ObjectDetectorOptions = object_detector.ObjectDetectorOptions
|
||||||
_RUNNING_MODE = running_mode_module.VisionTaskRunningMode
|
_RUNNING_MODE = running_mode_module.VisionTaskRunningMode
|
||||||
|
|
||||||
_MODEL_FILE = 'coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite'
|
_MODEL_FILE = 'coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite'
|
||||||
|
_NO_NMS_MODEL_FILE = 'efficientdet_lite0_fp16_no_nms.tflite'
|
||||||
_IMAGE_FILE = 'cats_and_dogs.jpg'
|
_IMAGE_FILE = 'cats_and_dogs.jpg'
|
||||||
_EXPECTED_DETECTION_RESULT = _DetectionResult(
|
_EXPECTED_DETECTION_RESULT = _DetectionResult(
|
||||||
detections=[
|
detections=[
|
||||||
|
@ -304,7 +305,7 @@ class ObjectDetectorTest(parameterized.TestCase):
|
||||||
with _ObjectDetector.create_from_options(options) as unused_detector:
|
with _ObjectDetector.create_from_options(options) as unused_detector:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def test_empty_detection_outputs(self):
|
def test_empty_detection_outputs_with_in_model_nms(self):
|
||||||
options = _ObjectDetectorOptions(
|
options = _ObjectDetectorOptions(
|
||||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
score_threshold=1,
|
score_threshold=1,
|
||||||
|
@ -314,6 +315,18 @@ class ObjectDetectorTest(parameterized.TestCase):
|
||||||
detection_result = detector.detect(self.test_image)
|
detection_result = detector.detect(self.test_image)
|
||||||
self.assertEmpty(detection_result.detections)
|
self.assertEmpty(detection_result.detections)
|
||||||
|
|
||||||
|
def test_empty_detection_outputs_without_in_model_nms(self):
|
||||||
|
options = _ObjectDetectorOptions(
|
||||||
|
base_options=_BaseOptions(
|
||||||
|
model_asset_path=test_utils.get_test_data_path(
|
||||||
|
os.path.join(_TEST_DATA_DIR, _NO_NMS_MODEL_FILE))),
|
||||||
|
score_threshold=1,
|
||||||
|
)
|
||||||
|
with _ObjectDetector.create_from_options(options) as detector:
|
||||||
|
# Performs object detection on the input.
|
||||||
|
detection_result = detector.detect(self.test_image)
|
||||||
|
self.assertEmpty(detection_result.detections)
|
||||||
|
|
||||||
def test_missing_result_callback(self):
|
def test_missing_result_callback(self):
|
||||||
options = _ObjectDetectorOptions(
|
options = _ObjectDetectorOptions(
|
||||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
|
|
@ -198,6 +198,15 @@ class ObjectDetector(base_vision_task_api.BaseVisionTaskApi):
|
||||||
def packets_callback(output_packets: Mapping[str, packet_module.Packet]):
|
def packets_callback(output_packets: Mapping[str, packet_module.Packet]):
|
||||||
if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty():
|
if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty():
|
||||||
return
|
return
|
||||||
|
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
|
||||||
|
if output_packets[_DETECTIONS_OUT_STREAM_NAME].is_empty():
|
||||||
|
empty_packet = output_packets[_DETECTIONS_OUT_STREAM_NAME]
|
||||||
|
options.result_callback(
|
||||||
|
ObjectDetectorResult([]),
|
||||||
|
image,
|
||||||
|
empty_packet.timestamp.value // _MICRO_SECONDS_PER_MILLISECOND,
|
||||||
|
)
|
||||||
|
return
|
||||||
detection_proto_list = packet_getter.get_proto_list(
|
detection_proto_list = packet_getter.get_proto_list(
|
||||||
output_packets[_DETECTIONS_OUT_STREAM_NAME]
|
output_packets[_DETECTIONS_OUT_STREAM_NAME]
|
||||||
)
|
)
|
||||||
|
@ -207,7 +216,6 @@ class ObjectDetector(base_vision_task_api.BaseVisionTaskApi):
|
||||||
for result in detection_proto_list
|
for result in detection_proto_list
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
|
|
||||||
timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp
|
timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp
|
||||||
options.result_callback(detection_result, image, timestamp)
|
options.result_callback(detection_result, image, timestamp)
|
||||||
|
|
||||||
|
@ -266,6 +274,8 @@ class ObjectDetector(base_vision_task_api.BaseVisionTaskApi):
|
||||||
normalized_rect.to_pb2()
|
normalized_rect.to_pb2()
|
||||||
),
|
),
|
||||||
})
|
})
|
||||||
|
if output_packets[_DETECTIONS_OUT_STREAM_NAME].is_empty():
|
||||||
|
return ObjectDetectorResult([])
|
||||||
detection_proto_list = packet_getter.get_proto_list(
|
detection_proto_list = packet_getter.get_proto_list(
|
||||||
output_packets[_DETECTIONS_OUT_STREAM_NAME]
|
output_packets[_DETECTIONS_OUT_STREAM_NAME]
|
||||||
)
|
)
|
||||||
|
@ -315,6 +325,8 @@ class ObjectDetector(base_vision_task_api.BaseVisionTaskApi):
|
||||||
normalized_rect.to_pb2()
|
normalized_rect.to_pb2()
|
||||||
).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
|
).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
|
||||||
})
|
})
|
||||||
|
if output_packets[_DETECTIONS_OUT_STREAM_NAME].is_empty():
|
||||||
|
return ObjectDetectorResult([])
|
||||||
detection_proto_list = packet_getter.get_proto_list(
|
detection_proto_list = packet_getter.get_proto_list(
|
||||||
output_packets[_DETECTIONS_OUT_STREAM_NAME]
|
output_packets[_DETECTIONS_OUT_STREAM_NAME]
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user