mirror of
https://github.com/moses-smt/mosesdecoder.git
synced 2024-12-27 22:14:57 +03:00
309 lines
14 KiB
C++
309 lines
14 KiB
C++
#include "lm/model.hh"
|
|
|
|
#include "lm/blank.hh"
|
|
#include "lm/lm_exception.hh"
|
|
#include "lm/search_hashed.hh"
|
|
#include "lm/search_trie.hh"
|
|
#include "lm/read_arpa.hh"
|
|
#include "util/have.hh"
|
|
#include "util/murmur_hash.hh"
|
|
|
|
#include <algorithm>
|
|
#include <functional>
|
|
#include <numeric>
|
|
#include <cmath>
|
|
#include <limits>
|
|
|
|
namespace lm {
|
|
namespace ngram {
|
|
namespace detail {
|
|
|
|
template <class Search, class VocabularyT> const ModelType GenericModel<Search, VocabularyT>::kModelType = Search::kModelType;
|
|
|
|
template <class Search, class VocabularyT> uint64_t GenericModel<Search, VocabularyT>::Size(const std::vector<uint64_t> &counts, const Config &config) {
|
|
return VocabularyT::Size(counts[0], config) + Search::Size(counts, config);
|
|
}
|
|
|
|
template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::SetupMemory(void *base, const std::vector<uint64_t> &counts, const Config &config) {
|
|
size_t goal_size = util::CheckOverflow(Size(counts, config));
|
|
uint8_t *start = static_cast<uint8_t*>(base);
|
|
size_t allocated = VocabularyT::Size(counts[0], config);
|
|
vocab_.SetupMemory(start, allocated, counts[0], config);
|
|
start += allocated;
|
|
start = search_.SetupMemory(start, counts, config);
|
|
if (static_cast<std::size_t>(start - static_cast<uint8_t*>(base)) != goal_size) UTIL_THROW(FormatLoadException, "The data structures took " << (start - static_cast<uint8_t*>(base)) << " but Size says they should take " << goal_size);
|
|
}
|
|
|
|
template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::GenericModel(const char *file, const Config &config) {
|
|
LoadLM(file, config, *this);
|
|
|
|
// g++ prints warnings unless these are fully initialized.
|
|
State begin_sentence = State();
|
|
begin_sentence.length = 1;
|
|
begin_sentence.words[0] = vocab_.BeginSentence();
|
|
typename Search::Node ignored_node;
|
|
bool ignored_independent_left;
|
|
uint64_t ignored_extend_left;
|
|
begin_sentence.backoff[0] = search_.LookupUnigram(begin_sentence.words[0], ignored_node, ignored_independent_left, ignored_extend_left).Backoff();
|
|
State null_context = State();
|
|
null_context.length = 0;
|
|
P::Init(begin_sentence, null_context, vocab_, search_.Order());
|
|
}
|
|
|
|
namespace {
|
|
void CheckCounts(const std::vector<uint64_t> &counts) {
|
|
UTIL_THROW_IF(counts.size() > KENLM_MAX_ORDER, FormatLoadException, "This model has order " << counts.size() << " but KenLM was compiled to support up to " << KENLM_MAX_ORDER << ". " << KENLM_ORDER_MESSAGE);
|
|
if (sizeof(uint64_t) > sizeof(std::size_t)) {
|
|
for (std::vector<uint64_t>::const_iterator i = counts.begin(); i != counts.end(); ++i) {
|
|
UTIL_THROW_IF(*i > static_cast<uint64_t>(std::numeric_limits<size_t>::max()), util::OverflowException, "This model has " << *i << " " << (i - counts.begin() + 1) << "-grams which is too many for 32-bit machines.");
|
|
}
|
|
}
|
|
}
|
|
} // namespace
|
|
|
|
template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromBinary(void *start, const Parameters ¶ms, const Config &config, int fd) {
|
|
CheckCounts(params.counts);
|
|
SetupMemory(start, params.counts, config);
|
|
vocab_.LoadedBinary(params.fixed.has_vocabulary, fd, config.enumerate_vocab);
|
|
search_.LoadedBinary();
|
|
}
|
|
|
|
template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromARPA(const char *file, const Config &config) {
|
|
// Backing file is the ARPA. Steal it so we can make the backing file the mmap output if any.
|
|
util::FilePiece f(backing_.file.release(), file, config.ProgressMessages());
|
|
try {
|
|
std::vector<uint64_t> counts;
|
|
// File counts do not include pruned trigrams that extend to quadgrams etc. These will be fixed by search_.
|
|
ReadARPACounts(f, counts);
|
|
CheckCounts(counts);
|
|
if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model.");
|
|
if (config.probing_multiplier <= 1.0) UTIL_THROW(ConfigException, "probing multiplier must be > 1.0");
|
|
|
|
std::size_t vocab_size = util::CheckOverflow(VocabularyT::Size(counts[0], config));
|
|
// Setup the binary file for writing the vocab lookup table. The search_ is responsible for growing the binary file to its needs.
|
|
vocab_.SetupMemory(SetupJustVocab(config, counts.size(), vocab_size, backing_), vocab_size, counts[0], config);
|
|
|
|
if (config.write_mmap) {
|
|
WriteWordsWrapper wrap(config.enumerate_vocab);
|
|
vocab_.ConfigureEnumerate(&wrap, counts[0]);
|
|
search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_);
|
|
wrap.Write(backing_.file.get(), backing_.vocab.size() + vocab_.UnkCountChangePadding() + Search::Size(counts, config));
|
|
} else {
|
|
vocab_.ConfigureEnumerate(config.enumerate_vocab, counts[0]);
|
|
search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_);
|
|
}
|
|
|
|
if (!vocab_.SawUnk()) {
|
|
assert(config.unknown_missing != THROW_UP);
|
|
// Default probabilities for unknown.
|
|
search_.UnknownUnigram().backoff = 0.0;
|
|
search_.UnknownUnigram().prob = config.unknown_missing_logprob;
|
|
}
|
|
FinishFile(config, kModelType, kVersion, counts, vocab_.UnkCountChangePadding(), backing_);
|
|
} catch (util::Exception &e) {
|
|
e << " Byte: " << f.Offset();
|
|
throw;
|
|
}
|
|
}
|
|
|
|
template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config) {
|
|
util::AdvanceOrThrow(fd, VocabularyT::Size(counts[0], config));
|
|
Search::UpdateConfigFromBinary(fd, counts, config);
|
|
}
|
|
|
|
template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::FullScore(const State &in_state, const WordIndex new_word, State &out_state) const {
|
|
FullScoreReturn ret = ScoreExceptBackoff(in_state.words, in_state.words + in_state.length, new_word, out_state);
|
|
for (const float *i = in_state.backoff + ret.ngram_length - 1; i < in_state.backoff + in_state.length; ++i) {
|
|
ret.prob += *i;
|
|
}
|
|
return ret;
|
|
}
|
|
|
|
template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::FullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const {
|
|
context_rend = std::min(context_rend, context_rbegin + P::Order() - 1);
|
|
FullScoreReturn ret = ScoreExceptBackoff(context_rbegin, context_rend, new_word, out_state);
|
|
|
|
// Add the backoff weights for n-grams of order start to (context_rend - context_rbegin).
|
|
unsigned char start = ret.ngram_length;
|
|
if (context_rend - context_rbegin < static_cast<std::ptrdiff_t>(start)) return ret;
|
|
|
|
bool independent_left;
|
|
uint64_t extend_left;
|
|
typename Search::Node node;
|
|
if (start <= 1) {
|
|
ret.prob += search_.LookupUnigram(*context_rbegin, node, independent_left, extend_left).Backoff();
|
|
start = 2;
|
|
} else if (!search_.FastMakeNode(context_rbegin, context_rbegin + start - 1, node)) {
|
|
return ret;
|
|
}
|
|
// i is the order of the backoff we're looking for.
|
|
unsigned char order_minus_2 = start - 2;
|
|
for (const WordIndex *i = context_rbegin + start - 1; i < context_rend; ++i, ++order_minus_2) {
|
|
typename Search::MiddlePointer p(search_.LookupMiddle(order_minus_2, *i, node, independent_left, extend_left));
|
|
if (!p.Found()) break;
|
|
ret.prob += p.Backoff();
|
|
}
|
|
return ret;
|
|
}
|
|
|
|
template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::GetState(const WordIndex *context_rbegin, const WordIndex *context_rend, State &out_state) const {
|
|
// Generate a state from context.
|
|
context_rend = std::min(context_rend, context_rbegin + P::Order() - 1);
|
|
if (context_rend == context_rbegin) {
|
|
out_state.length = 0;
|
|
return;
|
|
}
|
|
typename Search::Node node;
|
|
bool independent_left;
|
|
uint64_t extend_left;
|
|
out_state.backoff[0] = search_.LookupUnigram(*context_rbegin, node, independent_left, extend_left).Backoff();
|
|
out_state.length = HasExtension(out_state.backoff[0]) ? 1 : 0;
|
|
float *backoff_out = out_state.backoff + 1;
|
|
unsigned char order_minus_2 = 0;
|
|
for (const WordIndex *i = context_rbegin + 1; i < context_rend; ++i, ++backoff_out, ++order_minus_2) {
|
|
typename Search::MiddlePointer p(search_.LookupMiddle(order_minus_2, *i, node, independent_left, extend_left));
|
|
if (!p.Found()) {
|
|
std::copy(context_rbegin, context_rbegin + out_state.length, out_state.words);
|
|
return;
|
|
}
|
|
*backoff_out = p.Backoff();
|
|
if (HasExtension(*backoff_out)) out_state.length = i - context_rbegin + 1;
|
|
}
|
|
std::copy(context_rbegin, context_rbegin + out_state.length, out_state.words);
|
|
}
|
|
|
|
template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::ExtendLeft(
|
|
const WordIndex *add_rbegin, const WordIndex *add_rend,
|
|
const float *backoff_in,
|
|
uint64_t extend_pointer,
|
|
unsigned char extend_length,
|
|
float *backoff_out,
|
|
unsigned char &next_use) const {
|
|
FullScoreReturn ret;
|
|
typename Search::Node node;
|
|
if (extend_length == 1) {
|
|
typename Search::UnigramPointer ptr(search_.LookupUnigram(static_cast<WordIndex>(extend_pointer), node, ret.independent_left, ret.extend_left));
|
|
ret.rest = ptr.Rest();
|
|
ret.prob = ptr.Prob();
|
|
assert(!ret.independent_left);
|
|
} else {
|
|
typename Search::MiddlePointer ptr(search_.Unpack(extend_pointer, extend_length, node));
|
|
ret.rest = ptr.Rest();
|
|
ret.prob = ptr.Prob();
|
|
ret.extend_left = extend_pointer;
|
|
// If this function is called, then it does depend on left words.
|
|
ret.independent_left = false;
|
|
}
|
|
float subtract_me = ret.rest;
|
|
ret.ngram_length = extend_length;
|
|
next_use = extend_length;
|
|
ResumeScore(add_rbegin, add_rend, extend_length - 1, node, backoff_out, next_use, ret);
|
|
next_use -= extend_length;
|
|
// Charge backoffs.
|
|
for (const float *b = backoff_in + ret.ngram_length - extend_length; b < backoff_in + (add_rend - add_rbegin); ++b) ret.prob += *b;
|
|
ret.prob -= subtract_me;
|
|
ret.rest -= subtract_me;
|
|
return ret;
|
|
}
|
|
|
|
namespace {
|
|
// Do a paraonoid copy of history, assuming new_word has already been copied
|
|
// (hence the -1). out_state.length could be zero so I avoided using
|
|
// std::copy.
|
|
void CopyRemainingHistory(const WordIndex *from, State &out_state) {
|
|
WordIndex *out = out_state.words + 1;
|
|
const WordIndex *in_end = from + static_cast<ptrdiff_t>(out_state.length) - 1;
|
|
for (const WordIndex *in = from; in < in_end; ++in, ++out) *out = *in;
|
|
}
|
|
} // namespace
|
|
|
|
/* Ugly optimized function. Produce a score excluding backoff.
|
|
* The search goes in increasing order of ngram length.
|
|
* Context goes backward, so context_begin is the word immediately preceeding
|
|
* new_word.
|
|
*/
|
|
template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::ScoreExceptBackoff(
|
|
const WordIndex *const context_rbegin,
|
|
const WordIndex *const context_rend,
|
|
const WordIndex new_word,
|
|
State &out_state) const {
|
|
assert(new_word < vocab_.Bound());
|
|
FullScoreReturn ret;
|
|
// ret.ngram_length contains the last known non-blank ngram length.
|
|
ret.ngram_length = 1;
|
|
|
|
typename Search::Node node;
|
|
typename Search::UnigramPointer uni(search_.LookupUnigram(new_word, node, ret.independent_left, ret.extend_left));
|
|
out_state.backoff[0] = uni.Backoff();
|
|
ret.prob = uni.Prob();
|
|
ret.rest = uni.Rest();
|
|
|
|
// This is the length of the context that should be used for continuation to the right.
|
|
out_state.length = HasExtension(out_state.backoff[0]) ? 1 : 0;
|
|
// We'll write the word anyway since it will probably be used and does no harm being there.
|
|
out_state.words[0] = new_word;
|
|
if (context_rbegin == context_rend) return ret;
|
|
|
|
ResumeScore(context_rbegin, context_rend, 0, node, out_state.backoff + 1, out_state.length, ret);
|
|
CopyRemainingHistory(context_rbegin, out_state);
|
|
return ret;
|
|
}
|
|
|
|
template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::ResumeScore(const WordIndex *hist_iter, const WordIndex *const context_rend, unsigned char order_minus_2, typename Search::Node &node, float *backoff_out, unsigned char &next_use, FullScoreReturn &ret) const {
|
|
for (; ; ++order_minus_2, ++hist_iter, ++backoff_out) {
|
|
if (hist_iter == context_rend) return;
|
|
if (ret.independent_left) return;
|
|
if (order_minus_2 == P::Order() - 2) break;
|
|
|
|
typename Search::MiddlePointer pointer(search_.LookupMiddle(order_minus_2, *hist_iter, node, ret.independent_left, ret.extend_left));
|
|
if (!pointer.Found()) return;
|
|
*backoff_out = pointer.Backoff();
|
|
ret.prob = pointer.Prob();
|
|
ret.rest = pointer.Rest();
|
|
ret.ngram_length = order_minus_2 + 2;
|
|
if (HasExtension(*backoff_out)) {
|
|
next_use = ret.ngram_length;
|
|
}
|
|
}
|
|
ret.independent_left = true;
|
|
typename Search::LongestPointer longest(search_.LookupLongest(*hist_iter, node));
|
|
if (longest.Found()) {
|
|
ret.prob = longest.Prob();
|
|
ret.rest = ret.prob;
|
|
// There is no blank in longest_.
|
|
ret.ngram_length = P::Order();
|
|
}
|
|
}
|
|
|
|
template <class Search, class VocabularyT> float GenericModel<Search, VocabularyT>::InternalUnRest(const uint64_t *pointers_begin, const uint64_t *pointers_end, unsigned char first_length) const {
|
|
float ret;
|
|
typename Search::Node node;
|
|
if (first_length == 1) {
|
|
if (pointers_begin >= pointers_end) return 0.0;
|
|
bool independent_left;
|
|
uint64_t extend_left;
|
|
typename Search::UnigramPointer ptr(search_.LookupUnigram(static_cast<WordIndex>(*pointers_begin), node, independent_left, extend_left));
|
|
ret = ptr.Prob() - ptr.Rest();
|
|
++first_length;
|
|
++pointers_begin;
|
|
} else {
|
|
ret = 0.0;
|
|
}
|
|
for (const uint64_t *i = pointers_begin; i < pointers_end; ++i, ++first_length) {
|
|
typename Search::MiddlePointer ptr(search_.Unpack(*i, first_length, node));
|
|
ret += ptr.Prob() - ptr.Rest();
|
|
}
|
|
return ret;
|
|
}
|
|
|
|
template class GenericModel<HashedSearch<BackoffValue>, ProbingVocabulary>;
|
|
template class GenericModel<HashedSearch<RestValue>, ProbingVocabulary>;
|
|
template class GenericModel<trie::TrieSearch<DontQuantize, trie::DontBhiksha>, SortedVocabulary>;
|
|
template class GenericModel<trie::TrieSearch<DontQuantize, trie::ArrayBhiksha>, SortedVocabulary>;
|
|
template class GenericModel<trie::TrieSearch<SeparatelyQuantize, trie::DontBhiksha>, SortedVocabulary>;
|
|
template class GenericModel<trie::TrieSearch<SeparatelyQuantize, trie::ArrayBhiksha>, SortedVocabulary>;
|
|
|
|
} // namespace detail
|
|
} // namespace ngram
|
|
} // namespace lm
|