diff --git a/mediapipe/tasks/ios/test/vision/object_detector/MPPObjectDetectorTests.m b/mediapipe/tasks/ios/test/vision/object_detector/MPPObjectDetectorTests.m index cb76de88e..67551c984 100644 --- a/mediapipe/tasks/ios/test/vision/object_detector/MPPObjectDetectorTests.m +++ b/mediapipe/tasks/ios/test/vision/object_detector/MPPObjectDetectorTests.m @@ -23,9 +23,8 @@ static NSDictionary *const kCatsAndDogsImage = @{@"name" : @"cats_and_dogs", @"t static NSDictionary *const kCatsAndDogsRotatedImage = @{@"name" : @"cats_and_dogs_rotated", @"type" : @"jpg"}; static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks"; - -#define PixelDifferenceTolerance 5.0f -#define ScoreDifferenceTolerance 1e-2f +static const float pixelDifferenceTolerance = 5.0f; +static const float scoreDifferenceTolerance = 1e-2f; #define AssertEqualErrors(error, expectedError) \ XCTAssertNotNil(error); \ @@ -35,59 +34,29 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks"; [error.localizedDescription rangeOfString:expectedError.localizedDescription].location, \ NSNotFound) -#define AssertEqualCategoryArrays(categories, expectedCategories, detectionIndex) \ - XCTAssertEqual(categories.count, expectedCategories.count); \ - for (int j = 0; j < categories.count; j++) { \ - XCTAssertEqual(categories[j].index, expectedCategories[j].index, \ - @"detection Index = %d category array index j = %d", detectionIndex, j); \ - XCTAssertEqualWithAccuracy( \ - categories[j].score, expectedCategories[j].score, ScoreDifferenceTolerance, \ - @"detection Index = %d, category array index j = %d", detectionIndex, j); \ - XCTAssertEqualObjects(categories[j].categoryName, expectedCategories[j].categoryName, \ - @"detection Index = %d, category array index j = %d", detectionIndex, \ - j); \ - XCTAssertEqualObjects(categories[j].displayName, expectedCategories[j].displayName, \ - @"detection Index = %d, category array index j = %d", detectionIndex, \ - j); \ - \ - \ - } +#define AssertEqualCategories(category, expectedCategory, detectionIndex, categoryIndex) \ + XCTAssertEqual(category.index, expectedCategory.index, \ + @"detection Index = %d category array index j = %d", detectionIndex, \ + categoryIndex); \ + XCTAssertEqualWithAccuracy(category.score, expectedCategory.score, scoreDifferenceTolerance, \ + @"detection Index = %d, category array index j = %d", detectionIndex, \ + categoryIndex); \ + XCTAssertEqualObjects(category.categoryName, expectedCategory.categoryName, \ + @"detection Index = %d, category array index j = %d", detectionIndex, \ + categoryIndex); \ + XCTAssertEqualObjects(category.displayName, expectedCategory.displayName, \ + @"detection Index = %d, category array index j = %d", detectionIndex, \ + categoryIndex); #define AssertApproximatelyEqualBoundingBoxes(boundingBox, expectedBoundingBox, idx) \ XCTAssertEqualWithAccuracy(boundingBox.origin.x, expectedBoundingBox.origin.x, \ - PixelDifferenceTolerance, @"index i = %d", idx); \ + pixelDifferenceTolerance, @"index i = %d", idx); \ XCTAssertEqualWithAccuracy(boundingBox.origin.y, expectedBoundingBox.origin.y, \ - PixelDifferenceTolerance, @"index i = %d", idx); \ + pixelDifferenceTolerance, @"index i = %d", idx); \ XCTAssertEqualWithAccuracy(boundingBox.size.width, expectedBoundingBox.size.width, \ - PixelDifferenceTolerance, @"index i = %d", idx); \ + pixelDifferenceTolerance, @"index i = %d", idx); \ XCTAssertEqualWithAccuracy(boundingBox.size.height, expectedBoundingBox.size.height, \ - PixelDifferenceTolerance, @"index i = %d", idx); - -#define AssertEqualDetections(detection, expectedDetection, idx) \ - XCTAssertNotNil(detection); \ - AssertEqualCategoryArrays(detection.categories, expectedDetection.categories, idx); \ - AssertApproximatelyEqualBoundingBoxes(detection.boundingBox, expectedDetection.boundingBox, idx); - -#define AssertEqualDetectionArrays(detections, expectedDetections) \ - XCTAssertEqual(detections.count, expectedDetections.count); \ - for (int i = 0; i < detections.count; i++) { \ - AssertEqualDetections(detections[i], expectedDetections[i], i); \ - } - -#define AssertEqualObjectDetectionResults(objectDetectionResult, expectedObjectDetectionResult, \ - expectedDetectionCount) \ - XCTAssertNotNil(objectDetectionResult); \ - NSArray *detectionsSubsetToCompare; \ - XCTAssertEqual(objectDetectionResult.detections.count, expectedDetectionCount); \ - if (objectDetectionResult.detections.count > expectedObjectDetectionResult.detections.count) { \ - detectionsSubsetToCompare = [objectDetectionResult.detections \ - subarrayWithRange:NSMakeRange(0, expectedObjectDetectionResult.detections.count)]; \ - } else { \ - detectionsSubsetToCompare = objectDetectionResult.detections; \ - } \ - \ - AssertEqualDetectionArrays(detectionsSubsetToCompare, expectedObjectDetectionResult.detections); \ - XCTAssertEqual(objectDetectionResult.timestampMs, expectedObjectDetectionResult.timestampMs); + pixelDifferenceTolerance, @"index i = %d", idx); @interface MPPObjectDetectorTests : XCTestCase @end @@ -124,6 +93,39 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks"; return [[MPPObjectDetectionResult alloc] initWithDetections:detections timestampMs:timestampMs]; } +- (void)assertDetections:(NSArray *)detections + isEqualToExpectedDetections:(NSArray *)expectedDetections { + for (int i = 0; i < detections.count; i++) { + MPPDetection *detection = detections[i]; + XCTAssertNotNil(detection); + for (int j = 0; j < detection.categories.count; j++) { + AssertEqualCategories(detection.categories[j], expectedDetections[i].categories[j], i, j); + } + AssertApproximatelyEqualBoundingBoxes(detection.boundingBox, expectedDetections[i].boundingBox, + i); + } +} + +- (void)assertObjectDetectionResult:(MPPObjectDetectionResult *)objectDetectionResult + isEqualToExpectedResult:(MPPObjectDetectionResult *)expectedObjectDetectionResult + expectedDetectionsCount:(NSInteger)expectedDetectionsCount { + XCTAssertNotNil(objectDetectionResult); + + NSArray *detectionsSubsetToCompare; + XCTAssertEqual(objectDetectionResult.detections.count, expectedDetectionsCount); + if (objectDetectionResult.detections.count > expectedObjectDetectionResult.detections.count) { + detectionsSubsetToCompare = [objectDetectionResult.detections + subarrayWithRange:NSMakeRange(0, expectedObjectDetectionResult.detections.count)]; + } else { + detectionsSubsetToCompare = objectDetectionResult.detections; + } + + [self assertDetections:detectionsSubsetToCompare + isEqualToExpectedDetections:expectedObjectDetectionResult.detections]; + + XCTAssertEqual(objectDetectionResult.timestampMs, expectedObjectDetectionResult.timestampMs); +} + #pragma mark File - (NSString *)filePathWithName:(NSString *)fileName extension:(NSString *)extension { @@ -189,9 +191,11 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks"; equalsObjectDetectionResult:(MPPObjectDetectionResult *)expectedObjectDetectionResult { MPPObjectDetectionResult *objectDetectionResult = [objectDetector detectInImage:mppImage error:nil]; - AssertEqualObjectDetectionResults( - objectDetectionResult, expectedObjectDetectionResult, - maxResults > 0 ? maxResults : objectDetectionResult.detections.count); + + [self assertObjectDetectionResult:objectDetectionResult + isEqualToExpectedResult:expectedObjectDetectionResult + expectedDetectionsCount:maxResults > 0 ? maxResults + : objectDetectionResult.detections.count]; } - (void)assertResultsOfDetectInImageWithFileInfo:(NSDictionary *)fileInfo @@ -604,10 +608,12 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks"; MPPObjectDetectionResult *objectDetectionResult = [objectDetector detectInVideoFrame:image timestampMs:i error:nil]; - AssertEqualObjectDetectionResults( - objectDetectionResult, - [MPPObjectDetectorTests expectedDetectionResultForCatsAndDogsImageWithTimestampMs:i], - maxResults); + + [self + assertObjectDetectionResult:objectDetectionResult + isEqualToExpectedResult:[MPPObjectDetectorTests + expectedDetectionResultForCatsAndDogsImageWithTimestampMs:i] + expectedDetectionsCount:maxResults]; } } @@ -624,10 +630,11 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks"; expectation.expectedFulfillmentCount = 1; options.completion = ^(MPPObjectDetectionResult *result, NSInteger timestampMs, NSError *error) { - AssertEqualObjectDetectionResults( - result, - [MPPObjectDetectorTests expectedDetectionResultForCatsAndDogsImageWithTimestampMs:1], - maxResults); + [self assertObjectDetectionResult:result + isEqualToExpectedResult: + [MPPObjectDetectorTests + expectedDetectionResultForCatsAndDogsImageWithTimestampMs:timestampMs] + expectedDetectionsCount:maxResults]; [expectation fulfill]; }; @@ -678,11 +685,11 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks"; expectation.inverted = YES; options.completion = ^(MPPObjectDetectionResult *result, NSInteger timestampMs, NSError *error) { - AssertEqualObjectDetectionResults( - result, - [MPPObjectDetectorTests - expectedDetectionResultForCatsAndDogsImageWithTimestampMs:timestampMs], - maxResults); + [self assertObjectDetectionResult:result + isEqualToExpectedResult: + [MPPObjectDetectorTests + expectedDetectionResultForCatsAndDogsImageWithTimestampMs:timestampMs] + expectedDetectionsCount:maxResults]; [expectation fulfill]; };