From 3b55fb9f6a86d7dfbf9c2119543f2addc00a751f Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Fri, 3 Feb 2023 13:42:32 +0530 Subject: [PATCH] Added iOS test for quantized embedding --- .../text/text_embedder/MPPTextEmbedderTests.m | 103 +++++++++++++++--- 1 file changed, 89 insertions(+), 14 deletions(-) diff --git a/mediapipe/tasks/ios/test/text/text_embedder/MPPTextEmbedderTests.m b/mediapipe/tasks/ios/test/text/text_embedder/MPPTextEmbedderTests.m index 36e0ef8c0..0468c9b81 100644 --- a/mediapipe/tasks/ios/test/text/text_embedder/MPPTextEmbedderTests.m +++ b/mediapipe/tasks/ios/test/text/text_embedder/MPPTextEmbedderTests.m @@ -23,7 +23,7 @@ static NSString *const kText1 = @"it's a charming and often affecting journey"; static NSString *const kText2 = @"what a great and fantastic trip"; static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks"; static const float kFloatDiffTolerance = 1e-4; -static const float kDoubleDiffTolerance = 1e-4; +static const float kSimilarityDiffTolerance = 1e-4; #define AssertEqualErrors(error, expectedError) \ XCTAssertNotNil(error); \ @@ -38,13 +38,24 @@ static const float kDoubleDiffTolerance = 1e-4; XCTAssertNotNil(textEmbedderResult.embeddingResult); \ XCTAssertEqual(textEmbedderResult.embeddingResult.embeddings.count, 1); -#define AssertEmbeddingIsFloat(embedding) \ - XCTAssertNotNil(embedding.floatEmbedding); \ - XCTAssertNil(embedding.quantizedEmbedding); +#define AssertEmbeddingType(embedding, quantized) \ + if (quantized) { \ + XCTAssertNil(embedding.floatEmbedding); \ + XCTAssertNotNil(embedding.quantizedEmbedding); \ + } \ + else { \ + XCTAssertNotNil(embedding.floatEmbedding); \ + XCTAssertNil(embedding.quantizedEmbedding);\ + } -#define AssertFloatEmbeddingHasExpectedValues(floatEmbedding, expectedLength, expectedFirstValue) \ - XCTAssertEqual(floatEmbedding.count, expectedLength); \ - XCTAssertEqualWithAccuracy(floatEmbedding[0].floatValue, expectedFirstValue, kFloatDiffTolerance); +#define AssertEmbeddingHasExpectedValues(embedding, expectedLength, expectedFirstValue, quantize) \ + XCTAssertEqual(embedding.count, expectedLength); \ + if (quantize) { \ + XCTAssertEqual(embedding[0].charValue, expectedFirstValue); \ + } \ + else { \ + XCTAssertEqualWithAccuracy(embedding[0].floatValue, expectedFirstValue, kFloatDiffTolerance); \ + } \ @interface MPPTextEmbedderTests : XCTestCase @end @@ -69,15 +80,55 @@ static const float kDoubleDiffTolerance = 1e-4; return textEmbedder; } +- (MPPTextEmbedderOptions *)textEmbedderOptionsWithModelName:(NSString *)modelName { + NSString *modelPath = [self filePathWithName:modelName extension:@"tflite"]; + MPPTextEmbedderOptions *textEmbedderOptions = [[MPPTextEmbedderOptions alloc] init]; + textEmbedderOptions.baseOptions.modelAssetPath = modelPath; + + return textEmbedderOptions; +} + - (NSArray *)assertFloatEmbeddingResultsOfEmbedText:(NSString *)text usingTextEmbedder:(MPPTextEmbedder *)textEmbedder hasCount:(NSUInteger)embeddingCount firstValue:(float)firstValue { MPPTextEmbedderResult *embedderResult = [textEmbedder embedText:text error:nil]; AssertTextEmbedderResultHasOneEmbedding(embedderResult); - AssertEmbeddingIsFloat(embedderResult.embeddingResult.embeddings[0]); - AssertFloatEmbeddingHasExpectedValues(embedderResult.embeddingResult.embeddings[0].floatEmbedding, - embeddingCount, firstValue); + + AssertEmbeddingType( + embedderResult.embeddingResult.embeddings[0], // embedding + NO // quantized + ); + + AssertEmbeddingHasExpectedValues( + embedderResult.embeddingResult.embeddings[0].floatEmbedding, // embedding + embeddingCount, // expectedLength + firstValue, // expectedFirstValue + NO // quantize + ); + + return embedderResult.embeddingResult.embeddings[0]; +} + +- (NSArray*)assertQuantizedEmbeddingResultsOfEmbedText:(NSString *)text + usingTextEmbedder:(MPPTextEmbedder *)textEmbedder + hasCount:(NSUInteger)embeddingCount + firstValue:(char)firstValue { + MPPTextEmbedderResult *embedderResult = [textEmbedder embedText:text error:nil]; + AssertTextEmbedderResultHasOneEmbedding(embedderResult); + + AssertEmbeddingType( + embedderResult.embeddingResult.embeddings[0], // embedding + YES // quantized + ); + + AssertEmbeddingHasExpectedValues( + embedderResult.embeddingResult.embeddings[0].quantizedEmbedding, // embedding + embeddingCount, // expectedLength + firstValue, // expectedFirstValue + YES // quantize + ); + return embedderResult.embeddingResult.embeddings[0]; } @@ -97,7 +148,10 @@ static const float kDoubleDiffTolerance = 1e-4; @"INVALID_ARGUMENT: ExternalFile must specify at least one of 'file_content', " @"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'." }]; - AssertEqualErrors(error, expectedError); + AssertEqualErrors( + error, // error + expectedError // expectedError + ); } - (void)testEmbedWithBertSucceeds { @@ -116,7 +170,7 @@ static const float kDoubleDiffTolerance = 1e-4; NSNumber *cosineSimilarity = [MPPTextEmbedder cosineSimilarityBetweenEmbedding1:embedding1 andEmbedding2:embedding2 error:nil]; - XCTAssertEqualWithAccuracy(cosineSimilarity.doubleValue, 0.96386, kDoubleDiffTolerance); + XCTAssertEqualWithAccuracy(cosineSimilarity.doubleValue, 0.96386, kSimilarityDiffTolerance); } - (void)testEmbedWithRegexSucceeds { @@ -136,7 +190,7 @@ static const float kDoubleDiffTolerance = 1e-4; NSNumber *cosineSimilarity = [MPPTextEmbedder cosineSimilarityBetweenEmbedding1:embedding1 andEmbedding2:embedding2 error:nil]; - XCTAssertEqualWithAccuracy(cosineSimilarity.doubleValue, 0.999937f, kDoubleDiffTolerance); + XCTAssertEqualWithAccuracy(cosineSimilarity.doubleValue, 0.999937f, kSimilarityDiffTolerance); } - (void)testEmbedWithBertAndDifferentThemesSucceeds { @@ -163,7 +217,28 @@ static const float kDoubleDiffTolerance = 1e-4; error:nil]; // TODO: The similarity should likely be lower - XCTAssertEqualWithAccuracy(cosineSimilarity.doubleValue, 0.963203f, kDoubleDiffTolerance); + XCTAssertEqualWithAccuracy(cosineSimilarity.doubleValue, 0.963203f, kSimilarityDiffTolerance); +} + +- (void)testEmbedWithQuantizeSucceeds { + MPPTextEmbedderOptions *options = + [self textEmbedderOptionsWithModelName:kBertTextEmbedderModelName]; + options.quantize = YES; + + MPPTextEmbedder *textEmbedder = [[MPPTextEmbedder alloc] initWithOptions:options error:nil]; + XCTAssertNotNil(textEmbedder); + + MPPEmbedding *embedding1 = [self assertQuantizedEmbeddingResultsOfEmbedText:@"it's a charming and often affecting journey" + usingTextEmbedder:textEmbedder + hasCount:512 + firstValue:127]; + + MPPEmbedding *embedding2 = [self assertQuantizedEmbeddingResultsOfEmbedText:@"what a great and fantastic trip" + usingTextEmbedder:textEmbedder + hasCount:512 + firstValue:127]; + NSNumber *cosineSimilarity = [MPPTextEmbedder cosineSimilarityBetweenEmbedding1:embedding1 andEmbedding2:embedding2 error:nil]; + XCTAssertEqualWithAccuracy(cosineSimilarity.doubleValue, 0.864113, kSimilarityDiffTolerance); } @end