mirror of
https://github.com/moses-smt/mosesdecoder.git
synced 2024-12-27 14:05:29 +03:00
277 lines
9.0 KiB
C++
277 lines
9.0 KiB
C++
#include "lm/builder/initial_probabilities.hh"
|
|
|
|
#include "lm/builder/discount.hh"
|
|
#include "lm/builder/ngram_stream.hh"
|
|
#include "lm/builder/sort.hh"
|
|
#include "lm/builder/hash_gamma.hh"
|
|
#include "util/murmur_hash.hh"
|
|
#include "util/file.hh"
|
|
#include "util/stream/chain.hh"
|
|
#include "util/stream/io.hh"
|
|
#include "util/stream/stream.hh"
|
|
|
|
#include <vector>
|
|
|
|
namespace lm { namespace builder {
|
|
|
|
namespace {
|
|
struct BufferEntry {
|
|
// Gamma from page 20 of Chen and Goodman.
|
|
float gamma;
|
|
// \sum_w a(c w) for all w.
|
|
float denominator;
|
|
};
|
|
|
|
struct HashBufferEntry : public BufferEntry {
|
|
// Hash value of ngram. Used to join contexts with backoffs.
|
|
uint64_t hash_value;
|
|
};
|
|
|
|
// Reads all entries in order like NGramStream does.
|
|
// But deletes any entries that have CutoffCount below or equal to pruning
|
|
// threshold.
|
|
class PruneNGramStream {
|
|
public:
|
|
PruneNGramStream(const util::stream::ChainPosition &position) :
|
|
current_(NULL, NGram::OrderFromSize(position.GetChain().EntrySize())),
|
|
dest_(NULL, NGram::OrderFromSize(position.GetChain().EntrySize())),
|
|
currentCount_(0),
|
|
block_(position)
|
|
{
|
|
StartBlock();
|
|
}
|
|
|
|
NGram &operator*() { return current_; }
|
|
NGram *operator->() { return ¤t_; }
|
|
|
|
operator bool() const {
|
|
return block_;
|
|
}
|
|
|
|
PruneNGramStream &operator++() {
|
|
assert(block_);
|
|
|
|
if (current_.Order() > 1) {
|
|
if(currentCount_ > 0) {
|
|
if(dest_.Base() < current_.Base()) {
|
|
memcpy(dest_.Base(), current_.Base(), current_.TotalSize());
|
|
}
|
|
dest_.NextInMemory();
|
|
}
|
|
} else {
|
|
dest_.NextInMemory();
|
|
}
|
|
|
|
current_.NextInMemory();
|
|
|
|
uint8_t *block_base = static_cast<uint8_t*>(block_->Get());
|
|
if (current_.Base() == block_base + block_->ValidSize()) {
|
|
block_->SetValidSize(dest_.Base() - block_base);
|
|
++block_;
|
|
StartBlock();
|
|
}
|
|
|
|
currentCount_ = current_.CutoffCount();
|
|
|
|
return *this;
|
|
}
|
|
|
|
private:
|
|
void StartBlock() {
|
|
for (; ; ++block_) {
|
|
if (!block_) return;
|
|
if (block_->ValidSize()) break;
|
|
}
|
|
current_.ReBase(block_->Get());
|
|
currentCount_ = current_.CutoffCount();
|
|
|
|
dest_.ReBase(block_->Get());
|
|
}
|
|
|
|
NGram current_; // input iterator
|
|
NGram dest_; // output iterator
|
|
|
|
uint64_t currentCount_;
|
|
|
|
util::stream::Link block_;
|
|
};
|
|
|
|
// Extract an array of HashedGamma from an array of BufferEntry.
|
|
class OnlyGamma {
|
|
public:
|
|
OnlyGamma(bool pruning) : pruning_(pruning) {}
|
|
|
|
void Run(const util::stream::ChainPosition &position) {
|
|
for (util::stream::Link block_it(position); block_it; ++block_it) {
|
|
if(pruning_) {
|
|
const HashBufferEntry *in = static_cast<const HashBufferEntry*>(block_it->Get());
|
|
const HashBufferEntry *end = static_cast<const HashBufferEntry*>(block_it->ValidEnd());
|
|
|
|
// Just make it point to the beginning of the stream so it can be overwritten
|
|
// With HashGamma values. Do not attempt to interpret the values until set below.
|
|
HashGamma *out = static_cast<HashGamma*>(block_it->Get());
|
|
for (; in < end; out += 1, in += 1) {
|
|
// buffering, otherwise might overwrite values too early
|
|
float gamma_buf = in->gamma;
|
|
uint64_t hash_buf = in->hash_value;
|
|
|
|
out->gamma = gamma_buf;
|
|
out->hash_value = hash_buf;
|
|
}
|
|
block_it->SetValidSize((block_it->ValidSize() * sizeof(HashGamma)) / sizeof(HashBufferEntry));
|
|
}
|
|
else {
|
|
float *out = static_cast<float*>(block_it->Get());
|
|
const float *in = out;
|
|
const float *end = static_cast<const float*>(block_it->ValidEnd());
|
|
for (out += 1, in += 2; in < end; out += 1, in += 2) {
|
|
*out = *in;
|
|
}
|
|
block_it->SetValidSize(block_it->ValidSize() / 2);
|
|
}
|
|
}
|
|
}
|
|
|
|
private:
|
|
bool pruning_;
|
|
};
|
|
|
|
class AddRight {
|
|
public:
|
|
AddRight(const Discount &discount, const util::stream::ChainPosition &input, bool pruning)
|
|
: discount_(discount), input_(input), pruning_(pruning) {}
|
|
|
|
void Run(const util::stream::ChainPosition &output) {
|
|
NGramStream in(input_);
|
|
util::stream::Stream out(output);
|
|
|
|
std::vector<WordIndex> previous(in->Order() - 1);
|
|
// Silly windows requires this workaround to just get an invalid pointer when empty.
|
|
void *const previous_raw = previous.empty() ? NULL : static_cast<void*>(&previous[0]);
|
|
const std::size_t size = sizeof(WordIndex) * previous.size();
|
|
|
|
for(; in; ++out) {
|
|
memcpy(previous_raw, in->begin(), size);
|
|
uint64_t denominator = 0;
|
|
uint64_t normalizer = 0;
|
|
|
|
uint64_t counts[4];
|
|
memset(counts, 0, sizeof(counts));
|
|
do {
|
|
denominator += in->UnmarkedCount();
|
|
|
|
// Collect unused probability mass from pruning.
|
|
// Becomes 0 for unpruned ngrams.
|
|
normalizer += in->UnmarkedCount() - in->CutoffCount();
|
|
|
|
// Chen&Goodman do not mention counting based on cutoffs, but
|
|
// backoff becomes larger than 1 otherwise, so probably needs
|
|
// to count cutoffs. Counts normally without pruning.
|
|
if(in->CutoffCount() > 0)
|
|
++counts[std::min(in->CutoffCount(), static_cast<uint64_t>(3))];
|
|
|
|
} while (++in && !memcmp(previous_raw, in->begin(), size));
|
|
|
|
BufferEntry &entry = *reinterpret_cast<BufferEntry*>(out.Get());
|
|
entry.denominator = static_cast<float>(denominator);
|
|
entry.gamma = 0.0;
|
|
for (unsigned i = 1; i <= 3; ++i) {
|
|
entry.gamma += discount_.Get(i) * static_cast<float>(counts[i]);
|
|
}
|
|
|
|
// Makes model sum to 1 with pruning (I hope).
|
|
entry.gamma += normalizer;
|
|
|
|
entry.gamma /= entry.denominator;
|
|
|
|
if(pruning_) {
|
|
// If pruning is enabled the stream actually contains HashBufferEntry, see InitialProbabilities(...),
|
|
// so add a hash value that identifies the current ngram.
|
|
static_cast<HashBufferEntry*>(&entry)->hash_value = util::MurmurHashNative(previous_raw, size);
|
|
}
|
|
}
|
|
out.Poison();
|
|
}
|
|
|
|
private:
|
|
const Discount &discount_;
|
|
const util::stream::ChainPosition input_;
|
|
bool pruning_;
|
|
};
|
|
|
|
class MergeRight {
|
|
public:
|
|
MergeRight(bool interpolate_unigrams, const util::stream::ChainPosition &from_adder, const Discount &discount)
|
|
: interpolate_unigrams_(interpolate_unigrams), from_adder_(from_adder), discount_(discount) {}
|
|
|
|
// calculate the initial probability of each n-gram (before order-interpolation)
|
|
// Run() gets invoked once for each order
|
|
void Run(const util::stream::ChainPosition &primary) {
|
|
util::stream::Stream summed(from_adder_);
|
|
|
|
PruneNGramStream grams(primary);
|
|
|
|
// Without interpolation, the interpolation weight goes to <unk>.
|
|
if (grams->Order() == 1 && !interpolate_unigrams_) {
|
|
BufferEntry sums(*static_cast<const BufferEntry*>(summed.Get()));
|
|
assert(*grams->begin() == kUNK);
|
|
grams->Value().uninterp.prob = sums.gamma;
|
|
grams->Value().uninterp.gamma = 0.0;
|
|
while (++grams) {
|
|
grams->Value().uninterp.prob = discount_.Apply(grams->Count()) / sums.denominator;
|
|
grams->Value().uninterp.gamma = 0.0;
|
|
}
|
|
++summed;
|
|
return;
|
|
}
|
|
|
|
std::vector<WordIndex> previous(grams->Order() - 1);
|
|
const std::size_t size = sizeof(WordIndex) * previous.size();
|
|
for (; grams; ++summed) {
|
|
memcpy(&previous[0], grams->begin(), size);
|
|
const BufferEntry &sums = *static_cast<const BufferEntry*>(summed.Get());
|
|
|
|
do {
|
|
Payload &pay = grams->Value();
|
|
pay.uninterp.prob = discount_.Apply(grams->UnmarkedCount()) / sums.denominator;
|
|
pay.uninterp.gamma = sums.gamma;
|
|
} while (++grams && !memcmp(&previous[0], grams->begin(), size));
|
|
}
|
|
}
|
|
|
|
private:
|
|
bool interpolate_unigrams_;
|
|
util::stream::ChainPosition from_adder_;
|
|
Discount discount_;
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void InitialProbabilities(
|
|
const InitialProbabilitiesConfig &config,
|
|
const std::vector<Discount> &discounts,
|
|
util::stream::Chains &primary,
|
|
util::stream::Chains &second_in,
|
|
util::stream::Chains &gamma_out,
|
|
const std::vector<uint64_t> &prune_thresholds) {
|
|
for (size_t i = 0; i < primary.size(); ++i) {
|
|
util::stream::ChainConfig gamma_config = config.adder_out;
|
|
if(prune_thresholds[i] > 0)
|
|
gamma_config.entry_size = sizeof(HashBufferEntry);
|
|
else
|
|
gamma_config.entry_size = sizeof(BufferEntry);
|
|
|
|
util::stream::ChainPosition second(second_in[i].Add());
|
|
second_in[i] >> util::stream::kRecycle;
|
|
gamma_out.push_back(gamma_config);
|
|
gamma_out[i] >> AddRight(discounts[i], second, prune_thresholds[i] > 0);
|
|
|
|
primary[i] >> MergeRight(config.interpolate_unigrams, gamma_out[i].Add(), discounts[i]);
|
|
|
|
// Don't bother with the OnlyGamma thread for something to discard.
|
|
if (i) gamma_out[i] >> OnlyGamma(prune_thresholds[i] > 0);
|
|
}
|
|
}
|
|
|
|
}} // namespaces
|