mosesdecoder/lm/search_hashed.hh
2011-11-17 12:49:55 +00:00

180 lines
5.7 KiB
C++

#ifndef LM_SEARCH_HASHED__
#define LM_SEARCH_HASHED__
#include "lm/model_type.hh"
#include "lm/config.hh"
#include "lm/read_arpa.hh"
#include "lm/return.hh"
#include "lm/weights.hh"
#include "util/bit_packing.hh"
#include "util/key_value_packing.hh"
#include "util/probing_hash_table.hh"
#include <algorithm>
#include <iostream>
#include <vector>
namespace util { class FilePiece; }
namespace lm {
namespace ngram {
struct Backing;
namespace detail {
inline uint64_t CombineWordHash(uint64_t current, const WordIndex next) {
uint64_t ret = (current * 8978948897894561157ULL) ^ (static_cast<uint64_t>(1 + next) * 17894857484156487943ULL);
return ret;
}
struct HashedSearch {
typedef uint64_t Node;
class Unigram {
public:
Unigram() {}
Unigram(void *start, std::size_t /*allocated*/) : unigram_(static_cast<ProbBackoff*>(start)) {}
static std::size_t Size(uint64_t count) {
return (count + 1) * sizeof(ProbBackoff); // +1 for hallucinate <unk>
}
const ProbBackoff &Lookup(WordIndex index) const { return unigram_[index]; }
ProbBackoff &Unknown() { return unigram_[0]; }
void LoadedBinary() {}
// For building.
ProbBackoff *Raw() { return unigram_; }
private:
ProbBackoff *unigram_;
};
Unigram unigram;
void LookupUnigram(WordIndex word, float &backoff, Node &next, FullScoreReturn &ret) const {
const ProbBackoff &entry = unigram.Lookup(word);
util::FloatEnc val;
val.f = entry.prob;
ret.independent_left = (val.i & util::kSignBit);
ret.extend_left = static_cast<uint64_t>(word);
val.i |= util::kSignBit;
ret.prob = val.f;
backoff = entry.backoff;
next = static_cast<Node>(word);
}
};
template <class MiddleT, class LongestT> class TemplateHashedSearch : public HashedSearch {
public:
typedef MiddleT Middle;
typedef LongestT Longest;
Longest longest;
static const unsigned int kVersion = 0;
// TODO: move probing_multiplier here with next binary file format update.
static void UpdateConfigFromBinary(int, const std::vector<uint64_t> &, Config &) {}
static std::size_t Size(const std::vector<uint64_t> &counts, const Config &config) {
std::size_t ret = Unigram::Size(counts[0]);
for (unsigned char n = 1; n < counts.size() - 1; ++n) {
ret += Middle::Size(counts[n], config.probing_multiplier);
}
return ret + Longest::Size(counts.back(), config.probing_multiplier);
}
uint8_t *SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config);
template <class Voc> void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, Voc &vocab, Backing &backing);
const Middle *MiddleBegin() const { return &*middle_.begin(); }
const Middle *MiddleEnd() const { return &*middle_.end(); }
Node Unpack(uint64_t extend_pointer, unsigned char extend_length, float &prob) const {
util::FloatEnc val;
if (extend_length == 1) {
val.f = unigram.Lookup(static_cast<uint64_t>(extend_pointer)).prob;
} else {
typename Middle::ConstIterator found;
if (!middle_[extend_length - 2].Find(extend_pointer, found)) {
std::cerr << "Extend pointer " << extend_pointer << " should have been found for length " << (unsigned) extend_length << std::endl;
abort();
}
val.f = found->GetValue().prob;
}
val.i |= util::kSignBit;
prob = val.f;
return extend_pointer;
}
bool LookupMiddle(const Middle &middle, WordIndex word, float &backoff, Node &node, FullScoreReturn &ret) const {
node = CombineWordHash(node, word);
typename Middle::ConstIterator found;
if (!middle.Find(node, found)) return false;
util::FloatEnc enc;
enc.f = found->GetValue().prob;
ret.independent_left = (enc.i & util::kSignBit);
ret.extend_left = node;
enc.i |= util::kSignBit;
ret.prob = enc.f;
backoff = found->GetValue().backoff;
return true;
}
void LoadedBinary();
bool LookupMiddleNoProb(const Middle &middle, WordIndex word, float &backoff, Node &node) const {
node = CombineWordHash(node, word);
typename Middle::ConstIterator found;
if (!middle.Find(node, found)) return false;
backoff = found->GetValue().backoff;
return true;
}
bool LookupLongest(WordIndex word, float &prob, Node &node) const {
// Sign bit is always on because longest n-grams do not extend left.
node = CombineWordHash(node, word);
typename Longest::ConstIterator found;
if (!longest.Find(node, found)) return false;
prob = found->GetValue().prob;
return true;
}
// Geenrate a node without necessarily checking that it actually exists.
// Optionally return false if it's know to not exist.
bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const {
assert(begin != end);
node = static_cast<Node>(*begin);
for (const WordIndex *i = begin + 1; i < end; ++i) {
node = CombineWordHash(node, *i);
}
return true;
}
private:
std::vector<Middle> middle_;
};
// std::identity is an SGI extension :-(
struct IdentityHash : public std::unary_function<uint64_t, size_t> {
size_t operator()(uint64_t arg) const { return static_cast<size_t>(arg); }
};
struct ProbingHashedSearch : public TemplateHashedSearch<
util::ProbingHashTable<util::ByteAlignedPacking<uint64_t, ProbBackoff>, IdentityHash>,
util::ProbingHashTable<util::ByteAlignedPacking<uint64_t, Prob>, IdentityHash> > {
static const ModelType kModelType = HASH_PROBING;
};
} // namespace detail
} // namespace ngram
} // namespace lm
#endif // LM_SEARCH_HASHED__