mirror of
https://github.com/moses-smt/mosesdecoder.git
synced 2025-01-02 17:09:36 +03:00
66 lines
2.1 KiB
C++
66 lines
2.1 KiB
C++
#include "lm/builder/interpolate.hh"
|
|
|
|
#include "lm/builder/joint_order.hh"
|
|
#include "lm/builder/multi_stream.hh"
|
|
#include "lm/builder/sort.hh"
|
|
#include "lm/lm_exception.hh"
|
|
|
|
#include <assert.h>
|
|
|
|
namespace lm { namespace builder {
|
|
namespace {
|
|
|
|
class Callback {
|
|
public:
|
|
Callback(float uniform_prob, const ChainPositions &backoffs) : backoffs_(backoffs.size()), probs_(backoffs.size() + 2) {
|
|
probs_[0] = uniform_prob;
|
|
for (std::size_t i = 0; i < backoffs.size(); ++i) {
|
|
backoffs_.push_back(backoffs[i]);
|
|
}
|
|
}
|
|
|
|
~Callback() {
|
|
for (std::size_t i = 0; i < backoffs_.size(); ++i) {
|
|
if (backoffs_[i]) {
|
|
std::cerr << "Backoffs do not match for order " << (i + 1) << std::endl;
|
|
abort();
|
|
}
|
|
}
|
|
}
|
|
|
|
void Enter(unsigned order_minus_1, NGram &gram) {
|
|
Payload &pay = gram.Value();
|
|
pay.complete.prob = pay.uninterp.prob + pay.uninterp.gamma * probs_[order_minus_1];
|
|
probs_[order_minus_1 + 1] = pay.complete.prob;
|
|
pay.complete.prob = log10(pay.complete.prob);
|
|
// TODO: this is a hack to skip n-grams that don't appear as context. Pruning will require some different handling.
|
|
if (order_minus_1 < backoffs_.size() && *(gram.end() - 1) != kUNK && *(gram.end() - 1) != kEOS) {
|
|
pay.complete.backoff = log10(*static_cast<const float*>(backoffs_[order_minus_1].Get()));
|
|
++backoffs_[order_minus_1];
|
|
} else {
|
|
// Not a context.
|
|
pay.complete.backoff = 0.0;
|
|
}
|
|
}
|
|
|
|
void Exit(unsigned, const NGram &) const {}
|
|
|
|
private:
|
|
FixedArray<util::stream::Stream> backoffs_;
|
|
|
|
std::vector<float> probs_;
|
|
};
|
|
} // namespace
|
|
|
|
Interpolate::Interpolate(uint64_t unigram_count, const ChainPositions &backoffs)
|
|
: uniform_prob_(1.0 / static_cast<float>(unigram_count - 1)), backoffs_(backoffs) {}
|
|
|
|
// perform order-wise interpolation
|
|
void Interpolate::Run(const ChainPositions &positions) {
|
|
assert(positions.size() == backoffs_.size() + 1);
|
|
Callback callback(uniform_prob_, backoffs_);
|
|
JointOrder<Callback, SuffixOrder>(positions, callback);
|
|
}
|
|
|
|
}} // namespaces
|