2011-09-24 19:24:33 +04:00
|
|
|
/* Efficient left and right language model state for sentence fragments.
|
|
|
|
* Intended usage:
|
|
|
|
* Store ChartState with every chart entry.
|
|
|
|
* To do a rule application:
|
|
|
|
* 1. Make a ChartState object for your new entry.
|
|
|
|
* 2. Construct RuleScore.
|
|
|
|
* 3. Going from left to right, call Terminal or NonTerminal.
|
|
|
|
* For terminals, just pass the vocab id.
|
|
|
|
* For non-terminals, pass that non-terminal's ChartState.
|
|
|
|
* If your decoder expects scores inclusive of subtree scores (i.e. you
|
|
|
|
* label entries with the highest-scoring path), pass the non-terminal's
|
|
|
|
* score as prob.
|
|
|
|
* If your decoder expects relative scores and will walk the chart later,
|
|
|
|
* pass prob = 0.0.
|
|
|
|
* In other words, the only effect of prob is that it gets added to the
|
|
|
|
* returned log probability.
|
|
|
|
* 4. Call Finish. It returns the log probability.
|
|
|
|
*
|
|
|
|
* There's a couple more details:
|
|
|
|
* Do not pass <s> to Terminal as it is formally not a word in the sentence,
|
|
|
|
* only context. Instead, call BeginSentence. If called, it should be the
|
|
|
|
* first call after RuleScore is constructed (since <s> is always the
|
|
|
|
* leftmost).
|
|
|
|
*
|
|
|
|
* If the leftmost RHS is a non-terminal, it's faster to call BeginNonTerminal.
|
|
|
|
*
|
|
|
|
* Hashing and sorting comparison operators are provided. All state objects
|
|
|
|
* are POD. If you intend to use memcmp on raw state objects, you must call
|
|
|
|
* ZeroRemaining first, as the value of array entries beyond length is
|
|
|
|
* otherwise undefined.
|
|
|
|
*
|
|
|
|
* Usage is of course not limited to chart decoding. Anything that generates
|
|
|
|
* sentence fragments missing left context could benefit. For example, a
|
|
|
|
* phrase-based decoder could pre-score phrases, storing ChartState with each
|
|
|
|
* phrase, even if hypotheses are generated left-to-right.
|
|
|
|
*/
|
|
|
|
|
2011-09-21 20:06:48 +04:00
|
|
|
#ifndef LM_LEFT__
|
|
|
|
#define LM_LEFT__
|
|
|
|
|
2012-09-28 18:04:48 +04:00
|
|
|
#include "lm/max_order.hh"
|
2012-06-28 18:58:59 +04:00
|
|
|
#include "lm/state.hh"
|
2011-09-21 20:06:48 +04:00
|
|
|
#include "lm/return.hh"
|
|
|
|
|
2011-09-24 19:24:33 +04:00
|
|
|
#include "util/murmur_hash.hh"
|
|
|
|
|
2011-09-21 20:06:48 +04:00
|
|
|
#include <algorithm>
|
|
|
|
|
|
|
|
namespace lm {
|
|
|
|
namespace ngram {
|
|
|
|
|
|
|
|
template <class M> class RuleScore {
|
|
|
|
public:
|
2012-11-15 16:00:15 +04:00
|
|
|
explicit RuleScore(const M &model, ChartState &out) : model_(model), out_(&out), left_done_(false), prob_(0.0) {
|
2011-09-21 20:06:48 +04:00
|
|
|
out.left.length = 0;
|
|
|
|
out.right.length = 0;
|
|
|
|
}
|
|
|
|
|
|
|
|
void BeginSentence() {
|
2012-11-15 16:00:15 +04:00
|
|
|
out_->right = model_.BeginSentenceState();
|
|
|
|
// out_->left is empty.
|
2011-09-21 20:06:48 +04:00
|
|
|
left_done_ = true;
|
|
|
|
}
|
|
|
|
|
|
|
|
void Terminal(WordIndex word) {
|
2012-11-15 16:00:15 +04:00
|
|
|
State copy(out_->right);
|
|
|
|
FullScoreReturn ret(model_.FullScore(copy, word, out_->right));
|
2012-06-28 18:58:59 +04:00
|
|
|
if (left_done_) { prob_ += ret.prob; return; }
|
2011-10-11 22:40:00 +04:00
|
|
|
if (ret.independent_left) {
|
2012-06-28 18:58:59 +04:00
|
|
|
prob_ += ret.prob;
|
2011-10-11 22:40:00 +04:00
|
|
|
left_done_ = true;
|
|
|
|
return;
|
|
|
|
}
|
2012-11-15 16:00:15 +04:00
|
|
|
out_->left.pointers[out_->left.length++] = ret.extend_left;
|
2012-06-28 18:58:59 +04:00
|
|
|
prob_ += ret.rest;
|
2012-11-15 16:00:15 +04:00
|
|
|
if (out_->right.length != copy.length + 1)
|
2011-10-11 22:40:00 +04:00
|
|
|
left_done_ = true;
|
2011-09-21 20:06:48 +04:00
|
|
|
}
|
|
|
|
|
|
|
|
// Faster version of NonTerminal for the case where the rule begins with a non-terminal.
|
2012-06-28 18:58:59 +04:00
|
|
|
void BeginNonTerminal(const ChartState &in, float prob = 0.0) {
|
2011-09-21 20:06:48 +04:00
|
|
|
prob_ = prob;
|
2012-11-15 16:00:15 +04:00
|
|
|
*out_ = in;
|
2012-06-28 18:58:59 +04:00
|
|
|
left_done_ = in.left.full;
|
2011-09-21 20:06:48 +04:00
|
|
|
}
|
|
|
|
|
2012-06-28 18:58:59 +04:00
|
|
|
void NonTerminal(const ChartState &in, float prob = 0.0) {
|
2011-09-21 20:06:48 +04:00
|
|
|
prob_ += prob;
|
|
|
|
|
|
|
|
if (!in.left.length) {
|
2012-06-28 18:58:59 +04:00
|
|
|
if (in.left.full) {
|
2012-11-15 16:00:15 +04:00
|
|
|
for (const float *i = out_->right.backoff; i < out_->right.backoff + out_->right.length; ++i) prob_ += *i;
|
2011-09-21 20:06:48 +04:00
|
|
|
left_done_ = true;
|
2012-11-15 16:00:15 +04:00
|
|
|
out_->right = in.right;
|
2011-09-21 20:06:48 +04:00
|
|
|
}
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
2012-11-15 16:00:15 +04:00
|
|
|
if (!out_->right.length) {
|
|
|
|
out_->right = in.right;
|
2012-06-28 18:58:59 +04:00
|
|
|
if (left_done_) {
|
|
|
|
prob_ += model_.UnRest(in.left.pointers, in.left.pointers + in.left.length, 1);
|
|
|
|
return;
|
|
|
|
}
|
2012-11-15 16:00:15 +04:00
|
|
|
if (out_->left.length) {
|
2011-09-21 20:06:48 +04:00
|
|
|
left_done_ = true;
|
|
|
|
} else {
|
2012-11-15 16:00:15 +04:00
|
|
|
out_->left = in.left;
|
2012-06-28 18:58:59 +04:00
|
|
|
left_done_ = in.left.full;
|
2011-09-21 20:06:48 +04:00
|
|
|
}
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
2012-08-09 00:22:13 +04:00
|
|
|
float backoffs[KENLM_MAX_ORDER - 1], backoffs2[KENLM_MAX_ORDER - 1];
|
2011-09-21 20:06:48 +04:00
|
|
|
float *back = backoffs, *back2 = backoffs2;
|
2012-11-15 16:00:15 +04:00
|
|
|
unsigned char next_use = out_->right.length;
|
2011-10-19 14:00:57 +04:00
|
|
|
|
|
|
|
// First word
|
2012-11-15 16:00:15 +04:00
|
|
|
if (ExtendLeft(in, next_use, 1, out_->right.backoff, back)) return;
|
2011-10-20 20:34:15 +04:00
|
|
|
|
2011-10-19 14:00:57 +04:00
|
|
|
// Words after the first, so extending a bigram to begin with
|
2011-10-20 20:34:15 +04:00
|
|
|
for (unsigned char extend_length = 2; extend_length <= in.left.length; ++extend_length) {
|
|
|
|
if (ExtendLeft(in, next_use, extend_length, back, back2)) return;
|
2011-09-21 20:06:48 +04:00
|
|
|
std::swap(back, back2);
|
|
|
|
}
|
|
|
|
|
2012-06-28 18:58:59 +04:00
|
|
|
if (in.left.full) {
|
2011-09-21 20:06:48 +04:00
|
|
|
for (const float *i = back; i != back + next_use; ++i) prob_ += *i;
|
|
|
|
left_done_ = true;
|
2012-11-15 16:00:15 +04:00
|
|
|
out_->right = in.right;
|
2011-09-21 20:06:48 +04:00
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Right state was minimized, so it's already independent of the new words to the left.
|
|
|
|
if (in.right.length < in.left.length) {
|
2012-11-15 16:00:15 +04:00
|
|
|
out_->right = in.right;
|
2011-09-21 20:06:48 +04:00
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Shift exisiting words down.
|
2012-11-15 16:00:15 +04:00
|
|
|
for (WordIndex *i = out_->right.words + next_use - 1; i >= out_->right.words; --i) {
|
2011-09-21 20:06:48 +04:00
|
|
|
*(i + in.right.length) = *i;
|
|
|
|
}
|
|
|
|
// Add words from in.right.
|
2012-11-15 16:00:15 +04:00
|
|
|
std::copy(in.right.words, in.right.words + in.right.length, out_->right.words);
|
2011-09-21 20:06:48 +04:00
|
|
|
// Assemble backoff composed on the existing state's backoff followed by the new state's backoff.
|
2012-11-15 16:00:15 +04:00
|
|
|
std::copy(in.right.backoff, in.right.backoff + in.right.length, out_->right.backoff);
|
|
|
|
std::copy(back, back + next_use, out_->right.backoff + in.right.length);
|
|
|
|
out_->right.length = in.right.length + next_use;
|
2011-09-21 20:06:48 +04:00
|
|
|
}
|
|
|
|
|
|
|
|
float Finish() {
|
2011-10-11 22:40:00 +04:00
|
|
|
// A N-1-gram might extend left and right but we should still set full to true because it's an N-1-gram.
|
2012-11-15 16:00:15 +04:00
|
|
|
out_->left.full = left_done_ || (out_->left.length == model_.Order() - 1);
|
2011-09-21 20:06:48 +04:00
|
|
|
return prob_;
|
|
|
|
}
|
|
|
|
|
2012-06-28 18:58:59 +04:00
|
|
|
void Reset() {
|
|
|
|
prob_ = 0.0;
|
|
|
|
left_done_ = false;
|
2012-11-15 16:00:15 +04:00
|
|
|
out_->left.length = 0;
|
|
|
|
out_->right.length = 0;
|
|
|
|
}
|
|
|
|
void Reset(ChartState &replacement) {
|
|
|
|
out_ = &replacement;
|
|
|
|
Reset();
|
2012-06-28 18:58:59 +04:00
|
|
|
}
|
|
|
|
|
2011-09-21 20:06:48 +04:00
|
|
|
private:
|
2011-10-20 20:34:15 +04:00
|
|
|
bool ExtendLeft(const ChartState &in, unsigned char &next_use, unsigned char extend_length, const float *back_in, float *back_out) {
|
|
|
|
ProcessRet(model_.ExtendLeft(
|
2012-11-15 16:00:15 +04:00
|
|
|
out_->right.words, out_->right.words + next_use, // Words to extend into
|
2011-10-20 20:34:15 +04:00
|
|
|
back_in, // Backoffs to use
|
|
|
|
in.left.pointers[extend_length - 1], extend_length, // Words to be extended
|
|
|
|
back_out, // Backoffs for the next score
|
|
|
|
next_use)); // Length of n-gram to use in next scoring.
|
2012-11-15 16:00:15 +04:00
|
|
|
if (next_use != out_->right.length) {
|
2011-10-20 20:34:15 +04:00
|
|
|
left_done_ = true;
|
|
|
|
if (!next_use) {
|
|
|
|
// Early exit.
|
2012-11-15 16:00:15 +04:00
|
|
|
out_->right = in.right;
|
2012-06-28 18:58:59 +04:00
|
|
|
prob_ += model_.UnRest(in.left.pointers + extend_length, in.left.pointers + in.left.length, extend_length + 1);
|
2011-10-20 20:34:15 +04:00
|
|
|
return true;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// Continue scoring.
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
2011-09-21 20:06:48 +04:00
|
|
|
void ProcessRet(const FullScoreReturn &ret) {
|
2012-06-28 18:58:59 +04:00
|
|
|
if (left_done_) {
|
|
|
|
prob_ += ret.prob;
|
|
|
|
return;
|
|
|
|
}
|
2011-09-21 20:06:48 +04:00
|
|
|
if (ret.independent_left) {
|
2012-06-28 18:58:59 +04:00
|
|
|
prob_ += ret.prob;
|
2011-09-21 20:06:48 +04:00
|
|
|
left_done_ = true;
|
|
|
|
return;
|
|
|
|
}
|
2012-11-15 16:00:15 +04:00
|
|
|
out_->left.pointers[out_->left.length++] = ret.extend_left;
|
2012-06-28 18:58:59 +04:00
|
|
|
prob_ += ret.rest;
|
2011-09-21 20:06:48 +04:00
|
|
|
}
|
|
|
|
|
|
|
|
const M &model_;
|
|
|
|
|
2012-11-15 16:00:15 +04:00
|
|
|
ChartState *out_;
|
2011-09-21 20:06:48 +04:00
|
|
|
|
|
|
|
bool left_done_;
|
|
|
|
|
|
|
|
float prob_;
|
|
|
|
};
|
|
|
|
|
|
|
|
} // namespace ngram
|
|
|
|
} // namespace lm
|
|
|
|
|
|
|
|
#endif // LM_LEFT__
|