Internal MediaPipe Tasks change.

PiperOrigin-RevId: 524377097
This commit is contained in:
MediaPipe Team 2023-04-14 13:47:36 -07:00 committed by Copybara-Service
parent 92f45c98d8
commit abd6574c6d
8 changed files with 725 additions and 0 deletions

View File

@ -18,6 +18,13 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"]) licenses(["notice"])
filegroup(
name = "testdata",
srcs = glob([
"testdata/**",
]),
)
filegroup( filegroup(
name = "config_fbs", name = "config_fbs",
srcs = ["config.fbs"], srcs = ["config.fbs"],
@ -80,3 +87,66 @@ cc_test(
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
], ],
) )
cc_library(
name = "sentencepiece_constants",
hdrs = ["sentencepiece_constants.h"],
)
cc_library(
name = "model_converter",
srcs = [
"model_converter.cc",
],
hdrs = [
"model_converter.h",
],
deps = [
":config",
":double_array_trie_builder",
":encoder_config",
":sentencepiece_constants",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_sentencepiece//src:sentencepiece_model_cc_proto",
],
)
cc_library(
name = "optimized_encoder",
srcs = [
"optimized_encoder.cc",
],
hdrs = [
"optimized_encoder.h",
],
deps = [
":double_array_trie",
":encoder_config",
":utils",
],
)
cc_test(
name = "optimized_encoder_test",
srcs = [
"optimized_encoder_test.cc",
],
data = [
":testdata",
],
deps = [
":double_array_trie_builder",
":encoder_config",
":model_converter",
":optimized_encoder",
"//mediapipe/framework/deps:file_path",
"//mediapipe/framework/port:gtest_main",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings:str_format",
"@com_google_sentencepiece//src:sentencepiece_cc_proto",
"@com_google_sentencepiece//src:sentencepiece_processor",
"@org_tensorflow//tensorflow/core:lib",
],
)

View File

@ -0,0 +1,131 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/model_converter.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_builder.h"
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config_generated.h"
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_constants.h"
#include "src/sentencepiece_model.pb.h"
namespace mediapipe::tflite_operations::sentencepiece {
std::tuple<std::vector<uint32_t>, std::vector<int8_t>>
DecodePrecompiledCharsmap(
const ::sentencepiece::NormalizerSpec& normalizer_spec) {
// This function "undoes" encoding done by
// sentencepiece::normalizer::Normalizer::EncodePrecompiledCharsMap.
const char* precompiled_map = normalizer_spec.precompiled_charsmap().data();
const uint32_t trie_size =
*reinterpret_cast<const uint32_t*>(precompiled_map);
const uint32_t* trie_ptr =
reinterpret_cast<const uint32_t*>(precompiled_map + sizeof(uint32_t));
const int8_t* normalized_ptr = reinterpret_cast<const int8_t*>(
precompiled_map + sizeof(uint32_t) + trie_size);
const int normalized_size = normalizer_spec.precompiled_charsmap().length() -
sizeof(uint32_t) - trie_size;
return std::make_tuple(
std::vector<uint32_t>(trie_ptr, trie_ptr + trie_size / sizeof(uint32_t)),
std::vector<int8_t>(normalized_ptr, normalized_ptr + normalized_size));
}
absl::StatusOr<std::string> ConvertSentencepieceModelToFlatBuffer(
const std::string& model_config_str, int encoding_offset) {
::sentencepiece::ModelProto model_config;
if (!model_config.ParseFromString(model_config_str)) {
return absl::InvalidArgumentError(
"Invalid configuration, can't parse SentencePiece model config " +
model_config.InitializationErrorString());
}
// Convert sentencepieces.
std::vector<std::string> pieces;
pieces.reserve(model_config.pieces_size());
std::vector<float> scores;
scores.reserve(model_config.pieces_size());
std::vector<int> ids;
ids.reserve(model_config.pieces_size());
float min_score = 0.0;
int index = 0;
for (const auto& piece : model_config.pieces()) {
switch (piece.type()) {
case ::sentencepiece::ModelProto::SentencePiece::NORMAL:
case ::sentencepiece::ModelProto::SentencePiece::USER_DEFINED:
pieces.push_back(piece.piece());
ids.push_back(index);
if (piece.score() < min_score) {
min_score = piece.score();
}
break;
case ::sentencepiece::ModelProto::SentencePiece::UNKNOWN:
case ::sentencepiece::ModelProto::SentencePiece::CONTROL:
// Ignore unknown and control codes.
break;
default:
return absl::InvalidArgumentError("Invalid SentencePiece piece type " +
piece.piece());
}
scores.push_back(piece.score());
++index;
}
flatbuffers::FlatBufferBuilder builder(1024);
const auto pieces_trie_vector = builder.CreateVector(BuildTrie(pieces, ids));
const auto pieces_score_vector = builder.CreateVector(scores);
TrieBuilder pieces_trie_builder(builder);
pieces_trie_builder.add_nodes(pieces_trie_vector);
const auto pieces_trie_fbs = pieces_trie_builder.Finish();
// Converting normalization.
const auto normalization =
DecodePrecompiledCharsmap(model_config.normalizer_spec());
const auto normalization_trie = std::get<0>(normalization);
const auto normalization_strings = std::get<1>(normalization);
const auto normalization_trie_vector =
builder.CreateVector(normalization_trie);
TrieBuilder normalization_trie_builder(builder);
normalization_trie_builder.add_nodes(normalization_trie_vector);
const auto normalization_trie_fbs = normalization_trie_builder.Finish();
const auto normalization_strings_fbs =
builder.CreateVector(normalization_strings);
EncoderConfigBuilder ecb(builder);
ecb.add_version(EncoderVersion::EncoderVersion_SENTENCE_PIECE);
ecb.add_start_code(model_config.trainer_spec().bos_id());
ecb.add_end_code(model_config.trainer_spec().eos_id());
ecb.add_unknown_code(model_config.trainer_spec().unk_id());
ecb.add_unknown_penalty(min_score - kUnkPenalty);
ecb.add_encoding_offset(encoding_offset);
ecb.add_pieces(pieces_trie_fbs);
ecb.add_pieces_scores(pieces_score_vector);
ecb.add_remove_extra_whitespaces(
model_config.normalizer_spec().remove_extra_whitespaces());
ecb.add_add_dummy_prefix(model_config.normalizer_spec().add_dummy_prefix());
ecb.add_escape_whitespaces(
model_config.normalizer_spec().escape_whitespaces());
ecb.add_normalized_prefixes(normalization_trie_fbs);
ecb.add_normalized_replacements(normalization_strings_fbs);
FinishEncoderConfigBuffer(builder, ecb.Finish());
return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
builder.GetSize());
}
std::string ConvertSentencepieceModel(const std::string& model_string) {
const auto result = ConvertSentencepieceModelToFlatBuffer(model_string);
assert(result.status().ok());
return result.value();
}
} // namespace mediapipe::tflite_operations::sentencepiece

View File

@ -0,0 +1,33 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_MODEL_CONVERTER_H_
#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_MODEL_CONVERTER_H_
#include <string>
#include "absl/status/statusor.h"
namespace mediapipe::tflite_operations::sentencepiece {
// Converts Sentencepiece configuration to flatbuffer format.
// encoding_offset is used by some encoders that combine different encodings.
absl::StatusOr<std::string> ConvertSentencepieceModelToFlatBuffer(
const std::string& model_config_str, int encoding_offset = 0);
std::string ConvertSentencepieceModel(const std::string& model_string);
} // namespace mediapipe::tflite_operations::sentencepiece
#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_MODEL_CONVERTER_H_

View File

@ -0,0 +1,236 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.h"
#include <algorithm>
#include <tuple>
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie.h"
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config_generated.h"
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/utils.h"
namespace mediapipe::tflite_operations::sentencepiece {
namespace {
const char kSpaceSymbol[] = "\xe2\x96\x81";
template <typename processing_callback>
std::tuple<std::string, std::vector<int>> process_string(
const std::string& input, const std::vector<int>& offsets,
const processing_callback& pc) {
std::string result_string;
result_string.reserve(input.size());
std::vector<int> result_offsets;
result_offsets.reserve(offsets.size());
for (int i = 0, j = 0; i < input.size();) {
auto result = pc(input.data() + i, input.size() - i);
auto consumed = std::get<0>(result);
auto new_string = std::get<1>(result);
if (consumed == 0) {
// Skip the current byte and move forward.
result_string.push_back(input[i]);
result_offsets.push_back(offsets[j]);
i++;
j++;
continue;
}
result_string.append(new_string.data(), new_string.length());
for (int i = 0; i < new_string.length(); ++i) {
result_offsets.push_back(offsets[j]);
}
j += consumed;
i += consumed;
}
return std::make_tuple(result_string, result_offsets);
}
inline char is_whitespace(char c) {
return c == ' ' || c == '\t' || c == '\r' || c == '\n';
}
std::tuple<int, utils::string_view> remove_extra_whitespaces(const char* data,
int len) {
if (len == 0 || !is_whitespace(*data)) {
return std::make_tuple(0, utils::string_view(nullptr, 0));
}
int num_consumed = 1;
for (; num_consumed < len && is_whitespace(data[num_consumed]);
++num_consumed) {
}
return num_consumed > 1
? std::make_tuple(num_consumed, utils::string_view(" ", 1))
: std::make_tuple(0, utils::string_view(nullptr, 0));
}
std::tuple<int, utils::string_view> find_replacement(
const char* data, int len, const DoubleArrayTrie& dat,
const flatbuffers::Vector<int8_t>& replacements) {
const auto max_match = dat.LongestPrefixMatch(utils::string_view(data, len));
if (!max_match.empty()) {
// Because flatbuffer byte is signed char which is not the same as char,
// there is the reinterpret_cast here.
const char* replaced_string_ptr =
reinterpret_cast<const char*>(replacements.data() + max_match.id);
return std::make_tuple(max_match.match_length,
utils::string_view(replaced_string_ptr));
}
return std::make_tuple(0, utils::string_view(nullptr, 0));
}
} // namespace
std::tuple<std::string, std::vector<int>> NormalizeString(
const std::string& in_string, const EncoderConfig& config) {
std::vector<int> output_offsets;
std::string result = in_string;
output_offsets.reserve(in_string.length());
for (int i = 0; i < in_string.length(); ++i) {
output_offsets.push_back(i);
}
if (in_string.empty()) {
return std::make_tuple(result, output_offsets);
}
if (config.add_dummy_prefix()) {
result.insert(result.begin(), ' ');
output_offsets.insert(output_offsets.begin(), 0);
}
// Greedely replace normalized_prefixes with normalized_replacements
if (config.normalized_prefixes() != nullptr &&
config.normalized_replacements() != nullptr) {
const DoubleArrayTrie normalized_prefixes_matcher(
config.normalized_prefixes()->nodes());
const auto norm_replace = [&config, &normalized_prefixes_matcher](
const char* data, int len) {
return find_replacement(data, len, normalized_prefixes_matcher,
*config.normalized_replacements());
};
std::tie(result, output_offsets) =
process_string(result, output_offsets, norm_replace);
}
if (config.remove_extra_whitespaces()) {
std::tie(result, output_offsets) =
process_string(result, output_offsets, remove_extra_whitespaces);
if (!result.empty() && is_whitespace(result.back())) {
result.pop_back();
output_offsets.pop_back();
}
}
if (config.escape_whitespaces()) {
const auto replace_whitespaces = [](const char* data, int len) {
if (len > 0 && is_whitespace(*data)) {
return std::make_tuple(1, utils::string_view(kSpaceSymbol));
}
return std::make_tuple(0, utils::string_view(nullptr, 0));
};
std::tie(result, output_offsets) =
process_string(result, output_offsets, replace_whitespaces);
}
return std::make_tuple(result, output_offsets);
}
EncoderResult EncodeNormalizedString(const std::string& str,
const std::vector<int>& offsets,
const EncoderConfig& config, bool add_bos,
bool add_eos, bool reverse) {
const DoubleArrayTrie piece_matcher(config.pieces()->nodes());
const flatbuffers::Vector<float>* piece_scores = config.pieces_scores();
const int unknown_code = config.unknown_code();
const float unknown_penalty = config.unknown_penalty();
struct LatticeElement {
float score = 0;
int code = -1;
int prev_position = -1;
LatticeElement(float score_, int code_, int prev_position_)
: score(score_), code(code_), prev_position(prev_position_) {}
LatticeElement() {}
};
const int length = str.length();
std::vector<LatticeElement> lattice(length + 1);
for (int i = 0; i < length; ++i) {
if (i > 0 && lattice[i].prev_position < 0) {
// This state is unreachable.
continue;
}
if (unknown_code >= 0) {
// Put unknown code.
const float penalized_score = lattice[i].score + unknown_penalty;
const int pos = i + 1;
LatticeElement& current_element = lattice[pos];
if (current_element.prev_position < 0 ||
current_element.score < penalized_score) {
current_element = LatticeElement(
penalized_score, unknown_code,
// If the current state is already reached by unknown code, merge
// states.
lattice[i].code == unknown_code ? lattice[i].prev_position : i);
}
}
auto lattice_update = [&lattice, i,
piece_scores](const DoubleArrayTrie::Match& m) {
LatticeElement& target_element = lattice[i + m.match_length];
const float score = lattice[i].score + (*piece_scores)[m.id];
if (target_element.prev_position < 0 || target_element.score < score) {
target_element = LatticeElement(score, m.id, i);
}
};
piece_matcher.IteratePrefixMatches(
utils::string_view(str.data() + i, length - i), lattice_update);
}
EncoderResult result;
if (add_eos) {
result.codes.push_back(config.end_code());
result.offsets.push_back(length);
}
if (lattice[length].prev_position >= 0) {
for (int pos = length; pos > 0;) {
auto code = lattice[pos].code;
if (code != config.unknown_code()) {
code += config.encoding_offset();
}
result.codes.push_back(code);
pos = lattice[pos].prev_position;
result.offsets.push_back(offsets[pos]);
}
}
if (add_bos) {
result.codes.push_back(config.start_code());
result.offsets.push_back(0);
}
if (!reverse) {
std::reverse(result.codes.begin(), result.codes.end());
std::reverse(result.offsets.begin(), result.offsets.end());
}
return result;
}
EncoderResult EncodeString(const std::string& string, const void* config_buffer,
bool add_bos, bool add_eos, bool reverse) {
// Get the config from the buffer.
const EncoderConfig* config = GetEncoderConfig(config_buffer);
if (config->version() != EncoderVersion::EncoderVersion_SENTENCE_PIECE) {
EncoderResult result;
result.type = EncoderResultType::WRONG_CONFIG;
return result;
}
std::string normalized_string;
std::vector<int> offsets;
std::tie(normalized_string, offsets) = NormalizeString(string, *config);
return EncodeNormalizedString(normalized_string, offsets, *config, add_bos,
add_eos, reverse);
}
} // namespace mediapipe::tflite_operations::sentencepiece

View File

@ -0,0 +1,46 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_OPTIMIZED_ENCODER_H_
#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_OPTIMIZED_ENCODER_H_
// Sentencepiece encoder optimized with memmapped model.
#include <string>
#include <tuple>
#include <vector>
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config_generated.h"
namespace mediapipe::tflite_operations::sentencepiece {
enum class EncoderResultType { SUCCESS = 0, WRONG_CONFIG = 1 };
struct EncoderResult {
EncoderResultType type = EncoderResultType::SUCCESS;
std::vector<int> codes;
std::vector<int> offsets;
};
std::tuple<std::string, std::vector<int>> NormalizeString(
const std::string& in_string, const EncoderConfig& config);
// Encodes one string and returns ids and offsets. Takes the configuration as a
// type-erased buffer.
EncoderResult EncodeString(const std::string& string, const void* config_buffer,
bool add_bos, bool add_eos, bool reverse);
} // namespace mediapipe::tflite_operations::sentencepiece
#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_OPTIMIZED_ENCODER_H_

View File

@ -0,0 +1,171 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.h"
#include <fstream>
#include "absl/flags/flag.h"
#include "absl/status/status.h"
#include "absl/strings/str_format.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_builder.h"
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config_generated.h"
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/model_converter.h"
#include "src/sentencepiece.pb.h"
#include "src/sentencepiece_processor.h"
#include "tensorflow/core/platform/env.h"
namespace mediapipe::tflite_operations::sentencepiece {
namespace internal {
tensorflow::Status TFReadFileToString(const std::string& filepath,
std::string* data) {
return tensorflow::ReadFileToString(tensorflow::Env::Default(), filepath,
data);
}
absl::Status StdReadFileToString(const std::string& filepath,
std::string* data) {
std::ifstream infile(filepath);
if (!infile.is_open()) {
return absl::NotFoundError(
absl::StrFormat("Error when opening %s", filepath));
}
std::string contents((std::istreambuf_iterator<char>(infile)),
(std::istreambuf_iterator<char>()));
data->append(contents);
infile.close();
return absl::OkStatus();
}
} // namespace internal
namespace {
using ::mediapipe::file::JoinPath;
static char kConfigFilePath[] =
"/mediapipe/tasks/cc/text/custom_ops/"
"sentencepiece/testdata/sentencepiece.model";
TEST(OptimizedEncoder, NormalizeStringWhitestpaces) {
flatbuffers::FlatBufferBuilder builder(1024);
EncoderConfigBuilder ecb(builder);
ecb.add_remove_extra_whitespaces(true);
ecb.add_add_dummy_prefix(true);
ecb.add_escape_whitespaces(true);
FinishEncoderConfigBuffer(builder, ecb.Finish());
const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer());
{
const auto result = NormalizeString("x y", *config);
const auto res_string = std::get<0>(result);
const auto offsets = std::get<1>(result);
EXPECT_EQ(res_string, "\xe2\x96\x81x\xe2\x96\x81y");
EXPECT_THAT(offsets, ::testing::ElementsAre(0, 0, 0, 0, 1, 1, 1, 3));
}
{
const auto result = NormalizeString("\tx y\n", *config);
const auto res_string = std::get<0>(result);
const auto offsets = std::get<1>(result);
EXPECT_EQ(res_string, "\xe2\x96\x81x\xe2\x96\x81y");
EXPECT_THAT(offsets, ::testing::ElementsAre(0, 0, 0, 1, 2, 2, 2, 4));
}
}
TEST(OptimizedEncoder, NormalizeStringReplacement) {
flatbuffers::FlatBufferBuilder builder(1024);
const std::vector<std::string> norm_prefixes = {"A", "AA", "AAA", "AAAA"};
const char norm_replacements[] = "A1\0A2\0A3\0A4";
const auto trie_vector =
builder.CreateVector(BuildTrie(norm_prefixes, {0, 3, 6, 9}));
const auto norm_r = builder.CreateVector<int8_t>(
reinterpret_cast<const signed char*>(norm_replacements),
sizeof(norm_replacements));
TrieBuilder trie_builder(builder);
trie_builder.add_nodes(trie_vector);
const auto norm_p = trie_builder.Finish();
EncoderConfigBuilder ecb(builder);
ecb.add_remove_extra_whitespaces(false);
ecb.add_normalized_prefixes(norm_p);
ecb.add_normalized_replacements(norm_r);
FinishEncoderConfigBuffer(builder, ecb.Finish());
const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer());
{
const auto result = NormalizeString("ABAABAAABAAAA", *config);
const auto res_string = std::get<0>(result);
const auto offsets = std::get<1>(result);
EXPECT_EQ(res_string, "A1BA2BA3BA4");
EXPECT_THAT(offsets,
::testing::ElementsAre(0, 0, 1, 2, 2, 4, 5, 5, 8, 9, 9));
}
}
TEST(OptimizedEncoder, NormalizeStringWhitespacesRemove) {
flatbuffers::FlatBufferBuilder builder(1024);
const std::vector<std::string> norm_prefixes = {"A", "AA", "AAA", "AAAA",
"X"};
const char norm_replacements[] = "A1\0A2\0A3\0A4\0 ";
const auto trie_vector =
builder.CreateVector(BuildTrie(norm_prefixes, {0, 3, 6, 9, 12}));
const auto norm_r = builder.CreateVector<int8_t>(
reinterpret_cast<const signed char*>(norm_replacements),
sizeof(norm_replacements));
TrieBuilder trie_builder(builder);
trie_builder.add_nodes(trie_vector);
const auto norm_p = trie_builder.Finish();
EncoderConfigBuilder ecb(builder);
ecb.add_remove_extra_whitespaces(true);
ecb.add_normalized_prefixes(norm_p);
ecb.add_normalized_replacements(norm_r);
FinishEncoderConfigBuffer(builder, ecb.Finish());
const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer());
{
const auto result = NormalizeString("XXABAABAAABAAAA", *config);
const auto res_string = std::get<0>(result);
const auto offsets = std::get<1>(result);
EXPECT_EQ(res_string, " A1BA2BA3BA4");
EXPECT_THAT(offsets,
::testing::ElementsAre(0, 2, 2, 3, 4, 4, 6, 7, 7, 10, 11, 11));
}
}
TEST(OptimizedEncoder, ConfigConverter) {
std::string config;
auto status =
internal::TFReadFileToString(JoinPath("./", kConfigFilePath), &config);
ASSERT_TRUE(status.ok());
::sentencepiece::SentencePieceProcessor processor;
ASSERT_TRUE(processor.LoadFromSerializedProto(config).ok());
const auto converted_model = ConvertSentencepieceModel(config);
const std::string test_string("Hello world!\\xF0\\x9F\\x8D\\x95");
const auto encoded =
EncodeString(test_string, converted_model.data(), false, false, false);
ASSERT_EQ(encoded.codes.size(), encoded.offsets.size());
::sentencepiece::SentencePieceText reference_encoded;
ASSERT_TRUE(processor.Encode(test_string, &reference_encoded).ok());
EXPECT_EQ(encoded.codes.size(), reference_encoded.pieces_size());
for (int i = 0; i < encoded.codes.size(); ++i) {
EXPECT_EQ(encoded.codes[i], reference_encoded.pieces(i).id());
EXPECT_EQ(encoded.offsets[i], reference_encoded.pieces(i).begin());
}
}
} // namespace
} // namespace mediapipe::tflite_operations::sentencepiece

View File

@ -0,0 +1,38 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_CONSTANTS_H_
#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_CONSTANTS_H_
namespace mediapipe::tflite_operations::sentencepiece {
// The constant is copied from
// https://github.com/google/sentencepiece/blob/master/src/unigram_model.cc
constexpr float kUnkPenalty = 10.0;
// These constants are copied from
// https://github.com/google/sentencepiece/blob/master/src/sentencepiece_processor.cc
//
// Replaces white space with U+2581 (LOWER ONE EIGHT BLOCK).
constexpr char kSpaceSymbol[] = "\xe2\x96\x81";
// Encodes <unk> into U+2047 (DOUBLE QUESTION MARK),
// since this character can be useful both for user and
// developer. We can easily figure out that <unk> is emitted.
constexpr char kDefaultUnknownSymbol[] = " \xE2\x81\x87 ";
} // namespace mediapipe::tflite_operations::sentencepiece
#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_CONSTANTS_H_