Added swift tests for text embedder
This commit is contained in:
		
							parent
							
								
									474e994a5f
								
							
						
					
					
						commit
						d625918995
					
				| 
						 | 
					@ -53,3 +53,28 @@ ios_unit_test(
 | 
				
			||||||
        ":MPPTextEmbedderObjcTestLibrary",
 | 
					        ":MPPTextEmbedderObjcTestLibrary",
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					swift_library(
 | 
				
			||||||
 | 
					    name = "MPPTextEmbedderSwiftTestLibrary",
 | 
				
			||||||
 | 
					    testonly = 1,
 | 
				
			||||||
 | 
					    srcs = ["TextEmbedderTests.swift"],
 | 
				
			||||||
 | 
					    data = [
 | 
				
			||||||
 | 
					       "//mediapipe/tasks/testdata/text:mobilebert_embedding_model",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/testdata/text:regex_embedding_with_metadata",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					    tags = TFL_DEFAULT_TAGS,
 | 
				
			||||||
 | 
					    deps = [
 | 
				
			||||||
 | 
					       "//mediapipe/tasks/ios/common:MPPCommon",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/ios/text/text_embedder:MPPTextEmbedder",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					ios_unit_test(
 | 
				
			||||||
 | 
					    name = "MPPTextEmbedderSwiftTest",
 | 
				
			||||||
 | 
					    minimum_os_version = MPP_TASK_MINIMUM_OS_VERSION,
 | 
				
			||||||
 | 
					    runner = tflite_ios_lab_runner("IOS_LATEST"),
 | 
				
			||||||
 | 
					    tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS,
 | 
				
			||||||
 | 
					    deps = [
 | 
				
			||||||
 | 
					        ":MPPTextEmbedderSwiftTestLibrary",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,114 @@
 | 
				
			||||||
 | 
					// Copyright 2023 The MediaPipe Authors.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					// you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					// You may obtain a copy of the License at
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//      http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					// distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					// See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					// limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import MPPCommon
 | 
				
			||||||
 | 
					import XCTest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@testable import MPPTextEmbedder
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class TextEmbedderTests: XCTestCase {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  static let bundle = Bundle(for: TextEmbedderTests.self)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  static let bertModelPath = bundle.path(
 | 
				
			||||||
 | 
					    forResource: "mobilebert_embedding_with_metadata",
 | 
				
			||||||
 | 
					    ofType: "tflite")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  static let text1 = "it's a charming and often affecting journey"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  static let text2 = "what a great and fantastic trip"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  static let floatDiffTolerance: Float = 1e-4
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  static let doubleDiffTolerance: Double = 1e-4
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  func assertEqualErrorDescriptions(
 | 
				
			||||||
 | 
					    _ error: Error, expectedLocalizedDescription: String
 | 
				
			||||||
 | 
					  ) {
 | 
				
			||||||
 | 
					    XCTAssertEqual(
 | 
				
			||||||
 | 
					      error.localizedDescription,
 | 
				
			||||||
 | 
					      expectedLocalizedDescription)
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  func assertTextEmbedderResultHasOneEmbedding(
 | 
				
			||||||
 | 
					    _ textEmbedderResult: TextEmbedderResult
 | 
				
			||||||
 | 
					  ) {
 | 
				
			||||||
 | 
					    XCTAssertEqual(textEmbedderResult.embeddingResult.embeddings.count, 1)
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  func assertEmbeddingIsFloat(
 | 
				
			||||||
 | 
					    _ embedding: Embedding
 | 
				
			||||||
 | 
					  ) {
 | 
				
			||||||
 | 
					    XCTAssertNil(embedding.quantizedEmbedding)
 | 
				
			||||||
 | 
					    XCTAssertNotNil(embedding.floatEmbedding)
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  func assertEmbedding(
 | 
				
			||||||
 | 
					    _ floatEmbedding: [NSNumber],
 | 
				
			||||||
 | 
					    hasCount embeddingCount: Int,
 | 
				
			||||||
 | 
					    hasFirstValue firstValue: Float
 | 
				
			||||||
 | 
					  ) {
 | 
				
			||||||
 | 
					    XCTAssertEqual(floatEmbedding.count, embeddingCount); 
 | 
				
			||||||
 | 
					    XCTAssertEqual(
 | 
				
			||||||
 | 
					      floatEmbedding[0].floatValue, 
 | 
				
			||||||
 | 
					      firstValue, accuracy: 
 | 
				
			||||||
 | 
					      TextEmbedderTests.floatDiffTolerance);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  func assertFloatEmbeddingResultsForEmbed(
 | 
				
			||||||
 | 
					    text: String,
 | 
				
			||||||
 | 
					    using textEmbedder: TextEmbedder,
 | 
				
			||||||
 | 
					    hasCount embeddingCount: Int,
 | 
				
			||||||
 | 
					    hasFirstValue firstValue: Float
 | 
				
			||||||
 | 
					  ) throws -> Embedding {
 | 
				
			||||||
 | 
					    let textEmbedderResult =
 | 
				
			||||||
 | 
					      try XCTUnwrap(
 | 
				
			||||||
 | 
					        textEmbedder.embed(text: text))
 | 
				
			||||||
 | 
					    assertTextEmbedderResultHasOneEmbedding(textEmbedderResult)
 | 
				
			||||||
 | 
					    assertEmbeddingIsFloat(textEmbedderResult.embeddingResult.embeddings[0])
 | 
				
			||||||
 | 
					    assertEmbedding(
 | 
				
			||||||
 | 
					      textEmbedderResult.embeddingResult.embeddings[0].floatEmbedding!,
 | 
				
			||||||
 | 
					      hasCount: embeddingCount,
 | 
				
			||||||
 | 
					      hasFirstValue: firstValue)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return textEmbedderResult.embeddingResult.embeddings[0]
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  func testEmbedWithBertSucceeds() throws {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    let modelPath = try XCTUnwrap(TextEmbedderTests.bertModelPath)
 | 
				
			||||||
 | 
					    let textEmbedder = try XCTUnwrap(TextEmbedder(modelPath: modelPath))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    let embedding1 = try assertFloatEmbeddingResultsForEmbed(
 | 
				
			||||||
 | 
					      text: TextEmbedderTests.text1,
 | 
				
			||||||
 | 
					      using: textEmbedder,
 | 
				
			||||||
 | 
					      hasCount: 512,
 | 
				
			||||||
 | 
					      hasFirstValue: 20.057026)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    let embedding2 = try assertFloatEmbeddingResultsForEmbed(
 | 
				
			||||||
 | 
					      text: TextEmbedderTests.text2,
 | 
				
			||||||
 | 
					      using: textEmbedder,
 | 
				
			||||||
 | 
					      hasCount: 512,
 | 
				
			||||||
 | 
					      hasFirstValue: 21.254150)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    let cosineSimilarity = try XCTUnwrap(TextEmbedder.cosineSimilarity(
 | 
				
			||||||
 | 
					      embedding1: embedding1,
 | 
				
			||||||
 | 
					      embedding2: embedding2))
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    XCTAssertEqual(
 | 
				
			||||||
 | 
					      cosineSimilarity.doubleValue, 
 | 
				
			||||||
 | 
					      0.96386, 
 | 
				
			||||||
 | 
					      accuracy: TextEmbedderTests.doubleDiffTolerance)
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -100,8 +100,8 @@ NS_SWIFT_NAME(TextEmbedder)
 | 
				
			||||||
 */
 | 
					 */
 | 
				
			||||||
+ (nullable NSNumber *)cosineSimilarityBetweenEmbedding1:(MPPEmbedding *)embedding1
 | 
					+ (nullable NSNumber *)cosineSimilarityBetweenEmbedding1:(MPPEmbedding *)embedding1
 | 
				
			||||||
                                           andEmbedding2:(MPPEmbedding *)embedding2
 | 
					                                           andEmbedding2:(MPPEmbedding *)embedding2
 | 
				
			||||||
                                                   error:(NSError **)error
 | 
					                                                   error:(NSError **)error NS_SWIFT_NAME(cosineSimilarity(embedding1:embedding2:));
 | 
				
			||||||
    NS_SWIFT_NAME(cosineSimilarity(embedding1: embedding2:));
 | 
					    // NS_SWIFT_NAME(cosineSimilarity(embedding1: embedding2:));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
+ (instancetype)new NS_UNAVAILABLE;
 | 
					+ (instancetype)new NS_UNAVAILABLE;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user