Added iOS test for quantized embedding

This commit is contained in:
Prianka Liz Kariat 2023-02-03 13:42:32 +05:30
parent b5b10e7681
commit 3b55fb9f6a

View File

@ -23,7 +23,7 @@ static NSString *const kText1 = @"it's a charming and often affecting journey";
static NSString *const kText2 = @"what a great and fantastic trip";
static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
static const float kFloatDiffTolerance = 1e-4;
static const float kDoubleDiffTolerance = 1e-4;
static const float kSimilarityDiffTolerance = 1e-4;
#define AssertEqualErrors(error, expectedError) \
XCTAssertNotNil(error); \
@ -38,13 +38,24 @@ static const float kDoubleDiffTolerance = 1e-4;
XCTAssertNotNil(textEmbedderResult.embeddingResult); \
XCTAssertEqual(textEmbedderResult.embeddingResult.embeddings.count, 1);
#define AssertEmbeddingIsFloat(embedding) \
#define AssertEmbeddingType(embedding, quantized) \
if (quantized) { \
XCTAssertNil(embedding.floatEmbedding); \
XCTAssertNotNil(embedding.quantizedEmbedding); \
} \
else { \
XCTAssertNotNil(embedding.floatEmbedding); \
XCTAssertNil(embedding.quantizedEmbedding);
XCTAssertNil(embedding.quantizedEmbedding);\
}
#define AssertFloatEmbeddingHasExpectedValues(floatEmbedding, expectedLength, expectedFirstValue) \
XCTAssertEqual(floatEmbedding.count, expectedLength); \
XCTAssertEqualWithAccuracy(floatEmbedding[0].floatValue, expectedFirstValue, kFloatDiffTolerance);
#define AssertEmbeddingHasExpectedValues(embedding, expectedLength, expectedFirstValue, quantize) \
XCTAssertEqual(embedding.count, expectedLength); \
if (quantize) { \
XCTAssertEqual(embedding[0].charValue, expectedFirstValue); \
} \
else { \
XCTAssertEqualWithAccuracy(embedding[0].floatValue, expectedFirstValue, kFloatDiffTolerance); \
} \
@interface MPPTextEmbedderTests : XCTestCase
@end
@ -69,15 +80,55 @@ static const float kDoubleDiffTolerance = 1e-4;
return textEmbedder;
}
- (MPPTextEmbedderOptions *)textEmbedderOptionsWithModelName:(NSString *)modelName {
NSString *modelPath = [self filePathWithName:modelName extension:@"tflite"];
MPPTextEmbedderOptions *textEmbedderOptions = [[MPPTextEmbedderOptions alloc] init];
textEmbedderOptions.baseOptions.modelAssetPath = modelPath;
return textEmbedderOptions;
}
- (NSArray<NSNumber *> *)assertFloatEmbeddingResultsOfEmbedText:(NSString *)text
usingTextEmbedder:(MPPTextEmbedder *)textEmbedder
hasCount:(NSUInteger)embeddingCount
firstValue:(float)firstValue {
MPPTextEmbedderResult *embedderResult = [textEmbedder embedText:text error:nil];
AssertTextEmbedderResultHasOneEmbedding(embedderResult);
AssertEmbeddingIsFloat(embedderResult.embeddingResult.embeddings[0]);
AssertFloatEmbeddingHasExpectedValues(embedderResult.embeddingResult.embeddings[0].floatEmbedding,
embeddingCount, firstValue);
AssertEmbeddingType(
embedderResult.embeddingResult.embeddings[0], // embedding
NO // quantized
);
AssertEmbeddingHasExpectedValues(
embedderResult.embeddingResult.embeddings[0].floatEmbedding, // embedding
embeddingCount, // expectedLength
firstValue, // expectedFirstValue
NO // quantize
);
return embedderResult.embeddingResult.embeddings[0];
}
- (NSArray<NSNumber *>*)assertQuantizedEmbeddingResultsOfEmbedText:(NSString *)text
usingTextEmbedder:(MPPTextEmbedder *)textEmbedder
hasCount:(NSUInteger)embeddingCount
firstValue:(char)firstValue {
MPPTextEmbedderResult *embedderResult = [textEmbedder embedText:text error:nil];
AssertTextEmbedderResultHasOneEmbedding(embedderResult);
AssertEmbeddingType(
embedderResult.embeddingResult.embeddings[0], // embedding
YES // quantized
);
AssertEmbeddingHasExpectedValues(
embedderResult.embeddingResult.embeddings[0].quantizedEmbedding, // embedding
embeddingCount, // expectedLength
firstValue, // expectedFirstValue
YES // quantize
);
return embedderResult.embeddingResult.embeddings[0];
}
@ -97,7 +148,10 @@ static const float kDoubleDiffTolerance = 1e-4;
@"INVALID_ARGUMENT: ExternalFile must specify at least one of 'file_content', "
@"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."
}];
AssertEqualErrors(error, expectedError);
AssertEqualErrors(
error, // error
expectedError // expectedError
);
}
- (void)testEmbedWithBertSucceeds {
@ -116,7 +170,7 @@ static const float kDoubleDiffTolerance = 1e-4;
NSNumber *cosineSimilarity = [MPPTextEmbedder cosineSimilarityBetweenEmbedding1:embedding1
andEmbedding2:embedding2
error:nil];
XCTAssertEqualWithAccuracy(cosineSimilarity.doubleValue, 0.96386, kDoubleDiffTolerance);
XCTAssertEqualWithAccuracy(cosineSimilarity.doubleValue, 0.96386, kSimilarityDiffTolerance);
}
- (void)testEmbedWithRegexSucceeds {
@ -136,7 +190,7 @@ static const float kDoubleDiffTolerance = 1e-4;
NSNumber *cosineSimilarity = [MPPTextEmbedder cosineSimilarityBetweenEmbedding1:embedding1
andEmbedding2:embedding2
error:nil];
XCTAssertEqualWithAccuracy(cosineSimilarity.doubleValue, 0.999937f, kDoubleDiffTolerance);
XCTAssertEqualWithAccuracy(cosineSimilarity.doubleValue, 0.999937f, kSimilarityDiffTolerance);
}
- (void)testEmbedWithBertAndDifferentThemesSucceeds {
@ -163,7 +217,28 @@ static const float kDoubleDiffTolerance = 1e-4;
error:nil];
// TODO: The similarity should likely be lower
XCTAssertEqualWithAccuracy(cosineSimilarity.doubleValue, 0.963203f, kDoubleDiffTolerance);
XCTAssertEqualWithAccuracy(cosineSimilarity.doubleValue, 0.963203f, kSimilarityDiffTolerance);
}
- (void)testEmbedWithQuantizeSucceeds {
MPPTextEmbedderOptions *options =
[self textEmbedderOptionsWithModelName:kBertTextEmbedderModelName];
options.quantize = YES;
MPPTextEmbedder *textEmbedder = [[MPPTextEmbedder alloc] initWithOptions:options error:nil];
XCTAssertNotNil(textEmbedder);
MPPEmbedding *embedding1 = [self assertQuantizedEmbeddingResultsOfEmbedText:@"it's a charming and often affecting journey"
usingTextEmbedder:textEmbedder
hasCount:512
firstValue:127];
MPPEmbedding *embedding2 = [self assertQuantizedEmbeddingResultsOfEmbedText:@"what a great and fantastic trip"
usingTextEmbedder:textEmbedder
hasCount:512
firstValue:127];
NSNumber *cosineSimilarity = [MPPTextEmbedder cosineSimilarityBetweenEmbedding1:embedding1 andEmbedding2:embedding2 error:nil];
XCTAssertEqualWithAccuracy(cosineSimilarity.doubleValue, 0.864113, kSimilarityDiffTolerance);
}
@end