diff --git a/VERSION.txt b/VERSION.txt index 9c178d3..c65d728 100644 --- a/VERSION.txt +++ b/VERSION.txt @@ -1 +1 @@ -0.1.95 +0.1.96 diff --git a/python/VERSION.txt b/python/VERSION.txt index 9c178d3..c65d728 100644 --- a/python/VERSION.txt +++ b/python/VERSION.txt @@ -1 +1 @@ -0.1.95 +0.1.96 diff --git a/src/bpe_model.cc b/src/bpe_model.cc index 5d77baa..22cd115 100644 --- a/src/bpe_model.cc +++ b/src/bpe_model.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License.! -#include "bpe_model.h" - #include #include #include @@ -21,6 +19,7 @@ #include #include +#include "bpe_model.h" #include "freelist.h" #include "third_party/absl/container/flat_hash_map.h" #include "util.h" diff --git a/src/builder.cc b/src/builder.cc index 88346dd..794ce5f 100644 --- a/src/builder.cc +++ b/src/builder.cc @@ -12,12 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License.! -#include "builder.h" - #include #include #include +#include "builder.h" #include "filesystem.h" #include "third_party/absl/strings/str_join.h" #include "third_party/absl/strings/str_replace.h" @@ -368,6 +367,7 @@ util::Status Builder::BuildNmtNFKCMap(CharsMap *chars_map) { nfkc_map[{0xFEFF}] = {0x20}; // ZERO WIDTH NO-BREAK nfkc_map[{0xFFFD}] = {0x20}; // REPLACEMENT CHARACTER nfkc_map[{0x200C}] = {0x20}; // ZERO WIDTH NON-JOINER + nfkc_map[{0x200D}] = {0x20}; // ZERO WIDTH JOINER // Ascii Control characters nfkc_map[{0x0001}] = {}; diff --git a/src/builtin_pb/sentencepiece_model.pb.cc b/src/builtin_pb/sentencepiece_model.pb.cc index 4863136..e913731 100644 --- a/src/builtin_pb/sentencepiece_model.pb.cc +++ b/src/builtin_pb/sentencepiece_model.pb.cc @@ -285,22 +285,22 @@ class TrainerSpec::_Internal { (*has_bits)[0] |= 1u; } static void set_has_model_type(HasBits* has_bits) { - (*has_bits)[0] |= 262144u; + (*has_bits)[0] |= 524288u; } static void set_has_vocab_size(HasBits* has_bits) { - (*has_bits)[0] |= 524288u; + (*has_bits)[0] |= 1048576u; } static void set_has_self_test_sample_size(HasBits* has_bits) { (*has_bits)[0] |= 256u; } static void set_has_character_coverage(HasBits* has_bits) { - (*has_bits)[0] |= 1048576u; + (*has_bits)[0] |= 2097152u; } static void set_has_input_sentence_size(HasBits* has_bits) { (*has_bits)[0] |= 1024u; } static void set_has_shuffle_input_sentence(HasBits* has_bits) { - (*has_bits)[0] |= 134217728u; + (*has_bits)[0] |= 268435456u; } static void set_has_mining_sentence_size(HasBits* has_bits) { (*has_bits)[0] |= 512u; @@ -309,65 +309,68 @@ class TrainerSpec::_Internal { (*has_bits)[0] |= 2048u; } static void set_has_seed_sentencepiece_size(HasBits* has_bits) { - (*has_bits)[0] |= 2097152u; - } - static void set_has_shrinking_factor(HasBits* has_bits) { (*has_bits)[0] |= 4194304u; } - static void set_has_max_sentence_length(HasBits* has_bits) { - (*has_bits)[0] |= 33554432u; - } - static void set_has_num_threads(HasBits* has_bits) { + static void set_has_shrinking_factor(HasBits* has_bits) { (*has_bits)[0] |= 8388608u; } - static void set_has_num_sub_iterations(HasBits* has_bits) { - (*has_bits)[0] |= 16777216u; - } - static void set_has_max_sentencepiece_length(HasBits* has_bits) { + static void set_has_max_sentence_length(HasBits* has_bits) { (*has_bits)[0] |= 67108864u; } - static void set_has_split_by_unicode_script(HasBits* has_bits) { - (*has_bits)[0] |= 268435456u; + static void set_has_num_threads(HasBits* has_bits) { + (*has_bits)[0] |= 16777216u; } - static void set_has_split_by_number(HasBits* has_bits) { + static void set_has_num_sub_iterations(HasBits* has_bits) { + (*has_bits)[0] |= 33554432u; + } + static void set_has_max_sentencepiece_length(HasBits* has_bits) { + (*has_bits)[0] |= 134217728u; + } + static void set_has_split_by_unicode_script(HasBits* has_bits) { (*has_bits)[0] |= 536870912u; } - static void set_has_split_by_whitespace(HasBits* has_bits) { + static void set_has_split_by_number(HasBits* has_bits) { (*has_bits)[0] |= 1073741824u; } + static void set_has_split_by_whitespace(HasBits* has_bits) { + (*has_bits)[0] |= 2147483648u; + } static void set_has_treat_whitespace_as_suffix(HasBits* has_bits) { (*has_bits)[0] |= 4096u; } - static void set_has_split_digits(HasBits* has_bits) { + static void set_has_allow_whitespace_only_pieces(HasBits* has_bits) { (*has_bits)[0] |= 8192u; } + static void set_has_split_digits(HasBits* has_bits) { + (*has_bits)[0] |= 16384u; + } static void set_has_required_chars(HasBits* has_bits) { (*has_bits)[0] |= 4u; } static void set_has_byte_fallback(HasBits* has_bits) { - (*has_bits)[0] |= 16384u; + (*has_bits)[0] |= 32768u; } static void set_has_vocabulary_output_piece_score(HasBits* has_bits) { - (*has_bits)[0] |= 2147483648u; - } - static void set_has_hard_vocab_limit(HasBits* has_bits) { (*has_bits)[1] |= 1u; } + static void set_has_hard_vocab_limit(HasBits* has_bits) { + (*has_bits)[1] |= 2u; + } static void set_has_use_all_vocab(HasBits* has_bits) { - (*has_bits)[0] |= 32768u; + (*has_bits)[0] |= 131072u; } static void set_has_unk_id(HasBits* has_bits) { (*has_bits)[0] |= 65536u; } static void set_has_bos_id(HasBits* has_bits) { - (*has_bits)[1] |= 2u; - } - static void set_has_eos_id(HasBits* has_bits) { (*has_bits)[1] |= 4u; } - static void set_has_pad_id(HasBits* has_bits) { + static void set_has_eos_id(HasBits* has_bits) { (*has_bits)[1] |= 8u; } + static void set_has_pad_id(HasBits* has_bits) { + (*has_bits)[1] |= 16u; + } static void set_has_unk_piece(HasBits* has_bits) { (*has_bits)[0] |= 16u; } @@ -384,7 +387,7 @@ class TrainerSpec::_Internal { (*has_bits)[0] |= 8u; } static void set_has_train_extremely_large_corpus(HasBits* has_bits) { - (*has_bits)[0] |= 131072u; + (*has_bits)[0] |= 262144u; } }; @@ -566,8 +569,8 @@ void TrainerSpec::Clear() { } if (cached_has_bits & 0x0000ff00u) { ::memset(&self_test_sample_size_, 0, static_cast( - reinterpret_cast(&use_all_vocab_) - - reinterpret_cast(&self_test_sample_size_)) + sizeof(use_all_vocab_)); + reinterpret_cast(&byte_fallback_) - + reinterpret_cast(&self_test_sample_size_)) + sizeof(byte_fallback_)); } if (cached_has_bits & 0x00ff0000u) { ::memset(&unk_id_, 0, static_cast( @@ -578,9 +581,9 @@ void TrainerSpec::Clear() { character_coverage_ = 0.9995f; seed_sentencepiece_size_ = 1000000; shrinking_factor_ = 0.75f; - num_threads_ = 16; } if (cached_has_bits & 0xff000000u) { + num_threads_ = 16; num_sub_iterations_ = 2; max_sentence_length_ = 4192; max_sentencepiece_length_ = 16; @@ -588,10 +591,10 @@ void TrainerSpec::Clear() { split_by_unicode_script_ = true; split_by_number_ = true; split_by_whitespace_ = true; - vocabulary_output_piece_score_ = true; } cached_has_bits = _has_bits_[1]; - if (cached_has_bits & 0x0000000fu) { + if (cached_has_bits & 0x0000001fu) { + vocabulary_output_piece_score_ = true; hard_vocab_limit_ = true; bos_id_ = 1; eos_id_ = 2; @@ -806,6 +809,14 @@ const char* TrainerSpec::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID CHK_(ptr); } else goto handle_unusual; continue; + // optional bool allow_whitespace_only_pieces = 26 [default = false]; + case 26: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 208)) { + _Internal::set_has_allow_whitespace_only_pieces(&_has_bits_); + allow_whitespace_only_pieces_ = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; // repeated string control_symbols = 30; case 30: if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 242)) { @@ -1000,14 +1011,14 @@ failure: } // optional .sentencepiece.TrainerSpec.ModelType model_type = 3 [default = UNIGRAM]; - if (cached_has_bits & 0x00040000u) { + if (cached_has_bits & 0x00080000u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteEnumToArray( 3, this->_internal_model_type(), target); } // optional int32 vocab_size = 4 [default = 8000]; - if (cached_has_bits & 0x00080000u) { + if (cached_has_bits & 0x00100000u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(4, this->_internal_vocab_size(), target); } @@ -1031,7 +1042,7 @@ failure: } // optional float character_coverage = 10 [default = 0.9995]; - if (cached_has_bits & 0x00100000u) { + if (cached_has_bits & 0x00200000u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteFloatToArray(10, this->_internal_character_coverage(), target); } @@ -1055,61 +1066,61 @@ failure: } // optional int32 seed_sentencepiece_size = 14 [default = 1000000]; - if (cached_has_bits & 0x00200000u) { + if (cached_has_bits & 0x00400000u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(14, this->_internal_seed_sentencepiece_size(), target); } // optional float shrinking_factor = 15 [default = 0.75]; - if (cached_has_bits & 0x00400000u) { + if (cached_has_bits & 0x00800000u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteFloatToArray(15, this->_internal_shrinking_factor(), target); } // optional int32 num_threads = 16 [default = 16]; - if (cached_has_bits & 0x00800000u) { + if (cached_has_bits & 0x01000000u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(16, this->_internal_num_threads(), target); } // optional int32 num_sub_iterations = 17 [default = 2]; - if (cached_has_bits & 0x01000000u) { + if (cached_has_bits & 0x02000000u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(17, this->_internal_num_sub_iterations(), target); } // optional int32 max_sentence_length = 18 [default = 4192]; - if (cached_has_bits & 0x02000000u) { + if (cached_has_bits & 0x04000000u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(18, this->_internal_max_sentence_length(), target); } // optional bool shuffle_input_sentence = 19 [default = true]; - if (cached_has_bits & 0x08000000u) { + if (cached_has_bits & 0x10000000u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(19, this->_internal_shuffle_input_sentence(), target); } // optional int32 max_sentencepiece_length = 20 [default = 16]; - if (cached_has_bits & 0x04000000u) { + if (cached_has_bits & 0x08000000u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(20, this->_internal_max_sentencepiece_length(), target); } // optional bool split_by_unicode_script = 21 [default = true]; - if (cached_has_bits & 0x10000000u) { + if (cached_has_bits & 0x20000000u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(21, this->_internal_split_by_unicode_script(), target); } // optional bool split_by_whitespace = 22 [default = true]; - if (cached_has_bits & 0x40000000u) { + if (cached_has_bits & 0x80000000u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(22, this->_internal_split_by_whitespace(), target); } // optional bool split_by_number = 23 [default = true]; - if (cached_has_bits & 0x20000000u) { + if (cached_has_bits & 0x40000000u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(23, this->_internal_split_by_number(), target); } @@ -1121,11 +1132,17 @@ failure: } // optional bool split_digits = 25 [default = false]; - if (cached_has_bits & 0x00002000u) { + if (cached_has_bits & 0x00004000u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(25, this->_internal_split_digits(), target); } + // optional bool allow_whitespace_only_pieces = 26 [default = false]; + if (cached_has_bits & 0x00002000u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(26, this->_internal_allow_whitespace_only_pieces(), target); + } + // repeated string control_symbols = 30; for (int i = 0, n = this->_internal_control_symbols_size(); i < n; i++) { const auto& s = this->_internal_control_symbols(i); @@ -1138,28 +1155,28 @@ failure: target = stream->WriteString(31, s, target); } + cached_has_bits = _has_bits_[1]; // optional bool vocabulary_output_piece_score = 32 [default = true]; - if (cached_has_bits & 0x80000000u) { + if (cached_has_bits & 0x00000001u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(32, this->_internal_vocabulary_output_piece_score(), target); } - cached_has_bits = _has_bits_[1]; // optional bool hard_vocab_limit = 33 [default = true]; - if (cached_has_bits & 0x00000001u) { + if (cached_has_bits & 0x00000002u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(33, this->_internal_hard_vocab_limit(), target); } cached_has_bits = _has_bits_[0]; // optional bool use_all_vocab = 34 [default = false]; - if (cached_has_bits & 0x00008000u) { + if (cached_has_bits & 0x00020000u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(34, this->_internal_use_all_vocab(), target); } // optional bool byte_fallback = 35 [default = false]; - if (cached_has_bits & 0x00004000u) { + if (cached_has_bits & 0x00008000u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(35, this->_internal_byte_fallback(), target); } @@ -1178,19 +1195,19 @@ failure: cached_has_bits = _has_bits_[1]; // optional int32 bos_id = 41 [default = 1]; - if (cached_has_bits & 0x00000002u) { + if (cached_has_bits & 0x00000004u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(41, this->_internal_bos_id(), target); } // optional int32 eos_id = 42 [default = 2]; - if (cached_has_bits & 0x00000004u) { + if (cached_has_bits & 0x00000008u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(42, this->_internal_eos_id(), target); } // optional int32 pad_id = 43 [default = -1]; - if (cached_has_bits & 0x00000008u) { + if (cached_has_bits & 0x00000010u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(43, this->_internal_pad_id(), target); } @@ -1227,7 +1244,7 @@ failure: } // optional bool train_extremely_large_corpus = 49 [default = false]; - if (cached_has_bits & 0x00020000u) { + if (cached_has_bits & 0x00040000u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(49, this->_internal_train_extremely_large_corpus(), target); } @@ -1379,17 +1396,17 @@ size_t TrainerSpec::ByteSizeLong() const { total_size += 2 + 1; } - // optional bool split_digits = 25 [default = false]; + // optional bool allow_whitespace_only_pieces = 26 [default = false]; if (cached_has_bits & 0x00002000u) { total_size += 2 + 1; } - // optional bool byte_fallback = 35 [default = false]; + // optional bool split_digits = 25 [default = false]; if (cached_has_bits & 0x00004000u) { total_size += 2 + 1; } - // optional bool use_all_vocab = 34 [default = false]; + // optional bool byte_fallback = 35 [default = false]; if (cached_has_bits & 0x00008000u) { total_size += 2 + 1; } @@ -1403,120 +1420,125 @@ size_t TrainerSpec::ByteSizeLong() const { this->_internal_unk_id()); } - // optional bool train_extremely_large_corpus = 49 [default = false]; + // optional bool use_all_vocab = 34 [default = false]; if (cached_has_bits & 0x00020000u) { total_size += 2 + 1; } - // optional .sentencepiece.TrainerSpec.ModelType model_type = 3 [default = UNIGRAM]; + // optional bool train_extremely_large_corpus = 49 [default = false]; if (cached_has_bits & 0x00040000u) { + total_size += 2 + 1; + } + + // optional .sentencepiece.TrainerSpec.ModelType model_type = 3 [default = UNIGRAM]; + if (cached_has_bits & 0x00080000u) { total_size += 1 + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::EnumSize(this->_internal_model_type()); } // optional int32 vocab_size = 4 [default = 8000]; - if (cached_has_bits & 0x00080000u) { + if (cached_has_bits & 0x00100000u) { total_size += 1 + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( this->_internal_vocab_size()); } // optional float character_coverage = 10 [default = 0.9995]; - if (cached_has_bits & 0x00100000u) { + if (cached_has_bits & 0x00200000u) { total_size += 1 + 4; } // optional int32 seed_sentencepiece_size = 14 [default = 1000000]; - if (cached_has_bits & 0x00200000u) { + if (cached_has_bits & 0x00400000u) { total_size += 1 + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( this->_internal_seed_sentencepiece_size()); } // optional float shrinking_factor = 15 [default = 0.75]; - if (cached_has_bits & 0x00400000u) { + if (cached_has_bits & 0x00800000u) { total_size += 1 + 4; } + } + if (cached_has_bits & 0xff000000u) { // optional int32 num_threads = 16 [default = 16]; - if (cached_has_bits & 0x00800000u) { + if (cached_has_bits & 0x01000000u) { total_size += 2 + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( this->_internal_num_threads()); } - } - if (cached_has_bits & 0xff000000u) { // optional int32 num_sub_iterations = 17 [default = 2]; - if (cached_has_bits & 0x01000000u) { + if (cached_has_bits & 0x02000000u) { total_size += 2 + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( this->_internal_num_sub_iterations()); } // optional int32 max_sentence_length = 18 [default = 4192]; - if (cached_has_bits & 0x02000000u) { + if (cached_has_bits & 0x04000000u) { total_size += 2 + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( this->_internal_max_sentence_length()); } // optional int32 max_sentencepiece_length = 20 [default = 16]; - if (cached_has_bits & 0x04000000u) { + if (cached_has_bits & 0x08000000u) { total_size += 2 + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( this->_internal_max_sentencepiece_length()); } // optional bool shuffle_input_sentence = 19 [default = true]; - if (cached_has_bits & 0x08000000u) { - total_size += 2 + 1; - } - - // optional bool split_by_unicode_script = 21 [default = true]; if (cached_has_bits & 0x10000000u) { total_size += 2 + 1; } - // optional bool split_by_number = 23 [default = true]; + // optional bool split_by_unicode_script = 21 [default = true]; if (cached_has_bits & 0x20000000u) { total_size += 2 + 1; } - // optional bool split_by_whitespace = 22 [default = true]; + // optional bool split_by_number = 23 [default = true]; if (cached_has_bits & 0x40000000u) { total_size += 2 + 1; } - // optional bool vocabulary_output_piece_score = 32 [default = true]; + // optional bool split_by_whitespace = 22 [default = true]; if (cached_has_bits & 0x80000000u) { total_size += 2 + 1; } } cached_has_bits = _has_bits_[1]; - if (cached_has_bits & 0x0000000fu) { - // optional bool hard_vocab_limit = 33 [default = true]; + if (cached_has_bits & 0x0000001fu) { + // optional bool vocabulary_output_piece_score = 32 [default = true]; if (cached_has_bits & 0x00000001u) { total_size += 2 + 1; } - // optional int32 bos_id = 41 [default = 1]; + // optional bool hard_vocab_limit = 33 [default = true]; if (cached_has_bits & 0x00000002u) { + total_size += 2 + 1; + } + + // optional int32 bos_id = 41 [default = 1]; + if (cached_has_bits & 0x00000004u) { total_size += 2 + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( this->_internal_bos_id()); } // optional int32 eos_id = 42 [default = 2]; - if (cached_has_bits & 0x00000004u) { + if (cached_has_bits & 0x00000008u) { total_size += 2 + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( this->_internal_eos_id()); } // optional int32 pad_id = 43 [default = -1]; - if (cached_has_bits & 0x00000008u) { + if (cached_has_bits & 0x00000010u) { total_size += 2 + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( this->_internal_pad_id()); @@ -1593,13 +1615,13 @@ void TrainerSpec::MergeFrom(const TrainerSpec& from) { treat_whitespace_as_suffix_ = from.treat_whitespace_as_suffix_; } if (cached_has_bits & 0x00002000u) { - split_digits_ = from.split_digits_; + allow_whitespace_only_pieces_ = from.allow_whitespace_only_pieces_; } if (cached_has_bits & 0x00004000u) { - byte_fallback_ = from.byte_fallback_; + split_digits_ = from.split_digits_; } if (cached_has_bits & 0x00008000u) { - use_all_vocab_ = from.use_all_vocab_; + byte_fallback_ = from.byte_fallback_; } _has_bits_[0] |= cached_has_bits; } @@ -1608,67 +1630,70 @@ void TrainerSpec::MergeFrom(const TrainerSpec& from) { unk_id_ = from.unk_id_; } if (cached_has_bits & 0x00020000u) { - train_extremely_large_corpus_ = from.train_extremely_large_corpus_; + use_all_vocab_ = from.use_all_vocab_; } if (cached_has_bits & 0x00040000u) { - model_type_ = from.model_type_; + train_extremely_large_corpus_ = from.train_extremely_large_corpus_; } if (cached_has_bits & 0x00080000u) { - vocab_size_ = from.vocab_size_; + model_type_ = from.model_type_; } if (cached_has_bits & 0x00100000u) { - character_coverage_ = from.character_coverage_; + vocab_size_ = from.vocab_size_; } if (cached_has_bits & 0x00200000u) { - seed_sentencepiece_size_ = from.seed_sentencepiece_size_; + character_coverage_ = from.character_coverage_; } if (cached_has_bits & 0x00400000u) { - shrinking_factor_ = from.shrinking_factor_; + seed_sentencepiece_size_ = from.seed_sentencepiece_size_; } if (cached_has_bits & 0x00800000u) { - num_threads_ = from.num_threads_; + shrinking_factor_ = from.shrinking_factor_; } _has_bits_[0] |= cached_has_bits; } if (cached_has_bits & 0xff000000u) { if (cached_has_bits & 0x01000000u) { - num_sub_iterations_ = from.num_sub_iterations_; + num_threads_ = from.num_threads_; } if (cached_has_bits & 0x02000000u) { - max_sentence_length_ = from.max_sentence_length_; + num_sub_iterations_ = from.num_sub_iterations_; } if (cached_has_bits & 0x04000000u) { - max_sentencepiece_length_ = from.max_sentencepiece_length_; + max_sentence_length_ = from.max_sentence_length_; } if (cached_has_bits & 0x08000000u) { - shuffle_input_sentence_ = from.shuffle_input_sentence_; + max_sentencepiece_length_ = from.max_sentencepiece_length_; } if (cached_has_bits & 0x10000000u) { - split_by_unicode_script_ = from.split_by_unicode_script_; + shuffle_input_sentence_ = from.shuffle_input_sentence_; } if (cached_has_bits & 0x20000000u) { - split_by_number_ = from.split_by_number_; + split_by_unicode_script_ = from.split_by_unicode_script_; } if (cached_has_bits & 0x40000000u) { - split_by_whitespace_ = from.split_by_whitespace_; + split_by_number_ = from.split_by_number_; } if (cached_has_bits & 0x80000000u) { - vocabulary_output_piece_score_ = from.vocabulary_output_piece_score_; + split_by_whitespace_ = from.split_by_whitespace_; } _has_bits_[0] |= cached_has_bits; } cached_has_bits = from._has_bits_[1]; - if (cached_has_bits & 0x0000000fu) { + if (cached_has_bits & 0x0000001fu) { if (cached_has_bits & 0x00000001u) { - hard_vocab_limit_ = from.hard_vocab_limit_; + vocabulary_output_piece_score_ = from.vocabulary_output_piece_score_; } if (cached_has_bits & 0x00000002u) { - bos_id_ = from.bos_id_; + hard_vocab_limit_ = from.hard_vocab_limit_; } if (cached_has_bits & 0x00000004u) { - eos_id_ = from.eos_id_; + bos_id_ = from.bos_id_; } if (cached_has_bits & 0x00000008u) { + eos_id_ = from.eos_id_; + } + if (cached_has_bits & 0x00000010u) { pad_id_ = from.pad_id_; } _has_bits_[1] |= cached_has_bits; diff --git a/src/builtin_pb/sentencepiece_model.pb.h b/src/builtin_pb/sentencepiece_model.pb.h index 31dc65b..f527aa7 100644 --- a/src/builtin_pb/sentencepiece_model.pb.h +++ b/src/builtin_pb/sentencepiece_model.pb.h @@ -278,10 +278,11 @@ class TrainerSpec PROTOBUF_FINAL : kInputSentenceSizeFieldNumber = 11, kTrainingSentenceSizeFieldNumber = 13, kTreatWhitespaceAsSuffixFieldNumber = 24, + kAllowWhitespaceOnlyPiecesFieldNumber = 26, kSplitDigitsFieldNumber = 25, kByteFallbackFieldNumber = 35, - kUseAllVocabFieldNumber = 34, kUnkIdFieldNumber = 40, + kUseAllVocabFieldNumber = 34, kTrainExtremelyLargeCorpusFieldNumber = 49, kModelTypeFieldNumber = 3, kVocabSizeFieldNumber = 4, @@ -623,6 +624,19 @@ class TrainerSpec PROTOBUF_FINAL : void _internal_set_treat_whitespace_as_suffix(bool value); public: + // optional bool allow_whitespace_only_pieces = 26 [default = false]; + bool has_allow_whitespace_only_pieces() const; + private: + bool _internal_has_allow_whitespace_only_pieces() const; + public: + void clear_allow_whitespace_only_pieces(); + bool allow_whitespace_only_pieces() const; + void set_allow_whitespace_only_pieces(bool value); + private: + bool _internal_allow_whitespace_only_pieces() const; + void _internal_set_allow_whitespace_only_pieces(bool value); + public: + // optional bool split_digits = 25 [default = false]; bool has_split_digits() const; private: @@ -649,19 +663,6 @@ class TrainerSpec PROTOBUF_FINAL : void _internal_set_byte_fallback(bool value); public: - // optional bool use_all_vocab = 34 [default = false]; - bool has_use_all_vocab() const; - private: - bool _internal_has_use_all_vocab() const; - public: - void clear_use_all_vocab(); - bool use_all_vocab() const; - void set_use_all_vocab(bool value); - private: - bool _internal_use_all_vocab() const; - void _internal_set_use_all_vocab(bool value); - public: - // optional int32 unk_id = 40 [default = 0]; bool has_unk_id() const; private: @@ -675,6 +676,19 @@ class TrainerSpec PROTOBUF_FINAL : void _internal_set_unk_id(::PROTOBUF_NAMESPACE_ID::int32 value); public: + // optional bool use_all_vocab = 34 [default = false]; + bool has_use_all_vocab() const; + private: + bool _internal_has_use_all_vocab() const; + public: + void clear_use_all_vocab(); + bool use_all_vocab() const; + void set_use_all_vocab(bool value); + private: + bool _internal_use_all_vocab() const; + void _internal_set_use_all_vocab(bool value); + public: + // optional bool train_extremely_large_corpus = 49 [default = false]; bool has_train_extremely_large_corpus() const; private: @@ -956,10 +970,11 @@ class TrainerSpec PROTOBUF_FINAL : ::PROTOBUF_NAMESPACE_ID::uint64 input_sentence_size_; ::PROTOBUF_NAMESPACE_ID::int32 training_sentence_size_; bool treat_whitespace_as_suffix_; + bool allow_whitespace_only_pieces_; bool split_digits_; bool byte_fallback_; - bool use_all_vocab_; ::PROTOBUF_NAMESPACE_ID::int32 unk_id_; + bool use_all_vocab_; bool train_extremely_large_corpus_; int model_type_; ::PROTOBUF_NAMESPACE_ID::int32 vocab_size_; @@ -2180,7 +2195,7 @@ inline void TrainerSpec::set_allocated_model_prefix(std::string* model_prefix) { // optional .sentencepiece.TrainerSpec.ModelType model_type = 3 [default = UNIGRAM]; inline bool TrainerSpec::_internal_has_model_type() const { - bool value = (_has_bits_[0] & 0x00040000u) != 0; + bool value = (_has_bits_[0] & 0x00080000u) != 0; return value; } inline bool TrainerSpec::has_model_type() const { @@ -2188,7 +2203,7 @@ inline bool TrainerSpec::has_model_type() const { } inline void TrainerSpec::clear_model_type() { model_type_ = 1; - _has_bits_[0] &= ~0x00040000u; + _has_bits_[0] &= ~0x00080000u; } inline ::sentencepiece::TrainerSpec_ModelType TrainerSpec::_internal_model_type() const { return static_cast< ::sentencepiece::TrainerSpec_ModelType >(model_type_); @@ -2199,7 +2214,7 @@ inline ::sentencepiece::TrainerSpec_ModelType TrainerSpec::model_type() const { } inline void TrainerSpec::_internal_set_model_type(::sentencepiece::TrainerSpec_ModelType value) { assert(::sentencepiece::TrainerSpec_ModelType_IsValid(value)); - _has_bits_[0] |= 0x00040000u; + _has_bits_[0] |= 0x00080000u; model_type_ = value; } inline void TrainerSpec::set_model_type(::sentencepiece::TrainerSpec_ModelType value) { @@ -2209,7 +2224,7 @@ inline void TrainerSpec::set_model_type(::sentencepiece::TrainerSpec_ModelType v // optional int32 vocab_size = 4 [default = 8000]; inline bool TrainerSpec::_internal_has_vocab_size() const { - bool value = (_has_bits_[0] & 0x00080000u) != 0; + bool value = (_has_bits_[0] & 0x00100000u) != 0; return value; } inline bool TrainerSpec::has_vocab_size() const { @@ -2217,7 +2232,7 @@ inline bool TrainerSpec::has_vocab_size() const { } inline void TrainerSpec::clear_vocab_size() { vocab_size_ = 8000; - _has_bits_[0] &= ~0x00080000u; + _has_bits_[0] &= ~0x00100000u; } inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_vocab_size() const { return vocab_size_; @@ -2227,7 +2242,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::vocab_size() const { return _internal_vocab_size(); } inline void TrainerSpec::_internal_set_vocab_size(::PROTOBUF_NAMESPACE_ID::int32 value) { - _has_bits_[0] |= 0x00080000u; + _has_bits_[0] |= 0x00100000u; vocab_size_ = value; } inline void TrainerSpec::set_vocab_size(::PROTOBUF_NAMESPACE_ID::int32 value) { @@ -2339,7 +2354,7 @@ inline void TrainerSpec::set_self_test_sample_size(::PROTOBUF_NAMESPACE_ID::int3 // optional float character_coverage = 10 [default = 0.9995]; inline bool TrainerSpec::_internal_has_character_coverage() const { - bool value = (_has_bits_[0] & 0x00100000u) != 0; + bool value = (_has_bits_[0] & 0x00200000u) != 0; return value; } inline bool TrainerSpec::has_character_coverage() const { @@ -2347,7 +2362,7 @@ inline bool TrainerSpec::has_character_coverage() const { } inline void TrainerSpec::clear_character_coverage() { character_coverage_ = 0.9995f; - _has_bits_[0] &= ~0x00100000u; + _has_bits_[0] &= ~0x00200000u; } inline float TrainerSpec::_internal_character_coverage() const { return character_coverage_; @@ -2357,7 +2372,7 @@ inline float TrainerSpec::character_coverage() const { return _internal_character_coverage(); } inline void TrainerSpec::_internal_set_character_coverage(float value) { - _has_bits_[0] |= 0x00100000u; + _has_bits_[0] |= 0x00200000u; character_coverage_ = value; } inline void TrainerSpec::set_character_coverage(float value) { @@ -2395,7 +2410,7 @@ inline void TrainerSpec::set_input_sentence_size(::PROTOBUF_NAMESPACE_ID::uint64 // optional bool shuffle_input_sentence = 19 [default = true]; inline bool TrainerSpec::_internal_has_shuffle_input_sentence() const { - bool value = (_has_bits_[0] & 0x08000000u) != 0; + bool value = (_has_bits_[0] & 0x10000000u) != 0; return value; } inline bool TrainerSpec::has_shuffle_input_sentence() const { @@ -2403,7 +2418,7 @@ inline bool TrainerSpec::has_shuffle_input_sentence() const { } inline void TrainerSpec::clear_shuffle_input_sentence() { shuffle_input_sentence_ = true; - _has_bits_[0] &= ~0x08000000u; + _has_bits_[0] &= ~0x10000000u; } inline bool TrainerSpec::_internal_shuffle_input_sentence() const { return shuffle_input_sentence_; @@ -2413,7 +2428,7 @@ inline bool TrainerSpec::shuffle_input_sentence() const { return _internal_shuffle_input_sentence(); } inline void TrainerSpec::_internal_set_shuffle_input_sentence(bool value) { - _has_bits_[0] |= 0x08000000u; + _has_bits_[0] |= 0x10000000u; shuffle_input_sentence_ = value; } inline void TrainerSpec::set_shuffle_input_sentence(bool value) { @@ -2479,7 +2494,7 @@ inline void TrainerSpec::set_training_sentence_size(::PROTOBUF_NAMESPACE_ID::int // optional int32 seed_sentencepiece_size = 14 [default = 1000000]; inline bool TrainerSpec::_internal_has_seed_sentencepiece_size() const { - bool value = (_has_bits_[0] & 0x00200000u) != 0; + bool value = (_has_bits_[0] & 0x00400000u) != 0; return value; } inline bool TrainerSpec::has_seed_sentencepiece_size() const { @@ -2487,7 +2502,7 @@ inline bool TrainerSpec::has_seed_sentencepiece_size() const { } inline void TrainerSpec::clear_seed_sentencepiece_size() { seed_sentencepiece_size_ = 1000000; - _has_bits_[0] &= ~0x00200000u; + _has_bits_[0] &= ~0x00400000u; } inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_seed_sentencepiece_size() const { return seed_sentencepiece_size_; @@ -2497,7 +2512,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::seed_sentencepiece_size() con return _internal_seed_sentencepiece_size(); } inline void TrainerSpec::_internal_set_seed_sentencepiece_size(::PROTOBUF_NAMESPACE_ID::int32 value) { - _has_bits_[0] |= 0x00200000u; + _has_bits_[0] |= 0x00400000u; seed_sentencepiece_size_ = value; } inline void TrainerSpec::set_seed_sentencepiece_size(::PROTOBUF_NAMESPACE_ID::int32 value) { @@ -2507,7 +2522,7 @@ inline void TrainerSpec::set_seed_sentencepiece_size(::PROTOBUF_NAMESPACE_ID::in // optional float shrinking_factor = 15 [default = 0.75]; inline bool TrainerSpec::_internal_has_shrinking_factor() const { - bool value = (_has_bits_[0] & 0x00400000u) != 0; + bool value = (_has_bits_[0] & 0x00800000u) != 0; return value; } inline bool TrainerSpec::has_shrinking_factor() const { @@ -2515,7 +2530,7 @@ inline bool TrainerSpec::has_shrinking_factor() const { } inline void TrainerSpec::clear_shrinking_factor() { shrinking_factor_ = 0.75f; - _has_bits_[0] &= ~0x00400000u; + _has_bits_[0] &= ~0x00800000u; } inline float TrainerSpec::_internal_shrinking_factor() const { return shrinking_factor_; @@ -2525,7 +2540,7 @@ inline float TrainerSpec::shrinking_factor() const { return _internal_shrinking_factor(); } inline void TrainerSpec::_internal_set_shrinking_factor(float value) { - _has_bits_[0] |= 0x00400000u; + _has_bits_[0] |= 0x00800000u; shrinking_factor_ = value; } inline void TrainerSpec::set_shrinking_factor(float value) { @@ -2535,7 +2550,7 @@ inline void TrainerSpec::set_shrinking_factor(float value) { // optional int32 max_sentence_length = 18 [default = 4192]; inline bool TrainerSpec::_internal_has_max_sentence_length() const { - bool value = (_has_bits_[0] & 0x02000000u) != 0; + bool value = (_has_bits_[0] & 0x04000000u) != 0; return value; } inline bool TrainerSpec::has_max_sentence_length() const { @@ -2543,7 +2558,7 @@ inline bool TrainerSpec::has_max_sentence_length() const { } inline void TrainerSpec::clear_max_sentence_length() { max_sentence_length_ = 4192; - _has_bits_[0] &= ~0x02000000u; + _has_bits_[0] &= ~0x04000000u; } inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_max_sentence_length() const { return max_sentence_length_; @@ -2553,7 +2568,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::max_sentence_length() const { return _internal_max_sentence_length(); } inline void TrainerSpec::_internal_set_max_sentence_length(::PROTOBUF_NAMESPACE_ID::int32 value) { - _has_bits_[0] |= 0x02000000u; + _has_bits_[0] |= 0x04000000u; max_sentence_length_ = value; } inline void TrainerSpec::set_max_sentence_length(::PROTOBUF_NAMESPACE_ID::int32 value) { @@ -2563,7 +2578,7 @@ inline void TrainerSpec::set_max_sentence_length(::PROTOBUF_NAMESPACE_ID::int32 // optional int32 num_threads = 16 [default = 16]; inline bool TrainerSpec::_internal_has_num_threads() const { - bool value = (_has_bits_[0] & 0x00800000u) != 0; + bool value = (_has_bits_[0] & 0x01000000u) != 0; return value; } inline bool TrainerSpec::has_num_threads() const { @@ -2571,7 +2586,7 @@ inline bool TrainerSpec::has_num_threads() const { } inline void TrainerSpec::clear_num_threads() { num_threads_ = 16; - _has_bits_[0] &= ~0x00800000u; + _has_bits_[0] &= ~0x01000000u; } inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_num_threads() const { return num_threads_; @@ -2581,7 +2596,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::num_threads() const { return _internal_num_threads(); } inline void TrainerSpec::_internal_set_num_threads(::PROTOBUF_NAMESPACE_ID::int32 value) { - _has_bits_[0] |= 0x00800000u; + _has_bits_[0] |= 0x01000000u; num_threads_ = value; } inline void TrainerSpec::set_num_threads(::PROTOBUF_NAMESPACE_ID::int32 value) { @@ -2591,7 +2606,7 @@ inline void TrainerSpec::set_num_threads(::PROTOBUF_NAMESPACE_ID::int32 value) { // optional int32 num_sub_iterations = 17 [default = 2]; inline bool TrainerSpec::_internal_has_num_sub_iterations() const { - bool value = (_has_bits_[0] & 0x01000000u) != 0; + bool value = (_has_bits_[0] & 0x02000000u) != 0; return value; } inline bool TrainerSpec::has_num_sub_iterations() const { @@ -2599,7 +2614,7 @@ inline bool TrainerSpec::has_num_sub_iterations() const { } inline void TrainerSpec::clear_num_sub_iterations() { num_sub_iterations_ = 2; - _has_bits_[0] &= ~0x01000000u; + _has_bits_[0] &= ~0x02000000u; } inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_num_sub_iterations() const { return num_sub_iterations_; @@ -2609,7 +2624,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::num_sub_iterations() const { return _internal_num_sub_iterations(); } inline void TrainerSpec::_internal_set_num_sub_iterations(::PROTOBUF_NAMESPACE_ID::int32 value) { - _has_bits_[0] |= 0x01000000u; + _has_bits_[0] |= 0x02000000u; num_sub_iterations_ = value; } inline void TrainerSpec::set_num_sub_iterations(::PROTOBUF_NAMESPACE_ID::int32 value) { @@ -2619,7 +2634,7 @@ inline void TrainerSpec::set_num_sub_iterations(::PROTOBUF_NAMESPACE_ID::int32 v // optional int32 max_sentencepiece_length = 20 [default = 16]; inline bool TrainerSpec::_internal_has_max_sentencepiece_length() const { - bool value = (_has_bits_[0] & 0x04000000u) != 0; + bool value = (_has_bits_[0] & 0x08000000u) != 0; return value; } inline bool TrainerSpec::has_max_sentencepiece_length() const { @@ -2627,7 +2642,7 @@ inline bool TrainerSpec::has_max_sentencepiece_length() const { } inline void TrainerSpec::clear_max_sentencepiece_length() { max_sentencepiece_length_ = 16; - _has_bits_[0] &= ~0x04000000u; + _has_bits_[0] &= ~0x08000000u; } inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_max_sentencepiece_length() const { return max_sentencepiece_length_; @@ -2637,7 +2652,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::max_sentencepiece_length() co return _internal_max_sentencepiece_length(); } inline void TrainerSpec::_internal_set_max_sentencepiece_length(::PROTOBUF_NAMESPACE_ID::int32 value) { - _has_bits_[0] |= 0x04000000u; + _has_bits_[0] |= 0x08000000u; max_sentencepiece_length_ = value; } inline void TrainerSpec::set_max_sentencepiece_length(::PROTOBUF_NAMESPACE_ID::int32 value) { @@ -2647,7 +2662,7 @@ inline void TrainerSpec::set_max_sentencepiece_length(::PROTOBUF_NAMESPACE_ID::i // optional bool split_by_unicode_script = 21 [default = true]; inline bool TrainerSpec::_internal_has_split_by_unicode_script() const { - bool value = (_has_bits_[0] & 0x10000000u) != 0; + bool value = (_has_bits_[0] & 0x20000000u) != 0; return value; } inline bool TrainerSpec::has_split_by_unicode_script() const { @@ -2655,7 +2670,7 @@ inline bool TrainerSpec::has_split_by_unicode_script() const { } inline void TrainerSpec::clear_split_by_unicode_script() { split_by_unicode_script_ = true; - _has_bits_[0] &= ~0x10000000u; + _has_bits_[0] &= ~0x20000000u; } inline bool TrainerSpec::_internal_split_by_unicode_script() const { return split_by_unicode_script_; @@ -2665,7 +2680,7 @@ inline bool TrainerSpec::split_by_unicode_script() const { return _internal_split_by_unicode_script(); } inline void TrainerSpec::_internal_set_split_by_unicode_script(bool value) { - _has_bits_[0] |= 0x10000000u; + _has_bits_[0] |= 0x20000000u; split_by_unicode_script_ = value; } inline void TrainerSpec::set_split_by_unicode_script(bool value) { @@ -2675,7 +2690,7 @@ inline void TrainerSpec::set_split_by_unicode_script(bool value) { // optional bool split_by_number = 23 [default = true]; inline bool TrainerSpec::_internal_has_split_by_number() const { - bool value = (_has_bits_[0] & 0x20000000u) != 0; + bool value = (_has_bits_[0] & 0x40000000u) != 0; return value; } inline bool TrainerSpec::has_split_by_number() const { @@ -2683,7 +2698,7 @@ inline bool TrainerSpec::has_split_by_number() const { } inline void TrainerSpec::clear_split_by_number() { split_by_number_ = true; - _has_bits_[0] &= ~0x20000000u; + _has_bits_[0] &= ~0x40000000u; } inline bool TrainerSpec::_internal_split_by_number() const { return split_by_number_; @@ -2693,7 +2708,7 @@ inline bool TrainerSpec::split_by_number() const { return _internal_split_by_number(); } inline void TrainerSpec::_internal_set_split_by_number(bool value) { - _has_bits_[0] |= 0x20000000u; + _has_bits_[0] |= 0x40000000u; split_by_number_ = value; } inline void TrainerSpec::set_split_by_number(bool value) { @@ -2703,7 +2718,7 @@ inline void TrainerSpec::set_split_by_number(bool value) { // optional bool split_by_whitespace = 22 [default = true]; inline bool TrainerSpec::_internal_has_split_by_whitespace() const { - bool value = (_has_bits_[0] & 0x40000000u) != 0; + bool value = (_has_bits_[0] & 0x80000000u) != 0; return value; } inline bool TrainerSpec::has_split_by_whitespace() const { @@ -2711,7 +2726,7 @@ inline bool TrainerSpec::has_split_by_whitespace() const { } inline void TrainerSpec::clear_split_by_whitespace() { split_by_whitespace_ = true; - _has_bits_[0] &= ~0x40000000u; + _has_bits_[0] &= ~0x80000000u; } inline bool TrainerSpec::_internal_split_by_whitespace() const { return split_by_whitespace_; @@ -2721,7 +2736,7 @@ inline bool TrainerSpec::split_by_whitespace() const { return _internal_split_by_whitespace(); } inline void TrainerSpec::_internal_set_split_by_whitespace(bool value) { - _has_bits_[0] |= 0x40000000u; + _has_bits_[0] |= 0x80000000u; split_by_whitespace_ = value; } inline void TrainerSpec::set_split_by_whitespace(bool value) { @@ -2757,9 +2772,37 @@ inline void TrainerSpec::set_treat_whitespace_as_suffix(bool value) { // @@protoc_insertion_point(field_set:sentencepiece.TrainerSpec.treat_whitespace_as_suffix) } +// optional bool allow_whitespace_only_pieces = 26 [default = false]; +inline bool TrainerSpec::_internal_has_allow_whitespace_only_pieces() const { + bool value = (_has_bits_[0] & 0x00002000u) != 0; + return value; +} +inline bool TrainerSpec::has_allow_whitespace_only_pieces() const { + return _internal_has_allow_whitespace_only_pieces(); +} +inline void TrainerSpec::clear_allow_whitespace_only_pieces() { + allow_whitespace_only_pieces_ = false; + _has_bits_[0] &= ~0x00002000u; +} +inline bool TrainerSpec::_internal_allow_whitespace_only_pieces() const { + return allow_whitespace_only_pieces_; +} +inline bool TrainerSpec::allow_whitespace_only_pieces() const { + // @@protoc_insertion_point(field_get:sentencepiece.TrainerSpec.allow_whitespace_only_pieces) + return _internal_allow_whitespace_only_pieces(); +} +inline void TrainerSpec::_internal_set_allow_whitespace_only_pieces(bool value) { + _has_bits_[0] |= 0x00002000u; + allow_whitespace_only_pieces_ = value; +} +inline void TrainerSpec::set_allow_whitespace_only_pieces(bool value) { + _internal_set_allow_whitespace_only_pieces(value); + // @@protoc_insertion_point(field_set:sentencepiece.TrainerSpec.allow_whitespace_only_pieces) +} + // optional bool split_digits = 25 [default = false]; inline bool TrainerSpec::_internal_has_split_digits() const { - bool value = (_has_bits_[0] & 0x00002000u) != 0; + bool value = (_has_bits_[0] & 0x00004000u) != 0; return value; } inline bool TrainerSpec::has_split_digits() const { @@ -2767,7 +2810,7 @@ inline bool TrainerSpec::has_split_digits() const { } inline void TrainerSpec::clear_split_digits() { split_digits_ = false; - _has_bits_[0] &= ~0x00002000u; + _has_bits_[0] &= ~0x00004000u; } inline bool TrainerSpec::_internal_split_digits() const { return split_digits_; @@ -2777,7 +2820,7 @@ inline bool TrainerSpec::split_digits() const { return _internal_split_digits(); } inline void TrainerSpec::_internal_set_split_digits(bool value) { - _has_bits_[0] |= 0x00002000u; + _has_bits_[0] |= 0x00004000u; split_digits_ = value; } inline void TrainerSpec::set_split_digits(bool value) { @@ -3008,7 +3051,7 @@ inline void TrainerSpec::set_allocated_required_chars(std::string* required_char // optional bool byte_fallback = 35 [default = false]; inline bool TrainerSpec::_internal_has_byte_fallback() const { - bool value = (_has_bits_[0] & 0x00004000u) != 0; + bool value = (_has_bits_[0] & 0x00008000u) != 0; return value; } inline bool TrainerSpec::has_byte_fallback() const { @@ -3016,7 +3059,7 @@ inline bool TrainerSpec::has_byte_fallback() const { } inline void TrainerSpec::clear_byte_fallback() { byte_fallback_ = false; - _has_bits_[0] &= ~0x00004000u; + _has_bits_[0] &= ~0x00008000u; } inline bool TrainerSpec::_internal_byte_fallback() const { return byte_fallback_; @@ -3026,7 +3069,7 @@ inline bool TrainerSpec::byte_fallback() const { return _internal_byte_fallback(); } inline void TrainerSpec::_internal_set_byte_fallback(bool value) { - _has_bits_[0] |= 0x00004000u; + _has_bits_[0] |= 0x00008000u; byte_fallback_ = value; } inline void TrainerSpec::set_byte_fallback(bool value) { @@ -3036,7 +3079,7 @@ inline void TrainerSpec::set_byte_fallback(bool value) { // optional bool vocabulary_output_piece_score = 32 [default = true]; inline bool TrainerSpec::_internal_has_vocabulary_output_piece_score() const { - bool value = (_has_bits_[0] & 0x80000000u) != 0; + bool value = (_has_bits_[1] & 0x00000001u) != 0; return value; } inline bool TrainerSpec::has_vocabulary_output_piece_score() const { @@ -3044,7 +3087,7 @@ inline bool TrainerSpec::has_vocabulary_output_piece_score() const { } inline void TrainerSpec::clear_vocabulary_output_piece_score() { vocabulary_output_piece_score_ = true; - _has_bits_[0] &= ~0x80000000u; + _has_bits_[1] &= ~0x00000001u; } inline bool TrainerSpec::_internal_vocabulary_output_piece_score() const { return vocabulary_output_piece_score_; @@ -3054,7 +3097,7 @@ inline bool TrainerSpec::vocabulary_output_piece_score() const { return _internal_vocabulary_output_piece_score(); } inline void TrainerSpec::_internal_set_vocabulary_output_piece_score(bool value) { - _has_bits_[0] |= 0x80000000u; + _has_bits_[1] |= 0x00000001u; vocabulary_output_piece_score_ = value; } inline void TrainerSpec::set_vocabulary_output_piece_score(bool value) { @@ -3064,7 +3107,7 @@ inline void TrainerSpec::set_vocabulary_output_piece_score(bool value) { // optional bool hard_vocab_limit = 33 [default = true]; inline bool TrainerSpec::_internal_has_hard_vocab_limit() const { - bool value = (_has_bits_[1] & 0x00000001u) != 0; + bool value = (_has_bits_[1] & 0x00000002u) != 0; return value; } inline bool TrainerSpec::has_hard_vocab_limit() const { @@ -3072,7 +3115,7 @@ inline bool TrainerSpec::has_hard_vocab_limit() const { } inline void TrainerSpec::clear_hard_vocab_limit() { hard_vocab_limit_ = true; - _has_bits_[1] &= ~0x00000001u; + _has_bits_[1] &= ~0x00000002u; } inline bool TrainerSpec::_internal_hard_vocab_limit() const { return hard_vocab_limit_; @@ -3082,7 +3125,7 @@ inline bool TrainerSpec::hard_vocab_limit() const { return _internal_hard_vocab_limit(); } inline void TrainerSpec::_internal_set_hard_vocab_limit(bool value) { - _has_bits_[1] |= 0x00000001u; + _has_bits_[1] |= 0x00000002u; hard_vocab_limit_ = value; } inline void TrainerSpec::set_hard_vocab_limit(bool value) { @@ -3092,7 +3135,7 @@ inline void TrainerSpec::set_hard_vocab_limit(bool value) { // optional bool use_all_vocab = 34 [default = false]; inline bool TrainerSpec::_internal_has_use_all_vocab() const { - bool value = (_has_bits_[0] & 0x00008000u) != 0; + bool value = (_has_bits_[0] & 0x00020000u) != 0; return value; } inline bool TrainerSpec::has_use_all_vocab() const { @@ -3100,7 +3143,7 @@ inline bool TrainerSpec::has_use_all_vocab() const { } inline void TrainerSpec::clear_use_all_vocab() { use_all_vocab_ = false; - _has_bits_[0] &= ~0x00008000u; + _has_bits_[0] &= ~0x00020000u; } inline bool TrainerSpec::_internal_use_all_vocab() const { return use_all_vocab_; @@ -3110,7 +3153,7 @@ inline bool TrainerSpec::use_all_vocab() const { return _internal_use_all_vocab(); } inline void TrainerSpec::_internal_set_use_all_vocab(bool value) { - _has_bits_[0] |= 0x00008000u; + _has_bits_[0] |= 0x00020000u; use_all_vocab_ = value; } inline void TrainerSpec::set_use_all_vocab(bool value) { @@ -3148,7 +3191,7 @@ inline void TrainerSpec::set_unk_id(::PROTOBUF_NAMESPACE_ID::int32 value) { // optional int32 bos_id = 41 [default = 1]; inline bool TrainerSpec::_internal_has_bos_id() const { - bool value = (_has_bits_[1] & 0x00000002u) != 0; + bool value = (_has_bits_[1] & 0x00000004u) != 0; return value; } inline bool TrainerSpec::has_bos_id() const { @@ -3156,7 +3199,7 @@ inline bool TrainerSpec::has_bos_id() const { } inline void TrainerSpec::clear_bos_id() { bos_id_ = 1; - _has_bits_[1] &= ~0x00000002u; + _has_bits_[1] &= ~0x00000004u; } inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_bos_id() const { return bos_id_; @@ -3166,7 +3209,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::bos_id() const { return _internal_bos_id(); } inline void TrainerSpec::_internal_set_bos_id(::PROTOBUF_NAMESPACE_ID::int32 value) { - _has_bits_[1] |= 0x00000002u; + _has_bits_[1] |= 0x00000004u; bos_id_ = value; } inline void TrainerSpec::set_bos_id(::PROTOBUF_NAMESPACE_ID::int32 value) { @@ -3176,7 +3219,7 @@ inline void TrainerSpec::set_bos_id(::PROTOBUF_NAMESPACE_ID::int32 value) { // optional int32 eos_id = 42 [default = 2]; inline bool TrainerSpec::_internal_has_eos_id() const { - bool value = (_has_bits_[1] & 0x00000004u) != 0; + bool value = (_has_bits_[1] & 0x00000008u) != 0; return value; } inline bool TrainerSpec::has_eos_id() const { @@ -3184,7 +3227,7 @@ inline bool TrainerSpec::has_eos_id() const { } inline void TrainerSpec::clear_eos_id() { eos_id_ = 2; - _has_bits_[1] &= ~0x00000004u; + _has_bits_[1] &= ~0x00000008u; } inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_eos_id() const { return eos_id_; @@ -3194,7 +3237,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::eos_id() const { return _internal_eos_id(); } inline void TrainerSpec::_internal_set_eos_id(::PROTOBUF_NAMESPACE_ID::int32 value) { - _has_bits_[1] |= 0x00000004u; + _has_bits_[1] |= 0x00000008u; eos_id_ = value; } inline void TrainerSpec::set_eos_id(::PROTOBUF_NAMESPACE_ID::int32 value) { @@ -3204,7 +3247,7 @@ inline void TrainerSpec::set_eos_id(::PROTOBUF_NAMESPACE_ID::int32 value) { // optional int32 pad_id = 43 [default = -1]; inline bool TrainerSpec::_internal_has_pad_id() const { - bool value = (_has_bits_[1] & 0x00000008u) != 0; + bool value = (_has_bits_[1] & 0x00000010u) != 0; return value; } inline bool TrainerSpec::has_pad_id() const { @@ -3212,7 +3255,7 @@ inline bool TrainerSpec::has_pad_id() const { } inline void TrainerSpec::clear_pad_id() { pad_id_ = -1; - _has_bits_[1] &= ~0x00000008u; + _has_bits_[1] &= ~0x00000010u; } inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_pad_id() const { return pad_id_; @@ -3222,7 +3265,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::pad_id() const { return _internal_pad_id(); } inline void TrainerSpec::_internal_set_pad_id(::PROTOBUF_NAMESPACE_ID::int32 value) { - _has_bits_[1] |= 0x00000008u; + _has_bits_[1] |= 0x00000010u; pad_id_ = value; } inline void TrainerSpec::set_pad_id(::PROTOBUF_NAMESPACE_ID::int32 value) { @@ -3602,7 +3645,7 @@ inline void TrainerSpec::set_allocated_unk_surface(std::string* unk_surface) { // optional bool train_extremely_large_corpus = 49 [default = false]; inline bool TrainerSpec::_internal_has_train_extremely_large_corpus() const { - bool value = (_has_bits_[0] & 0x00020000u) != 0; + bool value = (_has_bits_[0] & 0x00040000u) != 0; return value; } inline bool TrainerSpec::has_train_extremely_large_corpus() const { @@ -3610,7 +3653,7 @@ inline bool TrainerSpec::has_train_extremely_large_corpus() const { } inline void TrainerSpec::clear_train_extremely_large_corpus() { train_extremely_large_corpus_ = false; - _has_bits_[0] &= ~0x00020000u; + _has_bits_[0] &= ~0x00040000u; } inline bool TrainerSpec::_internal_train_extremely_large_corpus() const { return train_extremely_large_corpus_; @@ -3620,7 +3663,7 @@ inline bool TrainerSpec::train_extremely_large_corpus() const { return _internal_train_extremely_large_corpus(); } inline void TrainerSpec::_internal_set_train_extremely_large_corpus(bool value) { - _has_bits_[0] |= 0x00020000u; + _has_bits_[0] |= 0x00040000u; train_extremely_large_corpus_ = value; } inline void TrainerSpec::set_train_extremely_large_corpus(bool value) { diff --git a/src/model_interface.cc b/src/model_interface.cc index ea5d0e7..c49be1e 100644 --- a/src/model_interface.cc +++ b/src/model_interface.cc @@ -134,32 +134,53 @@ void ModelInterface::InitializePieces() { } std::vector SplitIntoWords(absl::string_view text, - bool treat_whitespace_as_suffix) { + bool treat_ws_as_suffix, + bool allow_ws_only_pieces) { const char *begin = text.data(); const char *end = text.data() + text.size(); // Space symbol (U+2581) const absl::string_view kSpaceSymbol = "\xe2\x96\x81"; + bool in_ws_sequence = false; std::vector result; - if (treat_whitespace_as_suffix) { + if (treat_ws_as_suffix) { // put ws tokens at the end of non-ws sequences. if (begin < end) result.emplace_back(begin, 0); while (begin < end) { const int mblen = std::min(string_util::OneCharLen(begin), end - begin); const bool is_ws = absl::string_view(begin, mblen) == kSpaceSymbol; + + if (is_ws) { // keep track of sequences consecutive ws tokens. + in_ws_sequence = true; + } else if (in_ws_sequence) { + if (allow_ws_only_pieces) result.emplace_back(begin, 0); + + in_ws_sequence = false; + } + result.back() = absl::string_view(result.back().data(), result.back().size() + mblen); begin += mblen; - if (begin < end && is_ws) result.emplace_back(begin, 0); + + if (begin < end && is_ws && !allow_ws_only_pieces) + result.emplace_back(begin, 0); } } else { while (begin < end) { const int mblen = std::min(string_util::OneCharLen(begin), end - begin); + bool is_ws = absl::string_view(begin, mblen) == kSpaceSymbol; + + // if is whitespace (and not in sequence if allow_ws_only_pieces is True) if (begin == text.data() || - absl::string_view(begin, mblen) == kSpaceSymbol) + (is_ws && (!in_ws_sequence || !allow_ws_only_pieces))) { result.emplace_back(begin, 0); // add empty string piece. + in_ws_sequence = true; + } + + if (in_ws_sequence && !is_ws) in_ws_sequence = false; + result.back() = absl::string_view(result.back().data(), result.back().size() + mblen); begin += mblen; diff --git a/src/model_interface.h b/src/model_interface.h index 75cbb23..06b3a65 100644 --- a/src/model_interface.h +++ b/src/model_interface.h @@ -33,8 +33,9 @@ namespace sentencepiece { // "_this_is_a_pen" => ["_this", "_is", "_a", "_pen"] -std::vector SplitIntoWords(absl::string_view text, - bool add_ws_as_suffix = false); +std::vector SplitIntoWords( + absl::string_view text, bool treat_ws_as_suffix = false, + bool allow_ws_only_pieces = false); // Converts byte (0-255) to piece (e.g., 58 -> "<0x3A>"). std::string ByteToPiece(unsigned char c); @@ -106,12 +107,42 @@ class ModelInterface { return EncodeResult(); } + // Sample `samples` many tokenisations from the segmentation lattice + // If `wor` is true, the samples are taken without replacement, and the scores + // are the inclusion probabilities of the elements in the sample; otherwise + // the samples are taken with replacement and the scores are the log-probs of + // sample elements + // If `include_best` is true, the best tokenisation is always included in the + // sample, and the remaining elements are sampled excluding the best. + virtual NBestEncodeResult SampleEncodeAndScore(absl::string_view normalized, + float alpha, int samples, + bool wor, + bool include_best) const { + LOG(ERROR) << "Not implemented."; + return {{EncodeResult(), 0.0}}; + } + + // Calculates the entropy of the segmentation lattice with inverse temperature + // `theta`. + // Uses a novel dynamic program to calculate the entropy. + virtual float CalculateEntropy(absl::string_view normalized, + float theta) const { + LOG(ERROR) << "Not implemented."; + return 0.0; + } + // Return true if SampleEncode returns a valid result. virtual bool IsSampleEncodeAvailable() const { return false; } // Return true if NBestEncode returns a valid result. virtual bool IsNBestEncodeAvailable() const { return false; } + // Return true if SampleEncodeAndScore returns a valid result. + virtual bool IsSampleEncodeAndScoreAvailable() const { return false; } + + // Return true if CalculateEntropy returns a valid result. + virtual bool IsCalculateEntropyAvailable() const { return false; } + // Returns the vocab id of `piece`. // Returns UNK(0) if `piece` is unknown virtual int PieceToId(absl::string_view piece) const; @@ -124,7 +155,10 @@ class ModelInterface { // Returns the size of sentence pieces, which is the same // as the size of vocabulary for NMT. - virtual int GetPieceSize() const { return model_proto_->pieces_size(); } + virtual int GetPieceSize() const { + if (!model_proto_) return 0; + return model_proto_->pieces_size(); + } // Returns the score of `id`. // Score represents a log probability of the piece. diff --git a/src/model_interface_test.cc b/src/model_interface_test.cc index f5ee492..69ee4e6 100644 --- a/src/model_interface_test.cc +++ b/src/model_interface_test.cc @@ -412,6 +412,50 @@ TEST(ModelInterfaceTest, SplitIntoWordsSuffixTest) { } } +TEST(ModelInterfaceTest, SplitIntoWordsWhiteSpaceOnly) { + { + const auto v = + SplitIntoWords("this" WS "is" WS "a" WS "pen" WS, true, true); + EXPECT_EQ(4, v.size()); + EXPECT_EQ("this" WS, v[0]); + EXPECT_EQ("is" WS, v[1]); + EXPECT_EQ("a" WS, v[2]); + EXPECT_EQ("pen" WS, v[3]); + } + + { + const auto v = SplitIntoWords(WS WS WS "a", false, true); + EXPECT_EQ(1, v.size()); + EXPECT_EQ(WS WS WS "a", v[0]); + } + + { + const auto v = SplitIntoWords("a" WS WS WS, true, true); + EXPECT_EQ(1, v.size()); + EXPECT_EQ("a" WS WS WS, v[0]); + } + + { + const auto v = SplitIntoWords(WS WS, true, true); + EXPECT_EQ(1, v.size()); + EXPECT_EQ(WS WS, v[0]); + } + + { + const auto v = SplitIntoWords(WS WS "a" WS, true, true); + EXPECT_EQ(2, v.size()); + EXPECT_EQ(WS WS, v[0]); + EXPECT_EQ("a" WS, v[1]); + } + + { + const auto v = SplitIntoWords(WS WS "a" WS, false, true); + EXPECT_EQ(2, v.size()); + EXPECT_EQ(WS WS "a", v[0]); + EXPECT_EQ(WS, v[1]); + } +} + TEST(ModelInterfaceTest, ByteToPieceTest) { EXPECT_EQ(ByteToPiece(0), "<0x00>"); EXPECT_EQ(ByteToPiece(1), "<0x01>"); diff --git a/src/normalizer.cc b/src/normalizer.cc index 3fe919b..d87f89b 100644 --- a/src/normalizer.cc +++ b/src/normalizer.cc @@ -12,12 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License.! -#include "normalizer.h" - #include #include #include "common.h" +#include "normalizer.h" #include "third_party/absl/memory/memory.h" #include "third_party/absl/strings/match.h" #include "third_party/absl/strings/string_view.h" @@ -278,11 +277,11 @@ util::Status Normalizer::DecodePrecompiledCharsMap( absl::string_view blob, absl::string_view *trie_blob, absl::string_view *normalized, std::string *buffer) { uint32 trie_blob_size = 0; - if (blob.size() <= sizeof(trie_blob_size) || !string_util::DecodePOD( absl::string_view(blob.data(), sizeof(trie_blob_size)), - &trie_blob_size)) { + &trie_blob_size) || + trie_blob_size >= blob.size()) { return util::InternalError("Blob for normalization rule is broken."); } diff --git a/src/normalizer.h b/src/normalizer.h index 37fdb8a..c79813c 100644 --- a/src/normalizer.h +++ b/src/normalizer.h @@ -22,7 +22,6 @@ #include #include "common.h" -#include "util.h" #include "sentencepiece_model.pb.h" #include "sentencepiece_processor.h" #include "third_party/absl/strings/string_view.h" diff --git a/src/normalizer_test.cc b/src/normalizer_test.cc index 585e8f4..6c402bf 100644 --- a/src/normalizer_test.cc +++ b/src/normalizer_test.cc @@ -12,11 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License.! -#include "normalizer.h" - #include #include "builder.h" +#include "normalizer.h" #include "sentencepiece_trainer.h" #include "testharness.h" #include "util.h" diff --git a/src/sentencepiece_model.proto b/src/sentencepiece_model.proto index e735527..ee8e877 100644 --- a/src/sentencepiece_model.proto +++ b/src/sentencepiece_model.proto @@ -139,6 +139,10 @@ message TrainerSpec { // of sentence. optional bool treat_whitespace_as_suffix = 24 [default = false]; + // Allows pieces that only contain whitespaces instead of appearing only as + // prefix or suffix of other pieces. + optional bool allow_whitespace_only_pieces = 26 [default = false]; + // Split all digits (0-9) into separate pieces. optional bool split_digits = 25 [default = false]; diff --git a/src/sentencepiece_processor.cc b/src/sentencepiece_processor.cc index e4e9d4a..1e4e7a0 100644 --- a/src/sentencepiece_processor.cc +++ b/src/sentencepiece_processor.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License.! -#include "sentencepiece_processor.h" - #include #include #include @@ -24,6 +22,7 @@ #include "model_interface.h" #include "normalizer.h" #include "sentencepiece.pb.h" +#include "sentencepiece_processor.h" #include "third_party/absl/memory/memory.h" #include "third_party/absl/strings/numbers.h" #include "third_party/absl/strings/str_cat.h" @@ -504,6 +503,43 @@ util::Status SentencePieceProcessor::SampleEncode( return util::OkStatus(); } +util::Status SentencePieceProcessor::SampleEncodeAndScore( + absl::string_view input, int samples, float theta, bool wor, + bool include_best, NBestSentencePieceText *samples_spt) const { + CHECK_OR_RETURN(model_->IsSampleEncodeAndScoreAvailable()) + << "SampleEncodeAndScore is not available for the current model."; + std::string normalized; + std::vector norm_to_orig; + RETURN_IF_ERROR(normalizer_->Normalize(input, &normalized, &norm_to_orig)); + + const auto results = model_->SampleEncodeAndScore(normalized, theta, samples, + wor, include_best); + CHECK_OR_RETURN(!results.empty()) + << "SampleEncodeAndScore returns empty result."; + + for (const auto &result : results) { + auto *spt = samples_spt->add_nbests(); + spt->set_score(result.second); + RETURN_IF_ERROR(PopulateSentencePieceText(input, normalized, norm_to_orig, + result.first, spt)); + } + + return util::OkStatus(); +} + +util::Status SentencePieceProcessor::CalculateEntropy(absl::string_view input, + float theta, + float *entropy) const { + CHECK_OR_RETURN(model_->IsCalculateEntropyAvailable()) + << "CalculateEntropy is not available for the current model."; + std::string normalized; + std::vector norm_to_orig; + RETURN_IF_ERROR(normalizer_->Normalize(input, &normalized, &norm_to_orig)); + + *entropy = model_->CalculateEntropy(normalized, theta); + return util::OkStatus(); +} + util::Status SentencePieceProcessor::Decode( const std::vector &pieces, SentencePieceText *spt) const { CHECK_OR_RETURN_STATUS_PROTO(spt); @@ -833,6 +869,12 @@ std::string SentencePieceProcessor::serialized_model_proto() const { return model_proto_ ? model_proto_->SerializeAsString() : ""; } +// Set seed value of random generator. +// Do not set static_cast(-1), +// as this seed is reserved for initializing from +// std::random_device. +void SetRandomGeneratorSeed(unsigned int seed); + namespace io { util::Status LoadModelProto(absl::string_view filename, diff --git a/src/sentencepiece_processor.h b/src/sentencepiece_processor.h index 7227920..7c75838 100644 --- a/src/sentencepiece_processor.h +++ b/src/sentencepiece_processor.h @@ -315,6 +315,15 @@ class SentencePieceProcessor { virtual util::Status SampleEncode(absl::string_view input, int nbest_size, float alpha, SentencePieceText *spt) const; + // Samples N segmentation and returns the scores as well + virtual util::Status SampleEncodeAndScore( + absl::string_view input, int samples, float theta, bool wor, + bool include_best, NBestSentencePieceText *samples_spt) const; + + // Calculate entropy of possible tokenisations + virtual util::Status CalculateEntropy(absl::string_view input, float theta, + float *entropy) const; + // Given a sequence of pieces, decodes it into SentencePieceText. virtual util::Status Decode(const std::vector &pieces, SentencePieceText *spt) const; diff --git a/src/sentencepiece_processor_test.cc b/src/sentencepiece_processor_test.cc index e10a47c..373e73e 100644 --- a/src/sentencepiece_processor_test.cc +++ b/src/sentencepiece_processor_test.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License.! -#include "sentencepiece_processor.h" - #include #include "builder.h" @@ -22,6 +20,7 @@ #include "normalizer.h" #include "sentencepiece.pb.h" #include "sentencepiece_model.pb.h" +#include "sentencepiece_processor.h" #include "sentencepiece_trainer.h" #include "testharness.h" #include "third_party/absl/container/flat_hash_map.h" @@ -1139,13 +1138,6 @@ TEST(SentencePieceProcessorTest, EndToEndTest) { EXPECT_EQ("cba", output); } - // Out of range - { - std::string output; - const std::vector ids = {3, 4, 127}; - EXPECT_FALSE(sp.Decode(ids, &output).ok()); - } - { EXPECT_TRUE(sp.SetDecodeExtraOptions("bos:eos").ok()); @@ -1172,6 +1164,13 @@ TEST(SentencePieceProcessorTest, EndToEndTest) { EXPECT_EQ("cba", output); } + // Out of range + { + std::string output; + const std::vector ids = {3, 4, 127}; + EXPECT_FALSE(sp.Decode(ids, &output).ok()); + } + { EXPECT_TRUE(sp.SetDecodeExtraOptions("bos:eos:reverse").ok()); diff --git a/src/sentencepiece_trainer.cc b/src/sentencepiece_trainer.cc index 429d0f4..888f05e 100644 --- a/src/sentencepiece_trainer.cc +++ b/src/sentencepiece_trainer.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License.! -#include "sentencepiece_trainer.h" - #include #include @@ -22,7 +20,9 @@ #include "normalizer.h" #include "sentencepiece.pb.h" #include "sentencepiece_model.pb.h" +#include "sentencepiece_trainer.h" #include "spec_parser.h" +#include "third_party/absl/flags/flag.h" #include "third_party/absl/strings/numbers.h" #include "third_party/absl/strings/str_cat.h" #include "third_party/absl/strings/str_split.h" @@ -31,6 +31,8 @@ #include "trainer_factory.h" #include "util.h" +ABSL_DECLARE_FLAG(int, minloglevel); + namespace sentencepiece { namespace { static constexpr char kDefaultNormalizerName[] = "nmt_nfkc"; @@ -110,7 +112,7 @@ util::Status SentencePieceTrainer::MergeSpecsFromArgs( for (auto arg : absl::StrSplit(args, " ")) { absl::ConsumePrefix(&arg, "--"); std::string key, value; - const auto pos = arg.find("="); + const auto pos = arg.find('='); if (pos == absl::string_view::npos) { key = std::string(arg); } else { @@ -149,7 +151,7 @@ util::Status SentencePieceTrainer::MergeSpecsFromArgs( } else if (key == "minloglevel") { int v = 0; CHECK_OR_RETURN(absl::SimpleAtoi(value, &v)); - logging::SetMinLogLevel(v); + absl::SetFlag(&FLAGS_minloglevel, v); continue; } diff --git a/src/spec_parser.h b/src/spec_parser.h index a168322..2c5a95b 100644 --- a/src/spec_parser.h +++ b/src/spec_parser.h @@ -145,6 +145,7 @@ inline std::string PrintProto(const TrainerSpec &message, PRINT_PARAM(split_by_whitespace); PRINT_PARAM(split_digits); PRINT_PARAM(treat_whitespace_as_suffix); + PRINT_PARAM(allow_whitespace_only_pieces); PRINT_REPEATED_STRING(control_symbols); PRINT_REPEATED_STRING(user_defined_symbols); PRINT_PARAM(required_chars); @@ -219,6 +220,7 @@ util::Status SentencePieceTrainer::SetProtoField(const std::string &name, PARSE_BOOL(split_by_whitespace); PARSE_BOOL(split_digits); PARSE_BOOL(treat_whitespace_as_suffix); + PARSE_BOOL(allow_whitespace_only_pieces); PARSE_REPEATED_STRING(control_symbols); PARSE_REPEATED_STRING(user_defined_symbols); PARSE_STRING(required_chars); diff --git a/src/spm_decode_main.cc b/src/spm_decode_main.cc index 32cb382..3382ddc 100644 --- a/src/spm_decode_main.cc +++ b/src/spm_decode_main.cc @@ -64,6 +64,7 @@ int main(int argc, char *argv[]) { auto ToIds = [&](const std::vector &pieces) { std::vector ids; + ids.reserve(pieces.size()); for (const auto &s : pieces) { ids.push_back(atoi(s.c_str())); } diff --git a/src/spm_encode_main.cc b/src/spm_encode_main.cc index f151ecf..4d12a38 100644 --- a/src/spm_encode_main.cc +++ b/src/spm_encode_main.cc @@ -28,16 +28,17 @@ #include "trainer_interface.h" ABSL_FLAG(std::string, model, "", "model file name"); -ABSL_FLAG(std::string, output_format, "piece", - "choose from piece, id, proto, nbest_piece, nbest_id, nbest_proto, " - "sample_piece, sample_id or sample_proto."); +ABSL_FLAG( + std::string, output_format, "piece", + "choose from piece, id, proto, nbest_piece, nbest_id, or nbest_proto"); ABSL_FLAG(std::string, input, "", "input filename"); ABSL_FLAG(std::string, output, "", "output filename"); ABSL_FLAG(std::string, extra_options, "", "':' separated encoder extra options, e.g., \"reverse:bos:eos\""); ABSL_FLAG(int32, nbest_size, 10, "NBest size"); ABSL_FLAG(double, alpha, 0.5, "Smoothing parameter for sampling mode."); -ABSL_FLAG(int32, random_seed, -1, "Seed value for random generator."); +ABSL_FLAG(uint32, random_seed, static_cast(-1), + "Seed value for random generator."); // Piece restriction with vocabulary file. // https://github.com/rsennrich/subword-nmt#best-practice-advice-for-byte-pair-encoding-in-nmt @@ -61,8 +62,9 @@ int main(int argc, char *argv[]) { rest_args.push_back(absl::GetFlag(FLAGS_input)); } - if (absl::GetFlag(FLAGS_random_seed) != -1) + if (absl::GetFlag(FLAGS_random_seed) != -1) { sentencepiece::SetRandomGeneratorSeed(absl::GetFlag(FLAGS_random_seed)); + } if (rest_args.empty()) rest_args.push_back(""); // empty means that reading from stdin. diff --git a/src/spm_train_main.cc b/src/spm_train_main.cc index a21fb8b..baf8dbf 100644 --- a/src/spm_train_main.cc +++ b/src/spm_train_main.cc @@ -80,6 +80,9 @@ ABSL_FLAG(bool, split_digits, kDefaultTrainerSpec.split_digits(), ABSL_FLAG(bool, treat_whitespace_as_suffix, kDefaultTrainerSpec.treat_whitespace_as_suffix(), "treat whitespace marker as suffix instead of prefix."); +ABSL_FLAG(bool, allow_whitespace_only_pieces, + kDefaultTrainerSpec.allow_whitespace_only_pieces(), + "allow pieces that only contain (consecutive) whitespace tokens"); ABSL_FLAG(std::string, control_symbols, "", "comma separated list of control symbols"); ABSL_FLAG(std::string, control_symbols_file, "", @@ -138,7 +141,8 @@ ABSL_FLAG(std::string, unk_surface, kDefaultTrainerSpec.unk_surface(), ABSL_FLAG(bool, train_extremely_large_corpus, kDefaultTrainerSpec.train_extremely_large_corpus(), "Increase bit depth for unigram tokenization."); -ABSL_FLAG(int32, random_seed, -1, "Seed value for random generator."); +ABSL_FLAG(uint32, random_seed, static_cast(-1), + "Seed value for random generator."); int main(int argc, char *argv[]) { sentencepiece::ParseCommandLineFlags(argv[0], &argc, &argv, true); @@ -150,8 +154,9 @@ int main(int argc, char *argv[]) { CHECK(!absl::GetFlag(FLAGS_input).empty()); CHECK(!absl::GetFlag(FLAGS_model_prefix).empty()); - if (absl::GetFlag(FLAGS_random_seed) != -1) + if (absl::GetFlag(FLAGS_random_seed) != -1) { sentencepiece::SetRandomGeneratorSeed(absl::GetFlag(FLAGS_random_seed)); + } auto load_lines = [](absl::string_view filename) { std::vector lines; @@ -211,6 +216,7 @@ int main(int argc, char *argv[]) { SetTrainerSpecFromFlag(split_digits); SetTrainerSpecFromFlag(byte_fallback); SetTrainerSpecFromFlag(treat_whitespace_as_suffix); + SetTrainerSpecFromFlag(allow_whitespace_only_pieces); SetTrainerSpecFromFlag(hard_vocab_limit); SetTrainerSpecFromFlag(use_all_vocab); SetTrainerSpecFromFlag(unk_id); diff --git a/src/trainer_interface.cc b/src/trainer_interface.cc index 53edc7b..a3a4b74 100644 --- a/src/trainer_interface.cc +++ b/src/trainer_interface.cc @@ -12,8 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License.! -#include "trainer_interface.h" - +#include #include #include #include @@ -34,6 +33,7 @@ #include "third_party/absl/strings/str_format.h" #include "third_party/absl/strings/str_join.h" #include "third_party/absl/strings/str_split.h" +#include "trainer_interface.h" #include "unicode_script.h" #include "util.h" @@ -86,6 +86,10 @@ util::Status VerifySpec(const TrainerSpec &trainer_spec) { return util::OkStatus(); } +bool is_unicode_decimal_number(char32 c) { + return (c >= 0x30 && c <= 0x39) || (c >= 0xff10 && c <= 0xff19); +} + class SentenceSelector { public: using Sampler = random::ReservoirSampler; @@ -210,9 +214,10 @@ bool TrainerInterface::IsValidSentencePiece( constexpr unicode_script::ScriptType kAnyType = static_cast(-1); - auto is_number = [](char32 c) { return (c >= 0x30 && c <= 0x39); }; - unicode_script::ScriptType prev_script = kAnyType; + bool all_whitespace_piece = + std::all_of(sentencepiece.begin(), sentencepiece.end(), + [](char32 c) { return c == kWSChar; }); for (size_t pos = 0; pos < sentencepiece.size(); ++pos) { const char32 c = sentencepiece[pos]; @@ -235,25 +240,30 @@ bool TrainerInterface::IsValidSentencePiece( } if (c == kWSChar) { - // Only allows whitespace to appear as a prefix of piece. + // Only allows whitespace to appear as a prefix of piece unless + // allow_whitespace_only_pieces is True. // When split_by_whitespace is false, we allow whitespaces to // appear in the middle, "foo_bar", but do not allow them // to appear as suffix, "foo_bar_". // Regardless of the setting of split_by_whitespace, // whitespace is treated as a prefix/infix of symbol or - // independent symbol. - if (trainer_spec_.treat_whitespace_as_suffix()) { - if ((trainer_spec_.split_by_whitespace() && - pos < sentencepiece.size() - 1) || - (!trainer_spec_.split_by_whitespace() && - pos < sentencepiece.size() - 1 && pos == 0)) { - return false; - } - } else { - if ((trainer_spec_.split_by_whitespace() && pos > 0) || - (!trainer_spec_.split_by_whitespace() && pos > 0 && - pos == sentencepiece.size() - 1)) { - return false; + // independent symbol, unless allow_whitespace_only_pieces() is true, + // in which case whitespace only pieces can occur. + if (!trainer_spec_.allow_whitespace_only_pieces() || + !all_whitespace_piece) { + if (trainer_spec_.treat_whitespace_as_suffix()) { + if ((trainer_spec_.split_by_whitespace() && + pos < sentencepiece.size() - 1) || + (!trainer_spec_.split_by_whitespace() && + pos < sentencepiece.size() - 1 && pos == 0)) { + return false; + } + } else { + if ((trainer_spec_.split_by_whitespace() && pos > 0) || + (!trainer_spec_.split_by_whitespace() && pos > 0 && + pos == sentencepiece.size() - 1)) { + return false; + } } } } else { @@ -265,11 +275,11 @@ bool TrainerInterface::IsValidSentencePiece( s = unicode_script::U_Han; } - if (!trainer_spec_.split_by_number() && is_number(c)) { + if (!trainer_spec_.split_by_number() && is_unicode_decimal_number(c)) { s = kAnyType; } - if (trainer_spec_.split_digits() && is_number(c)) { + if (trainer_spec_.split_digits() && is_unicode_decimal_number(c)) { if (sentencepiece.size() > 1) return false; } @@ -518,7 +528,8 @@ void TrainerInterface::SplitSentencesByWhitespace() { absl::flat_hash_map tokens; for (const auto &s : sentences_) { for (const auto &w : - SplitIntoWords(s.first, trainer_spec_.treat_whitespace_as_suffix())) { + SplitIntoWords(s.first, trainer_spec_.treat_whitespace_as_suffix(), + trainer_spec_.allow_whitespace_only_pieces())) { tokens[std::string(w)] += s.second; } } diff --git a/src/trainer_interface_test.cc b/src/trainer_interface_test.cc index c61c7ce..70a51ad 100644 --- a/src/trainer_interface_test.cc +++ b/src/trainer_interface_test.cc @@ -81,6 +81,7 @@ TEST(TrainerInterfaceTest, IsValidSentencePieceTest) { trainer_spec.set_split_by_whitespace(false); EXPECT_TRUE(IsValid(WS)); + EXPECT_TRUE(IsValid(WS WS WS "a")); EXPECT_TRUE(IsValid(WS "a")); EXPECT_FALSE(IsValid("a" WS)); EXPECT_FALSE(IsValid(WS "a" WS)); @@ -88,7 +89,17 @@ TEST(TrainerInterfaceTest, IsValidSentencePieceTest) { EXPECT_TRUE(IsValid(WS "a" WS "b")); EXPECT_TRUE(IsValid(WS "a" WS "b" WS "c")); EXPECT_FALSE(IsValid("a" WS "b" WS)); + EXPECT_FALSE(IsValid(WS WS)); + EXPECT_FALSE(IsValid(WS WS WS)); + trainer_spec.set_allow_whitespace_only_pieces(true); + EXPECT_TRUE(IsValid(WS)); + EXPECT_TRUE(IsValid(WS WS)); + EXPECT_TRUE(IsValid(WS WS WS)); + EXPECT_TRUE(IsValid(WS WS "a")); + EXPECT_FALSE(IsValid("a" WS WS)); // suffix whitespace illegal without flag + + trainer_spec.set_allow_whitespace_only_pieces(false); trainer_spec.set_split_by_unicode_script(false); EXPECT_TRUE(IsValid("あいう")); EXPECT_TRUE(IsValid("グーグル")); @@ -124,6 +135,15 @@ TEST(TrainerInterfaceTest, IsValidSentencePieceTest) { EXPECT_FALSE(IsValid(WS "a" WS "b")); EXPECT_FALSE(IsValid("a" WS "b" WS)); + trainer_spec.set_allow_whitespace_only_pieces(true); + EXPECT_TRUE(IsValid(WS)); + EXPECT_TRUE(IsValid(WS WS)); + EXPECT_FALSE(IsValid(WS "a" WS)); + EXPECT_FALSE(IsValid("a" WS "b")); + EXPECT_FALSE(IsValid(WS "a" WS "b")); + EXPECT_FALSE(IsValid("a" WS "b" WS)); + + trainer_spec.set_allow_whitespace_only_pieces(false); trainer_spec.set_split_by_whitespace(false); EXPECT_TRUE(IsValid(WS)); EXPECT_FALSE(IsValid(WS "a")); @@ -146,6 +166,12 @@ TEST(TrainerInterfaceTest, IsValidSentencePieceTest) { EXPECT_FALSE(IsValid("2007")); EXPECT_FALSE(IsValid("x1")); EXPECT_FALSE(IsValid("2x")); + // Fullwidth digits. + EXPECT_TRUE(IsValid("1")); + EXPECT_FALSE(IsValid("59")); + EXPECT_FALSE(IsValid("2007")); + EXPECT_FALSE(IsValid("*1")); + EXPECT_FALSE(IsValid("2*")); } TEST(TrainerInterfaceTest, OverrideSpecialPiecesTest) { diff --git a/src/unigram_model.cc b/src/unigram_model.cc index bd2d99b..3b99060 100644 --- a/src/unigram_model.cc +++ b/src/unigram_model.cc @@ -15,6 +15,7 @@ #include #include #include +#include #include #include #include @@ -55,6 +56,17 @@ inline float LogSumExp(float x, float y, bool init_mode) { return vmax + log(std::exp(static_cast(vmin - vmax)) + 1.0); } } + +// Returns a sample from a standard Gumbel distribution. +// If U ~ U[0, 1], -log(-log U) ~ G(0,1) +inline float Gumbel() { + const float kEpsilon = 1e-7; + auto *mt = random::GetRandomGenerator(); + std::uniform_real_distribution dis(0.0, 1.0); + float noise = -std::log(-(std::log(dis(*mt) + kEpsilon))); + + return noise; +} } // namespace Lattice::Lattice() : node_allocator_(kPreallocateLatticeNodeSize) {} @@ -145,7 +157,7 @@ Lattice::Node *Lattice::Insert(int pos, int length) { return node; } -std::vector Lattice::Viterbi() { +Lattice::LatticePathWithScore Lattice::Viterbi() { const int len = size(); for (int pos = 0; pos <= len; ++pos) { @@ -171,6 +183,7 @@ std::vector Lattice::Viterbi() { // backtrace std::vector results; + float score = begin_nodes(len)[0]->backtrace_score; for (Node *node = begin_nodes_[len][0]->prev; node->prev != nullptr; node = node->prev) { results.push_back(node); @@ -178,30 +191,32 @@ std::vector Lattice::Viterbi() { std::reverse(results.begin(), results.end()); - return results; + LatticePathWithScore retval = {results, score}; + + return retval; } -float Lattice::PopulateMarginal(float freq, - std::vector *expected) const { - if (expected == nullptr) return 0.0; - +std::vector Lattice::ForwardAlgorithm(float theta) const { const int len = size(); - - // alpha and beta (accumulative log prob) in Forward Backward. - // the index of alpha/beta is Node::node_id. std::vector alpha(node_allocator_.size(), 0.0); - std::vector beta(node_allocator_.size(), 0.0); for (int pos = 0; pos <= len; ++pos) { for (Node *rnode : begin_nodes_[pos]) { for (Node *lnode : end_nodes_[pos]) { - alpha[rnode->node_id] = LogSumExp(alpha[rnode->node_id], - lnode->score + alpha[lnode->node_id], - lnode == end_nodes_[pos][0]); + alpha[rnode->node_id] = LogSumExp( + alpha[rnode->node_id], theta * lnode->score + alpha[lnode->node_id], + lnode == end_nodes_[pos][0]); } } } + return alpha; +} + +std::vector Lattice::BackwardAlgorithm(float theta) const { + const int len = size(); + std::vector beta(node_allocator_.size(), 0.0); + for (int pos = len; pos >= 0; --pos) { for (Node *lnode : end_nodes_[pos]) { for (Node *rnode : begin_nodes_[pos]) { @@ -212,6 +227,21 @@ float Lattice::PopulateMarginal(float freq, } } + return beta; +} + +float Lattice::PopulateMarginal(float freq, + std::vector *expected) const { + if (expected == nullptr) return 0.0; + + const int len = size(); + + // alpha and beta (accumulative log prob) in Forward Backward. + // the index of alpha/beta is Node::node_id. + + const auto alpha = ForwardAlgorithm(1.0); + const auto beta = BackwardAlgorithm(1.0); + const float Z = alpha[begin_nodes_[len][0]->node_id]; for (int pos = 0; pos < len; ++pos) { for (Node *node : begin_nodes_[pos]) { @@ -228,13 +258,46 @@ float Lattice::PopulateMarginal(float freq, return freq * Z; } -std::vector> Lattice::NBest(size_t nbest_size) { +float Lattice::CalculateEntropy(float theta) const { + const int len = size(); + + // alpha[node_id] is the marginal prob of sequence up to start of node + // H is entropy of sequence + // the index of alpha/H is Node::node_id. + std::vector alpha(node_allocator_.size(), 0.0); + std::vector H(node_allocator_.size(), 0.0); + + // Populate the forward marginals to get the normalising constant + alpha = ForwardAlgorithm(theta); + + // Now populate the forward entropies + for (int pos = 0; pos <= len; ++pos) { + for (Node *rnode : begin_nodes_[pos]) { + for (Node *lnode : end_nodes_[pos]) { + // Contribution each lnode makes = p(lnode) * (H(lnode) + log p(lnode)) + + // We have to normalise p(lnode) by the marginal contribution it makes + const float lnode_transition_prob = + ((theta * lnode->score) + alpha[lnode->node_id] - + alpha[rnode->node_id]); + H[rnode->node_id] += std::exp(lnode_transition_prob) * + (H[lnode->node_id] + lnode_transition_prob); + } + } + } + + return -H[begin_nodes_[len][0]->node_id]; +} + +std::vector Lattice::NBest(size_t nbest_size, + bool sample, + float theta) { if (nbest_size < 1) { LOG(WARNING) << "nbest_size >= 1. Returns empty result."; return {}; } - if (nbest_size == 1) { + if (nbest_size == 1 && !sample) { return {Viterbi()}; } @@ -243,6 +306,7 @@ std::vector> Lattice::NBest(size_t nbest_size) { // At each partial path x, compute f(x) as follows // f(x) = g(x) + h(x). // g(x): the sum of scores from EOS to the left-most node in x. + // for a complete hypothesis, g(hyp) is the score of the hypothesis. // h(x): a heuristic that estimates the largest score from x to BOS. // f(x): the priority to pop a new hypothesis from the priority queue. // @@ -268,17 +332,26 @@ std::vector> Lattice::NBest(size_t nbest_size) { model::FreeList hypothesis_allocator(kPreallocatedHypothesisSize); Agenda agenda; - std::vector> results; + std::vector results; auto *eos = hypothesis_allocator.Allocate(); eos->node = eos_node(); eos->next = nullptr; - eos->fx = eos->node->score; - eos->gx = eos->node->score; - agenda.push(eos); + eos->gx = 0.0; - // Run Viterbi first to fill backtrace score. - Viterbi(); + std::vector alpha(node_allocator_.size(), 0.0); + + if (sample) { + // Run forwards algorithm to get normalising constants + alpha = ForwardAlgorithm(theta); + // f(eos) = Gumbel(0), as it is the perturbed score of the entire lattice. + eos->fx = Gumbel(); + } else { + // Run Viterbi first to fill backtrace score. + Viterbi(); + eos->fx = eos->node->backtrace_score; + } + agenda.push(eos); while (!agenda.empty()) { auto *top = agenda.top(); @@ -289,21 +362,56 @@ std::vector> Lattice::NBest(size_t nbest_size) { if (node == bos_node()) { results.resize(results.size() + 1); for (auto *n = top->next; n->next != nullptr; n = n->next) { - results.back().push_back(n->node); + results.back().first.push_back(n->node); } + results.back().second = top->fx; if (results.size() == nbest_size) { break; } continue; } + const int end_nodes_size = end_nodes(node->pos).size(); + std::vector probs(end_nodes_size, 0.0); + std::vector perturbed_probs(end_nodes_size, 0.0); + std::vector adjusted_probs(end_nodes_size, 0.0); + const float Z = alpha[node->node_id]; + if (sample) { + float max_score = -1e8; + // Calculate the marginal and perturbed scores for stochastic search + for (int i = 0; i < end_nodes(node->pos).size(); i++) { + Node *lnode = end_nodes(node->pos)[i]; + // Calculate backwards transition score + probs[i] = top->gx + alpha[lnode->node_id] + (theta * lnode->score) - Z; + perturbed_probs[i] = probs[i] + Gumbel(); + if (perturbed_probs[i] > max_score) { + max_score = perturbed_probs[i]; + } + } + // Now constrain the sampled continuations to match the score of parent + for (int i = 0; i < adjusted_probs.size(); i++) { + // Use numerically stable version of truncated Gumbel: + // https://arxiv.org/pdf/1903.06059.pdf appendix B.3 + const float v = top->fx - perturbed_probs[i] + + std::log1p(-std::exp(perturbed_probs[i] - max_score)); + adjusted_probs[i] = top->fx - std::max(static_cast(0.0), v) - + std::log1p(std::exp(-std::abs(v))); + } + } + // Expands new node ending at node->pos - for (Node *lnode : end_nodes(node->pos)) { + for (int i = 0; i < end_nodes(node->pos).size(); i++) { + Node *lnode = end_nodes(node->pos)[i]; auto *hyp = hypothesis_allocator.Allocate(); hyp->node = lnode; - hyp->gx = lnode->score + top->gx; // just adds node->score - hyp->fx = - lnode->backtrace_score + top->gx; // backtrace_score is h(node). + if (sample) { + hyp->gx = probs[i]; + hyp->fx = adjusted_probs[i]; + } else { + hyp->gx = lnode->score + top->gx; // just adds node->score + hyp->fx = + lnode->backtrace_score + top->gx; // backtrace_score is h(node). + } hyp->next = top; agenda.push(hyp); } @@ -335,15 +443,7 @@ std::vector Lattice::Sample(float theta) { std::vector alpha(node_allocator_.size(), 0.0); - for (int pos = 0; pos <= len; ++pos) { - for (Node *rnode : begin_nodes_[pos]) { - for (Node *lnode : end_nodes_[pos]) { - alpha[rnode->node_id] = LogSumExp( - alpha[rnode->node_id], theta * lnode->score + alpha[lnode->node_id], - lnode == end_nodes_[pos][0]); - } - } - } + alpha = ForwardAlgorithm(theta); auto *mt = random::GetRandomGenerator(); @@ -514,7 +614,7 @@ EncodeResult Model::Encode(absl::string_view normalized) const { PopulateNodes(&lattice); EncodeResult results; - for (const auto *node : lattice.Viterbi()) { + for (const auto *node : lattice.Viterbi().first) { results.emplace_back(node->piece, node->id); } @@ -534,14 +634,12 @@ NBestEncodeResult Model::NBestEncode(absl::string_view normalized, PopulateNodes(&lattice); NBestEncodeResult nbest_results; - for (const auto &nbest : lattice.NBest(nbest_size)) { + for (const auto &nbest : lattice.NBest(nbest_size, false, 0.0)) { EncodeResult results; - float score = 0.0; - for (const auto *node : nbest) { - score += node->score; + for (const auto *node : nbest.first) { results.emplace_back(node->piece, node->id); } - nbest_results.emplace_back(results, score); + nbest_results.emplace_back(results, nbest.second); } return nbest_results; @@ -565,6 +663,123 @@ EncodeResult Model::SampleEncode(absl::string_view normalized, return results; } +NBestEncodeResult Model::SampleEncodeAndScore(absl::string_view normalized, + float theta, int samples, + bool wor, + bool include_best) const { + if (!status().ok() || normalized.empty()) { + return {}; + } + NBestEncodeResult results; + Lattice lattice; + lattice.SetSentence(normalized); + PopulateNodes(&lattice); + + std::vector alpha = lattice.ForwardAlgorithm(theta); + float marginal = alpha[lattice.eos_node()->node_id]; + + if (include_best) { + if (!wor) { + LOG(FATAL) << "include_best not supported for wor false"; + } + EncodeResult result; + Lattice::LatticePathWithScore best_path = lattice.Viterbi(); + + for (const auto *node : best_path.first) { + result.emplace_back(node->piece, node->id); + } + + // Inclusion probability if we always include the best is 1. + results.emplace_back(result, 0.0); + } + + if (wor) { + // Draw k+1 samples as we need perturbed score of k+1th element + std::vector nbest_samples = + lattice.NBest(samples + 1, true, theta); + + if (include_best) { + std::vector> nbest_paths( + nbest_samples.size()); + for (int i = 0; i < nbest_samples.size(); i++) { + nbest_paths[i] = nbest_samples[i].first; + } + // Remove the best result from the samples if necessary + Lattice::LatticePathWithScore best_path = lattice.Viterbi(); + + const int index_of_best = + (std::find(nbest_paths.begin(), nbest_paths.end(), best_path.first) - + nbest_paths.begin()); + + if (index_of_best != nbest_samples.size()) { + LOG(INFO) << "removing best path from samples"; + nbest_samples.erase(nbest_samples.begin() + index_of_best); + } else { + nbest_samples.pop_back(); + } + } + // We use the perturbed score of the k+1th element to calculate the + // inclusion probability. + const double kappa = static_cast(nbest_samples.back().second); + // Discard the last sample + nbest_samples.pop_back(); + for (const auto &nbest : nbest_samples) { + EncodeResult result; + float score = 0.0; + + for (const auto *node : nbest.first) { + score += (theta * node->score); + result.emplace_back(node->piece, node->id); + } + + results.emplace_back(result, score - marginal); + } + + // Now calculate the inclusion probability + for (auto &it : results) { + // Only modify non best sample inclusion probabilities. + if (it.second != 0.0) { + double x = it.second - kappa; + double y = std::exp(x); + double inclusion_prob; + if (x <= -10) { + // Series expansion of the log Gumbel survival function up to eps. + inclusion_prob = + x - (y / 2) + (std::pow(y, 2) / 24) - std::pow(y, 4) / 2880; + } else { + inclusion_prob = std::log(-std::expm1(-y)); + } + it.second = static_cast(inclusion_prob); + } + } + } else { + while (results.size() < samples) { + Lattice lattice; + lattice.SetSentence(normalized); + PopulateNodes(&lattice); + + float score = 0.0; + EncodeResult result; + std::vector sample = lattice.Sample(theta); + for (const auto *node : sample) { + result.emplace_back(node->piece, node->id); + score += (theta * node->score); + } + results.emplace_back(result, score - marginal); + } + } + + return results; +} + +float Model::CalculateEntropy(absl::string_view normalized, float theta) const { + Lattice lattice; + lattice.SetSentence(normalized); + PopulateNodes(&lattice); + + return lattice.CalculateEntropy(theta); +} + bool Model::VerifyOutputsEquivalent(absl::string_view expected, absl::string_view actual) const { auto compute_unigram_model_score = diff --git a/src/unigram_model.h b/src/unigram_model.h index 2f66a5f..448e489 100644 --- a/src/unigram_model.h +++ b/src/unigram_model.h @@ -82,17 +82,28 @@ class Lattice { // After calling this method, The caller must set Node::score and Node::id. Node *Insert(int pos, int length); + using LatticePathWithScore = std::pair, float>; + // Returns Viterbi path. All nodes must be populated in advance. - std::vector Viterbi(); + LatticePathWithScore Viterbi(); + + // Runs forwards/backwards algorithm, returns vector with normalised + // transition probs. + std::vector ForwardAlgorithm(float theta) const; + std::vector BackwardAlgorithm(float theta) const; // Returns n-best results. - std::vector> NBest(size_t nbest_size); + std::vector NBest(size_t nbest_size, bool sample, + float theta); // Samples one path from the lattice according to the // generation probability (Product of piece probabilities). // `theta` is a smoothing parameter. std::vector Sample(float theta); + // Calculates the entropy of the lattice. + float CalculateEntropy(float theta) const; + // Populates marginal probability of every node in this lattice. // |freq| is the frequency of the sentence. // for (auto *node : all_nodes_) { @@ -127,8 +138,19 @@ class Model : public ModelInterface { EncodeResult SampleEncode(absl::string_view normalized, float theta) const override; + NBestEncodeResult SampleEncodeAndScore(absl::string_view normalized, + float theta, int samples, bool wor, + bool include_best) const override; + + float CalculateEntropy(absl::string_view normalized, + float theta) const override; + bool IsSampleEncodeAvailable() const override { return true; } + bool IsSampleEncodeAndScoreAvailable() const override { return true; } + + bool IsCalculateEntropyAvailable() const override { return true; } + bool IsNBestEncodeAvailable() const override { return true; } // Returns the minimum score in sentence pieces. diff --git a/src/unigram_model_test.cc b/src/unigram_model_test.cc index dacec38..5c292cb 100644 --- a/src/unigram_model_test.cc +++ b/src/unigram_model_test.cc @@ -161,11 +161,11 @@ TEST(LatticeTest, InsertTest) { TEST(LatticeTest, ViterbiFromIncompleteLatticeTest) { Lattice lattice; lattice.SetSentence("ABC"); - EXPECT_TRUE(lattice.Viterbi().empty()); + EXPECT_TRUE(lattice.Viterbi().first.empty()); // Still incomplete lattice.Insert(0, 1); - EXPECT_TRUE(lattice.Viterbi().empty()); + EXPECT_TRUE(lattice.Viterbi().first.empty()); lattice.Insert(1, 1); lattice.Insert(2, 1); @@ -198,16 +198,16 @@ TEST(LatticeTest, ViterbiTest) { InsertWithScore(&lattice, 0, 1, 0.0); // A InsertWithScore(&lattice, 1, 1, 0.0); // B InsertWithScore(&lattice, 2, 1, 0.0); // C - EXPECT_EQ("A B C", GetTokenized(lattice.Viterbi())); + EXPECT_EQ("A B C", GetTokenized(lattice.Viterbi().first)); InsertWithScore(&lattice, 0, 2, 2.0); // AB - EXPECT_EQ("AB C", GetTokenized(lattice.Viterbi())); + EXPECT_EQ("AB C", GetTokenized(lattice.Viterbi().first)); InsertWithScore(&lattice, 1, 2, 5.0); // BC - EXPECT_EQ("A BC", GetTokenized(lattice.Viterbi())); + EXPECT_EQ("A BC", GetTokenized(lattice.Viterbi().first)); InsertWithScore(&lattice, 0, 3, 10.0); // ABC - EXPECT_EQ("ABC", GetTokenized(lattice.Viterbi())); + EXPECT_EQ("ABC", GetTokenized(lattice.Viterbi().first)); } TEST(LatticeTest, NBestTest) { @@ -221,21 +221,174 @@ TEST(LatticeTest, NBestTest) { InsertWithScore(&lattice, 1, 2, 5.0); // BC InsertWithScore(&lattice, 0, 3, 10.0); // ABC - auto nbests = lattice.NBest(10); + auto nbests = lattice.NBest(10, false, 0.0); EXPECT_EQ(4, nbests.size()); - EXPECT_EQ("ABC", GetTokenized(nbests[0])); - EXPECT_EQ("A BC", GetTokenized(nbests[1])); - EXPECT_EQ("AB C", GetTokenized(nbests[2])); - EXPECT_EQ("A B C", GetTokenized(nbests[3])); + EXPECT_EQ("ABC", GetTokenized(nbests[0].first)); + EXPECT_EQ("A BC", GetTokenized(nbests[1].first)); + EXPECT_EQ("AB C", GetTokenized(nbests[2].first)); + EXPECT_EQ("A B C", GetTokenized(nbests[3].first)); - auto nbests0 = lattice.NBest(0); + auto nbests0 = lattice.NBest(0, false, 0.0); EXPECT_TRUE(nbests0.empty()); - auto nbests1 = lattice.NBest(1); + auto nbests1 = lattice.NBest(1, false, 0.0); EXPECT_EQ(nbests1.size(), 1); } +TEST(LatticeTest, NBestSampleTest) { + Lattice lattice; + lattice.SetSentence("ABC"); + + InsertWithScore(&lattice, 0, 1, 0.0); // A + InsertWithScore(&lattice, 1, 1, 0.0); // B + InsertWithScore(&lattice, 2, 1, 0.1); // C + InsertWithScore(&lattice, 0, 2, 0.2); // AB + InsertWithScore(&lattice, 1, 2, 0.5); // BC + InsertWithScore(&lattice, 0, 3, 1.0); // ABC + + // Calculate expected probabilities of each path + // Note that sampling without replacement affects the expected frequencies! + const std::vector kTheta = {0.0, 0.01, 0.5, 0.7, 1.0}; + for (const auto theta : kTheta) { + std::vector strings = {"ABC", "AB C", "A BC", "A B C"}; + std::map probs; + probs["ABC"] = std::exp(theta * 1.0); + probs["AB C"] = std::exp(theta * (0.2 + 0.1)); + probs["A BC"] = std::exp(theta * (0.0 + 0.5)); + probs["A B C"] = std::exp(theta * (0.0 + 0.0 + 0.1)); + + for (const auto &it : strings) { + EXPECT_EQ(1, probs.count(it)); + } + + double Z = 0.0; + for (const auto &it : probs) Z += it.second; + for (auto &it : probs) it.second /= Z; + + std::map, float> pair_probs; + for (const auto first : strings) { + for (const auto second : strings) { + if (first == second) { + pair_probs[std::make_pair(first, second)] = 0; + } else { + float first_prob = probs[first]; + float second_prob = probs[second] / (1 - first_prob); + pair_probs[std::make_pair(first, second)] = first_prob * second_prob; + } + } + } + + std::map inclusion_probs; + for (const auto string : strings) { + float inclusion_prob = 0.0; + for (const auto other_string : strings) { + inclusion_prob += pair_probs[std::make_pair(string, other_string)]; + } + for (const auto other_string : strings) { + inclusion_prob += pair_probs[std::make_pair(other_string, string)]; + } + inclusion_probs[string] = inclusion_prob / 2; + } + + int kTrials = 10000; + + std::vector kNumSamples = {1, 2}; + + for (const auto num_samples : kNumSamples) { + std::map counts; + for (int i = 0; i < kTrials; i++) { + auto nbests = lattice.NBest(num_samples, true, theta); + for (const auto nbest : nbests) { + counts[GetTokenized(nbest.first)]++; + } + } + + EXPECT_EQ(inclusion_probs.size(), counts.size()); + // If we take multiple samples WOR, we have to use corrected probs. + std::map probs_to_use = + (num_samples == 1 ? probs : inclusion_probs); + + for (const auto &it : probs_to_use) { + EXPECT_NEAR(it.second, 1.0 * counts[it.first] / (kTrials * num_samples), + 0.02); + } + } + } +} + +TEST(LatticeTest, CalculateEntropyTest) { + Lattice lattice; + lattice.SetSentence("ABC"); + + InsertWithScore(&lattice, 0, 1, 0.0); // A + InsertWithScore(&lattice, 1, 1, 0.0); // B + InsertWithScore(&lattice, 2, 1, 0.1); // C + InsertWithScore(&lattice, 0, 2, 0.2); // AB + InsertWithScore(&lattice, 1, 2, 0.5); // BC + InsertWithScore(&lattice, 0, 3, 1.0); // ABC + + // Calculate expected probabilities of each path + const std::vector kTheta = {0.0, 0.01, 0.5, 0.7, 1.0}; + for (const auto theta : kTheta) { + std::vector strings = {"ABC", "AB C", "A BC", "A B C"}; + std::map probs; + probs["ABC"] = std::exp(theta * 1.0); + probs["AB C"] = std::exp(theta * (0.2 + 0.1)); + probs["A BC"] = std::exp(theta * (0.0 + 0.5)); + probs["A B C"] = std::exp(theta * (0.0 + 0.0 + 0.1)); + + double Z = 0.0; + for (const auto &it : probs) Z += it.second; + for (auto &it : probs) it.second /= Z; + + for (const auto &it : strings) { + EXPECT_EQ(1, probs.count(it)); + } + float entropy = 0.0; + for (const auto &it : probs) { + entropy += (it.second * std::log(it.second)); + } + EXPECT_NEAR(-entropy, lattice.CalculateEntropy(theta), 0.02); + } +} + +TEST(LatticeTest, ForwardAlgorithmTest) { + Lattice lattice; + lattice.SetSentence("ABC"); + + InsertWithScore(&lattice, 0, 1, 0.0); // A + InsertWithScore(&lattice, 1, 1, 0.0); // B + InsertWithScore(&lattice, 2, 1, 0.1); // C + InsertWithScore(&lattice, 0, 2, 0.2); // AB + InsertWithScore(&lattice, 1, 2, 0.5); // BC + InsertWithScore(&lattice, 0, 3, 1.0); // ABC + + const std::vector kTheta = {0.0, 0.01, 0.5, 0.7, 1.0}; + for (const auto theta : kTheta) { + std::vector alpha = lattice.ForwardAlgorithm(theta); + EXPECT_EQ(alpha.size(), 8); // 6 nodes, plus BOS, EOS + // only alpha[C], alpha[EOS] have non-zero alpha + for (int i : {0, 1, 2, 3}) { + for (const auto &node : lattice.begin_nodes(i)) { + if (i < 2) { + EXPECT_EQ(alpha[node->node_id], 0.0); + } else if (i == 2) { + float Z = + std::log(std::exp(theta * (0.0 + 0.0)) + std::exp(theta * 0.2)); + EXPECT_EQ(alpha[node->node_id], Z); + } else if (i == 3) { + float Z = std::log(std::exp(theta * (0.0 + 0.0 + 0.1)) + // A + B + C + std::exp(theta * (0.2 + 0.1)) + // AB + C + std::exp(theta * (0.0 + 0.5)) + // A + BC + std::exp(theta * 1.0)); // ABC + EXPECT_EQ(Z, alpha[node->node_id]); + } + } + } + } +} + TEST(LatticeTest, PopulateMarginalTest) { Lattice lattice; lattice.SetSentence("ABC"); @@ -361,6 +514,102 @@ TEST(UnigramModelTest, SetUnigramModelTest) { model.model_proto().SerializeAsString()); } +TEST(UnigramModelTest, SampleEncodeAndScoreTest) { + // Test whether inclusion probabilities are correct + ModelProto model_proto = MakeBaseModelProto(); + AddPiece(&model_proto, "A", 0.0); // 3 + AddPiece(&model_proto, "B", 0.0); // 4 + AddPiece(&model_proto, "C", 0.1); // 5 + AddPiece(&model_proto, "AB", 0.2); // 6 + AddPiece(&model_proto, "BC", 0.5); // 7 + AddPiece(&model_proto, "ABC", 1.0); // 8 + + Model model(model_proto); + + Lattice lattice; + lattice.SetSentence("ABC"); + model.PopulateNodes(&lattice); + + std::vector kTheta = {0.0, 1.0}; + + for (const auto theta : kTheta) { + std::vector strings = {"ABC", "AB C", "A BC", "A B C"}; + std::map probs; + probs["ABC"] = std::exp(theta * 1.0); + probs["AB C"] = std::exp(theta * (0.2 + 0.1)); + probs["A BC"] = std::exp(theta * (0.0 + 0.5)); + probs["A B C"] = std::exp(theta * (0.0 + 0.0 + 0.1)); + + for (const auto &it : strings) { + EXPECT_EQ(1, probs.count(it)); + } + + double Z = 0.0; + for (const auto &it : probs) Z += it.second; + for (auto &it : probs) it.second /= Z; + + std::map, float> pair_probs; + for (const auto first : strings) { + for (const auto second : strings) { + if (first == second) { + pair_probs[std::make_pair(first, second)] = 0; + } else { + float first_prob = probs[first]; + float second_prob = probs[second] / (1 - first_prob); + pair_probs[std::make_pair(first, second)] = first_prob * second_prob; + } + } + } + + std::map inclusion_probs; + for (const auto string : strings) { + float inclusion_prob = 0.0; + for (const auto other_string : strings) { + inclusion_prob += pair_probs[std::make_pair(string, other_string)]; + } + for (const auto other_string : strings) { + inclusion_prob += pair_probs[std::make_pair(other_string, string)]; + } + inclusion_probs[string] = inclusion_prob / 2; + } + std::vector kNumSamples = {1, 2}; + + for (const auto num_samples : kNumSamples) { + std::map counts; + std::map scores; + int kTrials = 50000; + for (int i = 0; i < kTrials; i++) { + NBestEncodeResult sample = + model.SampleEncodeAndScore("ABC", theta, num_samples, true, false); + + for (const auto &it : sample) { + std::vector tokens; + for (const auto &inner_it : it.first) { + tokens.push_back(std::string(inner_it.first)); + } + std::string sample_string = absl::StrJoin(tokens, " "); + counts[sample_string] += 1; + // use the fact that E(1_{i in sample} / score of i) = 1 + // see https://arxiv.org/pdf/1903.06059.pdf appendix D + scores[sample_string] += std::exp(-it.second); + } + } + + // Check that counts and probs are correct + std::map probs_to_use = + (num_samples == 1 ? probs : inclusion_probs); + + for (const auto &it : scores) Z += it.second; + for (const auto &it : probs_to_use) { + EXPECT_NEAR(it.second, 1.0 * counts[it.first] / (kTrials * num_samples), + 0.02); + // The expectation is quite loose, use a higher tolerance + EXPECT_NEAR(1.0, scores[it.first] / kTrials, 0.05); + } + } + } +} + TEST_P(UnigramModelTest, PieceToIdTest) { ModelProto model_proto = MakeBaseModelProto(); diff --git a/src/unigram_model_trainer.cc b/src/unigram_model_trainer.cc index f2afc32..9615040 100644 --- a/src/unigram_model_trainer.cc +++ b/src/unigram_model_trainer.cc @@ -223,7 +223,7 @@ std::vector Trainer::RunEStep(const TrainerModel &model, float *obj, lattice.SetSentence(w); model.PopulateNodes(&lattice); const float Z = lattice.PopulateMarginal(freq, &expected[n]); - ntokens[n] += lattice.Viterbi().size(); + ntokens[n] += lattice.Viterbi().first.size(); CHECK(!std::isnan(Z)) << "likelihood is NAN. Input sentence may be too long"; objs[n] -= Z / all_sentence_freq; @@ -297,17 +297,17 @@ TrainerModel::SentencePieces Trainer::PruneSentencePieces( const auto &w = sentencepieces[i]; lattice.SetSentence(w.first); model.PopulateNodes(&lattice); - const auto nbests = lattice.NBest(2); + const auto nbests = lattice.NBest(2, false, 0.0); if (nbests.size() == 1) { // No second-best result is found. always keep this sentencepiece. always_keep[i] = true; continue; - } else if (nbests[0].size() >= 2) { + } else if (nbests[0].first.size() >= 2) { // Can safely remove this sentencepiece if its Viterbi path is split. always_keep[i] = false; - } else if (nbests[0].size() == 1) { + } else if (nbests[0].first.size() == 1) { always_keep[i] = true; - for (const auto *node : nbests[1]) { + for (const auto *node : nbests[1].first) { alternatives[i].push_back(node->id); } } @@ -339,7 +339,7 @@ TrainerModel::SentencePieces Trainer::PruneSentencePieces( lattice.SetSentence(w.first); model.PopulateNodes(&lattice); vsums[n] += w.second; - for (const auto *node : lattice.Viterbi()) { + for (const auto *node : lattice.Viterbi().first) { if (node->id >= 0) { freqs[n][node->id] += w.second; inverteds[n][node->id].push_back(i); diff --git a/src/util.cc b/src/util.cc index 9120673..8424448 100644 --- a/src/util.cc +++ b/src/util.cc @@ -12,11 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License.! -#include "util.h" - #include +#include "util.h" + namespace sentencepiece { + namespace { constexpr unsigned int kDefaultSeed = static_cast(-1); static unsigned int g_seed = kDefaultSeed; diff --git a/third_party/absl/flags/flag.cc b/third_party/absl/flags/flag.cc index 09ff78f..e7ac841 100644 --- a/third_party/absl/flags/flag.cc +++ b/third_party/absl/flags/flag.cc @@ -171,6 +171,7 @@ void Flag::set_value_as_str(const std::string &value_as_str) { template class Flag; template class Flag; +template class Flag; template class Flag; template class Flag; template class Flag;