Added iOS test for quantized embedding
This commit is contained in:
parent
b5b10e7681
commit
3b55fb9f6a
|
@ -23,7 +23,7 @@ static NSString *const kText1 = @"it's a charming and often affecting journey";
|
||||||
static NSString *const kText2 = @"what a great and fantastic trip";
|
static NSString *const kText2 = @"what a great and fantastic trip";
|
||||||
static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
|
static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
|
||||||
static const float kFloatDiffTolerance = 1e-4;
|
static const float kFloatDiffTolerance = 1e-4;
|
||||||
static const float kDoubleDiffTolerance = 1e-4;
|
static const float kSimilarityDiffTolerance = 1e-4;
|
||||||
|
|
||||||
#define AssertEqualErrors(error, expectedError) \
|
#define AssertEqualErrors(error, expectedError) \
|
||||||
XCTAssertNotNil(error); \
|
XCTAssertNotNil(error); \
|
||||||
|
@ -38,13 +38,24 @@ static const float kDoubleDiffTolerance = 1e-4;
|
||||||
XCTAssertNotNil(textEmbedderResult.embeddingResult); \
|
XCTAssertNotNil(textEmbedderResult.embeddingResult); \
|
||||||
XCTAssertEqual(textEmbedderResult.embeddingResult.embeddings.count, 1);
|
XCTAssertEqual(textEmbedderResult.embeddingResult.embeddings.count, 1);
|
||||||
|
|
||||||
#define AssertEmbeddingIsFloat(embedding) \
|
#define AssertEmbeddingType(embedding, quantized) \
|
||||||
|
if (quantized) { \
|
||||||
|
XCTAssertNil(embedding.floatEmbedding); \
|
||||||
|
XCTAssertNotNil(embedding.quantizedEmbedding); \
|
||||||
|
} \
|
||||||
|
else { \
|
||||||
XCTAssertNotNil(embedding.floatEmbedding); \
|
XCTAssertNotNil(embedding.floatEmbedding); \
|
||||||
XCTAssertNil(embedding.quantizedEmbedding);
|
XCTAssertNil(embedding.quantizedEmbedding);\
|
||||||
|
}
|
||||||
|
|
||||||
#define AssertFloatEmbeddingHasExpectedValues(floatEmbedding, expectedLength, expectedFirstValue) \
|
#define AssertEmbeddingHasExpectedValues(embedding, expectedLength, expectedFirstValue, quantize) \
|
||||||
XCTAssertEqual(floatEmbedding.count, expectedLength); \
|
XCTAssertEqual(embedding.count, expectedLength); \
|
||||||
XCTAssertEqualWithAccuracy(floatEmbedding[0].floatValue, expectedFirstValue, kFloatDiffTolerance);
|
if (quantize) { \
|
||||||
|
XCTAssertEqual(embedding[0].charValue, expectedFirstValue); \
|
||||||
|
} \
|
||||||
|
else { \
|
||||||
|
XCTAssertEqualWithAccuracy(embedding[0].floatValue, expectedFirstValue, kFloatDiffTolerance); \
|
||||||
|
} \
|
||||||
|
|
||||||
@interface MPPTextEmbedderTests : XCTestCase
|
@interface MPPTextEmbedderTests : XCTestCase
|
||||||
@end
|
@end
|
||||||
|
@ -69,15 +80,55 @@ static const float kDoubleDiffTolerance = 1e-4;
|
||||||
return textEmbedder;
|
return textEmbedder;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
- (MPPTextEmbedderOptions *)textEmbedderOptionsWithModelName:(NSString *)modelName {
|
||||||
|
NSString *modelPath = [self filePathWithName:modelName extension:@"tflite"];
|
||||||
|
MPPTextEmbedderOptions *textEmbedderOptions = [[MPPTextEmbedderOptions alloc] init];
|
||||||
|
textEmbedderOptions.baseOptions.modelAssetPath = modelPath;
|
||||||
|
|
||||||
|
return textEmbedderOptions;
|
||||||
|
}
|
||||||
|
|
||||||
- (NSArray<NSNumber *> *)assertFloatEmbeddingResultsOfEmbedText:(NSString *)text
|
- (NSArray<NSNumber *> *)assertFloatEmbeddingResultsOfEmbedText:(NSString *)text
|
||||||
usingTextEmbedder:(MPPTextEmbedder *)textEmbedder
|
usingTextEmbedder:(MPPTextEmbedder *)textEmbedder
|
||||||
hasCount:(NSUInteger)embeddingCount
|
hasCount:(NSUInteger)embeddingCount
|
||||||
firstValue:(float)firstValue {
|
firstValue:(float)firstValue {
|
||||||
MPPTextEmbedderResult *embedderResult = [textEmbedder embedText:text error:nil];
|
MPPTextEmbedderResult *embedderResult = [textEmbedder embedText:text error:nil];
|
||||||
AssertTextEmbedderResultHasOneEmbedding(embedderResult);
|
AssertTextEmbedderResultHasOneEmbedding(embedderResult);
|
||||||
AssertEmbeddingIsFloat(embedderResult.embeddingResult.embeddings[0]);
|
|
||||||
AssertFloatEmbeddingHasExpectedValues(embedderResult.embeddingResult.embeddings[0].floatEmbedding,
|
AssertEmbeddingType(
|
||||||
embeddingCount, firstValue);
|
embedderResult.embeddingResult.embeddings[0], // embedding
|
||||||
|
NO // quantized
|
||||||
|
);
|
||||||
|
|
||||||
|
AssertEmbeddingHasExpectedValues(
|
||||||
|
embedderResult.embeddingResult.embeddings[0].floatEmbedding, // embedding
|
||||||
|
embeddingCount, // expectedLength
|
||||||
|
firstValue, // expectedFirstValue
|
||||||
|
NO // quantize
|
||||||
|
);
|
||||||
|
|
||||||
|
return embedderResult.embeddingResult.embeddings[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
- (NSArray<NSNumber *>*)assertQuantizedEmbeddingResultsOfEmbedText:(NSString *)text
|
||||||
|
usingTextEmbedder:(MPPTextEmbedder *)textEmbedder
|
||||||
|
hasCount:(NSUInteger)embeddingCount
|
||||||
|
firstValue:(char)firstValue {
|
||||||
|
MPPTextEmbedderResult *embedderResult = [textEmbedder embedText:text error:nil];
|
||||||
|
AssertTextEmbedderResultHasOneEmbedding(embedderResult);
|
||||||
|
|
||||||
|
AssertEmbeddingType(
|
||||||
|
embedderResult.embeddingResult.embeddings[0], // embedding
|
||||||
|
YES // quantized
|
||||||
|
);
|
||||||
|
|
||||||
|
AssertEmbeddingHasExpectedValues(
|
||||||
|
embedderResult.embeddingResult.embeddings[0].quantizedEmbedding, // embedding
|
||||||
|
embeddingCount, // expectedLength
|
||||||
|
firstValue, // expectedFirstValue
|
||||||
|
YES // quantize
|
||||||
|
);
|
||||||
|
|
||||||
return embedderResult.embeddingResult.embeddings[0];
|
return embedderResult.embeddingResult.embeddings[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -97,7 +148,10 @@ static const float kDoubleDiffTolerance = 1e-4;
|
||||||
@"INVALID_ARGUMENT: ExternalFile must specify at least one of 'file_content', "
|
@"INVALID_ARGUMENT: ExternalFile must specify at least one of 'file_content', "
|
||||||
@"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."
|
@"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."
|
||||||
}];
|
}];
|
||||||
AssertEqualErrors(error, expectedError);
|
AssertEqualErrors(
|
||||||
|
error, // error
|
||||||
|
expectedError // expectedError
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
- (void)testEmbedWithBertSucceeds {
|
- (void)testEmbedWithBertSucceeds {
|
||||||
|
@ -116,7 +170,7 @@ static const float kDoubleDiffTolerance = 1e-4;
|
||||||
NSNumber *cosineSimilarity = [MPPTextEmbedder cosineSimilarityBetweenEmbedding1:embedding1
|
NSNumber *cosineSimilarity = [MPPTextEmbedder cosineSimilarityBetweenEmbedding1:embedding1
|
||||||
andEmbedding2:embedding2
|
andEmbedding2:embedding2
|
||||||
error:nil];
|
error:nil];
|
||||||
XCTAssertEqualWithAccuracy(cosineSimilarity.doubleValue, 0.96386, kDoubleDiffTolerance);
|
XCTAssertEqualWithAccuracy(cosineSimilarity.doubleValue, 0.96386, kSimilarityDiffTolerance);
|
||||||
}
|
}
|
||||||
|
|
||||||
- (void)testEmbedWithRegexSucceeds {
|
- (void)testEmbedWithRegexSucceeds {
|
||||||
|
@ -136,7 +190,7 @@ static const float kDoubleDiffTolerance = 1e-4;
|
||||||
NSNumber *cosineSimilarity = [MPPTextEmbedder cosineSimilarityBetweenEmbedding1:embedding1
|
NSNumber *cosineSimilarity = [MPPTextEmbedder cosineSimilarityBetweenEmbedding1:embedding1
|
||||||
andEmbedding2:embedding2
|
andEmbedding2:embedding2
|
||||||
error:nil];
|
error:nil];
|
||||||
XCTAssertEqualWithAccuracy(cosineSimilarity.doubleValue, 0.999937f, kDoubleDiffTolerance);
|
XCTAssertEqualWithAccuracy(cosineSimilarity.doubleValue, 0.999937f, kSimilarityDiffTolerance);
|
||||||
}
|
}
|
||||||
|
|
||||||
- (void)testEmbedWithBertAndDifferentThemesSucceeds {
|
- (void)testEmbedWithBertAndDifferentThemesSucceeds {
|
||||||
|
@ -163,7 +217,28 @@ static const float kDoubleDiffTolerance = 1e-4;
|
||||||
error:nil];
|
error:nil];
|
||||||
|
|
||||||
// TODO: The similarity should likely be lower
|
// TODO: The similarity should likely be lower
|
||||||
XCTAssertEqualWithAccuracy(cosineSimilarity.doubleValue, 0.963203f, kDoubleDiffTolerance);
|
XCTAssertEqualWithAccuracy(cosineSimilarity.doubleValue, 0.963203f, kSimilarityDiffTolerance);
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)testEmbedWithQuantizeSucceeds {
|
||||||
|
MPPTextEmbedderOptions *options =
|
||||||
|
[self textEmbedderOptionsWithModelName:kBertTextEmbedderModelName];
|
||||||
|
options.quantize = YES;
|
||||||
|
|
||||||
|
MPPTextEmbedder *textEmbedder = [[MPPTextEmbedder alloc] initWithOptions:options error:nil];
|
||||||
|
XCTAssertNotNil(textEmbedder);
|
||||||
|
|
||||||
|
MPPEmbedding *embedding1 = [self assertQuantizedEmbeddingResultsOfEmbedText:@"it's a charming and often affecting journey"
|
||||||
|
usingTextEmbedder:textEmbedder
|
||||||
|
hasCount:512
|
||||||
|
firstValue:127];
|
||||||
|
|
||||||
|
MPPEmbedding *embedding2 = [self assertQuantizedEmbeddingResultsOfEmbedText:@"what a great and fantastic trip"
|
||||||
|
usingTextEmbedder:textEmbedder
|
||||||
|
hasCount:512
|
||||||
|
firstValue:127];
|
||||||
|
NSNumber *cosineSimilarity = [MPPTextEmbedder cosineSimilarityBetweenEmbedding1:embedding1 andEmbedding2:embedding2 error:nil];
|
||||||
|
XCTAssertEqualWithAccuracy(cosineSimilarity.doubleValue, 0.864113, kSimilarityDiffTolerance);
|
||||||
}
|
}
|
||||||
|
|
||||||
@end
|
@end
|
||||||
|
|
Loading…
Reference in New Issue
Block a user