mirror of
https://github.com/google/sentencepiece.git
synced 2024-10-26 11:38:45 +03:00
parent
05db0894d8
commit
3a5bc5815b
@ -1 +1 @@
|
||||
0.1.96
|
||||
0.1.95
|
||||
|
@ -1 +1 @@
|
||||
0.1.96
|
||||
0.1.95
|
||||
|
@ -12,6 +12,8 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.!
|
||||
|
||||
#include "bpe_model.h"
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
@ -19,7 +21,6 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "bpe_model.h"
|
||||
#include "freelist.h"
|
||||
#include "third_party/absl/container/flat_hash_map.h"
|
||||
#include "util.h"
|
||||
|
@ -12,11 +12,12 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.!
|
||||
|
||||
#include "builder.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <utility>
|
||||
|
||||
#include "builder.h"
|
||||
#include "filesystem.h"
|
||||
#include "third_party/absl/strings/str_join.h"
|
||||
#include "third_party/absl/strings/str_replace.h"
|
||||
@ -367,7 +368,6 @@ util::Status Builder::BuildNmtNFKCMap(CharsMap *chars_map) {
|
||||
nfkc_map[{0xFEFF}] = {0x20}; // ZERO WIDTH NO-BREAK
|
||||
nfkc_map[{0xFFFD}] = {0x20}; // REPLACEMENT CHARACTER
|
||||
nfkc_map[{0x200C}] = {0x20}; // ZERO WIDTH NON-JOINER
|
||||
nfkc_map[{0x200D}] = {0x20}; // ZERO WIDTH JOINER
|
||||
|
||||
// Ascii Control characters
|
||||
nfkc_map[{0x0001}] = {};
|
||||
|
@ -285,22 +285,22 @@ class TrainerSpec::_Internal {
|
||||
(*has_bits)[0] |= 1u;
|
||||
}
|
||||
static void set_has_model_type(HasBits* has_bits) {
|
||||
(*has_bits)[0] |= 524288u;
|
||||
(*has_bits)[0] |= 262144u;
|
||||
}
|
||||
static void set_has_vocab_size(HasBits* has_bits) {
|
||||
(*has_bits)[0] |= 1048576u;
|
||||
(*has_bits)[0] |= 524288u;
|
||||
}
|
||||
static void set_has_self_test_sample_size(HasBits* has_bits) {
|
||||
(*has_bits)[0] |= 256u;
|
||||
}
|
||||
static void set_has_character_coverage(HasBits* has_bits) {
|
||||
(*has_bits)[0] |= 2097152u;
|
||||
(*has_bits)[0] |= 1048576u;
|
||||
}
|
||||
static void set_has_input_sentence_size(HasBits* has_bits) {
|
||||
(*has_bits)[0] |= 1024u;
|
||||
}
|
||||
static void set_has_shuffle_input_sentence(HasBits* has_bits) {
|
||||
(*has_bits)[0] |= 268435456u;
|
||||
(*has_bits)[0] |= 134217728u;
|
||||
}
|
||||
static void set_has_mining_sentence_size(HasBits* has_bits) {
|
||||
(*has_bits)[0] |= 512u;
|
||||
@ -309,67 +309,64 @@ class TrainerSpec::_Internal {
|
||||
(*has_bits)[0] |= 2048u;
|
||||
}
|
||||
static void set_has_seed_sentencepiece_size(HasBits* has_bits) {
|
||||
(*has_bits)[0] |= 4194304u;
|
||||
(*has_bits)[0] |= 2097152u;
|
||||
}
|
||||
static void set_has_shrinking_factor(HasBits* has_bits) {
|
||||
(*has_bits)[0] |= 8388608u;
|
||||
(*has_bits)[0] |= 4194304u;
|
||||
}
|
||||
static void set_has_max_sentence_length(HasBits* has_bits) {
|
||||
(*has_bits)[0] |= 67108864u;
|
||||
}
|
||||
static void set_has_num_threads(HasBits* has_bits) {
|
||||
(*has_bits)[0] |= 16777216u;
|
||||
}
|
||||
static void set_has_num_sub_iterations(HasBits* has_bits) {
|
||||
(*has_bits)[0] |= 33554432u;
|
||||
}
|
||||
static void set_has_num_threads(HasBits* has_bits) {
|
||||
(*has_bits)[0] |= 8388608u;
|
||||
}
|
||||
static void set_has_num_sub_iterations(HasBits* has_bits) {
|
||||
(*has_bits)[0] |= 16777216u;
|
||||
}
|
||||
static void set_has_max_sentencepiece_length(HasBits* has_bits) {
|
||||
(*has_bits)[0] |= 134217728u;
|
||||
(*has_bits)[0] |= 67108864u;
|
||||
}
|
||||
static void set_has_split_by_unicode_script(HasBits* has_bits) {
|
||||
(*has_bits)[0] |= 536870912u;
|
||||
(*has_bits)[0] |= 268435456u;
|
||||
}
|
||||
static void set_has_split_by_number(HasBits* has_bits) {
|
||||
(*has_bits)[0] |= 1073741824u;
|
||||
(*has_bits)[0] |= 536870912u;
|
||||
}
|
||||
static void set_has_split_by_whitespace(HasBits* has_bits) {
|
||||
(*has_bits)[0] |= 2147483648u;
|
||||
(*has_bits)[0] |= 1073741824u;
|
||||
}
|
||||
static void set_has_treat_whitespace_as_suffix(HasBits* has_bits) {
|
||||
(*has_bits)[0] |= 4096u;
|
||||
}
|
||||
static void set_has_allow_whitespace_only_pieces(HasBits* has_bits) {
|
||||
(*has_bits)[0] |= 8192u;
|
||||
}
|
||||
static void set_has_split_digits(HasBits* has_bits) {
|
||||
(*has_bits)[0] |= 16384u;
|
||||
(*has_bits)[0] |= 8192u;
|
||||
}
|
||||
static void set_has_required_chars(HasBits* has_bits) {
|
||||
(*has_bits)[0] |= 4u;
|
||||
}
|
||||
static void set_has_byte_fallback(HasBits* has_bits) {
|
||||
(*has_bits)[0] |= 32768u;
|
||||
(*has_bits)[0] |= 16384u;
|
||||
}
|
||||
static void set_has_vocabulary_output_piece_score(HasBits* has_bits) {
|
||||
(*has_bits)[1] |= 1u;
|
||||
(*has_bits)[0] |= 2147483648u;
|
||||
}
|
||||
static void set_has_hard_vocab_limit(HasBits* has_bits) {
|
||||
(*has_bits)[1] |= 2u;
|
||||
(*has_bits)[1] |= 1u;
|
||||
}
|
||||
static void set_has_use_all_vocab(HasBits* has_bits) {
|
||||
(*has_bits)[0] |= 131072u;
|
||||
(*has_bits)[0] |= 32768u;
|
||||
}
|
||||
static void set_has_unk_id(HasBits* has_bits) {
|
||||
(*has_bits)[0] |= 65536u;
|
||||
}
|
||||
static void set_has_bos_id(HasBits* has_bits) {
|
||||
(*has_bits)[1] |= 4u;
|
||||
(*has_bits)[1] |= 2u;
|
||||
}
|
||||
static void set_has_eos_id(HasBits* has_bits) {
|
||||
(*has_bits)[1] |= 8u;
|
||||
(*has_bits)[1] |= 4u;
|
||||
}
|
||||
static void set_has_pad_id(HasBits* has_bits) {
|
||||
(*has_bits)[1] |= 16u;
|
||||
(*has_bits)[1] |= 8u;
|
||||
}
|
||||
static void set_has_unk_piece(HasBits* has_bits) {
|
||||
(*has_bits)[0] |= 16u;
|
||||
@ -387,7 +384,7 @@ class TrainerSpec::_Internal {
|
||||
(*has_bits)[0] |= 8u;
|
||||
}
|
||||
static void set_has_train_extremely_large_corpus(HasBits* has_bits) {
|
||||
(*has_bits)[0] |= 262144u;
|
||||
(*has_bits)[0] |= 131072u;
|
||||
}
|
||||
};
|
||||
|
||||
@ -569,8 +566,8 @@ void TrainerSpec::Clear() {
|
||||
}
|
||||
if (cached_has_bits & 0x0000ff00u) {
|
||||
::memset(&self_test_sample_size_, 0, static_cast<size_t>(
|
||||
reinterpret_cast<char*>(&byte_fallback_) -
|
||||
reinterpret_cast<char*>(&self_test_sample_size_)) + sizeof(byte_fallback_));
|
||||
reinterpret_cast<char*>(&use_all_vocab_) -
|
||||
reinterpret_cast<char*>(&self_test_sample_size_)) + sizeof(use_all_vocab_));
|
||||
}
|
||||
if (cached_has_bits & 0x00ff0000u) {
|
||||
::memset(&unk_id_, 0, static_cast<size_t>(
|
||||
@ -581,9 +578,9 @@ void TrainerSpec::Clear() {
|
||||
character_coverage_ = 0.9995f;
|
||||
seed_sentencepiece_size_ = 1000000;
|
||||
shrinking_factor_ = 0.75f;
|
||||
num_threads_ = 16;
|
||||
}
|
||||
if (cached_has_bits & 0xff000000u) {
|
||||
num_threads_ = 16;
|
||||
num_sub_iterations_ = 2;
|
||||
max_sentence_length_ = 4192;
|
||||
max_sentencepiece_length_ = 16;
|
||||
@ -591,10 +588,10 @@ void TrainerSpec::Clear() {
|
||||
split_by_unicode_script_ = true;
|
||||
split_by_number_ = true;
|
||||
split_by_whitespace_ = true;
|
||||
vocabulary_output_piece_score_ = true;
|
||||
}
|
||||
cached_has_bits = _has_bits_[1];
|
||||
if (cached_has_bits & 0x0000001fu) {
|
||||
vocabulary_output_piece_score_ = true;
|
||||
if (cached_has_bits & 0x0000000fu) {
|
||||
hard_vocab_limit_ = true;
|
||||
bos_id_ = 1;
|
||||
eos_id_ = 2;
|
||||
@ -809,14 +806,6 @@ const char* TrainerSpec::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID
|
||||
CHK_(ptr);
|
||||
} else goto handle_unusual;
|
||||
continue;
|
||||
// optional bool allow_whitespace_only_pieces = 26 [default = false];
|
||||
case 26:
|
||||
if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 208)) {
|
||||
_Internal::set_has_allow_whitespace_only_pieces(&_has_bits_);
|
||||
allow_whitespace_only_pieces_ = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr);
|
||||
CHK_(ptr);
|
||||
} else goto handle_unusual;
|
||||
continue;
|
||||
// repeated string control_symbols = 30;
|
||||
case 30:
|
||||
if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 242)) {
|
||||
@ -1011,14 +1000,14 @@ failure:
|
||||
}
|
||||
|
||||
// optional .sentencepiece.TrainerSpec.ModelType model_type = 3 [default = UNIGRAM];
|
||||
if (cached_has_bits & 0x00080000u) {
|
||||
if (cached_has_bits & 0x00040000u) {
|
||||
target = stream->EnsureSpace(target);
|
||||
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteEnumToArray(
|
||||
3, this->_internal_model_type(), target);
|
||||
}
|
||||
|
||||
// optional int32 vocab_size = 4 [default = 8000];
|
||||
if (cached_has_bits & 0x00100000u) {
|
||||
if (cached_has_bits & 0x00080000u) {
|
||||
target = stream->EnsureSpace(target);
|
||||
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(4, this->_internal_vocab_size(), target);
|
||||
}
|
||||
@ -1042,7 +1031,7 @@ failure:
|
||||
}
|
||||
|
||||
// optional float character_coverage = 10 [default = 0.9995];
|
||||
if (cached_has_bits & 0x00200000u) {
|
||||
if (cached_has_bits & 0x00100000u) {
|
||||
target = stream->EnsureSpace(target);
|
||||
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteFloatToArray(10, this->_internal_character_coverage(), target);
|
||||
}
|
||||
@ -1066,61 +1055,61 @@ failure:
|
||||
}
|
||||
|
||||
// optional int32 seed_sentencepiece_size = 14 [default = 1000000];
|
||||
if (cached_has_bits & 0x00400000u) {
|
||||
if (cached_has_bits & 0x00200000u) {
|
||||
target = stream->EnsureSpace(target);
|
||||
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(14, this->_internal_seed_sentencepiece_size(), target);
|
||||
}
|
||||
|
||||
// optional float shrinking_factor = 15 [default = 0.75];
|
||||
if (cached_has_bits & 0x00800000u) {
|
||||
if (cached_has_bits & 0x00400000u) {
|
||||
target = stream->EnsureSpace(target);
|
||||
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteFloatToArray(15, this->_internal_shrinking_factor(), target);
|
||||
}
|
||||
|
||||
// optional int32 num_threads = 16 [default = 16];
|
||||
if (cached_has_bits & 0x01000000u) {
|
||||
if (cached_has_bits & 0x00800000u) {
|
||||
target = stream->EnsureSpace(target);
|
||||
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(16, this->_internal_num_threads(), target);
|
||||
}
|
||||
|
||||
// optional int32 num_sub_iterations = 17 [default = 2];
|
||||
if (cached_has_bits & 0x02000000u) {
|
||||
if (cached_has_bits & 0x01000000u) {
|
||||
target = stream->EnsureSpace(target);
|
||||
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(17, this->_internal_num_sub_iterations(), target);
|
||||
}
|
||||
|
||||
// optional int32 max_sentence_length = 18 [default = 4192];
|
||||
if (cached_has_bits & 0x04000000u) {
|
||||
if (cached_has_bits & 0x02000000u) {
|
||||
target = stream->EnsureSpace(target);
|
||||
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(18, this->_internal_max_sentence_length(), target);
|
||||
}
|
||||
|
||||
// optional bool shuffle_input_sentence = 19 [default = true];
|
||||
if (cached_has_bits & 0x10000000u) {
|
||||
if (cached_has_bits & 0x08000000u) {
|
||||
target = stream->EnsureSpace(target);
|
||||
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(19, this->_internal_shuffle_input_sentence(), target);
|
||||
}
|
||||
|
||||
// optional int32 max_sentencepiece_length = 20 [default = 16];
|
||||
if (cached_has_bits & 0x08000000u) {
|
||||
if (cached_has_bits & 0x04000000u) {
|
||||
target = stream->EnsureSpace(target);
|
||||
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(20, this->_internal_max_sentencepiece_length(), target);
|
||||
}
|
||||
|
||||
// optional bool split_by_unicode_script = 21 [default = true];
|
||||
if (cached_has_bits & 0x20000000u) {
|
||||
if (cached_has_bits & 0x10000000u) {
|
||||
target = stream->EnsureSpace(target);
|
||||
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(21, this->_internal_split_by_unicode_script(), target);
|
||||
}
|
||||
|
||||
// optional bool split_by_whitespace = 22 [default = true];
|
||||
if (cached_has_bits & 0x80000000u) {
|
||||
if (cached_has_bits & 0x40000000u) {
|
||||
target = stream->EnsureSpace(target);
|
||||
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(22, this->_internal_split_by_whitespace(), target);
|
||||
}
|
||||
|
||||
// optional bool split_by_number = 23 [default = true];
|
||||
if (cached_has_bits & 0x40000000u) {
|
||||
if (cached_has_bits & 0x20000000u) {
|
||||
target = stream->EnsureSpace(target);
|
||||
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(23, this->_internal_split_by_number(), target);
|
||||
}
|
||||
@ -1132,15 +1121,9 @@ failure:
|
||||
}
|
||||
|
||||
// optional bool split_digits = 25 [default = false];
|
||||
if (cached_has_bits & 0x00004000u) {
|
||||
target = stream->EnsureSpace(target);
|
||||
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(25, this->_internal_split_digits(), target);
|
||||
}
|
||||
|
||||
// optional bool allow_whitespace_only_pieces = 26 [default = false];
|
||||
if (cached_has_bits & 0x00002000u) {
|
||||
target = stream->EnsureSpace(target);
|
||||
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(26, this->_internal_allow_whitespace_only_pieces(), target);
|
||||
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(25, this->_internal_split_digits(), target);
|
||||
}
|
||||
|
||||
// repeated string control_symbols = 30;
|
||||
@ -1155,28 +1138,28 @@ failure:
|
||||
target = stream->WriteString(31, s, target);
|
||||
}
|
||||
|
||||
cached_has_bits = _has_bits_[1];
|
||||
// optional bool vocabulary_output_piece_score = 32 [default = true];
|
||||
if (cached_has_bits & 0x00000001u) {
|
||||
if (cached_has_bits & 0x80000000u) {
|
||||
target = stream->EnsureSpace(target);
|
||||
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(32, this->_internal_vocabulary_output_piece_score(), target);
|
||||
}
|
||||
|
||||
cached_has_bits = _has_bits_[1];
|
||||
// optional bool hard_vocab_limit = 33 [default = true];
|
||||
if (cached_has_bits & 0x00000002u) {
|
||||
if (cached_has_bits & 0x00000001u) {
|
||||
target = stream->EnsureSpace(target);
|
||||
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(33, this->_internal_hard_vocab_limit(), target);
|
||||
}
|
||||
|
||||
cached_has_bits = _has_bits_[0];
|
||||
// optional bool use_all_vocab = 34 [default = false];
|
||||
if (cached_has_bits & 0x00020000u) {
|
||||
if (cached_has_bits & 0x00008000u) {
|
||||
target = stream->EnsureSpace(target);
|
||||
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(34, this->_internal_use_all_vocab(), target);
|
||||
}
|
||||
|
||||
// optional bool byte_fallback = 35 [default = false];
|
||||
if (cached_has_bits & 0x00008000u) {
|
||||
if (cached_has_bits & 0x00004000u) {
|
||||
target = stream->EnsureSpace(target);
|
||||
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(35, this->_internal_byte_fallback(), target);
|
||||
}
|
||||
@ -1195,19 +1178,19 @@ failure:
|
||||
|
||||
cached_has_bits = _has_bits_[1];
|
||||
// optional int32 bos_id = 41 [default = 1];
|
||||
if (cached_has_bits & 0x00000004u) {
|
||||
if (cached_has_bits & 0x00000002u) {
|
||||
target = stream->EnsureSpace(target);
|
||||
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(41, this->_internal_bos_id(), target);
|
||||
}
|
||||
|
||||
// optional int32 eos_id = 42 [default = 2];
|
||||
if (cached_has_bits & 0x00000008u) {
|
||||
if (cached_has_bits & 0x00000004u) {
|
||||
target = stream->EnsureSpace(target);
|
||||
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(42, this->_internal_eos_id(), target);
|
||||
}
|
||||
|
||||
// optional int32 pad_id = 43 [default = -1];
|
||||
if (cached_has_bits & 0x00000010u) {
|
||||
if (cached_has_bits & 0x00000008u) {
|
||||
target = stream->EnsureSpace(target);
|
||||
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(43, this->_internal_pad_id(), target);
|
||||
}
|
||||
@ -1244,7 +1227,7 @@ failure:
|
||||
}
|
||||
|
||||
// optional bool train_extremely_large_corpus = 49 [default = false];
|
||||
if (cached_has_bits & 0x00040000u) {
|
||||
if (cached_has_bits & 0x00020000u) {
|
||||
target = stream->EnsureSpace(target);
|
||||
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(49, this->_internal_train_extremely_large_corpus(), target);
|
||||
}
|
||||
@ -1396,17 +1379,17 @@ size_t TrainerSpec::ByteSizeLong() const {
|
||||
total_size += 2 + 1;
|
||||
}
|
||||
|
||||
// optional bool allow_whitespace_only_pieces = 26 [default = false];
|
||||
// optional bool split_digits = 25 [default = false];
|
||||
if (cached_has_bits & 0x00002000u) {
|
||||
total_size += 2 + 1;
|
||||
}
|
||||
|
||||
// optional bool split_digits = 25 [default = false];
|
||||
// optional bool byte_fallback = 35 [default = false];
|
||||
if (cached_has_bits & 0x00004000u) {
|
||||
total_size += 2 + 1;
|
||||
}
|
||||
|
||||
// optional bool byte_fallback = 35 [default = false];
|
||||
// optional bool use_all_vocab = 34 [default = false];
|
||||
if (cached_has_bits & 0x00008000u) {
|
||||
total_size += 2 + 1;
|
||||
}
|
||||
@ -1420,125 +1403,120 @@ size_t TrainerSpec::ByteSizeLong() const {
|
||||
this->_internal_unk_id());
|
||||
}
|
||||
|
||||
// optional bool use_all_vocab = 34 [default = false];
|
||||
// optional bool train_extremely_large_corpus = 49 [default = false];
|
||||
if (cached_has_bits & 0x00020000u) {
|
||||
total_size += 2 + 1;
|
||||
}
|
||||
|
||||
// optional bool train_extremely_large_corpus = 49 [default = false];
|
||||
if (cached_has_bits & 0x00040000u) {
|
||||
total_size += 2 + 1;
|
||||
}
|
||||
|
||||
// optional .sentencepiece.TrainerSpec.ModelType model_type = 3 [default = UNIGRAM];
|
||||
if (cached_has_bits & 0x00080000u) {
|
||||
if (cached_has_bits & 0x00040000u) {
|
||||
total_size += 1 +
|
||||
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::EnumSize(this->_internal_model_type());
|
||||
}
|
||||
|
||||
// optional int32 vocab_size = 4 [default = 8000];
|
||||
if (cached_has_bits & 0x00100000u) {
|
||||
if (cached_has_bits & 0x00080000u) {
|
||||
total_size += 1 +
|
||||
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size(
|
||||
this->_internal_vocab_size());
|
||||
}
|
||||
|
||||
// optional float character_coverage = 10 [default = 0.9995];
|
||||
if (cached_has_bits & 0x00200000u) {
|
||||
if (cached_has_bits & 0x00100000u) {
|
||||
total_size += 1 + 4;
|
||||
}
|
||||
|
||||
// optional int32 seed_sentencepiece_size = 14 [default = 1000000];
|
||||
if (cached_has_bits & 0x00400000u) {
|
||||
if (cached_has_bits & 0x00200000u) {
|
||||
total_size += 1 +
|
||||
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size(
|
||||
this->_internal_seed_sentencepiece_size());
|
||||
}
|
||||
|
||||
// optional float shrinking_factor = 15 [default = 0.75];
|
||||
if (cached_has_bits & 0x00800000u) {
|
||||
if (cached_has_bits & 0x00400000u) {
|
||||
total_size += 1 + 4;
|
||||
}
|
||||
|
||||
}
|
||||
if (cached_has_bits & 0xff000000u) {
|
||||
// optional int32 num_threads = 16 [default = 16];
|
||||
if (cached_has_bits & 0x01000000u) {
|
||||
if (cached_has_bits & 0x00800000u) {
|
||||
total_size += 2 +
|
||||
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size(
|
||||
this->_internal_num_threads());
|
||||
}
|
||||
|
||||
}
|
||||
if (cached_has_bits & 0xff000000u) {
|
||||
// optional int32 num_sub_iterations = 17 [default = 2];
|
||||
if (cached_has_bits & 0x02000000u) {
|
||||
if (cached_has_bits & 0x01000000u) {
|
||||
total_size += 2 +
|
||||
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size(
|
||||
this->_internal_num_sub_iterations());
|
||||
}
|
||||
|
||||
// optional int32 max_sentence_length = 18 [default = 4192];
|
||||
if (cached_has_bits & 0x04000000u) {
|
||||
if (cached_has_bits & 0x02000000u) {
|
||||
total_size += 2 +
|
||||
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size(
|
||||
this->_internal_max_sentence_length());
|
||||
}
|
||||
|
||||
// optional int32 max_sentencepiece_length = 20 [default = 16];
|
||||
if (cached_has_bits & 0x08000000u) {
|
||||
if (cached_has_bits & 0x04000000u) {
|
||||
total_size += 2 +
|
||||
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size(
|
||||
this->_internal_max_sentencepiece_length());
|
||||
}
|
||||
|
||||
// optional bool shuffle_input_sentence = 19 [default = true];
|
||||
if (cached_has_bits & 0x10000000u) {
|
||||
if (cached_has_bits & 0x08000000u) {
|
||||
total_size += 2 + 1;
|
||||
}
|
||||
|
||||
// optional bool split_by_unicode_script = 21 [default = true];
|
||||
if (cached_has_bits & 0x20000000u) {
|
||||
if (cached_has_bits & 0x10000000u) {
|
||||
total_size += 2 + 1;
|
||||
}
|
||||
|
||||
// optional bool split_by_number = 23 [default = true];
|
||||
if (cached_has_bits & 0x40000000u) {
|
||||
if (cached_has_bits & 0x20000000u) {
|
||||
total_size += 2 + 1;
|
||||
}
|
||||
|
||||
// optional bool split_by_whitespace = 22 [default = true];
|
||||
if (cached_has_bits & 0x40000000u) {
|
||||
total_size += 2 + 1;
|
||||
}
|
||||
|
||||
// optional bool vocabulary_output_piece_score = 32 [default = true];
|
||||
if (cached_has_bits & 0x80000000u) {
|
||||
total_size += 2 + 1;
|
||||
}
|
||||
|
||||
}
|
||||
cached_has_bits = _has_bits_[1];
|
||||
if (cached_has_bits & 0x0000001fu) {
|
||||
// optional bool vocabulary_output_piece_score = 32 [default = true];
|
||||
if (cached_has_bits & 0x0000000fu) {
|
||||
// optional bool hard_vocab_limit = 33 [default = true];
|
||||
if (cached_has_bits & 0x00000001u) {
|
||||
total_size += 2 + 1;
|
||||
}
|
||||
|
||||
// optional bool hard_vocab_limit = 33 [default = true];
|
||||
if (cached_has_bits & 0x00000002u) {
|
||||
total_size += 2 + 1;
|
||||
}
|
||||
|
||||
// optional int32 bos_id = 41 [default = 1];
|
||||
if (cached_has_bits & 0x00000004u) {
|
||||
if (cached_has_bits & 0x00000002u) {
|
||||
total_size += 2 +
|
||||
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size(
|
||||
this->_internal_bos_id());
|
||||
}
|
||||
|
||||
// optional int32 eos_id = 42 [default = 2];
|
||||
if (cached_has_bits & 0x00000008u) {
|
||||
if (cached_has_bits & 0x00000004u) {
|
||||
total_size += 2 +
|
||||
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size(
|
||||
this->_internal_eos_id());
|
||||
}
|
||||
|
||||
// optional int32 pad_id = 43 [default = -1];
|
||||
if (cached_has_bits & 0x00000010u) {
|
||||
if (cached_has_bits & 0x00000008u) {
|
||||
total_size += 2 +
|
||||
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size(
|
||||
this->_internal_pad_id());
|
||||
@ -1615,14 +1593,14 @@ void TrainerSpec::MergeFrom(const TrainerSpec& from) {
|
||||
treat_whitespace_as_suffix_ = from.treat_whitespace_as_suffix_;
|
||||
}
|
||||
if (cached_has_bits & 0x00002000u) {
|
||||
allow_whitespace_only_pieces_ = from.allow_whitespace_only_pieces_;
|
||||
}
|
||||
if (cached_has_bits & 0x00004000u) {
|
||||
split_digits_ = from.split_digits_;
|
||||
}
|
||||
if (cached_has_bits & 0x00008000u) {
|
||||
if (cached_has_bits & 0x00004000u) {
|
||||
byte_fallback_ = from.byte_fallback_;
|
||||
}
|
||||
if (cached_has_bits & 0x00008000u) {
|
||||
use_all_vocab_ = from.use_all_vocab_;
|
||||
}
|
||||
_has_bits_[0] |= cached_has_bits;
|
||||
}
|
||||
if (cached_has_bits & 0x00ff0000u) {
|
||||
@ -1630,70 +1608,67 @@ void TrainerSpec::MergeFrom(const TrainerSpec& from) {
|
||||
unk_id_ = from.unk_id_;
|
||||
}
|
||||
if (cached_has_bits & 0x00020000u) {
|
||||
use_all_vocab_ = from.use_all_vocab_;
|
||||
}
|
||||
if (cached_has_bits & 0x00040000u) {
|
||||
train_extremely_large_corpus_ = from.train_extremely_large_corpus_;
|
||||
}
|
||||
if (cached_has_bits & 0x00080000u) {
|
||||
if (cached_has_bits & 0x00040000u) {
|
||||
model_type_ = from.model_type_;
|
||||
}
|
||||
if (cached_has_bits & 0x00100000u) {
|
||||
if (cached_has_bits & 0x00080000u) {
|
||||
vocab_size_ = from.vocab_size_;
|
||||
}
|
||||
if (cached_has_bits & 0x00200000u) {
|
||||
if (cached_has_bits & 0x00100000u) {
|
||||
character_coverage_ = from.character_coverage_;
|
||||
}
|
||||
if (cached_has_bits & 0x00400000u) {
|
||||
if (cached_has_bits & 0x00200000u) {
|
||||
seed_sentencepiece_size_ = from.seed_sentencepiece_size_;
|
||||
}
|
||||
if (cached_has_bits & 0x00800000u) {
|
||||
if (cached_has_bits & 0x00400000u) {
|
||||
shrinking_factor_ = from.shrinking_factor_;
|
||||
}
|
||||
if (cached_has_bits & 0x00800000u) {
|
||||
num_threads_ = from.num_threads_;
|
||||
}
|
||||
_has_bits_[0] |= cached_has_bits;
|
||||
}
|
||||
if (cached_has_bits & 0xff000000u) {
|
||||
if (cached_has_bits & 0x01000000u) {
|
||||
num_threads_ = from.num_threads_;
|
||||
}
|
||||
if (cached_has_bits & 0x02000000u) {
|
||||
num_sub_iterations_ = from.num_sub_iterations_;
|
||||
}
|
||||
if (cached_has_bits & 0x04000000u) {
|
||||
if (cached_has_bits & 0x02000000u) {
|
||||
max_sentence_length_ = from.max_sentence_length_;
|
||||
}
|
||||
if (cached_has_bits & 0x08000000u) {
|
||||
if (cached_has_bits & 0x04000000u) {
|
||||
max_sentencepiece_length_ = from.max_sentencepiece_length_;
|
||||
}
|
||||
if (cached_has_bits & 0x10000000u) {
|
||||
if (cached_has_bits & 0x08000000u) {
|
||||
shuffle_input_sentence_ = from.shuffle_input_sentence_;
|
||||
}
|
||||
if (cached_has_bits & 0x20000000u) {
|
||||
if (cached_has_bits & 0x10000000u) {
|
||||
split_by_unicode_script_ = from.split_by_unicode_script_;
|
||||
}
|
||||
if (cached_has_bits & 0x40000000u) {
|
||||
if (cached_has_bits & 0x20000000u) {
|
||||
split_by_number_ = from.split_by_number_;
|
||||
}
|
||||
if (cached_has_bits & 0x80000000u) {
|
||||
if (cached_has_bits & 0x40000000u) {
|
||||
split_by_whitespace_ = from.split_by_whitespace_;
|
||||
}
|
||||
if (cached_has_bits & 0x80000000u) {
|
||||
vocabulary_output_piece_score_ = from.vocabulary_output_piece_score_;
|
||||
}
|
||||
_has_bits_[0] |= cached_has_bits;
|
||||
}
|
||||
cached_has_bits = from._has_bits_[1];
|
||||
if (cached_has_bits & 0x0000001fu) {
|
||||
if (cached_has_bits & 0x0000000fu) {
|
||||
if (cached_has_bits & 0x00000001u) {
|
||||
vocabulary_output_piece_score_ = from.vocabulary_output_piece_score_;
|
||||
}
|
||||
if (cached_has_bits & 0x00000002u) {
|
||||
hard_vocab_limit_ = from.hard_vocab_limit_;
|
||||
}
|
||||
if (cached_has_bits & 0x00000004u) {
|
||||
if (cached_has_bits & 0x00000002u) {
|
||||
bos_id_ = from.bos_id_;
|
||||
}
|
||||
if (cached_has_bits & 0x00000008u) {
|
||||
if (cached_has_bits & 0x00000004u) {
|
||||
eos_id_ = from.eos_id_;
|
||||
}
|
||||
if (cached_has_bits & 0x00000010u) {
|
||||
if (cached_has_bits & 0x00000008u) {
|
||||
pad_id_ = from.pad_id_;
|
||||
}
|
||||
_has_bits_[1] |= cached_has_bits;
|
||||
|
@ -278,11 +278,10 @@ class TrainerSpec PROTOBUF_FINAL :
|
||||
kInputSentenceSizeFieldNumber = 11,
|
||||
kTrainingSentenceSizeFieldNumber = 13,
|
||||
kTreatWhitespaceAsSuffixFieldNumber = 24,
|
||||
kAllowWhitespaceOnlyPiecesFieldNumber = 26,
|
||||
kSplitDigitsFieldNumber = 25,
|
||||
kByteFallbackFieldNumber = 35,
|
||||
kUnkIdFieldNumber = 40,
|
||||
kUseAllVocabFieldNumber = 34,
|
||||
kUnkIdFieldNumber = 40,
|
||||
kTrainExtremelyLargeCorpusFieldNumber = 49,
|
||||
kModelTypeFieldNumber = 3,
|
||||
kVocabSizeFieldNumber = 4,
|
||||
@ -624,19 +623,6 @@ class TrainerSpec PROTOBUF_FINAL :
|
||||
void _internal_set_treat_whitespace_as_suffix(bool value);
|
||||
public:
|
||||
|
||||
// optional bool allow_whitespace_only_pieces = 26 [default = false];
|
||||
bool has_allow_whitespace_only_pieces() const;
|
||||
private:
|
||||
bool _internal_has_allow_whitespace_only_pieces() const;
|
||||
public:
|
||||
void clear_allow_whitespace_only_pieces();
|
||||
bool allow_whitespace_only_pieces() const;
|
||||
void set_allow_whitespace_only_pieces(bool value);
|
||||
private:
|
||||
bool _internal_allow_whitespace_only_pieces() const;
|
||||
void _internal_set_allow_whitespace_only_pieces(bool value);
|
||||
public:
|
||||
|
||||
// optional bool split_digits = 25 [default = false];
|
||||
bool has_split_digits() const;
|
||||
private:
|
||||
@ -663,19 +649,6 @@ class TrainerSpec PROTOBUF_FINAL :
|
||||
void _internal_set_byte_fallback(bool value);
|
||||
public:
|
||||
|
||||
// optional int32 unk_id = 40 [default = 0];
|
||||
bool has_unk_id() const;
|
||||
private:
|
||||
bool _internal_has_unk_id() const;
|
||||
public:
|
||||
void clear_unk_id();
|
||||
::PROTOBUF_NAMESPACE_ID::int32 unk_id() const;
|
||||
void set_unk_id(::PROTOBUF_NAMESPACE_ID::int32 value);
|
||||
private:
|
||||
::PROTOBUF_NAMESPACE_ID::int32 _internal_unk_id() const;
|
||||
void _internal_set_unk_id(::PROTOBUF_NAMESPACE_ID::int32 value);
|
||||
public:
|
||||
|
||||
// optional bool use_all_vocab = 34 [default = false];
|
||||
bool has_use_all_vocab() const;
|
||||
private:
|
||||
@ -689,6 +662,19 @@ class TrainerSpec PROTOBUF_FINAL :
|
||||
void _internal_set_use_all_vocab(bool value);
|
||||
public:
|
||||
|
||||
// optional int32 unk_id = 40 [default = 0];
|
||||
bool has_unk_id() const;
|
||||
private:
|
||||
bool _internal_has_unk_id() const;
|
||||
public:
|
||||
void clear_unk_id();
|
||||
::PROTOBUF_NAMESPACE_ID::int32 unk_id() const;
|
||||
void set_unk_id(::PROTOBUF_NAMESPACE_ID::int32 value);
|
||||
private:
|
||||
::PROTOBUF_NAMESPACE_ID::int32 _internal_unk_id() const;
|
||||
void _internal_set_unk_id(::PROTOBUF_NAMESPACE_ID::int32 value);
|
||||
public:
|
||||
|
||||
// optional bool train_extremely_large_corpus = 49 [default = false];
|
||||
bool has_train_extremely_large_corpus() const;
|
||||
private:
|
||||
@ -970,11 +956,10 @@ class TrainerSpec PROTOBUF_FINAL :
|
||||
::PROTOBUF_NAMESPACE_ID::uint64 input_sentence_size_;
|
||||
::PROTOBUF_NAMESPACE_ID::int32 training_sentence_size_;
|
||||
bool treat_whitespace_as_suffix_;
|
||||
bool allow_whitespace_only_pieces_;
|
||||
bool split_digits_;
|
||||
bool byte_fallback_;
|
||||
::PROTOBUF_NAMESPACE_ID::int32 unk_id_;
|
||||
bool use_all_vocab_;
|
||||
::PROTOBUF_NAMESPACE_ID::int32 unk_id_;
|
||||
bool train_extremely_large_corpus_;
|
||||
int model_type_;
|
||||
::PROTOBUF_NAMESPACE_ID::int32 vocab_size_;
|
||||
@ -2195,7 +2180,7 @@ inline void TrainerSpec::set_allocated_model_prefix(std::string* model_prefix) {
|
||||
|
||||
// optional .sentencepiece.TrainerSpec.ModelType model_type = 3 [default = UNIGRAM];
|
||||
inline bool TrainerSpec::_internal_has_model_type() const {
|
||||
bool value = (_has_bits_[0] & 0x00080000u) != 0;
|
||||
bool value = (_has_bits_[0] & 0x00040000u) != 0;
|
||||
return value;
|
||||
}
|
||||
inline bool TrainerSpec::has_model_type() const {
|
||||
@ -2203,7 +2188,7 @@ inline bool TrainerSpec::has_model_type() const {
|
||||
}
|
||||
inline void TrainerSpec::clear_model_type() {
|
||||
model_type_ = 1;
|
||||
_has_bits_[0] &= ~0x00080000u;
|
||||
_has_bits_[0] &= ~0x00040000u;
|
||||
}
|
||||
inline ::sentencepiece::TrainerSpec_ModelType TrainerSpec::_internal_model_type() const {
|
||||
return static_cast< ::sentencepiece::TrainerSpec_ModelType >(model_type_);
|
||||
@ -2214,7 +2199,7 @@ inline ::sentencepiece::TrainerSpec_ModelType TrainerSpec::model_type() const {
|
||||
}
|
||||
inline void TrainerSpec::_internal_set_model_type(::sentencepiece::TrainerSpec_ModelType value) {
|
||||
assert(::sentencepiece::TrainerSpec_ModelType_IsValid(value));
|
||||
_has_bits_[0] |= 0x00080000u;
|
||||
_has_bits_[0] |= 0x00040000u;
|
||||
model_type_ = value;
|
||||
}
|
||||
inline void TrainerSpec::set_model_type(::sentencepiece::TrainerSpec_ModelType value) {
|
||||
@ -2224,7 +2209,7 @@ inline void TrainerSpec::set_model_type(::sentencepiece::TrainerSpec_ModelType v
|
||||
|
||||
// optional int32 vocab_size = 4 [default = 8000];
|
||||
inline bool TrainerSpec::_internal_has_vocab_size() const {
|
||||
bool value = (_has_bits_[0] & 0x00100000u) != 0;
|
||||
bool value = (_has_bits_[0] & 0x00080000u) != 0;
|
||||
return value;
|
||||
}
|
||||
inline bool TrainerSpec::has_vocab_size() const {
|
||||
@ -2232,7 +2217,7 @@ inline bool TrainerSpec::has_vocab_size() const {
|
||||
}
|
||||
inline void TrainerSpec::clear_vocab_size() {
|
||||
vocab_size_ = 8000;
|
||||
_has_bits_[0] &= ~0x00100000u;
|
||||
_has_bits_[0] &= ~0x00080000u;
|
||||
}
|
||||
inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_vocab_size() const {
|
||||
return vocab_size_;
|
||||
@ -2242,7 +2227,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::vocab_size() const {
|
||||
return _internal_vocab_size();
|
||||
}
|
||||
inline void TrainerSpec::_internal_set_vocab_size(::PROTOBUF_NAMESPACE_ID::int32 value) {
|
||||
_has_bits_[0] |= 0x00100000u;
|
||||
_has_bits_[0] |= 0x00080000u;
|
||||
vocab_size_ = value;
|
||||
}
|
||||
inline void TrainerSpec::set_vocab_size(::PROTOBUF_NAMESPACE_ID::int32 value) {
|
||||
@ -2354,7 +2339,7 @@ inline void TrainerSpec::set_self_test_sample_size(::PROTOBUF_NAMESPACE_ID::int3
|
||||
|
||||
// optional float character_coverage = 10 [default = 0.9995];
|
||||
inline bool TrainerSpec::_internal_has_character_coverage() const {
|
||||
bool value = (_has_bits_[0] & 0x00200000u) != 0;
|
||||
bool value = (_has_bits_[0] & 0x00100000u) != 0;
|
||||
return value;
|
||||
}
|
||||
inline bool TrainerSpec::has_character_coverage() const {
|
||||
@ -2362,7 +2347,7 @@ inline bool TrainerSpec::has_character_coverage() const {
|
||||
}
|
||||
inline void TrainerSpec::clear_character_coverage() {
|
||||
character_coverage_ = 0.9995f;
|
||||
_has_bits_[0] &= ~0x00200000u;
|
||||
_has_bits_[0] &= ~0x00100000u;
|
||||
}
|
||||
inline float TrainerSpec::_internal_character_coverage() const {
|
||||
return character_coverage_;
|
||||
@ -2372,7 +2357,7 @@ inline float TrainerSpec::character_coverage() const {
|
||||
return _internal_character_coverage();
|
||||
}
|
||||
inline void TrainerSpec::_internal_set_character_coverage(float value) {
|
||||
_has_bits_[0] |= 0x00200000u;
|
||||
_has_bits_[0] |= 0x00100000u;
|
||||
character_coverage_ = value;
|
||||
}
|
||||
inline void TrainerSpec::set_character_coverage(float value) {
|
||||
@ -2410,7 +2395,7 @@ inline void TrainerSpec::set_input_sentence_size(::PROTOBUF_NAMESPACE_ID::uint64
|
||||
|
||||
// optional bool shuffle_input_sentence = 19 [default = true];
|
||||
inline bool TrainerSpec::_internal_has_shuffle_input_sentence() const {
|
||||
bool value = (_has_bits_[0] & 0x10000000u) != 0;
|
||||
bool value = (_has_bits_[0] & 0x08000000u) != 0;
|
||||
return value;
|
||||
}
|
||||
inline bool TrainerSpec::has_shuffle_input_sentence() const {
|
||||
@ -2418,7 +2403,7 @@ inline bool TrainerSpec::has_shuffle_input_sentence() const {
|
||||
}
|
||||
inline void TrainerSpec::clear_shuffle_input_sentence() {
|
||||
shuffle_input_sentence_ = true;
|
||||
_has_bits_[0] &= ~0x10000000u;
|
||||
_has_bits_[0] &= ~0x08000000u;
|
||||
}
|
||||
inline bool TrainerSpec::_internal_shuffle_input_sentence() const {
|
||||
return shuffle_input_sentence_;
|
||||
@ -2428,7 +2413,7 @@ inline bool TrainerSpec::shuffle_input_sentence() const {
|
||||
return _internal_shuffle_input_sentence();
|
||||
}
|
||||
inline void TrainerSpec::_internal_set_shuffle_input_sentence(bool value) {
|
||||
_has_bits_[0] |= 0x10000000u;
|
||||
_has_bits_[0] |= 0x08000000u;
|
||||
shuffle_input_sentence_ = value;
|
||||
}
|
||||
inline void TrainerSpec::set_shuffle_input_sentence(bool value) {
|
||||
@ -2494,7 +2479,7 @@ inline void TrainerSpec::set_training_sentence_size(::PROTOBUF_NAMESPACE_ID::int
|
||||
|
||||
// optional int32 seed_sentencepiece_size = 14 [default = 1000000];
|
||||
inline bool TrainerSpec::_internal_has_seed_sentencepiece_size() const {
|
||||
bool value = (_has_bits_[0] & 0x00400000u) != 0;
|
||||
bool value = (_has_bits_[0] & 0x00200000u) != 0;
|
||||
return value;
|
||||
}
|
||||
inline bool TrainerSpec::has_seed_sentencepiece_size() const {
|
||||
@ -2502,7 +2487,7 @@ inline bool TrainerSpec::has_seed_sentencepiece_size() const {
|
||||
}
|
||||
inline void TrainerSpec::clear_seed_sentencepiece_size() {
|
||||
seed_sentencepiece_size_ = 1000000;
|
||||
_has_bits_[0] &= ~0x00400000u;
|
||||
_has_bits_[0] &= ~0x00200000u;
|
||||
}
|
||||
inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_seed_sentencepiece_size() const {
|
||||
return seed_sentencepiece_size_;
|
||||
@ -2512,7 +2497,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::seed_sentencepiece_size() con
|
||||
return _internal_seed_sentencepiece_size();
|
||||
}
|
||||
inline void TrainerSpec::_internal_set_seed_sentencepiece_size(::PROTOBUF_NAMESPACE_ID::int32 value) {
|
||||
_has_bits_[0] |= 0x00400000u;
|
||||
_has_bits_[0] |= 0x00200000u;
|
||||
seed_sentencepiece_size_ = value;
|
||||
}
|
||||
inline void TrainerSpec::set_seed_sentencepiece_size(::PROTOBUF_NAMESPACE_ID::int32 value) {
|
||||
@ -2522,7 +2507,7 @@ inline void TrainerSpec::set_seed_sentencepiece_size(::PROTOBUF_NAMESPACE_ID::in
|
||||
|
||||
// optional float shrinking_factor = 15 [default = 0.75];
|
||||
inline bool TrainerSpec::_internal_has_shrinking_factor() const {
|
||||
bool value = (_has_bits_[0] & 0x00800000u) != 0;
|
||||
bool value = (_has_bits_[0] & 0x00400000u) != 0;
|
||||
return value;
|
||||
}
|
||||
inline bool TrainerSpec::has_shrinking_factor() const {
|
||||
@ -2530,7 +2515,7 @@ inline bool TrainerSpec::has_shrinking_factor() const {
|
||||
}
|
||||
inline void TrainerSpec::clear_shrinking_factor() {
|
||||
shrinking_factor_ = 0.75f;
|
||||
_has_bits_[0] &= ~0x00800000u;
|
||||
_has_bits_[0] &= ~0x00400000u;
|
||||
}
|
||||
inline float TrainerSpec::_internal_shrinking_factor() const {
|
||||
return shrinking_factor_;
|
||||
@ -2540,7 +2525,7 @@ inline float TrainerSpec::shrinking_factor() const {
|
||||
return _internal_shrinking_factor();
|
||||
}
|
||||
inline void TrainerSpec::_internal_set_shrinking_factor(float value) {
|
||||
_has_bits_[0] |= 0x00800000u;
|
||||
_has_bits_[0] |= 0x00400000u;
|
||||
shrinking_factor_ = value;
|
||||
}
|
||||
inline void TrainerSpec::set_shrinking_factor(float value) {
|
||||
@ -2550,7 +2535,7 @@ inline void TrainerSpec::set_shrinking_factor(float value) {
|
||||
|
||||
// optional int32 max_sentence_length = 18 [default = 4192];
|
||||
inline bool TrainerSpec::_internal_has_max_sentence_length() const {
|
||||
bool value = (_has_bits_[0] & 0x04000000u) != 0;
|
||||
bool value = (_has_bits_[0] & 0x02000000u) != 0;
|
||||
return value;
|
||||
}
|
||||
inline bool TrainerSpec::has_max_sentence_length() const {
|
||||
@ -2558,7 +2543,7 @@ inline bool TrainerSpec::has_max_sentence_length() const {
|
||||
}
|
||||
inline void TrainerSpec::clear_max_sentence_length() {
|
||||
max_sentence_length_ = 4192;
|
||||
_has_bits_[0] &= ~0x04000000u;
|
||||
_has_bits_[0] &= ~0x02000000u;
|
||||
}
|
||||
inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_max_sentence_length() const {
|
||||
return max_sentence_length_;
|
||||
@ -2568,7 +2553,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::max_sentence_length() const {
|
||||
return _internal_max_sentence_length();
|
||||
}
|
||||
inline void TrainerSpec::_internal_set_max_sentence_length(::PROTOBUF_NAMESPACE_ID::int32 value) {
|
||||
_has_bits_[0] |= 0x04000000u;
|
||||
_has_bits_[0] |= 0x02000000u;
|
||||
max_sentence_length_ = value;
|
||||
}
|
||||
inline void TrainerSpec::set_max_sentence_length(::PROTOBUF_NAMESPACE_ID::int32 value) {
|
||||
@ -2578,7 +2563,7 @@ inline void TrainerSpec::set_max_sentence_length(::PROTOBUF_NAMESPACE_ID::int32
|
||||
|
||||
// optional int32 num_threads = 16 [default = 16];
|
||||
inline bool TrainerSpec::_internal_has_num_threads() const {
|
||||
bool value = (_has_bits_[0] & 0x01000000u) != 0;
|
||||
bool value = (_has_bits_[0] & 0x00800000u) != 0;
|
||||
return value;
|
||||
}
|
||||
inline bool TrainerSpec::has_num_threads() const {
|
||||
@ -2586,7 +2571,7 @@ inline bool TrainerSpec::has_num_threads() const {
|
||||
}
|
||||
inline void TrainerSpec::clear_num_threads() {
|
||||
num_threads_ = 16;
|
||||
_has_bits_[0] &= ~0x01000000u;
|
||||
_has_bits_[0] &= ~0x00800000u;
|
||||
}
|
||||
inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_num_threads() const {
|
||||
return num_threads_;
|
||||
@ -2596,7 +2581,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::num_threads() const {
|
||||
return _internal_num_threads();
|
||||
}
|
||||
inline void TrainerSpec::_internal_set_num_threads(::PROTOBUF_NAMESPACE_ID::int32 value) {
|
||||
_has_bits_[0] |= 0x01000000u;
|
||||
_has_bits_[0] |= 0x00800000u;
|
||||
num_threads_ = value;
|
||||
}
|
||||
inline void TrainerSpec::set_num_threads(::PROTOBUF_NAMESPACE_ID::int32 value) {
|
||||
@ -2606,7 +2591,7 @@ inline void TrainerSpec::set_num_threads(::PROTOBUF_NAMESPACE_ID::int32 value) {
|
||||
|
||||
// optional int32 num_sub_iterations = 17 [default = 2];
|
||||
inline bool TrainerSpec::_internal_has_num_sub_iterations() const {
|
||||
bool value = (_has_bits_[0] & 0x02000000u) != 0;
|
||||
bool value = (_has_bits_[0] & 0x01000000u) != 0;
|
||||
return value;
|
||||
}
|
||||
inline bool TrainerSpec::has_num_sub_iterations() const {
|
||||
@ -2614,7 +2599,7 @@ inline bool TrainerSpec::has_num_sub_iterations() const {
|
||||
}
|
||||
inline void TrainerSpec::clear_num_sub_iterations() {
|
||||
num_sub_iterations_ = 2;
|
||||
_has_bits_[0] &= ~0x02000000u;
|
||||
_has_bits_[0] &= ~0x01000000u;
|
||||
}
|
||||
inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_num_sub_iterations() const {
|
||||
return num_sub_iterations_;
|
||||
@ -2624,7 +2609,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::num_sub_iterations() const {
|
||||
return _internal_num_sub_iterations();
|
||||
}
|
||||
inline void TrainerSpec::_internal_set_num_sub_iterations(::PROTOBUF_NAMESPACE_ID::int32 value) {
|
||||
_has_bits_[0] |= 0x02000000u;
|
||||
_has_bits_[0] |= 0x01000000u;
|
||||
num_sub_iterations_ = value;
|
||||
}
|
||||
inline void TrainerSpec::set_num_sub_iterations(::PROTOBUF_NAMESPACE_ID::int32 value) {
|
||||
@ -2634,7 +2619,7 @@ inline void TrainerSpec::set_num_sub_iterations(::PROTOBUF_NAMESPACE_ID::int32 v
|
||||
|
||||
// optional int32 max_sentencepiece_length = 20 [default = 16];
|
||||
inline bool TrainerSpec::_internal_has_max_sentencepiece_length() const {
|
||||
bool value = (_has_bits_[0] & 0x08000000u) != 0;
|
||||
bool value = (_has_bits_[0] & 0x04000000u) != 0;
|
||||
return value;
|
||||
}
|
||||
inline bool TrainerSpec::has_max_sentencepiece_length() const {
|
||||
@ -2642,7 +2627,7 @@ inline bool TrainerSpec::has_max_sentencepiece_length() const {
|
||||
}
|
||||
inline void TrainerSpec::clear_max_sentencepiece_length() {
|
||||
max_sentencepiece_length_ = 16;
|
||||
_has_bits_[0] &= ~0x08000000u;
|
||||
_has_bits_[0] &= ~0x04000000u;
|
||||
}
|
||||
inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_max_sentencepiece_length() const {
|
||||
return max_sentencepiece_length_;
|
||||
@ -2652,7 +2637,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::max_sentencepiece_length() co
|
||||
return _internal_max_sentencepiece_length();
|
||||
}
|
||||
inline void TrainerSpec::_internal_set_max_sentencepiece_length(::PROTOBUF_NAMESPACE_ID::int32 value) {
|
||||
_has_bits_[0] |= 0x08000000u;
|
||||
_has_bits_[0] |= 0x04000000u;
|
||||
max_sentencepiece_length_ = value;
|
||||
}
|
||||
inline void TrainerSpec::set_max_sentencepiece_length(::PROTOBUF_NAMESPACE_ID::int32 value) {
|
||||
@ -2662,7 +2647,7 @@ inline void TrainerSpec::set_max_sentencepiece_length(::PROTOBUF_NAMESPACE_ID::i
|
||||
|
||||
// optional bool split_by_unicode_script = 21 [default = true];
|
||||
inline bool TrainerSpec::_internal_has_split_by_unicode_script() const {
|
||||
bool value = (_has_bits_[0] & 0x20000000u) != 0;
|
||||
bool value = (_has_bits_[0] & 0x10000000u) != 0;
|
||||
return value;
|
||||
}
|
||||
inline bool TrainerSpec::has_split_by_unicode_script() const {
|
||||
@ -2670,7 +2655,7 @@ inline bool TrainerSpec::has_split_by_unicode_script() const {
|
||||
}
|
||||
inline void TrainerSpec::clear_split_by_unicode_script() {
|
||||
split_by_unicode_script_ = true;
|
||||
_has_bits_[0] &= ~0x20000000u;
|
||||
_has_bits_[0] &= ~0x10000000u;
|
||||
}
|
||||
inline bool TrainerSpec::_internal_split_by_unicode_script() const {
|
||||
return split_by_unicode_script_;
|
||||
@ -2680,7 +2665,7 @@ inline bool TrainerSpec::split_by_unicode_script() const {
|
||||
return _internal_split_by_unicode_script();
|
||||
}
|
||||
inline void TrainerSpec::_internal_set_split_by_unicode_script(bool value) {
|
||||
_has_bits_[0] |= 0x20000000u;
|
||||
_has_bits_[0] |= 0x10000000u;
|
||||
split_by_unicode_script_ = value;
|
||||
}
|
||||
inline void TrainerSpec::set_split_by_unicode_script(bool value) {
|
||||
@ -2690,7 +2675,7 @@ inline void TrainerSpec::set_split_by_unicode_script(bool value) {
|
||||
|
||||
// optional bool split_by_number = 23 [default = true];
|
||||
inline bool TrainerSpec::_internal_has_split_by_number() const {
|
||||
bool value = (_has_bits_[0] & 0x40000000u) != 0;
|
||||
bool value = (_has_bits_[0] & 0x20000000u) != 0;
|
||||
return value;
|
||||
}
|
||||
inline bool TrainerSpec::has_split_by_number() const {
|
||||
@ -2698,7 +2683,7 @@ inline bool TrainerSpec::has_split_by_number() const {
|
||||
}
|
||||
inline void TrainerSpec::clear_split_by_number() {
|
||||
split_by_number_ = true;
|
||||
_has_bits_[0] &= ~0x40000000u;
|
||||
_has_bits_[0] &= ~0x20000000u;
|
||||
}
|
||||
inline bool TrainerSpec::_internal_split_by_number() const {
|
||||
return split_by_number_;
|
||||
@ -2708,7 +2693,7 @@ inline bool TrainerSpec::split_by_number() const {
|
||||
return _internal_split_by_number();
|
||||
}
|
||||
inline void TrainerSpec::_internal_set_split_by_number(bool value) {
|
||||
_has_bits_[0] |= 0x40000000u;
|
||||
_has_bits_[0] |= 0x20000000u;
|
||||
split_by_number_ = value;
|
||||
}
|
||||
inline void TrainerSpec::set_split_by_number(bool value) {
|
||||
@ -2718,7 +2703,7 @@ inline void TrainerSpec::set_split_by_number(bool value) {
|
||||
|
||||
// optional bool split_by_whitespace = 22 [default = true];
|
||||
inline bool TrainerSpec::_internal_has_split_by_whitespace() const {
|
||||
bool value = (_has_bits_[0] & 0x80000000u) != 0;
|
||||
bool value = (_has_bits_[0] & 0x40000000u) != 0;
|
||||
return value;
|
||||
}
|
||||
inline bool TrainerSpec::has_split_by_whitespace() const {
|
||||
@ -2726,7 +2711,7 @@ inline bool TrainerSpec::has_split_by_whitespace() const {
|
||||
}
|
||||
inline void TrainerSpec::clear_split_by_whitespace() {
|
||||
split_by_whitespace_ = true;
|
||||
_has_bits_[0] &= ~0x80000000u;
|
||||
_has_bits_[0] &= ~0x40000000u;
|
||||
}
|
||||
inline bool TrainerSpec::_internal_split_by_whitespace() const {
|
||||
return split_by_whitespace_;
|
||||
@ -2736,7 +2721,7 @@ inline bool TrainerSpec::split_by_whitespace() const {
|
||||
return _internal_split_by_whitespace();
|
||||
}
|
||||
inline void TrainerSpec::_internal_set_split_by_whitespace(bool value) {
|
||||
_has_bits_[0] |= 0x80000000u;
|
||||
_has_bits_[0] |= 0x40000000u;
|
||||
split_by_whitespace_ = value;
|
||||
}
|
||||
inline void TrainerSpec::set_split_by_whitespace(bool value) {
|
||||
@ -2772,37 +2757,9 @@ inline void TrainerSpec::set_treat_whitespace_as_suffix(bool value) {
|
||||
// @@protoc_insertion_point(field_set:sentencepiece.TrainerSpec.treat_whitespace_as_suffix)
|
||||
}
|
||||
|
||||
// optional bool allow_whitespace_only_pieces = 26 [default = false];
|
||||
inline bool TrainerSpec::_internal_has_allow_whitespace_only_pieces() const {
|
||||
bool value = (_has_bits_[0] & 0x00002000u) != 0;
|
||||
return value;
|
||||
}
|
||||
inline bool TrainerSpec::has_allow_whitespace_only_pieces() const {
|
||||
return _internal_has_allow_whitespace_only_pieces();
|
||||
}
|
||||
inline void TrainerSpec::clear_allow_whitespace_only_pieces() {
|
||||
allow_whitespace_only_pieces_ = false;
|
||||
_has_bits_[0] &= ~0x00002000u;
|
||||
}
|
||||
inline bool TrainerSpec::_internal_allow_whitespace_only_pieces() const {
|
||||
return allow_whitespace_only_pieces_;
|
||||
}
|
||||
inline bool TrainerSpec::allow_whitespace_only_pieces() const {
|
||||
// @@protoc_insertion_point(field_get:sentencepiece.TrainerSpec.allow_whitespace_only_pieces)
|
||||
return _internal_allow_whitespace_only_pieces();
|
||||
}
|
||||
inline void TrainerSpec::_internal_set_allow_whitespace_only_pieces(bool value) {
|
||||
_has_bits_[0] |= 0x00002000u;
|
||||
allow_whitespace_only_pieces_ = value;
|
||||
}
|
||||
inline void TrainerSpec::set_allow_whitespace_only_pieces(bool value) {
|
||||
_internal_set_allow_whitespace_only_pieces(value);
|
||||
// @@protoc_insertion_point(field_set:sentencepiece.TrainerSpec.allow_whitespace_only_pieces)
|
||||
}
|
||||
|
||||
// optional bool split_digits = 25 [default = false];
|
||||
inline bool TrainerSpec::_internal_has_split_digits() const {
|
||||
bool value = (_has_bits_[0] & 0x00004000u) != 0;
|
||||
bool value = (_has_bits_[0] & 0x00002000u) != 0;
|
||||
return value;
|
||||
}
|
||||
inline bool TrainerSpec::has_split_digits() const {
|
||||
@ -2810,7 +2767,7 @@ inline bool TrainerSpec::has_split_digits() const {
|
||||
}
|
||||
inline void TrainerSpec::clear_split_digits() {
|
||||
split_digits_ = false;
|
||||
_has_bits_[0] &= ~0x00004000u;
|
||||
_has_bits_[0] &= ~0x00002000u;
|
||||
}
|
||||
inline bool TrainerSpec::_internal_split_digits() const {
|
||||
return split_digits_;
|
||||
@ -2820,7 +2777,7 @@ inline bool TrainerSpec::split_digits() const {
|
||||
return _internal_split_digits();
|
||||
}
|
||||
inline void TrainerSpec::_internal_set_split_digits(bool value) {
|
||||
_has_bits_[0] |= 0x00004000u;
|
||||
_has_bits_[0] |= 0x00002000u;
|
||||
split_digits_ = value;
|
||||
}
|
||||
inline void TrainerSpec::set_split_digits(bool value) {
|
||||
@ -3051,7 +3008,7 @@ inline void TrainerSpec::set_allocated_required_chars(std::string* required_char
|
||||
|
||||
// optional bool byte_fallback = 35 [default = false];
|
||||
inline bool TrainerSpec::_internal_has_byte_fallback() const {
|
||||
bool value = (_has_bits_[0] & 0x00008000u) != 0;
|
||||
bool value = (_has_bits_[0] & 0x00004000u) != 0;
|
||||
return value;
|
||||
}
|
||||
inline bool TrainerSpec::has_byte_fallback() const {
|
||||
@ -3059,7 +3016,7 @@ inline bool TrainerSpec::has_byte_fallback() const {
|
||||
}
|
||||
inline void TrainerSpec::clear_byte_fallback() {
|
||||
byte_fallback_ = false;
|
||||
_has_bits_[0] &= ~0x00008000u;
|
||||
_has_bits_[0] &= ~0x00004000u;
|
||||
}
|
||||
inline bool TrainerSpec::_internal_byte_fallback() const {
|
||||
return byte_fallback_;
|
||||
@ -3069,7 +3026,7 @@ inline bool TrainerSpec::byte_fallback() const {
|
||||
return _internal_byte_fallback();
|
||||
}
|
||||
inline void TrainerSpec::_internal_set_byte_fallback(bool value) {
|
||||
_has_bits_[0] |= 0x00008000u;
|
||||
_has_bits_[0] |= 0x00004000u;
|
||||
byte_fallback_ = value;
|
||||
}
|
||||
inline void TrainerSpec::set_byte_fallback(bool value) {
|
||||
@ -3079,7 +3036,7 @@ inline void TrainerSpec::set_byte_fallback(bool value) {
|
||||
|
||||
// optional bool vocabulary_output_piece_score = 32 [default = true];
|
||||
inline bool TrainerSpec::_internal_has_vocabulary_output_piece_score() const {
|
||||
bool value = (_has_bits_[1] & 0x00000001u) != 0;
|
||||
bool value = (_has_bits_[0] & 0x80000000u) != 0;
|
||||
return value;
|
||||
}
|
||||
inline bool TrainerSpec::has_vocabulary_output_piece_score() const {
|
||||
@ -3087,7 +3044,7 @@ inline bool TrainerSpec::has_vocabulary_output_piece_score() const {
|
||||
}
|
||||
inline void TrainerSpec::clear_vocabulary_output_piece_score() {
|
||||
vocabulary_output_piece_score_ = true;
|
||||
_has_bits_[1] &= ~0x00000001u;
|
||||
_has_bits_[0] &= ~0x80000000u;
|
||||
}
|
||||
inline bool TrainerSpec::_internal_vocabulary_output_piece_score() const {
|
||||
return vocabulary_output_piece_score_;
|
||||
@ -3097,7 +3054,7 @@ inline bool TrainerSpec::vocabulary_output_piece_score() const {
|
||||
return _internal_vocabulary_output_piece_score();
|
||||
}
|
||||
inline void TrainerSpec::_internal_set_vocabulary_output_piece_score(bool value) {
|
||||
_has_bits_[1] |= 0x00000001u;
|
||||
_has_bits_[0] |= 0x80000000u;
|
||||
vocabulary_output_piece_score_ = value;
|
||||
}
|
||||
inline void TrainerSpec::set_vocabulary_output_piece_score(bool value) {
|
||||
@ -3107,7 +3064,7 @@ inline void TrainerSpec::set_vocabulary_output_piece_score(bool value) {
|
||||
|
||||
// optional bool hard_vocab_limit = 33 [default = true];
|
||||
inline bool TrainerSpec::_internal_has_hard_vocab_limit() const {
|
||||
bool value = (_has_bits_[1] & 0x00000002u) != 0;
|
||||
bool value = (_has_bits_[1] & 0x00000001u) != 0;
|
||||
return value;
|
||||
}
|
||||
inline bool TrainerSpec::has_hard_vocab_limit() const {
|
||||
@ -3115,7 +3072,7 @@ inline bool TrainerSpec::has_hard_vocab_limit() const {
|
||||
}
|
||||
inline void TrainerSpec::clear_hard_vocab_limit() {
|
||||
hard_vocab_limit_ = true;
|
||||
_has_bits_[1] &= ~0x00000002u;
|
||||
_has_bits_[1] &= ~0x00000001u;
|
||||
}
|
||||
inline bool TrainerSpec::_internal_hard_vocab_limit() const {
|
||||
return hard_vocab_limit_;
|
||||
@ -3125,7 +3082,7 @@ inline bool TrainerSpec::hard_vocab_limit() const {
|
||||
return _internal_hard_vocab_limit();
|
||||
}
|
||||
inline void TrainerSpec::_internal_set_hard_vocab_limit(bool value) {
|
||||
_has_bits_[1] |= 0x00000002u;
|
||||
_has_bits_[1] |= 0x00000001u;
|
||||
hard_vocab_limit_ = value;
|
||||
}
|
||||
inline void TrainerSpec::set_hard_vocab_limit(bool value) {
|
||||
@ -3135,7 +3092,7 @@ inline void TrainerSpec::set_hard_vocab_limit(bool value) {
|
||||
|
||||
// optional bool use_all_vocab = 34 [default = false];
|
||||
inline bool TrainerSpec::_internal_has_use_all_vocab() const {
|
||||
bool value = (_has_bits_[0] & 0x00020000u) != 0;
|
||||
bool value = (_has_bits_[0] & 0x00008000u) != 0;
|
||||
return value;
|
||||
}
|
||||
inline bool TrainerSpec::has_use_all_vocab() const {
|
||||
@ -3143,7 +3100,7 @@ inline bool TrainerSpec::has_use_all_vocab() const {
|
||||
}
|
||||
inline void TrainerSpec::clear_use_all_vocab() {
|
||||
use_all_vocab_ = false;
|
||||
_has_bits_[0] &= ~0x00020000u;
|
||||
_has_bits_[0] &= ~0x00008000u;
|
||||
}
|
||||
inline bool TrainerSpec::_internal_use_all_vocab() const {
|
||||
return use_all_vocab_;
|
||||
@ -3153,7 +3110,7 @@ inline bool TrainerSpec::use_all_vocab() const {
|
||||
return _internal_use_all_vocab();
|
||||
}
|
||||
inline void TrainerSpec::_internal_set_use_all_vocab(bool value) {
|
||||
_has_bits_[0] |= 0x00020000u;
|
||||
_has_bits_[0] |= 0x00008000u;
|
||||
use_all_vocab_ = value;
|
||||
}
|
||||
inline void TrainerSpec::set_use_all_vocab(bool value) {
|
||||
@ -3191,7 +3148,7 @@ inline void TrainerSpec::set_unk_id(::PROTOBUF_NAMESPACE_ID::int32 value) {
|
||||
|
||||
// optional int32 bos_id = 41 [default = 1];
|
||||
inline bool TrainerSpec::_internal_has_bos_id() const {
|
||||
bool value = (_has_bits_[1] & 0x00000004u) != 0;
|
||||
bool value = (_has_bits_[1] & 0x00000002u) != 0;
|
||||
return value;
|
||||
}
|
||||
inline bool TrainerSpec::has_bos_id() const {
|
||||
@ -3199,7 +3156,7 @@ inline bool TrainerSpec::has_bos_id() const {
|
||||
}
|
||||
inline void TrainerSpec::clear_bos_id() {
|
||||
bos_id_ = 1;
|
||||
_has_bits_[1] &= ~0x00000004u;
|
||||
_has_bits_[1] &= ~0x00000002u;
|
||||
}
|
||||
inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_bos_id() const {
|
||||
return bos_id_;
|
||||
@ -3209,7 +3166,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::bos_id() const {
|
||||
return _internal_bos_id();
|
||||
}
|
||||
inline void TrainerSpec::_internal_set_bos_id(::PROTOBUF_NAMESPACE_ID::int32 value) {
|
||||
_has_bits_[1] |= 0x00000004u;
|
||||
_has_bits_[1] |= 0x00000002u;
|
||||
bos_id_ = value;
|
||||
}
|
||||
inline void TrainerSpec::set_bos_id(::PROTOBUF_NAMESPACE_ID::int32 value) {
|
||||
@ -3219,7 +3176,7 @@ inline void TrainerSpec::set_bos_id(::PROTOBUF_NAMESPACE_ID::int32 value) {
|
||||
|
||||
// optional int32 eos_id = 42 [default = 2];
|
||||
inline bool TrainerSpec::_internal_has_eos_id() const {
|
||||
bool value = (_has_bits_[1] & 0x00000008u) != 0;
|
||||
bool value = (_has_bits_[1] & 0x00000004u) != 0;
|
||||
return value;
|
||||
}
|
||||
inline bool TrainerSpec::has_eos_id() const {
|
||||
@ -3227,7 +3184,7 @@ inline bool TrainerSpec::has_eos_id() const {
|
||||
}
|
||||
inline void TrainerSpec::clear_eos_id() {
|
||||
eos_id_ = 2;
|
||||
_has_bits_[1] &= ~0x00000008u;
|
||||
_has_bits_[1] &= ~0x00000004u;
|
||||
}
|
||||
inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_eos_id() const {
|
||||
return eos_id_;
|
||||
@ -3237,7 +3194,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::eos_id() const {
|
||||
return _internal_eos_id();
|
||||
}
|
||||
inline void TrainerSpec::_internal_set_eos_id(::PROTOBUF_NAMESPACE_ID::int32 value) {
|
||||
_has_bits_[1] |= 0x00000008u;
|
||||
_has_bits_[1] |= 0x00000004u;
|
||||
eos_id_ = value;
|
||||
}
|
||||
inline void TrainerSpec::set_eos_id(::PROTOBUF_NAMESPACE_ID::int32 value) {
|
||||
@ -3247,7 +3204,7 @@ inline void TrainerSpec::set_eos_id(::PROTOBUF_NAMESPACE_ID::int32 value) {
|
||||
|
||||
// optional int32 pad_id = 43 [default = -1];
|
||||
inline bool TrainerSpec::_internal_has_pad_id() const {
|
||||
bool value = (_has_bits_[1] & 0x00000010u) != 0;
|
||||
bool value = (_has_bits_[1] & 0x00000008u) != 0;
|
||||
return value;
|
||||
}
|
||||
inline bool TrainerSpec::has_pad_id() const {
|
||||
@ -3255,7 +3212,7 @@ inline bool TrainerSpec::has_pad_id() const {
|
||||
}
|
||||
inline void TrainerSpec::clear_pad_id() {
|
||||
pad_id_ = -1;
|
||||
_has_bits_[1] &= ~0x00000010u;
|
||||
_has_bits_[1] &= ~0x00000008u;
|
||||
}
|
||||
inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_pad_id() const {
|
||||
return pad_id_;
|
||||
@ -3265,7 +3222,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::pad_id() const {
|
||||
return _internal_pad_id();
|
||||
}
|
||||
inline void TrainerSpec::_internal_set_pad_id(::PROTOBUF_NAMESPACE_ID::int32 value) {
|
||||
_has_bits_[1] |= 0x00000010u;
|
||||
_has_bits_[1] |= 0x00000008u;
|
||||
pad_id_ = value;
|
||||
}
|
||||
inline void TrainerSpec::set_pad_id(::PROTOBUF_NAMESPACE_ID::int32 value) {
|
||||
@ -3645,7 +3602,7 @@ inline void TrainerSpec::set_allocated_unk_surface(std::string* unk_surface) {
|
||||
|
||||
// optional bool train_extremely_large_corpus = 49 [default = false];
|
||||
inline bool TrainerSpec::_internal_has_train_extremely_large_corpus() const {
|
||||
bool value = (_has_bits_[0] & 0x00040000u) != 0;
|
||||
bool value = (_has_bits_[0] & 0x00020000u) != 0;
|
||||
return value;
|
||||
}
|
||||
inline bool TrainerSpec::has_train_extremely_large_corpus() const {
|
||||
@ -3653,7 +3610,7 @@ inline bool TrainerSpec::has_train_extremely_large_corpus() const {
|
||||
}
|
||||
inline void TrainerSpec::clear_train_extremely_large_corpus() {
|
||||
train_extremely_large_corpus_ = false;
|
||||
_has_bits_[0] &= ~0x00040000u;
|
||||
_has_bits_[0] &= ~0x00020000u;
|
||||
}
|
||||
inline bool TrainerSpec::_internal_train_extremely_large_corpus() const {
|
||||
return train_extremely_large_corpus_;
|
||||
@ -3663,7 +3620,7 @@ inline bool TrainerSpec::train_extremely_large_corpus() const {
|
||||
return _internal_train_extremely_large_corpus();
|
||||
}
|
||||
inline void TrainerSpec::_internal_set_train_extremely_large_corpus(bool value) {
|
||||
_has_bits_[0] |= 0x00040000u;
|
||||
_has_bits_[0] |= 0x00020000u;
|
||||
train_extremely_large_corpus_ = value;
|
||||
}
|
||||
inline void TrainerSpec::set_train_extremely_large_corpus(bool value) {
|
||||
|
@ -134,53 +134,32 @@ void ModelInterface::InitializePieces() {
|
||||
}
|
||||
|
||||
std::vector<absl::string_view> SplitIntoWords(absl::string_view text,
|
||||
bool treat_ws_as_suffix,
|
||||
bool allow_ws_only_pieces) {
|
||||
bool treat_whitespace_as_suffix) {
|
||||
const char *begin = text.data();
|
||||
const char *end = text.data() + text.size();
|
||||
|
||||
// Space symbol (U+2581)
|
||||
const absl::string_view kSpaceSymbol = "\xe2\x96\x81";
|
||||
bool in_ws_sequence = false;
|
||||
|
||||
std::vector<absl::string_view> result;
|
||||
if (treat_ws_as_suffix) { // put ws tokens at the end of non-ws sequences.
|
||||
if (treat_whitespace_as_suffix) {
|
||||
if (begin < end) result.emplace_back(begin, 0);
|
||||
while (begin < end) {
|
||||
const int mblen =
|
||||
std::min<int>(string_util::OneCharLen(begin), end - begin);
|
||||
const bool is_ws = absl::string_view(begin, mblen) == kSpaceSymbol;
|
||||
|
||||
if (is_ws) { // keep track of sequences consecutive ws tokens.
|
||||
in_ws_sequence = true;
|
||||
} else if (in_ws_sequence) {
|
||||
if (allow_ws_only_pieces) result.emplace_back(begin, 0);
|
||||
|
||||
in_ws_sequence = false;
|
||||
}
|
||||
|
||||
result.back() =
|
||||
absl::string_view(result.back().data(), result.back().size() + mblen);
|
||||
begin += mblen;
|
||||
|
||||
if (begin < end && is_ws && !allow_ws_only_pieces)
|
||||
result.emplace_back(begin, 0);
|
||||
if (begin < end && is_ws) result.emplace_back(begin, 0);
|
||||
}
|
||||
} else {
|
||||
while (begin < end) {
|
||||
const int mblen =
|
||||
std::min<int>(string_util::OneCharLen(begin), end - begin);
|
||||
bool is_ws = absl::string_view(begin, mblen) == kSpaceSymbol;
|
||||
|
||||
// if is whitespace (and not in sequence if allow_ws_only_pieces is True)
|
||||
if (begin == text.data() ||
|
||||
(is_ws && (!in_ws_sequence || !allow_ws_only_pieces))) {
|
||||
absl::string_view(begin, mblen) == kSpaceSymbol)
|
||||
result.emplace_back(begin, 0); // add empty string piece.
|
||||
in_ws_sequence = true;
|
||||
}
|
||||
|
||||
if (in_ws_sequence && !is_ws) in_ws_sequence = false;
|
||||
|
||||
result.back() =
|
||||
absl::string_view(result.back().data(), result.back().size() + mblen);
|
||||
begin += mblen;
|
||||
|
@ -33,9 +33,8 @@
|
||||
namespace sentencepiece {
|
||||
|
||||
// "_this_is_a_pen" => ["_this", "_is", "_a", "_pen"]
|
||||
std::vector<absl::string_view> SplitIntoWords(
|
||||
absl::string_view text, bool treat_ws_as_suffix = false,
|
||||
bool allow_ws_only_pieces = false);
|
||||
std::vector<absl::string_view> SplitIntoWords(absl::string_view text,
|
||||
bool add_ws_as_suffix = false);
|
||||
|
||||
// Converts byte (0-255) to piece (e.g., 58 -> "<0x3A>").
|
||||
std::string ByteToPiece(unsigned char c);
|
||||
@ -107,42 +106,12 @@ class ModelInterface {
|
||||
return EncodeResult();
|
||||
}
|
||||
|
||||
// Sample `samples` many tokenisations from the segmentation lattice
|
||||
// If `wor` is true, the samples are taken without replacement, and the scores
|
||||
// are the inclusion probabilities of the elements in the sample; otherwise
|
||||
// the samples are taken with replacement and the scores are the log-probs of
|
||||
// sample elements
|
||||
// If `include_best` is true, the best tokenisation is always included in the
|
||||
// sample, and the remaining elements are sampled excluding the best.
|
||||
virtual NBestEncodeResult SampleEncodeAndScore(absl::string_view normalized,
|
||||
float alpha, int samples,
|
||||
bool wor,
|
||||
bool include_best) const {
|
||||
LOG(ERROR) << "Not implemented.";
|
||||
return {{EncodeResult(), 0.0}};
|
||||
}
|
||||
|
||||
// Calculates the entropy of the segmentation lattice with inverse temperature
|
||||
// `theta`.
|
||||
// Uses a novel dynamic program to calculate the entropy.
|
||||
virtual float CalculateEntropy(absl::string_view normalized,
|
||||
float theta) const {
|
||||
LOG(ERROR) << "Not implemented.";
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
// Return true if SampleEncode returns a valid result.
|
||||
virtual bool IsSampleEncodeAvailable() const { return false; }
|
||||
|
||||
// Return true if NBestEncode returns a valid result.
|
||||
virtual bool IsNBestEncodeAvailable() const { return false; }
|
||||
|
||||
// Return true if SampleEncodeAndScore returns a valid result.
|
||||
virtual bool IsSampleEncodeAndScoreAvailable() const { return false; }
|
||||
|
||||
// Return true if CalculateEntropy returns a valid result.
|
||||
virtual bool IsCalculateEntropyAvailable() const { return false; }
|
||||
|
||||
// Returns the vocab id of `piece`.
|
||||
// Returns UNK(0) if `piece` is unknown
|
||||
virtual int PieceToId(absl::string_view piece) const;
|
||||
@ -155,10 +124,7 @@ class ModelInterface {
|
||||
|
||||
// Returns the size of sentence pieces, which is the same
|
||||
// as the size of vocabulary for NMT.
|
||||
virtual int GetPieceSize() const {
|
||||
if (!model_proto_) return 0;
|
||||
return model_proto_->pieces_size();
|
||||
}
|
||||
virtual int GetPieceSize() const { return model_proto_->pieces_size(); }
|
||||
|
||||
// Returns the score of `id`.
|
||||
// Score represents a log probability of the piece.
|
||||
|
@ -412,50 +412,6 @@ TEST(ModelInterfaceTest, SplitIntoWordsSuffixTest) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ModelInterfaceTest, SplitIntoWordsWhiteSpaceOnly) {
|
||||
{
|
||||
const auto v =
|
||||
SplitIntoWords("this" WS "is" WS "a" WS "pen" WS, true, true);
|
||||
EXPECT_EQ(4, v.size());
|
||||
EXPECT_EQ("this" WS, v[0]);
|
||||
EXPECT_EQ("is" WS, v[1]);
|
||||
EXPECT_EQ("a" WS, v[2]);
|
||||
EXPECT_EQ("pen" WS, v[3]);
|
||||
}
|
||||
|
||||
{
|
||||
const auto v = SplitIntoWords(WS WS WS "a", false, true);
|
||||
EXPECT_EQ(1, v.size());
|
||||
EXPECT_EQ(WS WS WS "a", v[0]);
|
||||
}
|
||||
|
||||
{
|
||||
const auto v = SplitIntoWords("a" WS WS WS, true, true);
|
||||
EXPECT_EQ(1, v.size());
|
||||
EXPECT_EQ("a" WS WS WS, v[0]);
|
||||
}
|
||||
|
||||
{
|
||||
const auto v = SplitIntoWords(WS WS, true, true);
|
||||
EXPECT_EQ(1, v.size());
|
||||
EXPECT_EQ(WS WS, v[0]);
|
||||
}
|
||||
|
||||
{
|
||||
const auto v = SplitIntoWords(WS WS "a" WS, true, true);
|
||||
EXPECT_EQ(2, v.size());
|
||||
EXPECT_EQ(WS WS, v[0]);
|
||||
EXPECT_EQ("a" WS, v[1]);
|
||||
}
|
||||
|
||||
{
|
||||
const auto v = SplitIntoWords(WS WS "a" WS, false, true);
|
||||
EXPECT_EQ(2, v.size());
|
||||
EXPECT_EQ(WS WS "a", v[0]);
|
||||
EXPECT_EQ(WS, v[1]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ModelInterfaceTest, ByteToPieceTest) {
|
||||
EXPECT_EQ(ByteToPiece(0), "<0x00>");
|
||||
EXPECT_EQ(ByteToPiece(1), "<0x01>");
|
||||
|
@ -12,11 +12,12 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.!
|
||||
|
||||
#include "normalizer.h"
|
||||
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "common.h"
|
||||
#include "normalizer.h"
|
||||
#include "third_party/absl/memory/memory.h"
|
||||
#include "third_party/absl/strings/match.h"
|
||||
#include "third_party/absl/strings/string_view.h"
|
||||
@ -277,11 +278,11 @@ util::Status Normalizer::DecodePrecompiledCharsMap(
|
||||
absl::string_view blob, absl::string_view *trie_blob,
|
||||
absl::string_view *normalized, std::string *buffer) {
|
||||
uint32 trie_blob_size = 0;
|
||||
|
||||
if (blob.size() <= sizeof(trie_blob_size) ||
|
||||
!string_util::DecodePOD<uint32>(
|
||||
absl::string_view(blob.data(), sizeof(trie_blob_size)),
|
||||
&trie_blob_size) ||
|
||||
trie_blob_size >= blob.size()) {
|
||||
&trie_blob_size)) {
|
||||
return util::InternalError("Blob for normalization rule is broken.");
|
||||
}
|
||||
|
||||
|
@ -22,6 +22,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include "common.h"
|
||||
#include "util.h"
|
||||
#include "sentencepiece_model.pb.h"
|
||||
#include "sentencepiece_processor.h"
|
||||
#include "third_party/absl/strings/string_view.h"
|
||||
|
@ -12,10 +12,11 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.!
|
||||
|
||||
#include "normalizer.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "builder.h"
|
||||
#include "normalizer.h"
|
||||
#include "sentencepiece_trainer.h"
|
||||
#include "testharness.h"
|
||||
#include "util.h"
|
||||
|
@ -139,10 +139,6 @@ message TrainerSpec {
|
||||
// of sentence.
|
||||
optional bool treat_whitespace_as_suffix = 24 [default = false];
|
||||
|
||||
// Allows pieces that only contain whitespaces instead of appearing only as
|
||||
// prefix or suffix of other pieces.
|
||||
optional bool allow_whitespace_only_pieces = 26 [default = false];
|
||||
|
||||
// Split all digits (0-9) into separate pieces.
|
||||
optional bool split_digits = 25 [default = false];
|
||||
|
||||
|
@ -12,6 +12,8 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.!
|
||||
|
||||
#include "sentencepiece_processor.h"
|
||||
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <utility>
|
||||
@ -22,7 +24,6 @@
|
||||
#include "model_interface.h"
|
||||
#include "normalizer.h"
|
||||
#include "sentencepiece.pb.h"
|
||||
#include "sentencepiece_processor.h"
|
||||
#include "third_party/absl/memory/memory.h"
|
||||
#include "third_party/absl/strings/numbers.h"
|
||||
#include "third_party/absl/strings/str_cat.h"
|
||||
@ -503,43 +504,6 @@ util::Status SentencePieceProcessor::SampleEncode(
|
||||
return util::OkStatus();
|
||||
}
|
||||
|
||||
util::Status SentencePieceProcessor::SampleEncodeAndScore(
|
||||
absl::string_view input, int samples, float theta, bool wor,
|
||||
bool include_best, NBestSentencePieceText *samples_spt) const {
|
||||
CHECK_OR_RETURN(model_->IsSampleEncodeAndScoreAvailable())
|
||||
<< "SampleEncodeAndScore is not available for the current model.";
|
||||
std::string normalized;
|
||||
std::vector<size_t> norm_to_orig;
|
||||
RETURN_IF_ERROR(normalizer_->Normalize(input, &normalized, &norm_to_orig));
|
||||
|
||||
const auto results = model_->SampleEncodeAndScore(normalized, theta, samples,
|
||||
wor, include_best);
|
||||
CHECK_OR_RETURN(!results.empty())
|
||||
<< "SampleEncodeAndScore returns empty result.";
|
||||
|
||||
for (const auto &result : results) {
|
||||
auto *spt = samples_spt->add_nbests();
|
||||
spt->set_score(result.second);
|
||||
RETURN_IF_ERROR(PopulateSentencePieceText(input, normalized, norm_to_orig,
|
||||
result.first, spt));
|
||||
}
|
||||
|
||||
return util::OkStatus();
|
||||
}
|
||||
|
||||
util::Status SentencePieceProcessor::CalculateEntropy(absl::string_view input,
|
||||
float theta,
|
||||
float *entropy) const {
|
||||
CHECK_OR_RETURN(model_->IsCalculateEntropyAvailable())
|
||||
<< "CalculateEntropy is not available for the current model.";
|
||||
std::string normalized;
|
||||
std::vector<size_t> norm_to_orig;
|
||||
RETURN_IF_ERROR(normalizer_->Normalize(input, &normalized, &norm_to_orig));
|
||||
|
||||
*entropy = model_->CalculateEntropy(normalized, theta);
|
||||
return util::OkStatus();
|
||||
}
|
||||
|
||||
util::Status SentencePieceProcessor::Decode(
|
||||
const std::vector<std::string> &pieces, SentencePieceText *spt) const {
|
||||
CHECK_OR_RETURN_STATUS_PROTO(spt);
|
||||
@ -869,12 +833,6 @@ std::string SentencePieceProcessor::serialized_model_proto() const {
|
||||
return model_proto_ ? model_proto_->SerializeAsString() : "";
|
||||
}
|
||||
|
||||
// Set seed value of random generator.
|
||||
// Do not set static_cast<unique_int>(-1),
|
||||
// as this seed is reserved for initializing from
|
||||
// std::random_device.
|
||||
void SetRandomGeneratorSeed(unsigned int seed);
|
||||
|
||||
namespace io {
|
||||
|
||||
util::Status LoadModelProto(absl::string_view filename,
|
||||
|
@ -315,15 +315,6 @@ class SentencePieceProcessor {
|
||||
virtual util::Status SampleEncode(absl::string_view input, int nbest_size,
|
||||
float alpha, SentencePieceText *spt) const;
|
||||
|
||||
// Samples N segmentation and returns the scores as well
|
||||
virtual util::Status SampleEncodeAndScore(
|
||||
absl::string_view input, int samples, float theta, bool wor,
|
||||
bool include_best, NBestSentencePieceText *samples_spt) const;
|
||||
|
||||
// Calculate entropy of possible tokenisations
|
||||
virtual util::Status CalculateEntropy(absl::string_view input, float theta,
|
||||
float *entropy) const;
|
||||
|
||||
// Given a sequence of pieces, decodes it into SentencePieceText.
|
||||
virtual util::Status Decode(const std::vector<std::string> &pieces,
|
||||
SentencePieceText *spt) const;
|
||||
|
@ -12,6 +12,8 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.!
|
||||
|
||||
#include "sentencepiece_processor.h"
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "builder.h"
|
||||
@ -20,7 +22,6 @@
|
||||
#include "normalizer.h"
|
||||
#include "sentencepiece.pb.h"
|
||||
#include "sentencepiece_model.pb.h"
|
||||
#include "sentencepiece_processor.h"
|
||||
#include "sentencepiece_trainer.h"
|
||||
#include "testharness.h"
|
||||
#include "third_party/absl/container/flat_hash_map.h"
|
||||
@ -1138,6 +1139,13 @@ TEST(SentencePieceProcessorTest, EndToEndTest) {
|
||||
EXPECT_EQ("cba", output);
|
||||
}
|
||||
|
||||
// Out of range
|
||||
{
|
||||
std::string output;
|
||||
const std::vector<int> ids = {3, 4, 127};
|
||||
EXPECT_FALSE(sp.Decode(ids, &output).ok());
|
||||
}
|
||||
|
||||
{
|
||||
EXPECT_TRUE(sp.SetDecodeExtraOptions("bos:eos").ok());
|
||||
|
||||
@ -1164,13 +1172,6 @@ TEST(SentencePieceProcessorTest, EndToEndTest) {
|
||||
EXPECT_EQ("cba", output);
|
||||
}
|
||||
|
||||
// Out of range
|
||||
{
|
||||
std::string output;
|
||||
const std::vector<int> ids = {3, 4, 127};
|
||||
EXPECT_FALSE(sp.Decode(ids, &output).ok());
|
||||
}
|
||||
|
||||
{
|
||||
EXPECT_TRUE(sp.SetDecodeExtraOptions("bos:eos:reverse").ok());
|
||||
|
||||
|
@ -12,6 +12,8 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.!
|
||||
|
||||
#include "sentencepiece_trainer.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
@ -20,9 +22,7 @@
|
||||
#include "normalizer.h"
|
||||
#include "sentencepiece.pb.h"
|
||||
#include "sentencepiece_model.pb.h"
|
||||
#include "sentencepiece_trainer.h"
|
||||
#include "spec_parser.h"
|
||||
#include "third_party/absl/flags/flag.h"
|
||||
#include "third_party/absl/strings/numbers.h"
|
||||
#include "third_party/absl/strings/str_cat.h"
|
||||
#include "third_party/absl/strings/str_split.h"
|
||||
@ -31,8 +31,6 @@
|
||||
#include "trainer_factory.h"
|
||||
#include "util.h"
|
||||
|
||||
ABSL_DECLARE_FLAG(int, minloglevel);
|
||||
|
||||
namespace sentencepiece {
|
||||
namespace {
|
||||
static constexpr char kDefaultNormalizerName[] = "nmt_nfkc";
|
||||
@ -112,7 +110,7 @@ util::Status SentencePieceTrainer::MergeSpecsFromArgs(
|
||||
for (auto arg : absl::StrSplit(args, " ")) {
|
||||
absl::ConsumePrefix(&arg, "--");
|
||||
std::string key, value;
|
||||
const auto pos = arg.find('=');
|
||||
const auto pos = arg.find("=");
|
||||
if (pos == absl::string_view::npos) {
|
||||
key = std::string(arg);
|
||||
} else {
|
||||
@ -151,7 +149,7 @@ util::Status SentencePieceTrainer::MergeSpecsFromArgs(
|
||||
} else if (key == "minloglevel") {
|
||||
int v = 0;
|
||||
CHECK_OR_RETURN(absl::SimpleAtoi(value, &v));
|
||||
absl::SetFlag(&FLAGS_minloglevel, v);
|
||||
logging::SetMinLogLevel(v);
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -145,7 +145,6 @@ inline std::string PrintProto(const TrainerSpec &message,
|
||||
PRINT_PARAM(split_by_whitespace);
|
||||
PRINT_PARAM(split_digits);
|
||||
PRINT_PARAM(treat_whitespace_as_suffix);
|
||||
PRINT_PARAM(allow_whitespace_only_pieces);
|
||||
PRINT_REPEATED_STRING(control_symbols);
|
||||
PRINT_REPEATED_STRING(user_defined_symbols);
|
||||
PRINT_PARAM(required_chars);
|
||||
@ -220,7 +219,6 @@ util::Status SentencePieceTrainer::SetProtoField(const std::string &name,
|
||||
PARSE_BOOL(split_by_whitespace);
|
||||
PARSE_BOOL(split_digits);
|
||||
PARSE_BOOL(treat_whitespace_as_suffix);
|
||||
PARSE_BOOL(allow_whitespace_only_pieces);
|
||||
PARSE_REPEATED_STRING(control_symbols);
|
||||
PARSE_REPEATED_STRING(user_defined_symbols);
|
||||
PARSE_STRING(required_chars);
|
||||
|
@ -64,7 +64,6 @@ int main(int argc, char *argv[]) {
|
||||
|
||||
auto ToIds = [&](const std::vector<std::string> &pieces) {
|
||||
std::vector<int> ids;
|
||||
ids.reserve(pieces.size());
|
||||
for (const auto &s : pieces) {
|
||||
ids.push_back(atoi(s.c_str()));
|
||||
}
|
||||
|
@ -28,17 +28,16 @@
|
||||
#include "trainer_interface.h"
|
||||
|
||||
ABSL_FLAG(std::string, model, "", "model file name");
|
||||
ABSL_FLAG(
|
||||
std::string, output_format, "piece",
|
||||
"choose from piece, id, proto, nbest_piece, nbest_id, or nbest_proto");
|
||||
ABSL_FLAG(std::string, output_format, "piece",
|
||||
"choose from piece, id, proto, nbest_piece, nbest_id, nbest_proto, "
|
||||
"sample_piece, sample_id or sample_proto.");
|
||||
ABSL_FLAG(std::string, input, "", "input filename");
|
||||
ABSL_FLAG(std::string, output, "", "output filename");
|
||||
ABSL_FLAG(std::string, extra_options, "",
|
||||
"':' separated encoder extra options, e.g., \"reverse:bos:eos\"");
|
||||
ABSL_FLAG(int32, nbest_size, 10, "NBest size");
|
||||
ABSL_FLAG(double, alpha, 0.5, "Smoothing parameter for sampling mode.");
|
||||
ABSL_FLAG(uint32, random_seed, static_cast<uint32>(-1),
|
||||
"Seed value for random generator.");
|
||||
ABSL_FLAG(int32, random_seed, -1, "Seed value for random generator.");
|
||||
|
||||
// Piece restriction with vocabulary file.
|
||||
// https://github.com/rsennrich/subword-nmt#best-practice-advice-for-byte-pair-encoding-in-nmt
|
||||
@ -62,9 +61,8 @@ int main(int argc, char *argv[]) {
|
||||
rest_args.push_back(absl::GetFlag(FLAGS_input));
|
||||
}
|
||||
|
||||
if (absl::GetFlag(FLAGS_random_seed) != -1) {
|
||||
if (absl::GetFlag(FLAGS_random_seed) != -1)
|
||||
sentencepiece::SetRandomGeneratorSeed(absl::GetFlag(FLAGS_random_seed));
|
||||
}
|
||||
|
||||
if (rest_args.empty())
|
||||
rest_args.push_back(""); // empty means that reading from stdin.
|
||||
|
@ -80,9 +80,6 @@ ABSL_FLAG(bool, split_digits, kDefaultTrainerSpec.split_digits(),
|
||||
ABSL_FLAG(bool, treat_whitespace_as_suffix,
|
||||
kDefaultTrainerSpec.treat_whitespace_as_suffix(),
|
||||
"treat whitespace marker as suffix instead of prefix.");
|
||||
ABSL_FLAG(bool, allow_whitespace_only_pieces,
|
||||
kDefaultTrainerSpec.allow_whitespace_only_pieces(),
|
||||
"allow pieces that only contain (consecutive) whitespace tokens");
|
||||
ABSL_FLAG(std::string, control_symbols, "",
|
||||
"comma separated list of control symbols");
|
||||
ABSL_FLAG(std::string, control_symbols_file, "",
|
||||
@ -141,8 +138,7 @@ ABSL_FLAG(std::string, unk_surface, kDefaultTrainerSpec.unk_surface(),
|
||||
ABSL_FLAG(bool, train_extremely_large_corpus,
|
||||
kDefaultTrainerSpec.train_extremely_large_corpus(),
|
||||
"Increase bit depth for unigram tokenization.");
|
||||
ABSL_FLAG(uint32, random_seed, static_cast<uint32>(-1),
|
||||
"Seed value for random generator.");
|
||||
ABSL_FLAG(int32, random_seed, -1, "Seed value for random generator.");
|
||||
|
||||
int main(int argc, char *argv[]) {
|
||||
sentencepiece::ParseCommandLineFlags(argv[0], &argc, &argv, true);
|
||||
@ -154,9 +150,8 @@ int main(int argc, char *argv[]) {
|
||||
CHECK(!absl::GetFlag(FLAGS_input).empty());
|
||||
CHECK(!absl::GetFlag(FLAGS_model_prefix).empty());
|
||||
|
||||
if (absl::GetFlag(FLAGS_random_seed) != -1) {
|
||||
if (absl::GetFlag(FLAGS_random_seed) != -1)
|
||||
sentencepiece::SetRandomGeneratorSeed(absl::GetFlag(FLAGS_random_seed));
|
||||
}
|
||||
|
||||
auto load_lines = [](absl::string_view filename) {
|
||||
std::vector<std::string> lines;
|
||||
@ -216,7 +211,6 @@ int main(int argc, char *argv[]) {
|
||||
SetTrainerSpecFromFlag(split_digits);
|
||||
SetTrainerSpecFromFlag(byte_fallback);
|
||||
SetTrainerSpecFromFlag(treat_whitespace_as_suffix);
|
||||
SetTrainerSpecFromFlag(allow_whitespace_only_pieces);
|
||||
SetTrainerSpecFromFlag(hard_vocab_limit);
|
||||
SetTrainerSpecFromFlag(use_all_vocab);
|
||||
SetTrainerSpecFromFlag(unk_id);
|
||||
|
@ -12,7 +12,8 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.!
|
||||
|
||||
#include <algorithm>
|
||||
#include "trainer_interface.h"
|
||||
|
||||
#include <cstdlib>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
@ -33,7 +34,6 @@
|
||||
#include "third_party/absl/strings/str_format.h"
|
||||
#include "third_party/absl/strings/str_join.h"
|
||||
#include "third_party/absl/strings/str_split.h"
|
||||
#include "trainer_interface.h"
|
||||
#include "unicode_script.h"
|
||||
#include "util.h"
|
||||
|
||||
@ -86,10 +86,6 @@ util::Status VerifySpec(const TrainerSpec &trainer_spec) {
|
||||
return util::OkStatus();
|
||||
}
|
||||
|
||||
bool is_unicode_decimal_number(char32 c) {
|
||||
return (c >= 0x30 && c <= 0x39) || (c >= 0xff10 && c <= 0xff19);
|
||||
}
|
||||
|
||||
class SentenceSelector {
|
||||
public:
|
||||
using Sampler = random::ReservoirSampler<TrainerInterface::Sentence>;
|
||||
@ -214,10 +210,9 @@ bool TrainerInterface::IsValidSentencePiece(
|
||||
constexpr unicode_script::ScriptType kAnyType =
|
||||
static_cast<unicode_script::ScriptType>(-1);
|
||||
|
||||
auto is_number = [](char32 c) { return (c >= 0x30 && c <= 0x39); };
|
||||
|
||||
unicode_script::ScriptType prev_script = kAnyType;
|
||||
bool all_whitespace_piece =
|
||||
std::all_of(sentencepiece.begin(), sentencepiece.end(),
|
||||
[](char32 c) { return c == kWSChar; });
|
||||
|
||||
for (size_t pos = 0; pos < sentencepiece.size(); ++pos) {
|
||||
const char32 c = sentencepiece[pos];
|
||||
@ -240,30 +235,25 @@ bool TrainerInterface::IsValidSentencePiece(
|
||||
}
|
||||
|
||||
if (c == kWSChar) {
|
||||
// Only allows whitespace to appear as a prefix of piece unless
|
||||
// allow_whitespace_only_pieces is True.
|
||||
// Only allows whitespace to appear as a prefix of piece.
|
||||
// When split_by_whitespace is false, we allow whitespaces to
|
||||
// appear in the middle, "foo_bar", but do not allow them
|
||||
// to appear as suffix, "foo_bar_".
|
||||
// Regardless of the setting of split_by_whitespace,
|
||||
// whitespace is treated as a prefix/infix of symbol or
|
||||
// independent symbol, unless allow_whitespace_only_pieces() is true,
|
||||
// in which case whitespace only pieces can occur.
|
||||
if (!trainer_spec_.allow_whitespace_only_pieces() ||
|
||||
!all_whitespace_piece) {
|
||||
if (trainer_spec_.treat_whitespace_as_suffix()) {
|
||||
if ((trainer_spec_.split_by_whitespace() &&
|
||||
pos < sentencepiece.size() - 1) ||
|
||||
(!trainer_spec_.split_by_whitespace() &&
|
||||
pos < sentencepiece.size() - 1 && pos == 0)) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
if ((trainer_spec_.split_by_whitespace() && pos > 0) ||
|
||||
(!trainer_spec_.split_by_whitespace() && pos > 0 &&
|
||||
pos == sentencepiece.size() - 1)) {
|
||||
return false;
|
||||
}
|
||||
// independent symbol.
|
||||
if (trainer_spec_.treat_whitespace_as_suffix()) {
|
||||
if ((trainer_spec_.split_by_whitespace() &&
|
||||
pos < sentencepiece.size() - 1) ||
|
||||
(!trainer_spec_.split_by_whitespace() &&
|
||||
pos < sentencepiece.size() - 1 && pos == 0)) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
if ((trainer_spec_.split_by_whitespace() && pos > 0) ||
|
||||
(!trainer_spec_.split_by_whitespace() && pos > 0 &&
|
||||
pos == sentencepiece.size() - 1)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@ -275,11 +265,11 @@ bool TrainerInterface::IsValidSentencePiece(
|
||||
s = unicode_script::U_Han;
|
||||
}
|
||||
|
||||
if (!trainer_spec_.split_by_number() && is_unicode_decimal_number(c)) {
|
||||
if (!trainer_spec_.split_by_number() && is_number(c)) {
|
||||
s = kAnyType;
|
||||
}
|
||||
|
||||
if (trainer_spec_.split_digits() && is_unicode_decimal_number(c)) {
|
||||
if (trainer_spec_.split_digits() && is_number(c)) {
|
||||
if (sentencepiece.size() > 1) return false;
|
||||
}
|
||||
|
||||
@ -528,8 +518,7 @@ void TrainerInterface::SplitSentencesByWhitespace() {
|
||||
absl::flat_hash_map<std::string, int64> tokens;
|
||||
for (const auto &s : sentences_) {
|
||||
for (const auto &w :
|
||||
SplitIntoWords(s.first, trainer_spec_.treat_whitespace_as_suffix(),
|
||||
trainer_spec_.allow_whitespace_only_pieces())) {
|
||||
SplitIntoWords(s.first, trainer_spec_.treat_whitespace_as_suffix())) {
|
||||
tokens[std::string(w)] += s.second;
|
||||
}
|
||||
}
|
||||
|
@ -81,7 +81,6 @@ TEST(TrainerInterfaceTest, IsValidSentencePieceTest) {
|
||||
|
||||
trainer_spec.set_split_by_whitespace(false);
|
||||
EXPECT_TRUE(IsValid(WS));
|
||||
EXPECT_TRUE(IsValid(WS WS WS "a"));
|
||||
EXPECT_TRUE(IsValid(WS "a"));
|
||||
EXPECT_FALSE(IsValid("a" WS));
|
||||
EXPECT_FALSE(IsValid(WS "a" WS));
|
||||
@ -89,17 +88,7 @@ TEST(TrainerInterfaceTest, IsValidSentencePieceTest) {
|
||||
EXPECT_TRUE(IsValid(WS "a" WS "b"));
|
||||
EXPECT_TRUE(IsValid(WS "a" WS "b" WS "c"));
|
||||
EXPECT_FALSE(IsValid("a" WS "b" WS));
|
||||
EXPECT_FALSE(IsValid(WS WS));
|
||||
EXPECT_FALSE(IsValid(WS WS WS));
|
||||
|
||||
trainer_spec.set_allow_whitespace_only_pieces(true);
|
||||
EXPECT_TRUE(IsValid(WS));
|
||||
EXPECT_TRUE(IsValid(WS WS));
|
||||
EXPECT_TRUE(IsValid(WS WS WS));
|
||||
EXPECT_TRUE(IsValid(WS WS "a"));
|
||||
EXPECT_FALSE(IsValid("a" WS WS)); // suffix whitespace illegal without flag
|
||||
|
||||
trainer_spec.set_allow_whitespace_only_pieces(false);
|
||||
trainer_spec.set_split_by_unicode_script(false);
|
||||
EXPECT_TRUE(IsValid("あいう"));
|
||||
EXPECT_TRUE(IsValid("グーグル"));
|
||||
@ -135,15 +124,6 @@ TEST(TrainerInterfaceTest, IsValidSentencePieceTest) {
|
||||
EXPECT_FALSE(IsValid(WS "a" WS "b"));
|
||||
EXPECT_FALSE(IsValid("a" WS "b" WS));
|
||||
|
||||
trainer_spec.set_allow_whitespace_only_pieces(true);
|
||||
EXPECT_TRUE(IsValid(WS));
|
||||
EXPECT_TRUE(IsValid(WS WS));
|
||||
EXPECT_FALSE(IsValid(WS "a" WS));
|
||||
EXPECT_FALSE(IsValid("a" WS "b"));
|
||||
EXPECT_FALSE(IsValid(WS "a" WS "b"));
|
||||
EXPECT_FALSE(IsValid("a" WS "b" WS));
|
||||
|
||||
trainer_spec.set_allow_whitespace_only_pieces(false);
|
||||
trainer_spec.set_split_by_whitespace(false);
|
||||
EXPECT_TRUE(IsValid(WS));
|
||||
EXPECT_FALSE(IsValid(WS "a"));
|
||||
@ -166,12 +146,6 @@ TEST(TrainerInterfaceTest, IsValidSentencePieceTest) {
|
||||
EXPECT_FALSE(IsValid("2007"));
|
||||
EXPECT_FALSE(IsValid("x1"));
|
||||
EXPECT_FALSE(IsValid("2x"));
|
||||
// Fullwidth digits.
|
||||
EXPECT_TRUE(IsValid("1"));
|
||||
EXPECT_FALSE(IsValid("59"));
|
||||
EXPECT_FALSE(IsValid("2007"));
|
||||
EXPECT_FALSE(IsValid("*1"));
|
||||
EXPECT_FALSE(IsValid("2*"));
|
||||
}
|
||||
|
||||
TEST(TrainerInterfaceTest, OverrideSpecialPiecesTest) {
|
||||
|
@ -15,7 +15,6 @@
|
||||
#include <algorithm>
|
||||
#include <cfloat>
|
||||
#include <cmath>
|
||||
#include <complex>
|
||||
#include <map>
|
||||
#include <queue>
|
||||
#include <string>
|
||||
@ -56,17 +55,6 @@ inline float LogSumExp(float x, float y, bool init_mode) {
|
||||
return vmax + log(std::exp(static_cast<double>(vmin - vmax)) + 1.0);
|
||||
}
|
||||
}
|
||||
|
||||
// Returns a sample from a standard Gumbel distribution.
|
||||
// If U ~ U[0, 1], -log(-log U) ~ G(0,1)
|
||||
inline float Gumbel() {
|
||||
const float kEpsilon = 1e-7;
|
||||
auto *mt = random::GetRandomGenerator();
|
||||
std::uniform_real_distribution<float> dis(0.0, 1.0);
|
||||
float noise = -std::log(-(std::log(dis(*mt) + kEpsilon)));
|
||||
|
||||
return noise;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Lattice::Lattice() : node_allocator_(kPreallocateLatticeNodeSize) {}
|
||||
@ -157,7 +145,7 @@ Lattice::Node *Lattice::Insert(int pos, int length) {
|
||||
return node;
|
||||
}
|
||||
|
||||
Lattice::LatticePathWithScore Lattice::Viterbi() {
|
||||
std::vector<Lattice::Node *> Lattice::Viterbi() {
|
||||
const int len = size();
|
||||
|
||||
for (int pos = 0; pos <= len; ++pos) {
|
||||
@ -183,7 +171,6 @@ Lattice::LatticePathWithScore Lattice::Viterbi() {
|
||||
|
||||
// backtrace
|
||||
std::vector<Node *> results;
|
||||
float score = begin_nodes(len)[0]->backtrace_score;
|
||||
for (Node *node = begin_nodes_[len][0]->prev; node->prev != nullptr;
|
||||
node = node->prev) {
|
||||
results.push_back(node);
|
||||
@ -191,43 +178,7 @@ Lattice::LatticePathWithScore Lattice::Viterbi() {
|
||||
|
||||
std::reverse(results.begin(), results.end());
|
||||
|
||||
LatticePathWithScore retval = {results, score};
|
||||
|
||||
return retval;
|
||||
}
|
||||
|
||||
std::vector<float> Lattice::ForwardAlgorithm(float theta) const {
|
||||
const int len = size();
|
||||
std::vector<float> alpha(node_allocator_.size(), 0.0);
|
||||
|
||||
for (int pos = 0; pos <= len; ++pos) {
|
||||
for (Node *rnode : begin_nodes_[pos]) {
|
||||
for (Node *lnode : end_nodes_[pos]) {
|
||||
alpha[rnode->node_id] = LogSumExp(
|
||||
alpha[rnode->node_id], theta * lnode->score + alpha[lnode->node_id],
|
||||
lnode == end_nodes_[pos][0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return alpha;
|
||||
}
|
||||
|
||||
std::vector<float> Lattice::BackwardAlgorithm(float theta) const {
|
||||
const int len = size();
|
||||
std::vector<float> beta(node_allocator_.size(), 0.0);
|
||||
|
||||
for (int pos = len; pos >= 0; --pos) {
|
||||
for (Node *lnode : end_nodes_[pos]) {
|
||||
for (Node *rnode : begin_nodes_[pos]) {
|
||||
beta[lnode->node_id] =
|
||||
LogSumExp(beta[lnode->node_id], rnode->score + beta[rnode->node_id],
|
||||
rnode == begin_nodes_[pos][0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return beta;
|
||||
return results;
|
||||
}
|
||||
|
||||
float Lattice::PopulateMarginal(float freq,
|
||||
@ -238,9 +189,28 @@ float Lattice::PopulateMarginal(float freq,
|
||||
|
||||
// alpha and beta (accumulative log prob) in Forward Backward.
|
||||
// the index of alpha/beta is Node::node_id.
|
||||
std::vector<float> alpha(node_allocator_.size(), 0.0);
|
||||
std::vector<float> beta(node_allocator_.size(), 0.0);
|
||||
|
||||
const auto alpha = ForwardAlgorithm(1.0);
|
||||
const auto beta = BackwardAlgorithm(1.0);
|
||||
for (int pos = 0; pos <= len; ++pos) {
|
||||
for (Node *rnode : begin_nodes_[pos]) {
|
||||
for (Node *lnode : end_nodes_[pos]) {
|
||||
alpha[rnode->node_id] = LogSumExp(alpha[rnode->node_id],
|
||||
lnode->score + alpha[lnode->node_id],
|
||||
lnode == end_nodes_[pos][0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int pos = len; pos >= 0; --pos) {
|
||||
for (Node *lnode : end_nodes_[pos]) {
|
||||
for (Node *rnode : begin_nodes_[pos]) {
|
||||
beta[lnode->node_id] =
|
||||
LogSumExp(beta[lnode->node_id], rnode->score + beta[rnode->node_id],
|
||||
rnode == begin_nodes_[pos][0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const float Z = alpha[begin_nodes_[len][0]->node_id];
|
||||
for (int pos = 0; pos < len; ++pos) {
|
||||
@ -258,46 +228,13 @@ float Lattice::PopulateMarginal(float freq,
|
||||
return freq * Z;
|
||||
}
|
||||
|
||||
float Lattice::CalculateEntropy(float theta) const {
|
||||
const int len = size();
|
||||
|
||||
// alpha[node_id] is the marginal prob of sequence up to start of node
|
||||
// H is entropy of sequence
|
||||
// the index of alpha/H is Node::node_id.
|
||||
std::vector<float> alpha(node_allocator_.size(), 0.0);
|
||||
std::vector<float> H(node_allocator_.size(), 0.0);
|
||||
|
||||
// Populate the forward marginals to get the normalising constant
|
||||
alpha = ForwardAlgorithm(theta);
|
||||
|
||||
// Now populate the forward entropies
|
||||
for (int pos = 0; pos <= len; ++pos) {
|
||||
for (Node *rnode : begin_nodes_[pos]) {
|
||||
for (Node *lnode : end_nodes_[pos]) {
|
||||
// Contribution each lnode makes = p(lnode) * (H(lnode) + log p(lnode))
|
||||
|
||||
// We have to normalise p(lnode) by the marginal contribution it makes
|
||||
const float lnode_transition_prob =
|
||||
((theta * lnode->score) + alpha[lnode->node_id] -
|
||||
alpha[rnode->node_id]);
|
||||
H[rnode->node_id] += std::exp(lnode_transition_prob) *
|
||||
(H[lnode->node_id] + lnode_transition_prob);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return -H[begin_nodes_[len][0]->node_id];
|
||||
}
|
||||
|
||||
std::vector<Lattice::LatticePathWithScore> Lattice::NBest(size_t nbest_size,
|
||||
bool sample,
|
||||
float theta) {
|
||||
std::vector<std::vector<Lattice::Node *>> Lattice::NBest(size_t nbest_size) {
|
||||
if (nbest_size < 1) {
|
||||
LOG(WARNING) << "nbest_size >= 1. Returns empty result.";
|
||||
return {};
|
||||
}
|
||||
|
||||
if (nbest_size == 1 && !sample) {
|
||||
if (nbest_size == 1) {
|
||||
return {Viterbi()};
|
||||
}
|
||||
|
||||
@ -306,7 +243,6 @@ std::vector<Lattice::LatticePathWithScore> Lattice::NBest(size_t nbest_size,
|
||||
// At each partial path x, compute f(x) as follows
|
||||
// f(x) = g(x) + h(x).
|
||||
// g(x): the sum of scores from EOS to the left-most node in x.
|
||||
// for a complete hypothesis, g(hyp) is the score of the hypothesis.
|
||||
// h(x): a heuristic that estimates the largest score from x to BOS.
|
||||
// f(x): the priority to pop a new hypothesis from the priority queue.
|
||||
//
|
||||
@ -332,27 +268,18 @@ std::vector<Lattice::LatticePathWithScore> Lattice::NBest(size_t nbest_size,
|
||||
model::FreeList<Hypothesis> hypothesis_allocator(kPreallocatedHypothesisSize);
|
||||
|
||||
Agenda agenda;
|
||||
std::vector<Lattice::LatticePathWithScore> results;
|
||||
std::vector<std::vector<Node *>> results;
|
||||
|
||||
auto *eos = hypothesis_allocator.Allocate();
|
||||
eos->node = eos_node();
|
||||
eos->next = nullptr;
|
||||
eos->gx = 0.0;
|
||||
|
||||
std::vector<float> alpha(node_allocator_.size(), 0.0);
|
||||
|
||||
if (sample) {
|
||||
// Run forwards algorithm to get normalising constants
|
||||
alpha = ForwardAlgorithm(theta);
|
||||
// f(eos) = Gumbel(0), as it is the perturbed score of the entire lattice.
|
||||
eos->fx = Gumbel();
|
||||
} else {
|
||||
// Run Viterbi first to fill backtrace score.
|
||||
Viterbi();
|
||||
eos->fx = eos->node->backtrace_score;
|
||||
}
|
||||
eos->fx = eos->node->score;
|
||||
eos->gx = eos->node->score;
|
||||
agenda.push(eos);
|
||||
|
||||
// Run Viterbi first to fill backtrace score.
|
||||
Viterbi();
|
||||
|
||||
while (!agenda.empty()) {
|
||||
auto *top = agenda.top();
|
||||
agenda.pop();
|
||||
@ -362,56 +289,21 @@ std::vector<Lattice::LatticePathWithScore> Lattice::NBest(size_t nbest_size,
|
||||
if (node == bos_node()) {
|
||||
results.resize(results.size() + 1);
|
||||
for (auto *n = top->next; n->next != nullptr; n = n->next) {
|
||||
results.back().first.push_back(n->node);
|
||||
results.back().push_back(n->node);
|
||||
}
|
||||
results.back().second = top->fx;
|
||||
if (results.size() == nbest_size) {
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
const int end_nodes_size = end_nodes(node->pos).size();
|
||||
std::vector<float> probs(end_nodes_size, 0.0);
|
||||
std::vector<float> perturbed_probs(end_nodes_size, 0.0);
|
||||
std::vector<double> adjusted_probs(end_nodes_size, 0.0);
|
||||
const float Z = alpha[node->node_id];
|
||||
if (sample) {
|
||||
float max_score = -1e8;
|
||||
// Calculate the marginal and perturbed scores for stochastic search
|
||||
for (int i = 0; i < end_nodes(node->pos).size(); i++) {
|
||||
Node *lnode = end_nodes(node->pos)[i];
|
||||
// Calculate backwards transition score
|
||||
probs[i] = top->gx + alpha[lnode->node_id] + (theta * lnode->score) - Z;
|
||||
perturbed_probs[i] = probs[i] + Gumbel();
|
||||
if (perturbed_probs[i] > max_score) {
|
||||
max_score = perturbed_probs[i];
|
||||
}
|
||||
}
|
||||
// Now constrain the sampled continuations to match the score of parent
|
||||
for (int i = 0; i < adjusted_probs.size(); i++) {
|
||||
// Use numerically stable version of truncated Gumbel:
|
||||
// https://arxiv.org/pdf/1903.06059.pdf appendix B.3
|
||||
const float v = top->fx - perturbed_probs[i] +
|
||||
std::log1p(-std::exp(perturbed_probs[i] - max_score));
|
||||
adjusted_probs[i] = top->fx - std::max(static_cast<float>(0.0), v) -
|
||||
std::log1p(std::exp(-std::abs(v)));
|
||||
}
|
||||
}
|
||||
|
||||
// Expands new node ending at node->pos
|
||||
for (int i = 0; i < end_nodes(node->pos).size(); i++) {
|
||||
Node *lnode = end_nodes(node->pos)[i];
|
||||
for (Node *lnode : end_nodes(node->pos)) {
|
||||
auto *hyp = hypothesis_allocator.Allocate();
|
||||
hyp->node = lnode;
|
||||
if (sample) {
|
||||
hyp->gx = probs[i];
|
||||
hyp->fx = adjusted_probs[i];
|
||||
} else {
|
||||
hyp->gx = lnode->score + top->gx; // just adds node->score
|
||||
hyp->fx =
|
||||
lnode->backtrace_score + top->gx; // backtrace_score is h(node).
|
||||
}
|
||||
hyp->gx = lnode->score + top->gx; // just adds node->score
|
||||
hyp->fx =
|
||||
lnode->backtrace_score + top->gx; // backtrace_score is h(node).
|
||||
hyp->next = top;
|
||||
agenda.push(hyp);
|
||||
}
|
||||
@ -443,7 +335,15 @@ std::vector<Lattice::Node *> Lattice::Sample(float theta) {
|
||||
|
||||
std::vector<float> alpha(node_allocator_.size(), 0.0);
|
||||
|
||||
alpha = ForwardAlgorithm(theta);
|
||||
for (int pos = 0; pos <= len; ++pos) {
|
||||
for (Node *rnode : begin_nodes_[pos]) {
|
||||
for (Node *lnode : end_nodes_[pos]) {
|
||||
alpha[rnode->node_id] = LogSumExp(
|
||||
alpha[rnode->node_id], theta * lnode->score + alpha[lnode->node_id],
|
||||
lnode == end_nodes_[pos][0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto *mt = random::GetRandomGenerator();
|
||||
|
||||
@ -614,7 +514,7 @@ EncodeResult Model::Encode(absl::string_view normalized) const {
|
||||
PopulateNodes(&lattice);
|
||||
|
||||
EncodeResult results;
|
||||
for (const auto *node : lattice.Viterbi().first) {
|
||||
for (const auto *node : lattice.Viterbi()) {
|
||||
results.emplace_back(node->piece, node->id);
|
||||
}
|
||||
|
||||
@ -634,12 +534,14 @@ NBestEncodeResult Model::NBestEncode(absl::string_view normalized,
|
||||
PopulateNodes(&lattice);
|
||||
|
||||
NBestEncodeResult nbest_results;
|
||||
for (const auto &nbest : lattice.NBest(nbest_size, false, 0.0)) {
|
||||
for (const auto &nbest : lattice.NBest(nbest_size)) {
|
||||
EncodeResult results;
|
||||
for (const auto *node : nbest.first) {
|
||||
float score = 0.0;
|
||||
for (const auto *node : nbest) {
|
||||
score += node->score;
|
||||
results.emplace_back(node->piece, node->id);
|
||||
}
|
||||
nbest_results.emplace_back(results, nbest.second);
|
||||
nbest_results.emplace_back(results, score);
|
||||
}
|
||||
|
||||
return nbest_results;
|
||||
@ -663,123 +565,6 @@ EncodeResult Model::SampleEncode(absl::string_view normalized,
|
||||
return results;
|
||||
}
|
||||
|
||||
NBestEncodeResult Model::SampleEncodeAndScore(absl::string_view normalized,
|
||||
float theta, int samples,
|
||||
bool wor,
|
||||
bool include_best) const {
|
||||
if (!status().ok() || normalized.empty()) {
|
||||
return {};
|
||||
}
|
||||
NBestEncodeResult results;
|
||||
Lattice lattice;
|
||||
lattice.SetSentence(normalized);
|
||||
PopulateNodes(&lattice);
|
||||
|
||||
std::vector<float> alpha = lattice.ForwardAlgorithm(theta);
|
||||
float marginal = alpha[lattice.eos_node()->node_id];
|
||||
|
||||
if (include_best) {
|
||||
if (!wor) {
|
||||
LOG(FATAL) << "include_best not supported for wor false";
|
||||
}
|
||||
EncodeResult result;
|
||||
Lattice::LatticePathWithScore best_path = lattice.Viterbi();
|
||||
|
||||
for (const auto *node : best_path.first) {
|
||||
result.emplace_back(node->piece, node->id);
|
||||
}
|
||||
|
||||
// Inclusion probability if we always include the best is 1.
|
||||
results.emplace_back(result, 0.0);
|
||||
}
|
||||
|
||||
if (wor) {
|
||||
// Draw k+1 samples as we need perturbed score of k+1th element
|
||||
std::vector<Lattice::LatticePathWithScore> nbest_samples =
|
||||
lattice.NBest(samples + 1, true, theta);
|
||||
|
||||
if (include_best) {
|
||||
std::vector<std::vector<Lattice::Node *>> nbest_paths(
|
||||
nbest_samples.size());
|
||||
for (int i = 0; i < nbest_samples.size(); i++) {
|
||||
nbest_paths[i] = nbest_samples[i].first;
|
||||
}
|
||||
// Remove the best result from the samples if necessary
|
||||
Lattice::LatticePathWithScore best_path = lattice.Viterbi();
|
||||
|
||||
const int index_of_best =
|
||||
(std::find(nbest_paths.begin(), nbest_paths.end(), best_path.first) -
|
||||
nbest_paths.begin());
|
||||
|
||||
if (index_of_best != nbest_samples.size()) {
|
||||
LOG(INFO) << "removing best path from samples";
|
||||
nbest_samples.erase(nbest_samples.begin() + index_of_best);
|
||||
} else {
|
||||
nbest_samples.pop_back();
|
||||
}
|
||||
}
|
||||
// We use the perturbed score of the k+1th element to calculate the
|
||||
// inclusion probability.
|
||||
const double kappa = static_cast<double>(nbest_samples.back().second);
|
||||
// Discard the last sample
|
||||
nbest_samples.pop_back();
|
||||
for (const auto &nbest : nbest_samples) {
|
||||
EncodeResult result;
|
||||
float score = 0.0;
|
||||
|
||||
for (const auto *node : nbest.first) {
|
||||
score += (theta * node->score);
|
||||
result.emplace_back(node->piece, node->id);
|
||||
}
|
||||
|
||||
results.emplace_back(result, score - marginal);
|
||||
}
|
||||
|
||||
// Now calculate the inclusion probability
|
||||
for (auto &it : results) {
|
||||
// Only modify non best sample inclusion probabilities.
|
||||
if (it.second != 0.0) {
|
||||
double x = it.second - kappa;
|
||||
double y = std::exp(x);
|
||||
double inclusion_prob;
|
||||
if (x <= -10) {
|
||||
// Series expansion of the log Gumbel survival function up to eps.
|
||||
inclusion_prob =
|
||||
x - (y / 2) + (std::pow(y, 2) / 24) - std::pow(y, 4) / 2880;
|
||||
} else {
|
||||
inclusion_prob = std::log(-std::expm1(-y));
|
||||
}
|
||||
it.second = static_cast<float>(inclusion_prob);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
while (results.size() < samples) {
|
||||
Lattice lattice;
|
||||
lattice.SetSentence(normalized);
|
||||
PopulateNodes(&lattice);
|
||||
|
||||
float score = 0.0;
|
||||
EncodeResult result;
|
||||
std::vector<Lattice::Node *> sample = lattice.Sample(theta);
|
||||
for (const auto *node : sample) {
|
||||
result.emplace_back(node->piece, node->id);
|
||||
score += (theta * node->score);
|
||||
}
|
||||
results.emplace_back(result, score - marginal);
|
||||
}
|
||||
}
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
float Model::CalculateEntropy(absl::string_view normalized, float theta) const {
|
||||
Lattice lattice;
|
||||
lattice.SetSentence(normalized);
|
||||
PopulateNodes(&lattice);
|
||||
|
||||
return lattice.CalculateEntropy(theta);
|
||||
}
|
||||
|
||||
bool Model::VerifyOutputsEquivalent(absl::string_view expected,
|
||||
absl::string_view actual) const {
|
||||
auto compute_unigram_model_score =
|
||||
|
@ -82,28 +82,17 @@ class Lattice {
|
||||
// After calling this method, The caller must set Node::score and Node::id.
|
||||
Node *Insert(int pos, int length);
|
||||
|
||||
using LatticePathWithScore = std::pair<std::vector<Node *>, float>;
|
||||
|
||||
// Returns Viterbi path. All nodes must be populated in advance.
|
||||
LatticePathWithScore Viterbi();
|
||||
|
||||
// Runs forwards/backwards algorithm, returns vector with normalised
|
||||
// transition probs.
|
||||
std::vector<float> ForwardAlgorithm(float theta) const;
|
||||
std::vector<float> BackwardAlgorithm(float theta) const;
|
||||
std::vector<Node *> Viterbi();
|
||||
|
||||
// Returns n-best results.
|
||||
std::vector<LatticePathWithScore> NBest(size_t nbest_size, bool sample,
|
||||
float theta);
|
||||
std::vector<std::vector<Node *>> NBest(size_t nbest_size);
|
||||
|
||||
// Samples one path from the lattice according to the
|
||||
// generation probability (Product of piece probabilities).
|
||||
// `theta` is a smoothing parameter.
|
||||
std::vector<Node *> Sample(float theta);
|
||||
|
||||
// Calculates the entropy of the lattice.
|
||||
float CalculateEntropy(float theta) const;
|
||||
|
||||
// Populates marginal probability of every node in this lattice.
|
||||
// |freq| is the frequency of the sentence.
|
||||
// for (auto *node : all_nodes_) {
|
||||
@ -138,19 +127,8 @@ class Model : public ModelInterface {
|
||||
EncodeResult SampleEncode(absl::string_view normalized,
|
||||
float theta) const override;
|
||||
|
||||
NBestEncodeResult SampleEncodeAndScore(absl::string_view normalized,
|
||||
float theta, int samples, bool wor,
|
||||
bool include_best) const override;
|
||||
|
||||
float CalculateEntropy(absl::string_view normalized,
|
||||
float theta) const override;
|
||||
|
||||
bool IsSampleEncodeAvailable() const override { return true; }
|
||||
|
||||
bool IsSampleEncodeAndScoreAvailable() const override { return true; }
|
||||
|
||||
bool IsCalculateEntropyAvailable() const override { return true; }
|
||||
|
||||
bool IsNBestEncodeAvailable() const override { return true; }
|
||||
|
||||
// Returns the minimum score in sentence pieces.
|
||||
|
@ -161,11 +161,11 @@ TEST(LatticeTest, InsertTest) {
|
||||
TEST(LatticeTest, ViterbiFromIncompleteLatticeTest) {
|
||||
Lattice lattice;
|
||||
lattice.SetSentence("ABC");
|
||||
EXPECT_TRUE(lattice.Viterbi().first.empty());
|
||||
EXPECT_TRUE(lattice.Viterbi().empty());
|
||||
|
||||
// Still incomplete
|
||||
lattice.Insert(0, 1);
|
||||
EXPECT_TRUE(lattice.Viterbi().first.empty());
|
||||
EXPECT_TRUE(lattice.Viterbi().empty());
|
||||
|
||||
lattice.Insert(1, 1);
|
||||
lattice.Insert(2, 1);
|
||||
@ -198,16 +198,16 @@ TEST(LatticeTest, ViterbiTest) {
|
||||
InsertWithScore(&lattice, 0, 1, 0.0); // A
|
||||
InsertWithScore(&lattice, 1, 1, 0.0); // B
|
||||
InsertWithScore(&lattice, 2, 1, 0.0); // C
|
||||
EXPECT_EQ("A B C", GetTokenized(lattice.Viterbi().first));
|
||||
EXPECT_EQ("A B C", GetTokenized(lattice.Viterbi()));
|
||||
|
||||
InsertWithScore(&lattice, 0, 2, 2.0); // AB
|
||||
EXPECT_EQ("AB C", GetTokenized(lattice.Viterbi().first));
|
||||
EXPECT_EQ("AB C", GetTokenized(lattice.Viterbi()));
|
||||
|
||||
InsertWithScore(&lattice, 1, 2, 5.0); // BC
|
||||
EXPECT_EQ("A BC", GetTokenized(lattice.Viterbi().first));
|
||||
EXPECT_EQ("A BC", GetTokenized(lattice.Viterbi()));
|
||||
|
||||
InsertWithScore(&lattice, 0, 3, 10.0); // ABC
|
||||
EXPECT_EQ("ABC", GetTokenized(lattice.Viterbi().first));
|
||||
EXPECT_EQ("ABC", GetTokenized(lattice.Viterbi()));
|
||||
}
|
||||
|
||||
TEST(LatticeTest, NBestTest) {
|
||||
@ -221,174 +221,21 @@ TEST(LatticeTest, NBestTest) {
|
||||
InsertWithScore(&lattice, 1, 2, 5.0); // BC
|
||||
InsertWithScore(&lattice, 0, 3, 10.0); // ABC
|
||||
|
||||
auto nbests = lattice.NBest(10, false, 0.0);
|
||||
auto nbests = lattice.NBest(10);
|
||||
EXPECT_EQ(4, nbests.size());
|
||||
|
||||
EXPECT_EQ("ABC", GetTokenized(nbests[0].first));
|
||||
EXPECT_EQ("A BC", GetTokenized(nbests[1].first));
|
||||
EXPECT_EQ("AB C", GetTokenized(nbests[2].first));
|
||||
EXPECT_EQ("A B C", GetTokenized(nbests[3].first));
|
||||
EXPECT_EQ("ABC", GetTokenized(nbests[0]));
|
||||
EXPECT_EQ("A BC", GetTokenized(nbests[1]));
|
||||
EXPECT_EQ("AB C", GetTokenized(nbests[2]));
|
||||
EXPECT_EQ("A B C", GetTokenized(nbests[3]));
|
||||
|
||||
auto nbests0 = lattice.NBest(0, false, 0.0);
|
||||
auto nbests0 = lattice.NBest(0);
|
||||
EXPECT_TRUE(nbests0.empty());
|
||||
|
||||
auto nbests1 = lattice.NBest(1, false, 0.0);
|
||||
auto nbests1 = lattice.NBest(1);
|
||||
EXPECT_EQ(nbests1.size(), 1);
|
||||
}
|
||||
|
||||
TEST(LatticeTest, NBestSampleTest) {
|
||||
Lattice lattice;
|
||||
lattice.SetSentence("ABC");
|
||||
|
||||
InsertWithScore(&lattice, 0, 1, 0.0); // A
|
||||
InsertWithScore(&lattice, 1, 1, 0.0); // B
|
||||
InsertWithScore(&lattice, 2, 1, 0.1); // C
|
||||
InsertWithScore(&lattice, 0, 2, 0.2); // AB
|
||||
InsertWithScore(&lattice, 1, 2, 0.5); // BC
|
||||
InsertWithScore(&lattice, 0, 3, 1.0); // ABC
|
||||
|
||||
// Calculate expected probabilities of each path
|
||||
// Note that sampling without replacement affects the expected frequencies!
|
||||
const std::vector<double> kTheta = {0.0, 0.01, 0.5, 0.7, 1.0};
|
||||
for (const auto theta : kTheta) {
|
||||
std::vector<std::string> strings = {"ABC", "AB C", "A BC", "A B C"};
|
||||
std::map<std::string, float> probs;
|
||||
probs["ABC"] = std::exp(theta * 1.0);
|
||||
probs["AB C"] = std::exp(theta * (0.2 + 0.1));
|
||||
probs["A BC"] = std::exp(theta * (0.0 + 0.5));
|
||||
probs["A B C"] = std::exp(theta * (0.0 + 0.0 + 0.1));
|
||||
|
||||
for (const auto &it : strings) {
|
||||
EXPECT_EQ(1, probs.count(it));
|
||||
}
|
||||
|
||||
double Z = 0.0;
|
||||
for (const auto &it : probs) Z += it.second;
|
||||
for (auto &it : probs) it.second /= Z;
|
||||
|
||||
std::map<std::pair<std::string, std::string>, float> pair_probs;
|
||||
for (const auto first : strings) {
|
||||
for (const auto second : strings) {
|
||||
if (first == second) {
|
||||
pair_probs[std::make_pair(first, second)] = 0;
|
||||
} else {
|
||||
float first_prob = probs[first];
|
||||
float second_prob = probs[second] / (1 - first_prob);
|
||||
pair_probs[std::make_pair(first, second)] = first_prob * second_prob;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::map<std::string, float> inclusion_probs;
|
||||
for (const auto string : strings) {
|
||||
float inclusion_prob = 0.0;
|
||||
for (const auto other_string : strings) {
|
||||
inclusion_prob += pair_probs[std::make_pair(string, other_string)];
|
||||
}
|
||||
for (const auto other_string : strings) {
|
||||
inclusion_prob += pair_probs[std::make_pair(other_string, string)];
|
||||
}
|
||||
inclusion_probs[string] = inclusion_prob / 2;
|
||||
}
|
||||
|
||||
int kTrials = 10000;
|
||||
|
||||
std::vector<int> kNumSamples = {1, 2};
|
||||
|
||||
for (const auto num_samples : kNumSamples) {
|
||||
std::map<std::string, int> counts;
|
||||
for (int i = 0; i < kTrials; i++) {
|
||||
auto nbests = lattice.NBest(num_samples, true, theta);
|
||||
for (const auto nbest : nbests) {
|
||||
counts[GetTokenized(nbest.first)]++;
|
||||
}
|
||||
}
|
||||
|
||||
EXPECT_EQ(inclusion_probs.size(), counts.size());
|
||||
// If we take multiple samples WOR, we have to use corrected probs.
|
||||
std::map<std::string, float> probs_to_use =
|
||||
(num_samples == 1 ? probs : inclusion_probs);
|
||||
|
||||
for (const auto &it : probs_to_use) {
|
||||
EXPECT_NEAR(it.second, 1.0 * counts[it.first] / (kTrials * num_samples),
|
||||
0.02);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(LatticeTest, CalculateEntropyTest) {
|
||||
Lattice lattice;
|
||||
lattice.SetSentence("ABC");
|
||||
|
||||
InsertWithScore(&lattice, 0, 1, 0.0); // A
|
||||
InsertWithScore(&lattice, 1, 1, 0.0); // B
|
||||
InsertWithScore(&lattice, 2, 1, 0.1); // C
|
||||
InsertWithScore(&lattice, 0, 2, 0.2); // AB
|
||||
InsertWithScore(&lattice, 1, 2, 0.5); // BC
|
||||
InsertWithScore(&lattice, 0, 3, 1.0); // ABC
|
||||
|
||||
// Calculate expected probabilities of each path
|
||||
const std::vector<double> kTheta = {0.0, 0.01, 0.5, 0.7, 1.0};
|
||||
for (const auto theta : kTheta) {
|
||||
std::vector<std::string> strings = {"ABC", "AB C", "A BC", "A B C"};
|
||||
std::map<std::string, float> probs;
|
||||
probs["ABC"] = std::exp(theta * 1.0);
|
||||
probs["AB C"] = std::exp(theta * (0.2 + 0.1));
|
||||
probs["A BC"] = std::exp(theta * (0.0 + 0.5));
|
||||
probs["A B C"] = std::exp(theta * (0.0 + 0.0 + 0.1));
|
||||
|
||||
double Z = 0.0;
|
||||
for (const auto &it : probs) Z += it.second;
|
||||
for (auto &it : probs) it.second /= Z;
|
||||
|
||||
for (const auto &it : strings) {
|
||||
EXPECT_EQ(1, probs.count(it));
|
||||
}
|
||||
float entropy = 0.0;
|
||||
for (const auto &it : probs) {
|
||||
entropy += (it.second * std::log(it.second));
|
||||
}
|
||||
EXPECT_NEAR(-entropy, lattice.CalculateEntropy(theta), 0.02);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(LatticeTest, ForwardAlgorithmTest) {
|
||||
Lattice lattice;
|
||||
lattice.SetSentence("ABC");
|
||||
|
||||
InsertWithScore(&lattice, 0, 1, 0.0); // A
|
||||
InsertWithScore(&lattice, 1, 1, 0.0); // B
|
||||
InsertWithScore(&lattice, 2, 1, 0.1); // C
|
||||
InsertWithScore(&lattice, 0, 2, 0.2); // AB
|
||||
InsertWithScore(&lattice, 1, 2, 0.5); // BC
|
||||
InsertWithScore(&lattice, 0, 3, 1.0); // ABC
|
||||
|
||||
const std::vector<float> kTheta = {0.0, 0.01, 0.5, 0.7, 1.0};
|
||||
for (const auto theta : kTheta) {
|
||||
std::vector<float> alpha = lattice.ForwardAlgorithm(theta);
|
||||
EXPECT_EQ(alpha.size(), 8); // 6 nodes, plus BOS, EOS
|
||||
// only alpha[C], alpha[EOS] have non-zero alpha
|
||||
for (int i : {0, 1, 2, 3}) {
|
||||
for (const auto &node : lattice.begin_nodes(i)) {
|
||||
if (i < 2) {
|
||||
EXPECT_EQ(alpha[node->node_id], 0.0);
|
||||
} else if (i == 2) {
|
||||
float Z =
|
||||
std::log(std::exp(theta * (0.0 + 0.0)) + std::exp(theta * 0.2));
|
||||
EXPECT_EQ(alpha[node->node_id], Z);
|
||||
} else if (i == 3) {
|
||||
float Z = std::log(std::exp(theta * (0.0 + 0.0 + 0.1)) + // A + B + C
|
||||
std::exp(theta * (0.2 + 0.1)) + // AB + C
|
||||
std::exp(theta * (0.0 + 0.5)) + // A + BC
|
||||
std::exp(theta * 1.0)); // ABC
|
||||
EXPECT_EQ(Z, alpha[node->node_id]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(LatticeTest, PopulateMarginalTest) {
|
||||
Lattice lattice;
|
||||
lattice.SetSentence("ABC");
|
||||
@ -514,102 +361,6 @@ TEST(UnigramModelTest, SetUnigramModelTest) {
|
||||
model.model_proto().SerializeAsString());
|
||||
}
|
||||
|
||||
TEST(UnigramModelTest, SampleEncodeAndScoreTest) {
|
||||
// Test whether inclusion probabilities are correct
|
||||
ModelProto model_proto = MakeBaseModelProto();
|
||||
AddPiece(&model_proto, "A", 0.0); // 3
|
||||
AddPiece(&model_proto, "B", 0.0); // 4
|
||||
AddPiece(&model_proto, "C", 0.1); // 5
|
||||
AddPiece(&model_proto, "AB", 0.2); // 6
|
||||
AddPiece(&model_proto, "BC", 0.5); // 7
|
||||
AddPiece(&model_proto, "ABC", 1.0); // 8
|
||||
|
||||
Model model(model_proto);
|
||||
|
||||
Lattice lattice;
|
||||
lattice.SetSentence("ABC");
|
||||
model.PopulateNodes(&lattice);
|
||||
|
||||
std::vector<float> kTheta = {0.0, 1.0};
|
||||
|
||||
for (const auto theta : kTheta) {
|
||||
std::vector<std::string> strings = {"ABC", "AB C", "A BC", "A B C"};
|
||||
std::map<std::string, float> probs;
|
||||
probs["ABC"] = std::exp(theta * 1.0);
|
||||
probs["AB C"] = std::exp(theta * (0.2 + 0.1));
|
||||
probs["A BC"] = std::exp(theta * (0.0 + 0.5));
|
||||
probs["A B C"] = std::exp(theta * (0.0 + 0.0 + 0.1));
|
||||
|
||||
for (const auto &it : strings) {
|
||||
EXPECT_EQ(1, probs.count(it));
|
||||
}
|
||||
|
||||
double Z = 0.0;
|
||||
for (const auto &it : probs) Z += it.second;
|
||||
for (auto &it : probs) it.second /= Z;
|
||||
|
||||
std::map<std::pair<std::string, std::string>, float> pair_probs;
|
||||
for (const auto first : strings) {
|
||||
for (const auto second : strings) {
|
||||
if (first == second) {
|
||||
pair_probs[std::make_pair(first, second)] = 0;
|
||||
} else {
|
||||
float first_prob = probs[first];
|
||||
float second_prob = probs[second] / (1 - first_prob);
|
||||
pair_probs[std::make_pair(first, second)] = first_prob * second_prob;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::map<std::string, float> inclusion_probs;
|
||||
for (const auto string : strings) {
|
||||
float inclusion_prob = 0.0;
|
||||
for (const auto other_string : strings) {
|
||||
inclusion_prob += pair_probs[std::make_pair(string, other_string)];
|
||||
}
|
||||
for (const auto other_string : strings) {
|
||||
inclusion_prob += pair_probs[std::make_pair(other_string, string)];
|
||||
}
|
||||
inclusion_probs[string] = inclusion_prob / 2;
|
||||
}
|
||||
std::vector<int> kNumSamples = {1, 2};
|
||||
|
||||
for (const auto num_samples : kNumSamples) {
|
||||
std::map<std::string, int> counts;
|
||||
std::map<std::string, float> scores;
|
||||
int kTrials = 50000;
|
||||
for (int i = 0; i < kTrials; i++) {
|
||||
NBestEncodeResult sample =
|
||||
model.SampleEncodeAndScore("ABC", theta, num_samples, true, false);
|
||||
|
||||
for (const auto &it : sample) {
|
||||
std::vector<std::string> tokens;
|
||||
for (const auto &inner_it : it.first) {
|
||||
tokens.push_back(std::string(inner_it.first));
|
||||
}
|
||||
std::string sample_string = absl::StrJoin(tokens, " ");
|
||||
counts[sample_string] += 1;
|
||||
// use the fact that E(1_{i in sample} / score of i) = 1
|
||||
// see https://arxiv.org/pdf/1903.06059.pdf appendix D
|
||||
scores[sample_string] += std::exp(-it.second);
|
||||
}
|
||||
}
|
||||
|
||||
// Check that counts and probs are correct
|
||||
std::map<std::string, float> probs_to_use =
|
||||
(num_samples == 1 ? probs : inclusion_probs);
|
||||
|
||||
for (const auto &it : scores) Z += it.second;
|
||||
for (const auto &it : probs_to_use) {
|
||||
EXPECT_NEAR(it.second, 1.0 * counts[it.first] / (kTrials * num_samples),
|
||||
0.02);
|
||||
// The expectation is quite loose, use a higher tolerance
|
||||
EXPECT_NEAR(1.0, scores[it.first] / kTrials, 0.05);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(UnigramModelTest, PieceToIdTest) {
|
||||
ModelProto model_proto = MakeBaseModelProto();
|
||||
|
||||
|
@ -223,7 +223,7 @@ std::vector<float> Trainer::RunEStep(const TrainerModel &model, float *obj,
|
||||
lattice.SetSentence(w);
|
||||
model.PopulateNodes(&lattice);
|
||||
const float Z = lattice.PopulateMarginal(freq, &expected[n]);
|
||||
ntokens[n] += lattice.Viterbi().first.size();
|
||||
ntokens[n] += lattice.Viterbi().size();
|
||||
CHECK(!std::isnan(Z))
|
||||
<< "likelihood is NAN. Input sentence may be too long";
|
||||
objs[n] -= Z / all_sentence_freq;
|
||||
@ -297,17 +297,17 @@ TrainerModel::SentencePieces Trainer::PruneSentencePieces(
|
||||
const auto &w = sentencepieces[i];
|
||||
lattice.SetSentence(w.first);
|
||||
model.PopulateNodes(&lattice);
|
||||
const auto nbests = lattice.NBest(2, false, 0.0);
|
||||
const auto nbests = lattice.NBest(2);
|
||||
if (nbests.size() == 1) {
|
||||
// No second-best result is found. always keep this sentencepiece.
|
||||
always_keep[i] = true;
|
||||
continue;
|
||||
} else if (nbests[0].first.size() >= 2) {
|
||||
} else if (nbests[0].size() >= 2) {
|
||||
// Can safely remove this sentencepiece if its Viterbi path is split.
|
||||
always_keep[i] = false;
|
||||
} else if (nbests[0].first.size() == 1) {
|
||||
} else if (nbests[0].size() == 1) {
|
||||
always_keep[i] = true;
|
||||
for (const auto *node : nbests[1].first) {
|
||||
for (const auto *node : nbests[1]) {
|
||||
alternatives[i].push_back(node->id);
|
||||
}
|
||||
}
|
||||
@ -339,7 +339,7 @@ TrainerModel::SentencePieces Trainer::PruneSentencePieces(
|
||||
lattice.SetSentence(w.first);
|
||||
model.PopulateNodes(&lattice);
|
||||
vsums[n] += w.second;
|
||||
for (const auto *node : lattice.Viterbi().first) {
|
||||
for (const auto *node : lattice.Viterbi()) {
|
||||
if (node->id >= 0) {
|
||||
freqs[n][node->id] += w.second;
|
||||
inverteds[n][node->id].push_back(i);
|
||||
|
@ -12,12 +12,11 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.!
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "util.h"
|
||||
|
||||
namespace sentencepiece {
|
||||
#include <iostream>
|
||||
|
||||
namespace sentencepiece {
|
||||
namespace {
|
||||
constexpr unsigned int kDefaultSeed = static_cast<unsigned int>(-1);
|
||||
static unsigned int g_seed = kDefaultSeed;
|
||||
|
1
third_party/absl/flags/flag.cc
vendored
1
third_party/absl/flags/flag.cc
vendored
@ -171,7 +171,6 @@ void Flag<bool>::set_value_as_str(const std::string &value_as_str) {
|
||||
|
||||
template class Flag<std::string>;
|
||||
template class Flag<int32>;
|
||||
template class Flag<uint32>;
|
||||
template class Flag<double>;
|
||||
template class Flag<bool>;
|
||||
template class Flag<int64>;
|
||||
|
Loading…
Reference in New Issue
Block a user