Internal MediaPipe Tasks change.
PiperOrigin-RevId: 524377097
This commit is contained in:
parent
92f45c98d8
commit
abd6574c6d
|
@ -18,6 +18,13 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
|
|||
|
||||
licenses(["notice"])
|
||||
|
||||
filegroup(
|
||||
name = "testdata",
|
||||
srcs = glob([
|
||||
"testdata/**",
|
||||
]),
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "config_fbs",
|
||||
srcs = ["config.fbs"],
|
||||
|
@ -80,3 +87,66 @@ cc_test(
|
|||
"//mediapipe/framework/port:gtest_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "sentencepiece_constants",
|
||||
hdrs = ["sentencepiece_constants.h"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "model_converter",
|
||||
srcs = [
|
||||
"model_converter.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"model_converter.h",
|
||||
],
|
||||
deps = [
|
||||
":config",
|
||||
":double_array_trie_builder",
|
||||
":encoder_config",
|
||||
":sentencepiece_constants",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_sentencepiece//src:sentencepiece_model_cc_proto",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "optimized_encoder",
|
||||
srcs = [
|
||||
"optimized_encoder.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"optimized_encoder.h",
|
||||
],
|
||||
deps = [
|
||||
":double_array_trie",
|
||||
":encoder_config",
|
||||
":utils",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "optimized_encoder_test",
|
||||
srcs = [
|
||||
"optimized_encoder_test.cc",
|
||||
],
|
||||
data = [
|
||||
":testdata",
|
||||
],
|
||||
deps = [
|
||||
":double_array_trie_builder",
|
||||
":encoder_config",
|
||||
":model_converter",
|
||||
":optimized_encoder",
|
||||
"//mediapipe/framework/deps:file_path",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"@com_google_absl//absl/flags:flag",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_sentencepiece//src:sentencepiece_cc_proto",
|
||||
"@com_google_sentencepiece//src:sentencepiece_processor",
|
||||
"@org_tensorflow//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -0,0 +1,131 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/model_converter.h"
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_builder.h"
|
||||
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config_generated.h"
|
||||
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_constants.h"
|
||||
#include "src/sentencepiece_model.pb.h"
|
||||
|
||||
namespace mediapipe::tflite_operations::sentencepiece {
|
||||
|
||||
std::tuple<std::vector<uint32_t>, std::vector<int8_t>>
|
||||
DecodePrecompiledCharsmap(
|
||||
const ::sentencepiece::NormalizerSpec& normalizer_spec) {
|
||||
// This function "undoes" encoding done by
|
||||
// sentencepiece::normalizer::Normalizer::EncodePrecompiledCharsMap.
|
||||
const char* precompiled_map = normalizer_spec.precompiled_charsmap().data();
|
||||
const uint32_t trie_size =
|
||||
*reinterpret_cast<const uint32_t*>(precompiled_map);
|
||||
const uint32_t* trie_ptr =
|
||||
reinterpret_cast<const uint32_t*>(precompiled_map + sizeof(uint32_t));
|
||||
const int8_t* normalized_ptr = reinterpret_cast<const int8_t*>(
|
||||
precompiled_map + sizeof(uint32_t) + trie_size);
|
||||
const int normalized_size = normalizer_spec.precompiled_charsmap().length() -
|
||||
sizeof(uint32_t) - trie_size;
|
||||
return std::make_tuple(
|
||||
std::vector<uint32_t>(trie_ptr, trie_ptr + trie_size / sizeof(uint32_t)),
|
||||
std::vector<int8_t>(normalized_ptr, normalized_ptr + normalized_size));
|
||||
}
|
||||
|
||||
absl::StatusOr<std::string> ConvertSentencepieceModelToFlatBuffer(
|
||||
const std::string& model_config_str, int encoding_offset) {
|
||||
::sentencepiece::ModelProto model_config;
|
||||
if (!model_config.ParseFromString(model_config_str)) {
|
||||
return absl::InvalidArgumentError(
|
||||
"Invalid configuration, can't parse SentencePiece model config " +
|
||||
model_config.InitializationErrorString());
|
||||
}
|
||||
// Convert sentencepieces.
|
||||
std::vector<std::string> pieces;
|
||||
pieces.reserve(model_config.pieces_size());
|
||||
std::vector<float> scores;
|
||||
scores.reserve(model_config.pieces_size());
|
||||
std::vector<int> ids;
|
||||
ids.reserve(model_config.pieces_size());
|
||||
float min_score = 0.0;
|
||||
int index = 0;
|
||||
for (const auto& piece : model_config.pieces()) {
|
||||
switch (piece.type()) {
|
||||
case ::sentencepiece::ModelProto::SentencePiece::NORMAL:
|
||||
case ::sentencepiece::ModelProto::SentencePiece::USER_DEFINED:
|
||||
pieces.push_back(piece.piece());
|
||||
ids.push_back(index);
|
||||
if (piece.score() < min_score) {
|
||||
min_score = piece.score();
|
||||
}
|
||||
break;
|
||||
case ::sentencepiece::ModelProto::SentencePiece::UNKNOWN:
|
||||
case ::sentencepiece::ModelProto::SentencePiece::CONTROL:
|
||||
// Ignore unknown and control codes.
|
||||
break;
|
||||
default:
|
||||
return absl::InvalidArgumentError("Invalid SentencePiece piece type " +
|
||||
piece.piece());
|
||||
}
|
||||
scores.push_back(piece.score());
|
||||
++index;
|
||||
}
|
||||
flatbuffers::FlatBufferBuilder builder(1024);
|
||||
const auto pieces_trie_vector = builder.CreateVector(BuildTrie(pieces, ids));
|
||||
const auto pieces_score_vector = builder.CreateVector(scores);
|
||||
TrieBuilder pieces_trie_builder(builder);
|
||||
pieces_trie_builder.add_nodes(pieces_trie_vector);
|
||||
const auto pieces_trie_fbs = pieces_trie_builder.Finish();
|
||||
|
||||
// Converting normalization.
|
||||
const auto normalization =
|
||||
DecodePrecompiledCharsmap(model_config.normalizer_spec());
|
||||
const auto normalization_trie = std::get<0>(normalization);
|
||||
const auto normalization_strings = std::get<1>(normalization);
|
||||
const auto normalization_trie_vector =
|
||||
builder.CreateVector(normalization_trie);
|
||||
TrieBuilder normalization_trie_builder(builder);
|
||||
normalization_trie_builder.add_nodes(normalization_trie_vector);
|
||||
const auto normalization_trie_fbs = normalization_trie_builder.Finish();
|
||||
const auto normalization_strings_fbs =
|
||||
builder.CreateVector(normalization_strings);
|
||||
|
||||
EncoderConfigBuilder ecb(builder);
|
||||
ecb.add_version(EncoderVersion::EncoderVersion_SENTENCE_PIECE);
|
||||
ecb.add_start_code(model_config.trainer_spec().bos_id());
|
||||
ecb.add_end_code(model_config.trainer_spec().eos_id());
|
||||
ecb.add_unknown_code(model_config.trainer_spec().unk_id());
|
||||
ecb.add_unknown_penalty(min_score - kUnkPenalty);
|
||||
ecb.add_encoding_offset(encoding_offset);
|
||||
ecb.add_pieces(pieces_trie_fbs);
|
||||
ecb.add_pieces_scores(pieces_score_vector);
|
||||
ecb.add_remove_extra_whitespaces(
|
||||
model_config.normalizer_spec().remove_extra_whitespaces());
|
||||
ecb.add_add_dummy_prefix(model_config.normalizer_spec().add_dummy_prefix());
|
||||
ecb.add_escape_whitespaces(
|
||||
model_config.normalizer_spec().escape_whitespaces());
|
||||
ecb.add_normalized_prefixes(normalization_trie_fbs);
|
||||
ecb.add_normalized_replacements(normalization_strings_fbs);
|
||||
FinishEncoderConfigBuffer(builder, ecb.Finish());
|
||||
return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
|
||||
builder.GetSize());
|
||||
}
|
||||
|
||||
std::string ConvertSentencepieceModel(const std::string& model_string) {
|
||||
const auto result = ConvertSentencepieceModelToFlatBuffer(model_string);
|
||||
assert(result.status().ok());
|
||||
return result.value();
|
||||
}
|
||||
|
||||
} // namespace mediapipe::tflite_operations::sentencepiece
|
|
@ -0,0 +1,33 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_MODEL_CONVERTER_H_
|
||||
#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_MODEL_CONVERTER_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
|
||||
namespace mediapipe::tflite_operations::sentencepiece {
|
||||
|
||||
// Converts Sentencepiece configuration to flatbuffer format.
|
||||
// encoding_offset is used by some encoders that combine different encodings.
|
||||
absl::StatusOr<std::string> ConvertSentencepieceModelToFlatBuffer(
|
||||
const std::string& model_config_str, int encoding_offset = 0);
|
||||
std::string ConvertSentencepieceModel(const std::string& model_string);
|
||||
|
||||
} // namespace mediapipe::tflite_operations::sentencepiece
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_MODEL_CONVERTER_H_
|
|
@ -0,0 +1,236 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <tuple>
|
||||
|
||||
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie.h"
|
||||
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config_generated.h"
|
||||
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/utils.h"
|
||||
|
||||
namespace mediapipe::tflite_operations::sentencepiece {
|
||||
namespace {
|
||||
|
||||
const char kSpaceSymbol[] = "\xe2\x96\x81";
|
||||
|
||||
template <typename processing_callback>
|
||||
std::tuple<std::string, std::vector<int>> process_string(
|
||||
const std::string& input, const std::vector<int>& offsets,
|
||||
const processing_callback& pc) {
|
||||
std::string result_string;
|
||||
result_string.reserve(input.size());
|
||||
std::vector<int> result_offsets;
|
||||
result_offsets.reserve(offsets.size());
|
||||
for (int i = 0, j = 0; i < input.size();) {
|
||||
auto result = pc(input.data() + i, input.size() - i);
|
||||
auto consumed = std::get<0>(result);
|
||||
auto new_string = std::get<1>(result);
|
||||
if (consumed == 0) {
|
||||
// Skip the current byte and move forward.
|
||||
result_string.push_back(input[i]);
|
||||
result_offsets.push_back(offsets[j]);
|
||||
i++;
|
||||
j++;
|
||||
continue;
|
||||
}
|
||||
result_string.append(new_string.data(), new_string.length());
|
||||
for (int i = 0; i < new_string.length(); ++i) {
|
||||
result_offsets.push_back(offsets[j]);
|
||||
}
|
||||
j += consumed;
|
||||
i += consumed;
|
||||
}
|
||||
return std::make_tuple(result_string, result_offsets);
|
||||
}
|
||||
|
||||
inline char is_whitespace(char c) {
|
||||
return c == ' ' || c == '\t' || c == '\r' || c == '\n';
|
||||
}
|
||||
|
||||
std::tuple<int, utils::string_view> remove_extra_whitespaces(const char* data,
|
||||
int len) {
|
||||
if (len == 0 || !is_whitespace(*data)) {
|
||||
return std::make_tuple(0, utils::string_view(nullptr, 0));
|
||||
}
|
||||
int num_consumed = 1;
|
||||
for (; num_consumed < len && is_whitespace(data[num_consumed]);
|
||||
++num_consumed) {
|
||||
}
|
||||
return num_consumed > 1
|
||||
? std::make_tuple(num_consumed, utils::string_view(" ", 1))
|
||||
: std::make_tuple(0, utils::string_view(nullptr, 0));
|
||||
}
|
||||
|
||||
std::tuple<int, utils::string_view> find_replacement(
|
||||
const char* data, int len, const DoubleArrayTrie& dat,
|
||||
const flatbuffers::Vector<int8_t>& replacements) {
|
||||
const auto max_match = dat.LongestPrefixMatch(utils::string_view(data, len));
|
||||
if (!max_match.empty()) {
|
||||
// Because flatbuffer byte is signed char which is not the same as char,
|
||||
// there is the reinterpret_cast here.
|
||||
const char* replaced_string_ptr =
|
||||
reinterpret_cast<const char*>(replacements.data() + max_match.id);
|
||||
return std::make_tuple(max_match.match_length,
|
||||
utils::string_view(replaced_string_ptr));
|
||||
}
|
||||
return std::make_tuple(0, utils::string_view(nullptr, 0));
|
||||
}
|
||||
} // namespace
|
||||
|
||||
std::tuple<std::string, std::vector<int>> NormalizeString(
|
||||
const std::string& in_string, const EncoderConfig& config) {
|
||||
std::vector<int> output_offsets;
|
||||
std::string result = in_string;
|
||||
output_offsets.reserve(in_string.length());
|
||||
for (int i = 0; i < in_string.length(); ++i) {
|
||||
output_offsets.push_back(i);
|
||||
}
|
||||
if (in_string.empty()) {
|
||||
return std::make_tuple(result, output_offsets);
|
||||
}
|
||||
if (config.add_dummy_prefix()) {
|
||||
result.insert(result.begin(), ' ');
|
||||
output_offsets.insert(output_offsets.begin(), 0);
|
||||
}
|
||||
// Greedely replace normalized_prefixes with normalized_replacements
|
||||
if (config.normalized_prefixes() != nullptr &&
|
||||
config.normalized_replacements() != nullptr) {
|
||||
const DoubleArrayTrie normalized_prefixes_matcher(
|
||||
config.normalized_prefixes()->nodes());
|
||||
const auto norm_replace = [&config, &normalized_prefixes_matcher](
|
||||
const char* data, int len) {
|
||||
return find_replacement(data, len, normalized_prefixes_matcher,
|
||||
*config.normalized_replacements());
|
||||
};
|
||||
std::tie(result, output_offsets) =
|
||||
process_string(result, output_offsets, norm_replace);
|
||||
}
|
||||
if (config.remove_extra_whitespaces()) {
|
||||
std::tie(result, output_offsets) =
|
||||
process_string(result, output_offsets, remove_extra_whitespaces);
|
||||
if (!result.empty() && is_whitespace(result.back())) {
|
||||
result.pop_back();
|
||||
output_offsets.pop_back();
|
||||
}
|
||||
}
|
||||
if (config.escape_whitespaces()) {
|
||||
const auto replace_whitespaces = [](const char* data, int len) {
|
||||
if (len > 0 && is_whitespace(*data)) {
|
||||
return std::make_tuple(1, utils::string_view(kSpaceSymbol));
|
||||
}
|
||||
return std::make_tuple(0, utils::string_view(nullptr, 0));
|
||||
};
|
||||
std::tie(result, output_offsets) =
|
||||
process_string(result, output_offsets, replace_whitespaces);
|
||||
}
|
||||
|
||||
return std::make_tuple(result, output_offsets);
|
||||
}
|
||||
|
||||
EncoderResult EncodeNormalizedString(const std::string& str,
|
||||
const std::vector<int>& offsets,
|
||||
const EncoderConfig& config, bool add_bos,
|
||||
bool add_eos, bool reverse) {
|
||||
const DoubleArrayTrie piece_matcher(config.pieces()->nodes());
|
||||
const flatbuffers::Vector<float>* piece_scores = config.pieces_scores();
|
||||
const int unknown_code = config.unknown_code();
|
||||
const float unknown_penalty = config.unknown_penalty();
|
||||
struct LatticeElement {
|
||||
float score = 0;
|
||||
int code = -1;
|
||||
int prev_position = -1;
|
||||
LatticeElement(float score_, int code_, int prev_position_)
|
||||
: score(score_), code(code_), prev_position(prev_position_) {}
|
||||
LatticeElement() {}
|
||||
};
|
||||
const int length = str.length();
|
||||
std::vector<LatticeElement> lattice(length + 1);
|
||||
for (int i = 0; i < length; ++i) {
|
||||
if (i > 0 && lattice[i].prev_position < 0) {
|
||||
// This state is unreachable.
|
||||
continue;
|
||||
}
|
||||
if (unknown_code >= 0) {
|
||||
// Put unknown code.
|
||||
const float penalized_score = lattice[i].score + unknown_penalty;
|
||||
const int pos = i + 1;
|
||||
LatticeElement& current_element = lattice[pos];
|
||||
if (current_element.prev_position < 0 ||
|
||||
current_element.score < penalized_score) {
|
||||
current_element = LatticeElement(
|
||||
penalized_score, unknown_code,
|
||||
// If the current state is already reached by unknown code, merge
|
||||
// states.
|
||||
lattice[i].code == unknown_code ? lattice[i].prev_position : i);
|
||||
}
|
||||
}
|
||||
auto lattice_update = [&lattice, i,
|
||||
piece_scores](const DoubleArrayTrie::Match& m) {
|
||||
LatticeElement& target_element = lattice[i + m.match_length];
|
||||
const float score = lattice[i].score + (*piece_scores)[m.id];
|
||||
if (target_element.prev_position < 0 || target_element.score < score) {
|
||||
target_element = LatticeElement(score, m.id, i);
|
||||
}
|
||||
};
|
||||
piece_matcher.IteratePrefixMatches(
|
||||
utils::string_view(str.data() + i, length - i), lattice_update);
|
||||
}
|
||||
|
||||
EncoderResult result;
|
||||
if (add_eos) {
|
||||
result.codes.push_back(config.end_code());
|
||||
result.offsets.push_back(length);
|
||||
}
|
||||
if (lattice[length].prev_position >= 0) {
|
||||
for (int pos = length; pos > 0;) {
|
||||
auto code = lattice[pos].code;
|
||||
if (code != config.unknown_code()) {
|
||||
code += config.encoding_offset();
|
||||
}
|
||||
result.codes.push_back(code);
|
||||
pos = lattice[pos].prev_position;
|
||||
result.offsets.push_back(offsets[pos]);
|
||||
}
|
||||
}
|
||||
if (add_bos) {
|
||||
result.codes.push_back(config.start_code());
|
||||
result.offsets.push_back(0);
|
||||
}
|
||||
if (!reverse) {
|
||||
std::reverse(result.codes.begin(), result.codes.end());
|
||||
std::reverse(result.offsets.begin(), result.offsets.end());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
EncoderResult EncodeString(const std::string& string, const void* config_buffer,
|
||||
bool add_bos, bool add_eos, bool reverse) {
|
||||
// Get the config from the buffer.
|
||||
const EncoderConfig* config = GetEncoderConfig(config_buffer);
|
||||
if (config->version() != EncoderVersion::EncoderVersion_SENTENCE_PIECE) {
|
||||
EncoderResult result;
|
||||
result.type = EncoderResultType::WRONG_CONFIG;
|
||||
return result;
|
||||
}
|
||||
std::string normalized_string;
|
||||
std::vector<int> offsets;
|
||||
std::tie(normalized_string, offsets) = NormalizeString(string, *config);
|
||||
return EncodeNormalizedString(normalized_string, offsets, *config, add_bos,
|
||||
add_eos, reverse);
|
||||
}
|
||||
|
||||
} // namespace mediapipe::tflite_operations::sentencepiece
|
|
@ -0,0 +1,46 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_OPTIMIZED_ENCODER_H_
|
||||
#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_OPTIMIZED_ENCODER_H_
|
||||
|
||||
// Sentencepiece encoder optimized with memmapped model.
|
||||
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config_generated.h"
|
||||
|
||||
namespace mediapipe::tflite_operations::sentencepiece {
|
||||
|
||||
enum class EncoderResultType { SUCCESS = 0, WRONG_CONFIG = 1 };
|
||||
|
||||
struct EncoderResult {
|
||||
EncoderResultType type = EncoderResultType::SUCCESS;
|
||||
std::vector<int> codes;
|
||||
std::vector<int> offsets;
|
||||
};
|
||||
std::tuple<std::string, std::vector<int>> NormalizeString(
|
||||
const std::string& in_string, const EncoderConfig& config);
|
||||
|
||||
// Encodes one string and returns ids and offsets. Takes the configuration as a
|
||||
// type-erased buffer.
|
||||
EncoderResult EncodeString(const std::string& string, const void* config_buffer,
|
||||
bool add_bos, bool add_eos, bool reverse);
|
||||
|
||||
} // namespace mediapipe::tflite_operations::sentencepiece
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_OPTIMIZED_ENCODER_H_
|
|
@ -0,0 +1,171 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.h"
|
||||
|
||||
#include <fstream>
|
||||
|
||||
#include "absl/flags/flag.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "mediapipe/framework/deps/file_path.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_builder.h"
|
||||
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config_generated.h"
|
||||
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/model_converter.h"
|
||||
#include "src/sentencepiece.pb.h"
|
||||
#include "src/sentencepiece_processor.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
|
||||
namespace mediapipe::tflite_operations::sentencepiece {
|
||||
|
||||
namespace internal {
|
||||
|
||||
tensorflow::Status TFReadFileToString(const std::string& filepath,
|
||||
std::string* data) {
|
||||
return tensorflow::ReadFileToString(tensorflow::Env::Default(), filepath,
|
||||
data);
|
||||
}
|
||||
|
||||
absl::Status StdReadFileToString(const std::string& filepath,
|
||||
std::string* data) {
|
||||
std::ifstream infile(filepath);
|
||||
if (!infile.is_open()) {
|
||||
return absl::NotFoundError(
|
||||
absl::StrFormat("Error when opening %s", filepath));
|
||||
}
|
||||
std::string contents((std::istreambuf_iterator<char>(infile)),
|
||||
(std::istreambuf_iterator<char>()));
|
||||
data->append(contents);
|
||||
infile.close();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
} // namespace internal
|
||||
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::file::JoinPath;
|
||||
|
||||
static char kConfigFilePath[] =
|
||||
"/mediapipe/tasks/cc/text/custom_ops/"
|
||||
"sentencepiece/testdata/sentencepiece.model";
|
||||
|
||||
TEST(OptimizedEncoder, NormalizeStringWhitestpaces) {
|
||||
flatbuffers::FlatBufferBuilder builder(1024);
|
||||
EncoderConfigBuilder ecb(builder);
|
||||
ecb.add_remove_extra_whitespaces(true);
|
||||
ecb.add_add_dummy_prefix(true);
|
||||
ecb.add_escape_whitespaces(true);
|
||||
FinishEncoderConfigBuffer(builder, ecb.Finish());
|
||||
const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer());
|
||||
{
|
||||
const auto result = NormalizeString("x y", *config);
|
||||
const auto res_string = std::get<0>(result);
|
||||
const auto offsets = std::get<1>(result);
|
||||
EXPECT_EQ(res_string, "\xe2\x96\x81x\xe2\x96\x81y");
|
||||
EXPECT_THAT(offsets, ::testing::ElementsAre(0, 0, 0, 0, 1, 1, 1, 3));
|
||||
}
|
||||
{
|
||||
const auto result = NormalizeString("\tx y\n", *config);
|
||||
const auto res_string = std::get<0>(result);
|
||||
const auto offsets = std::get<1>(result);
|
||||
EXPECT_EQ(res_string, "\xe2\x96\x81x\xe2\x96\x81y");
|
||||
EXPECT_THAT(offsets, ::testing::ElementsAre(0, 0, 0, 1, 2, 2, 2, 4));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(OptimizedEncoder, NormalizeStringReplacement) {
|
||||
flatbuffers::FlatBufferBuilder builder(1024);
|
||||
const std::vector<std::string> norm_prefixes = {"A", "AA", "AAA", "AAAA"};
|
||||
const char norm_replacements[] = "A1\0A2\0A3\0A4";
|
||||
const auto trie_vector =
|
||||
builder.CreateVector(BuildTrie(norm_prefixes, {0, 3, 6, 9}));
|
||||
const auto norm_r = builder.CreateVector<int8_t>(
|
||||
reinterpret_cast<const signed char*>(norm_replacements),
|
||||
sizeof(norm_replacements));
|
||||
TrieBuilder trie_builder(builder);
|
||||
trie_builder.add_nodes(trie_vector);
|
||||
const auto norm_p = trie_builder.Finish();
|
||||
EncoderConfigBuilder ecb(builder);
|
||||
ecb.add_remove_extra_whitespaces(false);
|
||||
ecb.add_normalized_prefixes(norm_p);
|
||||
ecb.add_normalized_replacements(norm_r);
|
||||
FinishEncoderConfigBuffer(builder, ecb.Finish());
|
||||
const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer());
|
||||
{
|
||||
const auto result = NormalizeString("ABAABAAABAAAA", *config);
|
||||
const auto res_string = std::get<0>(result);
|
||||
const auto offsets = std::get<1>(result);
|
||||
EXPECT_EQ(res_string, "A1BA2BA3BA4");
|
||||
EXPECT_THAT(offsets,
|
||||
::testing::ElementsAre(0, 0, 1, 2, 2, 4, 5, 5, 8, 9, 9));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(OptimizedEncoder, NormalizeStringWhitespacesRemove) {
|
||||
flatbuffers::FlatBufferBuilder builder(1024);
|
||||
const std::vector<std::string> norm_prefixes = {"A", "AA", "AAA", "AAAA",
|
||||
"X"};
|
||||
const char norm_replacements[] = "A1\0A2\0A3\0A4\0 ";
|
||||
const auto trie_vector =
|
||||
builder.CreateVector(BuildTrie(norm_prefixes, {0, 3, 6, 9, 12}));
|
||||
const auto norm_r = builder.CreateVector<int8_t>(
|
||||
reinterpret_cast<const signed char*>(norm_replacements),
|
||||
sizeof(norm_replacements));
|
||||
TrieBuilder trie_builder(builder);
|
||||
trie_builder.add_nodes(trie_vector);
|
||||
const auto norm_p = trie_builder.Finish();
|
||||
EncoderConfigBuilder ecb(builder);
|
||||
ecb.add_remove_extra_whitespaces(true);
|
||||
ecb.add_normalized_prefixes(norm_p);
|
||||
ecb.add_normalized_replacements(norm_r);
|
||||
FinishEncoderConfigBuffer(builder, ecb.Finish());
|
||||
const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer());
|
||||
{
|
||||
const auto result = NormalizeString("XXABAABAAABAAAA", *config);
|
||||
const auto res_string = std::get<0>(result);
|
||||
const auto offsets = std::get<1>(result);
|
||||
EXPECT_EQ(res_string, " A1BA2BA3BA4");
|
||||
EXPECT_THAT(offsets,
|
||||
::testing::ElementsAre(0, 2, 2, 3, 4, 4, 6, 7, 7, 10, 11, 11));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(OptimizedEncoder, ConfigConverter) {
|
||||
std::string config;
|
||||
auto status =
|
||||
internal::TFReadFileToString(JoinPath("./", kConfigFilePath), &config);
|
||||
ASSERT_TRUE(status.ok());
|
||||
|
||||
::sentencepiece::SentencePieceProcessor processor;
|
||||
ASSERT_TRUE(processor.LoadFromSerializedProto(config).ok());
|
||||
const auto converted_model = ConvertSentencepieceModel(config);
|
||||
const std::string test_string("Hello world!\\xF0\\x9F\\x8D\\x95");
|
||||
const auto encoded =
|
||||
EncodeString(test_string, converted_model.data(), false, false, false);
|
||||
ASSERT_EQ(encoded.codes.size(), encoded.offsets.size());
|
||||
|
||||
::sentencepiece::SentencePieceText reference_encoded;
|
||||
ASSERT_TRUE(processor.Encode(test_string, &reference_encoded).ok());
|
||||
EXPECT_EQ(encoded.codes.size(), reference_encoded.pieces_size());
|
||||
for (int i = 0; i < encoded.codes.size(); ++i) {
|
||||
EXPECT_EQ(encoded.codes[i], reference_encoded.pieces(i).id());
|
||||
EXPECT_EQ(encoded.offsets[i], reference_encoded.pieces(i).begin());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe::tflite_operations::sentencepiece
|
|
@ -0,0 +1,38 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_CONSTANTS_H_
|
||||
#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_CONSTANTS_H_
|
||||
|
||||
namespace mediapipe::tflite_operations::sentencepiece {
|
||||
|
||||
// The constant is copied from
|
||||
// https://github.com/google/sentencepiece/blob/master/src/unigram_model.cc
|
||||
constexpr float kUnkPenalty = 10.0;
|
||||
|
||||
// These constants are copied from
|
||||
// https://github.com/google/sentencepiece/blob/master/src/sentencepiece_processor.cc
|
||||
//
|
||||
// Replaces white space with U+2581 (LOWER ONE EIGHT BLOCK).
|
||||
constexpr char kSpaceSymbol[] = "\xe2\x96\x81";
|
||||
|
||||
// Encodes <unk> into U+2047 (DOUBLE QUESTION MARK),
|
||||
// since this character can be useful both for user and
|
||||
// developer. We can easily figure out that <unk> is emitted.
|
||||
constexpr char kDefaultUnknownSymbol[] = " \xE2\x81\x87 ";
|
||||
|
||||
} // namespace mediapipe::tflite_operations::sentencepiece
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_CONSTANTS_H_
|
BIN
mediapipe/tasks/cc/text/custom_ops/sentencepiece/testdata/sentencepiece.model
vendored
Normal file
BIN
mediapipe/tasks/cc/text/custom_ops/sentencepiece/testdata/sentencepiece.model
vendored
Normal file
Binary file not shown.
Loading…
Reference in New Issue
Block a user