mirror of
https://github.com/google/sentencepiece.git
synced 2024-09-11 10:55:42 +03:00
parent
05db0894d8
commit
3a5bc5815b
@ -1 +1 @@
|
|||||||
0.1.96
|
0.1.95
|
||||||
|
@ -1 +1 @@
|
|||||||
0.1.96
|
0.1.95
|
||||||
|
@ -12,6 +12,8 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.!
|
// limitations under the License.!
|
||||||
|
|
||||||
|
#include "bpe_model.h"
|
||||||
|
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <queue>
|
#include <queue>
|
||||||
@ -19,7 +21,6 @@
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "bpe_model.h"
|
|
||||||
#include "freelist.h"
|
#include "freelist.h"
|
||||||
#include "third_party/absl/container/flat_hash_map.h"
|
#include "third_party/absl/container/flat_hash_map.h"
|
||||||
#include "util.h"
|
#include "util.h"
|
||||||
|
@ -12,11 +12,12 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.!
|
// limitations under the License.!
|
||||||
|
|
||||||
|
#include "builder.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#include "builder.h"
|
|
||||||
#include "filesystem.h"
|
#include "filesystem.h"
|
||||||
#include "third_party/absl/strings/str_join.h"
|
#include "third_party/absl/strings/str_join.h"
|
||||||
#include "third_party/absl/strings/str_replace.h"
|
#include "third_party/absl/strings/str_replace.h"
|
||||||
@ -367,7 +368,6 @@ util::Status Builder::BuildNmtNFKCMap(CharsMap *chars_map) {
|
|||||||
nfkc_map[{0xFEFF}] = {0x20}; // ZERO WIDTH NO-BREAK
|
nfkc_map[{0xFEFF}] = {0x20}; // ZERO WIDTH NO-BREAK
|
||||||
nfkc_map[{0xFFFD}] = {0x20}; // REPLACEMENT CHARACTER
|
nfkc_map[{0xFFFD}] = {0x20}; // REPLACEMENT CHARACTER
|
||||||
nfkc_map[{0x200C}] = {0x20}; // ZERO WIDTH NON-JOINER
|
nfkc_map[{0x200C}] = {0x20}; // ZERO WIDTH NON-JOINER
|
||||||
nfkc_map[{0x200D}] = {0x20}; // ZERO WIDTH JOINER
|
|
||||||
|
|
||||||
// Ascii Control characters
|
// Ascii Control characters
|
||||||
nfkc_map[{0x0001}] = {};
|
nfkc_map[{0x0001}] = {};
|
||||||
|
@ -285,22 +285,22 @@ class TrainerSpec::_Internal {
|
|||||||
(*has_bits)[0] |= 1u;
|
(*has_bits)[0] |= 1u;
|
||||||
}
|
}
|
||||||
static void set_has_model_type(HasBits* has_bits) {
|
static void set_has_model_type(HasBits* has_bits) {
|
||||||
(*has_bits)[0] |= 524288u;
|
(*has_bits)[0] |= 262144u;
|
||||||
}
|
}
|
||||||
static void set_has_vocab_size(HasBits* has_bits) {
|
static void set_has_vocab_size(HasBits* has_bits) {
|
||||||
(*has_bits)[0] |= 1048576u;
|
(*has_bits)[0] |= 524288u;
|
||||||
}
|
}
|
||||||
static void set_has_self_test_sample_size(HasBits* has_bits) {
|
static void set_has_self_test_sample_size(HasBits* has_bits) {
|
||||||
(*has_bits)[0] |= 256u;
|
(*has_bits)[0] |= 256u;
|
||||||
}
|
}
|
||||||
static void set_has_character_coverage(HasBits* has_bits) {
|
static void set_has_character_coverage(HasBits* has_bits) {
|
||||||
(*has_bits)[0] |= 2097152u;
|
(*has_bits)[0] |= 1048576u;
|
||||||
}
|
}
|
||||||
static void set_has_input_sentence_size(HasBits* has_bits) {
|
static void set_has_input_sentence_size(HasBits* has_bits) {
|
||||||
(*has_bits)[0] |= 1024u;
|
(*has_bits)[0] |= 1024u;
|
||||||
}
|
}
|
||||||
static void set_has_shuffle_input_sentence(HasBits* has_bits) {
|
static void set_has_shuffle_input_sentence(HasBits* has_bits) {
|
||||||
(*has_bits)[0] |= 268435456u;
|
(*has_bits)[0] |= 134217728u;
|
||||||
}
|
}
|
||||||
static void set_has_mining_sentence_size(HasBits* has_bits) {
|
static void set_has_mining_sentence_size(HasBits* has_bits) {
|
||||||
(*has_bits)[0] |= 512u;
|
(*has_bits)[0] |= 512u;
|
||||||
@ -309,67 +309,64 @@ class TrainerSpec::_Internal {
|
|||||||
(*has_bits)[0] |= 2048u;
|
(*has_bits)[0] |= 2048u;
|
||||||
}
|
}
|
||||||
static void set_has_seed_sentencepiece_size(HasBits* has_bits) {
|
static void set_has_seed_sentencepiece_size(HasBits* has_bits) {
|
||||||
(*has_bits)[0] |= 4194304u;
|
(*has_bits)[0] |= 2097152u;
|
||||||
}
|
}
|
||||||
static void set_has_shrinking_factor(HasBits* has_bits) {
|
static void set_has_shrinking_factor(HasBits* has_bits) {
|
||||||
(*has_bits)[0] |= 8388608u;
|
(*has_bits)[0] |= 4194304u;
|
||||||
}
|
}
|
||||||
static void set_has_max_sentence_length(HasBits* has_bits) {
|
static void set_has_max_sentence_length(HasBits* has_bits) {
|
||||||
(*has_bits)[0] |= 67108864u;
|
|
||||||
}
|
|
||||||
static void set_has_num_threads(HasBits* has_bits) {
|
|
||||||
(*has_bits)[0] |= 16777216u;
|
|
||||||
}
|
|
||||||
static void set_has_num_sub_iterations(HasBits* has_bits) {
|
|
||||||
(*has_bits)[0] |= 33554432u;
|
(*has_bits)[0] |= 33554432u;
|
||||||
}
|
}
|
||||||
|
static void set_has_num_threads(HasBits* has_bits) {
|
||||||
|
(*has_bits)[0] |= 8388608u;
|
||||||
|
}
|
||||||
|
static void set_has_num_sub_iterations(HasBits* has_bits) {
|
||||||
|
(*has_bits)[0] |= 16777216u;
|
||||||
|
}
|
||||||
static void set_has_max_sentencepiece_length(HasBits* has_bits) {
|
static void set_has_max_sentencepiece_length(HasBits* has_bits) {
|
||||||
(*has_bits)[0] |= 134217728u;
|
(*has_bits)[0] |= 67108864u;
|
||||||
}
|
}
|
||||||
static void set_has_split_by_unicode_script(HasBits* has_bits) {
|
static void set_has_split_by_unicode_script(HasBits* has_bits) {
|
||||||
(*has_bits)[0] |= 536870912u;
|
(*has_bits)[0] |= 268435456u;
|
||||||
}
|
}
|
||||||
static void set_has_split_by_number(HasBits* has_bits) {
|
static void set_has_split_by_number(HasBits* has_bits) {
|
||||||
(*has_bits)[0] |= 1073741824u;
|
(*has_bits)[0] |= 536870912u;
|
||||||
}
|
}
|
||||||
static void set_has_split_by_whitespace(HasBits* has_bits) {
|
static void set_has_split_by_whitespace(HasBits* has_bits) {
|
||||||
(*has_bits)[0] |= 2147483648u;
|
(*has_bits)[0] |= 1073741824u;
|
||||||
}
|
}
|
||||||
static void set_has_treat_whitespace_as_suffix(HasBits* has_bits) {
|
static void set_has_treat_whitespace_as_suffix(HasBits* has_bits) {
|
||||||
(*has_bits)[0] |= 4096u;
|
(*has_bits)[0] |= 4096u;
|
||||||
}
|
}
|
||||||
static void set_has_allow_whitespace_only_pieces(HasBits* has_bits) {
|
|
||||||
(*has_bits)[0] |= 8192u;
|
|
||||||
}
|
|
||||||
static void set_has_split_digits(HasBits* has_bits) {
|
static void set_has_split_digits(HasBits* has_bits) {
|
||||||
(*has_bits)[0] |= 16384u;
|
(*has_bits)[0] |= 8192u;
|
||||||
}
|
}
|
||||||
static void set_has_required_chars(HasBits* has_bits) {
|
static void set_has_required_chars(HasBits* has_bits) {
|
||||||
(*has_bits)[0] |= 4u;
|
(*has_bits)[0] |= 4u;
|
||||||
}
|
}
|
||||||
static void set_has_byte_fallback(HasBits* has_bits) {
|
static void set_has_byte_fallback(HasBits* has_bits) {
|
||||||
(*has_bits)[0] |= 32768u;
|
(*has_bits)[0] |= 16384u;
|
||||||
}
|
}
|
||||||
static void set_has_vocabulary_output_piece_score(HasBits* has_bits) {
|
static void set_has_vocabulary_output_piece_score(HasBits* has_bits) {
|
||||||
(*has_bits)[1] |= 1u;
|
(*has_bits)[0] |= 2147483648u;
|
||||||
}
|
}
|
||||||
static void set_has_hard_vocab_limit(HasBits* has_bits) {
|
static void set_has_hard_vocab_limit(HasBits* has_bits) {
|
||||||
(*has_bits)[1] |= 2u;
|
(*has_bits)[1] |= 1u;
|
||||||
}
|
}
|
||||||
static void set_has_use_all_vocab(HasBits* has_bits) {
|
static void set_has_use_all_vocab(HasBits* has_bits) {
|
||||||
(*has_bits)[0] |= 131072u;
|
(*has_bits)[0] |= 32768u;
|
||||||
}
|
}
|
||||||
static void set_has_unk_id(HasBits* has_bits) {
|
static void set_has_unk_id(HasBits* has_bits) {
|
||||||
(*has_bits)[0] |= 65536u;
|
(*has_bits)[0] |= 65536u;
|
||||||
}
|
}
|
||||||
static void set_has_bos_id(HasBits* has_bits) {
|
static void set_has_bos_id(HasBits* has_bits) {
|
||||||
(*has_bits)[1] |= 4u;
|
(*has_bits)[1] |= 2u;
|
||||||
}
|
}
|
||||||
static void set_has_eos_id(HasBits* has_bits) {
|
static void set_has_eos_id(HasBits* has_bits) {
|
||||||
(*has_bits)[1] |= 8u;
|
(*has_bits)[1] |= 4u;
|
||||||
}
|
}
|
||||||
static void set_has_pad_id(HasBits* has_bits) {
|
static void set_has_pad_id(HasBits* has_bits) {
|
||||||
(*has_bits)[1] |= 16u;
|
(*has_bits)[1] |= 8u;
|
||||||
}
|
}
|
||||||
static void set_has_unk_piece(HasBits* has_bits) {
|
static void set_has_unk_piece(HasBits* has_bits) {
|
||||||
(*has_bits)[0] |= 16u;
|
(*has_bits)[0] |= 16u;
|
||||||
@ -387,7 +384,7 @@ class TrainerSpec::_Internal {
|
|||||||
(*has_bits)[0] |= 8u;
|
(*has_bits)[0] |= 8u;
|
||||||
}
|
}
|
||||||
static void set_has_train_extremely_large_corpus(HasBits* has_bits) {
|
static void set_has_train_extremely_large_corpus(HasBits* has_bits) {
|
||||||
(*has_bits)[0] |= 262144u;
|
(*has_bits)[0] |= 131072u;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -569,8 +566,8 @@ void TrainerSpec::Clear() {
|
|||||||
}
|
}
|
||||||
if (cached_has_bits & 0x0000ff00u) {
|
if (cached_has_bits & 0x0000ff00u) {
|
||||||
::memset(&self_test_sample_size_, 0, static_cast<size_t>(
|
::memset(&self_test_sample_size_, 0, static_cast<size_t>(
|
||||||
reinterpret_cast<char*>(&byte_fallback_) -
|
reinterpret_cast<char*>(&use_all_vocab_) -
|
||||||
reinterpret_cast<char*>(&self_test_sample_size_)) + sizeof(byte_fallback_));
|
reinterpret_cast<char*>(&self_test_sample_size_)) + sizeof(use_all_vocab_));
|
||||||
}
|
}
|
||||||
if (cached_has_bits & 0x00ff0000u) {
|
if (cached_has_bits & 0x00ff0000u) {
|
||||||
::memset(&unk_id_, 0, static_cast<size_t>(
|
::memset(&unk_id_, 0, static_cast<size_t>(
|
||||||
@ -581,9 +578,9 @@ void TrainerSpec::Clear() {
|
|||||||
character_coverage_ = 0.9995f;
|
character_coverage_ = 0.9995f;
|
||||||
seed_sentencepiece_size_ = 1000000;
|
seed_sentencepiece_size_ = 1000000;
|
||||||
shrinking_factor_ = 0.75f;
|
shrinking_factor_ = 0.75f;
|
||||||
|
num_threads_ = 16;
|
||||||
}
|
}
|
||||||
if (cached_has_bits & 0xff000000u) {
|
if (cached_has_bits & 0xff000000u) {
|
||||||
num_threads_ = 16;
|
|
||||||
num_sub_iterations_ = 2;
|
num_sub_iterations_ = 2;
|
||||||
max_sentence_length_ = 4192;
|
max_sentence_length_ = 4192;
|
||||||
max_sentencepiece_length_ = 16;
|
max_sentencepiece_length_ = 16;
|
||||||
@ -591,10 +588,10 @@ void TrainerSpec::Clear() {
|
|||||||
split_by_unicode_script_ = true;
|
split_by_unicode_script_ = true;
|
||||||
split_by_number_ = true;
|
split_by_number_ = true;
|
||||||
split_by_whitespace_ = true;
|
split_by_whitespace_ = true;
|
||||||
|
vocabulary_output_piece_score_ = true;
|
||||||
}
|
}
|
||||||
cached_has_bits = _has_bits_[1];
|
cached_has_bits = _has_bits_[1];
|
||||||
if (cached_has_bits & 0x0000001fu) {
|
if (cached_has_bits & 0x0000000fu) {
|
||||||
vocabulary_output_piece_score_ = true;
|
|
||||||
hard_vocab_limit_ = true;
|
hard_vocab_limit_ = true;
|
||||||
bos_id_ = 1;
|
bos_id_ = 1;
|
||||||
eos_id_ = 2;
|
eos_id_ = 2;
|
||||||
@ -809,14 +806,6 @@ const char* TrainerSpec::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID
|
|||||||
CHK_(ptr);
|
CHK_(ptr);
|
||||||
} else goto handle_unusual;
|
} else goto handle_unusual;
|
||||||
continue;
|
continue;
|
||||||
// optional bool allow_whitespace_only_pieces = 26 [default = false];
|
|
||||||
case 26:
|
|
||||||
if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 208)) {
|
|
||||||
_Internal::set_has_allow_whitespace_only_pieces(&_has_bits_);
|
|
||||||
allow_whitespace_only_pieces_ = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr);
|
|
||||||
CHK_(ptr);
|
|
||||||
} else goto handle_unusual;
|
|
||||||
continue;
|
|
||||||
// repeated string control_symbols = 30;
|
// repeated string control_symbols = 30;
|
||||||
case 30:
|
case 30:
|
||||||
if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 242)) {
|
if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 242)) {
|
||||||
@ -1011,14 +1000,14 @@ failure:
|
|||||||
}
|
}
|
||||||
|
|
||||||
// optional .sentencepiece.TrainerSpec.ModelType model_type = 3 [default = UNIGRAM];
|
// optional .sentencepiece.TrainerSpec.ModelType model_type = 3 [default = UNIGRAM];
|
||||||
if (cached_has_bits & 0x00080000u) {
|
if (cached_has_bits & 0x00040000u) {
|
||||||
target = stream->EnsureSpace(target);
|
target = stream->EnsureSpace(target);
|
||||||
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteEnumToArray(
|
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteEnumToArray(
|
||||||
3, this->_internal_model_type(), target);
|
3, this->_internal_model_type(), target);
|
||||||
}
|
}
|
||||||
|
|
||||||
// optional int32 vocab_size = 4 [default = 8000];
|
// optional int32 vocab_size = 4 [default = 8000];
|
||||||
if (cached_has_bits & 0x00100000u) {
|
if (cached_has_bits & 0x00080000u) {
|
||||||
target = stream->EnsureSpace(target);
|
target = stream->EnsureSpace(target);
|
||||||
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(4, this->_internal_vocab_size(), target);
|
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(4, this->_internal_vocab_size(), target);
|
||||||
}
|
}
|
||||||
@ -1042,7 +1031,7 @@ failure:
|
|||||||
}
|
}
|
||||||
|
|
||||||
// optional float character_coverage = 10 [default = 0.9995];
|
// optional float character_coverage = 10 [default = 0.9995];
|
||||||
if (cached_has_bits & 0x00200000u) {
|
if (cached_has_bits & 0x00100000u) {
|
||||||
target = stream->EnsureSpace(target);
|
target = stream->EnsureSpace(target);
|
||||||
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteFloatToArray(10, this->_internal_character_coverage(), target);
|
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteFloatToArray(10, this->_internal_character_coverage(), target);
|
||||||
}
|
}
|
||||||
@ -1066,61 +1055,61 @@ failure:
|
|||||||
}
|
}
|
||||||
|
|
||||||
// optional int32 seed_sentencepiece_size = 14 [default = 1000000];
|
// optional int32 seed_sentencepiece_size = 14 [default = 1000000];
|
||||||
if (cached_has_bits & 0x00400000u) {
|
if (cached_has_bits & 0x00200000u) {
|
||||||
target = stream->EnsureSpace(target);
|
target = stream->EnsureSpace(target);
|
||||||
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(14, this->_internal_seed_sentencepiece_size(), target);
|
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(14, this->_internal_seed_sentencepiece_size(), target);
|
||||||
}
|
}
|
||||||
|
|
||||||
// optional float shrinking_factor = 15 [default = 0.75];
|
// optional float shrinking_factor = 15 [default = 0.75];
|
||||||
if (cached_has_bits & 0x00800000u) {
|
if (cached_has_bits & 0x00400000u) {
|
||||||
target = stream->EnsureSpace(target);
|
target = stream->EnsureSpace(target);
|
||||||
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteFloatToArray(15, this->_internal_shrinking_factor(), target);
|
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteFloatToArray(15, this->_internal_shrinking_factor(), target);
|
||||||
}
|
}
|
||||||
|
|
||||||
// optional int32 num_threads = 16 [default = 16];
|
// optional int32 num_threads = 16 [default = 16];
|
||||||
if (cached_has_bits & 0x01000000u) {
|
if (cached_has_bits & 0x00800000u) {
|
||||||
target = stream->EnsureSpace(target);
|
target = stream->EnsureSpace(target);
|
||||||
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(16, this->_internal_num_threads(), target);
|
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(16, this->_internal_num_threads(), target);
|
||||||
}
|
}
|
||||||
|
|
||||||
// optional int32 num_sub_iterations = 17 [default = 2];
|
// optional int32 num_sub_iterations = 17 [default = 2];
|
||||||
if (cached_has_bits & 0x02000000u) {
|
if (cached_has_bits & 0x01000000u) {
|
||||||
target = stream->EnsureSpace(target);
|
target = stream->EnsureSpace(target);
|
||||||
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(17, this->_internal_num_sub_iterations(), target);
|
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(17, this->_internal_num_sub_iterations(), target);
|
||||||
}
|
}
|
||||||
|
|
||||||
// optional int32 max_sentence_length = 18 [default = 4192];
|
// optional int32 max_sentence_length = 18 [default = 4192];
|
||||||
if (cached_has_bits & 0x04000000u) {
|
if (cached_has_bits & 0x02000000u) {
|
||||||
target = stream->EnsureSpace(target);
|
target = stream->EnsureSpace(target);
|
||||||
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(18, this->_internal_max_sentence_length(), target);
|
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(18, this->_internal_max_sentence_length(), target);
|
||||||
}
|
}
|
||||||
|
|
||||||
// optional bool shuffle_input_sentence = 19 [default = true];
|
// optional bool shuffle_input_sentence = 19 [default = true];
|
||||||
if (cached_has_bits & 0x10000000u) {
|
if (cached_has_bits & 0x08000000u) {
|
||||||
target = stream->EnsureSpace(target);
|
target = stream->EnsureSpace(target);
|
||||||
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(19, this->_internal_shuffle_input_sentence(), target);
|
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(19, this->_internal_shuffle_input_sentence(), target);
|
||||||
}
|
}
|
||||||
|
|
||||||
// optional int32 max_sentencepiece_length = 20 [default = 16];
|
// optional int32 max_sentencepiece_length = 20 [default = 16];
|
||||||
if (cached_has_bits & 0x08000000u) {
|
if (cached_has_bits & 0x04000000u) {
|
||||||
target = stream->EnsureSpace(target);
|
target = stream->EnsureSpace(target);
|
||||||
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(20, this->_internal_max_sentencepiece_length(), target);
|
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(20, this->_internal_max_sentencepiece_length(), target);
|
||||||
}
|
}
|
||||||
|
|
||||||
// optional bool split_by_unicode_script = 21 [default = true];
|
// optional bool split_by_unicode_script = 21 [default = true];
|
||||||
if (cached_has_bits & 0x20000000u) {
|
if (cached_has_bits & 0x10000000u) {
|
||||||
target = stream->EnsureSpace(target);
|
target = stream->EnsureSpace(target);
|
||||||
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(21, this->_internal_split_by_unicode_script(), target);
|
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(21, this->_internal_split_by_unicode_script(), target);
|
||||||
}
|
}
|
||||||
|
|
||||||
// optional bool split_by_whitespace = 22 [default = true];
|
// optional bool split_by_whitespace = 22 [default = true];
|
||||||
if (cached_has_bits & 0x80000000u) {
|
if (cached_has_bits & 0x40000000u) {
|
||||||
target = stream->EnsureSpace(target);
|
target = stream->EnsureSpace(target);
|
||||||
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(22, this->_internal_split_by_whitespace(), target);
|
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(22, this->_internal_split_by_whitespace(), target);
|
||||||
}
|
}
|
||||||
|
|
||||||
// optional bool split_by_number = 23 [default = true];
|
// optional bool split_by_number = 23 [default = true];
|
||||||
if (cached_has_bits & 0x40000000u) {
|
if (cached_has_bits & 0x20000000u) {
|
||||||
target = stream->EnsureSpace(target);
|
target = stream->EnsureSpace(target);
|
||||||
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(23, this->_internal_split_by_number(), target);
|
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(23, this->_internal_split_by_number(), target);
|
||||||
}
|
}
|
||||||
@ -1132,15 +1121,9 @@ failure:
|
|||||||
}
|
}
|
||||||
|
|
||||||
// optional bool split_digits = 25 [default = false];
|
// optional bool split_digits = 25 [default = false];
|
||||||
if (cached_has_bits & 0x00004000u) {
|
|
||||||
target = stream->EnsureSpace(target);
|
|
||||||
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(25, this->_internal_split_digits(), target);
|
|
||||||
}
|
|
||||||
|
|
||||||
// optional bool allow_whitespace_only_pieces = 26 [default = false];
|
|
||||||
if (cached_has_bits & 0x00002000u) {
|
if (cached_has_bits & 0x00002000u) {
|
||||||
target = stream->EnsureSpace(target);
|
target = stream->EnsureSpace(target);
|
||||||
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(26, this->_internal_allow_whitespace_only_pieces(), target);
|
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(25, this->_internal_split_digits(), target);
|
||||||
}
|
}
|
||||||
|
|
||||||
// repeated string control_symbols = 30;
|
// repeated string control_symbols = 30;
|
||||||
@ -1155,28 +1138,28 @@ failure:
|
|||||||
target = stream->WriteString(31, s, target);
|
target = stream->WriteString(31, s, target);
|
||||||
}
|
}
|
||||||
|
|
||||||
cached_has_bits = _has_bits_[1];
|
|
||||||
// optional bool vocabulary_output_piece_score = 32 [default = true];
|
// optional bool vocabulary_output_piece_score = 32 [default = true];
|
||||||
if (cached_has_bits & 0x00000001u) {
|
if (cached_has_bits & 0x80000000u) {
|
||||||
target = stream->EnsureSpace(target);
|
target = stream->EnsureSpace(target);
|
||||||
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(32, this->_internal_vocabulary_output_piece_score(), target);
|
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(32, this->_internal_vocabulary_output_piece_score(), target);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cached_has_bits = _has_bits_[1];
|
||||||
// optional bool hard_vocab_limit = 33 [default = true];
|
// optional bool hard_vocab_limit = 33 [default = true];
|
||||||
if (cached_has_bits & 0x00000002u) {
|
if (cached_has_bits & 0x00000001u) {
|
||||||
target = stream->EnsureSpace(target);
|
target = stream->EnsureSpace(target);
|
||||||
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(33, this->_internal_hard_vocab_limit(), target);
|
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(33, this->_internal_hard_vocab_limit(), target);
|
||||||
}
|
}
|
||||||
|
|
||||||
cached_has_bits = _has_bits_[0];
|
cached_has_bits = _has_bits_[0];
|
||||||
// optional bool use_all_vocab = 34 [default = false];
|
// optional bool use_all_vocab = 34 [default = false];
|
||||||
if (cached_has_bits & 0x00020000u) {
|
if (cached_has_bits & 0x00008000u) {
|
||||||
target = stream->EnsureSpace(target);
|
target = stream->EnsureSpace(target);
|
||||||
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(34, this->_internal_use_all_vocab(), target);
|
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(34, this->_internal_use_all_vocab(), target);
|
||||||
}
|
}
|
||||||
|
|
||||||
// optional bool byte_fallback = 35 [default = false];
|
// optional bool byte_fallback = 35 [default = false];
|
||||||
if (cached_has_bits & 0x00008000u) {
|
if (cached_has_bits & 0x00004000u) {
|
||||||
target = stream->EnsureSpace(target);
|
target = stream->EnsureSpace(target);
|
||||||
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(35, this->_internal_byte_fallback(), target);
|
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(35, this->_internal_byte_fallback(), target);
|
||||||
}
|
}
|
||||||
@ -1195,19 +1178,19 @@ failure:
|
|||||||
|
|
||||||
cached_has_bits = _has_bits_[1];
|
cached_has_bits = _has_bits_[1];
|
||||||
// optional int32 bos_id = 41 [default = 1];
|
// optional int32 bos_id = 41 [default = 1];
|
||||||
if (cached_has_bits & 0x00000004u) {
|
if (cached_has_bits & 0x00000002u) {
|
||||||
target = stream->EnsureSpace(target);
|
target = stream->EnsureSpace(target);
|
||||||
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(41, this->_internal_bos_id(), target);
|
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(41, this->_internal_bos_id(), target);
|
||||||
}
|
}
|
||||||
|
|
||||||
// optional int32 eos_id = 42 [default = 2];
|
// optional int32 eos_id = 42 [default = 2];
|
||||||
if (cached_has_bits & 0x00000008u) {
|
if (cached_has_bits & 0x00000004u) {
|
||||||
target = stream->EnsureSpace(target);
|
target = stream->EnsureSpace(target);
|
||||||
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(42, this->_internal_eos_id(), target);
|
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(42, this->_internal_eos_id(), target);
|
||||||
}
|
}
|
||||||
|
|
||||||
// optional int32 pad_id = 43 [default = -1];
|
// optional int32 pad_id = 43 [default = -1];
|
||||||
if (cached_has_bits & 0x00000010u) {
|
if (cached_has_bits & 0x00000008u) {
|
||||||
target = stream->EnsureSpace(target);
|
target = stream->EnsureSpace(target);
|
||||||
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(43, this->_internal_pad_id(), target);
|
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(43, this->_internal_pad_id(), target);
|
||||||
}
|
}
|
||||||
@ -1244,7 +1227,7 @@ failure:
|
|||||||
}
|
}
|
||||||
|
|
||||||
// optional bool train_extremely_large_corpus = 49 [default = false];
|
// optional bool train_extremely_large_corpus = 49 [default = false];
|
||||||
if (cached_has_bits & 0x00040000u) {
|
if (cached_has_bits & 0x00020000u) {
|
||||||
target = stream->EnsureSpace(target);
|
target = stream->EnsureSpace(target);
|
||||||
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(49, this->_internal_train_extremely_large_corpus(), target);
|
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(49, this->_internal_train_extremely_large_corpus(), target);
|
||||||
}
|
}
|
||||||
@ -1396,17 +1379,17 @@ size_t TrainerSpec::ByteSizeLong() const {
|
|||||||
total_size += 2 + 1;
|
total_size += 2 + 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// optional bool allow_whitespace_only_pieces = 26 [default = false];
|
// optional bool split_digits = 25 [default = false];
|
||||||
if (cached_has_bits & 0x00002000u) {
|
if (cached_has_bits & 0x00002000u) {
|
||||||
total_size += 2 + 1;
|
total_size += 2 + 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// optional bool split_digits = 25 [default = false];
|
// optional bool byte_fallback = 35 [default = false];
|
||||||
if (cached_has_bits & 0x00004000u) {
|
if (cached_has_bits & 0x00004000u) {
|
||||||
total_size += 2 + 1;
|
total_size += 2 + 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// optional bool byte_fallback = 35 [default = false];
|
// optional bool use_all_vocab = 34 [default = false];
|
||||||
if (cached_has_bits & 0x00008000u) {
|
if (cached_has_bits & 0x00008000u) {
|
||||||
total_size += 2 + 1;
|
total_size += 2 + 1;
|
||||||
}
|
}
|
||||||
@ -1420,125 +1403,120 @@ size_t TrainerSpec::ByteSizeLong() const {
|
|||||||
this->_internal_unk_id());
|
this->_internal_unk_id());
|
||||||
}
|
}
|
||||||
|
|
||||||
// optional bool use_all_vocab = 34 [default = false];
|
// optional bool train_extremely_large_corpus = 49 [default = false];
|
||||||
if (cached_has_bits & 0x00020000u) {
|
if (cached_has_bits & 0x00020000u) {
|
||||||
total_size += 2 + 1;
|
total_size += 2 + 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// optional bool train_extremely_large_corpus = 49 [default = false];
|
|
||||||
if (cached_has_bits & 0x00040000u) {
|
|
||||||
total_size += 2 + 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// optional .sentencepiece.TrainerSpec.ModelType model_type = 3 [default = UNIGRAM];
|
// optional .sentencepiece.TrainerSpec.ModelType model_type = 3 [default = UNIGRAM];
|
||||||
if (cached_has_bits & 0x00080000u) {
|
if (cached_has_bits & 0x00040000u) {
|
||||||
total_size += 1 +
|
total_size += 1 +
|
||||||
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::EnumSize(this->_internal_model_type());
|
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::EnumSize(this->_internal_model_type());
|
||||||
}
|
}
|
||||||
|
|
||||||
// optional int32 vocab_size = 4 [default = 8000];
|
// optional int32 vocab_size = 4 [default = 8000];
|
||||||
if (cached_has_bits & 0x00100000u) {
|
if (cached_has_bits & 0x00080000u) {
|
||||||
total_size += 1 +
|
total_size += 1 +
|
||||||
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size(
|
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size(
|
||||||
this->_internal_vocab_size());
|
this->_internal_vocab_size());
|
||||||
}
|
}
|
||||||
|
|
||||||
// optional float character_coverage = 10 [default = 0.9995];
|
// optional float character_coverage = 10 [default = 0.9995];
|
||||||
if (cached_has_bits & 0x00200000u) {
|
if (cached_has_bits & 0x00100000u) {
|
||||||
total_size += 1 + 4;
|
total_size += 1 + 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
// optional int32 seed_sentencepiece_size = 14 [default = 1000000];
|
// optional int32 seed_sentencepiece_size = 14 [default = 1000000];
|
||||||
if (cached_has_bits & 0x00400000u) {
|
if (cached_has_bits & 0x00200000u) {
|
||||||
total_size += 1 +
|
total_size += 1 +
|
||||||
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size(
|
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size(
|
||||||
this->_internal_seed_sentencepiece_size());
|
this->_internal_seed_sentencepiece_size());
|
||||||
}
|
}
|
||||||
|
|
||||||
// optional float shrinking_factor = 15 [default = 0.75];
|
// optional float shrinking_factor = 15 [default = 0.75];
|
||||||
if (cached_has_bits & 0x00800000u) {
|
if (cached_has_bits & 0x00400000u) {
|
||||||
total_size += 1 + 4;
|
total_size += 1 + 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
|
||||||
if (cached_has_bits & 0xff000000u) {
|
|
||||||
// optional int32 num_threads = 16 [default = 16];
|
// optional int32 num_threads = 16 [default = 16];
|
||||||
if (cached_has_bits & 0x01000000u) {
|
if (cached_has_bits & 0x00800000u) {
|
||||||
total_size += 2 +
|
total_size += 2 +
|
||||||
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size(
|
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size(
|
||||||
this->_internal_num_threads());
|
this->_internal_num_threads());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
if (cached_has_bits & 0xff000000u) {
|
||||||
// optional int32 num_sub_iterations = 17 [default = 2];
|
// optional int32 num_sub_iterations = 17 [default = 2];
|
||||||
if (cached_has_bits & 0x02000000u) {
|
if (cached_has_bits & 0x01000000u) {
|
||||||
total_size += 2 +
|
total_size += 2 +
|
||||||
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size(
|
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size(
|
||||||
this->_internal_num_sub_iterations());
|
this->_internal_num_sub_iterations());
|
||||||
}
|
}
|
||||||
|
|
||||||
// optional int32 max_sentence_length = 18 [default = 4192];
|
// optional int32 max_sentence_length = 18 [default = 4192];
|
||||||
if (cached_has_bits & 0x04000000u) {
|
if (cached_has_bits & 0x02000000u) {
|
||||||
total_size += 2 +
|
total_size += 2 +
|
||||||
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size(
|
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size(
|
||||||
this->_internal_max_sentence_length());
|
this->_internal_max_sentence_length());
|
||||||
}
|
}
|
||||||
|
|
||||||
// optional int32 max_sentencepiece_length = 20 [default = 16];
|
// optional int32 max_sentencepiece_length = 20 [default = 16];
|
||||||
if (cached_has_bits & 0x08000000u) {
|
if (cached_has_bits & 0x04000000u) {
|
||||||
total_size += 2 +
|
total_size += 2 +
|
||||||
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size(
|
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size(
|
||||||
this->_internal_max_sentencepiece_length());
|
this->_internal_max_sentencepiece_length());
|
||||||
}
|
}
|
||||||
|
|
||||||
// optional bool shuffle_input_sentence = 19 [default = true];
|
// optional bool shuffle_input_sentence = 19 [default = true];
|
||||||
if (cached_has_bits & 0x10000000u) {
|
if (cached_has_bits & 0x08000000u) {
|
||||||
total_size += 2 + 1;
|
total_size += 2 + 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// optional bool split_by_unicode_script = 21 [default = true];
|
// optional bool split_by_unicode_script = 21 [default = true];
|
||||||
if (cached_has_bits & 0x20000000u) {
|
if (cached_has_bits & 0x10000000u) {
|
||||||
total_size += 2 + 1;
|
total_size += 2 + 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// optional bool split_by_number = 23 [default = true];
|
// optional bool split_by_number = 23 [default = true];
|
||||||
if (cached_has_bits & 0x40000000u) {
|
if (cached_has_bits & 0x20000000u) {
|
||||||
total_size += 2 + 1;
|
total_size += 2 + 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// optional bool split_by_whitespace = 22 [default = true];
|
// optional bool split_by_whitespace = 22 [default = true];
|
||||||
|
if (cached_has_bits & 0x40000000u) {
|
||||||
|
total_size += 2 + 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// optional bool vocabulary_output_piece_score = 32 [default = true];
|
||||||
if (cached_has_bits & 0x80000000u) {
|
if (cached_has_bits & 0x80000000u) {
|
||||||
total_size += 2 + 1;
|
total_size += 2 + 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
cached_has_bits = _has_bits_[1];
|
cached_has_bits = _has_bits_[1];
|
||||||
if (cached_has_bits & 0x0000001fu) {
|
if (cached_has_bits & 0x0000000fu) {
|
||||||
// optional bool vocabulary_output_piece_score = 32 [default = true];
|
// optional bool hard_vocab_limit = 33 [default = true];
|
||||||
if (cached_has_bits & 0x00000001u) {
|
if (cached_has_bits & 0x00000001u) {
|
||||||
total_size += 2 + 1;
|
total_size += 2 + 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// optional bool hard_vocab_limit = 33 [default = true];
|
|
||||||
if (cached_has_bits & 0x00000002u) {
|
|
||||||
total_size += 2 + 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// optional int32 bos_id = 41 [default = 1];
|
// optional int32 bos_id = 41 [default = 1];
|
||||||
if (cached_has_bits & 0x00000004u) {
|
if (cached_has_bits & 0x00000002u) {
|
||||||
total_size += 2 +
|
total_size += 2 +
|
||||||
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size(
|
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size(
|
||||||
this->_internal_bos_id());
|
this->_internal_bos_id());
|
||||||
}
|
}
|
||||||
|
|
||||||
// optional int32 eos_id = 42 [default = 2];
|
// optional int32 eos_id = 42 [default = 2];
|
||||||
if (cached_has_bits & 0x00000008u) {
|
if (cached_has_bits & 0x00000004u) {
|
||||||
total_size += 2 +
|
total_size += 2 +
|
||||||
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size(
|
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size(
|
||||||
this->_internal_eos_id());
|
this->_internal_eos_id());
|
||||||
}
|
}
|
||||||
|
|
||||||
// optional int32 pad_id = 43 [default = -1];
|
// optional int32 pad_id = 43 [default = -1];
|
||||||
if (cached_has_bits & 0x00000010u) {
|
if (cached_has_bits & 0x00000008u) {
|
||||||
total_size += 2 +
|
total_size += 2 +
|
||||||
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size(
|
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size(
|
||||||
this->_internal_pad_id());
|
this->_internal_pad_id());
|
||||||
@ -1615,14 +1593,14 @@ void TrainerSpec::MergeFrom(const TrainerSpec& from) {
|
|||||||
treat_whitespace_as_suffix_ = from.treat_whitespace_as_suffix_;
|
treat_whitespace_as_suffix_ = from.treat_whitespace_as_suffix_;
|
||||||
}
|
}
|
||||||
if (cached_has_bits & 0x00002000u) {
|
if (cached_has_bits & 0x00002000u) {
|
||||||
allow_whitespace_only_pieces_ = from.allow_whitespace_only_pieces_;
|
|
||||||
}
|
|
||||||
if (cached_has_bits & 0x00004000u) {
|
|
||||||
split_digits_ = from.split_digits_;
|
split_digits_ = from.split_digits_;
|
||||||
}
|
}
|
||||||
if (cached_has_bits & 0x00008000u) {
|
if (cached_has_bits & 0x00004000u) {
|
||||||
byte_fallback_ = from.byte_fallback_;
|
byte_fallback_ = from.byte_fallback_;
|
||||||
}
|
}
|
||||||
|
if (cached_has_bits & 0x00008000u) {
|
||||||
|
use_all_vocab_ = from.use_all_vocab_;
|
||||||
|
}
|
||||||
_has_bits_[0] |= cached_has_bits;
|
_has_bits_[0] |= cached_has_bits;
|
||||||
}
|
}
|
||||||
if (cached_has_bits & 0x00ff0000u) {
|
if (cached_has_bits & 0x00ff0000u) {
|
||||||
@ -1630,70 +1608,67 @@ void TrainerSpec::MergeFrom(const TrainerSpec& from) {
|
|||||||
unk_id_ = from.unk_id_;
|
unk_id_ = from.unk_id_;
|
||||||
}
|
}
|
||||||
if (cached_has_bits & 0x00020000u) {
|
if (cached_has_bits & 0x00020000u) {
|
||||||
use_all_vocab_ = from.use_all_vocab_;
|
|
||||||
}
|
|
||||||
if (cached_has_bits & 0x00040000u) {
|
|
||||||
train_extremely_large_corpus_ = from.train_extremely_large_corpus_;
|
train_extremely_large_corpus_ = from.train_extremely_large_corpus_;
|
||||||
}
|
}
|
||||||
if (cached_has_bits & 0x00080000u) {
|
if (cached_has_bits & 0x00040000u) {
|
||||||
model_type_ = from.model_type_;
|
model_type_ = from.model_type_;
|
||||||
}
|
}
|
||||||
if (cached_has_bits & 0x00100000u) {
|
if (cached_has_bits & 0x00080000u) {
|
||||||
vocab_size_ = from.vocab_size_;
|
vocab_size_ = from.vocab_size_;
|
||||||
}
|
}
|
||||||
if (cached_has_bits & 0x00200000u) {
|
if (cached_has_bits & 0x00100000u) {
|
||||||
character_coverage_ = from.character_coverage_;
|
character_coverage_ = from.character_coverage_;
|
||||||
}
|
}
|
||||||
if (cached_has_bits & 0x00400000u) {
|
if (cached_has_bits & 0x00200000u) {
|
||||||
seed_sentencepiece_size_ = from.seed_sentencepiece_size_;
|
seed_sentencepiece_size_ = from.seed_sentencepiece_size_;
|
||||||
}
|
}
|
||||||
if (cached_has_bits & 0x00800000u) {
|
if (cached_has_bits & 0x00400000u) {
|
||||||
shrinking_factor_ = from.shrinking_factor_;
|
shrinking_factor_ = from.shrinking_factor_;
|
||||||
}
|
}
|
||||||
|
if (cached_has_bits & 0x00800000u) {
|
||||||
|
num_threads_ = from.num_threads_;
|
||||||
|
}
|
||||||
_has_bits_[0] |= cached_has_bits;
|
_has_bits_[0] |= cached_has_bits;
|
||||||
}
|
}
|
||||||
if (cached_has_bits & 0xff000000u) {
|
if (cached_has_bits & 0xff000000u) {
|
||||||
if (cached_has_bits & 0x01000000u) {
|
if (cached_has_bits & 0x01000000u) {
|
||||||
num_threads_ = from.num_threads_;
|
|
||||||
}
|
|
||||||
if (cached_has_bits & 0x02000000u) {
|
|
||||||
num_sub_iterations_ = from.num_sub_iterations_;
|
num_sub_iterations_ = from.num_sub_iterations_;
|
||||||
}
|
}
|
||||||
if (cached_has_bits & 0x04000000u) {
|
if (cached_has_bits & 0x02000000u) {
|
||||||
max_sentence_length_ = from.max_sentence_length_;
|
max_sentence_length_ = from.max_sentence_length_;
|
||||||
}
|
}
|
||||||
if (cached_has_bits & 0x08000000u) {
|
if (cached_has_bits & 0x04000000u) {
|
||||||
max_sentencepiece_length_ = from.max_sentencepiece_length_;
|
max_sentencepiece_length_ = from.max_sentencepiece_length_;
|
||||||
}
|
}
|
||||||
if (cached_has_bits & 0x10000000u) {
|
if (cached_has_bits & 0x08000000u) {
|
||||||
shuffle_input_sentence_ = from.shuffle_input_sentence_;
|
shuffle_input_sentence_ = from.shuffle_input_sentence_;
|
||||||
}
|
}
|
||||||
if (cached_has_bits & 0x20000000u) {
|
if (cached_has_bits & 0x10000000u) {
|
||||||
split_by_unicode_script_ = from.split_by_unicode_script_;
|
split_by_unicode_script_ = from.split_by_unicode_script_;
|
||||||
}
|
}
|
||||||
if (cached_has_bits & 0x40000000u) {
|
if (cached_has_bits & 0x20000000u) {
|
||||||
split_by_number_ = from.split_by_number_;
|
split_by_number_ = from.split_by_number_;
|
||||||
}
|
}
|
||||||
if (cached_has_bits & 0x80000000u) {
|
if (cached_has_bits & 0x40000000u) {
|
||||||
split_by_whitespace_ = from.split_by_whitespace_;
|
split_by_whitespace_ = from.split_by_whitespace_;
|
||||||
}
|
}
|
||||||
|
if (cached_has_bits & 0x80000000u) {
|
||||||
|
vocabulary_output_piece_score_ = from.vocabulary_output_piece_score_;
|
||||||
|
}
|
||||||
_has_bits_[0] |= cached_has_bits;
|
_has_bits_[0] |= cached_has_bits;
|
||||||
}
|
}
|
||||||
cached_has_bits = from._has_bits_[1];
|
cached_has_bits = from._has_bits_[1];
|
||||||
if (cached_has_bits & 0x0000001fu) {
|
if (cached_has_bits & 0x0000000fu) {
|
||||||
if (cached_has_bits & 0x00000001u) {
|
if (cached_has_bits & 0x00000001u) {
|
||||||
vocabulary_output_piece_score_ = from.vocabulary_output_piece_score_;
|
|
||||||
}
|
|
||||||
if (cached_has_bits & 0x00000002u) {
|
|
||||||
hard_vocab_limit_ = from.hard_vocab_limit_;
|
hard_vocab_limit_ = from.hard_vocab_limit_;
|
||||||
}
|
}
|
||||||
if (cached_has_bits & 0x00000004u) {
|
if (cached_has_bits & 0x00000002u) {
|
||||||
bos_id_ = from.bos_id_;
|
bos_id_ = from.bos_id_;
|
||||||
}
|
}
|
||||||
if (cached_has_bits & 0x00000008u) {
|
if (cached_has_bits & 0x00000004u) {
|
||||||
eos_id_ = from.eos_id_;
|
eos_id_ = from.eos_id_;
|
||||||
}
|
}
|
||||||
if (cached_has_bits & 0x00000010u) {
|
if (cached_has_bits & 0x00000008u) {
|
||||||
pad_id_ = from.pad_id_;
|
pad_id_ = from.pad_id_;
|
||||||
}
|
}
|
||||||
_has_bits_[1] |= cached_has_bits;
|
_has_bits_[1] |= cached_has_bits;
|
||||||
|
@ -278,11 +278,10 @@ class TrainerSpec PROTOBUF_FINAL :
|
|||||||
kInputSentenceSizeFieldNumber = 11,
|
kInputSentenceSizeFieldNumber = 11,
|
||||||
kTrainingSentenceSizeFieldNumber = 13,
|
kTrainingSentenceSizeFieldNumber = 13,
|
||||||
kTreatWhitespaceAsSuffixFieldNumber = 24,
|
kTreatWhitespaceAsSuffixFieldNumber = 24,
|
||||||
kAllowWhitespaceOnlyPiecesFieldNumber = 26,
|
|
||||||
kSplitDigitsFieldNumber = 25,
|
kSplitDigitsFieldNumber = 25,
|
||||||
kByteFallbackFieldNumber = 35,
|
kByteFallbackFieldNumber = 35,
|
||||||
kUnkIdFieldNumber = 40,
|
|
||||||
kUseAllVocabFieldNumber = 34,
|
kUseAllVocabFieldNumber = 34,
|
||||||
|
kUnkIdFieldNumber = 40,
|
||||||
kTrainExtremelyLargeCorpusFieldNumber = 49,
|
kTrainExtremelyLargeCorpusFieldNumber = 49,
|
||||||
kModelTypeFieldNumber = 3,
|
kModelTypeFieldNumber = 3,
|
||||||
kVocabSizeFieldNumber = 4,
|
kVocabSizeFieldNumber = 4,
|
||||||
@ -624,19 +623,6 @@ class TrainerSpec PROTOBUF_FINAL :
|
|||||||
void _internal_set_treat_whitespace_as_suffix(bool value);
|
void _internal_set_treat_whitespace_as_suffix(bool value);
|
||||||
public:
|
public:
|
||||||
|
|
||||||
// optional bool allow_whitespace_only_pieces = 26 [default = false];
|
|
||||||
bool has_allow_whitespace_only_pieces() const;
|
|
||||||
private:
|
|
||||||
bool _internal_has_allow_whitespace_only_pieces() const;
|
|
||||||
public:
|
|
||||||
void clear_allow_whitespace_only_pieces();
|
|
||||||
bool allow_whitespace_only_pieces() const;
|
|
||||||
void set_allow_whitespace_only_pieces(bool value);
|
|
||||||
private:
|
|
||||||
bool _internal_allow_whitespace_only_pieces() const;
|
|
||||||
void _internal_set_allow_whitespace_only_pieces(bool value);
|
|
||||||
public:
|
|
||||||
|
|
||||||
// optional bool split_digits = 25 [default = false];
|
// optional bool split_digits = 25 [default = false];
|
||||||
bool has_split_digits() const;
|
bool has_split_digits() const;
|
||||||
private:
|
private:
|
||||||
@ -663,19 +649,6 @@ class TrainerSpec PROTOBUF_FINAL :
|
|||||||
void _internal_set_byte_fallback(bool value);
|
void _internal_set_byte_fallback(bool value);
|
||||||
public:
|
public:
|
||||||
|
|
||||||
// optional int32 unk_id = 40 [default = 0];
|
|
||||||
bool has_unk_id() const;
|
|
||||||
private:
|
|
||||||
bool _internal_has_unk_id() const;
|
|
||||||
public:
|
|
||||||
void clear_unk_id();
|
|
||||||
::PROTOBUF_NAMESPACE_ID::int32 unk_id() const;
|
|
||||||
void set_unk_id(::PROTOBUF_NAMESPACE_ID::int32 value);
|
|
||||||
private:
|
|
||||||
::PROTOBUF_NAMESPACE_ID::int32 _internal_unk_id() const;
|
|
||||||
void _internal_set_unk_id(::PROTOBUF_NAMESPACE_ID::int32 value);
|
|
||||||
public:
|
|
||||||
|
|
||||||
// optional bool use_all_vocab = 34 [default = false];
|
// optional bool use_all_vocab = 34 [default = false];
|
||||||
bool has_use_all_vocab() const;
|
bool has_use_all_vocab() const;
|
||||||
private:
|
private:
|
||||||
@ -689,6 +662,19 @@ class TrainerSpec PROTOBUF_FINAL :
|
|||||||
void _internal_set_use_all_vocab(bool value);
|
void _internal_set_use_all_vocab(bool value);
|
||||||
public:
|
public:
|
||||||
|
|
||||||
|
// optional int32 unk_id = 40 [default = 0];
|
||||||
|
bool has_unk_id() const;
|
||||||
|
private:
|
||||||
|
bool _internal_has_unk_id() const;
|
||||||
|
public:
|
||||||
|
void clear_unk_id();
|
||||||
|
::PROTOBUF_NAMESPACE_ID::int32 unk_id() const;
|
||||||
|
void set_unk_id(::PROTOBUF_NAMESPACE_ID::int32 value);
|
||||||
|
private:
|
||||||
|
::PROTOBUF_NAMESPACE_ID::int32 _internal_unk_id() const;
|
||||||
|
void _internal_set_unk_id(::PROTOBUF_NAMESPACE_ID::int32 value);
|
||||||
|
public:
|
||||||
|
|
||||||
// optional bool train_extremely_large_corpus = 49 [default = false];
|
// optional bool train_extremely_large_corpus = 49 [default = false];
|
||||||
bool has_train_extremely_large_corpus() const;
|
bool has_train_extremely_large_corpus() const;
|
||||||
private:
|
private:
|
||||||
@ -970,11 +956,10 @@ class TrainerSpec PROTOBUF_FINAL :
|
|||||||
::PROTOBUF_NAMESPACE_ID::uint64 input_sentence_size_;
|
::PROTOBUF_NAMESPACE_ID::uint64 input_sentence_size_;
|
||||||
::PROTOBUF_NAMESPACE_ID::int32 training_sentence_size_;
|
::PROTOBUF_NAMESPACE_ID::int32 training_sentence_size_;
|
||||||
bool treat_whitespace_as_suffix_;
|
bool treat_whitespace_as_suffix_;
|
||||||
bool allow_whitespace_only_pieces_;
|
|
||||||
bool split_digits_;
|
bool split_digits_;
|
||||||
bool byte_fallback_;
|
bool byte_fallback_;
|
||||||
::PROTOBUF_NAMESPACE_ID::int32 unk_id_;
|
|
||||||
bool use_all_vocab_;
|
bool use_all_vocab_;
|
||||||
|
::PROTOBUF_NAMESPACE_ID::int32 unk_id_;
|
||||||
bool train_extremely_large_corpus_;
|
bool train_extremely_large_corpus_;
|
||||||
int model_type_;
|
int model_type_;
|
||||||
::PROTOBUF_NAMESPACE_ID::int32 vocab_size_;
|
::PROTOBUF_NAMESPACE_ID::int32 vocab_size_;
|
||||||
@ -2195,7 +2180,7 @@ inline void TrainerSpec::set_allocated_model_prefix(std::string* model_prefix) {
|
|||||||
|
|
||||||
// optional .sentencepiece.TrainerSpec.ModelType model_type = 3 [default = UNIGRAM];
|
// optional .sentencepiece.TrainerSpec.ModelType model_type = 3 [default = UNIGRAM];
|
||||||
inline bool TrainerSpec::_internal_has_model_type() const {
|
inline bool TrainerSpec::_internal_has_model_type() const {
|
||||||
bool value = (_has_bits_[0] & 0x00080000u) != 0;
|
bool value = (_has_bits_[0] & 0x00040000u) != 0;
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
inline bool TrainerSpec::has_model_type() const {
|
inline bool TrainerSpec::has_model_type() const {
|
||||||
@ -2203,7 +2188,7 @@ inline bool TrainerSpec::has_model_type() const {
|
|||||||
}
|
}
|
||||||
inline void TrainerSpec::clear_model_type() {
|
inline void TrainerSpec::clear_model_type() {
|
||||||
model_type_ = 1;
|
model_type_ = 1;
|
||||||
_has_bits_[0] &= ~0x00080000u;
|
_has_bits_[0] &= ~0x00040000u;
|
||||||
}
|
}
|
||||||
inline ::sentencepiece::TrainerSpec_ModelType TrainerSpec::_internal_model_type() const {
|
inline ::sentencepiece::TrainerSpec_ModelType TrainerSpec::_internal_model_type() const {
|
||||||
return static_cast< ::sentencepiece::TrainerSpec_ModelType >(model_type_);
|
return static_cast< ::sentencepiece::TrainerSpec_ModelType >(model_type_);
|
||||||
@ -2214,7 +2199,7 @@ inline ::sentencepiece::TrainerSpec_ModelType TrainerSpec::model_type() const {
|
|||||||
}
|
}
|
||||||
inline void TrainerSpec::_internal_set_model_type(::sentencepiece::TrainerSpec_ModelType value) {
|
inline void TrainerSpec::_internal_set_model_type(::sentencepiece::TrainerSpec_ModelType value) {
|
||||||
assert(::sentencepiece::TrainerSpec_ModelType_IsValid(value));
|
assert(::sentencepiece::TrainerSpec_ModelType_IsValid(value));
|
||||||
_has_bits_[0] |= 0x00080000u;
|
_has_bits_[0] |= 0x00040000u;
|
||||||
model_type_ = value;
|
model_type_ = value;
|
||||||
}
|
}
|
||||||
inline void TrainerSpec::set_model_type(::sentencepiece::TrainerSpec_ModelType value) {
|
inline void TrainerSpec::set_model_type(::sentencepiece::TrainerSpec_ModelType value) {
|
||||||
@ -2224,7 +2209,7 @@ inline void TrainerSpec::set_model_type(::sentencepiece::TrainerSpec_ModelType v
|
|||||||
|
|
||||||
// optional int32 vocab_size = 4 [default = 8000];
|
// optional int32 vocab_size = 4 [default = 8000];
|
||||||
inline bool TrainerSpec::_internal_has_vocab_size() const {
|
inline bool TrainerSpec::_internal_has_vocab_size() const {
|
||||||
bool value = (_has_bits_[0] & 0x00100000u) != 0;
|
bool value = (_has_bits_[0] & 0x00080000u) != 0;
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
inline bool TrainerSpec::has_vocab_size() const {
|
inline bool TrainerSpec::has_vocab_size() const {
|
||||||
@ -2232,7 +2217,7 @@ inline bool TrainerSpec::has_vocab_size() const {
|
|||||||
}
|
}
|
||||||
inline void TrainerSpec::clear_vocab_size() {
|
inline void TrainerSpec::clear_vocab_size() {
|
||||||
vocab_size_ = 8000;
|
vocab_size_ = 8000;
|
||||||
_has_bits_[0] &= ~0x00100000u;
|
_has_bits_[0] &= ~0x00080000u;
|
||||||
}
|
}
|
||||||
inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_vocab_size() const {
|
inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_vocab_size() const {
|
||||||
return vocab_size_;
|
return vocab_size_;
|
||||||
@ -2242,7 +2227,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::vocab_size() const {
|
|||||||
return _internal_vocab_size();
|
return _internal_vocab_size();
|
||||||
}
|
}
|
||||||
inline void TrainerSpec::_internal_set_vocab_size(::PROTOBUF_NAMESPACE_ID::int32 value) {
|
inline void TrainerSpec::_internal_set_vocab_size(::PROTOBUF_NAMESPACE_ID::int32 value) {
|
||||||
_has_bits_[0] |= 0x00100000u;
|
_has_bits_[0] |= 0x00080000u;
|
||||||
vocab_size_ = value;
|
vocab_size_ = value;
|
||||||
}
|
}
|
||||||
inline void TrainerSpec::set_vocab_size(::PROTOBUF_NAMESPACE_ID::int32 value) {
|
inline void TrainerSpec::set_vocab_size(::PROTOBUF_NAMESPACE_ID::int32 value) {
|
||||||
@ -2354,7 +2339,7 @@ inline void TrainerSpec::set_self_test_sample_size(::PROTOBUF_NAMESPACE_ID::int3
|
|||||||
|
|
||||||
// optional float character_coverage = 10 [default = 0.9995];
|
// optional float character_coverage = 10 [default = 0.9995];
|
||||||
inline bool TrainerSpec::_internal_has_character_coverage() const {
|
inline bool TrainerSpec::_internal_has_character_coverage() const {
|
||||||
bool value = (_has_bits_[0] & 0x00200000u) != 0;
|
bool value = (_has_bits_[0] & 0x00100000u) != 0;
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
inline bool TrainerSpec::has_character_coverage() const {
|
inline bool TrainerSpec::has_character_coverage() const {
|
||||||
@ -2362,7 +2347,7 @@ inline bool TrainerSpec::has_character_coverage() const {
|
|||||||
}
|
}
|
||||||
inline void TrainerSpec::clear_character_coverage() {
|
inline void TrainerSpec::clear_character_coverage() {
|
||||||
character_coverage_ = 0.9995f;
|
character_coverage_ = 0.9995f;
|
||||||
_has_bits_[0] &= ~0x00200000u;
|
_has_bits_[0] &= ~0x00100000u;
|
||||||
}
|
}
|
||||||
inline float TrainerSpec::_internal_character_coverage() const {
|
inline float TrainerSpec::_internal_character_coverage() const {
|
||||||
return character_coverage_;
|
return character_coverage_;
|
||||||
@ -2372,7 +2357,7 @@ inline float TrainerSpec::character_coverage() const {
|
|||||||
return _internal_character_coverage();
|
return _internal_character_coverage();
|
||||||
}
|
}
|
||||||
inline void TrainerSpec::_internal_set_character_coverage(float value) {
|
inline void TrainerSpec::_internal_set_character_coverage(float value) {
|
||||||
_has_bits_[0] |= 0x00200000u;
|
_has_bits_[0] |= 0x00100000u;
|
||||||
character_coverage_ = value;
|
character_coverage_ = value;
|
||||||
}
|
}
|
||||||
inline void TrainerSpec::set_character_coverage(float value) {
|
inline void TrainerSpec::set_character_coverage(float value) {
|
||||||
@ -2410,7 +2395,7 @@ inline void TrainerSpec::set_input_sentence_size(::PROTOBUF_NAMESPACE_ID::uint64
|
|||||||
|
|
||||||
// optional bool shuffle_input_sentence = 19 [default = true];
|
// optional bool shuffle_input_sentence = 19 [default = true];
|
||||||
inline bool TrainerSpec::_internal_has_shuffle_input_sentence() const {
|
inline bool TrainerSpec::_internal_has_shuffle_input_sentence() const {
|
||||||
bool value = (_has_bits_[0] & 0x10000000u) != 0;
|
bool value = (_has_bits_[0] & 0x08000000u) != 0;
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
inline bool TrainerSpec::has_shuffle_input_sentence() const {
|
inline bool TrainerSpec::has_shuffle_input_sentence() const {
|
||||||
@ -2418,7 +2403,7 @@ inline bool TrainerSpec::has_shuffle_input_sentence() const {
|
|||||||
}
|
}
|
||||||
inline void TrainerSpec::clear_shuffle_input_sentence() {
|
inline void TrainerSpec::clear_shuffle_input_sentence() {
|
||||||
shuffle_input_sentence_ = true;
|
shuffle_input_sentence_ = true;
|
||||||
_has_bits_[0] &= ~0x10000000u;
|
_has_bits_[0] &= ~0x08000000u;
|
||||||
}
|
}
|
||||||
inline bool TrainerSpec::_internal_shuffle_input_sentence() const {
|
inline bool TrainerSpec::_internal_shuffle_input_sentence() const {
|
||||||
return shuffle_input_sentence_;
|
return shuffle_input_sentence_;
|
||||||
@ -2428,7 +2413,7 @@ inline bool TrainerSpec::shuffle_input_sentence() const {
|
|||||||
return _internal_shuffle_input_sentence();
|
return _internal_shuffle_input_sentence();
|
||||||
}
|
}
|
||||||
inline void TrainerSpec::_internal_set_shuffle_input_sentence(bool value) {
|
inline void TrainerSpec::_internal_set_shuffle_input_sentence(bool value) {
|
||||||
_has_bits_[0] |= 0x10000000u;
|
_has_bits_[0] |= 0x08000000u;
|
||||||
shuffle_input_sentence_ = value;
|
shuffle_input_sentence_ = value;
|
||||||
}
|
}
|
||||||
inline void TrainerSpec::set_shuffle_input_sentence(bool value) {
|
inline void TrainerSpec::set_shuffle_input_sentence(bool value) {
|
||||||
@ -2494,7 +2479,7 @@ inline void TrainerSpec::set_training_sentence_size(::PROTOBUF_NAMESPACE_ID::int
|
|||||||
|
|
||||||
// optional int32 seed_sentencepiece_size = 14 [default = 1000000];
|
// optional int32 seed_sentencepiece_size = 14 [default = 1000000];
|
||||||
inline bool TrainerSpec::_internal_has_seed_sentencepiece_size() const {
|
inline bool TrainerSpec::_internal_has_seed_sentencepiece_size() const {
|
||||||
bool value = (_has_bits_[0] & 0x00400000u) != 0;
|
bool value = (_has_bits_[0] & 0x00200000u) != 0;
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
inline bool TrainerSpec::has_seed_sentencepiece_size() const {
|
inline bool TrainerSpec::has_seed_sentencepiece_size() const {
|
||||||
@ -2502,7 +2487,7 @@ inline bool TrainerSpec::has_seed_sentencepiece_size() const {
|
|||||||
}
|
}
|
||||||
inline void TrainerSpec::clear_seed_sentencepiece_size() {
|
inline void TrainerSpec::clear_seed_sentencepiece_size() {
|
||||||
seed_sentencepiece_size_ = 1000000;
|
seed_sentencepiece_size_ = 1000000;
|
||||||
_has_bits_[0] &= ~0x00400000u;
|
_has_bits_[0] &= ~0x00200000u;
|
||||||
}
|
}
|
||||||
inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_seed_sentencepiece_size() const {
|
inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_seed_sentencepiece_size() const {
|
||||||
return seed_sentencepiece_size_;
|
return seed_sentencepiece_size_;
|
||||||
@ -2512,7 +2497,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::seed_sentencepiece_size() con
|
|||||||
return _internal_seed_sentencepiece_size();
|
return _internal_seed_sentencepiece_size();
|
||||||
}
|
}
|
||||||
inline void TrainerSpec::_internal_set_seed_sentencepiece_size(::PROTOBUF_NAMESPACE_ID::int32 value) {
|
inline void TrainerSpec::_internal_set_seed_sentencepiece_size(::PROTOBUF_NAMESPACE_ID::int32 value) {
|
||||||
_has_bits_[0] |= 0x00400000u;
|
_has_bits_[0] |= 0x00200000u;
|
||||||
seed_sentencepiece_size_ = value;
|
seed_sentencepiece_size_ = value;
|
||||||
}
|
}
|
||||||
inline void TrainerSpec::set_seed_sentencepiece_size(::PROTOBUF_NAMESPACE_ID::int32 value) {
|
inline void TrainerSpec::set_seed_sentencepiece_size(::PROTOBUF_NAMESPACE_ID::int32 value) {
|
||||||
@ -2522,7 +2507,7 @@ inline void TrainerSpec::set_seed_sentencepiece_size(::PROTOBUF_NAMESPACE_ID::in
|
|||||||
|
|
||||||
// optional float shrinking_factor = 15 [default = 0.75];
|
// optional float shrinking_factor = 15 [default = 0.75];
|
||||||
inline bool TrainerSpec::_internal_has_shrinking_factor() const {
|
inline bool TrainerSpec::_internal_has_shrinking_factor() const {
|
||||||
bool value = (_has_bits_[0] & 0x00800000u) != 0;
|
bool value = (_has_bits_[0] & 0x00400000u) != 0;
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
inline bool TrainerSpec::has_shrinking_factor() const {
|
inline bool TrainerSpec::has_shrinking_factor() const {
|
||||||
@ -2530,7 +2515,7 @@ inline bool TrainerSpec::has_shrinking_factor() const {
|
|||||||
}
|
}
|
||||||
inline void TrainerSpec::clear_shrinking_factor() {
|
inline void TrainerSpec::clear_shrinking_factor() {
|
||||||
shrinking_factor_ = 0.75f;
|
shrinking_factor_ = 0.75f;
|
||||||
_has_bits_[0] &= ~0x00800000u;
|
_has_bits_[0] &= ~0x00400000u;
|
||||||
}
|
}
|
||||||
inline float TrainerSpec::_internal_shrinking_factor() const {
|
inline float TrainerSpec::_internal_shrinking_factor() const {
|
||||||
return shrinking_factor_;
|
return shrinking_factor_;
|
||||||
@ -2540,7 +2525,7 @@ inline float TrainerSpec::shrinking_factor() const {
|
|||||||
return _internal_shrinking_factor();
|
return _internal_shrinking_factor();
|
||||||
}
|
}
|
||||||
inline void TrainerSpec::_internal_set_shrinking_factor(float value) {
|
inline void TrainerSpec::_internal_set_shrinking_factor(float value) {
|
||||||
_has_bits_[0] |= 0x00800000u;
|
_has_bits_[0] |= 0x00400000u;
|
||||||
shrinking_factor_ = value;
|
shrinking_factor_ = value;
|
||||||
}
|
}
|
||||||
inline void TrainerSpec::set_shrinking_factor(float value) {
|
inline void TrainerSpec::set_shrinking_factor(float value) {
|
||||||
@ -2550,7 +2535,7 @@ inline void TrainerSpec::set_shrinking_factor(float value) {
|
|||||||
|
|
||||||
// optional int32 max_sentence_length = 18 [default = 4192];
|
// optional int32 max_sentence_length = 18 [default = 4192];
|
||||||
inline bool TrainerSpec::_internal_has_max_sentence_length() const {
|
inline bool TrainerSpec::_internal_has_max_sentence_length() const {
|
||||||
bool value = (_has_bits_[0] & 0x04000000u) != 0;
|
bool value = (_has_bits_[0] & 0x02000000u) != 0;
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
inline bool TrainerSpec::has_max_sentence_length() const {
|
inline bool TrainerSpec::has_max_sentence_length() const {
|
||||||
@ -2558,7 +2543,7 @@ inline bool TrainerSpec::has_max_sentence_length() const {
|
|||||||
}
|
}
|
||||||
inline void TrainerSpec::clear_max_sentence_length() {
|
inline void TrainerSpec::clear_max_sentence_length() {
|
||||||
max_sentence_length_ = 4192;
|
max_sentence_length_ = 4192;
|
||||||
_has_bits_[0] &= ~0x04000000u;
|
_has_bits_[0] &= ~0x02000000u;
|
||||||
}
|
}
|
||||||
inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_max_sentence_length() const {
|
inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_max_sentence_length() const {
|
||||||
return max_sentence_length_;
|
return max_sentence_length_;
|
||||||
@ -2568,7 +2553,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::max_sentence_length() const {
|
|||||||
return _internal_max_sentence_length();
|
return _internal_max_sentence_length();
|
||||||
}
|
}
|
||||||
inline void TrainerSpec::_internal_set_max_sentence_length(::PROTOBUF_NAMESPACE_ID::int32 value) {
|
inline void TrainerSpec::_internal_set_max_sentence_length(::PROTOBUF_NAMESPACE_ID::int32 value) {
|
||||||
_has_bits_[0] |= 0x04000000u;
|
_has_bits_[0] |= 0x02000000u;
|
||||||
max_sentence_length_ = value;
|
max_sentence_length_ = value;
|
||||||
}
|
}
|
||||||
inline void TrainerSpec::set_max_sentence_length(::PROTOBUF_NAMESPACE_ID::int32 value) {
|
inline void TrainerSpec::set_max_sentence_length(::PROTOBUF_NAMESPACE_ID::int32 value) {
|
||||||
@ -2578,7 +2563,7 @@ inline void TrainerSpec::set_max_sentence_length(::PROTOBUF_NAMESPACE_ID::int32
|
|||||||
|
|
||||||
// optional int32 num_threads = 16 [default = 16];
|
// optional int32 num_threads = 16 [default = 16];
|
||||||
inline bool TrainerSpec::_internal_has_num_threads() const {
|
inline bool TrainerSpec::_internal_has_num_threads() const {
|
||||||
bool value = (_has_bits_[0] & 0x01000000u) != 0;
|
bool value = (_has_bits_[0] & 0x00800000u) != 0;
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
inline bool TrainerSpec::has_num_threads() const {
|
inline bool TrainerSpec::has_num_threads() const {
|
||||||
@ -2586,7 +2571,7 @@ inline bool TrainerSpec::has_num_threads() const {
|
|||||||
}
|
}
|
||||||
inline void TrainerSpec::clear_num_threads() {
|
inline void TrainerSpec::clear_num_threads() {
|
||||||
num_threads_ = 16;
|
num_threads_ = 16;
|
||||||
_has_bits_[0] &= ~0x01000000u;
|
_has_bits_[0] &= ~0x00800000u;
|
||||||
}
|
}
|
||||||
inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_num_threads() const {
|
inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_num_threads() const {
|
||||||
return num_threads_;
|
return num_threads_;
|
||||||
@ -2596,7 +2581,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::num_threads() const {
|
|||||||
return _internal_num_threads();
|
return _internal_num_threads();
|
||||||
}
|
}
|
||||||
inline void TrainerSpec::_internal_set_num_threads(::PROTOBUF_NAMESPACE_ID::int32 value) {
|
inline void TrainerSpec::_internal_set_num_threads(::PROTOBUF_NAMESPACE_ID::int32 value) {
|
||||||
_has_bits_[0] |= 0x01000000u;
|
_has_bits_[0] |= 0x00800000u;
|
||||||
num_threads_ = value;
|
num_threads_ = value;
|
||||||
}
|
}
|
||||||
inline void TrainerSpec::set_num_threads(::PROTOBUF_NAMESPACE_ID::int32 value) {
|
inline void TrainerSpec::set_num_threads(::PROTOBUF_NAMESPACE_ID::int32 value) {
|
||||||
@ -2606,7 +2591,7 @@ inline void TrainerSpec::set_num_threads(::PROTOBUF_NAMESPACE_ID::int32 value) {
|
|||||||
|
|
||||||
// optional int32 num_sub_iterations = 17 [default = 2];
|
// optional int32 num_sub_iterations = 17 [default = 2];
|
||||||
inline bool TrainerSpec::_internal_has_num_sub_iterations() const {
|
inline bool TrainerSpec::_internal_has_num_sub_iterations() const {
|
||||||
bool value = (_has_bits_[0] & 0x02000000u) != 0;
|
bool value = (_has_bits_[0] & 0x01000000u) != 0;
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
inline bool TrainerSpec::has_num_sub_iterations() const {
|
inline bool TrainerSpec::has_num_sub_iterations() const {
|
||||||
@ -2614,7 +2599,7 @@ inline bool TrainerSpec::has_num_sub_iterations() const {
|
|||||||
}
|
}
|
||||||
inline void TrainerSpec::clear_num_sub_iterations() {
|
inline void TrainerSpec::clear_num_sub_iterations() {
|
||||||
num_sub_iterations_ = 2;
|
num_sub_iterations_ = 2;
|
||||||
_has_bits_[0] &= ~0x02000000u;
|
_has_bits_[0] &= ~0x01000000u;
|
||||||
}
|
}
|
||||||
inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_num_sub_iterations() const {
|
inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_num_sub_iterations() const {
|
||||||
return num_sub_iterations_;
|
return num_sub_iterations_;
|
||||||
@ -2624,7 +2609,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::num_sub_iterations() const {
|
|||||||
return _internal_num_sub_iterations();
|
return _internal_num_sub_iterations();
|
||||||
}
|
}
|
||||||
inline void TrainerSpec::_internal_set_num_sub_iterations(::PROTOBUF_NAMESPACE_ID::int32 value) {
|
inline void TrainerSpec::_internal_set_num_sub_iterations(::PROTOBUF_NAMESPACE_ID::int32 value) {
|
||||||
_has_bits_[0] |= 0x02000000u;
|
_has_bits_[0] |= 0x01000000u;
|
||||||
num_sub_iterations_ = value;
|
num_sub_iterations_ = value;
|
||||||
}
|
}
|
||||||
inline void TrainerSpec::set_num_sub_iterations(::PROTOBUF_NAMESPACE_ID::int32 value) {
|
inline void TrainerSpec::set_num_sub_iterations(::PROTOBUF_NAMESPACE_ID::int32 value) {
|
||||||
@ -2634,7 +2619,7 @@ inline void TrainerSpec::set_num_sub_iterations(::PROTOBUF_NAMESPACE_ID::int32 v
|
|||||||
|
|
||||||
// optional int32 max_sentencepiece_length = 20 [default = 16];
|
// optional int32 max_sentencepiece_length = 20 [default = 16];
|
||||||
inline bool TrainerSpec::_internal_has_max_sentencepiece_length() const {
|
inline bool TrainerSpec::_internal_has_max_sentencepiece_length() const {
|
||||||
bool value = (_has_bits_[0] & 0x08000000u) != 0;
|
bool value = (_has_bits_[0] & 0x04000000u) != 0;
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
inline bool TrainerSpec::has_max_sentencepiece_length() const {
|
inline bool TrainerSpec::has_max_sentencepiece_length() const {
|
||||||
@ -2642,7 +2627,7 @@ inline bool TrainerSpec::has_max_sentencepiece_length() const {
|
|||||||
}
|
}
|
||||||
inline void TrainerSpec::clear_max_sentencepiece_length() {
|
inline void TrainerSpec::clear_max_sentencepiece_length() {
|
||||||
max_sentencepiece_length_ = 16;
|
max_sentencepiece_length_ = 16;
|
||||||
_has_bits_[0] &= ~0x08000000u;
|
_has_bits_[0] &= ~0x04000000u;
|
||||||
}
|
}
|
||||||
inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_max_sentencepiece_length() const {
|
inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_max_sentencepiece_length() const {
|
||||||
return max_sentencepiece_length_;
|
return max_sentencepiece_length_;
|
||||||
@ -2652,7 +2637,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::max_sentencepiece_length() co
|
|||||||
return _internal_max_sentencepiece_length();
|
return _internal_max_sentencepiece_length();
|
||||||
}
|
}
|
||||||
inline void TrainerSpec::_internal_set_max_sentencepiece_length(::PROTOBUF_NAMESPACE_ID::int32 value) {
|
inline void TrainerSpec::_internal_set_max_sentencepiece_length(::PROTOBUF_NAMESPACE_ID::int32 value) {
|
||||||
_has_bits_[0] |= 0x08000000u;
|
_has_bits_[0] |= 0x04000000u;
|
||||||
max_sentencepiece_length_ = value;
|
max_sentencepiece_length_ = value;
|
||||||
}
|
}
|
||||||
inline void TrainerSpec::set_max_sentencepiece_length(::PROTOBUF_NAMESPACE_ID::int32 value) {
|
inline void TrainerSpec::set_max_sentencepiece_length(::PROTOBUF_NAMESPACE_ID::int32 value) {
|
||||||
@ -2662,7 +2647,7 @@ inline void TrainerSpec::set_max_sentencepiece_length(::PROTOBUF_NAMESPACE_ID::i
|
|||||||
|
|
||||||
// optional bool split_by_unicode_script = 21 [default = true];
|
// optional bool split_by_unicode_script = 21 [default = true];
|
||||||
inline bool TrainerSpec::_internal_has_split_by_unicode_script() const {
|
inline bool TrainerSpec::_internal_has_split_by_unicode_script() const {
|
||||||
bool value = (_has_bits_[0] & 0x20000000u) != 0;
|
bool value = (_has_bits_[0] & 0x10000000u) != 0;
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
inline bool TrainerSpec::has_split_by_unicode_script() const {
|
inline bool TrainerSpec::has_split_by_unicode_script() const {
|
||||||
@ -2670,7 +2655,7 @@ inline bool TrainerSpec::has_split_by_unicode_script() const {
|
|||||||
}
|
}
|
||||||
inline void TrainerSpec::clear_split_by_unicode_script() {
|
inline void TrainerSpec::clear_split_by_unicode_script() {
|
||||||
split_by_unicode_script_ = true;
|
split_by_unicode_script_ = true;
|
||||||
_has_bits_[0] &= ~0x20000000u;
|
_has_bits_[0] &= ~0x10000000u;
|
||||||
}
|
}
|
||||||
inline bool TrainerSpec::_internal_split_by_unicode_script() const {
|
inline bool TrainerSpec::_internal_split_by_unicode_script() const {
|
||||||
return split_by_unicode_script_;
|
return split_by_unicode_script_;
|
||||||
@ -2680,7 +2665,7 @@ inline bool TrainerSpec::split_by_unicode_script() const {
|
|||||||
return _internal_split_by_unicode_script();
|
return _internal_split_by_unicode_script();
|
||||||
}
|
}
|
||||||
inline void TrainerSpec::_internal_set_split_by_unicode_script(bool value) {
|
inline void TrainerSpec::_internal_set_split_by_unicode_script(bool value) {
|
||||||
_has_bits_[0] |= 0x20000000u;
|
_has_bits_[0] |= 0x10000000u;
|
||||||
split_by_unicode_script_ = value;
|
split_by_unicode_script_ = value;
|
||||||
}
|
}
|
||||||
inline void TrainerSpec::set_split_by_unicode_script(bool value) {
|
inline void TrainerSpec::set_split_by_unicode_script(bool value) {
|
||||||
@ -2690,7 +2675,7 @@ inline void TrainerSpec::set_split_by_unicode_script(bool value) {
|
|||||||
|
|
||||||
// optional bool split_by_number = 23 [default = true];
|
// optional bool split_by_number = 23 [default = true];
|
||||||
inline bool TrainerSpec::_internal_has_split_by_number() const {
|
inline bool TrainerSpec::_internal_has_split_by_number() const {
|
||||||
bool value = (_has_bits_[0] & 0x40000000u) != 0;
|
bool value = (_has_bits_[0] & 0x20000000u) != 0;
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
inline bool TrainerSpec::has_split_by_number() const {
|
inline bool TrainerSpec::has_split_by_number() const {
|
||||||
@ -2698,7 +2683,7 @@ inline bool TrainerSpec::has_split_by_number() const {
|
|||||||
}
|
}
|
||||||
inline void TrainerSpec::clear_split_by_number() {
|
inline void TrainerSpec::clear_split_by_number() {
|
||||||
split_by_number_ = true;
|
split_by_number_ = true;
|
||||||
_has_bits_[0] &= ~0x40000000u;
|
_has_bits_[0] &= ~0x20000000u;
|
||||||
}
|
}
|
||||||
inline bool TrainerSpec::_internal_split_by_number() const {
|
inline bool TrainerSpec::_internal_split_by_number() const {
|
||||||
return split_by_number_;
|
return split_by_number_;
|
||||||
@ -2708,7 +2693,7 @@ inline bool TrainerSpec::split_by_number() const {
|
|||||||
return _internal_split_by_number();
|
return _internal_split_by_number();
|
||||||
}
|
}
|
||||||
inline void TrainerSpec::_internal_set_split_by_number(bool value) {
|
inline void TrainerSpec::_internal_set_split_by_number(bool value) {
|
||||||
_has_bits_[0] |= 0x40000000u;
|
_has_bits_[0] |= 0x20000000u;
|
||||||
split_by_number_ = value;
|
split_by_number_ = value;
|
||||||
}
|
}
|
||||||
inline void TrainerSpec::set_split_by_number(bool value) {
|
inline void TrainerSpec::set_split_by_number(bool value) {
|
||||||
@ -2718,7 +2703,7 @@ inline void TrainerSpec::set_split_by_number(bool value) {
|
|||||||
|
|
||||||
// optional bool split_by_whitespace = 22 [default = true];
|
// optional bool split_by_whitespace = 22 [default = true];
|
||||||
inline bool TrainerSpec::_internal_has_split_by_whitespace() const {
|
inline bool TrainerSpec::_internal_has_split_by_whitespace() const {
|
||||||
bool value = (_has_bits_[0] & 0x80000000u) != 0;
|
bool value = (_has_bits_[0] & 0x40000000u) != 0;
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
inline bool TrainerSpec::has_split_by_whitespace() const {
|
inline bool TrainerSpec::has_split_by_whitespace() const {
|
||||||
@ -2726,7 +2711,7 @@ inline bool TrainerSpec::has_split_by_whitespace() const {
|
|||||||
}
|
}
|
||||||
inline void TrainerSpec::clear_split_by_whitespace() {
|
inline void TrainerSpec::clear_split_by_whitespace() {
|
||||||
split_by_whitespace_ = true;
|
split_by_whitespace_ = true;
|
||||||
_has_bits_[0] &= ~0x80000000u;
|
_has_bits_[0] &= ~0x40000000u;
|
||||||
}
|
}
|
||||||
inline bool TrainerSpec::_internal_split_by_whitespace() const {
|
inline bool TrainerSpec::_internal_split_by_whitespace() const {
|
||||||
return split_by_whitespace_;
|
return split_by_whitespace_;
|
||||||
@ -2736,7 +2721,7 @@ inline bool TrainerSpec::split_by_whitespace() const {
|
|||||||
return _internal_split_by_whitespace();
|
return _internal_split_by_whitespace();
|
||||||
}
|
}
|
||||||
inline void TrainerSpec::_internal_set_split_by_whitespace(bool value) {
|
inline void TrainerSpec::_internal_set_split_by_whitespace(bool value) {
|
||||||
_has_bits_[0] |= 0x80000000u;
|
_has_bits_[0] |= 0x40000000u;
|
||||||
split_by_whitespace_ = value;
|
split_by_whitespace_ = value;
|
||||||
}
|
}
|
||||||
inline void TrainerSpec::set_split_by_whitespace(bool value) {
|
inline void TrainerSpec::set_split_by_whitespace(bool value) {
|
||||||
@ -2772,37 +2757,9 @@ inline void TrainerSpec::set_treat_whitespace_as_suffix(bool value) {
|
|||||||
// @@protoc_insertion_point(field_set:sentencepiece.TrainerSpec.treat_whitespace_as_suffix)
|
// @@protoc_insertion_point(field_set:sentencepiece.TrainerSpec.treat_whitespace_as_suffix)
|
||||||
}
|
}
|
||||||
|
|
||||||
// optional bool allow_whitespace_only_pieces = 26 [default = false];
|
|
||||||
inline bool TrainerSpec::_internal_has_allow_whitespace_only_pieces() const {
|
|
||||||
bool value = (_has_bits_[0] & 0x00002000u) != 0;
|
|
||||||
return value;
|
|
||||||
}
|
|
||||||
inline bool TrainerSpec::has_allow_whitespace_only_pieces() const {
|
|
||||||
return _internal_has_allow_whitespace_only_pieces();
|
|
||||||
}
|
|
||||||
inline void TrainerSpec::clear_allow_whitespace_only_pieces() {
|
|
||||||
allow_whitespace_only_pieces_ = false;
|
|
||||||
_has_bits_[0] &= ~0x00002000u;
|
|
||||||
}
|
|
||||||
inline bool TrainerSpec::_internal_allow_whitespace_only_pieces() const {
|
|
||||||
return allow_whitespace_only_pieces_;
|
|
||||||
}
|
|
||||||
inline bool TrainerSpec::allow_whitespace_only_pieces() const {
|
|
||||||
// @@protoc_insertion_point(field_get:sentencepiece.TrainerSpec.allow_whitespace_only_pieces)
|
|
||||||
return _internal_allow_whitespace_only_pieces();
|
|
||||||
}
|
|
||||||
inline void TrainerSpec::_internal_set_allow_whitespace_only_pieces(bool value) {
|
|
||||||
_has_bits_[0] |= 0x00002000u;
|
|
||||||
allow_whitespace_only_pieces_ = value;
|
|
||||||
}
|
|
||||||
inline void TrainerSpec::set_allow_whitespace_only_pieces(bool value) {
|
|
||||||
_internal_set_allow_whitespace_only_pieces(value);
|
|
||||||
// @@protoc_insertion_point(field_set:sentencepiece.TrainerSpec.allow_whitespace_only_pieces)
|
|
||||||
}
|
|
||||||
|
|
||||||
// optional bool split_digits = 25 [default = false];
|
// optional bool split_digits = 25 [default = false];
|
||||||
inline bool TrainerSpec::_internal_has_split_digits() const {
|
inline bool TrainerSpec::_internal_has_split_digits() const {
|
||||||
bool value = (_has_bits_[0] & 0x00004000u) != 0;
|
bool value = (_has_bits_[0] & 0x00002000u) != 0;
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
inline bool TrainerSpec::has_split_digits() const {
|
inline bool TrainerSpec::has_split_digits() const {
|
||||||
@ -2810,7 +2767,7 @@ inline bool TrainerSpec::has_split_digits() const {
|
|||||||
}
|
}
|
||||||
inline void TrainerSpec::clear_split_digits() {
|
inline void TrainerSpec::clear_split_digits() {
|
||||||
split_digits_ = false;
|
split_digits_ = false;
|
||||||
_has_bits_[0] &= ~0x00004000u;
|
_has_bits_[0] &= ~0x00002000u;
|
||||||
}
|
}
|
||||||
inline bool TrainerSpec::_internal_split_digits() const {
|
inline bool TrainerSpec::_internal_split_digits() const {
|
||||||
return split_digits_;
|
return split_digits_;
|
||||||
@ -2820,7 +2777,7 @@ inline bool TrainerSpec::split_digits() const {
|
|||||||
return _internal_split_digits();
|
return _internal_split_digits();
|
||||||
}
|
}
|
||||||
inline void TrainerSpec::_internal_set_split_digits(bool value) {
|
inline void TrainerSpec::_internal_set_split_digits(bool value) {
|
||||||
_has_bits_[0] |= 0x00004000u;
|
_has_bits_[0] |= 0x00002000u;
|
||||||
split_digits_ = value;
|
split_digits_ = value;
|
||||||
}
|
}
|
||||||
inline void TrainerSpec::set_split_digits(bool value) {
|
inline void TrainerSpec::set_split_digits(bool value) {
|
||||||
@ -3051,7 +3008,7 @@ inline void TrainerSpec::set_allocated_required_chars(std::string* required_char
|
|||||||
|
|
||||||
// optional bool byte_fallback = 35 [default = false];
|
// optional bool byte_fallback = 35 [default = false];
|
||||||
inline bool TrainerSpec::_internal_has_byte_fallback() const {
|
inline bool TrainerSpec::_internal_has_byte_fallback() const {
|
||||||
bool value = (_has_bits_[0] & 0x00008000u) != 0;
|
bool value = (_has_bits_[0] & 0x00004000u) != 0;
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
inline bool TrainerSpec::has_byte_fallback() const {
|
inline bool TrainerSpec::has_byte_fallback() const {
|
||||||
@ -3059,7 +3016,7 @@ inline bool TrainerSpec::has_byte_fallback() const {
|
|||||||
}
|
}
|
||||||
inline void TrainerSpec::clear_byte_fallback() {
|
inline void TrainerSpec::clear_byte_fallback() {
|
||||||
byte_fallback_ = false;
|
byte_fallback_ = false;
|
||||||
_has_bits_[0] &= ~0x00008000u;
|
_has_bits_[0] &= ~0x00004000u;
|
||||||
}
|
}
|
||||||
inline bool TrainerSpec::_internal_byte_fallback() const {
|
inline bool TrainerSpec::_internal_byte_fallback() const {
|
||||||
return byte_fallback_;
|
return byte_fallback_;
|
||||||
@ -3069,7 +3026,7 @@ inline bool TrainerSpec::byte_fallback() const {
|
|||||||
return _internal_byte_fallback();
|
return _internal_byte_fallback();
|
||||||
}
|
}
|
||||||
inline void TrainerSpec::_internal_set_byte_fallback(bool value) {
|
inline void TrainerSpec::_internal_set_byte_fallback(bool value) {
|
||||||
_has_bits_[0] |= 0x00008000u;
|
_has_bits_[0] |= 0x00004000u;
|
||||||
byte_fallback_ = value;
|
byte_fallback_ = value;
|
||||||
}
|
}
|
||||||
inline void TrainerSpec::set_byte_fallback(bool value) {
|
inline void TrainerSpec::set_byte_fallback(bool value) {
|
||||||
@ -3079,7 +3036,7 @@ inline void TrainerSpec::set_byte_fallback(bool value) {
|
|||||||
|
|
||||||
// optional bool vocabulary_output_piece_score = 32 [default = true];
|
// optional bool vocabulary_output_piece_score = 32 [default = true];
|
||||||
inline bool TrainerSpec::_internal_has_vocabulary_output_piece_score() const {
|
inline bool TrainerSpec::_internal_has_vocabulary_output_piece_score() const {
|
||||||
bool value = (_has_bits_[1] & 0x00000001u) != 0;
|
bool value = (_has_bits_[0] & 0x80000000u) != 0;
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
inline bool TrainerSpec::has_vocabulary_output_piece_score() const {
|
inline bool TrainerSpec::has_vocabulary_output_piece_score() const {
|
||||||
@ -3087,7 +3044,7 @@ inline bool TrainerSpec::has_vocabulary_output_piece_score() const {
|
|||||||
}
|
}
|
||||||
inline void TrainerSpec::clear_vocabulary_output_piece_score() {
|
inline void TrainerSpec::clear_vocabulary_output_piece_score() {
|
||||||
vocabulary_output_piece_score_ = true;
|
vocabulary_output_piece_score_ = true;
|
||||||
_has_bits_[1] &= ~0x00000001u;
|
_has_bits_[0] &= ~0x80000000u;
|
||||||
}
|
}
|
||||||
inline bool TrainerSpec::_internal_vocabulary_output_piece_score() const {
|
inline bool TrainerSpec::_internal_vocabulary_output_piece_score() const {
|
||||||
return vocabulary_output_piece_score_;
|
return vocabulary_output_piece_score_;
|
||||||
@ -3097,7 +3054,7 @@ inline bool TrainerSpec::vocabulary_output_piece_score() const {
|
|||||||
return _internal_vocabulary_output_piece_score();
|
return _internal_vocabulary_output_piece_score();
|
||||||
}
|
}
|
||||||
inline void TrainerSpec::_internal_set_vocabulary_output_piece_score(bool value) {
|
inline void TrainerSpec::_internal_set_vocabulary_output_piece_score(bool value) {
|
||||||
_has_bits_[1] |= 0x00000001u;
|
_has_bits_[0] |= 0x80000000u;
|
||||||
vocabulary_output_piece_score_ = value;
|
vocabulary_output_piece_score_ = value;
|
||||||
}
|
}
|
||||||
inline void TrainerSpec::set_vocabulary_output_piece_score(bool value) {
|
inline void TrainerSpec::set_vocabulary_output_piece_score(bool value) {
|
||||||
@ -3107,7 +3064,7 @@ inline void TrainerSpec::set_vocabulary_output_piece_score(bool value) {
|
|||||||
|
|
||||||
// optional bool hard_vocab_limit = 33 [default = true];
|
// optional bool hard_vocab_limit = 33 [default = true];
|
||||||
inline bool TrainerSpec::_internal_has_hard_vocab_limit() const {
|
inline bool TrainerSpec::_internal_has_hard_vocab_limit() const {
|
||||||
bool value = (_has_bits_[1] & 0x00000002u) != 0;
|
bool value = (_has_bits_[1] & 0x00000001u) != 0;
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
inline bool TrainerSpec::has_hard_vocab_limit() const {
|
inline bool TrainerSpec::has_hard_vocab_limit() const {
|
||||||
@ -3115,7 +3072,7 @@ inline bool TrainerSpec::has_hard_vocab_limit() const {
|
|||||||
}
|
}
|
||||||
inline void TrainerSpec::clear_hard_vocab_limit() {
|
inline void TrainerSpec::clear_hard_vocab_limit() {
|
||||||
hard_vocab_limit_ = true;
|
hard_vocab_limit_ = true;
|
||||||
_has_bits_[1] &= ~0x00000002u;
|
_has_bits_[1] &= ~0x00000001u;
|
||||||
}
|
}
|
||||||
inline bool TrainerSpec::_internal_hard_vocab_limit() const {
|
inline bool TrainerSpec::_internal_hard_vocab_limit() const {
|
||||||
return hard_vocab_limit_;
|
return hard_vocab_limit_;
|
||||||
@ -3125,7 +3082,7 @@ inline bool TrainerSpec::hard_vocab_limit() const {
|
|||||||
return _internal_hard_vocab_limit();
|
return _internal_hard_vocab_limit();
|
||||||
}
|
}
|
||||||
inline void TrainerSpec::_internal_set_hard_vocab_limit(bool value) {
|
inline void TrainerSpec::_internal_set_hard_vocab_limit(bool value) {
|
||||||
_has_bits_[1] |= 0x00000002u;
|
_has_bits_[1] |= 0x00000001u;
|
||||||
hard_vocab_limit_ = value;
|
hard_vocab_limit_ = value;
|
||||||
}
|
}
|
||||||
inline void TrainerSpec::set_hard_vocab_limit(bool value) {
|
inline void TrainerSpec::set_hard_vocab_limit(bool value) {
|
||||||
@ -3135,7 +3092,7 @@ inline void TrainerSpec::set_hard_vocab_limit(bool value) {
|
|||||||
|
|
||||||
// optional bool use_all_vocab = 34 [default = false];
|
// optional bool use_all_vocab = 34 [default = false];
|
||||||
inline bool TrainerSpec::_internal_has_use_all_vocab() const {
|
inline bool TrainerSpec::_internal_has_use_all_vocab() const {
|
||||||
bool value = (_has_bits_[0] & 0x00020000u) != 0;
|
bool value = (_has_bits_[0] & 0x00008000u) != 0;
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
inline bool TrainerSpec::has_use_all_vocab() const {
|
inline bool TrainerSpec::has_use_all_vocab() const {
|
||||||
@ -3143,7 +3100,7 @@ inline bool TrainerSpec::has_use_all_vocab() const {
|
|||||||
}
|
}
|
||||||
inline void TrainerSpec::clear_use_all_vocab() {
|
inline void TrainerSpec::clear_use_all_vocab() {
|
||||||
use_all_vocab_ = false;
|
use_all_vocab_ = false;
|
||||||
_has_bits_[0] &= ~0x00020000u;
|
_has_bits_[0] &= ~0x00008000u;
|
||||||
}
|
}
|
||||||
inline bool TrainerSpec::_internal_use_all_vocab() const {
|
inline bool TrainerSpec::_internal_use_all_vocab() const {
|
||||||
return use_all_vocab_;
|
return use_all_vocab_;
|
||||||
@ -3153,7 +3110,7 @@ inline bool TrainerSpec::use_all_vocab() const {
|
|||||||
return _internal_use_all_vocab();
|
return _internal_use_all_vocab();
|
||||||
}
|
}
|
||||||
inline void TrainerSpec::_internal_set_use_all_vocab(bool value) {
|
inline void TrainerSpec::_internal_set_use_all_vocab(bool value) {
|
||||||
_has_bits_[0] |= 0x00020000u;
|
_has_bits_[0] |= 0x00008000u;
|
||||||
use_all_vocab_ = value;
|
use_all_vocab_ = value;
|
||||||
}
|
}
|
||||||
inline void TrainerSpec::set_use_all_vocab(bool value) {
|
inline void TrainerSpec::set_use_all_vocab(bool value) {
|
||||||
@ -3191,7 +3148,7 @@ inline void TrainerSpec::set_unk_id(::PROTOBUF_NAMESPACE_ID::int32 value) {
|
|||||||
|
|
||||||
// optional int32 bos_id = 41 [default = 1];
|
// optional int32 bos_id = 41 [default = 1];
|
||||||
inline bool TrainerSpec::_internal_has_bos_id() const {
|
inline bool TrainerSpec::_internal_has_bos_id() const {
|
||||||
bool value = (_has_bits_[1] & 0x00000004u) != 0;
|
bool value = (_has_bits_[1] & 0x00000002u) != 0;
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
inline bool TrainerSpec::has_bos_id() const {
|
inline bool TrainerSpec::has_bos_id() const {
|
||||||
@ -3199,7 +3156,7 @@ inline bool TrainerSpec::has_bos_id() const {
|
|||||||
}
|
}
|
||||||
inline void TrainerSpec::clear_bos_id() {
|
inline void TrainerSpec::clear_bos_id() {
|
||||||
bos_id_ = 1;
|
bos_id_ = 1;
|
||||||
_has_bits_[1] &= ~0x00000004u;
|
_has_bits_[1] &= ~0x00000002u;
|
||||||
}
|
}
|
||||||
inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_bos_id() const {
|
inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_bos_id() const {
|
||||||
return bos_id_;
|
return bos_id_;
|
||||||
@ -3209,7 +3166,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::bos_id() const {
|
|||||||
return _internal_bos_id();
|
return _internal_bos_id();
|
||||||
}
|
}
|
||||||
inline void TrainerSpec::_internal_set_bos_id(::PROTOBUF_NAMESPACE_ID::int32 value) {
|
inline void TrainerSpec::_internal_set_bos_id(::PROTOBUF_NAMESPACE_ID::int32 value) {
|
||||||
_has_bits_[1] |= 0x00000004u;
|
_has_bits_[1] |= 0x00000002u;
|
||||||
bos_id_ = value;
|
bos_id_ = value;
|
||||||
}
|
}
|
||||||
inline void TrainerSpec::set_bos_id(::PROTOBUF_NAMESPACE_ID::int32 value) {
|
inline void TrainerSpec::set_bos_id(::PROTOBUF_NAMESPACE_ID::int32 value) {
|
||||||
@ -3219,7 +3176,7 @@ inline void TrainerSpec::set_bos_id(::PROTOBUF_NAMESPACE_ID::int32 value) {
|
|||||||
|
|
||||||
// optional int32 eos_id = 42 [default = 2];
|
// optional int32 eos_id = 42 [default = 2];
|
||||||
inline bool TrainerSpec::_internal_has_eos_id() const {
|
inline bool TrainerSpec::_internal_has_eos_id() const {
|
||||||
bool value = (_has_bits_[1] & 0x00000008u) != 0;
|
bool value = (_has_bits_[1] & 0x00000004u) != 0;
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
inline bool TrainerSpec::has_eos_id() const {
|
inline bool TrainerSpec::has_eos_id() const {
|
||||||
@ -3227,7 +3184,7 @@ inline bool TrainerSpec::has_eos_id() const {
|
|||||||
}
|
}
|
||||||
inline void TrainerSpec::clear_eos_id() {
|
inline void TrainerSpec::clear_eos_id() {
|
||||||
eos_id_ = 2;
|
eos_id_ = 2;
|
||||||
_has_bits_[1] &= ~0x00000008u;
|
_has_bits_[1] &= ~0x00000004u;
|
||||||
}
|
}
|
||||||
inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_eos_id() const {
|
inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_eos_id() const {
|
||||||
return eos_id_;
|
return eos_id_;
|
||||||
@ -3237,7 +3194,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::eos_id() const {
|
|||||||
return _internal_eos_id();
|
return _internal_eos_id();
|
||||||
}
|
}
|
||||||
inline void TrainerSpec::_internal_set_eos_id(::PROTOBUF_NAMESPACE_ID::int32 value) {
|
inline void TrainerSpec::_internal_set_eos_id(::PROTOBUF_NAMESPACE_ID::int32 value) {
|
||||||
_has_bits_[1] |= 0x00000008u;
|
_has_bits_[1] |= 0x00000004u;
|
||||||
eos_id_ = value;
|
eos_id_ = value;
|
||||||
}
|
}
|
||||||
inline void TrainerSpec::set_eos_id(::PROTOBUF_NAMESPACE_ID::int32 value) {
|
inline void TrainerSpec::set_eos_id(::PROTOBUF_NAMESPACE_ID::int32 value) {
|
||||||
@ -3247,7 +3204,7 @@ inline void TrainerSpec::set_eos_id(::PROTOBUF_NAMESPACE_ID::int32 value) {
|
|||||||
|
|
||||||
// optional int32 pad_id = 43 [default = -1];
|
// optional int32 pad_id = 43 [default = -1];
|
||||||
inline bool TrainerSpec::_internal_has_pad_id() const {
|
inline bool TrainerSpec::_internal_has_pad_id() const {
|
||||||
bool value = (_has_bits_[1] & 0x00000010u) != 0;
|
bool value = (_has_bits_[1] & 0x00000008u) != 0;
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
inline bool TrainerSpec::has_pad_id() const {
|
inline bool TrainerSpec::has_pad_id() const {
|
||||||
@ -3255,7 +3212,7 @@ inline bool TrainerSpec::has_pad_id() const {
|
|||||||
}
|
}
|
||||||
inline void TrainerSpec::clear_pad_id() {
|
inline void TrainerSpec::clear_pad_id() {
|
||||||
pad_id_ = -1;
|
pad_id_ = -1;
|
||||||
_has_bits_[1] &= ~0x00000010u;
|
_has_bits_[1] &= ~0x00000008u;
|
||||||
}
|
}
|
||||||
inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_pad_id() const {
|
inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::_internal_pad_id() const {
|
||||||
return pad_id_;
|
return pad_id_;
|
||||||
@ -3265,7 +3222,7 @@ inline ::PROTOBUF_NAMESPACE_ID::int32 TrainerSpec::pad_id() const {
|
|||||||
return _internal_pad_id();
|
return _internal_pad_id();
|
||||||
}
|
}
|
||||||
inline void TrainerSpec::_internal_set_pad_id(::PROTOBUF_NAMESPACE_ID::int32 value) {
|
inline void TrainerSpec::_internal_set_pad_id(::PROTOBUF_NAMESPACE_ID::int32 value) {
|
||||||
_has_bits_[1] |= 0x00000010u;
|
_has_bits_[1] |= 0x00000008u;
|
||||||
pad_id_ = value;
|
pad_id_ = value;
|
||||||
}
|
}
|
||||||
inline void TrainerSpec::set_pad_id(::PROTOBUF_NAMESPACE_ID::int32 value) {
|
inline void TrainerSpec::set_pad_id(::PROTOBUF_NAMESPACE_ID::int32 value) {
|
||||||
@ -3645,7 +3602,7 @@ inline void TrainerSpec::set_allocated_unk_surface(std::string* unk_surface) {
|
|||||||
|
|
||||||
// optional bool train_extremely_large_corpus = 49 [default = false];
|
// optional bool train_extremely_large_corpus = 49 [default = false];
|
||||||
inline bool TrainerSpec::_internal_has_train_extremely_large_corpus() const {
|
inline bool TrainerSpec::_internal_has_train_extremely_large_corpus() const {
|
||||||
bool value = (_has_bits_[0] & 0x00040000u) != 0;
|
bool value = (_has_bits_[0] & 0x00020000u) != 0;
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
inline bool TrainerSpec::has_train_extremely_large_corpus() const {
|
inline bool TrainerSpec::has_train_extremely_large_corpus() const {
|
||||||
@ -3653,7 +3610,7 @@ inline bool TrainerSpec::has_train_extremely_large_corpus() const {
|
|||||||
}
|
}
|
||||||
inline void TrainerSpec::clear_train_extremely_large_corpus() {
|
inline void TrainerSpec::clear_train_extremely_large_corpus() {
|
||||||
train_extremely_large_corpus_ = false;
|
train_extremely_large_corpus_ = false;
|
||||||
_has_bits_[0] &= ~0x00040000u;
|
_has_bits_[0] &= ~0x00020000u;
|
||||||
}
|
}
|
||||||
inline bool TrainerSpec::_internal_train_extremely_large_corpus() const {
|
inline bool TrainerSpec::_internal_train_extremely_large_corpus() const {
|
||||||
return train_extremely_large_corpus_;
|
return train_extremely_large_corpus_;
|
||||||
@ -3663,7 +3620,7 @@ inline bool TrainerSpec::train_extremely_large_corpus() const {
|
|||||||
return _internal_train_extremely_large_corpus();
|
return _internal_train_extremely_large_corpus();
|
||||||
}
|
}
|
||||||
inline void TrainerSpec::_internal_set_train_extremely_large_corpus(bool value) {
|
inline void TrainerSpec::_internal_set_train_extremely_large_corpus(bool value) {
|
||||||
_has_bits_[0] |= 0x00040000u;
|
_has_bits_[0] |= 0x00020000u;
|
||||||
train_extremely_large_corpus_ = value;
|
train_extremely_large_corpus_ = value;
|
||||||
}
|
}
|
||||||
inline void TrainerSpec::set_train_extremely_large_corpus(bool value) {
|
inline void TrainerSpec::set_train_extremely_large_corpus(bool value) {
|
||||||
|
@ -134,53 +134,32 @@ void ModelInterface::InitializePieces() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::vector<absl::string_view> SplitIntoWords(absl::string_view text,
|
std::vector<absl::string_view> SplitIntoWords(absl::string_view text,
|
||||||
bool treat_ws_as_suffix,
|
bool treat_whitespace_as_suffix) {
|
||||||
bool allow_ws_only_pieces) {
|
|
||||||
const char *begin = text.data();
|
const char *begin = text.data();
|
||||||
const char *end = text.data() + text.size();
|
const char *end = text.data() + text.size();
|
||||||
|
|
||||||
// Space symbol (U+2581)
|
// Space symbol (U+2581)
|
||||||
const absl::string_view kSpaceSymbol = "\xe2\x96\x81";
|
const absl::string_view kSpaceSymbol = "\xe2\x96\x81";
|
||||||
bool in_ws_sequence = false;
|
|
||||||
|
|
||||||
std::vector<absl::string_view> result;
|
std::vector<absl::string_view> result;
|
||||||
if (treat_ws_as_suffix) { // put ws tokens at the end of non-ws sequences.
|
if (treat_whitespace_as_suffix) {
|
||||||
if (begin < end) result.emplace_back(begin, 0);
|
if (begin < end) result.emplace_back(begin, 0);
|
||||||
while (begin < end) {
|
while (begin < end) {
|
||||||
const int mblen =
|
const int mblen =
|
||||||
std::min<int>(string_util::OneCharLen(begin), end - begin);
|
std::min<int>(string_util::OneCharLen(begin), end - begin);
|
||||||
const bool is_ws = absl::string_view(begin, mblen) == kSpaceSymbol;
|
const bool is_ws = absl::string_view(begin, mblen) == kSpaceSymbol;
|
||||||
|
|
||||||
if (is_ws) { // keep track of sequences consecutive ws tokens.
|
|
||||||
in_ws_sequence = true;
|
|
||||||
} else if (in_ws_sequence) {
|
|
||||||
if (allow_ws_only_pieces) result.emplace_back(begin, 0);
|
|
||||||
|
|
||||||
in_ws_sequence = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
result.back() =
|
result.back() =
|
||||||
absl::string_view(result.back().data(), result.back().size() + mblen);
|
absl::string_view(result.back().data(), result.back().size() + mblen);
|
||||||
begin += mblen;
|
begin += mblen;
|
||||||
|
if (begin < end && is_ws) result.emplace_back(begin, 0);
|
||||||
if (begin < end && is_ws && !allow_ws_only_pieces)
|
|
||||||
result.emplace_back(begin, 0);
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
while (begin < end) {
|
while (begin < end) {
|
||||||
const int mblen =
|
const int mblen =
|
||||||
std::min<int>(string_util::OneCharLen(begin), end - begin);
|
std::min<int>(string_util::OneCharLen(begin), end - begin);
|
||||||
bool is_ws = absl::string_view(begin, mblen) == kSpaceSymbol;
|
|
||||||
|
|
||||||
// if is whitespace (and not in sequence if allow_ws_only_pieces is True)
|
|
||||||
if (begin == text.data() ||
|
if (begin == text.data() ||
|
||||||
(is_ws && (!in_ws_sequence || !allow_ws_only_pieces))) {
|
absl::string_view(begin, mblen) == kSpaceSymbol)
|
||||||
result.emplace_back(begin, 0); // add empty string piece.
|
result.emplace_back(begin, 0); // add empty string piece.
|
||||||
in_ws_sequence = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (in_ws_sequence && !is_ws) in_ws_sequence = false;
|
|
||||||
|
|
||||||
result.back() =
|
result.back() =
|
||||||
absl::string_view(result.back().data(), result.back().size() + mblen);
|
absl::string_view(result.back().data(), result.back().size() + mblen);
|
||||||
begin += mblen;
|
begin += mblen;
|
||||||
|
@ -33,9 +33,8 @@
|
|||||||
namespace sentencepiece {
|
namespace sentencepiece {
|
||||||
|
|
||||||
// "_this_is_a_pen" => ["_this", "_is", "_a", "_pen"]
|
// "_this_is_a_pen" => ["_this", "_is", "_a", "_pen"]
|
||||||
std::vector<absl::string_view> SplitIntoWords(
|
std::vector<absl::string_view> SplitIntoWords(absl::string_view text,
|
||||||
absl::string_view text, bool treat_ws_as_suffix = false,
|
bool add_ws_as_suffix = false);
|
||||||
bool allow_ws_only_pieces = false);
|
|
||||||
|
|
||||||
// Converts byte (0-255) to piece (e.g., 58 -> "<0x3A>").
|
// Converts byte (0-255) to piece (e.g., 58 -> "<0x3A>").
|
||||||
std::string ByteToPiece(unsigned char c);
|
std::string ByteToPiece(unsigned char c);
|
||||||
@ -107,42 +106,12 @@ class ModelInterface {
|
|||||||
return EncodeResult();
|
return EncodeResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sample `samples` many tokenisations from the segmentation lattice
|
|
||||||
// If `wor` is true, the samples are taken without replacement, and the scores
|
|
||||||
// are the inclusion probabilities of the elements in the sample; otherwise
|
|
||||||
// the samples are taken with replacement and the scores are the log-probs of
|
|
||||||
// sample elements
|
|
||||||
// If `include_best` is true, the best tokenisation is always included in the
|
|
||||||
// sample, and the remaining elements are sampled excluding the best.
|
|
||||||
virtual NBestEncodeResult SampleEncodeAndScore(absl::string_view normalized,
|
|
||||||
float alpha, int samples,
|
|
||||||
bool wor,
|
|
||||||
bool include_best) const {
|
|
||||||
LOG(ERROR) << "Not implemented.";
|
|
||||||
return {{EncodeResult(), 0.0}};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Calculates the entropy of the segmentation lattice with inverse temperature
|
|
||||||
// `theta`.
|
|
||||||
// Uses a novel dynamic program to calculate the entropy.
|
|
||||||
virtual float CalculateEntropy(absl::string_view normalized,
|
|
||||||
float theta) const {
|
|
||||||
LOG(ERROR) << "Not implemented.";
|
|
||||||
return 0.0;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return true if SampleEncode returns a valid result.
|
// Return true if SampleEncode returns a valid result.
|
||||||
virtual bool IsSampleEncodeAvailable() const { return false; }
|
virtual bool IsSampleEncodeAvailable() const { return false; }
|
||||||
|
|
||||||
// Return true if NBestEncode returns a valid result.
|
// Return true if NBestEncode returns a valid result.
|
||||||
virtual bool IsNBestEncodeAvailable() const { return false; }
|
virtual bool IsNBestEncodeAvailable() const { return false; }
|
||||||
|
|
||||||
// Return true if SampleEncodeAndScore returns a valid result.
|
|
||||||
virtual bool IsSampleEncodeAndScoreAvailable() const { return false; }
|
|
||||||
|
|
||||||
// Return true if CalculateEntropy returns a valid result.
|
|
||||||
virtual bool IsCalculateEntropyAvailable() const { return false; }
|
|
||||||
|
|
||||||
// Returns the vocab id of `piece`.
|
// Returns the vocab id of `piece`.
|
||||||
// Returns UNK(0) if `piece` is unknown
|
// Returns UNK(0) if `piece` is unknown
|
||||||
virtual int PieceToId(absl::string_view piece) const;
|
virtual int PieceToId(absl::string_view piece) const;
|
||||||
@ -155,10 +124,7 @@ class ModelInterface {
|
|||||||
|
|
||||||
// Returns the size of sentence pieces, which is the same
|
// Returns the size of sentence pieces, which is the same
|
||||||
// as the size of vocabulary for NMT.
|
// as the size of vocabulary for NMT.
|
||||||
virtual int GetPieceSize() const {
|
virtual int GetPieceSize() const { return model_proto_->pieces_size(); }
|
||||||
if (!model_proto_) return 0;
|
|
||||||
return model_proto_->pieces_size();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns the score of `id`.
|
// Returns the score of `id`.
|
||||||
// Score represents a log probability of the piece.
|
// Score represents a log probability of the piece.
|
||||||
|
@ -412,50 +412,6 @@ TEST(ModelInterfaceTest, SplitIntoWordsSuffixTest) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(ModelInterfaceTest, SplitIntoWordsWhiteSpaceOnly) {
|
|
||||||
{
|
|
||||||
const auto v =
|
|
||||||
SplitIntoWords("this" WS "is" WS "a" WS "pen" WS, true, true);
|
|
||||||
EXPECT_EQ(4, v.size());
|
|
||||||
EXPECT_EQ("this" WS, v[0]);
|
|
||||||
EXPECT_EQ("is" WS, v[1]);
|
|
||||||
EXPECT_EQ("a" WS, v[2]);
|
|
||||||
EXPECT_EQ("pen" WS, v[3]);
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
const auto v = SplitIntoWords(WS WS WS "a", false, true);
|
|
||||||
EXPECT_EQ(1, v.size());
|
|
||||||
EXPECT_EQ(WS WS WS "a", v[0]);
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
const auto v = SplitIntoWords("a" WS WS WS, true, true);
|
|
||||||
EXPECT_EQ(1, v.size());
|
|
||||||
EXPECT_EQ("a" WS WS WS, v[0]);
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
const auto v = SplitIntoWords(WS WS, true, true);
|
|
||||||
EXPECT_EQ(1, v.size());
|
|
||||||
EXPECT_EQ(WS WS, v[0]);
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
const auto v = SplitIntoWords(WS WS "a" WS, true, true);
|
|
||||||
EXPECT_EQ(2, v.size());
|
|
||||||
EXPECT_EQ(WS WS, v[0]);
|
|
||||||
EXPECT_EQ("a" WS, v[1]);
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
const auto v = SplitIntoWords(WS WS "a" WS, false, true);
|
|
||||||
EXPECT_EQ(2, v.size());
|
|
||||||
EXPECT_EQ(WS WS "a", v[0]);
|
|
||||||
EXPECT_EQ(WS, v[1]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(ModelInterfaceTest, ByteToPieceTest) {
|
TEST(ModelInterfaceTest, ByteToPieceTest) {
|
||||||
EXPECT_EQ(ByteToPiece(0), "<0x00>");
|
EXPECT_EQ(ByteToPiece(0), "<0x00>");
|
||||||
EXPECT_EQ(ByteToPiece(1), "<0x01>");
|
EXPECT_EQ(ByteToPiece(1), "<0x01>");
|
||||||
|
@ -12,11 +12,12 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.!
|
// limitations under the License.!
|
||||||
|
|
||||||
|
#include "normalizer.h"
|
||||||
|
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
#include "normalizer.h"
|
|
||||||
#include "third_party/absl/memory/memory.h"
|
#include "third_party/absl/memory/memory.h"
|
||||||
#include "third_party/absl/strings/match.h"
|
#include "third_party/absl/strings/match.h"
|
||||||
#include "third_party/absl/strings/string_view.h"
|
#include "third_party/absl/strings/string_view.h"
|
||||||
@ -277,11 +278,11 @@ util::Status Normalizer::DecodePrecompiledCharsMap(
|
|||||||
absl::string_view blob, absl::string_view *trie_blob,
|
absl::string_view blob, absl::string_view *trie_blob,
|
||||||
absl::string_view *normalized, std::string *buffer) {
|
absl::string_view *normalized, std::string *buffer) {
|
||||||
uint32 trie_blob_size = 0;
|
uint32 trie_blob_size = 0;
|
||||||
|
|
||||||
if (blob.size() <= sizeof(trie_blob_size) ||
|
if (blob.size() <= sizeof(trie_blob_size) ||
|
||||||
!string_util::DecodePOD<uint32>(
|
!string_util::DecodePOD<uint32>(
|
||||||
absl::string_view(blob.data(), sizeof(trie_blob_size)),
|
absl::string_view(blob.data(), sizeof(trie_blob_size)),
|
||||||
&trie_blob_size) ||
|
&trie_blob_size)) {
|
||||||
trie_blob_size >= blob.size()) {
|
|
||||||
return util::InternalError("Blob for normalization rule is broken.");
|
return util::InternalError("Blob for normalization rule is broken.");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -22,6 +22,7 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
|
#include "util.h"
|
||||||
#include "sentencepiece_model.pb.h"
|
#include "sentencepiece_model.pb.h"
|
||||||
#include "sentencepiece_processor.h"
|
#include "sentencepiece_processor.h"
|
||||||
#include "third_party/absl/strings/string_view.h"
|
#include "third_party/absl/strings/string_view.h"
|
||||||
|
@ -12,10 +12,11 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.!
|
// limitations under the License.!
|
||||||
|
|
||||||
|
#include "normalizer.h"
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "builder.h"
|
#include "builder.h"
|
||||||
#include "normalizer.h"
|
|
||||||
#include "sentencepiece_trainer.h"
|
#include "sentencepiece_trainer.h"
|
||||||
#include "testharness.h"
|
#include "testharness.h"
|
||||||
#include "util.h"
|
#include "util.h"
|
||||||
|
@ -139,10 +139,6 @@ message TrainerSpec {
|
|||||||
// of sentence.
|
// of sentence.
|
||||||
optional bool treat_whitespace_as_suffix = 24 [default = false];
|
optional bool treat_whitespace_as_suffix = 24 [default = false];
|
||||||
|
|
||||||
// Allows pieces that only contain whitespaces instead of appearing only as
|
|
||||||
// prefix or suffix of other pieces.
|
|
||||||
optional bool allow_whitespace_only_pieces = 26 [default = false];
|
|
||||||
|
|
||||||
// Split all digits (0-9) into separate pieces.
|
// Split all digits (0-9) into separate pieces.
|
||||||
optional bool split_digits = 25 [default = false];
|
optional bool split_digits = 25 [default = false];
|
||||||
|
|
||||||
|
@ -12,6 +12,8 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.!
|
// limitations under the License.!
|
||||||
|
|
||||||
|
#include "sentencepiece_processor.h"
|
||||||
|
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
@ -22,7 +24,6 @@
|
|||||||
#include "model_interface.h"
|
#include "model_interface.h"
|
||||||
#include "normalizer.h"
|
#include "normalizer.h"
|
||||||
#include "sentencepiece.pb.h"
|
#include "sentencepiece.pb.h"
|
||||||
#include "sentencepiece_processor.h"
|
|
||||||
#include "third_party/absl/memory/memory.h"
|
#include "third_party/absl/memory/memory.h"
|
||||||
#include "third_party/absl/strings/numbers.h"
|
#include "third_party/absl/strings/numbers.h"
|
||||||
#include "third_party/absl/strings/str_cat.h"
|
#include "third_party/absl/strings/str_cat.h"
|
||||||
@ -503,43 +504,6 @@ util::Status SentencePieceProcessor::SampleEncode(
|
|||||||
return util::OkStatus();
|
return util::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
util::Status SentencePieceProcessor::SampleEncodeAndScore(
|
|
||||||
absl::string_view input, int samples, float theta, bool wor,
|
|
||||||
bool include_best, NBestSentencePieceText *samples_spt) const {
|
|
||||||
CHECK_OR_RETURN(model_->IsSampleEncodeAndScoreAvailable())
|
|
||||||
<< "SampleEncodeAndScore is not available for the current model.";
|
|
||||||
std::string normalized;
|
|
||||||
std::vector<size_t> norm_to_orig;
|
|
||||||
RETURN_IF_ERROR(normalizer_->Normalize(input, &normalized, &norm_to_orig));
|
|
||||||
|
|
||||||
const auto results = model_->SampleEncodeAndScore(normalized, theta, samples,
|
|
||||||
wor, include_best);
|
|
||||||
CHECK_OR_RETURN(!results.empty())
|
|
||||||
<< "SampleEncodeAndScore returns empty result.";
|
|
||||||
|
|
||||||
for (const auto &result : results) {
|
|
||||||
auto *spt = samples_spt->add_nbests();
|
|
||||||
spt->set_score(result.second);
|
|
||||||
RETURN_IF_ERROR(PopulateSentencePieceText(input, normalized, norm_to_orig,
|
|
||||||
result.first, spt));
|
|
||||||
}
|
|
||||||
|
|
||||||
return util::OkStatus();
|
|
||||||
}
|
|
||||||
|
|
||||||
util::Status SentencePieceProcessor::CalculateEntropy(absl::string_view input,
|
|
||||||
float theta,
|
|
||||||
float *entropy) const {
|
|
||||||
CHECK_OR_RETURN(model_->IsCalculateEntropyAvailable())
|
|
||||||
<< "CalculateEntropy is not available for the current model.";
|
|
||||||
std::string normalized;
|
|
||||||
std::vector<size_t> norm_to_orig;
|
|
||||||
RETURN_IF_ERROR(normalizer_->Normalize(input, &normalized, &norm_to_orig));
|
|
||||||
|
|
||||||
*entropy = model_->CalculateEntropy(normalized, theta);
|
|
||||||
return util::OkStatus();
|
|
||||||
}
|
|
||||||
|
|
||||||
util::Status SentencePieceProcessor::Decode(
|
util::Status SentencePieceProcessor::Decode(
|
||||||
const std::vector<std::string> &pieces, SentencePieceText *spt) const {
|
const std::vector<std::string> &pieces, SentencePieceText *spt) const {
|
||||||
CHECK_OR_RETURN_STATUS_PROTO(spt);
|
CHECK_OR_RETURN_STATUS_PROTO(spt);
|
||||||
@ -869,12 +833,6 @@ std::string SentencePieceProcessor::serialized_model_proto() const {
|
|||||||
return model_proto_ ? model_proto_->SerializeAsString() : "";
|
return model_proto_ ? model_proto_->SerializeAsString() : "";
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set seed value of random generator.
|
|
||||||
// Do not set static_cast<unique_int>(-1),
|
|
||||||
// as this seed is reserved for initializing from
|
|
||||||
// std::random_device.
|
|
||||||
void SetRandomGeneratorSeed(unsigned int seed);
|
|
||||||
|
|
||||||
namespace io {
|
namespace io {
|
||||||
|
|
||||||
util::Status LoadModelProto(absl::string_view filename,
|
util::Status LoadModelProto(absl::string_view filename,
|
||||||
|
@ -315,15 +315,6 @@ class SentencePieceProcessor {
|
|||||||
virtual util::Status SampleEncode(absl::string_view input, int nbest_size,
|
virtual util::Status SampleEncode(absl::string_view input, int nbest_size,
|
||||||
float alpha, SentencePieceText *spt) const;
|
float alpha, SentencePieceText *spt) const;
|
||||||
|
|
||||||
// Samples N segmentation and returns the scores as well
|
|
||||||
virtual util::Status SampleEncodeAndScore(
|
|
||||||
absl::string_view input, int samples, float theta, bool wor,
|
|
||||||
bool include_best, NBestSentencePieceText *samples_spt) const;
|
|
||||||
|
|
||||||
// Calculate entropy of possible tokenisations
|
|
||||||
virtual util::Status CalculateEntropy(absl::string_view input, float theta,
|
|
||||||
float *entropy) const;
|
|
||||||
|
|
||||||
// Given a sequence of pieces, decodes it into SentencePieceText.
|
// Given a sequence of pieces, decodes it into SentencePieceText.
|
||||||
virtual util::Status Decode(const std::vector<std::string> &pieces,
|
virtual util::Status Decode(const std::vector<std::string> &pieces,
|
||||||
SentencePieceText *spt) const;
|
SentencePieceText *spt) const;
|
||||||
|
@ -12,6 +12,8 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.!
|
// limitations under the License.!
|
||||||
|
|
||||||
|
#include "sentencepiece_processor.h"
|
||||||
|
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#include "builder.h"
|
#include "builder.h"
|
||||||
@ -20,7 +22,6 @@
|
|||||||
#include "normalizer.h"
|
#include "normalizer.h"
|
||||||
#include "sentencepiece.pb.h"
|
#include "sentencepiece.pb.h"
|
||||||
#include "sentencepiece_model.pb.h"
|
#include "sentencepiece_model.pb.h"
|
||||||
#include "sentencepiece_processor.h"
|
|
||||||
#include "sentencepiece_trainer.h"
|
#include "sentencepiece_trainer.h"
|
||||||
#include "testharness.h"
|
#include "testharness.h"
|
||||||
#include "third_party/absl/container/flat_hash_map.h"
|
#include "third_party/absl/container/flat_hash_map.h"
|
||||||
@ -1138,6 +1139,13 @@ TEST(SentencePieceProcessorTest, EndToEndTest) {
|
|||||||
EXPECT_EQ("cba", output);
|
EXPECT_EQ("cba", output);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Out of range
|
||||||
|
{
|
||||||
|
std::string output;
|
||||||
|
const std::vector<int> ids = {3, 4, 127};
|
||||||
|
EXPECT_FALSE(sp.Decode(ids, &output).ok());
|
||||||
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
EXPECT_TRUE(sp.SetDecodeExtraOptions("bos:eos").ok());
|
EXPECT_TRUE(sp.SetDecodeExtraOptions("bos:eos").ok());
|
||||||
|
|
||||||
@ -1164,13 +1172,6 @@ TEST(SentencePieceProcessorTest, EndToEndTest) {
|
|||||||
EXPECT_EQ("cba", output);
|
EXPECT_EQ("cba", output);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Out of range
|
|
||||||
{
|
|
||||||
std::string output;
|
|
||||||
const std::vector<int> ids = {3, 4, 127};
|
|
||||||
EXPECT_FALSE(sp.Decode(ids, &output).ok());
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
{
|
||||||
EXPECT_TRUE(sp.SetDecodeExtraOptions("bos:eos:reverse").ok());
|
EXPECT_TRUE(sp.SetDecodeExtraOptions("bos:eos:reverse").ok());
|
||||||
|
|
||||||
|
@ -12,6 +12,8 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.!
|
// limitations under the License.!
|
||||||
|
|
||||||
|
#include "sentencepiece_trainer.h"
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
@ -20,9 +22,7 @@
|
|||||||
#include "normalizer.h"
|
#include "normalizer.h"
|
||||||
#include "sentencepiece.pb.h"
|
#include "sentencepiece.pb.h"
|
||||||
#include "sentencepiece_model.pb.h"
|
#include "sentencepiece_model.pb.h"
|
||||||
#include "sentencepiece_trainer.h"
|
|
||||||
#include "spec_parser.h"
|
#include "spec_parser.h"
|
||||||
#include "third_party/absl/flags/flag.h"
|
|
||||||
#include "third_party/absl/strings/numbers.h"
|
#include "third_party/absl/strings/numbers.h"
|
||||||
#include "third_party/absl/strings/str_cat.h"
|
#include "third_party/absl/strings/str_cat.h"
|
||||||
#include "third_party/absl/strings/str_split.h"
|
#include "third_party/absl/strings/str_split.h"
|
||||||
@ -31,8 +31,6 @@
|
|||||||
#include "trainer_factory.h"
|
#include "trainer_factory.h"
|
||||||
#include "util.h"
|
#include "util.h"
|
||||||
|
|
||||||
ABSL_DECLARE_FLAG(int, minloglevel);
|
|
||||||
|
|
||||||
namespace sentencepiece {
|
namespace sentencepiece {
|
||||||
namespace {
|
namespace {
|
||||||
static constexpr char kDefaultNormalizerName[] = "nmt_nfkc";
|
static constexpr char kDefaultNormalizerName[] = "nmt_nfkc";
|
||||||
@ -112,7 +110,7 @@ util::Status SentencePieceTrainer::MergeSpecsFromArgs(
|
|||||||
for (auto arg : absl::StrSplit(args, " ")) {
|
for (auto arg : absl::StrSplit(args, " ")) {
|
||||||
absl::ConsumePrefix(&arg, "--");
|
absl::ConsumePrefix(&arg, "--");
|
||||||
std::string key, value;
|
std::string key, value;
|
||||||
const auto pos = arg.find('=');
|
const auto pos = arg.find("=");
|
||||||
if (pos == absl::string_view::npos) {
|
if (pos == absl::string_view::npos) {
|
||||||
key = std::string(arg);
|
key = std::string(arg);
|
||||||
} else {
|
} else {
|
||||||
@ -151,7 +149,7 @@ util::Status SentencePieceTrainer::MergeSpecsFromArgs(
|
|||||||
} else if (key == "minloglevel") {
|
} else if (key == "minloglevel") {
|
||||||
int v = 0;
|
int v = 0;
|
||||||
CHECK_OR_RETURN(absl::SimpleAtoi(value, &v));
|
CHECK_OR_RETURN(absl::SimpleAtoi(value, &v));
|
||||||
absl::SetFlag(&FLAGS_minloglevel, v);
|
logging::SetMinLogLevel(v);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -145,7 +145,6 @@ inline std::string PrintProto(const TrainerSpec &message,
|
|||||||
PRINT_PARAM(split_by_whitespace);
|
PRINT_PARAM(split_by_whitespace);
|
||||||
PRINT_PARAM(split_digits);
|
PRINT_PARAM(split_digits);
|
||||||
PRINT_PARAM(treat_whitespace_as_suffix);
|
PRINT_PARAM(treat_whitespace_as_suffix);
|
||||||
PRINT_PARAM(allow_whitespace_only_pieces);
|
|
||||||
PRINT_REPEATED_STRING(control_symbols);
|
PRINT_REPEATED_STRING(control_symbols);
|
||||||
PRINT_REPEATED_STRING(user_defined_symbols);
|
PRINT_REPEATED_STRING(user_defined_symbols);
|
||||||
PRINT_PARAM(required_chars);
|
PRINT_PARAM(required_chars);
|
||||||
@ -220,7 +219,6 @@ util::Status SentencePieceTrainer::SetProtoField(const std::string &name,
|
|||||||
PARSE_BOOL(split_by_whitespace);
|
PARSE_BOOL(split_by_whitespace);
|
||||||
PARSE_BOOL(split_digits);
|
PARSE_BOOL(split_digits);
|
||||||
PARSE_BOOL(treat_whitespace_as_suffix);
|
PARSE_BOOL(treat_whitespace_as_suffix);
|
||||||
PARSE_BOOL(allow_whitespace_only_pieces);
|
|
||||||
PARSE_REPEATED_STRING(control_symbols);
|
PARSE_REPEATED_STRING(control_symbols);
|
||||||
PARSE_REPEATED_STRING(user_defined_symbols);
|
PARSE_REPEATED_STRING(user_defined_symbols);
|
||||||
PARSE_STRING(required_chars);
|
PARSE_STRING(required_chars);
|
||||||
|
@ -64,7 +64,6 @@ int main(int argc, char *argv[]) {
|
|||||||
|
|
||||||
auto ToIds = [&](const std::vector<std::string> &pieces) {
|
auto ToIds = [&](const std::vector<std::string> &pieces) {
|
||||||
std::vector<int> ids;
|
std::vector<int> ids;
|
||||||
ids.reserve(pieces.size());
|
|
||||||
for (const auto &s : pieces) {
|
for (const auto &s : pieces) {
|
||||||
ids.push_back(atoi(s.c_str()));
|
ids.push_back(atoi(s.c_str()));
|
||||||
}
|
}
|
||||||
|
@ -28,17 +28,16 @@
|
|||||||
#include "trainer_interface.h"
|
#include "trainer_interface.h"
|
||||||
|
|
||||||
ABSL_FLAG(std::string, model, "", "model file name");
|
ABSL_FLAG(std::string, model, "", "model file name");
|
||||||
ABSL_FLAG(
|
ABSL_FLAG(std::string, output_format, "piece",
|
||||||
std::string, output_format, "piece",
|
"choose from piece, id, proto, nbest_piece, nbest_id, nbest_proto, "
|
||||||
"choose from piece, id, proto, nbest_piece, nbest_id, or nbest_proto");
|
"sample_piece, sample_id or sample_proto.");
|
||||||
ABSL_FLAG(std::string, input, "", "input filename");
|
ABSL_FLAG(std::string, input, "", "input filename");
|
||||||
ABSL_FLAG(std::string, output, "", "output filename");
|
ABSL_FLAG(std::string, output, "", "output filename");
|
||||||
ABSL_FLAG(std::string, extra_options, "",
|
ABSL_FLAG(std::string, extra_options, "",
|
||||||
"':' separated encoder extra options, e.g., \"reverse:bos:eos\"");
|
"':' separated encoder extra options, e.g., \"reverse:bos:eos\"");
|
||||||
ABSL_FLAG(int32, nbest_size, 10, "NBest size");
|
ABSL_FLAG(int32, nbest_size, 10, "NBest size");
|
||||||
ABSL_FLAG(double, alpha, 0.5, "Smoothing parameter for sampling mode.");
|
ABSL_FLAG(double, alpha, 0.5, "Smoothing parameter for sampling mode.");
|
||||||
ABSL_FLAG(uint32, random_seed, static_cast<uint32>(-1),
|
ABSL_FLAG(int32, random_seed, -1, "Seed value for random generator.");
|
||||||
"Seed value for random generator.");
|
|
||||||
|
|
||||||
// Piece restriction with vocabulary file.
|
// Piece restriction with vocabulary file.
|
||||||
// https://github.com/rsennrich/subword-nmt#best-practice-advice-for-byte-pair-encoding-in-nmt
|
// https://github.com/rsennrich/subword-nmt#best-practice-advice-for-byte-pair-encoding-in-nmt
|
||||||
@ -62,9 +61,8 @@ int main(int argc, char *argv[]) {
|
|||||||
rest_args.push_back(absl::GetFlag(FLAGS_input));
|
rest_args.push_back(absl::GetFlag(FLAGS_input));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (absl::GetFlag(FLAGS_random_seed) != -1) {
|
if (absl::GetFlag(FLAGS_random_seed) != -1)
|
||||||
sentencepiece::SetRandomGeneratorSeed(absl::GetFlag(FLAGS_random_seed));
|
sentencepiece::SetRandomGeneratorSeed(absl::GetFlag(FLAGS_random_seed));
|
||||||
}
|
|
||||||
|
|
||||||
if (rest_args.empty())
|
if (rest_args.empty())
|
||||||
rest_args.push_back(""); // empty means that reading from stdin.
|
rest_args.push_back(""); // empty means that reading from stdin.
|
||||||
|
@ -80,9 +80,6 @@ ABSL_FLAG(bool, split_digits, kDefaultTrainerSpec.split_digits(),
|
|||||||
ABSL_FLAG(bool, treat_whitespace_as_suffix,
|
ABSL_FLAG(bool, treat_whitespace_as_suffix,
|
||||||
kDefaultTrainerSpec.treat_whitespace_as_suffix(),
|
kDefaultTrainerSpec.treat_whitespace_as_suffix(),
|
||||||
"treat whitespace marker as suffix instead of prefix.");
|
"treat whitespace marker as suffix instead of prefix.");
|
||||||
ABSL_FLAG(bool, allow_whitespace_only_pieces,
|
|
||||||
kDefaultTrainerSpec.allow_whitespace_only_pieces(),
|
|
||||||
"allow pieces that only contain (consecutive) whitespace tokens");
|
|
||||||
ABSL_FLAG(std::string, control_symbols, "",
|
ABSL_FLAG(std::string, control_symbols, "",
|
||||||
"comma separated list of control symbols");
|
"comma separated list of control symbols");
|
||||||
ABSL_FLAG(std::string, control_symbols_file, "",
|
ABSL_FLAG(std::string, control_symbols_file, "",
|
||||||
@ -141,8 +138,7 @@ ABSL_FLAG(std::string, unk_surface, kDefaultTrainerSpec.unk_surface(),
|
|||||||
ABSL_FLAG(bool, train_extremely_large_corpus,
|
ABSL_FLAG(bool, train_extremely_large_corpus,
|
||||||
kDefaultTrainerSpec.train_extremely_large_corpus(),
|
kDefaultTrainerSpec.train_extremely_large_corpus(),
|
||||||
"Increase bit depth for unigram tokenization.");
|
"Increase bit depth for unigram tokenization.");
|
||||||
ABSL_FLAG(uint32, random_seed, static_cast<uint32>(-1),
|
ABSL_FLAG(int32, random_seed, -1, "Seed value for random generator.");
|
||||||
"Seed value for random generator.");
|
|
||||||
|
|
||||||
int main(int argc, char *argv[]) {
|
int main(int argc, char *argv[]) {
|
||||||
sentencepiece::ParseCommandLineFlags(argv[0], &argc, &argv, true);
|
sentencepiece::ParseCommandLineFlags(argv[0], &argc, &argv, true);
|
||||||
@ -154,9 +150,8 @@ int main(int argc, char *argv[]) {
|
|||||||
CHECK(!absl::GetFlag(FLAGS_input).empty());
|
CHECK(!absl::GetFlag(FLAGS_input).empty());
|
||||||
CHECK(!absl::GetFlag(FLAGS_model_prefix).empty());
|
CHECK(!absl::GetFlag(FLAGS_model_prefix).empty());
|
||||||
|
|
||||||
if (absl::GetFlag(FLAGS_random_seed) != -1) {
|
if (absl::GetFlag(FLAGS_random_seed) != -1)
|
||||||
sentencepiece::SetRandomGeneratorSeed(absl::GetFlag(FLAGS_random_seed));
|
sentencepiece::SetRandomGeneratorSeed(absl::GetFlag(FLAGS_random_seed));
|
||||||
}
|
|
||||||
|
|
||||||
auto load_lines = [](absl::string_view filename) {
|
auto load_lines = [](absl::string_view filename) {
|
||||||
std::vector<std::string> lines;
|
std::vector<std::string> lines;
|
||||||
@ -216,7 +211,6 @@ int main(int argc, char *argv[]) {
|
|||||||
SetTrainerSpecFromFlag(split_digits);
|
SetTrainerSpecFromFlag(split_digits);
|
||||||
SetTrainerSpecFromFlag(byte_fallback);
|
SetTrainerSpecFromFlag(byte_fallback);
|
||||||
SetTrainerSpecFromFlag(treat_whitespace_as_suffix);
|
SetTrainerSpecFromFlag(treat_whitespace_as_suffix);
|
||||||
SetTrainerSpecFromFlag(allow_whitespace_only_pieces);
|
|
||||||
SetTrainerSpecFromFlag(hard_vocab_limit);
|
SetTrainerSpecFromFlag(hard_vocab_limit);
|
||||||
SetTrainerSpecFromFlag(use_all_vocab);
|
SetTrainerSpecFromFlag(use_all_vocab);
|
||||||
SetTrainerSpecFromFlag(unk_id);
|
SetTrainerSpecFromFlag(unk_id);
|
||||||
|
@ -12,7 +12,8 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.!
|
// limitations under the License.!
|
||||||
|
|
||||||
#include <algorithm>
|
#include "trainer_interface.h"
|
||||||
|
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <set>
|
#include <set>
|
||||||
@ -33,7 +34,6 @@
|
|||||||
#include "third_party/absl/strings/str_format.h"
|
#include "third_party/absl/strings/str_format.h"
|
||||||
#include "third_party/absl/strings/str_join.h"
|
#include "third_party/absl/strings/str_join.h"
|
||||||
#include "third_party/absl/strings/str_split.h"
|
#include "third_party/absl/strings/str_split.h"
|
||||||
#include "trainer_interface.h"
|
|
||||||
#include "unicode_script.h"
|
#include "unicode_script.h"
|
||||||
#include "util.h"
|
#include "util.h"
|
||||||
|
|
||||||
@ -86,10 +86,6 @@ util::Status VerifySpec(const TrainerSpec &trainer_spec) {
|
|||||||
return util::OkStatus();
|
return util::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool is_unicode_decimal_number(char32 c) {
|
|
||||||
return (c >= 0x30 && c <= 0x39) || (c >= 0xff10 && c <= 0xff19);
|
|
||||||
}
|
|
||||||
|
|
||||||
class SentenceSelector {
|
class SentenceSelector {
|
||||||
public:
|
public:
|
||||||
using Sampler = random::ReservoirSampler<TrainerInterface::Sentence>;
|
using Sampler = random::ReservoirSampler<TrainerInterface::Sentence>;
|
||||||
@ -214,10 +210,9 @@ bool TrainerInterface::IsValidSentencePiece(
|
|||||||
constexpr unicode_script::ScriptType kAnyType =
|
constexpr unicode_script::ScriptType kAnyType =
|
||||||
static_cast<unicode_script::ScriptType>(-1);
|
static_cast<unicode_script::ScriptType>(-1);
|
||||||
|
|
||||||
|
auto is_number = [](char32 c) { return (c >= 0x30 && c <= 0x39); };
|
||||||
|
|
||||||
unicode_script::ScriptType prev_script = kAnyType;
|
unicode_script::ScriptType prev_script = kAnyType;
|
||||||
bool all_whitespace_piece =
|
|
||||||
std::all_of(sentencepiece.begin(), sentencepiece.end(),
|
|
||||||
[](char32 c) { return c == kWSChar; });
|
|
||||||
|
|
||||||
for (size_t pos = 0; pos < sentencepiece.size(); ++pos) {
|
for (size_t pos = 0; pos < sentencepiece.size(); ++pos) {
|
||||||
const char32 c = sentencepiece[pos];
|
const char32 c = sentencepiece[pos];
|
||||||
@ -240,30 +235,25 @@ bool TrainerInterface::IsValidSentencePiece(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (c == kWSChar) {
|
if (c == kWSChar) {
|
||||||
// Only allows whitespace to appear as a prefix of piece unless
|
// Only allows whitespace to appear as a prefix of piece.
|
||||||
// allow_whitespace_only_pieces is True.
|
|
||||||
// When split_by_whitespace is false, we allow whitespaces to
|
// When split_by_whitespace is false, we allow whitespaces to
|
||||||
// appear in the middle, "foo_bar", but do not allow them
|
// appear in the middle, "foo_bar", but do not allow them
|
||||||
// to appear as suffix, "foo_bar_".
|
// to appear as suffix, "foo_bar_".
|
||||||
// Regardless of the setting of split_by_whitespace,
|
// Regardless of the setting of split_by_whitespace,
|
||||||
// whitespace is treated as a prefix/infix of symbol or
|
// whitespace is treated as a prefix/infix of symbol or
|
||||||
// independent symbol, unless allow_whitespace_only_pieces() is true,
|
// independent symbol.
|
||||||
// in which case whitespace only pieces can occur.
|
if (trainer_spec_.treat_whitespace_as_suffix()) {
|
||||||
if (!trainer_spec_.allow_whitespace_only_pieces() ||
|
if ((trainer_spec_.split_by_whitespace() &&
|
||||||
!all_whitespace_piece) {
|
pos < sentencepiece.size() - 1) ||
|
||||||
if (trainer_spec_.treat_whitespace_as_suffix()) {
|
(!trainer_spec_.split_by_whitespace() &&
|
||||||
if ((trainer_spec_.split_by_whitespace() &&
|
pos < sentencepiece.size() - 1 && pos == 0)) {
|
||||||
pos < sentencepiece.size() - 1) ||
|
return false;
|
||||||
(!trainer_spec_.split_by_whitespace() &&
|
}
|
||||||
pos < sentencepiece.size() - 1 && pos == 0)) {
|
} else {
|
||||||
return false;
|
if ((trainer_spec_.split_by_whitespace() && pos > 0) ||
|
||||||
}
|
(!trainer_spec_.split_by_whitespace() && pos > 0 &&
|
||||||
} else {
|
pos == sentencepiece.size() - 1)) {
|
||||||
if ((trainer_spec_.split_by_whitespace() && pos > 0) ||
|
return false;
|
||||||
(!trainer_spec_.split_by_whitespace() && pos > 0 &&
|
|
||||||
pos == sentencepiece.size() - 1)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -275,11 +265,11 @@ bool TrainerInterface::IsValidSentencePiece(
|
|||||||
s = unicode_script::U_Han;
|
s = unicode_script::U_Han;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!trainer_spec_.split_by_number() && is_unicode_decimal_number(c)) {
|
if (!trainer_spec_.split_by_number() && is_number(c)) {
|
||||||
s = kAnyType;
|
s = kAnyType;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (trainer_spec_.split_digits() && is_unicode_decimal_number(c)) {
|
if (trainer_spec_.split_digits() && is_number(c)) {
|
||||||
if (sentencepiece.size() > 1) return false;
|
if (sentencepiece.size() > 1) return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -528,8 +518,7 @@ void TrainerInterface::SplitSentencesByWhitespace() {
|
|||||||
absl::flat_hash_map<std::string, int64> tokens;
|
absl::flat_hash_map<std::string, int64> tokens;
|
||||||
for (const auto &s : sentences_) {
|
for (const auto &s : sentences_) {
|
||||||
for (const auto &w :
|
for (const auto &w :
|
||||||
SplitIntoWords(s.first, trainer_spec_.treat_whitespace_as_suffix(),
|
SplitIntoWords(s.first, trainer_spec_.treat_whitespace_as_suffix())) {
|
||||||
trainer_spec_.allow_whitespace_only_pieces())) {
|
|
||||||
tokens[std::string(w)] += s.second;
|
tokens[std::string(w)] += s.second;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -81,7 +81,6 @@ TEST(TrainerInterfaceTest, IsValidSentencePieceTest) {
|
|||||||
|
|
||||||
trainer_spec.set_split_by_whitespace(false);
|
trainer_spec.set_split_by_whitespace(false);
|
||||||
EXPECT_TRUE(IsValid(WS));
|
EXPECT_TRUE(IsValid(WS));
|
||||||
EXPECT_TRUE(IsValid(WS WS WS "a"));
|
|
||||||
EXPECT_TRUE(IsValid(WS "a"));
|
EXPECT_TRUE(IsValid(WS "a"));
|
||||||
EXPECT_FALSE(IsValid("a" WS));
|
EXPECT_FALSE(IsValid("a" WS));
|
||||||
EXPECT_FALSE(IsValid(WS "a" WS));
|
EXPECT_FALSE(IsValid(WS "a" WS));
|
||||||
@ -89,17 +88,7 @@ TEST(TrainerInterfaceTest, IsValidSentencePieceTest) {
|
|||||||
EXPECT_TRUE(IsValid(WS "a" WS "b"));
|
EXPECT_TRUE(IsValid(WS "a" WS "b"));
|
||||||
EXPECT_TRUE(IsValid(WS "a" WS "b" WS "c"));
|
EXPECT_TRUE(IsValid(WS "a" WS "b" WS "c"));
|
||||||
EXPECT_FALSE(IsValid("a" WS "b" WS));
|
EXPECT_FALSE(IsValid("a" WS "b" WS));
|
||||||
EXPECT_FALSE(IsValid(WS WS));
|
|
||||||
EXPECT_FALSE(IsValid(WS WS WS));
|
|
||||||
|
|
||||||
trainer_spec.set_allow_whitespace_only_pieces(true);
|
|
||||||
EXPECT_TRUE(IsValid(WS));
|
|
||||||
EXPECT_TRUE(IsValid(WS WS));
|
|
||||||
EXPECT_TRUE(IsValid(WS WS WS));
|
|
||||||
EXPECT_TRUE(IsValid(WS WS "a"));
|
|
||||||
EXPECT_FALSE(IsValid("a" WS WS)); // suffix whitespace illegal without flag
|
|
||||||
|
|
||||||
trainer_spec.set_allow_whitespace_only_pieces(false);
|
|
||||||
trainer_spec.set_split_by_unicode_script(false);
|
trainer_spec.set_split_by_unicode_script(false);
|
||||||
EXPECT_TRUE(IsValid("あいう"));
|
EXPECT_TRUE(IsValid("あいう"));
|
||||||
EXPECT_TRUE(IsValid("グーグル"));
|
EXPECT_TRUE(IsValid("グーグル"));
|
||||||
@ -135,15 +124,6 @@ TEST(TrainerInterfaceTest, IsValidSentencePieceTest) {
|
|||||||
EXPECT_FALSE(IsValid(WS "a" WS "b"));
|
EXPECT_FALSE(IsValid(WS "a" WS "b"));
|
||||||
EXPECT_FALSE(IsValid("a" WS "b" WS));
|
EXPECT_FALSE(IsValid("a" WS "b" WS));
|
||||||
|
|
||||||
trainer_spec.set_allow_whitespace_only_pieces(true);
|
|
||||||
EXPECT_TRUE(IsValid(WS));
|
|
||||||
EXPECT_TRUE(IsValid(WS WS));
|
|
||||||
EXPECT_FALSE(IsValid(WS "a" WS));
|
|
||||||
EXPECT_FALSE(IsValid("a" WS "b"));
|
|
||||||
EXPECT_FALSE(IsValid(WS "a" WS "b"));
|
|
||||||
EXPECT_FALSE(IsValid("a" WS "b" WS));
|
|
||||||
|
|
||||||
trainer_spec.set_allow_whitespace_only_pieces(false);
|
|
||||||
trainer_spec.set_split_by_whitespace(false);
|
trainer_spec.set_split_by_whitespace(false);
|
||||||
EXPECT_TRUE(IsValid(WS));
|
EXPECT_TRUE(IsValid(WS));
|
||||||
EXPECT_FALSE(IsValid(WS "a"));
|
EXPECT_FALSE(IsValid(WS "a"));
|
||||||
@ -166,12 +146,6 @@ TEST(TrainerInterfaceTest, IsValidSentencePieceTest) {
|
|||||||
EXPECT_FALSE(IsValid("2007"));
|
EXPECT_FALSE(IsValid("2007"));
|
||||||
EXPECT_FALSE(IsValid("x1"));
|
EXPECT_FALSE(IsValid("x1"));
|
||||||
EXPECT_FALSE(IsValid("2x"));
|
EXPECT_FALSE(IsValid("2x"));
|
||||||
// Fullwidth digits.
|
|
||||||
EXPECT_TRUE(IsValid("1"));
|
|
||||||
EXPECT_FALSE(IsValid("59"));
|
|
||||||
EXPECT_FALSE(IsValid("2007"));
|
|
||||||
EXPECT_FALSE(IsValid("*1"));
|
|
||||||
EXPECT_FALSE(IsValid("2*"));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(TrainerInterfaceTest, OverrideSpecialPiecesTest) {
|
TEST(TrainerInterfaceTest, OverrideSpecialPiecesTest) {
|
||||||
|
@ -15,7 +15,6 @@
|
|||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cfloat>
|
#include <cfloat>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <complex>
|
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <queue>
|
#include <queue>
|
||||||
#include <string>
|
#include <string>
|
||||||
@ -56,17 +55,6 @@ inline float LogSumExp(float x, float y, bool init_mode) {
|
|||||||
return vmax + log(std::exp(static_cast<double>(vmin - vmax)) + 1.0);
|
return vmax + log(std::exp(static_cast<double>(vmin - vmax)) + 1.0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns a sample from a standard Gumbel distribution.
|
|
||||||
// If U ~ U[0, 1], -log(-log U) ~ G(0,1)
|
|
||||||
inline float Gumbel() {
|
|
||||||
const float kEpsilon = 1e-7;
|
|
||||||
auto *mt = random::GetRandomGenerator();
|
|
||||||
std::uniform_real_distribution<float> dis(0.0, 1.0);
|
|
||||||
float noise = -std::log(-(std::log(dis(*mt) + kEpsilon)));
|
|
||||||
|
|
||||||
return noise;
|
|
||||||
}
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Lattice::Lattice() : node_allocator_(kPreallocateLatticeNodeSize) {}
|
Lattice::Lattice() : node_allocator_(kPreallocateLatticeNodeSize) {}
|
||||||
@ -157,7 +145,7 @@ Lattice::Node *Lattice::Insert(int pos, int length) {
|
|||||||
return node;
|
return node;
|
||||||
}
|
}
|
||||||
|
|
||||||
Lattice::LatticePathWithScore Lattice::Viterbi() {
|
std::vector<Lattice::Node *> Lattice::Viterbi() {
|
||||||
const int len = size();
|
const int len = size();
|
||||||
|
|
||||||
for (int pos = 0; pos <= len; ++pos) {
|
for (int pos = 0; pos <= len; ++pos) {
|
||||||
@ -183,7 +171,6 @@ Lattice::LatticePathWithScore Lattice::Viterbi() {
|
|||||||
|
|
||||||
// backtrace
|
// backtrace
|
||||||
std::vector<Node *> results;
|
std::vector<Node *> results;
|
||||||
float score = begin_nodes(len)[0]->backtrace_score;
|
|
||||||
for (Node *node = begin_nodes_[len][0]->prev; node->prev != nullptr;
|
for (Node *node = begin_nodes_[len][0]->prev; node->prev != nullptr;
|
||||||
node = node->prev) {
|
node = node->prev) {
|
||||||
results.push_back(node);
|
results.push_back(node);
|
||||||
@ -191,43 +178,7 @@ Lattice::LatticePathWithScore Lattice::Viterbi() {
|
|||||||
|
|
||||||
std::reverse(results.begin(), results.end());
|
std::reverse(results.begin(), results.end());
|
||||||
|
|
||||||
LatticePathWithScore retval = {results, score};
|
return results;
|
||||||
|
|
||||||
return retval;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<float> Lattice::ForwardAlgorithm(float theta) const {
|
|
||||||
const int len = size();
|
|
||||||
std::vector<float> alpha(node_allocator_.size(), 0.0);
|
|
||||||
|
|
||||||
for (int pos = 0; pos <= len; ++pos) {
|
|
||||||
for (Node *rnode : begin_nodes_[pos]) {
|
|
||||||
for (Node *lnode : end_nodes_[pos]) {
|
|
||||||
alpha[rnode->node_id] = LogSumExp(
|
|
||||||
alpha[rnode->node_id], theta * lnode->score + alpha[lnode->node_id],
|
|
||||||
lnode == end_nodes_[pos][0]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return alpha;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<float> Lattice::BackwardAlgorithm(float theta) const {
|
|
||||||
const int len = size();
|
|
||||||
std::vector<float> beta(node_allocator_.size(), 0.0);
|
|
||||||
|
|
||||||
for (int pos = len; pos >= 0; --pos) {
|
|
||||||
for (Node *lnode : end_nodes_[pos]) {
|
|
||||||
for (Node *rnode : begin_nodes_[pos]) {
|
|
||||||
beta[lnode->node_id] =
|
|
||||||
LogSumExp(beta[lnode->node_id], rnode->score + beta[rnode->node_id],
|
|
||||||
rnode == begin_nodes_[pos][0]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return beta;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
float Lattice::PopulateMarginal(float freq,
|
float Lattice::PopulateMarginal(float freq,
|
||||||
@ -238,9 +189,28 @@ float Lattice::PopulateMarginal(float freq,
|
|||||||
|
|
||||||
// alpha and beta (accumulative log prob) in Forward Backward.
|
// alpha and beta (accumulative log prob) in Forward Backward.
|
||||||
// the index of alpha/beta is Node::node_id.
|
// the index of alpha/beta is Node::node_id.
|
||||||
|
std::vector<float> alpha(node_allocator_.size(), 0.0);
|
||||||
|
std::vector<float> beta(node_allocator_.size(), 0.0);
|
||||||
|
|
||||||
const auto alpha = ForwardAlgorithm(1.0);
|
for (int pos = 0; pos <= len; ++pos) {
|
||||||
const auto beta = BackwardAlgorithm(1.0);
|
for (Node *rnode : begin_nodes_[pos]) {
|
||||||
|
for (Node *lnode : end_nodes_[pos]) {
|
||||||
|
alpha[rnode->node_id] = LogSumExp(alpha[rnode->node_id],
|
||||||
|
lnode->score + alpha[lnode->node_id],
|
||||||
|
lnode == end_nodes_[pos][0]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int pos = len; pos >= 0; --pos) {
|
||||||
|
for (Node *lnode : end_nodes_[pos]) {
|
||||||
|
for (Node *rnode : begin_nodes_[pos]) {
|
||||||
|
beta[lnode->node_id] =
|
||||||
|
LogSumExp(beta[lnode->node_id], rnode->score + beta[rnode->node_id],
|
||||||
|
rnode == begin_nodes_[pos][0]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const float Z = alpha[begin_nodes_[len][0]->node_id];
|
const float Z = alpha[begin_nodes_[len][0]->node_id];
|
||||||
for (int pos = 0; pos < len; ++pos) {
|
for (int pos = 0; pos < len; ++pos) {
|
||||||
@ -258,46 +228,13 @@ float Lattice::PopulateMarginal(float freq,
|
|||||||
return freq * Z;
|
return freq * Z;
|
||||||
}
|
}
|
||||||
|
|
||||||
float Lattice::CalculateEntropy(float theta) const {
|
std::vector<std::vector<Lattice::Node *>> Lattice::NBest(size_t nbest_size) {
|
||||||
const int len = size();
|
|
||||||
|
|
||||||
// alpha[node_id] is the marginal prob of sequence up to start of node
|
|
||||||
// H is entropy of sequence
|
|
||||||
// the index of alpha/H is Node::node_id.
|
|
||||||
std::vector<float> alpha(node_allocator_.size(), 0.0);
|
|
||||||
std::vector<float> H(node_allocator_.size(), 0.0);
|
|
||||||
|
|
||||||
// Populate the forward marginals to get the normalising constant
|
|
||||||
alpha = ForwardAlgorithm(theta);
|
|
||||||
|
|
||||||
// Now populate the forward entropies
|
|
||||||
for (int pos = 0; pos <= len; ++pos) {
|
|
||||||
for (Node *rnode : begin_nodes_[pos]) {
|
|
||||||
for (Node *lnode : end_nodes_[pos]) {
|
|
||||||
// Contribution each lnode makes = p(lnode) * (H(lnode) + log p(lnode))
|
|
||||||
|
|
||||||
// We have to normalise p(lnode) by the marginal contribution it makes
|
|
||||||
const float lnode_transition_prob =
|
|
||||||
((theta * lnode->score) + alpha[lnode->node_id] -
|
|
||||||
alpha[rnode->node_id]);
|
|
||||||
H[rnode->node_id] += std::exp(lnode_transition_prob) *
|
|
||||||
(H[lnode->node_id] + lnode_transition_prob);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return -H[begin_nodes_[len][0]->node_id];
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<Lattice::LatticePathWithScore> Lattice::NBest(size_t nbest_size,
|
|
||||||
bool sample,
|
|
||||||
float theta) {
|
|
||||||
if (nbest_size < 1) {
|
if (nbest_size < 1) {
|
||||||
LOG(WARNING) << "nbest_size >= 1. Returns empty result.";
|
LOG(WARNING) << "nbest_size >= 1. Returns empty result.";
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
if (nbest_size == 1 && !sample) {
|
if (nbest_size == 1) {
|
||||||
return {Viterbi()};
|
return {Viterbi()};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -306,7 +243,6 @@ std::vector<Lattice::LatticePathWithScore> Lattice::NBest(size_t nbest_size,
|
|||||||
// At each partial path x, compute f(x) as follows
|
// At each partial path x, compute f(x) as follows
|
||||||
// f(x) = g(x) + h(x).
|
// f(x) = g(x) + h(x).
|
||||||
// g(x): the sum of scores from EOS to the left-most node in x.
|
// g(x): the sum of scores from EOS to the left-most node in x.
|
||||||
// for a complete hypothesis, g(hyp) is the score of the hypothesis.
|
|
||||||
// h(x): a heuristic that estimates the largest score from x to BOS.
|
// h(x): a heuristic that estimates the largest score from x to BOS.
|
||||||
// f(x): the priority to pop a new hypothesis from the priority queue.
|
// f(x): the priority to pop a new hypothesis from the priority queue.
|
||||||
//
|
//
|
||||||
@ -332,27 +268,18 @@ std::vector<Lattice::LatticePathWithScore> Lattice::NBest(size_t nbest_size,
|
|||||||
model::FreeList<Hypothesis> hypothesis_allocator(kPreallocatedHypothesisSize);
|
model::FreeList<Hypothesis> hypothesis_allocator(kPreallocatedHypothesisSize);
|
||||||
|
|
||||||
Agenda agenda;
|
Agenda agenda;
|
||||||
std::vector<Lattice::LatticePathWithScore> results;
|
std::vector<std::vector<Node *>> results;
|
||||||
|
|
||||||
auto *eos = hypothesis_allocator.Allocate();
|
auto *eos = hypothesis_allocator.Allocate();
|
||||||
eos->node = eos_node();
|
eos->node = eos_node();
|
||||||
eos->next = nullptr;
|
eos->next = nullptr;
|
||||||
eos->gx = 0.0;
|
eos->fx = eos->node->score;
|
||||||
|
eos->gx = eos->node->score;
|
||||||
std::vector<float> alpha(node_allocator_.size(), 0.0);
|
|
||||||
|
|
||||||
if (sample) {
|
|
||||||
// Run forwards algorithm to get normalising constants
|
|
||||||
alpha = ForwardAlgorithm(theta);
|
|
||||||
// f(eos) = Gumbel(0), as it is the perturbed score of the entire lattice.
|
|
||||||
eos->fx = Gumbel();
|
|
||||||
} else {
|
|
||||||
// Run Viterbi first to fill backtrace score.
|
|
||||||
Viterbi();
|
|
||||||
eos->fx = eos->node->backtrace_score;
|
|
||||||
}
|
|
||||||
agenda.push(eos);
|
agenda.push(eos);
|
||||||
|
|
||||||
|
// Run Viterbi first to fill backtrace score.
|
||||||
|
Viterbi();
|
||||||
|
|
||||||
while (!agenda.empty()) {
|
while (!agenda.empty()) {
|
||||||
auto *top = agenda.top();
|
auto *top = agenda.top();
|
||||||
agenda.pop();
|
agenda.pop();
|
||||||
@ -362,56 +289,21 @@ std::vector<Lattice::LatticePathWithScore> Lattice::NBest(size_t nbest_size,
|
|||||||
if (node == bos_node()) {
|
if (node == bos_node()) {
|
||||||
results.resize(results.size() + 1);
|
results.resize(results.size() + 1);
|
||||||
for (auto *n = top->next; n->next != nullptr; n = n->next) {
|
for (auto *n = top->next; n->next != nullptr; n = n->next) {
|
||||||
results.back().first.push_back(n->node);
|
results.back().push_back(n->node);
|
||||||
}
|
}
|
||||||
results.back().second = top->fx;
|
|
||||||
if (results.size() == nbest_size) {
|
if (results.size() == nbest_size) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int end_nodes_size = end_nodes(node->pos).size();
|
|
||||||
std::vector<float> probs(end_nodes_size, 0.0);
|
|
||||||
std::vector<float> perturbed_probs(end_nodes_size, 0.0);
|
|
||||||
std::vector<double> adjusted_probs(end_nodes_size, 0.0);
|
|
||||||
const float Z = alpha[node->node_id];
|
|
||||||
if (sample) {
|
|
||||||
float max_score = -1e8;
|
|
||||||
// Calculate the marginal and perturbed scores for stochastic search
|
|
||||||
for (int i = 0; i < end_nodes(node->pos).size(); i++) {
|
|
||||||
Node *lnode = end_nodes(node->pos)[i];
|
|
||||||
// Calculate backwards transition score
|
|
||||||
probs[i] = top->gx + alpha[lnode->node_id] + (theta * lnode->score) - Z;
|
|
||||||
perturbed_probs[i] = probs[i] + Gumbel();
|
|
||||||
if (perturbed_probs[i] > max_score) {
|
|
||||||
max_score = perturbed_probs[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Now constrain the sampled continuations to match the score of parent
|
|
||||||
for (int i = 0; i < adjusted_probs.size(); i++) {
|
|
||||||
// Use numerically stable version of truncated Gumbel:
|
|
||||||
// https://arxiv.org/pdf/1903.06059.pdf appendix B.3
|
|
||||||
const float v = top->fx - perturbed_probs[i] +
|
|
||||||
std::log1p(-std::exp(perturbed_probs[i] - max_score));
|
|
||||||
adjusted_probs[i] = top->fx - std::max(static_cast<float>(0.0), v) -
|
|
||||||
std::log1p(std::exp(-std::abs(v)));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Expands new node ending at node->pos
|
// Expands new node ending at node->pos
|
||||||
for (int i = 0; i < end_nodes(node->pos).size(); i++) {
|
for (Node *lnode : end_nodes(node->pos)) {
|
||||||
Node *lnode = end_nodes(node->pos)[i];
|
|
||||||
auto *hyp = hypothesis_allocator.Allocate();
|
auto *hyp = hypothesis_allocator.Allocate();
|
||||||
hyp->node = lnode;
|
hyp->node = lnode;
|
||||||
if (sample) {
|
hyp->gx = lnode->score + top->gx; // just adds node->score
|
||||||
hyp->gx = probs[i];
|
hyp->fx =
|
||||||
hyp->fx = adjusted_probs[i];
|
lnode->backtrace_score + top->gx; // backtrace_score is h(node).
|
||||||
} else {
|
|
||||||
hyp->gx = lnode->score + top->gx; // just adds node->score
|
|
||||||
hyp->fx =
|
|
||||||
lnode->backtrace_score + top->gx; // backtrace_score is h(node).
|
|
||||||
}
|
|
||||||
hyp->next = top;
|
hyp->next = top;
|
||||||
agenda.push(hyp);
|
agenda.push(hyp);
|
||||||
}
|
}
|
||||||
@ -443,7 +335,15 @@ std::vector<Lattice::Node *> Lattice::Sample(float theta) {
|
|||||||
|
|
||||||
std::vector<float> alpha(node_allocator_.size(), 0.0);
|
std::vector<float> alpha(node_allocator_.size(), 0.0);
|
||||||
|
|
||||||
alpha = ForwardAlgorithm(theta);
|
for (int pos = 0; pos <= len; ++pos) {
|
||||||
|
for (Node *rnode : begin_nodes_[pos]) {
|
||||||
|
for (Node *lnode : end_nodes_[pos]) {
|
||||||
|
alpha[rnode->node_id] = LogSumExp(
|
||||||
|
alpha[rnode->node_id], theta * lnode->score + alpha[lnode->node_id],
|
||||||
|
lnode == end_nodes_[pos][0]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
auto *mt = random::GetRandomGenerator();
|
auto *mt = random::GetRandomGenerator();
|
||||||
|
|
||||||
@ -614,7 +514,7 @@ EncodeResult Model::Encode(absl::string_view normalized) const {
|
|||||||
PopulateNodes(&lattice);
|
PopulateNodes(&lattice);
|
||||||
|
|
||||||
EncodeResult results;
|
EncodeResult results;
|
||||||
for (const auto *node : lattice.Viterbi().first) {
|
for (const auto *node : lattice.Viterbi()) {
|
||||||
results.emplace_back(node->piece, node->id);
|
results.emplace_back(node->piece, node->id);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -634,12 +534,14 @@ NBestEncodeResult Model::NBestEncode(absl::string_view normalized,
|
|||||||
PopulateNodes(&lattice);
|
PopulateNodes(&lattice);
|
||||||
|
|
||||||
NBestEncodeResult nbest_results;
|
NBestEncodeResult nbest_results;
|
||||||
for (const auto &nbest : lattice.NBest(nbest_size, false, 0.0)) {
|
for (const auto &nbest : lattice.NBest(nbest_size)) {
|
||||||
EncodeResult results;
|
EncodeResult results;
|
||||||
for (const auto *node : nbest.first) {
|
float score = 0.0;
|
||||||
|
for (const auto *node : nbest) {
|
||||||
|
score += node->score;
|
||||||
results.emplace_back(node->piece, node->id);
|
results.emplace_back(node->piece, node->id);
|
||||||
}
|
}
|
||||||
nbest_results.emplace_back(results, nbest.second);
|
nbest_results.emplace_back(results, score);
|
||||||
}
|
}
|
||||||
|
|
||||||
return nbest_results;
|
return nbest_results;
|
||||||
@ -663,123 +565,6 @@ EncodeResult Model::SampleEncode(absl::string_view normalized,
|
|||||||
return results;
|
return results;
|
||||||
}
|
}
|
||||||
|
|
||||||
NBestEncodeResult Model::SampleEncodeAndScore(absl::string_view normalized,
|
|
||||||
float theta, int samples,
|
|
||||||
bool wor,
|
|
||||||
bool include_best) const {
|
|
||||||
if (!status().ok() || normalized.empty()) {
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
NBestEncodeResult results;
|
|
||||||
Lattice lattice;
|
|
||||||
lattice.SetSentence(normalized);
|
|
||||||
PopulateNodes(&lattice);
|
|
||||||
|
|
||||||
std::vector<float> alpha = lattice.ForwardAlgorithm(theta);
|
|
||||||
float marginal = alpha[lattice.eos_node()->node_id];
|
|
||||||
|
|
||||||
if (include_best) {
|
|
||||||
if (!wor) {
|
|
||||||
LOG(FATAL) << "include_best not supported for wor false";
|
|
||||||
}
|
|
||||||
EncodeResult result;
|
|
||||||
Lattice::LatticePathWithScore best_path = lattice.Viterbi();
|
|
||||||
|
|
||||||
for (const auto *node : best_path.first) {
|
|
||||||
result.emplace_back(node->piece, node->id);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Inclusion probability if we always include the best is 1.
|
|
||||||
results.emplace_back(result, 0.0);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (wor) {
|
|
||||||
// Draw k+1 samples as we need perturbed score of k+1th element
|
|
||||||
std::vector<Lattice::LatticePathWithScore> nbest_samples =
|
|
||||||
lattice.NBest(samples + 1, true, theta);
|
|
||||||
|
|
||||||
if (include_best) {
|
|
||||||
std::vector<std::vector<Lattice::Node *>> nbest_paths(
|
|
||||||
nbest_samples.size());
|
|
||||||
for (int i = 0; i < nbest_samples.size(); i++) {
|
|
||||||
nbest_paths[i] = nbest_samples[i].first;
|
|
||||||
}
|
|
||||||
// Remove the best result from the samples if necessary
|
|
||||||
Lattice::LatticePathWithScore best_path = lattice.Viterbi();
|
|
||||||
|
|
||||||
const int index_of_best =
|
|
||||||
(std::find(nbest_paths.begin(), nbest_paths.end(), best_path.first) -
|
|
||||||
nbest_paths.begin());
|
|
||||||
|
|
||||||
if (index_of_best != nbest_samples.size()) {
|
|
||||||
LOG(INFO) << "removing best path from samples";
|
|
||||||
nbest_samples.erase(nbest_samples.begin() + index_of_best);
|
|
||||||
} else {
|
|
||||||
nbest_samples.pop_back();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// We use the perturbed score of the k+1th element to calculate the
|
|
||||||
// inclusion probability.
|
|
||||||
const double kappa = static_cast<double>(nbest_samples.back().second);
|
|
||||||
// Discard the last sample
|
|
||||||
nbest_samples.pop_back();
|
|
||||||
for (const auto &nbest : nbest_samples) {
|
|
||||||
EncodeResult result;
|
|
||||||
float score = 0.0;
|
|
||||||
|
|
||||||
for (const auto *node : nbest.first) {
|
|
||||||
score += (theta * node->score);
|
|
||||||
result.emplace_back(node->piece, node->id);
|
|
||||||
}
|
|
||||||
|
|
||||||
results.emplace_back(result, score - marginal);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now calculate the inclusion probability
|
|
||||||
for (auto &it : results) {
|
|
||||||
// Only modify non best sample inclusion probabilities.
|
|
||||||
if (it.second != 0.0) {
|
|
||||||
double x = it.second - kappa;
|
|
||||||
double y = std::exp(x);
|
|
||||||
double inclusion_prob;
|
|
||||||
if (x <= -10) {
|
|
||||||
// Series expansion of the log Gumbel survival function up to eps.
|
|
||||||
inclusion_prob =
|
|
||||||
x - (y / 2) + (std::pow(y, 2) / 24) - std::pow(y, 4) / 2880;
|
|
||||||
} else {
|
|
||||||
inclusion_prob = std::log(-std::expm1(-y));
|
|
||||||
}
|
|
||||||
it.second = static_cast<float>(inclusion_prob);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
while (results.size() < samples) {
|
|
||||||
Lattice lattice;
|
|
||||||
lattice.SetSentence(normalized);
|
|
||||||
PopulateNodes(&lattice);
|
|
||||||
|
|
||||||
float score = 0.0;
|
|
||||||
EncodeResult result;
|
|
||||||
std::vector<Lattice::Node *> sample = lattice.Sample(theta);
|
|
||||||
for (const auto *node : sample) {
|
|
||||||
result.emplace_back(node->piece, node->id);
|
|
||||||
score += (theta * node->score);
|
|
||||||
}
|
|
||||||
results.emplace_back(result, score - marginal);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return results;
|
|
||||||
}
|
|
||||||
|
|
||||||
float Model::CalculateEntropy(absl::string_view normalized, float theta) const {
|
|
||||||
Lattice lattice;
|
|
||||||
lattice.SetSentence(normalized);
|
|
||||||
PopulateNodes(&lattice);
|
|
||||||
|
|
||||||
return lattice.CalculateEntropy(theta);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool Model::VerifyOutputsEquivalent(absl::string_view expected,
|
bool Model::VerifyOutputsEquivalent(absl::string_view expected,
|
||||||
absl::string_view actual) const {
|
absl::string_view actual) const {
|
||||||
auto compute_unigram_model_score =
|
auto compute_unigram_model_score =
|
||||||
|
@ -82,28 +82,17 @@ class Lattice {
|
|||||||
// After calling this method, The caller must set Node::score and Node::id.
|
// After calling this method, The caller must set Node::score and Node::id.
|
||||||
Node *Insert(int pos, int length);
|
Node *Insert(int pos, int length);
|
||||||
|
|
||||||
using LatticePathWithScore = std::pair<std::vector<Node *>, float>;
|
|
||||||
|
|
||||||
// Returns Viterbi path. All nodes must be populated in advance.
|
// Returns Viterbi path. All nodes must be populated in advance.
|
||||||
LatticePathWithScore Viterbi();
|
std::vector<Node *> Viterbi();
|
||||||
|
|
||||||
// Runs forwards/backwards algorithm, returns vector with normalised
|
|
||||||
// transition probs.
|
|
||||||
std::vector<float> ForwardAlgorithm(float theta) const;
|
|
||||||
std::vector<float> BackwardAlgorithm(float theta) const;
|
|
||||||
|
|
||||||
// Returns n-best results.
|
// Returns n-best results.
|
||||||
std::vector<LatticePathWithScore> NBest(size_t nbest_size, bool sample,
|
std::vector<std::vector<Node *>> NBest(size_t nbest_size);
|
||||||
float theta);
|
|
||||||
|
|
||||||
// Samples one path from the lattice according to the
|
// Samples one path from the lattice according to the
|
||||||
// generation probability (Product of piece probabilities).
|
// generation probability (Product of piece probabilities).
|
||||||
// `theta` is a smoothing parameter.
|
// `theta` is a smoothing parameter.
|
||||||
std::vector<Node *> Sample(float theta);
|
std::vector<Node *> Sample(float theta);
|
||||||
|
|
||||||
// Calculates the entropy of the lattice.
|
|
||||||
float CalculateEntropy(float theta) const;
|
|
||||||
|
|
||||||
// Populates marginal probability of every node in this lattice.
|
// Populates marginal probability of every node in this lattice.
|
||||||
// |freq| is the frequency of the sentence.
|
// |freq| is the frequency of the sentence.
|
||||||
// for (auto *node : all_nodes_) {
|
// for (auto *node : all_nodes_) {
|
||||||
@ -138,19 +127,8 @@ class Model : public ModelInterface {
|
|||||||
EncodeResult SampleEncode(absl::string_view normalized,
|
EncodeResult SampleEncode(absl::string_view normalized,
|
||||||
float theta) const override;
|
float theta) const override;
|
||||||
|
|
||||||
NBestEncodeResult SampleEncodeAndScore(absl::string_view normalized,
|
|
||||||
float theta, int samples, bool wor,
|
|
||||||
bool include_best) const override;
|
|
||||||
|
|
||||||
float CalculateEntropy(absl::string_view normalized,
|
|
||||||
float theta) const override;
|
|
||||||
|
|
||||||
bool IsSampleEncodeAvailable() const override { return true; }
|
bool IsSampleEncodeAvailable() const override { return true; }
|
||||||
|
|
||||||
bool IsSampleEncodeAndScoreAvailable() const override { return true; }
|
|
||||||
|
|
||||||
bool IsCalculateEntropyAvailable() const override { return true; }
|
|
||||||
|
|
||||||
bool IsNBestEncodeAvailable() const override { return true; }
|
bool IsNBestEncodeAvailable() const override { return true; }
|
||||||
|
|
||||||
// Returns the minimum score in sentence pieces.
|
// Returns the minimum score in sentence pieces.
|
||||||
|
@ -161,11 +161,11 @@ TEST(LatticeTest, InsertTest) {
|
|||||||
TEST(LatticeTest, ViterbiFromIncompleteLatticeTest) {
|
TEST(LatticeTest, ViterbiFromIncompleteLatticeTest) {
|
||||||
Lattice lattice;
|
Lattice lattice;
|
||||||
lattice.SetSentence("ABC");
|
lattice.SetSentence("ABC");
|
||||||
EXPECT_TRUE(lattice.Viterbi().first.empty());
|
EXPECT_TRUE(lattice.Viterbi().empty());
|
||||||
|
|
||||||
// Still incomplete
|
// Still incomplete
|
||||||
lattice.Insert(0, 1);
|
lattice.Insert(0, 1);
|
||||||
EXPECT_TRUE(lattice.Viterbi().first.empty());
|
EXPECT_TRUE(lattice.Viterbi().empty());
|
||||||
|
|
||||||
lattice.Insert(1, 1);
|
lattice.Insert(1, 1);
|
||||||
lattice.Insert(2, 1);
|
lattice.Insert(2, 1);
|
||||||
@ -198,16 +198,16 @@ TEST(LatticeTest, ViterbiTest) {
|
|||||||
InsertWithScore(&lattice, 0, 1, 0.0); // A
|
InsertWithScore(&lattice, 0, 1, 0.0); // A
|
||||||
InsertWithScore(&lattice, 1, 1, 0.0); // B
|
InsertWithScore(&lattice, 1, 1, 0.0); // B
|
||||||
InsertWithScore(&lattice, 2, 1, 0.0); // C
|
InsertWithScore(&lattice, 2, 1, 0.0); // C
|
||||||
EXPECT_EQ("A B C", GetTokenized(lattice.Viterbi().first));
|
EXPECT_EQ("A B C", GetTokenized(lattice.Viterbi()));
|
||||||
|
|
||||||
InsertWithScore(&lattice, 0, 2, 2.0); // AB
|
InsertWithScore(&lattice, 0, 2, 2.0); // AB
|
||||||
EXPECT_EQ("AB C", GetTokenized(lattice.Viterbi().first));
|
EXPECT_EQ("AB C", GetTokenized(lattice.Viterbi()));
|
||||||
|
|
||||||
InsertWithScore(&lattice, 1, 2, 5.0); // BC
|
InsertWithScore(&lattice, 1, 2, 5.0); // BC
|
||||||
EXPECT_EQ("A BC", GetTokenized(lattice.Viterbi().first));
|
EXPECT_EQ("A BC", GetTokenized(lattice.Viterbi()));
|
||||||
|
|
||||||
InsertWithScore(&lattice, 0, 3, 10.0); // ABC
|
InsertWithScore(&lattice, 0, 3, 10.0); // ABC
|
||||||
EXPECT_EQ("ABC", GetTokenized(lattice.Viterbi().first));
|
EXPECT_EQ("ABC", GetTokenized(lattice.Viterbi()));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(LatticeTest, NBestTest) {
|
TEST(LatticeTest, NBestTest) {
|
||||||
@ -221,174 +221,21 @@ TEST(LatticeTest, NBestTest) {
|
|||||||
InsertWithScore(&lattice, 1, 2, 5.0); // BC
|
InsertWithScore(&lattice, 1, 2, 5.0); // BC
|
||||||
InsertWithScore(&lattice, 0, 3, 10.0); // ABC
|
InsertWithScore(&lattice, 0, 3, 10.0); // ABC
|
||||||
|
|
||||||
auto nbests = lattice.NBest(10, false, 0.0);
|
auto nbests = lattice.NBest(10);
|
||||||
EXPECT_EQ(4, nbests.size());
|
EXPECT_EQ(4, nbests.size());
|
||||||
|
|
||||||
EXPECT_EQ("ABC", GetTokenized(nbests[0].first));
|
EXPECT_EQ("ABC", GetTokenized(nbests[0]));
|
||||||
EXPECT_EQ("A BC", GetTokenized(nbests[1].first));
|
EXPECT_EQ("A BC", GetTokenized(nbests[1]));
|
||||||
EXPECT_EQ("AB C", GetTokenized(nbests[2].first));
|
EXPECT_EQ("AB C", GetTokenized(nbests[2]));
|
||||||
EXPECT_EQ("A B C", GetTokenized(nbests[3].first));
|
EXPECT_EQ("A B C", GetTokenized(nbests[3]));
|
||||||
|
|
||||||
auto nbests0 = lattice.NBest(0, false, 0.0);
|
auto nbests0 = lattice.NBest(0);
|
||||||
EXPECT_TRUE(nbests0.empty());
|
EXPECT_TRUE(nbests0.empty());
|
||||||
|
|
||||||
auto nbests1 = lattice.NBest(1, false, 0.0);
|
auto nbests1 = lattice.NBest(1);
|
||||||
EXPECT_EQ(nbests1.size(), 1);
|
EXPECT_EQ(nbests1.size(), 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(LatticeTest, NBestSampleTest) {
|
|
||||||
Lattice lattice;
|
|
||||||
lattice.SetSentence("ABC");
|
|
||||||
|
|
||||||
InsertWithScore(&lattice, 0, 1, 0.0); // A
|
|
||||||
InsertWithScore(&lattice, 1, 1, 0.0); // B
|
|
||||||
InsertWithScore(&lattice, 2, 1, 0.1); // C
|
|
||||||
InsertWithScore(&lattice, 0, 2, 0.2); // AB
|
|
||||||
InsertWithScore(&lattice, 1, 2, 0.5); // BC
|
|
||||||
InsertWithScore(&lattice, 0, 3, 1.0); // ABC
|
|
||||||
|
|
||||||
// Calculate expected probabilities of each path
|
|
||||||
// Note that sampling without replacement affects the expected frequencies!
|
|
||||||
const std::vector<double> kTheta = {0.0, 0.01, 0.5, 0.7, 1.0};
|
|
||||||
for (const auto theta : kTheta) {
|
|
||||||
std::vector<std::string> strings = {"ABC", "AB C", "A BC", "A B C"};
|
|
||||||
std::map<std::string, float> probs;
|
|
||||||
probs["ABC"] = std::exp(theta * 1.0);
|
|
||||||
probs["AB C"] = std::exp(theta * (0.2 + 0.1));
|
|
||||||
probs["A BC"] = std::exp(theta * (0.0 + 0.5));
|
|
||||||
probs["A B C"] = std::exp(theta * (0.0 + 0.0 + 0.1));
|
|
||||||
|
|
||||||
for (const auto &it : strings) {
|
|
||||||
EXPECT_EQ(1, probs.count(it));
|
|
||||||
}
|
|
||||||
|
|
||||||
double Z = 0.0;
|
|
||||||
for (const auto &it : probs) Z += it.second;
|
|
||||||
for (auto &it : probs) it.second /= Z;
|
|
||||||
|
|
||||||
std::map<std::pair<std::string, std::string>, float> pair_probs;
|
|
||||||
for (const auto first : strings) {
|
|
||||||
for (const auto second : strings) {
|
|
||||||
if (first == second) {
|
|
||||||
pair_probs[std::make_pair(first, second)] = 0;
|
|
||||||
} else {
|
|
||||||
float first_prob = probs[first];
|
|
||||||
float second_prob = probs[second] / (1 - first_prob);
|
|
||||||
pair_probs[std::make_pair(first, second)] = first_prob * second_prob;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::map<std::string, float> inclusion_probs;
|
|
||||||
for (const auto string : strings) {
|
|
||||||
float inclusion_prob = 0.0;
|
|
||||||
for (const auto other_string : strings) {
|
|
||||||
inclusion_prob += pair_probs[std::make_pair(string, other_string)];
|
|
||||||
}
|
|
||||||
for (const auto other_string : strings) {
|
|
||||||
inclusion_prob += pair_probs[std::make_pair(other_string, string)];
|
|
||||||
}
|
|
||||||
inclusion_probs[string] = inclusion_prob / 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
int kTrials = 10000;
|
|
||||||
|
|
||||||
std::vector<int> kNumSamples = {1, 2};
|
|
||||||
|
|
||||||
for (const auto num_samples : kNumSamples) {
|
|
||||||
std::map<std::string, int> counts;
|
|
||||||
for (int i = 0; i < kTrials; i++) {
|
|
||||||
auto nbests = lattice.NBest(num_samples, true, theta);
|
|
||||||
for (const auto nbest : nbests) {
|
|
||||||
counts[GetTokenized(nbest.first)]++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
EXPECT_EQ(inclusion_probs.size(), counts.size());
|
|
||||||
// If we take multiple samples WOR, we have to use corrected probs.
|
|
||||||
std::map<std::string, float> probs_to_use =
|
|
||||||
(num_samples == 1 ? probs : inclusion_probs);
|
|
||||||
|
|
||||||
for (const auto &it : probs_to_use) {
|
|
||||||
EXPECT_NEAR(it.second, 1.0 * counts[it.first] / (kTrials * num_samples),
|
|
||||||
0.02);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(LatticeTest, CalculateEntropyTest) {
|
|
||||||
Lattice lattice;
|
|
||||||
lattice.SetSentence("ABC");
|
|
||||||
|
|
||||||
InsertWithScore(&lattice, 0, 1, 0.0); // A
|
|
||||||
InsertWithScore(&lattice, 1, 1, 0.0); // B
|
|
||||||
InsertWithScore(&lattice, 2, 1, 0.1); // C
|
|
||||||
InsertWithScore(&lattice, 0, 2, 0.2); // AB
|
|
||||||
InsertWithScore(&lattice, 1, 2, 0.5); // BC
|
|
||||||
InsertWithScore(&lattice, 0, 3, 1.0); // ABC
|
|
||||||
|
|
||||||
// Calculate expected probabilities of each path
|
|
||||||
const std::vector<double> kTheta = {0.0, 0.01, 0.5, 0.7, 1.0};
|
|
||||||
for (const auto theta : kTheta) {
|
|
||||||
std::vector<std::string> strings = {"ABC", "AB C", "A BC", "A B C"};
|
|
||||||
std::map<std::string, float> probs;
|
|
||||||
probs["ABC"] = std::exp(theta * 1.0);
|
|
||||||
probs["AB C"] = std::exp(theta * (0.2 + 0.1));
|
|
||||||
probs["A BC"] = std::exp(theta * (0.0 + 0.5));
|
|
||||||
probs["A B C"] = std::exp(theta * (0.0 + 0.0 + 0.1));
|
|
||||||
|
|
||||||
double Z = 0.0;
|
|
||||||
for (const auto &it : probs) Z += it.second;
|
|
||||||
for (auto &it : probs) it.second /= Z;
|
|
||||||
|
|
||||||
for (const auto &it : strings) {
|
|
||||||
EXPECT_EQ(1, probs.count(it));
|
|
||||||
}
|
|
||||||
float entropy = 0.0;
|
|
||||||
for (const auto &it : probs) {
|
|
||||||
entropy += (it.second * std::log(it.second));
|
|
||||||
}
|
|
||||||
EXPECT_NEAR(-entropy, lattice.CalculateEntropy(theta), 0.02);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(LatticeTest, ForwardAlgorithmTest) {
|
|
||||||
Lattice lattice;
|
|
||||||
lattice.SetSentence("ABC");
|
|
||||||
|
|
||||||
InsertWithScore(&lattice, 0, 1, 0.0); // A
|
|
||||||
InsertWithScore(&lattice, 1, 1, 0.0); // B
|
|
||||||
InsertWithScore(&lattice, 2, 1, 0.1); // C
|
|
||||||
InsertWithScore(&lattice, 0, 2, 0.2); // AB
|
|
||||||
InsertWithScore(&lattice, 1, 2, 0.5); // BC
|
|
||||||
InsertWithScore(&lattice, 0, 3, 1.0); // ABC
|
|
||||||
|
|
||||||
const std::vector<float> kTheta = {0.0, 0.01, 0.5, 0.7, 1.0};
|
|
||||||
for (const auto theta : kTheta) {
|
|
||||||
std::vector<float> alpha = lattice.ForwardAlgorithm(theta);
|
|
||||||
EXPECT_EQ(alpha.size(), 8); // 6 nodes, plus BOS, EOS
|
|
||||||
// only alpha[C], alpha[EOS] have non-zero alpha
|
|
||||||
for (int i : {0, 1, 2, 3}) {
|
|
||||||
for (const auto &node : lattice.begin_nodes(i)) {
|
|
||||||
if (i < 2) {
|
|
||||||
EXPECT_EQ(alpha[node->node_id], 0.0);
|
|
||||||
} else if (i == 2) {
|
|
||||||
float Z =
|
|
||||||
std::log(std::exp(theta * (0.0 + 0.0)) + std::exp(theta * 0.2));
|
|
||||||
EXPECT_EQ(alpha[node->node_id], Z);
|
|
||||||
} else if (i == 3) {
|
|
||||||
float Z = std::log(std::exp(theta * (0.0 + 0.0 + 0.1)) + // A + B + C
|
|
||||||
std::exp(theta * (0.2 + 0.1)) + // AB + C
|
|
||||||
std::exp(theta * (0.0 + 0.5)) + // A + BC
|
|
||||||
std::exp(theta * 1.0)); // ABC
|
|
||||||
EXPECT_EQ(Z, alpha[node->node_id]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(LatticeTest, PopulateMarginalTest) {
|
TEST(LatticeTest, PopulateMarginalTest) {
|
||||||
Lattice lattice;
|
Lattice lattice;
|
||||||
lattice.SetSentence("ABC");
|
lattice.SetSentence("ABC");
|
||||||
@ -514,102 +361,6 @@ TEST(UnigramModelTest, SetUnigramModelTest) {
|
|||||||
model.model_proto().SerializeAsString());
|
model.model_proto().SerializeAsString());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(UnigramModelTest, SampleEncodeAndScoreTest) {
|
|
||||||
// Test whether inclusion probabilities are correct
|
|
||||||
ModelProto model_proto = MakeBaseModelProto();
|
|
||||||
AddPiece(&model_proto, "A", 0.0); // 3
|
|
||||||
AddPiece(&model_proto, "B", 0.0); // 4
|
|
||||||
AddPiece(&model_proto, "C", 0.1); // 5
|
|
||||||
AddPiece(&model_proto, "AB", 0.2); // 6
|
|
||||||
AddPiece(&model_proto, "BC", 0.5); // 7
|
|
||||||
AddPiece(&model_proto, "ABC", 1.0); // 8
|
|
||||||
|
|
||||||
Model model(model_proto);
|
|
||||||
|
|
||||||
Lattice lattice;
|
|
||||||
lattice.SetSentence("ABC");
|
|
||||||
model.PopulateNodes(&lattice);
|
|
||||||
|
|
||||||
std::vector<float> kTheta = {0.0, 1.0};
|
|
||||||
|
|
||||||
for (const auto theta : kTheta) {
|
|
||||||
std::vector<std::string> strings = {"ABC", "AB C", "A BC", "A B C"};
|
|
||||||
std::map<std::string, float> probs;
|
|
||||||
probs["ABC"] = std::exp(theta * 1.0);
|
|
||||||
probs["AB C"] = std::exp(theta * (0.2 + 0.1));
|
|
||||||
probs["A BC"] = std::exp(theta * (0.0 + 0.5));
|
|
||||||
probs["A B C"] = std::exp(theta * (0.0 + 0.0 + 0.1));
|
|
||||||
|
|
||||||
for (const auto &it : strings) {
|
|
||||||
EXPECT_EQ(1, probs.count(it));
|
|
||||||
}
|
|
||||||
|
|
||||||
double Z = 0.0;
|
|
||||||
for (const auto &it : probs) Z += it.second;
|
|
||||||
for (auto &it : probs) it.second /= Z;
|
|
||||||
|
|
||||||
std::map<std::pair<std::string, std::string>, float> pair_probs;
|
|
||||||
for (const auto first : strings) {
|
|
||||||
for (const auto second : strings) {
|
|
||||||
if (first == second) {
|
|
||||||
pair_probs[std::make_pair(first, second)] = 0;
|
|
||||||
} else {
|
|
||||||
float first_prob = probs[first];
|
|
||||||
float second_prob = probs[second] / (1 - first_prob);
|
|
||||||
pair_probs[std::make_pair(first, second)] = first_prob * second_prob;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::map<std::string, float> inclusion_probs;
|
|
||||||
for (const auto string : strings) {
|
|
||||||
float inclusion_prob = 0.0;
|
|
||||||
for (const auto other_string : strings) {
|
|
||||||
inclusion_prob += pair_probs[std::make_pair(string, other_string)];
|
|
||||||
}
|
|
||||||
for (const auto other_string : strings) {
|
|
||||||
inclusion_prob += pair_probs[std::make_pair(other_string, string)];
|
|
||||||
}
|
|
||||||
inclusion_probs[string] = inclusion_prob / 2;
|
|
||||||
}
|
|
||||||
std::vector<int> kNumSamples = {1, 2};
|
|
||||||
|
|
||||||
for (const auto num_samples : kNumSamples) {
|
|
||||||
std::map<std::string, int> counts;
|
|
||||||
std::map<std::string, float> scores;
|
|
||||||
int kTrials = 50000;
|
|
||||||
for (int i = 0; i < kTrials; i++) {
|
|
||||||
NBestEncodeResult sample =
|
|
||||||
model.SampleEncodeAndScore("ABC", theta, num_samples, true, false);
|
|
||||||
|
|
||||||
for (const auto &it : sample) {
|
|
||||||
std::vector<std::string> tokens;
|
|
||||||
for (const auto &inner_it : it.first) {
|
|
||||||
tokens.push_back(std::string(inner_it.first));
|
|
||||||
}
|
|
||||||
std::string sample_string = absl::StrJoin(tokens, " ");
|
|
||||||
counts[sample_string] += 1;
|
|
||||||
// use the fact that E(1_{i in sample} / score of i) = 1
|
|
||||||
// see https://arxiv.org/pdf/1903.06059.pdf appendix D
|
|
||||||
scores[sample_string] += std::exp(-it.second);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check that counts and probs are correct
|
|
||||||
std::map<std::string, float> probs_to_use =
|
|
||||||
(num_samples == 1 ? probs : inclusion_probs);
|
|
||||||
|
|
||||||
for (const auto &it : scores) Z += it.second;
|
|
||||||
for (const auto &it : probs_to_use) {
|
|
||||||
EXPECT_NEAR(it.second, 1.0 * counts[it.first] / (kTrials * num_samples),
|
|
||||||
0.02);
|
|
||||||
// The expectation is quite loose, use a higher tolerance
|
|
||||||
EXPECT_NEAR(1.0, scores[it.first] / kTrials, 0.05);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_P(UnigramModelTest, PieceToIdTest) {
|
TEST_P(UnigramModelTest, PieceToIdTest) {
|
||||||
ModelProto model_proto = MakeBaseModelProto();
|
ModelProto model_proto = MakeBaseModelProto();
|
||||||
|
|
||||||
|
@ -223,7 +223,7 @@ std::vector<float> Trainer::RunEStep(const TrainerModel &model, float *obj,
|
|||||||
lattice.SetSentence(w);
|
lattice.SetSentence(w);
|
||||||
model.PopulateNodes(&lattice);
|
model.PopulateNodes(&lattice);
|
||||||
const float Z = lattice.PopulateMarginal(freq, &expected[n]);
|
const float Z = lattice.PopulateMarginal(freq, &expected[n]);
|
||||||
ntokens[n] += lattice.Viterbi().first.size();
|
ntokens[n] += lattice.Viterbi().size();
|
||||||
CHECK(!std::isnan(Z))
|
CHECK(!std::isnan(Z))
|
||||||
<< "likelihood is NAN. Input sentence may be too long";
|
<< "likelihood is NAN. Input sentence may be too long";
|
||||||
objs[n] -= Z / all_sentence_freq;
|
objs[n] -= Z / all_sentence_freq;
|
||||||
@ -297,17 +297,17 @@ TrainerModel::SentencePieces Trainer::PruneSentencePieces(
|
|||||||
const auto &w = sentencepieces[i];
|
const auto &w = sentencepieces[i];
|
||||||
lattice.SetSentence(w.first);
|
lattice.SetSentence(w.first);
|
||||||
model.PopulateNodes(&lattice);
|
model.PopulateNodes(&lattice);
|
||||||
const auto nbests = lattice.NBest(2, false, 0.0);
|
const auto nbests = lattice.NBest(2);
|
||||||
if (nbests.size() == 1) {
|
if (nbests.size() == 1) {
|
||||||
// No second-best result is found. always keep this sentencepiece.
|
// No second-best result is found. always keep this sentencepiece.
|
||||||
always_keep[i] = true;
|
always_keep[i] = true;
|
||||||
continue;
|
continue;
|
||||||
} else if (nbests[0].first.size() >= 2) {
|
} else if (nbests[0].size() >= 2) {
|
||||||
// Can safely remove this sentencepiece if its Viterbi path is split.
|
// Can safely remove this sentencepiece if its Viterbi path is split.
|
||||||
always_keep[i] = false;
|
always_keep[i] = false;
|
||||||
} else if (nbests[0].first.size() == 1) {
|
} else if (nbests[0].size() == 1) {
|
||||||
always_keep[i] = true;
|
always_keep[i] = true;
|
||||||
for (const auto *node : nbests[1].first) {
|
for (const auto *node : nbests[1]) {
|
||||||
alternatives[i].push_back(node->id);
|
alternatives[i].push_back(node->id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -339,7 +339,7 @@ TrainerModel::SentencePieces Trainer::PruneSentencePieces(
|
|||||||
lattice.SetSentence(w.first);
|
lattice.SetSentence(w.first);
|
||||||
model.PopulateNodes(&lattice);
|
model.PopulateNodes(&lattice);
|
||||||
vsums[n] += w.second;
|
vsums[n] += w.second;
|
||||||
for (const auto *node : lattice.Viterbi().first) {
|
for (const auto *node : lattice.Viterbi()) {
|
||||||
if (node->id >= 0) {
|
if (node->id >= 0) {
|
||||||
freqs[n][node->id] += w.second;
|
freqs[n][node->id] += w.second;
|
||||||
inverteds[n][node->id].push_back(i);
|
inverteds[n][node->id].push_back(i);
|
||||||
|
@ -12,12 +12,11 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.!
|
// limitations under the License.!
|
||||||
|
|
||||||
#include <iostream>
|
|
||||||
|
|
||||||
#include "util.h"
|
#include "util.h"
|
||||||
|
|
||||||
namespace sentencepiece {
|
#include <iostream>
|
||||||
|
|
||||||
|
namespace sentencepiece {
|
||||||
namespace {
|
namespace {
|
||||||
constexpr unsigned int kDefaultSeed = static_cast<unsigned int>(-1);
|
constexpr unsigned int kDefaultSeed = static_cast<unsigned int>(-1);
|
||||||
static unsigned int g_seed = kDefaultSeed;
|
static unsigned int g_seed = kDefaultSeed;
|
||||||
|
1
third_party/absl/flags/flag.cc
vendored
1
third_party/absl/flags/flag.cc
vendored
@ -171,7 +171,6 @@ void Flag<bool>::set_value_as_str(const std::string &value_as_str) {
|
|||||||
|
|
||||||
template class Flag<std::string>;
|
template class Flag<std::string>;
|
||||||
template class Flag<int32>;
|
template class Flag<int32>;
|
||||||
template class Flag<uint32>;
|
|
||||||
template class Flag<double>;
|
template class Flag<double>;
|
||||||
template class Flag<bool>;
|
template class Flag<bool>;
|
||||||
template class Flag<int64>;
|
template class Flag<int64>;
|
||||||
|
Loading…
Reference in New Issue
Block a user