Added iOS interactive segmenter implementation
This commit is contained in:
parent
ecd87649a7
commit
d1237787e2
|
@ -25,12 +25,30 @@ objc_library(
|
||||||
|
|
||||||
objc_library(
|
objc_library(
|
||||||
name = "MPPInteractiveSegmenter",
|
name = "MPPInteractiveSegmenter",
|
||||||
|
srcs = ["sources/MPPInteractiveSegmenter.mm"],
|
||||||
hdrs = ["sources/MPPInteractiveSegmenter.h"],
|
hdrs = ["sources/MPPInteractiveSegmenter.h"],
|
||||||
|
copts = [
|
||||||
|
"-ObjC++",
|
||||||
|
"-std=c++17",
|
||||||
|
"-x objective-c++",
|
||||||
|
],
|
||||||
module_name = "MPPInteractiveSegmenter",
|
module_name = "MPPInteractiveSegmenter",
|
||||||
deps = [
|
deps = [
|
||||||
":MPPInteractiveSegmenterOptions",
|
":MPPInteractiveSegmenterOptions",
|
||||||
|
"//mediapipe/tasks/cc/vision/interactive_segmenter:interactive_segmenter_graph",
|
||||||
|
"//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator",
|
||||||
|
"//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator_cc_proto",
|
||||||
|
"//mediapipe/tasks/ios/common:MPPCommon",
|
||||||
|
"//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
|
||||||
|
"//mediapipe/tasks/ios/common/utils:NSStringHelpers",
|
||||||
"//mediapipe/tasks/ios/components/containers:MPPRegionOfInterest",
|
"//mediapipe/tasks/ios/components/containers:MPPRegionOfInterest",
|
||||||
|
"//mediapipe/tasks/ios/core:MPPTaskInfo",
|
||||||
"//mediapipe/tasks/ios/vision/core:MPPImage",
|
"//mediapipe/tasks/ios/vision/core:MPPImage",
|
||||||
|
"//mediapipe/tasks/ios/vision/core:MPPVisionPacketCreator",
|
||||||
|
"//mediapipe/tasks/ios/vision/core:MPPVisionTaskRunner",
|
||||||
|
"//mediapipe/tasks/ios/vision/interactive_segmenter/utils:MPPInteractiveSegmenterOptionsHelpers",
|
||||||
"//mediapipe/tasks/ios/vision/image_segmenter:MPPImageSegmenterResult",
|
"//mediapipe/tasks/ios/vision/image_segmenter:MPPImageSegmenterResult",
|
||||||
|
"//mediapipe/tasks/ios/vision/image_segmenter/utils:MPPImageSegmenterResultHelpers",
|
||||||
|
"//mediapipe/util:label_map_cc_proto",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -0,0 +1,257 @@
|
||||||
|
// Copyright 2023 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#import "mediapipe/tasks/ios/vision/interactive_segmenter/sources/MPPInteractiveSegmenter.h"
|
||||||
|
|
||||||
|
#import "mediapipe/tasks/ios/common/sources/MPPCommon.h"
|
||||||
|
#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h"
|
||||||
|
#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h"
|
||||||
|
#import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h"
|
||||||
|
#import "mediapipe/tasks/ios/vision/core/sources/MPPVisionPacketCreator.h"
|
||||||
|
#import "mediapipe/tasks/ios/vision/core/sources/MPPVisionTaskRunner.h"
|
||||||
|
#import "mediapipe/tasks/ios/vision/image_segmenter/utils/sources/MPPImageSegmenterResult+Helpers.h"
|
||||||
|
#import "mediapipe/tasks/ios/vision/interactive_segmenter/utils/sources/MPPInteractiveSegmenterOptions+Helpers.h"
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.pb.h"
|
||||||
|
#include "mediapipe/util/label_map.pb.h"
|
||||||
|
|
||||||
|
static constexpr int kMicrosecondsPerMillisecond = 1000;
|
||||||
|
|
||||||
|
// Constants for the underlying MP Tasks Graph. See
|
||||||
|
// https://github.com/google/mediapipe/tree/master/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc
|
||||||
|
static NSString *const kConfidenceMasksStreamName = @"confidence_masks";
|
||||||
|
static NSString *const kConfidenceMasksTag = @"CONFIDENCE_MASKS";
|
||||||
|
static NSString *const kCategoryMaskStreamName = @"category_mask";
|
||||||
|
static NSString *const kCategoryMaskTag = @"CATEGORY_MASK";
|
||||||
|
static NSString *const kQualityScoresStreamName = @"quality_scores";
|
||||||
|
static NSString *const kQualityScoresTag = @"QUALITY_SCORES";
|
||||||
|
static NSString *const kImageInStreamName = @"image_in";
|
||||||
|
static NSString *const kImageOutStreamName = @"image_out";
|
||||||
|
static NSString *const kImageTag = @"IMAGE";
|
||||||
|
static NSString *const kNormRectStreamName = @"norm_rect_in";
|
||||||
|
static NSString *const kNormRectTag = @"NORM_RECT";
|
||||||
|
static NSString *const kRoiInStreamName = @"roi_in";
|
||||||
|
static NSString *const kRoiTag = @"ROI";
|
||||||
|
static NSString *const kTaskGraphName =
|
||||||
|
@"mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph";
|
||||||
|
static NSString *const kTaskName = @"interactiveSegmenter";
|
||||||
|
|
||||||
|
#define InputPacketMap(imagePacket, normalizedRectPacket) \
|
||||||
|
{ \
|
||||||
|
{kImageInStreamName.cppString, imagePacket}, { \
|
||||||
|
kNormRectStreamName.cppString, normalizedRectPacket \
|
||||||
|
} \
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
using ::mediapipe::CalculatorGraphConfig;
|
||||||
|
using ::mediapipe::Packet;
|
||||||
|
using ::mediapipe::Timestamp;
|
||||||
|
using ::mediapipe::tasks::TensorsToSegmentationCalculatorOptions;
|
||||||
|
using ::mediapipe::tasks::core::PacketMap;
|
||||||
|
using ::mediapipe::tasks::core::PacketsCallback;
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
|
@interface MPPInteractiveSegmenter () {
|
||||||
|
/** iOS Vision Task Runner */
|
||||||
|
MPPVisionTaskRunner *_visionTaskRunner;
|
||||||
|
}
|
||||||
|
|
||||||
|
@end
|
||||||
|
|
||||||
|
@implementation MPPInteractiveSegmenter
|
||||||
|
|
||||||
|
#pragma mark - Public
|
||||||
|
|
||||||
|
- (instancetype)initWithOptions:(MPPInteractiveSegmenterOptions *)options error:(NSError **)error {
|
||||||
|
self = [super init];
|
||||||
|
if (self) {
|
||||||
|
NSMutableArray<NSString *> *outputStreams = [NSMutableArray
|
||||||
|
arrayWithObjects:[NSString stringWithFormat:@"%@:%@", kQualityScoresTag,
|
||||||
|
kQualityScoresStreamName],
|
||||||
|
[NSString stringWithFormat:@"%@:%@", kImageTag, kImageOutStreamName], nil];
|
||||||
|
if (options.shouldOutputConfidenceMasks) {
|
||||||
|
[outputStreams addObject:[NSString stringWithFormat:@"%@:%@", kConfidenceMasksTag,
|
||||||
|
kConfidenceMasksStreamName]];
|
||||||
|
}
|
||||||
|
if (options.shouldOutputCategoryMask) {
|
||||||
|
[outputStreams addObject:[NSString stringWithFormat:@"%@:%@", kCategoryMaskTag,
|
||||||
|
kCategoryMaskStreamName]];
|
||||||
|
}
|
||||||
|
|
||||||
|
MPPTaskInfo *taskInfo = [[MPPTaskInfo alloc]
|
||||||
|
initWithTaskGraphName:kTaskGraphName
|
||||||
|
inputStreams:@[
|
||||||
|
[NSString stringWithFormat:@"%@:%@", kImageTag, kImageInStreamName],
|
||||||
|
[NSString stringWithFormat:@"%@:%@", kNormRectTag, kNormRectStreamName],
|
||||||
|
[NSString stringWithFormat:@"%@:%@", kRoiTag, kRoiInStreamName],
|
||||||
|
]
|
||||||
|
outputStreams:outputStreams
|
||||||
|
taskOptions:options
|
||||||
|
enableFlowLimiting:NO
|
||||||
|
error:error];
|
||||||
|
|
||||||
|
if (!taskInfo) {
|
||||||
|
return nil;
|
||||||
|
}
|
||||||
|
|
||||||
|
PacketsCallback packetsCallback = nullptr;
|
||||||
|
|
||||||
|
_visionTaskRunner = [[MPPVisionTaskRunner alloc] initWithTaskInfo:taskInfo
|
||||||
|
runningMode:MPPRunningModeImage
|
||||||
|
roiAllowed:YES
|
||||||
|
packetsCallback:std::move(packetsCallback)
|
||||||
|
imageInputStreamName:kImageInStreamName
|
||||||
|
normRectInputStreamName:kNormRectStreamName
|
||||||
|
error:error];
|
||||||
|
if (!_visionTaskRunner) {
|
||||||
|
return nil;
|
||||||
|
}
|
||||||
|
|
||||||
|
_labels = [MPPInteractiveSegmenter populateLabelsWithGraphConfig:_visionTaskRunner.graphConfig
|
||||||
|
error:error];
|
||||||
|
if (!_labels) {
|
||||||
|
return nil;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return self;
|
||||||
|
}
|
||||||
|
|
||||||
|
- (instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error {
|
||||||
|
MPPInteractiveSegmenterOptions *options = [[MPPInteractiveSegmenterOptions alloc] init];
|
||||||
|
|
||||||
|
options.baseOptions.modelAssetPath = modelPath;
|
||||||
|
|
||||||
|
return [self initWithOptions:options error:error];
|
||||||
|
}
|
||||||
|
|
||||||
|
- (nullable MPPImageSegmenterResult *)segmentImage:(MPPImage *)image
|
||||||
|
regionOfInterest:(MPPRegionOfInterest *)regionOfInterest
|
||||||
|
error:(NSError **)error {
|
||||||
|
return [self segmentImage:image
|
||||||
|
regionOfInterest:regionOfInterest
|
||||||
|
shouldCopyOutputMaskPacketData:YES
|
||||||
|
error:error];
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)segmentImage:(MPPImage *)image
|
||||||
|
regionOfInterest:(MPPRegionOfInterest *)regionOfInterest
|
||||||
|
withCompletionHandler:(void (^)(MPPImageSegmenterResult *_Nullable result,
|
||||||
|
NSError *_Nullable error))completionHandler {
|
||||||
|
NSError *error = nil;
|
||||||
|
MPPImageSegmenterResult *result = [self segmentImage:image
|
||||||
|
regionOfInterest:regionOfInterest
|
||||||
|
shouldCopyOutputMaskPacketData:NO
|
||||||
|
error:&error];
|
||||||
|
completionHandler(result, error);
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma mark - Private
|
||||||
|
|
||||||
|
+ (NSArray<NSString *> *)populateLabelsWithGraphConfig:(const CalculatorGraphConfig &)graphConfig
|
||||||
|
error:(NSError **)error {
|
||||||
|
bool found_tensor_to_segmentation_calculator = false;
|
||||||
|
|
||||||
|
NSMutableArray<NSString *> *labels =
|
||||||
|
[NSMutableArray arrayWithCapacity:(NSUInteger)graphConfig.node_size()];
|
||||||
|
for (const auto &node : graphConfig.node()) {
|
||||||
|
if (node.calculator() == "mediapipe.tasks.TensorsToSegmentationCalculator") {
|
||||||
|
if (!found_tensor_to_segmentation_calculator) {
|
||||||
|
found_tensor_to_segmentation_calculator = true;
|
||||||
|
} else {
|
||||||
|
[MPPCommonUtils createCustomError:error
|
||||||
|
withCode:MPPTasksErrorCodeFailedPreconditionError
|
||||||
|
description:@"The graph has more than one "
|
||||||
|
@"`mediapipe.tasks.TensorsToSegmentationCalculator`."];
|
||||||
|
return nil;
|
||||||
|
}
|
||||||
|
TensorsToSegmentationCalculatorOptions options =
|
||||||
|
node.options().GetExtension(TensorsToSegmentationCalculatorOptions::ext);
|
||||||
|
if (!options.label_items().empty()) {
|
||||||
|
for (int i = 0; i < options.label_items_size(); ++i) {
|
||||||
|
if (!options.label_items().contains(i)) {
|
||||||
|
[MPPCommonUtils
|
||||||
|
createCustomError:error
|
||||||
|
withCode:MPPTasksErrorCodeFailedPreconditionError
|
||||||
|
description:[NSString
|
||||||
|
stringWithFormat:@"The lablemap has no expected key %d.", i]];
|
||||||
|
|
||||||
|
return nil;
|
||||||
|
}
|
||||||
|
[labels addObject:[NSString stringWithCppString:options.label_items().at(i).name()]];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return labels;
|
||||||
|
}
|
||||||
|
|
||||||
|
- (nullable MPPImageSegmenterResult *)segmentImage:(MPPImage *)image
|
||||||
|
regionOfInterest:(MPPRegionOfInterest *)regionOfInterest
|
||||||
|
shouldCopyOutputMaskPacketData:(BOOL)shouldCopyMaskPacketData
|
||||||
|
error:(NSError **)error {
|
||||||
|
std::optional<PacketMap> inputPacketMap = [_visionTaskRunner inputPacketMapWithMPPImage:image
|
||||||
|
regionOfInterest:CGRectZero
|
||||||
|
error:error];
|
||||||
|
|
||||||
|
if (!inputPacketMap.has_value()) {
|
||||||
|
return nil;
|
||||||
|
}
|
||||||
|
std::optional<Packet> renderDataPacket =
|
||||||
|
[MPPVisionPacketCreator createRenderDataPacketWithRegionOfInterest:regionOfInterest
|
||||||
|
error:error];
|
||||||
|
|
||||||
|
if (!renderDataPacket.has_value()) {
|
||||||
|
return nil;
|
||||||
|
}
|
||||||
|
|
||||||
|
inputPacketMap->insert({kRoiInStreamName.cppString, renderDataPacket.value()});
|
||||||
|
|
||||||
|
// inputPacketMap.value().insert(std::pair<std::string, Packet>(kRoiInStreamName.cppString,
|
||||||
|
// renderDataPacket.value()));
|
||||||
|
|
||||||
|
std::optional<PacketMap> outputPacketMap =
|
||||||
|
[_visionTaskRunner processPacketMap:inputPacketMap.value() error:error];
|
||||||
|
|
||||||
|
return [MPPInteractiveSegmenter
|
||||||
|
imageSegmenterResultWithOptionalOutputPacketMap:outputPacketMap
|
||||||
|
shouldCopyMaskPacketData:shouldCopyMaskPacketData];
|
||||||
|
}
|
||||||
|
|
||||||
|
+ (nullable MPPImageSegmenterResult *)
|
||||||
|
imageSegmenterResultWithOptionalOutputPacketMap:(std::optional<PacketMap> &)outputPacketMap
|
||||||
|
shouldCopyMaskPacketData:(BOOL)shouldCopyMaskPacketData {
|
||||||
|
if (!outputPacketMap.has_value()) {
|
||||||
|
return nil;
|
||||||
|
}
|
||||||
|
|
||||||
|
PacketMap &outputPacketMapValue = outputPacketMap.value();
|
||||||
|
|
||||||
|
return [MPPImageSegmenterResult
|
||||||
|
imageSegmenterResultWithConfidenceMasksPacket:outputPacketMapValue[kConfidenceMasksStreamName
|
||||||
|
.cppString]
|
||||||
|
categoryMaskPacket:outputPacketMapValue[kCategoryMaskStreamName
|
||||||
|
.cppString]
|
||||||
|
qualityScoresPacket:outputPacketMapValue[kQualityScoresStreamName
|
||||||
|
.cppString]
|
||||||
|
timestampInMilliseconds:outputPacketMapValue[kImageOutStreamName
|
||||||
|
.cppString]
|
||||||
|
.Timestamp()
|
||||||
|
.Value() /
|
||||||
|
kMicrosecondsPerMillisecond
|
||||||
|
shouldCopyMaskPacketData:shouldCopyMaskPacketData];
|
||||||
|
}
|
||||||
|
|
||||||
|
@end
|
Loading…
Reference in New Issue
Block a user