mirror of
https://github.com/moses-smt/mosesdecoder.git
synced 2025-01-01 00:12:27 +03:00
261 lines
8.2 KiB
C++
261 lines
8.2 KiB
C++
#include "lm/builder/corpus_count.hh"
|
|
|
|
#include "lm/builder/ngram.hh"
|
|
#include "lm/lm_exception.hh"
|
|
#include "lm/word_index.hh"
|
|
#include "util/fake_ofstream.hh"
|
|
#include "util/file.hh"
|
|
#include "util/file_piece.hh"
|
|
#include "util/murmur_hash.hh"
|
|
#include "util/probing_hash_table.hh"
|
|
#include "util/scoped.hh"
|
|
#include "util/stream/chain.hh"
|
|
#include "util/stream/timer.hh"
|
|
#include "util/tokenize_piece.hh"
|
|
|
|
#include <boost/unordered_set.hpp>
|
|
#include <boost/unordered_map.hpp>
|
|
|
|
#include <functional>
|
|
|
|
#include <stdint.h>
|
|
|
|
namespace lm {
|
|
namespace builder {
|
|
namespace {
|
|
|
|
#pragma pack(push)
|
|
#pragma pack(4)
|
|
struct VocabEntry {
|
|
typedef uint64_t Key;
|
|
|
|
uint64_t GetKey() const { return key; }
|
|
void SetKey(uint64_t to) { key = to; }
|
|
|
|
uint64_t key;
|
|
lm::WordIndex value;
|
|
};
|
|
#pragma pack(pop)
|
|
|
|
const float kProbingMultiplier = 1.5;
|
|
|
|
class VocabHandout {
|
|
public:
|
|
static std::size_t MemUsage(WordIndex initial_guess) {
|
|
if (initial_guess < 2) initial_guess = 2;
|
|
return util::CheckOverflow(Table::Size(initial_guess, kProbingMultiplier));
|
|
}
|
|
|
|
explicit VocabHandout(int fd, WordIndex initial_guess) :
|
|
table_backing_(util::CallocOrThrow(MemUsage(initial_guess))),
|
|
table_(table_backing_.get(), MemUsage(initial_guess)),
|
|
double_cutoff_(std::max<std::size_t>(initial_guess * 1.1, 1)),
|
|
word_list_(fd) {
|
|
Lookup("<unk>"); // Force 0
|
|
Lookup("<s>"); // Force 1
|
|
Lookup("</s>"); // Force 2
|
|
}
|
|
|
|
WordIndex Lookup(const StringPiece &word) {
|
|
VocabEntry entry;
|
|
entry.key = util::MurmurHashNative(word.data(), word.size());
|
|
entry.value = table_.SizeNoSerialization();
|
|
|
|
Table::MutableIterator it;
|
|
if (table_.FindOrInsert(entry, it))
|
|
return it->value;
|
|
word_list_ << word << '\0';
|
|
UTIL_THROW_IF(Size() >= std::numeric_limits<lm::WordIndex>::max(), VocabLoadException, "Too many vocabulary words. Change WordIndex to uint64_t in lm/word_index.hh.");
|
|
if (Size() >= double_cutoff_) {
|
|
table_backing_.call_realloc(table_.DoubleTo());
|
|
table_.Double(table_backing_.get());
|
|
double_cutoff_ *= 2;
|
|
}
|
|
return entry.value;
|
|
}
|
|
|
|
WordIndex Size() const {
|
|
return table_.SizeNoSerialization();
|
|
}
|
|
|
|
private:
|
|
// TODO: factor out a resizable probing hash table.
|
|
// TODO: use mremap on linux to get all zeros on resizes.
|
|
util::scoped_malloc table_backing_;
|
|
|
|
typedef util::ProbingHashTable<VocabEntry, util::IdentityHash> Table;
|
|
Table table_;
|
|
|
|
std::size_t double_cutoff_;
|
|
|
|
util::FakeOFStream word_list_;
|
|
};
|
|
|
|
class DedupeHash : public std::unary_function<const WordIndex *, bool> {
|
|
public:
|
|
explicit DedupeHash(std::size_t order) : size_(order * sizeof(WordIndex)) {}
|
|
|
|
std::size_t operator()(const WordIndex *start) const {
|
|
return util::MurmurHashNative(start, size_);
|
|
}
|
|
|
|
private:
|
|
const std::size_t size_;
|
|
};
|
|
|
|
class DedupeEquals : public std::binary_function<const WordIndex *, const WordIndex *, bool> {
|
|
public:
|
|
explicit DedupeEquals(std::size_t order) : size_(order * sizeof(WordIndex)) {}
|
|
|
|
bool operator()(const WordIndex *first, const WordIndex *second) const {
|
|
return !memcmp(first, second, size_);
|
|
}
|
|
|
|
private:
|
|
const std::size_t size_;
|
|
};
|
|
|
|
struct DedupeEntry {
|
|
typedef WordIndex *Key;
|
|
Key GetKey() const { return key; }
|
|
void SetKey(WordIndex *to) { key = to; }
|
|
Key key;
|
|
static DedupeEntry Construct(WordIndex *at) {
|
|
DedupeEntry ret;
|
|
ret.key = at;
|
|
return ret;
|
|
}
|
|
};
|
|
|
|
typedef util::ProbingHashTable<DedupeEntry, DedupeHash, DedupeEquals> Dedupe;
|
|
|
|
class Writer {
|
|
public:
|
|
Writer(std::size_t order, const util::stream::ChainPosition &position, void *dedupe_mem, std::size_t dedupe_mem_size)
|
|
: block_(position), gram_(block_->Get(), order),
|
|
dedupe_invalid_(order, std::numeric_limits<WordIndex>::max()),
|
|
dedupe_(dedupe_mem, dedupe_mem_size, &dedupe_invalid_[0], DedupeHash(order), DedupeEquals(order)),
|
|
buffer_(new WordIndex[order - 1]),
|
|
block_size_(position.GetChain().BlockSize()) {
|
|
dedupe_.Clear();
|
|
assert(Dedupe::Size(position.GetChain().BlockSize() / position.GetChain().EntrySize(), kProbingMultiplier) == dedupe_mem_size);
|
|
if (order == 1) {
|
|
// Add special words. AdjustCounts is responsible if order != 1.
|
|
AddUnigramWord(kUNK);
|
|
AddUnigramWord(kBOS);
|
|
}
|
|
}
|
|
|
|
~Writer() {
|
|
block_->SetValidSize(reinterpret_cast<const uint8_t*>(gram_.begin()) - static_cast<const uint8_t*>(block_->Get()));
|
|
(++block_).Poison();
|
|
}
|
|
|
|
// Write context with a bunch of <s>
|
|
void StartSentence() {
|
|
for (WordIndex *i = gram_.begin(); i != gram_.end() - 1; ++i) {
|
|
*i = kBOS;
|
|
}
|
|
}
|
|
|
|
void Append(WordIndex word) {
|
|
*(gram_.end() - 1) = word;
|
|
Dedupe::MutableIterator at;
|
|
bool found = dedupe_.FindOrInsert(DedupeEntry::Construct(gram_.begin()), at);
|
|
if (found) {
|
|
// Already present.
|
|
NGram already(at->key, gram_.Order());
|
|
++(already.Count());
|
|
// Shift left by one.
|
|
memmove(gram_.begin(), gram_.begin() + 1, sizeof(WordIndex) * (gram_.Order() - 1));
|
|
return;
|
|
}
|
|
// Complete the write.
|
|
gram_.Count() = 1;
|
|
// Prepare the next n-gram.
|
|
if (reinterpret_cast<uint8_t*>(gram_.begin()) + gram_.TotalSize() != static_cast<uint8_t*>(block_->Get()) + block_size_) {
|
|
NGram last(gram_);
|
|
gram_.NextInMemory();
|
|
std::copy(last.begin() + 1, last.end(), gram_.begin());
|
|
return;
|
|
}
|
|
// Block end. Need to store the context in a temporary buffer.
|
|
std::copy(gram_.begin() + 1, gram_.end(), buffer_.get());
|
|
dedupe_.Clear();
|
|
block_->SetValidSize(block_size_);
|
|
gram_.ReBase((++block_)->Get());
|
|
std::copy(buffer_.get(), buffer_.get() + gram_.Order() - 1, gram_.begin());
|
|
}
|
|
|
|
private:
|
|
void AddUnigramWord(WordIndex index) {
|
|
*gram_.begin() = index;
|
|
gram_.Count() = 0;
|
|
gram_.NextInMemory();
|
|
if (gram_.Base() == static_cast<uint8_t*>(block_->Get()) + block_size_) {
|
|
block_->SetValidSize(block_size_);
|
|
gram_.ReBase((++block_)->Get());
|
|
}
|
|
}
|
|
|
|
util::stream::Link block_;
|
|
|
|
NGram gram_;
|
|
|
|
// This is the memory behind the invalid value in dedupe_.
|
|
std::vector<WordIndex> dedupe_invalid_;
|
|
// Hash table combiner implementation.
|
|
Dedupe dedupe_;
|
|
|
|
// Small buffer to hold existing ngrams when shifting across a block boundary.
|
|
boost::scoped_array<WordIndex> buffer_;
|
|
|
|
const std::size_t block_size_;
|
|
};
|
|
|
|
} // namespace
|
|
|
|
float CorpusCount::DedupeMultiplier(std::size_t order) {
|
|
return kProbingMultiplier * static_cast<float>(sizeof(DedupeEntry)) / static_cast<float>(NGram::TotalSize(order));
|
|
}
|
|
|
|
std::size_t CorpusCount::VocabUsage(std::size_t vocab_estimate) {
|
|
return VocabHandout::MemUsage(vocab_estimate);
|
|
}
|
|
|
|
CorpusCount::CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block)
|
|
: from_(from), vocab_write_(vocab_write), token_count_(token_count), type_count_(type_count),
|
|
dedupe_mem_size_(Dedupe::Size(entries_per_block, kProbingMultiplier)),
|
|
dedupe_mem_(util::MallocOrThrow(dedupe_mem_size_)) {
|
|
}
|
|
|
|
void CorpusCount::Run(const util::stream::ChainPosition &position) {
|
|
UTIL_TIMER("(%w s) Counted n-grams\n");
|
|
|
|
VocabHandout vocab(vocab_write_, type_count_);
|
|
token_count_ = 0;
|
|
type_count_ = 0;
|
|
const WordIndex end_sentence = vocab.Lookup("</s>");
|
|
Writer writer(NGram::OrderFromSize(position.GetChain().EntrySize()), position, dedupe_mem_.get(), dedupe_mem_size_);
|
|
uint64_t count = 0;
|
|
StringPiece delimiters("\0\t\r ", 4);
|
|
try {
|
|
while(true) {
|
|
StringPiece line(from_.ReadLine());
|
|
writer.StartSentence();
|
|
for (util::TokenIter<util::AnyCharacter, true> w(line, delimiters); w; ++w) {
|
|
WordIndex word = vocab.Lookup(*w);
|
|
UTIL_THROW_IF(word <= 2, FormatLoadException, "Special word " << *w << " is not allowed in the corpus. I plan to support models containing <unk> in the future.");
|
|
writer.Append(word);
|
|
++count;
|
|
}
|
|
writer.Append(end_sentence);
|
|
}
|
|
} catch (const util::EndOfFileException &e) {}
|
|
token_count_ = count;
|
|
type_count_ = vocab.Size();
|
|
}
|
|
|
|
} // namespace builder
|
|
} // namespace lm
|