Added swift tests for text embedder

This commit is contained in:
Prianka Liz Kariat 2023-02-02 18:36:55 +05:30
parent 474e994a5f
commit d625918995
3 changed files with 141 additions and 2 deletions

View File

@ -53,3 +53,28 @@ ios_unit_test(
":MPPTextEmbedderObjcTestLibrary", ":MPPTextEmbedderObjcTestLibrary",
], ],
) )
swift_library(
name = "MPPTextEmbedderSwiftTestLibrary",
testonly = 1,
srcs = ["TextEmbedderTests.swift"],
data = [
"//mediapipe/tasks/testdata/text:mobilebert_embedding_model",
"//mediapipe/tasks/testdata/text:regex_embedding_with_metadata",
],
tags = TFL_DEFAULT_TAGS,
deps = [
"//mediapipe/tasks/ios/common:MPPCommon",
"//mediapipe/tasks/ios/text/text_embedder:MPPTextEmbedder",
],
)
ios_unit_test(
name = "MPPTextEmbedderSwiftTest",
minimum_os_version = MPP_TASK_MINIMUM_OS_VERSION,
runner = tflite_ios_lab_runner("IOS_LATEST"),
tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS,
deps = [
":MPPTextEmbedderSwiftTestLibrary",
],
)

View File

@ -0,0 +1,114 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import MPPCommon
import XCTest
@testable import MPPTextEmbedder
class TextEmbedderTests: XCTestCase {
static let bundle = Bundle(for: TextEmbedderTests.self)
static let bertModelPath = bundle.path(
forResource: "mobilebert_embedding_with_metadata",
ofType: "tflite")
static let text1 = "it's a charming and often affecting journey"
static let text2 = "what a great and fantastic trip"
static let floatDiffTolerance: Float = 1e-4
static let doubleDiffTolerance: Double = 1e-4
func assertEqualErrorDescriptions(
_ error: Error, expectedLocalizedDescription: String
) {
XCTAssertEqual(
error.localizedDescription,
expectedLocalizedDescription)
}
func assertTextEmbedderResultHasOneEmbedding(
_ textEmbedderResult: TextEmbedderResult
) {
XCTAssertEqual(textEmbedderResult.embeddingResult.embeddings.count, 1)
}
func assertEmbeddingIsFloat(
_ embedding: Embedding
) {
XCTAssertNil(embedding.quantizedEmbedding)
XCTAssertNotNil(embedding.floatEmbedding)
}
func assertEmbedding(
_ floatEmbedding: [NSNumber],
hasCount embeddingCount: Int,
hasFirstValue firstValue: Float
) {
XCTAssertEqual(floatEmbedding.count, embeddingCount);
XCTAssertEqual(
floatEmbedding[0].floatValue,
firstValue, accuracy:
TextEmbedderTests.floatDiffTolerance);
}
func assertFloatEmbeddingResultsForEmbed(
text: String,
using textEmbedder: TextEmbedder,
hasCount embeddingCount: Int,
hasFirstValue firstValue: Float
) throws -> Embedding {
let textEmbedderResult =
try XCTUnwrap(
textEmbedder.embed(text: text))
assertTextEmbedderResultHasOneEmbedding(textEmbedderResult)
assertEmbeddingIsFloat(textEmbedderResult.embeddingResult.embeddings[0])
assertEmbedding(
textEmbedderResult.embeddingResult.embeddings[0].floatEmbedding!,
hasCount: embeddingCount,
hasFirstValue: firstValue)
return textEmbedderResult.embeddingResult.embeddings[0]
}
func testEmbedWithBertSucceeds() throws {
let modelPath = try XCTUnwrap(TextEmbedderTests.bertModelPath)
let textEmbedder = try XCTUnwrap(TextEmbedder(modelPath: modelPath))
let embedding1 = try assertFloatEmbeddingResultsForEmbed(
text: TextEmbedderTests.text1,
using: textEmbedder,
hasCount: 512,
hasFirstValue: 20.057026)
let embedding2 = try assertFloatEmbeddingResultsForEmbed(
text: TextEmbedderTests.text2,
using: textEmbedder,
hasCount: 512,
hasFirstValue: 21.254150)
let cosineSimilarity = try XCTUnwrap(TextEmbedder.cosineSimilarity(
embedding1: embedding1,
embedding2: embedding2))
XCTAssertEqual(
cosineSimilarity.doubleValue,
0.96386,
accuracy: TextEmbedderTests.doubleDiffTolerance)
}
}

View File

@ -100,8 +100,8 @@ NS_SWIFT_NAME(TextEmbedder)
*/ */
+ (nullable NSNumber *)cosineSimilarityBetweenEmbedding1:(MPPEmbedding *)embedding1 + (nullable NSNumber *)cosineSimilarityBetweenEmbedding1:(MPPEmbedding *)embedding1
andEmbedding2:(MPPEmbedding *)embedding2 andEmbedding2:(MPPEmbedding *)embedding2
error:(NSError **)error error:(NSError **)error NS_SWIFT_NAME(cosineSimilarity(embedding1:embedding2:));
NS_SWIFT_NAME(cosineSimilarity(embedding1: embedding2:)); // NS_SWIFT_NAME(cosineSimilarity(embedding1: embedding2:));
+ (instancetype)new NS_UNAVAILABLE; + (instancetype)new NS_UNAVAILABLE;