diff --git a/WORKSPACE b/WORKSPACE index 6e079f142..760898185 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -239,6 +239,16 @@ http_archive( repo_mapping = {"@com_google_glog" : "@com_github_glog_glog_no_gflags"}, ) +http_archive( + name = "darts_clone", + build_file = "@//third_party:darts_clone.BUILD", + sha256 = "c97f55d05c98da6fcaf7f9ecc6a6dc6bc5b18b8564465f77abff8879d446491c", + strip_prefix = "darts-clone-e40ce4627526985a7767444b6ed6893ab6ff8983", + urls = [ + "https://github.com/s-yata/darts-clone/archive/e40ce4627526985a7767444b6ed6893ab6ff8983.zip", + ], +) + http_archive( name = "org_tensorflow_text", sha256 = "f64647276f7288d1b1fe4c89581d51404d0ce4ae97f2bcc4c19bd667549adca8", diff --git a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/BUILD b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/BUILD new file mode 100644 index 000000000..072b21f53 --- /dev/null +++ b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/BUILD @@ -0,0 +1,82 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@flatbuffers//:build_defs.bzl", "flatbuffer_cc_library") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +filegroup( + name = "config_fbs", + srcs = ["config.fbs"], +) + +flatbuffer_cc_library( + name = "config", + srcs = [ + "config.fbs", + ], +) + +flatbuffer_cc_library( + name = "encoder_config", + srcs = [ + "encoder_config.fbs", + ], + includes = [":config_fbs"], +) + +cc_library( + name = "utils", + hdrs = [ + "utils.h", + ], +) + +cc_library( + name = "double_array_trie", + hdrs = [ + "double_array_trie.h", + ], + deps = [ + ":config", + ":utils", + ], +) + +cc_library( + name = "double_array_trie_builder", + srcs = [ + "double_array_trie_builder.cc", + ], + hdrs = [ + "double_array_trie_builder.h", + ], + deps = ["@darts_clone"], +) + +cc_test( + name = "double_array_trie_test", + srcs = [ + "double_array_trie_test.cc", + ], + deps = [ + ":double_array_trie", + ":double_array_trie_builder", + ":encoder_config", + ":utils", + "//mediapipe/framework/port:gtest_main", + ], +) diff --git a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/config.fbs b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/config.fbs new file mode 100644 index 000000000..16408ffee --- /dev/null +++ b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/config.fbs @@ -0,0 +1,25 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +namespace mediapipe.tflite_operations.sentencepiece; + +table Trie { + nodes: [uint32]; +} + + +enum EncoderVersion: byte { + SENTENCE_PIECE = 0, +} diff --git a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie.h b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie.h new file mode 100644 index 000000000..c3b568f1c --- /dev/null +++ b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie.h @@ -0,0 +1,111 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_H_ +#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_H_ + +#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/config_generated.h" +#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/utils.h" + +namespace mediapipe::tflite_operations::sentencepiece { + +// A trie node specifies a node in the tree, either an intermediate node or +// a leaf node. +// A leaf node contains the id as an int of the string match. This id is encoded +// in the lower 31 bits, thus the number of distinct ids is 2^31. +// An intermediate node has an associated label and an offset to its children. +// The label is encoded in the least significant byte and must match the input +// character during matching. + +// A memory mappable trie, compatible with Darts::DoubleArray. +class DoubleArrayTrie { + public: + struct Match { + Match() {} + Match(int id, int match_length) : id(id), match_length(match_length) {} + int id = -1; + int match_length = -1; + bool empty() const { return match_length == -1; } + bool operator==(const Match& m) const { + return m.id == id && m.match_length == match_length; + } + }; + + // nodes and nodes_length specify the array of the nodes of the trie. + explicit DoubleArrayTrie(const flatbuffers::Vector* nodes) + : nodes_(nodes) {} + + // Finds matches that are prefixes of a string. + template + void IteratePrefixMatches(const utils::string_view& input, + callback update_fn) const; + + // Finds the longest prefix match of a string. + Match LongestPrefixMatch(const utils::string_view& input) const { + Match match; + IteratePrefixMatches(input, [&match](const Match& m) { match = m; }); + return match; + } + + private: + // Returns whether a node as a leaf as a child. + bool has_leaf(uint32_t i) const { return ((*nodes_)[i]) & 0x100; } + + // Returns a value associated with a node. Available when a node is a leaf. + int value(uint32_t i) const { + return static_cast(((*nodes_)[i]) & 0x7fffffff); + } + + // Returns a label associated with a node. + // A leaf node will have the MSB set and thus return an invalid label. + int32_t label(uint32_t i) const { return ((*nodes_)[i]) & 0x800000ff; } + + // Returns offset to children. + int32_t offset(uint32_t i) const { + const uint32_t node = (*nodes_)[i]; + return (node >> 10) << ((node & 0x200) >> 6); + } + + const flatbuffers::Vector* nodes_; +}; + +template +void DoubleArrayTrie::IteratePrefixMatches(const utils::string_view& input, + callback update_fn) const { + if (nodes_->size() == 0) { + return; + } + uint32_t pos = offset(0); + for (int i = 0; i < input.length(); ++i) { + pos ^= static_cast(input.at(i)); + if (pos < 0 || pos >= nodes_->size() || label(pos) != input.at(i)) { + // No match, exit. + return; + } + const bool node_has_leaf = has_leaf(pos); + pos ^= offset(pos); + if (pos < 0 || pos >= nodes_->size()) { + // We can get here only if the structure is corrupted. + return; + } + if (node_has_leaf) { + update_fn(Match(value(pos), i + 1)); + } + } +} + +} // namespace mediapipe::tflite_operations::sentencepiece + +#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_H_ diff --git a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_builder.cc b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_builder.cc new file mode 100644 index 000000000..f492b5c48 --- /dev/null +++ b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_builder.cc @@ -0,0 +1,75 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_builder.h" + +#include +#include + +#include "include/darts.h" + +namespace mediapipe::tflite_operations::sentencepiece { + +std::vector BuildTrie(const std::vector& data) { + std::vector ids; + ids.reserve(data.size()); + for (int i = 0; i < data.size(); ++i) { + ids.push_back(i); + } + return BuildTrie(data, ids); +} + +std::vector BuildTrie(const std::vector& data, + const std::vector& ids) { + // We make strong assumptions about binary structure of trie. + struct OneElement { + OneElement(const std::string* key_, int index_) + : key(key_), index(index_) {} + const std::string* key; + int index; + bool operator<(const OneElement& el) const { return *key < *el.key; } + }; + std::vector elements; + elements.reserve(data.size()); + auto data_iterator = std::begin(data); + auto ids_iterator = std::begin(ids); + for (; data_iterator != std::end(data) && ids_iterator != std::end(ids); + ++data_iterator, ++ids_iterator) { + elements.emplace_back(&(*data_iterator), *ids_iterator); + } + // Sort by keys. + std::sort(elements.begin(), elements.end()); + + // Create vectors to build the trie. + std::vector strings; + std::vector indexes; + strings.reserve(data.size()); + indexes.reserve(data.size()); + for (const auto& el : elements) { + strings.push_back(el.key->c_str()); + indexes.push_back(el.index); + } + auto trie = std::make_unique(); + trie->build(data.size(), const_cast(&strings[0]), nullptr, + &indexes[0]); + // We make strong assumptions about internal Darts trie structure: + // - it is a vector of 32 bit signed integers + // - the "array" is the only one structure that contains all information about + // the trie. + const uint32_t* trie_data = static_cast(trie->array()); + return std::vector(trie_data, trie_data + trie->size()); +} + +} // namespace mediapipe::tflite_operations::sentencepiece diff --git a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_builder.h b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_builder.h new file mode 100644 index 000000000..94c50bffc --- /dev/null +++ b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_builder.h @@ -0,0 +1,32 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_BUILDER_H_ +#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_BUILDER_H_ + +#include +#include + +namespace mediapipe::tflite_operations::sentencepiece { + +std::vector BuildTrie(const std::vector& data, + const std::vector& ids); + +// A variant where ids are indexes in data. +std::vector BuildTrie(const std::vector& data); + +} // namespace mediapipe::tflite_operations::sentencepiece + +#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_BUILDER_H_ diff --git a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_test.cc b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_test.cc new file mode 100644 index 000000000..60a78e126 --- /dev/null +++ b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_test.cc @@ -0,0 +1,73 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie.h" + +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_builder.h" +#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config_generated.h" +#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/utils.h" + +namespace mediapipe::tflite_operations::sentencepiece { + +TEST(DoubleArrayTrieTest, Match) { + flatbuffers::FlatBufferBuilder builder(1024); + const std::vector test_strings = {"A", "AAX", "AA", "B"}; + const auto trie_vector = builder.CreateVector(BuildTrie(test_strings)); + TrieBuilder trie_builder(builder); + trie_builder.add_nodes(trie_vector); + const auto pieces = trie_builder.Finish(); + EncoderConfigBuilder ecb(builder); + ecb.add_pieces(pieces); + FinishEncoderConfigBuffer(builder, ecb.Finish()); + const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer()); + DoubleArrayTrie dat(config->pieces()->nodes()); + EXPECT_EQ(dat.LongestPrefixMatch(utils::string_view("AAL")), + DoubleArrayTrie::Match(2, 2)); + + std::vector matches; + dat.IteratePrefixMatches( + utils::string_view("AAXL"), + [&matches](const DoubleArrayTrie::Match& m) { matches.push_back(m); }); + EXPECT_THAT(matches, testing::ElementsAre(DoubleArrayTrie::Match(0, 1), + DoubleArrayTrie::Match(2, 2), + DoubleArrayTrie::Match(1, 3))); +} + +TEST(DoubleArrayTrieTest, ComplexMatch) { + flatbuffers::FlatBufferBuilder builder(1024); + const std::vector test_strings = {"\xe2\x96\x81the", ",", "s", + "\xe2\x96\x81Hello"}; + const std::vector test_ids = {0, 5, 10, 15}; + const auto trie_vector = + builder.CreateVector(BuildTrie(test_strings, test_ids)); + TrieBuilder trie_builder(builder); + trie_builder.add_nodes(trie_vector); + const auto pieces = trie_builder.Finish(); + EncoderConfigBuilder ecb(builder); + ecb.add_pieces(pieces); + FinishEncoderConfigBuffer(builder, ecb.Finish()); + const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer()); + DoubleArrayTrie dat(config->pieces()->nodes()); + + std::vector matches; + dat.IteratePrefixMatches( + utils::string_view("\xe2\x96\x81Hello"), + [&matches](const DoubleArrayTrie::Match& m) { matches.push_back(m); }); + EXPECT_THAT(matches, testing::ElementsAre(DoubleArrayTrie::Match(15, 8))); +} + +} // namespace mediapipe::tflite_operations::sentencepiece diff --git a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config.fbs b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config.fbs new file mode 100644 index 000000000..2e7836803 --- /dev/null +++ b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config.fbs @@ -0,0 +1,52 @@ +// Copyright 2023 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +include "config.fbs"; + +namespace mediapipe.tflite_operations.sentencepiece; + +table EncoderConfig { + // Version of the encoder. + version: EncoderVersion = SENTENCE_PIECE; + start_code: int32 = 0; + end_code: int32 = 0; + + unknown_code: int32 = -1; + // Weight of "unknown code" when encoding. "Penalty" because it usually has a + // big negative weight,less than any other sentencepiece. + unknown_penalty: float = 0; + + // The offset for encoding, usually used when codes with low codes are reserved + // for some special needs. + encoding_offset: int32; + + // String pieces for encoding. + pieces: Trie; + pieces_scores: [float]; + + // Normalization related parameters. + remove_extra_whitespaces: bool; + + // Add a whitespace prefix before encoding. + add_dummy_prefix: bool; + + // Escape whitespaces during encoding so the decoder can restore them exactly as + // in the input. + escape_whitespaces: bool; + + // Normalization parameters. + normalized_prefixes: Trie; + normalized_replacements: [byte]; +} + +root_type EncoderConfig; diff --git a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/utils.h b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/utils.h new file mode 100644 index 000000000..c1b7728cc --- /dev/null +++ b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/utils.h @@ -0,0 +1,60 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_UTILS_H_ +#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_UTILS_H_ + +#include +#include + +namespace mediapipe::tflite_operations::sentencepiece { + +// AOSP and WASM doesn't support string_view, +// we put here a minimal re-implementation. +namespace utils { + +class string_view { + public: + explicit string_view(const std::string& s) + : str_(s.data()), len_(s.length()) {} + string_view(const char* str, int len) : str_(str), len_(len) {} + // A constructor from c string. + explicit string_view(const char* s) : str_(s), len_(strlen(s)) {} + + int length() const { return len_; } + const char* data() const { return str_; } + bool empty() const { return len_ == 0; } + unsigned char at(int i) const { return str_[i]; } + + private: + const char* str_ = nullptr; + const int len_ = 0; +}; + +inline std::ostream& operator<<(std::ostream& os, const string_view& sv) { + os << std::string(sv.data(), sv.length()); + return os; +} +inline bool operator==(const string_view& view1, const string_view& view2) { + if (view1.length() != view2.length()) { + return false; + } + return memcmp(view1.data(), view2.data(), view1.length()) == 0; +} + +} // namespace utils +} // namespace mediapipe::tflite_operations::sentencepiece + +#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_UTILS_H_ diff --git a/third_party/darts_clone.BUILD b/third_party/darts_clone.BUILD new file mode 100644 index 000000000..a15c2d68d --- /dev/null +++ b/third_party/darts_clone.BUILD @@ -0,0 +1,29 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Description: +# Darts-clone is a clone of Darts (Double-ARray Trie System). + +licenses(["notice"]) + +exports_files(["LICENSE"]) + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "darts_clone", + hdrs = [ + "include/darts.h", + ], +)