2013-01-18 19:58:54 +04:00
|
|
|
#include "lm/builder/interpolate.hh"
|
|
|
|
|
2014-06-02 21:28:02 +04:00
|
|
|
#include "lm/builder/hash_gamma.hh"
|
2013-01-18 19:58:54 +04:00
|
|
|
#include "lm/builder/joint_order.hh"
|
2014-06-02 21:28:02 +04:00
|
|
|
#include "lm/builder/ngram_stream.hh"
|
2013-01-18 19:58:54 +04:00
|
|
|
#include "lm/builder/sort.hh"
|
|
|
|
#include "lm/lm_exception.hh"
|
2014-06-02 21:28:02 +04:00
|
|
|
#include "util/fixed_array.hh"
|
|
|
|
#include "util/murmur_hash.hh"
|
2013-01-18 19:58:54 +04:00
|
|
|
|
|
|
|
#include <assert.h>
|
|
|
|
|
|
|
|
namespace lm { namespace builder {
|
|
|
|
namespace {
|
|
|
|
|
|
|
|
class Callback {
|
|
|
|
public:
|
2014-06-02 21:28:02 +04:00
|
|
|
Callback(float uniform_prob, const util::stream::ChainPositions &backoffs, const std::vector<uint64_t> &prune_thresholds)
|
|
|
|
: backoffs_(backoffs.size()), probs_(backoffs.size() + 2), prune_thresholds_(prune_thresholds) {
|
2013-01-18 19:58:54 +04:00
|
|
|
probs_[0] = uniform_prob;
|
|
|
|
for (std::size_t i = 0; i < backoffs.size(); ++i) {
|
|
|
|
backoffs_.push_back(backoffs[i]);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
~Callback() {
|
|
|
|
for (std::size_t i = 0; i < backoffs_.size(); ++i) {
|
2014-07-27 22:35:15 +04:00
|
|
|
if(prune_thresholds_[i + 1] > 0)
|
|
|
|
while(backoffs_[i])
|
|
|
|
++backoffs_[i];
|
|
|
|
|
2013-01-18 19:58:54 +04:00
|
|
|
if (backoffs_[i]) {
|
|
|
|
std::cerr << "Backoffs do not match for order " << (i + 1) << std::endl;
|
|
|
|
abort();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
void Enter(unsigned order_minus_1, NGram &gram) {
|
|
|
|
Payload &pay = gram.Value();
|
|
|
|
pay.complete.prob = pay.uninterp.prob + pay.uninterp.gamma * probs_[order_minus_1];
|
|
|
|
probs_[order_minus_1 + 1] = pay.complete.prob;
|
|
|
|
pay.complete.prob = log10(pay.complete.prob);
|
2014-06-02 21:28:02 +04:00
|
|
|
|
2014-01-28 04:51:35 +04:00
|
|
|
if (order_minus_1 < backoffs_.size() && *(gram.end() - 1) != kUNK && *(gram.end() - 1) != kEOS) {
|
2014-06-02 21:28:02 +04:00
|
|
|
// This skips over ngrams if backoffs have been exhausted.
|
|
|
|
if(!backoffs_[order_minus_1]) {
|
|
|
|
pay.complete.backoff = 0.0;
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
|
|
|
if(prune_thresholds_[order_minus_1 + 1] > 0) {
|
|
|
|
//Compute hash value for current context
|
|
|
|
uint64_t current_hash = util::MurmurHashNative(gram.begin(), gram.Order() * sizeof(WordIndex));
|
|
|
|
|
|
|
|
const HashGamma *hashed_backoff = static_cast<const HashGamma*>(backoffs_[order_minus_1].Get());
|
2014-07-27 22:35:15 +04:00
|
|
|
while(current_hash != hashed_backoff->hash_value && ++backoffs_[order_minus_1])
|
2014-06-02 21:28:02 +04:00
|
|
|
hashed_backoff = static_cast<const HashGamma*>(backoffs_[order_minus_1].Get());
|
2014-07-27 22:35:15 +04:00
|
|
|
|
2014-06-02 21:28:02 +04:00
|
|
|
if(current_hash == hashed_backoff->hash_value) {
|
|
|
|
pay.complete.backoff = log10(hashed_backoff->gamma);
|
|
|
|
++backoffs_[order_minus_1];
|
|
|
|
} else {
|
|
|
|
// Has been pruned away so it is not a context anymore
|
|
|
|
pay.complete.backoff = 0.0;
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
pay.complete.backoff = log10(*static_cast<const float*>(backoffs_[order_minus_1].Get()));
|
|
|
|
++backoffs_[order_minus_1];
|
|
|
|
}
|
2013-01-18 19:58:54 +04:00
|
|
|
} else {
|
2014-06-02 21:28:02 +04:00
|
|
|
// Not a context.
|
2013-01-18 19:58:54 +04:00
|
|
|
pay.complete.backoff = 0.0;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
void Exit(unsigned, const NGram &) const {}
|
|
|
|
|
|
|
|
private:
|
2014-06-02 21:28:02 +04:00
|
|
|
util::FixedArray<util::stream::Stream> backoffs_;
|
2013-01-18 19:58:54 +04:00
|
|
|
|
|
|
|
std::vector<float> probs_;
|
2014-06-02 21:28:02 +04:00
|
|
|
const std::vector<uint64_t>& prune_thresholds_;
|
2013-01-18 19:58:54 +04:00
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2014-06-02 21:28:02 +04:00
|
|
|
Interpolate::Interpolate(uint64_t vocab_size, const util::stream::ChainPositions &backoffs, const std::vector<uint64_t>& prune_thresholds)
|
|
|
|
: uniform_prob_(1.0 / static_cast<float>(vocab_size)), // Includes <unk> but excludes <s>.
|
|
|
|
backoffs_(backoffs),
|
|
|
|
prune_thresholds_(prune_thresholds) {}
|
2013-01-18 19:58:54 +04:00
|
|
|
|
|
|
|
// perform order-wise interpolation
|
2014-06-02 21:28:02 +04:00
|
|
|
void Interpolate::Run(const util::stream::ChainPositions &positions) {
|
2013-01-18 19:58:54 +04:00
|
|
|
assert(positions.size() == backoffs_.size() + 1);
|
2014-06-02 21:28:02 +04:00
|
|
|
Callback callback(uniform_prob_, backoffs_, prune_thresholds_);
|
2013-01-18 19:58:54 +04:00
|
|
|
JointOrder<Callback, SuffixOrder>(positions, callback);
|
|
|
|
}
|
|
|
|
|
|
|
|
}} // namespaces
|