mosesdecoder/lm/builder/adjust_counts.cc
2013-01-18 15:59:51 +00:00

217 lines
7.1 KiB
C++

#include "lm/builder/adjust_counts.hh"
#include "lm/builder/multi_stream.hh"
#include "util/stream/timer.hh"
#include <algorithm>
namespace lm { namespace builder {
BadDiscountException::BadDiscountException() throw() {}
BadDiscountException::~BadDiscountException() throw() {}
namespace {
// Return last word in full that is different.
const WordIndex* FindDifference(const NGram &full, const NGram &lower_last) {
const WordIndex *cur_word = full.end() - 1;
const WordIndex *pre_word = lower_last.end() - 1;
// Find last difference.
for (; pre_word >= lower_last.begin() && *pre_word == *cur_word; --cur_word, --pre_word) {}
return cur_word;
}
class StatCollector {
public:
StatCollector(std::size_t order, std::vector<uint64_t> &counts, std::vector<Discount> &discounts)
: orders_(order), full_(orders_.back()), counts_(counts), discounts_(discounts) {
memset(&orders_[0], 0, sizeof(OrderStat) * order);
}
~StatCollector() {}
void CalculateDiscounts() {
counts_.resize(orders_.size());
discounts_.resize(orders_.size());
for (std::size_t i = 0; i < orders_.size(); ++i) {
const OrderStat &s = orders_[i];
counts_[i] = s.count;
for (unsigned j = 1; j < 4; ++j) {
// TODO: Specialize error message for j == 3, meaning 3+
UTIL_THROW_IF(s.n[j] == 0, BadDiscountException, "Could not calculate Kneser-Ney discounts for "
<< (i+1) << "-grams with adjusted count " << (j+1) << " because we didn't observe any "
<< (i+1) << "-grams with adjusted count " << j << "; Is this small or artificial data?");
}
// See equation (26) in Chen and Goodman.
discounts_[i].amount[0] = 0.0;
float y = static_cast<float>(s.n[1]) / static_cast<float>(s.n[1] + 2.0 * s.n[2]);
for (unsigned j = 1; j < 4; ++j) {
discounts_[i].amount[j] = static_cast<float>(j) - static_cast<float>(j + 1) * y * static_cast<float>(s.n[j+1]) / static_cast<float>(s.n[j]);
UTIL_THROW_IF(discounts_[i].amount[j] < 0.0 || discounts_[i].amount[j] > j, BadDiscountException, "ERROR: " << (i+1) << "-gram discount out of range for adjusted count " << j << ": " << discounts_[i].amount[j]);
}
}
}
void Add(std::size_t order_minus_1, uint64_t count) {
OrderStat &stat = orders_[order_minus_1];
++stat.count;
if (count < 5) ++stat.n[count];
}
void AddFull(uint64_t count) {
++full_.count;
if (count < 5) ++full_.n[count];
}
private:
struct OrderStat {
// n_1 in equation 26 of Chen and Goodman etc
uint64_t n[5];
uint64_t count;
};
std::vector<OrderStat> orders_;
OrderStat &full_;
std::vector<uint64_t> &counts_;
std::vector<Discount> &discounts_;
};
// Reads all entries in order like NGramStream does.
// But deletes any entries that have <s> in the 1st (not 0th) position on the
// way out by putting other entries in their place. This disrupts the sort
// order but we don't care because the data is going to be sorted again.
class CollapseStream {
public:
CollapseStream(const util::stream::ChainPosition &position) :
current_(NULL, NGram::OrderFromSize(position.GetChain().EntrySize())),
block_(position) {
StartBlock();
}
const NGram &operator*() const { return current_; }
const NGram *operator->() const { return &current_; }
operator bool() const { return block_; }
CollapseStream &operator++() {
assert(block_);
if (current_.begin()[1] == kBOS && current_.Base() < copy_from_) {
memcpy(current_.Base(), copy_from_, current_.TotalSize());
UpdateCopyFrom();
}
current_.NextInMemory();
uint8_t *block_base = static_cast<uint8_t*>(block_->Get());
if (current_.Base() == block_base + block_->ValidSize()) {
block_->SetValidSize(copy_from_ + current_.TotalSize() - block_base);
++block_;
StartBlock();
}
return *this;
}
private:
void StartBlock() {
for (; ; ++block_) {
if (!block_) return;
if (block_->ValidSize()) break;
}
current_.ReBase(block_->Get());
copy_from_ = static_cast<uint8_t*>(block_->Get()) + block_->ValidSize();
UpdateCopyFrom();
}
// Find last without bos.
void UpdateCopyFrom() {
for (copy_from_ -= current_.TotalSize(); copy_from_ >= current_.Base(); copy_from_ -= current_.TotalSize()) {
if (NGram(copy_from_, current_.Order()).begin()[1] != kBOS) break;
}
}
NGram current_;
// Goes backwards in the block
uint8_t *copy_from_;
util::stream::Link block_;
};
} // namespace
void AdjustCounts::Run(const ChainPositions &positions) {
UTIL_TIMER("(%w s) Adjusted counts\n");
const std::size_t order = positions.size();
StatCollector stats(order, counts_, discounts_);
if (order == 1) {
// Only unigrams. Just collect stats.
for (NGramStream full(positions[0]); full; ++full)
stats.AddFull(full->Count());
stats.CalculateDiscounts();
return;
}
NGramStreams streams;
streams.Init(positions, positions.size() - 1);
CollapseStream full(positions[positions.size() - 1]);
// Initialization: <unk> has count 0 and so does <s>.
NGramStream *lower_valid = streams.begin();
streams[0]->Count() = 0;
*streams[0]->begin() = kUNK;
stats.Add(0, 0);
(++streams[0])->Count() = 0;
*streams[0]->begin() = kBOS;
// not in stats because it will get put in later.
// iterate over full (the stream of the highest order ngrams)
for (; full; ++full) {
const WordIndex *different = FindDifference(*full, **lower_valid);
std::size_t same = full->end() - 1 - different;
// Increment the adjusted count.
if (same) ++streams[same - 1]->Count();
// Output all the valid ones that changed.
for (; lower_valid >= &streams[same]; --lower_valid) {
stats.Add(lower_valid - streams.begin(), (*lower_valid)->Count());
++*lower_valid;
}
// This is here because bos is also const WordIndex *, so copy gets
// consistent argument types.
const WordIndex *full_end = full->end();
// Initialize and mark as valid up to bos.
const WordIndex *bos;
for (bos = different; (bos > full->begin()) && (*bos != kBOS); --bos) {
++lower_valid;
std::copy(bos, full_end, (*lower_valid)->begin());
(*lower_valid)->Count() = 1;
}
// Now bos indicates where <s> is or is the 0th word of full.
if (bos != full->begin()) {
// There is an <s> beyond the 0th word.
NGramStream &to = *++lower_valid;
std::copy(bos, full_end, to->begin());
to->Count() = full->Count();
} else {
stats.AddFull(full->Count());
}
assert(lower_valid >= &streams[0]);
}
// Output everything valid.
for (NGramStream *s = streams.begin(); s <= lower_valid; ++s) {
stats.Add(s - streams.begin(), (*s)->Count());
++*s;
}
// Poison everyone! Except the N-grams which were already poisoned by the input.
for (NGramStream *s = streams.begin(); s != streams.end(); ++s)
s->Poison();
stats.CalculateDiscounts();
// NOTE: See special early-return case for unigrams near the top of this function
}
}} // namespaces