Merge pull request #4141 from priankakariatyml:ios-image-classifier-impl-files

PiperOrigin-RevId: 518138209
This commit is contained in:
Copybara-Service 2023-03-20 19:05:37 -07:00
commit 6e0542c16a
21 changed files with 1062 additions and 35 deletions

View File

@ -17,6 +17,8 @@
#include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/tasks/cc/core/task_runner.h" #include "mediapipe/tasks/cc/core/task_runner.h"
#include <optional>
NS_ASSUME_NONNULL_BEGIN NS_ASSUME_NONNULL_BEGIN
/** /**
@ -62,24 +64,57 @@ NS_ASSUME_NONNULL_BEGIN
error:(NSError **)error NS_DESIGNATED_INITIALIZER; error:(NSError **)error NS_DESIGNATED_INITIALIZER;
/** /**
* A synchronous method for processing batch data or offline streaming data. This method is designed * A synchronous method for invoking the C++ task runner for processing batch data or offline
* for processing either batch data such as unrelated images and texts or offline streaming data * streaming data. This method is designed for processing either batch data such as unrelated images
* such as the decoded frames from a video file or audio file. The call blocks the current * and texts or offline streaming data such as the decoded frames from a video file or audio file.
* thread until a failure status or a successful result is returned. If the input packets have no * The call blocks the current thread until a failure status or a successful result is returned. If
* timestamp, an internal timestamp will be assigned per invocation. Otherwise, when the timestamp * the input packets have no timestamp, an internal timestamp will be assigned per invocation.
* is set in the input packets, the caller must ensure that the input packet timestamps are greater * Otherwise, when the timestamp is set in the input packets, the caller must ensure that the input
* than the timestamps of the previous invocation. This method is thread-unsafe and it is the * packet timestamps are greater than the timestamps of the previous invocation. This method is
* caller's responsibility to synchronize access to this method across multiple threads and to * thread-unsafe and it is the caller's responsibility to synchronize access to this method across
* ensure that the input packet timestamps are in order. * multiple threads and to ensure that the input packet timestamps are in order.
*
* @param packetMap A `PacketMap` containing pairs of input stream name and data packet which are to
* be sent to the C++ task runner for processing synchronously.
* @param error Pointer to the memory location where errors if any should be saved. If @c NULL, no
* error will be saved.
*
* @return An optional output `PacketMap` containing pairs of output stream name and data packet
* which holds the results of processing the input packet map, if there are no errors.
*/ */
- (absl::StatusOr<mediapipe::tasks::core::PacketMap>)process: - (std::optional<mediapipe::tasks::core::PacketMap>)
(const mediapipe::tasks::core::PacketMap &)packetMap; processPacketMap:(const mediapipe::tasks::core::PacketMap &)packetMap
error:(NSError **)error;
/**
* An asynchronous method that is designed for handling live streaming data such as live camera. A
* user-defined PacketsCallback function must be provided in the constructor to receive the output
* packets. The caller must ensure that the input packet timestamps are monotonically increasing.
* This method is thread-unsafe and it is the caller's responsibility to synchronize access to this
* method across multiple threads and to ensure that the input packet timestamps are in order.
*
* @param packetMap A `PacketMap` containing pairs of input stream name and data packet that are to
* be sent to the C++ task runner for processing asynchronously.
* @param error Pointer to the memory location where errors if any should be saved. If @c NULL, no
* error will be saved.
*
* @return A `BOOL` indicating if the live stream data was sent to the C++ task runner successfully.
* Please note that any errors during processing of the live stream packet map will only be
* available in the user-defined `packetsCallback` that was provided during initialization of the
* `MPPVisionTaskRunner`.
*/
- (BOOL)sendPacketMap:(const mediapipe::tasks::core::PacketMap &)packetMap error:(NSError **)error;
/** /**
* Shuts down the C++ task runner. After the runner is closed, any calls that send input data to the * Shuts down the C++ task runner. After the runner is closed, any calls that send input data to the
* runner are illegal and will receive errors. * runner are illegal and will receive errors.
*
* @param error Pointer to the memory location where errors if any should be saved. If @c NULL, no
* error will be saved.
*
* @return A `BOOL` indicating if the C++ task runner was shutdown successfully.
*/ */
- (absl::Status)close; - (BOOL)closeWithError:(NSError **)error;
- (instancetype)init NS_UNAVAILABLE; - (instancetype)init NS_UNAVAILABLE;

View File

@ -50,12 +50,22 @@ using TaskRunnerCpp = ::mediapipe::tasks::core::TaskRunner;
return self; return self;
} }
- (absl::StatusOr<PacketMap>)process:(const PacketMap &)packetMap { - (std::optional<PacketMap>)processPacketMap:(const PacketMap &)packetMap error:(NSError **)error {
return _cppTaskRunner->Process(packetMap); absl::StatusOr<PacketMap> resultPacketMap = _cppTaskRunner->Process(packetMap);
if (![MPPCommonUtils checkCppError:resultPacketMap.status() toError:error]) {
return std::nullopt;
}
return resultPacketMap.value();
} }
- (absl::Status)close { - (BOOL)sendPacketMap:(const PacketMap &)packetMap error:(NSError **)error {
return _cppTaskRunner->Close(); absl::Status sendStatus = _cppTaskRunner->Send(packetMap);
return [MPPCommonUtils checkCppError:sendStatus toError:error];
}
- (BOOL)closeWithError:(NSError **)error {
absl::Status closeStatus = _cppTaskRunner->Close();
return [MPPCommonUtils checkCppError:closeStatus toError:error];
} }
@end @end

View File

@ -58,6 +58,5 @@ objc_library(
"//mediapipe/tasks/ios/text/core:MPPTextTaskRunner", "//mediapipe/tasks/ios/text/core:MPPTextTaskRunner",
"//mediapipe/tasks/ios/text/text_classifier/utils:MPPTextClassifierOptionsHelpers", "//mediapipe/tasks/ios/text/text_classifier/utils:MPPTextClassifierOptionsHelpers",
"//mediapipe/tasks/ios/text/text_classifier/utils:MPPTextClassifierResultHelpers", "//mediapipe/tasks/ios/text/text_classifier/utils:MPPTextClassifierResultHelpers",
"@com_google_absl//absl/status:statusor",
], ],
) )

View File

@ -22,7 +22,6 @@
#import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h" #import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h"
#import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.h" #import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.h"
#include "absl/status/statusor.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
namespace { namespace {
@ -83,15 +82,16 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.T
Packet packet = [MPPTextPacketCreator createWithText:text]; Packet packet = [MPPTextPacketCreator createWithText:text];
std::map<std::string, Packet> packetMap = {{kTextInStreamName.cppString, packet}}; std::map<std::string, Packet> packetMap = {{kTextInStreamName.cppString, packet}};
absl::StatusOr<PacketMap> statusOrOutputPacketMap = [_textTaskRunner process:packetMap]; std::optional<PacketMap> outputPacketMap = [_textTaskRunner processPacketMap:packetMap
error:error];
if (![MPPCommonUtils checkCppError:statusOrOutputPacketMap.status() toError:error]) { if (!outputPacketMap.has_value()) {
return nil; return nil;
} }
return [MPPTextClassifierResult return
textClassifierResultWithClassificationsPacket:statusOrOutputPacketMap.value() [MPPTextClassifierResult textClassifierResultWithClassificationsPacket:
[kClassificationsStreamName.cppString]]; outputPacketMap.value()[kClassificationsStreamName.cppString]];
} }
@end @end

View File

@ -58,6 +58,5 @@ objc_library(
"//mediapipe/tasks/ios/text/core:MPPTextTaskRunner", "//mediapipe/tasks/ios/text/core:MPPTextTaskRunner",
"//mediapipe/tasks/ios/text/text_embedder/utils:MPPTextEmbedderOptionsHelpers", "//mediapipe/tasks/ios/text/text_embedder/utils:MPPTextEmbedderOptionsHelpers",
"//mediapipe/tasks/ios/text/text_embedder/utils:MPPTextEmbedderResultHelpers", "//mediapipe/tasks/ios/text/text_embedder/utils:MPPTextEmbedderResultHelpers",
"@com_google_absl//absl/status:statusor",
], ],
) )

View File

@ -23,8 +23,6 @@
#import "mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderOptions+Helpers.h" #import "mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderOptions+Helpers.h"
#import "mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderResult+Helpers.h" #import "mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderResult+Helpers.h"
#include "absl/status/statusor.h"
namespace { namespace {
using ::mediapipe::Packet; using ::mediapipe::Packet;
using ::mediapipe::tasks::core::PacketMap; using ::mediapipe::tasks::core::PacketMap;
@ -83,14 +81,15 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_embedder.Tex
Packet packet = [MPPTextPacketCreator createWithText:text]; Packet packet = [MPPTextPacketCreator createWithText:text];
std::map<std::string, Packet> packetMap = {{kTextInStreamName.cppString, packet}}; std::map<std::string, Packet> packetMap = {{kTextInStreamName.cppString, packet}};
absl::StatusOr<PacketMap> statusOrOutputPacketMap = [_textTaskRunner process:packetMap];
if (![MPPCommonUtils checkCppError:statusOrOutputPacketMap.status() toError:error]) { std::optional<PacketMap> outputPacketMap = [_textTaskRunner processPacketMap:packetMap
error:error];
if (!outputPacketMap.has_value()) {
return nil; return nil;
} }
return [MPPTextEmbedderResult return [MPPTextEmbedderResult
textEmbedderResultWithOutputPacket:statusOrOutputPacketMap textEmbedderResultWithOutputPacket:outputPacketMap
.value()[kEmbeddingsOutStreamName.cppString]]; .value()[kEmbeddingsOutStreamName.cppString]];
} }

View File

@ -26,6 +26,24 @@ objc_library(
module_name = "MPPRunningMode", module_name = "MPPRunningMode",
) )
objc_library(
name = "MPPVisionPacketCreator",
srcs = ["sources/MPPVisionPacketCreator.mm"],
hdrs = ["sources/MPPVisionPacketCreator.h"],
copts = [
"-ObjC++",
"-std=c++17",
],
deps = [
":MPPImage",
"//mediapipe/framework:packet",
"//mediapipe/framework:timestamp",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/tasks/ios/vision/core/utils:MPPImageUtils",
],
)
objc_library( objc_library(
name = "MPPVisionTaskRunner", name = "MPPVisionTaskRunner",
srcs = ["sources/MPPVisionTaskRunner.mm"], srcs = ["sources/MPPVisionTaskRunner.mm"],
@ -36,8 +54,11 @@ objc_library(
], ],
deps = [ deps = [
":MPPRunningMode", ":MPPRunningMode",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/tasks/ios/common:MPPCommon", "//mediapipe/tasks/ios/common:MPPCommon",
"//mediapipe/tasks/ios/common/utils:MPPCommonUtils", "//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
"//mediapipe/tasks/ios/core:MPPTaskRunner", "//mediapipe/tasks/ios/core:MPPTaskRunner",
"//third_party/apple_frameworks:UIKit",
"@com_google_absl//absl/status:statusor",
], ],
) )

View File

@ -38,4 +38,17 @@ typedef NS_ENUM(NSUInteger, MPPRunningMode) {
} NS_SWIFT_NAME(RunningMode); } NS_SWIFT_NAME(RunningMode);
NS_INLINE NSString *MPPRunningModeDisplayName(MPPRunningMode runningMode) {
if (runningMode > MPPRunningModeLiveStream) {
return nil;
}
NSString *displayNameMap[MPPRunningModeLiveStream + 1] = {
[MPPRunningModeImage] = @"#MPPRunningModeImage",
[MPPRunningModeVideo] = @ "#MPPRunningModeVideo",
[MPPRunningModeLiveStream] = @ "#MPPRunningModeLiveStream"};
return displayNameMap[runningMode];
}
NS_ASSUME_NONNULL_END NS_ASSUME_NONNULL_END

View File

@ -14,14 +14,63 @@
#import <Foundation/Foundation.h> #import <Foundation/Foundation.h>
#include "mediapipe/framework/packet.h"
#import "mediapipe/tasks/ios/vision/core/sources/MPPImage.h" #import "mediapipe/tasks/ios/vision/core/sources/MPPImage.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/packet.h"
/** /**
* This class helps create various kinds of packets for Mediapipe Vision Tasks. * This class helps create various kinds of packets for Mediapipe Vision Tasks.
*/ */
@interface MPPVisionPacketCreator : NSObject @interface MPPVisionPacketCreator : NSObject
/**
* Creates a MediapPipe Packet wrapping an `MPPImage` that can be send to a graph.
*
* @param image The image to send to the MediaPipe graph.
* @param error Pointer to the memory location where errors if any should be saved. If @c NULL, no
* error will be saved.
*
* @return The MediaPipe packet containing the image. An empty packet is returned if an error
* occurred during the conversion.
*/
+ (mediapipe::Packet)createPacketWithMPPImage:(MPPImage *)image error:(NSError **)error; + (mediapipe::Packet)createPacketWithMPPImage:(MPPImage *)image error:(NSError **)error;
/**
* Creates a MediapPipe Packet wrapping an `MPPImage` that can be send to a graph at the specified
* timestamp.
*
* @param image The image to send to the MediaPipe graph.
* @param timestampMs The timestamp (in milliseconds) to assign to the packet.
* @param error Pointer to the memory location where errors if any should be saved. If @c NULL, no
* error will be saved.
*
* @return The MediaPipe packet containing the image. An empty packet is returned if an error
* occurred during the conversion.
*/
+ (mediapipe::Packet)createPacketWithMPPImage:(MPPImage *)image
timestampMs:(NSInteger)timestampMs
error:(NSError **)error;
/**
* Creates a MediapPipe Packet wrapping a `NormalizedRect` that can be send to a graph.
*
* @param image The `NormalizedRect` to send to the MediaPipe graph.
*
* @return The MediaPipe packet containing the normalized rect.
*/
+ (mediapipe::Packet)createPacketWithNormalizedRect:(mediapipe::NormalizedRect &)normalizedRect;
/**
* Creates a MediapPipe Packet wrapping a `NormalizedRect` that can be send to a graph at the
* specified timestamp.
*
* @param image The `NormalizedRect` to send to the MediaPipe graph.
* @param timestampMs The timestamp (in milliseconds) to assign to the packet.
*
* @return The MediaPipe packet containing the normalized rect.
*/
+ (mediapipe::Packet)createPacketWithNormalizedRect:(mediapipe::NormalizedRect &)normalizedRect
timestampMs:(NSInteger)timestampMs;
@end @end

View File

@ -16,18 +16,19 @@
#import "mediapipe/tasks/ios/vision/core/utils/sources/MPPImage+Utils.h" #import "mediapipe/tasks/ios/vision/core/utils/sources/MPPImage+Utils.h"
#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/timestamp.h"
static const NSUInteger kMicroSecondsPerMilliSecond = 1000;
namespace { namespace {
using ::mediapipe::Image; using ::mediapipe::Image;
using ::mediapipe::ImageFrame; using ::mediapipe::ImageFrame;
using ::mediapipe::MakePacket; using ::mediapipe::MakePacket;
using ::mediapipe::NormalizedRect;
using ::mediapipe::Packet; using ::mediapipe::Packet;
using ::mediapipe::Timestamp;
} // namespace } // namespace
struct freeDeleter {
void operator()(void *ptr) { free(ptr); }
};
@implementation MPPVisionPacketCreator @implementation MPPVisionPacketCreator
+ (Packet)createPacketWithMPPImage:(MPPImage *)image error:(NSError **)error { + (Packet)createPacketWithMPPImage:(MPPImage *)image error:(NSError **)error {
@ -40,4 +41,27 @@ struct freeDeleter {
return MakePacket<Image>(std::move(imageFrame)); return MakePacket<Image>(std::move(imageFrame));
} }
+ (Packet)createPacketWithMPPImage:(MPPImage *)image
timestampMs:(NSInteger)timestampMs
error:(NSError **)error {
std::unique_ptr<ImageFrame> imageFrame = [image imageFrameWithError:error];
if (!imageFrame) {
return Packet();
}
return MakePacket<Image>(std::move(imageFrame))
.At(Timestamp(int64(timestampMs * kMicroSecondsPerMilliSecond)));
}
+ (Packet)createPacketWithNormalizedRect:(NormalizedRect &)normalizedRect {
return MakePacket<NormalizedRect>(std::move(normalizedRect));
}
+ (Packet)createPacketWithNormalizedRect:(NormalizedRect &)normalizedRect
timestampMs:(NSInteger)timestampMs {
return MakePacket<NormalizedRect>(std::move(normalizedRect))
.At(Timestamp(int64(timestampMs * kMicroSecondsPerMilliSecond)));
}
@end @end

View File

@ -13,10 +13,13 @@
// limitations under the License. // limitations under the License.
#import <Foundation/Foundation.h> #import <Foundation/Foundation.h>
#import <UIKit/UIKit.h>
#import "mediapipe/tasks/ios/core/sources/MPPTaskRunner.h" #import "mediapipe/tasks/ios/core/sources/MPPTaskRunner.h"
#import "mediapipe/tasks/ios/vision/core/sources/MPPRunningMode.h" #import "mediapipe/tasks/ios/vision/core/sources/MPPRunningMode.h"
#include "mediapipe/framework/formats/rect.pb.h"
NS_ASSUME_NONNULL_BEGIN NS_ASSUME_NONNULL_BEGIN
/** /**
@ -54,6 +57,82 @@ NS_ASSUME_NONNULL_BEGIN
(mediapipe::tasks::core::PacketsCallback)packetsCallback (mediapipe::tasks::core::PacketsCallback)packetsCallback
error:(NSError **)error NS_DESIGNATED_INITIALIZER; error:(NSError **)error NS_DESIGNATED_INITIALIZER;
/**
* Creates a `NormalizedRect` from a region of interest and an image orientation, performing
* sanity checks on-the-fly.
* If the input region of interest equals `CGRectZero`, returns a default `NormalizedRect` covering
* the whole image with rotation set according `imageOrientation`. If `ROIAllowed` is NO, an error
* will be returned if the input region of interest is not equal to `CGRectZero`. Mirrored
* orientations (`UIImageOrientationUpMirrored`,`UIImageOrientationDownMirrored`,
* `UIImageOrientationLeftMirrored`,`UIImageOrientationRightMirrored`) are not supported. An error
* will be returned if `imageOrientation` is equal to any one of them.
*
* @param roi A `CGRect` specifying the region of interest. If the input region of interest equals
* `CGRectZero`, the returned `NormalizedRect` covers the whole image. Make sure that `roi` equals
* `CGRectZero` if `ROIAllowed` is NO. Otherwise, an error will be returned.
* @param imageOrientation A `UIImageOrientation` indicating the rotation to be applied to the
* image. The resulting `NormalizedRect` will convert the `imageOrientation` to degrees clockwise.
* Mirrored orientations (`UIImageOrientationUpMirrored`, `UIImageOrientationDownMirrored`,
* `UIImageOrientationLeftMirrored`, `UIImageOrientationRightMirrored`) are not supported. An error
* will be returned if `imageOrientation` is equal to any one of them.
* @param ROIAllowed Indicates if the `roi` field is allowed to be a value other than `CGRectZero`.
* @param error Pointer to the memory location where errors if any should be saved. If @c NULL, no
* error will be saved.
*
* @return An optional `NormalizedRect` from the given region of interest and image orientation.
*/
- (std::optional<mediapipe::NormalizedRect>)
normalizedRectFromRegionOfInterest:(CGRect)roi
imageOrientation:(UIImageOrientation)imageOrientation
ROIAllowed:(BOOL)ROIAllowed
error:(NSError **)error;
/**
* A synchronous method to invoke the C++ task runner to process single image inputs. The call
* blocks the current thread until a failure status or a successful result is returned.
*
* @param packetMap A `PackeMap` containing pairs of input stream name and data packet.
* @param error Pointer to the memory location where errors if any should be
* saved. If @c NULL, no error will be saved.
*
* @return An optional `PacketMap` containing pairs of output stream name and data packet.
*/
- (std::optional<mediapipe::tasks::core::PacketMap>)
processImagePacketMap:(const mediapipe::tasks::core::PacketMap &)packetMap
error:(NSError **)error;
/**
* A synchronous method to invoke the C++ task runner to process continuous video frames. The call
* blocks the current thread until a failure status or a successful result is returned.
*
* @param packetMap A `PackeMap` containing pairs of input stream name and data packet.
* @param error Pointer to the memory location where errors if any should be saved. If @c NULL, no
* error will be saved.
*
* @return An optional `PacketMap` containing pairs of output stream name and data packet.
*/
- (std::optional<mediapipe::tasks::core::PacketMap>)
processVideoFramePacketMap:(const mediapipe::tasks::core::PacketMap &)packetMap
error:(NSError **)error;
/**
* An asynchronous method to send live stream data to the C++ task runner. The call blocks the
* current thread until a failure status or a successful result is returned. The results will be
* available in the user-defined `packetsCallback` that was provided during initialization of the
* `MPPVisionTaskRunner`.
*
* @param packetMap A `PackeMap` containing pairs of input stream name and data packet.
* @param error Pointer to the memory location where errors if any should be saved. If @c NULL, no
* error will be saved.
*
* @return A `BOOL` indicating if the live stream data was sent to the C++ task runner successfully.
* Please note that any errors during processing of the live stream packet map will only be
* available in the user-defined `packetsCallback` that was provided during initialization of the
* `MPPVisionTaskRunner`.
*/
- (BOOL)processLiveStreamPacketMap:(const mediapipe::tasks::core::PacketMap &)packetMap
error:(NSError **)error;
- (instancetype)initWithCalculatorGraphConfig:(mediapipe::CalculatorGraphConfig)graphConfig - (instancetype)initWithCalculatorGraphConfig:(mediapipe::CalculatorGraphConfig)graphConfig
packetsCallback: packetsCallback:
(mediapipe::tasks::core::PacketsCallback)packetsCallback (mediapipe::tasks::core::PacketsCallback)packetsCallback

View File

@ -17,11 +17,26 @@
#import "mediapipe/tasks/ios/common/sources/MPPCommon.h" #import "mediapipe/tasks/ios/common/sources/MPPCommon.h"
#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h" #import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h"
#include "absl/status/statusor.h"
#include <optional>
namespace { namespace {
using ::mediapipe::CalculatorGraphConfig; using ::mediapipe::CalculatorGraphConfig;
using ::mediapipe::NormalizedRect;
using ::mediapipe::tasks::core::PacketMap;
using ::mediapipe::tasks::core::PacketsCallback; using ::mediapipe::tasks::core::PacketsCallback;
} // namespace } // namespace
/** Rotation degress for a 90 degree rotation to the right. */
static const NSInteger kMPPOrientationDegreesRight = -90;
/** Rotation degress for a 180 degree rotation. */
static const NSInteger kMPPOrientationDegreesDown = -180;
/** Rotation degress for a 90 degree rotation to the left. */
static const NSInteger kMPPOrientationDegreesLeft = -270;
@interface MPPVisionTaskRunner () { @interface MPPVisionTaskRunner () {
MPPRunningMode _runningMode; MPPRunningMode _runningMode;
} }
@ -70,4 +85,100 @@ using ::mediapipe::tasks::core::PacketsCallback;
return self; return self;
} }
- (std::optional<NormalizedRect>)normalizedRectFromRegionOfInterest:(CGRect)roi
imageOrientation:
(UIImageOrientation)imageOrientation
ROIAllowed:(BOOL)ROIAllowed
error:(NSError **)error {
if (CGRectEqualToRect(roi, CGRectZero) && !ROIAllowed) {
[MPPCommonUtils createCustomError:error
withCode:MPPTasksErrorCodeInvalidArgumentError
description:@"This task doesn't support region-of-interest."];
return std::nullopt;
}
CGRect calculatedRoi = CGRectEqualToRect(roi, CGRectZero) ? roi : CGRectMake(0.0, 0.0, 1.0, 1.0);
NormalizedRect normalizedRect;
normalizedRect.set_x_center(CGRectGetMidX(calculatedRoi));
normalizedRect.set_y_center(CGRectGetMidY(calculatedRoi));
normalizedRect.set_width(CGRectGetWidth(calculatedRoi));
normalizedRect.set_height(CGRectGetHeight(calculatedRoi));
int rotationDegrees = 0;
switch (imageOrientation) {
case UIImageOrientationUp:
break;
case UIImageOrientationRight: {
rotationDegrees = kMPPOrientationDegreesRight;
break;
}
case UIImageOrientationDown: {
rotationDegrees = kMPPOrientationDegreesDown;
break;
}
case UIImageOrientationLeft: {
rotationDegrees = kMPPOrientationDegreesLeft;
break;
}
default:
[MPPCommonUtils
createCustomError:error
withCode:MPPTasksErrorCodeInvalidArgumentError
description:
@"Unsupported UIImageOrientation. `imageOrientation` cannot be equal to "
@"any of the mirrored orientations "
@"(`UIImageOrientationUpMirrored`,`UIImageOrientationDownMirrored`,`"
@"UIImageOrientationLeftMirrored`,`UIImageOrientationRightMirrored`)"];
}
normalizedRect.set_rotation(rotationDegrees * M_PI / kMPPOrientationDegreesDown);
return normalizedRect;
}
- (std::optional<PacketMap>)processImagePacketMap:(const PacketMap &)packetMap
error:(NSError **)error {
if (_runningMode != MPPRunningModeImage) {
[MPPCommonUtils
createCustomError:error
withCode:MPPTasksErrorCodeInvalidArgumentError
description:[NSString stringWithFormat:@"The vision task is not initialized with "
@"image mode. Current Running Mode: %@",
MPPRunningModeDisplayName(_runningMode)]];
return std::nullopt;
}
return [self processPacketMap:packetMap error:error];
}
- (std::optional<PacketMap>)processVideoFramePacketMap:(const PacketMap &)packetMap
error:(NSError **)error {
if (_runningMode != MPPRunningModeVideo) {
[MPPCommonUtils
createCustomError:error
withCode:MPPTasksErrorCodeInvalidArgumentError
description:[NSString stringWithFormat:@"The vision task is not initialized with "
@"video mode. Current Running Mode: %@",
MPPRunningModeDisplayName(_runningMode)]];
return std::nullopt;
}
return [self processPacketMap:packetMap error:error];
}
- (BOOL)processLiveStreamPacketMap:(const PacketMap &)packetMap error:(NSError **)error {
if (_runningMode != MPPRunningModeLiveStream) {
[MPPCommonUtils
createCustomError:error
withCode:MPPTasksErrorCodeInvalidArgumentError
description:[NSString stringWithFormat:@"The vision task is not initialized with "
@"live stream mode. Current Running Mode: %@",
MPPRunningModeDisplayName(_runningMode)]];
return NO;
}
return [self sendPacketMap:packetMap error:error];
}
@end @end

View File

@ -36,3 +36,30 @@ objc_library(
"//mediapipe/tasks/ios/vision/core:MPPRunningMode", "//mediapipe/tasks/ios/vision/core:MPPRunningMode",
], ],
) )
objc_library(
name = "MPPImageClassifier",
srcs = ["sources/MPPImageClassifier.mm"],
hdrs = ["sources/MPPImageClassifier.h"],
copts = [
"-ObjC++",
"-std=c++17",
"-x objective-c++",
],
module_name = "MPPImageClassifier",
deps = [
":MPPImageClassifierOptions",
":MPPImageClassifierResult",
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
"//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph",
"//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
"//mediapipe/tasks/ios/common/utils:NSStringHelpers",
"//mediapipe/tasks/ios/core:MPPTaskInfo",
"//mediapipe/tasks/ios/core:MPPTaskOptions",
"//mediapipe/tasks/ios/vision/core:MPPImage",
"//mediapipe/tasks/ios/vision/core:MPPVisionPacketCreator",
"//mediapipe/tasks/ios/vision/core:MPPVisionTaskRunner",
"//mediapipe/tasks/ios/vision/image_classifier/utils:MPPImageClassifierOptionsHelpers",
"//mediapipe/tasks/ios/vision/image_classifier/utils:MPPImageClassifierResultHelpers",
],
)

View File

@ -0,0 +1,219 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import <Foundation/Foundation.h>
#import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h"
#import "mediapipe/tasks/ios/vision/core/sources/MPPImage.h"
#import "mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifierOptions.h"
#import "mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifierResult.h"
NS_ASSUME_NONNULL_BEGIN
/**
* @brief Performs classification on images.
*
* The API expects a TFLite model with optional, but strongly recommended,
* [TFLite Model Metadata.](https://www.tensorflow.org/lite/convert/metadata").
*
* The API supports models with one image input tensor and one or more output tensors. To be more
* specific, here are the requirements.
*
* Input tensor
* (kTfLiteUInt8/kTfLiteFloat32)
* - image input of size `[batch x height x width x channels]`.
* - batch inference is not supported (`batch` is required to be 1).
* - only RGB inputs are supported (`channels` is required to be 3).
* - if type is kTfLiteFloat32, NormalizationOptions are required to be attached to the metadata
* for input normalization.
*
* At least one output tensor with:
* (kTfLiteUInt8/kTfLiteFloat32)
* - `N `classes and either 2 or 4 dimensions, i.e. `[1 x N]` or `[1 x 1 x 1 x N]`
* - optional (but recommended) label map(s) as AssociatedFiles with type TENSOR_AXIS_LABELS,
* containing one label per line. The first such AssociatedFile (if any) is used to fill the
* `class_name` field of the results. The `display_name` field is filled from the AssociatedFile
* (if any) whose locale matches the `display_names_locale` field of the `ImageClassifierOptions`
* used at creation time ("en" by default, i.e. English). If none of these are available, only
* the `index` field of the results will be filled.
* - optional score calibration can be attached using ScoreCalibrationOptions and an AssociatedFile
* with type TENSOR_AXIS_SCORE_CALIBRATION. See metadata_schema.fbs [1] for more details.
*/
NS_SWIFT_NAME(ImageClassifier)
@interface MPPImageClassifier : NSObject
/**
* Creates a new instance of `MPPImageClassifier` from an absolute path to a TensorFlow Lite model
* file stored locally on the device and the default `MPPImageClassifierOptions`.
*
* @param modelPath An absolute path to a TensorFlow Lite model file stored locally on the device.
* @param error An optional error parameter populated when there is an error in initializing the
* image classifier.
*
* @return A new instance of `MPPImageClassifier` with the given model path. `nil` if there is an
* error in initializing the image classifier.
*/
- (nullable instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error;
/**
* Creates a new instance of `MPPImageClassifier` from the given `MPPImageClassifierOptions`.
*
* @param options The options of type `MPPImageClassifierOptions` to use for configuring the
* `MPPImageClassifier`.
* @param error An optional error parameter populated when there is an error in initializing the
* image classifier.
*
* @return A new instance of `MPPImageClassifier` with the given options. `nil` if there is an error
* in initializing the image classifier.
*/
- (nullable instancetype)initWithOptions:(MPPImageClassifierOptions *)options
error:(NSError **)error NS_DESIGNATED_INITIALIZER;
/**
* Performs image classification on the provided MPPImage using the whole image as region of
* interest. Rotation will be applied according to the `orientation` property of the provided
* `MPPImage`. Only use this method when the `MPPImageClassifier` is created with
* `MPPRunningModeImage`.
*
* @param image The `MPPImage` on which image classification is to be performed.
* @param error An optional error parameter populated when there is an error in performing image
* classification on the input image.
*
* @return An `MPPImageClassifierResult` object that contains a list of image classifications.
*/
- (nullable MPPImageClassifierResult *)classifyImage:(MPPImage *)image
error:(NSError **)error
NS_SWIFT_NAME(classify(image:));
/**
* Performs image classification on the provided `MPPImage` cropped to the specified region of
* interest. Rotation will be applied on the cropped image according to the `orientation` property
* of the provided `MPPImage`. Only use this method when the `MPPImageClassifier` is created with
* `MPPRunningModeImage`.
*
* @param image The `MPPImage` on which image classification is to be performed.
* @param roi A `CGRect` specifying the region of interest within the given `MPPImage`, on which
* image classification should be performed.
* @param error An optional error parameter populated when there is an error in performing image
* classification on the input image.
*
* @return An `MPPImageClassifierResult` object that contains a list of image classifications.
*/
- (nullable MPPImageClassifierResult *)classifyImage:(MPPImage *)image
regionOfInterest:(CGRect)roi
error:(NSError **)error
NS_SWIFT_NAME(classify(image:regionOfInterest:));
/**
* Performs image classification on the provided video frame of type `MPPImage` using the whole
* image as region of interest. Rotation will be applied according to the `orientation` property of
* the provided `MPPImage`. Only use this method when the `MPPImageClassifier` is created with
* `MPPRunningModeVideo`.
*
* @param image The `MPPImage` on which image classification is to be performed.
* @param timestampMs The video frame's timestamp (in milliseconds). The input timestamps must be
* monotonically increasing.
* @param error An optional error parameter populated when there is an error in performing image
* classification on the input video frame.
*
* @return An `MPPImageClassifierResult` object that contains a list of image classifications.
*/
- (nullable MPPImageClassifierResult *)classifyVideoFrame:(MPPImage *)image
timestampMs:(NSInteger)timestampMs
error:(NSError **)error
NS_SWIFT_NAME(classify(videoFrame:timestampMs:));
/**
* Performs image classification on the provided video frame of type `MPPImage` cropped to the
* specified region of interest. Rotation will be applied according to the `orientation` property of
* the provided `MPPImage`. Only use this method when the `MPPImageClassifier` is created with
* `MPPRunningModeVideo`.
*
* It's required to provide the video frame's timestamp (in milliseconds). The input timestamps must
* be monotonically increasing.
*
* @param image A live stream image data of type `MPPImage` on which image classification is to be
* performed.
* @param timestampMs The video frame's timestamp (in milliseconds). The input timestamps must be
* monotonically increasing.
* @param roi A `CGRect` specifying the region of interest within the video frame of type
* `MPPImage`, on which image classification should be performed.
* @param error An optional error parameter populated when there is an error in performing image
* classification on the input video frame.
*
* @return An `MPPImageClassifierResult` object that contains a list of image classifications.
*/
- (nullable MPPImageClassifierResult *)classifyVideoFrame:(MPPImage *)image
timestampMs:(NSInteger)timestampMs
regionOfInterest:(CGRect)roi
error:(NSError **)error
NS_SWIFT_NAME(classify(videoFrame:timestampMs:regionOfInterest:));
/**
* Sends live stream image data of type `MPPImage` to perform image classification using the whole
* image as region of interest. Rotation will be applied according to the `orientation` property of
* the provided `MPPImage`. Only use this method when the `MPPImageClassifier` is created with
* `MPPRunningModeLiveStream`. Results are provided asynchronously via the `completion` callback
* provided in the `MPPImageClassifierOptions`.
*
* It's required to provide a timestamp (in milliseconds) to indicate when the input image is sent
* to the image classifier. The input timestamps must be monotonically increasing.
*
* @param image A live stream image data of type `MPPImage` on which image classification is to be
* performed.
* @param timestampMs The timestamp (in milliseconds) which indicates when the input image is sent
* to the image classifier. The input timestamps must be monotonically increasing.
* @param error An optional error parameter populated when there is an error in performing image
* classification on the input live stream image data.
*
* @return `YES` if the image was sent to the task successfully, otherwise `NO`.
*/
- (BOOL)classifyAsyncImage:(MPPImage *)image
timestampMs:(NSInteger)timestampMs
error:(NSError **)error NS_SWIFT_NAME(classifyAsync(image:timestampMs:));
/**
* Sends live stream image data of type `MPPImage` to perform image classification, cropped to the
* specified region of interest.. Rotation will be applied according to the `orientation` property
* of the provided `MPPImage`. Only use this method when the `MPPImageClassifier` is created with
* `MPPRunningModeLiveStream`. Results are provided asynchronously via the `completion` callback
* provided in the `MPPImageClassifierOptions`.
*
* It's required to provide a timestamp (in milliseconds) to indicate when the input image is sent
* to the image classifier. The input timestamps must be monotonically increasing.
*
* @param image A live stream image data of type `MPPImage` on which image classification is to be
* performed.
* @param timestampMs The timestamp (in milliseconds) which indicates when the input image is sent
* to the image classifier. The input timestamps must be monotonically increasing.
* @param roi A `CGRect` specifying the region of interest within the given live stream image data
* of type `MPPImage`, on which image classification should be performed.
* @param error An optional error parameter populated when there is an error in performing image
* classification on the input live stream image data.
*
* @return `YES` if the image was sent to the task successfully, otherwise `NO`.
*/
- (BOOL)classifyAsyncImage:(MPPImage *)image
timestampMs:(NSInteger)timestampMs
regionOfInterest:(CGRect)roi
error:(NSError **)error
NS_SWIFT_NAME(classifyAsync(image:timestampMs:regionOfInterest:));
- (instancetype)init NS_UNAVAILABLE;
+ (instancetype)new NS_UNAVAILABLE;
@end
NS_ASSUME_NONNULL_END

View File

@ -0,0 +1,232 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import "mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifier.h"
#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h"
#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h"
#import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h"
#import "mediapipe/tasks/ios/vision/core/sources/MPPVisionPacketCreator.h"
#import "mediapipe/tasks/ios/vision/core/sources/MPPVisionTaskRunner.h"
#import "mediapipe/tasks/ios/vision/image_classifier/utils/sources/MPPImageClassifierOptions+Helpers.h"
#import "mediapipe/tasks/ios/vision/image_classifier/utils/sources/MPPImageClassifierResult+Helpers.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
namespace {
using ::mediapipe::NormalizedRect;
using ::mediapipe::Packet;
using ::mediapipe::tasks::core::PacketMap;
using ::mediapipe::tasks::core::PacketsCallback;
} // namespace
static NSString *const kClassificationsStreamName = @"classifications_out";
static NSString *const kClassificationsTag = @"CLASSIFICATIONS";
static NSString *const kImageInStreamName = @"image_in";
static NSString *const kImageOutStreamName = @"image_out";
static NSString *const kImageTag = @"IMAGE";
static NSString *const kNormRectName = @"norm_rect_in";
static NSString *const kNormRectTag = @"NORM_RECT";
static NSString *const kTaskGraphName =
@"mediapipe.tasks.vision.image_classifier.ImageClassifierGraph";
#define InputPacketMap(imagePacket, normalizedRectPacket) \
{ \
{kImageInStreamName.cppString, imagePacket}, { kNormRectName.cppString, normalizedRectPacket } \
}
@interface MPPImageClassifier () {
/** iOS Vision Task Runner */
MPPVisionTaskRunner *_visionTaskRunner;
}
@end
@implementation MPPImageClassifier
- (instancetype)initWithOptions:(MPPImageClassifierOptions *)options error:(NSError **)error {
self = [super init];
if (self) {
MPPTaskInfo *taskInfo = [[MPPTaskInfo alloc]
initWithTaskGraphName:kTaskGraphName
inputStreams:@[ [NSString
stringWithFormat:@"%@:%@", kImageTag, kImageInStreamName] ]
outputStreams:@[ [NSString stringWithFormat:@"%@:%@", kClassificationsTag,
kClassificationsStreamName] ]
taskOptions:options
enableFlowLimiting:NO
error:error];
if (!taskInfo) {
return nil;
}
PacketsCallback packetsCallback = nullptr;
if (options.completion) {
packetsCallback = [=](absl::StatusOr<PacketMap> status_or_packets) {
NSError *callbackError = nil;
MPPImageClassifierResult *result;
if ([MPPCommonUtils checkCppError:status_or_packets.status() toError:&callbackError]) {
result = [MPPImageClassifierResult
imageClassifierResultWithClassificationsPacket:
status_or_packets.value()[kClassificationsStreamName.cppString]];
}
options.completion(result, callbackError);
};
}
_visionTaskRunner =
[[MPPVisionTaskRunner alloc] initWithCalculatorGraphConfig:[taskInfo generateGraphConfig]
runningMode:options.runningMode
packetsCallback:std::move(packetsCallback)
error:error];
if (!_visionTaskRunner) {
return nil;
}
}
return self;
}
- (instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error {
MPPImageClassifierOptions *options = [[MPPImageClassifierOptions alloc] init];
options.baseOptions.modelAssetPath = modelPath;
return [self initWithOptions:options error:error];
}
- (nullable MPPImageClassifierResult *)classifyImage:(MPPImage *)image
regionOfInterest:(CGRect)roi
error:(NSError **)error {
std::optional<NormalizedRect> rect =
[_visionTaskRunner normalizedRectFromRegionOfInterest:roi
imageOrientation:image.orientation
ROIAllowed:YES
error:error];
if (!rect.has_value()) {
return nil;
}
Packet imagePacket = [MPPVisionPacketCreator createPacketWithMPPImage:image error:error];
if (imagePacket.IsEmpty()) {
return nil;
}
Packet normalizedRectPacket =
[MPPVisionPacketCreator createPacketWithNormalizedRect:rect.value()];
PacketMap inputPacketMap = InputPacketMap(imagePacket, normalizedRectPacket);
std::optional<PacketMap> outputPacketMap = [_visionTaskRunner processPacketMap:inputPacketMap
error:error];
if (!outputPacketMap.has_value()) {
return nil;
}
return
[MPPImageClassifierResult imageClassifierResultWithClassificationsPacket:
outputPacketMap.value()[kClassificationsStreamName.cppString]];
}
- (std::optional<PacketMap>)inputPacketMapWithMPPImage:(MPPImage *)image
timestampMs:(NSInteger)timestampMs
regionOfInterest:(CGRect)roi
error:(NSError **)error {
std::optional<NormalizedRect> rect =
[_visionTaskRunner normalizedRectFromRegionOfInterest:roi
imageOrientation:image.orientation
ROIAllowed:YES
error:error];
if (!rect.has_value()) {
return std::nullopt;
}
Packet imagePacket = [MPPVisionPacketCreator createPacketWithMPPImage:image
timestampMs:timestampMs
error:error];
if (imagePacket.IsEmpty()) {
return std::nullopt;
}
Packet normalizedRectPacket = [MPPVisionPacketCreator createPacketWithNormalizedRect:rect.value()
timestampMs:timestampMs];
PacketMap inputPacketMap = InputPacketMap(imagePacket, normalizedRectPacket);
return inputPacketMap;
}
- (nullable MPPImageClassifierResult *)classifyImage:(MPPImage *)image error:(NSError **)error {
return [self classifyImage:image regionOfInterest:CGRectZero error:error];
}
- (nullable MPPImageClassifierResult *)classifyVideoFrame:(MPPImage *)image
timestampMs:(NSInteger)timestampMs
regionOfInterest:(CGRect)roi
error:(NSError **)error {
std::optional<PacketMap> inputPacketMap = [self inputPacketMapWithMPPImage:image
timestampMs:timestampMs
regionOfInterest:roi
error:error];
if (!inputPacketMap.has_value()) {
return nil;
}
std::optional<PacketMap> outputPacketMap =
[_visionTaskRunner processVideoFramePacketMap:inputPacketMap.value() error:error];
if (!outputPacketMap.has_value()) {
return nil;
}
return
[MPPImageClassifierResult imageClassifierResultWithClassificationsPacket:
outputPacketMap.value()[kClassificationsStreamName.cppString]];
}
- (nullable MPPImageClassifierResult *)classifyVideoFrame:(MPPImage *)image
timestampMs:(NSInteger)timestampMs
error:(NSError **)error {
return [self classifyVideoFrame:image
timestampMs:timestampMs
regionOfInterest:CGRectZero
error:error];
}
- (BOOL)classifyAsyncImage:(MPPImage *)image
timestampMs:(NSInteger)timestampMs
regionOfInterest:(CGRect)roi
error:(NSError **)error {
std::optional<PacketMap> inputPacketMap = [self inputPacketMapWithMPPImage:image
timestampMs:timestampMs
regionOfInterest:roi
error:error];
if (!inputPacketMap.has_value()) {
return NO;
}
return [_visionTaskRunner processLiveStreamPacketMap:inputPacketMap.value() error:error];
}
- (BOOL)classifyAsyncImage:(MPPImage *)image
timestampMs:(NSInteger)timestampMs
error:(NSError **)error {
return [self classifyAsyncImage:image
timestampMs:timestampMs
regionOfInterest:CGRectZero
error:error];
}
@end

View File

@ -31,6 +31,7 @@ NS_SWIFT_NAME(ImageClassifierOptions)
/** /**
* The user-defined result callback for processing live stream data. The result callback should only * The user-defined result callback for processing live stream data. The result callback should only
* be specified when the running mode is set to the live stream mode. * be specified when the running mode is set to the live stream mode.
* TODO: Add parameter `MPPImage` in the callback.
*/ */
@property(nonatomic, copy) void (^completion)(MPPImageClassifierResult *result, NSError *error); @property(nonatomic, copy) void (^completion)(MPPImageClassifierResult *result, NSError *error);

View File

@ -0,0 +1,44 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
objc_library(
name = "MPPImageClassifierOptionsHelpers",
srcs = ["sources/MPPImageClassifierOptions+Helpers.mm"],
hdrs = ["sources/MPPImageClassifierOptions+Helpers.h"],
deps = [
"//mediapipe/framework:calculator_options_cc_proto",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_cc_proto",
"//mediapipe/tasks/ios/common/utils:NSStringHelpers",
"//mediapipe/tasks/ios/core:MPPTaskOptionsProtocol",
"//mediapipe/tasks/ios/core/utils:MPPBaseOptionsHelpers",
"//mediapipe/tasks/ios/vision/image_classifier:MPPImageClassifierOptions",
],
)
objc_library(
name = "MPPImageClassifierResultHelpers",
srcs = ["sources/MPPImageClassifierResult+Helpers.mm"],
hdrs = ["sources/MPPImageClassifierResult+Helpers.h"],
deps = [
"//mediapipe/framework:packet",
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
"//mediapipe/tasks/ios/components/containers/utils:MPPClassificationResultHelpers",
"//mediapipe/tasks/ios/vision/image_classifier:MPPImageClassifierResult",
],
)

View File

@ -0,0 +1,32 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/framework/calculator_options.pb.h"
#import "mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h"
#import "mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifierOptions.h"
NS_ASSUME_NONNULL_BEGIN
@interface MPPImageClassifierOptions (Helpers) <MPPTaskOptionsProtocol>
/**
* Populates the provided `CalculatorOptions` proto container with the current settings.
*
* @param optionsProto The `CalculatorOptions` proto object to copy the settings to.
*/
- (void)copyToProto:(::mediapipe::CalculatorOptions *)optionsProto;
@end
NS_ASSUME_NONNULL_END

View File

@ -0,0 +1,56 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import "mediapipe/tasks/ios/vision/image_classifier/utils/sources/MPPImageClassifierOptions+Helpers.h"
#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h"
#import "mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.h"
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
#include "mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.pb.h"
namespace {
using CalculatorOptionsProto = ::mediapipe::CalculatorOptions;
using ImageClassifierGraphOptionsProto =
::mediapipe::tasks::vision::image_classifier::proto::ImageClassifierGraphOptions;
using ClassifierOptionsProto = ::mediapipe::tasks::components::processors::proto::ClassifierOptions;
} // namespace
@implementation MPPImageClassifierOptions (Helpers)
- (void)copyToProto:(CalculatorOptionsProto *)optionsProto {
ImageClassifierGraphOptionsProto *graphOptions =
optionsProto->MutableExtension(ImageClassifierGraphOptionsProto::ext);
[self.baseOptions copyToProto:graphOptions->mutable_base_options()];
ClassifierOptionsProto *classifierOptionsProto = graphOptions->mutable_classifier_options();
classifierOptionsProto->Clear();
if (self.displayNamesLocale) {
classifierOptionsProto->set_display_names_locale(self.displayNamesLocale.cppString);
}
classifierOptionsProto->set_max_results((int)self.maxResults);
classifierOptionsProto->set_score_threshold(self.scoreThreshold);
for (NSString *category in self.categoryAllowlist) {
classifierOptionsProto->add_category_allowlist(category.cppString);
}
for (NSString *category in self.categoryDenylist) {
classifierOptionsProto->add_category_denylist(category.cppString);
}
}
@end

View File

@ -0,0 +1,36 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import "mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifierResult.h"
#include "mediapipe/framework/packet.h"
NS_ASSUME_NONNULL_BEGIN
@interface MPPImageClassifierResult (Helpers)
/**
* Creates an `MPPImageClassifierResult` from a MediaPipe packet containing an
* `ClassificationResultProto`.
*
* @param packet a MediaPipe packet wrapping a ClassificationResultProto.
*
* @return An `MPPImageClassifierResult` object that contains a list of image classifications.
*/
+ (MPPImageClassifierResult *)imageClassifierResultWithClassificationsPacket:
(const mediapipe::Packet &)packet;
@end
NS_ASSUME_NONNULL_END

View File

@ -0,0 +1,41 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h"
#import "mediapipe/tasks/ios/vision/image_classifier/utils/sources/MPPImageClassifierResult+Helpers.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
static const int kMicroSecondsPerMilliSecond = 1000;
namespace {
using ClassificationResultProto =
::mediapipe::tasks::components::containers::proto::ClassificationResult;
using ::mediapipe::Packet;
} // namespace
@implementation MPPImageClassifierResult (Helpers)
+ (MPPImageClassifierResult *)imageClassifierResultWithClassificationsPacket:
(const Packet &)packet {
MPPClassificationResult *classificationResult = [MPPClassificationResult
classificationResultWithProto:packet.Get<ClassificationResultProto>()];
return [[MPPImageClassifierResult alloc]
initWithClassificationResult:classificationResult
timestampMs:(NSInteger)(packet.Timestamp().Value() /
kMicroSecondsPerMilliSecond)];
}
@end