Added swift tests for text embedder
This commit is contained in:
parent
474e994a5f
commit
d625918995
|
@ -53,3 +53,28 @@ ios_unit_test(
|
||||||
":MPPTextEmbedderObjcTestLibrary",
|
":MPPTextEmbedderObjcTestLibrary",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
swift_library(
|
||||||
|
name = "MPPTextEmbedderSwiftTestLibrary",
|
||||||
|
testonly = 1,
|
||||||
|
srcs = ["TextEmbedderTests.swift"],
|
||||||
|
data = [
|
||||||
|
"//mediapipe/tasks/testdata/text:mobilebert_embedding_model",
|
||||||
|
"//mediapipe/tasks/testdata/text:regex_embedding_with_metadata",
|
||||||
|
],
|
||||||
|
tags = TFL_DEFAULT_TAGS,
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/tasks/ios/common:MPPCommon",
|
||||||
|
"//mediapipe/tasks/ios/text/text_embedder:MPPTextEmbedder",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
ios_unit_test(
|
||||||
|
name = "MPPTextEmbedderSwiftTest",
|
||||||
|
minimum_os_version = MPP_TASK_MINIMUM_OS_VERSION,
|
||||||
|
runner = tflite_ios_lab_runner("IOS_LATEST"),
|
||||||
|
tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS,
|
||||||
|
deps = [
|
||||||
|
":MPPTextEmbedderSwiftTestLibrary",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
|
@ -0,0 +1,114 @@
|
||||||
|
// Copyright 2023 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
import MPPCommon
|
||||||
|
import XCTest
|
||||||
|
|
||||||
|
@testable import MPPTextEmbedder
|
||||||
|
|
||||||
|
class TextEmbedderTests: XCTestCase {
|
||||||
|
|
||||||
|
static let bundle = Bundle(for: TextEmbedderTests.self)
|
||||||
|
|
||||||
|
static let bertModelPath = bundle.path(
|
||||||
|
forResource: "mobilebert_embedding_with_metadata",
|
||||||
|
ofType: "tflite")
|
||||||
|
|
||||||
|
static let text1 = "it's a charming and often affecting journey"
|
||||||
|
|
||||||
|
static let text2 = "what a great and fantastic trip"
|
||||||
|
|
||||||
|
static let floatDiffTolerance: Float = 1e-4
|
||||||
|
|
||||||
|
static let doubleDiffTolerance: Double = 1e-4
|
||||||
|
|
||||||
|
func assertEqualErrorDescriptions(
|
||||||
|
_ error: Error, expectedLocalizedDescription: String
|
||||||
|
) {
|
||||||
|
XCTAssertEqual(
|
||||||
|
error.localizedDescription,
|
||||||
|
expectedLocalizedDescription)
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertTextEmbedderResultHasOneEmbedding(
|
||||||
|
_ textEmbedderResult: TextEmbedderResult
|
||||||
|
) {
|
||||||
|
XCTAssertEqual(textEmbedderResult.embeddingResult.embeddings.count, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertEmbeddingIsFloat(
|
||||||
|
_ embedding: Embedding
|
||||||
|
) {
|
||||||
|
XCTAssertNil(embedding.quantizedEmbedding)
|
||||||
|
XCTAssertNotNil(embedding.floatEmbedding)
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertEmbedding(
|
||||||
|
_ floatEmbedding: [NSNumber],
|
||||||
|
hasCount embeddingCount: Int,
|
||||||
|
hasFirstValue firstValue: Float
|
||||||
|
) {
|
||||||
|
XCTAssertEqual(floatEmbedding.count, embeddingCount);
|
||||||
|
XCTAssertEqual(
|
||||||
|
floatEmbedding[0].floatValue,
|
||||||
|
firstValue, accuracy:
|
||||||
|
TextEmbedderTests.floatDiffTolerance);
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertFloatEmbeddingResultsForEmbed(
|
||||||
|
text: String,
|
||||||
|
using textEmbedder: TextEmbedder,
|
||||||
|
hasCount embeddingCount: Int,
|
||||||
|
hasFirstValue firstValue: Float
|
||||||
|
) throws -> Embedding {
|
||||||
|
let textEmbedderResult =
|
||||||
|
try XCTUnwrap(
|
||||||
|
textEmbedder.embed(text: text))
|
||||||
|
assertTextEmbedderResultHasOneEmbedding(textEmbedderResult)
|
||||||
|
assertEmbeddingIsFloat(textEmbedderResult.embeddingResult.embeddings[0])
|
||||||
|
assertEmbedding(
|
||||||
|
textEmbedderResult.embeddingResult.embeddings[0].floatEmbedding!,
|
||||||
|
hasCount: embeddingCount,
|
||||||
|
hasFirstValue: firstValue)
|
||||||
|
|
||||||
|
return textEmbedderResult.embeddingResult.embeddings[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
func testEmbedWithBertSucceeds() throws {
|
||||||
|
|
||||||
|
let modelPath = try XCTUnwrap(TextEmbedderTests.bertModelPath)
|
||||||
|
let textEmbedder = try XCTUnwrap(TextEmbedder(modelPath: modelPath))
|
||||||
|
|
||||||
|
let embedding1 = try assertFloatEmbeddingResultsForEmbed(
|
||||||
|
text: TextEmbedderTests.text1,
|
||||||
|
using: textEmbedder,
|
||||||
|
hasCount: 512,
|
||||||
|
hasFirstValue: 20.057026)
|
||||||
|
|
||||||
|
let embedding2 = try assertFloatEmbeddingResultsForEmbed(
|
||||||
|
text: TextEmbedderTests.text2,
|
||||||
|
using: textEmbedder,
|
||||||
|
hasCount: 512,
|
||||||
|
hasFirstValue: 21.254150)
|
||||||
|
|
||||||
|
let cosineSimilarity = try XCTUnwrap(TextEmbedder.cosineSimilarity(
|
||||||
|
embedding1: embedding1,
|
||||||
|
embedding2: embedding2))
|
||||||
|
|
||||||
|
XCTAssertEqual(
|
||||||
|
cosineSimilarity.doubleValue,
|
||||||
|
0.96386,
|
||||||
|
accuracy: TextEmbedderTests.doubleDiffTolerance)
|
||||||
|
}
|
||||||
|
}
|
|
@ -100,8 +100,8 @@ NS_SWIFT_NAME(TextEmbedder)
|
||||||
*/
|
*/
|
||||||
+ (nullable NSNumber *)cosineSimilarityBetweenEmbedding1:(MPPEmbedding *)embedding1
|
+ (nullable NSNumber *)cosineSimilarityBetweenEmbedding1:(MPPEmbedding *)embedding1
|
||||||
andEmbedding2:(MPPEmbedding *)embedding2
|
andEmbedding2:(MPPEmbedding *)embedding2
|
||||||
error:(NSError **)error
|
error:(NSError **)error NS_SWIFT_NAME(cosineSimilarity(embedding1:embedding2:));
|
||||||
NS_SWIFT_NAME(cosineSimilarity(embedding1: embedding2:));
|
// NS_SWIFT_NAME(cosineSimilarity(embedding1: embedding2:));
|
||||||
|
|
||||||
+ (instancetype)new NS_UNAVAILABLE;
|
+ (instancetype)new NS_UNAVAILABLE;
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user