diff --git a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/BUILD b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/BUILD index 072b21f53..a1833ac54 100644 --- a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/BUILD +++ b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/BUILD @@ -18,6 +18,13 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) +filegroup( + name = "testdata", + srcs = glob([ + "testdata/**", + ]), +) + filegroup( name = "config_fbs", srcs = ["config.fbs"], @@ -80,3 +87,66 @@ cc_test( "//mediapipe/framework/port:gtest_main", ], ) + +cc_library( + name = "sentencepiece_constants", + hdrs = ["sentencepiece_constants.h"], +) + +cc_library( + name = "model_converter", + srcs = [ + "model_converter.cc", + ], + hdrs = [ + "model_converter.h", + ], + deps = [ + ":config", + ":double_array_trie_builder", + ":encoder_config", + ":sentencepiece_constants", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_sentencepiece//src:sentencepiece_model_cc_proto", + ], +) + +cc_library( + name = "optimized_encoder", + srcs = [ + "optimized_encoder.cc", + ], + hdrs = [ + "optimized_encoder.h", + ], + deps = [ + ":double_array_trie", + ":encoder_config", + ":utils", + ], +) + +cc_test( + name = "optimized_encoder_test", + srcs = [ + "optimized_encoder_test.cc", + ], + data = [ + ":testdata", + ], + deps = [ + ":double_array_trie_builder", + ":encoder_config", + ":model_converter", + ":optimized_encoder", + "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/port:gtest_main", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", + "@com_google_sentencepiece//src:sentencepiece_cc_proto", + "@com_google_sentencepiece//src:sentencepiece_processor", + "@org_tensorflow//tensorflow/core:lib", + ], +) diff --git a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/model_converter.cc b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/model_converter.cc new file mode 100644 index 000000000..3a831f3d7 --- /dev/null +++ b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/model_converter.cc @@ -0,0 +1,131 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/model_converter.h" + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_builder.h" +#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config_generated.h" +#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_constants.h" +#include "src/sentencepiece_model.pb.h" + +namespace mediapipe::tflite_operations::sentencepiece { + +std::tuple, std::vector> +DecodePrecompiledCharsmap( + const ::sentencepiece::NormalizerSpec& normalizer_spec) { + // This function "undoes" encoding done by + // sentencepiece::normalizer::Normalizer::EncodePrecompiledCharsMap. + const char* precompiled_map = normalizer_spec.precompiled_charsmap().data(); + const uint32_t trie_size = + *reinterpret_cast(precompiled_map); + const uint32_t* trie_ptr = + reinterpret_cast(precompiled_map + sizeof(uint32_t)); + const int8_t* normalized_ptr = reinterpret_cast( + precompiled_map + sizeof(uint32_t) + trie_size); + const int normalized_size = normalizer_spec.precompiled_charsmap().length() - + sizeof(uint32_t) - trie_size; + return std::make_tuple( + std::vector(trie_ptr, trie_ptr + trie_size / sizeof(uint32_t)), + std::vector(normalized_ptr, normalized_ptr + normalized_size)); +} + +absl::StatusOr ConvertSentencepieceModelToFlatBuffer( + const std::string& model_config_str, int encoding_offset) { + ::sentencepiece::ModelProto model_config; + if (!model_config.ParseFromString(model_config_str)) { + return absl::InvalidArgumentError( + "Invalid configuration, can't parse SentencePiece model config " + + model_config.InitializationErrorString()); + } + // Convert sentencepieces. + std::vector pieces; + pieces.reserve(model_config.pieces_size()); + std::vector scores; + scores.reserve(model_config.pieces_size()); + std::vector ids; + ids.reserve(model_config.pieces_size()); + float min_score = 0.0; + int index = 0; + for (const auto& piece : model_config.pieces()) { + switch (piece.type()) { + case ::sentencepiece::ModelProto::SentencePiece::NORMAL: + case ::sentencepiece::ModelProto::SentencePiece::USER_DEFINED: + pieces.push_back(piece.piece()); + ids.push_back(index); + if (piece.score() < min_score) { + min_score = piece.score(); + } + break; + case ::sentencepiece::ModelProto::SentencePiece::UNKNOWN: + case ::sentencepiece::ModelProto::SentencePiece::CONTROL: + // Ignore unknown and control codes. + break; + default: + return absl::InvalidArgumentError("Invalid SentencePiece piece type " + + piece.piece()); + } + scores.push_back(piece.score()); + ++index; + } + flatbuffers::FlatBufferBuilder builder(1024); + const auto pieces_trie_vector = builder.CreateVector(BuildTrie(pieces, ids)); + const auto pieces_score_vector = builder.CreateVector(scores); + TrieBuilder pieces_trie_builder(builder); + pieces_trie_builder.add_nodes(pieces_trie_vector); + const auto pieces_trie_fbs = pieces_trie_builder.Finish(); + + // Converting normalization. + const auto normalization = + DecodePrecompiledCharsmap(model_config.normalizer_spec()); + const auto normalization_trie = std::get<0>(normalization); + const auto normalization_strings = std::get<1>(normalization); + const auto normalization_trie_vector = + builder.CreateVector(normalization_trie); + TrieBuilder normalization_trie_builder(builder); + normalization_trie_builder.add_nodes(normalization_trie_vector); + const auto normalization_trie_fbs = normalization_trie_builder.Finish(); + const auto normalization_strings_fbs = + builder.CreateVector(normalization_strings); + + EncoderConfigBuilder ecb(builder); + ecb.add_version(EncoderVersion::EncoderVersion_SENTENCE_PIECE); + ecb.add_start_code(model_config.trainer_spec().bos_id()); + ecb.add_end_code(model_config.trainer_spec().eos_id()); + ecb.add_unknown_code(model_config.trainer_spec().unk_id()); + ecb.add_unknown_penalty(min_score - kUnkPenalty); + ecb.add_encoding_offset(encoding_offset); + ecb.add_pieces(pieces_trie_fbs); + ecb.add_pieces_scores(pieces_score_vector); + ecb.add_remove_extra_whitespaces( + model_config.normalizer_spec().remove_extra_whitespaces()); + ecb.add_add_dummy_prefix(model_config.normalizer_spec().add_dummy_prefix()); + ecb.add_escape_whitespaces( + model_config.normalizer_spec().escape_whitespaces()); + ecb.add_normalized_prefixes(normalization_trie_fbs); + ecb.add_normalized_replacements(normalization_strings_fbs); + FinishEncoderConfigBuffer(builder, ecb.Finish()); + return std::string(reinterpret_cast(builder.GetBufferPointer()), + builder.GetSize()); +} + +std::string ConvertSentencepieceModel(const std::string& model_string) { + const auto result = ConvertSentencepieceModelToFlatBuffer(model_string); + assert(result.status().ok()); + return result.value(); +} + +} // namespace mediapipe::tflite_operations::sentencepiece diff --git a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/model_converter.h b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/model_converter.h new file mode 100644 index 000000000..828db16da --- /dev/null +++ b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/model_converter.h @@ -0,0 +1,33 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_MODEL_CONVERTER_H_ +#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_MODEL_CONVERTER_H_ + +#include + +#include "absl/status/statusor.h" + +namespace mediapipe::tflite_operations::sentencepiece { + +// Converts Sentencepiece configuration to flatbuffer format. +// encoding_offset is used by some encoders that combine different encodings. +absl::StatusOr ConvertSentencepieceModelToFlatBuffer( + const std::string& model_config_str, int encoding_offset = 0); +std::string ConvertSentencepieceModel(const std::string& model_string); + +} // namespace mediapipe::tflite_operations::sentencepiece + +#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_MODEL_CONVERTER_H_ diff --git a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.cc b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.cc new file mode 100644 index 000000000..365b1a5ad --- /dev/null +++ b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.cc @@ -0,0 +1,236 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.h" + +#include +#include + +#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie.h" +#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config_generated.h" +#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/utils.h" + +namespace mediapipe::tflite_operations::sentencepiece { +namespace { + +const char kSpaceSymbol[] = "\xe2\x96\x81"; + +template +std::tuple> process_string( + const std::string& input, const std::vector& offsets, + const processing_callback& pc) { + std::string result_string; + result_string.reserve(input.size()); + std::vector result_offsets; + result_offsets.reserve(offsets.size()); + for (int i = 0, j = 0; i < input.size();) { + auto result = pc(input.data() + i, input.size() - i); + auto consumed = std::get<0>(result); + auto new_string = std::get<1>(result); + if (consumed == 0) { + // Skip the current byte and move forward. + result_string.push_back(input[i]); + result_offsets.push_back(offsets[j]); + i++; + j++; + continue; + } + result_string.append(new_string.data(), new_string.length()); + for (int i = 0; i < new_string.length(); ++i) { + result_offsets.push_back(offsets[j]); + } + j += consumed; + i += consumed; + } + return std::make_tuple(result_string, result_offsets); +} + +inline char is_whitespace(char c) { + return c == ' ' || c == '\t' || c == '\r' || c == '\n'; +} + +std::tuple remove_extra_whitespaces(const char* data, + int len) { + if (len == 0 || !is_whitespace(*data)) { + return std::make_tuple(0, utils::string_view(nullptr, 0)); + } + int num_consumed = 1; + for (; num_consumed < len && is_whitespace(data[num_consumed]); + ++num_consumed) { + } + return num_consumed > 1 + ? std::make_tuple(num_consumed, utils::string_view(" ", 1)) + : std::make_tuple(0, utils::string_view(nullptr, 0)); +} + +std::tuple find_replacement( + const char* data, int len, const DoubleArrayTrie& dat, + const flatbuffers::Vector& replacements) { + const auto max_match = dat.LongestPrefixMatch(utils::string_view(data, len)); + if (!max_match.empty()) { + // Because flatbuffer byte is signed char which is not the same as char, + // there is the reinterpret_cast here. + const char* replaced_string_ptr = + reinterpret_cast(replacements.data() + max_match.id); + return std::make_tuple(max_match.match_length, + utils::string_view(replaced_string_ptr)); + } + return std::make_tuple(0, utils::string_view(nullptr, 0)); +} +} // namespace + +std::tuple> NormalizeString( + const std::string& in_string, const EncoderConfig& config) { + std::vector output_offsets; + std::string result = in_string; + output_offsets.reserve(in_string.length()); + for (int i = 0; i < in_string.length(); ++i) { + output_offsets.push_back(i); + } + if (in_string.empty()) { + return std::make_tuple(result, output_offsets); + } + if (config.add_dummy_prefix()) { + result.insert(result.begin(), ' '); + output_offsets.insert(output_offsets.begin(), 0); + } + // Greedely replace normalized_prefixes with normalized_replacements + if (config.normalized_prefixes() != nullptr && + config.normalized_replacements() != nullptr) { + const DoubleArrayTrie normalized_prefixes_matcher( + config.normalized_prefixes()->nodes()); + const auto norm_replace = [&config, &normalized_prefixes_matcher]( + const char* data, int len) { + return find_replacement(data, len, normalized_prefixes_matcher, + *config.normalized_replacements()); + }; + std::tie(result, output_offsets) = + process_string(result, output_offsets, norm_replace); + } + if (config.remove_extra_whitespaces()) { + std::tie(result, output_offsets) = + process_string(result, output_offsets, remove_extra_whitespaces); + if (!result.empty() && is_whitespace(result.back())) { + result.pop_back(); + output_offsets.pop_back(); + } + } + if (config.escape_whitespaces()) { + const auto replace_whitespaces = [](const char* data, int len) { + if (len > 0 && is_whitespace(*data)) { + return std::make_tuple(1, utils::string_view(kSpaceSymbol)); + } + return std::make_tuple(0, utils::string_view(nullptr, 0)); + }; + std::tie(result, output_offsets) = + process_string(result, output_offsets, replace_whitespaces); + } + + return std::make_tuple(result, output_offsets); +} + +EncoderResult EncodeNormalizedString(const std::string& str, + const std::vector& offsets, + const EncoderConfig& config, bool add_bos, + bool add_eos, bool reverse) { + const DoubleArrayTrie piece_matcher(config.pieces()->nodes()); + const flatbuffers::Vector* piece_scores = config.pieces_scores(); + const int unknown_code = config.unknown_code(); + const float unknown_penalty = config.unknown_penalty(); + struct LatticeElement { + float score = 0; + int code = -1; + int prev_position = -1; + LatticeElement(float score_, int code_, int prev_position_) + : score(score_), code(code_), prev_position(prev_position_) {} + LatticeElement() {} + }; + const int length = str.length(); + std::vector lattice(length + 1); + for (int i = 0; i < length; ++i) { + if (i > 0 && lattice[i].prev_position < 0) { + // This state is unreachable. + continue; + } + if (unknown_code >= 0) { + // Put unknown code. + const float penalized_score = lattice[i].score + unknown_penalty; + const int pos = i + 1; + LatticeElement& current_element = lattice[pos]; + if (current_element.prev_position < 0 || + current_element.score < penalized_score) { + current_element = LatticeElement( + penalized_score, unknown_code, + // If the current state is already reached by unknown code, merge + // states. + lattice[i].code == unknown_code ? lattice[i].prev_position : i); + } + } + auto lattice_update = [&lattice, i, + piece_scores](const DoubleArrayTrie::Match& m) { + LatticeElement& target_element = lattice[i + m.match_length]; + const float score = lattice[i].score + (*piece_scores)[m.id]; + if (target_element.prev_position < 0 || target_element.score < score) { + target_element = LatticeElement(score, m.id, i); + } + }; + piece_matcher.IteratePrefixMatches( + utils::string_view(str.data() + i, length - i), lattice_update); + } + + EncoderResult result; + if (add_eos) { + result.codes.push_back(config.end_code()); + result.offsets.push_back(length); + } + if (lattice[length].prev_position >= 0) { + for (int pos = length; pos > 0;) { + auto code = lattice[pos].code; + if (code != config.unknown_code()) { + code += config.encoding_offset(); + } + result.codes.push_back(code); + pos = lattice[pos].prev_position; + result.offsets.push_back(offsets[pos]); + } + } + if (add_bos) { + result.codes.push_back(config.start_code()); + result.offsets.push_back(0); + } + if (!reverse) { + std::reverse(result.codes.begin(), result.codes.end()); + std::reverse(result.offsets.begin(), result.offsets.end()); + } + return result; +} + +EncoderResult EncodeString(const std::string& string, const void* config_buffer, + bool add_bos, bool add_eos, bool reverse) { + // Get the config from the buffer. + const EncoderConfig* config = GetEncoderConfig(config_buffer); + if (config->version() != EncoderVersion::EncoderVersion_SENTENCE_PIECE) { + EncoderResult result; + result.type = EncoderResultType::WRONG_CONFIG; + return result; + } + std::string normalized_string; + std::vector offsets; + std::tie(normalized_string, offsets) = NormalizeString(string, *config); + return EncodeNormalizedString(normalized_string, offsets, *config, add_bos, + add_eos, reverse); +} + +} // namespace mediapipe::tflite_operations::sentencepiece diff --git a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.h b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.h new file mode 100644 index 000000000..849a47849 --- /dev/null +++ b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.h @@ -0,0 +1,46 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_OPTIMIZED_ENCODER_H_ +#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_OPTIMIZED_ENCODER_H_ + +// Sentencepiece encoder optimized with memmapped model. + +#include +#include +#include + +#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config_generated.h" + +namespace mediapipe::tflite_operations::sentencepiece { + +enum class EncoderResultType { SUCCESS = 0, WRONG_CONFIG = 1 }; + +struct EncoderResult { + EncoderResultType type = EncoderResultType::SUCCESS; + std::vector codes; + std::vector offsets; +}; +std::tuple> NormalizeString( + const std::string& in_string, const EncoderConfig& config); + +// Encodes one string and returns ids and offsets. Takes the configuration as a +// type-erased buffer. +EncoderResult EncodeString(const std::string& string, const void* config_buffer, + bool add_bos, bool add_eos, bool reverse); + +} // namespace mediapipe::tflite_operations::sentencepiece + +#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_OPTIMIZED_ENCODER_H_ diff --git a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder_test.cc b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder_test.cc new file mode 100644 index 000000000..e65bd1850 --- /dev/null +++ b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder_test.cc @@ -0,0 +1,171 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.h" + +#include + +#include "absl/flags/flag.h" +#include "absl/status/status.h" +#include "absl/strings/str_format.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_builder.h" +#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config_generated.h" +#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/model_converter.h" +#include "src/sentencepiece.pb.h" +#include "src/sentencepiece_processor.h" +#include "tensorflow/core/platform/env.h" + +namespace mediapipe::tflite_operations::sentencepiece { + +namespace internal { + +tensorflow::Status TFReadFileToString(const std::string& filepath, + std::string* data) { + return tensorflow::ReadFileToString(tensorflow::Env::Default(), filepath, + data); +} + +absl::Status StdReadFileToString(const std::string& filepath, + std::string* data) { + std::ifstream infile(filepath); + if (!infile.is_open()) { + return absl::NotFoundError( + absl::StrFormat("Error when opening %s", filepath)); + } + std::string contents((std::istreambuf_iterator(infile)), + (std::istreambuf_iterator())); + data->append(contents); + infile.close(); + return absl::OkStatus(); +} +} // namespace internal + +namespace { + +using ::mediapipe::file::JoinPath; + +static char kConfigFilePath[] = + "/mediapipe/tasks/cc/text/custom_ops/" + "sentencepiece/testdata/sentencepiece.model"; + +TEST(OptimizedEncoder, NormalizeStringWhitestpaces) { + flatbuffers::FlatBufferBuilder builder(1024); + EncoderConfigBuilder ecb(builder); + ecb.add_remove_extra_whitespaces(true); + ecb.add_add_dummy_prefix(true); + ecb.add_escape_whitespaces(true); + FinishEncoderConfigBuffer(builder, ecb.Finish()); + const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer()); + { + const auto result = NormalizeString("x y", *config); + const auto res_string = std::get<0>(result); + const auto offsets = std::get<1>(result); + EXPECT_EQ(res_string, "\xe2\x96\x81x\xe2\x96\x81y"); + EXPECT_THAT(offsets, ::testing::ElementsAre(0, 0, 0, 0, 1, 1, 1, 3)); + } + { + const auto result = NormalizeString("\tx y\n", *config); + const auto res_string = std::get<0>(result); + const auto offsets = std::get<1>(result); + EXPECT_EQ(res_string, "\xe2\x96\x81x\xe2\x96\x81y"); + EXPECT_THAT(offsets, ::testing::ElementsAre(0, 0, 0, 1, 2, 2, 2, 4)); + } +} + +TEST(OptimizedEncoder, NormalizeStringReplacement) { + flatbuffers::FlatBufferBuilder builder(1024); + const std::vector norm_prefixes = {"A", "AA", "AAA", "AAAA"}; + const char norm_replacements[] = "A1\0A2\0A3\0A4"; + const auto trie_vector = + builder.CreateVector(BuildTrie(norm_prefixes, {0, 3, 6, 9})); + const auto norm_r = builder.CreateVector( + reinterpret_cast(norm_replacements), + sizeof(norm_replacements)); + TrieBuilder trie_builder(builder); + trie_builder.add_nodes(trie_vector); + const auto norm_p = trie_builder.Finish(); + EncoderConfigBuilder ecb(builder); + ecb.add_remove_extra_whitespaces(false); + ecb.add_normalized_prefixes(norm_p); + ecb.add_normalized_replacements(norm_r); + FinishEncoderConfigBuffer(builder, ecb.Finish()); + const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer()); + { + const auto result = NormalizeString("ABAABAAABAAAA", *config); + const auto res_string = std::get<0>(result); + const auto offsets = std::get<1>(result); + EXPECT_EQ(res_string, "A1BA2BA3BA4"); + EXPECT_THAT(offsets, + ::testing::ElementsAre(0, 0, 1, 2, 2, 4, 5, 5, 8, 9, 9)); + } +} + +TEST(OptimizedEncoder, NormalizeStringWhitespacesRemove) { + flatbuffers::FlatBufferBuilder builder(1024); + const std::vector norm_prefixes = {"A", "AA", "AAA", "AAAA", + "X"}; + const char norm_replacements[] = "A1\0A2\0A3\0A4\0 "; + const auto trie_vector = + builder.CreateVector(BuildTrie(norm_prefixes, {0, 3, 6, 9, 12})); + const auto norm_r = builder.CreateVector( + reinterpret_cast(norm_replacements), + sizeof(norm_replacements)); + TrieBuilder trie_builder(builder); + trie_builder.add_nodes(trie_vector); + const auto norm_p = trie_builder.Finish(); + EncoderConfigBuilder ecb(builder); + ecb.add_remove_extra_whitespaces(true); + ecb.add_normalized_prefixes(norm_p); + ecb.add_normalized_replacements(norm_r); + FinishEncoderConfigBuffer(builder, ecb.Finish()); + const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer()); + { + const auto result = NormalizeString("XXABAABAAABAAAA", *config); + const auto res_string = std::get<0>(result); + const auto offsets = std::get<1>(result); + EXPECT_EQ(res_string, " A1BA2BA3BA4"); + EXPECT_THAT(offsets, + ::testing::ElementsAre(0, 2, 2, 3, 4, 4, 6, 7, 7, 10, 11, 11)); + } +} + +TEST(OptimizedEncoder, ConfigConverter) { + std::string config; + auto status = + internal::TFReadFileToString(JoinPath("./", kConfigFilePath), &config); + ASSERT_TRUE(status.ok()); + + ::sentencepiece::SentencePieceProcessor processor; + ASSERT_TRUE(processor.LoadFromSerializedProto(config).ok()); + const auto converted_model = ConvertSentencepieceModel(config); + const std::string test_string("Hello world!\\xF0\\x9F\\x8D\\x95"); + const auto encoded = + EncodeString(test_string, converted_model.data(), false, false, false); + ASSERT_EQ(encoded.codes.size(), encoded.offsets.size()); + + ::sentencepiece::SentencePieceText reference_encoded; + ASSERT_TRUE(processor.Encode(test_string, &reference_encoded).ok()); + EXPECT_EQ(encoded.codes.size(), reference_encoded.pieces_size()); + for (int i = 0; i < encoded.codes.size(); ++i) { + EXPECT_EQ(encoded.codes[i], reference_encoded.pieces(i).id()); + EXPECT_EQ(encoded.offsets[i], reference_encoded.pieces(i).begin()); + } +} + +} // namespace +} // namespace mediapipe::tflite_operations::sentencepiece diff --git a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_constants.h b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_constants.h new file mode 100644 index 000000000..faf481844 --- /dev/null +++ b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_constants.h @@ -0,0 +1,38 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_CONSTANTS_H_ +#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_CONSTANTS_H_ + +namespace mediapipe::tflite_operations::sentencepiece { + +// The constant is copied from +// https://github.com/google/sentencepiece/blob/master/src/unigram_model.cc +constexpr float kUnkPenalty = 10.0; + +// These constants are copied from +// https://github.com/google/sentencepiece/blob/master/src/sentencepiece_processor.cc +// +// Replaces white space with U+2581 (LOWER ONE EIGHT BLOCK). +constexpr char kSpaceSymbol[] = "\xe2\x96\x81"; + +// Encodes into U+2047 (DOUBLE QUESTION MARK), +// since this character can be useful both for user and +// developer. We can easily figure out that is emitted. +constexpr char kDefaultUnknownSymbol[] = " \xE2\x81\x87 "; + +} // namespace mediapipe::tflite_operations::sentencepiece + +#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_CONSTANTS_H_ diff --git a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/testdata/sentencepiece.model b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/testdata/sentencepiece.model new file mode 100644 index 000000000..041188ffd Binary files /dev/null and b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/testdata/sentencepiece.model differ