From 800a7b4a27495dc2c3b052874e913c4ad993e660 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Sat, 6 May 2023 01:08:50 -0700 Subject: [PATCH] Object detector handle empty packet when no object is detected. PiperOrigin-RevId: 529919638 --- .../cc/vision/object_detector/object_detector.cc | 16 +++++++++++++++- .../object_detector/object_detector_test.cc | 16 ++++++++++++++++ .../vision/objectdetector/ObjectDetector.java | 8 ++++++++ .../objectdetector/ObjectDetectorTest.java | 15 +++++++++++++++ .../python/test/vision/object_detector_test.py | 15 ++++++++++++++- mediapipe/tasks/python/vision/object_detector.py | 14 +++++++++++++- 6 files changed, 81 insertions(+), 3 deletions(-) diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector.cc b/mediapipe/tasks/cc/vision/object_detector/object_detector.cc index 01fd3eb7b..152ee3273 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector.cc +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector.cc @@ -129,9 +129,17 @@ absl::StatusOr> ObjectDetector::Create( if (status_or_packets.value()[kImageOutStreamName].IsEmpty()) { return; } + Packet image_packet = status_or_packets.value()[kImageOutStreamName]; Packet detections_packet = status_or_packets.value()[kDetectionsOutStreamName]; - Packet image_packet = status_or_packets.value()[kImageOutStreamName]; + if (detections_packet.IsEmpty()) { + Packet empty_packet = + status_or_packets.value()[kDetectionsOutStreamName]; + result_callback( + {ConvertToDetectionResult({})}, image_packet.Get(), + empty_packet.Timestamp().Value() / kMicroSecondsPerMilliSecond); + return; + } result_callback(ConvertToDetectionResult( detections_packet.Get>()), image_packet.Get(), @@ -165,6 +173,9 @@ absl::StatusOr ObjectDetector::Detect( ProcessImageData( {{kImageInStreamName, MakePacket(std::move(image))}, {kNormRectName, MakePacket(std::move(norm_rect))}})); + if (output_packets[kDetectionsOutStreamName].IsEmpty()) { + return {ConvertToDetectionResult({})}; + } return ConvertToDetectionResult( output_packets[kDetectionsOutStreamName].Get>()); } @@ -190,6 +201,9 @@ absl::StatusOr ObjectDetector::DetectForVideo( {kNormRectName, MakePacket(std::move(norm_rect)) .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}})); + if (output_packets[kDetectionsOutStreamName].IsEmpty()) { + return {ConvertToDetectionResult({})}; + } return ConvertToDetectionResult( output_packets[kDetectionsOutStreamName].Get>()); } diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc b/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc index 8642af7c4..e66fc19bb 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc @@ -499,6 +499,22 @@ TEST_F(ImageModeTest, SucceedsEfficientDetNoNmsModel) { })pb")})); } +TEST_F(ImageModeTest, SucceedsNoObjectDetected) { + MP_ASSERT_OK_AND_ASSIGN(Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, + "cats_and_dogs.jpg"))); + auto options = std::make_unique(); + options->max_results = 4; + options->score_threshold = 1.0f; + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kEfficientDetWithoutNms); + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr object_detector, + ObjectDetector::Create(std::move(options))); + MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image)); + MP_ASSERT_OK(object_detector->Close()); + EXPECT_THAT(results.detections, testing::IsEmpty()); +} + TEST_F(ImageModeTest, SucceedsWithoutImageResizing) { MP_ASSERT_OK_AND_ASSIGN(Image image, DecodeImageFromFile(JoinPath( "./", kTestDataDirectory, diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java index 5287ba325..d9a36cce7 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java @@ -39,6 +39,7 @@ import com.google.mediapipe.formats.proto.DetectionProto.Detection; import java.io.File; import java.io.IOException; import java.nio.ByteBuffer; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -170,6 +171,13 @@ public final class ObjectDetector extends BaseVisionTaskApi { new OutputHandler.OutputPacketConverter() { @Override public ObjectDetectionResult convertToTaskResult(List packets) { + // If there is no object detected in the image, just returns empty lists. + if (packets.get(DETECTIONS_OUT_STREAM_INDEX).isEmpty()) { + return ObjectDetectionResult.create( + new ArrayList<>(), + BaseVisionTaskApi.generateResultTimestampMs( + detectorOptions.runningMode(), packets.get(DETECTIONS_OUT_STREAM_INDEX))); + } return ObjectDetectionResult.create( PacketGetter.getProtoVector( packets.get(DETECTIONS_OUT_STREAM_INDEX), Detection.parser()), diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetectorTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetectorTest.java index 33aa025d2..20ddfcef6 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetectorTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetectorTest.java @@ -45,6 +45,7 @@ import org.junit.runners.Suite.SuiteClasses; @SuiteClasses({ObjectDetectorTest.General.class, ObjectDetectorTest.RunningModeTest.class}) public class ObjectDetectorTest { private static final String MODEL_FILE = "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite"; + private static final String NO_NMS_MODEL_FILE = "efficientdet_lite0_fp16_no_nms.tflite"; private static final String CAT_AND_DOG_IMAGE = "cats_and_dogs.jpg"; private static final String CAT_AND_DOG_ROTATED_IMAGE = "cats_and_dogs_rotated.jpg"; private static final int IMAGE_WIDTH = 1200; @@ -109,6 +110,20 @@ public class ObjectDetectorTest { assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE); } + @Test + public void detect_succeedsWithNoObjectDetected() throws Exception { + ObjectDetectorOptions options = + ObjectDetectorOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(NO_NMS_MODEL_FILE).build()) + .setScoreThreshold(1.0f) + .build(); + ObjectDetector objectDetector = + ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); + ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); + // The score threshold should block objects. + assertThat(results.detections()).isEmpty(); + } + @Test public void detect_succeedsWithAllowListOption() throws Exception { ObjectDetectorOptions options = diff --git a/mediapipe/tasks/python/test/vision/object_detector_test.py b/mediapipe/tasks/python/test/vision/object_detector_test.py index 7878e7f52..adeddafd7 100644 --- a/mediapipe/tasks/python/test/vision/object_detector_test.py +++ b/mediapipe/tasks/python/test/vision/object_detector_test.py @@ -44,6 +44,7 @@ _ObjectDetectorOptions = object_detector.ObjectDetectorOptions _RUNNING_MODE = running_mode_module.VisionTaskRunningMode _MODEL_FILE = 'coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite' +_NO_NMS_MODEL_FILE = 'efficientdet_lite0_fp16_no_nms.tflite' _IMAGE_FILE = 'cats_and_dogs.jpg' _EXPECTED_DETECTION_RESULT = _DetectionResult( detections=[ @@ -304,7 +305,7 @@ class ObjectDetectorTest(parameterized.TestCase): with _ObjectDetector.create_from_options(options) as unused_detector: pass - def test_empty_detection_outputs(self): + def test_empty_detection_outputs_with_in_model_nms(self): options = _ObjectDetectorOptions( base_options=_BaseOptions(model_asset_path=self.model_path), score_threshold=1, @@ -314,6 +315,18 @@ class ObjectDetectorTest(parameterized.TestCase): detection_result = detector.detect(self.test_image) self.assertEmpty(detection_result.detections) + def test_empty_detection_outputs_without_in_model_nms(self): + options = _ObjectDetectorOptions( + base_options=_BaseOptions( + model_asset_path=test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, _NO_NMS_MODEL_FILE))), + score_threshold=1, + ) + with _ObjectDetector.create_from_options(options) as detector: + # Performs object detection on the input. + detection_result = detector.detect(self.test_image) + self.assertEmpty(detection_result.detections) + def test_missing_result_callback(self): options = _ObjectDetectorOptions( base_options=_BaseOptions(model_asset_path=self.model_path), diff --git a/mediapipe/tasks/python/vision/object_detector.py b/mediapipe/tasks/python/vision/object_detector.py index 3bdd1b5de..380d57c22 100644 --- a/mediapipe/tasks/python/vision/object_detector.py +++ b/mediapipe/tasks/python/vision/object_detector.py @@ -198,6 +198,15 @@ class ObjectDetector(base_vision_task_api.BaseVisionTaskApi): def packets_callback(output_packets: Mapping[str, packet_module.Packet]): if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty(): return + image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME]) + if output_packets[_DETECTIONS_OUT_STREAM_NAME].is_empty(): + empty_packet = output_packets[_DETECTIONS_OUT_STREAM_NAME] + options.result_callback( + ObjectDetectorResult([]), + image, + empty_packet.timestamp.value // _MICRO_SECONDS_PER_MILLISECOND, + ) + return detection_proto_list = packet_getter.get_proto_list( output_packets[_DETECTIONS_OUT_STREAM_NAME] ) @@ -207,7 +216,6 @@ class ObjectDetector(base_vision_task_api.BaseVisionTaskApi): for result in detection_proto_list ] ) - image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME]) timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp options.result_callback(detection_result, image, timestamp) @@ -266,6 +274,8 @@ class ObjectDetector(base_vision_task_api.BaseVisionTaskApi): normalized_rect.to_pb2() ), }) + if output_packets[_DETECTIONS_OUT_STREAM_NAME].is_empty(): + return ObjectDetectorResult([]) detection_proto_list = packet_getter.get_proto_list( output_packets[_DETECTIONS_OUT_STREAM_NAME] ) @@ -315,6 +325,8 @@ class ObjectDetector(base_vision_task_api.BaseVisionTaskApi): normalized_rect.to_pb2() ).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), }) + if output_packets[_DETECTIONS_OUT_STREAM_NAME].is_empty(): + return ObjectDetectorResult([]) detection_proto_list = packet_getter.get_proto_list( output_packets[_DETECTIONS_OUT_STREAM_NAME] )