mirror of
https://github.com/moses-smt/mosesdecoder.git
synced 2024-12-28 06:22:14 +03:00
161 lines
5.6 KiB
C++
161 lines
5.6 KiB
C++
#include "lm/builder/interpolate.hh"
|
|
|
|
#include "lm/builder/hash_gamma.hh"
|
|
#include "lm/builder/joint_order.hh"
|
|
#include "lm/builder/ngram_stream.hh"
|
|
#include "lm/builder/sort.hh"
|
|
#include "lm/lm_exception.hh"
|
|
#include "util/fixed_array.hh"
|
|
#include "util/murmur_hash.hh"
|
|
|
|
#include <assert.h>
|
|
#include <math.h>
|
|
|
|
namespace lm { namespace builder {
|
|
namespace {
|
|
|
|
/* Calculate q, the collapsed probability and backoff, as defined in
|
|
* @inproceedings{Heafield-rest,
|
|
* author = {Kenneth Heafield and Philipp Koehn and Alon Lavie},
|
|
* title = {Language Model Rest Costs and Space-Efficient Storage},
|
|
* year = {2012},
|
|
* month = {July},
|
|
* booktitle = {Proceedings of the Joint Conference on Empirical Methods in Natural Language Processing and Computational Natural Language Learning},
|
|
* address = {Jeju Island, Korea},
|
|
* pages = {1169--1178},
|
|
* url = {http://kheafield.com/professional/edinburgh/rest\_paper.pdf},
|
|
* }
|
|
* This is particularly convenient to calculate during interpolation because
|
|
* the needed backoff terms are already accessed at the same time.
|
|
*/
|
|
class OutputQ {
|
|
public:
|
|
explicit OutputQ(std::size_t order) : q_delta_(order) {}
|
|
|
|
void Gram(unsigned order_minus_1, float full_backoff, ProbBackoff &out) {
|
|
float &q_del = q_delta_[order_minus_1];
|
|
if (order_minus_1) {
|
|
// Divide by context's backoff (which comes in as out.backoff)
|
|
q_del = q_delta_[order_minus_1 - 1] / out.backoff * full_backoff;
|
|
} else {
|
|
q_del = full_backoff;
|
|
}
|
|
out.prob = log10f(out.prob * q_del);
|
|
// TODO: stop wastefully outputting this!
|
|
out.backoff = 0.0;
|
|
}
|
|
|
|
private:
|
|
// Product of backoffs in the numerator divided by backoffs in the
|
|
// denominator. Does not include
|
|
std::vector<float> q_delta_;
|
|
};
|
|
|
|
/* Default: output probability and backoff */
|
|
class OutputProbBackoff {
|
|
public:
|
|
explicit OutputProbBackoff(std::size_t /*order*/) {}
|
|
|
|
void Gram(unsigned /*order_minus_1*/, float full_backoff, ProbBackoff &out) const {
|
|
// Correcting for numerical precision issues. Take that IRST.
|
|
out.prob = std::min(0.0f, log10f(out.prob));
|
|
out.backoff = log10f(full_backoff);
|
|
}
|
|
};
|
|
|
|
template <class Output> class Callback {
|
|
public:
|
|
Callback(float uniform_prob, const util::stream::ChainPositions &backoffs, const std::vector<uint64_t> &prune_thresholds, bool prune_vocab)
|
|
: backoffs_(backoffs.size()), probs_(backoffs.size() + 2),
|
|
prune_thresholds_(prune_thresholds),
|
|
prune_vocab_(prune_vocab),
|
|
output_(backoffs.size() + 1 /* order */) {
|
|
probs_[0] = uniform_prob;
|
|
for (std::size_t i = 0; i < backoffs.size(); ++i) {
|
|
backoffs_.push_back(backoffs[i]);
|
|
}
|
|
}
|
|
|
|
~Callback() {
|
|
for (std::size_t i = 0; i < backoffs_.size(); ++i) {
|
|
if(prune_vocab_ || prune_thresholds_[i + 1] > 0)
|
|
while(backoffs_[i])
|
|
++backoffs_[i];
|
|
|
|
if (backoffs_[i]) {
|
|
std::cerr << "Backoffs do not match for order " << (i + 1) << std::endl;
|
|
abort();
|
|
}
|
|
}
|
|
}
|
|
|
|
void Enter(unsigned order_minus_1, NGram &gram) {
|
|
Payload &pay = gram.Value();
|
|
pay.complete.prob = pay.uninterp.prob + pay.uninterp.gamma * probs_[order_minus_1];
|
|
probs_[order_minus_1 + 1] = pay.complete.prob;
|
|
|
|
float out_backoff;
|
|
if (order_minus_1 < backoffs_.size() && *(gram.end() - 1) != kUNK && *(gram.end() - 1) != kEOS && backoffs_[order_minus_1]) {
|
|
if(prune_vocab_ || prune_thresholds_[order_minus_1 + 1] > 0) {
|
|
//Compute hash value for current context
|
|
uint64_t current_hash = util::MurmurHashNative(gram.begin(), gram.Order() * sizeof(WordIndex));
|
|
|
|
const HashGamma *hashed_backoff = static_cast<const HashGamma*>(backoffs_[order_minus_1].Get());
|
|
while(current_hash != hashed_backoff->hash_value && ++backoffs_[order_minus_1])
|
|
hashed_backoff = static_cast<const HashGamma*>(backoffs_[order_minus_1].Get());
|
|
|
|
if(current_hash == hashed_backoff->hash_value) {
|
|
out_backoff = hashed_backoff->gamma;
|
|
++backoffs_[order_minus_1];
|
|
} else {
|
|
// Has been pruned away so it is not a context anymore
|
|
out_backoff = 1.0;
|
|
}
|
|
} else {
|
|
out_backoff = *static_cast<const float*>(backoffs_[order_minus_1].Get());
|
|
++backoffs_[order_minus_1];
|
|
}
|
|
} else {
|
|
// Not a context.
|
|
out_backoff = 1.0;
|
|
}
|
|
|
|
output_.Gram(order_minus_1, out_backoff, pay.complete);
|
|
}
|
|
|
|
void Exit(unsigned, const NGram &) const {}
|
|
|
|
private:
|
|
util::FixedArray<util::stream::Stream> backoffs_;
|
|
|
|
std::vector<float> probs_;
|
|
const std::vector<uint64_t>& prune_thresholds_;
|
|
bool prune_vocab_;
|
|
|
|
Output output_;
|
|
};
|
|
} // namespace
|
|
|
|
Interpolate::Interpolate(uint64_t vocab_size, const util::stream::ChainPositions &backoffs, const std::vector<uint64_t>& prune_thresholds, bool prune_vocab, bool output_q)
|
|
: uniform_prob_(1.0 / static_cast<float>(vocab_size)), // Includes <unk> but excludes <s>.
|
|
backoffs_(backoffs),
|
|
prune_thresholds_(prune_thresholds),
|
|
prune_vocab_(prune_vocab),
|
|
output_q_(output_q) {}
|
|
|
|
// perform order-wise interpolation
|
|
void Interpolate::Run(const util::stream::ChainPositions &positions) {
|
|
assert(positions.size() == backoffs_.size() + 1);
|
|
if (output_q_) {
|
|
typedef Callback<OutputQ> C;
|
|
C callback(uniform_prob_, backoffs_, prune_thresholds_, prune_vocab_);
|
|
JointOrder<C, SuffixOrder>(positions, callback);
|
|
} else {
|
|
typedef Callback<OutputProbBackoff> C;
|
|
C callback(uniform_prob_, backoffs_, prune_thresholds_, prune_vocab_);
|
|
JointOrder<C, SuffixOrder>(positions, callback);
|
|
}
|
|
}
|
|
|
|
}} // namespaces
|