Internal MediaPipe Tasks change.

PiperOrigin-RevId: 524345005
This commit is contained in:
MediaPipe Team 2023-04-14 11:37:50 -07:00 committed by Copybara-Service
parent 27038f534a
commit 257fa01b68
10 changed files with 549 additions and 0 deletions

View File

@ -239,6 +239,16 @@ http_archive(
repo_mapping = {"@com_google_glog" : "@com_github_glog_glog_no_gflags"},
)
http_archive(
name = "darts_clone",
build_file = "@//third_party:darts_clone.BUILD",
sha256 = "c97f55d05c98da6fcaf7f9ecc6a6dc6bc5b18b8564465f77abff8879d446491c",
strip_prefix = "darts-clone-e40ce4627526985a7767444b6ed6893ab6ff8983",
urls = [
"https://github.com/s-yata/darts-clone/archive/e40ce4627526985a7767444b6ed6893ab6ff8983.zip",
],
)
http_archive(
name = "org_tensorflow_text",
sha256 = "f64647276f7288d1b1fe4c89581d51404d0ce4ae97f2bcc4c19bd667549adca8",

View File

@ -0,0 +1,82 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load("@flatbuffers//:build_defs.bzl", "flatbuffer_cc_library")
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
filegroup(
name = "config_fbs",
srcs = ["config.fbs"],
)
flatbuffer_cc_library(
name = "config",
srcs = [
"config.fbs",
],
)
flatbuffer_cc_library(
name = "encoder_config",
srcs = [
"encoder_config.fbs",
],
includes = [":config_fbs"],
)
cc_library(
name = "utils",
hdrs = [
"utils.h",
],
)
cc_library(
name = "double_array_trie",
hdrs = [
"double_array_trie.h",
],
deps = [
":config",
":utils",
],
)
cc_library(
name = "double_array_trie_builder",
srcs = [
"double_array_trie_builder.cc",
],
hdrs = [
"double_array_trie_builder.h",
],
deps = ["@darts_clone"],
)
cc_test(
name = "double_array_trie_test",
srcs = [
"double_array_trie_test.cc",
],
deps = [
":double_array_trie",
":double_array_trie_builder",
":encoder_config",
":utils",
"//mediapipe/framework/port:gtest_main",
],
)

View File

@ -0,0 +1,25 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
namespace mediapipe.tflite_operations.sentencepiece;
table Trie {
nodes: [uint32];
}
enum EncoderVersion: byte {
SENTENCE_PIECE = 0,
}

View File

@ -0,0 +1,111 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_H_
#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_H_
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/config_generated.h"
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/utils.h"
namespace mediapipe::tflite_operations::sentencepiece {
// A trie node specifies a node in the tree, either an intermediate node or
// a leaf node.
// A leaf node contains the id as an int of the string match. This id is encoded
// in the lower 31 bits, thus the number of distinct ids is 2^31.
// An intermediate node has an associated label and an offset to its children.
// The label is encoded in the least significant byte and must match the input
// character during matching.
// A memory mappable trie, compatible with Darts::DoubleArray.
class DoubleArrayTrie {
public:
struct Match {
Match() {}
Match(int id, int match_length) : id(id), match_length(match_length) {}
int id = -1;
int match_length = -1;
bool empty() const { return match_length == -1; }
bool operator==(const Match& m) const {
return m.id == id && m.match_length == match_length;
}
};
// nodes and nodes_length specify the array of the nodes of the trie.
explicit DoubleArrayTrie(const flatbuffers::Vector<uint32_t>* nodes)
: nodes_(nodes) {}
// Finds matches that are prefixes of a string.
template <typename callback>
void IteratePrefixMatches(const utils::string_view& input,
callback update_fn) const;
// Finds the longest prefix match of a string.
Match LongestPrefixMatch(const utils::string_view& input) const {
Match match;
IteratePrefixMatches(input, [&match](const Match& m) { match = m; });
return match;
}
private:
// Returns whether a node as a leaf as a child.
bool has_leaf(uint32_t i) const { return ((*nodes_)[i]) & 0x100; }
// Returns a value associated with a node. Available when a node is a leaf.
int value(uint32_t i) const {
return static_cast<int>(((*nodes_)[i]) & 0x7fffffff);
}
// Returns a label associated with a node.
// A leaf node will have the MSB set and thus return an invalid label.
int32_t label(uint32_t i) const { return ((*nodes_)[i]) & 0x800000ff; }
// Returns offset to children.
int32_t offset(uint32_t i) const {
const uint32_t node = (*nodes_)[i];
return (node >> 10) << ((node & 0x200) >> 6);
}
const flatbuffers::Vector<uint32_t>* nodes_;
};
template <typename callback>
void DoubleArrayTrie::IteratePrefixMatches(const utils::string_view& input,
callback update_fn) const {
if (nodes_->size() == 0) {
return;
}
uint32_t pos = offset(0);
for (int i = 0; i < input.length(); ++i) {
pos ^= static_cast<unsigned char>(input.at(i));
if (pos < 0 || pos >= nodes_->size() || label(pos) != input.at(i)) {
// No match, exit.
return;
}
const bool node_has_leaf = has_leaf(pos);
pos ^= offset(pos);
if (pos < 0 || pos >= nodes_->size()) {
// We can get here only if the structure is corrupted.
return;
}
if (node_has_leaf) {
update_fn(Match(value(pos), i + 1));
}
}
}
} // namespace mediapipe::tflite_operations::sentencepiece
#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_H_

View File

@ -0,0 +1,75 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_builder.h"
#include <algorithm>
#include <memory>
#include "include/darts.h"
namespace mediapipe::tflite_operations::sentencepiece {
std::vector<uint32_t> BuildTrie(const std::vector<std::string>& data) {
std::vector<int> ids;
ids.reserve(data.size());
for (int i = 0; i < data.size(); ++i) {
ids.push_back(i);
}
return BuildTrie(data, ids);
}
std::vector<uint32_t> BuildTrie(const std::vector<std::string>& data,
const std::vector<int>& ids) {
// We make strong assumptions about binary structure of trie.
struct OneElement {
OneElement(const std::string* key_, int index_)
: key(key_), index(index_) {}
const std::string* key;
int index;
bool operator<(const OneElement& el) const { return *key < *el.key; }
};
std::vector<OneElement> elements;
elements.reserve(data.size());
auto data_iterator = std::begin(data);
auto ids_iterator = std::begin(ids);
for (; data_iterator != std::end(data) && ids_iterator != std::end(ids);
++data_iterator, ++ids_iterator) {
elements.emplace_back(&(*data_iterator), *ids_iterator);
}
// Sort by keys.
std::sort(elements.begin(), elements.end());
// Create vectors to build the trie.
std::vector<const char*> strings;
std::vector<int32_t> indexes;
strings.reserve(data.size());
indexes.reserve(data.size());
for (const auto& el : elements) {
strings.push_back(el.key->c_str());
indexes.push_back(el.index);
}
auto trie = std::make_unique<Darts::DoubleArray>();
trie->build(data.size(), const_cast<char**>(&strings[0]), nullptr,
&indexes[0]);
// We make strong assumptions about internal Darts trie structure:
// - it is a vector of 32 bit signed integers
// - the "array" is the only one structure that contains all information about
// the trie.
const uint32_t* trie_data = static_cast<const uint32_t*>(trie->array());
return std::vector<uint32_t>(trie_data, trie_data + trie->size());
}
} // namespace mediapipe::tflite_operations::sentencepiece

View File

@ -0,0 +1,32 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_BUILDER_H_
#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_BUILDER_H_
#include <string>
#include <vector>
namespace mediapipe::tflite_operations::sentencepiece {
std::vector<uint32_t> BuildTrie(const std::vector<std::string>& data,
const std::vector<int>& ids);
// A variant where ids are indexes in data.
std::vector<uint32_t> BuildTrie(const std::vector<std::string>& data);
} // namespace mediapipe::tflite_operations::sentencepiece
#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_BUILDER_H_

View File

@ -0,0 +1,73 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_builder.h"
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config_generated.h"
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/utils.h"
namespace mediapipe::tflite_operations::sentencepiece {
TEST(DoubleArrayTrieTest, Match) {
flatbuffers::FlatBufferBuilder builder(1024);
const std::vector<std::string> test_strings = {"A", "AAX", "AA", "B"};
const auto trie_vector = builder.CreateVector(BuildTrie(test_strings));
TrieBuilder trie_builder(builder);
trie_builder.add_nodes(trie_vector);
const auto pieces = trie_builder.Finish();
EncoderConfigBuilder ecb(builder);
ecb.add_pieces(pieces);
FinishEncoderConfigBuffer(builder, ecb.Finish());
const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer());
DoubleArrayTrie dat(config->pieces()->nodes());
EXPECT_EQ(dat.LongestPrefixMatch(utils::string_view("AAL")),
DoubleArrayTrie::Match(2, 2));
std::vector<DoubleArrayTrie::Match> matches;
dat.IteratePrefixMatches(
utils::string_view("AAXL"),
[&matches](const DoubleArrayTrie::Match& m) { matches.push_back(m); });
EXPECT_THAT(matches, testing::ElementsAre(DoubleArrayTrie::Match(0, 1),
DoubleArrayTrie::Match(2, 2),
DoubleArrayTrie::Match(1, 3)));
}
TEST(DoubleArrayTrieTest, ComplexMatch) {
flatbuffers::FlatBufferBuilder builder(1024);
const std::vector<std::string> test_strings = {"\xe2\x96\x81the", ",", "s",
"\xe2\x96\x81Hello"};
const std::vector<int> test_ids = {0, 5, 10, 15};
const auto trie_vector =
builder.CreateVector(BuildTrie(test_strings, test_ids));
TrieBuilder trie_builder(builder);
trie_builder.add_nodes(trie_vector);
const auto pieces = trie_builder.Finish();
EncoderConfigBuilder ecb(builder);
ecb.add_pieces(pieces);
FinishEncoderConfigBuffer(builder, ecb.Finish());
const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer());
DoubleArrayTrie dat(config->pieces()->nodes());
std::vector<DoubleArrayTrie::Match> matches;
dat.IteratePrefixMatches(
utils::string_view("\xe2\x96\x81Hello"),
[&matches](const DoubleArrayTrie::Match& m) { matches.push_back(m); });
EXPECT_THAT(matches, testing::ElementsAre(DoubleArrayTrie::Match(15, 8)));
}
} // namespace mediapipe::tflite_operations::sentencepiece

View File

@ -0,0 +1,52 @@
// Copyright 2023 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
include "config.fbs";
namespace mediapipe.tflite_operations.sentencepiece;
table EncoderConfig {
// Version of the encoder.
version: EncoderVersion = SENTENCE_PIECE;
start_code: int32 = 0;
end_code: int32 = 0;
unknown_code: int32 = -1;
// Weight of "unknown code" when encoding. "Penalty" because it usually has a
// big negative weight,less than any other sentencepiece.
unknown_penalty: float = 0;
// The offset for encoding, usually used when codes with low codes are reserved
// for some special needs.
encoding_offset: int32;
// String pieces for encoding.
pieces: Trie;
pieces_scores: [float];
// Normalization related parameters.
remove_extra_whitespaces: bool;
// Add a whitespace prefix before encoding.
add_dummy_prefix: bool;
// Escape whitespaces during encoding so the decoder can restore them exactly as
// in the input.
escape_whitespaces: bool;
// Normalization parameters.
normalized_prefixes: Trie;
normalized_replacements: [byte];
}
root_type EncoderConfig;

View File

@ -0,0 +1,60 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_UTILS_H_
#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_UTILS_H_
#include <ostream>
#include <string>
namespace mediapipe::tflite_operations::sentencepiece {
// AOSP and WASM doesn't support string_view,
// we put here a minimal re-implementation.
namespace utils {
class string_view {
public:
explicit string_view(const std::string& s)
: str_(s.data()), len_(s.length()) {}
string_view(const char* str, int len) : str_(str), len_(len) {}
// A constructor from c string.
explicit string_view(const char* s) : str_(s), len_(strlen(s)) {}
int length() const { return len_; }
const char* data() const { return str_; }
bool empty() const { return len_ == 0; }
unsigned char at(int i) const { return str_[i]; }
private:
const char* str_ = nullptr;
const int len_ = 0;
};
inline std::ostream& operator<<(std::ostream& os, const string_view& sv) {
os << std::string(sv.data(), sv.length());
return os;
}
inline bool operator==(const string_view& view1, const string_view& view2) {
if (view1.length() != view2.length()) {
return false;
}
return memcmp(view1.data(), view2.data(), view1.length()) == 0;
}
} // namespace utils
} // namespace mediapipe::tflite_operations::sentencepiece
#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_UTILS_H_

29
third_party/darts_clone.BUILD vendored Normal file
View File

@ -0,0 +1,29 @@
# Copyright 2023 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Description:
# Darts-clone is a clone of Darts (Double-ARray Trie System).
licenses(["notice"])
exports_files(["LICENSE"])
package(default_visibility = ["//visibility:public"])
cc_library(
name = "darts_clone",
hdrs = [
"include/darts.h",
],
)