diff --git a/mediapipe/tasks/ios/test/text/text_embedder/BUILD b/mediapipe/tasks/ios/test/text/text_embedder/BUILD new file mode 100644 index 000000000..04359cf9a --- /dev/null +++ b/mediapipe/tasks/ios/test/text/text_embedder/BUILD @@ -0,0 +1,55 @@ +load( + "@build_bazel_rules_apple//apple:ios.bzl", + "ios_unit_test", +) +load( + "@build_bazel_rules_swift//swift:swift.bzl", + "swift_library", +) +load( + "//mediapipe/tasks:ios/ios.bzl", + "MPP_TASK_MINIMUM_OS_VERSION", +) +load( + "@org_tensorflow//tensorflow/lite:special_rules.bzl", + "tflite_ios_lab_runner", +) + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +# Default tags for filtering iOS targets. Targets are restricted to Apple platforms. +TFL_DEFAULT_TAGS = [ + "apple", +] + +# Following sanitizer tests are not supported by iOS test targets. +TFL_DISABLED_SANITIZER_TAGS = [ + "noasan", + "nomsan", + "notsan", +] + +objc_library( + name = "MPPTextEmbedderObjcTestLibrary", + testonly = 1, + srcs = ["MPPTextEmbedderTests.m"], + data = [ + "//mediapipe/tasks/testdata/text:mobilebert_embedding_model", + "//mediapipe/tasks/testdata/text:regex_embedding_with_metadata", + ], + deps = [ + "//mediapipe/tasks/ios/common:MPPCommon", + "//mediapipe/tasks/ios/text/text_embedder:MPPTextEmbedder", + ], +) + +ios_unit_test( + name = "MPPTextEmbedderObjcTest", + minimum_os_version = MPP_TASK_MINIMUM_OS_VERSION, + runner = tflite_ios_lab_runner("IOS_LATEST"), + deps = [ + ":MPPTextEmbedderObjcTestLibrary", + ], +) diff --git a/mediapipe/tasks/ios/test/text/text_embedder/MPPTextEmbedderTests.m b/mediapipe/tasks/ios/test/text/text_embedder/MPPTextEmbedderTests.m new file mode 100644 index 000000000..c58c52298 --- /dev/null +++ b/mediapipe/tasks/ios/test/text/text_embedder/MPPTextEmbedderTests.m @@ -0,0 +1,142 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +#import "mediapipe/tasks/ios/common/sources/MPPCommon.h" +#import "mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.h" + +static NSString *const kBertTextEmbedderModelName = @"mobilebert_embedding_with_metadata"; +static NSString *const kRegexTextEmbedderModelName = @"regex_one_embedding_with_metadata"; +static NSString *const kText1 = @"it's a charming and often affecting journey"; +static NSString *const kText2 = @"what a great and fantastic trip"; +static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks"; +static const float kFloatDiffTolerance = 1e-4; +static const float kDoubleDiffTolerance = 1e-4; + +#define AssertEqualErrors(error, expectedError) \ + XCTAssertNotNil(error); \ + XCTAssertEqualObjects(error.domain, expectedError.domain); \ + XCTAssertEqual(error.code, expectedError.code); \ + XCTAssertNotEqual( \ + [error.localizedDescription rangeOfString:expectedError.localizedDescription].location, \ + NSNotFound) + +#define AssertTextEmbedderResultHasOneEmbedding(textEmbedderResult) \ + XCTAssertNotNil(textEmbedderResult); \ + XCTAssertNotNil(textEmbedderResult.embeddingResult); \ + XCTAssertEqual(textEmbedderResult.embeddingResult.embeddings.count, 1); + +#define AssertEmbeddingIsFloat(embedding) \ + XCTAssertNotNil(embedding.floatEmbedding); \ + XCTAssertNil(embedding.quantizedEmbedding); + +#define AssertFloatEmbeddingHasExpectedValues(floatEmbedding, expectedLength, expectedFirstValue) \ + XCTAssertEqual(floatEmbedding.count, expectedLength); \ + XCTAssertEqualWithAccuracy(floatEmbedding[0].floatValue, expectedFirstValue, kFloatDiffTolerance); + +@interface MPPTextEmbedderTests : XCTestCase +@end + +@implementation MPPTextEmbedderTests + +- (NSString *)filePathWithName:(NSString *)fileName extension:(NSString *)extension { + NSString *filePath = [[NSBundle bundleForClass:self.class] pathForResource:fileName + ofType:extension]; + return filePath; +} + +- (MPPTextEmbedder *)textEmbedderFromModelFileWithName:(NSString *)modelName { + NSString *modelPath = [self filePathWithName:modelName extension:@"tflite"]; + + NSError *error = nil; + MPPTextEmbedder *textEmbedder = [[MPPTextEmbedder alloc] initWithModelPath:modelPath + error:&error]; + + XCTAssertNotNil(textEmbedder); + + return textEmbedder; +} + +- (NSArray *)assertFloatEmbeddingResultsOfEmbedText:(NSString *)text + usingTextEmbedder:(MPPTextEmbedder *)textEmbedder + hasCount:(NSUInteger)embeddingCount + firstValue:(float)firstValue { + MPPTextEmbedderResult *embedderResult = [textEmbedder embedText:text error:nil]; + AssertTextEmbedderResultHasOneEmbedding(embedderResult); + AssertEmbeddingIsFloat(embedderResult.embeddingResult.embeddings[0]); + AssertFloatEmbeddingHasExpectedValues(embedderResult.embeddingResult.embeddings[0].floatEmbedding, + embeddingCount, firstValue); + return embedderResult.embeddingResult.embeddings[0]; +} + +- (void)testCreateTextEmbedderFailsWithMissingModelPath { + NSString *modelPath = [self filePathWithName:@"" extension:@""]; + + NSError *error = nil; + MPPTextEmbedder *textEmbedder = [[MPPTextEmbedder alloc] initWithModelPath:modelPath + error:&error]; + XCTAssertNil(textEmbedder); + + NSError *expectedError = [NSError + errorWithDomain:kExpectedErrorDomain + code:MPPTasksErrorCodeInvalidArgumentError + userInfo:@{ + NSLocalizedDescriptionKey : + @"INVALID_ARGUMENT: ExternalFile must specify at least one of 'file_content', " + @"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'." + }]; + AssertEqualErrors(error, expectedError); +} + +- (void)testEmbedWithBertSucceeds { + MPPTextEmbedder *textEmbedder = + [self textEmbedderFromModelFileWithName:kBertTextEmbedderModelName]; + + MPPEmbedding *embedding1 = [self assertFloatEmbeddingResultsOfEmbedText:kText1 + usingTextEmbedder:textEmbedder + hasCount:512 + firstValue:20.057026f]; + + MPPEmbedding *embedding2 = [self assertFloatEmbeddingResultsOfEmbedText:kText2 + usingTextEmbedder:textEmbedder + hasCount:512 + firstValue:21.254150f]; + NSNumber *cosineSimilarity = [MPPTextEmbedder cosineSimilarityBetweenEmbedding1:embedding1 + andEmbedding2:embedding2 + error:nil]; + XCTAssertEqualWithAccuracy(cosineSimilarity.doubleValue, 0.96386, kDoubleDiffTolerance); +} + +- (void)testEmbedWithRegexSucceeds { + MPPTextEmbedder *textEmbedder = + [self textEmbedderFromModelFileWithName:kRegexTextEmbedderModelName]; + + MPPEmbedding *embedding1 = [self assertFloatEmbeddingResultsOfEmbedText:kText1 + usingTextEmbedder:textEmbedder + hasCount:16 + firstValue:0.030935612f]; + + MPPEmbedding *embedding2 = [self assertFloatEmbeddingResultsOfEmbedText:kText2 + usingTextEmbedder:textEmbedder + hasCount:16 + firstValue:0.0312863f]; + + NSNumber *cosineSimilarity = [MPPTextEmbedder cosineSimilarityBetweenEmbedding1:embedding1 + andEmbedding2:embedding2 + error:nil]; + XCTAssertEqualWithAccuracy(cosineSimilarity.doubleValue, 0.999937f, kDoubleDiffTolerance); +} + +@end