Fix the ULM training bugs

This commit is contained in:
Taku Kudo 2023-04-27 17:32:57 +00:00
parent ba44ab1ca0
commit bb0b610fae
4 changed files with 45 additions and 90 deletions

View File

@ -461,7 +461,7 @@ std::vector<Lattice::LatticePathWithScore> Lattice::NBest(size_t nbest_size,
} else {
hyp->gx = lnode->score + top->gx; // just adds node->score
hyp->fx =
lnode->backtrace_score + hyp->gx; // backtrace_score is h(node).
lnode->backtrace_score + top->gx; // backtrace_score is h(node).
}
hyp->next = top;
agenda.push(hyp);

View File

@ -28,7 +28,6 @@
#include "pretokenizer_for_training.h"
#include "sentencepiece_trainer.h"
#include "third_party/absl/container/flat_hash_map.h"
#include "third_party/absl/container/flat_hash_set.h"
#include "third_party/absl/memory/memory.h"
#include "third_party/absl/strings/str_replace.h"
#include "third_party/absl/strings/str_split.h"
@ -41,7 +40,6 @@ namespace unigram {
namespace {
constexpr char32 kSentenceBoundary = 0x0000;
constexpr char32 kWsMarker = 0x2581;
double Digamma(double x) {
double result = 0.0;
@ -67,34 +65,16 @@ void ToLogProb(IT begin, IT end) {
}
}
template <typename T>
std::vector<std::pair<const T *, const T *>> SplitBySentenceBoundary(
const T *begin, const T *end) {
std::vector<std::pair<const T *, const T *>> result;
while (begin < end) {
const auto *p = std::find(begin, end, static_cast<T>(kSentenceBoundary));
if (p != end) {
result.emplace_back(begin, p);
begin = p + 1;
} else {
result.emplace_back(begin, end);
break;
}
}
return result;
}
template <class T>
class BoundedPriorityQueue {
public:
explicit BoundedPriorityQueue(size_t size) : size_(size) {}
~BoundedPriorityQueue() = default;
void push(const T &elem, int64 score) {
void push(T elem, int64 score) {
if (queue_.size() > 4 * size_) resize();
if (queue_.size() >= size_ && queue_[size_ - 1].second > score) return;
if (sorted && queue_.size() >= size_ && queue_[size_ - 1].second > score)
return;
queue_.emplace_back(elem, score);
}
@ -109,16 +89,11 @@ class BoundedPriorityQueue {
return (p1.second > p2.second ||
(p1.second == p2.second && p1.first < p2.first));
});
absl::flat_hash_set<absl::string_view> dup;
std::vector<std::pair<T, int64>> new_queue;
for (auto &p : queue_) {
if (dup.insert(p.first).second) new_queue.emplace_back(std::move(p));
if (new_queue.size() == size_) break;
}
queue_ = std::move(new_queue);
sorted = true;
if (queue_.size() > size_) queue_.resize(size_);
}
bool sorted = false;
size_t size_ = 0;
std::vector<std::pair<T, int64>> queue_;
};
@ -249,8 +224,7 @@ TrainerModel::SentencePieces Trainer::MakeSeedSentencePiecesInternal() {
n, kAlphabetSize, node_num));
LOG(INFO) << "Extracting frequent sub strings... node_num=" << node_num;
BoundedPriorityQueue<std::string> queue(
BoundedPriorityQueue<node_int_type> queue(
static_cast<size_t>(trainer_spec_.seed_sentencepiece_size()));
for (node_int_type i = 0; i < node_num; ++i) {
@ -259,31 +233,21 @@ TrainerModel::SentencePieces Trainer::MakeSeedSentencePiecesInternal() {
if (len <= 1) {
continue;
}
for (const auto &p :
SplitBySentenceBoundary(&array[offset], &array[offset + len])) {
if (p.first == p.second) continue;
const auto [begin, end] = NormalizeRange(p.first, p.second);
const char32 *begin = &array[offset];
const char32 *end = &array[offset + len];
// Skips if a substring contains a sentence boundary.
if (std::find(begin, end, kSentenceBoundary) != end) {
continue;
}
const UnicodeText uw(begin, end);
if (uw.size() <= 1 || !IsValidSentencePiece(uw)) {
if (!IsValidSentencePiece(uw)) {
continue;
}
// character-wise coverage is the default score.
const node_int_type freq = R[i] - L[i];
const node_int_type score = freq * len;
const auto w = string_util::UnicodeTextToUTF8(uw);
queue.push(w, score);
const auto subpieces =
SplitIntoWords(w, trainer_spec_.treat_whitespace_as_suffix(),
trainer_spec_.allow_whitespace_only_pieces());
if (subpieces.size() > 1) {
for (const auto &s : subpieces) queue.push(std::string(s), score);
}
}
queue.push(i, score);
}
// all_chars must be included in the seed sentencepieces.
@ -293,7 +257,16 @@ TrainerModel::SentencePieces Trainer::MakeSeedSentencePiecesInternal() {
}
for (const auto &p : queue.get()) {
seed_sentencepieces.emplace_back(p);
const node_int_type offset = SA[L[p.first]];
const node_int_type len = D[p.first];
CHECK_GT(len, 0);
const char32 *begin = &array[offset];
const char32 *end = &array[offset + len];
const UnicodeText uw(begin, end);
const std::string w = string_util::UnicodeTextToUTF8(uw);
CHECK(IsValidSentencePiece(uw)); // just in case.
CHECK(!port::ContainsKey(all_chars, w));
seed_sentencepieces.emplace_back(w, p.second);
}
ToLogProb(seed_sentencepieces.begin(), seed_sentencepieces.end());
@ -496,10 +469,10 @@ TrainerModel::SentencePieces Trainer::PruneSentencePieces(
// After removing the sentencepiece[i], its frequency freq[i] is
// re-assigned to alternatives.
// new_sum = current_sum - freq[i] + freq[i] * alternatives.size()
// = current_sum + freq[i] (alternatives - 1)
// new_sum = current_sum - freq[i] + freq[i] * alternatives[i].size()
// = current_sum + freq[i] * (alternatives[i] - 1)
const float logsum_alt = std::log(
static_cast<double>(sum + freq[i] * (alternatives.size() - 1)));
static_cast<double>(sum + freq[i] * (alternatives[i].size() - 1)));
// The frequencies of altenatives are increased by freq[i].
float logprob_alt = 0.0;
@ -530,22 +503,6 @@ TrainerModel::SentencePieces Trainer::PruneSentencePieces(
return new_sentencepieces;
}
std::pair<const char32 *, const char32 *> Trainer::NormalizeRange(
const char32 *begin, const char32 *end) const {
if (trainer_spec_.treat_whitespace_as_suffix()) {
while ((*begin == kSentenceBoundary || *begin == kWsMarker) &&
begin + 1 < end)
++begin;
while (*(end - 1) == kSentenceBoundary && begin + 1 < end) --end;
} else {
while (*begin == kSentenceBoundary && begin + 1 < end) ++begin;
while ((*(end - 1) == kSentenceBoundary || *(end - 1) == kWsMarker) &&
begin + 1 < end)
--end;
}
return std::make_pair(begin, end);
}
TrainerModel::SentencePieces Trainer::FinalizeSentencePieces(
const TrainerModel &model) const {
const auto &sentencepieces = model.GetSentencePieces();

View File

@ -105,9 +105,6 @@ class Trainer : public TrainerInterface {
TrainerModel::SentencePieces FinalizeSentencePieces(
const TrainerModel &model) const;
std::pair<const char32 *, const char32 *> NormalizeRange(
const char32 *begin, const char32 *end) const;
// When the size of SentencePieces becomes less than desired_vocab_size_,
// break the main training loop. desired_vocab_size_ = 1.1 * vocab_size_
// for now.

View File

@ -117,12 +117,12 @@ TEST(UnigramTrainerTest, BasicTest) {
30);
// Check seed pieces.
EXPECT_EQ(63, res.seed_pieces_and_probs.size());
EXPECT_EQ(27, res.seed_pieces_and_probs.size());
LOG(INFO) << "[" << res.sentence_pieces << "]";
// Check final pieces.
EXPECT_EQ(
"Overly Pineapple magnanimity Available ▁an ▁ a b A t g r P O v m y p n "
"l h d e i",
EXPECT_EQ("i a n y m l e apple ve O P r g t an v ▁ b A le ▁an p d h",
res.sentence_pieces);
}
@ -134,7 +134,8 @@ TEST(UnigramTrainerTest, BasicDPTest) {
"Overly \t 6", "Available \t 5"},
22, true /*use_dp*/, 0 /*dp_noise*/, 4 /*dp_clipping*/);
EXPECT_EQ(49, res.seed_pieces_and_probs.size());
// Got 16 instead of 27 seeds.
EXPECT_EQ(16, res.seed_pieces_and_probs.size());
// And they are equiv to if the last sentence was not there.
const auto& res_nodp = RunTrainer(
@ -195,9 +196,9 @@ TEST(UnigramTrainerTest, EndToEndTest) {
LOG(INFO) << "[" << absl::StrJoin(tok, " ") << std::endl;
EXPECT_EQ(
WS
" 吾輩 《 わ が は い 》 は猫である 。 名前はまだ 無 い 。 どこ で 生 "
"れた か とん と 見当 《 けん とう 》 が つか ぬ 。 何でも 薄 暗 い じめ "
"じめ した 所で ニャーニャー 泣 い ていた 事 だけは 記憶 している 。",
" 吾輩 《 わが はい 》 は猫である 。 名前はまだ 無 い 。 どこ で 生 れた "
"か とん と 見当 《 けん とう 》 が つか ぬ 。 何でも 薄 暗 い じめ じめ "
"した 所で ニャーニャー 泣 い ていた 事 だけは 記憶 している 。",
absl::StrJoin(tok, " "));
#endif
}