Internal MediaPipe Tasks change.
PiperOrigin-RevId: 524345005
This commit is contained in:
parent
27038f534a
commit
257fa01b68
10
WORKSPACE
10
WORKSPACE
|
@ -239,6 +239,16 @@ http_archive(
|
||||||
repo_mapping = {"@com_google_glog" : "@com_github_glog_glog_no_gflags"},
|
repo_mapping = {"@com_google_glog" : "@com_github_glog_glog_no_gflags"},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
http_archive(
|
||||||
|
name = "darts_clone",
|
||||||
|
build_file = "@//third_party:darts_clone.BUILD",
|
||||||
|
sha256 = "c97f55d05c98da6fcaf7f9ecc6a6dc6bc5b18b8564465f77abff8879d446491c",
|
||||||
|
strip_prefix = "darts-clone-e40ce4627526985a7767444b6ed6893ab6ff8983",
|
||||||
|
urls = [
|
||||||
|
"https://github.com/s-yata/darts-clone/archive/e40ce4627526985a7767444b6ed6893ab6ff8983.zip",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
http_archive(
|
http_archive(
|
||||||
name = "org_tensorflow_text",
|
name = "org_tensorflow_text",
|
||||||
sha256 = "f64647276f7288d1b1fe4c89581d51404d0ce4ae97f2bcc4c19bd667549adca8",
|
sha256 = "f64647276f7288d1b1fe4c89581d51404d0ce4ae97f2bcc4c19bd667549adca8",
|
||||||
|
|
82
mediapipe/tasks/cc/text/custom_ops/sentencepiece/BUILD
Normal file
82
mediapipe/tasks/cc/text/custom_ops/sentencepiece/BUILD
Normal file
|
@ -0,0 +1,82 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
load("@flatbuffers//:build_defs.bzl", "flatbuffer_cc_library")
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "config_fbs",
|
||||||
|
srcs = ["config.fbs"],
|
||||||
|
)
|
||||||
|
|
||||||
|
flatbuffer_cc_library(
|
||||||
|
name = "config",
|
||||||
|
srcs = [
|
||||||
|
"config.fbs",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
flatbuffer_cc_library(
|
||||||
|
name = "encoder_config",
|
||||||
|
srcs = [
|
||||||
|
"encoder_config.fbs",
|
||||||
|
],
|
||||||
|
includes = [":config_fbs"],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "utils",
|
||||||
|
hdrs = [
|
||||||
|
"utils.h",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "double_array_trie",
|
||||||
|
hdrs = [
|
||||||
|
"double_array_trie.h",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":config",
|
||||||
|
":utils",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "double_array_trie_builder",
|
||||||
|
srcs = [
|
||||||
|
"double_array_trie_builder.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"double_array_trie_builder.h",
|
||||||
|
],
|
||||||
|
deps = ["@darts_clone"],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "double_array_trie_test",
|
||||||
|
srcs = [
|
||||||
|
"double_array_trie_test.cc",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":double_array_trie",
|
||||||
|
":double_array_trie_builder",
|
||||||
|
":encoder_config",
|
||||||
|
":utils",
|
||||||
|
"//mediapipe/framework/port:gtest_main",
|
||||||
|
],
|
||||||
|
)
|
25
mediapipe/tasks/cc/text/custom_ops/sentencepiece/config.fbs
Normal file
25
mediapipe/tasks/cc/text/custom_ops/sentencepiece/config.fbs
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
namespace mediapipe.tflite_operations.sentencepiece;
|
||||||
|
|
||||||
|
table Trie {
|
||||||
|
nodes: [uint32];
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
enum EncoderVersion: byte {
|
||||||
|
SENTENCE_PIECE = 0,
|
||||||
|
}
|
|
@ -0,0 +1,111 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_H_
|
||||||
|
#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_H_
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/config_generated.h"
|
||||||
|
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/utils.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tflite_operations::sentencepiece {
|
||||||
|
|
||||||
|
// A trie node specifies a node in the tree, either an intermediate node or
|
||||||
|
// a leaf node.
|
||||||
|
// A leaf node contains the id as an int of the string match. This id is encoded
|
||||||
|
// in the lower 31 bits, thus the number of distinct ids is 2^31.
|
||||||
|
// An intermediate node has an associated label and an offset to its children.
|
||||||
|
// The label is encoded in the least significant byte and must match the input
|
||||||
|
// character during matching.
|
||||||
|
|
||||||
|
// A memory mappable trie, compatible with Darts::DoubleArray.
|
||||||
|
class DoubleArrayTrie {
|
||||||
|
public:
|
||||||
|
struct Match {
|
||||||
|
Match() {}
|
||||||
|
Match(int id, int match_length) : id(id), match_length(match_length) {}
|
||||||
|
int id = -1;
|
||||||
|
int match_length = -1;
|
||||||
|
bool empty() const { return match_length == -1; }
|
||||||
|
bool operator==(const Match& m) const {
|
||||||
|
return m.id == id && m.match_length == match_length;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// nodes and nodes_length specify the array of the nodes of the trie.
|
||||||
|
explicit DoubleArrayTrie(const flatbuffers::Vector<uint32_t>* nodes)
|
||||||
|
: nodes_(nodes) {}
|
||||||
|
|
||||||
|
// Finds matches that are prefixes of a string.
|
||||||
|
template <typename callback>
|
||||||
|
void IteratePrefixMatches(const utils::string_view& input,
|
||||||
|
callback update_fn) const;
|
||||||
|
|
||||||
|
// Finds the longest prefix match of a string.
|
||||||
|
Match LongestPrefixMatch(const utils::string_view& input) const {
|
||||||
|
Match match;
|
||||||
|
IteratePrefixMatches(input, [&match](const Match& m) { match = m; });
|
||||||
|
return match;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Returns whether a node as a leaf as a child.
|
||||||
|
bool has_leaf(uint32_t i) const { return ((*nodes_)[i]) & 0x100; }
|
||||||
|
|
||||||
|
// Returns a value associated with a node. Available when a node is a leaf.
|
||||||
|
int value(uint32_t i) const {
|
||||||
|
return static_cast<int>(((*nodes_)[i]) & 0x7fffffff);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns a label associated with a node.
|
||||||
|
// A leaf node will have the MSB set and thus return an invalid label.
|
||||||
|
int32_t label(uint32_t i) const { return ((*nodes_)[i]) & 0x800000ff; }
|
||||||
|
|
||||||
|
// Returns offset to children.
|
||||||
|
int32_t offset(uint32_t i) const {
|
||||||
|
const uint32_t node = (*nodes_)[i];
|
||||||
|
return (node >> 10) << ((node & 0x200) >> 6);
|
||||||
|
}
|
||||||
|
|
||||||
|
const flatbuffers::Vector<uint32_t>* nodes_;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename callback>
|
||||||
|
void DoubleArrayTrie::IteratePrefixMatches(const utils::string_view& input,
|
||||||
|
callback update_fn) const {
|
||||||
|
if (nodes_->size() == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
uint32_t pos = offset(0);
|
||||||
|
for (int i = 0; i < input.length(); ++i) {
|
||||||
|
pos ^= static_cast<unsigned char>(input.at(i));
|
||||||
|
if (pos < 0 || pos >= nodes_->size() || label(pos) != input.at(i)) {
|
||||||
|
// No match, exit.
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const bool node_has_leaf = has_leaf(pos);
|
||||||
|
pos ^= offset(pos);
|
||||||
|
if (pos < 0 || pos >= nodes_->size()) {
|
||||||
|
// We can get here only if the structure is corrupted.
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (node_has_leaf) {
|
||||||
|
update_fn(Match(value(pos), i + 1));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mediapipe::tflite_operations::sentencepiece
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_H_
|
|
@ -0,0 +1,75 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_builder.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "include/darts.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tflite_operations::sentencepiece {
|
||||||
|
|
||||||
|
std::vector<uint32_t> BuildTrie(const std::vector<std::string>& data) {
|
||||||
|
std::vector<int> ids;
|
||||||
|
ids.reserve(data.size());
|
||||||
|
for (int i = 0; i < data.size(); ++i) {
|
||||||
|
ids.push_back(i);
|
||||||
|
}
|
||||||
|
return BuildTrie(data, ids);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<uint32_t> BuildTrie(const std::vector<std::string>& data,
|
||||||
|
const std::vector<int>& ids) {
|
||||||
|
// We make strong assumptions about binary structure of trie.
|
||||||
|
struct OneElement {
|
||||||
|
OneElement(const std::string* key_, int index_)
|
||||||
|
: key(key_), index(index_) {}
|
||||||
|
const std::string* key;
|
||||||
|
int index;
|
||||||
|
bool operator<(const OneElement& el) const { return *key < *el.key; }
|
||||||
|
};
|
||||||
|
std::vector<OneElement> elements;
|
||||||
|
elements.reserve(data.size());
|
||||||
|
auto data_iterator = std::begin(data);
|
||||||
|
auto ids_iterator = std::begin(ids);
|
||||||
|
for (; data_iterator != std::end(data) && ids_iterator != std::end(ids);
|
||||||
|
++data_iterator, ++ids_iterator) {
|
||||||
|
elements.emplace_back(&(*data_iterator), *ids_iterator);
|
||||||
|
}
|
||||||
|
// Sort by keys.
|
||||||
|
std::sort(elements.begin(), elements.end());
|
||||||
|
|
||||||
|
// Create vectors to build the trie.
|
||||||
|
std::vector<const char*> strings;
|
||||||
|
std::vector<int32_t> indexes;
|
||||||
|
strings.reserve(data.size());
|
||||||
|
indexes.reserve(data.size());
|
||||||
|
for (const auto& el : elements) {
|
||||||
|
strings.push_back(el.key->c_str());
|
||||||
|
indexes.push_back(el.index);
|
||||||
|
}
|
||||||
|
auto trie = std::make_unique<Darts::DoubleArray>();
|
||||||
|
trie->build(data.size(), const_cast<char**>(&strings[0]), nullptr,
|
||||||
|
&indexes[0]);
|
||||||
|
// We make strong assumptions about internal Darts trie structure:
|
||||||
|
// - it is a vector of 32 bit signed integers
|
||||||
|
// - the "array" is the only one structure that contains all information about
|
||||||
|
// the trie.
|
||||||
|
const uint32_t* trie_data = static_cast<const uint32_t*>(trie->array());
|
||||||
|
return std::vector<uint32_t>(trie_data, trie_data + trie->size());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mediapipe::tflite_operations::sentencepiece
|
|
@ -0,0 +1,32 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_BUILDER_H_
|
||||||
|
#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_BUILDER_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace mediapipe::tflite_operations::sentencepiece {
|
||||||
|
|
||||||
|
std::vector<uint32_t> BuildTrie(const std::vector<std::string>& data,
|
||||||
|
const std::vector<int>& ids);
|
||||||
|
|
||||||
|
// A variant where ids are indexes in data.
|
||||||
|
std::vector<uint32_t> BuildTrie(const std::vector<std::string>& data);
|
||||||
|
|
||||||
|
} // namespace mediapipe::tflite_operations::sentencepiece
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_BUILDER_H_
|
|
@ -0,0 +1,73 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie.h"
|
||||||
|
|
||||||
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_builder.h"
|
||||||
|
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config_generated.h"
|
||||||
|
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/utils.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tflite_operations::sentencepiece {
|
||||||
|
|
||||||
|
TEST(DoubleArrayTrieTest, Match) {
|
||||||
|
flatbuffers::FlatBufferBuilder builder(1024);
|
||||||
|
const std::vector<std::string> test_strings = {"A", "AAX", "AA", "B"};
|
||||||
|
const auto trie_vector = builder.CreateVector(BuildTrie(test_strings));
|
||||||
|
TrieBuilder trie_builder(builder);
|
||||||
|
trie_builder.add_nodes(trie_vector);
|
||||||
|
const auto pieces = trie_builder.Finish();
|
||||||
|
EncoderConfigBuilder ecb(builder);
|
||||||
|
ecb.add_pieces(pieces);
|
||||||
|
FinishEncoderConfigBuffer(builder, ecb.Finish());
|
||||||
|
const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer());
|
||||||
|
DoubleArrayTrie dat(config->pieces()->nodes());
|
||||||
|
EXPECT_EQ(dat.LongestPrefixMatch(utils::string_view("AAL")),
|
||||||
|
DoubleArrayTrie::Match(2, 2));
|
||||||
|
|
||||||
|
std::vector<DoubleArrayTrie::Match> matches;
|
||||||
|
dat.IteratePrefixMatches(
|
||||||
|
utils::string_view("AAXL"),
|
||||||
|
[&matches](const DoubleArrayTrie::Match& m) { matches.push_back(m); });
|
||||||
|
EXPECT_THAT(matches, testing::ElementsAre(DoubleArrayTrie::Match(0, 1),
|
||||||
|
DoubleArrayTrie::Match(2, 2),
|
||||||
|
DoubleArrayTrie::Match(1, 3)));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(DoubleArrayTrieTest, ComplexMatch) {
|
||||||
|
flatbuffers::FlatBufferBuilder builder(1024);
|
||||||
|
const std::vector<std::string> test_strings = {"\xe2\x96\x81the", ",", "s",
|
||||||
|
"\xe2\x96\x81Hello"};
|
||||||
|
const std::vector<int> test_ids = {0, 5, 10, 15};
|
||||||
|
const auto trie_vector =
|
||||||
|
builder.CreateVector(BuildTrie(test_strings, test_ids));
|
||||||
|
TrieBuilder trie_builder(builder);
|
||||||
|
trie_builder.add_nodes(trie_vector);
|
||||||
|
const auto pieces = trie_builder.Finish();
|
||||||
|
EncoderConfigBuilder ecb(builder);
|
||||||
|
ecb.add_pieces(pieces);
|
||||||
|
FinishEncoderConfigBuffer(builder, ecb.Finish());
|
||||||
|
const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer());
|
||||||
|
DoubleArrayTrie dat(config->pieces()->nodes());
|
||||||
|
|
||||||
|
std::vector<DoubleArrayTrie::Match> matches;
|
||||||
|
dat.IteratePrefixMatches(
|
||||||
|
utils::string_view("\xe2\x96\x81Hello"),
|
||||||
|
[&matches](const DoubleArrayTrie::Match& m) { matches.push_back(m); });
|
||||||
|
EXPECT_THAT(matches, testing::ElementsAre(DoubleArrayTrie::Match(15, 8)));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mediapipe::tflite_operations::sentencepiece
|
|
@ -0,0 +1,52 @@
|
||||||
|
// Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
include "config.fbs";
|
||||||
|
|
||||||
|
namespace mediapipe.tflite_operations.sentencepiece;
|
||||||
|
|
||||||
|
table EncoderConfig {
|
||||||
|
// Version of the encoder.
|
||||||
|
version: EncoderVersion = SENTENCE_PIECE;
|
||||||
|
start_code: int32 = 0;
|
||||||
|
end_code: int32 = 0;
|
||||||
|
|
||||||
|
unknown_code: int32 = -1;
|
||||||
|
// Weight of "unknown code" when encoding. "Penalty" because it usually has a
|
||||||
|
// big negative weight,less than any other sentencepiece.
|
||||||
|
unknown_penalty: float = 0;
|
||||||
|
|
||||||
|
// The offset for encoding, usually used when codes with low codes are reserved
|
||||||
|
// for some special needs.
|
||||||
|
encoding_offset: int32;
|
||||||
|
|
||||||
|
// String pieces for encoding.
|
||||||
|
pieces: Trie;
|
||||||
|
pieces_scores: [float];
|
||||||
|
|
||||||
|
// Normalization related parameters.
|
||||||
|
remove_extra_whitespaces: bool;
|
||||||
|
|
||||||
|
// Add a whitespace prefix before encoding.
|
||||||
|
add_dummy_prefix: bool;
|
||||||
|
|
||||||
|
// Escape whitespaces during encoding so the decoder can restore them exactly as
|
||||||
|
// in the input.
|
||||||
|
escape_whitespaces: bool;
|
||||||
|
|
||||||
|
// Normalization parameters.
|
||||||
|
normalized_prefixes: Trie;
|
||||||
|
normalized_replacements: [byte];
|
||||||
|
}
|
||||||
|
|
||||||
|
root_type EncoderConfig;
|
60
mediapipe/tasks/cc/text/custom_ops/sentencepiece/utils.h
Normal file
60
mediapipe/tasks/cc/text/custom_ops/sentencepiece/utils.h
Normal file
|
@ -0,0 +1,60 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_UTILS_H_
|
||||||
|
#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_UTILS_H_
|
||||||
|
|
||||||
|
#include <ostream>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
namespace mediapipe::tflite_operations::sentencepiece {
|
||||||
|
|
||||||
|
// AOSP and WASM doesn't support string_view,
|
||||||
|
// we put here a minimal re-implementation.
|
||||||
|
namespace utils {
|
||||||
|
|
||||||
|
class string_view {
|
||||||
|
public:
|
||||||
|
explicit string_view(const std::string& s)
|
||||||
|
: str_(s.data()), len_(s.length()) {}
|
||||||
|
string_view(const char* str, int len) : str_(str), len_(len) {}
|
||||||
|
// A constructor from c string.
|
||||||
|
explicit string_view(const char* s) : str_(s), len_(strlen(s)) {}
|
||||||
|
|
||||||
|
int length() const { return len_; }
|
||||||
|
const char* data() const { return str_; }
|
||||||
|
bool empty() const { return len_ == 0; }
|
||||||
|
unsigned char at(int i) const { return str_[i]; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
const char* str_ = nullptr;
|
||||||
|
const int len_ = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
inline std::ostream& operator<<(std::ostream& os, const string_view& sv) {
|
||||||
|
os << std::string(sv.data(), sv.length());
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
inline bool operator==(const string_view& view1, const string_view& view2) {
|
||||||
|
if (view1.length() != view2.length()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return memcmp(view1.data(), view2.data(), view1.length()) == 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace utils
|
||||||
|
} // namespace mediapipe::tflite_operations::sentencepiece
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_UTILS_H_
|
29
third_party/darts_clone.BUILD
vendored
Normal file
29
third_party/darts_clone.BUILD
vendored
Normal file
|
@ -0,0 +1,29 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
# Description:
|
||||||
|
# Darts-clone is a clone of Darts (Double-ARray Trie System).
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
exports_files(["LICENSE"])
|
||||||
|
|
||||||
|
package(default_visibility = ["//visibility:public"])
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "darts_clone",
|
||||||
|
hdrs = [
|
||||||
|
"include/darts.h",
|
||||||
|
],
|
||||||
|
)
|
Loading…
Reference in New Issue
Block a user