Added cosine similarity to MPPTextEmbedder
This commit is contained in:
parent
84e1c93ffb
commit
867520af1c
|
@ -49,6 +49,7 @@ objc_library(
|
||||||
"//mediapipe/tasks/cc/text/text_embedder:text_embedder_graph",
|
"//mediapipe/tasks/cc/text/text_embedder:text_embedder_graph",
|
||||||
"//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
|
"//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
|
||||||
"//mediapipe/tasks/ios/common/utils:NSStringHelpers",
|
"//mediapipe/tasks/ios/common/utils:NSStringHelpers",
|
||||||
|
"//mediapipe/tasks/ios/components/utils:MPPCosineSimilarity",
|
||||||
"//mediapipe/tasks/ios/core:MPPTaskInfo",
|
"//mediapipe/tasks/ios/core:MPPTaskInfo",
|
||||||
"//mediapipe/tasks/ios/core:MPPTaskOptions",
|
"//mediapipe/tasks/ios/core:MPPTaskOptions",
|
||||||
"//mediapipe/tasks/ios/core:MPPTextPacketCreator",
|
"//mediapipe/tasks/ios/core:MPPTextPacketCreator",
|
||||||
|
|
|
@ -29,7 +29,7 @@ NS_ASSUME_NONNULL_BEGIN
|
||||||
* Metadata is required for models with int32 input tensors because it contains the input process
|
* Metadata is required for models with int32 input tensors because it contains the input process
|
||||||
* unit for the model's Tokenizer. No metadata is required for models with string input tensors.
|
* unit for the model's Tokenizer. No metadata is required for models with string input tensors.
|
||||||
*
|
*
|
||||||
* Input tensors
|
* Input tensors:
|
||||||
* - Three input tensors `kTfLiteInt32` of shape `[batch_size x bert_max_seq_len]`
|
* - Three input tensors `kTfLiteInt32` of shape `[batch_size x bert_max_seq_len]`
|
||||||
* representing the input ids, mask ids, and segment ids. This input signature requires
|
* representing the input ids, mask ids, and segment ids. This input signature requires
|
||||||
* a Bert Tokenizer process unit in the model metadata.
|
* a Bert Tokenizer process unit in the model metadata.
|
||||||
|
@ -62,7 +62,7 @@ NS_SWIFT_NAME(TextEmbedder)
|
||||||
* Creates a new instance of `MPPTextEmbedder` from the given `MPPTextEmbedderOptions`.
|
* Creates a new instance of `MPPTextEmbedder` from the given `MPPTextEmbedderOptions`.
|
||||||
*
|
*
|
||||||
* @param options The options of type `MPPTextEmbedderOptions` to use for configuring the
|
* @param options The options of type `MPPTextEmbedderOptions` to use for configuring the
|
||||||
* `MPPTextEmbedder.
|
* `MPPTextEmbedder`.
|
||||||
* @param error An optional error parameter populated when there is an error in initializing the
|
* @param error An optional error parameter populated when there is an error in initializing the
|
||||||
* text embedder.
|
* text embedder.
|
||||||
*
|
*
|
||||||
|
@ -86,6 +86,23 @@ NS_SWIFT_NAME(TextEmbedder)
|
||||||
|
|
||||||
- (instancetype)init NS_UNAVAILABLE;
|
- (instancetype)init NS_UNAVAILABLE;
|
||||||
|
|
||||||
|
/** Utility function to compute[cosine similarity](https://en.wikipedia.org/wiki/Cosine_similarity)
|
||||||
|
* between two `MPPEmbedding` objects.
|
||||||
|
*
|
||||||
|
* @param embedding1 One of the two `MPPEmbedding`s between whom cosine similarity is to be
|
||||||
|
* computed.
|
||||||
|
* @param embedding2 One of the two `MPPEmbedding`s between whom cosine similarity is to be
|
||||||
|
* computed.
|
||||||
|
* @param error An optional error parameter populated when there is an error in calculating cosine
|
||||||
|
* similarity between two embeddings.
|
||||||
|
*
|
||||||
|
* @return An `NSNumber` which holds the cosine similarity of type `double`.
|
||||||
|
*/
|
||||||
|
+ (nullable NSNumber *)cosineSimilarityBetweenEmbedding1:(MPPEmbedding *)embedding1
|
||||||
|
andEmbedding2:(MPPEmbedding *)embedding2
|
||||||
|
error:(NSError **)error
|
||||||
|
NS_SWIFT_NAME(cosineSimilarity(embedding1: embedding2:));
|
||||||
|
|
||||||
+ (instancetype)new NS_UNAVAILABLE;
|
+ (instancetype)new NS_UNAVAILABLE;
|
||||||
|
|
||||||
@end
|
@end
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h"
|
#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h"
|
||||||
#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h"
|
#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h"
|
||||||
|
#import "mediapipe/tasks/ios/components/utils/sources/MPPCosineSimilarity.h"
|
||||||
#import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h"
|
#import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h"
|
||||||
#import "mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.h"
|
#import "mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.h"
|
||||||
#import "mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.h"
|
#import "mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.h"
|
||||||
|
@ -93,4 +94,12 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_embedder.Tex
|
||||||
.value()[kEmbeddingsOutStreamName.cppString]];
|
.value()[kEmbeddingsOutStreamName.cppString]];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
+ (nullable NSNumber *)cosineSimilarityBetweenEmbedding1:(MPPEmbedding *)embedding1
|
||||||
|
andEmbedding2:(MPPEmbedding *)embedding2
|
||||||
|
error:(NSError **)error {
|
||||||
|
return [MPPCosineSimilarity computeBetweenEmbedding1:embedding1
|
||||||
|
andEmbedding2:embedding2
|
||||||
|
error:error];
|
||||||
|
}
|
||||||
|
|
||||||
@end
|
@end
|
||||||
|
|
Loading…
Reference in New Issue
Block a user