mirror of
https://github.com/moses-smt/mosesdecoder.git
synced 2024-12-27 22:14:57 +03:00
156 lines
3.7 KiB
C++
156 lines
3.7 KiB
C++
#ifndef LM_TRIE__
|
|
#define LM_TRIE__
|
|
|
|
#include "lm/weights.hh"
|
|
#include "lm/word_index.hh"
|
|
#include "util/bit_packing.hh"
|
|
|
|
#include <cstddef>
|
|
|
|
#include <stdint.h>
|
|
|
|
namespace lm {
|
|
namespace ngram {
|
|
struct Config;
|
|
namespace trie {
|
|
|
|
struct NodeRange {
|
|
uint64_t begin, end;
|
|
};
|
|
|
|
// TODO: if the number of unigrams is a concern, also bit pack these records.
|
|
struct UnigramValue {
|
|
ProbBackoff weights;
|
|
uint64_t next;
|
|
uint64_t Next() const { return next; }
|
|
};
|
|
|
|
class UnigramPointer {
|
|
public:
|
|
explicit UnigramPointer(const ProbBackoff &to) : to_(&to) {}
|
|
|
|
UnigramPointer() : to_(NULL) {}
|
|
|
|
bool Found() const { return to_ != NULL; }
|
|
|
|
float Prob() const { return to_->prob; }
|
|
float Backoff() const { return to_->backoff; }
|
|
float Rest() const { return Prob(); }
|
|
|
|
private:
|
|
const ProbBackoff *to_;
|
|
};
|
|
|
|
class Unigram {
|
|
public:
|
|
Unigram() {}
|
|
|
|
void Init(void *start) {
|
|
unigram_ = static_cast<UnigramValue*>(start);
|
|
}
|
|
|
|
static uint64_t Size(uint64_t count) {
|
|
// +1 in case unknown doesn't appear. +1 for the final next.
|
|
return (count + 2) * sizeof(UnigramValue);
|
|
}
|
|
|
|
const ProbBackoff &Lookup(WordIndex index) const { return unigram_[index].weights; }
|
|
|
|
ProbBackoff &Unknown() { return unigram_[0].weights; }
|
|
|
|
UnigramValue *Raw() {
|
|
return unigram_;
|
|
}
|
|
|
|
void LoadedBinary() {}
|
|
|
|
UnigramPointer Find(WordIndex word, NodeRange &next) const {
|
|
UnigramValue *val = unigram_ + word;
|
|
next.begin = val->next;
|
|
next.end = (val+1)->next;
|
|
return UnigramPointer(val->weights);
|
|
}
|
|
|
|
private:
|
|
UnigramValue *unigram_;
|
|
};
|
|
|
|
class BitPacked {
|
|
public:
|
|
BitPacked() {}
|
|
|
|
uint64_t InsertIndex() const {
|
|
return insert_index_;
|
|
}
|
|
|
|
protected:
|
|
static uint64_t BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits);
|
|
|
|
void BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits);
|
|
|
|
uint8_t word_bits_;
|
|
uint8_t total_bits_;
|
|
uint64_t word_mask_;
|
|
|
|
uint8_t *base_;
|
|
|
|
uint64_t insert_index_, max_vocab_;
|
|
};
|
|
|
|
template <class Bhiksha> class BitPackedMiddle : public BitPacked {
|
|
public:
|
|
static uint64_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const Config &config);
|
|
|
|
// next_source need not be initialized.
|
|
BitPackedMiddle(void *base, uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config);
|
|
|
|
util::BitAddress Insert(WordIndex word);
|
|
|
|
void FinishedLoading(uint64_t next_end, const Config &config);
|
|
|
|
void LoadedBinary() { bhiksha_.LoadedBinary(); }
|
|
|
|
util::BitAddress Find(WordIndex word, NodeRange &range, uint64_t &pointer) const;
|
|
|
|
util::BitAddress ReadEntry(uint64_t pointer, NodeRange &range) {
|
|
uint64_t addr = pointer * total_bits_;
|
|
addr += word_bits_;
|
|
bhiksha_.ReadNext(base_, addr + quant_bits_, pointer, total_bits_, range);
|
|
return util::BitAddress(base_, addr);
|
|
}
|
|
|
|
private:
|
|
uint8_t quant_bits_;
|
|
Bhiksha bhiksha_;
|
|
|
|
const BitPacked *next_source_;
|
|
};
|
|
|
|
class BitPackedLongest : public BitPacked {
|
|
public:
|
|
static uint64_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab) {
|
|
return BaseSize(entries, max_vocab, quant_bits);
|
|
}
|
|
|
|
BitPackedLongest() {}
|
|
|
|
void Init(void *base, uint8_t quant_bits, uint64_t max_vocab) {
|
|
BaseInit(base, max_vocab, quant_bits);
|
|
}
|
|
|
|
void LoadedBinary() {}
|
|
|
|
util::BitAddress Insert(WordIndex word);
|
|
|
|
util::BitAddress Find(WordIndex word, const NodeRange &node) const;
|
|
|
|
private:
|
|
uint8_t quant_bits_;
|
|
};
|
|
|
|
} // namespace trie
|
|
} // namespace ngram
|
|
} // namespace lm
|
|
|
|
#endif // LM_TRIE__
|