diff --git a/mediapipe/tasks/ios/test/vision/image_segmenter/MPPImageSegmenterTests.mm b/mediapipe/tasks/ios/test/vision/image_segmenter/MPPImageSegmenterTests.mm index aa42100c4..76822152e 100644 --- a/mediapipe/tasks/ios/test/vision/image_segmenter/MPPImageSegmenterTests.mm +++ b/mediapipe/tasks/ios/test/vision/image_segmenter/MPPImageSegmenterTests.mm @@ -20,6 +20,9 @@ #import "mediapipe/tasks/ios/vision/image_segmenter/sources/MPPImageSegmenter.h" #import "mediapipe/tasks/ios/vision/image_segmenter/sources/MPPImageSegmenterResult.h" +#include +#include + static MPPFileInfo *const kCatImageFileInfo = [[MPPFileInfo alloc] initWithName:@"cat" type:@"jpg"]; static MPPFileInfo *const kCatGoldenImageFileInfo = [[MPPFileInfo alloc] initWithName:@"cat_mask" @@ -34,39 +37,36 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks"; constexpr float kSimilarityThreshold = 0.96f; constexpr NSInteger kMagnificationFactor = 10; -double sum(const float *mask, size_t size) { +double sum(const std::vector& mask) { double sum = 0.0; - for (int i = 0; i < size; i++) { - sum += mask[i]; + for (const float &maskElement : mask) { + sum += maskElement; } return sum; } -float *multiply(const float *mask1, const float *mask2, size_t size) { +std::vector multiply(const float *mask1, const float *mask2, size_t size) { double sum = 0.0; - float *multipliedMask = (float *)malloc(size * sizeof(float)); - if (!multipliedMask) { - exit(-1); - } + + std::vector multipliedMask; + multipliedMask.reserve(size); + for (int i = 0; i < size; i++) { - multipliedMask[i] = mask1[i] * mask2[i]; + multipliedMask.push_back(mask1[i] * mask2[i]); } - + return multipliedMask; } double softIOU(const float *mask1, const float *mask2, size_t size) { - float *interSectionVector = multiply(mask1, mask2, size); - double interSectionSum = sum(interSectionVector, size); - free(interSectionVector); + std::vector interSectionVector = multiply(mask1, mask2, size); + double interSectionSum = sum(interSectionVector); - float *m1m1Vector = multiply(mask1, mask1, size); - double m1m1 = sum(m1m1Vector, size); - free(m1m1Vector); + std::vector m1m1Vector = multiply(mask1, mask1, size); + double m1m1 = sum(m1m1Vector); - float *m2m2Vector = multiply(mask2, mask2, size); - double m2m2 = sum(m2m2Vector, size); - free(m2m2Vector); + std::vector m2m2Vector = multiply(mask2, mask2, size); + double m2m2 = sum(m2m2Vector); double unionSum = m1m1 + m2m2 - interSectionSum;