diff --git a/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifier.mm b/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifier.mm index 6db7d06c5..564aede88 100644 --- a/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifier.mm +++ b/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifier.mm @@ -42,6 +42,11 @@ static NSString *const kNormRectTag = @"NORM_RECT"; static NSString *const kTaskGraphName = @"mediapipe.tasks.vision.image_classifier.ImageClassifierGraph"; +#define InputPacketMap(imagePacket, normalizedRectPacket) \ + { \ + {kImageInStreamName.cppString, imagePacket}, { kNormRectName.cppString, normalizedRectPacket } \ + } + @interface MPPImageClassifier () { /** iOS Vision Task Runner */ MPPVisionTaskRunner *_visionTaskRunner; @@ -123,8 +128,7 @@ static NSString *const kTaskGraphName = Packet normalizedRectPacket = [MPPVisionPacketCreator createPacketWithNormalizedRect:rect.value()]; - PacketMap inputPacketMap = {{kImageInStreamName.cppString, imagePacket}, - {kNormRectName.cppString, normalizedRectPacket}}; + PacketMap inputPacketMap = InputPacketMap(imagePacket, normalizedRectPacket); std::optional outputPacketMap = [_visionTaskRunner processPacketMap:inputPacketMap error:error]; @@ -137,6 +141,33 @@ static NSString *const kTaskGraphName = outputPacketMap.value()[kClassificationsStreamName.cppString]]; } +- (std::optional)inputPacketMapWithMPPImage:(MPPImage *)image + timestampMs:(NSInteger)timestampMs + regionOfInterest:(CGRect)roi + error:(NSError **)error { + std::optional rect = + [_visionTaskRunner normalizedRectFromRegionOfInterest:roi + imageOrientation:image.orientation + roiAllowed:YES + error:error]; + if (!rect.has_value()) { + return std::nullopt; + } + + Packet imagePacket = [MPPVisionPacketCreator createPacketWithMPPImage:image + timestampMs:timestampMs + error:error]; + if (imagePacket.IsEmpty()) { + return std::nullopt; + } + + Packet normalizedRectPacket = [MPPVisionPacketCreator createPacketWithNormalizedRect:rect.value() + timestampMs:timestampMs]; + + PacketMap inputPacketMap = InputPacketMap(imagePacket, normalizedRectPacket); + return inputPacketMap; +} + - (nullable MPPImageClassifierResult *)classifyImage:(MPPImage *)image error:(NSError **)error { return [self classifyImage:image regionOfInterest:CGRectZero error:error]; } @@ -145,30 +176,17 @@ static NSString *const kTaskGraphName = timestampMs:(NSInteger)timestampMs regionOfInterest:(CGRect)roi error:(NSError **)error { - std::optional rect = - [_visionTaskRunner normalizedRectFromRegionOfInterest:roi - imageOrientation:image.orientation - roiAllowed:YES - error:error]; - if (!rect.has_value()) { + std::optional inputPacketMap = [self inputPacketMapWithMPPImage:image + timestampMs:timestampMs + regionOfInterest:roi + error:error]; + if (!inputPacketMap.has_value()) { return nil; } - Packet imagePacket = [MPPVisionPacketCreator createPacketWithMPPImage:image - timestampMs:timestampMs - error:error]; - if (imagePacket.IsEmpty()) { - return nil; - } - - Packet normalizedRectPacket = [MPPVisionPacketCreator createPacketWithNormalizedRect:rect.value() - timestampMs:timestampMs]; - - PacketMap inputPacketMap = {{kImageInStreamName.cppString, imagePacket}, - {kNormRectName.cppString, normalizedRectPacket}}; - std::optional outputPacketMap = - [_visionTaskRunner processVideoFramePacketMap:inputPacketMap error:error]; + [_visionTaskRunner processVideoFramePacketMap:inputPacketMap.value() error:error]; + if (!outputPacketMap.has_value()) { return nil; } @@ -191,29 +209,15 @@ static NSString *const kTaskGraphName = timestampMs:(NSInteger)timestampMs regionOfInterest:(CGRect)roi error:(NSError **)error { - std::optional rect = - [_visionTaskRunner normalizedRectFromRegionOfInterest:roi - imageOrientation:image.orientation - roiAllowed:YES - error:error]; - if (!rect.has_value()) { + std::optional inputPacketMap = [self inputPacketMapWithMPPImage:image + timestampMs:timestampMs + regionOfInterest:roi + error:error]; + if (!inputPacketMap.has_value()) { return NO; } - Packet imagePacket = [MPPVisionPacketCreator createPacketWithMPPImage:image - timestampMs:timestampMs - error:error]; - if (imagePacket.IsEmpty()) { - return NO; - } - - Packet normalizedRectPacket = [MPPVisionPacketCreator createPacketWithNormalizedRect:rect.value() - timestampMs:timestampMs]; - - PacketMap inputPacketMap = {{kImageInStreamName.cppString, imagePacket}, - {kNormRectName.cppString, normalizedRectPacket}}; - - return [_visionTaskRunner processLiveStreamPacketMap:inputPacketMap error:error]; + return [_visionTaskRunner processLiveStreamPacketMap:inputPacketMap.value() error:error]; } - (BOOL)classifyAsyncImage:(MPPImage *)image