mosesdecoder/lm/search_trie.hh

134 lines
4.7 KiB
C++

#ifndef LM_SEARCH_TRIE__
#define LM_SEARCH_TRIE__
#include "lm/config.hh"
#include "lm/model_type.hh"
#include "lm/return.hh"
#include "lm/trie.hh"
#include "lm/weights.hh"
#include "util/file.hh"
#include "util/file_piece.hh"
#include <vector>
#include <assert.h>
namespace lm {
namespace ngram {
struct Backing;
class SortedVocabulary;
namespace trie {
template <class Quant, class Bhiksha> class TrieSearch;
class SortedFiles;
template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing);
template <class Quant, class Bhiksha> class TrieSearch {
public:
typedef NodeRange Node;
typedef ::lm::ngram::trie::Unigram Unigram;
Unigram unigram;
typedef trie::BitPackedMiddle<typename Quant::Middle, Bhiksha> Middle;
typedef trie::BitPackedLongest<typename Quant::Longest> Longest;
Longest longest;
static const ModelType kModelType = static_cast<ModelType>(TRIE_SORTED + Quant::kModelTypeAdd + Bhiksha::kModelTypeAdd);
static const unsigned int kVersion = 1;
static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config) {
Quant::UpdateConfigFromBinary(fd, counts, config);
util::AdvanceOrThrow(fd, Quant::Size(counts.size(), config) + Unigram::Size(counts[0]));
Bhiksha::UpdateConfigFromBinary(fd, config);
}
static std::size_t Size(const std::vector<uint64_t> &counts, const Config &config) {
std::size_t ret = Quant::Size(counts.size(), config) + Unigram::Size(counts[0]);
for (unsigned char i = 1; i < counts.size() - 1; ++i) {
ret += Middle::Size(Quant::MiddleBits(config), counts[i], counts[0], counts[i+1], config);
}
return ret + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]);
}
TrieSearch() : middle_begin_(NULL), middle_end_(NULL) {}
~TrieSearch() { FreeMiddles(); }
uint8_t *SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config);
void LoadedBinary();
typedef const Middle *MiddleIter;
const Middle *MiddleBegin() const { return middle_begin_; }
const Middle *MiddleEnd() const { return middle_end_; }
void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing);
void LookupUnigram(WordIndex word, float &backoff, Node &node, FullScoreReturn &ret) const {
unigram.Find(word, ret.prob, backoff, node);
ret.independent_left = (node.begin == node.end);
ret.extend_left = static_cast<uint64_t>(word);
}
bool LookupMiddle(const Middle &mid, WordIndex word, float &backoff, Node &node, FullScoreReturn &ret) const {
if (!mid.Find(word, ret.prob, backoff, node, ret.extend_left)) return false;
ret.independent_left = (node.begin == node.end);
return true;
}
bool LookupMiddleNoProb(const Middle &mid, WordIndex word, float &backoff, Node &node) const {
return mid.FindNoProb(word, backoff, node);
}
bool LookupLongest(WordIndex word, float &prob, const Node &node) const {
return longest.Find(word, prob, node);
}
bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const {
// TODO: don't decode backoff.
assert(begin != end);
FullScoreReturn ignored;
float ignored_backoff;
LookupUnigram(*begin, ignored_backoff, node, ignored);
for (const WordIndex *i = begin + 1; i < end; ++i) {
if (!LookupMiddleNoProb(middle_begin_[i - begin - 1], *i, ignored_backoff, node)) return false;
}
return true;
}
Node Unpack(uint64_t extend_pointer, unsigned char extend_length, float &prob) const {
if (extend_length == 1) {
float ignored;
Node ret;
unigram.Find(static_cast<WordIndex>(extend_pointer), prob, ignored, ret);
return ret;
}
return middle_begin_[extend_length - 2].ReadEntry(extend_pointer, prob);
}
private:
friend void BuildTrie<Quant, Bhiksha>(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing);
// Middles are managed manually so we can delay construction and they don't have to be copyable.
void FreeMiddles() {
for (const Middle *i = middle_begin_; i != middle_end_; ++i) {
i->~Middle();
}
free(middle_begin_);
}
Middle *middle_begin_, *middle_end_;
Quant quant_;
};
} // namespace trie
} // namespace ngram
} // namespace lm
#endif // LM_SEARCH_TRIE__