Updated documentation of iOS text embedder tests

This commit is contained in:
Prianka Liz Kariat 2023-02-03 18:06:05 +05:30
parent e290f9cf30
commit eeaa011998

View File

@ -38,24 +38,22 @@ static const float kSimilarityDiffTolerance = 1e-4;
XCTAssertNotNil(textEmbedderResult.embeddingResult); \ XCTAssertNotNil(textEmbedderResult.embeddingResult); \
XCTAssertEqual(textEmbedderResult.embeddingResult.embeddings.count, 1); XCTAssertEqual(textEmbedderResult.embeddingResult.embeddings.count, 1);
#define AssertEmbeddingType(embedding, quantized) \ #define AssertEmbeddingType(embedding, quantized) \
if (quantized) { \ if (quantized) { \
XCTAssertNil(embedding.floatEmbedding); \ XCTAssertNil(embedding.floatEmbedding); \
XCTAssertNotNil(embedding.quantizedEmbedding); \ XCTAssertNotNil(embedding.quantizedEmbedding); \
} \ } else { \
else { \ XCTAssertNotNil(embedding.floatEmbedding); \
XCTAssertNotNil(embedding.floatEmbedding); \ XCTAssertNil(embedding.quantizedEmbedding); \
XCTAssertNil(embedding.quantizedEmbedding);\
} }
#define AssertEmbeddingHasExpectedValues(embedding, expectedLength, expectedFirstValue, quantize) \ #define AssertEmbeddingHasExpectedValues(embedding, expectedLength, expectedFirstValue, quantize) \
XCTAssertEqual(embedding.count, expectedLength); \ XCTAssertEqual(embedding.count, expectedLength); \
if (quantize) { \ if (quantize) { \
XCTAssertEqual(embedding[0].charValue, expectedFirstValue); \ XCTAssertEqual(embedding[0].charValue, expectedFirstValue); \
} \ } else { \
else { \
XCTAssertEqualWithAccuracy(embedding[0].floatValue, expectedFirstValue, kFloatDiffTolerance); \ XCTAssertEqualWithAccuracy(embedding[0].floatValue, expectedFirstValue, kFloatDiffTolerance); \
} \ }
@interface MPPTextEmbedderTests : XCTestCase @interface MPPTextEmbedderTests : XCTestCase
@end @end
@ -95,38 +93,36 @@ static const float kSimilarityDiffTolerance = 1e-4;
MPPTextEmbedderResult *embedderResult = [textEmbedder embedText:text error:nil]; MPPTextEmbedderResult *embedderResult = [textEmbedder embedText:text error:nil];
AssertTextEmbedderResultHasOneEmbedding(embedderResult); AssertTextEmbedderResultHasOneEmbedding(embedderResult);
AssertEmbeddingType( AssertEmbeddingType(embedderResult.embeddingResult.embeddings[0], // embedding
embedderResult.embeddingResult.embeddings[0], // embedding NO // quantized
NO // quantized
); );
AssertEmbeddingHasExpectedValues( AssertEmbeddingHasExpectedValues(
embedderResult.embeddingResult.embeddings[0].floatEmbedding, // embedding embedderResult.embeddingResult.embeddings[0].floatEmbedding, // embedding
embeddingCount, // expectedLength embeddingCount, // expectedLength
firstValue, // expectedFirstValue firstValue, // expectedFirstValue
NO // quantize NO // quantize
); );
return embedderResult.embeddingResult.embeddings[0]; return embedderResult.embeddingResult.embeddings[0];
} }
- (NSArray<NSNumber *>*)assertQuantizedEmbeddingResultsOfEmbedText:(NSString *)text - (NSArray<NSNumber *> *)assertQuantizedEmbeddingResultsOfEmbedText:(NSString *)text
usingTextEmbedder:(MPPTextEmbedder *)textEmbedder usingTextEmbedder:(MPPTextEmbedder *)textEmbedder
hasCount:(NSUInteger)embeddingCount hasCount:(NSUInteger)embeddingCount
firstValue:(char)firstValue { firstValue:(char)firstValue {
MPPTextEmbedderResult *embedderResult = [textEmbedder embedText:text error:nil]; MPPTextEmbedderResult *embedderResult = [textEmbedder embedText:text error:nil];
AssertTextEmbedderResultHasOneEmbedding(embedderResult); AssertTextEmbedderResultHasOneEmbedding(embedderResult);
AssertEmbeddingType( AssertEmbeddingType(embedderResult.embeddingResult.embeddings[0], // embedding
embedderResult.embeddingResult.embeddings[0], // embedding YES // quantized
YES // quantized
); );
AssertEmbeddingHasExpectedValues( AssertEmbeddingHasExpectedValues(
embedderResult.embeddingResult.embeddings[0].quantizedEmbedding, // embedding embedderResult.embeddingResult.embeddings[0].quantizedEmbedding, // embedding
embeddingCount, // expectedLength embeddingCount, // expectedLength
firstValue, // expectedFirstValue firstValue, // expectedFirstValue
YES // quantize YES // quantize
); );
return embedderResult.embeddingResult.embeddings[0]; return embedderResult.embeddingResult.embeddings[0];
@ -148,9 +144,8 @@ static const float kSimilarityDiffTolerance = 1e-4;
@"INVALID_ARGUMENT: ExternalFile must specify at least one of 'file_content', " @"INVALID_ARGUMENT: ExternalFile must specify at least one of 'file_content', "
@"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'." @"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."
}]; }];
AssertEqualErrors( AssertEqualErrors(error, // error
error, // error expectedError // expectedError
expectedError // expectedError
); );
} }
@ -228,16 +223,20 @@ static const float kSimilarityDiffTolerance = 1e-4;
MPPTextEmbedder *textEmbedder = [[MPPTextEmbedder alloc] initWithOptions:options error:nil]; MPPTextEmbedder *textEmbedder = [[MPPTextEmbedder alloc] initWithOptions:options error:nil];
XCTAssertNotNil(textEmbedder); XCTAssertNotNil(textEmbedder);
MPPEmbedding *embedding1 = [self assertQuantizedEmbeddingResultsOfEmbedText:@"it's a charming and often affecting journey" MPPEmbedding *embedding1 = [self
usingTextEmbedder:textEmbedder assertQuantizedEmbeddingResultsOfEmbedText:@"it's a charming and often affecting journey"
hasCount:512 usingTextEmbedder:textEmbedder
firstValue:127]; hasCount:512
firstValue:127];
MPPEmbedding *embedding2 = [self assertQuantizedEmbeddingResultsOfEmbedText:@"what a great and fantastic trip" MPPEmbedding *embedding2 =
usingTextEmbedder:textEmbedder [self assertQuantizedEmbeddingResultsOfEmbedText:@"what a great and fantastic trip"
hasCount:512 usingTextEmbedder:textEmbedder
firstValue:127]; hasCount:512
NSNumber *cosineSimilarity = [MPPTextEmbedder cosineSimilarityBetweenEmbedding1:embedding1 andEmbedding2:embedding2 error:nil]; firstValue:127];
NSNumber *cosineSimilarity = [MPPTextEmbedder cosineSimilarityBetweenEmbedding1:embedding1
andEmbedding2:embedding2
error:nil];
XCTAssertEqualWithAccuracy(cosineSimilarity.doubleValue, 0.864113, kSimilarityDiffTolerance); XCTAssertEqualWithAccuracy(cosineSimilarity.doubleValue, 0.864113, kSimilarityDiffTolerance);
} }