120 lines
3.7 KiB
Swift
120 lines
3.7 KiB
Swift
// Copyright 2023 The MediaPipe Authors.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
import MPPCommon
|
|
import XCTest
|
|
|
|
@testable import MPPTextEmbedder
|
|
|
|
/** These tests are only for validating the Swift function signatures of the TextEmbedder.
|
|
* Objective C tests of the TextEmbedder provide more coverage with unit tests for
|
|
* different models and text embedder options. They can be found here:
|
|
* /mediapipe/tasks/ios/test/text/text_embedder/MPPTextEmbedderTests.m
|
|
*/
|
|
class TextEmbedderTests: XCTestCase {
|
|
|
|
static let bundle = Bundle(for: TextEmbedderTests.self)
|
|
|
|
static let bertModelPath = bundle.path(
|
|
forResource: "mobilebert_embedding_with_metadata",
|
|
ofType: "tflite")
|
|
|
|
static let text1 = "it's a charming and often affecting journey"
|
|
|
|
static let text2 = "what a great and fantastic trip"
|
|
|
|
static let floatDiffTolerance: Float = 1e-4
|
|
|
|
static let doubleDiffTolerance: Double = 1e-4
|
|
|
|
func assertEqualErrorDescriptions(
|
|
_ error: Error, expectedLocalizedDescription: String
|
|
) {
|
|
XCTAssertEqual(
|
|
error.localizedDescription,
|
|
expectedLocalizedDescription)
|
|
}
|
|
|
|
func assertTextEmbedderResultHasOneEmbedding(
|
|
_ textEmbedderResult: TextEmbedderResult
|
|
) {
|
|
XCTAssertEqual(textEmbedderResult.embeddingResult.embeddings.count, 1)
|
|
}
|
|
|
|
func assertEmbeddingIsFloat(
|
|
_ embedding: Embedding
|
|
) {
|
|
XCTAssertNil(embedding.quantizedEmbedding)
|
|
XCTAssertNotNil(embedding.floatEmbedding)
|
|
}
|
|
|
|
func assertEmbedding(
|
|
_ floatEmbedding: [NSNumber],
|
|
hasCount embeddingCount: Int,
|
|
hasFirstValue firstValue: Float
|
|
) {
|
|
XCTAssertEqual(floatEmbedding.count, embeddingCount);
|
|
XCTAssertEqual(
|
|
floatEmbedding[0].floatValue,
|
|
firstValue, accuracy:
|
|
TextEmbedderTests.floatDiffTolerance);
|
|
}
|
|
|
|
func assertFloatEmbeddingResultsForEmbed(
|
|
text: String,
|
|
using textEmbedder: TextEmbedder,
|
|
hasCount embeddingCount: Int,
|
|
hasFirstValue firstValue: Float
|
|
) throws -> Embedding {
|
|
let textEmbedderResult =
|
|
try XCTUnwrap(
|
|
textEmbedder.embed(text: text))
|
|
assertTextEmbedderResultHasOneEmbedding(textEmbedderResult)
|
|
assertEmbeddingIsFloat(textEmbedderResult.embeddingResult.embeddings[0])
|
|
assertEmbedding(
|
|
textEmbedderResult.embeddingResult.embeddings[0].floatEmbedding!,
|
|
hasCount: embeddingCount,
|
|
hasFirstValue: firstValue)
|
|
|
|
return textEmbedderResult.embeddingResult.embeddings[0]
|
|
}
|
|
|
|
func testEmbedWithBertSucceeds() throws {
|
|
|
|
let modelPath = try XCTUnwrap(TextEmbedderTests.bertModelPath)
|
|
let textEmbedder = try XCTUnwrap(TextEmbedder(modelPath: modelPath))
|
|
|
|
let embedding1 = try assertFloatEmbeddingResultsForEmbed(
|
|
text: TextEmbedderTests.text1,
|
|
using: textEmbedder,
|
|
hasCount: 512,
|
|
hasFirstValue: 20.057026)
|
|
|
|
let embedding2 = try assertFloatEmbeddingResultsForEmbed(
|
|
text: TextEmbedderTests.text2,
|
|
using: textEmbedder,
|
|
hasCount: 512,
|
|
hasFirstValue: 21.254150)
|
|
|
|
let cosineSimilarity = try XCTUnwrap(TextEmbedder.cosineSimilarity(
|
|
embedding1: embedding1,
|
|
embedding2: embedding2))
|
|
|
|
XCTAssertEqual(
|
|
cosineSimilarity.doubleValue,
|
|
0.96386,
|
|
accuracy: TextEmbedderTests.doubleDiffTolerance)
|
|
}
|
|
}
|