From 84e1c93ffbd6c43c84d229e5235e81159c5e8a25 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 2 Feb 2023 17:22:56 +0530 Subject: [PATCH 01/11] Added MPPCosineSimilarity --- mediapipe/tasks/ios/components/utils/BUILD | 33 +++++++ .../utils/sources/MPPCosineSimilarity.h | 48 ++++++++++ .../utils/sources/MPPCosineSimilarity.mm | 89 +++++++++++++++++++ 3 files changed, 170 insertions(+) create mode 100644 mediapipe/tasks/ios/components/utils/BUILD create mode 100644 mediapipe/tasks/ios/components/utils/sources/MPPCosineSimilarity.h create mode 100644 mediapipe/tasks/ios/components/utils/sources/MPPCosineSimilarity.mm diff --git a/mediapipe/tasks/ios/components/utils/BUILD b/mediapipe/tasks/ios/components/utils/BUILD new file mode 100644 index 000000000..c9f82d1d1 --- /dev/null +++ b/mediapipe/tasks/ios/components/utils/BUILD @@ -0,0 +1,33 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +objc_library( + name = "MPPCosineSimilarity", + srcs = ["sources/MPPCosineSimilarity.mm"], + hdrs = ["sources/MPPCosineSimilarity.h"], + copts = [ + "-ObjC++", + "-std=c++17", + "-x objective-c++", + ], + deps = [ + "//mediapipe/tasks/ios/common:MPPCommon", + "//mediapipe/tasks/ios/common/utils:MPPCommonUtils", + "//mediapipe/tasks/ios/components/containers:MPPEmbedding", + ] +) diff --git a/mediapipe/tasks/ios/components/utils/sources/MPPCosineSimilarity.h b/mediapipe/tasks/ios/components/utils/sources/MPPCosineSimilarity.h new file mode 100644 index 000000000..864baf169 --- /dev/null +++ b/mediapipe/tasks/ios/components/utils/sources/MPPCosineSimilarity.h @@ -0,0 +1,48 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +#import "mediapipe/tasks/ios/components/containers/sources/MPPEmbedding.h" + +NS_ASSUME_NONNULL_BEGIN + +/** Utility class for computing cosine similarity between `MPPEmbedding` objects. */ +NS_SWIFT_NAME(CosineSimilarity) + +@interface MPPCosineSimilarity : NSObject + +- (instancetype)init NS_UNAVAILABLE; + ++ (instancetype)new NS_UNAVAILABLE; + +/** Utility function to compute[cosine similarity](https://en.wikipedia.org/wiki/Cosine_similarity) + * between two `MPPEmbedding` objects. + * + * @param embedding1 One of the two `MPPEmbedding`s between whom cosine similarity is to be + * computed. + * @param embedding2 One of the two `MPPEmbedding`s between whom cosine similarity is to be + * computed. + * @param error An optional error parameter populated when there is an error in calculating cosine + * similarity between two embeddings. + * + * @return An `NSNumber` which holds the cosine similarity of type `double`. + */ ++ (nullable NSNumber *)computeBetweenEmbedding1:(MPPEmbedding *)embedding1 + andEmbedding2:(MPPEmbedding *)embedding2 + error:(NSError **)error; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/components/utils/sources/MPPCosineSimilarity.mm b/mediapipe/tasks/ios/components/utils/sources/MPPCosineSimilarity.mm new file mode 100644 index 000000000..dfbc54e01 --- /dev/null +++ b/mediapipe/tasks/ios/components/utils/sources/MPPCosineSimilarity.mm @@ -0,0 +1,89 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/components/utils/sources/MPPCosineSimilarity.h" + +#import "mediapipe/tasks/ios/common/sources/MPPCommon.h" +#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h" + +#include + +@implementation MPPCosineSimilarity + ++ (nullable NSNumber *)computeBetweenVector1:(NSArray *)u + andVector2:(NSArray *)v + isFloat:(BOOL)isFloat + error:(NSError **)error { + if (u.count != v.count) { + [MPPCommonUtils + createCustomError:error + withCode:MPPTasksErrorCodeInvalidArgumentError + description:[NSString stringWithFormat:@"Cannot compute cosine similarity between " + @"embeddings of different sizes (%d vs %d)", + u.count, v.count]]; + return nil; + } + + __block double dotProduct = 0.0; + __block double normU = 0.0; + __block double normV = 0.0; + + [u enumerateObjectsUsingBlock:^(NSNumber *num, NSUInteger idx, BOOL *stop) { + double uVal = 0.0; + double vVal = 0.0; + + if (isFloat) { + uVal = num.floatValue; + vVal = v[idx].floatValue; + } else { + uVal = num.charValue; + vVal = v[idx].charValue; + } + + dotProduct += uVal * vVal; + normU += uVal * uVal; + normV += vVal * vVal; + }]; + + return [NSNumber numberWithDouble:dotProduct / sqrt(normU * normV)]; +} + ++ (nullable NSNumber *)computeBetweenEmbedding1:(MPPEmbedding *)embedding1 + andEmbedding2:(MPPEmbedding *)embedding2 + error:(NSError **)error { + BOOL isFloat; + + if (embedding1.floatEmbedding && embedding2.floatEmbedding) { + return [MPPCosineSimilarity computeBetweenVector1:embedding1.floatEmbedding + andVector2:embedding2.floatEmbedding + isFloat:YES + error:error]; + } + + if (embedding1.quantizedEmbedding && embedding2.quantizedEmbedding) { + return [MPPCosineSimilarity computeBetweenVector1:embedding1.quantizedEmbedding + andVector2:embedding2.quantizedEmbedding + isFloat:NO + error:error]; + } + + [MPPCommonUtils + createCustomError:error + withCode:MPPTasksErrorCodeInvalidArgumentError + description: + @"Cannot compute cosine similarity between quantized and float embeddings."]; + return nil; +} + +@end From 867520af1c0c56d3a02987e110a733f6aaeca263 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 2 Feb 2023 17:29:51 +0530 Subject: [PATCH 02/11] Added cosine similarity to MPPTextEmbedder --- mediapipe/tasks/ios/text/text_embedder/BUILD | 1 + .../text_embedder/sources/MPPTextEmbedder.h | 21 +++++++++++++++++-- .../text_embedder/sources/MPPTextEmbedder.mm | 9 ++++++++ 3 files changed, 29 insertions(+), 2 deletions(-) diff --git a/mediapipe/tasks/ios/text/text_embedder/BUILD b/mediapipe/tasks/ios/text/text_embedder/BUILD index 21226b012..b02b1a9b5 100644 --- a/mediapipe/tasks/ios/text/text_embedder/BUILD +++ b/mediapipe/tasks/ios/text/text_embedder/BUILD @@ -49,6 +49,7 @@ objc_library( "//mediapipe/tasks/cc/text/text_embedder:text_embedder_graph", "//mediapipe/tasks/ios/common/utils:MPPCommonUtils", "//mediapipe/tasks/ios/common/utils:NSStringHelpers", + "//mediapipe/tasks/ios/components/utils:MPPCosineSimilarity", "//mediapipe/tasks/ios/core:MPPTaskInfo", "//mediapipe/tasks/ios/core:MPPTaskOptions", "//mediapipe/tasks/ios/core:MPPTextPacketCreator", diff --git a/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.h b/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.h index a45ab6747..ba5958a72 100644 --- a/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.h +++ b/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.h @@ -29,7 +29,7 @@ NS_ASSUME_NONNULL_BEGIN * Metadata is required for models with int32 input tensors because it contains the input process * unit for the model's Tokenizer. No metadata is required for models with string input tensors. * - * Input tensors + * Input tensors: * - Three input tensors `kTfLiteInt32` of shape `[batch_size x bert_max_seq_len]` * representing the input ids, mask ids, and segment ids. This input signature requires * a Bert Tokenizer process unit in the model metadata. @@ -62,7 +62,7 @@ NS_SWIFT_NAME(TextEmbedder) * Creates a new instance of `MPPTextEmbedder` from the given `MPPTextEmbedderOptions`. * * @param options The options of type `MPPTextEmbedderOptions` to use for configuring the - * `MPPTextEmbedder. + * `MPPTextEmbedder`. * @param error An optional error parameter populated when there is an error in initializing the * text embedder. * @@ -86,6 +86,23 @@ NS_SWIFT_NAME(TextEmbedder) - (instancetype)init NS_UNAVAILABLE; +/** Utility function to compute[cosine similarity](https://en.wikipedia.org/wiki/Cosine_similarity) + * between two `MPPEmbedding` objects. + * + * @param embedding1 One of the two `MPPEmbedding`s between whom cosine similarity is to be + * computed. + * @param embedding2 One of the two `MPPEmbedding`s between whom cosine similarity is to be + * computed. + * @param error An optional error parameter populated when there is an error in calculating cosine + * similarity between two embeddings. + * + * @return An `NSNumber` which holds the cosine similarity of type `double`. + */ ++ (nullable NSNumber *)cosineSimilarityBetweenEmbedding1:(MPPEmbedding *)embedding1 + andEmbedding2:(MPPEmbedding *)embedding2 + error:(NSError **)error + NS_SWIFT_NAME(cosineSimilarity(embedding1: embedding2:)); + + (instancetype)new NS_UNAVAILABLE; @end diff --git a/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.mm b/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.mm index a9c811cdb..62eb882d3 100644 --- a/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.mm +++ b/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.mm @@ -16,6 +16,7 @@ #import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h" #import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" +#import "mediapipe/tasks/ios/components/utils/sources/MPPCosineSimilarity.h" #import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h" #import "mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.h" #import "mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.h" @@ -93,4 +94,12 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_embedder.Tex .value()[kEmbeddingsOutStreamName.cppString]]; } ++ (nullable NSNumber *)cosineSimilarityBetweenEmbedding1:(MPPEmbedding *)embedding1 + andEmbedding2:(MPPEmbedding *)embedding2 + error:(NSError **)error { + return [MPPCosineSimilarity computeBetweenEmbedding1:embedding1 + andEmbedding2:embedding2 + error:error]; +} + @end From 474e994a5f95e8190ab1be93c20eec493a81edbe Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 2 Feb 2023 17:30:05 +0530 Subject: [PATCH 03/11] Added text embedder objective c tests --- .../tasks/ios/test/text/text_embedder/BUILD | 55 +++++++ .../text/text_embedder/MPPTextEmbedderTests.m | 142 ++++++++++++++++++ 2 files changed, 197 insertions(+) create mode 100644 mediapipe/tasks/ios/test/text/text_embedder/BUILD create mode 100644 mediapipe/tasks/ios/test/text/text_embedder/MPPTextEmbedderTests.m diff --git a/mediapipe/tasks/ios/test/text/text_embedder/BUILD b/mediapipe/tasks/ios/test/text/text_embedder/BUILD new file mode 100644 index 000000000..04359cf9a --- /dev/null +++ b/mediapipe/tasks/ios/test/text/text_embedder/BUILD @@ -0,0 +1,55 @@ +load( + "@build_bazel_rules_apple//apple:ios.bzl", + "ios_unit_test", +) +load( + "@build_bazel_rules_swift//swift:swift.bzl", + "swift_library", +) +load( + "//mediapipe/tasks:ios/ios.bzl", + "MPP_TASK_MINIMUM_OS_VERSION", +) +load( + "@org_tensorflow//tensorflow/lite:special_rules.bzl", + "tflite_ios_lab_runner", +) + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +# Default tags for filtering iOS targets. Targets are restricted to Apple platforms. +TFL_DEFAULT_TAGS = [ + "apple", +] + +# Following sanitizer tests are not supported by iOS test targets. +TFL_DISABLED_SANITIZER_TAGS = [ + "noasan", + "nomsan", + "notsan", +] + +objc_library( + name = "MPPTextEmbedderObjcTestLibrary", + testonly = 1, + srcs = ["MPPTextEmbedderTests.m"], + data = [ + "//mediapipe/tasks/testdata/text:mobilebert_embedding_model", + "//mediapipe/tasks/testdata/text:regex_embedding_with_metadata", + ], + deps = [ + "//mediapipe/tasks/ios/common:MPPCommon", + "//mediapipe/tasks/ios/text/text_embedder:MPPTextEmbedder", + ], +) + +ios_unit_test( + name = "MPPTextEmbedderObjcTest", + minimum_os_version = MPP_TASK_MINIMUM_OS_VERSION, + runner = tflite_ios_lab_runner("IOS_LATEST"), + deps = [ + ":MPPTextEmbedderObjcTestLibrary", + ], +) diff --git a/mediapipe/tasks/ios/test/text/text_embedder/MPPTextEmbedderTests.m b/mediapipe/tasks/ios/test/text/text_embedder/MPPTextEmbedderTests.m new file mode 100644 index 000000000..c58c52298 --- /dev/null +++ b/mediapipe/tasks/ios/test/text/text_embedder/MPPTextEmbedderTests.m @@ -0,0 +1,142 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +#import "mediapipe/tasks/ios/common/sources/MPPCommon.h" +#import "mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.h" + +static NSString *const kBertTextEmbedderModelName = @"mobilebert_embedding_with_metadata"; +static NSString *const kRegexTextEmbedderModelName = @"regex_one_embedding_with_metadata"; +static NSString *const kText1 = @"it's a charming and often affecting journey"; +static NSString *const kText2 = @"what a great and fantastic trip"; +static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks"; +static const float kFloatDiffTolerance = 1e-4; +static const float kDoubleDiffTolerance = 1e-4; + +#define AssertEqualErrors(error, expectedError) \ + XCTAssertNotNil(error); \ + XCTAssertEqualObjects(error.domain, expectedError.domain); \ + XCTAssertEqual(error.code, expectedError.code); \ + XCTAssertNotEqual( \ + [error.localizedDescription rangeOfString:expectedError.localizedDescription].location, \ + NSNotFound) + +#define AssertTextEmbedderResultHasOneEmbedding(textEmbedderResult) \ + XCTAssertNotNil(textEmbedderResult); \ + XCTAssertNotNil(textEmbedderResult.embeddingResult); \ + XCTAssertEqual(textEmbedderResult.embeddingResult.embeddings.count, 1); + +#define AssertEmbeddingIsFloat(embedding) \ + XCTAssertNotNil(embedding.floatEmbedding); \ + XCTAssertNil(embedding.quantizedEmbedding); + +#define AssertFloatEmbeddingHasExpectedValues(floatEmbedding, expectedLength, expectedFirstValue) \ + XCTAssertEqual(floatEmbedding.count, expectedLength); \ + XCTAssertEqualWithAccuracy(floatEmbedding[0].floatValue, expectedFirstValue, kFloatDiffTolerance); + +@interface MPPTextEmbedderTests : XCTestCase +@end + +@implementation MPPTextEmbedderTests + +- (NSString *)filePathWithName:(NSString *)fileName extension:(NSString *)extension { + NSString *filePath = [[NSBundle bundleForClass:self.class] pathForResource:fileName + ofType:extension]; + return filePath; +} + +- (MPPTextEmbedder *)textEmbedderFromModelFileWithName:(NSString *)modelName { + NSString *modelPath = [self filePathWithName:modelName extension:@"tflite"]; + + NSError *error = nil; + MPPTextEmbedder *textEmbedder = [[MPPTextEmbedder alloc] initWithModelPath:modelPath + error:&error]; + + XCTAssertNotNil(textEmbedder); + + return textEmbedder; +} + +- (NSArray *)assertFloatEmbeddingResultsOfEmbedText:(NSString *)text + usingTextEmbedder:(MPPTextEmbedder *)textEmbedder + hasCount:(NSUInteger)embeddingCount + firstValue:(float)firstValue { + MPPTextEmbedderResult *embedderResult = [textEmbedder embedText:text error:nil]; + AssertTextEmbedderResultHasOneEmbedding(embedderResult); + AssertEmbeddingIsFloat(embedderResult.embeddingResult.embeddings[0]); + AssertFloatEmbeddingHasExpectedValues(embedderResult.embeddingResult.embeddings[0].floatEmbedding, + embeddingCount, firstValue); + return embedderResult.embeddingResult.embeddings[0]; +} + +- (void)testCreateTextEmbedderFailsWithMissingModelPath { + NSString *modelPath = [self filePathWithName:@"" extension:@""]; + + NSError *error = nil; + MPPTextEmbedder *textEmbedder = [[MPPTextEmbedder alloc] initWithModelPath:modelPath + error:&error]; + XCTAssertNil(textEmbedder); + + NSError *expectedError = [NSError + errorWithDomain:kExpectedErrorDomain + code:MPPTasksErrorCodeInvalidArgumentError + userInfo:@{ + NSLocalizedDescriptionKey : + @"INVALID_ARGUMENT: ExternalFile must specify at least one of 'file_content', " + @"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'." + }]; + AssertEqualErrors(error, expectedError); +} + +- (void)testEmbedWithBertSucceeds { + MPPTextEmbedder *textEmbedder = + [self textEmbedderFromModelFileWithName:kBertTextEmbedderModelName]; + + MPPEmbedding *embedding1 = [self assertFloatEmbeddingResultsOfEmbedText:kText1 + usingTextEmbedder:textEmbedder + hasCount:512 + firstValue:20.057026f]; + + MPPEmbedding *embedding2 = [self assertFloatEmbeddingResultsOfEmbedText:kText2 + usingTextEmbedder:textEmbedder + hasCount:512 + firstValue:21.254150f]; + NSNumber *cosineSimilarity = [MPPTextEmbedder cosineSimilarityBetweenEmbedding1:embedding1 + andEmbedding2:embedding2 + error:nil]; + XCTAssertEqualWithAccuracy(cosineSimilarity.doubleValue, 0.96386, kDoubleDiffTolerance); +} + +- (void)testEmbedWithRegexSucceeds { + MPPTextEmbedder *textEmbedder = + [self textEmbedderFromModelFileWithName:kRegexTextEmbedderModelName]; + + MPPEmbedding *embedding1 = [self assertFloatEmbeddingResultsOfEmbedText:kText1 + usingTextEmbedder:textEmbedder + hasCount:16 + firstValue:0.030935612f]; + + MPPEmbedding *embedding2 = [self assertFloatEmbeddingResultsOfEmbedText:kText2 + usingTextEmbedder:textEmbedder + hasCount:16 + firstValue:0.0312863f]; + + NSNumber *cosineSimilarity = [MPPTextEmbedder cosineSimilarityBetweenEmbedding1:embedding1 + andEmbedding2:embedding2 + error:nil]; + XCTAssertEqualWithAccuracy(cosineSimilarity.doubleValue, 0.999937f, kDoubleDiffTolerance); +} + +@end From d6259189954f86869e3f666e750e2f18c2d7e818 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 2 Feb 2023 18:36:55 +0530 Subject: [PATCH 04/11] Added swift tests for text embedder --- .../tasks/ios/test/text/text_embedder/BUILD | 25 ++++ .../text_embedder/TextEmbedderTests.swift | 114 ++++++++++++++++++ .../text_embedder/sources/MPPTextEmbedder.h | 4 +- 3 files changed, 141 insertions(+), 2 deletions(-) create mode 100644 mediapipe/tasks/ios/test/text/text_embedder/TextEmbedderTests.swift diff --git a/mediapipe/tasks/ios/test/text/text_embedder/BUILD b/mediapipe/tasks/ios/test/text/text_embedder/BUILD index 04359cf9a..d4b0ac6d7 100644 --- a/mediapipe/tasks/ios/test/text/text_embedder/BUILD +++ b/mediapipe/tasks/ios/test/text/text_embedder/BUILD @@ -53,3 +53,28 @@ ios_unit_test( ":MPPTextEmbedderObjcTestLibrary", ], ) + +swift_library( + name = "MPPTextEmbedderSwiftTestLibrary", + testonly = 1, + srcs = ["TextEmbedderTests.swift"], + data = [ + "//mediapipe/tasks/testdata/text:mobilebert_embedding_model", + "//mediapipe/tasks/testdata/text:regex_embedding_with_metadata", + ], + tags = TFL_DEFAULT_TAGS, + deps = [ + "//mediapipe/tasks/ios/common:MPPCommon", + "//mediapipe/tasks/ios/text/text_embedder:MPPTextEmbedder", + ], +) + +ios_unit_test( + name = "MPPTextEmbedderSwiftTest", + minimum_os_version = MPP_TASK_MINIMUM_OS_VERSION, + runner = tflite_ios_lab_runner("IOS_LATEST"), + tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS, + deps = [ + ":MPPTextEmbedderSwiftTestLibrary", + ], +) diff --git a/mediapipe/tasks/ios/test/text/text_embedder/TextEmbedderTests.swift b/mediapipe/tasks/ios/test/text/text_embedder/TextEmbedderTests.swift new file mode 100644 index 000000000..bd7f6d5db --- /dev/null +++ b/mediapipe/tasks/ios/test/text/text_embedder/TextEmbedderTests.swift @@ -0,0 +1,114 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import MPPCommon +import XCTest + +@testable import MPPTextEmbedder + +class TextEmbedderTests: XCTestCase { + + static let bundle = Bundle(for: TextEmbedderTests.self) + + static let bertModelPath = bundle.path( + forResource: "mobilebert_embedding_with_metadata", + ofType: "tflite") + + static let text1 = "it's a charming and often affecting journey" + + static let text2 = "what a great and fantastic trip" + + static let floatDiffTolerance: Float = 1e-4 + + static let doubleDiffTolerance: Double = 1e-4 + + func assertEqualErrorDescriptions( + _ error: Error, expectedLocalizedDescription: String + ) { + XCTAssertEqual( + error.localizedDescription, + expectedLocalizedDescription) + } + + func assertTextEmbedderResultHasOneEmbedding( + _ textEmbedderResult: TextEmbedderResult + ) { + XCTAssertEqual(textEmbedderResult.embeddingResult.embeddings.count, 1) + } + + func assertEmbeddingIsFloat( + _ embedding: Embedding + ) { + XCTAssertNil(embedding.quantizedEmbedding) + XCTAssertNotNil(embedding.floatEmbedding) + } + + func assertEmbedding( + _ floatEmbedding: [NSNumber], + hasCount embeddingCount: Int, + hasFirstValue firstValue: Float + ) { + XCTAssertEqual(floatEmbedding.count, embeddingCount); + XCTAssertEqual( + floatEmbedding[0].floatValue, + firstValue, accuracy: + TextEmbedderTests.floatDiffTolerance); + } + + func assertFloatEmbeddingResultsForEmbed( + text: String, + using textEmbedder: TextEmbedder, + hasCount embeddingCount: Int, + hasFirstValue firstValue: Float + ) throws -> Embedding { + let textEmbedderResult = + try XCTUnwrap( + textEmbedder.embed(text: text)) + assertTextEmbedderResultHasOneEmbedding(textEmbedderResult) + assertEmbeddingIsFloat(textEmbedderResult.embeddingResult.embeddings[0]) + assertEmbedding( + textEmbedderResult.embeddingResult.embeddings[0].floatEmbedding!, + hasCount: embeddingCount, + hasFirstValue: firstValue) + + return textEmbedderResult.embeddingResult.embeddings[0] + } + + func testEmbedWithBertSucceeds() throws { + + let modelPath = try XCTUnwrap(TextEmbedderTests.bertModelPath) + let textEmbedder = try XCTUnwrap(TextEmbedder(modelPath: modelPath)) + + let embedding1 = try assertFloatEmbeddingResultsForEmbed( + text: TextEmbedderTests.text1, + using: textEmbedder, + hasCount: 512, + hasFirstValue: 20.057026) + + let embedding2 = try assertFloatEmbeddingResultsForEmbed( + text: TextEmbedderTests.text2, + using: textEmbedder, + hasCount: 512, + hasFirstValue: 21.254150) + + let cosineSimilarity = try XCTUnwrap(TextEmbedder.cosineSimilarity( + embedding1: embedding1, + embedding2: embedding2)) + + XCTAssertEqual( + cosineSimilarity.doubleValue, + 0.96386, + accuracy: TextEmbedderTests.doubleDiffTolerance) + } +} diff --git a/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.h b/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.h index ba5958a72..3eecd686f 100644 --- a/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.h +++ b/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.h @@ -100,8 +100,8 @@ NS_SWIFT_NAME(TextEmbedder) */ + (nullable NSNumber *)cosineSimilarityBetweenEmbedding1:(MPPEmbedding *)embedding1 andEmbedding2:(MPPEmbedding *)embedding2 - error:(NSError **)error - NS_SWIFT_NAME(cosineSimilarity(embedding1: embedding2:)); + error:(NSError **)error NS_SWIFT_NAME(cosineSimilarity(embedding1:embedding2:)); + // NS_SWIFT_NAME(cosineSimilarity(embedding1: embedding2:)); + (instancetype)new NS_UNAVAILABLE; From 20002f191a78378ee0711efdf419558110ea0724 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 2 Feb 2023 18:38:19 +0530 Subject: [PATCH 05/11] Changed documentation --- .../tasks/ios/components/utils/sources/MPPCosineSimilarity.h | 2 +- .../tasks/ios/components/utils/sources/MPPCosineSimilarity.mm | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mediapipe/tasks/ios/components/utils/sources/MPPCosineSimilarity.h b/mediapipe/tasks/ios/components/utils/sources/MPPCosineSimilarity.h index 864baf169..9e47960c7 100644 --- a/mediapipe/tasks/ios/components/utils/sources/MPPCosineSimilarity.h +++ b/mediapipe/tasks/ios/components/utils/sources/MPPCosineSimilarity.h @@ -1,4 +1,4 @@ -// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// Copyright 2023 The MediaPipe Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/mediapipe/tasks/ios/components/utils/sources/MPPCosineSimilarity.mm b/mediapipe/tasks/ios/components/utils/sources/MPPCosineSimilarity.mm index dfbc54e01..bc90ce95e 100644 --- a/mediapipe/tasks/ios/components/utils/sources/MPPCosineSimilarity.mm +++ b/mediapipe/tasks/ios/components/utils/sources/MPPCosineSimilarity.mm @@ -1,4 +1,4 @@ -// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// Copyright 2023 The MediaPipe Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. From 6ca1efdd55415e330f1d6b4f3303a0494814caee Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Fri, 3 Feb 2023 12:48:06 +0530 Subject: [PATCH 06/11] Updated MPPTextEmbedder Documentation --- .../tasks/ios/text/text_embedder/sources/MPPTextEmbedder.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.h b/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.h index 3eecd686f..61a7dd4d2 100644 --- a/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.h +++ b/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.h @@ -86,7 +86,8 @@ NS_SWIFT_NAME(TextEmbedder) - (instancetype)init NS_UNAVAILABLE; -/** Utility function to compute[cosine similarity](https://en.wikipedia.org/wiki/Cosine_similarity) +/** + * Utility function to compute[cosine similarity](https://en.wikipedia.org/wiki/Cosine_similarity) * between two `MPPEmbedding` objects. * * @param embedding1 One of the two `MPPEmbedding`s between whom cosine similarity is to be From a512e6b5f511bc56653543a2a03bb2515dbf92cf Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Fri, 3 Feb 2023 12:49:00 +0530 Subject: [PATCH 07/11] Updated MPPTextEmbedder Documentation --- mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.h | 1 - 1 file changed, 1 deletion(-) diff --git a/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.h b/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.h index 61a7dd4d2..f60e88ba4 100644 --- a/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.h +++ b/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.h @@ -102,7 +102,6 @@ NS_SWIFT_NAME(TextEmbedder) + (nullable NSNumber *)cosineSimilarityBetweenEmbedding1:(MPPEmbedding *)embedding1 andEmbedding2:(MPPEmbedding *)embedding2 error:(NSError **)error NS_SWIFT_NAME(cosineSimilarity(embedding1:embedding2:)); - // NS_SWIFT_NAME(cosineSimilarity(embedding1: embedding2:)); + (instancetype)new NS_UNAVAILABLE; From b5b10e7681a80bd1ee1860e8adad667c80c15b0f Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Fri, 3 Feb 2023 13:10:13 +0530 Subject: [PATCH 08/11] Added iOS test for different themes in text embedder --- .../text/text_embedder/MPPTextEmbedderTests.m | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/mediapipe/tasks/ios/test/text/text_embedder/MPPTextEmbedderTests.m b/mediapipe/tasks/ios/test/text/text_embedder/MPPTextEmbedderTests.m index c58c52298..36e0ef8c0 100644 --- a/mediapipe/tasks/ios/test/text/text_embedder/MPPTextEmbedderTests.m +++ b/mediapipe/tasks/ios/test/text/text_embedder/MPPTextEmbedderTests.m @@ -139,4 +139,31 @@ static const float kDoubleDiffTolerance = 1e-4; XCTAssertEqualWithAccuracy(cosineSimilarity.doubleValue, 0.999937f, kDoubleDiffTolerance); } +- (void)testEmbedWithBertAndDifferentThemesSucceeds { + MPPTextEmbedder *textEmbedder = + [self textEmbedderFromModelFileWithName:kBertTextEmbedderModelName]; + + MPPEmbedding *embedding1 = + [self assertFloatEmbeddingResultsOfEmbedText: + @"When you go to this restaurant, they hold the pancake upside-down before they " + @"hand it to you. It's a great gimmick." + usingTextEmbedder:textEmbedder + hasCount:512 + firstValue:42.0832]; + + MPPEmbedding *embedding2 = + [self assertFloatEmbeddingResultsOfEmbedText: + @"Let's make a plan to steal the declaration of independence." + usingTextEmbedder:textEmbedder + hasCount:512 + firstValue:50.8856]; + + NSNumber *cosineSimilarity = [MPPTextEmbedder cosineSimilarityBetweenEmbedding1:embedding1 + andEmbedding2:embedding2 + error:nil]; + + // TODO: The similarity should likely be lower + XCTAssertEqualWithAccuracy(cosineSimilarity.doubleValue, 0.963203f, kDoubleDiffTolerance); +} + @end From 3b55fb9f6a86d7dfbf9c2119543f2addc00a751f Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Fri, 3 Feb 2023 13:42:32 +0530 Subject: [PATCH 09/11] Added iOS test for quantized embedding --- .../text/text_embedder/MPPTextEmbedderTests.m | 103 +++++++++++++++--- 1 file changed, 89 insertions(+), 14 deletions(-) diff --git a/mediapipe/tasks/ios/test/text/text_embedder/MPPTextEmbedderTests.m b/mediapipe/tasks/ios/test/text/text_embedder/MPPTextEmbedderTests.m index 36e0ef8c0..0468c9b81 100644 --- a/mediapipe/tasks/ios/test/text/text_embedder/MPPTextEmbedderTests.m +++ b/mediapipe/tasks/ios/test/text/text_embedder/MPPTextEmbedderTests.m @@ -23,7 +23,7 @@ static NSString *const kText1 = @"it's a charming and often affecting journey"; static NSString *const kText2 = @"what a great and fantastic trip"; static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks"; static const float kFloatDiffTolerance = 1e-4; -static const float kDoubleDiffTolerance = 1e-4; +static const float kSimilarityDiffTolerance = 1e-4; #define AssertEqualErrors(error, expectedError) \ XCTAssertNotNil(error); \ @@ -38,13 +38,24 @@ static const float kDoubleDiffTolerance = 1e-4; XCTAssertNotNil(textEmbedderResult.embeddingResult); \ XCTAssertEqual(textEmbedderResult.embeddingResult.embeddings.count, 1); -#define AssertEmbeddingIsFloat(embedding) \ - XCTAssertNotNil(embedding.floatEmbedding); \ - XCTAssertNil(embedding.quantizedEmbedding); +#define AssertEmbeddingType(embedding, quantized) \ + if (quantized) { \ + XCTAssertNil(embedding.floatEmbedding); \ + XCTAssertNotNil(embedding.quantizedEmbedding); \ + } \ + else { \ + XCTAssertNotNil(embedding.floatEmbedding); \ + XCTAssertNil(embedding.quantizedEmbedding);\ + } -#define AssertFloatEmbeddingHasExpectedValues(floatEmbedding, expectedLength, expectedFirstValue) \ - XCTAssertEqual(floatEmbedding.count, expectedLength); \ - XCTAssertEqualWithAccuracy(floatEmbedding[0].floatValue, expectedFirstValue, kFloatDiffTolerance); +#define AssertEmbeddingHasExpectedValues(embedding, expectedLength, expectedFirstValue, quantize) \ + XCTAssertEqual(embedding.count, expectedLength); \ + if (quantize) { \ + XCTAssertEqual(embedding[0].charValue, expectedFirstValue); \ + } \ + else { \ + XCTAssertEqualWithAccuracy(embedding[0].floatValue, expectedFirstValue, kFloatDiffTolerance); \ + } \ @interface MPPTextEmbedderTests : XCTestCase @end @@ -69,15 +80,55 @@ static const float kDoubleDiffTolerance = 1e-4; return textEmbedder; } +- (MPPTextEmbedderOptions *)textEmbedderOptionsWithModelName:(NSString *)modelName { + NSString *modelPath = [self filePathWithName:modelName extension:@"tflite"]; + MPPTextEmbedderOptions *textEmbedderOptions = [[MPPTextEmbedderOptions alloc] init]; + textEmbedderOptions.baseOptions.modelAssetPath = modelPath; + + return textEmbedderOptions; +} + - (NSArray *)assertFloatEmbeddingResultsOfEmbedText:(NSString *)text usingTextEmbedder:(MPPTextEmbedder *)textEmbedder hasCount:(NSUInteger)embeddingCount firstValue:(float)firstValue { MPPTextEmbedderResult *embedderResult = [textEmbedder embedText:text error:nil]; AssertTextEmbedderResultHasOneEmbedding(embedderResult); - AssertEmbeddingIsFloat(embedderResult.embeddingResult.embeddings[0]); - AssertFloatEmbeddingHasExpectedValues(embedderResult.embeddingResult.embeddings[0].floatEmbedding, - embeddingCount, firstValue); + + AssertEmbeddingType( + embedderResult.embeddingResult.embeddings[0], // embedding + NO // quantized + ); + + AssertEmbeddingHasExpectedValues( + embedderResult.embeddingResult.embeddings[0].floatEmbedding, // embedding + embeddingCount, // expectedLength + firstValue, // expectedFirstValue + NO // quantize + ); + + return embedderResult.embeddingResult.embeddings[0]; +} + +- (NSArray*)assertQuantizedEmbeddingResultsOfEmbedText:(NSString *)text + usingTextEmbedder:(MPPTextEmbedder *)textEmbedder + hasCount:(NSUInteger)embeddingCount + firstValue:(char)firstValue { + MPPTextEmbedderResult *embedderResult = [textEmbedder embedText:text error:nil]; + AssertTextEmbedderResultHasOneEmbedding(embedderResult); + + AssertEmbeddingType( + embedderResult.embeddingResult.embeddings[0], // embedding + YES // quantized + ); + + AssertEmbeddingHasExpectedValues( + embedderResult.embeddingResult.embeddings[0].quantizedEmbedding, // embedding + embeddingCount, // expectedLength + firstValue, // expectedFirstValue + YES // quantize + ); + return embedderResult.embeddingResult.embeddings[0]; } @@ -97,7 +148,10 @@ static const float kDoubleDiffTolerance = 1e-4; @"INVALID_ARGUMENT: ExternalFile must specify at least one of 'file_content', " @"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'." }]; - AssertEqualErrors(error, expectedError); + AssertEqualErrors( + error, // error + expectedError // expectedError + ); } - (void)testEmbedWithBertSucceeds { @@ -116,7 +170,7 @@ static const float kDoubleDiffTolerance = 1e-4; NSNumber *cosineSimilarity = [MPPTextEmbedder cosineSimilarityBetweenEmbedding1:embedding1 andEmbedding2:embedding2 error:nil]; - XCTAssertEqualWithAccuracy(cosineSimilarity.doubleValue, 0.96386, kDoubleDiffTolerance); + XCTAssertEqualWithAccuracy(cosineSimilarity.doubleValue, 0.96386, kSimilarityDiffTolerance); } - (void)testEmbedWithRegexSucceeds { @@ -136,7 +190,7 @@ static const float kDoubleDiffTolerance = 1e-4; NSNumber *cosineSimilarity = [MPPTextEmbedder cosineSimilarityBetweenEmbedding1:embedding1 andEmbedding2:embedding2 error:nil]; - XCTAssertEqualWithAccuracy(cosineSimilarity.doubleValue, 0.999937f, kDoubleDiffTolerance); + XCTAssertEqualWithAccuracy(cosineSimilarity.doubleValue, 0.999937f, kSimilarityDiffTolerance); } - (void)testEmbedWithBertAndDifferentThemesSucceeds { @@ -163,7 +217,28 @@ static const float kDoubleDiffTolerance = 1e-4; error:nil]; // TODO: The similarity should likely be lower - XCTAssertEqualWithAccuracy(cosineSimilarity.doubleValue, 0.963203f, kDoubleDiffTolerance); + XCTAssertEqualWithAccuracy(cosineSimilarity.doubleValue, 0.963203f, kSimilarityDiffTolerance); +} + +- (void)testEmbedWithQuantizeSucceeds { + MPPTextEmbedderOptions *options = + [self textEmbedderOptionsWithModelName:kBertTextEmbedderModelName]; + options.quantize = YES; + + MPPTextEmbedder *textEmbedder = [[MPPTextEmbedder alloc] initWithOptions:options error:nil]; + XCTAssertNotNil(textEmbedder); + + MPPEmbedding *embedding1 = [self assertQuantizedEmbeddingResultsOfEmbedText:@"it's a charming and often affecting journey" + usingTextEmbedder:textEmbedder + hasCount:512 + firstValue:127]; + + MPPEmbedding *embedding2 = [self assertQuantizedEmbeddingResultsOfEmbedText:@"what a great and fantastic trip" + usingTextEmbedder:textEmbedder + hasCount:512 + firstValue:127]; + NSNumber *cosineSimilarity = [MPPTextEmbedder cosineSimilarityBetweenEmbedding1:embedding1 andEmbedding2:embedding2 error:nil]; + XCTAssertEqualWithAccuracy(cosineSimilarity.doubleValue, 0.864113, kSimilarityDiffTolerance); } @end From e290f9cf30a6e47857eeb905ea39fcb82c236ff4 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Fri, 3 Feb 2023 18:05:49 +0530 Subject: [PATCH 10/11] Added a note about swift test coverage in iOS text embedder tests --- .../ios/test/text/text_embedder/TextEmbedderTests.swift | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/mediapipe/tasks/ios/test/text/text_embedder/TextEmbedderTests.swift b/mediapipe/tasks/ios/test/text/text_embedder/TextEmbedderTests.swift index bd7f6d5db..98d83a981 100644 --- a/mediapipe/tasks/ios/test/text/text_embedder/TextEmbedderTests.swift +++ b/mediapipe/tasks/ios/test/text/text_embedder/TextEmbedderTests.swift @@ -17,6 +17,11 @@ import XCTest @testable import MPPTextEmbedder +/** These tests are only for validating the Swift function signatures of the TextEmbedder. + * Objective C tests of the TextEmbedder provide more coverage with unit tests for + * different models and text embedder options. They can be found here: + * /mediapipe/tasks/ios/test/text/text_embedder/MPPTextEmbedderTests.m + */ class TextEmbedderTests: XCTestCase { static let bundle = Bundle(for: TextEmbedderTests.self) From eeaa011998c5c0133369153c5f92bdd3739580c7 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Fri, 3 Feb 2023 18:06:05 +0530 Subject: [PATCH 11/11] Updated documentation of iOS text embedder tests --- .../text/text_embedder/MPPTextEmbedderTests.m | 103 +++++++++--------- 1 file changed, 51 insertions(+), 52 deletions(-) diff --git a/mediapipe/tasks/ios/test/text/text_embedder/MPPTextEmbedderTests.m b/mediapipe/tasks/ios/test/text/text_embedder/MPPTextEmbedderTests.m index 0468c9b81..2fa0f58f4 100644 --- a/mediapipe/tasks/ios/test/text/text_embedder/MPPTextEmbedderTests.m +++ b/mediapipe/tasks/ios/test/text/text_embedder/MPPTextEmbedderTests.m @@ -38,24 +38,22 @@ static const float kSimilarityDiffTolerance = 1e-4; XCTAssertNotNil(textEmbedderResult.embeddingResult); \ XCTAssertEqual(textEmbedderResult.embeddingResult.embeddings.count, 1); -#define AssertEmbeddingType(embedding, quantized) \ - if (quantized) { \ - XCTAssertNil(embedding.floatEmbedding); \ - XCTAssertNotNil(embedding.quantizedEmbedding); \ - } \ - else { \ - XCTAssertNotNil(embedding.floatEmbedding); \ - XCTAssertNil(embedding.quantizedEmbedding);\ - } +#define AssertEmbeddingType(embedding, quantized) \ + if (quantized) { \ + XCTAssertNil(embedding.floatEmbedding); \ + XCTAssertNotNil(embedding.quantizedEmbedding); \ + } else { \ + XCTAssertNotNil(embedding.floatEmbedding); \ + XCTAssertNil(embedding.quantizedEmbedding); \ + } #define AssertEmbeddingHasExpectedValues(embedding, expectedLength, expectedFirstValue, quantize) \ - XCTAssertEqual(embedding.count, expectedLength); \ - if (quantize) { \ - XCTAssertEqual(embedding[0].charValue, expectedFirstValue); \ - } \ - else { \ + XCTAssertEqual(embedding.count, expectedLength); \ + if (quantize) { \ + XCTAssertEqual(embedding[0].charValue, expectedFirstValue); \ + } else { \ XCTAssertEqualWithAccuracy(embedding[0].floatValue, expectedFirstValue, kFloatDiffTolerance); \ - } \ + } @interface MPPTextEmbedderTests : XCTestCase @end @@ -94,41 +92,39 @@ static const float kSimilarityDiffTolerance = 1e-4; firstValue:(float)firstValue { MPPTextEmbedderResult *embedderResult = [textEmbedder embedText:text error:nil]; AssertTextEmbedderResultHasOneEmbedding(embedderResult); - - AssertEmbeddingType( - embedderResult.embeddingResult.embeddings[0], // embedding - NO // quantized + + AssertEmbeddingType(embedderResult.embeddingResult.embeddings[0], // embedding + NO // quantized ); - + AssertEmbeddingHasExpectedValues( - embedderResult.embeddingResult.embeddings[0].floatEmbedding, // embedding - embeddingCount, // expectedLength - firstValue, // expectedFirstValue - NO // quantize + embedderResult.embeddingResult.embeddings[0].floatEmbedding, // embedding + embeddingCount, // expectedLength + firstValue, // expectedFirstValue + NO // quantize ); - + return embedderResult.embeddingResult.embeddings[0]; } -- (NSArray*)assertQuantizedEmbeddingResultsOfEmbedText:(NSString *)text - usingTextEmbedder:(MPPTextEmbedder *)textEmbedder - hasCount:(NSUInteger)embeddingCount - firstValue:(char)firstValue { +- (NSArray *)assertQuantizedEmbeddingResultsOfEmbedText:(NSString *)text + usingTextEmbedder:(MPPTextEmbedder *)textEmbedder + hasCount:(NSUInteger)embeddingCount + firstValue:(char)firstValue { MPPTextEmbedderResult *embedderResult = [textEmbedder embedText:text error:nil]; AssertTextEmbedderResultHasOneEmbedding(embedderResult); - - AssertEmbeddingType( - embedderResult.embeddingResult.embeddings[0], // embedding - YES // quantized + + AssertEmbeddingType(embedderResult.embeddingResult.embeddings[0], // embedding + YES // quantized ); - + AssertEmbeddingHasExpectedValues( - embedderResult.embeddingResult.embeddings[0].quantizedEmbedding, // embedding - embeddingCount, // expectedLength - firstValue, // expectedFirstValue - YES // quantize + embedderResult.embeddingResult.embeddings[0].quantizedEmbedding, // embedding + embeddingCount, // expectedLength + firstValue, // expectedFirstValue + YES // quantize ); - + return embedderResult.embeddingResult.embeddings[0]; } @@ -148,9 +144,8 @@ static const float kSimilarityDiffTolerance = 1e-4; @"INVALID_ARGUMENT: ExternalFile must specify at least one of 'file_content', " @"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'." }]; - AssertEqualErrors( - error, // error - expectedError // expectedError + AssertEqualErrors(error, // error + expectedError // expectedError ); } @@ -228,17 +223,21 @@ static const float kSimilarityDiffTolerance = 1e-4; MPPTextEmbedder *textEmbedder = [[MPPTextEmbedder alloc] initWithOptions:options error:nil]; XCTAssertNotNil(textEmbedder); - MPPEmbedding *embedding1 = [self assertQuantizedEmbeddingResultsOfEmbedText:@"it's a charming and often affecting journey" - usingTextEmbedder:textEmbedder - hasCount:512 - firstValue:127]; + MPPEmbedding *embedding1 = [self + assertQuantizedEmbeddingResultsOfEmbedText:@"it's a charming and often affecting journey" + usingTextEmbedder:textEmbedder + hasCount:512 + firstValue:127]; - MPPEmbedding *embedding2 = [self assertQuantizedEmbeddingResultsOfEmbedText:@"what a great and fantastic trip" - usingTextEmbedder:textEmbedder - hasCount:512 - firstValue:127]; - NSNumber *cosineSimilarity = [MPPTextEmbedder cosineSimilarityBetweenEmbedding1:embedding1 andEmbedding2:embedding2 error:nil]; - XCTAssertEqualWithAccuracy(cosineSimilarity.doubleValue, 0.864113, kSimilarityDiffTolerance); + MPPEmbedding *embedding2 = + [self assertQuantizedEmbeddingResultsOfEmbedText:@"what a great and fantastic trip" + usingTextEmbedder:textEmbedder + hasCount:512 + firstValue:127]; + NSNumber *cosineSimilarity = [MPPTextEmbedder cosineSimilarityBetweenEmbedding1:embedding1 + andEmbedding2:embedding2 + error:nil]; + XCTAssertEqualWithAccuracy(cosineSimilarity.doubleValue, 0.864113, kSimilarityDiffTolerance); } @end