Added methods for common functionality in MPPImageClassifier

This commit is contained in:
Prianka Liz Kariat 2023-03-03 12:01:45 +05:30
parent 8aaabe4a02
commit 289b3b20de

View File

@ -42,6 +42,11 @@ static NSString *const kNormRectTag = @"NORM_RECT";
static NSString *const kTaskGraphName =
@"mediapipe.tasks.vision.image_classifier.ImageClassifierGraph";
#define InputPacketMap(imagePacket, normalizedRectPacket) \
{ \
{kImageInStreamName.cppString, imagePacket}, { kNormRectName.cppString, normalizedRectPacket } \
}
@interface MPPImageClassifier () {
/** iOS Vision Task Runner */
MPPVisionTaskRunner *_visionTaskRunner;
@ -123,8 +128,7 @@ static NSString *const kTaskGraphName =
Packet normalizedRectPacket =
[MPPVisionPacketCreator createPacketWithNormalizedRect:rect.value()];
PacketMap inputPacketMap = {{kImageInStreamName.cppString, imagePacket},
{kNormRectName.cppString, normalizedRectPacket}};
PacketMap inputPacketMap = InputPacketMap(imagePacket, normalizedRectPacket);
std::optional<PacketMap> outputPacketMap = [_visionTaskRunner processPacketMap:inputPacketMap
error:error];
@ -137,11 +141,7 @@ static NSString *const kTaskGraphName =
outputPacketMap.value()[kClassificationsStreamName.cppString]];
}
- (nullable MPPImageClassifierResult *)classifyImage:(MPPImage *)image error:(NSError **)error {
return [self classifyImage:image regionOfInterest:CGRectZero error:error];
}
- (nullable MPPImageClassifierResult *)classifyVideoFrame:(MPPImage *)image
- (std::optional<PacketMap>)inputPacketMapWithMPPImage:(MPPImage *)image
timestampMs:(NSInteger)timestampMs
regionOfInterest:(CGRect)roi
error:(NSError **)error {
@ -151,24 +151,42 @@ static NSString *const kTaskGraphName =
roiAllowed:YES
error:error];
if (!rect.has_value()) {
return nil;
return std::nullopt;
}
Packet imagePacket = [MPPVisionPacketCreator createPacketWithMPPImage:image
timestampMs:timestampMs
error:error];
if (imagePacket.IsEmpty()) {
return nil;
return std::nullopt;
}
Packet normalizedRectPacket = [MPPVisionPacketCreator createPacketWithNormalizedRect:rect.value()
timestampMs:timestampMs];
PacketMap inputPacketMap = {{kImageInStreamName.cppString, imagePacket},
{kNormRectName.cppString, normalizedRectPacket}};
PacketMap inputPacketMap = InputPacketMap(imagePacket, normalizedRectPacket);
return inputPacketMap;
}
- (nullable MPPImageClassifierResult *)classifyImage:(MPPImage *)image error:(NSError **)error {
return [self classifyImage:image regionOfInterest:CGRectZero error:error];
}
- (nullable MPPImageClassifierResult *)classifyVideoFrame:(MPPImage *)image
timestampMs:(NSInteger)timestampMs
regionOfInterest:(CGRect)roi
error:(NSError **)error {
std::optional<PacketMap> inputPacketMap = [self inputPacketMapWithMPPImage:image
timestampMs:timestampMs
regionOfInterest:roi
error:error];
if (!inputPacketMap.has_value()) {
return nil;
}
std::optional<PacketMap> outputPacketMap =
[_visionTaskRunner processVideoFramePacketMap:inputPacketMap error:error];
[_visionTaskRunner processVideoFramePacketMap:inputPacketMap.value() error:error];
if (!outputPacketMap.has_value()) {
return nil;
}
@ -191,29 +209,15 @@ static NSString *const kTaskGraphName =
timestampMs:(NSInteger)timestampMs
regionOfInterest:(CGRect)roi
error:(NSError **)error {
std::optional<NormalizedRect> rect =
[_visionTaskRunner normalizedRectFromRegionOfInterest:roi
imageOrientation:image.orientation
roiAllowed:YES
error:error];
if (!rect.has_value()) {
return NO;
}
Packet imagePacket = [MPPVisionPacketCreator createPacketWithMPPImage:image
std::optional<PacketMap> inputPacketMap = [self inputPacketMapWithMPPImage:image
timestampMs:timestampMs
regionOfInterest:roi
error:error];
if (imagePacket.IsEmpty()) {
if (!inputPacketMap.has_value()) {
return NO;
}
Packet normalizedRectPacket = [MPPVisionPacketCreator createPacketWithNormalizedRect:rect.value()
timestampMs:timestampMs];
PacketMap inputPacketMap = {{kImageInStreamName.cppString, imagePacket},
{kNormRectName.cppString, normalizedRectPacket}};
return [_visionTaskRunner processLiveStreamPacketMap:inputPacketMap error:error];
return [_visionTaskRunner processLiveStreamPacketMap:inputPacketMap.value() error:error];
}
- (BOOL)classifyAsyncImage:(MPPImage *)image