Added text embedder objective c tests
This commit is contained in:
parent
867520af1c
commit
474e994a5f
55
mediapipe/tasks/ios/test/text/text_embedder/BUILD
Normal file
55
mediapipe/tasks/ios/test/text/text_embedder/BUILD
Normal file
|
@ -0,0 +1,55 @@
|
|||
load(
|
||||
"@build_bazel_rules_apple//apple:ios.bzl",
|
||||
"ios_unit_test",
|
||||
)
|
||||
load(
|
||||
"@build_bazel_rules_swift//swift:swift.bzl",
|
||||
"swift_library",
|
||||
)
|
||||
load(
|
||||
"//mediapipe/tasks:ios/ios.bzl",
|
||||
"MPP_TASK_MINIMUM_OS_VERSION",
|
||||
)
|
||||
load(
|
||||
"@org_tensorflow//tensorflow/lite:special_rules.bzl",
|
||||
"tflite_ios_lab_runner",
|
||||
)
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
# Default tags for filtering iOS targets. Targets are restricted to Apple platforms.
|
||||
TFL_DEFAULT_TAGS = [
|
||||
"apple",
|
||||
]
|
||||
|
||||
# Following sanitizer tests are not supported by iOS test targets.
|
||||
TFL_DISABLED_SANITIZER_TAGS = [
|
||||
"noasan",
|
||||
"nomsan",
|
||||
"notsan",
|
||||
]
|
||||
|
||||
objc_library(
|
||||
name = "MPPTextEmbedderObjcTestLibrary",
|
||||
testonly = 1,
|
||||
srcs = ["MPPTextEmbedderTests.m"],
|
||||
data = [
|
||||
"//mediapipe/tasks/testdata/text:mobilebert_embedding_model",
|
||||
"//mediapipe/tasks/testdata/text:regex_embedding_with_metadata",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/tasks/ios/common:MPPCommon",
|
||||
"//mediapipe/tasks/ios/text/text_embedder:MPPTextEmbedder",
|
||||
],
|
||||
)
|
||||
|
||||
ios_unit_test(
|
||||
name = "MPPTextEmbedderObjcTest",
|
||||
minimum_os_version = MPP_TASK_MINIMUM_OS_VERSION,
|
||||
runner = tflite_ios_lab_runner("IOS_LATEST"),
|
||||
deps = [
|
||||
":MPPTextEmbedderObjcTestLibrary",
|
||||
],
|
||||
)
|
|
@ -0,0 +1,142 @@
|
|||
// Copyright 2023 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#import <XCTest/XCTest.h>
|
||||
|
||||
#import "mediapipe/tasks/ios/common/sources/MPPCommon.h"
|
||||
#import "mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.h"
|
||||
|
||||
static NSString *const kBertTextEmbedderModelName = @"mobilebert_embedding_with_metadata";
|
||||
static NSString *const kRegexTextEmbedderModelName = @"regex_one_embedding_with_metadata";
|
||||
static NSString *const kText1 = @"it's a charming and often affecting journey";
|
||||
static NSString *const kText2 = @"what a great and fantastic trip";
|
||||
static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
|
||||
static const float kFloatDiffTolerance = 1e-4;
|
||||
static const float kDoubleDiffTolerance = 1e-4;
|
||||
|
||||
#define AssertEqualErrors(error, expectedError) \
|
||||
XCTAssertNotNil(error); \
|
||||
XCTAssertEqualObjects(error.domain, expectedError.domain); \
|
||||
XCTAssertEqual(error.code, expectedError.code); \
|
||||
XCTAssertNotEqual( \
|
||||
[error.localizedDescription rangeOfString:expectedError.localizedDescription].location, \
|
||||
NSNotFound)
|
||||
|
||||
#define AssertTextEmbedderResultHasOneEmbedding(textEmbedderResult) \
|
||||
XCTAssertNotNil(textEmbedderResult); \
|
||||
XCTAssertNotNil(textEmbedderResult.embeddingResult); \
|
||||
XCTAssertEqual(textEmbedderResult.embeddingResult.embeddings.count, 1);
|
||||
|
||||
#define AssertEmbeddingIsFloat(embedding) \
|
||||
XCTAssertNotNil(embedding.floatEmbedding); \
|
||||
XCTAssertNil(embedding.quantizedEmbedding);
|
||||
|
||||
#define AssertFloatEmbeddingHasExpectedValues(floatEmbedding, expectedLength, expectedFirstValue) \
|
||||
XCTAssertEqual(floatEmbedding.count, expectedLength); \
|
||||
XCTAssertEqualWithAccuracy(floatEmbedding[0].floatValue, expectedFirstValue, kFloatDiffTolerance);
|
||||
|
||||
@interface MPPTextEmbedderTests : XCTestCase
|
||||
@end
|
||||
|
||||
@implementation MPPTextEmbedderTests
|
||||
|
||||
- (NSString *)filePathWithName:(NSString *)fileName extension:(NSString *)extension {
|
||||
NSString *filePath = [[NSBundle bundleForClass:self.class] pathForResource:fileName
|
||||
ofType:extension];
|
||||
return filePath;
|
||||
}
|
||||
|
||||
- (MPPTextEmbedder *)textEmbedderFromModelFileWithName:(NSString *)modelName {
|
||||
NSString *modelPath = [self filePathWithName:modelName extension:@"tflite"];
|
||||
|
||||
NSError *error = nil;
|
||||
MPPTextEmbedder *textEmbedder = [[MPPTextEmbedder alloc] initWithModelPath:modelPath
|
||||
error:&error];
|
||||
|
||||
XCTAssertNotNil(textEmbedder);
|
||||
|
||||
return textEmbedder;
|
||||
}
|
||||
|
||||
- (NSArray<NSNumber *> *)assertFloatEmbeddingResultsOfEmbedText:(NSString *)text
|
||||
usingTextEmbedder:(MPPTextEmbedder *)textEmbedder
|
||||
hasCount:(NSUInteger)embeddingCount
|
||||
firstValue:(float)firstValue {
|
||||
MPPTextEmbedderResult *embedderResult = [textEmbedder embedText:text error:nil];
|
||||
AssertTextEmbedderResultHasOneEmbedding(embedderResult);
|
||||
AssertEmbeddingIsFloat(embedderResult.embeddingResult.embeddings[0]);
|
||||
AssertFloatEmbeddingHasExpectedValues(embedderResult.embeddingResult.embeddings[0].floatEmbedding,
|
||||
embeddingCount, firstValue);
|
||||
return embedderResult.embeddingResult.embeddings[0];
|
||||
}
|
||||
|
||||
- (void)testCreateTextEmbedderFailsWithMissingModelPath {
|
||||
NSString *modelPath = [self filePathWithName:@"" extension:@""];
|
||||
|
||||
NSError *error = nil;
|
||||
MPPTextEmbedder *textEmbedder = [[MPPTextEmbedder alloc] initWithModelPath:modelPath
|
||||
error:&error];
|
||||
XCTAssertNil(textEmbedder);
|
||||
|
||||
NSError *expectedError = [NSError
|
||||
errorWithDomain:kExpectedErrorDomain
|
||||
code:MPPTasksErrorCodeInvalidArgumentError
|
||||
userInfo:@{
|
||||
NSLocalizedDescriptionKey :
|
||||
@"INVALID_ARGUMENT: ExternalFile must specify at least one of 'file_content', "
|
||||
@"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."
|
||||
}];
|
||||
AssertEqualErrors(error, expectedError);
|
||||
}
|
||||
|
||||
- (void)testEmbedWithBertSucceeds {
|
||||
MPPTextEmbedder *textEmbedder =
|
||||
[self textEmbedderFromModelFileWithName:kBertTextEmbedderModelName];
|
||||
|
||||
MPPEmbedding *embedding1 = [self assertFloatEmbeddingResultsOfEmbedText:kText1
|
||||
usingTextEmbedder:textEmbedder
|
||||
hasCount:512
|
||||
firstValue:20.057026f];
|
||||
|
||||
MPPEmbedding *embedding2 = [self assertFloatEmbeddingResultsOfEmbedText:kText2
|
||||
usingTextEmbedder:textEmbedder
|
||||
hasCount:512
|
||||
firstValue:21.254150f];
|
||||
NSNumber *cosineSimilarity = [MPPTextEmbedder cosineSimilarityBetweenEmbedding1:embedding1
|
||||
andEmbedding2:embedding2
|
||||
error:nil];
|
||||
XCTAssertEqualWithAccuracy(cosineSimilarity.doubleValue, 0.96386, kDoubleDiffTolerance);
|
||||
}
|
||||
|
||||
- (void)testEmbedWithRegexSucceeds {
|
||||
MPPTextEmbedder *textEmbedder =
|
||||
[self textEmbedderFromModelFileWithName:kRegexTextEmbedderModelName];
|
||||
|
||||
MPPEmbedding *embedding1 = [self assertFloatEmbeddingResultsOfEmbedText:kText1
|
||||
usingTextEmbedder:textEmbedder
|
||||
hasCount:16
|
||||
firstValue:0.030935612f];
|
||||
|
||||
MPPEmbedding *embedding2 = [self assertFloatEmbeddingResultsOfEmbedText:kText2
|
||||
usingTextEmbedder:textEmbedder
|
||||
hasCount:16
|
||||
firstValue:0.0312863f];
|
||||
|
||||
NSNumber *cosineSimilarity = [MPPTextEmbedder cosineSimilarityBetweenEmbedding1:embedding1
|
||||
andEmbedding2:embedding2
|
||||
error:nil];
|
||||
XCTAssertEqualWithAccuracy(cosineSimilarity.doubleValue, 0.999937f, kDoubleDiffTolerance);
|
||||
}
|
||||
|
||||
@end
|
Loading…
Reference in New Issue
Block a user