Merge pull request #4052 from priankakariatyml:ios-text-embedder
PiperOrigin-RevId: 507602101
This commit is contained in:
commit
28c07430ba
33
mediapipe/tasks/ios/components/utils/BUILD
Normal file
33
mediapipe/tasks/ios/components/utils/BUILD
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
objc_library(
|
||||||
|
name = "MPPCosineSimilarity",
|
||||||
|
srcs = ["sources/MPPCosineSimilarity.mm"],
|
||||||
|
hdrs = ["sources/MPPCosineSimilarity.h"],
|
||||||
|
copts = [
|
||||||
|
"-ObjC++",
|
||||||
|
"-std=c++17",
|
||||||
|
"-x objective-c++",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/tasks/ios/common:MPPCommon",
|
||||||
|
"//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
|
||||||
|
"//mediapipe/tasks/ios/components/containers:MPPEmbedding",
|
||||||
|
],
|
||||||
|
)
|
|
@ -0,0 +1,48 @@
|
||||||
|
// Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#import <Foundation/Foundation.h>
|
||||||
|
|
||||||
|
#import "mediapipe/tasks/ios/components/containers/sources/MPPEmbedding.h"
|
||||||
|
|
||||||
|
NS_ASSUME_NONNULL_BEGIN
|
||||||
|
|
||||||
|
/** Utility class for computing cosine similarity between `MPPEmbedding` objects. */
|
||||||
|
NS_SWIFT_NAME(CosineSimilarity)
|
||||||
|
|
||||||
|
@interface MPPCosineSimilarity : NSObject
|
||||||
|
|
||||||
|
- (instancetype)init NS_UNAVAILABLE;
|
||||||
|
|
||||||
|
+ (instancetype)new NS_UNAVAILABLE;
|
||||||
|
|
||||||
|
/** Utility function to compute[cosine similarity](https://en.wikipedia.org/wiki/Cosine_similarity)
|
||||||
|
* between two `MPPEmbedding` objects.
|
||||||
|
*
|
||||||
|
* @param embedding1 One of the two `MPPEmbedding`s between whom cosine similarity is to be
|
||||||
|
* computed.
|
||||||
|
* @param embedding2 One of the two `MPPEmbedding`s between whom cosine similarity is to be
|
||||||
|
* computed.
|
||||||
|
* @param error An optional error parameter populated when there is an error in calculating cosine
|
||||||
|
* similarity between two embeddings.
|
||||||
|
*
|
||||||
|
* @return An `NSNumber` which holds the cosine similarity of type `double`.
|
||||||
|
*/
|
||||||
|
+ (nullable NSNumber *)computeBetweenEmbedding1:(MPPEmbedding *)embedding1
|
||||||
|
andEmbedding2:(MPPEmbedding *)embedding2
|
||||||
|
error:(NSError **)error;
|
||||||
|
|
||||||
|
@end
|
||||||
|
|
||||||
|
NS_ASSUME_NONNULL_END
|
|
@ -0,0 +1,88 @@
|
||||||
|
// Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#import "mediapipe/tasks/ios/components/utils/sources/MPPCosineSimilarity.h"
|
||||||
|
|
||||||
|
#include <math.h>
|
||||||
|
|
||||||
|
#import "mediapipe/tasks/ios/common/sources/MPPCommon.h"
|
||||||
|
#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h"
|
||||||
|
|
||||||
|
@implementation MPPCosineSimilarity
|
||||||
|
|
||||||
|
+ (nullable NSNumber *)computeBetweenVector1:(NSArray<NSNumber *> *)u
|
||||||
|
andVector2:(NSArray<NSNumber *> *)v
|
||||||
|
isFloat:(BOOL)isFloat
|
||||||
|
error:(NSError **)error {
|
||||||
|
if (u.count != v.count) {
|
||||||
|
[MPPCommonUtils
|
||||||
|
createCustomError:error
|
||||||
|
withCode:MPPTasksErrorCodeInvalidArgumentError
|
||||||
|
description:[NSString stringWithFormat:@"Cannot compute cosine similarity between "
|
||||||
|
@"embeddings of different sizes (%lu vs %lu",
|
||||||
|
static_cast<u_long>(u.count),
|
||||||
|
static_cast<u_long>(v.count)]];
|
||||||
|
return nil;
|
||||||
|
}
|
||||||
|
|
||||||
|
__block double dotProduct = 0.0;
|
||||||
|
__block double normU = 0.0;
|
||||||
|
__block double normV = 0.0;
|
||||||
|
|
||||||
|
[u enumerateObjectsUsingBlock:^(NSNumber *num, NSUInteger idx, BOOL *stop) {
|
||||||
|
double uVal = 0.0;
|
||||||
|
double vVal = 0.0;
|
||||||
|
|
||||||
|
if (isFloat) {
|
||||||
|
uVal = num.floatValue;
|
||||||
|
vVal = v[idx].floatValue;
|
||||||
|
} else {
|
||||||
|
uVal = num.charValue;
|
||||||
|
vVal = v[idx].charValue;
|
||||||
|
}
|
||||||
|
|
||||||
|
dotProduct += uVal * vVal;
|
||||||
|
normU += uVal * uVal;
|
||||||
|
normV += vVal * vVal;
|
||||||
|
}];
|
||||||
|
|
||||||
|
return [NSNumber numberWithDouble:dotProduct / sqrt(normU * normV)];
|
||||||
|
}
|
||||||
|
|
||||||
|
+ (nullable NSNumber *)computeBetweenEmbedding1:(MPPEmbedding *)embedding1
|
||||||
|
andEmbedding2:(MPPEmbedding *)embedding2
|
||||||
|
error:(NSError **)error {
|
||||||
|
if (embedding1.floatEmbedding && embedding2.floatEmbedding) {
|
||||||
|
return [MPPCosineSimilarity computeBetweenVector1:embedding1.floatEmbedding
|
||||||
|
andVector2:embedding2.floatEmbedding
|
||||||
|
isFloat:YES
|
||||||
|
error:error];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (embedding1.quantizedEmbedding && embedding2.quantizedEmbedding) {
|
||||||
|
return [MPPCosineSimilarity computeBetweenVector1:embedding1.quantizedEmbedding
|
||||||
|
andVector2:embedding2.quantizedEmbedding
|
||||||
|
isFloat:NO
|
||||||
|
error:error];
|
||||||
|
}
|
||||||
|
|
||||||
|
[MPPCommonUtils
|
||||||
|
createCustomError:error
|
||||||
|
withCode:MPPTasksErrorCodeInvalidArgumentError
|
||||||
|
description:
|
||||||
|
@"Cannot compute cosine similarity between quantized and float embeddings."];
|
||||||
|
return nil;
|
||||||
|
}
|
||||||
|
|
||||||
|
@end
|
80
mediapipe/tasks/ios/test/text/text_embedder/BUILD
Normal file
80
mediapipe/tasks/ios/test/text/text_embedder/BUILD
Normal file
|
@ -0,0 +1,80 @@
|
||||||
|
load(
|
||||||
|
"@build_bazel_rules_apple//apple:ios.bzl",
|
||||||
|
"ios_unit_test",
|
||||||
|
)
|
||||||
|
load(
|
||||||
|
"@build_bazel_rules_swift//swift:swift.bzl",
|
||||||
|
"swift_library",
|
||||||
|
)
|
||||||
|
load(
|
||||||
|
"//mediapipe/tasks:ios/ios.bzl",
|
||||||
|
"MPP_TASK_MINIMUM_OS_VERSION",
|
||||||
|
)
|
||||||
|
load(
|
||||||
|
"@org_tensorflow//tensorflow/lite:special_rules.bzl",
|
||||||
|
"tflite_ios_lab_runner",
|
||||||
|
)
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
# Default tags for filtering iOS targets. Targets are restricted to Apple platforms.
|
||||||
|
TFL_DEFAULT_TAGS = [
|
||||||
|
"apple",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Following sanitizer tests are not supported by iOS test targets.
|
||||||
|
TFL_DISABLED_SANITIZER_TAGS = [
|
||||||
|
"noasan",
|
||||||
|
"nomsan",
|
||||||
|
"notsan",
|
||||||
|
]
|
||||||
|
|
||||||
|
objc_library(
|
||||||
|
name = "MPPTextEmbedderObjcTestLibrary",
|
||||||
|
testonly = 1,
|
||||||
|
srcs = ["MPPTextEmbedderTests.m"],
|
||||||
|
data = [
|
||||||
|
"//mediapipe/tasks/testdata/text:mobilebert_embedding_model",
|
||||||
|
"//mediapipe/tasks/testdata/text:regex_embedding_with_metadata",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/tasks/ios/common:MPPCommon",
|
||||||
|
"//mediapipe/tasks/ios/text/text_embedder:MPPTextEmbedder",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
ios_unit_test(
|
||||||
|
name = "MPPTextEmbedderObjcTest",
|
||||||
|
minimum_os_version = MPP_TASK_MINIMUM_OS_VERSION,
|
||||||
|
runner = tflite_ios_lab_runner("IOS_LATEST"),
|
||||||
|
deps = [
|
||||||
|
":MPPTextEmbedderObjcTestLibrary",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
swift_library(
|
||||||
|
name = "MPPTextEmbedderSwiftTestLibrary",
|
||||||
|
testonly = 1,
|
||||||
|
srcs = ["TextEmbedderTests.swift"],
|
||||||
|
data = [
|
||||||
|
"//mediapipe/tasks/testdata/text:mobilebert_embedding_model",
|
||||||
|
"//mediapipe/tasks/testdata/text:regex_embedding_with_metadata",
|
||||||
|
],
|
||||||
|
tags = TFL_DEFAULT_TAGS,
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/tasks/ios/common:MPPCommon",
|
||||||
|
"//mediapipe/tasks/ios/text/text_embedder:MPPTextEmbedder",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
ios_unit_test(
|
||||||
|
name = "MPPTextEmbedderSwiftTest",
|
||||||
|
minimum_os_version = MPP_TASK_MINIMUM_OS_VERSION,
|
||||||
|
runner = tflite_ios_lab_runner("IOS_LATEST"),
|
||||||
|
tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS,
|
||||||
|
deps = [
|
||||||
|
":MPPTextEmbedderSwiftTestLibrary",
|
||||||
|
],
|
||||||
|
)
|
|
@ -0,0 +1,246 @@
|
||||||
|
// Copyright 2023 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#import <XCTest/XCTest.h>
|
||||||
|
|
||||||
|
#import "mediapipe/tasks/ios/common/sources/MPPCommon.h"
|
||||||
|
#import "mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.h"
|
||||||
|
|
||||||
|
static NSString *const kBertTextEmbedderModelName = @"mobilebert_embedding_with_metadata";
|
||||||
|
static NSString *const kRegexTextEmbedderModelName = @"regex_one_embedding_with_metadata";
|
||||||
|
static NSString *const kText1 = @"it's a charming and often affecting journey";
|
||||||
|
static NSString *const kText2 = @"what a great and fantastic trip";
|
||||||
|
static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
|
||||||
|
static const float kFloatDiffTolerance = 1e-4;
|
||||||
|
static const float kSimilarityDiffTolerance = 1e-4;
|
||||||
|
|
||||||
|
#define AssertEqualErrors(error, expectedError) \
|
||||||
|
XCTAssertNotNil(error); \
|
||||||
|
XCTAssertEqualObjects(error.domain, expectedError.domain); \
|
||||||
|
XCTAssertEqual(error.code, expectedError.code); \
|
||||||
|
XCTAssertNotEqual( \
|
||||||
|
[error.localizedDescription rangeOfString:expectedError.localizedDescription].location, \
|
||||||
|
NSNotFound)
|
||||||
|
|
||||||
|
#define AssertTextEmbedderResultHasOneEmbedding(textEmbedderResult) \
|
||||||
|
XCTAssertNotNil(textEmbedderResult); \
|
||||||
|
XCTAssertNotNil(textEmbedderResult.embeddingResult); \
|
||||||
|
XCTAssertEqual(textEmbedderResult.embeddingResult.embeddings.count, 1);
|
||||||
|
|
||||||
|
#define AssertEmbeddingType(embedding, quantized) \
|
||||||
|
if (quantized) { \
|
||||||
|
XCTAssertNil(embedding.floatEmbedding); \
|
||||||
|
XCTAssertNotNil(embedding.quantizedEmbedding); \
|
||||||
|
} else { \
|
||||||
|
XCTAssertNotNil(embedding.floatEmbedding); \
|
||||||
|
XCTAssertNil(embedding.quantizedEmbedding); \
|
||||||
|
}
|
||||||
|
|
||||||
|
#define AssertEmbeddingHasExpectedValues(embedding, expectedLength, expectedFirstValue, quantize) \
|
||||||
|
XCTAssertEqual(embedding.count, expectedLength); \
|
||||||
|
if (quantize) { \
|
||||||
|
XCTAssertEqual(embedding[0].charValue, expectedFirstValue); \
|
||||||
|
} else { \
|
||||||
|
XCTAssertEqualWithAccuracy(embedding[0].floatValue, expectedFirstValue, kFloatDiffTolerance); \
|
||||||
|
}
|
||||||
|
|
||||||
|
@interface MPPTextEmbedderTests : XCTestCase
|
||||||
|
@end
|
||||||
|
|
||||||
|
@implementation MPPTextEmbedderTests
|
||||||
|
|
||||||
|
- (NSString *)filePathWithName:(NSString *)fileName extension:(NSString *)extension {
|
||||||
|
return [[NSBundle bundleForClass:self.class] pathForResource:fileName
|
||||||
|
ofType:extension];
|
||||||
|
}
|
||||||
|
|
||||||
|
- (MPPTextEmbedder *)textEmbedderFromModelFileWithName:(NSString *)modelName {
|
||||||
|
NSString *modelPath = [self filePathWithName:modelName extension:@"tflite"];
|
||||||
|
|
||||||
|
NSError *error = nil;
|
||||||
|
MPPTextEmbedder *textEmbedder = [[MPPTextEmbedder alloc] initWithModelPath:modelPath
|
||||||
|
error:&error];
|
||||||
|
|
||||||
|
XCTAssertNotNil(textEmbedder);
|
||||||
|
|
||||||
|
return textEmbedder;
|
||||||
|
}
|
||||||
|
|
||||||
|
- (MPPTextEmbedderOptions *)textEmbedderOptionsWithModelName:(NSString *)modelName {
|
||||||
|
NSString *modelPath = [self filePathWithName:modelName extension:@"tflite"];
|
||||||
|
MPPTextEmbedderOptions *textEmbedderOptions = [[MPPTextEmbedderOptions alloc] init];
|
||||||
|
textEmbedderOptions.baseOptions.modelAssetPath = modelPath;
|
||||||
|
|
||||||
|
return textEmbedderOptions;
|
||||||
|
}
|
||||||
|
|
||||||
|
- (MPPEmbedding *)assertFloatEmbeddingResultsOfEmbedText:(NSString *)text
|
||||||
|
usingTextEmbedder:(MPPTextEmbedder *)textEmbedder
|
||||||
|
hasCount:(NSUInteger)embeddingCount
|
||||||
|
firstValue:(float)firstValue {
|
||||||
|
MPPTextEmbedderResult *embedderResult = [textEmbedder embedText:text error:nil];
|
||||||
|
AssertTextEmbedderResultHasOneEmbedding(embedderResult);
|
||||||
|
|
||||||
|
AssertEmbeddingType(embedderResult.embeddingResult.embeddings[0], // embedding
|
||||||
|
NO // quantized
|
||||||
|
);
|
||||||
|
|
||||||
|
AssertEmbeddingHasExpectedValues(
|
||||||
|
embedderResult.embeddingResult.embeddings[0].floatEmbedding, // embedding
|
||||||
|
embeddingCount, // expectedLength
|
||||||
|
firstValue, // expectedFirstValue
|
||||||
|
NO // quantize
|
||||||
|
);
|
||||||
|
|
||||||
|
return embedderResult.embeddingResult.embeddings[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
- (MPPEmbedding *)assertQuantizedEmbeddingResultsOfEmbedText:(NSString *)text
|
||||||
|
usingTextEmbedder:(MPPTextEmbedder *)textEmbedder
|
||||||
|
hasCount:(NSUInteger)embeddingCount
|
||||||
|
firstValue:(char)firstValue {
|
||||||
|
MPPTextEmbedderResult *embedderResult = [textEmbedder embedText:text error:nil];
|
||||||
|
AssertTextEmbedderResultHasOneEmbedding(embedderResult);
|
||||||
|
|
||||||
|
AssertEmbeddingType(embedderResult.embeddingResult.embeddings[0], // embedding
|
||||||
|
YES // quantized
|
||||||
|
);
|
||||||
|
|
||||||
|
AssertEmbeddingHasExpectedValues(
|
||||||
|
embedderResult.embeddingResult.embeddings[0].quantizedEmbedding, // embedding
|
||||||
|
embeddingCount, // expectedLength
|
||||||
|
firstValue, // expectedFirstValue
|
||||||
|
YES // quantize
|
||||||
|
);
|
||||||
|
|
||||||
|
return embedderResult.embeddingResult.embeddings[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)testCreateTextEmbedderFailsWithMissingModelPath {
|
||||||
|
NSString *modelPath = [self filePathWithName:@"" extension:@""];
|
||||||
|
|
||||||
|
NSError *error = nil;
|
||||||
|
MPPTextEmbedder *textEmbedder = [[MPPTextEmbedder alloc] initWithModelPath:modelPath
|
||||||
|
error:&error];
|
||||||
|
XCTAssertNil(textEmbedder);
|
||||||
|
|
||||||
|
NSError *expectedError = [NSError
|
||||||
|
errorWithDomain:kExpectedErrorDomain
|
||||||
|
code:MPPTasksErrorCodeInvalidArgumentError
|
||||||
|
userInfo:@{
|
||||||
|
NSLocalizedDescriptionKey :
|
||||||
|
@"INVALID_ARGUMENT: ExternalFile must specify at least one of 'file_content', "
|
||||||
|
@"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."
|
||||||
|
}];
|
||||||
|
AssertEqualErrors(error, // error
|
||||||
|
expectedError // expectedError
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)testEmbedWithBertSucceeds {
|
||||||
|
MPPTextEmbedder *textEmbedder =
|
||||||
|
[self textEmbedderFromModelFileWithName:kBertTextEmbedderModelName];
|
||||||
|
|
||||||
|
MPPEmbedding *embedding1 =
|
||||||
|
[self assertFloatEmbeddingResultsOfEmbedText:kText1
|
||||||
|
usingTextEmbedder:textEmbedder
|
||||||
|
hasCount:512
|
||||||
|
firstValue:21.214869f];
|
||||||
|
|
||||||
|
MPPEmbedding *embedding2 = [self assertFloatEmbeddingResultsOfEmbedText:kText2
|
||||||
|
usingTextEmbedder:textEmbedder
|
||||||
|
hasCount:512
|
||||||
|
firstValue:22.626251f];
|
||||||
|
NSNumber *cosineSimilarity = [MPPTextEmbedder cosineSimilarityBetweenEmbedding1:embedding1
|
||||||
|
andEmbedding2:embedding2
|
||||||
|
error:nil];
|
||||||
|
|
||||||
|
XCTAssertEqualWithAccuracy(cosineSimilarity.doubleValue, 0.971417490189,
|
||||||
|
kSimilarityDiffTolerance);
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)testEmbedWithRegexSucceeds {
|
||||||
|
MPPTextEmbedder *textEmbedder =
|
||||||
|
[self textEmbedderFromModelFileWithName:kRegexTextEmbedderModelName];
|
||||||
|
|
||||||
|
MPPEmbedding *embedding1 = [self assertFloatEmbeddingResultsOfEmbedText:kText1
|
||||||
|
usingTextEmbedder:textEmbedder
|
||||||
|
hasCount:16
|
||||||
|
firstValue:0.030935612f];
|
||||||
|
|
||||||
|
MPPEmbedding *embedding2 = [self assertFloatEmbeddingResultsOfEmbedText:kText2
|
||||||
|
usingTextEmbedder:textEmbedder
|
||||||
|
hasCount:16
|
||||||
|
firstValue:0.0312863f];
|
||||||
|
|
||||||
|
NSNumber *cosineSimilarity = [MPPTextEmbedder cosineSimilarityBetweenEmbedding1:embedding1
|
||||||
|
andEmbedding2:embedding2
|
||||||
|
error:nil];
|
||||||
|
|
||||||
|
XCTAssertEqualWithAccuracy(cosineSimilarity.doubleValue, 0.999937f, kSimilarityDiffTolerance);
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)testEmbedWithBertAndDifferentThemesSucceeds {
|
||||||
|
MPPTextEmbedder *textEmbedder =
|
||||||
|
[self textEmbedderFromModelFileWithName:kBertTextEmbedderModelName];
|
||||||
|
|
||||||
|
MPPEmbedding *embedding1 =
|
||||||
|
[self assertFloatEmbeddingResultsOfEmbedText:
|
||||||
|
@"When you go to this restaurant, they hold the pancake upside-down before they "
|
||||||
|
@"hand it to you. It's a great gimmick."
|
||||||
|
usingTextEmbedder:textEmbedder
|
||||||
|
hasCount:512
|
||||||
|
firstValue:43.1663];
|
||||||
|
|
||||||
|
MPPEmbedding *embedding2 =
|
||||||
|
[self assertFloatEmbeddingResultsOfEmbedText:
|
||||||
|
@"Let's make a plan to steal the declaration of independence."
|
||||||
|
usingTextEmbedder:textEmbedder
|
||||||
|
hasCount:512
|
||||||
|
firstValue:48.0898];
|
||||||
|
|
||||||
|
NSNumber *cosineSimilarity = [MPPTextEmbedder cosineSimilarityBetweenEmbedding1:embedding1
|
||||||
|
andEmbedding2:embedding2
|
||||||
|
error:nil];
|
||||||
|
|
||||||
|
// TODO: The similarity should likely be lower
|
||||||
|
XCTAssertEqualWithAccuracy(cosineSimilarity.doubleValue, 0.98151f, kSimilarityDiffTolerance);
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)testEmbedWithQuantizeSucceeds {
|
||||||
|
MPPTextEmbedderOptions *options =
|
||||||
|
[self textEmbedderOptionsWithModelName:kBertTextEmbedderModelName];
|
||||||
|
options.quantize = YES;
|
||||||
|
|
||||||
|
MPPTextEmbedder *textEmbedder = [[MPPTextEmbedder alloc] initWithOptions:options error:nil];
|
||||||
|
XCTAssertNotNil(textEmbedder);
|
||||||
|
|
||||||
|
MPPEmbedding *embedding1 = [self
|
||||||
|
assertQuantizedEmbeddingResultsOfEmbedText:@"it's a charming and often affecting journey"
|
||||||
|
usingTextEmbedder:textEmbedder
|
||||||
|
hasCount:512
|
||||||
|
firstValue:127];
|
||||||
|
|
||||||
|
MPPEmbedding *embedding2 =
|
||||||
|
[self assertQuantizedEmbeddingResultsOfEmbedText:@"what a great and fantastic trip"
|
||||||
|
usingTextEmbedder:textEmbedder
|
||||||
|
hasCount:512
|
||||||
|
firstValue:127];
|
||||||
|
NSNumber *cosineSimilarity = [MPPTextEmbedder cosineSimilarityBetweenEmbedding1:embedding1
|
||||||
|
andEmbedding2:embedding2
|
||||||
|
error:nil];
|
||||||
|
XCTAssertEqualWithAccuracy(cosineSimilarity.doubleValue, 0.88164f, kSimilarityDiffTolerance);
|
||||||
|
}
|
||||||
|
|
||||||
|
@end
|
|
@ -0,0 +1,121 @@
|
||||||
|
// Copyright 2023 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
import MPPCommon
|
||||||
|
import XCTest
|
||||||
|
|
||||||
|
@testable import MPPTextEmbedder
|
||||||
|
|
||||||
|
/// These tests are only for validating the Swift function signatures of the TextEmbedder.
|
||||||
|
/// Objective C tests of the TextEmbedder provide more coverage with unit tests for
|
||||||
|
/// different models and text embedder options. They can be found here:
|
||||||
|
/// /mediapipe/tasks/ios/test/text/text_embedder/MPPTextEmbedderTests.m
|
||||||
|
|
||||||
|
class TextEmbedderTests: XCTestCase {
|
||||||
|
|
||||||
|
static let bundle = Bundle(for: TextEmbedderTests.self)
|
||||||
|
|
||||||
|
static let bertModelPath = bundle.path(
|
||||||
|
forResource: "mobilebert_embedding_with_metadata",
|
||||||
|
ofType: "tflite")
|
||||||
|
|
||||||
|
static let text1 = "it's a charming and often affecting journey"
|
||||||
|
|
||||||
|
static let text2 = "what a great and fantastic trip"
|
||||||
|
|
||||||
|
static let floatDiffTolerance: Float = 1e-4
|
||||||
|
|
||||||
|
static let doubleDiffTolerance: Double = 1e-4
|
||||||
|
|
||||||
|
func assertEqualErrorDescriptions(
|
||||||
|
_ error: Error, expectedLocalizedDescription: String
|
||||||
|
) {
|
||||||
|
XCTAssertEqual(
|
||||||
|
error.localizedDescription,
|
||||||
|
expectedLocalizedDescription)
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertTextEmbedderResultHasOneEmbedding(
|
||||||
|
_ textEmbedderResult: TextEmbedderResult
|
||||||
|
) {
|
||||||
|
XCTAssertEqual(textEmbedderResult.embeddingResult.embeddings.count, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertEmbeddingIsFloat(
|
||||||
|
_ embedding: Embedding
|
||||||
|
) {
|
||||||
|
XCTAssertNil(embedding.quantizedEmbedding)
|
||||||
|
XCTAssertNotNil(embedding.floatEmbedding)
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertEmbedding(
|
||||||
|
_ floatEmbedding: [NSNumber],
|
||||||
|
hasCount embeddingCount: Int,
|
||||||
|
hasFirstValue firstValue: Float
|
||||||
|
) {
|
||||||
|
XCTAssertEqual(floatEmbedding.count, embeddingCount)
|
||||||
|
XCTAssertEqual(
|
||||||
|
floatEmbedding[0].floatValue,
|
||||||
|
firstValue,
|
||||||
|
accuracy:
|
||||||
|
TextEmbedderTests.floatDiffTolerance)
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertFloatEmbeddingResultsForEmbed(
|
||||||
|
text: String,
|
||||||
|
using textEmbedder: TextEmbedder,
|
||||||
|
hasCount embeddingCount: Int,
|
||||||
|
hasFirstValue firstValue: Float
|
||||||
|
) throws -> Embedding {
|
||||||
|
let textEmbedderResult =
|
||||||
|
try XCTUnwrap(
|
||||||
|
textEmbedder.embed(text: text))
|
||||||
|
assertTextEmbedderResultHasOneEmbedding(textEmbedderResult)
|
||||||
|
assertEmbeddingIsFloat(textEmbedderResult.embeddingResult.embeddings[0])
|
||||||
|
assertEmbedding(
|
||||||
|
textEmbedderResult.embeddingResult.embeddings[0].floatEmbedding!,
|
||||||
|
hasCount: embeddingCount,
|
||||||
|
hasFirstValue: firstValue)
|
||||||
|
|
||||||
|
return textEmbedderResult.embeddingResult.embeddings[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
func testEmbedWithBertSucceeds() throws {
|
||||||
|
|
||||||
|
let modelPath = try XCTUnwrap(TextEmbedderTests.bertModelPath)
|
||||||
|
let textEmbedder = try XCTUnwrap(TextEmbedder(modelPath: modelPath))
|
||||||
|
|
||||||
|
let embedding1 = try assertFloatEmbeddingResultsForEmbed(
|
||||||
|
text: TextEmbedderTests.text1,
|
||||||
|
using: textEmbedder,
|
||||||
|
hasCount: 512,
|
||||||
|
hasFirstValue: 21.214869)
|
||||||
|
|
||||||
|
let embedding2 = try assertFloatEmbeddingResultsForEmbed(
|
||||||
|
text: TextEmbedderTests.text2,
|
||||||
|
using: textEmbedder,
|
||||||
|
hasCount: 512,
|
||||||
|
hasFirstValue: 22.626251)
|
||||||
|
|
||||||
|
let cosineSimilarity = try XCTUnwrap(
|
||||||
|
TextEmbedder.cosineSimilarity(
|
||||||
|
embedding1: embedding1,
|
||||||
|
embedding2: embedding2))
|
||||||
|
|
||||||
|
XCTAssertEqual(
|
||||||
|
cosineSimilarity.doubleValue,
|
||||||
|
0.97141,
|
||||||
|
accuracy: TextEmbedderTests.doubleDiffTolerance)
|
||||||
|
}
|
||||||
|
}
|
|
@ -49,6 +49,7 @@ objc_library(
|
||||||
"//mediapipe/tasks/cc/text/text_embedder:text_embedder_graph",
|
"//mediapipe/tasks/cc/text/text_embedder:text_embedder_graph",
|
||||||
"//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
|
"//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
|
||||||
"//mediapipe/tasks/ios/common/utils:NSStringHelpers",
|
"//mediapipe/tasks/ios/common/utils:NSStringHelpers",
|
||||||
|
"//mediapipe/tasks/ios/components/utils:MPPCosineSimilarity",
|
||||||
"//mediapipe/tasks/ios/core:MPPTaskInfo",
|
"//mediapipe/tasks/ios/core:MPPTaskInfo",
|
||||||
"//mediapipe/tasks/ios/core:MPPTaskOptions",
|
"//mediapipe/tasks/ios/core:MPPTaskOptions",
|
||||||
"//mediapipe/tasks/ios/core:MPPTextPacketCreator",
|
"//mediapipe/tasks/ios/core:MPPTextPacketCreator",
|
||||||
|
|
|
@ -86,6 +86,24 @@ NS_SWIFT_NAME(TextEmbedder)
|
||||||
|
|
||||||
- (instancetype)init NS_UNAVAILABLE;
|
- (instancetype)init NS_UNAVAILABLE;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Utility function to compute[cosine similarity](https://en.wikipedia.org/wiki/Cosine_similarity)
|
||||||
|
* between two `MPPEmbedding` objects.
|
||||||
|
*
|
||||||
|
* @param embedding1 One of the two `MPPEmbedding`s between whom cosine similarity is to be
|
||||||
|
* computed.
|
||||||
|
* @param embedding2 One of the two `MPPEmbedding`s between whom cosine similarity is to be
|
||||||
|
* computed.
|
||||||
|
* @param error An optional error parameter populated when there is an error in calculating cosine
|
||||||
|
* similarity between two embeddings.
|
||||||
|
*
|
||||||
|
* @return An `NSNumber` which holds the cosine similarity of type `double`.
|
||||||
|
*/
|
||||||
|
+ (nullable NSNumber *)cosineSimilarityBetweenEmbedding1:(MPPEmbedding *)embedding1
|
||||||
|
andEmbedding2:(MPPEmbedding *)embedding2
|
||||||
|
error:(NSError **)error
|
||||||
|
NS_SWIFT_NAME(cosineSimilarity(embedding1:embedding2:));
|
||||||
|
|
||||||
+ (instancetype)new NS_UNAVAILABLE;
|
+ (instancetype)new NS_UNAVAILABLE;
|
||||||
|
|
||||||
@end
|
@end
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h"
|
#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h"
|
||||||
#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h"
|
#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h"
|
||||||
|
#import "mediapipe/tasks/ios/components/utils/sources/MPPCosineSimilarity.h"
|
||||||
#import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h"
|
#import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h"
|
||||||
#import "mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.h"
|
#import "mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.h"
|
||||||
#import "mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.h"
|
#import "mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.h"
|
||||||
|
@ -93,4 +94,12 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_embedder.Tex
|
||||||
.value()[kEmbeddingsOutStreamName.cppString]];
|
.value()[kEmbeddingsOutStreamName.cppString]];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
+ (nullable NSNumber *)cosineSimilarityBetweenEmbedding1:(MPPEmbedding *)embedding1
|
||||||
|
andEmbedding2:(MPPEmbedding *)embedding2
|
||||||
|
error:(NSError **)error {
|
||||||
|
return [MPPCosineSimilarity computeBetweenEmbedding1:embedding1
|
||||||
|
andEmbedding2:embedding2
|
||||||
|
error:error];
|
||||||
|
}
|
||||||
|
|
||||||
@end
|
@end
|
||||||
|
|
Loading…
Reference in New Issue
Block a user