diff --git a/mediapipe/tasks/cc/core/BUILD b/mediapipe/tasks/cc/core/BUILD index 7b2b97783..5aa9c9729 100644 --- a/mediapipe/tasks/cc/core/BUILD +++ b/mediapipe/tasks/cc/core/BUILD @@ -78,6 +78,7 @@ cc_library( hdrs = ["mediapipe_builtin_op_resolver.h"], deps = [ "//mediapipe/tasks/cc/text/custom_ops/ragged:ragged_tensor_to_tensor_tflite", + "//mediapipe/tasks/cc/text/custom_ops/sentencepiece:sentencepiece_tokenizer_tflite", "//mediapipe/tasks/cc/text/language_detector/custom_ops:kmeans_embedding_lookup", "//mediapipe/tasks/cc/text/language_detector/custom_ops:ngram_hash", "//mediapipe/util/tflite/operations:landmarks_to_transform_matrix", diff --git a/mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.cc b/mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.cc index ae64e33ef..80097fd09 100644 --- a/mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.cc +++ b/mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.cc @@ -16,6 +16,7 @@ limitations under the License. #include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h" #include "mediapipe/tasks/cc/text/custom_ops/ragged/ragged_tensor_to_tensor_tflite.h" +#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_tokenizer_tflite.h" #include "mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup.h" #include "mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h" #include "mediapipe/util/tflite/operations/landmarks_to_transform_matrix.h" @@ -51,6 +52,8 @@ MediaPipeBuiltinOpResolver::MediaPipeBuiltinOpResolver() { AddCustom("KmeansEmbeddingLookup", mediapipe::tflite_operations::Register_KmeansEmbeddingLookup()); // For the UniversalSentenceEncoder model. + AddCustom("TFSentencepieceTokenizeOp", + mediapipe::tflite_operations::Register_SENTENCEPIECE_TOKENIZER()); AddCustom("RaggedTensorToTensor", mediapipe::tflite_operations::Register_RAGGED_TENSOR_TO_TENSOR()); } diff --git a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/BUILD b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/BUILD index a1833ac54..19f843c4e 100644 --- a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/BUILD +++ b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/BUILD @@ -127,6 +127,26 @@ cc_library( ], ) +cc_library( + name = "sentencepiece_tokenizer_tflite", + srcs = ["sentencepiece_tokenizer_tflite.cc"], + hdrs = ["sentencepiece_tokenizer_tflite.h"], + visibility = [ + "//visibility:public", + ], + deps = + [ + ":optimized_encoder", + "@flatbuffers", + "@org_tensorflow//tensorflow/lite:framework", + "@org_tensorflow//tensorflow/lite:string_util", + "@org_tensorflow//tensorflow/lite/c:common", + "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", + "@org_tensorflow//tensorflow/lite/kernels:kernel_util", + "@org_tensorflow//tensorflow/lite/kernels/internal:tensor", + ], +) + cc_test( name = "optimized_encoder_test", srcs = [ diff --git a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_tokenizer_tflite.cc b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_tokenizer_tflite.cc new file mode 100644 index 000000000..468a3a54f --- /dev/null +++ b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_tokenizer_tflite.cc @@ -0,0 +1,129 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_tokenizer_tflite.h" + +#include "flatbuffers/flexbuffers.h" +#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/context.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/model.h" +#include "tensorflow/lite/string_util.h" + +namespace mediapipe::tflite_operations { +namespace sentencepiece::tokenizer { +namespace { + +using ::tflite::SetTensorToDynamic; + +constexpr int kSPModelIndex = 0; +constexpr int kInputIndex = 1; +constexpr int kAddBOSInput = 4; +constexpr int kAddEOSInput = 5; +constexpr int kReverseInput = 6; + +constexpr int kOutputValuesInd = 0; +constexpr int kOutputSplitsInd = 1; + +TfLiteIntArray* CreateSizeArray(const std::initializer_list& sizes) { + TfLiteIntArray* array_size = TfLiteIntArrayCreate(sizes.size()); + int index = 0; + for (const int size : sizes) { + array_size->data[index++] = size; + } + return array_size; +} +} // namespace + +// Initializes text encoder object from serialized parameters. +void* Initialize(TfLiteContext* /*context*/, const char* /*buffer*/, + size_t /*length*/) { + return nullptr; +} +void Free(TfLiteContext* /*context*/, void* /*buffer*/) {} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + // TODO: Add checks for input and output tensors. + TfLiteTensor& output_values = + context->tensors[node->outputs->data[kOutputValuesInd]]; + SetTensorToDynamic(&output_values); + + TfLiteTensor& output_splits = + context->tensors[node->outputs->data[kOutputSplitsInd]]; + SetTensorToDynamic(&output_splits); + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor& model_tensor = + context->tensors[node->inputs->data[kSPModelIndex]]; + const auto model_buffer_data = model_tensor.data.data; + const TfLiteTensor& input_text = + context->tensors[node->inputs->data[kInputIndex]]; + + const TfLiteTensor add_bos_tensor = + context->tensors[node->inputs->data[kAddBOSInput]]; + const bool add_bos = add_bos_tensor.data.b[0]; + const TfLiteTensor add_eos_tensor = + context->tensors[node->inputs->data[kAddEOSInput]]; + const bool add_eos = add_eos_tensor.data.b[0]; + const TfLiteTensor reverse_tensor = + context->tensors[node->inputs->data[kReverseInput]]; + const bool reverse = reverse_tensor.data.b[0]; + + std::vector encoded; + std::vector splits; + const int num_strings = tflite::GetStringCount(&input_text); + for (int i = 0; i < num_strings; ++i) { + const auto strref = tflite::GetString(&input_text, i); + const auto res = EncodeString(std::string(strref.str, strref.len), + model_buffer_data, add_bos, add_eos, reverse); + TF_LITE_ENSURE_MSG(context, res.type == EncoderResultType::SUCCESS, + "Sentencepiece conversion failed"); + std::copy(res.codes.begin(), res.codes.end(), std::back_inserter(encoded)); + splits.emplace_back(encoded.size()); + } + + TfLiteTensor& output_values = + context->tensors[node->outputs->data[kOutputValuesInd]]; + TF_LITE_ENSURE_OK(context, + context->ResizeTensor( + context, &output_values, + CreateSizeArray({static_cast(encoded.size())}))); + int32_t* output_values_flat = output_values.data.i32; + std::copy(encoded.begin(), encoded.end(), output_values_flat); + TfLiteTensor& output_splits = + context->tensors[node->outputs->data[kOutputSplitsInd]]; + TF_LITE_ENSURE_OK( + context, context->ResizeTensor( + context, &output_splits, + CreateSizeArray({static_cast(splits.size() + 1)}))); + int32_t* output_splits_flat = output_splits.data.i32; + *output_splits_flat = 0; + std::copy(splits.begin(), splits.end(), output_splits_flat + 1); + return kTfLiteOk; +} +} // namespace sentencepiece::tokenizer + +TfLiteRegistration* Register_SENTENCEPIECE_TOKENIZER() { + static TfLiteRegistration r = { + sentencepiece::tokenizer::Initialize, sentencepiece::tokenizer::Free, + sentencepiece::tokenizer::Prepare, sentencepiece::tokenizer::Eval}; + return &r; +} + +} // namespace mediapipe::tflite_operations diff --git a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_tokenizer_tflite.h b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_tokenizer_tflite.h new file mode 100644 index 000000000..8a9fa8aef --- /dev/null +++ b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_tokenizer_tflite.h @@ -0,0 +1,27 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_TOKENIZER_TFLITE_H_ +#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_TOKENIZER_TFLITE_H_ + +#include "tensorflow/lite/kernels/register.h" + +namespace mediapipe::tflite_operations { + +TfLiteRegistration* Register_SENTENCEPIECE_TOKENIZER(); + +} // namespace mediapipe::tflite_operations + +#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_TOKENIZER_TFLITE_H_ diff --git a/mediapipe/tasks/cc/text/text_embedder/text_embedder_test.cc b/mediapipe/tasks/cc/text/text_embedder/text_embedder_test.cc index 5e0be5578..474f0ca35 100644 --- a/mediapipe/tasks/cc/text/text_embedder/text_embedder_test.cc +++ b/mediapipe/tasks/cc/text/text_embedder/text_embedder_test.cc @@ -39,6 +39,8 @@ constexpr char kMobileBert[] = "mobilebert_embedding_with_metadata.tflite"; // Embedding model with regex preprocessing. constexpr char kRegexOneEmbeddingModel[] = "regex_one_embedding_with_metadata.tflite"; +constexpr char kUniversalSentenceEncoderModel[] = + "universal_sentence_encoder_qa_with_metadata.tflite"; // Tolerance for embedding vector coordinate values. constexpr float kEpsilon = 1e-4; @@ -147,6 +149,35 @@ TEST_F(EmbedderTest, SucceedsWithQuantization) { MP_ASSERT_OK(text_embedder->Close()); } +TEST(EmbedTest, SucceedsWithUniversalSentenceEncoderModel) { + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kUniversalSentenceEncoderModel); + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr text_embedder, + TextEmbedder::Create(std::move(options))); + + MP_ASSERT_OK_AND_ASSIGN( + auto result0, + text_embedder->Embed("it's a charming and often affecting journey")); + ASSERT_EQ(result0.embeddings.size(), 1); + ASSERT_EQ(result0.embeddings[0].float_embedding.size(), 100); + ASSERT_NEAR(result0.embeddings[0].float_embedding[0], 1.422951f, kEpsilon); + + MP_ASSERT_OK_AND_ASSIGN( + auto result1, text_embedder->Embed("what a great and fantastic trip")); + ASSERT_EQ(result1.embeddings.size(), 1); + ASSERT_EQ(result1.embeddings[0].float_embedding.size(), 100); + ASSERT_NEAR(result1.embeddings[0].float_embedding[0], 1.404664f, kEpsilon); + + // Check cosine similarity. + MP_ASSERT_OK_AND_ASSIGN( + double similarity, TextEmbedder::CosineSimilarity(result0.embeddings[0], + result1.embeddings[0])); + ASSERT_NEAR(similarity, 0.851961, kSimilarityTolerancy); + + MP_ASSERT_OK(text_embedder->Close()); +} + TEST_F(EmbedderTest, SucceedsWithMobileBertAndDifferentThemes) { auto options = std::make_unique(); options->base_options.model_asset_path = @@ -178,5 +209,31 @@ TEST_F(EmbedderTest, SucceedsWithMobileBertAndDifferentThemes) { MP_ASSERT_OK(text_embedder->Close()); } +TEST_F(EmbedderTest, SucceedsWithUSEAndDifferentThemes) { + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kUniversalSentenceEncoderModel); + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr text_embedder, + TextEmbedder::Create(std::move(options))); + + MP_ASSERT_OK_AND_ASSIGN( + TextEmbedderResult result0, + text_embedder->Embed("When you go to this restaurant, they hold the " + "pancake upside-down before they hand it " + "to you. It's a great gimmick.")); + MP_ASSERT_OK_AND_ASSIGN( + TextEmbedderResult result1, + text_embedder->Embed( + "Let's make a plan to steal the declaration of independence.")); + + // Check cosine similarity. + MP_ASSERT_OK_AND_ASSIGN( + double similarity, TextEmbedder::CosineSimilarity(result0.embeddings[0], + result1.embeddings[0])); + EXPECT_NEAR(similarity, 0.780334, kSimilarityTolerancy); + + MP_ASSERT_OK(text_embedder->Close()); +} + } // namespace } // namespace mediapipe::tasks::text::text_embedder