diff --git a/mediapipe/tasks/ios/vision/image_classifier/BUILD b/mediapipe/tasks/ios/vision/image_classifier/BUILD index 45e6e2156..130e5fe7d 100644 --- a/mediapipe/tasks/ios/vision/image_classifier/BUILD +++ b/mediapipe/tasks/ios/vision/image_classifier/BUILD @@ -36,3 +36,29 @@ objc_library( "//mediapipe/tasks/ios/vision/core:MPPRunningMode", ], ) + +objc_library( + name = "MPPImageClassifier", + srcs = ["sources/MPPImageClassifier.mm"], + hdrs = ["sources/MPPImageClassifier.h"], + copts = [ + "-ObjC++", + "-std=c++17", + "-x objective-c++", + ], + module_name = "MPPImageClassifier", + deps = [ + ":MPPImageClassifierOptions", + ":MPPImageClassifierResult", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + "//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph", + "//mediapipe/tasks/ios/common/utils:MPPCommonUtils", + "//mediapipe/tasks/ios/common/utils:NSStringHelpers", + "//mediapipe/tasks/ios/core:MPPTaskInfo", + "//mediapipe/tasks/ios/core:MPPTaskOptions", + "//mediapipe/tasks/ios/vision/core:MPPVisionPacketCreator", + "//mediapipe/tasks/ios/vision/core:MPPVisionTaskRunner", + "//mediapipe/tasks/ios/vision/image_classifier/utils:MPPImageClassifierOptionsHelpers", + "//mediapipe/tasks/ios/vision/image_classifier/utils:MPPImageClassifierResultHelpers", + ], +) diff --git a/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifier.h b/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifier.h new file mode 100644 index 000000000..1914f9aea --- /dev/null +++ b/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifier.h @@ -0,0 +1,217 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +#import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h" +#import "mediapipe/tasks/ios/vision/core/sources/MPPImage.h" +#import "mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifierOptions.h" +#import "mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifierResult.h" + +NS_ASSUME_NONNULL_BEGIN + +/** + * @brief Performs classification on images. + * + * The API expects a TFLite model with optional, but strongly recommended, + * [TFLite Model Metadata.](https://www.tensorflow.org/lite/convert/metadata"). + * + * The API supports models with one image input tensor and one or more output tensors. To be more + * specific, here are the requirements. + * + * Input tensor + * (kTfLiteUInt8/kTfLiteFloat32) + * - image input of size `[batch x height x width x channels]`. + * - batch inference is not supported (`batch` is required to be 1). + * - only RGB inputs are supported (`channels` is required to be 3). + * - if type is kTfLiteFloat32, NormalizationOptions are required to be attached to the metadata + * for input normalization. + * + * At least one output tensor with: + * (kTfLiteUInt8/kTfLiteFloat32) + * - `N `classes and either 2 or 4 dimensions, i.e. `[1 x N]` or `[1 x 1 x 1 x N]` + * - optional (but recommended) label map(s) as AssociatedFiles with type TENSOR_AXIS_LABELS, + * containing one label per line. The first such AssociatedFile (if any) is used to fill the + * `class_name` field of the results. The `display_name` field is filled from the AssociatedFile + * (if any) whose locale matches the `display_names_locale` field of the `ImageClassifierOptions` + * used at creation time ("en" by default, i.e. English). If none of these are available, only + * the `index` field of the results will be filled. + * - optional score calibration can be attached using ScoreCalibrationOptions and an AssociatedFile + * with type TENSOR_AXIS_SCORE_CALIBRATION. See metadata_schema.fbs [1] for more details. + */ +NS_SWIFT_NAME(ImageClassifier) +@interface MPPImageClassifier : NSObject + +/** + * Creates a new instance of `MPPImageClassifier` from an absolute path to a TensorFlow Lite + * model file stored locally on the device and the default `MPPImageClassifierOptions`. + * + * @param modelPath An absolute path to a TensorFlow Lite model file stored locally on the device. + * @param error An optional error parameter populated when there is an error in initializing the + * image classifier. + * + * @return A new instance of `MPPImageClassifier` with the given model path. `nil` if there is an + * error in initializing the image classifier. + */ +- (nullable instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error; + +/** + * Creates a new instance of `MPPImageClassifier` from the given `MPPImageClassifierOptions`. + * + * @param options The options of type `MPPImageClassifierOptions` to use for configuring the + * `MPPImageClassifier`. + * @param error An optional error parameter populated when there is an error in initializing the + * image classifier. + * + * @return A new instance of `MPPImageClassifier` with the given options. `nil` if there is an + * error in initializing the image classifier. + */ +- (nullable instancetype)initWithOptions:(MPPImageClassifierOptions *)options + error:(NSError **)error NS_DESIGNATED_INITIALIZER; + +/** + * Performs image classification on the provided MPPImage using the whole image as region of + * interest. Rotation will be applied according to the `orientation` property of the provided + * `MPPImage`. Only use this method when the `MPPImageClassifier` is created with + * `MPPRunningModeImage`. + * + * @param image The `MPPImage` on which image classification is to be performed. + * @param error An optional error parameter populated when there is an error in performing + * image classification on the input image. + * + * @return A `MPPImageClassifierResult` object that contains a list of image classifications. + */ +- (nullable MPPImageClassifierResult *)classifyImage:(MPPImage *)image + error:(NSError **)error + NS_SWIFT_NAME(classify(image:)); + +/** + * Performs image classification on the provided `MPPImage` cropped to the specified region of + * interest. Rotation will be applied on the cropped image according to the `orientation` property + * of the provided `MPPImage`. Only use this method when the `MPPImageClassifier` is created with + * `MPPRunningModeImage`. + * + * @param image The `MPPImage` on which image classification is to be performed. + * @param roi A `CGRect` specifying the region of interest within the given `MPPImage`, on which + * image classification should be performed. + * @param error An optional error parameter populated when there is an error in performing + * image classification on the input image. + * + * @return A `MPPImageClassifierResult` object that contains a list of image classifications. + */ +- (nullable MPPImageClassifierResult *)classifyImage:(MPPImage *)image + regionOfInterest:(CGRect)roi + error:(NSError **)error + NS_SWIFT_NAME(classify(image:regionOfInterest:)); + +/** + * Performs image classification on the provided video frame of type `MPPImage` using the whole + * image as region of interest. Rotation will be applied according to the `orientation` property of + * the provided `MPPImage`. Only use this method when the `MPPImageClassifier` is created with + * `MPPRunningModeImage`. + * + * @param image The `MPPImage` on which image classification is to be performed. + * @param timeStampMs The video frame's timestamp (in milliseconds). The input timestamps must be + * monotonically increasing. + * @param error An optional error parameter populated when there is an error in performing + * image classification on the input video frame. + * + * @return A `MPPImageClassifierResult` object that contains a list of image classifications. + */ +- (nullable MPPImageClassifierResult *)classifyVideoFrame:(MPPImage *)image + timestampMs:(NSInteger)timestampMs + error:(NSError **)error + NS_SWIFT_NAME(classify(videoFrame:timeStampMs:)); + +/** + * Performs image classification on the provided video frame of type `MPPImage` cropped to the + * specified region of interest. Rotation will be applied according to the `orientation` property of + * the provided `MPPImage`. Only use this method when the `MPPImageClassifier` is created with + * `MPPRunningModeVideo`. + * + * It's required to provide the video frame's timestamp (in milliseconds). The input timestamps must + * be monotonically increasing. + * + * @param image A live stream image data of type `MPPImage` on which image classification is to be + * performed. + * @param timeStampMs The video frame's timestamp (in milliseconds). The input timestamps must be + * monotonically increasing. + * @param roi A `CGRect` specifying the region of interest within the video frame of type + * `MPPImage`, on which image classification should be performed. + * @param error An optional error parameter populated when there is an error in performing + * image classification on the input video frame. + * + * @return A `MPPImageClassifierResult` object that contains a list of image classifications. + */ +- (nullable MPPImageClassifierResult *)classifyVideoFrame:(MPPImage *)image + timestampMs:(NSInteger)timestampMs + regionOfInterest:(CGRect)roi + error:(NSError **)error + NS_SWIFT_NAME(classify(videoFrame:timeStampMs:regionOfInterest:)); + +/** + * Sends live stream image data of type `MPPImage` to perform image classification using the whole + * image as region of interest. Rotation will be applied according to the `orientation` property of + * the provided `MPPImage`. Only use this method when the `MPPImageClassifier` is created with + * `MPPRunningModeLiveStream`. + * + * It's required to provide a timestamp (in milliseconds) to indicate when the input image is sent + * to the image classifier. The input timestamps must be monotonically increasing. + * + * @param image A live stream image data of type `MPPImage` on which image classification is to be + * performed. + * @param timeStampMs The timestamp (in milliseconds) which indicates when the input image is sent + * to the image classifier. The input timestamps must be monotonically increasing. + * @param error An optional error parameter populated when there is an error in performing + * image classification on the input live stream image data. + * + * @return A `MPPImageClassifierResult` object that contains a list of image classifications. + */ +- (BOOL)classifyAsyncImage:(MPPImage *)image + timestampMs:(NSInteger)timestampMs + error:(NSError **)error NS_SWIFT_NAME(classifyAsync(image:timeStampMs:)); + +/** + * Sends live stream image data of type `MPPImage` to perform image classification, cropped to the + * specified region of interest.. Rotation will be applied according to the `orientation` property + * of the provided `MPPImage`. Only use this method when the `MPPImageClassifier` is created with + * `MPPRunningModeLiveStream`. + * + * It's required to provide a timestamp (in milliseconds) to indicate when the input image is sent + * to the image classifier. The input timestamps must be monotonically increasing. + * + * @param image A live stream image data of type `MPPImage` on which image classification is to be + * performed. + * @param timeStampMs The timestamp (in milliseconds) which indicates when the input image is sent + * to the image classifier. The input timestamps must be monotonically increasing. + * @param roi A `CGRect` specifying the region of interest within the given live stream image data + * of type `MPPImage`, on which image classification should be performed. + * @param error An optional error parameter populated when there is an error in performing + * image classification on the input live stream image data. + * + * @return A `MPPImageClassifierResult` object that contains a list of image classifications. + */ +- (BOOL)classifyAsyncImage:(MPPImage *)image + timestampMs:(NSInteger)timestampMs + regionOfInterest:(CGRect)roi + error:(NSError **)error + NS_SWIFT_NAME(classifyAsync(image:timeStampMs:regionOfInterest:)); + +- (instancetype)init NS_UNAVAILABLE; + ++ (instancetype)new NS_UNAVAILABLE; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifier.mm b/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifier.mm new file mode 100644 index 000000000..f4e13717b --- /dev/null +++ b/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifier.mm @@ -0,0 +1,228 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifier.h" + +#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h" +#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" +#import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h" +#import "mediapipe/tasks/ios/vision/core/sources/MPPVisionPacketCreator.h" +#import "mediapipe/tasks/ios/vision/core/sources/MPPVisionTaskRunner.h" +#import "mediapipe/tasks/ios/vision/image_classifier/utils/sources/MPPImageClassifierOptions+Helpers.h" +#import "mediapipe/tasks/ios/vision/image_classifier/utils/sources/MPPImageClassifierResult+Helpers.h" + +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" + +namespace { +using ::mediapipe::NormalizedRect; +using ::mediapipe::Packet; +using ::mediapipe::tasks::core::PacketMap; +using ::mediapipe::tasks::core::PacketsCallback; +} // namespace + +static NSString *const kClassificationsStreamName = @"classifications_out"; +static NSString *const kClassificationsTag = @"CLASSIFICATIONS"; +static NSString *const kImageInStreamName = @"image_in"; +static NSString *const kImageOutStreamName = @"image_out"; +static NSString *const kImageTag = @"IMAGE"; +static NSString *const kNormRectName = @"norm_rect_in"; +static NSString *const kNormRectTag = @"NORM_RECT"; + +static NSString *const kTaskGraphName = + @"mediapipe.tasks.vision.image_classifier.ImageClassifierGraph"; + +@interface MPPImageClassifier () { + /** iOS Text Task Runner */ + MPPVisionTaskRunner *_visionTaskRunner; +} +@end + +@implementation MPPImageClassifier + +- (instancetype)initWithOptions:(MPPImageClassifierOptions *)options error:(NSError **)error { + self = [super init]; + if (self) { + MPPTaskInfo *taskInfo = [[MPPTaskInfo alloc] + initWithTaskGraphName:kTaskGraphName + inputStreams:@[ [NSString + stringWithFormat:@"%@:%@", kImageTag, kImageInStreamName] ] + outputStreams:@[ [NSString stringWithFormat:@"%@:%@", kClassificationsTag, + kClassificationsStreamName] ] + taskOptions:options + enableFlowLimiting:NO + error:error]; + + if (!taskInfo) { + return nil; + } + + PacketsCallback packetsCallback = nullptr; + + if (options.completion) { + packetsCallback = [=](absl::StatusOr status_or_packets) { + NSError *callbackError = nil; + MPPImageClassifierResult *result; + if ([MPPCommonUtils checkCppError:status_or_packets.status() toError:&callbackError]) { + result = [MPPImageClassifierResult + imageClassifierResultWithClassificationsPacket: + status_or_packets.value()[kClassificationsStreamName.cppString]]; + } + options.completion(result, callbackError); + }; + } + + _visionTaskRunner = + [[MPPVisionTaskRunner alloc] initWithCalculatorGraphConfig:[taskInfo generateGraphConfig] + runningMode:options.runningMode + packetsCallback:std::move(packetsCallback) + error:error]; + + if (!_visionTaskRunner) { + return nil; + } + } + return self; +} + +- (instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error { + MPPImageClassifierOptions *options = [[MPPImageClassifierOptions alloc] init]; + + options.baseOptions.modelAssetPath = modelPath; + + return [self initWithOptions:options error:error]; +} + +- (nullable MPPImageClassifierResult *)classifyImage:(MPPImage *)image + regionOfInterest:(CGRect)roi + error:(NSError **)error { + std::optional rect = + [_visionTaskRunner normalizedRectFromRegionOfInterest:roi + imageOrientation:image.orientation + roiAllowed:YES + error:error]; + if (!rect.has_value()) { + return nil; + } + + Packet imagePacket = [MPPVisionPacketCreator createPacketWithMPPImage:image error:error]; + if (imagePacket.IsEmpty()) { + return nil; + } + + Packet normalizedRectPacket = + [MPPVisionPacketCreator createPacketWithNormalizedRect:rect.value()]; + + PacketMap inputPacketMap = {{kImageInStreamName.cppString, imagePacket}, + {kNormRectName.cppString, normalizedRectPacket}}; + + std::optional outputPacketMap = [_visionTaskRunner processPacketMap:inputPacketMap + error:error]; + if (!outputPacketMap.has_value()) { + return nil; + } + + return + [MPPImageClassifierResult imageClassifierResultWithClassificationsPacket: + outputPacketMap.value()[kClassificationsStreamName.cppString]]; +} + +- (nullable MPPImageClassifierResult *)classifyImage:(MPPImage *)image error:(NSError **)error { + return [self classifyImage:image regionOfInterest:CGRectZero error:error]; +} + +- (nullable MPPImageClassifierResult *)classifyVideoFrame:(MPPImage *)image + timestampMs:(NSInteger)timestampMs + regionOfInterest:(CGRect)roi + error:(NSError **)error { + std::optional rect = + [_visionTaskRunner normalizedRectFromRegionOfInterest:roi + imageOrientation:image.orientation + roiAllowed:YES + error:error]; + if (!rect.has_value()) { + return nil; + } + + Packet imagePacket = [MPPVisionPacketCreator createPacketWithMPPImage:image + timestampMs:timestampMs + error:error]; + if (imagePacket.IsEmpty()) { + return nil; + } + + Packet normalizedRectPacket = [MPPVisionPacketCreator createPacketWithNormalizedRect:rect.value() + timestampMs:timestampMs]; + + PacketMap inputPacketMap = {{kImageInStreamName.cppString, imagePacket}, + {kNormRectName.cppString, normalizedRectPacket}}; + + std::optional outputPacketMap = + [_visionTaskRunner processVideoFramePacketMap:inputPacketMap error:error]; + if (!outputPacketMap.has_value()) { + return nil; + } + + return + [MPPImageClassifierResult imageClassifierResultWithClassificationsPacket: + outputPacketMap.value()[kClassificationsStreamName.cppString]]; +} + +- (nullable MPPImageClassifierResult *)classifyVideoFrame:(MPPImage *)image + timestampMs:(NSInteger)timestampMs + error:(NSError **)error { + return [self classifyVideoFrame:image + timestampMs:timestampMs + regionOfInterest:CGRectZero + error:error]; +} + +- (BOOL)classifyAsyncImage:(MPPImage *)image + timestampMs:(NSInteger)timestampMs + regionOfInterest:(CGRect)roi + error:(NSError **)error { + std::optional rect = + [_visionTaskRunner normalizedRectFromRegionOfInterest:roi + imageOrientation:image.orientation + roiAllowed:YES + error:error]; + if (!rect.has_value()) { + return NO; + } + + Packet imagePacket = [MPPVisionPacketCreator createPacketWithMPPImage:image + timestampMs:timestampMs + error:error]; + if (imagePacket.IsEmpty()) { + return NO; + } + + Packet normalizedRectPacket = [MPPVisionPacketCreator createPacketWithNormalizedRect:rect.value() + timestampMs:timestampMs]; + + PacketMap inputPacketMap = {{kImageInStreamName.cppString, imagePacket}, + {kNormRectName.cppString, normalizedRectPacket}}; + + return [_visionTaskRunner processLiveStreamPacketMap:inputPacketMap error:error]; +} + +- (BOOL)classifyAsyncImage:(MPPImage *)image + timestampMs:(NSInteger)timestampMs + error:(NSError **)error { + return [self classifyAsyncImage:image + timestampMs:timestampMs + regionOfInterest:CGRectZero + error:error]; +} + +@end