Merge pull request #4052 from priankakariatyml:ios-text-embedder
PiperOrigin-RevId: 507602101
This commit is contained in:
		
						commit
						28c07430ba
					
				
							
								
								
									
										33
									
								
								mediapipe/tasks/ios/components/utils/BUILD
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										33
									
								
								mediapipe/tasks/ios/components/utils/BUILD
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,33 @@
 | 
				
			||||||
 | 
					# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#      http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package(default_visibility = ["//mediapipe/tasks:internal"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					licenses(["notice"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					objc_library(
 | 
				
			||||||
 | 
					    name = "MPPCosineSimilarity",
 | 
				
			||||||
 | 
					    srcs = ["sources/MPPCosineSimilarity.mm"],
 | 
				
			||||||
 | 
					    hdrs = ["sources/MPPCosineSimilarity.h"],
 | 
				
			||||||
 | 
					    copts = [
 | 
				
			||||||
 | 
					        "-ObjC++",
 | 
				
			||||||
 | 
					        "-std=c++17",
 | 
				
			||||||
 | 
					        "-x objective-c++",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					    deps = [
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/ios/common:MPPCommon",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/ios/components/containers:MPPEmbedding",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,48 @@
 | 
				
			||||||
 | 
					// Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					// you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					// You may obtain a copy of the License at
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//      http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					// distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					// See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					// limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#import <Foundation/Foundation.h>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#import "mediapipe/tasks/ios/components/containers/sources/MPPEmbedding.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					NS_ASSUME_NONNULL_BEGIN
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/** Utility class for computing cosine similarity between `MPPEmbedding` objects. */
 | 
				
			||||||
 | 
					NS_SWIFT_NAME(CosineSimilarity)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@interface MPPCosineSimilarity : NSObject
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (instancetype)init NS_UNAVAILABLE;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					+ (instancetype)new NS_UNAVAILABLE;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/** Utility function to compute[cosine similarity](https://en.wikipedia.org/wiki/Cosine_similarity)
 | 
				
			||||||
 | 
					 * between two `MPPEmbedding` objects.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * @param embedding1 One of the two `MPPEmbedding`s between whom cosine similarity is to be
 | 
				
			||||||
 | 
					 * computed.
 | 
				
			||||||
 | 
					 * @param embedding2 One of the two `MPPEmbedding`s between whom cosine similarity is to be
 | 
				
			||||||
 | 
					 * computed.
 | 
				
			||||||
 | 
					 * @param error An optional error parameter populated when there is an error in calculating cosine
 | 
				
			||||||
 | 
					 * similarity between two embeddings.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * @return An `NSNumber` which holds the cosine similarity of type `double`.
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					+ (nullable NSNumber *)computeBetweenEmbedding1:(MPPEmbedding *)embedding1
 | 
				
			||||||
 | 
					                                  andEmbedding2:(MPPEmbedding *)embedding2
 | 
				
			||||||
 | 
					                                          error:(NSError **)error;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					NS_ASSUME_NONNULL_END
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,88 @@
 | 
				
			||||||
 | 
					// Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					// you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					// You may obtain a copy of the License at
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//      http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					// distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					// See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					// limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#import "mediapipe/tasks/ios/components/utils/sources/MPPCosineSimilarity.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include <math.h>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#import "mediapipe/tasks/ios/common/sources/MPPCommon.h"
 | 
				
			||||||
 | 
					#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@implementation MPPCosineSimilarity
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					+ (nullable NSNumber *)computeBetweenVector1:(NSArray<NSNumber *> *)u
 | 
				
			||||||
 | 
					                                  andVector2:(NSArray<NSNumber *> *)v
 | 
				
			||||||
 | 
					                                     isFloat:(BOOL)isFloat
 | 
				
			||||||
 | 
					                                       error:(NSError **)error {
 | 
				
			||||||
 | 
					  if (u.count != v.count) {
 | 
				
			||||||
 | 
					    [MPPCommonUtils
 | 
				
			||||||
 | 
					        createCustomError:error
 | 
				
			||||||
 | 
					                 withCode:MPPTasksErrorCodeInvalidArgumentError
 | 
				
			||||||
 | 
					              description:[NSString stringWithFormat:@"Cannot compute cosine similarity between "
 | 
				
			||||||
 | 
					                                                     @"embeddings of different sizes (%lu vs %lu",
 | 
				
			||||||
 | 
					                                                     static_cast<u_long>(u.count),
 | 
				
			||||||
 | 
					                                                     static_cast<u_long>(v.count)]];
 | 
				
			||||||
 | 
					    return nil;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  __block double dotProduct = 0.0;
 | 
				
			||||||
 | 
					  __block double normU = 0.0;
 | 
				
			||||||
 | 
					  __block double normV = 0.0;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  [u enumerateObjectsUsingBlock:^(NSNumber *num, NSUInteger idx, BOOL *stop) {
 | 
				
			||||||
 | 
					    double uVal = 0.0;
 | 
				
			||||||
 | 
					    double vVal = 0.0;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (isFloat) {
 | 
				
			||||||
 | 
					      uVal = num.floatValue;
 | 
				
			||||||
 | 
					      vVal = v[idx].floatValue;
 | 
				
			||||||
 | 
					    } else {
 | 
				
			||||||
 | 
					      uVal = num.charValue;
 | 
				
			||||||
 | 
					      vVal = v[idx].charValue;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    dotProduct += uVal * vVal;
 | 
				
			||||||
 | 
					    normU += uVal * uVal;
 | 
				
			||||||
 | 
					    normV += vVal * vVal;
 | 
				
			||||||
 | 
					  }];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return [NSNumber numberWithDouble:dotProduct / sqrt(normU * normV)];
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					+ (nullable NSNumber *)computeBetweenEmbedding1:(MPPEmbedding *)embedding1
 | 
				
			||||||
 | 
					                                  andEmbedding2:(MPPEmbedding *)embedding2
 | 
				
			||||||
 | 
					                                          error:(NSError **)error {
 | 
				
			||||||
 | 
					  if (embedding1.floatEmbedding && embedding2.floatEmbedding) {
 | 
				
			||||||
 | 
					    return [MPPCosineSimilarity computeBetweenVector1:embedding1.floatEmbedding
 | 
				
			||||||
 | 
					                                           andVector2:embedding2.floatEmbedding
 | 
				
			||||||
 | 
					                                              isFloat:YES
 | 
				
			||||||
 | 
					                                                error:error];
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  if (embedding1.quantizedEmbedding && embedding2.quantizedEmbedding) {
 | 
				
			||||||
 | 
					    return [MPPCosineSimilarity computeBetweenVector1:embedding1.quantizedEmbedding
 | 
				
			||||||
 | 
					                                           andVector2:embedding2.quantizedEmbedding
 | 
				
			||||||
 | 
					                                              isFloat:NO
 | 
				
			||||||
 | 
					                                                error:error];
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  [MPPCommonUtils
 | 
				
			||||||
 | 
					      createCustomError:error
 | 
				
			||||||
 | 
					               withCode:MPPTasksErrorCodeInvalidArgumentError
 | 
				
			||||||
 | 
					            description:
 | 
				
			||||||
 | 
					                @"Cannot compute cosine similarity between quantized and float embeddings."];
 | 
				
			||||||
 | 
					  return nil;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@end
 | 
				
			||||||
							
								
								
									
										80
									
								
								mediapipe/tasks/ios/test/text/text_embedder/BUILD
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										80
									
								
								mediapipe/tasks/ios/test/text/text_embedder/BUILD
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,80 @@
 | 
				
			||||||
 | 
					load(
 | 
				
			||||||
 | 
					    "@build_bazel_rules_apple//apple:ios.bzl",
 | 
				
			||||||
 | 
					    "ios_unit_test",
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					load(
 | 
				
			||||||
 | 
					    "@build_bazel_rules_swift//swift:swift.bzl",
 | 
				
			||||||
 | 
					    "swift_library",
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					load(
 | 
				
			||||||
 | 
					    "//mediapipe/tasks:ios/ios.bzl",
 | 
				
			||||||
 | 
					    "MPP_TASK_MINIMUM_OS_VERSION",
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					load(
 | 
				
			||||||
 | 
					    "@org_tensorflow//tensorflow/lite:special_rules.bzl",
 | 
				
			||||||
 | 
					    "tflite_ios_lab_runner",
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package(default_visibility = ["//mediapipe/tasks:internal"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					licenses(["notice"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Default tags for filtering iOS targets. Targets are restricted to Apple platforms.
 | 
				
			||||||
 | 
					TFL_DEFAULT_TAGS = [
 | 
				
			||||||
 | 
					    "apple",
 | 
				
			||||||
 | 
					]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Following sanitizer tests are not supported by iOS test targets.
 | 
				
			||||||
 | 
					TFL_DISABLED_SANITIZER_TAGS = [
 | 
				
			||||||
 | 
					    "noasan",
 | 
				
			||||||
 | 
					    "nomsan",
 | 
				
			||||||
 | 
					    "notsan",
 | 
				
			||||||
 | 
					]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					objc_library(
 | 
				
			||||||
 | 
					    name = "MPPTextEmbedderObjcTestLibrary",
 | 
				
			||||||
 | 
					    testonly = 1,
 | 
				
			||||||
 | 
					    srcs = ["MPPTextEmbedderTests.m"],
 | 
				
			||||||
 | 
					    data = [
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/testdata/text:mobilebert_embedding_model",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/testdata/text:regex_embedding_with_metadata",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					    deps = [
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/ios/common:MPPCommon",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/ios/text/text_embedder:MPPTextEmbedder",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					ios_unit_test(
 | 
				
			||||||
 | 
					    name = "MPPTextEmbedderObjcTest",
 | 
				
			||||||
 | 
					    minimum_os_version = MPP_TASK_MINIMUM_OS_VERSION,
 | 
				
			||||||
 | 
					    runner = tflite_ios_lab_runner("IOS_LATEST"),
 | 
				
			||||||
 | 
					    deps = [
 | 
				
			||||||
 | 
					        ":MPPTextEmbedderObjcTestLibrary",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					swift_library(
 | 
				
			||||||
 | 
					    name = "MPPTextEmbedderSwiftTestLibrary",
 | 
				
			||||||
 | 
					    testonly = 1,
 | 
				
			||||||
 | 
					    srcs = ["TextEmbedderTests.swift"],
 | 
				
			||||||
 | 
					    data = [
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/testdata/text:mobilebert_embedding_model",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/testdata/text:regex_embedding_with_metadata",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					    tags = TFL_DEFAULT_TAGS,
 | 
				
			||||||
 | 
					    deps = [
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/ios/common:MPPCommon",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/ios/text/text_embedder:MPPTextEmbedder",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					ios_unit_test(
 | 
				
			||||||
 | 
					    name = "MPPTextEmbedderSwiftTest",
 | 
				
			||||||
 | 
					    minimum_os_version = MPP_TASK_MINIMUM_OS_VERSION,
 | 
				
			||||||
 | 
					    runner = tflite_ios_lab_runner("IOS_LATEST"),
 | 
				
			||||||
 | 
					    tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS,
 | 
				
			||||||
 | 
					    deps = [
 | 
				
			||||||
 | 
					        ":MPPTextEmbedderSwiftTestLibrary",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,246 @@
 | 
				
			||||||
 | 
					// Copyright 2023 The MediaPipe Authors.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					// you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					// You may obtain a copy of the License at
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//      http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					// distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					// See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					// limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#import <XCTest/XCTest.h>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#import "mediapipe/tasks/ios/common/sources/MPPCommon.h"
 | 
				
			||||||
 | 
					#import "mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					static NSString *const kBertTextEmbedderModelName = @"mobilebert_embedding_with_metadata";
 | 
				
			||||||
 | 
					static NSString *const kRegexTextEmbedderModelName = @"regex_one_embedding_with_metadata";
 | 
				
			||||||
 | 
					static NSString *const kText1 = @"it's a charming and often affecting journey";
 | 
				
			||||||
 | 
					static NSString *const kText2 = @"what a great and fantastic trip";
 | 
				
			||||||
 | 
					static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
 | 
				
			||||||
 | 
					static const float kFloatDiffTolerance = 1e-4;
 | 
				
			||||||
 | 
					static const float kSimilarityDiffTolerance = 1e-4;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#define AssertEqualErrors(error, expectedError)                                               \
 | 
				
			||||||
 | 
					  XCTAssertNotNil(error);                                                                     \
 | 
				
			||||||
 | 
					  XCTAssertEqualObjects(error.domain, expectedError.domain);                                  \
 | 
				
			||||||
 | 
					  XCTAssertEqual(error.code, expectedError.code);                                             \
 | 
				
			||||||
 | 
					  XCTAssertNotEqual(                                                                          \
 | 
				
			||||||
 | 
					      [error.localizedDescription rangeOfString:expectedError.localizedDescription].location, \
 | 
				
			||||||
 | 
					      NSNotFound)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#define AssertTextEmbedderResultHasOneEmbedding(textEmbedderResult) \
 | 
				
			||||||
 | 
					  XCTAssertNotNil(textEmbedderResult);                              \
 | 
				
			||||||
 | 
					  XCTAssertNotNil(textEmbedderResult.embeddingResult);              \
 | 
				
			||||||
 | 
					  XCTAssertEqual(textEmbedderResult.embeddingResult.embeddings.count, 1);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#define AssertEmbeddingType(embedding, quantized)  \
 | 
				
			||||||
 | 
					  if (quantized) {                                 \
 | 
				
			||||||
 | 
					    XCTAssertNil(embedding.floatEmbedding);        \
 | 
				
			||||||
 | 
					    XCTAssertNotNil(embedding.quantizedEmbedding); \
 | 
				
			||||||
 | 
					  } else {                                         \
 | 
				
			||||||
 | 
					    XCTAssertNotNil(embedding.floatEmbedding);     \
 | 
				
			||||||
 | 
					    XCTAssertNil(embedding.quantizedEmbedding);    \
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#define AssertEmbeddingHasExpectedValues(embedding, expectedLength, expectedFirstValue, quantize) \
 | 
				
			||||||
 | 
					  XCTAssertEqual(embedding.count, expectedLength);                                                \
 | 
				
			||||||
 | 
					  if (quantize) {                                                                                 \
 | 
				
			||||||
 | 
					    XCTAssertEqual(embedding[0].charValue, expectedFirstValue);                                   \
 | 
				
			||||||
 | 
					  } else {                                                                                        \
 | 
				
			||||||
 | 
					    XCTAssertEqualWithAccuracy(embedding[0].floatValue, expectedFirstValue, kFloatDiffTolerance); \
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@interface MPPTextEmbedderTests : XCTestCase
 | 
				
			||||||
 | 
					@end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@implementation MPPTextEmbedderTests
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (NSString *)filePathWithName:(NSString *)fileName extension:(NSString *)extension {
 | 
				
			||||||
 | 
					  return [[NSBundle bundleForClass:self.class] pathForResource:fileName
 | 
				
			||||||
 | 
					                                                                      ofType:extension];
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (MPPTextEmbedder *)textEmbedderFromModelFileWithName:(NSString *)modelName {
 | 
				
			||||||
 | 
					  NSString *modelPath = [self filePathWithName:modelName extension:@"tflite"];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  NSError *error = nil;
 | 
				
			||||||
 | 
					  MPPTextEmbedder *textEmbedder = [[MPPTextEmbedder alloc] initWithModelPath:modelPath
 | 
				
			||||||
 | 
					                                                                       error:&error];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  XCTAssertNotNil(textEmbedder);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return textEmbedder;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (MPPTextEmbedderOptions *)textEmbedderOptionsWithModelName:(NSString *)modelName {
 | 
				
			||||||
 | 
					  NSString *modelPath = [self filePathWithName:modelName extension:@"tflite"];
 | 
				
			||||||
 | 
					  MPPTextEmbedderOptions *textEmbedderOptions = [[MPPTextEmbedderOptions alloc] init];
 | 
				
			||||||
 | 
					  textEmbedderOptions.baseOptions.modelAssetPath = modelPath;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return textEmbedderOptions;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (MPPEmbedding *)assertFloatEmbeddingResultsOfEmbedText:(NSString *)text
 | 
				
			||||||
 | 
					                                       usingTextEmbedder:(MPPTextEmbedder *)textEmbedder
 | 
				
			||||||
 | 
					                                                hasCount:(NSUInteger)embeddingCount
 | 
				
			||||||
 | 
					                                              firstValue:(float)firstValue {
 | 
				
			||||||
 | 
					  MPPTextEmbedderResult *embedderResult = [textEmbedder embedText:text error:nil];
 | 
				
			||||||
 | 
					  AssertTextEmbedderResultHasOneEmbedding(embedderResult);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  AssertEmbeddingType(embedderResult.embeddingResult.embeddings[0],  // embedding
 | 
				
			||||||
 | 
					                      NO                                             // quantized
 | 
				
			||||||
 | 
					  );
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  AssertEmbeddingHasExpectedValues(
 | 
				
			||||||
 | 
					      embedderResult.embeddingResult.embeddings[0].floatEmbedding,  // embedding
 | 
				
			||||||
 | 
					      embeddingCount,                                               // expectedLength
 | 
				
			||||||
 | 
					      firstValue,                                                   // expectedFirstValue
 | 
				
			||||||
 | 
					      NO                                                            // quantize
 | 
				
			||||||
 | 
					  );
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return embedderResult.embeddingResult.embeddings[0];
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (MPPEmbedding *)assertQuantizedEmbeddingResultsOfEmbedText:(NSString *)text
 | 
				
			||||||
 | 
					                                           usingTextEmbedder:(MPPTextEmbedder *)textEmbedder
 | 
				
			||||||
 | 
					                                                    hasCount:(NSUInteger)embeddingCount
 | 
				
			||||||
 | 
					                                                  firstValue:(char)firstValue {
 | 
				
			||||||
 | 
					  MPPTextEmbedderResult *embedderResult = [textEmbedder embedText:text error:nil];
 | 
				
			||||||
 | 
					  AssertTextEmbedderResultHasOneEmbedding(embedderResult);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  AssertEmbeddingType(embedderResult.embeddingResult.embeddings[0],  // embedding
 | 
				
			||||||
 | 
					                      YES                                            // quantized
 | 
				
			||||||
 | 
					  );
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  AssertEmbeddingHasExpectedValues(
 | 
				
			||||||
 | 
					      embedderResult.embeddingResult.embeddings[0].quantizedEmbedding,  // embedding
 | 
				
			||||||
 | 
					      embeddingCount,                                                   // expectedLength
 | 
				
			||||||
 | 
					      firstValue,                                                       // expectedFirstValue
 | 
				
			||||||
 | 
					      YES                                                               // quantize
 | 
				
			||||||
 | 
					  );
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return embedderResult.embeddingResult.embeddings[0];
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (void)testCreateTextEmbedderFailsWithMissingModelPath {
 | 
				
			||||||
 | 
					  NSString *modelPath = [self filePathWithName:@"" extension:@""];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  NSError *error = nil;
 | 
				
			||||||
 | 
					  MPPTextEmbedder *textEmbedder = [[MPPTextEmbedder alloc] initWithModelPath:modelPath
 | 
				
			||||||
 | 
					                                                                       error:&error];
 | 
				
			||||||
 | 
					  XCTAssertNil(textEmbedder);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  NSError *expectedError = [NSError
 | 
				
			||||||
 | 
					      errorWithDomain:kExpectedErrorDomain
 | 
				
			||||||
 | 
					                 code:MPPTasksErrorCodeInvalidArgumentError
 | 
				
			||||||
 | 
					             userInfo:@{
 | 
				
			||||||
 | 
					               NSLocalizedDescriptionKey :
 | 
				
			||||||
 | 
					                   @"INVALID_ARGUMENT: ExternalFile must specify at least one of 'file_content', "
 | 
				
			||||||
 | 
					                   @"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."
 | 
				
			||||||
 | 
					             }];
 | 
				
			||||||
 | 
					  AssertEqualErrors(error,         // error
 | 
				
			||||||
 | 
					                    expectedError  // expectedError
 | 
				
			||||||
 | 
					  );
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (void)testEmbedWithBertSucceeds {
 | 
				
			||||||
 | 
					  MPPTextEmbedder *textEmbedder =
 | 
				
			||||||
 | 
					      [self textEmbedderFromModelFileWithName:kBertTextEmbedderModelName];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MPPEmbedding *embedding1 =
 | 
				
			||||||
 | 
					      [self assertFloatEmbeddingResultsOfEmbedText:kText1
 | 
				
			||||||
 | 
					                                 usingTextEmbedder:textEmbedder
 | 
				
			||||||
 | 
					                                          hasCount:512
 | 
				
			||||||
 | 
					                                                               firstValue:21.214869f];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MPPEmbedding *embedding2 = [self assertFloatEmbeddingResultsOfEmbedText:kText2
 | 
				
			||||||
 | 
					                                                        usingTextEmbedder:textEmbedder
 | 
				
			||||||
 | 
					                                                                 hasCount:512
 | 
				
			||||||
 | 
					                                                               firstValue:22.626251f];
 | 
				
			||||||
 | 
					  NSNumber *cosineSimilarity = [MPPTextEmbedder cosineSimilarityBetweenEmbedding1:embedding1
 | 
				
			||||||
 | 
					                                                                    andEmbedding2:embedding2
 | 
				
			||||||
 | 
					                                                                            error:nil];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  XCTAssertEqualWithAccuracy(cosineSimilarity.doubleValue, 0.971417490189,
 | 
				
			||||||
 | 
					  kSimilarityDiffTolerance);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (void)testEmbedWithRegexSucceeds {
 | 
				
			||||||
 | 
					  MPPTextEmbedder *textEmbedder =
 | 
				
			||||||
 | 
					      [self textEmbedderFromModelFileWithName:kRegexTextEmbedderModelName];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MPPEmbedding *embedding1 = [self assertFloatEmbeddingResultsOfEmbedText:kText1
 | 
				
			||||||
 | 
					                                                        usingTextEmbedder:textEmbedder
 | 
				
			||||||
 | 
					                                                                 hasCount:16
 | 
				
			||||||
 | 
					                                                               firstValue:0.030935612f];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MPPEmbedding *embedding2 = [self assertFloatEmbeddingResultsOfEmbedText:kText2
 | 
				
			||||||
 | 
					                                                        usingTextEmbedder:textEmbedder
 | 
				
			||||||
 | 
					                                                                 hasCount:16
 | 
				
			||||||
 | 
					                                                               firstValue:0.0312863f];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  NSNumber *cosineSimilarity = [MPPTextEmbedder cosineSimilarityBetweenEmbedding1:embedding1
 | 
				
			||||||
 | 
					                                                                    andEmbedding2:embedding2
 | 
				
			||||||
 | 
					                                                                            error:nil];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  XCTAssertEqualWithAccuracy(cosineSimilarity.doubleValue, 0.999937f, kSimilarityDiffTolerance);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (void)testEmbedWithBertAndDifferentThemesSucceeds {
 | 
				
			||||||
 | 
					  MPPTextEmbedder *textEmbedder =
 | 
				
			||||||
 | 
					      [self textEmbedderFromModelFileWithName:kBertTextEmbedderModelName];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MPPEmbedding *embedding1 =
 | 
				
			||||||
 | 
					      [self assertFloatEmbeddingResultsOfEmbedText:
 | 
				
			||||||
 | 
					                @"When you go to this restaurant, they hold the pancake upside-down before they "
 | 
				
			||||||
 | 
					                @"hand it to you. It's a great gimmick."
 | 
				
			||||||
 | 
					                                 usingTextEmbedder:textEmbedder
 | 
				
			||||||
 | 
					                                          hasCount:512
 | 
				
			||||||
 | 
					                                        firstValue:43.1663];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MPPEmbedding *embedding2 =
 | 
				
			||||||
 | 
					      [self assertFloatEmbeddingResultsOfEmbedText:
 | 
				
			||||||
 | 
					                @"Let's make a plan to steal the declaration of independence."
 | 
				
			||||||
 | 
					                                 usingTextEmbedder:textEmbedder
 | 
				
			||||||
 | 
					                                          hasCount:512
 | 
				
			||||||
 | 
					                                        firstValue:48.0898];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  NSNumber *cosineSimilarity = [MPPTextEmbedder cosineSimilarityBetweenEmbedding1:embedding1
 | 
				
			||||||
 | 
					                                                                    andEmbedding2:embedding2
 | 
				
			||||||
 | 
					                                                                            error:nil];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // TODO: The similarity should likely be lower
 | 
				
			||||||
 | 
					  XCTAssertEqualWithAccuracy(cosineSimilarity.doubleValue, 0.98151f, kSimilarityDiffTolerance);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (void)testEmbedWithQuantizeSucceeds {
 | 
				
			||||||
 | 
					  MPPTextEmbedderOptions *options =
 | 
				
			||||||
 | 
					      [self textEmbedderOptionsWithModelName:kBertTextEmbedderModelName];
 | 
				
			||||||
 | 
					  options.quantize = YES;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MPPTextEmbedder *textEmbedder = [[MPPTextEmbedder alloc] initWithOptions:options error:nil];
 | 
				
			||||||
 | 
					  XCTAssertNotNil(textEmbedder);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MPPEmbedding *embedding1 = [self
 | 
				
			||||||
 | 
					      assertQuantizedEmbeddingResultsOfEmbedText:@"it's a charming and often affecting journey"
 | 
				
			||||||
 | 
					                               usingTextEmbedder:textEmbedder
 | 
				
			||||||
 | 
					                                        hasCount:512
 | 
				
			||||||
 | 
					                                      firstValue:127];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MPPEmbedding *embedding2 =
 | 
				
			||||||
 | 
					      [self assertQuantizedEmbeddingResultsOfEmbedText:@"what a great and fantastic trip"
 | 
				
			||||||
 | 
					                                     usingTextEmbedder:textEmbedder
 | 
				
			||||||
 | 
					                                              hasCount:512
 | 
				
			||||||
 | 
					                                            firstValue:127];
 | 
				
			||||||
 | 
					  NSNumber *cosineSimilarity = [MPPTextEmbedder cosineSimilarityBetweenEmbedding1:embedding1
 | 
				
			||||||
 | 
					                                                                    andEmbedding2:embedding2
 | 
				
			||||||
 | 
					                                                                            error:nil];
 | 
				
			||||||
 | 
					  XCTAssertEqualWithAccuracy(cosineSimilarity.doubleValue, 0.88164f, kSimilarityDiffTolerance);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@end
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,121 @@
 | 
				
			||||||
 | 
					// Copyright 2023 The MediaPipe Authors.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					// you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					// You may obtain a copy of the License at
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//      http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					// distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					// See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					// limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import MPPCommon
 | 
				
			||||||
 | 
					import XCTest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@testable import MPPTextEmbedder
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/// These tests are only for validating the Swift function signatures of the TextEmbedder.
 | 
				
			||||||
 | 
					/// Objective C tests of the TextEmbedder provide more coverage with unit tests for
 | 
				
			||||||
 | 
					/// different models and text embedder options. They can be found here:
 | 
				
			||||||
 | 
					/// /mediapipe/tasks/ios/test/text/text_embedder/MPPTextEmbedderTests.m
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class TextEmbedderTests: XCTestCase {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  static let bundle = Bundle(for: TextEmbedderTests.self)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  static let bertModelPath = bundle.path(
 | 
				
			||||||
 | 
					    forResource: "mobilebert_embedding_with_metadata",
 | 
				
			||||||
 | 
					    ofType: "tflite")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  static let text1 = "it's a charming and often affecting journey"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  static let text2 = "what a great and fantastic trip"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  static let floatDiffTolerance: Float = 1e-4
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  static let doubleDiffTolerance: Double = 1e-4
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  func assertEqualErrorDescriptions(
 | 
				
			||||||
 | 
					    _ error: Error, expectedLocalizedDescription: String
 | 
				
			||||||
 | 
					  ) {
 | 
				
			||||||
 | 
					    XCTAssertEqual(
 | 
				
			||||||
 | 
					      error.localizedDescription,
 | 
				
			||||||
 | 
					      expectedLocalizedDescription)
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  func assertTextEmbedderResultHasOneEmbedding(
 | 
				
			||||||
 | 
					    _ textEmbedderResult: TextEmbedderResult
 | 
				
			||||||
 | 
					  ) {
 | 
				
			||||||
 | 
					    XCTAssertEqual(textEmbedderResult.embeddingResult.embeddings.count, 1)
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  func assertEmbeddingIsFloat(
 | 
				
			||||||
 | 
					    _ embedding: Embedding
 | 
				
			||||||
 | 
					  ) {
 | 
				
			||||||
 | 
					    XCTAssertNil(embedding.quantizedEmbedding)
 | 
				
			||||||
 | 
					    XCTAssertNotNil(embedding.floatEmbedding)
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  func assertEmbedding(
 | 
				
			||||||
 | 
					    _ floatEmbedding: [NSNumber],
 | 
				
			||||||
 | 
					    hasCount embeddingCount: Int,
 | 
				
			||||||
 | 
					    hasFirstValue firstValue: Float
 | 
				
			||||||
 | 
					  ) {
 | 
				
			||||||
 | 
					    XCTAssertEqual(floatEmbedding.count, embeddingCount)
 | 
				
			||||||
 | 
					    XCTAssertEqual(
 | 
				
			||||||
 | 
					      floatEmbedding[0].floatValue,
 | 
				
			||||||
 | 
					      firstValue,
 | 
				
			||||||
 | 
					      accuracy:
 | 
				
			||||||
 | 
					        TextEmbedderTests.floatDiffTolerance)
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  func assertFloatEmbeddingResultsForEmbed(
 | 
				
			||||||
 | 
					    text: String,
 | 
				
			||||||
 | 
					    using textEmbedder: TextEmbedder,
 | 
				
			||||||
 | 
					    hasCount embeddingCount: Int,
 | 
				
			||||||
 | 
					    hasFirstValue firstValue: Float
 | 
				
			||||||
 | 
					  ) throws -> Embedding {
 | 
				
			||||||
 | 
					    let textEmbedderResult =
 | 
				
			||||||
 | 
					      try XCTUnwrap(
 | 
				
			||||||
 | 
					        textEmbedder.embed(text: text))
 | 
				
			||||||
 | 
					    assertTextEmbedderResultHasOneEmbedding(textEmbedderResult)
 | 
				
			||||||
 | 
					    assertEmbeddingIsFloat(textEmbedderResult.embeddingResult.embeddings[0])
 | 
				
			||||||
 | 
					    assertEmbedding(
 | 
				
			||||||
 | 
					      textEmbedderResult.embeddingResult.embeddings[0].floatEmbedding!,
 | 
				
			||||||
 | 
					      hasCount: embeddingCount,
 | 
				
			||||||
 | 
					      hasFirstValue: firstValue)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return textEmbedderResult.embeddingResult.embeddings[0]
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  func testEmbedWithBertSucceeds() throws {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    let modelPath = try XCTUnwrap(TextEmbedderTests.bertModelPath)
 | 
				
			||||||
 | 
					    let textEmbedder = try XCTUnwrap(TextEmbedder(modelPath: modelPath))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    let embedding1 = try assertFloatEmbeddingResultsForEmbed(
 | 
				
			||||||
 | 
					      text: TextEmbedderTests.text1,
 | 
				
			||||||
 | 
					      using: textEmbedder,
 | 
				
			||||||
 | 
					      hasCount: 512,
 | 
				
			||||||
 | 
					      hasFirstValue: 21.214869)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    let embedding2 = try assertFloatEmbeddingResultsForEmbed(
 | 
				
			||||||
 | 
					      text: TextEmbedderTests.text2,
 | 
				
			||||||
 | 
					      using: textEmbedder,
 | 
				
			||||||
 | 
					      hasCount: 512,
 | 
				
			||||||
 | 
					      hasFirstValue: 22.626251)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    let cosineSimilarity = try XCTUnwrap(
 | 
				
			||||||
 | 
					      TextEmbedder.cosineSimilarity(
 | 
				
			||||||
 | 
					        embedding1: embedding1,
 | 
				
			||||||
 | 
					        embedding2: embedding2))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    XCTAssertEqual(
 | 
				
			||||||
 | 
					      cosineSimilarity.doubleValue,
 | 
				
			||||||
 | 
					      0.97141,
 | 
				
			||||||
 | 
					      accuracy: TextEmbedderTests.doubleDiffTolerance)
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -49,6 +49,7 @@ objc_library(
 | 
				
			||||||
        "//mediapipe/tasks/cc/text/text_embedder:text_embedder_graph",
 | 
					        "//mediapipe/tasks/cc/text/text_embedder:text_embedder_graph",
 | 
				
			||||||
        "//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
 | 
					        "//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
 | 
				
			||||||
        "//mediapipe/tasks/ios/common/utils:NSStringHelpers",
 | 
					        "//mediapipe/tasks/ios/common/utils:NSStringHelpers",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/ios/components/utils:MPPCosineSimilarity",
 | 
				
			||||||
        "//mediapipe/tasks/ios/core:MPPTaskInfo",
 | 
					        "//mediapipe/tasks/ios/core:MPPTaskInfo",
 | 
				
			||||||
        "//mediapipe/tasks/ios/core:MPPTaskOptions",
 | 
					        "//mediapipe/tasks/ios/core:MPPTaskOptions",
 | 
				
			||||||
        "//mediapipe/tasks/ios/core:MPPTextPacketCreator",
 | 
					        "//mediapipe/tasks/ios/core:MPPTextPacketCreator",
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -86,6 +86,24 @@ NS_SWIFT_NAME(TextEmbedder)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
- (instancetype)init NS_UNAVAILABLE;
 | 
					- (instancetype)init NS_UNAVAILABLE;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * Utility function to compute[cosine similarity](https://en.wikipedia.org/wiki/Cosine_similarity)
 | 
				
			||||||
 | 
					 * between two `MPPEmbedding` objects.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * @param embedding1 One of the two `MPPEmbedding`s between whom cosine similarity is to be
 | 
				
			||||||
 | 
					 * computed.
 | 
				
			||||||
 | 
					 * @param embedding2 One of the two `MPPEmbedding`s between whom cosine similarity is to be
 | 
				
			||||||
 | 
					 * computed.
 | 
				
			||||||
 | 
					 * @param error An optional error parameter populated when there is an error in calculating cosine
 | 
				
			||||||
 | 
					 * similarity between two embeddings.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * @return An `NSNumber` which holds the cosine similarity of type `double`.
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					+ (nullable NSNumber *)cosineSimilarityBetweenEmbedding1:(MPPEmbedding *)embedding1
 | 
				
			||||||
 | 
					                                           andEmbedding2:(MPPEmbedding *)embedding2
 | 
				
			||||||
 | 
					                                                   error:(NSError **)error
 | 
				
			||||||
 | 
					    NS_SWIFT_NAME(cosineSimilarity(embedding1:embedding2:));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
+ (instancetype)new NS_UNAVAILABLE;
 | 
					+ (instancetype)new NS_UNAVAILABLE;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@end
 | 
					@end
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -16,6 +16,7 @@
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h"
 | 
					#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h"
 | 
				
			||||||
#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h"
 | 
					#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h"
 | 
				
			||||||
 | 
					#import "mediapipe/tasks/ios/components/utils/sources/MPPCosineSimilarity.h"
 | 
				
			||||||
#import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h"
 | 
					#import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h"
 | 
				
			||||||
#import "mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.h"
 | 
					#import "mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.h"
 | 
				
			||||||
#import "mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.h"
 | 
					#import "mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.h"
 | 
				
			||||||
| 
						 | 
					@ -93,4 +94,12 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_embedder.Tex
 | 
				
			||||||
                                             .value()[kEmbeddingsOutStreamName.cppString]];
 | 
					                                             .value()[kEmbeddingsOutStreamName.cppString]];
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					+ (nullable NSNumber *)cosineSimilarityBetweenEmbedding1:(MPPEmbedding *)embedding1
 | 
				
			||||||
 | 
					                                           andEmbedding2:(MPPEmbedding *)embedding2
 | 
				
			||||||
 | 
					                                                   error:(NSError **)error {
 | 
				
			||||||
 | 
					  return [MPPCosineSimilarity computeBetweenEmbedding1:embedding1
 | 
				
			||||||
 | 
					                                         andEmbedding2:embedding2
 | 
				
			||||||
 | 
					                                                 error:error];
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@end
 | 
					@end
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user