From 9f3ed99f5c68609fc94ac83a5c0aea9eb163f6d9 Mon Sep 17 00:00:00 2001 From: Taku Kudo Date: Wed, 25 May 2022 14:03:45 +0900 Subject: [PATCH] Sync internal to github. DP related features are added. --- VERSION.txt | 2 +- python/VERSION.txt | 2 +- src/bpe_model_trainer.h | 24 +- src/builder.cc | 2 +- src/builtin_pb/sentencepiece_model.pb.cc | 343 ++++++++++++++--------- src/builtin_pb/sentencepiece_model.pb.h | 297 ++++++++++++++------ src/freelist.h | 9 +- src/freelist_test.cc | 13 +- src/model_interface.h | 4 +- src/normalizer.cc | 6 +- src/normalizer.h | 1 - src/sentencepiece_model.proto | 11 + src/sentencepiece_processor.cc | 42 ++- src/sentencepiece_processor.h | 18 +- src/sentencepiece_processor_test.cc | 81 ++++++ src/spm_train_main.cc | 16 ++ src/trainer_interface.cc | 71 ++++- src/trainer_interface.h | 8 +- src/trainer_interface_test.cc | 6 +- src/unigram_model.cc | 88 +++++- src/unigram_model_test.cc | 29 +- src/unigram_model_trainer.h | 6 +- src/unigram_model_trainer_test.cc | 115 +++++++- third_party/absl/flags/flag.cc | 1 + 24 files changed, 898 insertions(+), 297 deletions(-) diff --git a/VERSION.txt b/VERSION.txt index c65d728..7cd2918 100644 --- a/VERSION.txt +++ b/VERSION.txt @@ -1 +1 @@ -0.1.96 +0.1.97 diff --git a/python/VERSION.txt b/python/VERSION.txt index c65d728..7cd2918 100644 --- a/python/VERSION.txt +++ b/python/VERSION.txt @@ -1 +1 @@ -0.1.96 +0.1.97 diff --git a/src/bpe_model_trainer.h b/src/bpe_model_trainer.h index e011a37..2fdfb9c 100644 --- a/src/bpe_model_trainer.h +++ b/src/bpe_model_trainer.h @@ -15,6 +15,8 @@ #ifndef BPE_MODEL_TRAINER_H_ #define BPE_MODEL_TRAINER_H_ +#include +#include #include #include #include @@ -44,12 +46,12 @@ class Trainer : public TrainerInterface { const Symbol *right; // right symbol in bigram string_util::UnicodeText chars; // all flattend chracter sequence bool is_unk; // true if this symbol is unknown. - uint64 fp; // fingerprint of this symbol. - uint64 freq; // frequency of this symbol. + uint64_t fp; // fingerprint of this symbol. + uint64_t freq; // frequency of this symbol. // Position list. Use set so that we can keep the order of occurrence. // See EncodePos/DecodePos. - std::set positions; + std::set positions; bool IsBigram() const { return left != nullptr && right != nullptr; } std::string ToString() const; @@ -62,19 +64,19 @@ class Trainer : public TrainerInterface { int right; // right symbol index }; - // Encodes sid, left and right bigram index into uint64. + // Encodes sid, left and right bigram index into uint64_t. // Encoded value keeps the order of sid, left and right. - static uint64 EncodePos(int sid, int l, int r) { + static uint64_t EncodePos(int sid, int l, int r) { CHECK_GE(l, 0); CHECK_GE(r, 0); - CHECK_LE(l, kuint16max); - CHECK_LE(r, kuint16max); - const uint64 n = (static_cast(sid) << 32 | (l << 16 | r)); + CHECK_LE(l, std::numeric_limits::max()); + CHECK_LE(r, std::numeric_limits::max()); + const uint64_t n = (static_cast(sid) << 32 | (l << 16 | r)); return n; } - // Decodes sid, left and right bigram index from uint64. - static Position DecodePos(uint64 n) { + // Decodes sid, left and right bigram index from uint64_t. + static Position DecodePos(uint64_t n) { Position p; p.sid = n >> 32; p.left = (n >> 16) & 0xffff; @@ -111,7 +113,7 @@ class Trainer : public TrainerInterface { void UpdateActiveSymbols(); // All unique symbols. Key is a fingerprint of Symbol. - absl::flat_hash_map symbols_cache_; + absl::flat_hash_map symbols_cache_; // Set of symbols from which we find the best symbol in each iteration. std::set active_symbols_; diff --git a/src/builder.cc b/src/builder.cc index 378aaa0..58668f6 100644 --- a/src/builder.cc +++ b/src/builder.cc @@ -367,7 +367,7 @@ util::Status Builder::BuildNmtNFKCMap(CharsMap *chars_map) { nfkc_map[{0xFEFF}] = {0x20}; // ZERO WIDTH NO-BREAK nfkc_map[{0xFFFD}] = {0x20}; // REPLACEMENT CHARACTER nfkc_map[{0x200C}] = {0x20}; // ZERO WIDTH NON-JOINER -// nfkc_map[{0x200D}] = {0x20}; // ZERO WIDTH JOINER + // nfkc_map[{0x200D}] = {0x20}; // ZERO WIDTH JOINER // Ascii Control characters nfkc_map[{0x0001}] = {}; diff --git a/src/builtin_pb/sentencepiece_model.pb.cc b/src/builtin_pb/sentencepiece_model.pb.cc index e913731..a844938 100644 --- a/src/builtin_pb/sentencepiece_model.pb.cc +++ b/src/builtin_pb/sentencepiece_model.pb.cc @@ -285,22 +285,31 @@ class TrainerSpec::_Internal { (*has_bits)[0] |= 1u; } static void set_has_model_type(HasBits* has_bits) { - (*has_bits)[0] |= 524288u; + (*has_bits)[0] |= 4194304u; } static void set_has_vocab_size(HasBits* has_bits) { - (*has_bits)[0] |= 1048576u; + (*has_bits)[0] |= 8388608u; } static void set_has_self_test_sample_size(HasBits* has_bits) { (*has_bits)[0] |= 256u; } - static void set_has_character_coverage(HasBits* has_bits) { + static void set_has_enable_differential_privacy(HasBits* has_bits) { + (*has_bits)[0] |= 4096u; + } + static void set_has_differential_privacy_noise_level(HasBits* has_bits) { + (*has_bits)[0] |= 1048576u; + } + static void set_has_differential_privacy_clipping_threshold(HasBits* has_bits) { (*has_bits)[0] |= 2097152u; } + static void set_has_character_coverage(HasBits* has_bits) { + (*has_bits)[0] |= 16777216u; + } static void set_has_input_sentence_size(HasBits* has_bits) { (*has_bits)[0] |= 1024u; } static void set_has_shuffle_input_sentence(HasBits* has_bits) { - (*has_bits)[0] |= 268435456u; + (*has_bits)[0] |= 2147483648u; } static void set_has_mining_sentence_size(HasBits* has_bits) { (*has_bits)[0] |= 512u; @@ -309,67 +318,67 @@ class TrainerSpec::_Internal { (*has_bits)[0] |= 2048u; } static void set_has_seed_sentencepiece_size(HasBits* has_bits) { - (*has_bits)[0] |= 4194304u; - } - static void set_has_shrinking_factor(HasBits* has_bits) { - (*has_bits)[0] |= 8388608u; - } - static void set_has_max_sentence_length(HasBits* has_bits) { - (*has_bits)[0] |= 67108864u; - } - static void set_has_num_threads(HasBits* has_bits) { - (*has_bits)[0] |= 16777216u; - } - static void set_has_num_sub_iterations(HasBits* has_bits) { (*has_bits)[0] |= 33554432u; } - static void set_has_max_sentencepiece_length(HasBits* has_bits) { - (*has_bits)[0] |= 134217728u; + static void set_has_shrinking_factor(HasBits* has_bits) { + (*has_bits)[0] |= 67108864u; } - static void set_has_split_by_unicode_script(HasBits* has_bits) { + static void set_has_max_sentence_length(HasBits* has_bits) { (*has_bits)[0] |= 536870912u; } - static void set_has_split_by_number(HasBits* has_bits) { + static void set_has_num_threads(HasBits* has_bits) { + (*has_bits)[0] |= 134217728u; + } + static void set_has_num_sub_iterations(HasBits* has_bits) { + (*has_bits)[0] |= 268435456u; + } + static void set_has_max_sentencepiece_length(HasBits* has_bits) { (*has_bits)[0] |= 1073741824u; } + static void set_has_split_by_unicode_script(HasBits* has_bits) { + (*has_bits)[1] |= 1u; + } + static void set_has_split_by_number(HasBits* has_bits) { + (*has_bits)[1] |= 2u; + } static void set_has_split_by_whitespace(HasBits* has_bits) { - (*has_bits)[0] |= 2147483648u; + (*has_bits)[1] |= 4u; } static void set_has_treat_whitespace_as_suffix(HasBits* has_bits) { - (*has_bits)[0] |= 4096u; - } - static void set_has_allow_whitespace_only_pieces(HasBits* has_bits) { (*has_bits)[0] |= 8192u; } - static void set_has_split_digits(HasBits* has_bits) { + static void set_has_allow_whitespace_only_pieces(HasBits* has_bits) { (*has_bits)[0] |= 16384u; } + static void set_has_split_digits(HasBits* has_bits) { + (*has_bits)[0] |= 32768u; + } static void set_has_required_chars(HasBits* has_bits) { (*has_bits)[0] |= 4u; } static void set_has_byte_fallback(HasBits* has_bits) { - (*has_bits)[0] |= 32768u; + (*has_bits)[0] |= 65536u; } static void set_has_vocabulary_output_piece_score(HasBits* has_bits) { - (*has_bits)[1] |= 1u; + (*has_bits)[1] |= 8u; } static void set_has_hard_vocab_limit(HasBits* has_bits) { - (*has_bits)[1] |= 2u; + (*has_bits)[1] |= 16u; } static void set_has_use_all_vocab(HasBits* has_bits) { (*has_bits)[0] |= 131072u; } static void set_has_unk_id(HasBits* has_bits) { - (*has_bits)[0] |= 65536u; + (*has_bits)[0] |= 524288u; } static void set_has_bos_id(HasBits* has_bits) { - (*has_bits)[1] |= 4u; + (*has_bits)[1] |= 32u; } static void set_has_eos_id(HasBits* has_bits) { - (*has_bits)[1] |= 8u; + (*has_bits)[1] |= 64u; } static void set_has_pad_id(HasBits* has_bits) { - (*has_bits)[1] |= 16u; + (*has_bits)[1] |= 128u; } static void set_has_unk_piece(HasBits* has_bits) { (*has_bits)[0] |= 16u; @@ -474,8 +483,8 @@ void TrainerSpec::SharedCtor() { pad_piece_.UnsafeSetDefault(nullptr); ::memset(reinterpret_cast(this) + static_cast( reinterpret_cast(&self_test_sample_size_) - reinterpret_cast(this)), - 0, static_cast(reinterpret_cast(&train_extremely_large_corpus_) - - reinterpret_cast(&self_test_sample_size_)) + sizeof(train_extremely_large_corpus_)); + 0, static_cast(reinterpret_cast(&differential_privacy_clipping_threshold_) - + reinterpret_cast(&self_test_sample_size_)) + sizeof(differential_privacy_clipping_threshold_)); model_type_ = 1; vocab_size_ = 8000; character_coverage_ = 0.9995f; @@ -569,31 +578,31 @@ void TrainerSpec::Clear() { } if (cached_has_bits & 0x0000ff00u) { ::memset(&self_test_sample_size_, 0, static_cast( - reinterpret_cast(&byte_fallback_) - - reinterpret_cast(&self_test_sample_size_)) + sizeof(byte_fallback_)); + reinterpret_cast(&split_digits_) - + reinterpret_cast(&self_test_sample_size_)) + sizeof(split_digits_)); } if (cached_has_bits & 0x00ff0000u) { - ::memset(&unk_id_, 0, static_cast( - reinterpret_cast(&train_extremely_large_corpus_) - - reinterpret_cast(&unk_id_)) + sizeof(train_extremely_large_corpus_)); + ::memset(&byte_fallback_, 0, static_cast( + reinterpret_cast(&differential_privacy_clipping_threshold_) - + reinterpret_cast(&byte_fallback_)) + sizeof(differential_privacy_clipping_threshold_)); model_type_ = 1; vocab_size_ = 8000; + } + if (cached_has_bits & 0xff000000u) { character_coverage_ = 0.9995f; seed_sentencepiece_size_ = 1000000; shrinking_factor_ = 0.75f; - } - if (cached_has_bits & 0xff000000u) { num_threads_ = 16; num_sub_iterations_ = 2; max_sentence_length_ = 4192; max_sentencepiece_length_ = 16; shuffle_input_sentence_ = true; + } + cached_has_bits = _has_bits_[1]; + if (cached_has_bits & 0x000000ffu) { split_by_unicode_script_ = true; split_by_number_ = true; split_by_whitespace_ = true; - } - cached_has_bits = _has_bits_[1]; - if (cached_has_bits & 0x0000001fu) { vocabulary_output_piece_score_ = true; hard_vocab_limit_ = true; bos_id_ = 1; @@ -963,6 +972,30 @@ const char* TrainerSpec::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID CHK_(ptr); } else goto handle_unusual; continue; + // optional bool enable_differential_privacy = 50 [default = false]; + case 50: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 144)) { + _Internal::set_has_enable_differential_privacy(&_has_bits_); + enable_differential_privacy_ = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional float differential_privacy_noise_level = 51 [default = 0]; + case 51: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 157)) { + _Internal::set_has_differential_privacy_noise_level(&_has_bits_); + differential_privacy_noise_level_ = ::PROTOBUF_NAMESPACE_ID::internal::UnalignedLoad(ptr); + ptr += sizeof(float); + } else goto handle_unusual; + continue; + // optional uint64 differential_privacy_clipping_threshold = 52 [default = 0]; + case 52: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 160)) { + _Internal::set_has_differential_privacy_clipping_threshold(&_has_bits_); + differential_privacy_clipping_threshold_ = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; default: { handle_unusual: if ((tag & 7) == 4 || tag == 0) { @@ -1011,14 +1044,14 @@ failure: } // optional .sentencepiece.TrainerSpec.ModelType model_type = 3 [default = UNIGRAM]; - if (cached_has_bits & 0x00080000u) { + if (cached_has_bits & 0x00400000u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteEnumToArray( 3, this->_internal_model_type(), target); } // optional int32 vocab_size = 4 [default = 8000]; - if (cached_has_bits & 0x00100000u) { + if (cached_has_bits & 0x00800000u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(4, this->_internal_vocab_size(), target); } @@ -1042,7 +1075,7 @@ failure: } // optional float character_coverage = 10 [default = 0.9995]; - if (cached_has_bits & 0x00200000u) { + if (cached_has_bits & 0x01000000u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteFloatToArray(10, this->_internal_character_coverage(), target); } @@ -1066,79 +1099,81 @@ failure: } // optional int32 seed_sentencepiece_size = 14 [default = 1000000]; - if (cached_has_bits & 0x00400000u) { + if (cached_has_bits & 0x02000000u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(14, this->_internal_seed_sentencepiece_size(), target); } // optional float shrinking_factor = 15 [default = 0.75]; - if (cached_has_bits & 0x00800000u) { + if (cached_has_bits & 0x04000000u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteFloatToArray(15, this->_internal_shrinking_factor(), target); } // optional int32 num_threads = 16 [default = 16]; - if (cached_has_bits & 0x01000000u) { + if (cached_has_bits & 0x08000000u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(16, this->_internal_num_threads(), target); } // optional int32 num_sub_iterations = 17 [default = 2]; - if (cached_has_bits & 0x02000000u) { + if (cached_has_bits & 0x10000000u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(17, this->_internal_num_sub_iterations(), target); } // optional int32 max_sentence_length = 18 [default = 4192]; - if (cached_has_bits & 0x04000000u) { + if (cached_has_bits & 0x20000000u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(18, this->_internal_max_sentence_length(), target); } // optional bool shuffle_input_sentence = 19 [default = true]; - if (cached_has_bits & 0x10000000u) { + if (cached_has_bits & 0x80000000u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(19, this->_internal_shuffle_input_sentence(), target); } // optional int32 max_sentencepiece_length = 20 [default = 16]; - if (cached_has_bits & 0x08000000u) { + if (cached_has_bits & 0x40000000u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(20, this->_internal_max_sentencepiece_length(), target); } + cached_has_bits = _has_bits_[1]; // optional bool split_by_unicode_script = 21 [default = true]; - if (cached_has_bits & 0x20000000u) { + if (cached_has_bits & 0x00000001u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(21, this->_internal_split_by_unicode_script(), target); } // optional bool split_by_whitespace = 22 [default = true]; - if (cached_has_bits & 0x80000000u) { + if (cached_has_bits & 0x00000004u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(22, this->_internal_split_by_whitespace(), target); } // optional bool split_by_number = 23 [default = true]; - if (cached_has_bits & 0x40000000u) { + if (cached_has_bits & 0x00000002u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(23, this->_internal_split_by_number(), target); } + cached_has_bits = _has_bits_[0]; // optional bool treat_whitespace_as_suffix = 24 [default = false]; - if (cached_has_bits & 0x00001000u) { + if (cached_has_bits & 0x00002000u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(24, this->_internal_treat_whitespace_as_suffix(), target); } // optional bool split_digits = 25 [default = false]; - if (cached_has_bits & 0x00004000u) { + if (cached_has_bits & 0x00008000u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(25, this->_internal_split_digits(), target); } // optional bool allow_whitespace_only_pieces = 26 [default = false]; - if (cached_has_bits & 0x00002000u) { + if (cached_has_bits & 0x00004000u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(26, this->_internal_allow_whitespace_only_pieces(), target); } @@ -1157,13 +1192,13 @@ failure: cached_has_bits = _has_bits_[1]; // optional bool vocabulary_output_piece_score = 32 [default = true]; - if (cached_has_bits & 0x00000001u) { + if (cached_has_bits & 0x00000008u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(32, this->_internal_vocabulary_output_piece_score(), target); } // optional bool hard_vocab_limit = 33 [default = true]; - if (cached_has_bits & 0x00000002u) { + if (cached_has_bits & 0x00000010u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(33, this->_internal_hard_vocab_limit(), target); } @@ -1176,7 +1211,7 @@ failure: } // optional bool byte_fallback = 35 [default = false]; - if (cached_has_bits & 0x00008000u) { + if (cached_has_bits & 0x00010000u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(35, this->_internal_byte_fallback(), target); } @@ -1188,26 +1223,26 @@ failure: } // optional int32 unk_id = 40 [default = 0]; - if (cached_has_bits & 0x00010000u) { + if (cached_has_bits & 0x00080000u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(40, this->_internal_unk_id(), target); } cached_has_bits = _has_bits_[1]; // optional int32 bos_id = 41 [default = 1]; - if (cached_has_bits & 0x00000004u) { + if (cached_has_bits & 0x00000020u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(41, this->_internal_bos_id(), target); } // optional int32 eos_id = 42 [default = 2]; - if (cached_has_bits & 0x00000008u) { + if (cached_has_bits & 0x00000040u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(42, this->_internal_eos_id(), target); } // optional int32 pad_id = 43 [default = -1]; - if (cached_has_bits & 0x00000010u) { + if (cached_has_bits & 0x00000080u) { target = stream->EnsureSpace(target); target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(43, this->_internal_pad_id(), target); } @@ -1249,6 +1284,24 @@ failure: target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(49, this->_internal_train_extremely_large_corpus(), target); } + // optional bool enable_differential_privacy = 50 [default = false]; + if (cached_has_bits & 0x00001000u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(50, this->_internal_enable_differential_privacy(), target); + } + + // optional float differential_privacy_noise_level = 51 [default = 0]; + if (cached_has_bits & 0x00100000u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteFloatToArray(51, this->_internal_differential_privacy_noise_level(), target); + } + + // optional uint64 differential_privacy_clipping_threshold = 52 [default = 0]; + if (cached_has_bits & 0x00200000u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteUInt64ToArray(52, this->_internal_differential_privacy_clipping_threshold(), target); + } + // Extension range [200, 536870912) target = _extensions_._InternalSerialize( 200, 536870912, target, stream); @@ -1391,33 +1444,31 @@ size_t TrainerSpec::ByteSizeLong() const { this->_internal_training_sentence_size()); } - // optional bool treat_whitespace_as_suffix = 24 [default = false]; + // optional bool enable_differential_privacy = 50 [default = false]; if (cached_has_bits & 0x00001000u) { total_size += 2 + 1; } - // optional bool allow_whitespace_only_pieces = 26 [default = false]; + // optional bool treat_whitespace_as_suffix = 24 [default = false]; if (cached_has_bits & 0x00002000u) { total_size += 2 + 1; } - // optional bool split_digits = 25 [default = false]; + // optional bool allow_whitespace_only_pieces = 26 [default = false]; if (cached_has_bits & 0x00004000u) { total_size += 2 + 1; } - // optional bool byte_fallback = 35 [default = false]; + // optional bool split_digits = 25 [default = false]; if (cached_has_bits & 0x00008000u) { total_size += 2 + 1; } } if (cached_has_bits & 0x00ff0000u) { - // optional int32 unk_id = 40 [default = 0]; + // optional bool byte_fallback = 35 [default = false]; if (cached_has_bits & 0x00010000u) { - total_size += 2 + - ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( - this->_internal_unk_id()); + total_size += 2 + 1; } // optional bool use_all_vocab = 34 [default = false]; @@ -1430,115 +1481,134 @@ size_t TrainerSpec::ByteSizeLong() const { total_size += 2 + 1; } - // optional .sentencepiece.TrainerSpec.ModelType model_type = 3 [default = UNIGRAM]; + // optional int32 unk_id = 40 [default = 0]; if (cached_has_bits & 0x00080000u) { + total_size += 2 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( + this->_internal_unk_id()); + } + + // optional float differential_privacy_noise_level = 51 [default = 0]; + if (cached_has_bits & 0x00100000u) { + total_size += 2 + 4; + } + + // optional uint64 differential_privacy_clipping_threshold = 52 [default = 0]; + if (cached_has_bits & 0x00200000u) { + total_size += 2 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::UInt64Size( + this->_internal_differential_privacy_clipping_threshold()); + } + + // optional .sentencepiece.TrainerSpec.ModelType model_type = 3 [default = UNIGRAM]; + if (cached_has_bits & 0x00400000u) { total_size += 1 + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::EnumSize(this->_internal_model_type()); } // optional int32 vocab_size = 4 [default = 8000]; - if (cached_has_bits & 0x00100000u) { + if (cached_has_bits & 0x00800000u) { total_size += 1 + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( this->_internal_vocab_size()); } + } + if (cached_has_bits & 0xff000000u) { // optional float character_coverage = 10 [default = 0.9995]; - if (cached_has_bits & 0x00200000u) { + if (cached_has_bits & 0x01000000u) { total_size += 1 + 4; } // optional int32 seed_sentencepiece_size = 14 [default = 1000000]; - if (cached_has_bits & 0x00400000u) { + if (cached_has_bits & 0x02000000u) { total_size += 1 + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( this->_internal_seed_sentencepiece_size()); } // optional float shrinking_factor = 15 [default = 0.75]; - if (cached_has_bits & 0x00800000u) { + if (cached_has_bits & 0x04000000u) { total_size += 1 + 4; } - } - if (cached_has_bits & 0xff000000u) { // optional int32 num_threads = 16 [default = 16]; - if (cached_has_bits & 0x01000000u) { + if (cached_has_bits & 0x08000000u) { total_size += 2 + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( this->_internal_num_threads()); } // optional int32 num_sub_iterations = 17 [default = 2]; - if (cached_has_bits & 0x02000000u) { + if (cached_has_bits & 0x10000000u) { total_size += 2 + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( this->_internal_num_sub_iterations()); } // optional int32 max_sentence_length = 18 [default = 4192]; - if (cached_has_bits & 0x04000000u) { + if (cached_has_bits & 0x20000000u) { total_size += 2 + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( this->_internal_max_sentence_length()); } // optional int32 max_sentencepiece_length = 20 [default = 16]; - if (cached_has_bits & 0x08000000u) { + if (cached_has_bits & 0x40000000u) { total_size += 2 + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( this->_internal_max_sentencepiece_length()); } // optional bool shuffle_input_sentence = 19 [default = true]; - if (cached_has_bits & 0x10000000u) { - total_size += 2 + 1; - } - - // optional bool split_by_unicode_script = 21 [default = true]; - if (cached_has_bits & 0x20000000u) { - total_size += 2 + 1; - } - - // optional bool split_by_number = 23 [default = true]; - if (cached_has_bits & 0x40000000u) { - total_size += 2 + 1; - } - - // optional bool split_by_whitespace = 22 [default = true]; if (cached_has_bits & 0x80000000u) { total_size += 2 + 1; } } cached_has_bits = _has_bits_[1]; - if (cached_has_bits & 0x0000001fu) { - // optional bool vocabulary_output_piece_score = 32 [default = true]; + if (cached_has_bits & 0x000000ffu) { + // optional bool split_by_unicode_script = 21 [default = true]; if (cached_has_bits & 0x00000001u) { total_size += 2 + 1; } - // optional bool hard_vocab_limit = 33 [default = true]; + // optional bool split_by_number = 23 [default = true]; if (cached_has_bits & 0x00000002u) { total_size += 2 + 1; } - // optional int32 bos_id = 41 [default = 1]; + // optional bool split_by_whitespace = 22 [default = true]; if (cached_has_bits & 0x00000004u) { + total_size += 2 + 1; + } + + // optional bool vocabulary_output_piece_score = 32 [default = true]; + if (cached_has_bits & 0x00000008u) { + total_size += 2 + 1; + } + + // optional bool hard_vocab_limit = 33 [default = true]; + if (cached_has_bits & 0x00000010u) { + total_size += 2 + 1; + } + + // optional int32 bos_id = 41 [default = 1]; + if (cached_has_bits & 0x00000020u) { total_size += 2 + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( this->_internal_bos_id()); } // optional int32 eos_id = 42 [default = 2]; - if (cached_has_bits & 0x00000008u) { + if (cached_has_bits & 0x00000040u) { total_size += 2 + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( this->_internal_eos_id()); } // optional int32 pad_id = 43 [default = -1]; - if (cached_has_bits & 0x00000010u) { + if (cached_has_bits & 0x00000080u) { total_size += 2 + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( this->_internal_pad_id()); @@ -1612,22 +1682,22 @@ void TrainerSpec::MergeFrom(const TrainerSpec& from) { training_sentence_size_ = from.training_sentence_size_; } if (cached_has_bits & 0x00001000u) { - treat_whitespace_as_suffix_ = from.treat_whitespace_as_suffix_; + enable_differential_privacy_ = from.enable_differential_privacy_; } if (cached_has_bits & 0x00002000u) { - allow_whitespace_only_pieces_ = from.allow_whitespace_only_pieces_; + treat_whitespace_as_suffix_ = from.treat_whitespace_as_suffix_; } if (cached_has_bits & 0x00004000u) { - split_digits_ = from.split_digits_; + allow_whitespace_only_pieces_ = from.allow_whitespace_only_pieces_; } if (cached_has_bits & 0x00008000u) { - byte_fallback_ = from.byte_fallback_; + split_digits_ = from.split_digits_; } _has_bits_[0] |= cached_has_bits; } if (cached_has_bits & 0x00ff0000u) { if (cached_has_bits & 0x00010000u) { - unk_id_ = from.unk_id_; + byte_fallback_ = from.byte_fallback_; } if (cached_has_bits & 0x00020000u) { use_all_vocab_ = from.use_all_vocab_; @@ -1636,64 +1706,73 @@ void TrainerSpec::MergeFrom(const TrainerSpec& from) { train_extremely_large_corpus_ = from.train_extremely_large_corpus_; } if (cached_has_bits & 0x00080000u) { - model_type_ = from.model_type_; + unk_id_ = from.unk_id_; } if (cached_has_bits & 0x00100000u) { - vocab_size_ = from.vocab_size_; + differential_privacy_noise_level_ = from.differential_privacy_noise_level_; } if (cached_has_bits & 0x00200000u) { - character_coverage_ = from.character_coverage_; + differential_privacy_clipping_threshold_ = from.differential_privacy_clipping_threshold_; } if (cached_has_bits & 0x00400000u) { - seed_sentencepiece_size_ = from.seed_sentencepiece_size_; + model_type_ = from.model_type_; } if (cached_has_bits & 0x00800000u) { - shrinking_factor_ = from.shrinking_factor_; + vocab_size_ = from.vocab_size_; } _has_bits_[0] |= cached_has_bits; } if (cached_has_bits & 0xff000000u) { if (cached_has_bits & 0x01000000u) { - num_threads_ = from.num_threads_; + character_coverage_ = from.character_coverage_; } if (cached_has_bits & 0x02000000u) { - num_sub_iterations_ = from.num_sub_iterations_; + seed_sentencepiece_size_ = from.seed_sentencepiece_size_; } if (cached_has_bits & 0x04000000u) { - max_sentence_length_ = from.max_sentence_length_; + shrinking_factor_ = from.shrinking_factor_; } if (cached_has_bits & 0x08000000u) { - max_sentencepiece_length_ = from.max_sentencepiece_length_; + num_threads_ = from.num_threads_; } if (cached_has_bits & 0x10000000u) { - shuffle_input_sentence_ = from.shuffle_input_sentence_; + num_sub_iterations_ = from.num_sub_iterations_; } if (cached_has_bits & 0x20000000u) { - split_by_unicode_script_ = from.split_by_unicode_script_; + max_sentence_length_ = from.max_sentence_length_; } if (cached_has_bits & 0x40000000u) { - split_by_number_ = from.split_by_number_; + max_sentencepiece_length_ = from.max_sentencepiece_length_; } if (cached_has_bits & 0x80000000u) { - split_by_whitespace_ = from.split_by_whitespace_; + shuffle_input_sentence_ = from.shuffle_input_sentence_; } _has_bits_[0] |= cached_has_bits; } cached_has_bits = from._has_bits_[1]; - if (cached_has_bits & 0x0000001fu) { + if (cached_has_bits & 0x000000ffu) { if (cached_has_bits & 0x00000001u) { - vocabulary_output_piece_score_ = from.vocabulary_output_piece_score_; + split_by_unicode_script_ = from.split_by_unicode_script_; } if (cached_has_bits & 0x00000002u) { - hard_vocab_limit_ = from.hard_vocab_limit_; + split_by_number_ = from.split_by_number_; } if (cached_has_bits & 0x00000004u) { - bos_id_ = from.bos_id_; + split_by_whitespace_ = from.split_by_whitespace_; } if (cached_has_bits & 0x00000008u) { - eos_id_ = from.eos_id_; + vocabulary_output_piece_score_ = from.vocabulary_output_piece_score_; } if (cached_has_bits & 0x00000010u) { + hard_vocab_limit_ = from.hard_vocab_limit_; + } + if (cached_has_bits & 0x00000020u) { + bos_id_ = from.bos_id_; + } + if (cached_has_bits & 0x00000040u) { + eos_id_ = from.eos_id_; + } + if (cached_has_bits & 0x00000080u) { pad_id_ = from.pad_id_; } _has_bits_[1] |= cached_has_bits; @@ -1734,8 +1813,8 @@ void TrainerSpec::InternalSwap(TrainerSpec* other) { eos_piece_.Swap(&other->eos_piece_, nullptr, GetArena()); pad_piece_.Swap(&other->pad_piece_, nullptr, GetArena()); ::PROTOBUF_NAMESPACE_ID::internal::memswap< - PROTOBUF_FIELD_OFFSET(TrainerSpec, train_extremely_large_corpus_) - + sizeof(TrainerSpec::train_extremely_large_corpus_) + PROTOBUF_FIELD_OFFSET(TrainerSpec, differential_privacy_clipping_threshold_) + + sizeof(TrainerSpec::differential_privacy_clipping_threshold_) - PROTOBUF_FIELD_OFFSET(TrainerSpec, self_test_sample_size_)>( reinterpret_cast(&self_test_sample_size_), reinterpret_cast(&other->self_test_sample_size_)); diff --git a/src/builtin_pb/sentencepiece_model.pb.h b/src/builtin_pb/sentencepiece_model.pb.h index f527aa7..84450e6 100644 --- a/src/builtin_pb/sentencepiece_model.pb.h +++ b/src/builtin_pb/sentencepiece_model.pb.h @@ -277,13 +277,16 @@ class TrainerSpec PROTOBUF_FINAL : kMiningSentenceSizeFieldNumber = 12, kInputSentenceSizeFieldNumber = 11, kTrainingSentenceSizeFieldNumber = 13, + kEnableDifferentialPrivacyFieldNumber = 50, kTreatWhitespaceAsSuffixFieldNumber = 24, kAllowWhitespaceOnlyPiecesFieldNumber = 26, kSplitDigitsFieldNumber = 25, kByteFallbackFieldNumber = 35, - kUnkIdFieldNumber = 40, kUseAllVocabFieldNumber = 34, kTrainExtremelyLargeCorpusFieldNumber = 49, + kUnkIdFieldNumber = 40, + kDifferentialPrivacyNoiseLevelFieldNumber = 51, + kDifferentialPrivacyClippingThresholdFieldNumber = 52, kModelTypeFieldNumber = 3, kVocabSizeFieldNumber = 4, kCharacterCoverageFieldNumber = 10, @@ -611,6 +614,19 @@ class TrainerSpec PROTOBUF_FINAL : void _internal_set_training_sentence_size(::PROTOBUF_NAMESPACE_ID::int32 value); public: + // optional bool enable_differential_privacy = 50 [default = false]; + bool has_enable_differential_privacy() const; + private: + bool _internal_has_enable_differential_privacy() const; + public: + void clear_enable_differential_privacy(); + bool enable_differential_privacy() const; + void set_enable_differential_privacy(bool value); + private: + bool _internal_enable_differential_privacy() const; + void _internal_set_enable_differential_privacy(bool value); + public: + // optional bool treat_whitespace_as_suffix = 24 [default = false]; bool has_treat_whitespace_as_suffix() const; private: @@ -663,19 +679,6 @@ class TrainerSpec PROTOBUF_FINAL : void _internal_set_byte_fallback(bool value); public: - // optional int32 unk_id = 40 [default = 0]; - bool has_unk_id() const; - private: - bool _internal_has_unk_id() const; - public: - void clear_unk_id(); - ::PROTOBUF_NAMESPACE_ID::int32 unk_id() const; - void set_unk_id(::PROTOBUF_NAMESPACE_ID::int32 value); - private: - ::PROTOBUF_NAMESPACE_ID::int32 _internal_unk_id() const; - void _internal_set_unk_id(::PROTOBUF_NAMESPACE_ID::int32 value); - public: - // optional bool use_all_vocab = 34 [default = false]; bool has_use_all_vocab() const; private: @@ -702,6 +705,45 @@ class TrainerSpec PROTOBUF_FINAL : void _internal_set_train_extremely_large_corpus(bool value); public: + // optional int32 unk_id = 40 [default = 0]; + bool has_unk_id() const; + private: + bool _internal_has_unk_id() const; + public: + void clear_unk_id(); + ::PROTOBUF_NAMESPACE_ID::int32 unk_id() const; + void set_unk_id(::PROTOBUF_NAMESPACE_ID::int32 value); + private: + ::PROTOBUF_NAMESPACE_ID::int32 _internal_unk_id() const; + void _internal_set_unk_id(::PROTOBUF_NAMESPACE_ID::int32 value); + public: + + // optional float differential_privacy_noise_level = 51 [default = 0]; + bool has_differential_privacy_noise_level() const; + private: + bool _internal_has_differential_privacy_noise_level() const; + public: + void clear_differential_privacy_noise_level(); + float differential_privacy_noise_level() const; + void set_differential_privacy_noise_level(float value); + private: + float _internal_differential_privacy_noise_level() const; + void _internal_set_differential_privacy_noise_level(float value); + public: + + // optional uint64 differential_privacy_clipping_threshold = 52 [default = 0]; + bool has_differential_privacy_clipping_threshold() const; + private: + bool _internal_has_differential_privacy_clipping_threshold() const; + public: + void clear_differential_privacy_clipping_threshold(); + ::PROTOBUF_NAMESPACE_ID::uint64 differential_privacy_clipping_threshold() const; + void set_differential_privacy_clipping_threshold(::PROTOBUF_NAMESPACE_ID::uint64 value); + private: + ::PROTOBUF_NAMESPACE_ID::uint64 _internal_differential_privacy_clipping_threshold() const; + void _internal_set_differential_privacy_clipping_threshold(::PROTOBUF_NAMESPACE_ID::uint64 value); + public: + // optional .sentencepiece.TrainerSpec.ModelType model_type = 3 [default = UNIGRAM]; bool has_model_type() const; private: @@ -969,13 +1011,16 @@ class TrainerSpec PROTOBUF_FINAL : ::PROTOBUF_NAMESPACE_ID::int32 mining_sentence_size_; ::PROTOBUF_NAMESPACE_ID::uint64 input_sentence_size_; ::PROTOBUF_NAMESPACE_ID::int32 training_sentence_size_; + bool enable_differential_privacy_; bool treat_whitespace_as_suffix_; bool allow_whitespace_only_pieces_; bool split_digits_; bool byte_fallback_; - ::PROTOBUF_NAMESPACE_ID::int32 unk_id_; bool use_all_vocab_; bool train_extremely_large_corpus_; + ::PROTOBUF_NAMESPACE_ID::int32 unk_id_; + float differential_privacy_noise_level_; + ::PROTOBUF_NAMESPACE_ID::uint64 differential_privacy_clipping_threshold_; int model_type_; ::PROTOBUF_NAMESPACE_ID::int32 vocab_size_; float character_coverage_; @@ -2195,7 +2240,7 @@ inline void TrainerSpec::set_allocated_model_prefix(std::string* model_prefix) { // optional .sentencepiece.TrainerSpec.ModelType model_type = 3 [default = UNIGRAM]; inline bool TrainerSpec::_internal_has_model_type() const { - bool value = (_has_bits_[0] & 0x00080000u) != 0; + bool value = (_has_bits_[0] & 0x00400000u) != 0; return value; } inline bool TrainerSpec::has_model_type() const { @@ -2203,7 +2248,7 @@ inline bool TrainerSpec::has_model_type() const { } inline void TrainerSpec::clear_model_type() { model_type_ = 1; - _has_bits_[0] &= ~0x00080000u; + _has_bits_[0] &= ~0x00400000u; } inline ::sentencepiece::TrainerSpec_ModelType TrainerSpec::_internal_model_type() const { return static_cast< ::sentencepiece::TrainerSpec_ModelType >(model_type_); @@ -2214,7 +2259,7 @@ inline ::sentencepiece::TrainerSpec_ModelType TrainerSpec::model_type() const { } inline void TrainerSpec::_internal_set_model_type(::sentencepiece::TrainerSpec_ModelType value) { assert(::sentencepiece::TrainerSpec_ModelType_IsValid(value)); - _has_bits_[0] |= 0x00080000u; + _has_bits_[0] |= 0x00400000u; model_type_ = value; } inline void TrainerSpec::set_model_type(::sentencepiece::TrainerSpec_ModelType value) { @@ -2224,7 +2269,7 @@ inline void TrainerSpec::set_model_type(::sentencepiece::TrainerSpec_ModelType v // optional int32 vocab_size = 4 [default = 8000]; inline bool TrainerSpec::_internal_has_vocab_size() const { - bool value = (_has_bits_[0] & 0x00100000u) != 0; + bool value = (_has_bits_[0] & 0x00800000u) != 0; return value; } inline bool TrainerSpec::has_vocab_size() const { @@ -2232,7 +2277,7 @@ inline bool TrainerSpec::has_vocab_size() const { } inline void TrainerSpec::clear_vocab_size() { vocab_size_ = 8000; - _has_bits_[0] &= ~0x00100000u; + _has_bits_[0] &= ~0x00800000u; } inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_vocab_size() const { return vocab_size_; @@ -2242,7 +2287,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::vocab_size() const { return _internal_vocab_size(); } inline void TrainerSpec::_internal_set_vocab_size(::PROTOBUF_NAMESPACE_ID::int32 value) { - _has_bits_[0] |= 0x00100000u; + _has_bits_[0] |= 0x00800000u; vocab_size_ = value; } inline void TrainerSpec::set_vocab_size(::PROTOBUF_NAMESPACE_ID::int32 value) { @@ -2352,9 +2397,93 @@ inline void TrainerSpec::set_self_test_sample_size(::PROTOBUF_NAMESPACE_ID::int3 // @@protoc_insertion_point(field_set:sentencepiece.TrainerSpec.self_test_sample_size) } +// optional bool enable_differential_privacy = 50 [default = false]; +inline bool TrainerSpec::_internal_has_enable_differential_privacy() const { + bool value = (_has_bits_[0] & 0x00001000u) != 0; + return value; +} +inline bool TrainerSpec::has_enable_differential_privacy() const { + return _internal_has_enable_differential_privacy(); +} +inline void TrainerSpec::clear_enable_differential_privacy() { + enable_differential_privacy_ = false; + _has_bits_[0] &= ~0x00001000u; +} +inline bool TrainerSpec::_internal_enable_differential_privacy() const { + return enable_differential_privacy_; +} +inline bool TrainerSpec::enable_differential_privacy() const { + // @@protoc_insertion_point(field_get:sentencepiece.TrainerSpec.enable_differential_privacy) + return _internal_enable_differential_privacy(); +} +inline void TrainerSpec::_internal_set_enable_differential_privacy(bool value) { + _has_bits_[0] |= 0x00001000u; + enable_differential_privacy_ = value; +} +inline void TrainerSpec::set_enable_differential_privacy(bool value) { + _internal_set_enable_differential_privacy(value); + // @@protoc_insertion_point(field_set:sentencepiece.TrainerSpec.enable_differential_privacy) +} + +// optional float differential_privacy_noise_level = 51 [default = 0]; +inline bool TrainerSpec::_internal_has_differential_privacy_noise_level() const { + bool value = (_has_bits_[0] & 0x00100000u) != 0; + return value; +} +inline bool TrainerSpec::has_differential_privacy_noise_level() const { + return _internal_has_differential_privacy_noise_level(); +} +inline void TrainerSpec::clear_differential_privacy_noise_level() { + differential_privacy_noise_level_ = 0; + _has_bits_[0] &= ~0x00100000u; +} +inline float TrainerSpec::_internal_differential_privacy_noise_level() const { + return differential_privacy_noise_level_; +} +inline float TrainerSpec::differential_privacy_noise_level() const { + // @@protoc_insertion_point(field_get:sentencepiece.TrainerSpec.differential_privacy_noise_level) + return _internal_differential_privacy_noise_level(); +} +inline void TrainerSpec::_internal_set_differential_privacy_noise_level(float value) { + _has_bits_[0] |= 0x00100000u; + differential_privacy_noise_level_ = value; +} +inline void TrainerSpec::set_differential_privacy_noise_level(float value) { + _internal_set_differential_privacy_noise_level(value); + // @@protoc_insertion_point(field_set:sentencepiece.TrainerSpec.differential_privacy_noise_level) +} + +// optional uint64 differential_privacy_clipping_threshold = 52 [default = 0]; +inline bool TrainerSpec::_internal_has_differential_privacy_clipping_threshold() const { + bool value = (_has_bits_[0] & 0x00200000u) != 0; + return value; +} +inline bool TrainerSpec::has_differential_privacy_clipping_threshold() const { + return _internal_has_differential_privacy_clipping_threshold(); +} +inline void TrainerSpec::clear_differential_privacy_clipping_threshold() { + differential_privacy_clipping_threshold_ = PROTOBUF_ULONGLONG(0); + _has_bits_[0] &= ~0x00200000u; +} +inline ::PROTOBUF_NAMESPACE_ID::uint64 TrainerSpec::_internal_differential_privacy_clipping_threshold() const { + return differential_privacy_clipping_threshold_; +} +inline ::PROTOBUF_NAMESPACE_ID::uint64 TrainerSpec::differential_privacy_clipping_threshold() const { + // @@protoc_insertion_point(field_get:sentencepiece.TrainerSpec.differential_privacy_clipping_threshold) + return _internal_differential_privacy_clipping_threshold(); +} +inline void TrainerSpec::_internal_set_differential_privacy_clipping_threshold(::PROTOBUF_NAMESPACE_ID::uint64 value) { + _has_bits_[0] |= 0x00200000u; + differential_privacy_clipping_threshold_ = value; +} +inline void TrainerSpec::set_differential_privacy_clipping_threshold(::PROTOBUF_NAMESPACE_ID::uint64 value) { + _internal_set_differential_privacy_clipping_threshold(value); + // @@protoc_insertion_point(field_set:sentencepiece.TrainerSpec.differential_privacy_clipping_threshold) +} + // optional float character_coverage = 10 [default = 0.9995]; inline bool TrainerSpec::_internal_has_character_coverage() const { - bool value = (_has_bits_[0] & 0x00200000u) != 0; + bool value = (_has_bits_[0] & 0x01000000u) != 0; return value; } inline bool TrainerSpec::has_character_coverage() const { @@ -2362,7 +2491,7 @@ inline bool TrainerSpec::has_character_coverage() const { } inline void TrainerSpec::clear_character_coverage() { character_coverage_ = 0.9995f; - _has_bits_[0] &= ~0x00200000u; + _has_bits_[0] &= ~0x01000000u; } inline float TrainerSpec::_internal_character_coverage() const { return character_coverage_; @@ -2372,7 +2501,7 @@ inline float TrainerSpec::character_coverage() const { return _internal_character_coverage(); } inline void TrainerSpec::_internal_set_character_coverage(float value) { - _has_bits_[0] |= 0x00200000u; + _has_bits_[0] |= 0x01000000u; character_coverage_ = value; } inline void TrainerSpec::set_character_coverage(float value) { @@ -2410,7 +2539,7 @@ inline void TrainerSpec::set_input_sentence_size(::PROTOBUF_NAMESPACE_ID::uint64 // optional bool shuffle_input_sentence = 19 [default = true]; inline bool TrainerSpec::_internal_has_shuffle_input_sentence() const { - bool value = (_has_bits_[0] & 0x10000000u) != 0; + bool value = (_has_bits_[0] & 0x80000000u) != 0; return value; } inline bool TrainerSpec::has_shuffle_input_sentence() const { @@ -2418,7 +2547,7 @@ inline bool TrainerSpec::has_shuffle_input_sentence() const { } inline void TrainerSpec::clear_shuffle_input_sentence() { shuffle_input_sentence_ = true; - _has_bits_[0] &= ~0x10000000u; + _has_bits_[0] &= ~0x80000000u; } inline bool TrainerSpec::_internal_shuffle_input_sentence() const { return shuffle_input_sentence_; @@ -2428,7 +2557,7 @@ inline bool TrainerSpec::shuffle_input_sentence() const { return _internal_shuffle_input_sentence(); } inline void TrainerSpec::_internal_set_shuffle_input_sentence(bool value) { - _has_bits_[0] |= 0x10000000u; + _has_bits_[0] |= 0x80000000u; shuffle_input_sentence_ = value; } inline void TrainerSpec::set_shuffle_input_sentence(bool value) { @@ -2494,7 +2623,7 @@ inline void TrainerSpec::set_training_sentence_size(::PROTOBUF_NAMESPACE_ID::int // optional int32 seed_sentencepiece_size = 14 [default = 1000000]; inline bool TrainerSpec::_internal_has_seed_sentencepiece_size() const { - bool value = (_has_bits_[0] & 0x00400000u) != 0; + bool value = (_has_bits_[0] & 0x02000000u) != 0; return value; } inline bool TrainerSpec::has_seed_sentencepiece_size() const { @@ -2502,7 +2631,7 @@ inline bool TrainerSpec::has_seed_sentencepiece_size() const { } inline void TrainerSpec::clear_seed_sentencepiece_size() { seed_sentencepiece_size_ = 1000000; - _has_bits_[0] &= ~0x00400000u; + _has_bits_[0] &= ~0x02000000u; } inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_seed_sentencepiece_size() const { return seed_sentencepiece_size_; @@ -2512,7 +2641,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::seed_sentencepiece_size() con return _internal_seed_sentencepiece_size(); } inline void TrainerSpec::_internal_set_seed_sentencepiece_size(::PROTOBUF_NAMESPACE_ID::int32 value) { - _has_bits_[0] |= 0x00400000u; + _has_bits_[0] |= 0x02000000u; seed_sentencepiece_size_ = value; } inline void TrainerSpec::set_seed_sentencepiece_size(::PROTOBUF_NAMESPACE_ID::int32 value) { @@ -2522,7 +2651,7 @@ inline void TrainerSpec::set_seed_sentencepiece_size(::PROTOBUF_NAMESPACE_ID::in // optional float shrinking_factor = 15 [default = 0.75]; inline bool TrainerSpec::_internal_has_shrinking_factor() const { - bool value = (_has_bits_[0] & 0x00800000u) != 0; + bool value = (_has_bits_[0] & 0x04000000u) != 0; return value; } inline bool TrainerSpec::has_shrinking_factor() const { @@ -2530,7 +2659,7 @@ inline bool TrainerSpec::has_shrinking_factor() const { } inline void TrainerSpec::clear_shrinking_factor() { shrinking_factor_ = 0.75f; - _has_bits_[0] &= ~0x00800000u; + _has_bits_[0] &= ~0x04000000u; } inline float TrainerSpec::_internal_shrinking_factor() const { return shrinking_factor_; @@ -2540,7 +2669,7 @@ inline float TrainerSpec::shrinking_factor() const { return _internal_shrinking_factor(); } inline void TrainerSpec::_internal_set_shrinking_factor(float value) { - _has_bits_[0] |= 0x00800000u; + _has_bits_[0] |= 0x04000000u; shrinking_factor_ = value; } inline void TrainerSpec::set_shrinking_factor(float value) { @@ -2550,7 +2679,7 @@ inline void TrainerSpec::set_shrinking_factor(float value) { // optional int32 max_sentence_length = 18 [default = 4192]; inline bool TrainerSpec::_internal_has_max_sentence_length() const { - bool value = (_has_bits_[0] & 0x04000000u) != 0; + bool value = (_has_bits_[0] & 0x20000000u) != 0; return value; } inline bool TrainerSpec::has_max_sentence_length() const { @@ -2558,7 +2687,7 @@ inline bool TrainerSpec::has_max_sentence_length() const { } inline void TrainerSpec::clear_max_sentence_length() { max_sentence_length_ = 4192; - _has_bits_[0] &= ~0x04000000u; + _has_bits_[0] &= ~0x20000000u; } inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_max_sentence_length() const { return max_sentence_length_; @@ -2568,7 +2697,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::max_sentence_length() const { return _internal_max_sentence_length(); } inline void TrainerSpec::_internal_set_max_sentence_length(::PROTOBUF_NAMESPACE_ID::int32 value) { - _has_bits_[0] |= 0x04000000u; + _has_bits_[0] |= 0x20000000u; max_sentence_length_ = value; } inline void TrainerSpec::set_max_sentence_length(::PROTOBUF_NAMESPACE_ID::int32 value) { @@ -2578,7 +2707,7 @@ inline void TrainerSpec::set_max_sentence_length(::PROTOBUF_NAMESPACE_ID::int32 // optional int32 num_threads = 16 [default = 16]; inline bool TrainerSpec::_internal_has_num_threads() const { - bool value = (_has_bits_[0] & 0x01000000u) != 0; + bool value = (_has_bits_[0] & 0x08000000u) != 0; return value; } inline bool TrainerSpec::has_num_threads() const { @@ -2586,7 +2715,7 @@ inline bool TrainerSpec::has_num_threads() const { } inline void TrainerSpec::clear_num_threads() { num_threads_ = 16; - _has_bits_[0] &= ~0x01000000u; + _has_bits_[0] &= ~0x08000000u; } inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_num_threads() const { return num_threads_; @@ -2596,7 +2725,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::num_threads() const { return _internal_num_threads(); } inline void TrainerSpec::_internal_set_num_threads(::PROTOBUF_NAMESPACE_ID::int32 value) { - _has_bits_[0] |= 0x01000000u; + _has_bits_[0] |= 0x08000000u; num_threads_ = value; } inline void TrainerSpec::set_num_threads(::PROTOBUF_NAMESPACE_ID::int32 value) { @@ -2606,7 +2735,7 @@ inline void TrainerSpec::set_num_threads(::PROTOBUF_NAMESPACE_ID::int32 value) { // optional int32 num_sub_iterations = 17 [default = 2]; inline bool TrainerSpec::_internal_has_num_sub_iterations() const { - bool value = (_has_bits_[0] & 0x02000000u) != 0; + bool value = (_has_bits_[0] & 0x10000000u) != 0; return value; } inline bool TrainerSpec::has_num_sub_iterations() const { @@ -2614,7 +2743,7 @@ inline bool TrainerSpec::has_num_sub_iterations() const { } inline void TrainerSpec::clear_num_sub_iterations() { num_sub_iterations_ = 2; - _has_bits_[0] &= ~0x02000000u; + _has_bits_[0] &= ~0x10000000u; } inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_num_sub_iterations() const { return num_sub_iterations_; @@ -2624,7 +2753,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::num_sub_iterations() const { return _internal_num_sub_iterations(); } inline void TrainerSpec::_internal_set_num_sub_iterations(::PROTOBUF_NAMESPACE_ID::int32 value) { - _has_bits_[0] |= 0x02000000u; + _has_bits_[0] |= 0x10000000u; num_sub_iterations_ = value; } inline void TrainerSpec::set_num_sub_iterations(::PROTOBUF_NAMESPACE_ID::int32 value) { @@ -2634,7 +2763,7 @@ inline void TrainerSpec::set_num_sub_iterations(::PROTOBUF_NAMESPACE_ID::int32 v // optional int32 max_sentencepiece_length = 20 [default = 16]; inline bool TrainerSpec::_internal_has_max_sentencepiece_length() const { - bool value = (_has_bits_[0] & 0x08000000u) != 0; + bool value = (_has_bits_[0] & 0x40000000u) != 0; return value; } inline bool TrainerSpec::has_max_sentencepiece_length() const { @@ -2642,7 +2771,7 @@ inline bool TrainerSpec::has_max_sentencepiece_length() const { } inline void TrainerSpec::clear_max_sentencepiece_length() { max_sentencepiece_length_ = 16; - _has_bits_[0] &= ~0x08000000u; + _has_bits_[0] &= ~0x40000000u; } inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_max_sentencepiece_length() const { return max_sentencepiece_length_; @@ -2652,7 +2781,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::max_sentencepiece_length() co return _internal_max_sentencepiece_length(); } inline void TrainerSpec::_internal_set_max_sentencepiece_length(::PROTOBUF_NAMESPACE_ID::int32 value) { - _has_bits_[0] |= 0x08000000u; + _has_bits_[0] |= 0x40000000u; max_sentencepiece_length_ = value; } inline void TrainerSpec::set_max_sentencepiece_length(::PROTOBUF_NAMESPACE_ID::int32 value) { @@ -2662,7 +2791,7 @@ inline void TrainerSpec::set_max_sentencepiece_length(::PROTOBUF_NAMESPACE_ID::i // optional bool split_by_unicode_script = 21 [default = true]; inline bool TrainerSpec::_internal_has_split_by_unicode_script() const { - bool value = (_has_bits_[0] & 0x20000000u) != 0; + bool value = (_has_bits_[1] & 0x00000001u) != 0; return value; } inline bool TrainerSpec::has_split_by_unicode_script() const { @@ -2670,7 +2799,7 @@ inline bool TrainerSpec::has_split_by_unicode_script() const { } inline void TrainerSpec::clear_split_by_unicode_script() { split_by_unicode_script_ = true; - _has_bits_[0] &= ~0x20000000u; + _has_bits_[1] &= ~0x00000001u; } inline bool TrainerSpec::_internal_split_by_unicode_script() const { return split_by_unicode_script_; @@ -2680,7 +2809,7 @@ inline bool TrainerSpec::split_by_unicode_script() const { return _internal_split_by_unicode_script(); } inline void TrainerSpec::_internal_set_split_by_unicode_script(bool value) { - _has_bits_[0] |= 0x20000000u; + _has_bits_[1] |= 0x00000001u; split_by_unicode_script_ = value; } inline void TrainerSpec::set_split_by_unicode_script(bool value) { @@ -2690,7 +2819,7 @@ inline void TrainerSpec::set_split_by_unicode_script(bool value) { // optional bool split_by_number = 23 [default = true]; inline bool TrainerSpec::_internal_has_split_by_number() const { - bool value = (_has_bits_[0] & 0x40000000u) != 0; + bool value = (_has_bits_[1] & 0x00000002u) != 0; return value; } inline bool TrainerSpec::has_split_by_number() const { @@ -2698,7 +2827,7 @@ inline bool TrainerSpec::has_split_by_number() const { } inline void TrainerSpec::clear_split_by_number() { split_by_number_ = true; - _has_bits_[0] &= ~0x40000000u; + _has_bits_[1] &= ~0x00000002u; } inline bool TrainerSpec::_internal_split_by_number() const { return split_by_number_; @@ -2708,7 +2837,7 @@ inline bool TrainerSpec::split_by_number() const { return _internal_split_by_number(); } inline void TrainerSpec::_internal_set_split_by_number(bool value) { - _has_bits_[0] |= 0x40000000u; + _has_bits_[1] |= 0x00000002u; split_by_number_ = value; } inline void TrainerSpec::set_split_by_number(bool value) { @@ -2718,7 +2847,7 @@ inline void TrainerSpec::set_split_by_number(bool value) { // optional bool split_by_whitespace = 22 [default = true]; inline bool TrainerSpec::_internal_has_split_by_whitespace() const { - bool value = (_has_bits_[0] & 0x80000000u) != 0; + bool value = (_has_bits_[1] & 0x00000004u) != 0; return value; } inline bool TrainerSpec::has_split_by_whitespace() const { @@ -2726,7 +2855,7 @@ inline bool TrainerSpec::has_split_by_whitespace() const { } inline void TrainerSpec::clear_split_by_whitespace() { split_by_whitespace_ = true; - _has_bits_[0] &= ~0x80000000u; + _has_bits_[1] &= ~0x00000004u; } inline bool TrainerSpec::_internal_split_by_whitespace() const { return split_by_whitespace_; @@ -2736,7 +2865,7 @@ inline bool TrainerSpec::split_by_whitespace() const { return _internal_split_by_whitespace(); } inline void TrainerSpec::_internal_set_split_by_whitespace(bool value) { - _has_bits_[0] |= 0x80000000u; + _has_bits_[1] |= 0x00000004u; split_by_whitespace_ = value; } inline void TrainerSpec::set_split_by_whitespace(bool value) { @@ -2746,7 +2875,7 @@ inline void TrainerSpec::set_split_by_whitespace(bool value) { // optional bool treat_whitespace_as_suffix = 24 [default = false]; inline bool TrainerSpec::_internal_has_treat_whitespace_as_suffix() const { - bool value = (_has_bits_[0] & 0x00001000u) != 0; + bool value = (_has_bits_[0] & 0x00002000u) != 0; return value; } inline bool TrainerSpec::has_treat_whitespace_as_suffix() const { @@ -2754,7 +2883,7 @@ inline bool TrainerSpec::has_treat_whitespace_as_suffix() const { } inline void TrainerSpec::clear_treat_whitespace_as_suffix() { treat_whitespace_as_suffix_ = false; - _has_bits_[0] &= ~0x00001000u; + _has_bits_[0] &= ~0x00002000u; } inline bool TrainerSpec::_internal_treat_whitespace_as_suffix() const { return treat_whitespace_as_suffix_; @@ -2764,7 +2893,7 @@ inline bool TrainerSpec::treat_whitespace_as_suffix() const { return _internal_treat_whitespace_as_suffix(); } inline void TrainerSpec::_internal_set_treat_whitespace_as_suffix(bool value) { - _has_bits_[0] |= 0x00001000u; + _has_bits_[0] |= 0x00002000u; treat_whitespace_as_suffix_ = value; } inline void TrainerSpec::set_treat_whitespace_as_suffix(bool value) { @@ -2774,7 +2903,7 @@ inline void TrainerSpec::set_treat_whitespace_as_suffix(bool value) { // optional bool allow_whitespace_only_pieces = 26 [default = false]; inline bool TrainerSpec::_internal_has_allow_whitespace_only_pieces() const { - bool value = (_has_bits_[0] & 0x00002000u) != 0; + bool value = (_has_bits_[0] & 0x00004000u) != 0; return value; } inline bool TrainerSpec::has_allow_whitespace_only_pieces() const { @@ -2782,7 +2911,7 @@ inline bool TrainerSpec::has_allow_whitespace_only_pieces() const { } inline void TrainerSpec::clear_allow_whitespace_only_pieces() { allow_whitespace_only_pieces_ = false; - _has_bits_[0] &= ~0x00002000u; + _has_bits_[0] &= ~0x00004000u; } inline bool TrainerSpec::_internal_allow_whitespace_only_pieces() const { return allow_whitespace_only_pieces_; @@ -2792,7 +2921,7 @@ inline bool TrainerSpec::allow_whitespace_only_pieces() const { return _internal_allow_whitespace_only_pieces(); } inline void TrainerSpec::_internal_set_allow_whitespace_only_pieces(bool value) { - _has_bits_[0] |= 0x00002000u; + _has_bits_[0] |= 0x00004000u; allow_whitespace_only_pieces_ = value; } inline void TrainerSpec::set_allow_whitespace_only_pieces(bool value) { @@ -2802,7 +2931,7 @@ inline void TrainerSpec::set_allow_whitespace_only_pieces(bool value) { // optional bool split_digits = 25 [default = false]; inline bool TrainerSpec::_internal_has_split_digits() const { - bool value = (_has_bits_[0] & 0x00004000u) != 0; + bool value = (_has_bits_[0] & 0x00008000u) != 0; return value; } inline bool TrainerSpec::has_split_digits() const { @@ -2810,7 +2939,7 @@ inline bool TrainerSpec::has_split_digits() const { } inline void TrainerSpec::clear_split_digits() { split_digits_ = false; - _has_bits_[0] &= ~0x00004000u; + _has_bits_[0] &= ~0x00008000u; } inline bool TrainerSpec::_internal_split_digits() const { return split_digits_; @@ -2820,7 +2949,7 @@ inline bool TrainerSpec::split_digits() const { return _internal_split_digits(); } inline void TrainerSpec::_internal_set_split_digits(bool value) { - _has_bits_[0] |= 0x00004000u; + _has_bits_[0] |= 0x00008000u; split_digits_ = value; } inline void TrainerSpec::set_split_digits(bool value) { @@ -3051,7 +3180,7 @@ inline void TrainerSpec::set_allocated_required_chars(std::string* required_char // optional bool byte_fallback = 35 [default = false]; inline bool TrainerSpec::_internal_has_byte_fallback() const { - bool value = (_has_bits_[0] & 0x00008000u) != 0; + bool value = (_has_bits_[0] & 0x00010000u) != 0; return value; } inline bool TrainerSpec::has_byte_fallback() const { @@ -3059,7 +3188,7 @@ inline bool TrainerSpec::has_byte_fallback() const { } inline void TrainerSpec::clear_byte_fallback() { byte_fallback_ = false; - _has_bits_[0] &= ~0x00008000u; + _has_bits_[0] &= ~0x00010000u; } inline bool TrainerSpec::_internal_byte_fallback() const { return byte_fallback_; @@ -3069,7 +3198,7 @@ inline bool TrainerSpec::byte_fallback() const { return _internal_byte_fallback(); } inline void TrainerSpec::_internal_set_byte_fallback(bool value) { - _has_bits_[0] |= 0x00008000u; + _has_bits_[0] |= 0x00010000u; byte_fallback_ = value; } inline void TrainerSpec::set_byte_fallback(bool value) { @@ -3079,7 +3208,7 @@ inline void TrainerSpec::set_byte_fallback(bool value) { // optional bool vocabulary_output_piece_score = 32 [default = true]; inline bool TrainerSpec::_internal_has_vocabulary_output_piece_score() const { - bool value = (_has_bits_[1] & 0x00000001u) != 0; + bool value = (_has_bits_[1] & 0x00000008u) != 0; return value; } inline bool TrainerSpec::has_vocabulary_output_piece_score() const { @@ -3087,7 +3216,7 @@ inline bool TrainerSpec::has_vocabulary_output_piece_score() const { } inline void TrainerSpec::clear_vocabulary_output_piece_score() { vocabulary_output_piece_score_ = true; - _has_bits_[1] &= ~0x00000001u; + _has_bits_[1] &= ~0x00000008u; } inline bool TrainerSpec::_internal_vocabulary_output_piece_score() const { return vocabulary_output_piece_score_; @@ -3097,7 +3226,7 @@ inline bool TrainerSpec::vocabulary_output_piece_score() const { return _internal_vocabulary_output_piece_score(); } inline void TrainerSpec::_internal_set_vocabulary_output_piece_score(bool value) { - _has_bits_[1] |= 0x00000001u; + _has_bits_[1] |= 0x00000008u; vocabulary_output_piece_score_ = value; } inline void TrainerSpec::set_vocabulary_output_piece_score(bool value) { @@ -3107,7 +3236,7 @@ inline void TrainerSpec::set_vocabulary_output_piece_score(bool value) { // optional bool hard_vocab_limit = 33 [default = true]; inline bool TrainerSpec::_internal_has_hard_vocab_limit() const { - bool value = (_has_bits_[1] & 0x00000002u) != 0; + bool value = (_has_bits_[1] & 0x00000010u) != 0; return value; } inline bool TrainerSpec::has_hard_vocab_limit() const { @@ -3115,7 +3244,7 @@ inline bool TrainerSpec::has_hard_vocab_limit() const { } inline void TrainerSpec::clear_hard_vocab_limit() { hard_vocab_limit_ = true; - _has_bits_[1] &= ~0x00000002u; + _has_bits_[1] &= ~0x00000010u; } inline bool TrainerSpec::_internal_hard_vocab_limit() const { return hard_vocab_limit_; @@ -3125,7 +3254,7 @@ inline bool TrainerSpec::hard_vocab_limit() const { return _internal_hard_vocab_limit(); } inline void TrainerSpec::_internal_set_hard_vocab_limit(bool value) { - _has_bits_[1] |= 0x00000002u; + _has_bits_[1] |= 0x00000010u; hard_vocab_limit_ = value; } inline void TrainerSpec::set_hard_vocab_limit(bool value) { @@ -3163,7 +3292,7 @@ inline void TrainerSpec::set_use_all_vocab(bool value) { // optional int32 unk_id = 40 [default = 0]; inline bool TrainerSpec::_internal_has_unk_id() const { - bool value = (_has_bits_[0] & 0x00010000u) != 0; + bool value = (_has_bits_[0] & 0x00080000u) != 0; return value; } inline bool TrainerSpec::has_unk_id() const { @@ -3171,7 +3300,7 @@ inline bool TrainerSpec::has_unk_id() const { } inline void TrainerSpec::clear_unk_id() { unk_id_ = 0; - _has_bits_[0] &= ~0x00010000u; + _has_bits_[0] &= ~0x00080000u; } inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_unk_id() const { return unk_id_; @@ -3181,7 +3310,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::unk_id() const { return _internal_unk_id(); } inline void TrainerSpec::_internal_set_unk_id(::PROTOBUF_NAMESPACE_ID::int32 value) { - _has_bits_[0] |= 0x00010000u; + _has_bits_[0] |= 0x00080000u; unk_id_ = value; } inline void TrainerSpec::set_unk_id(::PROTOBUF_NAMESPACE_ID::int32 value) { @@ -3191,7 +3320,7 @@ inline void TrainerSpec::set_unk_id(::PROTOBUF_NAMESPACE_ID::int32 value) { // optional int32 bos_id = 41 [default = 1]; inline bool TrainerSpec::_internal_has_bos_id() const { - bool value = (_has_bits_[1] & 0x00000004u) != 0; + bool value = (_has_bits_[1] & 0x00000020u) != 0; return value; } inline bool TrainerSpec::has_bos_id() const { @@ -3199,7 +3328,7 @@ inline bool TrainerSpec::has_bos_id() const { } inline void TrainerSpec::clear_bos_id() { bos_id_ = 1; - _has_bits_[1] &= ~0x00000004u; + _has_bits_[1] &= ~0x00000020u; } inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_bos_id() const { return bos_id_; @@ -3209,7 +3338,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::bos_id() const { return _internal_bos_id(); } inline void TrainerSpec::_internal_set_bos_id(::PROTOBUF_NAMESPACE_ID::int32 value) { - _has_bits_[1] |= 0x00000004u; + _has_bits_[1] |= 0x00000020u; bos_id_ = value; } inline void TrainerSpec::set_bos_id(::PROTOBUF_NAMESPACE_ID::int32 value) { @@ -3219,7 +3348,7 @@ inline void TrainerSpec::set_bos_id(::PROTOBUF_NAMESPACE_ID::int32 value) { // optional int32 eos_id = 42 [default = 2]; inline bool TrainerSpec::_internal_has_eos_id() const { - bool value = (_has_bits_[1] & 0x00000008u) != 0; + bool value = (_has_bits_[1] & 0x00000040u) != 0; return value; } inline bool TrainerSpec::has_eos_id() const { @@ -3227,7 +3356,7 @@ inline bool TrainerSpec::has_eos_id() const { } inline void TrainerSpec::clear_eos_id() { eos_id_ = 2; - _has_bits_[1] &= ~0x00000008u; + _has_bits_[1] &= ~0x00000040u; } inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_eos_id() const { return eos_id_; @@ -3237,7 +3366,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::eos_id() const { return _internal_eos_id(); } inline void TrainerSpec::_internal_set_eos_id(::PROTOBUF_NAMESPACE_ID::int32 value) { - _has_bits_[1] |= 0x00000008u; + _has_bits_[1] |= 0x00000040u; eos_id_ = value; } inline void TrainerSpec::set_eos_id(::PROTOBUF_NAMESPACE_ID::int32 value) { @@ -3247,7 +3376,7 @@ inline void TrainerSpec::set_eos_id(::PROTOBUF_NAMESPACE_ID::int32 value) { // optional int32 pad_id = 43 [default = -1]; inline bool TrainerSpec::_internal_has_pad_id() const { - bool value = (_has_bits_[1] & 0x00000010u) != 0; + bool value = (_has_bits_[1] & 0x00000080u) != 0; return value; } inline bool TrainerSpec::has_pad_id() const { @@ -3255,7 +3384,7 @@ inline bool TrainerSpec::has_pad_id() const { } inline void TrainerSpec::clear_pad_id() { pad_id_ = -1; - _has_bits_[1] &= ~0x00000010u; + _has_bits_[1] &= ~0x00000080u; } inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_pad_id() const { return pad_id_; @@ -3265,7 +3394,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::pad_id() const { return _internal_pad_id(); } inline void TrainerSpec::_internal_set_pad_id(::PROTOBUF_NAMESPACE_ID::int32 value) { - _has_bits_[1] |= 0x00000010u; + _has_bits_[1] |= 0x00000080u; pad_id_ = value; } inline void TrainerSpec::set_pad_id(::PROTOBUF_NAMESPACE_ID::int32 value) { diff --git a/src/freelist.h b/src/freelist.h index f4461f3..8038048 100644 --- a/src/freelist.h +++ b/src/freelist.h @@ -46,6 +46,13 @@ class FreeList { // Returns the number of allocated elements. size_t size() const { return chunk_size_ * chunk_index_ + element_index_; } + void swap(FreeList& other) { + std::swap(freelist_, other.freelist_); + std::swap(element_index_, other.element_index_); + std::swap(chunk_index_, other.chunk_index_); + std::swap(chunk_size_, other.chunk_size_); + } + // Returns the element as an array. T* operator[](size_t index) const { return freelist_[index / chunk_size_] + index % chunk_size_; @@ -76,7 +83,7 @@ class FreeList { // The last element is stored at freelist_[chunk_index_][element_index_] size_t element_index_ = 0; size_t chunk_index_ = 0; - const size_t chunk_size_ = 0; + size_t chunk_size_ = 0; // Do not modify except in swap() }; } // namespace model } // namespace sentencepiece diff --git a/src/freelist_test.cc b/src/freelist_test.cc index 9eb41a0..4c6c99e 100644 --- a/src/freelist_test.cc +++ b/src/freelist_test.cc @@ -30,17 +30,20 @@ TEST(FreeListTest, BasicTest) { *n = i; } - EXPECT_EQ(kSize, l.size()); + FreeList l2(3); // Test swap() + l.swap(l2); + + EXPECT_EQ(kSize, l2.size()); for (size_t i = 0; i < kSize; ++i) { - EXPECT_EQ(i, *l[i]); + EXPECT_EQ(i, *l2[i]); } - l.Free(); - EXPECT_EQ(0, l.size()); + l2.Free(); + EXPECT_EQ(0, l2.size()); // Zero-initialized after `Free`. for (size_t i = 0; i < kSize; ++i) { - int *n = l.Allocate(); + int *n = l2.Allocate(); EXPECT_EQ(0, *n); } } diff --git a/src/model_interface.h b/src/model_interface.h index aef5b53..06b3a65 100644 --- a/src/model_interface.h +++ b/src/model_interface.h @@ -107,12 +107,12 @@ class ModelInterface { return EncodeResult(); } - // Sample `samples` many tokenizations from the segmentation lattice + // Sample `samples` many tokenisations from the segmentation lattice // If `wor` is true, the samples are taken without replacement, and the scores // are the inclusion probabilities of the elements in the sample; otherwise // the samples are taken with replacement and the scores are the log-probs of // sample elements - // If `include_best` is true, the best tokenization is always included in the + // If `include_best` is true, the best tokenisation is always included in the // sample, and the remaining elements are sampled excluding the best. virtual NBestEncodeResult SampleEncodeAndScore(absl::string_view normalized, float alpha, int samples, diff --git a/src/normalizer.cc b/src/normalizer.cc index 100b875..d87f89b 100644 --- a/src/normalizer.cc +++ b/src/normalizer.cc @@ -12,12 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License.! -#include "normalizer.h" - #include #include #include "common.h" +#include "normalizer.h" #include "third_party/absl/memory/memory.h" #include "third_party/absl/strings/match.h" #include "third_party/absl/strings/string_view.h" @@ -281,7 +280,8 @@ util::Status Normalizer::DecodePrecompiledCharsMap( if (blob.size() <= sizeof(trie_blob_size) || !string_util::DecodePOD( absl::string_view(blob.data(), sizeof(trie_blob_size)), - &trie_blob_size)) { + &trie_blob_size) || + trie_blob_size >= blob.size()) { return util::InternalError("Blob for normalization rule is broken."); } diff --git a/src/normalizer.h b/src/normalizer.h index 622bbd2..c79813c 100644 --- a/src/normalizer.h +++ b/src/normalizer.h @@ -26,7 +26,6 @@ #include "sentencepiece_processor.h" #include "third_party/absl/strings/string_view.h" #include "third_party/darts_clone/darts.h" -#include "util.h" namespace sentencepiece { namespace normalizer { diff --git a/src/sentencepiece_model.proto b/src/sentencepiece_model.proto index ee8e877..b6c1224 100644 --- a/src/sentencepiece_model.proto +++ b/src/sentencepiece_model.proto @@ -20,6 +20,7 @@ option optimize_for = LITE_RUNTIME; package sentencepiece; // TrainerSpec encodes a various parameters for SentencePiece training. +// Next id: 53 message TrainerSpec { /////////////////////////////////////////////////////////////////// // General parameters @@ -62,6 +63,16 @@ message TrainerSpec { // Size of self-test samples, which are encoded in the model file. optional int32 self_test_sample_size = 6 [default = 0]; + // Whether to use DP version of sentencepiece. Use it with TSV input format + // (requires precomputed word tab counts to work). + optional bool enable_differential_privacy = 50 [default = false]; + // Set these parameters if you need DP version of sentencepiece. + // std of noise to add. + optional float differential_privacy_noise_level = 51 [default = 0.0]; + // Clipping threshold to apply after adding noise. All the words with + // frequency less than this value are dropped. + optional uint64 differential_privacy_clipping_threshold = 52 [default = 0]; + /////////////////////////////////////////////////////////////////// // Training parameters. // diff --git a/src/sentencepiece_processor.cc b/src/sentencepiece_processor.cc index 1e4e7a0..3ed1370 100644 --- a/src/sentencepiece_processor.cc +++ b/src/sentencepiece_processor.cc @@ -548,18 +548,24 @@ util::Status SentencePieceProcessor::Decode( if (model_proto_ && model_proto_->trainer_spec().has_unk_surface()) unk_surface = model_proto_->trainer_spec().unk_surface().c_str(); - auto DecodeSentencePiece = [&](absl::string_view piece, int id, - bool is_bos_ws) -> std::string { - if (IsControl(id)) { // , - return ""; // invisible symbol. + // Returns decoded piece and a boolean indicating if the function has consumed + // a bos whitespace token (a piece starting with a kSpaceSymbol). This is used + // to strip only the first whitespace token from the decoded sequence for + // add_dummy_prefix. + auto DecodeSentencePiece = + [&](absl::string_view piece, int id, + bool is_bos_ws) -> std::pair { + if (IsControl(id)) { // , + return std::make_pair("", false); // invisible symbol. } else if (IsUnknown(id)) { if (IdToPiece(id) == piece) { // - return unk_surface; + return std::make_pair(unk_surface, false); } else { // return piece when piece is not . - return std::string(piece); + return std::make_pair(std::string(piece), false); } } + bool has_bos_ws = false; // whether the token starts with a kSpaceSymbol if (is_bos_ws && (!model_proto_ || (model_proto_ && @@ -567,10 +573,17 @@ util::Status SentencePieceProcessor::Decode( model_proto_->normalizer_spec().remove_extra_whitespaces())))) { // Consume if the current position is bos and // piece starts with kSpaceSymbol. - absl::ConsumePrefix(&piece, kSpaceSymbol); + has_bos_ws = absl::ConsumePrefix(&piece, kSpaceSymbol); + + if (model_proto_ && + model_proto_->normalizer_spec().remove_extra_whitespaces()) { + // if we are removing extra whitespace, we remove all leading whitespace + has_bos_ws = false; + } } - return absl::StrReplaceAll(piece, {{kSpaceSymbol, " "}}); + return std::make_pair(absl::StrReplaceAll(piece, {{kSpaceSymbol, " "}}), + has_bos_ws); }; for (const std::string &w : pieces) { @@ -644,12 +657,23 @@ util::Status SentencePieceProcessor::Decode( }; int byte_start = 0; + bool is_bos_ws = true; // whether we expect a bos ws token to consume. + bool bos_ws_seen = false; + std::string decoded; + for (int i = 0; i < spt->pieces_size(); ++i) { const auto &sp = spt->pieces(i); if (!IsByte(sp.id())) { RETURN_IF_ERROR(ProcessBytePieces(byte_start, i)); + + // if we have seen a bos_ws token or any non-empty token + if (bos_ws_seen || !text->empty()) is_bos_ws = false; + byte_start = i + 1; - SetSurface(i, DecodeSentencePiece(sp.piece(), sp.id(), text->empty())); + std::tie(decoded, bos_ws_seen) = + DecodeSentencePiece(sp.piece(), sp.id(), is_bos_ws); + + SetSurface(i, decoded); } } RETURN_IF_ERROR(ProcessBytePieces(byte_start, spt->pieces_size())); diff --git a/src/sentencepiece_processor.h b/src/sentencepiece_processor.h index e8bd5f5..7cb5f26 100644 --- a/src/sentencepiece_processor.h +++ b/src/sentencepiece_processor.h @@ -163,6 +163,7 @@ namespace normalizer { class Normalizer; } // namespace normalizer +#ifndef SWIG // Defines the multiple versions of encoder within each model. Currently only // the Unigram model has an optimized encoder. enum class EncoderVersion { @@ -170,13 +171,16 @@ enum class EncoderVersion { kOriginal // The original encoder (user may choose to fall back to this // just in case). }; +#endif +#ifndef SWIGGO namespace util { // Redefine std::string for serialized_proto interface as Python's string is // a Unicode string. We can enforce the return value to be raw byte sequence // with SWIG's typemap. using bytes = std::string; } // namespace util +#endif class SentencePieceProcessor { public: @@ -250,6 +254,7 @@ class SentencePieceProcessor { virtual util::Status Decode(const std::vector &ids, std::string *detokenized) const; +#ifndef SWIG // Sets the encoder version. Normally users do not need to call this function. // But they can call this fucntion just in case if they want to fall back to // the original encoder. @@ -257,6 +262,7 @@ class SentencePieceProcessor { // Returns the current encoder version in use. virtual EncoderVersion GetEncoderVersion() const; +#endif ////////////////////////////////////////////////////////////// // NBest API. @@ -315,20 +321,12 @@ class SentencePieceProcessor { virtual util::Status SampleEncode(absl::string_view input, int nbest_size, float alpha, SentencePieceText *spt) const; - // Sample `samples` segmentations from the segmentation lattice. - // If `wor` is true, the samples are taken without replacement, and the scores - // are the inclusion probabilities of the elements in the sample; otherwise - // the samples are taken with replacement and the scores are the log-probes of - // sample elements. - // If `include_best` is true, the best tokenization is always included in the - // sample, and the remaining elements are sampled excluding the best. - // This method is only available in Unigram mode. + // Samples N segmentation and returns the scores as well virtual util::Status SampleEncodeAndScore( absl::string_view input, int samples, float theta, bool wor, bool include_best, NBestSentencePieceText *samples_spt) const; - // Calculate entropy of possible tokenization. - // Only available in unigram mode. + // Calculate entropy of possible tokenisations virtual util::Status CalculateEntropy(absl::string_view input, float theta, float *entropy) const; diff --git a/src/sentencepiece_processor_test.cc b/src/sentencepiece_processor_test.cc index 373e73e..d57ab5a 100644 --- a/src/sentencepiece_processor_test.cc +++ b/src/sentencepiece_processor_test.cc @@ -709,6 +709,87 @@ TEST(SentencepieceProcessorTest, DecodeTest) { } } +TEST(SentencepieceProcessorTest, DummyPrefixDecodeTest) { + class DecodeMockModel : public ModelInterface { + public: + EncodeResult Encode(absl::string_view normalized) const override { + return {}; + } + + int GetPieceSize() const override { return 7; } + + int PieceToId(absl::string_view piece) const override { + static absl::flat_hash_map + kMap = {{"", 0}, {"", 1}, {"", 2}, {WS "ABC", 3}, + {WS "DE", 4}, {"F", 5}, {"G" WS "H", 6}, {WS, 7}}; + return port::FindWithDefault(kMap, piece, 0); + } + + const std::string &IdToPiece(int id) const override { + static std::vector kMap = { + "", "", "", WS "ABC", WS "DE", "F", "G" WS "H", WS}; + return kMap[id]; + } + + bool IsUnknown(int id) const override { return (id == 0); } + + bool IsControl(int id) const override { return (id == 1 || id == 2); } + + bool IsByte(int id) const override { return false; } + + float GetScore(int id) const override { return 0.0; } + }; + + // start the sequence with a whitespace token + const std::vector input = { + "", WS, WS "ABC", "", WS "DE", "F", "G" WS "H", "I", ""}; + + { + SentencePieceProcessor sp; + auto proto = absl::make_unique(); + proto->mutable_trainer_spec()->set_unk_surface(""); + proto->mutable_normalizer_spec()->set_add_dummy_prefix(true); + proto->mutable_normalizer_spec()->set_remove_extra_whitespaces(false); + sp.Load(std::move(proto)).IgnoreError(); + + auto mock = absl::make_unique(); + sp.SetModel(std::move(mock)); + + const auto normalization_spec = MakeDefaultNormalizerSpec(); + sp.SetNormalizer( + absl::make_unique(normalization_spec)); + + SentencePieceText spt; + + EXPECT_TRUE(sp.Decode(input, &spt).ok()); + EXPECT_EQ(" ABC DEFG HI", spt.text()); + EXPECT_EQ(9, spt.pieces_size()); + } + + { + SentencePieceProcessor sp; + auto proto = absl::make_unique(); + proto->mutable_trainer_spec()->set_unk_surface(""); + proto->mutable_normalizer_spec()->set_add_dummy_prefix(true); + proto->mutable_normalizer_spec()->set_remove_extra_whitespaces(true); + sp.Load(std::move(proto)).IgnoreError(); + + auto mock = absl::make_unique(); + sp.SetModel(std::move(mock)); + + const auto normalization_spec = MakeDefaultNormalizerSpec(); + sp.SetNormalizer( + absl::make_unique(normalization_spec)); + + SentencePieceText spt; + + EXPECT_TRUE(sp.Decode(input, &spt).ok()); + EXPECT_EQ("ABC DEFG HI", spt.text()); + EXPECT_EQ(9, spt.pieces_size()); + } +} + TEST(SentencepieceProcessorTest, ByteFallbackDecodeTest) { class ByteFallbackDecodeMockModel : public ModelInterface { public: diff --git a/src/spm_train_main.cc b/src/spm_train_main.cc index baf8dbf..c34ee02 100644 --- a/src/spm_train_main.cc +++ b/src/spm_train_main.cc @@ -144,6 +144,18 @@ ABSL_FLAG(bool, train_extremely_large_corpus, ABSL_FLAG(uint32, random_seed, static_cast(-1), "Seed value for random generator."); +// DP related. +ABSL_FLAG(bool, enable_differential_privacy, false, + "Whether to add DP while training. Currently supported only by " + "UNIGRAM model."); + +ABSL_FLAG(float, differential_privacy_noise_level, 0.0f, + "Amount of noise to add for" + " DP"); +ABSL_FLAG(std::uint64_t, differential_privacy_clipping_threshold, 0, + "Threshold for" + " clipping the counts for DP"); + int main(int argc, char *argv[]) { sentencepiece::ParseCommandLineFlags(argv[0], &argc, &argv, true); @@ -235,6 +247,10 @@ int main(int argc, char *argv[]) { SetRepeatedTrainerSpecFromFlag(control_symbols); SetRepeatedTrainerSpecFromFlag(user_defined_symbols); SetTrainerSpecFromFlag(train_extremely_large_corpus); + // DP related. + SetTrainerSpecFromFlag(enable_differential_privacy); + SetTrainerSpecFromFlag(differential_privacy_noise_level); + SetTrainerSpecFromFlag(differential_privacy_clipping_threshold); SetRepeatedTrainerSpecFromFile(control_symbols); SetRepeatedTrainerSpecFromFile(user_defined_symbols); diff --git a/src/trainer_interface.cc b/src/trainer_interface.cc index b1bcd1b..ef0c370 100644 --- a/src/trainer_interface.cc +++ b/src/trainer_interface.cc @@ -28,6 +28,8 @@ #include "sentencepiece_trainer.h" #include "third_party/absl/container/flat_hash_map.h" #include "third_party/absl/memory/memory.h" +#include "third_party/absl/random/distributions.h" +#include "third_party/absl/random/random.h" #include "third_party/absl/strings/numbers.h" #include "third_party/absl/strings/str_cat.h" #include "third_party/absl/strings/str_format.h" @@ -273,8 +275,7 @@ bool TrainerInterface::IsValidSentencePiece( if (s == unicode_script::U_Hiragana || s == unicode_script::U_Katakana || c == 0x30FC) { // long vowel sound (Katakana) should be Katakana s = unicode_script::U_Han; - } - else if (s == unicode_script::U_Inherited) { + } else if (s == unicode_script::U_Inherited) { s = prev_script; } @@ -299,6 +300,22 @@ bool TrainerInterface::IsValidSentencePiece( return true; } +template +void AddDPNoise(const TrainerSpec &trainer_spec, absl::SharedBitGen &generator, + T *to_update) { + if (trainer_spec.differential_privacy_noise_level() > 0) { + float random_num = absl::Gaussian( + generator, 0, trainer_spec.differential_privacy_noise_level()); + + *to_update = + std::round(std::max(0.f, random_num + static_cast(*to_update))); + } + // Clip anything below the clipping threshold to 0. + if (*to_update < trainer_spec.differential_privacy_clipping_threshold()) { + *to_update = 0; + } +} + util::Status TrainerInterface::LoadSentences() { RETURN_IF_ERROR(status()); CHECK_OR_RETURN(sentences_.empty()); @@ -390,6 +407,7 @@ END: LOG(INFO) << "Sampled " << sentences_.size() << " sentences from " << selector.total_size() << " sentences."; } + if (too_long_lines > 0) LOG(INFO) << "Skipped " << too_long_lines << " too long sentences."; if (self_test_samples_.size() > 0) @@ -433,6 +451,54 @@ END: } } + // If DP is required, add the noise/clip the input. + if (trainer_spec_.enable_differential_privacy()) { + if (trainer_spec_.input_format() != "tsv") { + LOG(ERROR) + << "Dp version will not work correctly with text input format."; + } + if (trainer_spec_.differential_privacy_noise_level() <= 0) { + LOG(WARNING) << "Private version with <=0 noise level will give " + "infinity epsilon gurantees."; + } + if (trainer_spec_.differential_privacy_clipping_threshold() <= 0) { + LOG(WARNING) << "Private version with <=0 clipping threshold will give " + "infinity epsilon gurantees."; + } + + // Add noise to all the sentences via threadpool. + + // This line is mainly for tests with small num of sentences. + const auto num_workers = + std::min(trainer_spec_.num_threads(), sentences_.size() - 1); + + { + auto pool = absl::make_unique(num_workers); + pool->StartWorkers(); + for (int n = 0; n < num_workers; ++n) { + pool->Schedule([&, n]() { + // One per thread generator. + absl::SharedBitGen generator; + for (size_t i = n; i < sentences_.size(); i += num_workers) { + AddDPNoise(trainer_spec_, generator, + &(sentences_[i].second)); + } + }); + } + } + + // Remove zero freq elements. + const auto before_size = sentences_.size(); + auto it = std::remove_if(sentences_.begin(), sentences_.end(), + [](const Sentence &s) { return s.second <= 0; }); + const auto new_size = std::distance(sentences_.begin(), it); + const int num_erased = before_size - new_size; + sentences_.erase(it, sentences_.end()); + + LOG(INFO) << "DP noise resulted in " << 1.0 * num_erased / before_size + << " fraction of sentences removed."; + } + // Count character frequencies. int64 all_chars_count = 0; // A map from a character to {is_required_char, character count}. @@ -617,6 +683,7 @@ util::Status TrainerInterface::Serialize(ModelProto *model_proto) const { util::Status TrainerInterface::SaveModel(absl::string_view filename) const { LOG(INFO) << "Saving model: " << filename; ModelProto model_proto; + RETURN_IF_ERROR(Serialize(&model_proto)); auto output = filesystem::NewWritableFile(filename.data(), true); diff --git a/src/trainer_interface.h b/src/trainer_interface.h index f66d59a..8d625a9 100644 --- a/src/trainer_interface.h +++ b/src/trainer_interface.h @@ -107,16 +107,16 @@ class TrainerInterface { FRIEND_TEST(TrainerInterfaceTest, SerializeTest); FRIEND_TEST(TrainerInterfaceTest, CharactersTest); + // Loads all sentences from spec.input() or SentenceIterator. + // It loads at most input_sentence_size sentences. + util::Status LoadSentences(); + protected: // Returns true if |piece| is valid sentence piece. // The result is affected by // max_sentencepiece_length, split_by_whiespace, split_by_unicode_script. bool IsValidSentencePiece(const string_util::UnicodeText &piece) const; - // Loads all sentences from spec.input() or SentenceIterator. - // It loads at most input_sentence_size sentences. - util::Status LoadSentences(); - // Splits all sentencecs by whitespaces and // replace the |sentences_| with tokenized string. // e.g., diff --git a/src/trainer_interface_test.cc b/src/trainer_interface_test.cc index 4a3ab56..d6c0c78 100644 --- a/src/trainer_interface_test.cc +++ b/src/trainer_interface_test.cc @@ -78,8 +78,10 @@ TEST(TrainerInterfaceTest, IsValidSentencePieceTest) { EXPECT_FALSE(IsValid("ab cd")); EXPECT_FALSE(IsValid("\0\0")); EXPECT_FALSE(IsValid("\0")); - EXPECT_TRUE(IsValid("proteïni")); // Combining Diaeresis should inherit script from base character. - EXPECT_TRUE(IsValid("ثَبَّتَ")); // Arabic Fatha and Shadda should inherit script from base character. + EXPECT_TRUE(IsValid("proteïni")); // Combining Diaeresis should inherit + // script from base character. + EXPECT_TRUE(IsValid("ثَبَّتَ")); // Arabic Fatha and Shadda should inherit script + // from base character. trainer_spec.set_split_by_whitespace(false); EXPECT_TRUE(IsValid(WS)); diff --git a/src/unigram_model.cc b/src/unigram_model.cc index 3b99060..f87edb2 100644 --- a/src/unigram_model.cc +++ b/src/unigram_model.cc @@ -22,6 +22,7 @@ #include #include +#include "third_party/absl/container/flat_hash_map.h" #include "third_party/absl/memory/memory.h" #include "third_party/absl/strings/str_split.h" #include "third_party/absl/strings/string_view.h" @@ -289,6 +290,58 @@ float Lattice::CalculateEntropy(float theta) const { return -H[begin_nodes_[len][0]->node_id]; } +namespace { + +// The node structure to support A* algorithm in Lattice::NBest() +struct Hypothesis { + Lattice::Node *node; + Hypothesis *next; + float fx; // the priority to pop a new hypothesis from the priority queue. + float gx; // the sum of scores from EOS to the left-most node in x. +}; + +// Helper function for cloning a Hypothesis and the ones on their next paths. +// The graph structure is preserved. +// +// to_clone: the Hypothesis to clone. +// clone_map: mapping between the old pointers and the new pointers. +// allocator: allocate and own the cloned Hypothesis. +// +// Returns the cloned Hypothesis*. All Hypothesis on its "next" chain are also +// guaranteed to have been allocated via "allocator", and "clone_map" is updated +// with all new mappings. +Hypothesis *CloneHypAndDependents( + const Hypothesis *to_clone, + absl::flat_hash_map *clone_map, + model::FreeList *allocator) { + Hypothesis *cloned = nullptr; + Hypothesis **result_callback = &cloned; + + // Iteratively clone "to_clone" and its dependencies. + // The new pointer will be written back to *result_callback. + while (to_clone != nullptr) { + // If "to_clone" has already been cloned before, we just look up the result. + auto iter = clone_map->find(to_clone); + if (iter != clone_map->end()) { + *result_callback = iter->second; + break; + } + + // Allocate a new Hypothesis and copy the values. + Hypothesis *new_hyp = allocator->Allocate(); + *new_hyp = *to_clone; + *result_callback = new_hyp; + clone_map->insert({to_clone, new_hyp}); + + // Move on to clone "to_clone->next". + to_clone = to_clone->next; + result_callback = &(new_hyp->next); + } + return cloned; +} + +} // namespace + std::vector Lattice::NBest(size_t nbest_size, bool sample, float theta) { @@ -312,12 +365,6 @@ std::vector Lattice::NBest(size_t nbest_size, // // As left-to-right Viterbi search can tell the *exact* value of h(x), // we can obtain the exact n-best results with A*. - struct Hypothesis { - Node *node; - Hypothesis *next; - float fx; - float gx; - }; class HypothesisComparator { public: @@ -353,6 +400,8 @@ std::vector Lattice::NBest(size_t nbest_size, } agenda.push(eos); + int shrink_count = 0; // Number of times agenda has shrunk. For logging only. + bool printed_memory_warning = false; // For logging only. while (!agenda.empty()) { auto *top = agenda.top(); agenda.pop(); @@ -416,21 +465,42 @@ std::vector Lattice::NBest(size_t nbest_size, agenda.push(hyp); } + static constexpr int kOneBillion = 1000000000; // 10^9. + if (hypothesis_allocator.size() >= kOneBillion) { + if (!printed_memory_warning) { + printed_memory_warning = true; + LOG(WARNING) << "Allocator size exceeds " << kOneBillion + << " with an example of length " << this->size(); + } + } + // When the input is too long or contains duplicated phrases, // `agenda` will get extremely big. Here we avoid this case by // dynamically shrinking the agenda. - constexpr int kMaxAgendaSize = 100000; + constexpr int kMaxAgendaSize = 10000; constexpr int kMinAgendaSize = 512; if (agenda.size() >= kMaxAgendaSize) { - LOG(WARNING) << "Too big agenda. shrinking"; // Keeps the top `kMinAgendaSize` hypothesis. Agenda new_agenda; + // Keeps the top hypothesis and the ones on their "next" paths. + model::FreeList new_allocator(kPreallocatedHypothesisSize); + // Map between old Hypothesis* and new Hypothesis*. + absl::flat_hash_map clone_map; + const int size = std::min(kMinAgendaSize, nbest_size * 10); + shrink_count++; + LOG(WARNING) << "Too big agenda size " << agenda.size() + << ". Shrinking (round " << shrink_count << ") down to " + << size << "."; for (int i = 0; i < size; ++i) { - new_agenda.push(agenda.top()); + const Hypothesis *top_hyp = agenda.top(); + Hypothesis *cloned_hyp = + CloneHypAndDependents(top_hyp, &clone_map, &new_allocator); + new_agenda.push(cloned_hyp); agenda.pop(); } agenda = std::move(new_agenda); + hypothesis_allocator.swap(new_allocator); } } diff --git a/src/unigram_model_test.cc b/src/unigram_model_test.cc index f93b21c..8049d20 100644 --- a/src/unigram_model_test.cc +++ b/src/unigram_model_test.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License.! -#include "unigram_model.h" - #include #include #include @@ -24,6 +22,7 @@ #include "testharness.h" #include "third_party/absl/strings/str_cat.h" #include "third_party/absl/strings/str_join.h" +#include "unigram_model.h" #include "util.h" namespace sentencepiece { @@ -268,8 +267,8 @@ TEST(LatticeTest, NBestSampleTest) { for (auto &it : probs) it.second /= Z; std::map, float> pair_probs; - for (const auto first : strings) { - for (const auto second : strings) { + for (const auto &first : strings) { + for (const auto &second : strings) { if (first == second) { pair_probs[std::make_pair(first, second)] = 0; } else { @@ -281,12 +280,12 @@ TEST(LatticeTest, NBestSampleTest) { } std::map inclusion_probs; - for (const auto string : strings) { + for (const auto &string : strings) { float inclusion_prob = 0.0; - for (const auto other_string : strings) { + for (const auto &other_string : strings) { inclusion_prob += pair_probs[std::make_pair(string, other_string)]; } - for (const auto other_string : strings) { + for (const auto &other_string : strings) { inclusion_prob += pair_probs[std::make_pair(other_string, string)]; } inclusion_probs[string] = inclusion_prob / 2; @@ -300,7 +299,7 @@ TEST(LatticeTest, NBestSampleTest) { std::map counts; for (int i = 0; i < kTrials; i++) { auto nbests = lattice.NBest(num_samples, true, theta); - for (const auto nbest : nbests) { + for (const auto &nbest : nbests) { counts[GetTokenized(nbest.first)]++; } } @@ -550,25 +549,25 @@ TEST(UnigramModelTest, SampleEncodeAndScoreTest) { for (auto &it : probs) it.second /= Z; std::map, float> pair_probs; - for (const auto first : strings) { - for (const auto second : strings) { + for (const auto &first : strings) { + for (const auto &second : strings) { if (first == second) { pair_probs[std::make_pair(first, second)] = 0; } else { - float first_prob = probs[first]; - float second_prob = probs[second] / (1 - first_prob); + const float first_prob = probs[first]; + const float second_prob = probs[second] / (1 - first_prob); pair_probs[std::make_pair(first, second)] = first_prob * second_prob; } } } std::map inclusion_probs; - for (const auto string : strings) { + for (const auto &string : strings) { float inclusion_prob = 0.0; - for (const auto other_string : strings) { + for (const auto &other_string : strings) { inclusion_prob += pair_probs[std::make_pair(string, other_string)]; } - for (const auto other_string : strings) { + for (const auto &other_string : strings) { inclusion_prob += pair_probs[std::make_pair(other_string, string)]; } inclusion_probs[string] = inclusion_prob / 2; diff --git a/src/unigram_model_trainer.h b/src/unigram_model_trainer.h index 91fbeb4..5079a39 100644 --- a/src/unigram_model_trainer.h +++ b/src/unigram_model_trainer.h @@ -70,9 +70,6 @@ class Trainer : public TrainerInterface { util::Status Train() override; - private: - FRIEND_TEST(TrainerTest, IsValidSentencePieceTest); - // Makes seed pieces from the training corpus. // The size of seed pieces is determined by seed_sentencepiece_size. // node_int_type should be of integer type (int32 or int64), @@ -80,6 +77,9 @@ class Trainer : public TrainerInterface { template TrainerModel::SentencePieces MakeSeedSentencePieces() const; + private: + FRIEND_TEST(TrainerTest, IsValidSentencePieceTest); + // Executes the E step of EM and returns expected count. // The index of return array is the vocab id. // |objective| is a negative likelihood of the current model. diff --git a/src/unigram_model_trainer_test.cc b/src/unigram_model_trainer_test.cc index ffe515e..75c5fa1 100644 --- a/src/unigram_model_trainer_test.cc +++ b/src/unigram_model_trainer_test.cc @@ -12,6 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License.! +#include +#include + +#include "filesystem.h" #include "sentencepiece_model.pb.h" #include "sentencepiece_processor.h" #include "sentencepiece_trainer.h" @@ -23,7 +27,6 @@ namespace sentencepiece { namespace unigram { -namespace { // Space symbol #define WS "\xe2\x96\x81" @@ -35,6 +38,116 @@ TEST(UnigramTrainerTest, TrainerModelTest) { EXPECT_EQ(EncodeResult(), model.Encode("test")); } +struct TrainerResult { + std::string sentence_pieces; + std::vector> seed_pieces_and_probs; +}; + +TrainerResult RunTrainer(const std::vector& input, int size, + const bool use_dp = false, const float dp_noise = 0.0, + const uint32 dp_clip = 0) { + const std::string input_file = + util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "input"); + const std::string model_prefix = + util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "model"); + { + auto output = filesystem::NewWritableFile(input_file); + for (const auto& line : input) { + output->WriteLine(line); + } + } + + TrainerSpec trainer_spec; + trainer_spec.set_input_format("tsv"); + trainer_spec.set_model_type(TrainerSpec::UNIGRAM); + trainer_spec.add_input(input_file); + trainer_spec.set_vocab_size(size - 3); // remove , , + trainer_spec.set_model_prefix(model_prefix); + + trainer_spec.set_enable_differential_privacy(use_dp); + trainer_spec.set_differential_privacy_noise_level(dp_noise); + trainer_spec.set_differential_privacy_clipping_threshold(dp_clip); + + NormalizerSpec normalizer_spec; + normalizer_spec.set_name("identity"); + normalizer_spec.set_add_dummy_prefix(false); + + NormalizerSpec denormalizer_spec; + + std::vector> seed_pieces; + + { + Trainer trainer(trainer_spec, normalizer_spec, denormalizer_spec); + EXPECT_OK(trainer.LoadSentences()); + TrainerModel::SentencePieces res = trainer.MakeSeedSentencePieces(); + + for (const auto& piece : res) { + seed_pieces.emplace_back(piece.first, piece.second); + } + } + + std::vector pieces; + + { + Trainer trainer(trainer_spec, normalizer_spec, denormalizer_spec); + EXPECT_TRUE(trainer.Train().ok()); + + SentencePieceProcessor processor; + EXPECT_TRUE(processor.Load(model_prefix + ".model").ok()); + + const auto& model = processor.model_proto(); + + // remove , , + for (int i = 3; i < model.pieces_size(); ++i) { + pieces.emplace_back(model.pieces(i).piece()); + } + } + + TrainerResult res; + res.seed_pieces_and_probs = seed_pieces; + res.sentence_pieces = absl::StrJoin(pieces, " "); + return res; +} + +TEST(UnigramTrainerTest, BasicTest) { + const auto& res = RunTrainer( + {"magnanimity \t 5", "Pineapple \t 6", "i have an apple and a pen \t 1", + "Overly \t 6", "Available \t 3"}, + 30); + + // Check seed pieces. + EXPECT_EQ(27, res.seed_pieces_and_probs.size()); + + // Check final pieces. + EXPECT_EQ("i a n y m l e apple ve O P r t g an v ▁ A b le ▁an p d h", + res.sentence_pieces); +} + +TEST(UnigramTrainerTest, BasicDPTest) { + // no noise, clipping. + { + const auto& res = RunTrainer( + {"magnanimity \t 5", "Pineapple \t 6", "i have an apple and a pen \t 1", + "Overly \t 6", "Available \t 5"}, + 22, true /*use_dp*/, 0 /*dp_noise*/, 4 /*dp_clipping*/); + + // Got 16 instead of 27 seeds. + EXPECT_EQ(16, res.seed_pieces_and_probs.size()); + + // And they are equiv to if the last sentence was not there. + const auto& res_nodp = RunTrainer( + {"magnanimity \t 5", "Pineapple \t 6", "Overly \t 6", "Available \t 5"}, + 22); + + EXPECT_EQ(res.seed_pieces_and_probs, res_nodp.seed_pieces_and_probs); + + // Check final pieces. + EXPECT_EQ(res.sentence_pieces, res_nodp.sentence_pieces); + } +} + +namespace { + static constexpr char kTestInputData[] = "wagahaiwa_nekodearu.txt"; TEST(UnigramTrainerTest, EndToEndTest) { diff --git a/third_party/absl/flags/flag.cc b/third_party/absl/flags/flag.cc index e7ac841..8e99c0d 100644 --- a/third_party/absl/flags/flag.cc +++ b/third_party/absl/flags/flag.cc @@ -173,6 +173,7 @@ template class Flag; template class Flag; template class Flag; template class Flag; +template class Flag; template class Flag; template class Flag; template class Flag;