diff --git a/mediapipe/tasks/ios/test/text/text_embedder/MPPTextEmbedderTests.m b/mediapipe/tasks/ios/test/text/text_embedder/MPPTextEmbedderTests.m index 0468c9b81..2fa0f58f4 100644 --- a/mediapipe/tasks/ios/test/text/text_embedder/MPPTextEmbedderTests.m +++ b/mediapipe/tasks/ios/test/text/text_embedder/MPPTextEmbedderTests.m @@ -38,24 +38,22 @@ static const float kSimilarityDiffTolerance = 1e-4; XCTAssertNotNil(textEmbedderResult.embeddingResult); \ XCTAssertEqual(textEmbedderResult.embeddingResult.embeddings.count, 1); -#define AssertEmbeddingType(embedding, quantized) \ - if (quantized) { \ - XCTAssertNil(embedding.floatEmbedding); \ - XCTAssertNotNil(embedding.quantizedEmbedding); \ - } \ - else { \ - XCTAssertNotNil(embedding.floatEmbedding); \ - XCTAssertNil(embedding.quantizedEmbedding);\ - } +#define AssertEmbeddingType(embedding, quantized) \ + if (quantized) { \ + XCTAssertNil(embedding.floatEmbedding); \ + XCTAssertNotNil(embedding.quantizedEmbedding); \ + } else { \ + XCTAssertNotNil(embedding.floatEmbedding); \ + XCTAssertNil(embedding.quantizedEmbedding); \ + } #define AssertEmbeddingHasExpectedValues(embedding, expectedLength, expectedFirstValue, quantize) \ - XCTAssertEqual(embedding.count, expectedLength); \ - if (quantize) { \ - XCTAssertEqual(embedding[0].charValue, expectedFirstValue); \ - } \ - else { \ + XCTAssertEqual(embedding.count, expectedLength); \ + if (quantize) { \ + XCTAssertEqual(embedding[0].charValue, expectedFirstValue); \ + } else { \ XCTAssertEqualWithAccuracy(embedding[0].floatValue, expectedFirstValue, kFloatDiffTolerance); \ - } \ + } @interface MPPTextEmbedderTests : XCTestCase @end @@ -94,41 +92,39 @@ static const float kSimilarityDiffTolerance = 1e-4; firstValue:(float)firstValue { MPPTextEmbedderResult *embedderResult = [textEmbedder embedText:text error:nil]; AssertTextEmbedderResultHasOneEmbedding(embedderResult); - - AssertEmbeddingType( - embedderResult.embeddingResult.embeddings[0], // embedding - NO // quantized + + AssertEmbeddingType(embedderResult.embeddingResult.embeddings[0], // embedding + NO // quantized ); - + AssertEmbeddingHasExpectedValues( - embedderResult.embeddingResult.embeddings[0].floatEmbedding, // embedding - embeddingCount, // expectedLength - firstValue, // expectedFirstValue - NO // quantize + embedderResult.embeddingResult.embeddings[0].floatEmbedding, // embedding + embeddingCount, // expectedLength + firstValue, // expectedFirstValue + NO // quantize ); - + return embedderResult.embeddingResult.embeddings[0]; } -- (NSArray*)assertQuantizedEmbeddingResultsOfEmbedText:(NSString *)text - usingTextEmbedder:(MPPTextEmbedder *)textEmbedder - hasCount:(NSUInteger)embeddingCount - firstValue:(char)firstValue { +- (NSArray *)assertQuantizedEmbeddingResultsOfEmbedText:(NSString *)text + usingTextEmbedder:(MPPTextEmbedder *)textEmbedder + hasCount:(NSUInteger)embeddingCount + firstValue:(char)firstValue { MPPTextEmbedderResult *embedderResult = [textEmbedder embedText:text error:nil]; AssertTextEmbedderResultHasOneEmbedding(embedderResult); - - AssertEmbeddingType( - embedderResult.embeddingResult.embeddings[0], // embedding - YES // quantized + + AssertEmbeddingType(embedderResult.embeddingResult.embeddings[0], // embedding + YES // quantized ); - + AssertEmbeddingHasExpectedValues( - embedderResult.embeddingResult.embeddings[0].quantizedEmbedding, // embedding - embeddingCount, // expectedLength - firstValue, // expectedFirstValue - YES // quantize + embedderResult.embeddingResult.embeddings[0].quantizedEmbedding, // embedding + embeddingCount, // expectedLength + firstValue, // expectedFirstValue + YES // quantize ); - + return embedderResult.embeddingResult.embeddings[0]; } @@ -148,9 +144,8 @@ static const float kSimilarityDiffTolerance = 1e-4; @"INVALID_ARGUMENT: ExternalFile must specify at least one of 'file_content', " @"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'." }]; - AssertEqualErrors( - error, // error - expectedError // expectedError + AssertEqualErrors(error, // error + expectedError // expectedError ); } @@ -228,17 +223,21 @@ static const float kSimilarityDiffTolerance = 1e-4; MPPTextEmbedder *textEmbedder = [[MPPTextEmbedder alloc] initWithOptions:options error:nil]; XCTAssertNotNil(textEmbedder); - MPPEmbedding *embedding1 = [self assertQuantizedEmbeddingResultsOfEmbedText:@"it's a charming and often affecting journey" - usingTextEmbedder:textEmbedder - hasCount:512 - firstValue:127]; + MPPEmbedding *embedding1 = [self + assertQuantizedEmbeddingResultsOfEmbedText:@"it's a charming and often affecting journey" + usingTextEmbedder:textEmbedder + hasCount:512 + firstValue:127]; - MPPEmbedding *embedding2 = [self assertQuantizedEmbeddingResultsOfEmbedText:@"what a great and fantastic trip" - usingTextEmbedder:textEmbedder - hasCount:512 - firstValue:127]; - NSNumber *cosineSimilarity = [MPPTextEmbedder cosineSimilarityBetweenEmbedding1:embedding1 andEmbedding2:embedding2 error:nil]; - XCTAssertEqualWithAccuracy(cosineSimilarity.doubleValue, 0.864113, kSimilarityDiffTolerance); + MPPEmbedding *embedding2 = + [self assertQuantizedEmbeddingResultsOfEmbedText:@"what a great and fantastic trip" + usingTextEmbedder:textEmbedder + hasCount:512 + firstValue:127]; + NSNumber *cosineSimilarity = [MPPTextEmbedder cosineSimilarityBetweenEmbedding1:embedding1 + andEmbedding2:embedding2 + error:nil]; + XCTAssertEqualWithAccuracy(cosineSimilarity.doubleValue, 0.864113, kSimilarityDiffTolerance); } @end