Added methods for common functionality in MPPImageClassifier

This commit is contained in:
Prianka Liz Kariat 2023-03-03 12:01:45 +05:30
parent 8aaabe4a02
commit 289b3b20de

View File

@ -42,6 +42,11 @@ static NSString *const kNormRectTag = @"NORM_RECT";
static NSString *const kTaskGraphName = static NSString *const kTaskGraphName =
@"mediapipe.tasks.vision.image_classifier.ImageClassifierGraph"; @"mediapipe.tasks.vision.image_classifier.ImageClassifierGraph";
#define InputPacketMap(imagePacket, normalizedRectPacket) \
{ \
{kImageInStreamName.cppString, imagePacket}, { kNormRectName.cppString, normalizedRectPacket } \
}
@interface MPPImageClassifier () { @interface MPPImageClassifier () {
/** iOS Vision Task Runner */ /** iOS Vision Task Runner */
MPPVisionTaskRunner *_visionTaskRunner; MPPVisionTaskRunner *_visionTaskRunner;
@ -123,8 +128,7 @@ static NSString *const kTaskGraphName =
Packet normalizedRectPacket = Packet normalizedRectPacket =
[MPPVisionPacketCreator createPacketWithNormalizedRect:rect.value()]; [MPPVisionPacketCreator createPacketWithNormalizedRect:rect.value()];
PacketMap inputPacketMap = {{kImageInStreamName.cppString, imagePacket}, PacketMap inputPacketMap = InputPacketMap(imagePacket, normalizedRectPacket);
{kNormRectName.cppString, normalizedRectPacket}};
std::optional<PacketMap> outputPacketMap = [_visionTaskRunner processPacketMap:inputPacketMap std::optional<PacketMap> outputPacketMap = [_visionTaskRunner processPacketMap:inputPacketMap
error:error]; error:error];
@ -137,6 +141,33 @@ static NSString *const kTaskGraphName =
outputPacketMap.value()[kClassificationsStreamName.cppString]]; outputPacketMap.value()[kClassificationsStreamName.cppString]];
} }
- (std::optional<PacketMap>)inputPacketMapWithMPPImage:(MPPImage *)image
timestampMs:(NSInteger)timestampMs
regionOfInterest:(CGRect)roi
error:(NSError **)error {
std::optional<NormalizedRect> rect =
[_visionTaskRunner normalizedRectFromRegionOfInterest:roi
imageOrientation:image.orientation
roiAllowed:YES
error:error];
if (!rect.has_value()) {
return std::nullopt;
}
Packet imagePacket = [MPPVisionPacketCreator createPacketWithMPPImage:image
timestampMs:timestampMs
error:error];
if (imagePacket.IsEmpty()) {
return std::nullopt;
}
Packet normalizedRectPacket = [MPPVisionPacketCreator createPacketWithNormalizedRect:rect.value()
timestampMs:timestampMs];
PacketMap inputPacketMap = InputPacketMap(imagePacket, normalizedRectPacket);
return inputPacketMap;
}
- (nullable MPPImageClassifierResult *)classifyImage:(MPPImage *)image error:(NSError **)error { - (nullable MPPImageClassifierResult *)classifyImage:(MPPImage *)image error:(NSError **)error {
return [self classifyImage:image regionOfInterest:CGRectZero error:error]; return [self classifyImage:image regionOfInterest:CGRectZero error:error];
} }
@ -145,30 +176,17 @@ static NSString *const kTaskGraphName =
timestampMs:(NSInteger)timestampMs timestampMs:(NSInteger)timestampMs
regionOfInterest:(CGRect)roi regionOfInterest:(CGRect)roi
error:(NSError **)error { error:(NSError **)error {
std::optional<NormalizedRect> rect = std::optional<PacketMap> inputPacketMap = [self inputPacketMapWithMPPImage:image
[_visionTaskRunner normalizedRectFromRegionOfInterest:roi timestampMs:timestampMs
imageOrientation:image.orientation regionOfInterest:roi
roiAllowed:YES error:error];
error:error]; if (!inputPacketMap.has_value()) {
if (!rect.has_value()) {
return nil; return nil;
} }
Packet imagePacket = [MPPVisionPacketCreator createPacketWithMPPImage:image
timestampMs:timestampMs
error:error];
if (imagePacket.IsEmpty()) {
return nil;
}
Packet normalizedRectPacket = [MPPVisionPacketCreator createPacketWithNormalizedRect:rect.value()
timestampMs:timestampMs];
PacketMap inputPacketMap = {{kImageInStreamName.cppString, imagePacket},
{kNormRectName.cppString, normalizedRectPacket}};
std::optional<PacketMap> outputPacketMap = std::optional<PacketMap> outputPacketMap =
[_visionTaskRunner processVideoFramePacketMap:inputPacketMap error:error]; [_visionTaskRunner processVideoFramePacketMap:inputPacketMap.value() error:error];
if (!outputPacketMap.has_value()) { if (!outputPacketMap.has_value()) {
return nil; return nil;
} }
@ -191,29 +209,15 @@ static NSString *const kTaskGraphName =
timestampMs:(NSInteger)timestampMs timestampMs:(NSInteger)timestampMs
regionOfInterest:(CGRect)roi regionOfInterest:(CGRect)roi
error:(NSError **)error { error:(NSError **)error {
std::optional<NormalizedRect> rect = std::optional<PacketMap> inputPacketMap = [self inputPacketMapWithMPPImage:image
[_visionTaskRunner normalizedRectFromRegionOfInterest:roi timestampMs:timestampMs
imageOrientation:image.orientation regionOfInterest:roi
roiAllowed:YES error:error];
error:error]; if (!inputPacketMap.has_value()) {
if (!rect.has_value()) {
return NO; return NO;
} }
Packet imagePacket = [MPPVisionPacketCreator createPacketWithMPPImage:image return [_visionTaskRunner processLiveStreamPacketMap:inputPacketMap.value() error:error];
timestampMs:timestampMs
error:error];
if (imagePacket.IsEmpty()) {
return NO;
}
Packet normalizedRectPacket = [MPPVisionPacketCreator createPacketWithNormalizedRect:rect.value()
timestampMs:timestampMs];
PacketMap inputPacketMap = {{kImageInStreamName.cppString, imagePacket},
{kNormRectName.cppString, normalizedRectPacket}};
return [_visionTaskRunner processLiveStreamPacketMap:inputPacketMap error:error];
} }
- (BOOL)classifyAsyncImage:(MPPImage *)image - (BOOL)classifyAsyncImage:(MPPImage *)image