From 3a5bc5815be8736cb9081226749344d8ddf19542 Mon Sep 17 00:00:00 2001 From: Taku Kudo Date: Wed, 16 Jun 2021 14:51:52 +0900 Subject: [PATCH] Revert "sync from internal" This reverts commit 05db0894d8ea44b203c3501306061cde9e42c48e. --- VERSION.txt | 2 +- python/VERSION.txt | 2 +- src/bpe_model.cc | 3 +- src/builder.cc | 4 +- src/builtin_pb/sentencepiece_model.pb.cc | 237 ++++++++--------- src/builtin_pb/sentencepiece_model.pb.h | 205 ++++++--------- src/model_interface.cc | 29 +-- src/model_interface.h | 40 +-- src/model_interface_test.cc | 44 ---- src/normalizer.cc | 7 +- src/normalizer.h | 1 + src/normalizer_test.cc | 3 +- src/sentencepiece_model.proto | 4 - src/sentencepiece_processor.cc | 46 +--- src/sentencepiece_processor.h | 9 - src/sentencepiece_processor_test.cc | 17 +- src/sentencepiece_trainer.cc | 10 +- src/spec_parser.h | 2 - src/spm_decode_main.cc | 1 - src/spm_encode_main.cc | 12 +- src/spm_train_main.cc | 10 +- src/trainer_interface.cc | 53 ++-- src/trainer_interface_test.cc | 26 -- src/unigram_model.cc | 317 ++++------------------- src/unigram_model.h | 26 +- src/unigram_model_test.cc | 275 +------------------- src/unigram_model_trainer.cc | 12 +- src/util.cc | 5 +- third_party/absl/flags/flag.cc | 1 - 29 files changed, 324 insertions(+), 1079 deletions(-) diff --git a/VERSION.txt b/VERSION.txt index c65d728..9c178d3 100644 --- a/VERSION.txt +++ b/VERSION.txt @@ -1 +1 @@ -0.1.96 +0.1.95 diff --git a/python/VERSION.txt b/python/VERSION.txt index c65d728..9c178d3 100644 --- a/python/VERSION.txt +++ b/python/VERSION.txt @@ -1 +1 @@ -0.1.96 +0.1.95 diff --git a/src/bpe_model.cc b/src/bpe_model.cc index 22cd115..5d77baa 100644 --- a/src/bpe_model.cc +++ b/src/bpe_model.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License.! +#include "bpe_model.h" + #include #include #include @@ -19,7 +21,6 @@ #include #include -#include "bpe_model.h" #include "freelist.h" #include "third_party/absl/container/flat_hash_map.h" #include "util.h" diff --git a/src/builder.cc b/src/builder.cc index 794ce5f..88346dd 100644 --- a/src/builder.cc +++ b/src/builder.cc @@ -12,11 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License.! +#include "builder.h" + #include #include #include -#include "builder.h" #include "filesystem.h" #include "third_party/absl/strings/str_join.h" #include "third_party/absl/strings/str_replace.h" @@ -367,7 +368,6 @@ util::Status Builder::BuildNmtNFKCMap(CharsMap *chars_map) { nfkc_map[{0xFEFF}] = {0x20}; // ZERO WIDTH NO-BREAK nfkc_map[{0xFFFD}] = {0x20}; // REPLACEMENT CHARACTER nfkc_map[{0x200C}] = {0x20}; // ZERO WIDTH NON-JOINER - nfkc_map[{0x200D}] = {0x20}; // ZERO WIDTH JOINER // Ascii Control characters nfkc_map[{0x0001}] = {}; diff --git a/src/builtin_pb/sentencepiece_model.pb.cc b/src/builtin_pb/sentencepiece_model.pb.cc index e913731..4863136 100644 --- a/src/builtin_pb/sentencepiece_model.pb.cc +++ b/src/builtin_pb/sentencepiece_model.pb.cc @@ -285,22 +285,22 @@ class TrainerSpec::_Internal { (*has_bits)[0] |= 1u; } static void set_has_model_type(HasBits* has_bits) { - (*has_bits)[0] |= 524288u; + (*has_bits)[0] |= 262144u; } static void set_has_vocab_size(HasBits* has_bits) { - (*has_bits)[0] |= 1048576u; + (*has_bits)[0] |= 524288u; } static void set_has_self_test_sample_size(HasBits* has_bits) { (*has_bits)[0] |= 256u; } static void set_has_character_coverage(HasBits* has_bits) { - (*has_bits)[0] |= 2097152u; + (*has_bits)[0] |= 1048576u; } static void set_has_input_sentence_size(HasBits* has_bits) { (*has_bits)[0] |= 1024u; } static void set_has_shuffle_input_sentence(HasBits* has_bits) { - (*has_bits)[0] |= 268435456u; + (*has_bits)[0] |= 134217728u; } static void set_has_mining_sentence_size(HasBits* has_bits) { (*has_bits)[0] |= 512u; @@ -309,67 +309,64 @@ class TrainerSpec::_Internal { (*has_bits)[0] |= 2048u; } static void set_has_seed_sentencepiece_size(HasBits* has_bits) { - (*has_bits)[0] |= 4194304u; + (*has_bits)[0] |= 2097152u; } static void set_has_shrinking_factor(HasBits* has_bits) { - (*has_bits)[0] |= 8388608u; + (*has_bits)[0] |= 4194304u; } static void set_has_max_sentence_length(HasBits* has_bits) { - (*has_bits)[0] |= 67108864u; - } - static void set_has_num_threads(HasBits* has_bits) { - (*has_bits)[0] |= 16777216u; - } - static void set_has_num_sub_iterations(HasBits* has_bits) { (*has_bits)[0] |= 33554432u; } + static void set_has_num_threads(HasBits* has_bits) { + (*has_bits)[0] |= 8388608u; + } + static void set_has_num_sub_iterations(HasBits* has_bits) { + (*has_bits)[0] |= 16777216u; + } static void set_has_max_sentencepiece_length(HasBits* has_bits) { - (*has_bits)[0] |= 134217728u; + (*has_bits)[0] |= 67108864u; } static void set_has_split_by_unicode_script(HasBits* has_bits) { - (*has_bits)[0] |= 536870912u; + (*has_bits)[0] |= 268435456u; } static void set_has_split_by_number(HasBits* has_bits) { - (*has_bits)[0] |= 1073741824u; + (*has_bits)[0] |= 536870912u; } static void set_has_split_by_whitespace(HasBits* has_bits) { - (*has_bits)[0] |= 2147483648u; + (*has_bits)[0] |= 1073741824u; } static void set_has_treat_whitespace_as_suffix(HasBits* has_bits) { (*has_bits)[0] |= 4096u; } - static void set_has_allow_whitespace_only_pieces(HasBits* has_bits) { - (*has_bits)[0] |= 8192u; - } static void set_has_split_digits(HasBits* has_bits) { - (*has_bits)[0] |= 16384u; + (*has_bits)[0] |= 8192u; } static void set_has_required_chars(HasBits* has_bits) { (*has_bits)[0] |= 4u; } static void set_has_byte_fallback(HasBits* has_bits) { - (*has_bits)[0] |= 32768u; + (*has_bits)[0] |= 16384u; } static void set_has_vocabulary_output_piece_score(HasBits* has_bits) { - (*has_bits)[1] |= 1u; + (*has_bits)[0] |= 2147483648u; } static void set_has_hard_vocab_limit(HasBits* has_bits) { - (*has_bits)[1] |= 2u; + (*has_bits)[1] |= 1u; } static void set_has_use_all_vocab(HasBits* has_bits) { - (*has_bits)[0] |= 131072u; + (*has_bits)[0] |= 32768u; } static void set_has_unk_id(HasBits* has_bits) { (*has_bits)[0] |= 65536u; } static void set_has_bos_id(HasBits* has_bits) { - (*has_bits)[1] |= 4u; + (*has_bits)[1] |= 2u; } static void set_has_eos_id(HasBits* has_bits) { - (*has_bits)[1] |= 8u; + (*has_bits)[1] |= 4u; } static void set_has_pad_id(HasBits* has_bits) { - (*has_bits)[1] |= 16u; + (*has_bits)[1] |= 8u; } static void set_has_unk_piece(HasBits* has_bits) { (*has_bits)[0] |= 16u; @@ -387,7 +384,7 @@ class TrainerSpec::_Internal { (*has_bits)[0] |= 8u; } static void set_has_train_extremely_large_corpus(HasBits* has_bits) { - (*has_bits)[0] |= 262144u; + (*has_bits)[0] |= 131072u; } }; @@ -569,8 +566,8 @@ void TrainerSpec::Clear() { } if (cached_has_bits & 0x0000ff00u) { ::memset(&self_test_sample_size_, 0, static_cast( - reinterpret_cast(&byte_fallback_) - - reinterpret_cast(&self_test_sample_size_)) + sizeof(byte_fallback_)); + reinterpret_cast(&use_all_vocab_) - + reinterpret_cast(&self_test_sample_size_)) + sizeof(use_all_vocab_)); } if (cached_has_bits & 0x00ff0000u) { ::memset(&unk_id_, 0, static_cast( @@ -581,9 +578,9 @@ void TrainerSpec::Clear() { character_coverage_ = 0.9995f; seed_sentencepiece_size_ = 1000000; shrinking_factor_ = 0.75f; + num_threads_ = 16; } if (cached_has_bits & 0xff000000u) { - num_threads_ = 16; num_sub_iterations_ = 2; max_sentence_length_ = 4192; max_sentencepiece_length_ = 16; @@ -591,10 +588,10 @@ void TrainerSpec::Clear() { split_by_unicode_script_ = true; split_by_number_ = true; split_by_whitespace_ = true; + vocabulary_output_piece_score_ = true; } cached_has_bits = _has_bits_[1]; - if (cached_has_bits & 0x0000001fu) { - vocabulary_output_piece_score_ = true; + if (cached_has_bits & 0x0000000fu) { hard_vocab_limit_ = true; bos_id_ = 1; eos_id_ = 2; @@ -809,14 +806,6 @@ const char* TrainerSpec::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID CHK_(ptr); } else goto handle_unusual; continue; - // optional bool allow_whitespace_only_pieces = 26 [default = false]; - case 26: - if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 208)) { - _Internal::set_has_allow_whitespace_only_pieces(&_has_bits_); - allow_whitespace_only_pieces_ = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); - CHK_(ptr); - } else goto handle_unusual; - continue; // repeated string control_symbols = 30; case 30: if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 242)) { @@ -1011,14 +1000,14 @@ failure: } // optional .sentencepiece.TrainerSpec.ModelType model_type = 3 [default = UNIGRAM]; - if (cached_has_bits & 0x00080000u) { + if (cached_has_bits & 0x00040000u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteEnumToArray( 3, this->_internal_model_type(), target); } // optional int32 vocab_size = 4 [default = 8000]; - if (cached_has_bits & 0x00100000u) { + if (cached_has_bits & 0x00080000u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(4, this->_internal_vocab_size(), target); } @@ -1042,7 +1031,7 @@ failure: } // optional float character_coverage = 10 [default = 0.9995]; - if (cached_has_bits & 0x00200000u) { + if (cached_has_bits & 0x00100000u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteFloatToArray(10, this->_internal_character_coverage(), target); } @@ -1066,61 +1055,61 @@ failure: } // optional int32 seed_sentencepiece_size = 14 [default = 1000000]; - if (cached_has_bits & 0x00400000u) { + if (cached_has_bits & 0x00200000u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(14, this->_internal_seed_sentencepiece_size(), target); } // optional float shrinking_factor = 15 [default = 0.75]; - if (cached_has_bits & 0x00800000u) { + if (cached_has_bits & 0x00400000u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteFloatToArray(15, this->_internal_shrinking_factor(), target); } // optional int32 num_threads = 16 [default = 16]; - if (cached_has_bits & 0x01000000u) { + if (cached_has_bits & 0x00800000u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(16, this->_internal_num_threads(), target); } // optional int32 num_sub_iterations = 17 [default = 2]; - if (cached_has_bits & 0x02000000u) { + if (cached_has_bits & 0x01000000u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(17, this->_internal_num_sub_iterations(), target); } // optional int32 max_sentence_length = 18 [default = 4192]; - if (cached_has_bits & 0x04000000u) { + if (cached_has_bits & 0x02000000u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(18, this->_internal_max_sentence_length(), target); } // optional bool shuffle_input_sentence = 19 [default = true]; - if (cached_has_bits & 0x10000000u) { + if (cached_has_bits & 0x08000000u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(19, this->_internal_shuffle_input_sentence(), target); } // optional int32 max_sentencepiece_length = 20 [default = 16]; - if (cached_has_bits & 0x08000000u) { + if (cached_has_bits & 0x04000000u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(20, this->_internal_max_sentencepiece_length(), target); } // optional bool split_by_unicode_script = 21 [default = true]; - if (cached_has_bits & 0x20000000u) { + if (cached_has_bits & 0x10000000u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(21, this->_internal_split_by_unicode_script(), target); } // optional bool split_by_whitespace = 22 [default = true]; - if (cached_has_bits & 0x80000000u) { + if (cached_has_bits & 0x40000000u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(22, this->_internal_split_by_whitespace(), target); } // optional bool split_by_number = 23 [default = true]; - if (cached_has_bits & 0x40000000u) { + if (cached_has_bits & 0x20000000u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(23, this->_internal_split_by_number(), target); } @@ -1132,15 +1121,9 @@ failure: } // optional bool split_digits = 25 [default = false]; - if (cached_has_bits & 0x00004000u) { - target = stream->EnsureSpace(target); - target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(25, this->_internal_split_digits(), target); - } - - // optional bool allow_whitespace_only_pieces = 26 [default = false]; if (cached_has_bits & 0x00002000u) { target = stream->EnsureSpace(target); - target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(26, this->_internal_allow_whitespace_only_pieces(), target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(25, this->_internal_split_digits(), target); } // repeated string control_symbols = 30; @@ -1155,28 +1138,28 @@ failure: target = stream->WriteString(31, s, target); } - cached_has_bits = _has_bits_[1]; // optional bool vocabulary_output_piece_score = 32 [default = true]; - if (cached_has_bits & 0x00000001u) { + if (cached_has_bits & 0x80000000u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(32, this->_internal_vocabulary_output_piece_score(), target); } + cached_has_bits = _has_bits_[1]; // optional bool hard_vocab_limit = 33 [default = true]; - if (cached_has_bits & 0x00000002u) { + if (cached_has_bits & 0x00000001u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(33, this->_internal_hard_vocab_limit(), target); } cached_has_bits = _has_bits_[0]; // optional bool use_all_vocab = 34 [default = false]; - if (cached_has_bits & 0x00020000u) { + if (cached_has_bits & 0x00008000u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(34, this->_internal_use_all_vocab(), target); } // optional bool byte_fallback = 35 [default = false]; - if (cached_has_bits & 0x00008000u) { + if (cached_has_bits & 0x00004000u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(35, this->_internal_byte_fallback(), target); } @@ -1195,19 +1178,19 @@ failure: cached_has_bits = _has_bits_[1]; // optional int32 bos_id = 41 [default = 1]; - if (cached_has_bits & 0x00000004u) { + if (cached_has_bits & 0x00000002u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(41, this->_internal_bos_id(), target); } // optional int32 eos_id = 42 [default = 2]; - if (cached_has_bits & 0x00000008u) { + if (cached_has_bits & 0x00000004u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(42, this->_internal_eos_id(), target); } // optional int32 pad_id = 43 [default = -1]; - if (cached_has_bits & 0x00000010u) { + if (cached_has_bits & 0x00000008u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(43, this->_internal_pad_id(), target); } @@ -1244,7 +1227,7 @@ failure: } // optional bool train_extremely_large_corpus = 49 [default = false]; - if (cached_has_bits & 0x00040000u) { + if (cached_has_bits & 0x00020000u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(49, this->_internal_train_extremely_large_corpus(), target); } @@ -1396,17 +1379,17 @@ size_t TrainerSpec::ByteSizeLong() const { total_size += 2 + 1; } - // optional bool allow_whitespace_only_pieces = 26 [default = false]; + // optional bool split_digits = 25 [default = false]; if (cached_has_bits & 0x00002000u) { total_size += 2 + 1; } - // optional bool split_digits = 25 [default = false]; + // optional bool byte_fallback = 35 [default = false]; if (cached_has_bits & 0x00004000u) { total_size += 2 + 1; } - // optional bool byte_fallback = 35 [default = false]; + // optional bool use_all_vocab = 34 [default = false]; if (cached_has_bits & 0x00008000u) { total_size += 2 + 1; } @@ -1420,125 +1403,120 @@ size_t TrainerSpec::ByteSizeLong() const { this->_internal_unk_id()); } - // optional bool use_all_vocab = 34 [default = false]; + // optional bool train_extremely_large_corpus = 49 [default = false]; if (cached_has_bits & 0x00020000u) { total_size += 2 + 1; } - // optional bool train_extremely_large_corpus = 49 [default = false]; - if (cached_has_bits & 0x00040000u) { - total_size += 2 + 1; - } - // optional .sentencepiece.TrainerSpec.ModelType model_type = 3 [default = UNIGRAM]; - if (cached_has_bits & 0x00080000u) { + if (cached_has_bits & 0x00040000u) { total_size += 1 + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::EnumSize(this->_internal_model_type()); } // optional int32 vocab_size = 4 [default = 8000]; - if (cached_has_bits & 0x00100000u) { + if (cached_has_bits & 0x00080000u) { total_size += 1 + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( this->_internal_vocab_size()); } // optional float character_coverage = 10 [default = 0.9995]; - if (cached_has_bits & 0x00200000u) { + if (cached_has_bits & 0x00100000u) { total_size += 1 + 4; } // optional int32 seed_sentencepiece_size = 14 [default = 1000000]; - if (cached_has_bits & 0x00400000u) { + if (cached_has_bits & 0x00200000u) { total_size += 1 + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( this->_internal_seed_sentencepiece_size()); } // optional float shrinking_factor = 15 [default = 0.75]; - if (cached_has_bits & 0x00800000u) { + if (cached_has_bits & 0x00400000u) { total_size += 1 + 4; } - } - if (cached_has_bits & 0xff000000u) { // optional int32 num_threads = 16 [default = 16]; - if (cached_has_bits & 0x01000000u) { + if (cached_has_bits & 0x00800000u) { total_size += 2 + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( this->_internal_num_threads()); } + } + if (cached_has_bits & 0xff000000u) { // optional int32 num_sub_iterations = 17 [default = 2]; - if (cached_has_bits & 0x02000000u) { + if (cached_has_bits & 0x01000000u) { total_size += 2 + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( this->_internal_num_sub_iterations()); } // optional int32 max_sentence_length = 18 [default = 4192]; - if (cached_has_bits & 0x04000000u) { + if (cached_has_bits & 0x02000000u) { total_size += 2 + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( this->_internal_max_sentence_length()); } // optional int32 max_sentencepiece_length = 20 [default = 16]; - if (cached_has_bits & 0x08000000u) { + if (cached_has_bits & 0x04000000u) { total_size += 2 + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( this->_internal_max_sentencepiece_length()); } // optional bool shuffle_input_sentence = 19 [default = true]; - if (cached_has_bits & 0x10000000u) { + if (cached_has_bits & 0x08000000u) { total_size += 2 + 1; } // optional bool split_by_unicode_script = 21 [default = true]; - if (cached_has_bits & 0x20000000u) { + if (cached_has_bits & 0x10000000u) { total_size += 2 + 1; } // optional bool split_by_number = 23 [default = true]; - if (cached_has_bits & 0x40000000u) { + if (cached_has_bits & 0x20000000u) { total_size += 2 + 1; } // optional bool split_by_whitespace = 22 [default = true]; + if (cached_has_bits & 0x40000000u) { + total_size += 2 + 1; + } + + // optional bool vocabulary_output_piece_score = 32 [default = true]; if (cached_has_bits & 0x80000000u) { total_size += 2 + 1; } } cached_has_bits = _has_bits_[1]; - if (cached_has_bits & 0x0000001fu) { - // optional bool vocabulary_output_piece_score = 32 [default = true]; + if (cached_has_bits & 0x0000000fu) { + // optional bool hard_vocab_limit = 33 [default = true]; if (cached_has_bits & 0x00000001u) { total_size += 2 + 1; } - // optional bool hard_vocab_limit = 33 [default = true]; - if (cached_has_bits & 0x00000002u) { - total_size += 2 + 1; - } - // optional int32 bos_id = 41 [default = 1]; - if (cached_has_bits & 0x00000004u) { + if (cached_has_bits & 0x00000002u) { total_size += 2 + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( this->_internal_bos_id()); } // optional int32 eos_id = 42 [default = 2]; - if (cached_has_bits & 0x00000008u) { + if (cached_has_bits & 0x00000004u) { total_size += 2 + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( this->_internal_eos_id()); } // optional int32 pad_id = 43 [default = -1]; - if (cached_has_bits & 0x00000010u) { + if (cached_has_bits & 0x00000008u) { total_size += 2 + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( this->_internal_pad_id()); @@ -1615,14 +1593,14 @@ void TrainerSpec::MergeFrom(const TrainerSpec& from) { treat_whitespace_as_suffix_ = from.treat_whitespace_as_suffix_; } if (cached_has_bits & 0x00002000u) { - allow_whitespace_only_pieces_ = from.allow_whitespace_only_pieces_; - } - if (cached_has_bits & 0x00004000u) { split_digits_ = from.split_digits_; } - if (cached_has_bits & 0x00008000u) { + if (cached_has_bits & 0x00004000u) { byte_fallback_ = from.byte_fallback_; } + if (cached_has_bits & 0x00008000u) { + use_all_vocab_ = from.use_all_vocab_; + } _has_bits_[0] |= cached_has_bits; } if (cached_has_bits & 0x00ff0000u) { @@ -1630,70 +1608,67 @@ void TrainerSpec::MergeFrom(const TrainerSpec& from) { unk_id_ = from.unk_id_; } if (cached_has_bits & 0x00020000u) { - use_all_vocab_ = from.use_all_vocab_; - } - if (cached_has_bits & 0x00040000u) { train_extremely_large_corpus_ = from.train_extremely_large_corpus_; } - if (cached_has_bits & 0x00080000u) { + if (cached_has_bits & 0x00040000u) { model_type_ = from.model_type_; } - if (cached_has_bits & 0x00100000u) { + if (cached_has_bits & 0x00080000u) { vocab_size_ = from.vocab_size_; } - if (cached_has_bits & 0x00200000u) { + if (cached_has_bits & 0x00100000u) { character_coverage_ = from.character_coverage_; } - if (cached_has_bits & 0x00400000u) { + if (cached_has_bits & 0x00200000u) { seed_sentencepiece_size_ = from.seed_sentencepiece_size_; } - if (cached_has_bits & 0x00800000u) { + if (cached_has_bits & 0x00400000u) { shrinking_factor_ = from.shrinking_factor_; } + if (cached_has_bits & 0x00800000u) { + num_threads_ = from.num_threads_; + } _has_bits_[0] |= cached_has_bits; } if (cached_has_bits & 0xff000000u) { if (cached_has_bits & 0x01000000u) { - num_threads_ = from.num_threads_; - } - if (cached_has_bits & 0x02000000u) { num_sub_iterations_ = from.num_sub_iterations_; } - if (cached_has_bits & 0x04000000u) { + if (cached_has_bits & 0x02000000u) { max_sentence_length_ = from.max_sentence_length_; } - if (cached_has_bits & 0x08000000u) { + if (cached_has_bits & 0x04000000u) { max_sentencepiece_length_ = from.max_sentencepiece_length_; } - if (cached_has_bits & 0x10000000u) { + if (cached_has_bits & 0x08000000u) { shuffle_input_sentence_ = from.shuffle_input_sentence_; } - if (cached_has_bits & 0x20000000u) { + if (cached_has_bits & 0x10000000u) { split_by_unicode_script_ = from.split_by_unicode_script_; } - if (cached_has_bits & 0x40000000u) { + if (cached_has_bits & 0x20000000u) { split_by_number_ = from.split_by_number_; } - if (cached_has_bits & 0x80000000u) { + if (cached_has_bits & 0x40000000u) { split_by_whitespace_ = from.split_by_whitespace_; } + if (cached_has_bits & 0x80000000u) { + vocabulary_output_piece_score_ = from.vocabulary_output_piece_score_; + } _has_bits_[0] |= cached_has_bits; } cached_has_bits = from._has_bits_[1]; - if (cached_has_bits & 0x0000001fu) { + if (cached_has_bits & 0x0000000fu) { if (cached_has_bits & 0x00000001u) { - vocabulary_output_piece_score_ = from.vocabulary_output_piece_score_; - } - if (cached_has_bits & 0x00000002u) { hard_vocab_limit_ = from.hard_vocab_limit_; } - if (cached_has_bits & 0x00000004u) { + if (cached_has_bits & 0x00000002u) { bos_id_ = from.bos_id_; } - if (cached_has_bits & 0x00000008u) { + if (cached_has_bits & 0x00000004u) { eos_id_ = from.eos_id_; } - if (cached_has_bits & 0x00000010u) { + if (cached_has_bits & 0x00000008u) { pad_id_ = from.pad_id_; } _has_bits_[1] |= cached_has_bits; diff --git a/src/builtin_pb/sentencepiece_model.pb.h b/src/builtin_pb/sentencepiece_model.pb.h index f527aa7..31dc65b 100644 --- a/src/builtin_pb/sentencepiece_model.pb.h +++ b/src/builtin_pb/sentencepiece_model.pb.h @@ -278,11 +278,10 @@ class TrainerSpec PROTOBUF_FINAL : kInputSentenceSizeFieldNumber = 11, kTrainingSentenceSizeFieldNumber = 13, kTreatWhitespaceAsSuffixFieldNumber = 24, - kAllowWhitespaceOnlyPiecesFieldNumber = 26, kSplitDigitsFieldNumber = 25, kByteFallbackFieldNumber = 35, - kUnkIdFieldNumber = 40, kUseAllVocabFieldNumber = 34, + kUnkIdFieldNumber = 40, kTrainExtremelyLargeCorpusFieldNumber = 49, kModelTypeFieldNumber = 3, kVocabSizeFieldNumber = 4, @@ -624,19 +623,6 @@ class TrainerSpec PROTOBUF_FINAL : void _internal_set_treat_whitespace_as_suffix(bool value); public: - // optional bool allow_whitespace_only_pieces = 26 [default = false]; - bool has_allow_whitespace_only_pieces() const; - private: - bool _internal_has_allow_whitespace_only_pieces() const; - public: - void clear_allow_whitespace_only_pieces(); - bool allow_whitespace_only_pieces() const; - void set_allow_whitespace_only_pieces(bool value); - private: - bool _internal_allow_whitespace_only_pieces() const; - void _internal_set_allow_whitespace_only_pieces(bool value); - public: - // optional bool split_digits = 25 [default = false]; bool has_split_digits() const; private: @@ -663,19 +649,6 @@ class TrainerSpec PROTOBUF_FINAL : void _internal_set_byte_fallback(bool value); public: - // optional int32 unk_id = 40 [default = 0]; - bool has_unk_id() const; - private: - bool _internal_has_unk_id() const; - public: - void clear_unk_id(); - ::PROTOBUF_NAMESPACE_ID::int32 unk_id() const; - void set_unk_id(::PROTOBUF_NAMESPACE_ID::int32 value); - private: - ::PROTOBUF_NAMESPACE_ID::int32 _internal_unk_id() const; - void _internal_set_unk_id(::PROTOBUF_NAMESPACE_ID::int32 value); - public: - // optional bool use_all_vocab = 34 [default = false]; bool has_use_all_vocab() const; private: @@ -689,6 +662,19 @@ class TrainerSpec PROTOBUF_FINAL : void _internal_set_use_all_vocab(bool value); public: + // optional int32 unk_id = 40 [default = 0]; + bool has_unk_id() const; + private: + bool _internal_has_unk_id() const; + public: + void clear_unk_id(); + ::PROTOBUF_NAMESPACE_ID::int32 unk_id() const; + void set_unk_id(::PROTOBUF_NAMESPACE_ID::int32 value); + private: + ::PROTOBUF_NAMESPACE_ID::int32 _internal_unk_id() const; + void _internal_set_unk_id(::PROTOBUF_NAMESPACE_ID::int32 value); + public: + // optional bool train_extremely_large_corpus = 49 [default = false]; bool has_train_extremely_large_corpus() const; private: @@ -970,11 +956,10 @@ class TrainerSpec PROTOBUF_FINAL : ::PROTOBUF_NAMESPACE_ID::uint64 input_sentence_size_; ::PROTOBUF_NAMESPACE_ID::int32 training_sentence_size_; bool treat_whitespace_as_suffix_; - bool allow_whitespace_only_pieces_; bool split_digits_; bool byte_fallback_; - ::PROTOBUF_NAMESPACE_ID::int32 unk_id_; bool use_all_vocab_; + ::PROTOBUF_NAMESPACE_ID::int32 unk_id_; bool train_extremely_large_corpus_; int model_type_; ::PROTOBUF_NAMESPACE_ID::int32 vocab_size_; @@ -2195,7 +2180,7 @@ inline void TrainerSpec::set_allocated_model_prefix(std::string* model_prefix) { // optional .sentencepiece.TrainerSpec.ModelType model_type = 3 [default = UNIGRAM]; inline bool TrainerSpec::_internal_has_model_type() const { - bool value = (_has_bits_[0] & 0x00080000u) != 0; + bool value = (_has_bits_[0] & 0x00040000u) != 0; return value; } inline bool TrainerSpec::has_model_type() const { @@ -2203,7 +2188,7 @@ inline bool TrainerSpec::has_model_type() const { } inline void TrainerSpec::clear_model_type() { model_type_ = 1; - _has_bits_[0] &= ~0x00080000u; + _has_bits_[0] &= ~0x00040000u; } inline ::sentencepiece::TrainerSpec_ModelType TrainerSpec::_internal_model_type() const { return static_cast< ::sentencepiece::TrainerSpec_ModelType >(model_type_); @@ -2214,7 +2199,7 @@ inline ::sentencepiece::TrainerSpec_ModelType TrainerSpec::model_type() const { } inline void TrainerSpec::_internal_set_model_type(::sentencepiece::TrainerSpec_ModelType value) { assert(::sentencepiece::TrainerSpec_ModelType_IsValid(value)); - _has_bits_[0] |= 0x00080000u; + _has_bits_[0] |= 0x00040000u; model_type_ = value; } inline void TrainerSpec::set_model_type(::sentencepiece::TrainerSpec_ModelType value) { @@ -2224,7 +2209,7 @@ inline void TrainerSpec::set_model_type(::sentencepiece::TrainerSpec_ModelType v // optional int32 vocab_size = 4 [default = 8000]; inline bool TrainerSpec::_internal_has_vocab_size() const { - bool value = (_has_bits_[0] & 0x00100000u) != 0; + bool value = (_has_bits_[0] & 0x00080000u) != 0; return value; } inline bool TrainerSpec::has_vocab_size() const { @@ -2232,7 +2217,7 @@ inline bool TrainerSpec::has_vocab_size() const { } inline void TrainerSpec::clear_vocab_size() { vocab_size_ = 8000; - _has_bits_[0] &= ~0x00100000u; + _has_bits_[0] &= ~0x00080000u; } inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_vocab_size() const { return vocab_size_; @@ -2242,7 +2227,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::vocab_size() const { return _internal_vocab_size(); } inline void TrainerSpec::_internal_set_vocab_size(::PROTOBUF_NAMESPACE_ID::int32 value) { - _has_bits_[0] |= 0x00100000u; + _has_bits_[0] |= 0x00080000u; vocab_size_ = value; } inline void TrainerSpec::set_vocab_size(::PROTOBUF_NAMESPACE_ID::int32 value) { @@ -2354,7 +2339,7 @@ inline void TrainerSpec::set_self_test_sample_size(::PROTOBUF_NAMESPACE_ID::int3 // optional float character_coverage = 10 [default = 0.9995]; inline bool TrainerSpec::_internal_has_character_coverage() const { - bool value = (_has_bits_[0] & 0x00200000u) != 0; + bool value = (_has_bits_[0] & 0x00100000u) != 0; return value; } inline bool TrainerSpec::has_character_coverage() const { @@ -2362,7 +2347,7 @@ inline bool TrainerSpec::has_character_coverage() const { } inline void TrainerSpec::clear_character_coverage() { character_coverage_ = 0.9995f; - _has_bits_[0] &= ~0x00200000u; + _has_bits_[0] &= ~0x00100000u; } inline float TrainerSpec::_internal_character_coverage() const { return character_coverage_; @@ -2372,7 +2357,7 @@ inline float TrainerSpec::character_coverage() const { return _internal_character_coverage(); } inline void TrainerSpec::_internal_set_character_coverage(float value) { - _has_bits_[0] |= 0x00200000u; + _has_bits_[0] |= 0x00100000u; character_coverage_ = value; } inline void TrainerSpec::set_character_coverage(float value) { @@ -2410,7 +2395,7 @@ inline void TrainerSpec::set_input_sentence_size(::PROTOBUF_NAMESPACE_ID::uint64 // optional bool shuffle_input_sentence = 19 [default = true]; inline bool TrainerSpec::_internal_has_shuffle_input_sentence() const { - bool value = (_has_bits_[0] & 0x10000000u) != 0; + bool value = (_has_bits_[0] & 0x08000000u) != 0; return value; } inline bool TrainerSpec::has_shuffle_input_sentence() const { @@ -2418,7 +2403,7 @@ inline bool TrainerSpec::has_shuffle_input_sentence() const { } inline void TrainerSpec::clear_shuffle_input_sentence() { shuffle_input_sentence_ = true; - _has_bits_[0] &= ~0x10000000u; + _has_bits_[0] &= ~0x08000000u; } inline bool TrainerSpec::_internal_shuffle_input_sentence() const { return shuffle_input_sentence_; @@ -2428,7 +2413,7 @@ inline bool TrainerSpec::shuffle_input_sentence() const { return _internal_shuffle_input_sentence(); } inline void TrainerSpec::_internal_set_shuffle_input_sentence(bool value) { - _has_bits_[0] |= 0x10000000u; + _has_bits_[0] |= 0x08000000u; shuffle_input_sentence_ = value; } inline void TrainerSpec::set_shuffle_input_sentence(bool value) { @@ -2494,7 +2479,7 @@ inline void TrainerSpec::set_training_sentence_size(::PROTOBUF_NAMESPACE_ID::int // optional int32 seed_sentencepiece_size = 14 [default = 1000000]; inline bool TrainerSpec::_internal_has_seed_sentencepiece_size() const { - bool value = (_has_bits_[0] & 0x00400000u) != 0; + bool value = (_has_bits_[0] & 0x00200000u) != 0; return value; } inline bool TrainerSpec::has_seed_sentencepiece_size() const { @@ -2502,7 +2487,7 @@ inline bool TrainerSpec::has_seed_sentencepiece_size() const { } inline void TrainerSpec::clear_seed_sentencepiece_size() { seed_sentencepiece_size_ = 1000000; - _has_bits_[0] &= ~0x00400000u; + _has_bits_[0] &= ~0x00200000u; } inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_seed_sentencepiece_size() const { return seed_sentencepiece_size_; @@ -2512,7 +2497,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::seed_sentencepiece_size() con return _internal_seed_sentencepiece_size(); } inline void TrainerSpec::_internal_set_seed_sentencepiece_size(::PROTOBUF_NAMESPACE_ID::int32 value) { - _has_bits_[0] |= 0x00400000u; + _has_bits_[0] |= 0x00200000u; seed_sentencepiece_size_ = value; } inline void TrainerSpec::set_seed_sentencepiece_size(::PROTOBUF_NAMESPACE_ID::int32 value) { @@ -2522,7 +2507,7 @@ inline void TrainerSpec::set_seed_sentencepiece_size(::PROTOBUF_NAMESPACE_ID::in // optional float shrinking_factor = 15 [default = 0.75]; inline bool TrainerSpec::_internal_has_shrinking_factor() const { - bool value = (_has_bits_[0] & 0x00800000u) != 0; + bool value = (_has_bits_[0] & 0x00400000u) != 0; return value; } inline bool TrainerSpec::has_shrinking_factor() const { @@ -2530,7 +2515,7 @@ inline bool TrainerSpec::has_shrinking_factor() const { } inline void TrainerSpec::clear_shrinking_factor() { shrinking_factor_ = 0.75f; - _has_bits_[0] &= ~0x00800000u; + _has_bits_[0] &= ~0x00400000u; } inline float TrainerSpec::_internal_shrinking_factor() const { return shrinking_factor_; @@ -2540,7 +2525,7 @@ inline float TrainerSpec::shrinking_factor() const { return _internal_shrinking_factor(); } inline void TrainerSpec::_internal_set_shrinking_factor(float value) { - _has_bits_[0] |= 0x00800000u; + _has_bits_[0] |= 0x00400000u; shrinking_factor_ = value; } inline void TrainerSpec::set_shrinking_factor(float value) { @@ -2550,7 +2535,7 @@ inline void TrainerSpec::set_shrinking_factor(float value) { // optional int32 max_sentence_length = 18 [default = 4192]; inline bool TrainerSpec::_internal_has_max_sentence_length() const { - bool value = (_has_bits_[0] & 0x04000000u) != 0; + bool value = (_has_bits_[0] & 0x02000000u) != 0; return value; } inline bool TrainerSpec::has_max_sentence_length() const { @@ -2558,7 +2543,7 @@ inline bool TrainerSpec::has_max_sentence_length() const { } inline void TrainerSpec::clear_max_sentence_length() { max_sentence_length_ = 4192; - _has_bits_[0] &= ~0x04000000u; + _has_bits_[0] &= ~0x02000000u; } inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_max_sentence_length() const { return max_sentence_length_; @@ -2568,7 +2553,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::max_sentence_length() const { return _internal_max_sentence_length(); } inline void TrainerSpec::_internal_set_max_sentence_length(::PROTOBUF_NAMESPACE_ID::int32 value) { - _has_bits_[0] |= 0x04000000u; + _has_bits_[0] |= 0x02000000u; max_sentence_length_ = value; } inline void TrainerSpec::set_max_sentence_length(::PROTOBUF_NAMESPACE_ID::int32 value) { @@ -2578,7 +2563,7 @@ inline void TrainerSpec::set_max_sentence_length(::PROTOBUF_NAMESPACE_ID::int32 // optional int32 num_threads = 16 [default = 16]; inline bool TrainerSpec::_internal_has_num_threads() const { - bool value = (_has_bits_[0] & 0x01000000u) != 0; + bool value = (_has_bits_[0] & 0x00800000u) != 0; return value; } inline bool TrainerSpec::has_num_threads() const { @@ -2586,7 +2571,7 @@ inline bool TrainerSpec::has_num_threads() const { } inline void TrainerSpec::clear_num_threads() { num_threads_ = 16; - _has_bits_[0] &= ~0x01000000u; + _has_bits_[0] &= ~0x00800000u; } inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_num_threads() const { return num_threads_; @@ -2596,7 +2581,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::num_threads() const { return _internal_num_threads(); } inline void TrainerSpec::_internal_set_num_threads(::PROTOBUF_NAMESPACE_ID::int32 value) { - _has_bits_[0] |= 0x01000000u; + _has_bits_[0] |= 0x00800000u; num_threads_ = value; } inline void TrainerSpec::set_num_threads(::PROTOBUF_NAMESPACE_ID::int32 value) { @@ -2606,7 +2591,7 @@ inline void TrainerSpec::set_num_threads(::PROTOBUF_NAMESPACE_ID::int32 value) { // optional int32 num_sub_iterations = 17 [default = 2]; inline bool TrainerSpec::_internal_has_num_sub_iterations() const { - bool value = (_has_bits_[0] & 0x02000000u) != 0; + bool value = (_has_bits_[0] & 0x01000000u) != 0; return value; } inline bool TrainerSpec::has_num_sub_iterations() const { @@ -2614,7 +2599,7 @@ inline bool TrainerSpec::has_num_sub_iterations() const { } inline void TrainerSpec::clear_num_sub_iterations() { num_sub_iterations_ = 2; - _has_bits_[0] &= ~0x02000000u; + _has_bits_[0] &= ~0x01000000u; } inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_num_sub_iterations() const { return num_sub_iterations_; @@ -2624,7 +2609,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::num_sub_iterations() const { return _internal_num_sub_iterations(); } inline void TrainerSpec::_internal_set_num_sub_iterations(::PROTOBUF_NAMESPACE_ID::int32 value) { - _has_bits_[0] |= 0x02000000u; + _has_bits_[0] |= 0x01000000u; num_sub_iterations_ = value; } inline void TrainerSpec::set_num_sub_iterations(::PROTOBUF_NAMESPACE_ID::int32 value) { @@ -2634,7 +2619,7 @@ inline void TrainerSpec::set_num_sub_iterations(::PROTOBUF_NAMESPACE_ID::int32 v // optional int32 max_sentencepiece_length = 20 [default = 16]; inline bool TrainerSpec::_internal_has_max_sentencepiece_length() const { - bool value = (_has_bits_[0] & 0x08000000u) != 0; + bool value = (_has_bits_[0] & 0x04000000u) != 0; return value; } inline bool TrainerSpec::has_max_sentencepiece_length() const { @@ -2642,7 +2627,7 @@ inline bool TrainerSpec::has_max_sentencepiece_length() const { } inline void TrainerSpec::clear_max_sentencepiece_length() { max_sentencepiece_length_ = 16; - _has_bits_[0] &= ~0x08000000u; + _has_bits_[0] &= ~0x04000000u; } inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_max_sentencepiece_length() const { return max_sentencepiece_length_; @@ -2652,7 +2637,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::max_sentencepiece_length() co return _internal_max_sentencepiece_length(); } inline void TrainerSpec::_internal_set_max_sentencepiece_length(::PROTOBUF_NAMESPACE_ID::int32 value) { - _has_bits_[0] |= 0x08000000u; + _has_bits_[0] |= 0x04000000u; max_sentencepiece_length_ = value; } inline void TrainerSpec::set_max_sentencepiece_length(::PROTOBUF_NAMESPACE_ID::int32 value) { @@ -2662,7 +2647,7 @@ inline void TrainerSpec::set_max_sentencepiece_length(::PROTOBUF_NAMESPACE_ID::i // optional bool split_by_unicode_script = 21 [default = true]; inline bool TrainerSpec::_internal_has_split_by_unicode_script() const { - bool value = (_has_bits_[0] & 0x20000000u) != 0; + bool value = (_has_bits_[0] & 0x10000000u) != 0; return value; } inline bool TrainerSpec::has_split_by_unicode_script() const { @@ -2670,7 +2655,7 @@ inline bool TrainerSpec::has_split_by_unicode_script() const { } inline void TrainerSpec::clear_split_by_unicode_script() { split_by_unicode_script_ = true; - _has_bits_[0] &= ~0x20000000u; + _has_bits_[0] &= ~0x10000000u; } inline bool TrainerSpec::_internal_split_by_unicode_script() const { return split_by_unicode_script_; @@ -2680,7 +2665,7 @@ inline bool TrainerSpec::split_by_unicode_script() const { return _internal_split_by_unicode_script(); } inline void TrainerSpec::_internal_set_split_by_unicode_script(bool value) { - _has_bits_[0] |= 0x20000000u; + _has_bits_[0] |= 0x10000000u; split_by_unicode_script_ = value; } inline void TrainerSpec::set_split_by_unicode_script(bool value) { @@ -2690,7 +2675,7 @@ inline void TrainerSpec::set_split_by_unicode_script(bool value) { // optional bool split_by_number = 23 [default = true]; inline bool TrainerSpec::_internal_has_split_by_number() const { - bool value = (_has_bits_[0] & 0x40000000u) != 0; + bool value = (_has_bits_[0] & 0x20000000u) != 0; return value; } inline bool TrainerSpec::has_split_by_number() const { @@ -2698,7 +2683,7 @@ inline bool TrainerSpec::has_split_by_number() const { } inline void TrainerSpec::clear_split_by_number() { split_by_number_ = true; - _has_bits_[0] &= ~0x40000000u; + _has_bits_[0] &= ~0x20000000u; } inline bool TrainerSpec::_internal_split_by_number() const { return split_by_number_; @@ -2708,7 +2693,7 @@ inline bool TrainerSpec::split_by_number() const { return _internal_split_by_number(); } inline void TrainerSpec::_internal_set_split_by_number(bool value) { - _has_bits_[0] |= 0x40000000u; + _has_bits_[0] |= 0x20000000u; split_by_number_ = value; } inline void TrainerSpec::set_split_by_number(bool value) { @@ -2718,7 +2703,7 @@ inline void TrainerSpec::set_split_by_number(bool value) { // optional bool split_by_whitespace = 22 [default = true]; inline bool TrainerSpec::_internal_has_split_by_whitespace() const { - bool value = (_has_bits_[0] & 0x80000000u) != 0; + bool value = (_has_bits_[0] & 0x40000000u) != 0; return value; } inline bool TrainerSpec::has_split_by_whitespace() const { @@ -2726,7 +2711,7 @@ inline bool TrainerSpec::has_split_by_whitespace() const { } inline void TrainerSpec::clear_split_by_whitespace() { split_by_whitespace_ = true; - _has_bits_[0] &= ~0x80000000u; + _has_bits_[0] &= ~0x40000000u; } inline bool TrainerSpec::_internal_split_by_whitespace() const { return split_by_whitespace_; @@ -2736,7 +2721,7 @@ inline bool TrainerSpec::split_by_whitespace() const { return _internal_split_by_whitespace(); } inline void TrainerSpec::_internal_set_split_by_whitespace(bool value) { - _has_bits_[0] |= 0x80000000u; + _has_bits_[0] |= 0x40000000u; split_by_whitespace_ = value; } inline void TrainerSpec::set_split_by_whitespace(bool value) { @@ -2772,37 +2757,9 @@ inline void TrainerSpec::set_treat_whitespace_as_suffix(bool value) { // @@protoc_insertion_point(field_set:sentencepiece.TrainerSpec.treat_whitespace_as_suffix) } -// optional bool allow_whitespace_only_pieces = 26 [default = false]; -inline bool TrainerSpec::_internal_has_allow_whitespace_only_pieces() const { - bool value = (_has_bits_[0] & 0x00002000u) != 0; - return value; -} -inline bool TrainerSpec::has_allow_whitespace_only_pieces() const { - return _internal_has_allow_whitespace_only_pieces(); -} -inline void TrainerSpec::clear_allow_whitespace_only_pieces() { - allow_whitespace_only_pieces_ = false; - _has_bits_[0] &= ~0x00002000u; -} -inline bool TrainerSpec::_internal_allow_whitespace_only_pieces() const { - return allow_whitespace_only_pieces_; -} -inline bool TrainerSpec::allow_whitespace_only_pieces() const { - // @@protoc_insertion_point(field_get:sentencepiece.TrainerSpec.allow_whitespace_only_pieces) - return _internal_allow_whitespace_only_pieces(); -} -inline void TrainerSpec::_internal_set_allow_whitespace_only_pieces(bool value) { - _has_bits_[0] |= 0x00002000u; - allow_whitespace_only_pieces_ = value; -} -inline void TrainerSpec::set_allow_whitespace_only_pieces(bool value) { - _internal_set_allow_whitespace_only_pieces(value); - // @@protoc_insertion_point(field_set:sentencepiece.TrainerSpec.allow_whitespace_only_pieces) -} - // optional bool split_digits = 25 [default = false]; inline bool TrainerSpec::_internal_has_split_digits() const { - bool value = (_has_bits_[0] & 0x00004000u) != 0; + bool value = (_has_bits_[0] & 0x00002000u) != 0; return value; } inline bool TrainerSpec::has_split_digits() const { @@ -2810,7 +2767,7 @@ inline bool TrainerSpec::has_split_digits() const { } inline void TrainerSpec::clear_split_digits() { split_digits_ = false; - _has_bits_[0] &= ~0x00004000u; + _has_bits_[0] &= ~0x00002000u; } inline bool TrainerSpec::_internal_split_digits() const { return split_digits_; @@ -2820,7 +2777,7 @@ inline bool TrainerSpec::split_digits() const { return _internal_split_digits(); } inline void TrainerSpec::_internal_set_split_digits(bool value) { - _has_bits_[0] |= 0x00004000u; + _has_bits_[0] |= 0x00002000u; split_digits_ = value; } inline void TrainerSpec::set_split_digits(bool value) { @@ -3051,7 +3008,7 @@ inline void TrainerSpec::set_allocated_required_chars(std::string* required_char // optional bool byte_fallback = 35 [default = false]; inline bool TrainerSpec::_internal_has_byte_fallback() const { - bool value = (_has_bits_[0] & 0x00008000u) != 0; + bool value = (_has_bits_[0] & 0x00004000u) != 0; return value; } inline bool TrainerSpec::has_byte_fallback() const { @@ -3059,7 +3016,7 @@ inline bool TrainerSpec::has_byte_fallback() const { } inline void TrainerSpec::clear_byte_fallback() { byte_fallback_ = false; - _has_bits_[0] &= ~0x00008000u; + _has_bits_[0] &= ~0x00004000u; } inline bool TrainerSpec::_internal_byte_fallback() const { return byte_fallback_; @@ -3069,7 +3026,7 @@ inline bool TrainerSpec::byte_fallback() const { return _internal_byte_fallback(); } inline void TrainerSpec::_internal_set_byte_fallback(bool value) { - _has_bits_[0] |= 0x00008000u; + _has_bits_[0] |= 0x00004000u; byte_fallback_ = value; } inline void TrainerSpec::set_byte_fallback(bool value) { @@ -3079,7 +3036,7 @@ inline void TrainerSpec::set_byte_fallback(bool value) { // optional bool vocabulary_output_piece_score = 32 [default = true]; inline bool TrainerSpec::_internal_has_vocabulary_output_piece_score() const { - bool value = (_has_bits_[1] & 0x00000001u) != 0; + bool value = (_has_bits_[0] & 0x80000000u) != 0; return value; } inline bool TrainerSpec::has_vocabulary_output_piece_score() const { @@ -3087,7 +3044,7 @@ inline bool TrainerSpec::has_vocabulary_output_piece_score() const { } inline void TrainerSpec::clear_vocabulary_output_piece_score() { vocabulary_output_piece_score_ = true; - _has_bits_[1] &= ~0x00000001u; + _has_bits_[0] &= ~0x80000000u; } inline bool TrainerSpec::_internal_vocabulary_output_piece_score() const { return vocabulary_output_piece_score_; @@ -3097,7 +3054,7 @@ inline bool TrainerSpec::vocabulary_output_piece_score() const { return _internal_vocabulary_output_piece_score(); } inline void TrainerSpec::_internal_set_vocabulary_output_piece_score(bool value) { - _has_bits_[1] |= 0x00000001u; + _has_bits_[0] |= 0x80000000u; vocabulary_output_piece_score_ = value; } inline void TrainerSpec::set_vocabulary_output_piece_score(bool value) { @@ -3107,7 +3064,7 @@ inline void TrainerSpec::set_vocabulary_output_piece_score(bool value) { // optional bool hard_vocab_limit = 33 [default = true]; inline bool TrainerSpec::_internal_has_hard_vocab_limit() const { - bool value = (_has_bits_[1] & 0x00000002u) != 0; + bool value = (_has_bits_[1] & 0x00000001u) != 0; return value; } inline bool TrainerSpec::has_hard_vocab_limit() const { @@ -3115,7 +3072,7 @@ inline bool TrainerSpec::has_hard_vocab_limit() const { } inline void TrainerSpec::clear_hard_vocab_limit() { hard_vocab_limit_ = true; - _has_bits_[1] &= ~0x00000002u; + _has_bits_[1] &= ~0x00000001u; } inline bool TrainerSpec::_internal_hard_vocab_limit() const { return hard_vocab_limit_; @@ -3125,7 +3082,7 @@ inline bool TrainerSpec::hard_vocab_limit() const { return _internal_hard_vocab_limit(); } inline void TrainerSpec::_internal_set_hard_vocab_limit(bool value) { - _has_bits_[1] |= 0x00000002u; + _has_bits_[1] |= 0x00000001u; hard_vocab_limit_ = value; } inline void TrainerSpec::set_hard_vocab_limit(bool value) { @@ -3135,7 +3092,7 @@ inline void TrainerSpec::set_hard_vocab_limit(bool value) { // optional bool use_all_vocab = 34 [default = false]; inline bool TrainerSpec::_internal_has_use_all_vocab() const { - bool value = (_has_bits_[0] & 0x00020000u) != 0; + bool value = (_has_bits_[0] & 0x00008000u) != 0; return value; } inline bool TrainerSpec::has_use_all_vocab() const { @@ -3143,7 +3100,7 @@ inline bool TrainerSpec::has_use_all_vocab() const { } inline void TrainerSpec::clear_use_all_vocab() { use_all_vocab_ = false; - _has_bits_[0] &= ~0x00020000u; + _has_bits_[0] &= ~0x00008000u; } inline bool TrainerSpec::_internal_use_all_vocab() const { return use_all_vocab_; @@ -3153,7 +3110,7 @@ inline bool TrainerSpec::use_all_vocab() const { return _internal_use_all_vocab(); } inline void TrainerSpec::_internal_set_use_all_vocab(bool value) { - _has_bits_[0] |= 0x00020000u; + _has_bits_[0] |= 0x00008000u; use_all_vocab_ = value; } inline void TrainerSpec::set_use_all_vocab(bool value) { @@ -3191,7 +3148,7 @@ inline void TrainerSpec::set_unk_id(::PROTOBUF_NAMESPACE_ID::int32 value) { // optional int32 bos_id = 41 [default = 1]; inline bool TrainerSpec::_internal_has_bos_id() const { - bool value = (_has_bits_[1] & 0x00000004u) != 0; + bool value = (_has_bits_[1] & 0x00000002u) != 0; return value; } inline bool TrainerSpec::has_bos_id() const { @@ -3199,7 +3156,7 @@ inline bool TrainerSpec::has_bos_id() const { } inline void TrainerSpec::clear_bos_id() { bos_id_ = 1; - _has_bits_[1] &= ~0x00000004u; + _has_bits_[1] &= ~0x00000002u; } inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_bos_id() const { return bos_id_; @@ -3209,7 +3166,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::bos_id() const { return _internal_bos_id(); } inline void TrainerSpec::_internal_set_bos_id(::PROTOBUF_NAMESPACE_ID::int32 value) { - _has_bits_[1] |= 0x00000004u; + _has_bits_[1] |= 0x00000002u; bos_id_ = value; } inline void TrainerSpec::set_bos_id(::PROTOBUF_NAMESPACE_ID::int32 value) { @@ -3219,7 +3176,7 @@ inline void TrainerSpec::set_bos_id(::PROTOBUF_NAMESPACE_ID::int32 value) { // optional int32 eos_id = 42 [default = 2]; inline bool TrainerSpec::_internal_has_eos_id() const { - bool value = (_has_bits_[1] & 0x00000008u) != 0; + bool value = (_has_bits_[1] & 0x00000004u) != 0; return value; } inline bool TrainerSpec::has_eos_id() const { @@ -3227,7 +3184,7 @@ inline bool TrainerSpec::has_eos_id() const { } inline void TrainerSpec::clear_eos_id() { eos_id_ = 2; - _has_bits_[1] &= ~0x00000008u; + _has_bits_[1] &= ~0x00000004u; } inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_eos_id() const { return eos_id_; @@ -3237,7 +3194,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::eos_id() const { return _internal_eos_id(); } inline void TrainerSpec::_internal_set_eos_id(::PROTOBUF_NAMESPACE_ID::int32 value) { - _has_bits_[1] |= 0x00000008u; + _has_bits_[1] |= 0x00000004u; eos_id_ = value; } inline void TrainerSpec::set_eos_id(::PROTOBUF_NAMESPACE_ID::int32 value) { @@ -3247,7 +3204,7 @@ inline void TrainerSpec::set_eos_id(::PROTOBUF_NAMESPACE_ID::int32 value) { // optional int32 pad_id = 43 [default = -1]; inline bool TrainerSpec::_internal_has_pad_id() const { - bool value = (_has_bits_[1] & 0x00000010u) != 0; + bool value = (_has_bits_[1] & 0x00000008u) != 0; return value; } inline bool TrainerSpec::has_pad_id() const { @@ -3255,7 +3212,7 @@ inline bool TrainerSpec::has_pad_id() const { } inline void TrainerSpec::clear_pad_id() { pad_id_ = -1; - _has_bits_[1] &= ~0x00000010u; + _has_bits_[1] &= ~0x00000008u; } inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_pad_id() const { return pad_id_; @@ -3265,7 +3222,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::pad_id() const { return _internal_pad_id(); } inline void TrainerSpec::_internal_set_pad_id(::PROTOBUF_NAMESPACE_ID::int32 value) { - _has_bits_[1] |= 0x00000010u; + _has_bits_[1] |= 0x00000008u; pad_id_ = value; } inline void TrainerSpec::set_pad_id(::PROTOBUF_NAMESPACE_ID::int32 value) { @@ -3645,7 +3602,7 @@ inline void TrainerSpec::set_allocated_unk_surface(std::string* unk_surface) { // optional bool train_extremely_large_corpus = 49 [default = false]; inline bool TrainerSpec::_internal_has_train_extremely_large_corpus() const { - bool value = (_has_bits_[0] & 0x00040000u) != 0; + bool value = (_has_bits_[0] & 0x00020000u) != 0; return value; } inline bool TrainerSpec::has_train_extremely_large_corpus() const { @@ -3653,7 +3610,7 @@ inline bool TrainerSpec::has_train_extremely_large_corpus() const { } inline void TrainerSpec::clear_train_extremely_large_corpus() { train_extremely_large_corpus_ = false; - _has_bits_[0] &= ~0x00040000u; + _has_bits_[0] &= ~0x00020000u; } inline bool TrainerSpec::_internal_train_extremely_large_corpus() const { return train_extremely_large_corpus_; @@ -3663,7 +3620,7 @@ inline bool TrainerSpec::train_extremely_large_corpus() const { return _internal_train_extremely_large_corpus(); } inline void TrainerSpec::_internal_set_train_extremely_large_corpus(bool value) { - _has_bits_[0] |= 0x00040000u; + _has_bits_[0] |= 0x00020000u; train_extremely_large_corpus_ = value; } inline void TrainerSpec::set_train_extremely_large_corpus(bool value) { diff --git a/src/model_interface.cc b/src/model_interface.cc index c49be1e..ea5d0e7 100644 --- a/src/model_interface.cc +++ b/src/model_interface.cc @@ -134,53 +134,32 @@ void ModelInterface::InitializePieces() { } std::vector SplitIntoWords(absl::string_view text, - bool treat_ws_as_suffix, - bool allow_ws_only_pieces) { + bool treat_whitespace_as_suffix) { const char *begin = text.data(); const char *end = text.data() + text.size(); // Space symbol (U+2581) const absl::string_view kSpaceSymbol = "\xe2\x96\x81"; - bool in_ws_sequence = false; std::vector result; - if (treat_ws_as_suffix) { // put ws tokens at the end of non-ws sequences. + if (treat_whitespace_as_suffix) { if (begin < end) result.emplace_back(begin, 0); while (begin < end) { const int mblen = std::min(string_util::OneCharLen(begin), end - begin); const bool is_ws = absl::string_view(begin, mblen) == kSpaceSymbol; - - if (is_ws) { // keep track of sequences consecutive ws tokens. - in_ws_sequence = true; - } else if (in_ws_sequence) { - if (allow_ws_only_pieces) result.emplace_back(begin, 0); - - in_ws_sequence = false; - } - result.back() = absl::string_view(result.back().data(), result.back().size() + mblen); begin += mblen; - - if (begin < end && is_ws && !allow_ws_only_pieces) - result.emplace_back(begin, 0); + if (begin < end && is_ws) result.emplace_back(begin, 0); } } else { while (begin < end) { const int mblen = std::min(string_util::OneCharLen(begin), end - begin); - bool is_ws = absl::string_view(begin, mblen) == kSpaceSymbol; - - // if is whitespace (and not in sequence if allow_ws_only_pieces is True) if (begin == text.data() || - (is_ws && (!in_ws_sequence || !allow_ws_only_pieces))) { + absl::string_view(begin, mblen) == kSpaceSymbol) result.emplace_back(begin, 0); // add empty string piece. - in_ws_sequence = true; - } - - if (in_ws_sequence && !is_ws) in_ws_sequence = false; - result.back() = absl::string_view(result.back().data(), result.back().size() + mblen); begin += mblen; diff --git a/src/model_interface.h b/src/model_interface.h index 06b3a65..75cbb23 100644 --- a/src/model_interface.h +++ b/src/model_interface.h @@ -33,9 +33,8 @@ namespace sentencepiece { // "_this_is_a_pen" => ["_this", "_is", "_a", "_pen"] -std::vector SplitIntoWords( - absl::string_view text, bool treat_ws_as_suffix = false, - bool allow_ws_only_pieces = false); +std::vector SplitIntoWords(absl::string_view text, + bool add_ws_as_suffix = false); // Converts byte (0-255) to piece (e.g., 58 -> "<0x3A>"). std::string ByteToPiece(unsigned char c); @@ -107,42 +106,12 @@ class ModelInterface { return EncodeResult(); } - // Sample `samples` many tokenisations from the segmentation lattice - // If `wor` is true, the samples are taken without replacement, and the scores - // are the inclusion probabilities of the elements in the sample; otherwise - // the samples are taken with replacement and the scores are the log-probs of - // sample elements - // If `include_best` is true, the best tokenisation is always included in the - // sample, and the remaining elements are sampled excluding the best. - virtual NBestEncodeResult SampleEncodeAndScore(absl::string_view normalized, - float alpha, int samples, - bool wor, - bool include_best) const { - LOG(ERROR) << "Not implemented."; - return {{EncodeResult(), 0.0}}; - } - - // Calculates the entropy of the segmentation lattice with inverse temperature - // `theta`. - // Uses a novel dynamic program to calculate the entropy. - virtual float CalculateEntropy(absl::string_view normalized, - float theta) const { - LOG(ERROR) << "Not implemented."; - return 0.0; - } - // Return true if SampleEncode returns a valid result. virtual bool IsSampleEncodeAvailable() const { return false; } // Return true if NBestEncode returns a valid result. virtual bool IsNBestEncodeAvailable() const { return false; } - // Return true if SampleEncodeAndScore returns a valid result. - virtual bool IsSampleEncodeAndScoreAvailable() const { return false; } - - // Return true if CalculateEntropy returns a valid result. - virtual bool IsCalculateEntropyAvailable() const { return false; } - // Returns the vocab id of `piece`. // Returns UNK(0) if `piece` is unknown virtual int PieceToId(absl::string_view piece) const; @@ -155,10 +124,7 @@ class ModelInterface { // Returns the size of sentence pieces, which is the same // as the size of vocabulary for NMT. - virtual int GetPieceSize() const { - if (!model_proto_) return 0; - return model_proto_->pieces_size(); - } + virtual int GetPieceSize() const { return model_proto_->pieces_size(); } // Returns the score of `id`. // Score represents a log probability of the piece. diff --git a/src/model_interface_test.cc b/src/model_interface_test.cc index 69ee4e6..f5ee492 100644 --- a/src/model_interface_test.cc +++ b/src/model_interface_test.cc @@ -412,50 +412,6 @@ TEST(ModelInterfaceTest, SplitIntoWordsSuffixTest) { } } -TEST(ModelInterfaceTest, SplitIntoWordsWhiteSpaceOnly) { - { - const auto v = - SplitIntoWords("this" WS "is" WS "a" WS "pen" WS, true, true); - EXPECT_EQ(4, v.size()); - EXPECT_EQ("this" WS, v[0]); - EXPECT_EQ("is" WS, v[1]); - EXPECT_EQ("a" WS, v[2]); - EXPECT_EQ("pen" WS, v[3]); - } - - { - const auto v = SplitIntoWords(WS WS WS "a", false, true); - EXPECT_EQ(1, v.size()); - EXPECT_EQ(WS WS WS "a", v[0]); - } - - { - const auto v = SplitIntoWords("a" WS WS WS, true, true); - EXPECT_EQ(1, v.size()); - EXPECT_EQ("a" WS WS WS, v[0]); - } - - { - const auto v = SplitIntoWords(WS WS, true, true); - EXPECT_EQ(1, v.size()); - EXPECT_EQ(WS WS, v[0]); - } - - { - const auto v = SplitIntoWords(WS WS "a" WS, true, true); - EXPECT_EQ(2, v.size()); - EXPECT_EQ(WS WS, v[0]); - EXPECT_EQ("a" WS, v[1]); - } - - { - const auto v = SplitIntoWords(WS WS "a" WS, false, true); - EXPECT_EQ(2, v.size()); - EXPECT_EQ(WS WS "a", v[0]); - EXPECT_EQ(WS, v[1]); - } -} - TEST(ModelInterfaceTest, ByteToPieceTest) { EXPECT_EQ(ByteToPiece(0), "<0x00>"); EXPECT_EQ(ByteToPiece(1), "<0x01>"); diff --git a/src/normalizer.cc b/src/normalizer.cc index d87f89b..3fe919b 100644 --- a/src/normalizer.cc +++ b/src/normalizer.cc @@ -12,11 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License.! +#include "normalizer.h" + #include #include #include "common.h" -#include "normalizer.h" #include "third_party/absl/memory/memory.h" #include "third_party/absl/strings/match.h" #include "third_party/absl/strings/string_view.h" @@ -277,11 +278,11 @@ util::Status Normalizer::DecodePrecompiledCharsMap( absl::string_view blob, absl::string_view *trie_blob, absl::string_view *normalized, std::string *buffer) { uint32 trie_blob_size = 0; + if (blob.size() <= sizeof(trie_blob_size) || !string_util::DecodePOD( absl::string_view(blob.data(), sizeof(trie_blob_size)), - &trie_blob_size) || - trie_blob_size >= blob.size()) { + &trie_blob_size)) { return util::InternalError("Blob for normalization rule is broken."); } diff --git a/src/normalizer.h b/src/normalizer.h index c79813c..37fdb8a 100644 --- a/src/normalizer.h +++ b/src/normalizer.h @@ -22,6 +22,7 @@ #include #include "common.h" +#include "util.h" #include "sentencepiece_model.pb.h" #include "sentencepiece_processor.h" #include "third_party/absl/strings/string_view.h" diff --git a/src/normalizer_test.cc b/src/normalizer_test.cc index 6c402bf..585e8f4 100644 --- a/src/normalizer_test.cc +++ b/src/normalizer_test.cc @@ -12,10 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License.! +#include "normalizer.h" + #include #include "builder.h" -#include "normalizer.h" #include "sentencepiece_trainer.h" #include "testharness.h" #include "util.h" diff --git a/src/sentencepiece_model.proto b/src/sentencepiece_model.proto index ee8e877..e735527 100644 --- a/src/sentencepiece_model.proto +++ b/src/sentencepiece_model.proto @@ -139,10 +139,6 @@ message TrainerSpec { // of sentence. optional bool treat_whitespace_as_suffix = 24 [default = false]; - // Allows pieces that only contain whitespaces instead of appearing only as - // prefix or suffix of other pieces. - optional bool allow_whitespace_only_pieces = 26 [default = false]; - // Split all digits (0-9) into separate pieces. optional bool split_digits = 25 [default = false]; diff --git a/src/sentencepiece_processor.cc b/src/sentencepiece_processor.cc index 1e4e7a0..e4e9d4a 100644 --- a/src/sentencepiece_processor.cc +++ b/src/sentencepiece_processor.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License.! +#include "sentencepiece_processor.h" + #include #include #include @@ -22,7 +24,6 @@ #include "model_interface.h" #include "normalizer.h" #include "sentencepiece.pb.h" -#include "sentencepiece_processor.h" #include "third_party/absl/memory/memory.h" #include "third_party/absl/strings/numbers.h" #include "third_party/absl/strings/str_cat.h" @@ -503,43 +504,6 @@ util::Status SentencePieceProcessor::SampleEncode( return util::OkStatus(); } -util::Status SentencePieceProcessor::SampleEncodeAndScore( - absl::string_view input, int samples, float theta, bool wor, - bool include_best, NBestSentencePieceText *samples_spt) const { - CHECK_OR_RETURN(model_->IsSampleEncodeAndScoreAvailable()) - << "SampleEncodeAndScore is not available for the current model."; - std::string normalized; - std::vector norm_to_orig; - RETURN_IF_ERROR(normalizer_->Normalize(input, &normalized, &norm_to_orig)); - - const auto results = model_->SampleEncodeAndScore(normalized, theta, samples, - wor, include_best); - CHECK_OR_RETURN(!results.empty()) - << "SampleEncodeAndScore returns empty result."; - - for (const auto &result : results) { - auto *spt = samples_spt->add_nbests(); - spt->set_score(result.second); - RETURN_IF_ERROR(PopulateSentencePieceText(input, normalized, norm_to_orig, - result.first, spt)); - } - - return util::OkStatus(); -} - -util::Status SentencePieceProcessor::CalculateEntropy(absl::string_view input, - float theta, - float *entropy) const { - CHECK_OR_RETURN(model_->IsCalculateEntropyAvailable()) - << "CalculateEntropy is not available for the current model."; - std::string normalized; - std::vector norm_to_orig; - RETURN_IF_ERROR(normalizer_->Normalize(input, &normalized, &norm_to_orig)); - - *entropy = model_->CalculateEntropy(normalized, theta); - return util::OkStatus(); -} - util::Status SentencePieceProcessor::Decode( const std::vector &pieces, SentencePieceText *spt) const { CHECK_OR_RETURN_STATUS_PROTO(spt); @@ -869,12 +833,6 @@ std::string SentencePieceProcessor::serialized_model_proto() const { return model_proto_ ? model_proto_->SerializeAsString() : ""; } -// Set seed value of random generator. -// Do not set static_cast(-1), -// as this seed is reserved for initializing from -// std::random_device. -void SetRandomGeneratorSeed(unsigned int seed); - namespace io { util::Status LoadModelProto(absl::string_view filename, diff --git a/src/sentencepiece_processor.h b/src/sentencepiece_processor.h index 7c75838..7227920 100644 --- a/src/sentencepiece_processor.h +++ b/src/sentencepiece_processor.h @@ -315,15 +315,6 @@ class SentencePieceProcessor { virtual util::Status SampleEncode(absl::string_view input, int nbest_size, float alpha, SentencePieceText *spt) const; - // Samples N segmentation and returns the scores as well - virtual util::Status SampleEncodeAndScore( - absl::string_view input, int samples, float theta, bool wor, - bool include_best, NBestSentencePieceText *samples_spt) const; - - // Calculate entropy of possible tokenisations - virtual util::Status CalculateEntropy(absl::string_view input, float theta, - float *entropy) const; - // Given a sequence of pieces, decodes it into SentencePieceText. virtual util::Status Decode(const std::vector &pieces, SentencePieceText *spt) const; diff --git a/src/sentencepiece_processor_test.cc b/src/sentencepiece_processor_test.cc index 373e73e..e10a47c 100644 --- a/src/sentencepiece_processor_test.cc +++ b/src/sentencepiece_processor_test.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License.! +#include "sentencepiece_processor.h" + #include #include "builder.h" @@ -20,7 +22,6 @@ #include "normalizer.h" #include "sentencepiece.pb.h" #include "sentencepiece_model.pb.h" -#include "sentencepiece_processor.h" #include "sentencepiece_trainer.h" #include "testharness.h" #include "third_party/absl/container/flat_hash_map.h" @@ -1138,6 +1139,13 @@ TEST(SentencePieceProcessorTest, EndToEndTest) { EXPECT_EQ("cba", output); } + // Out of range + { + std::string output; + const std::vector ids = {3, 4, 127}; + EXPECT_FALSE(sp.Decode(ids, &output).ok()); + } + { EXPECT_TRUE(sp.SetDecodeExtraOptions("bos:eos").ok()); @@ -1164,13 +1172,6 @@ TEST(SentencePieceProcessorTest, EndToEndTest) { EXPECT_EQ("cba", output); } - // Out of range - { - std::string output; - const std::vector ids = {3, 4, 127}; - EXPECT_FALSE(sp.Decode(ids, &output).ok()); - } - { EXPECT_TRUE(sp.SetDecodeExtraOptions("bos:eos:reverse").ok()); diff --git a/src/sentencepiece_trainer.cc b/src/sentencepiece_trainer.cc index 888f05e..429d0f4 100644 --- a/src/sentencepiece_trainer.cc +++ b/src/sentencepiece_trainer.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License.! +#include "sentencepiece_trainer.h" + #include #include @@ -20,9 +22,7 @@ #include "normalizer.h" #include "sentencepiece.pb.h" #include "sentencepiece_model.pb.h" -#include "sentencepiece_trainer.h" #include "spec_parser.h" -#include "third_party/absl/flags/flag.h" #include "third_party/absl/strings/numbers.h" #include "third_party/absl/strings/str_cat.h" #include "third_party/absl/strings/str_split.h" @@ -31,8 +31,6 @@ #include "trainer_factory.h" #include "util.h" -ABSL_DECLARE_FLAG(int, minloglevel); - namespace sentencepiece { namespace { static constexpr char kDefaultNormalizerName[] = "nmt_nfkc"; @@ -112,7 +110,7 @@ util::Status SentencePieceTrainer::MergeSpecsFromArgs( for (auto arg : absl::StrSplit(args, " ")) { absl::ConsumePrefix(&arg, "--"); std::string key, value; - const auto pos = arg.find('='); + const auto pos = arg.find("="); if (pos == absl::string_view::npos) { key = std::string(arg); } else { @@ -151,7 +149,7 @@ util::Status SentencePieceTrainer::MergeSpecsFromArgs( } else if (key == "minloglevel") { int v = 0; CHECK_OR_RETURN(absl::SimpleAtoi(value, &v)); - absl::SetFlag(&FLAGS_minloglevel, v); + logging::SetMinLogLevel(v); continue; } diff --git a/src/spec_parser.h b/src/spec_parser.h index 2c5a95b..a168322 100644 --- a/src/spec_parser.h +++ b/src/spec_parser.h @@ -145,7 +145,6 @@ inline std::string PrintProto(const TrainerSpec &message, PRINT_PARAM(split_by_whitespace); PRINT_PARAM(split_digits); PRINT_PARAM(treat_whitespace_as_suffix); - PRINT_PARAM(allow_whitespace_only_pieces); PRINT_REPEATED_STRING(control_symbols); PRINT_REPEATED_STRING(user_defined_symbols); PRINT_PARAM(required_chars); @@ -220,7 +219,6 @@ util::Status SentencePieceTrainer::SetProtoField(const std::string &name, PARSE_BOOL(split_by_whitespace); PARSE_BOOL(split_digits); PARSE_BOOL(treat_whitespace_as_suffix); - PARSE_BOOL(allow_whitespace_only_pieces); PARSE_REPEATED_STRING(control_symbols); PARSE_REPEATED_STRING(user_defined_symbols); PARSE_STRING(required_chars); diff --git a/src/spm_decode_main.cc b/src/spm_decode_main.cc index 3382ddc..32cb382 100644 --- a/src/spm_decode_main.cc +++ b/src/spm_decode_main.cc @@ -64,7 +64,6 @@ int main(int argc, char *argv[]) { auto ToIds = [&](const std::vector &pieces) { std::vector ids; - ids.reserve(pieces.size()); for (const auto &s : pieces) { ids.push_back(atoi(s.c_str())); } diff --git a/src/spm_encode_main.cc b/src/spm_encode_main.cc index 4d12a38..f151ecf 100644 --- a/src/spm_encode_main.cc +++ b/src/spm_encode_main.cc @@ -28,17 +28,16 @@ #include "trainer_interface.h" ABSL_FLAG(std::string, model, "", "model file name"); -ABSL_FLAG( - std::string, output_format, "piece", - "choose from piece, id, proto, nbest_piece, nbest_id, or nbest_proto"); +ABSL_FLAG(std::string, output_format, "piece", + "choose from piece, id, proto, nbest_piece, nbest_id, nbest_proto, " + "sample_piece, sample_id or sample_proto."); ABSL_FLAG(std::string, input, "", "input filename"); ABSL_FLAG(std::string, output, "", "output filename"); ABSL_FLAG(std::string, extra_options, "", "':' separated encoder extra options, e.g., \"reverse:bos:eos\""); ABSL_FLAG(int32, nbest_size, 10, "NBest size"); ABSL_FLAG(double, alpha, 0.5, "Smoothing parameter for sampling mode."); -ABSL_FLAG(uint32, random_seed, static_cast(-1), - "Seed value for random generator."); +ABSL_FLAG(int32, random_seed, -1, "Seed value for random generator."); // Piece restriction with vocabulary file. // https://github.com/rsennrich/subword-nmt#best-practice-advice-for-byte-pair-encoding-in-nmt @@ -62,9 +61,8 @@ int main(int argc, char *argv[]) { rest_args.push_back(absl::GetFlag(FLAGS_input)); } - if (absl::GetFlag(FLAGS_random_seed) != -1) { + if (absl::GetFlag(FLAGS_random_seed) != -1) sentencepiece::SetRandomGeneratorSeed(absl::GetFlag(FLAGS_random_seed)); - } if (rest_args.empty()) rest_args.push_back(""); // empty means that reading from stdin. diff --git a/src/spm_train_main.cc b/src/spm_train_main.cc index baf8dbf..a21fb8b 100644 --- a/src/spm_train_main.cc +++ b/src/spm_train_main.cc @@ -80,9 +80,6 @@ ABSL_FLAG(bool, split_digits, kDefaultTrainerSpec.split_digits(), ABSL_FLAG(bool, treat_whitespace_as_suffix, kDefaultTrainerSpec.treat_whitespace_as_suffix(), "treat whitespace marker as suffix instead of prefix."); -ABSL_FLAG(bool, allow_whitespace_only_pieces, - kDefaultTrainerSpec.allow_whitespace_only_pieces(), - "allow pieces that only contain (consecutive) whitespace tokens"); ABSL_FLAG(std::string, control_symbols, "", "comma separated list of control symbols"); ABSL_FLAG(std::string, control_symbols_file, "", @@ -141,8 +138,7 @@ ABSL_FLAG(std::string, unk_surface, kDefaultTrainerSpec.unk_surface(), ABSL_FLAG(bool, train_extremely_large_corpus, kDefaultTrainerSpec.train_extremely_large_corpus(), "Increase bit depth for unigram tokenization."); -ABSL_FLAG(uint32, random_seed, static_cast(-1), - "Seed value for random generator."); +ABSL_FLAG(int32, random_seed, -1, "Seed value for random generator."); int main(int argc, char *argv[]) { sentencepiece::ParseCommandLineFlags(argv[0], &argc, &argv, true); @@ -154,9 +150,8 @@ int main(int argc, char *argv[]) { CHECK(!absl::GetFlag(FLAGS_input).empty()); CHECK(!absl::GetFlag(FLAGS_model_prefix).empty()); - if (absl::GetFlag(FLAGS_random_seed) != -1) { + if (absl::GetFlag(FLAGS_random_seed) != -1) sentencepiece::SetRandomGeneratorSeed(absl::GetFlag(FLAGS_random_seed)); - } auto load_lines = [](absl::string_view filename) { std::vector lines; @@ -216,7 +211,6 @@ int main(int argc, char *argv[]) { SetTrainerSpecFromFlag(split_digits); SetTrainerSpecFromFlag(byte_fallback); SetTrainerSpecFromFlag(treat_whitespace_as_suffix); - SetTrainerSpecFromFlag(allow_whitespace_only_pieces); SetTrainerSpecFromFlag(hard_vocab_limit); SetTrainerSpecFromFlag(use_all_vocab); SetTrainerSpecFromFlag(unk_id); diff --git a/src/trainer_interface.cc b/src/trainer_interface.cc index a3a4b74..53edc7b 100644 --- a/src/trainer_interface.cc +++ b/src/trainer_interface.cc @@ -12,7 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License.! -#include +#include "trainer_interface.h" + #include #include #include @@ -33,7 +34,6 @@ #include "third_party/absl/strings/str_format.h" #include "third_party/absl/strings/str_join.h" #include "third_party/absl/strings/str_split.h" -#include "trainer_interface.h" #include "unicode_script.h" #include "util.h" @@ -86,10 +86,6 @@ util::Status VerifySpec(const TrainerSpec &trainer_spec) { return util::OkStatus(); } -bool is_unicode_decimal_number(char32 c) { - return (c >= 0x30 && c <= 0x39) || (c >= 0xff10 && c <= 0xff19); -} - class SentenceSelector { public: using Sampler = random::ReservoirSampler; @@ -214,10 +210,9 @@ bool TrainerInterface::IsValidSentencePiece( constexpr unicode_script::ScriptType kAnyType = static_cast(-1); + auto is_number = [](char32 c) { return (c >= 0x30 && c <= 0x39); }; + unicode_script::ScriptType prev_script = kAnyType; - bool all_whitespace_piece = - std::all_of(sentencepiece.begin(), sentencepiece.end(), - [](char32 c) { return c == kWSChar; }); for (size_t pos = 0; pos < sentencepiece.size(); ++pos) { const char32 c = sentencepiece[pos]; @@ -240,30 +235,25 @@ bool TrainerInterface::IsValidSentencePiece( } if (c == kWSChar) { - // Only allows whitespace to appear as a prefix of piece unless - // allow_whitespace_only_pieces is True. + // Only allows whitespace to appear as a prefix of piece. // When split_by_whitespace is false, we allow whitespaces to // appear in the middle, "foo_bar", but do not allow them // to appear as suffix, "foo_bar_". // Regardless of the setting of split_by_whitespace, // whitespace is treated as a prefix/infix of symbol or - // independent symbol, unless allow_whitespace_only_pieces() is true, - // in which case whitespace only pieces can occur. - if (!trainer_spec_.allow_whitespace_only_pieces() || - !all_whitespace_piece) { - if (trainer_spec_.treat_whitespace_as_suffix()) { - if ((trainer_spec_.split_by_whitespace() && - pos < sentencepiece.size() - 1) || - (!trainer_spec_.split_by_whitespace() && - pos < sentencepiece.size() - 1 && pos == 0)) { - return false; - } - } else { - if ((trainer_spec_.split_by_whitespace() && pos > 0) || - (!trainer_spec_.split_by_whitespace() && pos > 0 && - pos == sentencepiece.size() - 1)) { - return false; - } + // independent symbol. + if (trainer_spec_.treat_whitespace_as_suffix()) { + if ((trainer_spec_.split_by_whitespace() && + pos < sentencepiece.size() - 1) || + (!trainer_spec_.split_by_whitespace() && + pos < sentencepiece.size() - 1 && pos == 0)) { + return false; + } + } else { + if ((trainer_spec_.split_by_whitespace() && pos > 0) || + (!trainer_spec_.split_by_whitespace() && pos > 0 && + pos == sentencepiece.size() - 1)) { + return false; } } } else { @@ -275,11 +265,11 @@ bool TrainerInterface::IsValidSentencePiece( s = unicode_script::U_Han; } - if (!trainer_spec_.split_by_number() && is_unicode_decimal_number(c)) { + if (!trainer_spec_.split_by_number() && is_number(c)) { s = kAnyType; } - if (trainer_spec_.split_digits() && is_unicode_decimal_number(c)) { + if (trainer_spec_.split_digits() && is_number(c)) { if (sentencepiece.size() > 1) return false; } @@ -528,8 +518,7 @@ void TrainerInterface::SplitSentencesByWhitespace() { absl::flat_hash_map tokens; for (const auto &s : sentences_) { for (const auto &w : - SplitIntoWords(s.first, trainer_spec_.treat_whitespace_as_suffix(), - trainer_spec_.allow_whitespace_only_pieces())) { + SplitIntoWords(s.first, trainer_spec_.treat_whitespace_as_suffix())) { tokens[std::string(w)] += s.second; } } diff --git a/src/trainer_interface_test.cc b/src/trainer_interface_test.cc index 70a51ad..c61c7ce 100644 --- a/src/trainer_interface_test.cc +++ b/src/trainer_interface_test.cc @@ -81,7 +81,6 @@ TEST(TrainerInterfaceTest, IsValidSentencePieceTest) { trainer_spec.set_split_by_whitespace(false); EXPECT_TRUE(IsValid(WS)); - EXPECT_TRUE(IsValid(WS WS WS "a")); EXPECT_TRUE(IsValid(WS "a")); EXPECT_FALSE(IsValid("a" WS)); EXPECT_FALSE(IsValid(WS "a" WS)); @@ -89,17 +88,7 @@ TEST(TrainerInterfaceTest, IsValidSentencePieceTest) { EXPECT_TRUE(IsValid(WS "a" WS "b")); EXPECT_TRUE(IsValid(WS "a" WS "b" WS "c")); EXPECT_FALSE(IsValid("a" WS "b" WS)); - EXPECT_FALSE(IsValid(WS WS)); - EXPECT_FALSE(IsValid(WS WS WS)); - trainer_spec.set_allow_whitespace_only_pieces(true); - EXPECT_TRUE(IsValid(WS)); - EXPECT_TRUE(IsValid(WS WS)); - EXPECT_TRUE(IsValid(WS WS WS)); - EXPECT_TRUE(IsValid(WS WS "a")); - EXPECT_FALSE(IsValid("a" WS WS)); // suffix whitespace illegal without flag - - trainer_spec.set_allow_whitespace_only_pieces(false); trainer_spec.set_split_by_unicode_script(false); EXPECT_TRUE(IsValid("あいう")); EXPECT_TRUE(IsValid("グーグル")); @@ -135,15 +124,6 @@ TEST(TrainerInterfaceTest, IsValidSentencePieceTest) { EXPECT_FALSE(IsValid(WS "a" WS "b")); EXPECT_FALSE(IsValid("a" WS "b" WS)); - trainer_spec.set_allow_whitespace_only_pieces(true); - EXPECT_TRUE(IsValid(WS)); - EXPECT_TRUE(IsValid(WS WS)); - EXPECT_FALSE(IsValid(WS "a" WS)); - EXPECT_FALSE(IsValid("a" WS "b")); - EXPECT_FALSE(IsValid(WS "a" WS "b")); - EXPECT_FALSE(IsValid("a" WS "b" WS)); - - trainer_spec.set_allow_whitespace_only_pieces(false); trainer_spec.set_split_by_whitespace(false); EXPECT_TRUE(IsValid(WS)); EXPECT_FALSE(IsValid(WS "a")); @@ -166,12 +146,6 @@ TEST(TrainerInterfaceTest, IsValidSentencePieceTest) { EXPECT_FALSE(IsValid("2007")); EXPECT_FALSE(IsValid("x1")); EXPECT_FALSE(IsValid("2x")); - // Fullwidth digits. - EXPECT_TRUE(IsValid("1")); - EXPECT_FALSE(IsValid("59")); - EXPECT_FALSE(IsValid("2007")); - EXPECT_FALSE(IsValid("*1")); - EXPECT_FALSE(IsValid("2*")); } TEST(TrainerInterfaceTest, OverrideSpecialPiecesTest) { diff --git a/src/unigram_model.cc b/src/unigram_model.cc index 3b99060..bd2d99b 100644 --- a/src/unigram_model.cc +++ b/src/unigram_model.cc @@ -15,7 +15,6 @@ #include #include #include -#include #include #include #include @@ -56,17 +55,6 @@ inline float LogSumExp(float x, float y, bool init_mode) { return vmax + log(std::exp(static_cast(vmin - vmax)) + 1.0); } } - -// Returns a sample from a standard Gumbel distribution. -// If U ~ U[0, 1], -log(-log U) ~ G(0,1) -inline float Gumbel() { - const float kEpsilon = 1e-7; - auto *mt = random::GetRandomGenerator(); - std::uniform_real_distribution dis(0.0, 1.0); - float noise = -std::log(-(std::log(dis(*mt) + kEpsilon))); - - return noise; -} } // namespace Lattice::Lattice() : node_allocator_(kPreallocateLatticeNodeSize) {} @@ -157,7 +145,7 @@ Lattice::Node *Lattice::Insert(int pos, int length) { return node; } -Lattice::LatticePathWithScore Lattice::Viterbi() { +std::vector Lattice::Viterbi() { const int len = size(); for (int pos = 0; pos <= len; ++pos) { @@ -183,7 +171,6 @@ Lattice::LatticePathWithScore Lattice::Viterbi() { // backtrace std::vector results; - float score = begin_nodes(len)[0]->backtrace_score; for (Node *node = begin_nodes_[len][0]->prev; node->prev != nullptr; node = node->prev) { results.push_back(node); @@ -191,43 +178,7 @@ Lattice::LatticePathWithScore Lattice::Viterbi() { std::reverse(results.begin(), results.end()); - LatticePathWithScore retval = {results, score}; - - return retval; -} - -std::vector Lattice::ForwardAlgorithm(float theta) const { - const int len = size(); - std::vector alpha(node_allocator_.size(), 0.0); - - for (int pos = 0; pos <= len; ++pos) { - for (Node *rnode : begin_nodes_[pos]) { - for (Node *lnode : end_nodes_[pos]) { - alpha[rnode->node_id] = LogSumExp( - alpha[rnode->node_id], theta * lnode->score + alpha[lnode->node_id], - lnode == end_nodes_[pos][0]); - } - } - } - - return alpha; -} - -std::vector Lattice::BackwardAlgorithm(float theta) const { - const int len = size(); - std::vector beta(node_allocator_.size(), 0.0); - - for (int pos = len; pos >= 0; --pos) { - for (Node *lnode : end_nodes_[pos]) { - for (Node *rnode : begin_nodes_[pos]) { - beta[lnode->node_id] = - LogSumExp(beta[lnode->node_id], rnode->score + beta[rnode->node_id], - rnode == begin_nodes_[pos][0]); - } - } - } - - return beta; + return results; } float Lattice::PopulateMarginal(float freq, @@ -238,9 +189,28 @@ float Lattice::PopulateMarginal(float freq, // alpha and beta (accumulative log prob) in Forward Backward. // the index of alpha/beta is Node::node_id. + std::vector alpha(node_allocator_.size(), 0.0); + std::vector beta(node_allocator_.size(), 0.0); - const auto alpha = ForwardAlgorithm(1.0); - const auto beta = BackwardAlgorithm(1.0); + for (int pos = 0; pos <= len; ++pos) { + for (Node *rnode : begin_nodes_[pos]) { + for (Node *lnode : end_nodes_[pos]) { + alpha[rnode->node_id] = LogSumExp(alpha[rnode->node_id], + lnode->score + alpha[lnode->node_id], + lnode == end_nodes_[pos][0]); + } + } + } + + for (int pos = len; pos >= 0; --pos) { + for (Node *lnode : end_nodes_[pos]) { + for (Node *rnode : begin_nodes_[pos]) { + beta[lnode->node_id] = + LogSumExp(beta[lnode->node_id], rnode->score + beta[rnode->node_id], + rnode == begin_nodes_[pos][0]); + } + } + } const float Z = alpha[begin_nodes_[len][0]->node_id]; for (int pos = 0; pos < len; ++pos) { @@ -258,46 +228,13 @@ float Lattice::PopulateMarginal(float freq, return freq * Z; } -float Lattice::CalculateEntropy(float theta) const { - const int len = size(); - - // alpha[node_id] is the marginal prob of sequence up to start of node - // H is entropy of sequence - // the index of alpha/H is Node::node_id. - std::vector alpha(node_allocator_.size(), 0.0); - std::vector H(node_allocator_.size(), 0.0); - - // Populate the forward marginals to get the normalising constant - alpha = ForwardAlgorithm(theta); - - // Now populate the forward entropies - for (int pos = 0; pos <= len; ++pos) { - for (Node *rnode : begin_nodes_[pos]) { - for (Node *lnode : end_nodes_[pos]) { - // Contribution each lnode makes = p(lnode) * (H(lnode) + log p(lnode)) - - // We have to normalise p(lnode) by the marginal contribution it makes - const float lnode_transition_prob = - ((theta * lnode->score) + alpha[lnode->node_id] - - alpha[rnode->node_id]); - H[rnode->node_id] += std::exp(lnode_transition_prob) * - (H[lnode->node_id] + lnode_transition_prob); - } - } - } - - return -H[begin_nodes_[len][0]->node_id]; -} - -std::vector Lattice::NBest(size_t nbest_size, - bool sample, - float theta) { +std::vector> Lattice::NBest(size_t nbest_size) { if (nbest_size < 1) { LOG(WARNING) << "nbest_size >= 1. Returns empty result."; return {}; } - if (nbest_size == 1 && !sample) { + if (nbest_size == 1) { return {Viterbi()}; } @@ -306,7 +243,6 @@ std::vector Lattice::NBest(size_t nbest_size, // At each partial path x, compute f(x) as follows // f(x) = g(x) + h(x). // g(x): the sum of scores from EOS to the left-most node in x. - // for a complete hypothesis, g(hyp) is the score of the hypothesis. // h(x): a heuristic that estimates the largest score from x to BOS. // f(x): the priority to pop a new hypothesis from the priority queue. // @@ -332,27 +268,18 @@ std::vector Lattice::NBest(size_t nbest_size, model::FreeList hypothesis_allocator(kPreallocatedHypothesisSize); Agenda agenda; - std::vector results; + std::vector> results; auto *eos = hypothesis_allocator.Allocate(); eos->node = eos_node(); eos->next = nullptr; - eos->gx = 0.0; - - std::vector alpha(node_allocator_.size(), 0.0); - - if (sample) { - // Run forwards algorithm to get normalising constants - alpha = ForwardAlgorithm(theta); - // f(eos) = Gumbel(0), as it is the perturbed score of the entire lattice. - eos->fx = Gumbel(); - } else { - // Run Viterbi first to fill backtrace score. - Viterbi(); - eos->fx = eos->node->backtrace_score; - } + eos->fx = eos->node->score; + eos->gx = eos->node->score; agenda.push(eos); + // Run Viterbi first to fill backtrace score. + Viterbi(); + while (!agenda.empty()) { auto *top = agenda.top(); agenda.pop(); @@ -362,56 +289,21 @@ std::vector Lattice::NBest(size_t nbest_size, if (node == bos_node()) { results.resize(results.size() + 1); for (auto *n = top->next; n->next != nullptr; n = n->next) { - results.back().first.push_back(n->node); + results.back().push_back(n->node); } - results.back().second = top->fx; if (results.size() == nbest_size) { break; } continue; } - const int end_nodes_size = end_nodes(node->pos).size(); - std::vector probs(end_nodes_size, 0.0); - std::vector perturbed_probs(end_nodes_size, 0.0); - std::vector adjusted_probs(end_nodes_size, 0.0); - const float Z = alpha[node->node_id]; - if (sample) { - float max_score = -1e8; - // Calculate the marginal and perturbed scores for stochastic search - for (int i = 0; i < end_nodes(node->pos).size(); i++) { - Node *lnode = end_nodes(node->pos)[i]; - // Calculate backwards transition score - probs[i] = top->gx + alpha[lnode->node_id] + (theta * lnode->score) - Z; - perturbed_probs[i] = probs[i] + Gumbel(); - if (perturbed_probs[i] > max_score) { - max_score = perturbed_probs[i]; - } - } - // Now constrain the sampled continuations to match the score of parent - for (int i = 0; i < adjusted_probs.size(); i++) { - // Use numerically stable version of truncated Gumbel: - // https://arxiv.org/pdf/1903.06059.pdf appendix B.3 - const float v = top->fx - perturbed_probs[i] + - std::log1p(-std::exp(perturbed_probs[i] - max_score)); - adjusted_probs[i] = top->fx - std::max(static_cast(0.0), v) - - std::log1p(std::exp(-std::abs(v))); - } - } - // Expands new node ending at node->pos - for (int i = 0; i < end_nodes(node->pos).size(); i++) { - Node *lnode = end_nodes(node->pos)[i]; + for (Node *lnode : end_nodes(node->pos)) { auto *hyp = hypothesis_allocator.Allocate(); hyp->node = lnode; - if (sample) { - hyp->gx = probs[i]; - hyp->fx = adjusted_probs[i]; - } else { - hyp->gx = lnode->score + top->gx; // just adds node->score - hyp->fx = - lnode->backtrace_score + top->gx; // backtrace_score is h(node). - } + hyp->gx = lnode->score + top->gx; // just adds node->score + hyp->fx = + lnode->backtrace_score + top->gx; // backtrace_score is h(node). hyp->next = top; agenda.push(hyp); } @@ -443,7 +335,15 @@ std::vector Lattice::Sample(float theta) { std::vector alpha(node_allocator_.size(), 0.0); - alpha = ForwardAlgorithm(theta); + for (int pos = 0; pos <= len; ++pos) { + for (Node *rnode : begin_nodes_[pos]) { + for (Node *lnode : end_nodes_[pos]) { + alpha[rnode->node_id] = LogSumExp( + alpha[rnode->node_id], theta * lnode->score + alpha[lnode->node_id], + lnode == end_nodes_[pos][0]); + } + } + } auto *mt = random::GetRandomGenerator(); @@ -614,7 +514,7 @@ EncodeResult Model::Encode(absl::string_view normalized) const { PopulateNodes(&lattice); EncodeResult results; - for (const auto *node : lattice.Viterbi().first) { + for (const auto *node : lattice.Viterbi()) { results.emplace_back(node->piece, node->id); } @@ -634,12 +534,14 @@ NBestEncodeResult Model::NBestEncode(absl::string_view normalized, PopulateNodes(&lattice); NBestEncodeResult nbest_results; - for (const auto &nbest : lattice.NBest(nbest_size, false, 0.0)) { + for (const auto &nbest : lattice.NBest(nbest_size)) { EncodeResult results; - for (const auto *node : nbest.first) { + float score = 0.0; + for (const auto *node : nbest) { + score += node->score; results.emplace_back(node->piece, node->id); } - nbest_results.emplace_back(results, nbest.second); + nbest_results.emplace_back(results, score); } return nbest_results; @@ -663,123 +565,6 @@ EncodeResult Model::SampleEncode(absl::string_view normalized, return results; } -NBestEncodeResult Model::SampleEncodeAndScore(absl::string_view normalized, - float theta, int samples, - bool wor, - bool include_best) const { - if (!status().ok() || normalized.empty()) { - return {}; - } - NBestEncodeResult results; - Lattice lattice; - lattice.SetSentence(normalized); - PopulateNodes(&lattice); - - std::vector alpha = lattice.ForwardAlgorithm(theta); - float marginal = alpha[lattice.eos_node()->node_id]; - - if (include_best) { - if (!wor) { - LOG(FATAL) << "include_best not supported for wor false"; - } - EncodeResult result; - Lattice::LatticePathWithScore best_path = lattice.Viterbi(); - - for (const auto *node : best_path.first) { - result.emplace_back(node->piece, node->id); - } - - // Inclusion probability if we always include the best is 1. - results.emplace_back(result, 0.0); - } - - if (wor) { - // Draw k+1 samples as we need perturbed score of k+1th element - std::vector nbest_samples = - lattice.NBest(samples + 1, true, theta); - - if (include_best) { - std::vector> nbest_paths( - nbest_samples.size()); - for (int i = 0; i < nbest_samples.size(); i++) { - nbest_paths[i] = nbest_samples[i].first; - } - // Remove the best result from the samples if necessary - Lattice::LatticePathWithScore best_path = lattice.Viterbi(); - - const int index_of_best = - (std::find(nbest_paths.begin(), nbest_paths.end(), best_path.first) - - nbest_paths.begin()); - - if (index_of_best != nbest_samples.size()) { - LOG(INFO) << "removing best path from samples"; - nbest_samples.erase(nbest_samples.begin() + index_of_best); - } else { - nbest_samples.pop_back(); - } - } - // We use the perturbed score of the k+1th element to calculate the - // inclusion probability. - const double kappa = static_cast(nbest_samples.back().second); - // Discard the last sample - nbest_samples.pop_back(); - for (const auto &nbest : nbest_samples) { - EncodeResult result; - float score = 0.0; - - for (const auto *node : nbest.first) { - score += (theta * node->score); - result.emplace_back(node->piece, node->id); - } - - results.emplace_back(result, score - marginal); - } - - // Now calculate the inclusion probability - for (auto &it : results) { - // Only modify non best sample inclusion probabilities. - if (it.second != 0.0) { - double x = it.second - kappa; - double y = std::exp(x); - double inclusion_prob; - if (x <= -10) { - // Series expansion of the log Gumbel survival function up to eps. - inclusion_prob = - x - (y / 2) + (std::pow(y, 2) / 24) - std::pow(y, 4) / 2880; - } else { - inclusion_prob = std::log(-std::expm1(-y)); - } - it.second = static_cast(inclusion_prob); - } - } - } else { - while (results.size() < samples) { - Lattice lattice; - lattice.SetSentence(normalized); - PopulateNodes(&lattice); - - float score = 0.0; - EncodeResult result; - std::vector sample = lattice.Sample(theta); - for (const auto *node : sample) { - result.emplace_back(node->piece, node->id); - score += (theta * node->score); - } - results.emplace_back(result, score - marginal); - } - } - - return results; -} - -float Model::CalculateEntropy(absl::string_view normalized, float theta) const { - Lattice lattice; - lattice.SetSentence(normalized); - PopulateNodes(&lattice); - - return lattice.CalculateEntropy(theta); -} - bool Model::VerifyOutputsEquivalent(absl::string_view expected, absl::string_view actual) const { auto compute_unigram_model_score = diff --git a/src/unigram_model.h b/src/unigram_model.h index 448e489..2f66a5f 100644 --- a/src/unigram_model.h +++ b/src/unigram_model.h @@ -82,28 +82,17 @@ class Lattice { // After calling this method, The caller must set Node::score and Node::id. Node *Insert(int pos, int length); - using LatticePathWithScore = std::pair, float>; - // Returns Viterbi path. All nodes must be populated in advance. - LatticePathWithScore Viterbi(); - - // Runs forwards/backwards algorithm, returns vector with normalised - // transition probs. - std::vector ForwardAlgorithm(float theta) const; - std::vector BackwardAlgorithm(float theta) const; + std::vector Viterbi(); // Returns n-best results. - std::vector NBest(size_t nbest_size, bool sample, - float theta); + std::vector> NBest(size_t nbest_size); // Samples one path from the lattice according to the // generation probability (Product of piece probabilities). // `theta` is a smoothing parameter. std::vector Sample(float theta); - // Calculates the entropy of the lattice. - float CalculateEntropy(float theta) const; - // Populates marginal probability of every node in this lattice. // |freq| is the frequency of the sentence. // for (auto *node : all_nodes_) { @@ -138,19 +127,8 @@ class Model : public ModelInterface { EncodeResult SampleEncode(absl::string_view normalized, float theta) const override; - NBestEncodeResult SampleEncodeAndScore(absl::string_view normalized, - float theta, int samples, bool wor, - bool include_best) const override; - - float CalculateEntropy(absl::string_view normalized, - float theta) const override; - bool IsSampleEncodeAvailable() const override { return true; } - bool IsSampleEncodeAndScoreAvailable() const override { return true; } - - bool IsCalculateEntropyAvailable() const override { return true; } - bool IsNBestEncodeAvailable() const override { return true; } // Returns the minimum score in sentence pieces. diff --git a/src/unigram_model_test.cc b/src/unigram_model_test.cc index 5c292cb..dacec38 100644 --- a/src/unigram_model_test.cc +++ b/src/unigram_model_test.cc @@ -161,11 +161,11 @@ TEST(LatticeTest, InsertTest) { TEST(LatticeTest, ViterbiFromIncompleteLatticeTest) { Lattice lattice; lattice.SetSentence("ABC"); - EXPECT_TRUE(lattice.Viterbi().first.empty()); + EXPECT_TRUE(lattice.Viterbi().empty()); // Still incomplete lattice.Insert(0, 1); - EXPECT_TRUE(lattice.Viterbi().first.empty()); + EXPECT_TRUE(lattice.Viterbi().empty()); lattice.Insert(1, 1); lattice.Insert(2, 1); @@ -198,16 +198,16 @@ TEST(LatticeTest, ViterbiTest) { InsertWithScore(&lattice, 0, 1, 0.0); // A InsertWithScore(&lattice, 1, 1, 0.0); // B InsertWithScore(&lattice, 2, 1, 0.0); // C - EXPECT_EQ("A B C", GetTokenized(lattice.Viterbi().first)); + EXPECT_EQ("A B C", GetTokenized(lattice.Viterbi())); InsertWithScore(&lattice, 0, 2, 2.0); // AB - EXPECT_EQ("AB C", GetTokenized(lattice.Viterbi().first)); + EXPECT_EQ("AB C", GetTokenized(lattice.Viterbi())); InsertWithScore(&lattice, 1, 2, 5.0); // BC - EXPECT_EQ("A BC", GetTokenized(lattice.Viterbi().first)); + EXPECT_EQ("A BC", GetTokenized(lattice.Viterbi())); InsertWithScore(&lattice, 0, 3, 10.0); // ABC - EXPECT_EQ("ABC", GetTokenized(lattice.Viterbi().first)); + EXPECT_EQ("ABC", GetTokenized(lattice.Viterbi())); } TEST(LatticeTest, NBestTest) { @@ -221,174 +221,21 @@ TEST(LatticeTest, NBestTest) { InsertWithScore(&lattice, 1, 2, 5.0); // BC InsertWithScore(&lattice, 0, 3, 10.0); // ABC - auto nbests = lattice.NBest(10, false, 0.0); + auto nbests = lattice.NBest(10); EXPECT_EQ(4, nbests.size()); - EXPECT_EQ("ABC", GetTokenized(nbests[0].first)); - EXPECT_EQ("A BC", GetTokenized(nbests[1].first)); - EXPECT_EQ("AB C", GetTokenized(nbests[2].first)); - EXPECT_EQ("A B C", GetTokenized(nbests[3].first)); + EXPECT_EQ("ABC", GetTokenized(nbests[0])); + EXPECT_EQ("A BC", GetTokenized(nbests[1])); + EXPECT_EQ("AB C", GetTokenized(nbests[2])); + EXPECT_EQ("A B C", GetTokenized(nbests[3])); - auto nbests0 = lattice.NBest(0, false, 0.0); + auto nbests0 = lattice.NBest(0); EXPECT_TRUE(nbests0.empty()); - auto nbests1 = lattice.NBest(1, false, 0.0); + auto nbests1 = lattice.NBest(1); EXPECT_EQ(nbests1.size(), 1); } -TEST(LatticeTest, NBestSampleTest) { - Lattice lattice; - lattice.SetSentence("ABC"); - - InsertWithScore(&lattice, 0, 1, 0.0); // A - InsertWithScore(&lattice, 1, 1, 0.0); // B - InsertWithScore(&lattice, 2, 1, 0.1); // C - InsertWithScore(&lattice, 0, 2, 0.2); // AB - InsertWithScore(&lattice, 1, 2, 0.5); // BC - InsertWithScore(&lattice, 0, 3, 1.0); // ABC - - // Calculate expected probabilities of each path - // Note that sampling without replacement affects the expected frequencies! - const std::vector kTheta = {0.0, 0.01, 0.5, 0.7, 1.0}; - for (const auto theta : kTheta) { - std::vector strings = {"ABC", "AB C", "A BC", "A B C"}; - std::map probs; - probs["ABC"] = std::exp(theta * 1.0); - probs["AB C"] = std::exp(theta * (0.2 + 0.1)); - probs["A BC"] = std::exp(theta * (0.0 + 0.5)); - probs["A B C"] = std::exp(theta * (0.0 + 0.0 + 0.1)); - - for (const auto &it : strings) { - EXPECT_EQ(1, probs.count(it)); - } - - double Z = 0.0; - for (const auto &it : probs) Z += it.second; - for (auto &it : probs) it.second /= Z; - - std::map, float> pair_probs; - for (const auto first : strings) { - for (const auto second : strings) { - if (first == second) { - pair_probs[std::make_pair(first, second)] = 0; - } else { - float first_prob = probs[first]; - float second_prob = probs[second] / (1 - first_prob); - pair_probs[std::make_pair(first, second)] = first_prob * second_prob; - } - } - } - - std::map inclusion_probs; - for (const auto string : strings) { - float inclusion_prob = 0.0; - for (const auto other_string : strings) { - inclusion_prob += pair_probs[std::make_pair(string, other_string)]; - } - for (const auto other_string : strings) { - inclusion_prob += pair_probs[std::make_pair(other_string, string)]; - } - inclusion_probs[string] = inclusion_prob / 2; - } - - int kTrials = 10000; - - std::vector kNumSamples = {1, 2}; - - for (const auto num_samples : kNumSamples) { - std::map counts; - for (int i = 0; i < kTrials; i++) { - auto nbests = lattice.NBest(num_samples, true, theta); - for (const auto nbest : nbests) { - counts[GetTokenized(nbest.first)]++; - } - } - - EXPECT_EQ(inclusion_probs.size(), counts.size()); - // If we take multiple samples WOR, we have to use corrected probs. - std::map probs_to_use = - (num_samples == 1 ? probs : inclusion_probs); - - for (const auto &it : probs_to_use) { - EXPECT_NEAR(it.second, 1.0 * counts[it.first] / (kTrials * num_samples), - 0.02); - } - } - } -} - -TEST(LatticeTest, CalculateEntropyTest) { - Lattice lattice; - lattice.SetSentence("ABC"); - - InsertWithScore(&lattice, 0, 1, 0.0); // A - InsertWithScore(&lattice, 1, 1, 0.0); // B - InsertWithScore(&lattice, 2, 1, 0.1); // C - InsertWithScore(&lattice, 0, 2, 0.2); // AB - InsertWithScore(&lattice, 1, 2, 0.5); // BC - InsertWithScore(&lattice, 0, 3, 1.0); // ABC - - // Calculate expected probabilities of each path - const std::vector kTheta = {0.0, 0.01, 0.5, 0.7, 1.0}; - for (const auto theta : kTheta) { - std::vector strings = {"ABC", "AB C", "A BC", "A B C"}; - std::map probs; - probs["ABC"] = std::exp(theta * 1.0); - probs["AB C"] = std::exp(theta * (0.2 + 0.1)); - probs["A BC"] = std::exp(theta * (0.0 + 0.5)); - probs["A B C"] = std::exp(theta * (0.0 + 0.0 + 0.1)); - - double Z = 0.0; - for (const auto &it : probs) Z += it.second; - for (auto &it : probs) it.second /= Z; - - for (const auto &it : strings) { - EXPECT_EQ(1, probs.count(it)); - } - float entropy = 0.0; - for (const auto &it : probs) { - entropy += (it.second * std::log(it.second)); - } - EXPECT_NEAR(-entropy, lattice.CalculateEntropy(theta), 0.02); - } -} - -TEST(LatticeTest, ForwardAlgorithmTest) { - Lattice lattice; - lattice.SetSentence("ABC"); - - InsertWithScore(&lattice, 0, 1, 0.0); // A - InsertWithScore(&lattice, 1, 1, 0.0); // B - InsertWithScore(&lattice, 2, 1, 0.1); // C - InsertWithScore(&lattice, 0, 2, 0.2); // AB - InsertWithScore(&lattice, 1, 2, 0.5); // BC - InsertWithScore(&lattice, 0, 3, 1.0); // ABC - - const std::vector kTheta = {0.0, 0.01, 0.5, 0.7, 1.0}; - for (const auto theta : kTheta) { - std::vector alpha = lattice.ForwardAlgorithm(theta); - EXPECT_EQ(alpha.size(), 8); // 6 nodes, plus BOS, EOS - // only alpha[C], alpha[EOS] have non-zero alpha - for (int i : {0, 1, 2, 3}) { - for (const auto &node : lattice.begin_nodes(i)) { - if (i < 2) { - EXPECT_EQ(alpha[node->node_id], 0.0); - } else if (i == 2) { - float Z = - std::log(std::exp(theta * (0.0 + 0.0)) + std::exp(theta * 0.2)); - EXPECT_EQ(alpha[node->node_id], Z); - } else if (i == 3) { - float Z = std::log(std::exp(theta * (0.0 + 0.0 + 0.1)) + // A + B + C - std::exp(theta * (0.2 + 0.1)) + // AB + C - std::exp(theta * (0.0 + 0.5)) + // A + BC - std::exp(theta * 1.0)); // ABC - EXPECT_EQ(Z, alpha[node->node_id]); - } - } - } - } -} - TEST(LatticeTest, PopulateMarginalTest) { Lattice lattice; lattice.SetSentence("ABC"); @@ -514,102 +361,6 @@ TEST(UnigramModelTest, SetUnigramModelTest) { model.model_proto().SerializeAsString()); } -TEST(UnigramModelTest, SampleEncodeAndScoreTest) { - // Test whether inclusion probabilities are correct - ModelProto model_proto = MakeBaseModelProto(); - AddPiece(&model_proto, "A", 0.0); // 3 - AddPiece(&model_proto, "B", 0.0); // 4 - AddPiece(&model_proto, "C", 0.1); // 5 - AddPiece(&model_proto, "AB", 0.2); // 6 - AddPiece(&model_proto, "BC", 0.5); // 7 - AddPiece(&model_proto, "ABC", 1.0); // 8 - - Model model(model_proto); - - Lattice lattice; - lattice.SetSentence("ABC"); - model.PopulateNodes(&lattice); - - std::vector kTheta = {0.0, 1.0}; - - for (const auto theta : kTheta) { - std::vector strings = {"ABC", "AB C", "A BC", "A B C"}; - std::map probs; - probs["ABC"] = std::exp(theta * 1.0); - probs["AB C"] = std::exp(theta * (0.2 + 0.1)); - probs["A BC"] = std::exp(theta * (0.0 + 0.5)); - probs["A B C"] = std::exp(theta * (0.0 + 0.0 + 0.1)); - - for (const auto &it : strings) { - EXPECT_EQ(1, probs.count(it)); - } - - double Z = 0.0; - for (const auto &it : probs) Z += it.second; - for (auto &it : probs) it.second /= Z; - - std::map, float> pair_probs; - for (const auto first : strings) { - for (const auto second : strings) { - if (first == second) { - pair_probs[std::make_pair(first, second)] = 0; - } else { - float first_prob = probs[first]; - float second_prob = probs[second] / (1 - first_prob); - pair_probs[std::make_pair(first, second)] = first_prob * second_prob; - } - } - } - - std::map inclusion_probs; - for (const auto string : strings) { - float inclusion_prob = 0.0; - for (const auto other_string : strings) { - inclusion_prob += pair_probs[std::make_pair(string, other_string)]; - } - for (const auto other_string : strings) { - inclusion_prob += pair_probs[std::make_pair(other_string, string)]; - } - inclusion_probs[string] = inclusion_prob / 2; - } - std::vector kNumSamples = {1, 2}; - - for (const auto num_samples : kNumSamples) { - std::map counts; - std::map scores; - int kTrials = 50000; - for (int i = 0; i < kTrials; i++) { - NBestEncodeResult sample = - model.SampleEncodeAndScore("ABC", theta, num_samples, true, false); - - for (const auto &it : sample) { - std::vector tokens; - for (const auto &inner_it : it.first) { - tokens.push_back(std::string(inner_it.first)); - } - std::string sample_string = absl::StrJoin(tokens, " "); - counts[sample_string] += 1; - // use the fact that E(1_{i in sample} / score of i) = 1 - // see https://arxiv.org/pdf/1903.06059.pdf appendix D - scores[sample_string] += std::exp(-it.second); - } - } - - // Check that counts and probs are correct - std::map probs_to_use = - (num_samples == 1 ? probs : inclusion_probs); - - for (const auto &it : scores) Z += it.second; - for (const auto &it : probs_to_use) { - EXPECT_NEAR(it.second, 1.0 * counts[it.first] / (kTrials * num_samples), - 0.02); - // The expectation is quite loose, use a higher tolerance - EXPECT_NEAR(1.0, scores[it.first] / kTrials, 0.05); - } - } - } -} - TEST_P(UnigramModelTest, PieceToIdTest) { ModelProto model_proto = MakeBaseModelProto(); diff --git a/src/unigram_model_trainer.cc b/src/unigram_model_trainer.cc index 9615040..f2afc32 100644 --- a/src/unigram_model_trainer.cc +++ b/src/unigram_model_trainer.cc @@ -223,7 +223,7 @@ std::vector Trainer::RunEStep(const TrainerModel &model, float *obj, lattice.SetSentence(w); model.PopulateNodes(&lattice); const float Z = lattice.PopulateMarginal(freq, &expected[n]); - ntokens[n] += lattice.Viterbi().first.size(); + ntokens[n] += lattice.Viterbi().size(); CHECK(!std::isnan(Z)) << "likelihood is NAN. Input sentence may be too long"; objs[n] -= Z / all_sentence_freq; @@ -297,17 +297,17 @@ TrainerModel::SentencePieces Trainer::PruneSentencePieces( const auto &w = sentencepieces[i]; lattice.SetSentence(w.first); model.PopulateNodes(&lattice); - const auto nbests = lattice.NBest(2, false, 0.0); + const auto nbests = lattice.NBest(2); if (nbests.size() == 1) { // No second-best result is found. always keep this sentencepiece. always_keep[i] = true; continue; - } else if (nbests[0].first.size() >= 2) { + } else if (nbests[0].size() >= 2) { // Can safely remove this sentencepiece if its Viterbi path is split. always_keep[i] = false; - } else if (nbests[0].first.size() == 1) { + } else if (nbests[0].size() == 1) { always_keep[i] = true; - for (const auto *node : nbests[1].first) { + for (const auto *node : nbests[1]) { alternatives[i].push_back(node->id); } } @@ -339,7 +339,7 @@ TrainerModel::SentencePieces Trainer::PruneSentencePieces( lattice.SetSentence(w.first); model.PopulateNodes(&lattice); vsums[n] += w.second; - for (const auto *node : lattice.Viterbi().first) { + for (const auto *node : lattice.Viterbi()) { if (node->id >= 0) { freqs[n][node->id] += w.second; inverteds[n][node->id].push_back(i); diff --git a/src/util.cc b/src/util.cc index 8424448..9120673 100644 --- a/src/util.cc +++ b/src/util.cc @@ -12,12 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License.! -#include - #include "util.h" -namespace sentencepiece { +#include +namespace sentencepiece { namespace { constexpr unsigned int kDefaultSeed = static_cast(-1); static unsigned int g_seed = kDefaultSeed; diff --git a/third_party/absl/flags/flag.cc b/third_party/absl/flags/flag.cc index e7ac841..09ff78f 100644 --- a/third_party/absl/flags/flag.cc +++ b/third_party/absl/flags/flag.cc @@ -171,7 +171,6 @@ void Flag::set_value_as_str(const std::string &value_as_str) { template class Flag; template class Flag; -template class Flag; template class Flag; template class Flag; template class Flag;