Internal MediaPipe Tasks change.
PiperOrigin-RevId: 516881879
This commit is contained in:
parent
ce3cd94f45
commit
18d88c531a
|
@ -42,3 +42,36 @@ cc_test(
|
|||
"@org_tensorflow//tensorflow/lite/kernels:test_util",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "ngram_hash",
|
||||
srcs = ["ngram_hash.cc"],
|
||||
hdrs = ["ngram_hash.h"],
|
||||
copts = tflite_copts(),
|
||||
deps = [
|
||||
"//mediapipe/tasks/cc/text/language_detector/custom_ops/utils:ngram_hash_ops_utils",
|
||||
"//mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash:murmur",
|
||||
"@flatbuffers",
|
||||
"@org_tensorflow//tensorflow/lite:string_util",
|
||||
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
||||
"@org_tensorflow//tensorflow/lite/kernels:kernel_util",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "ngram_hash_test",
|
||||
srcs = ["ngram_hash_test.cc"],
|
||||
deps = [
|
||||
":ngram_hash",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash:murmur",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@flatbuffers",
|
||||
"@org_tensorflow//tensorflow/lite:framework",
|
||||
"@org_tensorflow//tensorflow/lite:string_util",
|
||||
"@org_tensorflow//tensorflow/lite/c:common",
|
||||
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
||||
"@org_tensorflow//tensorflow/lite/kernels:test_util",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -0,0 +1,264 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "flatbuffers/flexbuffers.h"
|
||||
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/murmur.h"
|
||||
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/string_util.h"
|
||||
|
||||
namespace tflite::ops::custom {
|
||||
|
||||
namespace ngram_op {
|
||||
|
||||
namespace {
|
||||
|
||||
using ::flexbuffers::GetRoot;
|
||||
using ::flexbuffers::Map;
|
||||
using ::flexbuffers::TypedVector;
|
||||
using ::mediapipe::tasks::text::language_detector::custom_ops::
|
||||
LowercaseUnicodeStr;
|
||||
using ::mediapipe::tasks::text::language_detector::custom_ops::Tokenize;
|
||||
using ::mediapipe::tasks::text::language_detector::custom_ops::TokenizedOutput;
|
||||
using ::mediapipe::tasks::text::language_detector::custom_ops::hash::
|
||||
MurmurHash64WithSeed;
|
||||
using ::tflite::GetString;
|
||||
using ::tflite::StringRef;
|
||||
|
||||
constexpr int kInputMessage = 0;
|
||||
constexpr int kOutputLabel = 0;
|
||||
constexpr int kDefaultMaxSplits = 128;
|
||||
|
||||
// This op takes in a string, finds the character ngrams for it and then
|
||||
// maps each of these ngrams to an index using the specified vocabulary sizes.
|
||||
|
||||
// Input(s):
|
||||
// - input: Input string.
|
||||
// - seeds: Seed for the random number generator.
|
||||
// - ngram_lengths: Lengths of each of the ngrams. For example [1, 2, 3] would
|
||||
// be interpreted as generating unigrams, bigrams, and trigrams.
|
||||
// - vocab_sizes: Size of the vocabulary for each of the ngram features
|
||||
// respectively. The op would generate vocab ids to be less than or equal to
|
||||
// the vocab size. The index 0 implies an invalid ngram.
|
||||
// - max_splits: Maximum number of tokens in the output. If this is unset, the
|
||||
// limit is `kDefaultMaxSplits`.
|
||||
// - lower_case_input: If this is set to true, the input string would be
|
||||
// lower-cased before any processing.
|
||||
|
||||
// Output(s):
|
||||
// - output: A tensor of size [number of ngrams, number of tokens + 2],
|
||||
// where 2 tokens are reserved for the padding. If `max_splits` is set, this
|
||||
// length is <= max_splits, otherwise it is <= `kDefaultMaxSplits`.
|
||||
|
||||
// Helper class used for pre-processing the input.
|
||||
class NGramHashParams {
|
||||
public:
|
||||
NGramHashParams(const uint64_t seed, const std::vector<int>& ngram_lengths,
|
||||
const std::vector<int>& vocab_sizes, int max_splits,
|
||||
bool lower_case_input)
|
||||
: seed_(seed),
|
||||
ngram_lengths_(ngram_lengths),
|
||||
vocab_sizes_(vocab_sizes),
|
||||
max_splits_(max_splits),
|
||||
lower_case_input_(lower_case_input) {}
|
||||
|
||||
TfLiteStatus PreprocessInput(const TfLiteTensor* input_t,
|
||||
TfLiteContext* context) {
|
||||
if (input_t->bytes == 0) {
|
||||
context->ReportError(context, "Empty input not supported.");
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
// Do sanity checks on the input.
|
||||
if (ngram_lengths_.empty()) {
|
||||
context->ReportError(context, "`ngram_lengths` must be non-empty.");
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
if (vocab_sizes_.empty()) {
|
||||
context->ReportError(context, "`vocab_sizes` must be non-empty.");
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
if (ngram_lengths_.size() != vocab_sizes_.size()) {
|
||||
context->ReportError(
|
||||
context,
|
||||
"Sizes of `ngram_lengths` and `vocab_sizes` must be the same.");
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
if (max_splits_ <= 0) {
|
||||
context->ReportError(context, "`max_splits` must be > 0.");
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
// Obtain and tokenize the input.
|
||||
StringRef inputref = GetString(input_t, /*string_index=*/0);
|
||||
if (lower_case_input_) {
|
||||
std::string lower_cased_str;
|
||||
LowercaseUnicodeStr(inputref.str, inputref.len, &lower_cased_str);
|
||||
|
||||
tokenized_output_ =
|
||||
Tokenize(lower_cased_str.c_str(), inputref.len, max_splits_,
|
||||
/*exclude_nonalphaspace_tokens=*/true);
|
||||
} else {
|
||||
tokenized_output_ = Tokenize(inputref.str, inputref.len, max_splits_,
|
||||
/*exclude_nonalphaspace_tokens=*/true);
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
uint64_t GetSeed() const { return seed_; }
|
||||
|
||||
int GetNumTokens() const { return tokenized_output_.tokens.size(); }
|
||||
|
||||
int GetNumNGrams() const { return ngram_lengths_.size(); }
|
||||
|
||||
std::vector<int> GetNGramLengths() const { return ngram_lengths_; }
|
||||
|
||||
std::vector<int> GetVocabSizes() const { return vocab_sizes_; }
|
||||
|
||||
const TokenizedOutput& GetTokenizedOutput() const {
|
||||
return tokenized_output_;
|
||||
}
|
||||
|
||||
TokenizedOutput tokenized_output_;
|
||||
|
||||
private:
|
||||
const uint64_t seed_;
|
||||
std::vector<int> ngram_lengths_;
|
||||
std::vector<int> vocab_sizes_;
|
||||
const int max_splits_;
|
||||
const bool lower_case_input_;
|
||||
};
|
||||
|
||||
// Convert the TypedVector into a regular std::vector.
|
||||
std::vector<int> GetIntVector(TypedVector typed_vec) {
|
||||
std::vector<int> vec(typed_vec.size());
|
||||
for (int j = 0; j < typed_vec.size(); j++) {
|
||||
vec[j] = typed_vec[j].AsInt32();
|
||||
}
|
||||
return vec;
|
||||
}
|
||||
|
||||
void GetNGramHashIndices(NGramHashParams* params, int32_t* data) {
|
||||
const int max_unicode_length = params->GetNumTokens();
|
||||
const auto ngram_lengths = params->GetNGramLengths();
|
||||
const auto vocab_sizes = params->GetVocabSizes();
|
||||
const auto& tokenized_output = params->GetTokenizedOutput();
|
||||
const auto seed = params->GetSeed();
|
||||
|
||||
// Compute for each ngram.
|
||||
for (int ngram = 0; ngram < ngram_lengths.size(); ngram++) {
|
||||
const int vocab_size = vocab_sizes[ngram];
|
||||
const int ngram_length = ngram_lengths[ngram];
|
||||
|
||||
// Compute for each token within the input.
|
||||
for (int start = 0; start < tokenized_output.tokens.size(); start++) {
|
||||
// Compute the number of bytes for the ngram starting at the given
|
||||
// token.
|
||||
int num_bytes = 0;
|
||||
for (int i = start;
|
||||
i < tokenized_output.tokens.size() && i < (start + ngram_length);
|
||||
i++) {
|
||||
num_bytes += tokenized_output.tokens[i].second;
|
||||
}
|
||||
|
||||
// Compute the hash for the ngram starting at the token.
|
||||
const auto str_hash = MurmurHash64WithSeed(
|
||||
tokenized_output.str.c_str() + tokenized_output.tokens[start].first,
|
||||
num_bytes, seed);
|
||||
|
||||
// Map the hash to an index in the vocab.
|
||||
data[ngram * max_unicode_length + start] = (str_hash % vocab_size) + 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
|
||||
const Map& m = GetRoot(buffer_t, length).AsMap();
|
||||
|
||||
const uint64_t seed = m["seed"].AsUInt64();
|
||||
const std::vector<int> ngram_lengths =
|
||||
GetIntVector(m["ngram_lengths"].AsTypedVector());
|
||||
const std::vector<int> vocab_sizes =
|
||||
GetIntVector(m["vocab_sizes"].AsTypedVector());
|
||||
const int max_splits =
|
||||
m["max_splits"].IsNull() ? kDefaultMaxSplits : m["max_splits"].AsInt32();
|
||||
const bool lowercase_input =
|
||||
m["lowercase_input"].IsNull() ? true : m["lowercase_input"].AsBool();
|
||||
|
||||
return new NGramHashParams(seed, ngram_lengths, vocab_sizes, max_splits,
|
||||
lowercase_input);
|
||||
}
|
||||
|
||||
void Free(TfLiteContext* context, void* buffer) {
|
||||
delete reinterpret_cast<NGramHashParams*>(buffer);
|
||||
}
|
||||
|
||||
TfLiteStatus Resize(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputLabel);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
SetTensorToDynamic(output);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
NGramHashParams* params = reinterpret_cast<NGramHashParams*>(node->user_data);
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
params->PreprocessInput(GetInput(context, node, kInputMessage), context));
|
||||
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputLabel);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
if (IsDynamicTensor(output)) {
|
||||
TfLiteIntArray* output_size = TfLiteIntArrayCreate(3);
|
||||
output_size->data[0] = 1;
|
||||
output_size->data[1] = params->GetNumNGrams();
|
||||
output_size->data[2] = params->GetNumTokens();
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
context->ResizeTensor(context, output, output_size));
|
||||
} else {
|
||||
context->ReportError(context, "Output must by dynamic.");
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
if (output->type == kTfLiteInt32) {
|
||||
GetNGramHashIndices(params, output->data.i32);
|
||||
} else {
|
||||
context->ReportError(context, "Output type must be Int32.");
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace ngram_op
|
||||
|
||||
TfLiteRegistration* Register_NGRAM_HASH() {
|
||||
static TfLiteRegistration r = {ngram_op::Init, ngram_op::Free,
|
||||
ngram_op::Resize, ngram_op::Eval};
|
||||
return &r;
|
||||
}
|
||||
|
||||
} // namespace tflite::ops::custom
|
|
@ -0,0 +1,27 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_NGRAM_HASH_H_
|
||||
#define MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_NGRAM_HASH_H_
|
||||
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
|
||||
namespace tflite::ops::custom {
|
||||
|
||||
TfLiteRegistration* Register_NGRAM_HASH();
|
||||
|
||||
} // namespace tflite::ops::custom
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_NGRAM_HASH_H_
|
|
@ -0,0 +1,313 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "flatbuffers/flexbuffers.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/murmur.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tensorflow/lite/kernels/test_util.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
#include "tensorflow/lite/string_util.h"
|
||||
|
||||
namespace tflite::ops::custom {
|
||||
namespace {
|
||||
|
||||
using ::flexbuffers::Builder;
|
||||
using ::mediapipe::tasks::text::language_detector::custom_ops::hash::
|
||||
MurmurHash64WithSeed;
|
||||
using ::testing::ElementsAreArray;
|
||||
using ::testing::Message;
|
||||
|
||||
// Helper class for testing the op.
|
||||
class NGramHashModel : public SingleOpModel {
|
||||
public:
|
||||
explicit NGramHashModel(const uint64_t seed,
|
||||
const std::vector<int>& ngram_lengths,
|
||||
const std::vector<int>& vocab_sizes,
|
||||
const absl::optional<int> max_splits = std::nullopt) {
|
||||
// Setup the model inputs.
|
||||
Builder fbb;
|
||||
size_t start = fbb.StartMap();
|
||||
fbb.UInt("seed", seed);
|
||||
{
|
||||
size_t start = fbb.StartVector("ngram_lengths");
|
||||
for (const int& ngram_len : ngram_lengths) {
|
||||
fbb.Int(ngram_len);
|
||||
}
|
||||
fbb.EndVector(start, /*typed=*/true, /*fixed=*/false);
|
||||
}
|
||||
{
|
||||
size_t start = fbb.StartVector("vocab_sizes");
|
||||
for (const int& vocab_size : vocab_sizes) {
|
||||
fbb.Int(vocab_size);
|
||||
}
|
||||
fbb.EndVector(start, /*typed=*/true, /*fixed=*/false);
|
||||
}
|
||||
if (max_splits) {
|
||||
fbb.Int("max_splits", *max_splits);
|
||||
}
|
||||
fbb.EndMap(start);
|
||||
fbb.Finish();
|
||||
output_ = AddOutput({TensorType_INT32, {}});
|
||||
SetCustomOp("NGramHash", fbb.GetBuffer(), Register_NGRAM_HASH);
|
||||
BuildInterpreter({GetShape(input_)});
|
||||
}
|
||||
|
||||
void SetupInputTensor(const std::string& input) {
|
||||
PopulateStringTensor(input_, {input});
|
||||
CHECK(interpreter_->AllocateTensors() == kTfLiteOk)
|
||||
<< "Cannot allocate tensors";
|
||||
}
|
||||
|
||||
void Invoke(const std::string& input) {
|
||||
SetupInputTensor(input);
|
||||
CHECK_EQ(SingleOpModel::Invoke(), kTfLiteOk);
|
||||
}
|
||||
|
||||
TfLiteStatus InvokeUnchecked(const std::string& input) {
|
||||
SetupInputTensor(input);
|
||||
return SingleOpModel::Invoke();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<T> GetOutput() {
|
||||
return ExtractVector<T>(output_);
|
||||
}
|
||||
|
||||
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
|
||||
|
||||
private:
|
||||
int input_ = AddInput(TensorType_STRING);
|
||||
int output_;
|
||||
};
|
||||
|
||||
TEST(NGramHashTest, ReturnsExpectedValueWhenInputIsSane) {
|
||||
// Checks that the op returns the expected value when the input is sane.
|
||||
// Also checks that when `max_splits` is not specified, the entire string is
|
||||
// tokenized.
|
||||
const uint64_t kSeed = 123;
|
||||
const std::vector<int> vocab_sizes({100, 200});
|
||||
std::vector<int> ngram_lengths({1, 2});
|
||||
const std::vector<std::string> testcase_inputs({
|
||||
"hi",
|
||||
"wow",
|
||||
"!",
|
||||
"HI",
|
||||
});
|
||||
|
||||
// A hash function that maps the given string to an index in the embedding
|
||||
// table denoted by `vocab_idx`.
|
||||
auto hash = [vocab_sizes](std::string str, const int vocab_idx) {
|
||||
const auto hash_value =
|
||||
MurmurHash64WithSeed(str.c_str(), str.size(), kSeed);
|
||||
return static_cast<int>((hash_value % vocab_sizes[vocab_idx]) + 1);
|
||||
};
|
||||
const std::vector<std::vector<int>> expected_testcase_outputs(
|
||||
{{
|
||||
// Unigram & Bigram output for "hi".
|
||||
hash("^", 0),
|
||||
hash("h", 0),
|
||||
hash("i", 0),
|
||||
hash("$", 0),
|
||||
hash("^h", 1),
|
||||
hash("hi", 1),
|
||||
hash("i$", 1),
|
||||
hash("$", 1),
|
||||
},
|
||||
{
|
||||
// Unigram & Bigram output for "wow".
|
||||
hash("^", 0),
|
||||
hash("w", 0),
|
||||
hash("o", 0),
|
||||
hash("w", 0),
|
||||
hash("$", 0),
|
||||
hash("^w", 1),
|
||||
hash("wo", 1),
|
||||
hash("ow", 1),
|
||||
hash("w$", 1),
|
||||
hash("$", 1),
|
||||
},
|
||||
{
|
||||
// Unigram & Bigram output for "!" (which will get replaced by " ").
|
||||
hash("^", 0),
|
||||
hash(" ", 0),
|
||||
hash("$", 0),
|
||||
hash("^ ", 1),
|
||||
hash(" $", 1),
|
||||
hash("$", 1),
|
||||
},
|
||||
{
|
||||
// Unigram & Bigram output for "HI" (which will get lower-cased).
|
||||
hash("^", 0),
|
||||
hash("h", 0),
|
||||
hash("i", 0),
|
||||
hash("$", 0),
|
||||
hash("^h", 1),
|
||||
hash("hi", 1),
|
||||
hash("i$", 1),
|
||||
hash("$", 1),
|
||||
}});
|
||||
|
||||
NGramHashModel m(kSeed, ngram_lengths, vocab_sizes);
|
||||
for (int test_idx = 0; test_idx < testcase_inputs.size(); test_idx++) {
|
||||
const string& testcase_input = testcase_inputs[test_idx];
|
||||
m.Invoke(testcase_input);
|
||||
SCOPED_TRACE(Message() << "Where the testcases' input is: "
|
||||
<< testcase_input);
|
||||
EXPECT_THAT(m.GetOutput<int>(),
|
||||
ElementsAreArray(expected_testcase_outputs[test_idx]));
|
||||
EXPECT_THAT(m.GetOutputShape(),
|
||||
ElementsAreArray(
|
||||
{/*batch_size=*/1, static_cast<int>(ngram_lengths.size()),
|
||||
static_cast<int>(testcase_input.size()) + /*padding*/ 2}));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(NGramHashTest, ReturnsExpectedValueWhenMaxSplitsIsSpecified) {
|
||||
// Checks that the op returns the expected value when the input is correct
|
||||
// when `max_splits` is specified.
|
||||
const uint64_t kSeed = 123;
|
||||
const std::vector<int> vocab_sizes({100, 200});
|
||||
std::vector<int> ngram_lengths({1, 2});
|
||||
|
||||
const std::string testcase_input = "wow";
|
||||
const std::vector<int> max_splits({2, 3, 4, 5, 6});
|
||||
|
||||
// A hash function that maps the given string to an index in the embedding
|
||||
// table denoted by `vocab_idx`.
|
||||
auto hash = [vocab_sizes](std::string str, const int vocab_idx) {
|
||||
const auto hash_value =
|
||||
MurmurHash64WithSeed(str.c_str(), str.size(), kSeed);
|
||||
return static_cast<int>((hash_value % vocab_sizes[vocab_idx]) + 1);
|
||||
};
|
||||
|
||||
const std::vector<std::vector<int>> expected_testcase_outputs(
|
||||
{{
|
||||
// Unigram & Bigram output for "wow", when `max_splits` == 2.
|
||||
// We cannot include any of the actual tokens, since `max_splits`
|
||||
// only allows enough space for the delimiters.
|
||||
hash("^", 0),
|
||||
hash("$", 0),
|
||||
hash("^$", 1),
|
||||
hash("$", 1),
|
||||
},
|
||||
{
|
||||
// Unigram & Bigram output for "wow", when `max_splits` == 3.
|
||||
// We can start to include some tokens from the input string.
|
||||
hash("^", 0),
|
||||
hash("w", 0),
|
||||
hash("$", 0),
|
||||
hash("^w", 1),
|
||||
hash("w$", 1),
|
||||
hash("$", 1),
|
||||
},
|
||||
{
|
||||
// Unigram & Bigram output for "wow", when `max_splits` == 4.
|
||||
hash("^", 0),
|
||||
hash("w", 0),
|
||||
hash("o", 0),
|
||||
hash("$", 0),
|
||||
hash("^w", 1),
|
||||
hash("wo", 1),
|
||||
hash("o$", 1),
|
||||
hash("$", 1),
|
||||
},
|
||||
{
|
||||
// Unigram & Bigram output for "wow", when `max_splits` == 5.
|
||||
// We can include the full input string.
|
||||
hash("^", 0),
|
||||
hash("w", 0),
|
||||
hash("o", 0),
|
||||
hash("w", 0),
|
||||
hash("$", 0),
|
||||
hash("^w", 1),
|
||||
hash("wo", 1),
|
||||
hash("ow", 1),
|
||||
hash("w$", 1),
|
||||
hash("$", 1),
|
||||
},
|
||||
{
|
||||
// Unigram & Bigram output for "wow", when `max_splits` == 6.
|
||||
// `max_splits` is more than the full input string.
|
||||
hash("^", 0),
|
||||
hash("w", 0),
|
||||
hash("o", 0),
|
||||
hash("w", 0),
|
||||
hash("$", 0),
|
||||
hash("^w", 1),
|
||||
hash("wo", 1),
|
||||
hash("ow", 1),
|
||||
hash("w$", 1),
|
||||
hash("$", 1),
|
||||
}});
|
||||
|
||||
for (int test_idx = 0; test_idx < max_splits.size(); test_idx++) {
|
||||
const int testcase_max_splits = max_splits[test_idx];
|
||||
NGramHashModel m(kSeed, ngram_lengths, vocab_sizes, testcase_max_splits);
|
||||
m.Invoke(testcase_input);
|
||||
SCOPED_TRACE(Message() << "Where `max_splits` is: " << testcase_max_splits);
|
||||
EXPECT_THAT(m.GetOutput<int>(),
|
||||
ElementsAreArray(expected_testcase_outputs[test_idx]));
|
||||
EXPECT_THAT(
|
||||
m.GetOutputShape(),
|
||||
ElementsAreArray(
|
||||
{/*batch_size=*/1, static_cast<int>(ngram_lengths.size()),
|
||||
std::min(
|
||||
// Longest possible tokenization when using the entire
|
||||
// input.
|
||||
static_cast<int>(testcase_input.size()) + /*padding*/ 2,
|
||||
// Longest possible string when the `max_splits` value
|
||||
// is < testcase_input.size() + 2 for padding.
|
||||
testcase_max_splits)}));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(NGramHashTest, InvalidMaxSplitsValue) {
|
||||
// Check that the op errors out when given an invalid max splits value.
|
||||
const std::vector<int> invalid_max_splits({0, -1, -5, -100});
|
||||
for (const int max_splits : invalid_max_splits) {
|
||||
NGramHashModel m(/*seed=*/123, /*ngram_lengths=*/{100, 200},
|
||||
/*vocab_sizes=*/{1, 2}, /*max_splits=*/max_splits);
|
||||
EXPECT_EQ(m.InvokeUnchecked("hi"), kTfLiteError);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(NGramHashTest, MismatchNgramLengthsAndVocabSizes) {
|
||||
// Check that the op errors out when ngram lengths and vocab sizes mistmatch.
|
||||
{
|
||||
NGramHashModel m(/*seed=*/123, /*ngram_lengths=*/{100, 200, 300},
|
||||
/*vocab_sizes=*/{1, 2});
|
||||
EXPECT_EQ(m.InvokeUnchecked("hi"), kTfLiteError);
|
||||
}
|
||||
{
|
||||
NGramHashModel m(/*seed=*/123, /*ngram_lengths=*/{100, 200},
|
||||
/*vocab_sizes=*/{1, 2, 3});
|
||||
EXPECT_EQ(m.InvokeUnchecked("hi"), kTfLiteError);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite::ops::custom
|
Loading…
Reference in New Issue
Block a user