diff --git a/mediapipe/tasks/ios/test/text/text_embedder/MPPTextEmbedderTests.m b/mediapipe/tasks/ios/test/text/text_embedder/MPPTextEmbedderTests.m index c58c52298..36e0ef8c0 100644 --- a/mediapipe/tasks/ios/test/text/text_embedder/MPPTextEmbedderTests.m +++ b/mediapipe/tasks/ios/test/text/text_embedder/MPPTextEmbedderTests.m @@ -139,4 +139,31 @@ static const float kDoubleDiffTolerance = 1e-4; XCTAssertEqualWithAccuracy(cosineSimilarity.doubleValue, 0.999937f, kDoubleDiffTolerance); } +- (void)testEmbedWithBertAndDifferentThemesSucceeds { + MPPTextEmbedder *textEmbedder = + [self textEmbedderFromModelFileWithName:kBertTextEmbedderModelName]; + + MPPEmbedding *embedding1 = + [self assertFloatEmbeddingResultsOfEmbedText: + @"When you go to this restaurant, they hold the pancake upside-down before they " + @"hand it to you. It's a great gimmick." + usingTextEmbedder:textEmbedder + hasCount:512 + firstValue:42.0832]; + + MPPEmbedding *embedding2 = + [self assertFloatEmbeddingResultsOfEmbedText: + @"Let's make a plan to steal the declaration of independence." + usingTextEmbedder:textEmbedder + hasCount:512 + firstValue:50.8856]; + + NSNumber *cosineSimilarity = [MPPTextEmbedder cosineSimilarityBetweenEmbedding1:embedding1 + andEmbedding2:embedding2 + error:nil]; + + // TODO: The similarity should likely be lower + XCTAssertEqualWithAccuracy(cosineSimilarity.doubleValue, 0.963203f, kDoubleDiffTolerance); +} + @end