Added iOS test for quantized embedding
This commit is contained in:
parent
b5b10e7681
commit
3b55fb9f6a
|
@ -23,7 +23,7 @@ static NSString *const kText1 = @"it's a charming and often affecting journey";
|
|||
static NSString *const kText2 = @"what a great and fantastic trip";
|
||||
static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
|
||||
static const float kFloatDiffTolerance = 1e-4;
|
||||
static const float kDoubleDiffTolerance = 1e-4;
|
||||
static const float kSimilarityDiffTolerance = 1e-4;
|
||||
|
||||
#define AssertEqualErrors(error, expectedError) \
|
||||
XCTAssertNotNil(error); \
|
||||
|
@ -38,13 +38,24 @@ static const float kDoubleDiffTolerance = 1e-4;
|
|||
XCTAssertNotNil(textEmbedderResult.embeddingResult); \
|
||||
XCTAssertEqual(textEmbedderResult.embeddingResult.embeddings.count, 1);
|
||||
|
||||
#define AssertEmbeddingIsFloat(embedding) \
|
||||
#define AssertEmbeddingType(embedding, quantized) \
|
||||
if (quantized) { \
|
||||
XCTAssertNil(embedding.floatEmbedding); \
|
||||
XCTAssertNotNil(embedding.quantizedEmbedding); \
|
||||
} \
|
||||
else { \
|
||||
XCTAssertNotNil(embedding.floatEmbedding); \
|
||||
XCTAssertNil(embedding.quantizedEmbedding);
|
||||
XCTAssertNil(embedding.quantizedEmbedding);\
|
||||
}
|
||||
|
||||
#define AssertFloatEmbeddingHasExpectedValues(floatEmbedding, expectedLength, expectedFirstValue) \
|
||||
XCTAssertEqual(floatEmbedding.count, expectedLength); \
|
||||
XCTAssertEqualWithAccuracy(floatEmbedding[0].floatValue, expectedFirstValue, kFloatDiffTolerance);
|
||||
#define AssertEmbeddingHasExpectedValues(embedding, expectedLength, expectedFirstValue, quantize) \
|
||||
XCTAssertEqual(embedding.count, expectedLength); \
|
||||
if (quantize) { \
|
||||
XCTAssertEqual(embedding[0].charValue, expectedFirstValue); \
|
||||
} \
|
||||
else { \
|
||||
XCTAssertEqualWithAccuracy(embedding[0].floatValue, expectedFirstValue, kFloatDiffTolerance); \
|
||||
} \
|
||||
|
||||
@interface MPPTextEmbedderTests : XCTestCase
|
||||
@end
|
||||
|
@ -69,15 +80,55 @@ static const float kDoubleDiffTolerance = 1e-4;
|
|||
return textEmbedder;
|
||||
}
|
||||
|
||||
- (MPPTextEmbedderOptions *)textEmbedderOptionsWithModelName:(NSString *)modelName {
|
||||
NSString *modelPath = [self filePathWithName:modelName extension:@"tflite"];
|
||||
MPPTextEmbedderOptions *textEmbedderOptions = [[MPPTextEmbedderOptions alloc] init];
|
||||
textEmbedderOptions.baseOptions.modelAssetPath = modelPath;
|
||||
|
||||
return textEmbedderOptions;
|
||||
}
|
||||
|
||||
- (NSArray<NSNumber *> *)assertFloatEmbeddingResultsOfEmbedText:(NSString *)text
|
||||
usingTextEmbedder:(MPPTextEmbedder *)textEmbedder
|
||||
hasCount:(NSUInteger)embeddingCount
|
||||
firstValue:(float)firstValue {
|
||||
MPPTextEmbedderResult *embedderResult = [textEmbedder embedText:text error:nil];
|
||||
AssertTextEmbedderResultHasOneEmbedding(embedderResult);
|
||||
AssertEmbeddingIsFloat(embedderResult.embeddingResult.embeddings[0]);
|
||||
AssertFloatEmbeddingHasExpectedValues(embedderResult.embeddingResult.embeddings[0].floatEmbedding,
|
||||
embeddingCount, firstValue);
|
||||
|
||||
AssertEmbeddingType(
|
||||
embedderResult.embeddingResult.embeddings[0], // embedding
|
||||
NO // quantized
|
||||
);
|
||||
|
||||
AssertEmbeddingHasExpectedValues(
|
||||
embedderResult.embeddingResult.embeddings[0].floatEmbedding, // embedding
|
||||
embeddingCount, // expectedLength
|
||||
firstValue, // expectedFirstValue
|
||||
NO // quantize
|
||||
);
|
||||
|
||||
return embedderResult.embeddingResult.embeddings[0];
|
||||
}
|
||||
|
||||
- (NSArray<NSNumber *>*)assertQuantizedEmbeddingResultsOfEmbedText:(NSString *)text
|
||||
usingTextEmbedder:(MPPTextEmbedder *)textEmbedder
|
||||
hasCount:(NSUInteger)embeddingCount
|
||||
firstValue:(char)firstValue {
|
||||
MPPTextEmbedderResult *embedderResult = [textEmbedder embedText:text error:nil];
|
||||
AssertTextEmbedderResultHasOneEmbedding(embedderResult);
|
||||
|
||||
AssertEmbeddingType(
|
||||
embedderResult.embeddingResult.embeddings[0], // embedding
|
||||
YES // quantized
|
||||
);
|
||||
|
||||
AssertEmbeddingHasExpectedValues(
|
||||
embedderResult.embeddingResult.embeddings[0].quantizedEmbedding, // embedding
|
||||
embeddingCount, // expectedLength
|
||||
firstValue, // expectedFirstValue
|
||||
YES // quantize
|
||||
);
|
||||
|
||||
return embedderResult.embeddingResult.embeddings[0];
|
||||
}
|
||||
|
||||
|
@ -97,7 +148,10 @@ static const float kDoubleDiffTolerance = 1e-4;
|
|||
@"INVALID_ARGUMENT: ExternalFile must specify at least one of 'file_content', "
|
||||
@"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."
|
||||
}];
|
||||
AssertEqualErrors(error, expectedError);
|
||||
AssertEqualErrors(
|
||||
error, // error
|
||||
expectedError // expectedError
|
||||
);
|
||||
}
|
||||
|
||||
- (void)testEmbedWithBertSucceeds {
|
||||
|
@ -116,7 +170,7 @@ static const float kDoubleDiffTolerance = 1e-4;
|
|||
NSNumber *cosineSimilarity = [MPPTextEmbedder cosineSimilarityBetweenEmbedding1:embedding1
|
||||
andEmbedding2:embedding2
|
||||
error:nil];
|
||||
XCTAssertEqualWithAccuracy(cosineSimilarity.doubleValue, 0.96386, kDoubleDiffTolerance);
|
||||
XCTAssertEqualWithAccuracy(cosineSimilarity.doubleValue, 0.96386, kSimilarityDiffTolerance);
|
||||
}
|
||||
|
||||
- (void)testEmbedWithRegexSucceeds {
|
||||
|
@ -136,7 +190,7 @@ static const float kDoubleDiffTolerance = 1e-4;
|
|||
NSNumber *cosineSimilarity = [MPPTextEmbedder cosineSimilarityBetweenEmbedding1:embedding1
|
||||
andEmbedding2:embedding2
|
||||
error:nil];
|
||||
XCTAssertEqualWithAccuracy(cosineSimilarity.doubleValue, 0.999937f, kDoubleDiffTolerance);
|
||||
XCTAssertEqualWithAccuracy(cosineSimilarity.doubleValue, 0.999937f, kSimilarityDiffTolerance);
|
||||
}
|
||||
|
||||
- (void)testEmbedWithBertAndDifferentThemesSucceeds {
|
||||
|
@ -163,7 +217,28 @@ static const float kDoubleDiffTolerance = 1e-4;
|
|||
error:nil];
|
||||
|
||||
// TODO: The similarity should likely be lower
|
||||
XCTAssertEqualWithAccuracy(cosineSimilarity.doubleValue, 0.963203f, kDoubleDiffTolerance);
|
||||
XCTAssertEqualWithAccuracy(cosineSimilarity.doubleValue, 0.963203f, kSimilarityDiffTolerance);
|
||||
}
|
||||
|
||||
- (void)testEmbedWithQuantizeSucceeds {
|
||||
MPPTextEmbedderOptions *options =
|
||||
[self textEmbedderOptionsWithModelName:kBertTextEmbedderModelName];
|
||||
options.quantize = YES;
|
||||
|
||||
MPPTextEmbedder *textEmbedder = [[MPPTextEmbedder alloc] initWithOptions:options error:nil];
|
||||
XCTAssertNotNil(textEmbedder);
|
||||
|
||||
MPPEmbedding *embedding1 = [self assertQuantizedEmbeddingResultsOfEmbedText:@"it's a charming and often affecting journey"
|
||||
usingTextEmbedder:textEmbedder
|
||||
hasCount:512
|
||||
firstValue:127];
|
||||
|
||||
MPPEmbedding *embedding2 = [self assertQuantizedEmbeddingResultsOfEmbedText:@"what a great and fantastic trip"
|
||||
usingTextEmbedder:textEmbedder
|
||||
hasCount:512
|
||||
firstValue:127];
|
||||
NSNumber *cosineSimilarity = [MPPTextEmbedder cosineSimilarityBetweenEmbedding1:embedding1 andEmbedding2:embedding2 error:nil];
|
||||
XCTAssertEqualWithAccuracy(cosineSimilarity.doubleValue, 0.864113, kSimilarityDiffTolerance);
|
||||
}
|
||||
|
||||
@end
|
||||
|
|
Loading…
Reference in New Issue
Block a user