From 867520af1c0c56d3a02987e110a733f6aaeca263 Mon Sep 17 00:00:00 2001 From: Prianka Liz Kariat Date: Thu, 2 Feb 2023 17:29:51 +0530 Subject: [PATCH] Added cosine similarity to MPPTextEmbedder --- mediapipe/tasks/ios/text/text_embedder/BUILD | 1 + .../text_embedder/sources/MPPTextEmbedder.h | 21 +++++++++++++++++-- .../text_embedder/sources/MPPTextEmbedder.mm | 9 ++++++++ 3 files changed, 29 insertions(+), 2 deletions(-) diff --git a/mediapipe/tasks/ios/text/text_embedder/BUILD b/mediapipe/tasks/ios/text/text_embedder/BUILD index 21226b012..b02b1a9b5 100644 --- a/mediapipe/tasks/ios/text/text_embedder/BUILD +++ b/mediapipe/tasks/ios/text/text_embedder/BUILD @@ -49,6 +49,7 @@ objc_library( "//mediapipe/tasks/cc/text/text_embedder:text_embedder_graph", "//mediapipe/tasks/ios/common/utils:MPPCommonUtils", "//mediapipe/tasks/ios/common/utils:NSStringHelpers", + "//mediapipe/tasks/ios/components/utils:MPPCosineSimilarity", "//mediapipe/tasks/ios/core:MPPTaskInfo", "//mediapipe/tasks/ios/core:MPPTaskOptions", "//mediapipe/tasks/ios/core:MPPTextPacketCreator", diff --git a/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.h b/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.h index a45ab6747..ba5958a72 100644 --- a/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.h +++ b/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.h @@ -29,7 +29,7 @@ NS_ASSUME_NONNULL_BEGIN * Metadata is required for models with int32 input tensors because it contains the input process * unit for the model's Tokenizer. No metadata is required for models with string input tensors. * - * Input tensors + * Input tensors: * - Three input tensors `kTfLiteInt32` of shape `[batch_size x bert_max_seq_len]` * representing the input ids, mask ids, and segment ids. This input signature requires * a Bert Tokenizer process unit in the model metadata. @@ -62,7 +62,7 @@ NS_SWIFT_NAME(TextEmbedder) * Creates a new instance of `MPPTextEmbedder` from the given `MPPTextEmbedderOptions`. * * @param options The options of type `MPPTextEmbedderOptions` to use for configuring the - * `MPPTextEmbedder. + * `MPPTextEmbedder`. * @param error An optional error parameter populated when there is an error in initializing the * text embedder. * @@ -86,6 +86,23 @@ NS_SWIFT_NAME(TextEmbedder) - (instancetype)init NS_UNAVAILABLE; +/** Utility function to compute[cosine similarity](https://en.wikipedia.org/wiki/Cosine_similarity) + * between two `MPPEmbedding` objects. + * + * @param embedding1 One of the two `MPPEmbedding`s between whom cosine similarity is to be + * computed. + * @param embedding2 One of the two `MPPEmbedding`s between whom cosine similarity is to be + * computed. + * @param error An optional error parameter populated when there is an error in calculating cosine + * similarity between two embeddings. + * + * @return An `NSNumber` which holds the cosine similarity of type `double`. + */ ++ (nullable NSNumber *)cosineSimilarityBetweenEmbedding1:(MPPEmbedding *)embedding1 + andEmbedding2:(MPPEmbedding *)embedding2 + error:(NSError **)error + NS_SWIFT_NAME(cosineSimilarity(embedding1: embedding2:)); + + (instancetype)new NS_UNAVAILABLE; @end diff --git a/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.mm b/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.mm index a9c811cdb..62eb882d3 100644 --- a/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.mm +++ b/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.mm @@ -16,6 +16,7 @@ #import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h" #import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" +#import "mediapipe/tasks/ios/components/utils/sources/MPPCosineSimilarity.h" #import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h" #import "mediapipe/tasks/ios/core/sources/MPPTextPacketCreator.h" #import "mediapipe/tasks/ios/text/core/sources/MPPTextTaskRunner.h" @@ -93,4 +94,12 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_embedder.Tex .value()[kEmbeddingsOutStreamName.cppString]]; } ++ (nullable NSNumber *)cosineSimilarityBetweenEmbedding1:(MPPEmbedding *)embedding1 + andEmbedding2:(MPPEmbedding *)embedding2 + error:(NSError **)error { + return [MPPCosineSimilarity computeBetweenEmbedding1:embedding1 + andEmbedding2:embedding2 + error:error]; +} + @end