Added swift tests for text embedder
This commit is contained in:
parent
474e994a5f
commit
d625918995
|
@ -53,3 +53,28 @@ ios_unit_test(
|
|||
":MPPTextEmbedderObjcTestLibrary",
|
||||
],
|
||||
)
|
||||
|
||||
swift_library(
|
||||
name = "MPPTextEmbedderSwiftTestLibrary",
|
||||
testonly = 1,
|
||||
srcs = ["TextEmbedderTests.swift"],
|
||||
data = [
|
||||
"//mediapipe/tasks/testdata/text:mobilebert_embedding_model",
|
||||
"//mediapipe/tasks/testdata/text:regex_embedding_with_metadata",
|
||||
],
|
||||
tags = TFL_DEFAULT_TAGS,
|
||||
deps = [
|
||||
"//mediapipe/tasks/ios/common:MPPCommon",
|
||||
"//mediapipe/tasks/ios/text/text_embedder:MPPTextEmbedder",
|
||||
],
|
||||
)
|
||||
|
||||
ios_unit_test(
|
||||
name = "MPPTextEmbedderSwiftTest",
|
||||
minimum_os_version = MPP_TASK_MINIMUM_OS_VERSION,
|
||||
runner = tflite_ios_lab_runner("IOS_LATEST"),
|
||||
tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS,
|
||||
deps = [
|
||||
":MPPTextEmbedderSwiftTestLibrary",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -0,0 +1,114 @@
|
|||
// Copyright 2023 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
import MPPCommon
|
||||
import XCTest
|
||||
|
||||
@testable import MPPTextEmbedder
|
||||
|
||||
class TextEmbedderTests: XCTestCase {
|
||||
|
||||
static let bundle = Bundle(for: TextEmbedderTests.self)
|
||||
|
||||
static let bertModelPath = bundle.path(
|
||||
forResource: "mobilebert_embedding_with_metadata",
|
||||
ofType: "tflite")
|
||||
|
||||
static let text1 = "it's a charming and often affecting journey"
|
||||
|
||||
static let text2 = "what a great and fantastic trip"
|
||||
|
||||
static let floatDiffTolerance: Float = 1e-4
|
||||
|
||||
static let doubleDiffTolerance: Double = 1e-4
|
||||
|
||||
func assertEqualErrorDescriptions(
|
||||
_ error: Error, expectedLocalizedDescription: String
|
||||
) {
|
||||
XCTAssertEqual(
|
||||
error.localizedDescription,
|
||||
expectedLocalizedDescription)
|
||||
}
|
||||
|
||||
func assertTextEmbedderResultHasOneEmbedding(
|
||||
_ textEmbedderResult: TextEmbedderResult
|
||||
) {
|
||||
XCTAssertEqual(textEmbedderResult.embeddingResult.embeddings.count, 1)
|
||||
}
|
||||
|
||||
func assertEmbeddingIsFloat(
|
||||
_ embedding: Embedding
|
||||
) {
|
||||
XCTAssertNil(embedding.quantizedEmbedding)
|
||||
XCTAssertNotNil(embedding.floatEmbedding)
|
||||
}
|
||||
|
||||
func assertEmbedding(
|
||||
_ floatEmbedding: [NSNumber],
|
||||
hasCount embeddingCount: Int,
|
||||
hasFirstValue firstValue: Float
|
||||
) {
|
||||
XCTAssertEqual(floatEmbedding.count, embeddingCount);
|
||||
XCTAssertEqual(
|
||||
floatEmbedding[0].floatValue,
|
||||
firstValue, accuracy:
|
||||
TextEmbedderTests.floatDiffTolerance);
|
||||
}
|
||||
|
||||
func assertFloatEmbeddingResultsForEmbed(
|
||||
text: String,
|
||||
using textEmbedder: TextEmbedder,
|
||||
hasCount embeddingCount: Int,
|
||||
hasFirstValue firstValue: Float
|
||||
) throws -> Embedding {
|
||||
let textEmbedderResult =
|
||||
try XCTUnwrap(
|
||||
textEmbedder.embed(text: text))
|
||||
assertTextEmbedderResultHasOneEmbedding(textEmbedderResult)
|
||||
assertEmbeddingIsFloat(textEmbedderResult.embeddingResult.embeddings[0])
|
||||
assertEmbedding(
|
||||
textEmbedderResult.embeddingResult.embeddings[0].floatEmbedding!,
|
||||
hasCount: embeddingCount,
|
||||
hasFirstValue: firstValue)
|
||||
|
||||
return textEmbedderResult.embeddingResult.embeddings[0]
|
||||
}
|
||||
|
||||
func testEmbedWithBertSucceeds() throws {
|
||||
|
||||
let modelPath = try XCTUnwrap(TextEmbedderTests.bertModelPath)
|
||||
let textEmbedder = try XCTUnwrap(TextEmbedder(modelPath: modelPath))
|
||||
|
||||
let embedding1 = try assertFloatEmbeddingResultsForEmbed(
|
||||
text: TextEmbedderTests.text1,
|
||||
using: textEmbedder,
|
||||
hasCount: 512,
|
||||
hasFirstValue: 20.057026)
|
||||
|
||||
let embedding2 = try assertFloatEmbeddingResultsForEmbed(
|
||||
text: TextEmbedderTests.text2,
|
||||
using: textEmbedder,
|
||||
hasCount: 512,
|
||||
hasFirstValue: 21.254150)
|
||||
|
||||
let cosineSimilarity = try XCTUnwrap(TextEmbedder.cosineSimilarity(
|
||||
embedding1: embedding1,
|
||||
embedding2: embedding2))
|
||||
|
||||
XCTAssertEqual(
|
||||
cosineSimilarity.doubleValue,
|
||||
0.96386,
|
||||
accuracy: TextEmbedderTests.doubleDiffTolerance)
|
||||
}
|
||||
}
|
|
@ -100,8 +100,8 @@ NS_SWIFT_NAME(TextEmbedder)
|
|||
*/
|
||||
+ (nullable NSNumber *)cosineSimilarityBetweenEmbedding1:(MPPEmbedding *)embedding1
|
||||
andEmbedding2:(MPPEmbedding *)embedding2
|
||||
error:(NSError **)error
|
||||
NS_SWIFT_NAME(cosineSimilarity(embedding1: embedding2:));
|
||||
error:(NSError **)error NS_SWIFT_NAME(cosineSimilarity(embedding1:embedding2:));
|
||||
// NS_SWIFT_NAME(cosineSimilarity(embedding1: embedding2:));
|
||||
|
||||
+ (instancetype)new NS_UNAVAILABLE;
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user