Revert "sync from internal"

This reverts commit 05db0894d8.
This commit is contained in:
Taku Kudo 2021-06-16 14:51:52 +09:00
parent 05db0894d8
commit 3a5bc5815b
29 changed files with 324 additions and 1079 deletions

View File

@ -1 +1 @@
0.1.96
0.1.95

View File

@ -1 +1 @@
0.1.96
0.1.95

View File

@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.!
#include "bpe_model.h"
#include <functional>
#include <memory>
#include <queue>
@ -19,7 +21,6 @@
#include <utility>
#include <vector>
#include "bpe_model.h"
#include "freelist.h"
#include "third_party/absl/container/flat_hash_map.h"
#include "util.h"

View File

@ -12,11 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.!
#include "builder.h"
#include <algorithm>
#include <functional>
#include <utility>
#include "builder.h"
#include "filesystem.h"
#include "third_party/absl/strings/str_join.h"
#include "third_party/absl/strings/str_replace.h"
@ -367,7 +368,6 @@ util::Status Builder::BuildNmtNFKCMap(CharsMap *chars_map) {
nfkc_map[{0xFEFF}] = {0x20}; // ZERO WIDTH NO-BREAK
nfkc_map[{0xFFFD}] = {0x20}; // REPLACEMENT CHARACTER
nfkc_map[{0x200C}] = {0x20}; // ZERO WIDTH NON-JOINER
nfkc_map[{0x200D}] = {0x20}; // ZERO WIDTH JOINER
// Ascii Control characters
nfkc_map[{0x0001}] = {};

View File

@ -285,22 +285,22 @@ class TrainerSpec::_Internal {
(*has_bits)[0] |= 1u;
}
static void set_has_model_type(HasBits* has_bits) {
(*has_bits)[0] |= 524288u;
(*has_bits)[0] |= 262144u;
}
static void set_has_vocab_size(HasBits* has_bits) {
(*has_bits)[0] |= 1048576u;
(*has_bits)[0] |= 524288u;
}
static void set_has_self_test_sample_size(HasBits* has_bits) {
(*has_bits)[0] |= 256u;
}
static void set_has_character_coverage(HasBits* has_bits) {
(*has_bits)[0] |= 2097152u;
(*has_bits)[0] |= 1048576u;
}
static void set_has_input_sentence_size(HasBits* has_bits) {
(*has_bits)[0] |= 1024u;
}
static void set_has_shuffle_input_sentence(HasBits* has_bits) {
(*has_bits)[0] |= 268435456u;
(*has_bits)[0] |= 134217728u;
}
static void set_has_mining_sentence_size(HasBits* has_bits) {
(*has_bits)[0] |= 512u;
@ -309,67 +309,64 @@ class TrainerSpec::_Internal {
(*has_bits)[0] |= 2048u;
}
static void set_has_seed_sentencepiece_size(HasBits* has_bits) {
(*has_bits)[0] |= 4194304u;
(*has_bits)[0] |= 2097152u;
}
static void set_has_shrinking_factor(HasBits* has_bits) {
(*has_bits)[0] |= 8388608u;
(*has_bits)[0] |= 4194304u;
}
static void set_has_max_sentence_length(HasBits* has_bits) {
(*has_bits)[0] |= 67108864u;
}
static void set_has_num_threads(HasBits* has_bits) {
(*has_bits)[0] |= 16777216u;
}
static void set_has_num_sub_iterations(HasBits* has_bits) {
(*has_bits)[0] |= 33554432u;
}
static void set_has_num_threads(HasBits* has_bits) {
(*has_bits)[0] |= 8388608u;
}
static void set_has_num_sub_iterations(HasBits* has_bits) {
(*has_bits)[0] |= 16777216u;
}
static void set_has_max_sentencepiece_length(HasBits* has_bits) {
(*has_bits)[0] |= 134217728u;
(*has_bits)[0] |= 67108864u;
}
static void set_has_split_by_unicode_script(HasBits* has_bits) {
(*has_bits)[0] |= 536870912u;
(*has_bits)[0] |= 268435456u;
}
static void set_has_split_by_number(HasBits* has_bits) {
(*has_bits)[0] |= 1073741824u;
(*has_bits)[0] |= 536870912u;
}
static void set_has_split_by_whitespace(HasBits* has_bits) {
(*has_bits)[0] |= 2147483648u;
(*has_bits)[0] |= 1073741824u;
}
static void set_has_treat_whitespace_as_suffix(HasBits* has_bits) {
(*has_bits)[0] |= 4096u;
}
static void set_has_allow_whitespace_only_pieces(HasBits* has_bits) {
(*has_bits)[0] |= 8192u;
}
static void set_has_split_digits(HasBits* has_bits) {
(*has_bits)[0] |= 16384u;
(*has_bits)[0] |= 8192u;
}
static void set_has_required_chars(HasBits* has_bits) {
(*has_bits)[0] |= 4u;
}
static void set_has_byte_fallback(HasBits* has_bits) {
(*has_bits)[0] |= 32768u;
(*has_bits)[0] |= 16384u;
}
static void set_has_vocabulary_output_piece_score(HasBits* has_bits) {
(*has_bits)[1] |= 1u;
(*has_bits)[0] |= 2147483648u;
}
static void set_has_hard_vocab_limit(HasBits* has_bits) {
(*has_bits)[1] |= 2u;
(*has_bits)[1] |= 1u;
}
static void set_has_use_all_vocab(HasBits* has_bits) {
(*has_bits)[0] |= 131072u;
(*has_bits)[0] |= 32768u;
}
static void set_has_unk_id(HasBits* has_bits) {
(*has_bits)[0] |= 65536u;
}
static void set_has_bos_id(HasBits* has_bits) {
(*has_bits)[1] |= 4u;
(*has_bits)[1] |= 2u;
}
static void set_has_eos_id(HasBits* has_bits) {
(*has_bits)[1] |= 8u;
(*has_bits)[1] |= 4u;
}
static void set_has_pad_id(HasBits* has_bits) {
(*has_bits)[1] |= 16u;
(*has_bits)[1] |= 8u;
}
static void set_has_unk_piece(HasBits* has_bits) {
(*has_bits)[0] |= 16u;
@ -387,7 +384,7 @@ class TrainerSpec::_Internal {
(*has_bits)[0] |= 8u;
}
static void set_has_train_extremely_large_corpus(HasBits* has_bits) {
(*has_bits)[0] |= 262144u;
(*has_bits)[0] |= 131072u;
}
};
@ -569,8 +566,8 @@ void TrainerSpec::Clear() {
}
if (cached_has_bits & 0x0000ff00u) {
::memset(&self_test_sample_size_, 0, static_cast<size_t>(
reinterpret_cast<char*>(&byte_fallback_) -
reinterpret_cast<char*>(&self_test_sample_size_)) + sizeof(byte_fallback_));
reinterpret_cast<char*>(&use_all_vocab_) -
reinterpret_cast<char*>(&self_test_sample_size_)) + sizeof(use_all_vocab_));
}
if (cached_has_bits & 0x00ff0000u) {
::memset(&unk_id_, 0, static_cast<size_t>(
@ -581,9 +578,9 @@ void TrainerSpec::Clear() {
character_coverage_ = 0.9995f;
seed_sentencepiece_size_ = 1000000;
shrinking_factor_ = 0.75f;
num_threads_ = 16;
}
if (cached_has_bits & 0xff000000u) {
num_threads_ = 16;
num_sub_iterations_ = 2;
max_sentence_length_ = 4192;
max_sentencepiece_length_ = 16;
@ -591,10 +588,10 @@ void TrainerSpec::Clear() {
split_by_unicode_script_ = true;
split_by_number_ = true;
split_by_whitespace_ = true;
vocabulary_output_piece_score_ = true;
}
cached_has_bits = _has_bits_[1];
if (cached_has_bits & 0x0000001fu) {
vocabulary_output_piece_score_ = true;
if (cached_has_bits & 0x0000000fu) {
hard_vocab_limit_ = true;
bos_id_ = 1;
eos_id_ = 2;
@ -809,14 +806,6 @@ const char* TrainerSpec::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID
CHK_(ptr);
} else goto handle_unusual;
continue;
// optional bool allow_whitespace_only_pieces = 26 [default = false];
case 26:
if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 208)) {
_Internal::set_has_allow_whitespace_only_pieces(&_has_bits_);
allow_whitespace_only_pieces_ = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr);
CHK_(ptr);
} else goto handle_unusual;
continue;
// repeated string control_symbols = 30;
case 30:
if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 242)) {
@ -1011,14 +1000,14 @@ failure:
}
// optional .sentencepiece.TrainerSpec.ModelType model_type = 3 [default = UNIGRAM];
if (cached_has_bits & 0x00080000u) {
if (cached_has_bits & 0x00040000u) {
target = stream->EnsureSpace(target);
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteEnumToArray(
3, this->_internal_model_type(), target);
}
// optional int32 vocab_size = 4 [default = 8000];
if (cached_has_bits & 0x00100000u) {
if (cached_has_bits & 0x00080000u) {
target = stream->EnsureSpace(target);
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(4, this->_internal_vocab_size(), target);
}
@ -1042,7 +1031,7 @@ failure:
}
// optional float character_coverage = 10 [default = 0.9995];
if (cached_has_bits & 0x00200000u) {
if (cached_has_bits & 0x00100000u) {
target = stream->EnsureSpace(target);
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteFloatToArray(10, this->_internal_character_coverage(), target);
}
@ -1066,61 +1055,61 @@ failure:
}
// optional int32 seed_sentencepiece_size = 14 [default = 1000000];
if (cached_has_bits & 0x00400000u) {
if (cached_has_bits & 0x00200000u) {
target = stream->EnsureSpace(target);
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(14, this->_internal_seed_sentencepiece_size(), target);
}
// optional float shrinking_factor = 15 [default = 0.75];
if (cached_has_bits & 0x00800000u) {
if (cached_has_bits & 0x00400000u) {
target = stream->EnsureSpace(target);
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteFloatToArray(15, this->_internal_shrinking_factor(), target);
}
// optional int32 num_threads = 16 [default = 16];
if (cached_has_bits & 0x01000000u) {
if (cached_has_bits & 0x00800000u) {
target = stream->EnsureSpace(target);
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(16, this->_internal_num_threads(), target);
}
// optional int32 num_sub_iterations = 17 [default = 2];
if (cached_has_bits & 0x02000000u) {
if (cached_has_bits & 0x01000000u) {
target = stream->EnsureSpace(target);
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(17, this->_internal_num_sub_iterations(), target);
}
// optional int32 max_sentence_length = 18 [default = 4192];
if (cached_has_bits & 0x04000000u) {
if (cached_has_bits & 0x02000000u) {
target = stream->EnsureSpace(target);
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(18, this->_internal_max_sentence_length(), target);
}
// optional bool shuffle_input_sentence = 19 [default = true];
if (cached_has_bits & 0x10000000u) {
if (cached_has_bits & 0x08000000u) {
target = stream->EnsureSpace(target);
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(19, this->_internal_shuffle_input_sentence(), target);
}
// optional int32 max_sentencepiece_length = 20 [default = 16];
if (cached_has_bits & 0x08000000u) {
if (cached_has_bits & 0x04000000u) {
target = stream->EnsureSpace(target);
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(20, this->_internal_max_sentencepiece_length(), target);
}
// optional bool split_by_unicode_script = 21 [default = true];
if (cached_has_bits & 0x20000000u) {
if (cached_has_bits & 0x10000000u) {
target = stream->EnsureSpace(target);
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(21, this->_internal_split_by_unicode_script(), target);
}
// optional bool split_by_whitespace = 22 [default = true];
if (cached_has_bits & 0x80000000u) {
if (cached_has_bits & 0x40000000u) {
target = stream->EnsureSpace(target);
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(22, this->_internal_split_by_whitespace(), target);
}
// optional bool split_by_number = 23 [default = true];
if (cached_has_bits & 0x40000000u) {
if (cached_has_bits & 0x20000000u) {
target = stream->EnsureSpace(target);
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(23, this->_internal_split_by_number(), target);
}
@ -1132,15 +1121,9 @@ failure:
}
// optional bool split_digits = 25 [default = false];
if (cached_has_bits & 0x00004000u) {
target = stream->EnsureSpace(target);
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(25, this->_internal_split_digits(), target);
}
// optional bool allow_whitespace_only_pieces = 26 [default = false];
if (cached_has_bits & 0x00002000u) {
target = stream->EnsureSpace(target);
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(26, this->_internal_allow_whitespace_only_pieces(), target);
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(25, this->_internal_split_digits(), target);
}
// repeated string control_symbols = 30;
@ -1155,28 +1138,28 @@ failure:
target = stream->WriteString(31, s, target);
}
cached_has_bits = _has_bits_[1];
// optional bool vocabulary_output_piece_score = 32 [default = true];
if (cached_has_bits & 0x00000001u) {
if (cached_has_bits & 0x80000000u) {
target = stream->EnsureSpace(target);
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(32, this->_internal_vocabulary_output_piece_score(), target);
}
cached_has_bits = _has_bits_[1];
// optional bool hard_vocab_limit = 33 [default = true];
if (cached_has_bits & 0x00000002u) {
if (cached_has_bits & 0x00000001u) {
target = stream->EnsureSpace(target);
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(33, this->_internal_hard_vocab_limit(), target);
}
cached_has_bits = _has_bits_[0];
// optional bool use_all_vocab = 34 [default = false];
if (cached_has_bits & 0x00020000u) {
if (cached_has_bits & 0x00008000u) {
target = stream->EnsureSpace(target);
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(34, this->_internal_use_all_vocab(), target);
}
// optional bool byte_fallback = 35 [default = false];
if (cached_has_bits & 0x00008000u) {
if (cached_has_bits & 0x00004000u) {
target = stream->EnsureSpace(target);
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(35, this->_internal_byte_fallback(), target);
}
@ -1195,19 +1178,19 @@ failure:
cached_has_bits = _has_bits_[1];
// optional int32 bos_id = 41 [default = 1];
if (cached_has_bits & 0x00000004u) {
if (cached_has_bits & 0x00000002u) {
target = stream->EnsureSpace(target);
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(41, this->_internal_bos_id(), target);
}
// optional int32 eos_id = 42 [default = 2];
if (cached_has_bits & 0x00000008u) {
if (cached_has_bits & 0x00000004u) {
target = stream->EnsureSpace(target);
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(42, this->_internal_eos_id(), target);
}
// optional int32 pad_id = 43 [default = -1];
if (cached_has_bits & 0x00000010u) {
if (cached_has_bits & 0x00000008u) {
target = stream->EnsureSpace(target);
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(43, this->_internal_pad_id(), target);
}
@ -1244,7 +1227,7 @@ failure:
}
// optional bool train_extremely_large_corpus = 49 [default = false];
if (cached_has_bits & 0x00040000u) {
if (cached_has_bits & 0x00020000u) {
target = stream->EnsureSpace(target);
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(49, this->_internal_train_extremely_large_corpus(), target);
}
@ -1396,17 +1379,17 @@ size_t TrainerSpec::ByteSizeLong() const {
total_size += 2 + 1;
}
// optional bool allow_whitespace_only_pieces = 26 [default = false];
// optional bool split_digits = 25 [default = false];
if (cached_has_bits & 0x00002000u) {
total_size += 2 + 1;
}
// optional bool split_digits = 25 [default = false];
// optional bool byte_fallback = 35 [default = false];
if (cached_has_bits & 0x00004000u) {
total_size += 2 + 1;
}
// optional bool byte_fallback = 35 [default = false];
// optional bool use_all_vocab = 34 [default = false];
if (cached_has_bits & 0x00008000u) {
total_size += 2 + 1;
}
@ -1420,125 +1403,120 @@ size_t TrainerSpec::ByteSizeLong() const {
this->_internal_unk_id());
}
// optional bool use_all_vocab = 34 [default = false];
// optional bool train_extremely_large_corpus = 49 [default = false];
if (cached_has_bits & 0x00020000u) {
total_size += 2 + 1;
}
// optional bool train_extremely_large_corpus = 49 [default = false];
if (cached_has_bits & 0x00040000u) {
total_size += 2 + 1;
}
// optional .sentencepiece.TrainerSpec.ModelType model_type = 3 [default = UNIGRAM];
if (cached_has_bits & 0x00080000u) {
if (cached_has_bits & 0x00040000u) {
total_size += 1 +
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::EnumSize(this->_internal_model_type());
}
// optional int32 vocab_size = 4 [default = 8000];
if (cached_has_bits & 0x00100000u) {
if (cached_has_bits & 0x00080000u) {
total_size += 1 +
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size(
this->_internal_vocab_size());
}
// optional float character_coverage = 10 [default = 0.9995];
if (cached_has_bits & 0x00200000u) {
if (cached_has_bits & 0x00100000u) {
total_size += 1 + 4;
}
// optional int32 seed_sentencepiece_size = 14 [default = 1000000];
if (cached_has_bits & 0x00400000u) {
if (cached_has_bits & 0x00200000u) {
total_size += 1 +
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size(
this->_internal_seed_sentencepiece_size());
}
// optional float shrinking_factor = 15 [default = 0.75];
if (cached_has_bits & 0x00800000u) {
if (cached_has_bits & 0x00400000u) {
total_size += 1 + 4;
}
}
if (cached_has_bits & 0xff000000u) {
// optional int32 num_threads = 16 [default = 16];
if (cached_has_bits & 0x01000000u) {
if (cached_has_bits & 0x00800000u) {
total_size += 2 +
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size(
this->_internal_num_threads());
}
}
if (cached_has_bits & 0xff000000u) {
// optional int32 num_sub_iterations = 17 [default = 2];
if (cached_has_bits & 0x02000000u) {
if (cached_has_bits & 0x01000000u) {
total_size += 2 +
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size(
this->_internal_num_sub_iterations());
}
// optional int32 max_sentence_length = 18 [default = 4192];
if (cached_has_bits & 0x04000000u) {
if (cached_has_bits & 0x02000000u) {
total_size += 2 +
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size(
this->_internal_max_sentence_length());
}
// optional int32 max_sentencepiece_length = 20 [default = 16];
if (cached_has_bits & 0x08000000u) {
if (cached_has_bits & 0x04000000u) {
total_size += 2 +
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size(
this->_internal_max_sentencepiece_length());
}
// optional bool shuffle_input_sentence = 19 [default = true];
if (cached_has_bits & 0x10000000u) {
if (cached_has_bits & 0x08000000u) {
total_size += 2 + 1;
}
// optional bool split_by_unicode_script = 21 [default = true];
if (cached_has_bits & 0x20000000u) {
if (cached_has_bits & 0x10000000u) {
total_size += 2 + 1;
}
// optional bool split_by_number = 23 [default = true];
if (cached_has_bits & 0x40000000u) {
if (cached_has_bits & 0x20000000u) {
total_size += 2 + 1;
}
// optional bool split_by_whitespace = 22 [default = true];
if (cached_has_bits & 0x40000000u) {
total_size += 2 + 1;
}
// optional bool vocabulary_output_piece_score = 32 [default = true];
if (cached_has_bits & 0x80000000u) {
total_size += 2 + 1;
}
}
cached_has_bits = _has_bits_[1];
if (cached_has_bits & 0x0000001fu) {
// optional bool vocabulary_output_piece_score = 32 [default = true];
if (cached_has_bits & 0x0000000fu) {
// optional bool hard_vocab_limit = 33 [default = true];
if (cached_has_bits & 0x00000001u) {
total_size += 2 + 1;
}
// optional bool hard_vocab_limit = 33 [default = true];
if (cached_has_bits & 0x00000002u) {
total_size += 2 + 1;
}
// optional int32 bos_id = 41 [default = 1];
if (cached_has_bits & 0x00000004u) {
if (cached_has_bits & 0x00000002u) {
total_size += 2 +
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size(
this->_internal_bos_id());
}
// optional int32 eos_id = 42 [default = 2];
if (cached_has_bits & 0x00000008u) {
if (cached_has_bits & 0x00000004u) {
total_size += 2 +
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size(
this->_internal_eos_id());
}
// optional int32 pad_id = 43 [default = -1];
if (cached_has_bits & 0x00000010u) {
if (cached_has_bits & 0x00000008u) {
total_size += 2 +
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size(
this->_internal_pad_id());
@ -1615,14 +1593,14 @@ void TrainerSpec::MergeFrom(const TrainerSpec& from) {
treat_whitespace_as_suffix_ = from.treat_whitespace_as_suffix_;
}
if (cached_has_bits & 0x00002000u) {
allow_whitespace_only_pieces_ = from.allow_whitespace_only_pieces_;
}
if (cached_has_bits & 0x00004000u) {
split_digits_ = from.split_digits_;
}
if (cached_has_bits & 0x00008000u) {
if (cached_has_bits & 0x00004000u) {
byte_fallback_ = from.byte_fallback_;
}
if (cached_has_bits & 0x00008000u) {
use_all_vocab_ = from.use_all_vocab_;
}
_has_bits_[0] |= cached_has_bits;
}
if (cached_has_bits & 0x00ff0000u) {
@ -1630,70 +1608,67 @@ void TrainerSpec::MergeFrom(const TrainerSpec& from) {
unk_id_ = from.unk_id_;
}
if (cached_has_bits & 0x00020000u) {
use_all_vocab_ = from.use_all_vocab_;
}
if (cached_has_bits & 0x00040000u) {
train_extremely_large_corpus_ = from.train_extremely_large_corpus_;
}
if (cached_has_bits & 0x00080000u) {
if (cached_has_bits & 0x00040000u) {
model_type_ = from.model_type_;
}
if (cached_has_bits & 0x00100000u) {
if (cached_has_bits & 0x00080000u) {
vocab_size_ = from.vocab_size_;
}
if (cached_has_bits & 0x00200000u) {
if (cached_has_bits & 0x00100000u) {
character_coverage_ = from.character_coverage_;
}
if (cached_has_bits & 0x00400000u) {
if (cached_has_bits & 0x00200000u) {
seed_sentencepiece_size_ = from.seed_sentencepiece_size_;
}
if (cached_has_bits & 0x00800000u) {
if (cached_has_bits & 0x00400000u) {
shrinking_factor_ = from.shrinking_factor_;
}
if (cached_has_bits & 0x00800000u) {
num_threads_ = from.num_threads_;
}
_has_bits_[0] |= cached_has_bits;
}
if (cached_has_bits & 0xff000000u) {
if (cached_has_bits & 0x01000000u) {
num_threads_ = from.num_threads_;
}
if (cached_has_bits & 0x02000000u) {
num_sub_iterations_ = from.num_sub_iterations_;
}
if (cached_has_bits & 0x04000000u) {
if (cached_has_bits & 0x02000000u) {
max_sentence_length_ = from.max_sentence_length_;
}
if (cached_has_bits & 0x08000000u) {
if (cached_has_bits & 0x04000000u) {
max_sentencepiece_length_ = from.max_sentencepiece_length_;
}
if (cached_has_bits & 0x10000000u) {
if (cached_has_bits & 0x08000000u) {
shuffle_input_sentence_ = from.shuffle_input_sentence_;
}
if (cached_has_bits & 0x20000000u) {
if (cached_has_bits & 0x10000000u) {
split_by_unicode_script_ = from.split_by_unicode_script_;
}
if (cached_has_bits & 0x40000000u) {
if (cached_has_bits & 0x20000000u) {
split_by_number_ = from.split_by_number_;
}
if (cached_has_bits & 0x80000000u) {
if (cached_has_bits & 0x40000000u) {
split_by_whitespace_ = from.split_by_whitespace_;
}
if (cached_has_bits & 0x80000000u) {
vocabulary_output_piece_score_ = from.vocabulary_output_piece_score_;
}
_has_bits_[0] |= cached_has_bits;
}
cached_has_bits = from._has_bits_[1];
if (cached_has_bits & 0x0000001fu) {
if (cached_has_bits & 0x0000000fu) {
if (cached_has_bits & 0x00000001u) {
vocabulary_output_piece_score_ = from.vocabulary_output_piece_score_;
}
if (cached_has_bits & 0x00000002u) {
hard_vocab_limit_ = from.hard_vocab_limit_;
}
if (cached_has_bits & 0x00000004u) {
if (cached_has_bits & 0x00000002u) {
bos_id_ = from.bos_id_;
}
if (cached_has_bits & 0x00000008u) {
if (cached_has_bits & 0x00000004u) {
eos_id_ = from.eos_id_;
}
if (cached_has_bits & 0x00000010u) {
if (cached_has_bits & 0x00000008u) {
pad_id_ = from.pad_id_;
}
_has_bits_[1] |= cached_has_bits;

View File

@ -278,11 +278,10 @@ class TrainerSpec PROTOBUF_FINAL :
kInputSentenceSizeFieldNumber = 11,
kTrainingSentenceSizeFieldNumber = 13,
kTreatWhitespaceAsSuffixFieldNumber = 24,
kAllowWhitespaceOnlyPiecesFieldNumber = 26,
kSplitDigitsFieldNumber = 25,
kByteFallbackFieldNumber = 35,
kUnkIdFieldNumber = 40,
kUseAllVocabFieldNumber = 34,
kUnkIdFieldNumber = 40,
kTrainExtremelyLargeCorpusFieldNumber = 49,
kModelTypeFieldNumber = 3,
kVocabSizeFieldNumber = 4,
@ -624,19 +623,6 @@ class TrainerSpec PROTOBUF_FINAL :
void _internal_set_treat_whitespace_as_suffix(bool value);
public:
// optional bool allow_whitespace_only_pieces = 26 [default = false];
bool has_allow_whitespace_only_pieces() const;
private:
bool _internal_has_allow_whitespace_only_pieces() const;
public:
void clear_allow_whitespace_only_pieces();
bool allow_whitespace_only_pieces() const;
void set_allow_whitespace_only_pieces(bool value);
private:
bool _internal_allow_whitespace_only_pieces() const;
void _internal_set_allow_whitespace_only_pieces(bool value);
public:
// optional bool split_digits = 25 [default = false];
bool has_split_digits() const;
private:
@ -663,19 +649,6 @@ class TrainerSpec PROTOBUF_FINAL :
void _internal_set_byte_fallback(bool value);
public:
// optional int32 unk_id = 40 [default = 0];
bool has_unk_id() const;
private:
bool _internal_has_unk_id() const;
public:
void clear_unk_id();
::PROTOBUF_NAMESPACE_ID::int32 unk_id() const;
void set_unk_id(::PROTOBUF_NAMESPACE_ID::int32 value);
private:
::PROTOBUF_NAMESPACE_ID::int32 _internal_unk_id() const;
void _internal_set_unk_id(::PROTOBUF_NAMESPACE_ID::int32 value);
public:
// optional bool use_all_vocab = 34 [default = false];
bool has_use_all_vocab() const;
private:
@ -689,6 +662,19 @@ class TrainerSpec PROTOBUF_FINAL :
void _internal_set_use_all_vocab(bool value);
public:
// optional int32 unk_id = 40 [default = 0];
bool has_unk_id() const;
private:
bool _internal_has_unk_id() const;
public:
void clear_unk_id();
::PROTOBUF_NAMESPACE_ID::int32 unk_id() const;
void set_unk_id(::PROTOBUF_NAMESPACE_ID::int32 value);
private:
::PROTOBUF_NAMESPACE_ID::int32 _internal_unk_id() const;
void _internal_set_unk_id(::PROTOBUF_NAMESPACE_ID::int32 value);
public:
// optional bool train_extremely_large_corpus = 49 [default = false];
bool has_train_extremely_large_corpus() const;
private:
@ -970,11 +956,10 @@ class TrainerSpec PROTOBUF_FINAL :
::PROTOBUF_NAMESPACE_ID::uint64 input_sentence_size_;
::PROTOBUF_NAMESPACE_ID::int32 training_sentence_size_;
bool treat_whitespace_as_suffix_;
bool allow_whitespace_only_pieces_;
bool split_digits_;
bool byte_fallback_;
::PROTOBUF_NAMESPACE_ID::int32 unk_id_;
bool use_all_vocab_;
::PROTOBUF_NAMESPACE_ID::int32 unk_id_;
bool train_extremely_large_corpus_;
int model_type_;
::PROTOBUF_NAMESPACE_ID::int32 vocab_size_;
@ -2195,7 +2180,7 @@ inline void TrainerSpec::set_allocated_model_prefix(std::string* model_prefix) {
// optional .sentencepiece.TrainerSpec.ModelType model_type = 3 [default = UNIGRAM];
inline bool TrainerSpec::_internal_has_model_type() const {
bool value = (_has_bits_[0] & 0x00080000u) != 0;
bool value = (_has_bits_[0] & 0x00040000u) != 0;
return value;
}
inline bool TrainerSpec::has_model_type() const {
@ -2203,7 +2188,7 @@ inline bool TrainerSpec::has_model_type() const {
}
inline void TrainerSpec::clear_model_type() {
model_type_ = 1;
_has_bits_[0] &= ~0x00080000u;
_has_bits_[0] &= ~0x00040000u;
}
inline ::sentencepiece::TrainerSpec_ModelType TrainerSpec::_internal_model_type() const {
return static_cast< ::sentencepiece::TrainerSpec_ModelType >(model_type_);
@ -2214,7 +2199,7 @@ inline ::sentencepiece::TrainerSpec_ModelType TrainerSpec::model_type() const {
}
inline void TrainerSpec::_internal_set_model_type(::sentencepiece::TrainerSpec_ModelType value) {
assert(::sentencepiece::TrainerSpec_ModelType_IsValid(value));
_has_bits_[0] |= 0x00080000u;
_has_bits_[0] |= 0x00040000u;
model_type_ = value;
}
inline void TrainerSpec::set_model_type(::sentencepiece::TrainerSpec_ModelType value) {
@ -2224,7 +2209,7 @@ inline void TrainerSpec::set_model_type(::sentencepiece::TrainerSpec_ModelType v
// optional int32 vocab_size = 4 [default = 8000];
inline bool TrainerSpec::_internal_has_vocab_size() const {
bool value = (_has_bits_[0] & 0x00100000u) != 0;
bool value = (_has_bits_[0] & 0x00080000u) != 0;
return value;
}
inline bool TrainerSpec::has_vocab_size() const {
@ -2232,7 +2217,7 @@ inline bool TrainerSpec::has_vocab_size() const {
}
inline void TrainerSpec::clear_vocab_size() {
vocab_size_ = 8000;
_has_bits_[0] &= ~0x00100000u;
_has_bits_[0] &= ~0x00080000u;
}
inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_vocab_size() const {
return vocab_size_;
@ -2242,7 +2227,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::vocab_size() const {
return _internal_vocab_size();
}
inline void TrainerSpec::_internal_set_vocab_size(::PROTOBUF_NAMESPACE_ID::int32 value) {
_has_bits_[0] |= 0x00100000u;
_has_bits_[0] |= 0x00080000u;
vocab_size_ = value;
}
inline void TrainerSpec::set_vocab_size(::PROTOBUF_NAMESPACE_ID::int32 value) {
@ -2354,7 +2339,7 @@ inline void TrainerSpec::set_self_test_sample_size(::PROTOBUF_NAMESPACE_ID::int3
// optional float character_coverage = 10 [default = 0.9995];
inline bool TrainerSpec::_internal_has_character_coverage() const {
bool value = (_has_bits_[0] & 0x00200000u) != 0;
bool value = (_has_bits_[0] & 0x00100000u) != 0;
return value;
}
inline bool TrainerSpec::has_character_coverage() const {
@ -2362,7 +2347,7 @@ inline bool TrainerSpec::has_character_coverage() const {
}
inline void TrainerSpec::clear_character_coverage() {
character_coverage_ = 0.9995f;
_has_bits_[0] &= ~0x00200000u;
_has_bits_[0] &= ~0x00100000u;
}
inline float TrainerSpec::_internal_character_coverage() const {
return character_coverage_;
@ -2372,7 +2357,7 @@ inline float TrainerSpec::character_coverage() const {
return _internal_character_coverage();
}
inline void TrainerSpec::_internal_set_character_coverage(float value) {
_has_bits_[0] |= 0x00200000u;
_has_bits_[0] |= 0x00100000u;
character_coverage_ = value;
}
inline void TrainerSpec::set_character_coverage(float value) {
@ -2410,7 +2395,7 @@ inline void TrainerSpec::set_input_sentence_size(::PROTOBUF_NAMESPACE_ID::uint64
// optional bool shuffle_input_sentence = 19 [default = true];
inline bool TrainerSpec::_internal_has_shuffle_input_sentence() const {
bool value = (_has_bits_[0] & 0x10000000u) != 0;
bool value = (_has_bits_[0] & 0x08000000u) != 0;
return value;
}
inline bool TrainerSpec::has_shuffle_input_sentence() const {
@ -2418,7 +2403,7 @@ inline bool TrainerSpec::has_shuffle_input_sentence() const {
}
inline void TrainerSpec::clear_shuffle_input_sentence() {
shuffle_input_sentence_ = true;
_has_bits_[0] &= ~0x10000000u;
_has_bits_[0] &= ~0x08000000u;
}
inline bool TrainerSpec::_internal_shuffle_input_sentence() const {
return shuffle_input_sentence_;
@ -2428,7 +2413,7 @@ inline bool TrainerSpec::shuffle_input_sentence() const {
return _internal_shuffle_input_sentence();
}
inline void TrainerSpec::_internal_set_shuffle_input_sentence(bool value) {
_has_bits_[0] |= 0x10000000u;
_has_bits_[0] |= 0x08000000u;
shuffle_input_sentence_ = value;
}
inline void TrainerSpec::set_shuffle_input_sentence(bool value) {
@ -2494,7 +2479,7 @@ inline void TrainerSpec::set_training_sentence_size(::PROTOBUF_NAMESPACE_ID::int
// optional int32 seed_sentencepiece_size = 14 [default = 1000000];
inline bool TrainerSpec::_internal_has_seed_sentencepiece_size() const {
bool value = (_has_bits_[0] & 0x00400000u) != 0;
bool value = (_has_bits_[0] & 0x00200000u) != 0;
return value;
}
inline bool TrainerSpec::has_seed_sentencepiece_size() const {
@ -2502,7 +2487,7 @@ inline bool TrainerSpec::has_seed_sentencepiece_size() const {
}
inline void TrainerSpec::clear_seed_sentencepiece_size() {
seed_sentencepiece_size_ = 1000000;
_has_bits_[0] &= ~0x00400000u;
_has_bits_[0] &= ~0x00200000u;
}
inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_seed_sentencepiece_size() const {
return seed_sentencepiece_size_;
@ -2512,7 +2497,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::seed_sentencepiece_size() con
return _internal_seed_sentencepiece_size();
}
inline void TrainerSpec::_internal_set_seed_sentencepiece_size(::PROTOBUF_NAMESPACE_ID::int32 value) {
_has_bits_[0] |= 0x00400000u;
_has_bits_[0] |= 0x00200000u;
seed_sentencepiece_size_ = value;
}
inline void TrainerSpec::set_seed_sentencepiece_size(::PROTOBUF_NAMESPACE_ID::int32 value) {
@ -2522,7 +2507,7 @@ inline void TrainerSpec::set_seed_sentencepiece_size(::PROTOBUF_NAMESPACE_ID::in
// optional float shrinking_factor = 15 [default = 0.75];
inline bool TrainerSpec::_internal_has_shrinking_factor() const {
bool value = (_has_bits_[0] & 0x00800000u) != 0;
bool value = (_has_bits_[0] & 0x00400000u) != 0;
return value;
}
inline bool TrainerSpec::has_shrinking_factor() const {
@ -2530,7 +2515,7 @@ inline bool TrainerSpec::has_shrinking_factor() const {
}
inline void TrainerSpec::clear_shrinking_factor() {
shrinking_factor_ = 0.75f;
_has_bits_[0] &= ~0x00800000u;
_has_bits_[0] &= ~0x00400000u;
}
inline float TrainerSpec::_internal_shrinking_factor() const {
return shrinking_factor_;
@ -2540,7 +2525,7 @@ inline float TrainerSpec::shrinking_factor() const {
return _internal_shrinking_factor();
}
inline void TrainerSpec::_internal_set_shrinking_factor(float value) {
_has_bits_[0] |= 0x00800000u;
_has_bits_[0] |= 0x00400000u;
shrinking_factor_ = value;
}
inline void TrainerSpec::set_shrinking_factor(float value) {
@ -2550,7 +2535,7 @@ inline void TrainerSpec::set_shrinking_factor(float value) {
// optional int32 max_sentence_length = 18 [default = 4192];
inline bool TrainerSpec::_internal_has_max_sentence_length() const {
bool value = (_has_bits_[0] & 0x04000000u) != 0;
bool value = (_has_bits_[0] & 0x02000000u) != 0;
return value;
}
inline bool TrainerSpec::has_max_sentence_length() const {
@ -2558,7 +2543,7 @@ inline bool TrainerSpec::has_max_sentence_length() const {
}
inline void TrainerSpec::clear_max_sentence_length() {
max_sentence_length_ = 4192;
_has_bits_[0] &= ~0x04000000u;
_has_bits_[0] &= ~0x02000000u;
}
inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_max_sentence_length() const {
return max_sentence_length_;
@ -2568,7 +2553,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::max_sentence_length() const {
return _internal_max_sentence_length();
}
inline void TrainerSpec::_internal_set_max_sentence_length(::PROTOBUF_NAMESPACE_ID::int32 value) {
_has_bits_[0] |= 0x04000000u;
_has_bits_[0] |= 0x02000000u;
max_sentence_length_ = value;
}
inline void TrainerSpec::set_max_sentence_length(::PROTOBUF_NAMESPACE_ID::int32 value) {
@ -2578,7 +2563,7 @@ inline void TrainerSpec::set_max_sentence_length(::PROTOBUF_NAMESPACE_ID::int32
// optional int32 num_threads = 16 [default = 16];
inline bool TrainerSpec::_internal_has_num_threads() const {
bool value = (_has_bits_[0] & 0x01000000u) != 0;
bool value = (_has_bits_[0] & 0x00800000u) != 0;
return value;
}
inline bool TrainerSpec::has_num_threads() const {
@ -2586,7 +2571,7 @@ inline bool TrainerSpec::has_num_threads() const {
}
inline void TrainerSpec::clear_num_threads() {
num_threads_ = 16;
_has_bits_[0] &= ~0x01000000u;
_has_bits_[0] &= ~0x00800000u;
}
inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_num_threads() const {
return num_threads_;
@ -2596,7 +2581,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::num_threads() const {
return _internal_num_threads();
}
inline void TrainerSpec::_internal_set_num_threads(::PROTOBUF_NAMESPACE_ID::int32 value) {
_has_bits_[0] |= 0x01000000u;
_has_bits_[0] |= 0x00800000u;
num_threads_ = value;
}
inline void TrainerSpec::set_num_threads(::PROTOBUF_NAMESPACE_ID::int32 value) {
@ -2606,7 +2591,7 @@ inline void TrainerSpec::set_num_threads(::PROTOBUF_NAMESPACE_ID::int32 value) {
// optional int32 num_sub_iterations = 17 [default = 2];
inline bool TrainerSpec::_internal_has_num_sub_iterations() const {
bool value = (_has_bits_[0] & 0x02000000u) != 0;
bool value = (_has_bits_[0] & 0x01000000u) != 0;
return value;
}
inline bool TrainerSpec::has_num_sub_iterations() const {
@ -2614,7 +2599,7 @@ inline bool TrainerSpec::has_num_sub_iterations() const {
}
inline void TrainerSpec::clear_num_sub_iterations() {
num_sub_iterations_ = 2;
_has_bits_[0] &= ~0x02000000u;
_has_bits_[0] &= ~0x01000000u;
}
inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_num_sub_iterations() const {
return num_sub_iterations_;
@ -2624,7 +2609,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::num_sub_iterations() const {
return _internal_num_sub_iterations();
}
inline void TrainerSpec::_internal_set_num_sub_iterations(::PROTOBUF_NAMESPACE_ID::int32 value) {
_has_bits_[0] |= 0x02000000u;
_has_bits_[0] |= 0x01000000u;
num_sub_iterations_ = value;
}
inline void TrainerSpec::set_num_sub_iterations(::PROTOBUF_NAMESPACE_ID::int32 value) {
@ -2634,7 +2619,7 @@ inline void TrainerSpec::set_num_sub_iterations(::PROTOBUF_NAMESPACE_ID::int32 v
// optional int32 max_sentencepiece_length = 20 [default = 16];
inline bool TrainerSpec::_internal_has_max_sentencepiece_length() const {
bool value = (_has_bits_[0] & 0x08000000u) != 0;
bool value = (_has_bits_[0] & 0x04000000u) != 0;
return value;
}
inline bool TrainerSpec::has_max_sentencepiece_length() const {
@ -2642,7 +2627,7 @@ inline bool TrainerSpec::has_max_sentencepiece_length() const {
}
inline void TrainerSpec::clear_max_sentencepiece_length() {
max_sentencepiece_length_ = 16;
_has_bits_[0] &= ~0x08000000u;
_has_bits_[0] &= ~0x04000000u;
}
inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_max_sentencepiece_length() const {
return max_sentencepiece_length_;
@ -2652,7 +2637,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::max_sentencepiece_length() co
return _internal_max_sentencepiece_length();
}
inline void TrainerSpec::_internal_set_max_sentencepiece_length(::PROTOBUF_NAMESPACE_ID::int32 value) {
_has_bits_[0] |= 0x08000000u;
_has_bits_[0] |= 0x04000000u;
max_sentencepiece_length_ = value;
}
inline void TrainerSpec::set_max_sentencepiece_length(::PROTOBUF_NAMESPACE_ID::int32 value) {
@ -2662,7 +2647,7 @@ inline void TrainerSpec::set_max_sentencepiece_length(::PROTOBUF_NAMESPACE_ID::i
// optional bool split_by_unicode_script = 21 [default = true];
inline bool TrainerSpec::_internal_has_split_by_unicode_script() const {
bool value = (_has_bits_[0] & 0x20000000u) != 0;
bool value = (_has_bits_[0] & 0x10000000u) != 0;
return value;
}
inline bool TrainerSpec::has_split_by_unicode_script() const {
@ -2670,7 +2655,7 @@ inline bool TrainerSpec::has_split_by_unicode_script() const {
}
inline void TrainerSpec::clear_split_by_unicode_script() {
split_by_unicode_script_ = true;
_has_bits_[0] &= ~0x20000000u;
_has_bits_[0] &= ~0x10000000u;
}
inline bool TrainerSpec::_internal_split_by_unicode_script() const {
return split_by_unicode_script_;
@ -2680,7 +2665,7 @@ inline bool TrainerSpec::split_by_unicode_script() const {
return _internal_split_by_unicode_script();
}
inline void TrainerSpec::_internal_set_split_by_unicode_script(bool value) {
_has_bits_[0] |= 0x20000000u;
_has_bits_[0] |= 0x10000000u;
split_by_unicode_script_ = value;
}
inline void TrainerSpec::set_split_by_unicode_script(bool value) {
@ -2690,7 +2675,7 @@ inline void TrainerSpec::set_split_by_unicode_script(bool value) {
// optional bool split_by_number = 23 [default = true];
inline bool TrainerSpec::_internal_has_split_by_number() const {
bool value = (_has_bits_[0] & 0x40000000u) != 0;
bool value = (_has_bits_[0] & 0x20000000u) != 0;
return value;
}
inline bool TrainerSpec::has_split_by_number() const {
@ -2698,7 +2683,7 @@ inline bool TrainerSpec::has_split_by_number() const {
}
inline void TrainerSpec::clear_split_by_number() {
split_by_number_ = true;
_has_bits_[0] &= ~0x40000000u;
_has_bits_[0] &= ~0x20000000u;
}
inline bool TrainerSpec::_internal_split_by_number() const {
return split_by_number_;
@ -2708,7 +2693,7 @@ inline bool TrainerSpec::split_by_number() const {
return _internal_split_by_number();
}
inline void TrainerSpec::_internal_set_split_by_number(bool value) {
_has_bits_[0] |= 0x40000000u;
_has_bits_[0] |= 0x20000000u;
split_by_number_ = value;
}
inline void TrainerSpec::set_split_by_number(bool value) {
@ -2718,7 +2703,7 @@ inline void TrainerSpec::set_split_by_number(bool value) {
// optional bool split_by_whitespace = 22 [default = true];
inline bool TrainerSpec::_internal_has_split_by_whitespace() const {
bool value = (_has_bits_[0] & 0x80000000u) != 0;
bool value = (_has_bits_[0] & 0x40000000u) != 0;
return value;
}
inline bool TrainerSpec::has_split_by_whitespace() const {
@ -2726,7 +2711,7 @@ inline bool TrainerSpec::has_split_by_whitespace() const {
}
inline void TrainerSpec::clear_split_by_whitespace() {
split_by_whitespace_ = true;
_has_bits_[0] &= ~0x80000000u;
_has_bits_[0] &= ~0x40000000u;
}
inline bool TrainerSpec::_internal_split_by_whitespace() const {
return split_by_whitespace_;
@ -2736,7 +2721,7 @@ inline bool TrainerSpec::split_by_whitespace() const {
return _internal_split_by_whitespace();
}
inline void TrainerSpec::_internal_set_split_by_whitespace(bool value) {
_has_bits_[0] |= 0x80000000u;
_has_bits_[0] |= 0x40000000u;
split_by_whitespace_ = value;
}
inline void TrainerSpec::set_split_by_whitespace(bool value) {
@ -2772,37 +2757,9 @@ inline void TrainerSpec::set_treat_whitespace_as_suffix(bool value) {
// @@protoc_insertion_point(field_set:sentencepiece.TrainerSpec.treat_whitespace_as_suffix)
}
// optional bool allow_whitespace_only_pieces = 26 [default = false];
inline bool TrainerSpec::_internal_has_allow_whitespace_only_pieces() const {
bool value = (_has_bits_[0] & 0x00002000u) != 0;
return value;
}
inline bool TrainerSpec::has_allow_whitespace_only_pieces() const {
return _internal_has_allow_whitespace_only_pieces();
}
inline void TrainerSpec::clear_allow_whitespace_only_pieces() {
allow_whitespace_only_pieces_ = false;
_has_bits_[0] &= ~0x00002000u;
}
inline bool TrainerSpec::_internal_allow_whitespace_only_pieces() const {
return allow_whitespace_only_pieces_;
}
inline bool TrainerSpec::allow_whitespace_only_pieces() const {
// @@protoc_insertion_point(field_get:sentencepiece.TrainerSpec.allow_whitespace_only_pieces)
return _internal_allow_whitespace_only_pieces();
}
inline void TrainerSpec::_internal_set_allow_whitespace_only_pieces(bool value) {
_has_bits_[0] |= 0x00002000u;
allow_whitespace_only_pieces_ = value;
}
inline void TrainerSpec::set_allow_whitespace_only_pieces(bool value) {
_internal_set_allow_whitespace_only_pieces(value);
// @@protoc_insertion_point(field_set:sentencepiece.TrainerSpec.allow_whitespace_only_pieces)
}
// optional bool split_digits = 25 [default = false];
inline bool TrainerSpec::_internal_has_split_digits() const {
bool value = (_has_bits_[0] & 0x00004000u) != 0;
bool value = (_has_bits_[0] & 0x00002000u) != 0;
return value;
}
inline bool TrainerSpec::has_split_digits() const {
@ -2810,7 +2767,7 @@ inline bool TrainerSpec::has_split_digits() const {
}
inline void TrainerSpec::clear_split_digits() {
split_digits_ = false;
_has_bits_[0] &= ~0x00004000u;
_has_bits_[0] &= ~0x00002000u;
}
inline bool TrainerSpec::_internal_split_digits() const {
return split_digits_;
@ -2820,7 +2777,7 @@ inline bool TrainerSpec::split_digits() const {
return _internal_split_digits();
}
inline void TrainerSpec::_internal_set_split_digits(bool value) {
_has_bits_[0] |= 0x00004000u;
_has_bits_[0] |= 0x00002000u;
split_digits_ = value;
}
inline void TrainerSpec::set_split_digits(bool value) {
@ -3051,7 +3008,7 @@ inline void TrainerSpec::set_allocated_required_chars(std::string* required_char
// optional bool byte_fallback = 35 [default = false];
inline bool TrainerSpec::_internal_has_byte_fallback() const {
bool value = (_has_bits_[0] & 0x00008000u) != 0;
bool value = (_has_bits_[0] & 0x00004000u) != 0;
return value;
}
inline bool TrainerSpec::has_byte_fallback() const {
@ -3059,7 +3016,7 @@ inline bool TrainerSpec::has_byte_fallback() const {
}
inline void TrainerSpec::clear_byte_fallback() {
byte_fallback_ = false;
_has_bits_[0] &= ~0x00008000u;
_has_bits_[0] &= ~0x00004000u;
}
inline bool TrainerSpec::_internal_byte_fallback() const {
return byte_fallback_;
@ -3069,7 +3026,7 @@ inline bool TrainerSpec::byte_fallback() const {
return _internal_byte_fallback();
}
inline void TrainerSpec::_internal_set_byte_fallback(bool value) {
_has_bits_[0] |= 0x00008000u;
_has_bits_[0] |= 0x00004000u;
byte_fallback_ = value;
}
inline void TrainerSpec::set_byte_fallback(bool value) {
@ -3079,7 +3036,7 @@ inline void TrainerSpec::set_byte_fallback(bool value) {
// optional bool vocabulary_output_piece_score = 32 [default = true];
inline bool TrainerSpec::_internal_has_vocabulary_output_piece_score() const {
bool value = (_has_bits_[1] & 0x00000001u) != 0;
bool value = (_has_bits_[0] & 0x80000000u) != 0;
return value;
}
inline bool TrainerSpec::has_vocabulary_output_piece_score() const {
@ -3087,7 +3044,7 @@ inline bool TrainerSpec::has_vocabulary_output_piece_score() const {
}
inline void TrainerSpec::clear_vocabulary_output_piece_score() {
vocabulary_output_piece_score_ = true;
_has_bits_[1] &= ~0x00000001u;
_has_bits_[0] &= ~0x80000000u;
}
inline bool TrainerSpec::_internal_vocabulary_output_piece_score() const {
return vocabulary_output_piece_score_;
@ -3097,7 +3054,7 @@ inline bool TrainerSpec::vocabulary_output_piece_score() const {
return _internal_vocabulary_output_piece_score();
}
inline void TrainerSpec::_internal_set_vocabulary_output_piece_score(bool value) {
_has_bits_[1] |= 0x00000001u;
_has_bits_[0] |= 0x80000000u;
vocabulary_output_piece_score_ = value;
}
inline void TrainerSpec::set_vocabulary_output_piece_score(bool value) {
@ -3107,7 +3064,7 @@ inline void TrainerSpec::set_vocabulary_output_piece_score(bool value) {
// optional bool hard_vocab_limit = 33 [default = true];
inline bool TrainerSpec::_internal_has_hard_vocab_limit() const {
bool value = (_has_bits_[1] & 0x00000002u) != 0;
bool value = (_has_bits_[1] & 0x00000001u) != 0;
return value;
}
inline bool TrainerSpec::has_hard_vocab_limit() const {
@ -3115,7 +3072,7 @@ inline bool TrainerSpec::has_hard_vocab_limit() const {
}
inline void TrainerSpec::clear_hard_vocab_limit() {
hard_vocab_limit_ = true;
_has_bits_[1] &= ~0x00000002u;
_has_bits_[1] &= ~0x00000001u;
}
inline bool TrainerSpec::_internal_hard_vocab_limit() const {
return hard_vocab_limit_;
@ -3125,7 +3082,7 @@ inline bool TrainerSpec::hard_vocab_limit() const {
return _internal_hard_vocab_limit();
}
inline void TrainerSpec::_internal_set_hard_vocab_limit(bool value) {
_has_bits_[1] |= 0x00000002u;
_has_bits_[1] |= 0x00000001u;
hard_vocab_limit_ = value;
}
inline void TrainerSpec::set_hard_vocab_limit(bool value) {
@ -3135,7 +3092,7 @@ inline void TrainerSpec::set_hard_vocab_limit(bool value) {
// optional bool use_all_vocab = 34 [default = false];
inline bool TrainerSpec::_internal_has_use_all_vocab() const {
bool value = (_has_bits_[0] & 0x00020000u) != 0;
bool value = (_has_bits_[0] & 0x00008000u) != 0;
return value;
}
inline bool TrainerSpec::has_use_all_vocab() const {
@ -3143,7 +3100,7 @@ inline bool TrainerSpec::has_use_all_vocab() const {
}
inline void TrainerSpec::clear_use_all_vocab() {
use_all_vocab_ = false;
_has_bits_[0] &= ~0x00020000u;
_has_bits_[0] &= ~0x00008000u;
}
inline bool TrainerSpec::_internal_use_all_vocab() const {
return use_all_vocab_;
@ -3153,7 +3110,7 @@ inline bool TrainerSpec::use_all_vocab() const {
return _internal_use_all_vocab();
}
inline void TrainerSpec::_internal_set_use_all_vocab(bool value) {
_has_bits_[0] |= 0x00020000u;
_has_bits_[0] |= 0x00008000u;
use_all_vocab_ = value;
}
inline void TrainerSpec::set_use_all_vocab(bool value) {
@ -3191,7 +3148,7 @@ inline void TrainerSpec::set_unk_id(::PROTOBUF_NAMESPACE_ID::int32 value) {
// optional int32 bos_id = 41 [default = 1];
inline bool TrainerSpec::_internal_has_bos_id() const {
bool value = (_has_bits_[1] & 0x00000004u) != 0;
bool value = (_has_bits_[1] & 0x00000002u) != 0;
return value;
}
inline bool TrainerSpec::has_bos_id() const {
@ -3199,7 +3156,7 @@ inline bool TrainerSpec::has_bos_id() const {
}
inline void TrainerSpec::clear_bos_id() {
bos_id_ = 1;
_has_bits_[1] &= ~0x00000004u;
_has_bits_[1] &= ~0x00000002u;
}
inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_bos_id() const {
return bos_id_;
@ -3209,7 +3166,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::bos_id() const {
return _internal_bos_id();
}
inline void TrainerSpec::_internal_set_bos_id(::PROTOBUF_NAMESPACE_ID::int32 value) {
_has_bits_[1] |= 0x00000004u;
_has_bits_[1] |= 0x00000002u;
bos_id_ = value;
}
inline void TrainerSpec::set_bos_id(::PROTOBUF_NAMESPACE_ID::int32 value) {
@ -3219,7 +3176,7 @@ inline void TrainerSpec::set_bos_id(::PROTOBUF_NAMESPACE_ID::int32 value) {
// optional int32 eos_id = 42 [default = 2];
inline bool TrainerSpec::_internal_has_eos_id() const {
bool value = (_has_bits_[1] & 0x00000008u) != 0;
bool value = (_has_bits_[1] & 0x00000004u) != 0;
return value;
}
inline bool TrainerSpec::has_eos_id() const {
@ -3227,7 +3184,7 @@ inline bool TrainerSpec::has_eos_id() const {
}
inline void TrainerSpec::clear_eos_id() {
eos_id_ = 2;
_has_bits_[1] &= ~0x00000008u;
_has_bits_[1] &= ~0x00000004u;
}
inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_eos_id() const {
return eos_id_;
@ -3237,7 +3194,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::eos_id() const {
return _internal_eos_id();
}
inline void TrainerSpec::_internal_set_eos_id(::PROTOBUF_NAMESPACE_ID::int32 value) {
_has_bits_[1] |= 0x00000008u;
_has_bits_[1] |= 0x00000004u;
eos_id_ = value;
}
inline void TrainerSpec::set_eos_id(::PROTOBUF_NAMESPACE_ID::int32 value) {
@ -3247,7 +3204,7 @@ inline void TrainerSpec::set_eos_id(::PROTOBUF_NAMESPACE_ID::int32 value) {
// optional int32 pad_id = 43 [default = -1];
inline bool TrainerSpec::_internal_has_pad_id() const {
bool value = (_has_bits_[1] & 0x00000010u) != 0;
bool value = (_has_bits_[1] & 0x00000008u) != 0;
return value;
}
inline bool TrainerSpec::has_pad_id() const {
@ -3255,7 +3212,7 @@ inline bool TrainerSpec::has_pad_id() const {
}
inline void TrainerSpec::clear_pad_id() {
pad_id_ = -1;
_has_bits_[1] &= ~0x00000010u;
_has_bits_[1] &= ~0x00000008u;
}
inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_pad_id() const {
return pad_id_;
@ -3265,7 +3222,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::pad_id() const {
return _internal_pad_id();
}
inline void TrainerSpec::_internal_set_pad_id(::PROTOBUF_NAMESPACE_ID::int32 value) {
_has_bits_[1] |= 0x00000010u;
_has_bits_[1] |= 0x00000008u;
pad_id_ = value;
}
inline void TrainerSpec::set_pad_id(::PROTOBUF_NAMESPACE_ID::int32 value) {
@ -3645,7 +3602,7 @@ inline void TrainerSpec::set_allocated_unk_surface(std::string* unk_surface) {
// optional bool train_extremely_large_corpus = 49 [default = false];
inline bool TrainerSpec::_internal_has_train_extremely_large_corpus() const {
bool value = (_has_bits_[0] & 0x00040000u) != 0;
bool value = (_has_bits_[0] & 0x00020000u) != 0;
return value;
}
inline bool TrainerSpec::has_train_extremely_large_corpus() const {
@ -3653,7 +3610,7 @@ inline bool TrainerSpec::has_train_extremely_large_corpus() const {
}
inline void TrainerSpec::clear_train_extremely_large_corpus() {
train_extremely_large_corpus_ = false;
_has_bits_[0] &= ~0x00040000u;
_has_bits_[0] &= ~0x00020000u;
}
inline bool TrainerSpec::_internal_train_extremely_large_corpus() const {
return train_extremely_large_corpus_;
@ -3663,7 +3620,7 @@ inline bool TrainerSpec::train_extremely_large_corpus() const {
return _internal_train_extremely_large_corpus();
}
inline void TrainerSpec::_internal_set_train_extremely_large_corpus(bool value) {
_has_bits_[0] |= 0x00040000u;
_has_bits_[0] |= 0x00020000u;
train_extremely_large_corpus_ = value;
}
inline void TrainerSpec::set_train_extremely_large_corpus(bool value) {

View File

@ -134,53 +134,32 @@ void ModelInterface::InitializePieces() {
}
std::vector<absl::string_view> SplitIntoWords(absl::string_view text,
bool treat_ws_as_suffix,
bool allow_ws_only_pieces) {
bool treat_whitespace_as_suffix) {
const char *begin = text.data();
const char *end = text.data() + text.size();
// Space symbol (U+2581)
const absl::string_view kSpaceSymbol = "\xe2\x96\x81";
bool in_ws_sequence = false;
std::vector<absl::string_view> result;
if (treat_ws_as_suffix) { // put ws tokens at the end of non-ws sequences.
if (treat_whitespace_as_suffix) {
if (begin < end) result.emplace_back(begin, 0);
while (begin < end) {
const int mblen =
std::min<int>(string_util::OneCharLen(begin), end - begin);
const bool is_ws = absl::string_view(begin, mblen) == kSpaceSymbol;
if (is_ws) { // keep track of sequences consecutive ws tokens.
in_ws_sequence = true;
} else if (in_ws_sequence) {
if (allow_ws_only_pieces) result.emplace_back(begin, 0);
in_ws_sequence = false;
}
result.back() =
absl::string_view(result.back().data(), result.back().size() + mblen);
begin += mblen;
if (begin < end && is_ws && !allow_ws_only_pieces)
result.emplace_back(begin, 0);
if (begin < end && is_ws) result.emplace_back(begin, 0);
}
} else {
while (begin < end) {
const int mblen =
std::min<int>(string_util::OneCharLen(begin), end - begin);
bool is_ws = absl::string_view(begin, mblen) == kSpaceSymbol;
// if is whitespace (and not in sequence if allow_ws_only_pieces is True)
if (begin == text.data() ||
(is_ws && (!in_ws_sequence || !allow_ws_only_pieces))) {
absl::string_view(begin, mblen) == kSpaceSymbol)
result.emplace_back(begin, 0); // add empty string piece.
in_ws_sequence = true;
}
if (in_ws_sequence && !is_ws) in_ws_sequence = false;
result.back() =
absl::string_view(result.back().data(), result.back().size() + mblen);
begin += mblen;

View File

@ -33,9 +33,8 @@
namespace sentencepiece {
// "_this_is_a_pen" => ["_this", "_is", "_a", "_pen"]
std::vector<absl::string_view> SplitIntoWords(
absl::string_view text, bool treat_ws_as_suffix = false,
bool allow_ws_only_pieces = false);
std::vector<absl::string_view> SplitIntoWords(absl::string_view text,
bool add_ws_as_suffix = false);
// Converts byte (0-255) to piece (e.g., 58 -> "<0x3A>").
std::string ByteToPiece(unsigned char c);
@ -107,42 +106,12 @@ class ModelInterface {
return EncodeResult();
}
// Sample `samples` many tokenisations from the segmentation lattice
// If `wor` is true, the samples are taken without replacement, and the scores
// are the inclusion probabilities of the elements in the sample; otherwise
// the samples are taken with replacement and the scores are the log-probs of
// sample elements
// If `include_best` is true, the best tokenisation is always included in the
// sample, and the remaining elements are sampled excluding the best.
virtual NBestEncodeResult SampleEncodeAndScore(absl::string_view normalized,
float alpha, int samples,
bool wor,
bool include_best) const {
LOG(ERROR) << "Not implemented.";
return {{EncodeResult(), 0.0}};
}
// Calculates the entropy of the segmentation lattice with inverse temperature
// `theta`.
// Uses a novel dynamic program to calculate the entropy.
virtual float CalculateEntropy(absl::string_view normalized,
float theta) const {
LOG(ERROR) << "Not implemented.";
return 0.0;
}
// Return true if SampleEncode returns a valid result.
virtual bool IsSampleEncodeAvailable() const { return false; }
// Return true if NBestEncode returns a valid result.
virtual bool IsNBestEncodeAvailable() const { return false; }
// Return true if SampleEncodeAndScore returns a valid result.
virtual bool IsSampleEncodeAndScoreAvailable() const { return false; }
// Return true if CalculateEntropy returns a valid result.
virtual bool IsCalculateEntropyAvailable() const { return false; }
// Returns the vocab id of `piece`.
// Returns UNK(0) if `piece` is unknown
virtual int PieceToId(absl::string_view piece) const;
@ -155,10 +124,7 @@ class ModelInterface {
// Returns the size of sentence pieces, which is the same
// as the size of vocabulary for NMT.
virtual int GetPieceSize() const {
if (!model_proto_) return 0;
return model_proto_->pieces_size();
}
virtual int GetPieceSize() const { return model_proto_->pieces_size(); }
// Returns the score of `id`.
// Score represents a log probability of the piece.

View File

@ -412,50 +412,6 @@ TEST(ModelInterfaceTest, SplitIntoWordsSuffixTest) {
}
}
TEST(ModelInterfaceTest, SplitIntoWordsWhiteSpaceOnly) {
{
const auto v =
SplitIntoWords("this" WS "is" WS "a" WS "pen" WS, true, true);
EXPECT_EQ(4, v.size());
EXPECT_EQ("this" WS, v[0]);
EXPECT_EQ("is" WS, v[1]);
EXPECT_EQ("a" WS, v[2]);
EXPECT_EQ("pen" WS, v[3]);
}
{
const auto v = SplitIntoWords(WS WS WS "a", false, true);
EXPECT_EQ(1, v.size());
EXPECT_EQ(WS WS WS "a", v[0]);
}
{
const auto v = SplitIntoWords("a" WS WS WS, true, true);
EXPECT_EQ(1, v.size());
EXPECT_EQ("a" WS WS WS, v[0]);
}
{
const auto v = SplitIntoWords(WS WS, true, true);
EXPECT_EQ(1, v.size());
EXPECT_EQ(WS WS, v[0]);
}
{
const auto v = SplitIntoWords(WS WS "a" WS, true, true);
EXPECT_EQ(2, v.size());
EXPECT_EQ(WS WS, v[0]);
EXPECT_EQ("a" WS, v[1]);
}
{
const auto v = SplitIntoWords(WS WS "a" WS, false, true);
EXPECT_EQ(2, v.size());
EXPECT_EQ(WS WS "a", v[0]);
EXPECT_EQ(WS, v[1]);
}
}
TEST(ModelInterfaceTest, ByteToPieceTest) {
EXPECT_EQ(ByteToPiece(0), "<0x00>");
EXPECT_EQ(ByteToPiece(1), "<0x01>");

View File

@ -12,11 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.!
#include "normalizer.h"
#include <utility>
#include <vector>
#include "common.h"
#include "normalizer.h"
#include "third_party/absl/memory/memory.h"
#include "third_party/absl/strings/match.h"
#include "third_party/absl/strings/string_view.h"
@ -277,11 +278,11 @@ util::Status Normalizer::DecodePrecompiledCharsMap(
absl::string_view blob, absl::string_view *trie_blob,
absl::string_view *normalized, std::string *buffer) {
uint32 trie_blob_size = 0;
if (blob.size() <= sizeof(trie_blob_size) ||
!string_util::DecodePOD<uint32>(
absl::string_view(blob.data(), sizeof(trie_blob_size)),
&trie_blob_size) ||
trie_blob_size >= blob.size()) {
&trie_blob_size)) {
return util::InternalError("Blob for normalization rule is broken.");
}

View File

@ -22,6 +22,7 @@
#include <vector>
#include "common.h"
#include "util.h"
#include "sentencepiece_model.pb.h"
#include "sentencepiece_processor.h"
#include "third_party/absl/strings/string_view.h"

View File

@ -12,10 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.!
#include "normalizer.h"
#include <vector>
#include "builder.h"
#include "normalizer.h"
#include "sentencepiece_trainer.h"
#include "testharness.h"
#include "util.h"

View File

@ -139,10 +139,6 @@ message TrainerSpec {
// of sentence.
optional bool treat_whitespace_as_suffix = 24 [default = false];
// Allows pieces that only contain whitespaces instead of appearing only as
// prefix or suffix of other pieces.
optional bool allow_whitespace_only_pieces = 26 [default = false];
// Split all digits (0-9) into separate pieces.
optional bool split_digits = 25 [default = false];

View File

@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.!
#include "sentencepiece_processor.h"
#include <map>
#include <set>
#include <utility>
@ -22,7 +24,6 @@
#include "model_interface.h"
#include "normalizer.h"
#include "sentencepiece.pb.h"
#include "sentencepiece_processor.h"
#include "third_party/absl/memory/memory.h"
#include "third_party/absl/strings/numbers.h"
#include "third_party/absl/strings/str_cat.h"
@ -503,43 +504,6 @@ util::Status SentencePieceProcessor::SampleEncode(
return util::OkStatus();
}
util::Status SentencePieceProcessor::SampleEncodeAndScore(
absl::string_view input, int samples, float theta, bool wor,
bool include_best, NBestSentencePieceText *samples_spt) const {
CHECK_OR_RETURN(model_->IsSampleEncodeAndScoreAvailable())
<< "SampleEncodeAndScore is not available for the current model.";
std::string normalized;
std::vector<size_t> norm_to_orig;
RETURN_IF_ERROR(normalizer_->Normalize(input, &normalized, &norm_to_orig));
const auto results = model_->SampleEncodeAndScore(normalized, theta, samples,
wor, include_best);
CHECK_OR_RETURN(!results.empty())
<< "SampleEncodeAndScore returns empty result.";
for (const auto &result : results) {
auto *spt = samples_spt->add_nbests();
spt->set_score(result.second);
RETURN_IF_ERROR(PopulateSentencePieceText(input, normalized, norm_to_orig,
result.first, spt));
}
return util::OkStatus();
}
util::Status SentencePieceProcessor::CalculateEntropy(absl::string_view input,
float theta,
float *entropy) const {
CHECK_OR_RETURN(model_->IsCalculateEntropyAvailable())
<< "CalculateEntropy is not available for the current model.";
std::string normalized;
std::vector<size_t> norm_to_orig;
RETURN_IF_ERROR(normalizer_->Normalize(input, &normalized, &norm_to_orig));
*entropy = model_->CalculateEntropy(normalized, theta);
return util::OkStatus();
}
util::Status SentencePieceProcessor::Decode(
const std::vector<std::string> &pieces, SentencePieceText *spt) const {
CHECK_OR_RETURN_STATUS_PROTO(spt);
@ -869,12 +833,6 @@ std::string SentencePieceProcessor::serialized_model_proto() const {
return model_proto_ ? model_proto_->SerializeAsString() : "";
}
// Set seed value of random generator.
// Do not set static_cast<unique_int>(-1),
// as this seed is reserved for initializing from
// std::random_device.
void SetRandomGeneratorSeed(unsigned int seed);
namespace io {
util::Status LoadModelProto(absl::string_view filename,

View File

@ -315,15 +315,6 @@ class SentencePieceProcessor {
virtual util::Status SampleEncode(absl::string_view input, int nbest_size,
float alpha, SentencePieceText *spt) const;
// Samples N segmentation and returns the scores as well
virtual util::Status SampleEncodeAndScore(
absl::string_view input, int samples, float theta, bool wor,
bool include_best, NBestSentencePieceText *samples_spt) const;
// Calculate entropy of possible tokenisations
virtual util::Status CalculateEntropy(absl::string_view input, float theta,
float *entropy) const;
// Given a sequence of pieces, decodes it into SentencePieceText.
virtual util::Status Decode(const std::vector<std::string> &pieces,
SentencePieceText *spt) const;

View File

@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.!
#include "sentencepiece_processor.h"
#include <utility>
#include "builder.h"
@ -20,7 +22,6 @@
#include "normalizer.h"
#include "sentencepiece.pb.h"
#include "sentencepiece_model.pb.h"
#include "sentencepiece_processor.h"
#include "sentencepiece_trainer.h"
#include "testharness.h"
#include "third_party/absl/container/flat_hash_map.h"
@ -1138,6 +1139,13 @@ TEST(SentencePieceProcessorTest, EndToEndTest) {
EXPECT_EQ("cba", output);
}
// Out of range
{
std::string output;
const std::vector<int> ids = {3, 4, 127};
EXPECT_FALSE(sp.Decode(ids, &output).ok());
}
{
EXPECT_TRUE(sp.SetDecodeExtraOptions("bos:eos").ok());
@ -1164,13 +1172,6 @@ TEST(SentencePieceProcessorTest, EndToEndTest) {
EXPECT_EQ("cba", output);
}
// Out of range
{
std::string output;
const std::vector<int> ids = {3, 4, 127};
EXPECT_FALSE(sp.Decode(ids, &output).ok());
}
{
EXPECT_TRUE(sp.SetDecodeExtraOptions("bos:eos:reverse").ok());

View File

@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.!
#include "sentencepiece_trainer.h"
#include <string>
#include <vector>
@ -20,9 +22,7 @@
#include "normalizer.h"
#include "sentencepiece.pb.h"
#include "sentencepiece_model.pb.h"
#include "sentencepiece_trainer.h"
#include "spec_parser.h"
#include "third_party/absl/flags/flag.h"
#include "third_party/absl/strings/numbers.h"
#include "third_party/absl/strings/str_cat.h"
#include "third_party/absl/strings/str_split.h"
@ -31,8 +31,6 @@
#include "trainer_factory.h"
#include "util.h"
ABSL_DECLARE_FLAG(int, minloglevel);
namespace sentencepiece {
namespace {
static constexpr char kDefaultNormalizerName[] = "nmt_nfkc";
@ -112,7 +110,7 @@ util::Status SentencePieceTrainer::MergeSpecsFromArgs(
for (auto arg : absl::StrSplit(args, " ")) {
absl::ConsumePrefix(&arg, "--");
std::string key, value;
const auto pos = arg.find('=');
const auto pos = arg.find("=");
if (pos == absl::string_view::npos) {
key = std::string(arg);
} else {
@ -151,7 +149,7 @@ util::Status SentencePieceTrainer::MergeSpecsFromArgs(
} else if (key == "minloglevel") {
int v = 0;
CHECK_OR_RETURN(absl::SimpleAtoi(value, &v));
absl::SetFlag(&FLAGS_minloglevel, v);
logging::SetMinLogLevel(v);
continue;
}

View File

@ -145,7 +145,6 @@ inline std::string PrintProto(const TrainerSpec &message,
PRINT_PARAM(split_by_whitespace);
PRINT_PARAM(split_digits);
PRINT_PARAM(treat_whitespace_as_suffix);
PRINT_PARAM(allow_whitespace_only_pieces);
PRINT_REPEATED_STRING(control_symbols);
PRINT_REPEATED_STRING(user_defined_symbols);
PRINT_PARAM(required_chars);
@ -220,7 +219,6 @@ util::Status SentencePieceTrainer::SetProtoField(const std::string &name,
PARSE_BOOL(split_by_whitespace);
PARSE_BOOL(split_digits);
PARSE_BOOL(treat_whitespace_as_suffix);
PARSE_BOOL(allow_whitespace_only_pieces);
PARSE_REPEATED_STRING(control_symbols);
PARSE_REPEATED_STRING(user_defined_symbols);
PARSE_STRING(required_chars);

View File

@ -64,7 +64,6 @@ int main(int argc, char *argv[]) {
auto ToIds = [&](const std::vector<std::string> &pieces) {
std::vector<int> ids;
ids.reserve(pieces.size());
for (const auto &s : pieces) {
ids.push_back(atoi(s.c_str()));
}

View File

@ -28,17 +28,16 @@
#include "trainer_interface.h"
ABSL_FLAG(std::string, model, "", "model file name");
ABSL_FLAG(
std::string, output_format, "piece",
"choose from piece, id, proto, nbest_piece, nbest_id, or nbest_proto");
ABSL_FLAG(std::string, output_format, "piece",
"choose from piece, id, proto, nbest_piece, nbest_id, nbest_proto, "
"sample_piece, sample_id or sample_proto.");
ABSL_FLAG(std::string, input, "", "input filename");
ABSL_FLAG(std::string, output, "", "output filename");
ABSL_FLAG(std::string, extra_options, "",
"':' separated encoder extra options, e.g., \"reverse:bos:eos\"");
ABSL_FLAG(int32, nbest_size, 10, "NBest size");
ABSL_FLAG(double, alpha, 0.5, "Smoothing parameter for sampling mode.");
ABSL_FLAG(uint32, random_seed, static_cast<uint32>(-1),
"Seed value for random generator.");
ABSL_FLAG(int32, random_seed, -1, "Seed value for random generator.");
// Piece restriction with vocabulary file.
// https://github.com/rsennrich/subword-nmt#best-practice-advice-for-byte-pair-encoding-in-nmt
@ -62,9 +61,8 @@ int main(int argc, char *argv[]) {
rest_args.push_back(absl::GetFlag(FLAGS_input));
}
if (absl::GetFlag(FLAGS_random_seed) != -1) {
if (absl::GetFlag(FLAGS_random_seed) != -1)
sentencepiece::SetRandomGeneratorSeed(absl::GetFlag(FLAGS_random_seed));
}
if (rest_args.empty())
rest_args.push_back(""); // empty means that reading from stdin.

View File

@ -80,9 +80,6 @@ ABSL_FLAG(bool, split_digits, kDefaultTrainerSpec.split_digits(),
ABSL_FLAG(bool, treat_whitespace_as_suffix,
kDefaultTrainerSpec.treat_whitespace_as_suffix(),
"treat whitespace marker as suffix instead of prefix.");
ABSL_FLAG(bool, allow_whitespace_only_pieces,
kDefaultTrainerSpec.allow_whitespace_only_pieces(),
"allow pieces that only contain (consecutive) whitespace tokens");
ABSL_FLAG(std::string, control_symbols, "",
"comma separated list of control symbols");
ABSL_FLAG(std::string, control_symbols_file, "",
@ -141,8 +138,7 @@ ABSL_FLAG(std::string, unk_surface, kDefaultTrainerSpec.unk_surface(),
ABSL_FLAG(bool, train_extremely_large_corpus,
kDefaultTrainerSpec.train_extremely_large_corpus(),
"Increase bit depth for unigram tokenization.");
ABSL_FLAG(uint32, random_seed, static_cast<uint32>(-1),
"Seed value for random generator.");
ABSL_FLAG(int32, random_seed, -1, "Seed value for random generator.");
int main(int argc, char *argv[]) {
sentencepiece::ParseCommandLineFlags(argv[0], &argc, &argv, true);
@ -154,9 +150,8 @@ int main(int argc, char *argv[]) {
CHECK(!absl::GetFlag(FLAGS_input).empty());
CHECK(!absl::GetFlag(FLAGS_model_prefix).empty());
if (absl::GetFlag(FLAGS_random_seed) != -1) {
if (absl::GetFlag(FLAGS_random_seed) != -1)
sentencepiece::SetRandomGeneratorSeed(absl::GetFlag(FLAGS_random_seed));
}
auto load_lines = [](absl::string_view filename) {
std::vector<std::string> lines;
@ -216,7 +211,6 @@ int main(int argc, char *argv[]) {
SetTrainerSpecFromFlag(split_digits);
SetTrainerSpecFromFlag(byte_fallback);
SetTrainerSpecFromFlag(treat_whitespace_as_suffix);
SetTrainerSpecFromFlag(allow_whitespace_only_pieces);
SetTrainerSpecFromFlag(hard_vocab_limit);
SetTrainerSpecFromFlag(use_all_vocab);
SetTrainerSpecFromFlag(unk_id);

View File

@ -12,7 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.!
#include <algorithm>
#include "trainer_interface.h"
#include <cstdlib>
#include <memory>
#include <set>
@ -33,7 +34,6 @@
#include "third_party/absl/strings/str_format.h"
#include "third_party/absl/strings/str_join.h"
#include "third_party/absl/strings/str_split.h"
#include "trainer_interface.h"
#include "unicode_script.h"
#include "util.h"
@ -86,10 +86,6 @@ util::Status VerifySpec(const TrainerSpec &trainer_spec) {
return util::OkStatus();
}
bool is_unicode_decimal_number(char32 c) {
return (c >= 0x30 && c <= 0x39) || (c >= 0xff10 && c <= 0xff19);
}
class SentenceSelector {
public:
using Sampler = random::ReservoirSampler<TrainerInterface::Sentence>;
@ -214,10 +210,9 @@ bool TrainerInterface::IsValidSentencePiece(
constexpr unicode_script::ScriptType kAnyType =
static_cast<unicode_script::ScriptType>(-1);
auto is_number = [](char32 c) { return (c >= 0x30 && c <= 0x39); };
unicode_script::ScriptType prev_script = kAnyType;
bool all_whitespace_piece =
std::all_of(sentencepiece.begin(), sentencepiece.end(),
[](char32 c) { return c == kWSChar; });
for (size_t pos = 0; pos < sentencepiece.size(); ++pos) {
const char32 c = sentencepiece[pos];
@ -240,30 +235,25 @@ bool TrainerInterface::IsValidSentencePiece(
}
if (c == kWSChar) {
// Only allows whitespace to appear as a prefix of piece unless
// allow_whitespace_only_pieces is True.
// Only allows whitespace to appear as a prefix of piece.
// When split_by_whitespace is false, we allow whitespaces to
// appear in the middle, "foo_bar", but do not allow them
// to appear as suffix, "foo_bar_".
// Regardless of the setting of split_by_whitespace,
// whitespace is treated as a prefix/infix of symbol or
// independent symbol, unless allow_whitespace_only_pieces() is true,
// in which case whitespace only pieces can occur.
if (!trainer_spec_.allow_whitespace_only_pieces() ||
!all_whitespace_piece) {
if (trainer_spec_.treat_whitespace_as_suffix()) {
if ((trainer_spec_.split_by_whitespace() &&
pos < sentencepiece.size() - 1) ||
(!trainer_spec_.split_by_whitespace() &&
pos < sentencepiece.size() - 1 && pos == 0)) {
return false;
}
} else {
if ((trainer_spec_.split_by_whitespace() && pos > 0) ||
(!trainer_spec_.split_by_whitespace() && pos > 0 &&
pos == sentencepiece.size() - 1)) {
return false;
}
// independent symbol.
if (trainer_spec_.treat_whitespace_as_suffix()) {
if ((trainer_spec_.split_by_whitespace() &&
pos < sentencepiece.size() - 1) ||
(!trainer_spec_.split_by_whitespace() &&
pos < sentencepiece.size() - 1 && pos == 0)) {
return false;
}
} else {
if ((trainer_spec_.split_by_whitespace() && pos > 0) ||
(!trainer_spec_.split_by_whitespace() && pos > 0 &&
pos == sentencepiece.size() - 1)) {
return false;
}
}
} else {
@ -275,11 +265,11 @@ bool TrainerInterface::IsValidSentencePiece(
s = unicode_script::U_Han;
}
if (!trainer_spec_.split_by_number() && is_unicode_decimal_number(c)) {
if (!trainer_spec_.split_by_number() && is_number(c)) {
s = kAnyType;
}
if (trainer_spec_.split_digits() && is_unicode_decimal_number(c)) {
if (trainer_spec_.split_digits() && is_number(c)) {
if (sentencepiece.size() > 1) return false;
}
@ -528,8 +518,7 @@ void TrainerInterface::SplitSentencesByWhitespace() {
absl::flat_hash_map<std::string, int64> tokens;
for (const auto &s : sentences_) {
for (const auto &w :
SplitIntoWords(s.first, trainer_spec_.treat_whitespace_as_suffix(),
trainer_spec_.allow_whitespace_only_pieces())) {
SplitIntoWords(s.first, trainer_spec_.treat_whitespace_as_suffix())) {
tokens[std::string(w)] += s.second;
}
}

View File

@ -81,7 +81,6 @@ TEST(TrainerInterfaceTest, IsValidSentencePieceTest) {
trainer_spec.set_split_by_whitespace(false);
EXPECT_TRUE(IsValid(WS));
EXPECT_TRUE(IsValid(WS WS WS "a"));
EXPECT_TRUE(IsValid(WS "a"));
EXPECT_FALSE(IsValid("a" WS));
EXPECT_FALSE(IsValid(WS "a" WS));
@ -89,17 +88,7 @@ TEST(TrainerInterfaceTest, IsValidSentencePieceTest) {
EXPECT_TRUE(IsValid(WS "a" WS "b"));
EXPECT_TRUE(IsValid(WS "a" WS "b" WS "c"));
EXPECT_FALSE(IsValid("a" WS "b" WS));
EXPECT_FALSE(IsValid(WS WS));
EXPECT_FALSE(IsValid(WS WS WS));
trainer_spec.set_allow_whitespace_only_pieces(true);
EXPECT_TRUE(IsValid(WS));
EXPECT_TRUE(IsValid(WS WS));
EXPECT_TRUE(IsValid(WS WS WS));
EXPECT_TRUE(IsValid(WS WS "a"));
EXPECT_FALSE(IsValid("a" WS WS)); // suffix whitespace illegal without flag
trainer_spec.set_allow_whitespace_only_pieces(false);
trainer_spec.set_split_by_unicode_script(false);
EXPECT_TRUE(IsValid("あいう"));
EXPECT_TRUE(IsValid("グーグル"));
@ -135,15 +124,6 @@ TEST(TrainerInterfaceTest, IsValidSentencePieceTest) {
EXPECT_FALSE(IsValid(WS "a" WS "b"));
EXPECT_FALSE(IsValid("a" WS "b" WS));
trainer_spec.set_allow_whitespace_only_pieces(true);
EXPECT_TRUE(IsValid(WS));
EXPECT_TRUE(IsValid(WS WS));
EXPECT_FALSE(IsValid(WS "a" WS));
EXPECT_FALSE(IsValid("a" WS "b"));
EXPECT_FALSE(IsValid(WS "a" WS "b"));
EXPECT_FALSE(IsValid("a" WS "b" WS));
trainer_spec.set_allow_whitespace_only_pieces(false);
trainer_spec.set_split_by_whitespace(false);
EXPECT_TRUE(IsValid(WS));
EXPECT_FALSE(IsValid(WS "a"));
@ -166,12 +146,6 @@ TEST(TrainerInterfaceTest, IsValidSentencePieceTest) {
EXPECT_FALSE(IsValid("2007"));
EXPECT_FALSE(IsValid("x1"));
EXPECT_FALSE(IsValid("2x"));
// Fullwidth digits.
EXPECT_TRUE(IsValid(""));
EXPECT_FALSE(IsValid(""));
EXPECT_FALSE(IsValid(""));
EXPECT_FALSE(IsValid(""));
EXPECT_FALSE(IsValid(""));
}
TEST(TrainerInterfaceTest, OverrideSpecialPiecesTest) {

View File

@ -15,7 +15,6 @@
#include <algorithm>
#include <cfloat>
#include <cmath>
#include <complex>
#include <map>
#include <queue>
#include <string>
@ -56,17 +55,6 @@ inline float LogSumExp(float x, float y, bool init_mode) {
return vmax + log(std::exp(static_cast<double>(vmin - vmax)) + 1.0);
}
}
// Returns a sample from a standard Gumbel distribution.
// If U ~ U[0, 1], -log(-log U) ~ G(0,1)
inline float Gumbel() {
const float kEpsilon = 1e-7;
auto *mt = random::GetRandomGenerator();
std::uniform_real_distribution<float> dis(0.0, 1.0);
float noise = -std::log(-(std::log(dis(*mt) + kEpsilon)));
return noise;
}
} // namespace
Lattice::Lattice() : node_allocator_(kPreallocateLatticeNodeSize) {}
@ -157,7 +145,7 @@ Lattice::Node *Lattice::Insert(int pos, int length) {
return node;
}
Lattice::LatticePathWithScore Lattice::Viterbi() {
std::vector<Lattice::Node *> Lattice::Viterbi() {
const int len = size();
for (int pos = 0; pos <= len; ++pos) {
@ -183,7 +171,6 @@ Lattice::LatticePathWithScore Lattice::Viterbi() {
// backtrace
std::vector<Node *> results;
float score = begin_nodes(len)[0]->backtrace_score;
for (Node *node = begin_nodes_[len][0]->prev; node->prev != nullptr;
node = node->prev) {
results.push_back(node);
@ -191,43 +178,7 @@ Lattice::LatticePathWithScore Lattice::Viterbi() {
std::reverse(results.begin(), results.end());
LatticePathWithScore retval = {results, score};
return retval;
}
std::vector<float> Lattice::ForwardAlgorithm(float theta) const {
const int len = size();
std::vector<float> alpha(node_allocator_.size(), 0.0);
for (int pos = 0; pos <= len; ++pos) {
for (Node *rnode : begin_nodes_[pos]) {
for (Node *lnode : end_nodes_[pos]) {
alpha[rnode->node_id] = LogSumExp(
alpha[rnode->node_id], theta * lnode->score + alpha[lnode->node_id],
lnode == end_nodes_[pos][0]);
}
}
}
return alpha;
}
std::vector<float> Lattice::BackwardAlgorithm(float theta) const {
const int len = size();
std::vector<float> beta(node_allocator_.size(), 0.0);
for (int pos = len; pos >= 0; --pos) {
for (Node *lnode : end_nodes_[pos]) {
for (Node *rnode : begin_nodes_[pos]) {
beta[lnode->node_id] =
LogSumExp(beta[lnode->node_id], rnode->score + beta[rnode->node_id],
rnode == begin_nodes_[pos][0]);
}
}
}
return beta;
return results;
}
float Lattice::PopulateMarginal(float freq,
@ -238,9 +189,28 @@ float Lattice::PopulateMarginal(float freq,
// alpha and beta (accumulative log prob) in Forward Backward.
// the index of alpha/beta is Node::node_id.
std::vector<float> alpha(node_allocator_.size(), 0.0);
std::vector<float> beta(node_allocator_.size(), 0.0);
const auto alpha = ForwardAlgorithm(1.0);
const auto beta = BackwardAlgorithm(1.0);
for (int pos = 0; pos <= len; ++pos) {
for (Node *rnode : begin_nodes_[pos]) {
for (Node *lnode : end_nodes_[pos]) {
alpha[rnode->node_id] = LogSumExp(alpha[rnode->node_id],
lnode->score + alpha[lnode->node_id],
lnode == end_nodes_[pos][0]);
}
}
}
for (int pos = len; pos >= 0; --pos) {
for (Node *lnode : end_nodes_[pos]) {
for (Node *rnode : begin_nodes_[pos]) {
beta[lnode->node_id] =
LogSumExp(beta[lnode->node_id], rnode->score + beta[rnode->node_id],
rnode == begin_nodes_[pos][0]);
}
}
}
const float Z = alpha[begin_nodes_[len][0]->node_id];
for (int pos = 0; pos < len; ++pos) {
@ -258,46 +228,13 @@ float Lattice::PopulateMarginal(float freq,
return freq * Z;
}
float Lattice::CalculateEntropy(float theta) const {
const int len = size();
// alpha[node_id] is the marginal prob of sequence up to start of node
// H is entropy of sequence
// the index of alpha/H is Node::node_id.
std::vector<float> alpha(node_allocator_.size(), 0.0);
std::vector<float> H(node_allocator_.size(), 0.0);
// Populate the forward marginals to get the normalising constant
alpha = ForwardAlgorithm(theta);
// Now populate the forward entropies
for (int pos = 0; pos <= len; ++pos) {
for (Node *rnode : begin_nodes_[pos]) {
for (Node *lnode : end_nodes_[pos]) {
// Contribution each lnode makes = p(lnode) * (H(lnode) + log p(lnode))
// We have to normalise p(lnode) by the marginal contribution it makes
const float lnode_transition_prob =
((theta * lnode->score) + alpha[lnode->node_id] -
alpha[rnode->node_id]);
H[rnode->node_id] += std::exp(lnode_transition_prob) *
(H[lnode->node_id] + lnode_transition_prob);
}
}
}
return -H[begin_nodes_[len][0]->node_id];
}
std::vector<Lattice::LatticePathWithScore> Lattice::NBest(size_t nbest_size,
bool sample,
float theta) {
std::vector<std::vector<Lattice::Node *>> Lattice::NBest(size_t nbest_size) {
if (nbest_size < 1) {
LOG(WARNING) << "nbest_size >= 1. Returns empty result.";
return {};
}
if (nbest_size == 1 && !sample) {
if (nbest_size == 1) {
return {Viterbi()};
}
@ -306,7 +243,6 @@ std::vector<Lattice::LatticePathWithScore> Lattice::NBest(size_t nbest_size,
// At each partial path x, compute f(x) as follows
// f(x) = g(x) + h(x).
// g(x): the sum of scores from EOS to the left-most node in x.
// for a complete hypothesis, g(hyp) is the score of the hypothesis.
// h(x): a heuristic that estimates the largest score from x to BOS.
// f(x): the priority to pop a new hypothesis from the priority queue.
//
@ -332,27 +268,18 @@ std::vector<Lattice::LatticePathWithScore> Lattice::NBest(size_t nbest_size,
model::FreeList<Hypothesis> hypothesis_allocator(kPreallocatedHypothesisSize);
Agenda agenda;
std::vector<Lattice::LatticePathWithScore> results;
std::vector<std::vector<Node *>> results;
auto *eos = hypothesis_allocator.Allocate();
eos->node = eos_node();
eos->next = nullptr;
eos->gx = 0.0;
std::vector<float> alpha(node_allocator_.size(), 0.0);
if (sample) {
// Run forwards algorithm to get normalising constants
alpha = ForwardAlgorithm(theta);
// f(eos) = Gumbel(0), as it is the perturbed score of the entire lattice.
eos->fx = Gumbel();
} else {
// Run Viterbi first to fill backtrace score.
Viterbi();
eos->fx = eos->node->backtrace_score;
}
eos->fx = eos->node->score;
eos->gx = eos->node->score;
agenda.push(eos);
// Run Viterbi first to fill backtrace score.
Viterbi();
while (!agenda.empty()) {
auto *top = agenda.top();
agenda.pop();
@ -362,56 +289,21 @@ std::vector<Lattice::LatticePathWithScore> Lattice::NBest(size_t nbest_size,
if (node == bos_node()) {
results.resize(results.size() + 1);
for (auto *n = top->next; n->next != nullptr; n = n->next) {
results.back().first.push_back(n->node);
results.back().push_back(n->node);
}
results.back().second = top->fx;
if (results.size() == nbest_size) {
break;
}
continue;
}
const int end_nodes_size = end_nodes(node->pos).size();
std::vector<float> probs(end_nodes_size, 0.0);
std::vector<float> perturbed_probs(end_nodes_size, 0.0);
std::vector<double> adjusted_probs(end_nodes_size, 0.0);
const float Z = alpha[node->node_id];
if (sample) {
float max_score = -1e8;
// Calculate the marginal and perturbed scores for stochastic search
for (int i = 0; i < end_nodes(node->pos).size(); i++) {
Node *lnode = end_nodes(node->pos)[i];
// Calculate backwards transition score
probs[i] = top->gx + alpha[lnode->node_id] + (theta * lnode->score) - Z;
perturbed_probs[i] = probs[i] + Gumbel();
if (perturbed_probs[i] > max_score) {
max_score = perturbed_probs[i];
}
}
// Now constrain the sampled continuations to match the score of parent
for (int i = 0; i < adjusted_probs.size(); i++) {
// Use numerically stable version of truncated Gumbel:
// https://arxiv.org/pdf/1903.06059.pdf appendix B.3
const float v = top->fx - perturbed_probs[i] +
std::log1p(-std::exp(perturbed_probs[i] - max_score));
adjusted_probs[i] = top->fx - std::max(static_cast<float>(0.0), v) -
std::log1p(std::exp(-std::abs(v)));
}
}
// Expands new node ending at node->pos
for (int i = 0; i < end_nodes(node->pos).size(); i++) {
Node *lnode = end_nodes(node->pos)[i];
for (Node *lnode : end_nodes(node->pos)) {
auto *hyp = hypothesis_allocator.Allocate();
hyp->node = lnode;
if (sample) {
hyp->gx = probs[i];
hyp->fx = adjusted_probs[i];
} else {
hyp->gx = lnode->score + top->gx; // just adds node->score
hyp->fx =
lnode->backtrace_score + top->gx; // backtrace_score is h(node).
}
hyp->gx = lnode->score + top->gx; // just adds node->score
hyp->fx =
lnode->backtrace_score + top->gx; // backtrace_score is h(node).
hyp->next = top;
agenda.push(hyp);
}
@ -443,7 +335,15 @@ std::vector<Lattice::Node *> Lattice::Sample(float theta) {
std::vector<float> alpha(node_allocator_.size(), 0.0);
alpha = ForwardAlgorithm(theta);
for (int pos = 0; pos <= len; ++pos) {
for (Node *rnode : begin_nodes_[pos]) {
for (Node *lnode : end_nodes_[pos]) {
alpha[rnode->node_id] = LogSumExp(
alpha[rnode->node_id], theta * lnode->score + alpha[lnode->node_id],
lnode == end_nodes_[pos][0]);
}
}
}
auto *mt = random::GetRandomGenerator();
@ -614,7 +514,7 @@ EncodeResult Model::Encode(absl::string_view normalized) const {
PopulateNodes(&lattice);
EncodeResult results;
for (const auto *node : lattice.Viterbi().first) {
for (const auto *node : lattice.Viterbi()) {
results.emplace_back(node->piece, node->id);
}
@ -634,12 +534,14 @@ NBestEncodeResult Model::NBestEncode(absl::string_view normalized,
PopulateNodes(&lattice);
NBestEncodeResult nbest_results;
for (const auto &nbest : lattice.NBest(nbest_size, false, 0.0)) {
for (const auto &nbest : lattice.NBest(nbest_size)) {
EncodeResult results;
for (const auto *node : nbest.first) {
float score = 0.0;
for (const auto *node : nbest) {
score += node->score;
results.emplace_back(node->piece, node->id);
}
nbest_results.emplace_back(results, nbest.second);
nbest_results.emplace_back(results, score);
}
return nbest_results;
@ -663,123 +565,6 @@ EncodeResult Model::SampleEncode(absl::string_view normalized,
return results;
}
NBestEncodeResult Model::SampleEncodeAndScore(absl::string_view normalized,
float theta, int samples,
bool wor,
bool include_best) const {
if (!status().ok() || normalized.empty()) {
return {};
}
NBestEncodeResult results;
Lattice lattice;
lattice.SetSentence(normalized);
PopulateNodes(&lattice);
std::vector<float> alpha = lattice.ForwardAlgorithm(theta);
float marginal = alpha[lattice.eos_node()->node_id];
if (include_best) {
if (!wor) {
LOG(FATAL) << "include_best not supported for wor false";
}
EncodeResult result;
Lattice::LatticePathWithScore best_path = lattice.Viterbi();
for (const auto *node : best_path.first) {
result.emplace_back(node->piece, node->id);
}
// Inclusion probability if we always include the best is 1.
results.emplace_back(result, 0.0);
}
if (wor) {
// Draw k+1 samples as we need perturbed score of k+1th element
std::vector<Lattice::LatticePathWithScore> nbest_samples =
lattice.NBest(samples + 1, true, theta);
if (include_best) {
std::vector<std::vector<Lattice::Node *>> nbest_paths(
nbest_samples.size());
for (int i = 0; i < nbest_samples.size(); i++) {
nbest_paths[i] = nbest_samples[i].first;
}
// Remove the best result from the samples if necessary
Lattice::LatticePathWithScore best_path = lattice.Viterbi();
const int index_of_best =
(std::find(nbest_paths.begin(), nbest_paths.end(), best_path.first) -
nbest_paths.begin());
if (index_of_best != nbest_samples.size()) {
LOG(INFO) << "removing best path from samples";
nbest_samples.erase(nbest_samples.begin() + index_of_best);
} else {
nbest_samples.pop_back();
}
}
// We use the perturbed score of the k+1th element to calculate the
// inclusion probability.
const double kappa = static_cast<double>(nbest_samples.back().second);
// Discard the last sample
nbest_samples.pop_back();
for (const auto &nbest : nbest_samples) {
EncodeResult result;
float score = 0.0;
for (const auto *node : nbest.first) {
score += (theta * node->score);
result.emplace_back(node->piece, node->id);
}
results.emplace_back(result, score - marginal);
}
// Now calculate the inclusion probability
for (auto &it : results) {
// Only modify non best sample inclusion probabilities.
if (it.second != 0.0) {
double x = it.second - kappa;
double y = std::exp(x);
double inclusion_prob;
if (x <= -10) {
// Series expansion of the log Gumbel survival function up to eps.
inclusion_prob =
x - (y / 2) + (std::pow(y, 2) / 24) - std::pow(y, 4) / 2880;
} else {
inclusion_prob = std::log(-std::expm1(-y));
}
it.second = static_cast<float>(inclusion_prob);
}
}
} else {
while (results.size() < samples) {
Lattice lattice;
lattice.SetSentence(normalized);
PopulateNodes(&lattice);
float score = 0.0;
EncodeResult result;
std::vector<Lattice::Node *> sample = lattice.Sample(theta);
for (const auto *node : sample) {
result.emplace_back(node->piece, node->id);
score += (theta * node->score);
}
results.emplace_back(result, score - marginal);
}
}
return results;
}
float Model::CalculateEntropy(absl::string_view normalized, float theta) const {
Lattice lattice;
lattice.SetSentence(normalized);
PopulateNodes(&lattice);
return lattice.CalculateEntropy(theta);
}
bool Model::VerifyOutputsEquivalent(absl::string_view expected,
absl::string_view actual) const {
auto compute_unigram_model_score =

View File

@ -82,28 +82,17 @@ class Lattice {
// After calling this method, The caller must set Node::score and Node::id.
Node *Insert(int pos, int length);
using LatticePathWithScore = std::pair<std::vector<Node *>, float>;
// Returns Viterbi path. All nodes must be populated in advance.
LatticePathWithScore Viterbi();
// Runs forwards/backwards algorithm, returns vector with normalised
// transition probs.
std::vector<float> ForwardAlgorithm(float theta) const;
std::vector<float> BackwardAlgorithm(float theta) const;
std::vector<Node *> Viterbi();
// Returns n-best results.
std::vector<LatticePathWithScore> NBest(size_t nbest_size, bool sample,
float theta);
std::vector<std::vector<Node *>> NBest(size_t nbest_size);
// Samples one path from the lattice according to the
// generation probability (Product of piece probabilities).
// `theta` is a smoothing parameter.
std::vector<Node *> Sample(float theta);
// Calculates the entropy of the lattice.
float CalculateEntropy(float theta) const;
// Populates marginal probability of every node in this lattice.
// |freq| is the frequency of the sentence.
// for (auto *node : all_nodes_) {
@ -138,19 +127,8 @@ class Model : public ModelInterface {
EncodeResult SampleEncode(absl::string_view normalized,
float theta) const override;
NBestEncodeResult SampleEncodeAndScore(absl::string_view normalized,
float theta, int samples, bool wor,
bool include_best) const override;
float CalculateEntropy(absl::string_view normalized,
float theta) const override;
bool IsSampleEncodeAvailable() const override { return true; }
bool IsSampleEncodeAndScoreAvailable() const override { return true; }
bool IsCalculateEntropyAvailable() const override { return true; }
bool IsNBestEncodeAvailable() const override { return true; }
// Returns the minimum score in sentence pieces.

View File

@ -161,11 +161,11 @@ TEST(LatticeTest, InsertTest) {
TEST(LatticeTest, ViterbiFromIncompleteLatticeTest) {
Lattice lattice;
lattice.SetSentence("ABC");
EXPECT_TRUE(lattice.Viterbi().first.empty());
EXPECT_TRUE(lattice.Viterbi().empty());
// Still incomplete
lattice.Insert(0, 1);
EXPECT_TRUE(lattice.Viterbi().first.empty());
EXPECT_TRUE(lattice.Viterbi().empty());
lattice.Insert(1, 1);
lattice.Insert(2, 1);
@ -198,16 +198,16 @@ TEST(LatticeTest, ViterbiTest) {
InsertWithScore(&lattice, 0, 1, 0.0); // A
InsertWithScore(&lattice, 1, 1, 0.0); // B
InsertWithScore(&lattice, 2, 1, 0.0); // C
EXPECT_EQ("A B C", GetTokenized(lattice.Viterbi().first));
EXPECT_EQ("A B C", GetTokenized(lattice.Viterbi()));
InsertWithScore(&lattice, 0, 2, 2.0); // AB
EXPECT_EQ("AB C", GetTokenized(lattice.Viterbi().first));
EXPECT_EQ("AB C", GetTokenized(lattice.Viterbi()));
InsertWithScore(&lattice, 1, 2, 5.0); // BC
EXPECT_EQ("A BC", GetTokenized(lattice.Viterbi().first));
EXPECT_EQ("A BC", GetTokenized(lattice.Viterbi()));
InsertWithScore(&lattice, 0, 3, 10.0); // ABC
EXPECT_EQ("ABC", GetTokenized(lattice.Viterbi().first));
EXPECT_EQ("ABC", GetTokenized(lattice.Viterbi()));
}
TEST(LatticeTest, NBestTest) {
@ -221,174 +221,21 @@ TEST(LatticeTest, NBestTest) {
InsertWithScore(&lattice, 1, 2, 5.0); // BC
InsertWithScore(&lattice, 0, 3, 10.0); // ABC
auto nbests = lattice.NBest(10, false, 0.0);
auto nbests = lattice.NBest(10);
EXPECT_EQ(4, nbests.size());
EXPECT_EQ("ABC", GetTokenized(nbests[0].first));
EXPECT_EQ("A BC", GetTokenized(nbests[1].first));
EXPECT_EQ("AB C", GetTokenized(nbests[2].first));
EXPECT_EQ("A B C", GetTokenized(nbests[3].first));
EXPECT_EQ("ABC", GetTokenized(nbests[0]));
EXPECT_EQ("A BC", GetTokenized(nbests[1]));
EXPECT_EQ("AB C", GetTokenized(nbests[2]));
EXPECT_EQ("A B C", GetTokenized(nbests[3]));
auto nbests0 = lattice.NBest(0, false, 0.0);
auto nbests0 = lattice.NBest(0);
EXPECT_TRUE(nbests0.empty());
auto nbests1 = lattice.NBest(1, false, 0.0);
auto nbests1 = lattice.NBest(1);
EXPECT_EQ(nbests1.size(), 1);
}
TEST(LatticeTest, NBestSampleTest) {
Lattice lattice;
lattice.SetSentence("ABC");
InsertWithScore(&lattice, 0, 1, 0.0); // A
InsertWithScore(&lattice, 1, 1, 0.0); // B
InsertWithScore(&lattice, 2, 1, 0.1); // C
InsertWithScore(&lattice, 0, 2, 0.2); // AB
InsertWithScore(&lattice, 1, 2, 0.5); // BC
InsertWithScore(&lattice, 0, 3, 1.0); // ABC
// Calculate expected probabilities of each path
// Note that sampling without replacement affects the expected frequencies!
const std::vector<double> kTheta = {0.0, 0.01, 0.5, 0.7, 1.0};
for (const auto theta : kTheta) {
std::vector<std::string> strings = {"ABC", "AB C", "A BC", "A B C"};
std::map<std::string, float> probs;
probs["ABC"] = std::exp(theta * 1.0);
probs["AB C"] = std::exp(theta * (0.2 + 0.1));
probs["A BC"] = std::exp(theta * (0.0 + 0.5));
probs["A B C"] = std::exp(theta * (0.0 + 0.0 + 0.1));
for (const auto &it : strings) {
EXPECT_EQ(1, probs.count(it));
}
double Z = 0.0;
for (const auto &it : probs) Z += it.second;
for (auto &it : probs) it.second /= Z;
std::map<std::pair<std::string, std::string>, float> pair_probs;
for (const auto first : strings) {
for (const auto second : strings) {
if (first == second) {
pair_probs[std::make_pair(first, second)] = 0;
} else {
float first_prob = probs[first];
float second_prob = probs[second] / (1 - first_prob);
pair_probs[std::make_pair(first, second)] = first_prob * second_prob;
}
}
}
std::map<std::string, float> inclusion_probs;
for (const auto string : strings) {
float inclusion_prob = 0.0;
for (const auto other_string : strings) {
inclusion_prob += pair_probs[std::make_pair(string, other_string)];
}
for (const auto other_string : strings) {
inclusion_prob += pair_probs[std::make_pair(other_string, string)];
}
inclusion_probs[string] = inclusion_prob / 2;
}
int kTrials = 10000;
std::vector<int> kNumSamples = {1, 2};
for (const auto num_samples : kNumSamples) {
std::map<std::string, int> counts;
for (int i = 0; i < kTrials; i++) {
auto nbests = lattice.NBest(num_samples, true, theta);
for (const auto nbest : nbests) {
counts[GetTokenized(nbest.first)]++;
}
}
EXPECT_EQ(inclusion_probs.size(), counts.size());
// If we take multiple samples WOR, we have to use corrected probs.
std::map<std::string, float> probs_to_use =
(num_samples == 1 ? probs : inclusion_probs);
for (const auto &it : probs_to_use) {
EXPECT_NEAR(it.second, 1.0 * counts[it.first] / (kTrials * num_samples),
0.02);
}
}
}
}
TEST(LatticeTest, CalculateEntropyTest) {
Lattice lattice;
lattice.SetSentence("ABC");
InsertWithScore(&lattice, 0, 1, 0.0); // A
InsertWithScore(&lattice, 1, 1, 0.0); // B
InsertWithScore(&lattice, 2, 1, 0.1); // C
InsertWithScore(&lattice, 0, 2, 0.2); // AB
InsertWithScore(&lattice, 1, 2, 0.5); // BC
InsertWithScore(&lattice, 0, 3, 1.0); // ABC
// Calculate expected probabilities of each path
const std::vector<double> kTheta = {0.0, 0.01, 0.5, 0.7, 1.0};
for (const auto theta : kTheta) {
std::vector<std::string> strings = {"ABC", "AB C", "A BC", "A B C"};
std::map<std::string, float> probs;
probs["ABC"] = std::exp(theta * 1.0);
probs["AB C"] = std::exp(theta * (0.2 + 0.1));
probs["A BC"] = std::exp(theta * (0.0 + 0.5));
probs["A B C"] = std::exp(theta * (0.0 + 0.0 + 0.1));
double Z = 0.0;
for (const auto &it : probs) Z += it.second;
for (auto &it : probs) it.second /= Z;
for (const auto &it : strings) {
EXPECT_EQ(1, probs.count(it));
}
float entropy = 0.0;
for (const auto &it : probs) {
entropy += (it.second * std::log(it.second));
}
EXPECT_NEAR(-entropy, lattice.CalculateEntropy(theta), 0.02);
}
}
TEST(LatticeTest, ForwardAlgorithmTest) {
Lattice lattice;
lattice.SetSentence("ABC");
InsertWithScore(&lattice, 0, 1, 0.0); // A
InsertWithScore(&lattice, 1, 1, 0.0); // B
InsertWithScore(&lattice, 2, 1, 0.1); // C
InsertWithScore(&lattice, 0, 2, 0.2); // AB
InsertWithScore(&lattice, 1, 2, 0.5); // BC
InsertWithScore(&lattice, 0, 3, 1.0); // ABC
const std::vector<float> kTheta = {0.0, 0.01, 0.5, 0.7, 1.0};
for (const auto theta : kTheta) {
std::vector<float> alpha = lattice.ForwardAlgorithm(theta);
EXPECT_EQ(alpha.size(), 8); // 6 nodes, plus BOS, EOS
// only alpha[C], alpha[EOS] have non-zero alpha
for (int i : {0, 1, 2, 3}) {
for (const auto &node : lattice.begin_nodes(i)) {
if (i < 2) {
EXPECT_EQ(alpha[node->node_id], 0.0);
} else if (i == 2) {
float Z =
std::log(std::exp(theta * (0.0 + 0.0)) + std::exp(theta * 0.2));
EXPECT_EQ(alpha[node->node_id], Z);
} else if (i == 3) {
float Z = std::log(std::exp(theta * (0.0 + 0.0 + 0.1)) + // A + B + C
std::exp(theta * (0.2 + 0.1)) + // AB + C
std::exp(theta * (0.0 + 0.5)) + // A + BC
std::exp(theta * 1.0)); // ABC
EXPECT_EQ(Z, alpha[node->node_id]);
}
}
}
}
}
TEST(LatticeTest, PopulateMarginalTest) {
Lattice lattice;
lattice.SetSentence("ABC");
@ -514,102 +361,6 @@ TEST(UnigramModelTest, SetUnigramModelTest) {
model.model_proto().SerializeAsString());
}
TEST(UnigramModelTest, SampleEncodeAndScoreTest) {
// Test whether inclusion probabilities are correct
ModelProto model_proto = MakeBaseModelProto();
AddPiece(&model_proto, "A", 0.0); // 3
AddPiece(&model_proto, "B", 0.0); // 4
AddPiece(&model_proto, "C", 0.1); // 5
AddPiece(&model_proto, "AB", 0.2); // 6
AddPiece(&model_proto, "BC", 0.5); // 7
AddPiece(&model_proto, "ABC", 1.0); // 8
Model model(model_proto);
Lattice lattice;
lattice.SetSentence("ABC");
model.PopulateNodes(&lattice);
std::vector<float> kTheta = {0.0, 1.0};
for (const auto theta : kTheta) {
std::vector<std::string> strings = {"ABC", "AB C", "A BC", "A B C"};
std::map<std::string, float> probs;
probs["ABC"] = std::exp(theta * 1.0);
probs["AB C"] = std::exp(theta * (0.2 + 0.1));
probs["A BC"] = std::exp(theta * (0.0 + 0.5));
probs["A B C"] = std::exp(theta * (0.0 + 0.0 + 0.1));
for (const auto &it : strings) {
EXPECT_EQ(1, probs.count(it));
}
double Z = 0.0;
for (const auto &it : probs) Z += it.second;
for (auto &it : probs) it.second /= Z;
std::map<std::pair<std::string, std::string>, float> pair_probs;
for (const auto first : strings) {
for (const auto second : strings) {
if (first == second) {
pair_probs[std::make_pair(first, second)] = 0;
} else {
float first_prob = probs[first];
float second_prob = probs[second] / (1 - first_prob);
pair_probs[std::make_pair(first, second)] = first_prob * second_prob;
}
}
}
std::map<std::string, float> inclusion_probs;
for (const auto string : strings) {
float inclusion_prob = 0.0;
for (const auto other_string : strings) {
inclusion_prob += pair_probs[std::make_pair(string, other_string)];
}
for (const auto other_string : strings) {
inclusion_prob += pair_probs[std::make_pair(other_string, string)];
}
inclusion_probs[string] = inclusion_prob / 2;
}
std::vector<int> kNumSamples = {1, 2};
for (const auto num_samples : kNumSamples) {
std::map<std::string, int> counts;
std::map<std::string, float> scores;
int kTrials = 50000;
for (int i = 0; i < kTrials; i++) {
NBestEncodeResult sample =
model.SampleEncodeAndScore("ABC", theta, num_samples, true, false);
for (const auto &it : sample) {
std::vector<std::string> tokens;
for (const auto &inner_it : it.first) {
tokens.push_back(std::string(inner_it.first));
}
std::string sample_string = absl::StrJoin(tokens, " ");
counts[sample_string] += 1;
// use the fact that E(1_{i in sample} / score of i) = 1
// see https://arxiv.org/pdf/1903.06059.pdf appendix D
scores[sample_string] += std::exp(-it.second);
}
}
// Check that counts and probs are correct
std::map<std::string, float> probs_to_use =
(num_samples == 1 ? probs : inclusion_probs);
for (const auto &it : scores) Z += it.second;
for (const auto &it : probs_to_use) {
EXPECT_NEAR(it.second, 1.0 * counts[it.first] / (kTrials * num_samples),
0.02);
// The expectation is quite loose, use a higher tolerance
EXPECT_NEAR(1.0, scores[it.first] / kTrials, 0.05);
}
}
}
}
TEST_P(UnigramModelTest, PieceToIdTest) {
ModelProto model_proto = MakeBaseModelProto();

View File

@ -223,7 +223,7 @@ std::vector<float> Trainer::RunEStep(const TrainerModel &model, float *obj,
lattice.SetSentence(w);
model.PopulateNodes(&lattice);
const float Z = lattice.PopulateMarginal(freq, &expected[n]);
ntokens[n] += lattice.Viterbi().first.size();
ntokens[n] += lattice.Viterbi().size();
CHECK(!std::isnan(Z))
<< "likelihood is NAN. Input sentence may be too long";
objs[n] -= Z / all_sentence_freq;
@ -297,17 +297,17 @@ TrainerModel::SentencePieces Trainer::PruneSentencePieces(
const auto &w = sentencepieces[i];
lattice.SetSentence(w.first);
model.PopulateNodes(&lattice);
const auto nbests = lattice.NBest(2, false, 0.0);
const auto nbests = lattice.NBest(2);
if (nbests.size() == 1) {
// No second-best result is found. always keep this sentencepiece.
always_keep[i] = true;
continue;
} else if (nbests[0].first.size() >= 2) {
} else if (nbests[0].size() >= 2) {
// Can safely remove this sentencepiece if its Viterbi path is split.
always_keep[i] = false;
} else if (nbests[0].first.size() == 1) {
} else if (nbests[0].size() == 1) {
always_keep[i] = true;
for (const auto *node : nbests[1].first) {
for (const auto *node : nbests[1]) {
alternatives[i].push_back(node->id);
}
}
@ -339,7 +339,7 @@ TrainerModel::SentencePieces Trainer::PruneSentencePieces(
lattice.SetSentence(w.first);
model.PopulateNodes(&lattice);
vsums[n] += w.second;
for (const auto *node : lattice.Viterbi().first) {
for (const auto *node : lattice.Viterbi()) {
if (node->id >= 0) {
freqs[n][node->id] += w.second;
inverteds[n][node->id].push_back(i);

View File

@ -12,12 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.!
#include <iostream>
#include "util.h"
namespace sentencepiece {
#include <iostream>
namespace sentencepiece {
namespace {
constexpr unsigned int kDefaultSeed = static_cast<unsigned int>(-1);
static unsigned int g_seed = kDefaultSeed;

View File

@ -171,7 +171,6 @@ void Flag<bool>::set_value_as_str(const std::string &value_as_str) {
template class Flag<std::string>;
template class Flag<int32>;
template class Flag<uint32>;
template class Flag<double>;
template class Flag<bool>;
template class Flag<int64>;