mirror of
https://github.com/moses-smt/mosesdecoder.git
synced 2024-12-28 14:32:38 +03:00
180 lines
5.7 KiB
C++
180 lines
5.7 KiB
C++
#ifndef LM_SEARCH_HASHED__
|
|
#define LM_SEARCH_HASHED__
|
|
|
|
#include "lm/model_type.hh"
|
|
#include "lm/config.hh"
|
|
#include "lm/read_arpa.hh"
|
|
#include "lm/return.hh"
|
|
#include "lm/weights.hh"
|
|
|
|
#include "util/bit_packing.hh"
|
|
#include "util/key_value_packing.hh"
|
|
#include "util/probing_hash_table.hh"
|
|
|
|
#include <algorithm>
|
|
#include <iostream>
|
|
#include <vector>
|
|
|
|
namespace util { class FilePiece; }
|
|
|
|
namespace lm {
|
|
namespace ngram {
|
|
struct Backing;
|
|
namespace detail {
|
|
|
|
inline uint64_t CombineWordHash(uint64_t current, const WordIndex next) {
|
|
uint64_t ret = (current * 8978948897894561157ULL) ^ (static_cast<uint64_t>(1 + next) * 17894857484156487943ULL);
|
|
return ret;
|
|
}
|
|
|
|
struct HashedSearch {
|
|
typedef uint64_t Node;
|
|
|
|
class Unigram {
|
|
public:
|
|
Unigram() {}
|
|
|
|
Unigram(void *start, std::size_t /*allocated*/) : unigram_(static_cast<ProbBackoff*>(start)) {}
|
|
|
|
static std::size_t Size(uint64_t count) {
|
|
return (count + 1) * sizeof(ProbBackoff); // +1 for hallucinate <unk>
|
|
}
|
|
|
|
const ProbBackoff &Lookup(WordIndex index) const { return unigram_[index]; }
|
|
|
|
ProbBackoff &Unknown() { return unigram_[0]; }
|
|
|
|
void LoadedBinary() {}
|
|
|
|
// For building.
|
|
ProbBackoff *Raw() { return unigram_; }
|
|
|
|
private:
|
|
ProbBackoff *unigram_;
|
|
};
|
|
|
|
Unigram unigram;
|
|
|
|
void LookupUnigram(WordIndex word, float &backoff, Node &next, FullScoreReturn &ret) const {
|
|
const ProbBackoff &entry = unigram.Lookup(word);
|
|
util::FloatEnc val;
|
|
val.f = entry.prob;
|
|
ret.independent_left = (val.i & util::kSignBit);
|
|
ret.extend_left = static_cast<uint64_t>(word);
|
|
val.i |= util::kSignBit;
|
|
ret.prob = val.f;
|
|
backoff = entry.backoff;
|
|
next = static_cast<Node>(word);
|
|
}
|
|
};
|
|
|
|
template <class MiddleT, class LongestT> class TemplateHashedSearch : public HashedSearch {
|
|
public:
|
|
typedef MiddleT Middle;
|
|
|
|
typedef LongestT Longest;
|
|
Longest longest;
|
|
|
|
static const unsigned int kVersion = 0;
|
|
|
|
// TODO: move probing_multiplier here with next binary file format update.
|
|
static void UpdateConfigFromBinary(int, const std::vector<uint64_t> &, Config &) {}
|
|
|
|
static std::size_t Size(const std::vector<uint64_t> &counts, const Config &config) {
|
|
std::size_t ret = Unigram::Size(counts[0]);
|
|
for (unsigned char n = 1; n < counts.size() - 1; ++n) {
|
|
ret += Middle::Size(counts[n], config.probing_multiplier);
|
|
}
|
|
return ret + Longest::Size(counts.back(), config.probing_multiplier);
|
|
}
|
|
|
|
uint8_t *SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config);
|
|
|
|
template <class Voc> void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, Voc &vocab, Backing &backing);
|
|
|
|
const Middle *MiddleBegin() const { return &*middle_.begin(); }
|
|
const Middle *MiddleEnd() const { return &*middle_.end(); }
|
|
|
|
Node Unpack(uint64_t extend_pointer, unsigned char extend_length, float &prob) const {
|
|
util::FloatEnc val;
|
|
if (extend_length == 1) {
|
|
val.f = unigram.Lookup(static_cast<uint64_t>(extend_pointer)).prob;
|
|
} else {
|
|
typename Middle::ConstIterator found;
|
|
if (!middle_[extend_length - 2].Find(extend_pointer, found)) {
|
|
std::cerr << "Extend pointer " << extend_pointer << " should have been found for length " << (unsigned) extend_length << std::endl;
|
|
abort();
|
|
}
|
|
val.f = found->GetValue().prob;
|
|
}
|
|
val.i |= util::kSignBit;
|
|
prob = val.f;
|
|
return extend_pointer;
|
|
}
|
|
|
|
bool LookupMiddle(const Middle &middle, WordIndex word, float &backoff, Node &node, FullScoreReturn &ret) const {
|
|
node = CombineWordHash(node, word);
|
|
typename Middle::ConstIterator found;
|
|
if (!middle.Find(node, found)) return false;
|
|
util::FloatEnc enc;
|
|
enc.f = found->GetValue().prob;
|
|
ret.independent_left = (enc.i & util::kSignBit);
|
|
ret.extend_left = node;
|
|
enc.i |= util::kSignBit;
|
|
ret.prob = enc.f;
|
|
backoff = found->GetValue().backoff;
|
|
return true;
|
|
}
|
|
|
|
void LoadedBinary();
|
|
|
|
bool LookupMiddleNoProb(const Middle &middle, WordIndex word, float &backoff, Node &node) const {
|
|
node = CombineWordHash(node, word);
|
|
typename Middle::ConstIterator found;
|
|
if (!middle.Find(node, found)) return false;
|
|
backoff = found->GetValue().backoff;
|
|
return true;
|
|
}
|
|
|
|
bool LookupLongest(WordIndex word, float &prob, Node &node) const {
|
|
// Sign bit is always on because longest n-grams do not extend left.
|
|
node = CombineWordHash(node, word);
|
|
typename Longest::ConstIterator found;
|
|
if (!longest.Find(node, found)) return false;
|
|
prob = found->GetValue().prob;
|
|
return true;
|
|
}
|
|
|
|
// Geenrate a node without necessarily checking that it actually exists.
|
|
// Optionally return false if it's know to not exist.
|
|
bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const {
|
|
assert(begin != end);
|
|
node = static_cast<Node>(*begin);
|
|
for (const WordIndex *i = begin + 1; i < end; ++i) {
|
|
node = CombineWordHash(node, *i);
|
|
}
|
|
return true;
|
|
}
|
|
|
|
private:
|
|
std::vector<Middle> middle_;
|
|
};
|
|
|
|
// std::identity is an SGI extension :-(
|
|
struct IdentityHash : public std::unary_function<uint64_t, size_t> {
|
|
size_t operator()(uint64_t arg) const { return static_cast<size_t>(arg); }
|
|
};
|
|
|
|
struct ProbingHashedSearch : public TemplateHashedSearch<
|
|
util::ProbingHashTable<util::ByteAlignedPacking<uint64_t, ProbBackoff>, IdentityHash>,
|
|
util::ProbingHashTable<util::ByteAlignedPacking<uint64_t, Prob>, IdentityHash> > {
|
|
|
|
static const ModelType kModelType = HASH_PROBING;
|
|
};
|
|
|
|
} // namespace detail
|
|
} // namespace ngram
|
|
} // namespace lm
|
|
|
|
#endif // LM_SEARCH_HASHED__
|