mosesdecoder/kenlm/lm/model.cc
2010-12-08 03:15:37 +00:00

241 lines
11 KiB
C++

#include "lm/model.hh"
#include "lm/lm_exception.hh"
#include "lm/search_hashed.hh"
#include "lm/search_trie.hh"
#include "lm/read_arpa.hh"
#include "util/murmur_hash.hh"
#include <algorithm>
#include <functional>
#include <numeric>
#include <cmath>
namespace lm {
namespace ngram {
size_t hash_value(const State &state) {
return util::MurmurHashNative(state.history_, sizeof(WordIndex) * state.valid_length_);
}
namespace detail {
template <class Search, class VocabularyT> size_t GenericModel<Search, VocabularyT>::Size(const std::vector<uint64_t> &counts, const Config &config) {
if (counts.size() > kMaxOrder) UTIL_THROW(FormatLoadException, "This model has order " << counts.size() << ". Edit ngram.hh's kMaxOrder to at least this value and recompile.");
if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model.");
if (config.probing_multiplier <= 1.0) UTIL_THROW(ConfigException, "probing multiplier must be > 1.0");
return VocabularyT::Size(counts[0], config) + Search::Size(counts, config);
}
template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::SetupMemory(void *base, const std::vector<uint64_t> &counts, const Config &config) {
uint8_t *start = static_cast<uint8_t*>(base);
size_t allocated = VocabularyT::Size(counts[0], config);
vocab_.SetupMemory(start, allocated, counts[0], config);
start += allocated;
start = search_.SetupMemory(start, counts, config);
if (static_cast<std::size_t>(start - static_cast<uint8_t*>(base)) != Size(counts, config)) UTIL_THROW(FormatLoadException, "The data structures took " << (start - static_cast<uint8_t*>(base)) << " but Size says they should take " << Size(counts, config));
}
template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::GenericModel(const char *file, const Config &config) {
LoadLM(file, config, *this);
// g++ prints warnings unless these are fully initialized.
State begin_sentence = State();
begin_sentence.valid_length_ = 1;
begin_sentence.history_[0] = vocab_.BeginSentence();
begin_sentence.backoff_[0] = search_.unigram.Lookup(begin_sentence.history_[0]).backoff;
State null_context = State();
null_context.valid_length_ = 0;
P::Init(begin_sentence, null_context, vocab_, search_.middle.size() + 2);
}
template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromBinary(void *start, const Parameters &params, const Config &config, int fd) {
SetupMemory(start, params.counts, config);
vocab_.LoadedBinary(fd, config.enumerate_vocab);
search_.unigram.LoadedBinary();
for (typename std::vector<Middle>::iterator i = search_.middle.begin(); i != search_.middle.end(); ++i) {
i->LoadedBinary();
}
search_.longest.LoadedBinary();
}
template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromARPA(const char *file, util::FilePiece &f, void *start, const Parameters &params, const Config &config) {
SetupMemory(start, params.counts, config);
if (config.write_mmap) {
WriteWordsWrapper wrap(config.enumerate_vocab, backing_.file.get());
vocab_.ConfigureEnumerate(&wrap, params.counts[0]);
search_.InitializeFromARPA(file, f, params.counts, config, vocab_);
} else {
vocab_.ConfigureEnumerate(config.enumerate_vocab, params.counts[0]);
search_.InitializeFromARPA(file, f, params.counts, config, vocab_);
}
// TODO: fail faster?
if (!vocab_.SawUnk()) {
switch(config.unknown_missing) {
case Config::THROW_UP:
{
SpecialWordMissingException e("<unk>");
e << " and configuration was set to throw if unknown is missing";
throw e;
}
case Config::COMPLAIN:
if (config.messages) *config.messages << "Language model is missing <unk>. Substituting probability " << config.unknown_missing_prob << "." << std::endl;
// There's no break;. This is by design.
case Config::SILENT:
// Default probabilities for unknown.
search_.unigram.Unknown().backoff = 0.0;
search_.unigram.Unknown().prob = config.unknown_missing_prob;
break;
}
}
if (std::fabs(search_.unigram.Unknown().backoff) > 0.0000001) UTIL_THROW(FormatLoadException, "Backoff for unknown word should be zero, but was given as " << search_.unigram.Unknown().backoff);
}
template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::FullScore(const State &in_state, const WordIndex new_word, State &out_state) const {
unsigned char backoff_start;
FullScoreReturn ret = ScoreExceptBackoff(in_state.history_, in_state.history_ + in_state.valid_length_, new_word, backoff_start, out_state);
if (backoff_start - 1 < in_state.valid_length_) {
ret.prob = std::accumulate(in_state.backoff_ + backoff_start - 1, in_state.backoff_ + in_state.valid_length_, ret.prob);
}
return ret;
}
template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::FullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const {
unsigned char backoff_start;
context_rend = std::min(context_rend, context_rbegin + P::Order() - 1);
FullScoreReturn ret = ScoreExceptBackoff(context_rbegin, context_rend, new_word, backoff_start, out_state);
ret.prob += SlowBackoffLookup(context_rbegin, context_rend, backoff_start);
return ret;
}
template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::GetState(const WordIndex *context_rbegin, const WordIndex *context_rend, State &out_state) const {
context_rend = std::min(context_rend, context_rbegin + P::Order() - 1);
if (context_rend == context_rbegin || *context_rbegin == 0) {
out_state.valid_length_ = 0;
return;
}
float ignored_prob;
typename Search::Node node;
search_.LookupUnigram(*context_rbegin, ignored_prob, out_state.backoff_[0], node);
float *backoff_out = out_state.backoff_ + 1;
const WordIndex *i = context_rbegin + 1;
for (; i < context_rend; ++i, ++backoff_out) {
if (!search_.LookupMiddleNoProb(search_.middle[i - context_rbegin - 1], *i, *backoff_out, node)) {
out_state.valid_length_ = i - context_rbegin;
std::copy(context_rbegin, i, out_state.history_);
return;
}
}
std::copy(context_rbegin, context_rend, out_state.history_);
out_state.valid_length_ = static_cast<unsigned char>(context_rend - context_rbegin);
}
template <class Search, class VocabularyT> float GenericModel<Search, VocabularyT>::SlowBackoffLookup(
const WordIndex *const context_rbegin, const WordIndex *const context_rend, unsigned char start) const {
// Add the backoff weights for n-grams of order start to (context_rend - context_rbegin).
if (context_rend - context_rbegin < static_cast<std::ptrdiff_t>(start)) return 0.0;
float ret = 0.0;
if (start == 1) {
ret += search_.unigram.Lookup(*context_rbegin).backoff;
start = 2;
}
typename Search::Node node;
if (!search_.FastMakeNode(context_rbegin, context_rbegin + start - 1, node)) {
return 0.0;
}
float backoff;
// i is the order of the backoff we're looking for.
for (const WordIndex *i = context_rbegin + start - 1; i < context_rend; ++i) {
if (!search_.LookupMiddleNoProb(search_.middle[i - context_rbegin - 1], *i, backoff, node)) break;
ret += backoff;
}
return ret;
}
/* Ugly optimized function. Produce a score excluding backoff.
* The search goes in increasing order of ngram length.
* Context goes backward, so context_begin is the word immediately preceeding
* new_word.
*/
template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::ScoreExceptBackoff(
const WordIndex *context_rbegin,
const WordIndex *context_rend,
const WordIndex new_word,
unsigned char &backoff_start,
State &out_state) const {
FullScoreReturn ret;
typename Search::Node node;
float *backoff_out(out_state.backoff_);
search_.LookupUnigram(new_word, ret.prob, *backoff_out, node);
if (new_word == 0) {
ret.ngram_length = out_state.valid_length_ = 0;
// All of backoff.
backoff_start = 1;
return ret;
}
out_state.history_[0] = new_word;
if (context_rbegin == context_rend) {
ret.ngram_length = out_state.valid_length_ = 1;
// No backoff because we don't have the history for it.
backoff_start = P::Order();
return ret;
}
++backoff_out;
// Ok now we now that the bigram contains known words. Start by looking it up.
const WordIndex *hist_iter = context_rbegin;
typename std::vector<Middle>::const_iterator mid_iter = search_.middle.begin();
for (; ; ++mid_iter, ++hist_iter, ++backoff_out) {
if (hist_iter == context_rend) {
// Ran out of history. No backoff.
backoff_start = P::Order();
std::copy(context_rbegin, context_rend, out_state.history_ + 1);
ret.ngram_length = out_state.valid_length_ = (context_rend - context_rbegin) + 1;
// ret.prob was already set.
return ret;
}
if (mid_iter == search_.middle.end()) break;
if (!search_.LookupMiddle(*mid_iter, *hist_iter, ret.prob, *backoff_out, node)) {
// Didn't find an ngram using hist_iter.
// The history used in the found n-gram is [context_rbegin, hist_iter).
std::copy(context_rbegin, hist_iter, out_state.history_ + 1);
// Therefore, we found a (hist_iter - context_rbegin + 1)-gram including the last word.
ret.ngram_length = out_state.valid_length_ = (hist_iter - context_rbegin) + 1;
backoff_start = mid_iter - search_.middle.begin() + 1;
// ret.prob was already set.
return ret;
}
}
// It passed every lookup in search_.middle. That means it's at least a (P::Order() - 1)-gram.
// All that's left is to check search_.longest.
if (!search_.LookupLongest(*hist_iter, ret.prob, node)) {
// It's an (P::Order()-1)-gram
std::copy(context_rbegin, context_rbegin + P::Order() - 2, out_state.history_ + 1);
ret.ngram_length = out_state.valid_length_ = P::Order() - 1;
backoff_start = P::Order() - 1;
// ret.prob was already set.
return ret;
}
// It's an P::Order()-gram
// out_state.valid_length_ is still P::Order() - 1 because the next lookup will only need that much.
std::copy(context_rbegin, context_rbegin + P::Order() - 2, out_state.history_ + 1);
out_state.valid_length_ = P::Order() - 1;
ret.ngram_length = P::Order();
backoff_start = P::Order();
return ret;
}
template class GenericModel<ProbingHashedSearch, ProbingVocabulary>;
template class GenericModel<SortedHashedSearch, SortedVocabulary>;
template class GenericModel<trie::TrieSearch, SortedVocabulary>;
} // namespace detail
} // namespace ngram
} // namespace lm