Sync internal to github. DP related features are added.

This commit is contained in:
Taku Kudo 2022-05-25 14:03:45 +09:00
parent 9fe52f1aa6
commit 9f3ed99f5c
24 changed files with 898 additions and 297 deletions

View File

@ -1 +1 @@
0.1.96
0.1.97

View File

@ -1 +1 @@
0.1.96
0.1.97

View File

@ -15,6 +15,8 @@
#ifndef BPE_MODEL_TRAINER_H_
#define BPE_MODEL_TRAINER_H_
#include <cstdint>
#include <limits>
#include <set>
#include <string>
#include <vector>
@ -44,12 +46,12 @@ class Trainer : public TrainerInterface {
const Symbol *right; // right symbol in bigram
string_util::UnicodeText chars; // all flattend chracter sequence
bool is_unk; // true if this symbol is unknown.
uint64 fp; // fingerprint of this symbol.
uint64 freq; // frequency of this symbol.
uint64_t fp; // fingerprint of this symbol.
uint64_t freq; // frequency of this symbol.
// Position list. Use set so that we can keep the order of occurrence.
// See EncodePos/DecodePos.
std::set<uint64> positions;
std::set<uint64_t> positions;
bool IsBigram() const { return left != nullptr && right != nullptr; }
std::string ToString() const;
@ -62,19 +64,19 @@ class Trainer : public TrainerInterface {
int right; // right symbol index
};
// Encodes sid, left and right bigram index into uint64.
// Encodes sid, left and right bigram index into uint64_t.
// Encoded value keeps the order of sid, left and right.
static uint64 EncodePos(int sid, int l, int r) {
static uint64_t EncodePos(int sid, int l, int r) {
CHECK_GE(l, 0);
CHECK_GE(r, 0);
CHECK_LE(l, kuint16max);
CHECK_LE(r, kuint16max);
const uint64 n = (static_cast<uint64>(sid) << 32 | (l << 16 | r));
CHECK_LE(l, std::numeric_limits<uint16_t>::max());
CHECK_LE(r, std::numeric_limits<uint16_t>::max());
const uint64_t n = (static_cast<uint64_t>(sid) << 32 | (l << 16 | r));
return n;
}
// Decodes sid, left and right bigram index from uint64.
static Position DecodePos(uint64 n) {
// Decodes sid, left and right bigram index from uint64_t.
static Position DecodePos(uint64_t n) {
Position p;
p.sid = n >> 32;
p.left = (n >> 16) & 0xffff;
@ -111,7 +113,7 @@ class Trainer : public TrainerInterface {
void UpdateActiveSymbols();
// All unique symbols. Key is a fingerprint of Symbol.
absl::flat_hash_map<uint64, Symbol *> symbols_cache_;
absl::flat_hash_map<uint64_t, Symbol *> symbols_cache_;
// Set of symbols from which we find the best symbol in each iteration.
std::set<Symbol *> active_symbols_;

View File

@ -367,7 +367,7 @@ util::Status Builder::BuildNmtNFKCMap(CharsMap *chars_map) {
nfkc_map[{0xFEFF}] = {0x20}; // ZERO WIDTH NO-BREAK
nfkc_map[{0xFFFD}] = {0x20}; // REPLACEMENT CHARACTER
nfkc_map[{0x200C}] = {0x20}; // ZERO WIDTH NON-JOINER
// nfkc_map[{0x200D}] = {0x20}; // ZERO WIDTH JOINER
// nfkc_map[{0x200D}] = {0x20}; // ZERO WIDTH JOINER
// Ascii Control characters
nfkc_map[{0x0001}] = {};

View File

@ -285,22 +285,31 @@ class TrainerSpec::_Internal {
(*has_bits)[0] |= 1u;
}
static void set_has_model_type(HasBits* has_bits) {
(*has_bits)[0] |= 524288u;
(*has_bits)[0] |= 4194304u;
}
static void set_has_vocab_size(HasBits* has_bits) {
(*has_bits)[0] |= 1048576u;
(*has_bits)[0] |= 8388608u;
}
static void set_has_self_test_sample_size(HasBits* has_bits) {
(*has_bits)[0] |= 256u;
}
static void set_has_character_coverage(HasBits* has_bits) {
static void set_has_enable_differential_privacy(HasBits* has_bits) {
(*has_bits)[0] |= 4096u;
}
static void set_has_differential_privacy_noise_level(HasBits* has_bits) {
(*has_bits)[0] |= 1048576u;
}
static void set_has_differential_privacy_clipping_threshold(HasBits* has_bits) {
(*has_bits)[0] |= 2097152u;
}
static void set_has_character_coverage(HasBits* has_bits) {
(*has_bits)[0] |= 16777216u;
}
static void set_has_input_sentence_size(HasBits* has_bits) {
(*has_bits)[0] |= 1024u;
}
static void set_has_shuffle_input_sentence(HasBits* has_bits) {
(*has_bits)[0] |= 268435456u;
(*has_bits)[0] |= 2147483648u;
}
static void set_has_mining_sentence_size(HasBits* has_bits) {
(*has_bits)[0] |= 512u;
@ -309,67 +318,67 @@ class TrainerSpec::_Internal {
(*has_bits)[0] |= 2048u;
}
static void set_has_seed_sentencepiece_size(HasBits* has_bits) {
(*has_bits)[0] |= 4194304u;
}
static void set_has_shrinking_factor(HasBits* has_bits) {
(*has_bits)[0] |= 8388608u;
}
static void set_has_max_sentence_length(HasBits* has_bits) {
(*has_bits)[0] |= 67108864u;
}
static void set_has_num_threads(HasBits* has_bits) {
(*has_bits)[0] |= 16777216u;
}
static void set_has_num_sub_iterations(HasBits* has_bits) {
(*has_bits)[0] |= 33554432u;
}
static void set_has_max_sentencepiece_length(HasBits* has_bits) {
(*has_bits)[0] |= 134217728u;
static void set_has_shrinking_factor(HasBits* has_bits) {
(*has_bits)[0] |= 67108864u;
}
static void set_has_split_by_unicode_script(HasBits* has_bits) {
static void set_has_max_sentence_length(HasBits* has_bits) {
(*has_bits)[0] |= 536870912u;
}
static void set_has_split_by_number(HasBits* has_bits) {
static void set_has_num_threads(HasBits* has_bits) {
(*has_bits)[0] |= 134217728u;
}
static void set_has_num_sub_iterations(HasBits* has_bits) {
(*has_bits)[0] |= 268435456u;
}
static void set_has_max_sentencepiece_length(HasBits* has_bits) {
(*has_bits)[0] |= 1073741824u;
}
static void set_has_split_by_unicode_script(HasBits* has_bits) {
(*has_bits)[1] |= 1u;
}
static void set_has_split_by_number(HasBits* has_bits) {
(*has_bits)[1] |= 2u;
}
static void set_has_split_by_whitespace(HasBits* has_bits) {
(*has_bits)[0] |= 2147483648u;
(*has_bits)[1] |= 4u;
}
static void set_has_treat_whitespace_as_suffix(HasBits* has_bits) {
(*has_bits)[0] |= 4096u;
}
static void set_has_allow_whitespace_only_pieces(HasBits* has_bits) {
(*has_bits)[0] |= 8192u;
}
static void set_has_split_digits(HasBits* has_bits) {
static void set_has_allow_whitespace_only_pieces(HasBits* has_bits) {
(*has_bits)[0] |= 16384u;
}
static void set_has_split_digits(HasBits* has_bits) {
(*has_bits)[0] |= 32768u;
}
static void set_has_required_chars(HasBits* has_bits) {
(*has_bits)[0] |= 4u;
}
static void set_has_byte_fallback(HasBits* has_bits) {
(*has_bits)[0] |= 32768u;
(*has_bits)[0] |= 65536u;
}
static void set_has_vocabulary_output_piece_score(HasBits* has_bits) {
(*has_bits)[1] |= 1u;
(*has_bits)[1] |= 8u;
}
static void set_has_hard_vocab_limit(HasBits* has_bits) {
(*has_bits)[1] |= 2u;
(*has_bits)[1] |= 16u;
}
static void set_has_use_all_vocab(HasBits* has_bits) {
(*has_bits)[0] |= 131072u;
}
static void set_has_unk_id(HasBits* has_bits) {
(*has_bits)[0] |= 65536u;
(*has_bits)[0] |= 524288u;
}
static void set_has_bos_id(HasBits* has_bits) {
(*has_bits)[1] |= 4u;
(*has_bits)[1] |= 32u;
}
static void set_has_eos_id(HasBits* has_bits) {
(*has_bits)[1] |= 8u;
(*has_bits)[1] |= 64u;
}
static void set_has_pad_id(HasBits* has_bits) {
(*has_bits)[1] |= 16u;
(*has_bits)[1] |= 128u;
}
static void set_has_unk_piece(HasBits* has_bits) {
(*has_bits)[0] |= 16u;
@ -474,8 +483,8 @@ void TrainerSpec::SharedCtor() {
pad_piece_.UnsafeSetDefault(nullptr);
::memset(reinterpret_cast<char*>(this) + static_cast<size_t>(
reinterpret_cast<char*>(&self_test_sample_size_) - reinterpret_cast<char*>(this)),
0, static_cast<size_t>(reinterpret_cast<char*>(&train_extremely_large_corpus_) -
reinterpret_cast<char*>(&self_test_sample_size_)) + sizeof(train_extremely_large_corpus_));
0, static_cast<size_t>(reinterpret_cast<char*>(&differential_privacy_clipping_threshold_) -
reinterpret_cast<char*>(&self_test_sample_size_)) + sizeof(differential_privacy_clipping_threshold_));
model_type_ = 1;
vocab_size_ = 8000;
character_coverage_ = 0.9995f;
@ -569,31 +578,31 @@ void TrainerSpec::Clear() {
}
if (cached_has_bits & 0x0000ff00u) {
::memset(&self_test_sample_size_, 0, static_cast<size_t>(
reinterpret_cast<char*>(&byte_fallback_) -
reinterpret_cast<char*>(&self_test_sample_size_)) + sizeof(byte_fallback_));
reinterpret_cast<char*>(&split_digits_) -
reinterpret_cast<char*>(&self_test_sample_size_)) + sizeof(split_digits_));
}
if (cached_has_bits & 0x00ff0000u) {
::memset(&unk_id_, 0, static_cast<size_t>(
reinterpret_cast<char*>(&train_extremely_large_corpus_) -
reinterpret_cast<char*>(&unk_id_)) + sizeof(train_extremely_large_corpus_));
::memset(&byte_fallback_, 0, static_cast<size_t>(
reinterpret_cast<char*>(&differential_privacy_clipping_threshold_) -
reinterpret_cast<char*>(&byte_fallback_)) + sizeof(differential_privacy_clipping_threshold_));
model_type_ = 1;
vocab_size_ = 8000;
}
if (cached_has_bits & 0xff000000u) {
character_coverage_ = 0.9995f;
seed_sentencepiece_size_ = 1000000;
shrinking_factor_ = 0.75f;
}
if (cached_has_bits & 0xff000000u) {
num_threads_ = 16;
num_sub_iterations_ = 2;
max_sentence_length_ = 4192;
max_sentencepiece_length_ = 16;
shuffle_input_sentence_ = true;
}
cached_has_bits = _has_bits_[1];
if (cached_has_bits & 0x000000ffu) {
split_by_unicode_script_ = true;
split_by_number_ = true;
split_by_whitespace_ = true;
}
cached_has_bits = _has_bits_[1];
if (cached_has_bits & 0x0000001fu) {
vocabulary_output_piece_score_ = true;
hard_vocab_limit_ = true;
bos_id_ = 1;
@ -963,6 +972,30 @@ const char* TrainerSpec::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID
CHK_(ptr);
} else goto handle_unusual;
continue;
// optional bool enable_differential_privacy = 50 [default = false];
case 50:
if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 144)) {
_Internal::set_has_enable_differential_privacy(&_has_bits_);
enable_differential_privacy_ = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr);
CHK_(ptr);
} else goto handle_unusual;
continue;
// optional float differential_privacy_noise_level = 51 [default = 0];
case 51:
if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 157)) {
_Internal::set_has_differential_privacy_noise_level(&_has_bits_);
differential_privacy_noise_level_ = ::PROTOBUF_NAMESPACE_ID::internal::UnalignedLoad<float>(ptr);
ptr += sizeof(float);
} else goto handle_unusual;
continue;
// optional uint64 differential_privacy_clipping_threshold = 52 [default = 0];
case 52:
if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 160)) {
_Internal::set_has_differential_privacy_clipping_threshold(&_has_bits_);
differential_privacy_clipping_threshold_ = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr);
CHK_(ptr);
} else goto handle_unusual;
continue;
default: {
handle_unusual:
if ((tag & 7) == 4 || tag == 0) {
@ -1011,14 +1044,14 @@ failure:
}
// optional .sentencepiece.TrainerSpec.ModelType model_type = 3 [default = UNIGRAM];
if (cached_has_bits & 0x00080000u) {
if (cached_has_bits & 0x00400000u) {
target = stream->EnsureSpace(target);
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteEnumToArray(
3, this->_internal_model_type(), target);
}
// optional int32 vocab_size = 4 [default = 8000];
if (cached_has_bits & 0x00100000u) {
if (cached_has_bits & 0x00800000u) {
target = stream->EnsureSpace(target);
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(4, this->_internal_vocab_size(), target);
}
@ -1042,7 +1075,7 @@ failure:
}
// optional float character_coverage = 10 [default = 0.9995];
if (cached_has_bits & 0x00200000u) {
if (cached_has_bits & 0x01000000u) {
target = stream->EnsureSpace(target);
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteFloatToArray(10, this->_internal_character_coverage(), target);
}
@ -1066,79 +1099,81 @@ failure:
}
// optional int32 seed_sentencepiece_size = 14 [default = 1000000];
if (cached_has_bits & 0x00400000u) {
if (cached_has_bits & 0x02000000u) {
target = stream->EnsureSpace(target);
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(14, this->_internal_seed_sentencepiece_size(), target);
}
// optional float shrinking_factor = 15 [default = 0.75];
if (cached_has_bits & 0x00800000u) {
if (cached_has_bits & 0x04000000u) {
target = stream->EnsureSpace(target);
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteFloatToArray(15, this->_internal_shrinking_factor(), target);
}
// optional int32 num_threads = 16 [default = 16];
if (cached_has_bits & 0x01000000u) {
if (cached_has_bits & 0x08000000u) {
target = stream->EnsureSpace(target);
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(16, this->_internal_num_threads(), target);
}
// optional int32 num_sub_iterations = 17 [default = 2];
if (cached_has_bits & 0x02000000u) {
if (cached_has_bits & 0x10000000u) {
target = stream->EnsureSpace(target);
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(17, this->_internal_num_sub_iterations(), target);
}
// optional int32 max_sentence_length = 18 [default = 4192];
if (cached_has_bits & 0x04000000u) {
if (cached_has_bits & 0x20000000u) {
target = stream->EnsureSpace(target);
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(18, this->_internal_max_sentence_length(), target);
}
// optional bool shuffle_input_sentence = 19 [default = true];
if (cached_has_bits & 0x10000000u) {
if (cached_has_bits & 0x80000000u) {
target = stream->EnsureSpace(target);
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(19, this->_internal_shuffle_input_sentence(), target);
}
// optional int32 max_sentencepiece_length = 20 [default = 16];
if (cached_has_bits & 0x08000000u) {
if (cached_has_bits & 0x40000000u) {
target = stream->EnsureSpace(target);
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(20, this->_internal_max_sentencepiece_length(), target);
}
cached_has_bits = _has_bits_[1];
// optional bool split_by_unicode_script = 21 [default = true];
if (cached_has_bits & 0x20000000u) {
if (cached_has_bits & 0x00000001u) {
target = stream->EnsureSpace(target);
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(21, this->_internal_split_by_unicode_script(), target);
}
// optional bool split_by_whitespace = 22 [default = true];
if (cached_has_bits & 0x80000000u) {
if (cached_has_bits & 0x00000004u) {
target = stream->EnsureSpace(target);
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(22, this->_internal_split_by_whitespace(), target);
}
// optional bool split_by_number = 23 [default = true];
if (cached_has_bits & 0x40000000u) {
if (cached_has_bits & 0x00000002u) {
target = stream->EnsureSpace(target);
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(23, this->_internal_split_by_number(), target);
}
cached_has_bits = _has_bits_[0];
// optional bool treat_whitespace_as_suffix = 24 [default = false];
if (cached_has_bits & 0x00001000u) {
if (cached_has_bits & 0x00002000u) {
target = stream->EnsureSpace(target);
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(24, this->_internal_treat_whitespace_as_suffix(), target);
}
// optional bool split_digits = 25 [default = false];
if (cached_has_bits & 0x00004000u) {
if (cached_has_bits & 0x00008000u) {
target = stream->EnsureSpace(target);
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(25, this->_internal_split_digits(), target);
}
// optional bool allow_whitespace_only_pieces = 26 [default = false];
if (cached_has_bits & 0x00002000u) {
if (cached_has_bits & 0x00004000u) {
target = stream->EnsureSpace(target);
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(26, this->_internal_allow_whitespace_only_pieces(), target);
}
@ -1157,13 +1192,13 @@ failure:
cached_has_bits = _has_bits_[1];
// optional bool vocabulary_output_piece_score = 32 [default = true];
if (cached_has_bits & 0x00000001u) {
if (cached_has_bits & 0x00000008u) {
target = stream->EnsureSpace(target);
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(32, this->_internal_vocabulary_output_piece_score(), target);
}
// optional bool hard_vocab_limit = 33 [default = true];
if (cached_has_bits & 0x00000002u) {
if (cached_has_bits & 0x00000010u) {
target = stream->EnsureSpace(target);
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(33, this->_internal_hard_vocab_limit(), target);
}
@ -1176,7 +1211,7 @@ failure:
}
// optional bool byte_fallback = 35 [default = false];
if (cached_has_bits & 0x00008000u) {
if (cached_has_bits & 0x00010000u) {
target = stream->EnsureSpace(target);
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(35, this->_internal_byte_fallback(), target);
}
@ -1188,26 +1223,26 @@ failure:
}
// optional int32 unk_id = 40 [default = 0];
if (cached_has_bits & 0x00010000u) {
if (cached_has_bits & 0x00080000u) {
target = stream->EnsureSpace(target);
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(40, this->_internal_unk_id(), target);
}
cached_has_bits = _has_bits_[1];
// optional int32 bos_id = 41 [default = 1];
if (cached_has_bits & 0x00000004u) {
if (cached_has_bits & 0x00000020u) {
target = stream->EnsureSpace(target);
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(41, this->_internal_bos_id(), target);
}
// optional int32 eos_id = 42 [default = 2];
if (cached_has_bits & 0x00000008u) {
if (cached_has_bits & 0x00000040u) {
target = stream->EnsureSpace(target);
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(42, this->_internal_eos_id(), target);
}
// optional int32 pad_id = 43 [default = -1];
if (cached_has_bits & 0x00000010u) {
if (cached_has_bits & 0x00000080u) {
target = stream->EnsureSpace(target);
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(43, this->_internal_pad_id(), target);
}
@ -1249,6 +1284,24 @@ failure:
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(49, this->_internal_train_extremely_large_corpus(), target);
}
// optional bool enable_differential_privacy = 50 [default = false];
if (cached_has_bits & 0x00001000u) {
target = stream->EnsureSpace(target);
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(50, this->_internal_enable_differential_privacy(), target);
}
// optional float differential_privacy_noise_level = 51 [default = 0];
if (cached_has_bits & 0x00100000u) {
target = stream->EnsureSpace(target);
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteFloatToArray(51, this->_internal_differential_privacy_noise_level(), target);
}
// optional uint64 differential_privacy_clipping_threshold = 52 [default = 0];
if (cached_has_bits & 0x00200000u) {
target = stream->EnsureSpace(target);
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteUInt64ToArray(52, this->_internal_differential_privacy_clipping_threshold(), target);
}
// Extension range [200, 536870912)
target = _extensions_._InternalSerialize(
200, 536870912, target, stream);
@ -1391,33 +1444,31 @@ size_t TrainerSpec::ByteSizeLong() const {
this->_internal_training_sentence_size());
}
// optional bool treat_whitespace_as_suffix = 24 [default = false];
// optional bool enable_differential_privacy = 50 [default = false];
if (cached_has_bits & 0x00001000u) {
total_size += 2 + 1;
}
// optional bool allow_whitespace_only_pieces = 26 [default = false];
// optional bool treat_whitespace_as_suffix = 24 [default = false];
if (cached_has_bits & 0x00002000u) {
total_size += 2 + 1;
}
// optional bool split_digits = 25 [default = false];
// optional bool allow_whitespace_only_pieces = 26 [default = false];
if (cached_has_bits & 0x00004000u) {
total_size += 2 + 1;
}
// optional bool byte_fallback = 35 [default = false];
// optional bool split_digits = 25 [default = false];
if (cached_has_bits & 0x00008000u) {
total_size += 2 + 1;
}
}
if (cached_has_bits & 0x00ff0000u) {
// optional int32 unk_id = 40 [default = 0];
// optional bool byte_fallback = 35 [default = false];
if (cached_has_bits & 0x00010000u) {
total_size += 2 +
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size(
this->_internal_unk_id());
total_size += 2 + 1;
}
// optional bool use_all_vocab = 34 [default = false];
@ -1430,115 +1481,134 @@ size_t TrainerSpec::ByteSizeLong() const {
total_size += 2 + 1;
}
// optional .sentencepiece.TrainerSpec.ModelType model_type = 3 [default = UNIGRAM];
// optional int32 unk_id = 40 [default = 0];
if (cached_has_bits & 0x00080000u) {
total_size += 2 +
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size(
this->_internal_unk_id());
}
// optional float differential_privacy_noise_level = 51 [default = 0];
if (cached_has_bits & 0x00100000u) {
total_size += 2 + 4;
}
// optional uint64 differential_privacy_clipping_threshold = 52 [default = 0];
if (cached_has_bits & 0x00200000u) {
total_size += 2 +
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::UInt64Size(
this->_internal_differential_privacy_clipping_threshold());
}
// optional .sentencepiece.TrainerSpec.ModelType model_type = 3 [default = UNIGRAM];
if (cached_has_bits & 0x00400000u) {
total_size += 1 +
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::EnumSize(this->_internal_model_type());
}
// optional int32 vocab_size = 4 [default = 8000];
if (cached_has_bits & 0x00100000u) {
if (cached_has_bits & 0x00800000u) {
total_size += 1 +
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size(
this->_internal_vocab_size());
}
}
if (cached_has_bits & 0xff000000u) {
// optional float character_coverage = 10 [default = 0.9995];
if (cached_has_bits & 0x00200000u) {
if (cached_has_bits & 0x01000000u) {
total_size += 1 + 4;
}
// optional int32 seed_sentencepiece_size = 14 [default = 1000000];
if (cached_has_bits & 0x00400000u) {
if (cached_has_bits & 0x02000000u) {
total_size += 1 +
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size(
this->_internal_seed_sentencepiece_size());
}
// optional float shrinking_factor = 15 [default = 0.75];
if (cached_has_bits & 0x00800000u) {
if (cached_has_bits & 0x04000000u) {
total_size += 1 + 4;
}
}
if (cached_has_bits & 0xff000000u) {
// optional int32 num_threads = 16 [default = 16];
if (cached_has_bits & 0x01000000u) {
if (cached_has_bits & 0x08000000u) {
total_size += 2 +
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size(
this->_internal_num_threads());
}
// optional int32 num_sub_iterations = 17 [default = 2];
if (cached_has_bits & 0x02000000u) {
if (cached_has_bits & 0x10000000u) {
total_size += 2 +
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size(
this->_internal_num_sub_iterations());
}
// optional int32 max_sentence_length = 18 [default = 4192];
if (cached_has_bits & 0x04000000u) {
if (cached_has_bits & 0x20000000u) {
total_size += 2 +
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size(
this->_internal_max_sentence_length());
}
// optional int32 max_sentencepiece_length = 20 [default = 16];
if (cached_has_bits & 0x08000000u) {
if (cached_has_bits & 0x40000000u) {
total_size += 2 +
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size(
this->_internal_max_sentencepiece_length());
}
// optional bool shuffle_input_sentence = 19 [default = true];
if (cached_has_bits & 0x10000000u) {
total_size += 2 + 1;
}
// optional bool split_by_unicode_script = 21 [default = true];
if (cached_has_bits & 0x20000000u) {
total_size += 2 + 1;
}
// optional bool split_by_number = 23 [default = true];
if (cached_has_bits & 0x40000000u) {
total_size += 2 + 1;
}
// optional bool split_by_whitespace = 22 [default = true];
if (cached_has_bits & 0x80000000u) {
total_size += 2 + 1;
}
}
cached_has_bits = _has_bits_[1];
if (cached_has_bits & 0x0000001fu) {
// optional bool vocabulary_output_piece_score = 32 [default = true];
if (cached_has_bits & 0x000000ffu) {
// optional bool split_by_unicode_script = 21 [default = true];
if (cached_has_bits & 0x00000001u) {
total_size += 2 + 1;
}
// optional bool hard_vocab_limit = 33 [default = true];
// optional bool split_by_number = 23 [default = true];
if (cached_has_bits & 0x00000002u) {
total_size += 2 + 1;
}
// optional int32 bos_id = 41 [default = 1];
// optional bool split_by_whitespace = 22 [default = true];
if (cached_has_bits & 0x00000004u) {
total_size += 2 + 1;
}
// optional bool vocabulary_output_piece_score = 32 [default = true];
if (cached_has_bits & 0x00000008u) {
total_size += 2 + 1;
}
// optional bool hard_vocab_limit = 33 [default = true];
if (cached_has_bits & 0x00000010u) {
total_size += 2 + 1;
}
// optional int32 bos_id = 41 [default = 1];
if (cached_has_bits & 0x00000020u) {
total_size += 2 +
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size(
this->_internal_bos_id());
}
// optional int32 eos_id = 42 [default = 2];
if (cached_has_bits & 0x00000008u) {
if (cached_has_bits & 0x00000040u) {
total_size += 2 +
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size(
this->_internal_eos_id());
}
// optional int32 pad_id = 43 [default = -1];
if (cached_has_bits & 0x00000010u) {
if (cached_has_bits & 0x00000080u) {
total_size += 2 +
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size(
this->_internal_pad_id());
@ -1612,22 +1682,22 @@ void TrainerSpec::MergeFrom(const TrainerSpec& from) {
training_sentence_size_ = from.training_sentence_size_;
}
if (cached_has_bits & 0x00001000u) {
treat_whitespace_as_suffix_ = from.treat_whitespace_as_suffix_;
enable_differential_privacy_ = from.enable_differential_privacy_;
}
if (cached_has_bits & 0x00002000u) {
allow_whitespace_only_pieces_ = from.allow_whitespace_only_pieces_;
treat_whitespace_as_suffix_ = from.treat_whitespace_as_suffix_;
}
if (cached_has_bits & 0x00004000u) {
split_digits_ = from.split_digits_;
allow_whitespace_only_pieces_ = from.allow_whitespace_only_pieces_;
}
if (cached_has_bits & 0x00008000u) {
byte_fallback_ = from.byte_fallback_;
split_digits_ = from.split_digits_;
}
_has_bits_[0] |= cached_has_bits;
}
if (cached_has_bits & 0x00ff0000u) {
if (cached_has_bits & 0x00010000u) {
unk_id_ = from.unk_id_;
byte_fallback_ = from.byte_fallback_;
}
if (cached_has_bits & 0x00020000u) {
use_all_vocab_ = from.use_all_vocab_;
@ -1636,64 +1706,73 @@ void TrainerSpec::MergeFrom(const TrainerSpec& from) {
train_extremely_large_corpus_ = from.train_extremely_large_corpus_;
}
if (cached_has_bits & 0x00080000u) {
model_type_ = from.model_type_;
unk_id_ = from.unk_id_;
}
if (cached_has_bits & 0x00100000u) {
vocab_size_ = from.vocab_size_;
differential_privacy_noise_level_ = from.differential_privacy_noise_level_;
}
if (cached_has_bits & 0x00200000u) {
character_coverage_ = from.character_coverage_;
differential_privacy_clipping_threshold_ = from.differential_privacy_clipping_threshold_;
}
if (cached_has_bits & 0x00400000u) {
seed_sentencepiece_size_ = from.seed_sentencepiece_size_;
model_type_ = from.model_type_;
}
if (cached_has_bits & 0x00800000u) {
shrinking_factor_ = from.shrinking_factor_;
vocab_size_ = from.vocab_size_;
}
_has_bits_[0] |= cached_has_bits;
}
if (cached_has_bits & 0xff000000u) {
if (cached_has_bits & 0x01000000u) {
num_threads_ = from.num_threads_;
character_coverage_ = from.character_coverage_;
}
if (cached_has_bits & 0x02000000u) {
num_sub_iterations_ = from.num_sub_iterations_;
seed_sentencepiece_size_ = from.seed_sentencepiece_size_;
}
if (cached_has_bits & 0x04000000u) {
max_sentence_length_ = from.max_sentence_length_;
shrinking_factor_ = from.shrinking_factor_;
}
if (cached_has_bits & 0x08000000u) {
max_sentencepiece_length_ = from.max_sentencepiece_length_;
num_threads_ = from.num_threads_;
}
if (cached_has_bits & 0x10000000u) {
shuffle_input_sentence_ = from.shuffle_input_sentence_;
num_sub_iterations_ = from.num_sub_iterations_;
}
if (cached_has_bits & 0x20000000u) {
split_by_unicode_script_ = from.split_by_unicode_script_;
max_sentence_length_ = from.max_sentence_length_;
}
if (cached_has_bits & 0x40000000u) {
split_by_number_ = from.split_by_number_;
max_sentencepiece_length_ = from.max_sentencepiece_length_;
}
if (cached_has_bits & 0x80000000u) {
split_by_whitespace_ = from.split_by_whitespace_;
shuffle_input_sentence_ = from.shuffle_input_sentence_;
}
_has_bits_[0] |= cached_has_bits;
}
cached_has_bits = from._has_bits_[1];
if (cached_has_bits & 0x0000001fu) {
if (cached_has_bits & 0x000000ffu) {
if (cached_has_bits & 0x00000001u) {
vocabulary_output_piece_score_ = from.vocabulary_output_piece_score_;
split_by_unicode_script_ = from.split_by_unicode_script_;
}
if (cached_has_bits & 0x00000002u) {
hard_vocab_limit_ = from.hard_vocab_limit_;
split_by_number_ = from.split_by_number_;
}
if (cached_has_bits & 0x00000004u) {
bos_id_ = from.bos_id_;
split_by_whitespace_ = from.split_by_whitespace_;
}
if (cached_has_bits & 0x00000008u) {
eos_id_ = from.eos_id_;
vocabulary_output_piece_score_ = from.vocabulary_output_piece_score_;
}
if (cached_has_bits & 0x00000010u) {
hard_vocab_limit_ = from.hard_vocab_limit_;
}
if (cached_has_bits & 0x00000020u) {
bos_id_ = from.bos_id_;
}
if (cached_has_bits & 0x00000040u) {
eos_id_ = from.eos_id_;
}
if (cached_has_bits & 0x00000080u) {
pad_id_ = from.pad_id_;
}
_has_bits_[1] |= cached_has_bits;
@ -1734,8 +1813,8 @@ void TrainerSpec::InternalSwap(TrainerSpec* other) {
eos_piece_.Swap(&other->eos_piece_, nullptr, GetArena());
pad_piece_.Swap(&other->pad_piece_, nullptr, GetArena());
::PROTOBUF_NAMESPACE_ID::internal::memswap<
PROTOBUF_FIELD_OFFSET(TrainerSpec, train_extremely_large_corpus_)
+ sizeof(TrainerSpec::train_extremely_large_corpus_)
PROTOBUF_FIELD_OFFSET(TrainerSpec, differential_privacy_clipping_threshold_)
+ sizeof(TrainerSpec::differential_privacy_clipping_threshold_)
- PROTOBUF_FIELD_OFFSET(TrainerSpec, self_test_sample_size_)>(
reinterpret_cast<char*>(&self_test_sample_size_),
reinterpret_cast<char*>(&other->self_test_sample_size_));

View File

@ -277,13 +277,16 @@ class TrainerSpec PROTOBUF_FINAL :
kMiningSentenceSizeFieldNumber = 12,
kInputSentenceSizeFieldNumber = 11,
kTrainingSentenceSizeFieldNumber = 13,
kEnableDifferentialPrivacyFieldNumber = 50,
kTreatWhitespaceAsSuffixFieldNumber = 24,
kAllowWhitespaceOnlyPiecesFieldNumber = 26,
kSplitDigitsFieldNumber = 25,
kByteFallbackFieldNumber = 35,
kUnkIdFieldNumber = 40,
kUseAllVocabFieldNumber = 34,
kTrainExtremelyLargeCorpusFieldNumber = 49,
kUnkIdFieldNumber = 40,
kDifferentialPrivacyNoiseLevelFieldNumber = 51,
kDifferentialPrivacyClippingThresholdFieldNumber = 52,
kModelTypeFieldNumber = 3,
kVocabSizeFieldNumber = 4,
kCharacterCoverageFieldNumber = 10,
@ -611,6 +614,19 @@ class TrainerSpec PROTOBUF_FINAL :
void _internal_set_training_sentence_size(::PROTOBUF_NAMESPACE_ID::int32 value);
public:
// optional bool enable_differential_privacy = 50 [default = false];
bool has_enable_differential_privacy() const;
private:
bool _internal_has_enable_differential_privacy() const;
public:
void clear_enable_differential_privacy();
bool enable_differential_privacy() const;
void set_enable_differential_privacy(bool value);
private:
bool _internal_enable_differential_privacy() const;
void _internal_set_enable_differential_privacy(bool value);
public:
// optional bool treat_whitespace_as_suffix = 24 [default = false];
bool has_treat_whitespace_as_suffix() const;
private:
@ -663,19 +679,6 @@ class TrainerSpec PROTOBUF_FINAL :
void _internal_set_byte_fallback(bool value);
public:
// optional int32 unk_id = 40 [default = 0];
bool has_unk_id() const;
private:
bool _internal_has_unk_id() const;
public:
void clear_unk_id();
::PROTOBUF_NAMESPACE_ID::int32 unk_id() const;
void set_unk_id(::PROTOBUF_NAMESPACE_ID::int32 value);
private:
::PROTOBUF_NAMESPACE_ID::int32 _internal_unk_id() const;
void _internal_set_unk_id(::PROTOBUF_NAMESPACE_ID::int32 value);
public:
// optional bool use_all_vocab = 34 [default = false];
bool has_use_all_vocab() const;
private:
@ -702,6 +705,45 @@ class TrainerSpec PROTOBUF_FINAL :
void _internal_set_train_extremely_large_corpus(bool value);
public:
// optional int32 unk_id = 40 [default = 0];
bool has_unk_id() const;
private:
bool _internal_has_unk_id() const;
public:
void clear_unk_id();
::PROTOBUF_NAMESPACE_ID::int32 unk_id() const;
void set_unk_id(::PROTOBUF_NAMESPACE_ID::int32 value);
private:
::PROTOBUF_NAMESPACE_ID::int32 _internal_unk_id() const;
void _internal_set_unk_id(::PROTOBUF_NAMESPACE_ID::int32 value);
public:
// optional float differential_privacy_noise_level = 51 [default = 0];
bool has_differential_privacy_noise_level() const;
private:
bool _internal_has_differential_privacy_noise_level() const;
public:
void clear_differential_privacy_noise_level();
float differential_privacy_noise_level() const;
void set_differential_privacy_noise_level(float value);
private:
float _internal_differential_privacy_noise_level() const;
void _internal_set_differential_privacy_noise_level(float value);
public:
// optional uint64 differential_privacy_clipping_threshold = 52 [default = 0];
bool has_differential_privacy_clipping_threshold() const;
private:
bool _internal_has_differential_privacy_clipping_threshold() const;
public:
void clear_differential_privacy_clipping_threshold();
::PROTOBUF_NAMESPACE_ID::uint64 differential_privacy_clipping_threshold() const;
void set_differential_privacy_clipping_threshold(::PROTOBUF_NAMESPACE_ID::uint64 value);
private:
::PROTOBUF_NAMESPACE_ID::uint64 _internal_differential_privacy_clipping_threshold() const;
void _internal_set_differential_privacy_clipping_threshold(::PROTOBUF_NAMESPACE_ID::uint64 value);
public:
// optional .sentencepiece.TrainerSpec.ModelType model_type = 3 [default = UNIGRAM];
bool has_model_type() const;
private:
@ -969,13 +1011,16 @@ class TrainerSpec PROTOBUF_FINAL :
::PROTOBUF_NAMESPACE_ID::int32 mining_sentence_size_;
::PROTOBUF_NAMESPACE_ID::uint64 input_sentence_size_;
::PROTOBUF_NAMESPACE_ID::int32 training_sentence_size_;
bool enable_differential_privacy_;
bool treat_whitespace_as_suffix_;
bool allow_whitespace_only_pieces_;
bool split_digits_;
bool byte_fallback_;
::PROTOBUF_NAMESPACE_ID::int32 unk_id_;
bool use_all_vocab_;
bool train_extremely_large_corpus_;
::PROTOBUF_NAMESPACE_ID::int32 unk_id_;
float differential_privacy_noise_level_;
::PROTOBUF_NAMESPACE_ID::uint64 differential_privacy_clipping_threshold_;
int model_type_;
::PROTOBUF_NAMESPACE_ID::int32 vocab_size_;
float character_coverage_;
@ -2195,7 +2240,7 @@ inline void TrainerSpec::set_allocated_model_prefix(std::string* model_prefix) {
// optional .sentencepiece.TrainerSpec.ModelType model_type = 3 [default = UNIGRAM];
inline bool TrainerSpec::_internal_has_model_type() const {
bool value = (_has_bits_[0] & 0x00080000u) != 0;
bool value = (_has_bits_[0] & 0x00400000u) != 0;
return value;
}
inline bool TrainerSpec::has_model_type() const {
@ -2203,7 +2248,7 @@ inline bool TrainerSpec::has_model_type() const {
}
inline void TrainerSpec::clear_model_type() {
model_type_ = 1;
_has_bits_[0] &= ~0x00080000u;
_has_bits_[0] &= ~0x00400000u;
}
inline ::sentencepiece::TrainerSpec_ModelType TrainerSpec::_internal_model_type() const {
return static_cast< ::sentencepiece::TrainerSpec_ModelType >(model_type_);
@ -2214,7 +2259,7 @@ inline ::sentencepiece::TrainerSpec_ModelType TrainerSpec::model_type() const {
}
inline void TrainerSpec::_internal_set_model_type(::sentencepiece::TrainerSpec_ModelType value) {
assert(::sentencepiece::TrainerSpec_ModelType_IsValid(value));
_has_bits_[0] |= 0x00080000u;
_has_bits_[0] |= 0x00400000u;
model_type_ = value;
}
inline void TrainerSpec::set_model_type(::sentencepiece::TrainerSpec_ModelType value) {
@ -2224,7 +2269,7 @@ inline void TrainerSpec::set_model_type(::sentencepiece::TrainerSpec_ModelType v
// optional int32 vocab_size = 4 [default = 8000];
inline bool TrainerSpec::_internal_has_vocab_size() const {
bool value = (_has_bits_[0] & 0x00100000u) != 0;
bool value = (_has_bits_[0] & 0x00800000u) != 0;
return value;
}
inline bool TrainerSpec::has_vocab_size() const {
@ -2232,7 +2277,7 @@ inline bool TrainerSpec::has_vocab_size() const {
}
inline void TrainerSpec::clear_vocab_size() {
vocab_size_ = 8000;
_has_bits_[0] &= ~0x00100000u;
_has_bits_[0] &= ~0x00800000u;
}
inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_vocab_size() const {
return vocab_size_;
@ -2242,7 +2287,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::vocab_size() const {
return _internal_vocab_size();
}
inline void TrainerSpec::_internal_set_vocab_size(::PROTOBUF_NAMESPACE_ID::int32 value) {
_has_bits_[0] |= 0x00100000u;
_has_bits_[0] |= 0x00800000u;
vocab_size_ = value;
}
inline void TrainerSpec::set_vocab_size(::PROTOBUF_NAMESPACE_ID::int32 value) {
@ -2352,9 +2397,93 @@ inline void TrainerSpec::set_self_test_sample_size(::PROTOBUF_NAMESPACE_ID::int3
// @@protoc_insertion_point(field_set:sentencepiece.TrainerSpec.self_test_sample_size)
}
// optional bool enable_differential_privacy = 50 [default = false];
inline bool TrainerSpec::_internal_has_enable_differential_privacy() const {
bool value = (_has_bits_[0] & 0x00001000u) != 0;
return value;
}
inline bool TrainerSpec::has_enable_differential_privacy() const {
return _internal_has_enable_differential_privacy();
}
inline void TrainerSpec::clear_enable_differential_privacy() {
enable_differential_privacy_ = false;
_has_bits_[0] &= ~0x00001000u;
}
inline bool TrainerSpec::_internal_enable_differential_privacy() const {
return enable_differential_privacy_;
}
inline bool TrainerSpec::enable_differential_privacy() const {
// @@protoc_insertion_point(field_get:sentencepiece.TrainerSpec.enable_differential_privacy)
return _internal_enable_differential_privacy();
}
inline void TrainerSpec::_internal_set_enable_differential_privacy(bool value) {
_has_bits_[0] |= 0x00001000u;
enable_differential_privacy_ = value;
}
inline void TrainerSpec::set_enable_differential_privacy(bool value) {
_internal_set_enable_differential_privacy(value);
// @@protoc_insertion_point(field_set:sentencepiece.TrainerSpec.enable_differential_privacy)
}
// optional float differential_privacy_noise_level = 51 [default = 0];
inline bool TrainerSpec::_internal_has_differential_privacy_noise_level() const {
bool value = (_has_bits_[0] & 0x00100000u) != 0;
return value;
}
inline bool TrainerSpec::has_differential_privacy_noise_level() const {
return _internal_has_differential_privacy_noise_level();
}
inline void TrainerSpec::clear_differential_privacy_noise_level() {
differential_privacy_noise_level_ = 0;
_has_bits_[0] &= ~0x00100000u;
}
inline float TrainerSpec::_internal_differential_privacy_noise_level() const {
return differential_privacy_noise_level_;
}
inline float TrainerSpec::differential_privacy_noise_level() const {
// @@protoc_insertion_point(field_get:sentencepiece.TrainerSpec.differential_privacy_noise_level)
return _internal_differential_privacy_noise_level();
}
inline void TrainerSpec::_internal_set_differential_privacy_noise_level(float value) {
_has_bits_[0] |= 0x00100000u;
differential_privacy_noise_level_ = value;
}
inline void TrainerSpec::set_differential_privacy_noise_level(float value) {
_internal_set_differential_privacy_noise_level(value);
// @@protoc_insertion_point(field_set:sentencepiece.TrainerSpec.differential_privacy_noise_level)
}
// optional uint64 differential_privacy_clipping_threshold = 52 [default = 0];
inline bool TrainerSpec::_internal_has_differential_privacy_clipping_threshold() const {
bool value = (_has_bits_[0] & 0x00200000u) != 0;
return value;
}
inline bool TrainerSpec::has_differential_privacy_clipping_threshold() const {
return _internal_has_differential_privacy_clipping_threshold();
}
inline void TrainerSpec::clear_differential_privacy_clipping_threshold() {
differential_privacy_clipping_threshold_ = PROTOBUF_ULONGLONG(0);
_has_bits_[0] &= ~0x00200000u;
}
inline ::PROTOBUF_NAMESPACE_ID::uint64 TrainerSpec::_internal_differential_privacy_clipping_threshold() const {
return differential_privacy_clipping_threshold_;
}
inline ::PROTOBUF_NAMESPACE_ID::uint64 TrainerSpec::differential_privacy_clipping_threshold() const {
// @@protoc_insertion_point(field_get:sentencepiece.TrainerSpec.differential_privacy_clipping_threshold)
return _internal_differential_privacy_clipping_threshold();
}
inline void TrainerSpec::_internal_set_differential_privacy_clipping_threshold(::PROTOBUF_NAMESPACE_ID::uint64 value) {
_has_bits_[0] |= 0x00200000u;
differential_privacy_clipping_threshold_ = value;
}
inline void TrainerSpec::set_differential_privacy_clipping_threshold(::PROTOBUF_NAMESPACE_ID::uint64 value) {
_internal_set_differential_privacy_clipping_threshold(value);
// @@protoc_insertion_point(field_set:sentencepiece.TrainerSpec.differential_privacy_clipping_threshold)
}
// optional float character_coverage = 10 [default = 0.9995];
inline bool TrainerSpec::_internal_has_character_coverage() const {
bool value = (_has_bits_[0] & 0x00200000u) != 0;
bool value = (_has_bits_[0] & 0x01000000u) != 0;
return value;
}
inline bool TrainerSpec::has_character_coverage() const {
@ -2362,7 +2491,7 @@ inline bool TrainerSpec::has_character_coverage() const {
}
inline void TrainerSpec::clear_character_coverage() {
character_coverage_ = 0.9995f;
_has_bits_[0] &= ~0x00200000u;
_has_bits_[0] &= ~0x01000000u;
}
inline float TrainerSpec::_internal_character_coverage() const {
return character_coverage_;
@ -2372,7 +2501,7 @@ inline float TrainerSpec::character_coverage() const {
return _internal_character_coverage();
}
inline void TrainerSpec::_internal_set_character_coverage(float value) {
_has_bits_[0] |= 0x00200000u;
_has_bits_[0] |= 0x01000000u;
character_coverage_ = value;
}
inline void TrainerSpec::set_character_coverage(float value) {
@ -2410,7 +2539,7 @@ inline void TrainerSpec::set_input_sentence_size(::PROTOBUF_NAMESPACE_ID::uint64
// optional bool shuffle_input_sentence = 19 [default = true];
inline bool TrainerSpec::_internal_has_shuffle_input_sentence() const {
bool value = (_has_bits_[0] & 0x10000000u) != 0;
bool value = (_has_bits_[0] & 0x80000000u) != 0;
return value;
}
inline bool TrainerSpec::has_shuffle_input_sentence() const {
@ -2418,7 +2547,7 @@ inline bool TrainerSpec::has_shuffle_input_sentence() const {
}
inline void TrainerSpec::clear_shuffle_input_sentence() {
shuffle_input_sentence_ = true;
_has_bits_[0] &= ~0x10000000u;
_has_bits_[0] &= ~0x80000000u;
}
inline bool TrainerSpec::_internal_shuffle_input_sentence() const {
return shuffle_input_sentence_;
@ -2428,7 +2557,7 @@ inline bool TrainerSpec::shuffle_input_sentence() const {
return _internal_shuffle_input_sentence();
}
inline void TrainerSpec::_internal_set_shuffle_input_sentence(bool value) {
_has_bits_[0] |= 0x10000000u;
_has_bits_[0] |= 0x80000000u;
shuffle_input_sentence_ = value;
}
inline void TrainerSpec::set_shuffle_input_sentence(bool value) {
@ -2494,7 +2623,7 @@ inline void TrainerSpec::set_training_sentence_size(::PROTOBUF_NAMESPACE_ID::int
// optional int32 seed_sentencepiece_size = 14 [default = 1000000];
inline bool TrainerSpec::_internal_has_seed_sentencepiece_size() const {
bool value = (_has_bits_[0] & 0x00400000u) != 0;
bool value = (_has_bits_[0] & 0x02000000u) != 0;
return value;
}
inline bool TrainerSpec::has_seed_sentencepiece_size() const {
@ -2502,7 +2631,7 @@ inline bool TrainerSpec::has_seed_sentencepiece_size() const {
}
inline void TrainerSpec::clear_seed_sentencepiece_size() {
seed_sentencepiece_size_ = 1000000;
_has_bits_[0] &= ~0x00400000u;
_has_bits_[0] &= ~0x02000000u;
}
inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_seed_sentencepiece_size() const {
return seed_sentencepiece_size_;
@ -2512,7 +2641,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::seed_sentencepiece_size() con
return _internal_seed_sentencepiece_size();
}
inline void TrainerSpec::_internal_set_seed_sentencepiece_size(::PROTOBUF_NAMESPACE_ID::int32 value) {
_has_bits_[0] |= 0x00400000u;
_has_bits_[0] |= 0x02000000u;
seed_sentencepiece_size_ = value;
}
inline void TrainerSpec::set_seed_sentencepiece_size(::PROTOBUF_NAMESPACE_ID::int32 value) {
@ -2522,7 +2651,7 @@ inline void TrainerSpec::set_seed_sentencepiece_size(::PROTOBUF_NAMESPACE_ID::in
// optional float shrinking_factor = 15 [default = 0.75];
inline bool TrainerSpec::_internal_has_shrinking_factor() const {
bool value = (_has_bits_[0] & 0x00800000u) != 0;
bool value = (_has_bits_[0] & 0x04000000u) != 0;
return value;
}
inline bool TrainerSpec::has_shrinking_factor() const {
@ -2530,7 +2659,7 @@ inline bool TrainerSpec::has_shrinking_factor() const {
}
inline void TrainerSpec::clear_shrinking_factor() {
shrinking_factor_ = 0.75f;
_has_bits_[0] &= ~0x00800000u;
_has_bits_[0] &= ~0x04000000u;
}
inline float TrainerSpec::_internal_shrinking_factor() const {
return shrinking_factor_;
@ -2540,7 +2669,7 @@ inline float TrainerSpec::shrinking_factor() const {
return _internal_shrinking_factor();
}
inline void TrainerSpec::_internal_set_shrinking_factor(float value) {
_has_bits_[0] |= 0x00800000u;
_has_bits_[0] |= 0x04000000u;
shrinking_factor_ = value;
}
inline void TrainerSpec::set_shrinking_factor(float value) {
@ -2550,7 +2679,7 @@ inline void TrainerSpec::set_shrinking_factor(float value) {
// optional int32 max_sentence_length = 18 [default = 4192];
inline bool TrainerSpec::_internal_has_max_sentence_length() const {
bool value = (_has_bits_[0] & 0x04000000u) != 0;
bool value = (_has_bits_[0] & 0x20000000u) != 0;
return value;
}
inline bool TrainerSpec::has_max_sentence_length() const {
@ -2558,7 +2687,7 @@ inline bool TrainerSpec::has_max_sentence_length() const {
}
inline void TrainerSpec::clear_max_sentence_length() {
max_sentence_length_ = 4192;
_has_bits_[0] &= ~0x04000000u;
_has_bits_[0] &= ~0x20000000u;
}
inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_max_sentence_length() const {
return max_sentence_length_;
@ -2568,7 +2697,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::max_sentence_length() const {
return _internal_max_sentence_length();
}
inline void TrainerSpec::_internal_set_max_sentence_length(::PROTOBUF_NAMESPACE_ID::int32 value) {
_has_bits_[0] |= 0x04000000u;
_has_bits_[0] |= 0x20000000u;
max_sentence_length_ = value;
}
inline void TrainerSpec::set_max_sentence_length(::PROTOBUF_NAMESPACE_ID::int32 value) {
@ -2578,7 +2707,7 @@ inline void TrainerSpec::set_max_sentence_length(::PROTOBUF_NAMESPACE_ID::int32
// optional int32 num_threads = 16 [default = 16];
inline bool TrainerSpec::_internal_has_num_threads() const {
bool value = (_has_bits_[0] & 0x01000000u) != 0;
bool value = (_has_bits_[0] & 0x08000000u) != 0;
return value;
}
inline bool TrainerSpec::has_num_threads() const {
@ -2586,7 +2715,7 @@ inline bool TrainerSpec::has_num_threads() const {
}
inline void TrainerSpec::clear_num_threads() {
num_threads_ = 16;
_has_bits_[0] &= ~0x01000000u;
_has_bits_[0] &= ~0x08000000u;
}
inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_num_threads() const {
return num_threads_;
@ -2596,7 +2725,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::num_threads() const {
return _internal_num_threads();
}
inline void TrainerSpec::_internal_set_num_threads(::PROTOBUF_NAMESPACE_ID::int32 value) {
_has_bits_[0] |= 0x01000000u;
_has_bits_[0] |= 0x08000000u;
num_threads_ = value;
}
inline void TrainerSpec::set_num_threads(::PROTOBUF_NAMESPACE_ID::int32 value) {
@ -2606,7 +2735,7 @@ inline void TrainerSpec::set_num_threads(::PROTOBUF_NAMESPACE_ID::int32 value) {
// optional int32 num_sub_iterations = 17 [default = 2];
inline bool TrainerSpec::_internal_has_num_sub_iterations() const {
bool value = (_has_bits_[0] & 0x02000000u) != 0;
bool value = (_has_bits_[0] & 0x10000000u) != 0;
return value;
}
inline bool TrainerSpec::has_num_sub_iterations() const {
@ -2614,7 +2743,7 @@ inline bool TrainerSpec::has_num_sub_iterations() const {
}
inline void TrainerSpec::clear_num_sub_iterations() {
num_sub_iterations_ = 2;
_has_bits_[0] &= ~0x02000000u;
_has_bits_[0] &= ~0x10000000u;
}
inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_num_sub_iterations() const {
return num_sub_iterations_;
@ -2624,7 +2753,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::num_sub_iterations() const {
return _internal_num_sub_iterations();
}
inline void TrainerSpec::_internal_set_num_sub_iterations(::PROTOBUF_NAMESPACE_ID::int32 value) {
_has_bits_[0] |= 0x02000000u;
_has_bits_[0] |= 0x10000000u;
num_sub_iterations_ = value;
}
inline void TrainerSpec::set_num_sub_iterations(::PROTOBUF_NAMESPACE_ID::int32 value) {
@ -2634,7 +2763,7 @@ inline void TrainerSpec::set_num_sub_iterations(::PROTOBUF_NAMESPACE_ID::int32 v
// optional int32 max_sentencepiece_length = 20 [default = 16];
inline bool TrainerSpec::_internal_has_max_sentencepiece_length() const {
bool value = (_has_bits_[0] & 0x08000000u) != 0;
bool value = (_has_bits_[0] & 0x40000000u) != 0;
return value;
}
inline bool TrainerSpec::has_max_sentencepiece_length() const {
@ -2642,7 +2771,7 @@ inline bool TrainerSpec::has_max_sentencepiece_length() const {
}
inline void TrainerSpec::clear_max_sentencepiece_length() {
max_sentencepiece_length_ = 16;
_has_bits_[0] &= ~0x08000000u;
_has_bits_[0] &= ~0x40000000u;
}
inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_max_sentencepiece_length() const {
return max_sentencepiece_length_;
@ -2652,7 +2781,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::max_sentencepiece_length() co
return _internal_max_sentencepiece_length();
}
inline void TrainerSpec::_internal_set_max_sentencepiece_length(::PROTOBUF_NAMESPACE_ID::int32 value) {
_has_bits_[0] |= 0x08000000u;
_has_bits_[0] |= 0x40000000u;
max_sentencepiece_length_ = value;
}
inline void TrainerSpec::set_max_sentencepiece_length(::PROTOBUF_NAMESPACE_ID::int32 value) {
@ -2662,7 +2791,7 @@ inline void TrainerSpec::set_max_sentencepiece_length(::PROTOBUF_NAMESPACE_ID::i
// optional bool split_by_unicode_script = 21 [default = true];
inline bool TrainerSpec::_internal_has_split_by_unicode_script() const {
bool value = (_has_bits_[0] & 0x20000000u) != 0;
bool value = (_has_bits_[1] & 0x00000001u) != 0;
return value;
}
inline bool TrainerSpec::has_split_by_unicode_script() const {
@ -2670,7 +2799,7 @@ inline bool TrainerSpec::has_split_by_unicode_script() const {
}
inline void TrainerSpec::clear_split_by_unicode_script() {
split_by_unicode_script_ = true;
_has_bits_[0] &= ~0x20000000u;
_has_bits_[1] &= ~0x00000001u;
}
inline bool TrainerSpec::_internal_split_by_unicode_script() const {
return split_by_unicode_script_;
@ -2680,7 +2809,7 @@ inline bool TrainerSpec::split_by_unicode_script() const {
return _internal_split_by_unicode_script();
}
inline void TrainerSpec::_internal_set_split_by_unicode_script(bool value) {
_has_bits_[0] |= 0x20000000u;
_has_bits_[1] |= 0x00000001u;
split_by_unicode_script_ = value;
}
inline void TrainerSpec::set_split_by_unicode_script(bool value) {
@ -2690,7 +2819,7 @@ inline void TrainerSpec::set_split_by_unicode_script(bool value) {
// optional bool split_by_number = 23 [default = true];
inline bool TrainerSpec::_internal_has_split_by_number() const {
bool value = (_has_bits_[0] & 0x40000000u) != 0;
bool value = (_has_bits_[1] & 0x00000002u) != 0;
return value;
}
inline bool TrainerSpec::has_split_by_number() const {
@ -2698,7 +2827,7 @@ inline bool TrainerSpec::has_split_by_number() const {
}
inline void TrainerSpec::clear_split_by_number() {
split_by_number_ = true;
_has_bits_[0] &= ~0x40000000u;
_has_bits_[1] &= ~0x00000002u;
}
inline bool TrainerSpec::_internal_split_by_number() const {
return split_by_number_;
@ -2708,7 +2837,7 @@ inline bool TrainerSpec::split_by_number() const {
return _internal_split_by_number();
}
inline void TrainerSpec::_internal_set_split_by_number(bool value) {
_has_bits_[0] |= 0x40000000u;
_has_bits_[1] |= 0x00000002u;
split_by_number_ = value;
}
inline void TrainerSpec::set_split_by_number(bool value) {
@ -2718,7 +2847,7 @@ inline void TrainerSpec::set_split_by_number(bool value) {
// optional bool split_by_whitespace = 22 [default = true];
inline bool TrainerSpec::_internal_has_split_by_whitespace() const {
bool value = (_has_bits_[0] & 0x80000000u) != 0;
bool value = (_has_bits_[1] & 0x00000004u) != 0;
return value;
}
inline bool TrainerSpec::has_split_by_whitespace() const {
@ -2726,7 +2855,7 @@ inline bool TrainerSpec::has_split_by_whitespace() const {
}
inline void TrainerSpec::clear_split_by_whitespace() {
split_by_whitespace_ = true;
_has_bits_[0] &= ~0x80000000u;
_has_bits_[1] &= ~0x00000004u;
}
inline bool TrainerSpec::_internal_split_by_whitespace() const {
return split_by_whitespace_;
@ -2736,7 +2865,7 @@ inline bool TrainerSpec::split_by_whitespace() const {
return _internal_split_by_whitespace();
}
inline void TrainerSpec::_internal_set_split_by_whitespace(bool value) {
_has_bits_[0] |= 0x80000000u;
_has_bits_[1] |= 0x00000004u;
split_by_whitespace_ = value;
}
inline void TrainerSpec::set_split_by_whitespace(bool value) {
@ -2746,7 +2875,7 @@ inline void TrainerSpec::set_split_by_whitespace(bool value) {
// optional bool treat_whitespace_as_suffix = 24 [default = false];
inline bool TrainerSpec::_internal_has_treat_whitespace_as_suffix() const {
bool value = (_has_bits_[0] & 0x00001000u) != 0;
bool value = (_has_bits_[0] & 0x00002000u) != 0;
return value;
}
inline bool TrainerSpec::has_treat_whitespace_as_suffix() const {
@ -2754,7 +2883,7 @@ inline bool TrainerSpec::has_treat_whitespace_as_suffix() const {
}
inline void TrainerSpec::clear_treat_whitespace_as_suffix() {
treat_whitespace_as_suffix_ = false;
_has_bits_[0] &= ~0x00001000u;
_has_bits_[0] &= ~0x00002000u;
}
inline bool TrainerSpec::_internal_treat_whitespace_as_suffix() const {
return treat_whitespace_as_suffix_;
@ -2764,7 +2893,7 @@ inline bool TrainerSpec::treat_whitespace_as_suffix() const {
return _internal_treat_whitespace_as_suffix();
}
inline void TrainerSpec::_internal_set_treat_whitespace_as_suffix(bool value) {
_has_bits_[0] |= 0x00001000u;
_has_bits_[0] |= 0x00002000u;
treat_whitespace_as_suffix_ = value;
}
inline void TrainerSpec::set_treat_whitespace_as_suffix(bool value) {
@ -2774,7 +2903,7 @@ inline void TrainerSpec::set_treat_whitespace_as_suffix(bool value) {
// optional bool allow_whitespace_only_pieces = 26 [default = false];
inline bool TrainerSpec::_internal_has_allow_whitespace_only_pieces() const {
bool value = (_has_bits_[0] & 0x00002000u) != 0;
bool value = (_has_bits_[0] & 0x00004000u) != 0;
return value;
}
inline bool TrainerSpec::has_allow_whitespace_only_pieces() const {
@ -2782,7 +2911,7 @@ inline bool TrainerSpec::has_allow_whitespace_only_pieces() const {
}
inline void TrainerSpec::clear_allow_whitespace_only_pieces() {
allow_whitespace_only_pieces_ = false;
_has_bits_[0] &= ~0x00002000u;
_has_bits_[0] &= ~0x00004000u;
}
inline bool TrainerSpec::_internal_allow_whitespace_only_pieces() const {
return allow_whitespace_only_pieces_;
@ -2792,7 +2921,7 @@ inline bool TrainerSpec::allow_whitespace_only_pieces() const {
return _internal_allow_whitespace_only_pieces();
}
inline void TrainerSpec::_internal_set_allow_whitespace_only_pieces(bool value) {
_has_bits_[0] |= 0x00002000u;
_has_bits_[0] |= 0x00004000u;
allow_whitespace_only_pieces_ = value;
}
inline void TrainerSpec::set_allow_whitespace_only_pieces(bool value) {
@ -2802,7 +2931,7 @@ inline void TrainerSpec::set_allow_whitespace_only_pieces(bool value) {
// optional bool split_digits = 25 [default = false];
inline bool TrainerSpec::_internal_has_split_digits() const {
bool value = (_has_bits_[0] & 0x00004000u) != 0;
bool value = (_has_bits_[0] & 0x00008000u) != 0;
return value;
}
inline bool TrainerSpec::has_split_digits() const {
@ -2810,7 +2939,7 @@ inline bool TrainerSpec::has_split_digits() const {
}
inline void TrainerSpec::clear_split_digits() {
split_digits_ = false;
_has_bits_[0] &= ~0x00004000u;
_has_bits_[0] &= ~0x00008000u;
}
inline bool TrainerSpec::_internal_split_digits() const {
return split_digits_;
@ -2820,7 +2949,7 @@ inline bool TrainerSpec::split_digits() const {
return _internal_split_digits();
}
inline void TrainerSpec::_internal_set_split_digits(bool value) {
_has_bits_[0] |= 0x00004000u;
_has_bits_[0] |= 0x00008000u;
split_digits_ = value;
}
inline void TrainerSpec::set_split_digits(bool value) {
@ -3051,7 +3180,7 @@ inline void TrainerSpec::set_allocated_required_chars(std::string* required_char
// optional bool byte_fallback = 35 [default = false];
inline bool TrainerSpec::_internal_has_byte_fallback() const {
bool value = (_has_bits_[0] & 0x00008000u) != 0;
bool value = (_has_bits_[0] & 0x00010000u) != 0;
return value;
}
inline bool TrainerSpec::has_byte_fallback() const {
@ -3059,7 +3188,7 @@ inline bool TrainerSpec::has_byte_fallback() const {
}
inline void TrainerSpec::clear_byte_fallback() {
byte_fallback_ = false;
_has_bits_[0] &= ~0x00008000u;
_has_bits_[0] &= ~0x00010000u;
}
inline bool TrainerSpec::_internal_byte_fallback() const {
return byte_fallback_;
@ -3069,7 +3198,7 @@ inline bool TrainerSpec::byte_fallback() const {
return _internal_byte_fallback();
}
inline void TrainerSpec::_internal_set_byte_fallback(bool value) {
_has_bits_[0] |= 0x00008000u;
_has_bits_[0] |= 0x00010000u;
byte_fallback_ = value;
}
inline void TrainerSpec::set_byte_fallback(bool value) {
@ -3079,7 +3208,7 @@ inline void TrainerSpec::set_byte_fallback(bool value) {
// optional bool vocabulary_output_piece_score = 32 [default = true];
inline bool TrainerSpec::_internal_has_vocabulary_output_piece_score() const {
bool value = (_has_bits_[1] & 0x00000001u) != 0;
bool value = (_has_bits_[1] & 0x00000008u) != 0;
return value;
}
inline bool TrainerSpec::has_vocabulary_output_piece_score() const {
@ -3087,7 +3216,7 @@ inline bool TrainerSpec::has_vocabulary_output_piece_score() const {
}
inline void TrainerSpec::clear_vocabulary_output_piece_score() {
vocabulary_output_piece_score_ = true;
_has_bits_[1] &= ~0x00000001u;
_has_bits_[1] &= ~0x00000008u;
}
inline bool TrainerSpec::_internal_vocabulary_output_piece_score() const {
return vocabulary_output_piece_score_;
@ -3097,7 +3226,7 @@ inline bool TrainerSpec::vocabulary_output_piece_score() const {
return _internal_vocabulary_output_piece_score();
}
inline void TrainerSpec::_internal_set_vocabulary_output_piece_score(bool value) {
_has_bits_[1] |= 0x00000001u;
_has_bits_[1] |= 0x00000008u;
vocabulary_output_piece_score_ = value;
}
inline void TrainerSpec::set_vocabulary_output_piece_score(bool value) {
@ -3107,7 +3236,7 @@ inline void TrainerSpec::set_vocabulary_output_piece_score(bool value) {
// optional bool hard_vocab_limit = 33 [default = true];
inline bool TrainerSpec::_internal_has_hard_vocab_limit() const {
bool value = (_has_bits_[1] & 0x00000002u) != 0;
bool value = (_has_bits_[1] & 0x00000010u) != 0;
return value;
}
inline bool TrainerSpec::has_hard_vocab_limit() const {
@ -3115,7 +3244,7 @@ inline bool TrainerSpec::has_hard_vocab_limit() const {
}
inline void TrainerSpec::clear_hard_vocab_limit() {
hard_vocab_limit_ = true;
_has_bits_[1] &= ~0x00000002u;
_has_bits_[1] &= ~0x00000010u;
}
inline bool TrainerSpec::_internal_hard_vocab_limit() const {
return hard_vocab_limit_;
@ -3125,7 +3254,7 @@ inline bool TrainerSpec::hard_vocab_limit() const {
return _internal_hard_vocab_limit();
}
inline void TrainerSpec::_internal_set_hard_vocab_limit(bool value) {
_has_bits_[1] |= 0x00000002u;
_has_bits_[1] |= 0x00000010u;
hard_vocab_limit_ = value;
}
inline void TrainerSpec::set_hard_vocab_limit(bool value) {
@ -3163,7 +3292,7 @@ inline void TrainerSpec::set_use_all_vocab(bool value) {
// optional int32 unk_id = 40 [default = 0];
inline bool TrainerSpec::_internal_has_unk_id() const {
bool value = (_has_bits_[0] & 0x00010000u) != 0;
bool value = (_has_bits_[0] & 0x00080000u) != 0;
return value;
}
inline bool TrainerSpec::has_unk_id() const {
@ -3171,7 +3300,7 @@ inline bool TrainerSpec::has_unk_id() const {
}
inline void TrainerSpec::clear_unk_id() {
unk_id_ = 0;
_has_bits_[0] &= ~0x00010000u;
_has_bits_[0] &= ~0x00080000u;
}
inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_unk_id() const {
return unk_id_;
@ -3181,7 +3310,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::unk_id() const {
return _internal_unk_id();
}
inline void TrainerSpec::_internal_set_unk_id(::PROTOBUF_NAMESPACE_ID::int32 value) {
_has_bits_[0] |= 0x00010000u;
_has_bits_[0] |= 0x00080000u;
unk_id_ = value;
}
inline void TrainerSpec::set_unk_id(::PROTOBUF_NAMESPACE_ID::int32 value) {
@ -3191,7 +3320,7 @@ inline void TrainerSpec::set_unk_id(::PROTOBUF_NAMESPACE_ID::int32 value) {
// optional int32 bos_id = 41 [default = 1];
inline bool TrainerSpec::_internal_has_bos_id() const {
bool value = (_has_bits_[1] & 0x00000004u) != 0;
bool value = (_has_bits_[1] & 0x00000020u) != 0;
return value;
}
inline bool TrainerSpec::has_bos_id() const {
@ -3199,7 +3328,7 @@ inline bool TrainerSpec::has_bos_id() const {
}
inline void TrainerSpec::clear_bos_id() {
bos_id_ = 1;
_has_bits_[1] &= ~0x00000004u;
_has_bits_[1] &= ~0x00000020u;
}
inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_bos_id() const {
return bos_id_;
@ -3209,7 +3338,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::bos_id() const {
return _internal_bos_id();
}
inline void TrainerSpec::_internal_set_bos_id(::PROTOBUF_NAMESPACE_ID::int32 value) {
_has_bits_[1] |= 0x00000004u;
_has_bits_[1] |= 0x00000020u;
bos_id_ = value;
}
inline void TrainerSpec::set_bos_id(::PROTOBUF_NAMESPACE_ID::int32 value) {
@ -3219,7 +3348,7 @@ inline void TrainerSpec::set_bos_id(::PROTOBUF_NAMESPACE_ID::int32 value) {
// optional int32 eos_id = 42 [default = 2];
inline bool TrainerSpec::_internal_has_eos_id() const {
bool value = (_has_bits_[1] & 0x00000008u) != 0;
bool value = (_has_bits_[1] & 0x00000040u) != 0;
return value;
}
inline bool TrainerSpec::has_eos_id() const {
@ -3227,7 +3356,7 @@ inline bool TrainerSpec::has_eos_id() const {
}
inline void TrainerSpec::clear_eos_id() {
eos_id_ = 2;
_has_bits_[1] &= ~0x00000008u;
_has_bits_[1] &= ~0x00000040u;
}
inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_eos_id() const {
return eos_id_;
@ -3237,7 +3366,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::eos_id() const {
return _internal_eos_id();
}
inline void TrainerSpec::_internal_set_eos_id(::PROTOBUF_NAMESPACE_ID::int32 value) {
_has_bits_[1] |= 0x00000008u;
_has_bits_[1] |= 0x00000040u;
eos_id_ = value;
}
inline void TrainerSpec::set_eos_id(::PROTOBUF_NAMESPACE_ID::int32 value) {
@ -3247,7 +3376,7 @@ inline void TrainerSpec::set_eos_id(::PROTOBUF_NAMESPACE_ID::int32 value) {
// optional int32 pad_id = 43 [default = -1];
inline bool TrainerSpec::_internal_has_pad_id() const {
bool value = (_has_bits_[1] & 0x00000010u) != 0;
bool value = (_has_bits_[1] & 0x00000080u) != 0;
return value;
}
inline bool TrainerSpec::has_pad_id() const {
@ -3255,7 +3384,7 @@ inline bool TrainerSpec::has_pad_id() const {
}
inline void TrainerSpec::clear_pad_id() {
pad_id_ = -1;
_has_bits_[1] &= ~0x00000010u;
_has_bits_[1] &= ~0x00000080u;
}
inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_pad_id() const {
return pad_id_;
@ -3265,7 +3394,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::pad_id() const {
return _internal_pad_id();
}
inline void TrainerSpec::_internal_set_pad_id(::PROTOBUF_NAMESPACE_ID::int32 value) {
_has_bits_[1] |= 0x00000010u;
_has_bits_[1] |= 0x00000080u;
pad_id_ = value;
}
inline void TrainerSpec::set_pad_id(::PROTOBUF_NAMESPACE_ID::int32 value) {

View File

@ -46,6 +46,13 @@ class FreeList {
// Returns the number of allocated elements.
size_t size() const { return chunk_size_ * chunk_index_ + element_index_; }
void swap(FreeList<T>& other) {
std::swap(freelist_, other.freelist_);
std::swap(element_index_, other.element_index_);
std::swap(chunk_index_, other.chunk_index_);
std::swap(chunk_size_, other.chunk_size_);
}
// Returns the element as an array.
T* operator[](size_t index) const {
return freelist_[index / chunk_size_] + index % chunk_size_;
@ -76,7 +83,7 @@ class FreeList {
// The last element is stored at freelist_[chunk_index_][element_index_]
size_t element_index_ = 0;
size_t chunk_index_ = 0;
const size_t chunk_size_ = 0;
size_t chunk_size_ = 0; // Do not modify except in swap()
};
} // namespace model
} // namespace sentencepiece

View File

@ -30,17 +30,20 @@ TEST(FreeListTest, BasicTest) {
*n = i;
}
EXPECT_EQ(kSize, l.size());
FreeList<int> l2(3); // Test swap()
l.swap(l2);
EXPECT_EQ(kSize, l2.size());
for (size_t i = 0; i < kSize; ++i) {
EXPECT_EQ(i, *l[i]);
EXPECT_EQ(i, *l2[i]);
}
l.Free();
EXPECT_EQ(0, l.size());
l2.Free();
EXPECT_EQ(0, l2.size());
// Zero-initialized after `Free`.
for (size_t i = 0; i < kSize; ++i) {
int *n = l.Allocate();
int *n = l2.Allocate();
EXPECT_EQ(0, *n);
}
}

View File

@ -107,12 +107,12 @@ class ModelInterface {
return EncodeResult();
}
// Sample `samples` many tokenizations from the segmentation lattice
// Sample `samples` many tokenisations from the segmentation lattice
// If `wor` is true, the samples are taken without replacement, and the scores
// are the inclusion probabilities of the elements in the sample; otherwise
// the samples are taken with replacement and the scores are the log-probs of
// sample elements
// If `include_best` is true, the best tokenization is always included in the
// If `include_best` is true, the best tokenisation is always included in the
// sample, and the remaining elements are sampled excluding the best.
virtual NBestEncodeResult SampleEncodeAndScore(absl::string_view normalized,
float alpha, int samples,

View File

@ -12,12 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.!
#include "normalizer.h"
#include <utility>
#include <vector>
#include "common.h"
#include "normalizer.h"
#include "third_party/absl/memory/memory.h"
#include "third_party/absl/strings/match.h"
#include "third_party/absl/strings/string_view.h"
@ -281,7 +280,8 @@ util::Status Normalizer::DecodePrecompiledCharsMap(
if (blob.size() <= sizeof(trie_blob_size) ||
!string_util::DecodePOD<uint32>(
absl::string_view(blob.data(), sizeof(trie_blob_size)),
&trie_blob_size)) {
&trie_blob_size) ||
trie_blob_size >= blob.size()) {
return util::InternalError("Blob for normalization rule is broken.");
}

View File

@ -26,7 +26,6 @@
#include "sentencepiece_processor.h"
#include "third_party/absl/strings/string_view.h"
#include "third_party/darts_clone/darts.h"
#include "util.h"
namespace sentencepiece {
namespace normalizer {

View File

@ -20,6 +20,7 @@ option optimize_for = LITE_RUNTIME;
package sentencepiece;
// TrainerSpec encodes a various parameters for SentencePiece training.
// Next id: 53
message TrainerSpec {
///////////////////////////////////////////////////////////////////
// General parameters
@ -62,6 +63,16 @@ message TrainerSpec {
// Size of self-test samples, which are encoded in the model file.
optional int32 self_test_sample_size = 6 [default = 0];
// Whether to use DP version of sentencepiece. Use it with TSV input format
// (requires precomputed word tab counts to work).
optional bool enable_differential_privacy = 50 [default = false];
// Set these parameters if you need DP version of sentencepiece.
// std of noise to add.
optional float differential_privacy_noise_level = 51 [default = 0.0];
// Clipping threshold to apply after adding noise. All the words with
// frequency less than this value are dropped.
optional uint64 differential_privacy_clipping_threshold = 52 [default = 0];
///////////////////////////////////////////////////////////////////
// Training parameters.
//

View File

@ -548,18 +548,24 @@ util::Status SentencePieceProcessor::Decode(
if (model_proto_ && model_proto_->trainer_spec().has_unk_surface())
unk_surface = model_proto_->trainer_spec().unk_surface().c_str();
auto DecodeSentencePiece = [&](absl::string_view piece, int id,
bool is_bos_ws) -> std::string {
if (IsControl(id)) { // <s>, </s>
return ""; // invisible symbol.
// Returns decoded piece and a boolean indicating if the function has consumed
// a bos whitespace token (a piece starting with a kSpaceSymbol). This is used
// to strip only the first whitespace token from the decoded sequence for
// add_dummy_prefix.
auto DecodeSentencePiece =
[&](absl::string_view piece, int id,
bool is_bos_ws) -> std::pair<std::string, bool> {
if (IsControl(id)) { // <s>, </s>
return std::make_pair("", false); // invisible symbol.
} else if (IsUnknown(id)) {
if (IdToPiece(id) == piece) { // <unk>
return unk_surface;
return std::make_pair(unk_surface, false);
} else { // return piece when piece is not <unk>.
return std::string(piece);
return std::make_pair(std::string(piece), false);
}
}
bool has_bos_ws = false; // whether the token starts with a kSpaceSymbol
if (is_bos_ws &&
(!model_proto_ ||
(model_proto_ &&
@ -567,10 +573,17 @@ util::Status SentencePieceProcessor::Decode(
model_proto_->normalizer_spec().remove_extra_whitespaces())))) {
// Consume if the current position is bos and
// piece starts with kSpaceSymbol.
absl::ConsumePrefix(&piece, kSpaceSymbol);
has_bos_ws = absl::ConsumePrefix(&piece, kSpaceSymbol);
if (model_proto_ &&
model_proto_->normalizer_spec().remove_extra_whitespaces()) {
// if we are removing extra whitespace, we remove all leading whitespace
has_bos_ws = false;
}
}
return absl::StrReplaceAll(piece, {{kSpaceSymbol, " "}});
return std::make_pair(absl::StrReplaceAll(piece, {{kSpaceSymbol, " "}}),
has_bos_ws);
};
for (const std::string &w : pieces) {
@ -644,12 +657,23 @@ util::Status SentencePieceProcessor::Decode(
};
int byte_start = 0;
bool is_bos_ws = true; // whether we expect a bos ws token to consume.
bool bos_ws_seen = false;
std::string decoded;
for (int i = 0; i < spt->pieces_size(); ++i) {
const auto &sp = spt->pieces(i);
if (!IsByte(sp.id())) {
RETURN_IF_ERROR(ProcessBytePieces(byte_start, i));
// if we have seen a bos_ws token or any non-empty token
if (bos_ws_seen || !text->empty()) is_bos_ws = false;
byte_start = i + 1;
SetSurface(i, DecodeSentencePiece(sp.piece(), sp.id(), text->empty()));
std::tie(decoded, bos_ws_seen) =
DecodeSentencePiece(sp.piece(), sp.id(), is_bos_ws);
SetSurface(i, decoded);
}
}
RETURN_IF_ERROR(ProcessBytePieces(byte_start, spt->pieces_size()));

View File

@ -163,6 +163,7 @@ namespace normalizer {
class Normalizer;
} // namespace normalizer
#ifndef SWIG
// Defines the multiple versions of encoder within each model. Currently only
// the Unigram model has an optimized encoder.
enum class EncoderVersion {
@ -170,13 +171,16 @@ enum class EncoderVersion {
kOriginal // The original encoder (user may choose to fall back to this
// just in case).
};
#endif
#ifndef SWIGGO
namespace util {
// Redefine std::string for serialized_proto interface as Python's string is
// a Unicode string. We can enforce the return value to be raw byte sequence
// with SWIG's typemap.
using bytes = std::string;
} // namespace util
#endif
class SentencePieceProcessor {
public:
@ -250,6 +254,7 @@ class SentencePieceProcessor {
virtual util::Status Decode(const std::vector<int> &ids,
std::string *detokenized) const;
#ifndef SWIG
// Sets the encoder version. Normally users do not need to call this function.
// But they can call this fucntion just in case if they want to fall back to
// the original encoder.
@ -257,6 +262,7 @@ class SentencePieceProcessor {
// Returns the current encoder version in use.
virtual EncoderVersion GetEncoderVersion() const;
#endif
//////////////////////////////////////////////////////////////
// NBest API.
@ -315,20 +321,12 @@ class SentencePieceProcessor {
virtual util::Status SampleEncode(absl::string_view input, int nbest_size,
float alpha, SentencePieceText *spt) const;
// Sample `samples` segmentations from the segmentation lattice.
// If `wor` is true, the samples are taken without replacement, and the scores
// are the inclusion probabilities of the elements in the sample; otherwise
// the samples are taken with replacement and the scores are the log-probes of
// sample elements.
// If `include_best` is true, the best tokenization is always included in the
// sample, and the remaining elements are sampled excluding the best.
// This method is only available in Unigram mode.
// Samples N segmentation and returns the scores as well
virtual util::Status SampleEncodeAndScore(
absl::string_view input, int samples, float theta, bool wor,
bool include_best, NBestSentencePieceText *samples_spt) const;
// Calculate entropy of possible tokenization.
// Only available in unigram mode.
// Calculate entropy of possible tokenisations
virtual util::Status CalculateEntropy(absl::string_view input, float theta,
float *entropy) const;

View File

@ -709,6 +709,87 @@ TEST(SentencepieceProcessorTest, DecodeTest) {
}
}
TEST(SentencepieceProcessorTest, DummyPrefixDecodeTest) {
class DecodeMockModel : public ModelInterface {
public:
EncodeResult Encode(absl::string_view normalized) const override {
return {};
}
int GetPieceSize() const override { return 7; }
int PieceToId(absl::string_view piece) const override {
static absl::flat_hash_map<absl::string_view, int,
string_util::string_view_hash>
kMap = {{"<unk>", 0}, {"<s>", 1}, {"</s>", 2}, {WS "ABC", 3},
{WS "DE", 4}, {"F", 5}, {"G" WS "H", 6}, {WS, 7}};
return port::FindWithDefault(kMap, piece, 0);
}
const std::string &IdToPiece(int id) const override {
static std::vector<std::string> kMap = {
"<unk>", "<s>", "</s>", WS "ABC", WS "DE", "F", "G" WS "H", WS};
return kMap[id];
}
bool IsUnknown(int id) const override { return (id == 0); }
bool IsControl(int id) const override { return (id == 1 || id == 2); }
bool IsByte(int id) const override { return false; }
float GetScore(int id) const override { return 0.0; }
};
// start the sequence with a whitespace token
const std::vector<std::string> input = {
"<s>", WS, WS "ABC", "<unk>", WS "DE", "F", "G" WS "H", "I", "</s>"};
{
SentencePieceProcessor sp;
auto proto = absl::make_unique<ModelProto>();
proto->mutable_trainer_spec()->set_unk_surface("");
proto->mutable_normalizer_spec()->set_add_dummy_prefix(true);
proto->mutable_normalizer_spec()->set_remove_extra_whitespaces(false);
sp.Load(std::move(proto)).IgnoreError();
auto mock = absl::make_unique<DecodeMockModel>();
sp.SetModel(std::move(mock));
const auto normalization_spec = MakeDefaultNormalizerSpec();
sp.SetNormalizer(
absl::make_unique<normalizer::Normalizer>(normalization_spec));
SentencePieceText spt;
EXPECT_TRUE(sp.Decode(input, &spt).ok());
EXPECT_EQ(" ABC DEFG HI", spt.text());
EXPECT_EQ(9, spt.pieces_size());
}
{
SentencePieceProcessor sp;
auto proto = absl::make_unique<ModelProto>();
proto->mutable_trainer_spec()->set_unk_surface("");
proto->mutable_normalizer_spec()->set_add_dummy_prefix(true);
proto->mutable_normalizer_spec()->set_remove_extra_whitespaces(true);
sp.Load(std::move(proto)).IgnoreError();
auto mock = absl::make_unique<DecodeMockModel>();
sp.SetModel(std::move(mock));
const auto normalization_spec = MakeDefaultNormalizerSpec();
sp.SetNormalizer(
absl::make_unique<normalizer::Normalizer>(normalization_spec));
SentencePieceText spt;
EXPECT_TRUE(sp.Decode(input, &spt).ok());
EXPECT_EQ("ABC DEFG HI", spt.text());
EXPECT_EQ(9, spt.pieces_size());
}
}
TEST(SentencepieceProcessorTest, ByteFallbackDecodeTest) {
class ByteFallbackDecodeMockModel : public ModelInterface {
public:

View File

@ -144,6 +144,18 @@ ABSL_FLAG(bool, train_extremely_large_corpus,
ABSL_FLAG(uint32, random_seed, static_cast<uint32>(-1),
"Seed value for random generator.");
// DP related.
ABSL_FLAG(bool, enable_differential_privacy, false,
"Whether to add DP while training. Currently supported only by "
"UNIGRAM model.");
ABSL_FLAG(float, differential_privacy_noise_level, 0.0f,
"Amount of noise to add for"
" DP");
ABSL_FLAG(std::uint64_t, differential_privacy_clipping_threshold, 0,
"Threshold for"
" clipping the counts for DP");
int main(int argc, char *argv[]) {
sentencepiece::ParseCommandLineFlags(argv[0], &argc, &argv, true);
@ -235,6 +247,10 @@ int main(int argc, char *argv[]) {
SetRepeatedTrainerSpecFromFlag(control_symbols);
SetRepeatedTrainerSpecFromFlag(user_defined_symbols);
SetTrainerSpecFromFlag(train_extremely_large_corpus);
// DP related.
SetTrainerSpecFromFlag(enable_differential_privacy);
SetTrainerSpecFromFlag(differential_privacy_noise_level);
SetTrainerSpecFromFlag(differential_privacy_clipping_threshold);
SetRepeatedTrainerSpecFromFile(control_symbols);
SetRepeatedTrainerSpecFromFile(user_defined_symbols);

View File

@ -28,6 +28,8 @@
#include "sentencepiece_trainer.h"
#include "third_party/absl/container/flat_hash_map.h"
#include "third_party/absl/memory/memory.h"
#include "third_party/absl/random/distributions.h"
#include "third_party/absl/random/random.h"
#include "third_party/absl/strings/numbers.h"
#include "third_party/absl/strings/str_cat.h"
#include "third_party/absl/strings/str_format.h"
@ -273,8 +275,7 @@ bool TrainerInterface::IsValidSentencePiece(
if (s == unicode_script::U_Hiragana || s == unicode_script::U_Katakana ||
c == 0x30FC) { // long vowel sound (Katakana) should be Katakana
s = unicode_script::U_Han;
}
else if (s == unicode_script::U_Inherited) {
} else if (s == unicode_script::U_Inherited) {
s = prev_script;
}
@ -299,6 +300,22 @@ bool TrainerInterface::IsValidSentencePiece(
return true;
}
template <typename T>
void AddDPNoise(const TrainerSpec &trainer_spec, absl::SharedBitGen &generator,
T *to_update) {
if (trainer_spec.differential_privacy_noise_level() > 0) {
float random_num = absl::Gaussian<float>(
generator, 0, trainer_spec.differential_privacy_noise_level());
*to_update =
std::round(std::max(0.f, random_num + static_cast<float>(*to_update)));
}
// Clip anything below the clipping threshold to 0.
if (*to_update < trainer_spec.differential_privacy_clipping_threshold()) {
*to_update = 0;
}
}
util::Status TrainerInterface::LoadSentences() {
RETURN_IF_ERROR(status());
CHECK_OR_RETURN(sentences_.empty());
@ -390,6 +407,7 @@ END:
LOG(INFO) << "Sampled " << sentences_.size() << " sentences from "
<< selector.total_size() << " sentences.";
}
if (too_long_lines > 0)
LOG(INFO) << "Skipped " << too_long_lines << " too long sentences.";
if (self_test_samples_.size() > 0)
@ -433,6 +451,54 @@ END:
}
}
// If DP is required, add the noise/clip the input.
if (trainer_spec_.enable_differential_privacy()) {
if (trainer_spec_.input_format() != "tsv") {
LOG(ERROR)
<< "Dp version will not work correctly with text input format.";
}
if (trainer_spec_.differential_privacy_noise_level() <= 0) {
LOG(WARNING) << "Private version with <=0 noise level will give "
"infinity epsilon gurantees.";
}
if (trainer_spec_.differential_privacy_clipping_threshold() <= 0) {
LOG(WARNING) << "Private version with <=0 clipping threshold will give "
"infinity epsilon gurantees.";
}
// Add noise to all the sentences via threadpool.
// This line is mainly for tests with small num of sentences.
const auto num_workers =
std::min<uint64>(trainer_spec_.num_threads(), sentences_.size() - 1);
{
auto pool = absl::make_unique<ThreadPool>(num_workers);
pool->StartWorkers();
for (int n = 0; n < num_workers; ++n) {
pool->Schedule([&, n]() {
// One per thread generator.
absl::SharedBitGen generator;
for (size_t i = n; i < sentences_.size(); i += num_workers) {
AddDPNoise<int64>(trainer_spec_, generator,
&(sentences_[i].second));
}
});
}
}
// Remove zero freq elements.
const auto before_size = sentences_.size();
auto it = std::remove_if(sentences_.begin(), sentences_.end(),
[](const Sentence &s) { return s.second <= 0; });
const auto new_size = std::distance(sentences_.begin(), it);
const int num_erased = before_size - new_size;
sentences_.erase(it, sentences_.end());
LOG(INFO) << "DP noise resulted in " << 1.0 * num_erased / before_size
<< " fraction of sentences removed.";
}
// Count character frequencies.
int64 all_chars_count = 0;
// A map from a character to {is_required_char, character count}.
@ -617,6 +683,7 @@ util::Status TrainerInterface::Serialize(ModelProto *model_proto) const {
util::Status TrainerInterface::SaveModel(absl::string_view filename) const {
LOG(INFO) << "Saving model: " << filename;
ModelProto model_proto;
RETURN_IF_ERROR(Serialize(&model_proto));
auto output = filesystem::NewWritableFile(filename.data(), true);

View File

@ -107,16 +107,16 @@ class TrainerInterface {
FRIEND_TEST(TrainerInterfaceTest, SerializeTest);
FRIEND_TEST(TrainerInterfaceTest, CharactersTest);
// Loads all sentences from spec.input() or SentenceIterator.
// It loads at most input_sentence_size sentences.
util::Status LoadSentences();
protected:
// Returns true if |piece| is valid sentence piece.
// The result is affected by
// max_sentencepiece_length, split_by_whiespace, split_by_unicode_script.
bool IsValidSentencePiece(const string_util::UnicodeText &piece) const;
// Loads all sentences from spec.input() or SentenceIterator.
// It loads at most input_sentence_size sentences.
util::Status LoadSentences();
// Splits all sentencecs by whitespaces and
// replace the |sentences_| with tokenized string.
// e.g.,

View File

@ -78,8 +78,10 @@ TEST(TrainerInterfaceTest, IsValidSentencePieceTest) {
EXPECT_FALSE(IsValid("ab cd"));
EXPECT_FALSE(IsValid("\0\0"));
EXPECT_FALSE(IsValid("\0"));
EXPECT_TRUE(IsValid("proteïni")); // Combining Diaeresis should inherit script from base character.
EXPECT_TRUE(IsValid("ثَبَّتَ")); // Arabic Fatha and Shadda should inherit script from base character.
EXPECT_TRUE(IsValid("proteïni")); // Combining Diaeresis should inherit
// script from base character.
EXPECT_TRUE(IsValid("ثَبَّتَ")); // Arabic Fatha and Shadda should inherit script
// from base character.
trainer_spec.set_split_by_whitespace(false);
EXPECT_TRUE(IsValid(WS));

View File

@ -22,6 +22,7 @@
#include <utility>
#include <vector>
#include "third_party/absl/container/flat_hash_map.h"
#include "third_party/absl/memory/memory.h"
#include "third_party/absl/strings/str_split.h"
#include "third_party/absl/strings/string_view.h"
@ -289,6 +290,58 @@ float Lattice::CalculateEntropy(float theta) const {
return -H[begin_nodes_[len][0]->node_id];
}
namespace {
// The node structure to support A* algorithm in Lattice::NBest()
struct Hypothesis {
Lattice::Node *node;
Hypothesis *next;
float fx; // the priority to pop a new hypothesis from the priority queue.
float gx; // the sum of scores from EOS to the left-most node in x.
};
// Helper function for cloning a Hypothesis and the ones on their next paths.
// The graph structure is preserved.
//
// to_clone: the Hypothesis to clone.
// clone_map: mapping between the old pointers and the new pointers.
// allocator: allocate and own the cloned Hypothesis.
//
// Returns the cloned Hypothesis*. All Hypothesis on its "next" chain are also
// guaranteed to have been allocated via "allocator", and "clone_map" is updated
// with all new mappings.
Hypothesis *CloneHypAndDependents(
const Hypothesis *to_clone,
absl::flat_hash_map<const Hypothesis *, Hypothesis *> *clone_map,
model::FreeList<Hypothesis> *allocator) {
Hypothesis *cloned = nullptr;
Hypothesis **result_callback = &cloned;
// Iteratively clone "to_clone" and its dependencies.
// The new pointer will be written back to *result_callback.
while (to_clone != nullptr) {
// If "to_clone" has already been cloned before, we just look up the result.
auto iter = clone_map->find(to_clone);
if (iter != clone_map->end()) {
*result_callback = iter->second;
break;
}
// Allocate a new Hypothesis and copy the values.
Hypothesis *new_hyp = allocator->Allocate();
*new_hyp = *to_clone;
*result_callback = new_hyp;
clone_map->insert({to_clone, new_hyp});
// Move on to clone "to_clone->next".
to_clone = to_clone->next;
result_callback = &(new_hyp->next);
}
return cloned;
}
} // namespace
std::vector<Lattice::LatticePathWithScore> Lattice::NBest(size_t nbest_size,
bool sample,
float theta) {
@ -312,12 +365,6 @@ std::vector<Lattice::LatticePathWithScore> Lattice::NBest(size_t nbest_size,
//
// As left-to-right Viterbi search can tell the *exact* value of h(x),
// we can obtain the exact n-best results with A*.
struct Hypothesis {
Node *node;
Hypothesis *next;
float fx;
float gx;
};
class HypothesisComparator {
public:
@ -353,6 +400,8 @@ std::vector<Lattice::LatticePathWithScore> Lattice::NBest(size_t nbest_size,
}
agenda.push(eos);
int shrink_count = 0; // Number of times agenda has shrunk. For logging only.
bool printed_memory_warning = false; // For logging only.
while (!agenda.empty()) {
auto *top = agenda.top();
agenda.pop();
@ -416,21 +465,42 @@ std::vector<Lattice::LatticePathWithScore> Lattice::NBest(size_t nbest_size,
agenda.push(hyp);
}
static constexpr int kOneBillion = 1000000000; // 10^9.
if (hypothesis_allocator.size() >= kOneBillion) {
if (!printed_memory_warning) {
printed_memory_warning = true;
LOG(WARNING) << "Allocator size exceeds " << kOneBillion
<< " with an example of length " << this->size();
}
}
// When the input is too long or contains duplicated phrases,
// `agenda` will get extremely big. Here we avoid this case by
// dynamically shrinking the agenda.
constexpr int kMaxAgendaSize = 100000;
constexpr int kMaxAgendaSize = 10000;
constexpr int kMinAgendaSize = 512;
if (agenda.size() >= kMaxAgendaSize) {
LOG(WARNING) << "Too big agenda. shrinking";
// Keeps the top `kMinAgendaSize` hypothesis.
Agenda new_agenda;
// Keeps the top hypothesis and the ones on their "next" paths.
model::FreeList<Hypothesis> new_allocator(kPreallocatedHypothesisSize);
// Map between old Hypothesis* and new Hypothesis*.
absl::flat_hash_map<const Hypothesis *, Hypothesis *> clone_map;
const int size = std::min<int>(kMinAgendaSize, nbest_size * 10);
shrink_count++;
LOG(WARNING) << "Too big agenda size " << agenda.size()
<< ". Shrinking (round " << shrink_count << ") down to "
<< size << ".";
for (int i = 0; i < size; ++i) {
new_agenda.push(agenda.top());
const Hypothesis *top_hyp = agenda.top();
Hypothesis *cloned_hyp =
CloneHypAndDependents(top_hyp, &clone_map, &new_allocator);
new_agenda.push(cloned_hyp);
agenda.pop();
}
agenda = std::move(new_agenda);
hypothesis_allocator.swap(new_allocator);
}
}

View File

@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.!
#include "unigram_model.h"
#include <cmath>
#include <map>
#include <string>
@ -24,6 +22,7 @@
#include "testharness.h"
#include "third_party/absl/strings/str_cat.h"
#include "third_party/absl/strings/str_join.h"
#include "unigram_model.h"
#include "util.h"
namespace sentencepiece {
@ -268,8 +267,8 @@ TEST(LatticeTest, NBestSampleTest) {
for (auto &it : probs) it.second /= Z;
std::map<std::pair<std::string, std::string>, float> pair_probs;
for (const auto first : strings) {
for (const auto second : strings) {
for (const auto &first : strings) {
for (const auto &second : strings) {
if (first == second) {
pair_probs[std::make_pair(first, second)] = 0;
} else {
@ -281,12 +280,12 @@ TEST(LatticeTest, NBestSampleTest) {
}
std::map<std::string, float> inclusion_probs;
for (const auto string : strings) {
for (const auto &string : strings) {
float inclusion_prob = 0.0;
for (const auto other_string : strings) {
for (const auto &other_string : strings) {
inclusion_prob += pair_probs[std::make_pair(string, other_string)];
}
for (const auto other_string : strings) {
for (const auto &other_string : strings) {
inclusion_prob += pair_probs[std::make_pair(other_string, string)];
}
inclusion_probs[string] = inclusion_prob / 2;
@ -300,7 +299,7 @@ TEST(LatticeTest, NBestSampleTest) {
std::map<std::string, int> counts;
for (int i = 0; i < kTrials; i++) {
auto nbests = lattice.NBest(num_samples, true, theta);
for (const auto nbest : nbests) {
for (const auto &nbest : nbests) {
counts[GetTokenized(nbest.first)]++;
}
}
@ -550,25 +549,25 @@ TEST(UnigramModelTest, SampleEncodeAndScoreTest) {
for (auto &it : probs) it.second /= Z;
std::map<std::pair<std::string, std::string>, float> pair_probs;
for (const auto first : strings) {
for (const auto second : strings) {
for (const auto &first : strings) {
for (const auto &second : strings) {
if (first == second) {
pair_probs[std::make_pair(first, second)] = 0;
} else {
float first_prob = probs[first];
float second_prob = probs[second] / (1 - first_prob);
const float first_prob = probs[first];
const float second_prob = probs[second] / (1 - first_prob);
pair_probs[std::make_pair(first, second)] = first_prob * second_prob;
}
}
}
std::map<std::string, float> inclusion_probs;
for (const auto string : strings) {
for (const auto &string : strings) {
float inclusion_prob = 0.0;
for (const auto other_string : strings) {
for (const auto &other_string : strings) {
inclusion_prob += pair_probs[std::make_pair(string, other_string)];
}
for (const auto other_string : strings) {
for (const auto &other_string : strings) {
inclusion_prob += pair_probs[std::make_pair(other_string, string)];
}
inclusion_probs[string] = inclusion_prob / 2;

View File

@ -70,9 +70,6 @@ class Trainer : public TrainerInterface {
util::Status Train() override;
private:
FRIEND_TEST(TrainerTest, IsValidSentencePieceTest);
// Makes seed pieces from the training corpus.
// The size of seed pieces is determined by seed_sentencepiece_size.
// node_int_type should be of integer type (int32 or int64),
@ -80,6 +77,9 @@ class Trainer : public TrainerInterface {
template <typename node_int_type>
TrainerModel::SentencePieces MakeSeedSentencePieces() const;
private:
FRIEND_TEST(TrainerTest, IsValidSentencePieceTest);
// Executes the E step of EM and returns expected count.
// The index of return array is the vocab id.
// |objective| is a negative likelihood of the current model.

View File

@ -12,6 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.!
#include <string>
#include <vector>
#include "filesystem.h"
#include "sentencepiece_model.pb.h"
#include "sentencepiece_processor.h"
#include "sentencepiece_trainer.h"
@ -23,7 +27,6 @@
namespace sentencepiece {
namespace unigram {
namespace {
// Space symbol
#define WS "\xe2\x96\x81"
@ -35,6 +38,116 @@ TEST(UnigramTrainerTest, TrainerModelTest) {
EXPECT_EQ(EncodeResult(), model.Encode("test"));
}
struct TrainerResult {
std::string sentence_pieces;
std::vector<std::pair<std::string, float>> seed_pieces_and_probs;
};
TrainerResult RunTrainer(const std::vector<std::string>& input, int size,
const bool use_dp = false, const float dp_noise = 0.0,
const uint32 dp_clip = 0) {
const std::string input_file =
util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "input");
const std::string model_prefix =
util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "model");
{
auto output = filesystem::NewWritableFile(input_file);
for (const auto& line : input) {
output->WriteLine(line);
}
}
TrainerSpec trainer_spec;
trainer_spec.set_input_format("tsv");
trainer_spec.set_model_type(TrainerSpec::UNIGRAM);
trainer_spec.add_input(input_file);
trainer_spec.set_vocab_size(size - 3); // remove <unk>, <s>, </s>
trainer_spec.set_model_prefix(model_prefix);
trainer_spec.set_enable_differential_privacy(use_dp);
trainer_spec.set_differential_privacy_noise_level(dp_noise);
trainer_spec.set_differential_privacy_clipping_threshold(dp_clip);
NormalizerSpec normalizer_spec;
normalizer_spec.set_name("identity");
normalizer_spec.set_add_dummy_prefix(false);
NormalizerSpec denormalizer_spec;
std::vector<std::pair<std::string, float>> seed_pieces;
{
Trainer trainer(trainer_spec, normalizer_spec, denormalizer_spec);
EXPECT_OK(trainer.LoadSentences());
TrainerModel::SentencePieces res = trainer.MakeSeedSentencePieces<int32>();
for (const auto& piece : res) {
seed_pieces.emplace_back(piece.first, piece.second);
}
}
std::vector<std::string> pieces;
{
Trainer trainer(trainer_spec, normalizer_spec, denormalizer_spec);
EXPECT_TRUE(trainer.Train().ok());
SentencePieceProcessor processor;
EXPECT_TRUE(processor.Load(model_prefix + ".model").ok());
const auto& model = processor.model_proto();
// remove <unk>, <s>, </s>
for (int i = 3; i < model.pieces_size(); ++i) {
pieces.emplace_back(model.pieces(i).piece());
}
}
TrainerResult res;
res.seed_pieces_and_probs = seed_pieces;
res.sentence_pieces = absl::StrJoin(pieces, " ");
return res;
}
TEST(UnigramTrainerTest, BasicTest) {
const auto& res = RunTrainer(
{"magnanimity \t 5", "Pineapple \t 6", "i have an apple and a pen \t 1",
"Overly \t 6", "Available \t 3"},
30);
// Check seed pieces.
EXPECT_EQ(27, res.seed_pieces_and_probs.size());
// Check final pieces.
EXPECT_EQ("i a n y m l e apple ve O P r t g an v ▁ A b le ▁an p d h",
res.sentence_pieces);
}
TEST(UnigramTrainerTest, BasicDPTest) {
// no noise, clipping.
{
const auto& res = RunTrainer(
{"magnanimity \t 5", "Pineapple \t 6", "i have an apple and a pen \t 1",
"Overly \t 6", "Available \t 5"},
22, true /*use_dp*/, 0 /*dp_noise*/, 4 /*dp_clipping*/);
// Got 16 instead of 27 seeds.
EXPECT_EQ(16, res.seed_pieces_and_probs.size());
// And they are equiv to if the last sentence was not there.
const auto& res_nodp = RunTrainer(
{"magnanimity \t 5", "Pineapple \t 6", "Overly \t 6", "Available \t 5"},
22);
EXPECT_EQ(res.seed_pieces_and_probs, res_nodp.seed_pieces_and_probs);
// Check final pieces.
EXPECT_EQ(res.sentence_pieces, res_nodp.sentence_pieces);
}
}
namespace {
static constexpr char kTestInputData[] = "wagahaiwa_nekodearu.txt";
TEST(UnigramTrainerTest, EndToEndTest) {

View File

@ -173,6 +173,7 @@ template class Flag<std::string>;
template class Flag<int32>;
template class Flag<uint32>;
template class Flag<double>;
template class Flag<float>;
template class Flag<bool>;
template class Flag<int64>;
template class Flag<uint64>;