mirror of
https://github.com/moses-smt/mosesdecoder.git
synced 2024-12-28 14:32:38 +03:00
256 lines
8.2 KiB
C++
256 lines
8.2 KiB
C++
#include "lm/builder/corpus_count.hh"
|
|
|
|
#include "lm/builder/ngram.hh"
|
|
#include "lm/lm_exception.hh"
|
|
#include "lm/vocab.hh"
|
|
#include "lm/word_index.hh"
|
|
#include "util/fake_ofstream.hh"
|
|
#include "util/file.hh"
|
|
#include "util/file_piece.hh"
|
|
#include "util/murmur_hash.hh"
|
|
#include "util/probing_hash_table.hh"
|
|
#include "util/scoped.hh"
|
|
#include "util/stream/chain.hh"
|
|
#include "util/stream/timer.hh"
|
|
#include "util/tokenize_piece.hh"
|
|
|
|
#include <boost/unordered_set.hpp>
|
|
#include <boost/unordered_map.hpp>
|
|
|
|
#include <functional>
|
|
|
|
#include <stdint.h>
|
|
|
|
namespace lm {
|
|
namespace builder {
|
|
namespace {
|
|
|
|
#pragma pack(push)
|
|
#pragma pack(4)
|
|
struct VocabEntry {
|
|
typedef uint64_t Key;
|
|
|
|
uint64_t GetKey() const { return key; }
|
|
void SetKey(uint64_t to) { key = to; }
|
|
|
|
uint64_t key;
|
|
lm::WordIndex value;
|
|
};
|
|
#pragma pack(pop)
|
|
|
|
class DedupeHash : public std::unary_function<const WordIndex *, bool> {
|
|
public:
|
|
explicit DedupeHash(std::size_t order) : size_(order * sizeof(WordIndex)) {}
|
|
|
|
std::size_t operator()(const WordIndex *start) const {
|
|
return util::MurmurHashNative(start, size_);
|
|
}
|
|
|
|
private:
|
|
const std::size_t size_;
|
|
};
|
|
|
|
class DedupeEquals : public std::binary_function<const WordIndex *, const WordIndex *, bool> {
|
|
public:
|
|
explicit DedupeEquals(std::size_t order) : size_(order * sizeof(WordIndex)) {}
|
|
|
|
bool operator()(const WordIndex *first, const WordIndex *second) const {
|
|
return !memcmp(first, second, size_);
|
|
}
|
|
|
|
private:
|
|
const std::size_t size_;
|
|
};
|
|
|
|
struct DedupeEntry {
|
|
typedef WordIndex *Key;
|
|
Key GetKey() const { return key; }
|
|
void SetKey(WordIndex *to) { key = to; }
|
|
Key key;
|
|
static DedupeEntry Construct(WordIndex *at) {
|
|
DedupeEntry ret;
|
|
ret.key = at;
|
|
return ret;
|
|
}
|
|
};
|
|
|
|
|
|
// TODO: don't have this here, should be with probing hash table defaults?
|
|
const float kProbingMultiplier = 1.5;
|
|
|
|
typedef util::ProbingHashTable<DedupeEntry, DedupeHash, DedupeEquals> Dedupe;
|
|
|
|
class Writer {
|
|
public:
|
|
Writer(std::size_t order, const util::stream::ChainPosition &position, void *dedupe_mem, std::size_t dedupe_mem_size)
|
|
: block_(position), gram_(block_->Get(), order),
|
|
dedupe_invalid_(order, std::numeric_limits<WordIndex>::max()),
|
|
dedupe_(dedupe_mem, dedupe_mem_size, &dedupe_invalid_[0], DedupeHash(order), DedupeEquals(order)),
|
|
buffer_(new WordIndex[order - 1]),
|
|
block_size_(position.GetChain().BlockSize()) {
|
|
dedupe_.Clear();
|
|
assert(Dedupe::Size(position.GetChain().BlockSize() / position.GetChain().EntrySize(), kProbingMultiplier) == dedupe_mem_size);
|
|
if (order == 1) {
|
|
// Add special words. AdjustCounts is responsible if order != 1.
|
|
AddUnigramWord(kUNK);
|
|
AddUnigramWord(kBOS);
|
|
}
|
|
}
|
|
|
|
~Writer() {
|
|
block_->SetValidSize(reinterpret_cast<const uint8_t*>(gram_.begin()) - static_cast<const uint8_t*>(block_->Get()));
|
|
(++block_).Poison();
|
|
}
|
|
|
|
// Write context with a bunch of <s>
|
|
void StartSentence() {
|
|
for (WordIndex *i = gram_.begin(); i != gram_.end() - 1; ++i) {
|
|
*i = kBOS;
|
|
}
|
|
}
|
|
|
|
void Append(WordIndex word) {
|
|
*(gram_.end() - 1) = word;
|
|
Dedupe::MutableIterator at;
|
|
bool found = dedupe_.FindOrInsert(DedupeEntry::Construct(gram_.begin()), at);
|
|
if (found) {
|
|
// Already present.
|
|
NGram already(at->key, gram_.Order());
|
|
++(already.Count());
|
|
// Shift left by one.
|
|
memmove(gram_.begin(), gram_.begin() + 1, sizeof(WordIndex) * (gram_.Order() - 1));
|
|
return;
|
|
}
|
|
// Complete the write.
|
|
gram_.Count() = 1;
|
|
// Prepare the next n-gram.
|
|
if (reinterpret_cast<uint8_t*>(gram_.begin()) + gram_.TotalSize() != static_cast<uint8_t*>(block_->Get()) + block_size_) {
|
|
NGram last(gram_);
|
|
gram_.NextInMemory();
|
|
std::copy(last.begin() + 1, last.end(), gram_.begin());
|
|
return;
|
|
}
|
|
// Block end. Need to store the context in a temporary buffer.
|
|
std::copy(gram_.begin() + 1, gram_.end(), buffer_.get());
|
|
dedupe_.Clear();
|
|
block_->SetValidSize(block_size_);
|
|
gram_.ReBase((++block_)->Get());
|
|
std::copy(buffer_.get(), buffer_.get() + gram_.Order() - 1, gram_.begin());
|
|
}
|
|
|
|
private:
|
|
void AddUnigramWord(WordIndex index) {
|
|
*gram_.begin() = index;
|
|
gram_.Count() = 0;
|
|
gram_.NextInMemory();
|
|
if (gram_.Base() == static_cast<uint8_t*>(block_->Get()) + block_size_) {
|
|
block_->SetValidSize(block_size_);
|
|
gram_.ReBase((++block_)->Get());
|
|
}
|
|
}
|
|
|
|
util::stream::Link block_;
|
|
|
|
NGram gram_;
|
|
|
|
// This is the memory behind the invalid value in dedupe_.
|
|
std::vector<WordIndex> dedupe_invalid_;
|
|
// Hash table combiner implementation.
|
|
Dedupe dedupe_;
|
|
|
|
// Small buffer to hold existing ngrams when shifting across a block boundary.
|
|
boost::scoped_array<WordIndex> buffer_;
|
|
|
|
const std::size_t block_size_;
|
|
};
|
|
|
|
} // namespace
|
|
|
|
float CorpusCount::DedupeMultiplier(std::size_t order) {
|
|
return kProbingMultiplier * static_cast<float>(sizeof(DedupeEntry)) / static_cast<float>(NGram::TotalSize(order));
|
|
}
|
|
|
|
std::size_t CorpusCount::VocabUsage(std::size_t vocab_estimate) {
|
|
return ngram::GrowableVocab<ngram::WriteUniqueWords>::MemUsage(vocab_estimate);
|
|
}
|
|
|
|
CorpusCount::CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::vector<bool> &prune_words, const std::string& prune_vocab_filename, std::size_t entries_per_block, WarningAction disallowed_symbol)
|
|
: from_(from), vocab_write_(vocab_write), token_count_(token_count), type_count_(type_count),
|
|
prune_words_(prune_words), prune_vocab_filename_(prune_vocab_filename),
|
|
dedupe_mem_size_(Dedupe::Size(entries_per_block, kProbingMultiplier)),
|
|
dedupe_mem_(util::MallocOrThrow(dedupe_mem_size_)),
|
|
disallowed_symbol_action_(disallowed_symbol) {
|
|
}
|
|
|
|
namespace {
|
|
void ComplainDisallowed(StringPiece word, WarningAction &action) {
|
|
switch (action) {
|
|
case SILENT:
|
|
return;
|
|
case COMPLAIN:
|
|
std::cerr << "Warning: " << word << " appears in the input. All instances of <s>, </s>, and <unk> will be interpreted as whitespace." << std::endl;
|
|
action = SILENT;
|
|
return;
|
|
case THROW_UP:
|
|
UTIL_THROW(FormatLoadException, "Special word " << word << " is not allowed in the corpus. I plan to support models containing <unk> in the future. Pass --skip_symbols to convert these symbols to whitespace.");
|
|
}
|
|
}
|
|
} // namespace
|
|
|
|
void CorpusCount::Run(const util::stream::ChainPosition &position) {
|
|
ngram::GrowableVocab<ngram::WriteUniqueWords> vocab(type_count_, vocab_write_);
|
|
token_count_ = 0;
|
|
type_count_ = 0;
|
|
const WordIndex end_sentence = vocab.FindOrInsert("</s>");
|
|
Writer writer(NGram::OrderFromSize(position.GetChain().EntrySize()), position, dedupe_mem_.get(), dedupe_mem_size_);
|
|
uint64_t count = 0;
|
|
bool delimiters[256];
|
|
util::BoolCharacter::Build("\0\t\n\r ", delimiters);
|
|
try {
|
|
while(true) {
|
|
StringPiece line(from_.ReadLine());
|
|
writer.StartSentence();
|
|
for (util::TokenIter<util::BoolCharacter, true> w(line, delimiters); w; ++w) {
|
|
WordIndex word = vocab.FindOrInsert(*w);
|
|
if (word <= 2) {
|
|
ComplainDisallowed(*w, disallowed_symbol_action_);
|
|
continue;
|
|
}
|
|
writer.Append(word);
|
|
++count;
|
|
}
|
|
writer.Append(end_sentence);
|
|
}
|
|
} catch (const util::EndOfFileException &e) {}
|
|
token_count_ = count;
|
|
type_count_ = vocab.Size();
|
|
|
|
// Create list of unigrams that are supposed to be pruned
|
|
if (!prune_vocab_filename_.empty()) {
|
|
try {
|
|
util::FilePiece prune_vocab_file(prune_vocab_filename_.c_str());
|
|
|
|
prune_words_.resize(vocab.Size(), true);
|
|
try {
|
|
while (true) {
|
|
StringPiece line(prune_vocab_file.ReadLine());
|
|
for (util::TokenIter<util::BoolCharacter, true> w(line, delimiters); w; ++w)
|
|
prune_words_[vocab.Index(*w)] = false;
|
|
}
|
|
} catch (const util::EndOfFileException &e) {}
|
|
|
|
// Never prune <unk>, <s>, </s>
|
|
prune_words_[kUNK] = false;
|
|
prune_words_[kBOS] = false;
|
|
prune_words_[kEOS] = false;
|
|
|
|
} catch (const util::Exception &e) {
|
|
std::cerr << e.what() << std::endl;
|
|
abort();
|
|
}
|
|
}
|
|
}
|
|
|
|
} // namespace builder
|
|
} // namespace lm
|