Merge 4c03da59ef
into e23fa531e1
This commit is contained in:
commit
0cf5315e75
77
mediapipe/tasks/ios/test/vision/interactive_segmenter/BUILD
Normal file
77
mediapipe/tasks/ios/test/vision/interactive_segmenter/BUILD
Normal file
|
@ -0,0 +1,77 @@
|
|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load(
|
||||
"//mediapipe/framework/tool:ios.bzl",
|
||||
"MPP_TASK_MINIMUM_OS_VERSION",
|
||||
)
|
||||
load(
|
||||
"@org_tensorflow//tensorflow/lite:special_rules.bzl",
|
||||
"tflite_ios_lab_runner",
|
||||
)
|
||||
load("@build_bazel_rules_apple//apple:ios.bzl", "ios_unit_test")
|
||||
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
# Default tags for filtering iOS targets. Targets are restricted to Apple platforms.
|
||||
TFL_DEFAULT_TAGS = [
|
||||
"apple",
|
||||
]
|
||||
|
||||
# Following sanitizer tests are not supported by iOS test targets.
|
||||
TFL_DISABLED_SANITIZER_TAGS = [
|
||||
"noasan",
|
||||
"nomsan",
|
||||
"notsan",
|
||||
]
|
||||
|
||||
objc_library(
|
||||
name = "MPPInteractiveSegmenterObjcTestLibrary",
|
||||
testonly = 1,
|
||||
srcs = ["MPPInteractiveSegmenterTests.mm"],
|
||||
copts = [
|
||||
"-ObjC++",
|
||||
"-std=c++17",
|
||||
"-x objective-c++",
|
||||
],
|
||||
data = [
|
||||
"//mediapipe/tasks/testdata/vision:test_images",
|
||||
"//mediapipe/tasks/testdata/vision:test_models",
|
||||
"//mediapipe/tasks/testdata/vision:test_protos",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/tasks/ios/common:MPPCommon",
|
||||
"//mediapipe/tasks/ios/test/vision/utils:MPPImageTestUtils",
|
||||
"//mediapipe/tasks/ios/test/vision/utils:MPPMaskTestUtils",
|
||||
"//mediapipe/tasks/ios/vision/interactive_segmenter:MPPInteractiveSegmenter",
|
||||
"//mediapipe/tasks/ios/vision/image_segmenter:MPPImageSegmenterResult",
|
||||
] + select({
|
||||
"//third_party:opencv_ios_sim_arm64_source_build": ["@ios_opencv_source//:opencv_xcframework"],
|
||||
"//third_party:opencv_ios_arm64_source_build": ["@ios_opencv_source//:opencv_xcframework"],
|
||||
"//third_party:opencv_ios_x86_64_source_build": ["@ios_opencv_source//:opencv_xcframework"],
|
||||
"//conditions:default": ["@ios_opencv//:OpencvFramework"],
|
||||
}),
|
||||
)
|
||||
|
||||
ios_unit_test(
|
||||
name = "MPPInteractiveSegmenterObjcTest",
|
||||
minimum_os_version = MPP_TASK_MINIMUM_OS_VERSION,
|
||||
runner = tflite_ios_lab_runner("IOS_LATEST"),
|
||||
tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS,
|
||||
deps = [
|
||||
":MPPInteractiveSegmenterObjcTestLibrary",
|
||||
],
|
||||
)
|
|
@ -0,0 +1,323 @@
|
|||
// Copyright 2023 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#import <Foundation/Foundation.h>
|
||||
#import <XCTest/XCTest.h>
|
||||
|
||||
#import "mediapipe/tasks/ios/common/sources/MPPCommon.h"
|
||||
#import "mediapipe/tasks/ios/test/vision/utils/sources/MPPImage+TestUtils.h"
|
||||
#import "mediapipe/tasks/ios/test/vision/utils/sources/MPPMask+TestUtils.h"
|
||||
#import "mediapipe/tasks/ios/vision/image_segmenter/sources/MPPImageSegmenterResult.h"
|
||||
#import "mediapipe/tasks/ios/vision/interactive_segmenter/sources/MPPInteractiveSegmenter.h"
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
static MPPFileInfo *const kCatsAndDogsImageFileInfo =
|
||||
[[MPPFileInfo alloc] initWithName:@"cats_and_dogs" type:@"jpg"];
|
||||
static MPPFileInfo *const kCatsAndDogsMaskImage1FileInfo =
|
||||
[[MPPFileInfo alloc] initWithName:@"cats_and_dogs_mask_dog1" type:@"png"];
|
||||
static MPPFileInfo *const kCatsAndDogsMaskImage2FileInfo =
|
||||
[[MPPFileInfo alloc] initWithName:@"cats_and_dogs_mask_dog2" type:@"png"];
|
||||
|
||||
static MPPFileInfo *const kDeepLabModelFileInfo =
|
||||
[[MPPFileInfo alloc] initWithName:@"ptm_512_hdt_ptm_woid" type:@"tflite"];
|
||||
|
||||
static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
|
||||
|
||||
constexpr float kSimilarityThreshold = 0.97f;
|
||||
constexpr NSInteger kMagnificationFactor = 255;
|
||||
constexpr NSInteger kExpectedDeeplabV3ConfidenceMaskCount = 21;
|
||||
constexpr NSInteger kExpected128x128SelfieSegmentationConfidenceMaskCount = 2;
|
||||
constexpr NSInteger kExpected144x256SelfieSegmentationConfidenceMaskCount = 1;
|
||||
|
||||
#define AssertEqualErrors(error, expectedError) \
|
||||
XCTAssertNotNil(error); \
|
||||
XCTAssertEqualObjects(error.domain, expectedError.domain); \
|
||||
XCTAssertEqual(error.code, expectedError.code); \
|
||||
XCTAssertEqualObjects(error.localizedDescription, expectedError.localizedDescription)
|
||||
|
||||
namespace {
|
||||
double sum(const std::vector<float> &mask) {
|
||||
double sum = 0.0;
|
||||
for (const float &maskElement : mask) {
|
||||
sum += maskElement;
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
std::vector<float> multiply(const float *mask1, const float *mask2, size_t size) {
|
||||
std::vector<float> multipliedMask;
|
||||
multipliedMask.reserve(size);
|
||||
|
||||
for (int i = 0; i < size; i++) {
|
||||
multipliedMask.push_back(mask1[i] * mask2[i]);
|
||||
}
|
||||
|
||||
return multipliedMask;
|
||||
}
|
||||
|
||||
double softIOU(const float *mask1, const float *mask2, size_t size) {
|
||||
std::vector<float> interSectionVector = multiply(mask1, mask2, size);
|
||||
double interSectionSum = sum(interSectionVector);
|
||||
|
||||
std::vector<float> m1m1Vector = multiply(mask1, mask1, size);
|
||||
double m1m1 = sum(m1m1Vector);
|
||||
|
||||
std::vector<float> m2m2Vector = multiply(mask2, mask2, size);
|
||||
double m2m2 = sum(m2m2Vector);
|
||||
|
||||
double unionSum = m1m1 + m2m2 - interSectionSum;
|
||||
|
||||
return unionSum > 0.0 ? interSectionSum / unionSum : 0.0;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
@interface MPPInteractiveSegmenterTests : XCTestCase
|
||||
@end
|
||||
|
||||
@implementation MPPInteractiveSegmenterTests
|
||||
|
||||
#pragma mark General Tests
|
||||
|
||||
- (void)setUp {
|
||||
// When expected and actual mask sizes are not equal, iterating through mask data results in a
|
||||
// segmentation fault. Setting this property to `NO`, prevents each test case from executing the
|
||||
// remaining flow after a failure. Since expected and actual mask sizes are compared before
|
||||
// iterating through them, this prevents any illegal memory access.
|
||||
self.continueAfterFailure = NO;
|
||||
}
|
||||
|
||||
+ (NSString *)filePathWithName:(NSString *)fileName extension:(NSString *)extension {
|
||||
NSString *filePath =
|
||||
[[NSBundle bundleForClass:[MPPInteractiveSegmenterTests class]] pathForResource:fileName
|
||||
ofType:extension];
|
||||
return filePath;
|
||||
}
|
||||
|
||||
#pragma mark Image Mode Tests
|
||||
|
||||
- (void)testSegmentWithCategoryMaskSucceeds {
|
||||
MPPInteractiveSegmenterOptions *options =
|
||||
[self interactiveSegmenterOptionsWithModelFileInfo:kDeepLabModelFileInfo];
|
||||
options.shouldOutputConfidenceMasks = NO;
|
||||
options.shouldOutputCategoryMask = YES;
|
||||
|
||||
MPPInteractiveSegmenter *interactiveSegmenter =
|
||||
[self createInteractiveSegmenterWithOptionsSucceeds:options];
|
||||
|
||||
MPPRegionOfInterest *regionOfInterest = [[MPPRegionOfInterest alloc]
|
||||
initWithNormalizedKeyPoint:[[MPPNormalizedKeypoint alloc]
|
||||
initWithLocation:CGPointMake(0.44, 0.7)
|
||||
label:nil
|
||||
score:0.0f]];
|
||||
[self assertResultsOfSegmentImageWithFileInfo:kCatsAndDogsImageFileInfo
|
||||
regionOfInterest:regionOfInterest
|
||||
usingInteractiveSegmenter:interactiveSegmenter
|
||||
approximatelyEqualsExpectedCategoryMaskImageWithFileInfo:kCatsAndDogsMaskImage1FileInfo
|
||||
withMaskSimilarityThreshold:0.84f
|
||||
shouldHaveConfidenceMasks:NO];
|
||||
}
|
||||
|
||||
- (void)testSegmentWithConfidenceMaskSucceeds {
|
||||
MPPInteractiveSegmenterOptions *options =
|
||||
[self interactiveSegmenterOptionsWithModelFileInfo:kDeepLabModelFileInfo];
|
||||
|
||||
MPPInteractiveSegmenter *interactiveSegmenter =
|
||||
[self createInteractiveSegmenterWithOptionsSucceeds:options];
|
||||
|
||||
MPPRegionOfInterest *regionOfInterest = [[MPPRegionOfInterest alloc]
|
||||
initWithNormalizedKeyPoint:[[MPPNormalizedKeypoint alloc]
|
||||
initWithLocation:CGPointMake(0.44, 0.7)
|
||||
label:nil
|
||||
score:0.0f]];
|
||||
|
||||
[self assertResultsOfSegmentImageWithFileInfo:kCatsAndDogsImageFileInfo
|
||||
regionOfInterest:regionOfInterest
|
||||
usingInteractiveSegmenter:interactiveSegmenter
|
||||
hasConfidenceMasksCount:2
|
||||
approximatelyEqualsExpectedConfidenceMaskImageWithFileInfo:kCatsAndDogsMaskImage1FileInfo
|
||||
atIndex:1
|
||||
withMaskSimilarityThreshold:0.84f
|
||||
shouldHaveCategoryMask:NO];
|
||||
}
|
||||
|
||||
#pragma mark - Image Segmenter Initializers
|
||||
|
||||
- (MPPInteractiveSegmenterOptions *)interactiveSegmenterOptionsWithModelFileInfo:
|
||||
(MPPFileInfo *)fileInfo {
|
||||
MPPInteractiveSegmenterOptions *options = [[MPPInteractiveSegmenterOptions alloc] init];
|
||||
options.baseOptions.modelAssetPath = fileInfo.path;
|
||||
return options;
|
||||
}
|
||||
|
||||
- (MPPInteractiveSegmenter *)createInteractiveSegmenterWithOptionsSucceeds:
|
||||
(MPPInteractiveSegmenterOptions *)options {
|
||||
NSError *error;
|
||||
MPPInteractiveSegmenter *interactiveSegmenter =
|
||||
[[MPPInteractiveSegmenter alloc] initWithOptions:options error:&error];
|
||||
XCTAssertNotNil(interactiveSegmenter);
|
||||
XCTAssertNil(error);
|
||||
|
||||
return interactiveSegmenter;
|
||||
}
|
||||
|
||||
- (void)assertCreateInteractiveSegmenterWithOptions:(MPPInteractiveSegmenterOptions *)options
|
||||
failsWithExpectedError:(NSError *)expectedError {
|
||||
NSError *error = nil;
|
||||
MPPInteractiveSegmenter *interactiveSegmenter =
|
||||
[[MPPInteractiveSegmenter alloc] initWithOptions:options error:&error];
|
||||
|
||||
XCTAssertNil(interactiveSegmenter);
|
||||
AssertEqualErrors(error, expectedError);
|
||||
}
|
||||
|
||||
#pragma mark Assert Segmenter Results
|
||||
- (void)assertResultsOfSegmentImageWithFileInfo:(MPPFileInfo *)imageFileInfo
|
||||
regionOfInterest:(MPPRegionOfInterest *)regionOfInterest
|
||||
usingInteractiveSegmenter:
|
||||
(MPPInteractiveSegmenter *)interactiveSegmenter
|
||||
approximatelyEqualsExpectedCategoryMaskImageWithFileInfo:
|
||||
(MPPFileInfo *)expectedCategoryMaskFileInfo
|
||||
withMaskSimilarityThreshold:(const float)maskSImilarityThreshold
|
||||
shouldHaveConfidenceMasks:(BOOL)shouldHaveConfidenceMasks {
|
||||
MPPImageSegmenterResult *result = [self segmentImageWithFileInfo:imageFileInfo
|
||||
regionOfInterest:regionOfInterest
|
||||
usingInteractiveSegmenter:interactiveSegmenter];
|
||||
|
||||
XCTAssertNotNil(result.categoryMask);
|
||||
|
||||
if (shouldHaveConfidenceMasks) {
|
||||
XCTAssertNotNil(result.confidenceMasks);
|
||||
} else {
|
||||
XCTAssertNil(result.confidenceMasks);
|
||||
}
|
||||
|
||||
[self assertCategoryMask:result.categoryMask
|
||||
approximatelyEqualsExpectedCategoryMaskImageWithFileInfo:expectedCategoryMaskFileInfo
|
||||
withMaskSimilarityThreshold:maskSImilarityThreshold];
|
||||
}
|
||||
|
||||
- (void)assertResultsOfSegmentImageWithFileInfo:(MPPFileInfo *)imageFileInfo
|
||||
regionOfInterest:
|
||||
(MPPRegionOfInterest *)regionOfInterest
|
||||
usingInteractiveSegmenter:
|
||||
(MPPInteractiveSegmenter *)interactiveSegmenter
|
||||
hasConfidenceMasksCount:
|
||||
(NSUInteger)expectedConfidenceMasksCount
|
||||
approximatelyEqualsExpectedConfidenceMaskImageWithFileInfo:
|
||||
(MPPFileInfo *)expectedConfidenceMaskFileInfo
|
||||
atIndex:(NSInteger)index
|
||||
withMaskSimilarityThreshold:(const float)maskSImilarityThreshold
|
||||
shouldHaveCategoryMask:(BOOL)shouldHaveCategoryMask {
|
||||
MPPImageSegmenterResult *result = [self segmentImageWithFileInfo:imageFileInfo
|
||||
regionOfInterest:regionOfInterest
|
||||
usingInteractiveSegmenter:interactiveSegmenter];
|
||||
|
||||
[self assertInteractiveSegmenterResult:result
|
||||
hasConfidenceMasksCount:expectedConfidenceMasksCount
|
||||
approximatelyEqualsExpectedConfidenceMaskImageWithFileInfo:expectedConfidenceMaskFileInfo
|
||||
atIndex:index
|
||||
withMaskSimilarityThreshold:maskSImilarityThreshold
|
||||
shouldHaveCategoryMask:shouldHaveCategoryMask];
|
||||
}
|
||||
|
||||
- (void)assertInteractiveSegmenterResult:(MPPImageSegmenterResult *)result
|
||||
hasConfidenceMasksCount:
|
||||
(NSUInteger)expectedConfidenceMasksCount
|
||||
approximatelyEqualsExpectedConfidenceMaskImageWithFileInfo:
|
||||
(MPPFileInfo *)expectedConfidenceMaskFileInfo
|
||||
atIndex:(NSInteger)index
|
||||
withMaskSimilarityThreshold:(const float)maskSImilarityThreshold
|
||||
shouldHaveCategoryMask:(BOOL)shouldHaveCategoryMask {
|
||||
XCTAssertNotNil(result.confidenceMasks);
|
||||
|
||||
XCTAssertEqual(result.confidenceMasks.count, expectedConfidenceMasksCount);
|
||||
|
||||
if (shouldHaveCategoryMask) {
|
||||
XCTAssertNotNil(result.categoryMask);
|
||||
} else {
|
||||
XCTAssertNil(result.categoryMask);
|
||||
}
|
||||
|
||||
XCTAssertLessThan(index, result.confidenceMasks.count);
|
||||
|
||||
[self assertConfidenceMask:result.confidenceMasks[index]
|
||||
approximatelyEqualsExpectedConfidenceMaskImageWithFileInfo:expectedConfidenceMaskFileInfo
|
||||
withMaskSimilarityThreshold:maskSImilarityThreshold];
|
||||
}
|
||||
|
||||
- (MPPImageSegmenterResult *)segmentImageWithFileInfo:(MPPFileInfo *)fileInfo
|
||||
regionOfInterest:(MPPRegionOfInterest *)regionOfInterest
|
||||
usingInteractiveSegmenter:
|
||||
(MPPInteractiveSegmenter *)interactiveSegmenter {
|
||||
MPPImage *image = [MPPImage imageWithFileInfo:fileInfo];
|
||||
XCTAssertNotNil(image);
|
||||
|
||||
NSError *error;
|
||||
|
||||
MPPImageSegmenterResult *result = [interactiveSegmenter segmentImage:image
|
||||
regionOfInterest:regionOfInterest
|
||||
error:&error];
|
||||
|
||||
XCTAssertNil(error);
|
||||
XCTAssertNotNil(result);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
- (void)assertCategoryMask:(MPPMask *)categoryMask
|
||||
approximatelyEqualsExpectedCategoryMaskImageWithFileInfo:
|
||||
(MPPFileInfo *)expectedCategoryMaskImageFileInfo
|
||||
withMaskSimilarityThreshold:(const float)maskSImilarityThreshold {
|
||||
MPPMask *expectedCategoryMask =
|
||||
[[MPPMask alloc] initWithImageFileInfo:expectedCategoryMaskImageFileInfo];
|
||||
|
||||
XCTAssertEqual(categoryMask.width, expectedCategoryMask.width);
|
||||
XCTAssertEqual(categoryMask.height, expectedCategoryMask.height);
|
||||
|
||||
size_t maskSize = categoryMask.width * categoryMask.height;
|
||||
|
||||
const UInt8 *categoryMaskPixelData = categoryMask.uint8Data;
|
||||
const UInt8 *expectedCategoryMaskPixelData = expectedCategoryMask.uint8Data;
|
||||
|
||||
NSInteger consistentPixels = 0;
|
||||
|
||||
for (int i = 0; i < maskSize; i++) {
|
||||
consistentPixels +=
|
||||
categoryMaskPixelData[i] * kMagnificationFactor == expectedCategoryMaskPixelData[i] ? 1 : 0;
|
||||
}
|
||||
|
||||
XCTAssertGreaterThan((float)consistentPixels / (float)maskSize, maskSImilarityThreshold);
|
||||
}
|
||||
|
||||
- (void)assertConfidenceMask:(MPPMask *)confidenceMask
|
||||
approximatelyEqualsExpectedConfidenceMaskImageWithFileInfo:
|
||||
(MPPFileInfo *)expectedConfidenceMaskImageFileInfo
|
||||
withMaskSimilarityThreshold:
|
||||
(const float)maskSImilarityThreshold {
|
||||
MPPMask *expectedConfidenceMask =
|
||||
[[MPPMask alloc] initWithImageFileInfo:expectedConfidenceMaskImageFileInfo];
|
||||
|
||||
XCTAssertEqual(confidenceMask.width, expectedConfidenceMask.width);
|
||||
XCTAssertEqual(confidenceMask.height, expectedConfidenceMask.height);
|
||||
|
||||
size_t maskSize = confidenceMask.width * confidenceMask.height;
|
||||
|
||||
XCTAssertGreaterThan(
|
||||
softIOU(confidenceMask.float32Data, expectedConfidenceMask.float32Data, maskSize),
|
||||
maskSImilarityThreshold);
|
||||
}
|
||||
|
||||
@end
|
|
@ -25,12 +25,30 @@ objc_library(
|
|||
|
||||
objc_library(
|
||||
name = "MPPInteractiveSegmenter",
|
||||
srcs = ["sources/MPPInteractiveSegmenter.mm"],
|
||||
hdrs = ["sources/MPPInteractiveSegmenter.h"],
|
||||
copts = [
|
||||
"-ObjC++",
|
||||
"-std=c++17",
|
||||
"-x objective-c++",
|
||||
],
|
||||
module_name = "MPPInteractiveSegmenter",
|
||||
deps = [
|
||||
":MPPInteractiveSegmenterOptions",
|
||||
"//mediapipe/tasks/cc/vision/interactive_segmenter:interactive_segmenter_graph",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator_cc_proto",
|
||||
"//mediapipe/tasks/ios/common:MPPCommon",
|
||||
"//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
|
||||
"//mediapipe/tasks/ios/common/utils:NSStringHelpers",
|
||||
"//mediapipe/tasks/ios/components/containers:MPPRegionOfInterest",
|
||||
"//mediapipe/tasks/ios/core:MPPTaskInfo",
|
||||
"//mediapipe/tasks/ios/vision/core:MPPImage",
|
||||
"//mediapipe/tasks/ios/vision/core:MPPVisionPacketCreator",
|
||||
"//mediapipe/tasks/ios/vision/core:MPPVisionTaskRunner",
|
||||
"//mediapipe/tasks/ios/vision/interactive_segmenter/utils:MPPInteractiveSegmenterOptionsHelpers",
|
||||
"//mediapipe/tasks/ios/vision/image_segmenter:MPPImageSegmenterResult",
|
||||
"//mediapipe/tasks/ios/vision/image_segmenter/utils:MPPImageSegmenterResultHelpers",
|
||||
"//mediapipe/util:label_map_cc_proto",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -0,0 +1,257 @@
|
|||
// Copyright 2023 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#import "mediapipe/tasks/ios/vision/interactive_segmenter/sources/MPPInteractiveSegmenter.h"
|
||||
|
||||
#import "mediapipe/tasks/ios/common/sources/MPPCommon.h"
|
||||
#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h"
|
||||
#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h"
|
||||
#import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h"
|
||||
#import "mediapipe/tasks/ios/vision/core/sources/MPPVisionPacketCreator.h"
|
||||
#import "mediapipe/tasks/ios/vision/core/sources/MPPVisionTaskRunner.h"
|
||||
#import "mediapipe/tasks/ios/vision/image_segmenter/utils/sources/MPPImageSegmenterResult+Helpers.h"
|
||||
#import "mediapipe/tasks/ios/vision/interactive_segmenter/utils/sources/MPPInteractiveSegmenterOptions+Helpers.h"
|
||||
|
||||
#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.pb.h"
|
||||
#include "mediapipe/util/label_map.pb.h"
|
||||
|
||||
static constexpr int kMicrosecondsPerMillisecond = 1000;
|
||||
|
||||
// Constants for the underlying MP Tasks Graph. See
|
||||
// https://github.com/google/mediapipe/tree/master/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc
|
||||
static NSString *const kConfidenceMasksStreamName = @"confidence_masks";
|
||||
static NSString *const kConfidenceMasksTag = @"CONFIDENCE_MASKS";
|
||||
static NSString *const kCategoryMaskStreamName = @"category_mask";
|
||||
static NSString *const kCategoryMaskTag = @"CATEGORY_MASK";
|
||||
static NSString *const kQualityScoresStreamName = @"quality_scores";
|
||||
static NSString *const kQualityScoresTag = @"QUALITY_SCORES";
|
||||
static NSString *const kImageInStreamName = @"image_in";
|
||||
static NSString *const kImageOutStreamName = @"image_out";
|
||||
static NSString *const kImageTag = @"IMAGE";
|
||||
static NSString *const kNormRectStreamName = @"norm_rect_in";
|
||||
static NSString *const kNormRectTag = @"NORM_RECT";
|
||||
static NSString *const kRoiInStreamName = @"roi_in";
|
||||
static NSString *const kRoiTag = @"ROI";
|
||||
static NSString *const kTaskGraphName =
|
||||
@"mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph";
|
||||
static NSString *const kTaskName = @"interactiveSegmenter";
|
||||
|
||||
#define InputPacketMap(imagePacket, normalizedRectPacket) \
|
||||
{ \
|
||||
{kImageInStreamName.cppString, imagePacket}, { \
|
||||
kNormRectStreamName.cppString, normalizedRectPacket \
|
||||
} \
|
||||
}
|
||||
|
||||
namespace {
|
||||
using ::mediapipe::CalculatorGraphConfig;
|
||||
using ::mediapipe::Packet;
|
||||
using ::mediapipe::Timestamp;
|
||||
using ::mediapipe::tasks::TensorsToSegmentationCalculatorOptions;
|
||||
using ::mediapipe::tasks::core::PacketMap;
|
||||
using ::mediapipe::tasks::core::PacketsCallback;
|
||||
} // anonymous namespace
|
||||
|
||||
@interface MPPInteractiveSegmenter () {
|
||||
/** iOS Vision Task Runner */
|
||||
MPPVisionTaskRunner *_visionTaskRunner;
|
||||
}
|
||||
|
||||
@end
|
||||
|
||||
@implementation MPPInteractiveSegmenter
|
||||
|
||||
#pragma mark - Public
|
||||
|
||||
- (instancetype)initWithOptions:(MPPInteractiveSegmenterOptions *)options error:(NSError **)error {
|
||||
self = [super init];
|
||||
if (self) {
|
||||
NSMutableArray<NSString *> *outputStreams = [NSMutableArray
|
||||
arrayWithObjects:[NSString stringWithFormat:@"%@:%@", kQualityScoresTag,
|
||||
kQualityScoresStreamName],
|
||||
[NSString stringWithFormat:@"%@:%@", kImageTag, kImageOutStreamName], nil];
|
||||
if (options.shouldOutputConfidenceMasks) {
|
||||
[outputStreams addObject:[NSString stringWithFormat:@"%@:%@", kConfidenceMasksTag,
|
||||
kConfidenceMasksStreamName]];
|
||||
}
|
||||
if (options.shouldOutputCategoryMask) {
|
||||
[outputStreams addObject:[NSString stringWithFormat:@"%@:%@", kCategoryMaskTag,
|
||||
kCategoryMaskStreamName]];
|
||||
}
|
||||
|
||||
MPPTaskInfo *taskInfo = [[MPPTaskInfo alloc]
|
||||
initWithTaskGraphName:kTaskGraphName
|
||||
inputStreams:@[
|
||||
[NSString stringWithFormat:@"%@:%@", kImageTag, kImageInStreamName],
|
||||
[NSString stringWithFormat:@"%@:%@", kNormRectTag, kNormRectStreamName],
|
||||
[NSString stringWithFormat:@"%@:%@", kRoiTag, kRoiInStreamName],
|
||||
]
|
||||
outputStreams:outputStreams
|
||||
taskOptions:options
|
||||
enableFlowLimiting:NO
|
||||
error:error];
|
||||
|
||||
if (!taskInfo) {
|
||||
return nil;
|
||||
}
|
||||
|
||||
PacketsCallback packetsCallback = nullptr;
|
||||
|
||||
_visionTaskRunner = [[MPPVisionTaskRunner alloc] initWithTaskInfo:taskInfo
|
||||
runningMode:MPPRunningModeImage
|
||||
roiAllowed:YES
|
||||
packetsCallback:std::move(packetsCallback)
|
||||
imageInputStreamName:kImageInStreamName
|
||||
normRectInputStreamName:kNormRectStreamName
|
||||
error:error];
|
||||
if (!_visionTaskRunner) {
|
||||
return nil;
|
||||
}
|
||||
|
||||
_labels = [MPPInteractiveSegmenter populateLabelsWithGraphConfig:_visionTaskRunner.graphConfig
|
||||
error:error];
|
||||
if (!_labels) {
|
||||
return nil;
|
||||
}
|
||||
}
|
||||
|
||||
return self;
|
||||
}
|
||||
|
||||
- (instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error {
|
||||
MPPInteractiveSegmenterOptions *options = [[MPPInteractiveSegmenterOptions alloc] init];
|
||||
|
||||
options.baseOptions.modelAssetPath = modelPath;
|
||||
|
||||
return [self initWithOptions:options error:error];
|
||||
}
|
||||
|
||||
- (nullable MPPImageSegmenterResult *)segmentImage:(MPPImage *)image
|
||||
regionOfInterest:(MPPRegionOfInterest *)regionOfInterest
|
||||
error:(NSError **)error {
|
||||
return [self segmentImage:image
|
||||
regionOfInterest:regionOfInterest
|
||||
shouldCopyOutputMaskPacketData:YES
|
||||
error:error];
|
||||
}
|
||||
|
||||
- (void)segmentImage:(MPPImage *)image
|
||||
regionOfInterest:(MPPRegionOfInterest *)regionOfInterest
|
||||
withCompletionHandler:(void (^)(MPPImageSegmenterResult *_Nullable result,
|
||||
NSError *_Nullable error))completionHandler {
|
||||
NSError *error = nil;
|
||||
MPPImageSegmenterResult *result = [self segmentImage:image
|
||||
regionOfInterest:regionOfInterest
|
||||
shouldCopyOutputMaskPacketData:NO
|
||||
error:&error];
|
||||
completionHandler(result, error);
|
||||
}
|
||||
|
||||
#pragma mark - Private
|
||||
|
||||
+ (NSArray<NSString *> *)populateLabelsWithGraphConfig:(const CalculatorGraphConfig &)graphConfig
|
||||
error:(NSError **)error {
|
||||
bool found_tensor_to_segmentation_calculator = false;
|
||||
|
||||
NSMutableArray<NSString *> *labels =
|
||||
[NSMutableArray arrayWithCapacity:(NSUInteger)graphConfig.node_size()];
|
||||
for (const auto &node : graphConfig.node()) {
|
||||
if (node.calculator() == "mediapipe.tasks.TensorsToSegmentationCalculator") {
|
||||
if (!found_tensor_to_segmentation_calculator) {
|
||||
found_tensor_to_segmentation_calculator = true;
|
||||
} else {
|
||||
[MPPCommonUtils createCustomError:error
|
||||
withCode:MPPTasksErrorCodeFailedPreconditionError
|
||||
description:@"The graph has more than one "
|
||||
@"`mediapipe.tasks.TensorsToSegmentationCalculator`."];
|
||||
return nil;
|
||||
}
|
||||
TensorsToSegmentationCalculatorOptions options =
|
||||
node.options().GetExtension(TensorsToSegmentationCalculatorOptions::ext);
|
||||
if (!options.label_items().empty()) {
|
||||
for (int i = 0; i < options.label_items_size(); ++i) {
|
||||
if (!options.label_items().contains(i)) {
|
||||
[MPPCommonUtils
|
||||
createCustomError:error
|
||||
withCode:MPPTasksErrorCodeFailedPreconditionError
|
||||
description:[NSString
|
||||
stringWithFormat:@"The lablemap has no expected key %d.", i]];
|
||||
|
||||
return nil;
|
||||
}
|
||||
[labels addObject:[NSString stringWithCppString:options.label_items().at(i).name()]];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return labels;
|
||||
}
|
||||
|
||||
- (nullable MPPImageSegmenterResult *)segmentImage:(MPPImage *)image
|
||||
regionOfInterest:(MPPRegionOfInterest *)regionOfInterest
|
||||
shouldCopyOutputMaskPacketData:(BOOL)shouldCopyMaskPacketData
|
||||
error:(NSError **)error {
|
||||
std::optional<PacketMap> inputPacketMap = [_visionTaskRunner inputPacketMapWithMPPImage:image
|
||||
regionOfInterest:CGRectZero
|
||||
error:error];
|
||||
|
||||
if (!inputPacketMap.has_value()) {
|
||||
return nil;
|
||||
}
|
||||
std::optional<Packet> renderDataPacket =
|
||||
[MPPVisionPacketCreator createRenderDataPacketWithRegionOfInterest:regionOfInterest
|
||||
error:error];
|
||||
|
||||
if (!renderDataPacket.has_value()) {
|
||||
return nil;
|
||||
}
|
||||
|
||||
inputPacketMap->insert({kRoiInStreamName.cppString, renderDataPacket.value()});
|
||||
|
||||
// inputPacketMap.value().insert(std::pair<std::string, Packet>(kRoiInStreamName.cppString,
|
||||
// renderDataPacket.value()));
|
||||
|
||||
std::optional<PacketMap> outputPacketMap =
|
||||
[_visionTaskRunner processPacketMap:inputPacketMap.value() error:error];
|
||||
|
||||
return [MPPInteractiveSegmenter
|
||||
imageSegmenterResultWithOptionalOutputPacketMap:outputPacketMap
|
||||
shouldCopyMaskPacketData:shouldCopyMaskPacketData];
|
||||
}
|
||||
|
||||
+ (nullable MPPImageSegmenterResult *)
|
||||
imageSegmenterResultWithOptionalOutputPacketMap:(std::optional<PacketMap> &)outputPacketMap
|
||||
shouldCopyMaskPacketData:(BOOL)shouldCopyMaskPacketData {
|
||||
if (!outputPacketMap.has_value()) {
|
||||
return nil;
|
||||
}
|
||||
|
||||
PacketMap &outputPacketMapValue = outputPacketMap.value();
|
||||
|
||||
return [MPPImageSegmenterResult
|
||||
imageSegmenterResultWithConfidenceMasksPacket:outputPacketMapValue[kConfidenceMasksStreamName
|
||||
.cppString]
|
||||
categoryMaskPacket:outputPacketMapValue[kCategoryMaskStreamName
|
||||
.cppString]
|
||||
qualityScoresPacket:outputPacketMapValue[kQualityScoresStreamName
|
||||
.cppString]
|
||||
timestampInMilliseconds:outputPacketMapValue[kImageOutStreamName
|
||||
.cppString]
|
||||
.Timestamp()
|
||||
.Value() /
|
||||
kMicrosecondsPerMillisecond
|
||||
shouldCopyMaskPacketData:shouldCopyMaskPacketData];
|
||||
}
|
||||
|
||||
@end
|
31
mediapipe/tasks/ios/vision/interactive_segmenter/utils/BUILD
Normal file
31
mediapipe/tasks/ios/vision/interactive_segmenter/utils/BUILD
Normal file
|
@ -0,0 +1,31 @@
|
|||
# Copyright 2023 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
objc_library(
|
||||
name = "MPPInteractiveSegmenterOptionsHelpers",
|
||||
srcs = ["sources/MPPInteractiveSegmenterOptions+Helpers.mm"],
|
||||
hdrs = ["sources/MPPInteractiveSegmenterOptions+Helpers.h"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/ios/common/utils:NSStringHelpers",
|
||||
"//mediapipe/tasks/ios/core:MPPTaskOptionsProtocol",
|
||||
"//mediapipe/tasks/ios/core/utils:MPPBaseOptionsHelpers",
|
||||
"//mediapipe/tasks/ios/vision/interactive_segmenter:MPPInteractiveSegmenterOptions",
|
||||
],
|
||||
)
|
|
@ -0,0 +1,32 @@
|
|||
// Copyright 2023 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/framework/calculator_options.pb.h"
|
||||
#import "mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h"
|
||||
#import "mediapipe/tasks/ios/vision/interactive_segmenter/sources/MPPInteractiveSegmenterOptions.h"
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
@interface MPPInteractiveSegmenterOptions (Helpers) <MPPTaskOptionsProtocol>
|
||||
|
||||
/**
|
||||
* Populates the provided `CalculatorOptions` proto container with the current settings.
|
||||
*
|
||||
* @param optionsProto The `CalculatorOptions` proto object to copy the settings to.
|
||||
*/
|
||||
- (void)copyToProto:(::mediapipe::CalculatorOptions *)optionsProto;
|
||||
|
||||
@end
|
||||
|
||||
NS_ASSUME_NONNULL_END
|
|
@ -0,0 +1,41 @@
|
|||
// Copyright 2023 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#import "mediapipe/tasks/ios/vision/interactive_segmenter/utils/sources/MPPInteractiveSegmenterOptions+Helpers.h"
|
||||
|
||||
#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h"
|
||||
#import "mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.h"
|
||||
|
||||
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
|
||||
|
||||
namespace {
|
||||
using CalculatorOptionsProto = ::mediapipe::CalculatorOptions;
|
||||
using ImageSegmenterGraphOptionsProto =
|
||||
::mediapipe::tasks::vision::image_segmenter::proto::ImageSegmenterGraphOptions;
|
||||
using SegmenterOptionsProto = ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions;
|
||||
} // namespace
|
||||
|
||||
@implementation MPPInteractiveSegmenterOptions (Helpers)
|
||||
|
||||
- (void)copyToProto:(CalculatorOptionsProto *)optionsProto {
|
||||
ImageSegmenterGraphOptionsProto *imageSegmenterGraphOptionsProto =
|
||||
optionsProto->MutableExtension(ImageSegmenterGraphOptionsProto::ext);
|
||||
imageSegmenterGraphOptionsProto->Clear();
|
||||
|
||||
[self.baseOptions copyToProto:imageSegmenterGraphOptionsProto->mutable_base_options()
|
||||
withUseStreamMode:NO];
|
||||
imageSegmenterGraphOptionsProto->set_display_names_locale(self.displayNamesLocale.cppString);
|
||||
}
|
||||
|
||||
@end
|
Loading…
Reference in New Issue
Block a user