Internal MediaPipe Tasks change.
PiperOrigin-RevId: 516881879
This commit is contained in:
parent
ce3cd94f45
commit
18d88c531a
|
@ -42,3 +42,36 @@ cc_test(
|
||||||
"@org_tensorflow//tensorflow/lite/kernels:test_util",
|
"@org_tensorflow//tensorflow/lite/kernels:test_util",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "ngram_hash",
|
||||||
|
srcs = ["ngram_hash.cc"],
|
||||||
|
hdrs = ["ngram_hash.h"],
|
||||||
|
copts = tflite_copts(),
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/tasks/cc/text/language_detector/custom_ops/utils:ngram_hash_ops_utils",
|
||||||
|
"//mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash:murmur",
|
||||||
|
"@flatbuffers",
|
||||||
|
"@org_tensorflow//tensorflow/lite:string_util",
|
||||||
|
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
||||||
|
"@org_tensorflow//tensorflow/lite/kernels:kernel_util",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "ngram_hash_test",
|
||||||
|
srcs = ["ngram_hash_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":ngram_hash",
|
||||||
|
"//mediapipe/framework/port:gtest_main",
|
||||||
|
"//mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash:murmur",
|
||||||
|
"@com_google_absl//absl/types:optional",
|
||||||
|
"@flatbuffers",
|
||||||
|
"@org_tensorflow//tensorflow/lite:framework",
|
||||||
|
"@org_tensorflow//tensorflow/lite:string_util",
|
||||||
|
"@org_tensorflow//tensorflow/lite/c:common",
|
||||||
|
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
||||||
|
"@org_tensorflow//tensorflow/lite/kernels:test_util",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
|
@ -0,0 +1,264 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h"
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "flatbuffers/flexbuffers.h"
|
||||||
|
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/murmur.h"
|
||||||
|
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils.h"
|
||||||
|
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||||
|
#include "tensorflow/lite/string_util.h"
|
||||||
|
|
||||||
|
namespace tflite::ops::custom {
|
||||||
|
|
||||||
|
namespace ngram_op {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ::flexbuffers::GetRoot;
|
||||||
|
using ::flexbuffers::Map;
|
||||||
|
using ::flexbuffers::TypedVector;
|
||||||
|
using ::mediapipe::tasks::text::language_detector::custom_ops::
|
||||||
|
LowercaseUnicodeStr;
|
||||||
|
using ::mediapipe::tasks::text::language_detector::custom_ops::Tokenize;
|
||||||
|
using ::mediapipe::tasks::text::language_detector::custom_ops::TokenizedOutput;
|
||||||
|
using ::mediapipe::tasks::text::language_detector::custom_ops::hash::
|
||||||
|
MurmurHash64WithSeed;
|
||||||
|
using ::tflite::GetString;
|
||||||
|
using ::tflite::StringRef;
|
||||||
|
|
||||||
|
constexpr int kInputMessage = 0;
|
||||||
|
constexpr int kOutputLabel = 0;
|
||||||
|
constexpr int kDefaultMaxSplits = 128;
|
||||||
|
|
||||||
|
// This op takes in a string, finds the character ngrams for it and then
|
||||||
|
// maps each of these ngrams to an index using the specified vocabulary sizes.
|
||||||
|
|
||||||
|
// Input(s):
|
||||||
|
// - input: Input string.
|
||||||
|
// - seeds: Seed for the random number generator.
|
||||||
|
// - ngram_lengths: Lengths of each of the ngrams. For example [1, 2, 3] would
|
||||||
|
// be interpreted as generating unigrams, bigrams, and trigrams.
|
||||||
|
// - vocab_sizes: Size of the vocabulary for each of the ngram features
|
||||||
|
// respectively. The op would generate vocab ids to be less than or equal to
|
||||||
|
// the vocab size. The index 0 implies an invalid ngram.
|
||||||
|
// - max_splits: Maximum number of tokens in the output. If this is unset, the
|
||||||
|
// limit is `kDefaultMaxSplits`.
|
||||||
|
// - lower_case_input: If this is set to true, the input string would be
|
||||||
|
// lower-cased before any processing.
|
||||||
|
|
||||||
|
// Output(s):
|
||||||
|
// - output: A tensor of size [number of ngrams, number of tokens + 2],
|
||||||
|
// where 2 tokens are reserved for the padding. If `max_splits` is set, this
|
||||||
|
// length is <= max_splits, otherwise it is <= `kDefaultMaxSplits`.
|
||||||
|
|
||||||
|
// Helper class used for pre-processing the input.
|
||||||
|
class NGramHashParams {
|
||||||
|
public:
|
||||||
|
NGramHashParams(const uint64_t seed, const std::vector<int>& ngram_lengths,
|
||||||
|
const std::vector<int>& vocab_sizes, int max_splits,
|
||||||
|
bool lower_case_input)
|
||||||
|
: seed_(seed),
|
||||||
|
ngram_lengths_(ngram_lengths),
|
||||||
|
vocab_sizes_(vocab_sizes),
|
||||||
|
max_splits_(max_splits),
|
||||||
|
lower_case_input_(lower_case_input) {}
|
||||||
|
|
||||||
|
TfLiteStatus PreprocessInput(const TfLiteTensor* input_t,
|
||||||
|
TfLiteContext* context) {
|
||||||
|
if (input_t->bytes == 0) {
|
||||||
|
context->ReportError(context, "Empty input not supported.");
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do sanity checks on the input.
|
||||||
|
if (ngram_lengths_.empty()) {
|
||||||
|
context->ReportError(context, "`ngram_lengths` must be non-empty.");
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (vocab_sizes_.empty()) {
|
||||||
|
context->ReportError(context, "`vocab_sizes` must be non-empty.");
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ngram_lengths_.size() != vocab_sizes_.size()) {
|
||||||
|
context->ReportError(
|
||||||
|
context,
|
||||||
|
"Sizes of `ngram_lengths` and `vocab_sizes` must be the same.");
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (max_splits_ <= 0) {
|
||||||
|
context->ReportError(context, "`max_splits` must be > 0.");
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Obtain and tokenize the input.
|
||||||
|
StringRef inputref = GetString(input_t, /*string_index=*/0);
|
||||||
|
if (lower_case_input_) {
|
||||||
|
std::string lower_cased_str;
|
||||||
|
LowercaseUnicodeStr(inputref.str, inputref.len, &lower_cased_str);
|
||||||
|
|
||||||
|
tokenized_output_ =
|
||||||
|
Tokenize(lower_cased_str.c_str(), inputref.len, max_splits_,
|
||||||
|
/*exclude_nonalphaspace_tokens=*/true);
|
||||||
|
} else {
|
||||||
|
tokenized_output_ = Tokenize(inputref.str, inputref.len, max_splits_,
|
||||||
|
/*exclude_nonalphaspace_tokens=*/true);
|
||||||
|
}
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
uint64_t GetSeed() const { return seed_; }
|
||||||
|
|
||||||
|
int GetNumTokens() const { return tokenized_output_.tokens.size(); }
|
||||||
|
|
||||||
|
int GetNumNGrams() const { return ngram_lengths_.size(); }
|
||||||
|
|
||||||
|
std::vector<int> GetNGramLengths() const { return ngram_lengths_; }
|
||||||
|
|
||||||
|
std::vector<int> GetVocabSizes() const { return vocab_sizes_; }
|
||||||
|
|
||||||
|
const TokenizedOutput& GetTokenizedOutput() const {
|
||||||
|
return tokenized_output_;
|
||||||
|
}
|
||||||
|
|
||||||
|
TokenizedOutput tokenized_output_;
|
||||||
|
|
||||||
|
private:
|
||||||
|
const uint64_t seed_;
|
||||||
|
std::vector<int> ngram_lengths_;
|
||||||
|
std::vector<int> vocab_sizes_;
|
||||||
|
const int max_splits_;
|
||||||
|
const bool lower_case_input_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Convert the TypedVector into a regular std::vector.
|
||||||
|
std::vector<int> GetIntVector(TypedVector typed_vec) {
|
||||||
|
std::vector<int> vec(typed_vec.size());
|
||||||
|
for (int j = 0; j < typed_vec.size(); j++) {
|
||||||
|
vec[j] = typed_vec[j].AsInt32();
|
||||||
|
}
|
||||||
|
return vec;
|
||||||
|
}
|
||||||
|
|
||||||
|
void GetNGramHashIndices(NGramHashParams* params, int32_t* data) {
|
||||||
|
const int max_unicode_length = params->GetNumTokens();
|
||||||
|
const auto ngram_lengths = params->GetNGramLengths();
|
||||||
|
const auto vocab_sizes = params->GetVocabSizes();
|
||||||
|
const auto& tokenized_output = params->GetTokenizedOutput();
|
||||||
|
const auto seed = params->GetSeed();
|
||||||
|
|
||||||
|
// Compute for each ngram.
|
||||||
|
for (int ngram = 0; ngram < ngram_lengths.size(); ngram++) {
|
||||||
|
const int vocab_size = vocab_sizes[ngram];
|
||||||
|
const int ngram_length = ngram_lengths[ngram];
|
||||||
|
|
||||||
|
// Compute for each token within the input.
|
||||||
|
for (int start = 0; start < tokenized_output.tokens.size(); start++) {
|
||||||
|
// Compute the number of bytes for the ngram starting at the given
|
||||||
|
// token.
|
||||||
|
int num_bytes = 0;
|
||||||
|
for (int i = start;
|
||||||
|
i < tokenized_output.tokens.size() && i < (start + ngram_length);
|
||||||
|
i++) {
|
||||||
|
num_bytes += tokenized_output.tokens[i].second;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the hash for the ngram starting at the token.
|
||||||
|
const auto str_hash = MurmurHash64WithSeed(
|
||||||
|
tokenized_output.str.c_str() + tokenized_output.tokens[start].first,
|
||||||
|
num_bytes, seed);
|
||||||
|
|
||||||
|
// Map the hash to an index in the vocab.
|
||||||
|
data[ngram * max_unicode_length + start] = (str_hash % vocab_size) + 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||||
|
const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
|
||||||
|
const Map& m = GetRoot(buffer_t, length).AsMap();
|
||||||
|
|
||||||
|
const uint64_t seed = m["seed"].AsUInt64();
|
||||||
|
const std::vector<int> ngram_lengths =
|
||||||
|
GetIntVector(m["ngram_lengths"].AsTypedVector());
|
||||||
|
const std::vector<int> vocab_sizes =
|
||||||
|
GetIntVector(m["vocab_sizes"].AsTypedVector());
|
||||||
|
const int max_splits =
|
||||||
|
m["max_splits"].IsNull() ? kDefaultMaxSplits : m["max_splits"].AsInt32();
|
||||||
|
const bool lowercase_input =
|
||||||
|
m["lowercase_input"].IsNull() ? true : m["lowercase_input"].AsBool();
|
||||||
|
|
||||||
|
return new NGramHashParams(seed, ngram_lengths, vocab_sizes, max_splits,
|
||||||
|
lowercase_input);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Free(TfLiteContext* context, void* buffer) {
|
||||||
|
delete reinterpret_cast<NGramHashParams*>(buffer);
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus Resize(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
TfLiteTensor* output = GetOutput(context, node, kOutputLabel);
|
||||||
|
TF_LITE_ENSURE(context, output != nullptr);
|
||||||
|
SetTensorToDynamic(output);
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
NGramHashParams* params = reinterpret_cast<NGramHashParams*>(node->user_data);
|
||||||
|
TF_LITE_ENSURE_OK(
|
||||||
|
context,
|
||||||
|
params->PreprocessInput(GetInput(context, node, kInputMessage), context));
|
||||||
|
|
||||||
|
TfLiteTensor* output = GetOutput(context, node, kOutputLabel);
|
||||||
|
TF_LITE_ENSURE(context, output != nullptr);
|
||||||
|
if (IsDynamicTensor(output)) {
|
||||||
|
TfLiteIntArray* output_size = TfLiteIntArrayCreate(3);
|
||||||
|
output_size->data[0] = 1;
|
||||||
|
output_size->data[1] = params->GetNumNGrams();
|
||||||
|
output_size->data[2] = params->GetNumTokens();
|
||||||
|
TF_LITE_ENSURE_OK(context,
|
||||||
|
context->ResizeTensor(context, output, output_size));
|
||||||
|
} else {
|
||||||
|
context->ReportError(context, "Output must by dynamic.");
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (output->type == kTfLiteInt32) {
|
||||||
|
GetNGramHashIndices(params, output->data.i32);
|
||||||
|
} else {
|
||||||
|
context->ReportError(context, "Output type must be Int32.");
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace ngram_op
|
||||||
|
|
||||||
|
TfLiteRegistration* Register_NGRAM_HASH() {
|
||||||
|
static TfLiteRegistration r = {ngram_op::Init, ngram_op::Free,
|
||||||
|
ngram_op::Resize, ngram_op::Eval};
|
||||||
|
return &r;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tflite::ops::custom
|
|
@ -0,0 +1,27 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_NGRAM_HASH_H_
|
||||||
|
#define MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_NGRAM_HASH_H_
|
||||||
|
|
||||||
|
#include "tensorflow/lite/kernels/register.h"
|
||||||
|
|
||||||
|
namespace tflite::ops::custom {
|
||||||
|
|
||||||
|
TfLiteRegistration* Register_NGRAM_HASH();
|
||||||
|
|
||||||
|
} // namespace tflite::ops::custom
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_NGRAM_HASH_H_
|
|
@ -0,0 +1,313 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h"
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <optional>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/types/optional.h"
|
||||||
|
#include "flatbuffers/flexbuffers.h"
|
||||||
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/murmur.h"
|
||||||
|
#include "tensorflow/lite/c/common.h"
|
||||||
|
#include "tensorflow/lite/interpreter.h"
|
||||||
|
#include "tensorflow/lite/kernels/register.h"
|
||||||
|
#include "tensorflow/lite/kernels/test_util.h"
|
||||||
|
#include "tensorflow/lite/model.h"
|
||||||
|
#include "tensorflow/lite/string_util.h"
|
||||||
|
|
||||||
|
namespace tflite::ops::custom {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ::flexbuffers::Builder;
|
||||||
|
using ::mediapipe::tasks::text::language_detector::custom_ops::hash::
|
||||||
|
MurmurHash64WithSeed;
|
||||||
|
using ::testing::ElementsAreArray;
|
||||||
|
using ::testing::Message;
|
||||||
|
|
||||||
|
// Helper class for testing the op.
|
||||||
|
class NGramHashModel : public SingleOpModel {
|
||||||
|
public:
|
||||||
|
explicit NGramHashModel(const uint64_t seed,
|
||||||
|
const std::vector<int>& ngram_lengths,
|
||||||
|
const std::vector<int>& vocab_sizes,
|
||||||
|
const absl::optional<int> max_splits = std::nullopt) {
|
||||||
|
// Setup the model inputs.
|
||||||
|
Builder fbb;
|
||||||
|
size_t start = fbb.StartMap();
|
||||||
|
fbb.UInt("seed", seed);
|
||||||
|
{
|
||||||
|
size_t start = fbb.StartVector("ngram_lengths");
|
||||||
|
for (const int& ngram_len : ngram_lengths) {
|
||||||
|
fbb.Int(ngram_len);
|
||||||
|
}
|
||||||
|
fbb.EndVector(start, /*typed=*/true, /*fixed=*/false);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
size_t start = fbb.StartVector("vocab_sizes");
|
||||||
|
for (const int& vocab_size : vocab_sizes) {
|
||||||
|
fbb.Int(vocab_size);
|
||||||
|
}
|
||||||
|
fbb.EndVector(start, /*typed=*/true, /*fixed=*/false);
|
||||||
|
}
|
||||||
|
if (max_splits) {
|
||||||
|
fbb.Int("max_splits", *max_splits);
|
||||||
|
}
|
||||||
|
fbb.EndMap(start);
|
||||||
|
fbb.Finish();
|
||||||
|
output_ = AddOutput({TensorType_INT32, {}});
|
||||||
|
SetCustomOp("NGramHash", fbb.GetBuffer(), Register_NGRAM_HASH);
|
||||||
|
BuildInterpreter({GetShape(input_)});
|
||||||
|
}
|
||||||
|
|
||||||
|
void SetupInputTensor(const std::string& input) {
|
||||||
|
PopulateStringTensor(input_, {input});
|
||||||
|
CHECK(interpreter_->AllocateTensors() == kTfLiteOk)
|
||||||
|
<< "Cannot allocate tensors";
|
||||||
|
}
|
||||||
|
|
||||||
|
void Invoke(const std::string& input) {
|
||||||
|
SetupInputTensor(input);
|
||||||
|
CHECK_EQ(SingleOpModel::Invoke(), kTfLiteOk);
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus InvokeUnchecked(const std::string& input) {
|
||||||
|
SetupInputTensor(input);
|
||||||
|
return SingleOpModel::Invoke();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
std::vector<T> GetOutput() {
|
||||||
|
return ExtractVector<T>(output_);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
|
||||||
|
|
||||||
|
private:
|
||||||
|
int input_ = AddInput(TensorType_STRING);
|
||||||
|
int output_;
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST(NGramHashTest, ReturnsExpectedValueWhenInputIsSane) {
|
||||||
|
// Checks that the op returns the expected value when the input is sane.
|
||||||
|
// Also checks that when `max_splits` is not specified, the entire string is
|
||||||
|
// tokenized.
|
||||||
|
const uint64_t kSeed = 123;
|
||||||
|
const std::vector<int> vocab_sizes({100, 200});
|
||||||
|
std::vector<int> ngram_lengths({1, 2});
|
||||||
|
const std::vector<std::string> testcase_inputs({
|
||||||
|
"hi",
|
||||||
|
"wow",
|
||||||
|
"!",
|
||||||
|
"HI",
|
||||||
|
});
|
||||||
|
|
||||||
|
// A hash function that maps the given string to an index in the embedding
|
||||||
|
// table denoted by `vocab_idx`.
|
||||||
|
auto hash = [vocab_sizes](std::string str, const int vocab_idx) {
|
||||||
|
const auto hash_value =
|
||||||
|
MurmurHash64WithSeed(str.c_str(), str.size(), kSeed);
|
||||||
|
return static_cast<int>((hash_value % vocab_sizes[vocab_idx]) + 1);
|
||||||
|
};
|
||||||
|
const std::vector<std::vector<int>> expected_testcase_outputs(
|
||||||
|
{{
|
||||||
|
// Unigram & Bigram output for "hi".
|
||||||
|
hash("^", 0),
|
||||||
|
hash("h", 0),
|
||||||
|
hash("i", 0),
|
||||||
|
hash("$", 0),
|
||||||
|
hash("^h", 1),
|
||||||
|
hash("hi", 1),
|
||||||
|
hash("i$", 1),
|
||||||
|
hash("$", 1),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Unigram & Bigram output for "wow".
|
||||||
|
hash("^", 0),
|
||||||
|
hash("w", 0),
|
||||||
|
hash("o", 0),
|
||||||
|
hash("w", 0),
|
||||||
|
hash("$", 0),
|
||||||
|
hash("^w", 1),
|
||||||
|
hash("wo", 1),
|
||||||
|
hash("ow", 1),
|
||||||
|
hash("w$", 1),
|
||||||
|
hash("$", 1),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Unigram & Bigram output for "!" (which will get replaced by " ").
|
||||||
|
hash("^", 0),
|
||||||
|
hash(" ", 0),
|
||||||
|
hash("$", 0),
|
||||||
|
hash("^ ", 1),
|
||||||
|
hash(" $", 1),
|
||||||
|
hash("$", 1),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Unigram & Bigram output for "HI" (which will get lower-cased).
|
||||||
|
hash("^", 0),
|
||||||
|
hash("h", 0),
|
||||||
|
hash("i", 0),
|
||||||
|
hash("$", 0),
|
||||||
|
hash("^h", 1),
|
||||||
|
hash("hi", 1),
|
||||||
|
hash("i$", 1),
|
||||||
|
hash("$", 1),
|
||||||
|
}});
|
||||||
|
|
||||||
|
NGramHashModel m(kSeed, ngram_lengths, vocab_sizes);
|
||||||
|
for (int test_idx = 0; test_idx < testcase_inputs.size(); test_idx++) {
|
||||||
|
const string& testcase_input = testcase_inputs[test_idx];
|
||||||
|
m.Invoke(testcase_input);
|
||||||
|
SCOPED_TRACE(Message() << "Where the testcases' input is: "
|
||||||
|
<< testcase_input);
|
||||||
|
EXPECT_THAT(m.GetOutput<int>(),
|
||||||
|
ElementsAreArray(expected_testcase_outputs[test_idx]));
|
||||||
|
EXPECT_THAT(m.GetOutputShape(),
|
||||||
|
ElementsAreArray(
|
||||||
|
{/*batch_size=*/1, static_cast<int>(ngram_lengths.size()),
|
||||||
|
static_cast<int>(testcase_input.size()) + /*padding*/ 2}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(NGramHashTest, ReturnsExpectedValueWhenMaxSplitsIsSpecified) {
|
||||||
|
// Checks that the op returns the expected value when the input is correct
|
||||||
|
// when `max_splits` is specified.
|
||||||
|
const uint64_t kSeed = 123;
|
||||||
|
const std::vector<int> vocab_sizes({100, 200});
|
||||||
|
std::vector<int> ngram_lengths({1, 2});
|
||||||
|
|
||||||
|
const std::string testcase_input = "wow";
|
||||||
|
const std::vector<int> max_splits({2, 3, 4, 5, 6});
|
||||||
|
|
||||||
|
// A hash function that maps the given string to an index in the embedding
|
||||||
|
// table denoted by `vocab_idx`.
|
||||||
|
auto hash = [vocab_sizes](std::string str, const int vocab_idx) {
|
||||||
|
const auto hash_value =
|
||||||
|
MurmurHash64WithSeed(str.c_str(), str.size(), kSeed);
|
||||||
|
return static_cast<int>((hash_value % vocab_sizes[vocab_idx]) + 1);
|
||||||
|
};
|
||||||
|
|
||||||
|
const std::vector<std::vector<int>> expected_testcase_outputs(
|
||||||
|
{{
|
||||||
|
// Unigram & Bigram output for "wow", when `max_splits` == 2.
|
||||||
|
// We cannot include any of the actual tokens, since `max_splits`
|
||||||
|
// only allows enough space for the delimiters.
|
||||||
|
hash("^", 0),
|
||||||
|
hash("$", 0),
|
||||||
|
hash("^$", 1),
|
||||||
|
hash("$", 1),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Unigram & Bigram output for "wow", when `max_splits` == 3.
|
||||||
|
// We can start to include some tokens from the input string.
|
||||||
|
hash("^", 0),
|
||||||
|
hash("w", 0),
|
||||||
|
hash("$", 0),
|
||||||
|
hash("^w", 1),
|
||||||
|
hash("w$", 1),
|
||||||
|
hash("$", 1),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Unigram & Bigram output for "wow", when `max_splits` == 4.
|
||||||
|
hash("^", 0),
|
||||||
|
hash("w", 0),
|
||||||
|
hash("o", 0),
|
||||||
|
hash("$", 0),
|
||||||
|
hash("^w", 1),
|
||||||
|
hash("wo", 1),
|
||||||
|
hash("o$", 1),
|
||||||
|
hash("$", 1),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Unigram & Bigram output for "wow", when `max_splits` == 5.
|
||||||
|
// We can include the full input string.
|
||||||
|
hash("^", 0),
|
||||||
|
hash("w", 0),
|
||||||
|
hash("o", 0),
|
||||||
|
hash("w", 0),
|
||||||
|
hash("$", 0),
|
||||||
|
hash("^w", 1),
|
||||||
|
hash("wo", 1),
|
||||||
|
hash("ow", 1),
|
||||||
|
hash("w$", 1),
|
||||||
|
hash("$", 1),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Unigram & Bigram output for "wow", when `max_splits` == 6.
|
||||||
|
// `max_splits` is more than the full input string.
|
||||||
|
hash("^", 0),
|
||||||
|
hash("w", 0),
|
||||||
|
hash("o", 0),
|
||||||
|
hash("w", 0),
|
||||||
|
hash("$", 0),
|
||||||
|
hash("^w", 1),
|
||||||
|
hash("wo", 1),
|
||||||
|
hash("ow", 1),
|
||||||
|
hash("w$", 1),
|
||||||
|
hash("$", 1),
|
||||||
|
}});
|
||||||
|
|
||||||
|
for (int test_idx = 0; test_idx < max_splits.size(); test_idx++) {
|
||||||
|
const int testcase_max_splits = max_splits[test_idx];
|
||||||
|
NGramHashModel m(kSeed, ngram_lengths, vocab_sizes, testcase_max_splits);
|
||||||
|
m.Invoke(testcase_input);
|
||||||
|
SCOPED_TRACE(Message() << "Where `max_splits` is: " << testcase_max_splits);
|
||||||
|
EXPECT_THAT(m.GetOutput<int>(),
|
||||||
|
ElementsAreArray(expected_testcase_outputs[test_idx]));
|
||||||
|
EXPECT_THAT(
|
||||||
|
m.GetOutputShape(),
|
||||||
|
ElementsAreArray(
|
||||||
|
{/*batch_size=*/1, static_cast<int>(ngram_lengths.size()),
|
||||||
|
std::min(
|
||||||
|
// Longest possible tokenization when using the entire
|
||||||
|
// input.
|
||||||
|
static_cast<int>(testcase_input.size()) + /*padding*/ 2,
|
||||||
|
// Longest possible string when the `max_splits` value
|
||||||
|
// is < testcase_input.size() + 2 for padding.
|
||||||
|
testcase_max_splits)}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(NGramHashTest, InvalidMaxSplitsValue) {
|
||||||
|
// Check that the op errors out when given an invalid max splits value.
|
||||||
|
const std::vector<int> invalid_max_splits({0, -1, -5, -100});
|
||||||
|
for (const int max_splits : invalid_max_splits) {
|
||||||
|
NGramHashModel m(/*seed=*/123, /*ngram_lengths=*/{100, 200},
|
||||||
|
/*vocab_sizes=*/{1, 2}, /*max_splits=*/max_splits);
|
||||||
|
EXPECT_EQ(m.InvokeUnchecked("hi"), kTfLiteError);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(NGramHashTest, MismatchNgramLengthsAndVocabSizes) {
|
||||||
|
// Check that the op errors out when ngram lengths and vocab sizes mistmatch.
|
||||||
|
{
|
||||||
|
NGramHashModel m(/*seed=*/123, /*ngram_lengths=*/{100, 200, 300},
|
||||||
|
/*vocab_sizes=*/{1, 2});
|
||||||
|
EXPECT_EQ(m.InvokeUnchecked("hi"), kTfLiteError);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
NGramHashModel m(/*seed=*/123, /*ngram_lengths=*/{100, 200},
|
||||||
|
/*vocab_sizes=*/{1, 2, 3});
|
||||||
|
EXPECT_EQ(m.InvokeUnchecked("hi"), kTfLiteError);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace tflite::ops::custom
|
Loading…
Reference in New Issue
Block a user