Object detector handle empty packet when no object is detected.

PiperOrigin-RevId: 529919638
This commit is contained in:
MediaPipe Team 2023-05-06 01:08:50 -07:00 committed by Copybara-Service
parent cc8847def5
commit 800a7b4a27
6 changed files with 81 additions and 3 deletions

View File

@ -129,9 +129,17 @@ absl::StatusOr<std::unique_ptr<ObjectDetector>> ObjectDetector::Create(
if (status_or_packets.value()[kImageOutStreamName].IsEmpty()) { if (status_or_packets.value()[kImageOutStreamName].IsEmpty()) {
return; return;
} }
Packet image_packet = status_or_packets.value()[kImageOutStreamName];
Packet detections_packet = Packet detections_packet =
status_or_packets.value()[kDetectionsOutStreamName]; status_or_packets.value()[kDetectionsOutStreamName];
Packet image_packet = status_or_packets.value()[kImageOutStreamName]; if (detections_packet.IsEmpty()) {
Packet empty_packet =
status_or_packets.value()[kDetectionsOutStreamName];
result_callback(
{ConvertToDetectionResult({})}, image_packet.Get<Image>(),
empty_packet.Timestamp().Value() / kMicroSecondsPerMilliSecond);
return;
}
result_callback(ConvertToDetectionResult( result_callback(ConvertToDetectionResult(
detections_packet.Get<std::vector<Detection>>()), detections_packet.Get<std::vector<Detection>>()),
image_packet.Get<Image>(), image_packet.Get<Image>(),
@ -165,6 +173,9 @@ absl::StatusOr<ObjectDetectorResult> ObjectDetector::Detect(
ProcessImageData( ProcessImageData(
{{kImageInStreamName, MakePacket<Image>(std::move(image))}, {{kImageInStreamName, MakePacket<Image>(std::move(image))},
{kNormRectName, MakePacket<NormalizedRect>(std::move(norm_rect))}})); {kNormRectName, MakePacket<NormalizedRect>(std::move(norm_rect))}}));
if (output_packets[kDetectionsOutStreamName].IsEmpty()) {
return {ConvertToDetectionResult({})};
}
return ConvertToDetectionResult( return ConvertToDetectionResult(
output_packets[kDetectionsOutStreamName].Get<std::vector<Detection>>()); output_packets[kDetectionsOutStreamName].Get<std::vector<Detection>>());
} }
@ -190,6 +201,9 @@ absl::StatusOr<ObjectDetectorResult> ObjectDetector::DetectForVideo(
{kNormRectName, {kNormRectName,
MakePacket<NormalizedRect>(std::move(norm_rect)) MakePacket<NormalizedRect>(std::move(norm_rect))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}})); .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}));
if (output_packets[kDetectionsOutStreamName].IsEmpty()) {
return {ConvertToDetectionResult({})};
}
return ConvertToDetectionResult( return ConvertToDetectionResult(
output_packets[kDetectionsOutStreamName].Get<std::vector<Detection>>()); output_packets[kDetectionsOutStreamName].Get<std::vector<Detection>>());
} }

View File

@ -499,6 +499,22 @@ TEST_F(ImageModeTest, SucceedsEfficientDetNoNmsModel) {
})pb")})); })pb")}));
} }
TEST_F(ImageModeTest, SucceedsNoObjectDetected) {
MP_ASSERT_OK_AND_ASSIGN(Image image,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
"cats_and_dogs.jpg")));
auto options = std::make_unique<ObjectDetectorOptions>();
options->max_results = 4;
options->score_threshold = 1.0f;
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kEfficientDetWithoutNms);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector,
ObjectDetector::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image));
MP_ASSERT_OK(object_detector->Close());
EXPECT_THAT(results.detections, testing::IsEmpty());
}
TEST_F(ImageModeTest, SucceedsWithoutImageResizing) { TEST_F(ImageModeTest, SucceedsWithoutImageResizing) {
MP_ASSERT_OK_AND_ASSIGN(Image image, DecodeImageFromFile(JoinPath( MP_ASSERT_OK_AND_ASSIGN(Image image, DecodeImageFromFile(JoinPath(
"./", kTestDataDirectory, "./", kTestDataDirectory,

View File

@ -39,6 +39,7 @@ import com.google.mediapipe.formats.proto.DetectionProto.Detection;
import java.io.File; import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
@ -170,6 +171,13 @@ public final class ObjectDetector extends BaseVisionTaskApi {
new OutputHandler.OutputPacketConverter<ObjectDetectionResult, MPImage>() { new OutputHandler.OutputPacketConverter<ObjectDetectionResult, MPImage>() {
@Override @Override
public ObjectDetectionResult convertToTaskResult(List<Packet> packets) { public ObjectDetectionResult convertToTaskResult(List<Packet> packets) {
// If there is no object detected in the image, just returns empty lists.
if (packets.get(DETECTIONS_OUT_STREAM_INDEX).isEmpty()) {
return ObjectDetectionResult.create(
new ArrayList<>(),
BaseVisionTaskApi.generateResultTimestampMs(
detectorOptions.runningMode(), packets.get(DETECTIONS_OUT_STREAM_INDEX)));
}
return ObjectDetectionResult.create( return ObjectDetectionResult.create(
PacketGetter.getProtoVector( PacketGetter.getProtoVector(
packets.get(DETECTIONS_OUT_STREAM_INDEX), Detection.parser()), packets.get(DETECTIONS_OUT_STREAM_INDEX), Detection.parser()),

View File

@ -45,6 +45,7 @@ import org.junit.runners.Suite.SuiteClasses;
@SuiteClasses({ObjectDetectorTest.General.class, ObjectDetectorTest.RunningModeTest.class}) @SuiteClasses({ObjectDetectorTest.General.class, ObjectDetectorTest.RunningModeTest.class})
public class ObjectDetectorTest { public class ObjectDetectorTest {
private static final String MODEL_FILE = "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite"; private static final String MODEL_FILE = "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite";
private static final String NO_NMS_MODEL_FILE = "efficientdet_lite0_fp16_no_nms.tflite";
private static final String CAT_AND_DOG_IMAGE = "cats_and_dogs.jpg"; private static final String CAT_AND_DOG_IMAGE = "cats_and_dogs.jpg";
private static final String CAT_AND_DOG_ROTATED_IMAGE = "cats_and_dogs_rotated.jpg"; private static final String CAT_AND_DOG_ROTATED_IMAGE = "cats_and_dogs_rotated.jpg";
private static final int IMAGE_WIDTH = 1200; private static final int IMAGE_WIDTH = 1200;
@ -109,6 +110,20 @@ public class ObjectDetectorTest {
assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE); assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE);
} }
@Test
public void detect_succeedsWithNoObjectDetected() throws Exception {
ObjectDetectorOptions options =
ObjectDetectorOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(NO_NMS_MODEL_FILE).build())
.setScoreThreshold(1.0f)
.build();
ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
// The score threshold should block objects.
assertThat(results.detections()).isEmpty();
}
@Test @Test
public void detect_succeedsWithAllowListOption() throws Exception { public void detect_succeedsWithAllowListOption() throws Exception {
ObjectDetectorOptions options = ObjectDetectorOptions options =

View File

@ -44,6 +44,7 @@ _ObjectDetectorOptions = object_detector.ObjectDetectorOptions
_RUNNING_MODE = running_mode_module.VisionTaskRunningMode _RUNNING_MODE = running_mode_module.VisionTaskRunningMode
_MODEL_FILE = 'coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite' _MODEL_FILE = 'coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite'
_NO_NMS_MODEL_FILE = 'efficientdet_lite0_fp16_no_nms.tflite'
_IMAGE_FILE = 'cats_and_dogs.jpg' _IMAGE_FILE = 'cats_and_dogs.jpg'
_EXPECTED_DETECTION_RESULT = _DetectionResult( _EXPECTED_DETECTION_RESULT = _DetectionResult(
detections=[ detections=[
@ -304,7 +305,7 @@ class ObjectDetectorTest(parameterized.TestCase):
with _ObjectDetector.create_from_options(options) as unused_detector: with _ObjectDetector.create_from_options(options) as unused_detector:
pass pass
def test_empty_detection_outputs(self): def test_empty_detection_outputs_with_in_model_nms(self):
options = _ObjectDetectorOptions( options = _ObjectDetectorOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
score_threshold=1, score_threshold=1,
@ -314,6 +315,18 @@ class ObjectDetectorTest(parameterized.TestCase):
detection_result = detector.detect(self.test_image) detection_result = detector.detect(self.test_image)
self.assertEmpty(detection_result.detections) self.assertEmpty(detection_result.detections)
def test_empty_detection_outputs_without_in_model_nms(self):
options = _ObjectDetectorOptions(
base_options=_BaseOptions(
model_asset_path=test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, _NO_NMS_MODEL_FILE))),
score_threshold=1,
)
with _ObjectDetector.create_from_options(options) as detector:
# Performs object detection on the input.
detection_result = detector.detect(self.test_image)
self.assertEmpty(detection_result.detections)
def test_missing_result_callback(self): def test_missing_result_callback(self):
options = _ObjectDetectorOptions( options = _ObjectDetectorOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),

View File

@ -198,6 +198,15 @@ class ObjectDetector(base_vision_task_api.BaseVisionTaskApi):
def packets_callback(output_packets: Mapping[str, packet_module.Packet]): def packets_callback(output_packets: Mapping[str, packet_module.Packet]):
if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty(): if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty():
return return
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
if output_packets[_DETECTIONS_OUT_STREAM_NAME].is_empty():
empty_packet = output_packets[_DETECTIONS_OUT_STREAM_NAME]
options.result_callback(
ObjectDetectorResult([]),
image,
empty_packet.timestamp.value // _MICRO_SECONDS_PER_MILLISECOND,
)
return
detection_proto_list = packet_getter.get_proto_list( detection_proto_list = packet_getter.get_proto_list(
output_packets[_DETECTIONS_OUT_STREAM_NAME] output_packets[_DETECTIONS_OUT_STREAM_NAME]
) )
@ -207,7 +216,6 @@ class ObjectDetector(base_vision_task_api.BaseVisionTaskApi):
for result in detection_proto_list for result in detection_proto_list
] ]
) )
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp
options.result_callback(detection_result, image, timestamp) options.result_callback(detection_result, image, timestamp)
@ -266,6 +274,8 @@ class ObjectDetector(base_vision_task_api.BaseVisionTaskApi):
normalized_rect.to_pb2() normalized_rect.to_pb2()
), ),
}) })
if output_packets[_DETECTIONS_OUT_STREAM_NAME].is_empty():
return ObjectDetectorResult([])
detection_proto_list = packet_getter.get_proto_list( detection_proto_list = packet_getter.get_proto_list(
output_packets[_DETECTIONS_OUT_STREAM_NAME] output_packets[_DETECTIONS_OUT_STREAM_NAME]
) )
@ -315,6 +325,8 @@ class ObjectDetector(base_vision_task_api.BaseVisionTaskApi):
normalized_rect.to_pb2() normalized_rect.to_pb2()
).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), ).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
}) })
if output_packets[_DETECTIONS_OUT_STREAM_NAME].is_empty():
return ObjectDetectorResult([])
detection_proto_list = packet_getter.get_proto_list( detection_proto_list = packet_getter.get_proto_list(
output_packets[_DETECTIONS_OUT_STREAM_NAME] output_packets[_DETECTIONS_OUT_STREAM_NAME]
) )