diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/BUILD b/mediapipe/tasks/cc/text/language_detector/custom_ops/BUILD index 5e7c5afa5..090f528ef 100644 --- a/mediapipe/tasks/cc/text/language_detector/custom_ops/BUILD +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/BUILD @@ -42,3 +42,36 @@ cc_test( "@org_tensorflow//tensorflow/lite/kernels:test_util", ], ) + +cc_library( + name = "ngram_hash", + srcs = ["ngram_hash.cc"], + hdrs = ["ngram_hash.h"], + copts = tflite_copts(), + deps = [ + "//mediapipe/tasks/cc/text/language_detector/custom_ops/utils:ngram_hash_ops_utils", + "//mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash:murmur", + "@flatbuffers", + "@org_tensorflow//tensorflow/lite:string_util", + "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", + "@org_tensorflow//tensorflow/lite/kernels:kernel_util", + ], + alwayslink = 1, +) + +cc_test( + name = "ngram_hash_test", + srcs = ["ngram_hash_test.cc"], + deps = [ + ":ngram_hash", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash:murmur", + "@com_google_absl//absl/types:optional", + "@flatbuffers", + "@org_tensorflow//tensorflow/lite:framework", + "@org_tensorflow//tensorflow/lite:string_util", + "@org_tensorflow//tensorflow/lite/c:common", + "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", + "@org_tensorflow//tensorflow/lite/kernels:test_util", + ], +) diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.cc b/mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.cc new file mode 100644 index 000000000..738fa1128 --- /dev/null +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.cc @@ -0,0 +1,264 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h" + +#include +#include +#include + +#include "flatbuffers/flexbuffers.h" +#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/murmur.h" +#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/string_util.h" + +namespace tflite::ops::custom { + +namespace ngram_op { + +namespace { + +using ::flexbuffers::GetRoot; +using ::flexbuffers::Map; +using ::flexbuffers::TypedVector; +using ::mediapipe::tasks::text::language_detector::custom_ops:: + LowercaseUnicodeStr; +using ::mediapipe::tasks::text::language_detector::custom_ops::Tokenize; +using ::mediapipe::tasks::text::language_detector::custom_ops::TokenizedOutput; +using ::mediapipe::tasks::text::language_detector::custom_ops::hash:: + MurmurHash64WithSeed; +using ::tflite::GetString; +using ::tflite::StringRef; + +constexpr int kInputMessage = 0; +constexpr int kOutputLabel = 0; +constexpr int kDefaultMaxSplits = 128; + +// This op takes in a string, finds the character ngrams for it and then +// maps each of these ngrams to an index using the specified vocabulary sizes. + +// Input(s): +// - input: Input string. +// - seeds: Seed for the random number generator. +// - ngram_lengths: Lengths of each of the ngrams. For example [1, 2, 3] would +// be interpreted as generating unigrams, bigrams, and trigrams. +// - vocab_sizes: Size of the vocabulary for each of the ngram features +// respectively. The op would generate vocab ids to be less than or equal to +// the vocab size. The index 0 implies an invalid ngram. +// - max_splits: Maximum number of tokens in the output. If this is unset, the +// limit is `kDefaultMaxSplits`. +// - lower_case_input: If this is set to true, the input string would be +// lower-cased before any processing. + +// Output(s): +// - output: A tensor of size [number of ngrams, number of tokens + 2], +// where 2 tokens are reserved for the padding. If `max_splits` is set, this +// length is <= max_splits, otherwise it is <= `kDefaultMaxSplits`. + +// Helper class used for pre-processing the input. +class NGramHashParams { + public: + NGramHashParams(const uint64_t seed, const std::vector& ngram_lengths, + const std::vector& vocab_sizes, int max_splits, + bool lower_case_input) + : seed_(seed), + ngram_lengths_(ngram_lengths), + vocab_sizes_(vocab_sizes), + max_splits_(max_splits), + lower_case_input_(lower_case_input) {} + + TfLiteStatus PreprocessInput(const TfLiteTensor* input_t, + TfLiteContext* context) { + if (input_t->bytes == 0) { + context->ReportError(context, "Empty input not supported."); + return kTfLiteError; + } + + // Do sanity checks on the input. + if (ngram_lengths_.empty()) { + context->ReportError(context, "`ngram_lengths` must be non-empty."); + return kTfLiteError; + } + + if (vocab_sizes_.empty()) { + context->ReportError(context, "`vocab_sizes` must be non-empty."); + return kTfLiteError; + } + + if (ngram_lengths_.size() != vocab_sizes_.size()) { + context->ReportError( + context, + "Sizes of `ngram_lengths` and `vocab_sizes` must be the same."); + return kTfLiteError; + } + + if (max_splits_ <= 0) { + context->ReportError(context, "`max_splits` must be > 0."); + return kTfLiteError; + } + + // Obtain and tokenize the input. + StringRef inputref = GetString(input_t, /*string_index=*/0); + if (lower_case_input_) { + std::string lower_cased_str; + LowercaseUnicodeStr(inputref.str, inputref.len, &lower_cased_str); + + tokenized_output_ = + Tokenize(lower_cased_str.c_str(), inputref.len, max_splits_, + /*exclude_nonalphaspace_tokens=*/true); + } else { + tokenized_output_ = Tokenize(inputref.str, inputref.len, max_splits_, + /*exclude_nonalphaspace_tokens=*/true); + } + return kTfLiteOk; + } + uint64_t GetSeed() const { return seed_; } + + int GetNumTokens() const { return tokenized_output_.tokens.size(); } + + int GetNumNGrams() const { return ngram_lengths_.size(); } + + std::vector GetNGramLengths() const { return ngram_lengths_; } + + std::vector GetVocabSizes() const { return vocab_sizes_; } + + const TokenizedOutput& GetTokenizedOutput() const { + return tokenized_output_; + } + + TokenizedOutput tokenized_output_; + + private: + const uint64_t seed_; + std::vector ngram_lengths_; + std::vector vocab_sizes_; + const int max_splits_; + const bool lower_case_input_; +}; + +// Convert the TypedVector into a regular std::vector. +std::vector GetIntVector(TypedVector typed_vec) { + std::vector vec(typed_vec.size()); + for (int j = 0; j < typed_vec.size(); j++) { + vec[j] = typed_vec[j].AsInt32(); + } + return vec; +} + +void GetNGramHashIndices(NGramHashParams* params, int32_t* data) { + const int max_unicode_length = params->GetNumTokens(); + const auto ngram_lengths = params->GetNGramLengths(); + const auto vocab_sizes = params->GetVocabSizes(); + const auto& tokenized_output = params->GetTokenizedOutput(); + const auto seed = params->GetSeed(); + + // Compute for each ngram. + for (int ngram = 0; ngram < ngram_lengths.size(); ngram++) { + const int vocab_size = vocab_sizes[ngram]; + const int ngram_length = ngram_lengths[ngram]; + + // Compute for each token within the input. + for (int start = 0; start < tokenized_output.tokens.size(); start++) { + // Compute the number of bytes for the ngram starting at the given + // token. + int num_bytes = 0; + for (int i = start; + i < tokenized_output.tokens.size() && i < (start + ngram_length); + i++) { + num_bytes += tokenized_output.tokens[i].second; + } + + // Compute the hash for the ngram starting at the token. + const auto str_hash = MurmurHash64WithSeed( + tokenized_output.str.c_str() + tokenized_output.tokens[start].first, + num_bytes, seed); + + // Map the hash to an index in the vocab. + data[ngram * max_unicode_length + start] = (str_hash % vocab_size) + 1; + } + } +} + +} // namespace + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + const uint8_t* buffer_t = reinterpret_cast(buffer); + const Map& m = GetRoot(buffer_t, length).AsMap(); + + const uint64_t seed = m["seed"].AsUInt64(); + const std::vector ngram_lengths = + GetIntVector(m["ngram_lengths"].AsTypedVector()); + const std::vector vocab_sizes = + GetIntVector(m["vocab_sizes"].AsTypedVector()); + const int max_splits = + m["max_splits"].IsNull() ? kDefaultMaxSplits : m["max_splits"].AsInt32(); + const bool lowercase_input = + m["lowercase_input"].IsNull() ? true : m["lowercase_input"].AsBool(); + + return new NGramHashParams(seed, ngram_lengths, vocab_sizes, max_splits, + lowercase_input); +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + +TfLiteStatus Resize(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* output = GetOutput(context, node, kOutputLabel); + TF_LITE_ENSURE(context, output != nullptr); + SetTensorToDynamic(output); + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + NGramHashParams* params = reinterpret_cast(node->user_data); + TF_LITE_ENSURE_OK( + context, + params->PreprocessInput(GetInput(context, node, kInputMessage), context)); + + TfLiteTensor* output = GetOutput(context, node, kOutputLabel); + TF_LITE_ENSURE(context, output != nullptr); + if (IsDynamicTensor(output)) { + TfLiteIntArray* output_size = TfLiteIntArrayCreate(3); + output_size->data[0] = 1; + output_size->data[1] = params->GetNumNGrams(); + output_size->data[2] = params->GetNumTokens(); + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, output, output_size)); + } else { + context->ReportError(context, "Output must by dynamic."); + return kTfLiteError; + } + + if (output->type == kTfLiteInt32) { + GetNGramHashIndices(params, output->data.i32); + } else { + context->ReportError(context, "Output type must be Int32."); + return kTfLiteError; + } + + return kTfLiteOk; +} + +} // namespace ngram_op + +TfLiteRegistration* Register_NGRAM_HASH() { + static TfLiteRegistration r = {ngram_op::Init, ngram_op::Free, + ngram_op::Resize, ngram_op::Eval}; + return &r; +} + +} // namespace tflite::ops::custom diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h b/mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h new file mode 100644 index 000000000..a061357bd --- /dev/null +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h @@ -0,0 +1,27 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_NGRAM_HASH_H_ +#define MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_NGRAM_HASH_H_ + +#include "tensorflow/lite/kernels/register.h" + +namespace tflite::ops::custom { + +TfLiteRegistration* Register_NGRAM_HASH(); + +} // namespace tflite::ops::custom + +#endif // MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_NGRAM_HASH_H_ diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash_test.cc b/mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash_test.cc new file mode 100644 index 000000000..28d2dea6e --- /dev/null +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash_test.cc @@ -0,0 +1,313 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h" + +#include +#include +#include +#include + +#include "absl/types/optional.h" +#include "flatbuffers/flexbuffers.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/murmur.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" +#include "tensorflow/lite/string_util.h" + +namespace tflite::ops::custom { +namespace { + +using ::flexbuffers::Builder; +using ::mediapipe::tasks::text::language_detector::custom_ops::hash:: + MurmurHash64WithSeed; +using ::testing::ElementsAreArray; +using ::testing::Message; + +// Helper class for testing the op. +class NGramHashModel : public SingleOpModel { + public: + explicit NGramHashModel(const uint64_t seed, + const std::vector& ngram_lengths, + const std::vector& vocab_sizes, + const absl::optional max_splits = std::nullopt) { + // Setup the model inputs. + Builder fbb; + size_t start = fbb.StartMap(); + fbb.UInt("seed", seed); + { + size_t start = fbb.StartVector("ngram_lengths"); + for (const int& ngram_len : ngram_lengths) { + fbb.Int(ngram_len); + } + fbb.EndVector(start, /*typed=*/true, /*fixed=*/false); + } + { + size_t start = fbb.StartVector("vocab_sizes"); + for (const int& vocab_size : vocab_sizes) { + fbb.Int(vocab_size); + } + fbb.EndVector(start, /*typed=*/true, /*fixed=*/false); + } + if (max_splits) { + fbb.Int("max_splits", *max_splits); + } + fbb.EndMap(start); + fbb.Finish(); + output_ = AddOutput({TensorType_INT32, {}}); + SetCustomOp("NGramHash", fbb.GetBuffer(), Register_NGRAM_HASH); + BuildInterpreter({GetShape(input_)}); + } + + void SetupInputTensor(const std::string& input) { + PopulateStringTensor(input_, {input}); + CHECK(interpreter_->AllocateTensors() == kTfLiteOk) + << "Cannot allocate tensors"; + } + + void Invoke(const std::string& input) { + SetupInputTensor(input); + CHECK_EQ(SingleOpModel::Invoke(), kTfLiteOk); + } + + TfLiteStatus InvokeUnchecked(const std::string& input) { + SetupInputTensor(input); + return SingleOpModel::Invoke(); + } + + template + std::vector GetOutput() { + return ExtractVector(output_); + } + + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int input_ = AddInput(TensorType_STRING); + int output_; +}; + +TEST(NGramHashTest, ReturnsExpectedValueWhenInputIsSane) { + // Checks that the op returns the expected value when the input is sane. + // Also checks that when `max_splits` is not specified, the entire string is + // tokenized. + const uint64_t kSeed = 123; + const std::vector vocab_sizes({100, 200}); + std::vector ngram_lengths({1, 2}); + const std::vector testcase_inputs({ + "hi", + "wow", + "!", + "HI", + }); + + // A hash function that maps the given string to an index in the embedding + // table denoted by `vocab_idx`. + auto hash = [vocab_sizes](std::string str, const int vocab_idx) { + const auto hash_value = + MurmurHash64WithSeed(str.c_str(), str.size(), kSeed); + return static_cast((hash_value % vocab_sizes[vocab_idx]) + 1); + }; + const std::vector> expected_testcase_outputs( + {{ + // Unigram & Bigram output for "hi". + hash("^", 0), + hash("h", 0), + hash("i", 0), + hash("$", 0), + hash("^h", 1), + hash("hi", 1), + hash("i$", 1), + hash("$", 1), + }, + { + // Unigram & Bigram output for "wow". + hash("^", 0), + hash("w", 0), + hash("o", 0), + hash("w", 0), + hash("$", 0), + hash("^w", 1), + hash("wo", 1), + hash("ow", 1), + hash("w$", 1), + hash("$", 1), + }, + { + // Unigram & Bigram output for "!" (which will get replaced by " "). + hash("^", 0), + hash(" ", 0), + hash("$", 0), + hash("^ ", 1), + hash(" $", 1), + hash("$", 1), + }, + { + // Unigram & Bigram output for "HI" (which will get lower-cased). + hash("^", 0), + hash("h", 0), + hash("i", 0), + hash("$", 0), + hash("^h", 1), + hash("hi", 1), + hash("i$", 1), + hash("$", 1), + }}); + + NGramHashModel m(kSeed, ngram_lengths, vocab_sizes); + for (int test_idx = 0; test_idx < testcase_inputs.size(); test_idx++) { + const string& testcase_input = testcase_inputs[test_idx]; + m.Invoke(testcase_input); + SCOPED_TRACE(Message() << "Where the testcases' input is: " + << testcase_input); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(expected_testcase_outputs[test_idx])); + EXPECT_THAT(m.GetOutputShape(), + ElementsAreArray( + {/*batch_size=*/1, static_cast(ngram_lengths.size()), + static_cast(testcase_input.size()) + /*padding*/ 2})); + } +} + +TEST(NGramHashTest, ReturnsExpectedValueWhenMaxSplitsIsSpecified) { + // Checks that the op returns the expected value when the input is correct + // when `max_splits` is specified. + const uint64_t kSeed = 123; + const std::vector vocab_sizes({100, 200}); + std::vector ngram_lengths({1, 2}); + + const std::string testcase_input = "wow"; + const std::vector max_splits({2, 3, 4, 5, 6}); + + // A hash function that maps the given string to an index in the embedding + // table denoted by `vocab_idx`. + auto hash = [vocab_sizes](std::string str, const int vocab_idx) { + const auto hash_value = + MurmurHash64WithSeed(str.c_str(), str.size(), kSeed); + return static_cast((hash_value % vocab_sizes[vocab_idx]) + 1); + }; + + const std::vector> expected_testcase_outputs( + {{ + // Unigram & Bigram output for "wow", when `max_splits` == 2. + // We cannot include any of the actual tokens, since `max_splits` + // only allows enough space for the delimiters. + hash("^", 0), + hash("$", 0), + hash("^$", 1), + hash("$", 1), + }, + { + // Unigram & Bigram output for "wow", when `max_splits` == 3. + // We can start to include some tokens from the input string. + hash("^", 0), + hash("w", 0), + hash("$", 0), + hash("^w", 1), + hash("w$", 1), + hash("$", 1), + }, + { + // Unigram & Bigram output for "wow", when `max_splits` == 4. + hash("^", 0), + hash("w", 0), + hash("o", 0), + hash("$", 0), + hash("^w", 1), + hash("wo", 1), + hash("o$", 1), + hash("$", 1), + }, + { + // Unigram & Bigram output for "wow", when `max_splits` == 5. + // We can include the full input string. + hash("^", 0), + hash("w", 0), + hash("o", 0), + hash("w", 0), + hash("$", 0), + hash("^w", 1), + hash("wo", 1), + hash("ow", 1), + hash("w$", 1), + hash("$", 1), + }, + { + // Unigram & Bigram output for "wow", when `max_splits` == 6. + // `max_splits` is more than the full input string. + hash("^", 0), + hash("w", 0), + hash("o", 0), + hash("w", 0), + hash("$", 0), + hash("^w", 1), + hash("wo", 1), + hash("ow", 1), + hash("w$", 1), + hash("$", 1), + }}); + + for (int test_idx = 0; test_idx < max_splits.size(); test_idx++) { + const int testcase_max_splits = max_splits[test_idx]; + NGramHashModel m(kSeed, ngram_lengths, vocab_sizes, testcase_max_splits); + m.Invoke(testcase_input); + SCOPED_TRACE(Message() << "Where `max_splits` is: " << testcase_max_splits); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(expected_testcase_outputs[test_idx])); + EXPECT_THAT( + m.GetOutputShape(), + ElementsAreArray( + {/*batch_size=*/1, static_cast(ngram_lengths.size()), + std::min( + // Longest possible tokenization when using the entire + // input. + static_cast(testcase_input.size()) + /*padding*/ 2, + // Longest possible string when the `max_splits` value + // is < testcase_input.size() + 2 for padding. + testcase_max_splits)})); + } +} + +TEST(NGramHashTest, InvalidMaxSplitsValue) { + // Check that the op errors out when given an invalid max splits value. + const std::vector invalid_max_splits({0, -1, -5, -100}); + for (const int max_splits : invalid_max_splits) { + NGramHashModel m(/*seed=*/123, /*ngram_lengths=*/{100, 200}, + /*vocab_sizes=*/{1, 2}, /*max_splits=*/max_splits); + EXPECT_EQ(m.InvokeUnchecked("hi"), kTfLiteError); + } +} + +TEST(NGramHashTest, MismatchNgramLengthsAndVocabSizes) { + // Check that the op errors out when ngram lengths and vocab sizes mistmatch. + { + NGramHashModel m(/*seed=*/123, /*ngram_lengths=*/{100, 200, 300}, + /*vocab_sizes=*/{1, 2}); + EXPECT_EQ(m.InvokeUnchecked("hi"), kTfLiteError); + } + { + NGramHashModel m(/*seed=*/123, /*ngram_lengths=*/{100, 200}, + /*vocab_sizes=*/{1, 2, 3}); + EXPECT_EQ(m.InvokeUnchecked("hi"), kTfLiteError); + } +} + +} // namespace +} // namespace tflite::ops::custom