Internal MediaPipe Tasks change.
PiperOrigin-RevId: 516881879
This commit is contained in:
		
							parent
							
								
									ce3cd94f45
								
							
						
					
					
						commit
						18d88c531a
					
				|  | @ -42,3 +42,36 @@ cc_test( | |||
|         "@org_tensorflow//tensorflow/lite/kernels:test_util", | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
| cc_library( | ||||
|     name = "ngram_hash", | ||||
|     srcs = ["ngram_hash.cc"], | ||||
|     hdrs = ["ngram_hash.h"], | ||||
|     copts = tflite_copts(), | ||||
|     deps = [ | ||||
|         "//mediapipe/tasks/cc/text/language_detector/custom_ops/utils:ngram_hash_ops_utils", | ||||
|         "//mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash:murmur", | ||||
|         "@flatbuffers", | ||||
|         "@org_tensorflow//tensorflow/lite:string_util", | ||||
|         "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", | ||||
|         "@org_tensorflow//tensorflow/lite/kernels:kernel_util", | ||||
|     ], | ||||
|     alwayslink = 1, | ||||
| ) | ||||
| 
 | ||||
| cc_test( | ||||
|     name = "ngram_hash_test", | ||||
|     srcs = ["ngram_hash_test.cc"], | ||||
|     deps = [ | ||||
|         ":ngram_hash", | ||||
|         "//mediapipe/framework/port:gtest_main", | ||||
|         "//mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash:murmur", | ||||
|         "@com_google_absl//absl/types:optional", | ||||
|         "@flatbuffers", | ||||
|         "@org_tensorflow//tensorflow/lite:framework", | ||||
|         "@org_tensorflow//tensorflow/lite:string_util", | ||||
|         "@org_tensorflow//tensorflow/lite/c:common", | ||||
|         "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", | ||||
|         "@org_tensorflow//tensorflow/lite/kernels:test_util", | ||||
|     ], | ||||
| ) | ||||
|  |  | |||
|  | @ -0,0 +1,264 @@ | |||
| /* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| ==============================================================================*/ | ||||
| 
 | ||||
| #include "mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h" | ||||
| 
 | ||||
| #include <cstdint> | ||||
| #include <string> | ||||
| #include <vector> | ||||
| 
 | ||||
| #include "flatbuffers/flexbuffers.h" | ||||
| #include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/murmur.h" | ||||
| #include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils.h" | ||||
| #include "tensorflow/lite/kernels/kernel_util.h" | ||||
| #include "tensorflow/lite/string_util.h" | ||||
| 
 | ||||
| namespace tflite::ops::custom { | ||||
| 
 | ||||
| namespace ngram_op { | ||||
| 
 | ||||
| namespace { | ||||
| 
 | ||||
| using ::flexbuffers::GetRoot; | ||||
| using ::flexbuffers::Map; | ||||
| using ::flexbuffers::TypedVector; | ||||
| using ::mediapipe::tasks::text::language_detector::custom_ops:: | ||||
|     LowercaseUnicodeStr; | ||||
| using ::mediapipe::tasks::text::language_detector::custom_ops::Tokenize; | ||||
| using ::mediapipe::tasks::text::language_detector::custom_ops::TokenizedOutput; | ||||
| using ::mediapipe::tasks::text::language_detector::custom_ops::hash:: | ||||
|     MurmurHash64WithSeed; | ||||
| using ::tflite::GetString; | ||||
| using ::tflite::StringRef; | ||||
| 
 | ||||
| constexpr int kInputMessage = 0; | ||||
| constexpr int kOutputLabel = 0; | ||||
| constexpr int kDefaultMaxSplits = 128; | ||||
| 
 | ||||
| // This op takes in a string, finds the character ngrams for it and then
 | ||||
| // maps each of these ngrams to an index using the specified vocabulary sizes.
 | ||||
| 
 | ||||
| // Input(s):
 | ||||
| // - input: Input string.
 | ||||
| // - seeds: Seed for the random number generator.
 | ||||
| // - ngram_lengths: Lengths of each of the ngrams. For example [1, 2, 3] would
 | ||||
| //   be interpreted as generating unigrams, bigrams, and trigrams.
 | ||||
| // - vocab_sizes: Size of the vocabulary for each of the ngram features
 | ||||
| //   respectively. The op would generate vocab ids to be less than or equal to
 | ||||
| //   the vocab size. The index 0 implies an invalid ngram.
 | ||||
| // - max_splits: Maximum number of tokens in the output. If this is unset, the
 | ||||
| //   limit is `kDefaultMaxSplits`.
 | ||||
| // - lower_case_input: If this is set to true, the input string would be
 | ||||
| //   lower-cased before any processing.
 | ||||
| 
 | ||||
| // Output(s):
 | ||||
| // - output: A tensor of size [number of ngrams, number of tokens + 2],
 | ||||
| //   where 2 tokens are reserved for the padding. If `max_splits` is set, this
 | ||||
| //   length is <= max_splits, otherwise it is <= `kDefaultMaxSplits`.
 | ||||
| 
 | ||||
| // Helper class used for pre-processing the input.
 | ||||
| class NGramHashParams { | ||||
|  public: | ||||
|   NGramHashParams(const uint64_t seed, const std::vector<int>& ngram_lengths, | ||||
|                   const std::vector<int>& vocab_sizes, int max_splits, | ||||
|                   bool lower_case_input) | ||||
|       : seed_(seed), | ||||
|         ngram_lengths_(ngram_lengths), | ||||
|         vocab_sizes_(vocab_sizes), | ||||
|         max_splits_(max_splits), | ||||
|         lower_case_input_(lower_case_input) {} | ||||
| 
 | ||||
|   TfLiteStatus PreprocessInput(const TfLiteTensor* input_t, | ||||
|                                TfLiteContext* context) { | ||||
|     if (input_t->bytes == 0) { | ||||
|       context->ReportError(context, "Empty input not supported."); | ||||
|       return kTfLiteError; | ||||
|     } | ||||
| 
 | ||||
|     // Do sanity checks on the input.
 | ||||
|     if (ngram_lengths_.empty()) { | ||||
|       context->ReportError(context, "`ngram_lengths` must be non-empty."); | ||||
|       return kTfLiteError; | ||||
|     } | ||||
| 
 | ||||
|     if (vocab_sizes_.empty()) { | ||||
|       context->ReportError(context, "`vocab_sizes` must be non-empty."); | ||||
|       return kTfLiteError; | ||||
|     } | ||||
| 
 | ||||
|     if (ngram_lengths_.size() != vocab_sizes_.size()) { | ||||
|       context->ReportError( | ||||
|           context, | ||||
|           "Sizes of `ngram_lengths` and `vocab_sizes` must be the same."); | ||||
|       return kTfLiteError; | ||||
|     } | ||||
| 
 | ||||
|     if (max_splits_ <= 0) { | ||||
|       context->ReportError(context, "`max_splits` must be > 0."); | ||||
|       return kTfLiteError; | ||||
|     } | ||||
| 
 | ||||
|     // Obtain and tokenize the input.
 | ||||
|     StringRef inputref = GetString(input_t, /*string_index=*/0); | ||||
|     if (lower_case_input_) { | ||||
|       std::string lower_cased_str; | ||||
|       LowercaseUnicodeStr(inputref.str, inputref.len, &lower_cased_str); | ||||
| 
 | ||||
|       tokenized_output_ = | ||||
|           Tokenize(lower_cased_str.c_str(), inputref.len, max_splits_, | ||||
|                    /*exclude_nonalphaspace_tokens=*/true); | ||||
|     } else { | ||||
|       tokenized_output_ = Tokenize(inputref.str, inputref.len, max_splits_, | ||||
|                                    /*exclude_nonalphaspace_tokens=*/true); | ||||
|     } | ||||
|     return kTfLiteOk; | ||||
|   } | ||||
|   uint64_t GetSeed() const { return seed_; } | ||||
| 
 | ||||
|   int GetNumTokens() const { return tokenized_output_.tokens.size(); } | ||||
| 
 | ||||
|   int GetNumNGrams() const { return ngram_lengths_.size(); } | ||||
| 
 | ||||
|   std::vector<int> GetNGramLengths() const { return ngram_lengths_; } | ||||
| 
 | ||||
|   std::vector<int> GetVocabSizes() const { return vocab_sizes_; } | ||||
| 
 | ||||
|   const TokenizedOutput& GetTokenizedOutput() const { | ||||
|     return tokenized_output_; | ||||
|   } | ||||
| 
 | ||||
|   TokenizedOutput tokenized_output_; | ||||
| 
 | ||||
|  private: | ||||
|   const uint64_t seed_; | ||||
|   std::vector<int> ngram_lengths_; | ||||
|   std::vector<int> vocab_sizes_; | ||||
|   const int max_splits_; | ||||
|   const bool lower_case_input_; | ||||
| }; | ||||
| 
 | ||||
| // Convert the TypedVector into a regular std::vector.
 | ||||
| std::vector<int> GetIntVector(TypedVector typed_vec) { | ||||
|   std::vector<int> vec(typed_vec.size()); | ||||
|   for (int j = 0; j < typed_vec.size(); j++) { | ||||
|     vec[j] = typed_vec[j].AsInt32(); | ||||
|   } | ||||
|   return vec; | ||||
| } | ||||
| 
 | ||||
| void GetNGramHashIndices(NGramHashParams* params, int32_t* data) { | ||||
|   const int max_unicode_length = params->GetNumTokens(); | ||||
|   const auto ngram_lengths = params->GetNGramLengths(); | ||||
|   const auto vocab_sizes = params->GetVocabSizes(); | ||||
|   const auto& tokenized_output = params->GetTokenizedOutput(); | ||||
|   const auto seed = params->GetSeed(); | ||||
| 
 | ||||
|   // Compute for each ngram.
 | ||||
|   for (int ngram = 0; ngram < ngram_lengths.size(); ngram++) { | ||||
|     const int vocab_size = vocab_sizes[ngram]; | ||||
|     const int ngram_length = ngram_lengths[ngram]; | ||||
| 
 | ||||
|     // Compute for each token within the input.
 | ||||
|     for (int start = 0; start < tokenized_output.tokens.size(); start++) { | ||||
|       // Compute the number of bytes for the ngram starting at the given
 | ||||
|       // token.
 | ||||
|       int num_bytes = 0; | ||||
|       for (int i = start; | ||||
|            i < tokenized_output.tokens.size() && i < (start + ngram_length); | ||||
|            i++) { | ||||
|         num_bytes += tokenized_output.tokens[i].second; | ||||
|       } | ||||
| 
 | ||||
|       // Compute the hash for the ngram starting at the token.
 | ||||
|       const auto str_hash = MurmurHash64WithSeed( | ||||
|           tokenized_output.str.c_str() + tokenized_output.tokens[start].first, | ||||
|           num_bytes, seed); | ||||
| 
 | ||||
|       // Map the hash to an index in the vocab.
 | ||||
|       data[ngram * max_unicode_length + start] = (str_hash % vocab_size) + 1; | ||||
|     } | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| }  // namespace
 | ||||
| 
 | ||||
| void* Init(TfLiteContext* context, const char* buffer, size_t length) { | ||||
|   const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer); | ||||
|   const Map& m = GetRoot(buffer_t, length).AsMap(); | ||||
| 
 | ||||
|   const uint64_t seed = m["seed"].AsUInt64(); | ||||
|   const std::vector<int> ngram_lengths = | ||||
|       GetIntVector(m["ngram_lengths"].AsTypedVector()); | ||||
|   const std::vector<int> vocab_sizes = | ||||
|       GetIntVector(m["vocab_sizes"].AsTypedVector()); | ||||
|   const int max_splits = | ||||
|       m["max_splits"].IsNull() ? kDefaultMaxSplits : m["max_splits"].AsInt32(); | ||||
|   const bool lowercase_input = | ||||
|       m["lowercase_input"].IsNull() ? true : m["lowercase_input"].AsBool(); | ||||
| 
 | ||||
|   return new NGramHashParams(seed, ngram_lengths, vocab_sizes, max_splits, | ||||
|                              lowercase_input); | ||||
| } | ||||
| 
 | ||||
| void Free(TfLiteContext* context, void* buffer) { | ||||
|   delete reinterpret_cast<NGramHashParams*>(buffer); | ||||
| } | ||||
| 
 | ||||
| TfLiteStatus Resize(TfLiteContext* context, TfLiteNode* node) { | ||||
|   TfLiteTensor* output = GetOutput(context, node, kOutputLabel); | ||||
|   TF_LITE_ENSURE(context, output != nullptr); | ||||
|   SetTensorToDynamic(output); | ||||
|   return kTfLiteOk; | ||||
| } | ||||
| 
 | ||||
| TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { | ||||
|   NGramHashParams* params = reinterpret_cast<NGramHashParams*>(node->user_data); | ||||
|   TF_LITE_ENSURE_OK( | ||||
|       context, | ||||
|       params->PreprocessInput(GetInput(context, node, kInputMessage), context)); | ||||
| 
 | ||||
|   TfLiteTensor* output = GetOutput(context, node, kOutputLabel); | ||||
|   TF_LITE_ENSURE(context, output != nullptr); | ||||
|   if (IsDynamicTensor(output)) { | ||||
|     TfLiteIntArray* output_size = TfLiteIntArrayCreate(3); | ||||
|     output_size->data[0] = 1; | ||||
|     output_size->data[1] = params->GetNumNGrams(); | ||||
|     output_size->data[2] = params->GetNumTokens(); | ||||
|     TF_LITE_ENSURE_OK(context, | ||||
|                       context->ResizeTensor(context, output, output_size)); | ||||
|   } else { | ||||
|     context->ReportError(context, "Output must by dynamic."); | ||||
|     return kTfLiteError; | ||||
|   } | ||||
| 
 | ||||
|   if (output->type == kTfLiteInt32) { | ||||
|     GetNGramHashIndices(params, output->data.i32); | ||||
|   } else { | ||||
|     context->ReportError(context, "Output type must be Int32."); | ||||
|     return kTfLiteError; | ||||
|   } | ||||
| 
 | ||||
|   return kTfLiteOk; | ||||
| } | ||||
| 
 | ||||
| }  // namespace ngram_op
 | ||||
| 
 | ||||
| TfLiteRegistration* Register_NGRAM_HASH() { | ||||
|   static TfLiteRegistration r = {ngram_op::Init, ngram_op::Free, | ||||
|                                  ngram_op::Resize, ngram_op::Eval}; | ||||
|   return &r; | ||||
| } | ||||
| 
 | ||||
| }  // namespace tflite::ops::custom
 | ||||
|  | @ -0,0 +1,27 @@ | |||
| /* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| ==============================================================================*/ | ||||
| 
 | ||||
| #ifndef MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_NGRAM_HASH_H_ | ||||
| #define MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_NGRAM_HASH_H_ | ||||
| 
 | ||||
| #include "tensorflow/lite/kernels/register.h" | ||||
| 
 | ||||
| namespace tflite::ops::custom { | ||||
| 
 | ||||
| TfLiteRegistration* Register_NGRAM_HASH(); | ||||
| 
 | ||||
| }  // namespace tflite::ops::custom
 | ||||
| 
 | ||||
| #endif  // MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_NGRAM_HASH_H_
 | ||||
|  | @ -0,0 +1,313 @@ | |||
| /* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| ==============================================================================*/ | ||||
| 
 | ||||
| #include "mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h" | ||||
| 
 | ||||
| #include <cstdint> | ||||
| #include <optional> | ||||
| #include <string> | ||||
| #include <vector> | ||||
| 
 | ||||
| #include "absl/types/optional.h" | ||||
| #include "flatbuffers/flexbuffers.h" | ||||
| #include "mediapipe/framework/port/gmock.h" | ||||
| #include "mediapipe/framework/port/gtest.h" | ||||
| #include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/murmur.h" | ||||
| #include "tensorflow/lite/c/common.h" | ||||
| #include "tensorflow/lite/interpreter.h" | ||||
| #include "tensorflow/lite/kernels/register.h" | ||||
| #include "tensorflow/lite/kernels/test_util.h" | ||||
| #include "tensorflow/lite/model.h" | ||||
| #include "tensorflow/lite/string_util.h" | ||||
| 
 | ||||
| namespace tflite::ops::custom { | ||||
| namespace { | ||||
| 
 | ||||
| using ::flexbuffers::Builder; | ||||
| using ::mediapipe::tasks::text::language_detector::custom_ops::hash:: | ||||
|     MurmurHash64WithSeed; | ||||
| using ::testing::ElementsAreArray; | ||||
| using ::testing::Message; | ||||
| 
 | ||||
| // Helper class for testing the op.
 | ||||
| class NGramHashModel : public SingleOpModel { | ||||
|  public: | ||||
|   explicit NGramHashModel(const uint64_t seed, | ||||
|                           const std::vector<int>& ngram_lengths, | ||||
|                           const std::vector<int>& vocab_sizes, | ||||
|                           const absl::optional<int> max_splits = std::nullopt) { | ||||
|     // Setup the model inputs.
 | ||||
|     Builder fbb; | ||||
|     size_t start = fbb.StartMap(); | ||||
|     fbb.UInt("seed", seed); | ||||
|     { | ||||
|       size_t start = fbb.StartVector("ngram_lengths"); | ||||
|       for (const int& ngram_len : ngram_lengths) { | ||||
|         fbb.Int(ngram_len); | ||||
|       } | ||||
|       fbb.EndVector(start, /*typed=*/true, /*fixed=*/false); | ||||
|     } | ||||
|     { | ||||
|       size_t start = fbb.StartVector("vocab_sizes"); | ||||
|       for (const int& vocab_size : vocab_sizes) { | ||||
|         fbb.Int(vocab_size); | ||||
|       } | ||||
|       fbb.EndVector(start, /*typed=*/true, /*fixed=*/false); | ||||
|     } | ||||
|     if (max_splits) { | ||||
|       fbb.Int("max_splits", *max_splits); | ||||
|     } | ||||
|     fbb.EndMap(start); | ||||
|     fbb.Finish(); | ||||
|     output_ = AddOutput({TensorType_INT32, {}}); | ||||
|     SetCustomOp("NGramHash", fbb.GetBuffer(), Register_NGRAM_HASH); | ||||
|     BuildInterpreter({GetShape(input_)}); | ||||
|   } | ||||
| 
 | ||||
|   void SetupInputTensor(const std::string& input) { | ||||
|     PopulateStringTensor(input_, {input}); | ||||
|     CHECK(interpreter_->AllocateTensors() == kTfLiteOk) | ||||
|         << "Cannot allocate tensors"; | ||||
|   } | ||||
| 
 | ||||
|   void Invoke(const std::string& input) { | ||||
|     SetupInputTensor(input); | ||||
|     CHECK_EQ(SingleOpModel::Invoke(), kTfLiteOk); | ||||
|   } | ||||
| 
 | ||||
|   TfLiteStatus InvokeUnchecked(const std::string& input) { | ||||
|     SetupInputTensor(input); | ||||
|     return SingleOpModel::Invoke(); | ||||
|   } | ||||
| 
 | ||||
|   template <typename T> | ||||
|   std::vector<T> GetOutput() { | ||||
|     return ExtractVector<T>(output_); | ||||
|   } | ||||
| 
 | ||||
|   std::vector<int> GetOutputShape() { return GetTensorShape(output_); } | ||||
| 
 | ||||
|  private: | ||||
|   int input_ = AddInput(TensorType_STRING); | ||||
|   int output_; | ||||
| }; | ||||
| 
 | ||||
| TEST(NGramHashTest, ReturnsExpectedValueWhenInputIsSane) { | ||||
|   // Checks that the op returns the expected value when the input is sane.
 | ||||
|   // Also checks that when `max_splits` is not specified, the entire string is
 | ||||
|   // tokenized.
 | ||||
|   const uint64_t kSeed = 123; | ||||
|   const std::vector<int> vocab_sizes({100, 200}); | ||||
|   std::vector<int> ngram_lengths({1, 2}); | ||||
|   const std::vector<std::string> testcase_inputs({ | ||||
|       "hi", | ||||
|       "wow", | ||||
|       "!", | ||||
|       "HI", | ||||
|   }); | ||||
| 
 | ||||
|   // A hash function that maps the given string to an index in the embedding
 | ||||
|   // table denoted by `vocab_idx`.
 | ||||
|   auto hash = [vocab_sizes](std::string str, const int vocab_idx) { | ||||
|     const auto hash_value = | ||||
|         MurmurHash64WithSeed(str.c_str(), str.size(), kSeed); | ||||
|     return static_cast<int>((hash_value % vocab_sizes[vocab_idx]) + 1); | ||||
|   }; | ||||
|   const std::vector<std::vector<int>> expected_testcase_outputs( | ||||
|       {{ | ||||
|            // Unigram & Bigram output for "hi".
 | ||||
|            hash("^", 0), | ||||
|            hash("h", 0), | ||||
|            hash("i", 0), | ||||
|            hash("$", 0), | ||||
|            hash("^h", 1), | ||||
|            hash("hi", 1), | ||||
|            hash("i$", 1), | ||||
|            hash("$", 1), | ||||
|        }, | ||||
|        { | ||||
|            // Unigram & Bigram output for "wow".
 | ||||
|            hash("^", 0), | ||||
|            hash("w", 0), | ||||
|            hash("o", 0), | ||||
|            hash("w", 0), | ||||
|            hash("$", 0), | ||||
|            hash("^w", 1), | ||||
|            hash("wo", 1), | ||||
|            hash("ow", 1), | ||||
|            hash("w$", 1), | ||||
|            hash("$", 1), | ||||
|        }, | ||||
|        { | ||||
|            // Unigram & Bigram output for "!" (which will get replaced by " ").
 | ||||
|            hash("^", 0), | ||||
|            hash(" ", 0), | ||||
|            hash("$", 0), | ||||
|            hash("^ ", 1), | ||||
|            hash(" $", 1), | ||||
|            hash("$", 1), | ||||
|        }, | ||||
|        { | ||||
|            // Unigram & Bigram output for "HI" (which will get lower-cased).
 | ||||
|            hash("^", 0), | ||||
|            hash("h", 0), | ||||
|            hash("i", 0), | ||||
|            hash("$", 0), | ||||
|            hash("^h", 1), | ||||
|            hash("hi", 1), | ||||
|            hash("i$", 1), | ||||
|            hash("$", 1), | ||||
|        }}); | ||||
| 
 | ||||
|   NGramHashModel m(kSeed, ngram_lengths, vocab_sizes); | ||||
|   for (int test_idx = 0; test_idx < testcase_inputs.size(); test_idx++) { | ||||
|     const string& testcase_input = testcase_inputs[test_idx]; | ||||
|     m.Invoke(testcase_input); | ||||
|     SCOPED_TRACE(Message() << "Where the testcases' input is: " | ||||
|                            << testcase_input); | ||||
|     EXPECT_THAT(m.GetOutput<int>(), | ||||
|                 ElementsAreArray(expected_testcase_outputs[test_idx])); | ||||
|     EXPECT_THAT(m.GetOutputShape(), | ||||
|                 ElementsAreArray( | ||||
|                     {/*batch_size=*/1, static_cast<int>(ngram_lengths.size()), | ||||
|                      static_cast<int>(testcase_input.size()) + /*padding*/ 2})); | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| TEST(NGramHashTest, ReturnsExpectedValueWhenMaxSplitsIsSpecified) { | ||||
|   // Checks that the op returns the expected value when the input is correct
 | ||||
|   // when `max_splits` is specified.
 | ||||
|   const uint64_t kSeed = 123; | ||||
|   const std::vector<int> vocab_sizes({100, 200}); | ||||
|   std::vector<int> ngram_lengths({1, 2}); | ||||
| 
 | ||||
|   const std::string testcase_input = "wow"; | ||||
|   const std::vector<int> max_splits({2, 3, 4, 5, 6}); | ||||
| 
 | ||||
|   // A hash function that maps the given string to an index in the embedding
 | ||||
|   // table denoted by `vocab_idx`.
 | ||||
|   auto hash = [vocab_sizes](std::string str, const int vocab_idx) { | ||||
|     const auto hash_value = | ||||
|         MurmurHash64WithSeed(str.c_str(), str.size(), kSeed); | ||||
|     return static_cast<int>((hash_value % vocab_sizes[vocab_idx]) + 1); | ||||
|   }; | ||||
| 
 | ||||
|   const std::vector<std::vector<int>> expected_testcase_outputs( | ||||
|       {{ | ||||
|            // Unigram & Bigram output for "wow", when `max_splits` == 2.
 | ||||
|            // We cannot include any of the actual tokens, since `max_splits`
 | ||||
|            // only allows enough space for the delimiters.
 | ||||
|            hash("^", 0), | ||||
|            hash("$", 0), | ||||
|            hash("^$", 1), | ||||
|            hash("$", 1), | ||||
|        }, | ||||
|        { | ||||
|            // Unigram & Bigram output for "wow", when `max_splits` == 3.
 | ||||
|            // We can start to include some tokens from the input string.
 | ||||
|            hash("^", 0), | ||||
|            hash("w", 0), | ||||
|            hash("$", 0), | ||||
|            hash("^w", 1), | ||||
|            hash("w$", 1), | ||||
|            hash("$", 1), | ||||
|        }, | ||||
|        { | ||||
|            // Unigram & Bigram output for "wow", when `max_splits` == 4.
 | ||||
|            hash("^", 0), | ||||
|            hash("w", 0), | ||||
|            hash("o", 0), | ||||
|            hash("$", 0), | ||||
|            hash("^w", 1), | ||||
|            hash("wo", 1), | ||||
|            hash("o$", 1), | ||||
|            hash("$", 1), | ||||
|        }, | ||||
|        { | ||||
|            // Unigram & Bigram output for "wow", when `max_splits` == 5.
 | ||||
|            // We can include the full input string.
 | ||||
|            hash("^", 0), | ||||
|            hash("w", 0), | ||||
|            hash("o", 0), | ||||
|            hash("w", 0), | ||||
|            hash("$", 0), | ||||
|            hash("^w", 1), | ||||
|            hash("wo", 1), | ||||
|            hash("ow", 1), | ||||
|            hash("w$", 1), | ||||
|            hash("$", 1), | ||||
|        }, | ||||
|        { | ||||
|            // Unigram & Bigram output for "wow", when `max_splits` == 6.
 | ||||
|            // `max_splits` is more than the full input string.
 | ||||
|            hash("^", 0), | ||||
|            hash("w", 0), | ||||
|            hash("o", 0), | ||||
|            hash("w", 0), | ||||
|            hash("$", 0), | ||||
|            hash("^w", 1), | ||||
|            hash("wo", 1), | ||||
|            hash("ow", 1), | ||||
|            hash("w$", 1), | ||||
|            hash("$", 1), | ||||
|        }}); | ||||
| 
 | ||||
|   for (int test_idx = 0; test_idx < max_splits.size(); test_idx++) { | ||||
|     const int testcase_max_splits = max_splits[test_idx]; | ||||
|     NGramHashModel m(kSeed, ngram_lengths, vocab_sizes, testcase_max_splits); | ||||
|     m.Invoke(testcase_input); | ||||
|     SCOPED_TRACE(Message() << "Where `max_splits` is: " << testcase_max_splits); | ||||
|     EXPECT_THAT(m.GetOutput<int>(), | ||||
|                 ElementsAreArray(expected_testcase_outputs[test_idx])); | ||||
|     EXPECT_THAT( | ||||
|         m.GetOutputShape(), | ||||
|         ElementsAreArray( | ||||
|             {/*batch_size=*/1, static_cast<int>(ngram_lengths.size()), | ||||
|              std::min( | ||||
|                  // Longest possible tokenization when using the entire
 | ||||
|                  // input.
 | ||||
|                  static_cast<int>(testcase_input.size()) + /*padding*/ 2, | ||||
|                  // Longest possible string when the `max_splits` value
 | ||||
|                  // is < testcase_input.size() + 2 for padding.
 | ||||
|                  testcase_max_splits)})); | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| TEST(NGramHashTest, InvalidMaxSplitsValue) { | ||||
|   // Check that the op errors out when given an invalid max splits value.
 | ||||
|   const std::vector<int> invalid_max_splits({0, -1, -5, -100}); | ||||
|   for (const int max_splits : invalid_max_splits) { | ||||
|     NGramHashModel m(/*seed=*/123, /*ngram_lengths=*/{100, 200}, | ||||
|                      /*vocab_sizes=*/{1, 2}, /*max_splits=*/max_splits); | ||||
|     EXPECT_EQ(m.InvokeUnchecked("hi"), kTfLiteError); | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| TEST(NGramHashTest, MismatchNgramLengthsAndVocabSizes) { | ||||
|   // Check that the op errors out when ngram lengths and vocab sizes mistmatch.
 | ||||
|   { | ||||
|     NGramHashModel m(/*seed=*/123, /*ngram_lengths=*/{100, 200, 300}, | ||||
|                      /*vocab_sizes=*/{1, 2}); | ||||
|     EXPECT_EQ(m.InvokeUnchecked("hi"), kTfLiteError); | ||||
|   } | ||||
|   { | ||||
|     NGramHashModel m(/*seed=*/123, /*ngram_lengths=*/{100, 200}, | ||||
|                      /*vocab_sizes=*/{1, 2, 3}); | ||||
|     EXPECT_EQ(m.InvokeUnchecked("hi"), kTfLiteError); | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| }  // namespace
 | ||||
| }  // namespace tflite::ops::custom
 | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user