diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h index ed57d2df2..41515571a 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h +++ b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h @@ -17,6 +17,8 @@ #include "mediapipe/framework/calculator.pb.h" #include "mediapipe/tasks/cc/core/task_runner.h" +#include + NS_ASSUME_NONNULL_BEGIN /** @@ -62,24 +64,57 @@ NS_ASSUME_NONNULL_BEGIN error:(NSError **)error NS_DESIGNATED_INITIALIZER; /** - * A synchronous method for processing batch data or offline streaming data. This method is designed - * for processing either batch data such as unrelated images and texts or offline streaming data - * such as the decoded frames from a video file or audio file. The call blocks the current - * thread until a failure status or a successful result is returned. If the input packets have no - * timestamp, an internal timestamp will be assigned per invocation. Otherwise, when the timestamp - * is set in the input packets, the caller must ensure that the input packet timestamps are greater - * than the timestamps of the previous invocation. This method is thread-unsafe and it is the - * caller's responsibility to synchronize access to this method across multiple threads and to - * ensure that the input packet timestamps are in order. + * A synchronous method for invoking the C++ task runner for processing batch data or offline + * streaming data. This method is designed for processing either batch data such as unrelated images + * and texts or offline streaming data such as the decoded frames from a video file or audio file. + * The call blocks the current thread until a failure status or a successful result is returned. If + * the input packets have no timestamp, an internal timestamp will be assigned per invocation. + * Otherwise, when the timestamp is set in the input packets, the caller must ensure that the input + * packet timestamps are greater than the timestamps of the previous invocation. This method is + * thread-unsafe and it is the caller's responsibility to synchronize access to this method across + * multiple threads and to ensure that the input packet timestamps are in order. + * + * @param packetMap A `PacketMap` containing pairs of input stream name and data packet which are to + * be sent to the C++ task runner for processing synchronously. + * @param error Pointer to the memory location where errors if any should be saved. If @c NULL, no + * error will be saved. + * + * @return An optional output `PacketMap` containing pairs of output stream name and data packet + * which holds the results of processing the input packet map, if there are no errors. */ -- (absl::StatusOr)process: - (const mediapipe::tasks::core::PacketMap &)packetMap; +- (std::optional) + processPacketMap:(const mediapipe::tasks::core::PacketMap &)packetMap + error:(NSError **)error; + +/** + * An asynchronous method that is designed for handling live streaming data such as live camera. A + * user-defined PacketsCallback function must be provided in the constructor to receive the output + * packets. The caller must ensure that the input packet timestamps are monotonically increasing. + * This method is thread-unsafe and it is the caller's responsibility to synchronize access to this + * method across multiple threads and to ensure that the input packet timestamps are in order. + * + * @param packetMap A `PacketMap` containing pairs of input stream name and data packet that are to + * be sent to the C++ task runner for processing asynchronously. + * @param error Pointer to the memory location where errors if any should be saved. If @c NULL, no + * error will be saved. + * + * @return A `BOOL` indicating if the live stream data was sent to the C++ task runner successfully. + * Please note that any errors during processing of the live stream packet map will only be + * available in the user-defined `packetsCallback` that was provided during initialization of the + * `MPPVisionTaskRunner`. + */ +- (BOOL)sendPacketMap:(const mediapipe::tasks::core::PacketMap &)packetMap error:(NSError **)error; /** * Shuts down the C++ task runner. After the runner is closed, any calls that send input data to the * runner are illegal and will receive errors. + * + * @param error Pointer to the memory location where errors if any should be saved. If @c NULL, no + * error will be saved. + * + * @return A `BOOL` indicating if the C++ task runner was shutdown successfully. */ -- (absl::Status)close; +- (BOOL)closeWithError:(NSError **)error; - (instancetype)init NS_UNAVAILABLE; diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskRunner.mm b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.mm index eb777679a..0813760c2 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskRunner.mm +++ b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.mm @@ -50,12 +50,22 @@ using TaskRunnerCpp = ::mediapipe::tasks::core::TaskRunner; return self; } -- (absl::StatusOr)process:(const PacketMap &)packetMap { - return _cppTaskRunner->Process(packetMap); +- (std::optional)processPacketMap:(const PacketMap &)packetMap error:(NSError **)error { + absl::StatusOr resultPacketMap = _cppTaskRunner->Process(packetMap); + if (![MPPCommonUtils checkCppError:resultPacketMap.status() toError:error]) { + return std::nullopt; + } + return resultPacketMap.value(); } -- (absl::Status)close { - return _cppTaskRunner->Close(); +- (BOOL)sendPacketMap:(const PacketMap &)packetMap error:(NSError **)error { + absl::Status sendStatus = _cppTaskRunner->Send(packetMap); + return [MPPCommonUtils checkCppError:sendStatus toError:error]; +} + +- (BOOL)closeWithError:(NSError **)error { + absl::Status closeStatus = _cppTaskRunner->Close(); + return [MPPCommonUtils checkCppError:closeStatus toError:error]; } @end diff --git a/mediapipe/tasks/ios/text/text_classifier/BUILD b/mediapipe/tasks/ios/text/text_classifier/BUILD index c56d51e5f..7913340ac 100644 --- a/mediapipe/tasks/ios/text/text_classifier/BUILD +++ b/mediapipe/tasks/ios/text/text_classifier/BUILD @@ -58,6 +58,5 @@ objc_library( "//mediapipe/tasks/ios/text/core:MPPTextTaskRunner", "//mediapipe/tasks/ios/text/text_classifier/utils:MPPTextClassifierOptionsHelpers", "//mediapipe/tasks/ios/text/text_classifier/utils:MPPTextClassifierResultHelpers", - "@com_google_absl//absl/status:statusor", ], ) diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm index 52e4d92ac..f0e1e4152 100644 --- a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm @@ -22,7 +22,6 @@ #import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h" #import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.h" -#include "absl/status/statusor.h" #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" namespace { @@ -83,15 +82,16 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.T Packet packet = [MPPTextPacketCreator createWithText:text]; std::map packetMap = {{kTextInStreamName.cppString, packet}}; - absl::StatusOr statusOrOutputPacketMap = [_textTaskRunner process:packetMap]; + std::optional outputPacketMap = [_textTaskRunner processPacketMap:packetMap + error:error]; - if (![MPPCommonUtils checkCppError:statusOrOutputPacketMap.status() toError:error]) { + if (!outputPacketMap.has_value()) { return nil; } - return [MPPTextClassifierResult - textClassifierResultWithClassificationsPacket:statusOrOutputPacketMap.value() - [kClassificationsStreamName.cppString]]; + return + [MPPTextClassifierResult textClassifierResultWithClassificationsPacket: + outputPacketMap.value()[kClassificationsStreamName.cppString]]; } @end diff --git a/mediapipe/tasks/ios/text/text_embedder/BUILD b/mediapipe/tasks/ios/text/text_embedder/BUILD index 74aefdf77..a600d5366 100644 --- a/mediapipe/tasks/ios/text/text_embedder/BUILD +++ b/mediapipe/tasks/ios/text/text_embedder/BUILD @@ -58,6 +58,5 @@ objc_library( "//mediapipe/tasks/ios/text/core:MPPTextTaskRunner", "//mediapipe/tasks/ios/text/text_embedder/utils:MPPTextEmbedderOptionsHelpers", "//mediapipe/tasks/ios/text/text_embedder/utils:MPPTextEmbedderResultHelpers", - "@com_google_absl//absl/status:statusor", ], ) diff --git a/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.mm b/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.mm index 62eb882d3..e0f0d549d 100644 --- a/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.mm +++ b/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.mm @@ -23,8 +23,6 @@ #import "mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderOptions+Helpers.h" #import "mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderResult+Helpers.h" -#include "absl/status/statusor.h" - namespace { using ::mediapipe::Packet; using ::mediapipe::tasks::core::PacketMap; @@ -83,14 +81,15 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_embedder.Tex Packet packet = [MPPTextPacketCreator createWithText:text]; std::map packetMap = {{kTextInStreamName.cppString, packet}}; - absl::StatusOr statusOrOutputPacketMap = [_textTaskRunner process:packetMap]; - if (![MPPCommonUtils checkCppError:statusOrOutputPacketMap.status() toError:error]) { + std::optional outputPacketMap = [_textTaskRunner processPacketMap:packetMap + error:error]; + + if (!outputPacketMap.has_value()) { return nil; } - return [MPPTextEmbedderResult - textEmbedderResultWithOutputPacket:statusOrOutputPacketMap + textEmbedderResultWithOutputPacket:outputPacketMap .value()[kEmbeddingsOutStreamName.cppString]]; } diff --git a/mediapipe/tasks/ios/vision/core/BUILD b/mediapipe/tasks/ios/vision/core/BUILD index 1961ca6b0..a8164d674 100644 --- a/mediapipe/tasks/ios/vision/core/BUILD +++ b/mediapipe/tasks/ios/vision/core/BUILD @@ -26,6 +26,24 @@ objc_library( module_name = "MPPRunningMode", ) +objc_library( + name = "MPPVisionPacketCreator", + srcs = ["sources/MPPVisionPacketCreator.mm"], + hdrs = ["sources/MPPVisionPacketCreator.h"], + copts = [ + "-ObjC++", + "-std=c++17", + ], + deps = [ + ":MPPImage", + "//mediapipe/framework:packet", + "//mediapipe/framework:timestamp", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/tasks/ios/vision/core/utils:MPPImageUtils", + ], +) + objc_library( name = "MPPVisionTaskRunner", srcs = ["sources/MPPVisionTaskRunner.mm"], @@ -36,8 +54,11 @@ objc_library( ], deps = [ ":MPPRunningMode", + "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/tasks/ios/common:MPPCommon", "//mediapipe/tasks/ios/common/utils:MPPCommonUtils", "//mediapipe/tasks/ios/core:MPPTaskRunner", + "//third_party/apple_frameworks:UIKit", + "@com_google_absl//absl/status:statusor", ], ) diff --git a/mediapipe/tasks/ios/vision/core/sources/MPPRunningMode.h b/mediapipe/tasks/ios/vision/core/sources/MPPRunningMode.h index 5cc57b88a..ab76546df 100644 --- a/mediapipe/tasks/ios/vision/core/sources/MPPRunningMode.h +++ b/mediapipe/tasks/ios/vision/core/sources/MPPRunningMode.h @@ -38,4 +38,17 @@ typedef NS_ENUM(NSUInteger, MPPRunningMode) { } NS_SWIFT_NAME(RunningMode); +NS_INLINE NSString *MPPRunningModeDisplayName(MPPRunningMode runningMode) { + if (runningMode > MPPRunningModeLiveStream) { + return nil; + } + + NSString *displayNameMap[MPPRunningModeLiveStream + 1] = { + [MPPRunningModeImage] = @"#MPPRunningModeImage", + [MPPRunningModeVideo] = @ "#MPPRunningModeVideo", + [MPPRunningModeLiveStream] = @ "#MPPRunningModeLiveStream"}; + + return displayNameMap[runningMode]; +} + NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/vision/core/sources/MPPVisionPacketCreator.h b/mediapipe/tasks/ios/vision/core/sources/MPPVisionPacketCreator.h index cf597ec24..eaf059ad2 100644 --- a/mediapipe/tasks/ios/vision/core/sources/MPPVisionPacketCreator.h +++ b/mediapipe/tasks/ios/vision/core/sources/MPPVisionPacketCreator.h @@ -14,14 +14,63 @@ #import -#include "mediapipe/framework/packet.h" #import "mediapipe/tasks/ios/vision/core/sources/MPPImage.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/packet.h" + /** * This class helps create various kinds of packets for Mediapipe Vision Tasks. */ @interface MPPVisionPacketCreator : NSObject +/** + * Creates a MediapPipe Packet wrapping an `MPPImage` that can be send to a graph. + * + * @param image The image to send to the MediaPipe graph. + * @param error Pointer to the memory location where errors if any should be saved. If @c NULL, no + * error will be saved. + * + * @return The MediaPipe packet containing the image. An empty packet is returned if an error + * occurred during the conversion. + */ + (mediapipe::Packet)createPacketWithMPPImage:(MPPImage *)image error:(NSError **)error; +/** + * Creates a MediapPipe Packet wrapping an `MPPImage` that can be send to a graph at the specified + * timestamp. + * + * @param image The image to send to the MediaPipe graph. + * @param timestampMs The timestamp (in milliseconds) to assign to the packet. + * @param error Pointer to the memory location where errors if any should be saved. If @c NULL, no + * error will be saved. + * + * @return The MediaPipe packet containing the image. An empty packet is returned if an error + * occurred during the conversion. + */ ++ (mediapipe::Packet)createPacketWithMPPImage:(MPPImage *)image + timestampMs:(NSInteger)timestampMs + error:(NSError **)error; + +/** + * Creates a MediapPipe Packet wrapping a `NormalizedRect` that can be send to a graph. + * + * @param image The `NormalizedRect` to send to the MediaPipe graph. + * + * @return The MediaPipe packet containing the normalized rect. + */ ++ (mediapipe::Packet)createPacketWithNormalizedRect:(mediapipe::NormalizedRect &)normalizedRect; + +/** + * Creates a MediapPipe Packet wrapping a `NormalizedRect` that can be send to a graph at the + * specified timestamp. + * + * @param image The `NormalizedRect` to send to the MediaPipe graph. + * @param timestampMs The timestamp (in milliseconds) to assign to the packet. + * + * @return The MediaPipe packet containing the normalized rect. + */ ++ (mediapipe::Packet)createPacketWithNormalizedRect:(mediapipe::NormalizedRect &)normalizedRect + timestampMs:(NSInteger)timestampMs; + @end diff --git a/mediapipe/tasks/ios/vision/core/sources/MPPVisionPacketCreator.mm b/mediapipe/tasks/ios/vision/core/sources/MPPVisionPacketCreator.mm index 01e583e62..bf136a759 100644 --- a/mediapipe/tasks/ios/vision/core/sources/MPPVisionPacketCreator.mm +++ b/mediapipe/tasks/ios/vision/core/sources/MPPVisionPacketCreator.mm @@ -16,18 +16,19 @@ #import "mediapipe/tasks/ios/vision/core/utils/sources/MPPImage+Utils.h" #include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/timestamp.h" + +static const NSUInteger kMicroSecondsPerMilliSecond = 1000; namespace { using ::mediapipe::Image; using ::mediapipe::ImageFrame; using ::mediapipe::MakePacket; +using ::mediapipe::NormalizedRect; using ::mediapipe::Packet; +using ::mediapipe::Timestamp; } // namespace -struct freeDeleter { - void operator()(void *ptr) { free(ptr); } -}; - @implementation MPPVisionPacketCreator + (Packet)createPacketWithMPPImage:(MPPImage *)image error:(NSError **)error { @@ -40,4 +41,27 @@ struct freeDeleter { return MakePacket(std::move(imageFrame)); } ++ (Packet)createPacketWithMPPImage:(MPPImage *)image + timestampMs:(NSInteger)timestampMs + error:(NSError **)error { + std::unique_ptr imageFrame = [image imageFrameWithError:error]; + + if (!imageFrame) { + return Packet(); + } + + return MakePacket(std::move(imageFrame)) + .At(Timestamp(int64(timestampMs * kMicroSecondsPerMilliSecond))); +} + ++ (Packet)createPacketWithNormalizedRect:(NormalizedRect &)normalizedRect { + return MakePacket(std::move(normalizedRect)); +} + ++ (Packet)createPacketWithNormalizedRect:(NormalizedRect &)normalizedRect + timestampMs:(NSInteger)timestampMs { + return MakePacket(std::move(normalizedRect)) + .At(Timestamp(int64(timestampMs * kMicroSecondsPerMilliSecond))); +} + @end diff --git a/mediapipe/tasks/ios/vision/core/sources/MPPVisionTaskRunner.h b/mediapipe/tasks/ios/vision/core/sources/MPPVisionTaskRunner.h index 84b657305..f19e4ca75 100644 --- a/mediapipe/tasks/ios/vision/core/sources/MPPVisionTaskRunner.h +++ b/mediapipe/tasks/ios/vision/core/sources/MPPVisionTaskRunner.h @@ -13,10 +13,13 @@ // limitations under the License. #import +#import #import "mediapipe/tasks/ios/core/sources/MPPTaskRunner.h" #import "mediapipe/tasks/ios/vision/core/sources/MPPRunningMode.h" +#include "mediapipe/framework/formats/rect.pb.h" + NS_ASSUME_NONNULL_BEGIN /** @@ -54,6 +57,82 @@ NS_ASSUME_NONNULL_BEGIN (mediapipe::tasks::core::PacketsCallback)packetsCallback error:(NSError **)error NS_DESIGNATED_INITIALIZER; +/** + * Creates a `NormalizedRect` from a region of interest and an image orientation, performing + * sanity checks on-the-fly. + * If the input region of interest equals `CGRectZero`, returns a default `NormalizedRect` covering + * the whole image with rotation set according `imageOrientation`. If `ROIAllowed` is NO, an error + * will be returned if the input region of interest is not equal to `CGRectZero`. Mirrored + * orientations (`UIImageOrientationUpMirrored`,`UIImageOrientationDownMirrored`, + * `UIImageOrientationLeftMirrored`,`UIImageOrientationRightMirrored`) are not supported. An error + * will be returned if `imageOrientation` is equal to any one of them. + * + * @param roi A `CGRect` specifying the region of interest. If the input region of interest equals + * `CGRectZero`, the returned `NormalizedRect` covers the whole image. Make sure that `roi` equals + * `CGRectZero` if `ROIAllowed` is NO. Otherwise, an error will be returned. + * @param imageOrientation A `UIImageOrientation` indicating the rotation to be applied to the + * image. The resulting `NormalizedRect` will convert the `imageOrientation` to degrees clockwise. + * Mirrored orientations (`UIImageOrientationUpMirrored`, `UIImageOrientationDownMirrored`, + * `UIImageOrientationLeftMirrored`, `UIImageOrientationRightMirrored`) are not supported. An error + * will be returned if `imageOrientation` is equal to any one of them. + * @param ROIAllowed Indicates if the `roi` field is allowed to be a value other than `CGRectZero`. + * @param error Pointer to the memory location where errors if any should be saved. If @c NULL, no + * error will be saved. + * + * @return An optional `NormalizedRect` from the given region of interest and image orientation. + */ +- (std::optional) + normalizedRectFromRegionOfInterest:(CGRect)roi + imageOrientation:(UIImageOrientation)imageOrientation + ROIAllowed:(BOOL)ROIAllowed + error:(NSError **)error; + +/** + * A synchronous method to invoke the C++ task runner to process single image inputs. The call + * blocks the current thread until a failure status or a successful result is returned. + * + * @param packetMap A `PackeMap` containing pairs of input stream name and data packet. + * @param error Pointer to the memory location where errors if any should be + * saved. If @c NULL, no error will be saved. + * + * @return An optional `PacketMap` containing pairs of output stream name and data packet. + */ +- (std::optional) + processImagePacketMap:(const mediapipe::tasks::core::PacketMap &)packetMap + error:(NSError **)error; + +/** + * A synchronous method to invoke the C++ task runner to process continuous video frames. The call + * blocks the current thread until a failure status or a successful result is returned. + * + * @param packetMap A `PackeMap` containing pairs of input stream name and data packet. + * @param error Pointer to the memory location where errors if any should be saved. If @c NULL, no + * error will be saved. + * + * @return An optional `PacketMap` containing pairs of output stream name and data packet. + */ +- (std::optional) + processVideoFramePacketMap:(const mediapipe::tasks::core::PacketMap &)packetMap + error:(NSError **)error; + +/** + * An asynchronous method to send live stream data to the C++ task runner. The call blocks the + * current thread until a failure status or a successful result is returned. The results will be + * available in the user-defined `packetsCallback` that was provided during initialization of the + * `MPPVisionTaskRunner`. + * + * @param packetMap A `PackeMap` containing pairs of input stream name and data packet. + * @param error Pointer to the memory location where errors if any should be saved. If @c NULL, no + * error will be saved. + * + * @return A `BOOL` indicating if the live stream data was sent to the C++ task runner successfully. + * Please note that any errors during processing of the live stream packet map will only be + * available in the user-defined `packetsCallback` that was provided during initialization of the + * `MPPVisionTaskRunner`. + */ +- (BOOL)processLiveStreamPacketMap:(const mediapipe::tasks::core::PacketMap &)packetMap + error:(NSError **)error; + - (instancetype)initWithCalculatorGraphConfig:(mediapipe::CalculatorGraphConfig)graphConfig packetsCallback: (mediapipe::tasks::core::PacketsCallback)packetsCallback diff --git a/mediapipe/tasks/ios/vision/core/sources/MPPVisionTaskRunner.mm b/mediapipe/tasks/ios/vision/core/sources/MPPVisionTaskRunner.mm index bfa9e34e5..492d29a8b 100644 --- a/mediapipe/tasks/ios/vision/core/sources/MPPVisionTaskRunner.mm +++ b/mediapipe/tasks/ios/vision/core/sources/MPPVisionTaskRunner.mm @@ -17,11 +17,26 @@ #import "mediapipe/tasks/ios/common/sources/MPPCommon.h" #import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h" +#include "absl/status/statusor.h" + +#include + namespace { using ::mediapipe::CalculatorGraphConfig; +using ::mediapipe::NormalizedRect; +using ::mediapipe::tasks::core::PacketMap; using ::mediapipe::tasks::core::PacketsCallback; } // namespace +/** Rotation degress for a 90 degree rotation to the right. */ +static const NSInteger kMPPOrientationDegreesRight = -90; + +/** Rotation degress for a 180 degree rotation. */ +static const NSInteger kMPPOrientationDegreesDown = -180; + +/** Rotation degress for a 90 degree rotation to the left. */ +static const NSInteger kMPPOrientationDegreesLeft = -270; + @interface MPPVisionTaskRunner () { MPPRunningMode _runningMode; } @@ -70,4 +85,100 @@ using ::mediapipe::tasks::core::PacketsCallback; return self; } +- (std::optional)normalizedRectFromRegionOfInterest:(CGRect)roi + imageOrientation: + (UIImageOrientation)imageOrientation + ROIAllowed:(BOOL)ROIAllowed + error:(NSError **)error { + if (CGRectEqualToRect(roi, CGRectZero) && !ROIAllowed) { + [MPPCommonUtils createCustomError:error + withCode:MPPTasksErrorCodeInvalidArgumentError + description:@"This task doesn't support region-of-interest."]; + return std::nullopt; + } + + CGRect calculatedRoi = CGRectEqualToRect(roi, CGRectZero) ? roi : CGRectMake(0.0, 0.0, 1.0, 1.0); + + NormalizedRect normalizedRect; + normalizedRect.set_x_center(CGRectGetMidX(calculatedRoi)); + normalizedRect.set_y_center(CGRectGetMidY(calculatedRoi)); + normalizedRect.set_width(CGRectGetWidth(calculatedRoi)); + normalizedRect.set_height(CGRectGetHeight(calculatedRoi)); + + int rotationDegrees = 0; + switch (imageOrientation) { + case UIImageOrientationUp: + break; + case UIImageOrientationRight: { + rotationDegrees = kMPPOrientationDegreesRight; + break; + } + case UIImageOrientationDown: { + rotationDegrees = kMPPOrientationDegreesDown; + break; + } + case UIImageOrientationLeft: { + rotationDegrees = kMPPOrientationDegreesLeft; + break; + } + default: + [MPPCommonUtils + createCustomError:error + withCode:MPPTasksErrorCodeInvalidArgumentError + description: + @"Unsupported UIImageOrientation. `imageOrientation` cannot be equal to " + @"any of the mirrored orientations " + @"(`UIImageOrientationUpMirrored`,`UIImageOrientationDownMirrored`,`" + @"UIImageOrientationLeftMirrored`,`UIImageOrientationRightMirrored`)"]; + } + + normalizedRect.set_rotation(rotationDegrees * M_PI / kMPPOrientationDegreesDown); + + return normalizedRect; +} + +- (std::optional)processImagePacketMap:(const PacketMap &)packetMap + error:(NSError **)error { + if (_runningMode != MPPRunningModeImage) { + [MPPCommonUtils + createCustomError:error + withCode:MPPTasksErrorCodeInvalidArgumentError + description:[NSString stringWithFormat:@"The vision task is not initialized with " + @"image mode. Current Running Mode: %@", + MPPRunningModeDisplayName(_runningMode)]]; + return std::nullopt; + } + + return [self processPacketMap:packetMap error:error]; +} + +- (std::optional)processVideoFramePacketMap:(const PacketMap &)packetMap + error:(NSError **)error { + if (_runningMode != MPPRunningModeVideo) { + [MPPCommonUtils + createCustomError:error + withCode:MPPTasksErrorCodeInvalidArgumentError + description:[NSString stringWithFormat:@"The vision task is not initialized with " + @"video mode. Current Running Mode: %@", + MPPRunningModeDisplayName(_runningMode)]]; + return std::nullopt; + } + + return [self processPacketMap:packetMap error:error]; +} + +- (BOOL)processLiveStreamPacketMap:(const PacketMap &)packetMap error:(NSError **)error { + if (_runningMode != MPPRunningModeLiveStream) { + [MPPCommonUtils + createCustomError:error + withCode:MPPTasksErrorCodeInvalidArgumentError + description:[NSString stringWithFormat:@"The vision task is not initialized with " + @"live stream mode. Current Running Mode: %@", + MPPRunningModeDisplayName(_runningMode)]]; + return NO; + } + + return [self sendPacketMap:packetMap error:error]; +} + @end diff --git a/mediapipe/tasks/ios/vision/image_classifier/BUILD b/mediapipe/tasks/ios/vision/image_classifier/BUILD index 45e6e2156..4ebcd2b29 100644 --- a/mediapipe/tasks/ios/vision/image_classifier/BUILD +++ b/mediapipe/tasks/ios/vision/image_classifier/BUILD @@ -36,3 +36,30 @@ objc_library( "//mediapipe/tasks/ios/vision/core:MPPRunningMode", ], ) + +objc_library( + name = "MPPImageClassifier", + srcs = ["sources/MPPImageClassifier.mm"], + hdrs = ["sources/MPPImageClassifier.h"], + copts = [ + "-ObjC++", + "-std=c++17", + "-x objective-c++", + ], + module_name = "MPPImageClassifier", + deps = [ + ":MPPImageClassifierOptions", + ":MPPImageClassifierResult", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + "//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph", + "//mediapipe/tasks/ios/common/utils:MPPCommonUtils", + "//mediapipe/tasks/ios/common/utils:NSStringHelpers", + "//mediapipe/tasks/ios/core:MPPTaskInfo", + "//mediapipe/tasks/ios/core:MPPTaskOptions", + "//mediapipe/tasks/ios/vision/core:MPPImage", + "//mediapipe/tasks/ios/vision/core:MPPVisionPacketCreator", + "//mediapipe/tasks/ios/vision/core:MPPVisionTaskRunner", + "//mediapipe/tasks/ios/vision/image_classifier/utils:MPPImageClassifierOptionsHelpers", + "//mediapipe/tasks/ios/vision/image_classifier/utils:MPPImageClassifierResultHelpers", + ], +) diff --git a/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifier.h b/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifier.h new file mode 100644 index 000000000..581c8d95b --- /dev/null +++ b/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifier.h @@ -0,0 +1,219 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +#import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h" +#import "mediapipe/tasks/ios/vision/core/sources/MPPImage.h" +#import "mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifierOptions.h" +#import "mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifierResult.h" + +NS_ASSUME_NONNULL_BEGIN + +/** + * @brief Performs classification on images. + * + * The API expects a TFLite model with optional, but strongly recommended, + * [TFLite Model Metadata.](https://www.tensorflow.org/lite/convert/metadata"). + * + * The API supports models with one image input tensor and one or more output tensors. To be more + * specific, here are the requirements. + * + * Input tensor + * (kTfLiteUInt8/kTfLiteFloat32) + * - image input of size `[batch x height x width x channels]`. + * - batch inference is not supported (`batch` is required to be 1). + * - only RGB inputs are supported (`channels` is required to be 3). + * - if type is kTfLiteFloat32, NormalizationOptions are required to be attached to the metadata + * for input normalization. + * + * At least one output tensor with: + * (kTfLiteUInt8/kTfLiteFloat32) + * - `N `classes and either 2 or 4 dimensions, i.e. `[1 x N]` or `[1 x 1 x 1 x N]` + * - optional (but recommended) label map(s) as AssociatedFiles with type TENSOR_AXIS_LABELS, + * containing one label per line. The first such AssociatedFile (if any) is used to fill the + * `class_name` field of the results. The `display_name` field is filled from the AssociatedFile + * (if any) whose locale matches the `display_names_locale` field of the `ImageClassifierOptions` + * used at creation time ("en" by default, i.e. English). If none of these are available, only + * the `index` field of the results will be filled. + * - optional score calibration can be attached using ScoreCalibrationOptions and an AssociatedFile + * with type TENSOR_AXIS_SCORE_CALIBRATION. See metadata_schema.fbs [1] for more details. + */ +NS_SWIFT_NAME(ImageClassifier) +@interface MPPImageClassifier : NSObject + +/** + * Creates a new instance of `MPPImageClassifier` from an absolute path to a TensorFlow Lite model + * file stored locally on the device and the default `MPPImageClassifierOptions`. + * + * @param modelPath An absolute path to a TensorFlow Lite model file stored locally on the device. + * @param error An optional error parameter populated when there is an error in initializing the + * image classifier. + * + * @return A new instance of `MPPImageClassifier` with the given model path. `nil` if there is an + * error in initializing the image classifier. + */ +- (nullable instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error; + +/** + * Creates a new instance of `MPPImageClassifier` from the given `MPPImageClassifierOptions`. + * + * @param options The options of type `MPPImageClassifierOptions` to use for configuring the + * `MPPImageClassifier`. + * @param error An optional error parameter populated when there is an error in initializing the + * image classifier. + * + * @return A new instance of `MPPImageClassifier` with the given options. `nil` if there is an error + * in initializing the image classifier. + */ +- (nullable instancetype)initWithOptions:(MPPImageClassifierOptions *)options + error:(NSError **)error NS_DESIGNATED_INITIALIZER; + +/** + * Performs image classification on the provided MPPImage using the whole image as region of + * interest. Rotation will be applied according to the `orientation` property of the provided + * `MPPImage`. Only use this method when the `MPPImageClassifier` is created with + * `MPPRunningModeImage`. + * + * @param image The `MPPImage` on which image classification is to be performed. + * @param error An optional error parameter populated when there is an error in performing image + * classification on the input image. + * + * @return An `MPPImageClassifierResult` object that contains a list of image classifications. + */ +- (nullable MPPImageClassifierResult *)classifyImage:(MPPImage *)image + error:(NSError **)error + NS_SWIFT_NAME(classify(image:)); + +/** + * Performs image classification on the provided `MPPImage` cropped to the specified region of + * interest. Rotation will be applied on the cropped image according to the `orientation` property + * of the provided `MPPImage`. Only use this method when the `MPPImageClassifier` is created with + * `MPPRunningModeImage`. + * + * @param image The `MPPImage` on which image classification is to be performed. + * @param roi A `CGRect` specifying the region of interest within the given `MPPImage`, on which + * image classification should be performed. + * @param error An optional error parameter populated when there is an error in performing image + * classification on the input image. + * + * @return An `MPPImageClassifierResult` object that contains a list of image classifications. + */ +- (nullable MPPImageClassifierResult *)classifyImage:(MPPImage *)image + regionOfInterest:(CGRect)roi + error:(NSError **)error + NS_SWIFT_NAME(classify(image:regionOfInterest:)); + +/** + * Performs image classification on the provided video frame of type `MPPImage` using the whole + * image as region of interest. Rotation will be applied according to the `orientation` property of + * the provided `MPPImage`. Only use this method when the `MPPImageClassifier` is created with + * `MPPRunningModeVideo`. + * + * @param image The `MPPImage` on which image classification is to be performed. + * @param timestampMs The video frame's timestamp (in milliseconds). The input timestamps must be + * monotonically increasing. + * @param error An optional error parameter populated when there is an error in performing image + * classification on the input video frame. + * + * @return An `MPPImageClassifierResult` object that contains a list of image classifications. + */ +- (nullable MPPImageClassifierResult *)classifyVideoFrame:(MPPImage *)image + timestampMs:(NSInteger)timestampMs + error:(NSError **)error + NS_SWIFT_NAME(classify(videoFrame:timestampMs:)); + +/** + * Performs image classification on the provided video frame of type `MPPImage` cropped to the + * specified region of interest. Rotation will be applied according to the `orientation` property of + * the provided `MPPImage`. Only use this method when the `MPPImageClassifier` is created with + * `MPPRunningModeVideo`. + * + * It's required to provide the video frame's timestamp (in milliseconds). The input timestamps must + * be monotonically increasing. + * + * @param image A live stream image data of type `MPPImage` on which image classification is to be + * performed. + * @param timestampMs The video frame's timestamp (in milliseconds). The input timestamps must be + * monotonically increasing. + * @param roi A `CGRect` specifying the region of interest within the video frame of type + * `MPPImage`, on which image classification should be performed. + * @param error An optional error parameter populated when there is an error in performing image + * classification on the input video frame. + * + * @return An `MPPImageClassifierResult` object that contains a list of image classifications. + */ +- (nullable MPPImageClassifierResult *)classifyVideoFrame:(MPPImage *)image + timestampMs:(NSInteger)timestampMs + regionOfInterest:(CGRect)roi + error:(NSError **)error + NS_SWIFT_NAME(classify(videoFrame:timestampMs:regionOfInterest:)); + +/** + * Sends live stream image data of type `MPPImage` to perform image classification using the whole + * image as region of interest. Rotation will be applied according to the `orientation` property of + * the provided `MPPImage`. Only use this method when the `MPPImageClassifier` is created with + * `MPPRunningModeLiveStream`. Results are provided asynchronously via the `completion` callback + * provided in the `MPPImageClassifierOptions`. + * + * It's required to provide a timestamp (in milliseconds) to indicate when the input image is sent + * to the image classifier. The input timestamps must be monotonically increasing. + * + * @param image A live stream image data of type `MPPImage` on which image classification is to be + * performed. + * @param timestampMs The timestamp (in milliseconds) which indicates when the input image is sent + * to the image classifier. The input timestamps must be monotonically increasing. + * @param error An optional error parameter populated when there is an error in performing image + * classification on the input live stream image data. + * + * @return `YES` if the image was sent to the task successfully, otherwise `NO`. + */ +- (BOOL)classifyAsyncImage:(MPPImage *)image + timestampMs:(NSInteger)timestampMs + error:(NSError **)error NS_SWIFT_NAME(classifyAsync(image:timestampMs:)); + +/** + * Sends live stream image data of type `MPPImage` to perform image classification, cropped to the + * specified region of interest.. Rotation will be applied according to the `orientation` property + * of the provided `MPPImage`. Only use this method when the `MPPImageClassifier` is created with + * `MPPRunningModeLiveStream`. Results are provided asynchronously via the `completion` callback + * provided in the `MPPImageClassifierOptions`. + * + * It's required to provide a timestamp (in milliseconds) to indicate when the input image is sent + * to the image classifier. The input timestamps must be monotonically increasing. + * + * @param image A live stream image data of type `MPPImage` on which image classification is to be + * performed. + * @param timestampMs The timestamp (in milliseconds) which indicates when the input image is sent + * to the image classifier. The input timestamps must be monotonically increasing. + * @param roi A `CGRect` specifying the region of interest within the given live stream image data + * of type `MPPImage`, on which image classification should be performed. + * @param error An optional error parameter populated when there is an error in performing image + * classification on the input live stream image data. + * + * @return `YES` if the image was sent to the task successfully, otherwise `NO`. + */ +- (BOOL)classifyAsyncImage:(MPPImage *)image + timestampMs:(NSInteger)timestampMs + regionOfInterest:(CGRect)roi + error:(NSError **)error + NS_SWIFT_NAME(classifyAsync(image:timestampMs:regionOfInterest:)); + +- (instancetype)init NS_UNAVAILABLE; + ++ (instancetype)new NS_UNAVAILABLE; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifier.mm b/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifier.mm new file mode 100644 index 000000000..0ad79003f --- /dev/null +++ b/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifier.mm @@ -0,0 +1,232 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifier.h" + +#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h" +#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" +#import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h" +#import "mediapipe/tasks/ios/vision/core/sources/MPPVisionPacketCreator.h" +#import "mediapipe/tasks/ios/vision/core/sources/MPPVisionTaskRunner.h" +#import "mediapipe/tasks/ios/vision/image_classifier/utils/sources/MPPImageClassifierOptions+Helpers.h" +#import "mediapipe/tasks/ios/vision/image_classifier/utils/sources/MPPImageClassifierResult+Helpers.h" + +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" + +namespace { +using ::mediapipe::NormalizedRect; +using ::mediapipe::Packet; +using ::mediapipe::tasks::core::PacketMap; +using ::mediapipe::tasks::core::PacketsCallback; +} // namespace + +static NSString *const kClassificationsStreamName = @"classifications_out"; +static NSString *const kClassificationsTag = @"CLASSIFICATIONS"; +static NSString *const kImageInStreamName = @"image_in"; +static NSString *const kImageOutStreamName = @"image_out"; +static NSString *const kImageTag = @"IMAGE"; +static NSString *const kNormRectName = @"norm_rect_in"; +static NSString *const kNormRectTag = @"NORM_RECT"; + +static NSString *const kTaskGraphName = + @"mediapipe.tasks.vision.image_classifier.ImageClassifierGraph"; + +#define InputPacketMap(imagePacket, normalizedRectPacket) \ + { \ + {kImageInStreamName.cppString, imagePacket}, { kNormRectName.cppString, normalizedRectPacket } \ + } + +@interface MPPImageClassifier () { + /** iOS Vision Task Runner */ + MPPVisionTaskRunner *_visionTaskRunner; +} +@end + +@implementation MPPImageClassifier + +- (instancetype)initWithOptions:(MPPImageClassifierOptions *)options error:(NSError **)error { + self = [super init]; + if (self) { + MPPTaskInfo *taskInfo = [[MPPTaskInfo alloc] + initWithTaskGraphName:kTaskGraphName + inputStreams:@[ [NSString + stringWithFormat:@"%@:%@", kImageTag, kImageInStreamName] ] + outputStreams:@[ [NSString stringWithFormat:@"%@:%@", kClassificationsTag, + kClassificationsStreamName] ] + taskOptions:options + enableFlowLimiting:NO + error:error]; + + if (!taskInfo) { + return nil; + } + + PacketsCallback packetsCallback = nullptr; + + if (options.completion) { + packetsCallback = [=](absl::StatusOr status_or_packets) { + NSError *callbackError = nil; + MPPImageClassifierResult *result; + if ([MPPCommonUtils checkCppError:status_or_packets.status() toError:&callbackError]) { + result = [MPPImageClassifierResult + imageClassifierResultWithClassificationsPacket: + status_or_packets.value()[kClassificationsStreamName.cppString]]; + } + options.completion(result, callbackError); + }; + } + + _visionTaskRunner = + [[MPPVisionTaskRunner alloc] initWithCalculatorGraphConfig:[taskInfo generateGraphConfig] + runningMode:options.runningMode + packetsCallback:std::move(packetsCallback) + error:error]; + + if (!_visionTaskRunner) { + return nil; + } + } + return self; +} + +- (instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error { + MPPImageClassifierOptions *options = [[MPPImageClassifierOptions alloc] init]; + + options.baseOptions.modelAssetPath = modelPath; + + return [self initWithOptions:options error:error]; +} + +- (nullable MPPImageClassifierResult *)classifyImage:(MPPImage *)image + regionOfInterest:(CGRect)roi + error:(NSError **)error { + std::optional rect = + [_visionTaskRunner normalizedRectFromRegionOfInterest:roi + imageOrientation:image.orientation + ROIAllowed:YES + error:error]; + if (!rect.has_value()) { + return nil; + } + + Packet imagePacket = [MPPVisionPacketCreator createPacketWithMPPImage:image error:error]; + if (imagePacket.IsEmpty()) { + return nil; + } + + Packet normalizedRectPacket = + [MPPVisionPacketCreator createPacketWithNormalizedRect:rect.value()]; + + PacketMap inputPacketMap = InputPacketMap(imagePacket, normalizedRectPacket); + + std::optional outputPacketMap = [_visionTaskRunner processPacketMap:inputPacketMap + error:error]; + if (!outputPacketMap.has_value()) { + return nil; + } + + return + [MPPImageClassifierResult imageClassifierResultWithClassificationsPacket: + outputPacketMap.value()[kClassificationsStreamName.cppString]]; +} + +- (std::optional)inputPacketMapWithMPPImage:(MPPImage *)image + timestampMs:(NSInteger)timestampMs + regionOfInterest:(CGRect)roi + error:(NSError **)error { + std::optional rect = + [_visionTaskRunner normalizedRectFromRegionOfInterest:roi + imageOrientation:image.orientation + ROIAllowed:YES + error:error]; + if (!rect.has_value()) { + return std::nullopt; + } + + Packet imagePacket = [MPPVisionPacketCreator createPacketWithMPPImage:image + timestampMs:timestampMs + error:error]; + if (imagePacket.IsEmpty()) { + return std::nullopt; + } + + Packet normalizedRectPacket = [MPPVisionPacketCreator createPacketWithNormalizedRect:rect.value() + timestampMs:timestampMs]; + + PacketMap inputPacketMap = InputPacketMap(imagePacket, normalizedRectPacket); + return inputPacketMap; +} + +- (nullable MPPImageClassifierResult *)classifyImage:(MPPImage *)image error:(NSError **)error { + return [self classifyImage:image regionOfInterest:CGRectZero error:error]; +} + +- (nullable MPPImageClassifierResult *)classifyVideoFrame:(MPPImage *)image + timestampMs:(NSInteger)timestampMs + regionOfInterest:(CGRect)roi + error:(NSError **)error { + std::optional inputPacketMap = [self inputPacketMapWithMPPImage:image + timestampMs:timestampMs + regionOfInterest:roi + error:error]; + if (!inputPacketMap.has_value()) { + return nil; + } + + std::optional outputPacketMap = + [_visionTaskRunner processVideoFramePacketMap:inputPacketMap.value() error:error]; + + if (!outputPacketMap.has_value()) { + return nil; + } + + return + [MPPImageClassifierResult imageClassifierResultWithClassificationsPacket: + outputPacketMap.value()[kClassificationsStreamName.cppString]]; +} + +- (nullable MPPImageClassifierResult *)classifyVideoFrame:(MPPImage *)image + timestampMs:(NSInteger)timestampMs + error:(NSError **)error { + return [self classifyVideoFrame:image + timestampMs:timestampMs + regionOfInterest:CGRectZero + error:error]; +} + +- (BOOL)classifyAsyncImage:(MPPImage *)image + timestampMs:(NSInteger)timestampMs + regionOfInterest:(CGRect)roi + error:(NSError **)error { + std::optional inputPacketMap = [self inputPacketMapWithMPPImage:image + timestampMs:timestampMs + regionOfInterest:roi + error:error]; + if (!inputPacketMap.has_value()) { + return NO; + } + + return [_visionTaskRunner processLiveStreamPacketMap:inputPacketMap.value() error:error]; +} + +- (BOOL)classifyAsyncImage:(MPPImage *)image + timestampMs:(NSInteger)timestampMs + error:(NSError **)error { + return [self classifyAsyncImage:image + timestampMs:timestampMs + regionOfInterest:CGRectZero + error:error]; +} + +@end diff --git a/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifierOptions.h b/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifierOptions.h index f7e9a6297..2e6022041 100644 --- a/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifierOptions.h +++ b/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifierOptions.h @@ -31,6 +31,7 @@ NS_SWIFT_NAME(ImageClassifierOptions) /** * The user-defined result callback for processing live stream data. The result callback should only * be specified when the running mode is set to the live stream mode. + * TODO: Add parameter `MPPImage` in the callback. */ @property(nonatomic, copy) void (^completion)(MPPImageClassifierResult *result, NSError *error); diff --git a/mediapipe/tasks/ios/vision/image_classifier/utils/BUILD b/mediapipe/tasks/ios/vision/image_classifier/utils/BUILD new file mode 100644 index 000000000..c1928b6ff --- /dev/null +++ b/mediapipe/tasks/ios/vision/image_classifier/utils/BUILD @@ -0,0 +1,44 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +objc_library( + name = "MPPImageClassifierOptionsHelpers", + srcs = ["sources/MPPImageClassifierOptions+Helpers.mm"], + hdrs = ["sources/MPPImageClassifierOptions+Helpers.h"], + deps = [ + "//mediapipe/framework:calculator_options_cc_proto", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", + "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_cc_proto", + "//mediapipe/tasks/ios/common/utils:NSStringHelpers", + "//mediapipe/tasks/ios/core:MPPTaskOptionsProtocol", + "//mediapipe/tasks/ios/core/utils:MPPBaseOptionsHelpers", + "//mediapipe/tasks/ios/vision/image_classifier:MPPImageClassifierOptions", + ], +) + +objc_library( + name = "MPPImageClassifierResultHelpers", + srcs = ["sources/MPPImageClassifierResult+Helpers.mm"], + hdrs = ["sources/MPPImageClassifierResult+Helpers.h"], + deps = [ + "//mediapipe/framework:packet", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + "//mediapipe/tasks/ios/components/containers/utils:MPPClassificationResultHelpers", + "//mediapipe/tasks/ios/vision/image_classifier:MPPImageClassifierResult", + ], +) diff --git a/mediapipe/tasks/ios/vision/image_classifier/utils/sources/MPPImageClassifierOptions+Helpers.h b/mediapipe/tasks/ios/vision/image_classifier/utils/sources/MPPImageClassifierOptions+Helpers.h new file mode 100644 index 000000000..c3a3b2fec --- /dev/null +++ b/mediapipe/tasks/ios/vision/image_classifier/utils/sources/MPPImageClassifierOptions+Helpers.h @@ -0,0 +1,32 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/calculator_options.pb.h" +#import "mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h" +#import "mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifierOptions.h" + +NS_ASSUME_NONNULL_BEGIN + +@interface MPPImageClassifierOptions (Helpers) + +/** + * Populates the provided `CalculatorOptions` proto container with the current settings. + * + * @param optionsProto The `CalculatorOptions` proto object to copy the settings to. + */ +- (void)copyToProto:(::mediapipe::CalculatorOptions *)optionsProto; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/vision/image_classifier/utils/sources/MPPImageClassifierOptions+Helpers.mm b/mediapipe/tasks/ios/vision/image_classifier/utils/sources/MPPImageClassifierOptions+Helpers.mm new file mode 100644 index 000000000..36ecf9093 --- /dev/null +++ b/mediapipe/tasks/ios/vision/image_classifier/utils/sources/MPPImageClassifierOptions+Helpers.mm @@ -0,0 +1,56 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/vision/image_classifier/utils/sources/MPPImageClassifierOptions+Helpers.h" + +#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" +#import "mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.h" + +#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" +#include "mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.pb.h" + +namespace { +using CalculatorOptionsProto = ::mediapipe::CalculatorOptions; +using ImageClassifierGraphOptionsProto = + ::mediapipe::tasks::vision::image_classifier::proto::ImageClassifierGraphOptions; +using ClassifierOptionsProto = ::mediapipe::tasks::components::processors::proto::ClassifierOptions; +} // namespace + +@implementation MPPImageClassifierOptions (Helpers) + +- (void)copyToProto:(CalculatorOptionsProto *)optionsProto { + ImageClassifierGraphOptionsProto *graphOptions = + optionsProto->MutableExtension(ImageClassifierGraphOptionsProto::ext); + [self.baseOptions copyToProto:graphOptions->mutable_base_options()]; + + ClassifierOptionsProto *classifierOptionsProto = graphOptions->mutable_classifier_options(); + classifierOptionsProto->Clear(); + + if (self.displayNamesLocale) { + classifierOptionsProto->set_display_names_locale(self.displayNamesLocale.cppString); + } + + classifierOptionsProto->set_max_results((int)self.maxResults); + classifierOptionsProto->set_score_threshold(self.scoreThreshold); + + for (NSString *category in self.categoryAllowlist) { + classifierOptionsProto->add_category_allowlist(category.cppString); + } + + for (NSString *category in self.categoryDenylist) { + classifierOptionsProto->add_category_denylist(category.cppString); + } +} + +@end diff --git a/mediapipe/tasks/ios/vision/image_classifier/utils/sources/MPPImageClassifierResult+Helpers.h b/mediapipe/tasks/ios/vision/image_classifier/utils/sources/MPPImageClassifierResult+Helpers.h new file mode 100644 index 000000000..0375ac2a5 --- /dev/null +++ b/mediapipe/tasks/ios/vision/image_classifier/utils/sources/MPPImageClassifierResult+Helpers.h @@ -0,0 +1,36 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifierResult.h" + +#include "mediapipe/framework/packet.h" + +NS_ASSUME_NONNULL_BEGIN + +@interface MPPImageClassifierResult (Helpers) + +/** + * Creates an `MPPImageClassifierResult` from a MediaPipe packet containing an + * `ClassificationResultProto`. + * + * @param packet a MediaPipe packet wrapping a ClassificationResultProto. + * + * @return An `MPPImageClassifierResult` object that contains a list of image classifications. + */ ++ (MPPImageClassifierResult *)imageClassifierResultWithClassificationsPacket: + (const mediapipe::Packet &)packet; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/vision/image_classifier/utils/sources/MPPImageClassifierResult+Helpers.mm b/mediapipe/tasks/ios/vision/image_classifier/utils/sources/MPPImageClassifierResult+Helpers.mm new file mode 100644 index 000000000..09e21b278 --- /dev/null +++ b/mediapipe/tasks/ios/vision/image_classifier/utils/sources/MPPImageClassifierResult+Helpers.mm @@ -0,0 +1,41 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h" +#import "mediapipe/tasks/ios/vision/image_classifier/utils/sources/MPPImageClassifierResult+Helpers.h" + +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" + +static const int kMicroSecondsPerMilliSecond = 1000; + +namespace { +using ClassificationResultProto = + ::mediapipe::tasks::components::containers::proto::ClassificationResult; +using ::mediapipe::Packet; +} // namespace + +@implementation MPPImageClassifierResult (Helpers) + ++ (MPPImageClassifierResult *)imageClassifierResultWithClassificationsPacket: + (const Packet &)packet { + MPPClassificationResult *classificationResult = [MPPClassificationResult + classificationResultWithProto:packet.Get()]; + + return [[MPPImageClassifierResult alloc] + initWithClassificationResult:classificationResult + timestampMs:(NSInteger)(packet.Timestamp().Value() / + kMicroSecondsPerMilliSecond)]; +} + +@end