mirror of
https://github.com/google/sentencepiece.git
synced 2024-10-26 11:38:45 +03:00
Fix the ULM training bugs
This commit is contained in:
parent
ba44ab1ca0
commit
bb0b610fae
@ -461,7 +461,7 @@ std::vector<Lattice::LatticePathWithScore> Lattice::NBest(size_t nbest_size,
|
||||
} else {
|
||||
hyp->gx = lnode->score + top->gx; // just adds node->score
|
||||
hyp->fx =
|
||||
lnode->backtrace_score + hyp->gx; // backtrace_score is h(node).
|
||||
lnode->backtrace_score + top->gx; // backtrace_score is h(node).
|
||||
}
|
||||
hyp->next = top;
|
||||
agenda.push(hyp);
|
||||
|
@ -28,7 +28,6 @@
|
||||
#include "pretokenizer_for_training.h"
|
||||
#include "sentencepiece_trainer.h"
|
||||
#include "third_party/absl/container/flat_hash_map.h"
|
||||
#include "third_party/absl/container/flat_hash_set.h"
|
||||
#include "third_party/absl/memory/memory.h"
|
||||
#include "third_party/absl/strings/str_replace.h"
|
||||
#include "third_party/absl/strings/str_split.h"
|
||||
@ -41,7 +40,6 @@ namespace unigram {
|
||||
namespace {
|
||||
|
||||
constexpr char32 kSentenceBoundary = 0x0000;
|
||||
constexpr char32 kWsMarker = 0x2581;
|
||||
|
||||
double Digamma(double x) {
|
||||
double result = 0.0;
|
||||
@ -67,34 +65,16 @@ void ToLogProb(IT begin, IT end) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<std::pair<const T *, const T *>> SplitBySentenceBoundary(
|
||||
const T *begin, const T *end) {
|
||||
std::vector<std::pair<const T *, const T *>> result;
|
||||
|
||||
while (begin < end) {
|
||||
const auto *p = std::find(begin, end, static_cast<T>(kSentenceBoundary));
|
||||
if (p != end) {
|
||||
result.emplace_back(begin, p);
|
||||
begin = p + 1;
|
||||
} else {
|
||||
result.emplace_back(begin, end);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
class BoundedPriorityQueue {
|
||||
public:
|
||||
explicit BoundedPriorityQueue(size_t size) : size_(size) {}
|
||||
~BoundedPriorityQueue() = default;
|
||||
|
||||
void push(const T &elem, int64 score) {
|
||||
void push(T elem, int64 score) {
|
||||
if (queue_.size() > 4 * size_) resize();
|
||||
if (queue_.size() >= size_ && queue_[size_ - 1].second > score) return;
|
||||
if (sorted && queue_.size() >= size_ && queue_[size_ - 1].second > score)
|
||||
return;
|
||||
queue_.emplace_back(elem, score);
|
||||
}
|
||||
|
||||
@ -109,16 +89,11 @@ class BoundedPriorityQueue {
|
||||
return (p1.second > p2.second ||
|
||||
(p1.second == p2.second && p1.first < p2.first));
|
||||
});
|
||||
|
||||
absl::flat_hash_set<absl::string_view> dup;
|
||||
std::vector<std::pair<T, int64>> new_queue;
|
||||
for (auto &p : queue_) {
|
||||
if (dup.insert(p.first).second) new_queue.emplace_back(std::move(p));
|
||||
if (new_queue.size() == size_) break;
|
||||
}
|
||||
queue_ = std::move(new_queue);
|
||||
sorted = true;
|
||||
if (queue_.size() > size_) queue_.resize(size_);
|
||||
}
|
||||
|
||||
bool sorted = false;
|
||||
size_t size_ = 0;
|
||||
std::vector<std::pair<T, int64>> queue_;
|
||||
};
|
||||
@ -249,8 +224,7 @@ TrainerModel::SentencePieces Trainer::MakeSeedSentencePiecesInternal() {
|
||||
n, kAlphabetSize, node_num));
|
||||
|
||||
LOG(INFO) << "Extracting frequent sub strings... node_num=" << node_num;
|
||||
|
||||
BoundedPriorityQueue<std::string> queue(
|
||||
BoundedPriorityQueue<node_int_type> queue(
|
||||
static_cast<size_t>(trainer_spec_.seed_sentencepiece_size()));
|
||||
|
||||
for (node_int_type i = 0; i < node_num; ++i) {
|
||||
@ -259,31 +233,21 @@ TrainerModel::SentencePieces Trainer::MakeSeedSentencePiecesInternal() {
|
||||
if (len <= 1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (const auto &p :
|
||||
SplitBySentenceBoundary(&array[offset], &array[offset + len])) {
|
||||
if (p.first == p.second) continue;
|
||||
const auto [begin, end] = NormalizeRange(p.first, p.second);
|
||||
|
||||
const char32 *begin = &array[offset];
|
||||
const char32 *end = &array[offset + len];
|
||||
// Skips if a substring contains a sentence boundary.
|
||||
if (std::find(begin, end, kSentenceBoundary) != end) {
|
||||
continue;
|
||||
}
|
||||
const UnicodeText uw(begin, end);
|
||||
if (uw.size() <= 1 || !IsValidSentencePiece(uw)) {
|
||||
if (!IsValidSentencePiece(uw)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// character-wise coverage is the default score.
|
||||
const node_int_type freq = R[i] - L[i];
|
||||
const node_int_type score = freq * len;
|
||||
|
||||
const auto w = string_util::UnicodeTextToUTF8(uw);
|
||||
queue.push(w, score);
|
||||
|
||||
const auto subpieces =
|
||||
SplitIntoWords(w, trainer_spec_.treat_whitespace_as_suffix(),
|
||||
trainer_spec_.allow_whitespace_only_pieces());
|
||||
if (subpieces.size() > 1) {
|
||||
for (const auto &s : subpieces) queue.push(std::string(s), score);
|
||||
}
|
||||
}
|
||||
queue.push(i, score);
|
||||
}
|
||||
|
||||
// all_chars must be included in the seed sentencepieces.
|
||||
@ -293,7 +257,16 @@ TrainerModel::SentencePieces Trainer::MakeSeedSentencePiecesInternal() {
|
||||
}
|
||||
|
||||
for (const auto &p : queue.get()) {
|
||||
seed_sentencepieces.emplace_back(p);
|
||||
const node_int_type offset = SA[L[p.first]];
|
||||
const node_int_type len = D[p.first];
|
||||
CHECK_GT(len, 0);
|
||||
const char32 *begin = &array[offset];
|
||||
const char32 *end = &array[offset + len];
|
||||
const UnicodeText uw(begin, end);
|
||||
const std::string w = string_util::UnicodeTextToUTF8(uw);
|
||||
CHECK(IsValidSentencePiece(uw)); // just in case.
|
||||
CHECK(!port::ContainsKey(all_chars, w));
|
||||
seed_sentencepieces.emplace_back(w, p.second);
|
||||
}
|
||||
|
||||
ToLogProb(seed_sentencepieces.begin(), seed_sentencepieces.end());
|
||||
@ -496,10 +469,10 @@ TrainerModel::SentencePieces Trainer::PruneSentencePieces(
|
||||
|
||||
// After removing the sentencepiece[i], its frequency freq[i] is
|
||||
// re-assigned to alternatives.
|
||||
// new_sum = current_sum - freq[i] + freq[i] * alternatives.size()
|
||||
// = current_sum + freq[i] (alternatives - 1)
|
||||
// new_sum = current_sum - freq[i] + freq[i] * alternatives[i].size()
|
||||
// = current_sum + freq[i] * (alternatives[i] - 1)
|
||||
const float logsum_alt = std::log(
|
||||
static_cast<double>(sum + freq[i] * (alternatives.size() - 1)));
|
||||
static_cast<double>(sum + freq[i] * (alternatives[i].size() - 1)));
|
||||
|
||||
// The frequencies of altenatives are increased by freq[i].
|
||||
float logprob_alt = 0.0;
|
||||
@ -530,22 +503,6 @@ TrainerModel::SentencePieces Trainer::PruneSentencePieces(
|
||||
return new_sentencepieces;
|
||||
}
|
||||
|
||||
std::pair<const char32 *, const char32 *> Trainer::NormalizeRange(
|
||||
const char32 *begin, const char32 *end) const {
|
||||
if (trainer_spec_.treat_whitespace_as_suffix()) {
|
||||
while ((*begin == kSentenceBoundary || *begin == kWsMarker) &&
|
||||
begin + 1 < end)
|
||||
++begin;
|
||||
while (*(end - 1) == kSentenceBoundary && begin + 1 < end) --end;
|
||||
} else {
|
||||
while (*begin == kSentenceBoundary && begin + 1 < end) ++begin;
|
||||
while ((*(end - 1) == kSentenceBoundary || *(end - 1) == kWsMarker) &&
|
||||
begin + 1 < end)
|
||||
--end;
|
||||
}
|
||||
return std::make_pair(begin, end);
|
||||
}
|
||||
|
||||
TrainerModel::SentencePieces Trainer::FinalizeSentencePieces(
|
||||
const TrainerModel &model) const {
|
||||
const auto &sentencepieces = model.GetSentencePieces();
|
||||
|
@ -105,9 +105,6 @@ class Trainer : public TrainerInterface {
|
||||
TrainerModel::SentencePieces FinalizeSentencePieces(
|
||||
const TrainerModel &model) const;
|
||||
|
||||
std::pair<const char32 *, const char32 *> NormalizeRange(
|
||||
const char32 *begin, const char32 *end) const;
|
||||
|
||||
// When the size of SentencePieces becomes less than desired_vocab_size_,
|
||||
// break the main training loop. desired_vocab_size_ = 1.1 * vocab_size_
|
||||
// for now.
|
||||
|
@ -117,12 +117,12 @@ TEST(UnigramTrainerTest, BasicTest) {
|
||||
30);
|
||||
|
||||
// Check seed pieces.
|
||||
EXPECT_EQ(63, res.seed_pieces_and_probs.size());
|
||||
EXPECT_EQ(27, res.seed_pieces_and_probs.size());
|
||||
|
||||
LOG(INFO) << "[" << res.sentence_pieces << "]";
|
||||
|
||||
// Check final pieces.
|
||||
EXPECT_EQ(
|
||||
"Overly Pineapple magnanimity Available ▁an ▁ a b A t g r P O v m y p n "
|
||||
"l h d e i",
|
||||
EXPECT_EQ("i a n y m l e apple ve O P r g t an v ▁ b A le ▁an p d h",
|
||||
res.sentence_pieces);
|
||||
}
|
||||
|
||||
@ -134,7 +134,8 @@ TEST(UnigramTrainerTest, BasicDPTest) {
|
||||
"Overly \t 6", "Available \t 5"},
|
||||
22, true /*use_dp*/, 0 /*dp_noise*/, 4 /*dp_clipping*/);
|
||||
|
||||
EXPECT_EQ(49, res.seed_pieces_and_probs.size());
|
||||
// Got 16 instead of 27 seeds.
|
||||
EXPECT_EQ(16, res.seed_pieces_and_probs.size());
|
||||
|
||||
// And they are equiv to if the last sentence was not there.
|
||||
const auto& res_nodp = RunTrainer(
|
||||
@ -195,9 +196,9 @@ TEST(UnigramTrainerTest, EndToEndTest) {
|
||||
LOG(INFO) << "[" << absl::StrJoin(tok, " ") << std::endl;
|
||||
EXPECT_EQ(
|
||||
WS
|
||||
" 吾輩 《 わ が は い 》 は猫である 。 名前はまだ 無 い 。 どこ で 生 "
|
||||
"れた か とん と 見当 《 けん とう 》 が つか ぬ 。 何でも 薄 暗 い じめ "
|
||||
"じめ した 所で ニャーニャー 泣 い ていた 事 だけは 記憶 している 。",
|
||||
" 吾輩 《 わが はい 》 は猫である 。 名前はまだ 無 い 。 どこ で 生 れた "
|
||||
"か とん と 見当 《 けん とう 》 が つか ぬ 。 何でも 薄 暗 い じめ じめ "
|
||||
"した 所で ニャーニャー 泣 い ていた 事 だけは 記憶 している 。",
|
||||
absl::StrJoin(tok, " "));
|
||||
#endif
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user