Internal MediaPipe Tasks change.

PiperOrigin-RevId: 524942203
This commit is contained in:
MediaPipe Team 2023-04-17 13:56:11 -07:00 committed by Copybara-Service
parent b147002b7e
commit 2564fec44c
6 changed files with 237 additions and 0 deletions

View File

@ -78,6 +78,7 @@ cc_library(
hdrs = ["mediapipe_builtin_op_resolver.h"],
deps = [
"//mediapipe/tasks/cc/text/custom_ops/ragged:ragged_tensor_to_tensor_tflite",
"//mediapipe/tasks/cc/text/custom_ops/sentencepiece:sentencepiece_tokenizer_tflite",
"//mediapipe/tasks/cc/text/language_detector/custom_ops:kmeans_embedding_lookup",
"//mediapipe/tasks/cc/text/language_detector/custom_ops:ngram_hash",
"//mediapipe/util/tflite/operations:landmarks_to_transform_matrix",

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
#include "mediapipe/tasks/cc/text/custom_ops/ragged/ragged_tensor_to_tensor_tflite.h"
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_tokenizer_tflite.h"
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup.h"
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h"
#include "mediapipe/util/tflite/operations/landmarks_to_transform_matrix.h"
@ -51,6 +52,8 @@ MediaPipeBuiltinOpResolver::MediaPipeBuiltinOpResolver() {
AddCustom("KmeansEmbeddingLookup",
mediapipe::tflite_operations::Register_KmeansEmbeddingLookup());
// For the UniversalSentenceEncoder model.
AddCustom("TFSentencepieceTokenizeOp",
mediapipe::tflite_operations::Register_SENTENCEPIECE_TOKENIZER());
AddCustom("RaggedTensorToTensor",
mediapipe::tflite_operations::Register_RAGGED_TENSOR_TO_TENSOR());
}

View File

@ -127,6 +127,26 @@ cc_library(
],
)
cc_library(
name = "sentencepiece_tokenizer_tflite",
srcs = ["sentencepiece_tokenizer_tflite.cc"],
hdrs = ["sentencepiece_tokenizer_tflite.h"],
visibility = [
"//visibility:public",
],
deps =
[
":optimized_encoder",
"@flatbuffers",
"@org_tensorflow//tensorflow/lite:framework",
"@org_tensorflow//tensorflow/lite:string_util",
"@org_tensorflow//tensorflow/lite/c:common",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
"@org_tensorflow//tensorflow/lite/kernels:kernel_util",
"@org_tensorflow//tensorflow/lite/kernels/internal:tensor",
],
)
cc_test(
name = "optimized_encoder_test",
srcs = [

View File

@ -0,0 +1,129 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_tokenizer_tflite.h"
#include "flatbuffers/flexbuffers.h"
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/context.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/string_util.h"
namespace mediapipe::tflite_operations {
namespace sentencepiece::tokenizer {
namespace {
using ::tflite::SetTensorToDynamic;
constexpr int kSPModelIndex = 0;
constexpr int kInputIndex = 1;
constexpr int kAddBOSInput = 4;
constexpr int kAddEOSInput = 5;
constexpr int kReverseInput = 6;
constexpr int kOutputValuesInd = 0;
constexpr int kOutputSplitsInd = 1;
TfLiteIntArray* CreateSizeArray(const std::initializer_list<int>& sizes) {
TfLiteIntArray* array_size = TfLiteIntArrayCreate(sizes.size());
int index = 0;
for (const int size : sizes) {
array_size->data[index++] = size;
}
return array_size;
}
} // namespace
// Initializes text encoder object from serialized parameters.
void* Initialize(TfLiteContext* /*context*/, const char* /*buffer*/,
size_t /*length*/) {
return nullptr;
}
void Free(TfLiteContext* /*context*/, void* /*buffer*/) {}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// TODO: Add checks for input and output tensors.
TfLiteTensor& output_values =
context->tensors[node->outputs->data[kOutputValuesInd]];
SetTensorToDynamic(&output_values);
TfLiteTensor& output_splits =
context->tensors[node->outputs->data[kOutputSplitsInd]];
SetTensorToDynamic(&output_splits);
return kTfLiteOk;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor& model_tensor =
context->tensors[node->inputs->data[kSPModelIndex]];
const auto model_buffer_data = model_tensor.data.data;
const TfLiteTensor& input_text =
context->tensors[node->inputs->data[kInputIndex]];
const TfLiteTensor add_bos_tensor =
context->tensors[node->inputs->data[kAddBOSInput]];
const bool add_bos = add_bos_tensor.data.b[0];
const TfLiteTensor add_eos_tensor =
context->tensors[node->inputs->data[kAddEOSInput]];
const bool add_eos = add_eos_tensor.data.b[0];
const TfLiteTensor reverse_tensor =
context->tensors[node->inputs->data[kReverseInput]];
const bool reverse = reverse_tensor.data.b[0];
std::vector<int32> encoded;
std::vector<int32> splits;
const int num_strings = tflite::GetStringCount(&input_text);
for (int i = 0; i < num_strings; ++i) {
const auto strref = tflite::GetString(&input_text, i);
const auto res = EncodeString(std::string(strref.str, strref.len),
model_buffer_data, add_bos, add_eos, reverse);
TF_LITE_ENSURE_MSG(context, res.type == EncoderResultType::SUCCESS,
"Sentencepiece conversion failed");
std::copy(res.codes.begin(), res.codes.end(), std::back_inserter(encoded));
splits.emplace_back(encoded.size());
}
TfLiteTensor& output_values =
context->tensors[node->outputs->data[kOutputValuesInd]];
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(
context, &output_values,
CreateSizeArray({static_cast<int>(encoded.size())})));
int32_t* output_values_flat = output_values.data.i32;
std::copy(encoded.begin(), encoded.end(), output_values_flat);
TfLiteTensor& output_splits =
context->tensors[node->outputs->data[kOutputSplitsInd]];
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(
context, &output_splits,
CreateSizeArray({static_cast<int>(splits.size() + 1)})));
int32_t* output_splits_flat = output_splits.data.i32;
*output_splits_flat = 0;
std::copy(splits.begin(), splits.end(), output_splits_flat + 1);
return kTfLiteOk;
}
} // namespace sentencepiece::tokenizer
TfLiteRegistration* Register_SENTENCEPIECE_TOKENIZER() {
static TfLiteRegistration r = {
sentencepiece::tokenizer::Initialize, sentencepiece::tokenizer::Free,
sentencepiece::tokenizer::Prepare, sentencepiece::tokenizer::Eval};
return &r;
}
} // namespace mediapipe::tflite_operations

View File

@ -0,0 +1,27 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_TOKENIZER_TFLITE_H_
#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_TOKENIZER_TFLITE_H_
#include "tensorflow/lite/kernels/register.h"
namespace mediapipe::tflite_operations {
TfLiteRegistration* Register_SENTENCEPIECE_TOKENIZER();
} // namespace mediapipe::tflite_operations
#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_TOKENIZER_TFLITE_H_

View File

@ -39,6 +39,8 @@ constexpr char kMobileBert[] = "mobilebert_embedding_with_metadata.tflite";
// Embedding model with regex preprocessing.
constexpr char kRegexOneEmbeddingModel[] =
"regex_one_embedding_with_metadata.tflite";
constexpr char kUniversalSentenceEncoderModel[] =
"universal_sentence_encoder_qa_with_metadata.tflite";
// Tolerance for embedding vector coordinate values.
constexpr float kEpsilon = 1e-4;
@ -147,6 +149,35 @@ TEST_F(EmbedderTest, SucceedsWithQuantization) {
MP_ASSERT_OK(text_embedder->Close());
}
TEST(EmbedTest, SucceedsWithUniversalSentenceEncoderModel) {
auto options = std::make_unique<TextEmbedderOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kUniversalSentenceEncoderModel);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextEmbedder> text_embedder,
TextEmbedder::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(
auto result0,
text_embedder->Embed("it's a charming and often affecting journey"));
ASSERT_EQ(result0.embeddings.size(), 1);
ASSERT_EQ(result0.embeddings[0].float_embedding.size(), 100);
ASSERT_NEAR(result0.embeddings[0].float_embedding[0], 1.422951f, kEpsilon);
MP_ASSERT_OK_AND_ASSIGN(
auto result1, text_embedder->Embed("what a great and fantastic trip"));
ASSERT_EQ(result1.embeddings.size(), 1);
ASSERT_EQ(result1.embeddings[0].float_embedding.size(), 100);
ASSERT_NEAR(result1.embeddings[0].float_embedding[0], 1.404664f, kEpsilon);
// Check cosine similarity.
MP_ASSERT_OK_AND_ASSIGN(
double similarity, TextEmbedder::CosineSimilarity(result0.embeddings[0],
result1.embeddings[0]));
ASSERT_NEAR(similarity, 0.851961, kSimilarityTolerancy);
MP_ASSERT_OK(text_embedder->Close());
}
TEST_F(EmbedderTest, SucceedsWithMobileBertAndDifferentThemes) {
auto options = std::make_unique<TextEmbedderOptions>();
options->base_options.model_asset_path =
@ -178,5 +209,31 @@ TEST_F(EmbedderTest, SucceedsWithMobileBertAndDifferentThemes) {
MP_ASSERT_OK(text_embedder->Close());
}
TEST_F(EmbedderTest, SucceedsWithUSEAndDifferentThemes) {
auto options = std::make_unique<TextEmbedderOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kUniversalSentenceEncoderModel);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextEmbedder> text_embedder,
TextEmbedder::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(
TextEmbedderResult result0,
text_embedder->Embed("When you go to this restaurant, they hold the "
"pancake upside-down before they hand it "
"to you. It's a great gimmick."));
MP_ASSERT_OK_AND_ASSIGN(
TextEmbedderResult result1,
text_embedder->Embed(
"Let's make a plan to steal the declaration of independence."));
// Check cosine similarity.
MP_ASSERT_OK_AND_ASSIGN(
double similarity, TextEmbedder::CosineSimilarity(result0.embeddings[0],
result1.embeddings[0]));
EXPECT_NEAR(similarity, 0.780334, kSimilarityTolerancy);
MP_ASSERT_OK(text_embedder->Close());
}
} // namespace
} // namespace mediapipe::tasks::text::text_embedder