KenLM 31a6644 resizable probing hash table, build fixes

This commit is contained in:
Kenneth Heafield 2013-01-24 12:07:46 +00:00
parent 22bf1c77e9
commit 03b077364a
13 changed files with 240 additions and 43 deletions

View File

@ -27,9 +27,12 @@ cflags = [ os.environ "CFLAGS" ] ;
ldflags = [ os.environ "LDFLAGS" ] ;
#Run g++ with empty main and these arguments to see if it passes.
rule test_flags ( flags * ) {
rule test_flags ( flags * : main ? ) {
flags = $(cxxflags) $(ldflags) $(flags) ;
local cmd = "bash -c \"g++ "$(flags:J=" ")" -x c++ - <<<'int main() {}' -o $(TOP)/dummy >/dev/null 2>/dev/null && rm $(TOP)/dummy 2>/dev/null\"" ;
if ! $(main) {
main = "int main() {}" ;
}
local cmd = "bash -c \"g++ "$(flags:J=" ")" -x c++ - <<<'$(main)' -o $(TOP)/dummy >/dev/null 2>/dev/null && rm $(TOP)/dummy 2>/dev/null\"" ;
local ret = [ SHELL $(cmd) : exit-status ] ;
if --debug-configuration in [ modules.peek : ARGV ] {
echo $(cmd) ;
@ -119,18 +122,29 @@ rule boost-lib ( name macro : deps * ) {
#versions of boost do not have -mt tagged versions of all libraries. Sadly,
#boost.jam does not handle this correctly.
flags = $(L-boost-search)" -lboost_"$(name)"-mt$(boost-lib-version)" ;
if $(boost-auto-shared) != "<link>shared" {
local main ;
if $(name) = "unit_test_framework" {
main = "BOOST_AUTO_TEST_CASE(foo) {}" ;
flags += " -DBOOST_TEST_MODULE=CompileTest $(I-boost-include) -include boost/test/unit_test.hpp" ;
}
if $(boost-auto-shared) = "<link>shared" {
flags += " -DBOOST_$(macro)" ;
} else {
flags += " -static" ;
}
if [ test_flags $(flags) ] {
if [ test_flags $(flags) : $(main) ] {
lib inner_boost_$(name) : : <threading>single $(boost-search) <name>boost_$(name)$(boost-lib-version) : : <library>$(deps) ;
lib inner_boost_$(name) : : <threading>multi $(boost-search) <name>boost_$(name)-mt$(boost-lib-version) : : <library>$(deps) ;
} else {
lib inner_boost_$(name) : : $(boost-search) <name>boost_$(name)$(boost-lib-version) : : <library>$(deps) ;
}
alias boost_$(name) : inner_boost_$(name) : $(boost-auto-shared) : : <link>shared:<define>BOOST_$(macro) ;
requirements += <link>shared:<define>BOOST_$(macro) ;
if $(boost-auto-shared) = "<link>shared" {
alias boost_$(name) : inner_boost_$(name) : <link>shared ;
requirements += <define>BOOST_$(macro) ;
} else {
alias boost_$(name) : inner_boost_$(name) : <link>static ;
}
}
#Argument is e.g. 103600

View File

@ -23,9 +23,32 @@ namespace lm {
namespace builder {
namespace {
#pragma pack(push)
#pragma pack(4)
struct VocabEntry {
typedef uint64_t Key;
uint64_t GetKey() const { return key; }
void SetKey(uint64_t to) { key = to; }
uint64_t key;
lm::WordIndex value;
};
#pragma pack(pop)
const float kProbingMultiplier = 1.5;
class VocabHandout {
public:
explicit VocabHandout(int fd) {
static std::size_t MemUsage(WordIndex initial_guess) {
if (initial_guess < 2) initial_guess = 2;
return util::CheckOverflow(Table::Size(initial_guess, kProbingMultiplier));
}
explicit VocabHandout(int fd, WordIndex initial_guess) :
table_backing_(util::CallocOrThrow(MemUsage(initial_guess))),
table_(table_backing_.get(), MemUsage(initial_guess)),
double_cutoff_(std::max<std::size_t>(initial_guess * 1.1, 1)) {
util::scoped_fd duped(util::DupOrThrow(fd));
word_list_.reset(util::FDOpenOrThrow(duped));
@ -35,25 +58,38 @@ class VocabHandout {
}
WordIndex Lookup(const StringPiece &word) {
uint64_t hashed = util::MurmurHashNative(word.data(), word.size());
std::pair<Seen::iterator, bool> ret(seen_.insert(std::pair<uint64_t, lm::WordIndex>(hashed, seen_.size())));
if (ret.second) {
char null_delimit = 0;
util::WriteOrThrow(word_list_.get(), word.data(), word.size());
util::WriteOrThrow(word_list_.get(), &null_delimit, 1);
UTIL_THROW_IF(seen_.size() >= std::numeric_limits<lm::WordIndex>::max(), VocabLoadException, "Too many vocabulary words. Change WordIndex to uint64_t in lm/word_index.hh.");
VocabEntry entry;
entry.key = util::MurmurHashNative(word.data(), word.size());
entry.value = table_.SizeNoSerialization();
Table::MutableIterator it;
if (table_.FindOrInsert(entry, it))
return it->value;
char null_delimit = 0;
util::WriteOrThrow(word_list_.get(), word.data(), word.size());
util::WriteOrThrow(word_list_.get(), &null_delimit, 1);
UTIL_THROW_IF(Size() >= std::numeric_limits<lm::WordIndex>::max(), VocabLoadException, "Too many vocabulary words. Change WordIndex to uint64_t in lm/word_index.hh.");
if (Size() >= double_cutoff_) {
table_backing_.call_realloc(table_.DoubleTo());
table_.Double(table_backing_.get());
double_cutoff_ *= 2;
}
return ret.first->second;
return entry.value;
}
WordIndex Size() const {
return seen_.size();
return table_.SizeNoSerialization();
}
private:
typedef boost::unordered_map<uint64_t, lm::WordIndex> Seen;
// TODO: factor out a resizable probing hash table.
// TODO: use mremap on linux to get all zeros on resizes.
util::scoped_malloc table_backing_;
Seen seen_;
typedef util::ProbingHashTable<VocabEntry, util::IdentityHash> Table;
Table table_;
std::size_t double_cutoff_;
util::scoped_FILE word_list_;
};
@ -85,6 +121,7 @@ class DedupeEquals : public std::binary_function<const WordIndex *, const WordIn
struct DedupeEntry {
typedef WordIndex *Key;
Key GetKey() const { return key; }
void SetKey(WordIndex *to) { key = to; }
Key key;
static DedupeEntry Construct(WordIndex *at) {
DedupeEntry ret;
@ -95,8 +132,6 @@ struct DedupeEntry {
typedef util::ProbingHashTable<DedupeEntry, DedupeHash, DedupeEquals> Dedupe;
const float kProbingMultiplier = 1.5;
class Writer {
public:
Writer(std::size_t order, const util::stream::ChainPosition &position, void *dedupe_mem, std::size_t dedupe_mem_size)
@ -105,7 +140,7 @@ class Writer {
dedupe_(dedupe_mem, dedupe_mem_size, &dedupe_invalid_[0], DedupeHash(order), DedupeEquals(order)),
buffer_(new WordIndex[order - 1]),
block_size_(position.GetChain().BlockSize()) {
dedupe_.Clear(DedupeEntry::Construct(&dedupe_invalid_[0]));
dedupe_.Clear();
assert(Dedupe::Size(position.GetChain().BlockSize() / position.GetChain().EntrySize(), kProbingMultiplier) == dedupe_mem_size);
if (order == 1) {
// Add special words. AdjustCounts is responsible if order != 1.
@ -149,7 +184,7 @@ class Writer {
}
// Block end. Need to store the context in a temporary buffer.
std::copy(gram_.begin() + 1, gram_.end(), buffer_.get());
dedupe_.Clear(DedupeEntry::Construct(&dedupe_invalid_[0]));
dedupe_.Clear();
block_->SetValidSize(block_size_);
gram_.ReBase((++block_)->Get());
std::copy(buffer_.get(), buffer_.get() + gram_.Order() - 1, gram_.begin());
@ -187,18 +222,22 @@ float CorpusCount::DedupeMultiplier(std::size_t order) {
return kProbingMultiplier * static_cast<float>(sizeof(DedupeEntry)) / static_cast<float>(NGram::TotalSize(order));
}
std::size_t CorpusCount::VocabUsage(std::size_t vocab_estimate) {
return VocabHandout::MemUsage(vocab_estimate);
}
CorpusCount::CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block)
: from_(from), vocab_write_(vocab_write), token_count_(token_count), type_count_(type_count),
dedupe_mem_size_(Dedupe::Size(entries_per_block, kProbingMultiplier)),
dedupe_mem_(util::MallocOrThrow(dedupe_mem_size_)) {
token_count_ = 0;
type_count_ = 0;
}
void CorpusCount::Run(const util::stream::ChainPosition &position) {
UTIL_TIMER("(%w s) Counted n-grams\n");
VocabHandout vocab(vocab_write_);
VocabHandout vocab(vocab_write_, type_count_);
token_count_ = 0;
type_count_ = 0;
const WordIndex end_sentence = vocab.Lookup("</s>");
Writer writer(NGram::OrderFromSize(position.GetChain().EntrySize()), position, dedupe_mem_.get(), dedupe_mem_size_);
uint64_t count = 0;

View File

@ -23,6 +23,11 @@ class CorpusCount {
// Memory usage will be DedupeMultipler(order) * block_size + total_chain_size + unknown vocab_hash_size
static float DedupeMultiplier(std::size_t order);
// How much memory vocabulary will use based on estimated size of the vocab.
static std::size_t VocabUsage(std::size_t vocab_estimate);
// token_count: out.
// type_count aka vocabulary size. Initialize to an estimate. It is set to the exact value.
CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block);
void Run(const util::stream::ChainPosition &position);

View File

@ -44,7 +44,7 @@ BOOST_AUTO_TEST_CASE(Short) {
util::stream::Chain chain(config);
NGramStream stream;
uint64_t token_count;
WordIndex type_count;
WordIndex type_count = 10;
CorpusCount counter(input_piece, vocab.get(), token_count, type_count, chain.BlockSize() / chain.EntrySize());
chain >> boost::ref(counter) >> stream >> util::stream::kRecycle;

View File

@ -24,7 +24,6 @@ struct BufferEntry {
class OnlyGamma {
public:
void Run(const util::stream::ChainPosition &position) {
uint64_t count = 0;
for (util::stream::Link block_it(position); block_it; ++block_it) {
float *out = static_cast<float*>(block_it->Get());
const float *in = out;
@ -33,10 +32,7 @@ class OnlyGamma {
*out = *in;
}
block_it->SetValidSize(block_it->ValidSize() / 2);
count += block_it->ValidSize() / sizeof(float);
}
std::cerr << std::endl;
std::cerr << "Backoff count is " << count << std::endl;
}
};

View File

@ -42,9 +42,9 @@ int main(int argc, char *argv[]) {
("interpolate_unigrams", po::bool_switch(&pipeline.initial_probs.interpolate_unigrams), "Interpolate the unigrams (default: emulate SRILM by not interpolating)")
("temp_prefix,T", po::value<std::string>(&pipeline.sort.temp_prefix)->default_value("/tmp/lm"), "Temporary file prefix")
("memory,S", SizeOption(pipeline.sort.total_memory, util::GuessPhysicalMemory() ? "80%" : "1G"), "Sorting memory")
("vocab_memory", SizeOption(pipeline.assume_vocab_hash_size, "50M"), "Assume that the vocabulary hash table will use this much memory for purposes of calculating total memory in the count step")
("minimum_block", SizeOption(pipeline.minimum_block, "8K"), "Minimum block size to allow")
("sort_block", SizeOption(pipeline.sort.buffer_size, "64M"), "Size of IO operations for sort (determines arity)")
("vocab_estimate", po::value<lm::WordIndex>(&pipeline.vocab_estimate)->default_value(1000000), "Assume this vocabulary size for purposes of calculating memory in step 1 (corpus count) and pre-sizing the hash table")
("block_count", po::value<std::size_t>(&pipeline.block_count)->default_value(2), "Block count (per order)")
("vocab_file", po::value<std::string>(&pipeline.vocab_file)->default_value(""), "Location to write vocabulary file")
("verbose_header", po::bool_switch(&pipeline.verbose_header), "Add a verbose header to the ARPA file that includes information such as token count, smoothing type, etc.");

View File

@ -207,17 +207,18 @@ void CountText(int text_file /* input */, int vocab_file /* output */, Master &m
const PipelineConfig &config = master.Config();
std::cerr << "=== 1/5 Counting and sorting n-grams ===" << std::endl;
UTIL_THROW_IF(config.TotalMemory() < config.assume_vocab_hash_size, util::Exception, "Vocab hash size estimate " << config.assume_vocab_hash_size << " exceeds total memory " << config.TotalMemory());
const std::size_t vocab_usage = CorpusCount::VocabUsage(config.vocab_estimate);
UTIL_THROW_IF(config.TotalMemory() < vocab_usage, util::Exception, "Vocab hash size estimate " << vocab_usage << " exceeds total memory " << config.TotalMemory());
std::size_t memory_for_chain =
// This much memory to work with after vocab hash table.
static_cast<float>(config.TotalMemory() - config.assume_vocab_hash_size) /
static_cast<float>(config.TotalMemory() - vocab_usage) /
// Solve for block size including the dedupe multiplier for one block.
(static_cast<float>(config.block_count) + CorpusCount::DedupeMultiplier(config.order)) *
// Chain likes memory expressed in terms of total memory.
static_cast<float>(config.block_count);
util::stream::Chain chain(util::stream::ChainConfig(NGram::TotalSize(config.order), config.block_count, memory_for_chain));
WordIndex type_count;
WordIndex type_count = config.vocab_estimate;
util::FilePiece text(text_file, NULL, &std::cerr);
text_file_name = text.FileName();
CorpusCount counter(text, vocab_file, token_count, type_count, chain.BlockSize() / chain.EntrySize());

View File

@ -3,6 +3,7 @@
#include "lm/builder/initial_probabilities.hh"
#include "lm/builder/header_info.hh"
#include "lm/word_index.hh"
#include "util/stream/config.hh"
#include "util/file_piece.hh"
@ -19,9 +20,9 @@ struct PipelineConfig {
util::stream::ChainConfig read_backoffs;
bool verbose_header;
// Amount of memory to assume that the vocabulary hash table will use. This
// is subtracted from total memory for CorpusCount.
std::size_t assume_vocab_hash_size;
// Estimated vocabulary size. Used for sizing CorpusCount memory and
// initial probing hash table sizing, also in CorpusCount.
lm::WordIndex vocab_estimate;
// Minimum block size to tolerate.
std::size_t minimum_block;

View File

@ -120,7 +120,7 @@ void FilePiece::Initialize(const char *name, std::ostream *show_progress, std::s
}
Shift();
// gzip detect.
if ((position_end_ - position_) >= ReadCompressed::kMagicSize && ReadCompressed::DetectCompressedMagic(position_)) {
if ((position_end_ >= position_ + ReadCompressed::kMagicSize) && ReadCompressed::DetectCompressedMagic(position_)) {
if (!fallback_to_read_) {
at_end_ = false;
TransitionToRead();

View File

@ -6,6 +6,7 @@
#include <algorithm>
#include <cstddef>
#include <functional>
#include <vector>
#include <assert.h>
#include <stdint.h>
@ -73,10 +74,7 @@ template <class EntryT, class HashT, class EqualT = std::equal_to<typename Entry
assert(initialized_);
#endif
UTIL_THROW_IF(++entries_ >= buckets_, ProbingSizeException, "Hash table with " << buckets_ << " buckets is full.");
for (MutableIterator i(begin_ + (hash_(t.GetKey()) % buckets_));;) {
if (equal_(i->GetKey(), invalid_)) { *i = t; return i; }
if (++i == end_) { i = begin_; }
}
return UncheckedInsert(t);
}
// Return true if the value was found (and not inserted). This is consistent with Find but the opposite if hash_map!
@ -126,12 +124,96 @@ template <class EntryT, class HashT, class EqualT = std::equal_to<typename Entry
}
}
void Clear(Entry invalid) {
void Clear() {
Entry invalid;
invalid.SetKey(invalid_);
std::fill(begin_, end_, invalid);
entries_ = 0;
}
// Return number of entries assuming no serialization went on.
std::size_t SizeNoSerialization() const {
return entries_;
}
// Return memory size expected by Double.
std::size_t DoubleTo() const {
return buckets_ * 2 * sizeof(Entry);
}
// Inform the table that it has double the amount of memory.
// Pass clear_new = false if you are sure the new memory is initialized
// properly (to invalid_) i.e. by mremap.
void Double(void *new_base, bool clear_new = true) {
begin_ = static_cast<MutableIterator>(new_base);
MutableIterator old_end = begin_ + buckets_;
buckets_ *= 2;
end_ = begin_ + buckets_;
if (clear_new) {
Entry invalid;
invalid.SetKey(invalid_);
std::fill(old_end, end_, invalid);
}
std::vector<Entry> rolled_over;
// Move roll-over entries to a buffer because they might not roll over anymore. This should be small.
for (MutableIterator i = begin_; i != old_end && !equal_(i->GetKey(), invalid_); ++i) {
rolled_over.push_back(*i);
i->SetKey(invalid_);
}
/* Re-insert everything. Entries might go backwards to take over a
* recently opened gap, stay, move to new territory, or wrap around. If
* an entry wraps around, it might go to a pointer greater than i (which
* can happen at the beginning) and it will be revisited to possibly fill
* in a gap created later.
*/
Entry temp;
for (MutableIterator i = begin_; i != old_end; ++i) {
if (!equal_(i->GetKey(), invalid_)) {
temp = *i;
i->SetKey(invalid_);
UncheckedInsert(temp);
}
}
// Put the roll-over entries back in.
for (typename std::vector<Entry>::const_iterator i(rolled_over.begin()); i != rolled_over.end(); ++i) {
UncheckedInsert(*i);
}
}
// Mostly for tests, check consistency of every entry.
void CheckConsistency() {
MutableIterator last;
for (last = end_ - 1; last >= begin_ && !equal_(last->GetKey(), invalid_); --last) {}
UTIL_THROW_IF(last == begin_, ProbingSizeException, "Completely full");
MutableIterator i;
// Beginning can be wrap-arounds.
for (i = begin_; !equal_(i->GetKey(), invalid_); ++i) {
MutableIterator ideal = Ideal(*i);
UTIL_THROW_IF(ideal > i && ideal <= last, Exception, "Inconsistency at position " << (i - begin_) << " should be at " << (ideal - begin_));
}
MutableIterator pre_gap = i;
for (; i != end_; ++i) {
if (equal_(i->GetKey(), invalid_)) {
pre_gap = i;
continue;
}
MutableIterator ideal = Ideal(*i);
UTIL_THROW_IF(ideal > i || ideal <= pre_gap, Exception, "Inconsistency at position " << (i - begin_) << " with ideal " << (ideal - begin_));
}
}
private:
template <class T> MutableIterator Ideal(const T &t) {
return begin_ + (hash_(t.GetKey()) % buckets_);
}
template <class T> MutableIterator UncheckedInsert(const T &t) {
for (MutableIterator i(Ideal(t));;) {
if (equal_(i->GetKey(), invalid_)) { *i = t; return i; }
if (++i == end_) { i = begin_; }
}
}
MutableIterator begin_;
std::size_t buckets_;
MutableIterator end_;

View File

@ -1,10 +1,14 @@
#include "util/probing_hash_table.hh"
#include "util/murmur_hash.hh"
#include "util/scoped.hh"
#define BOOST_TEST_MODULE ProbingHashTableTest
#include <boost/test/unit_test.hpp>
#include <boost/scoped_array.hpp>
#include <boost/functional/hash.hpp>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <stdint.h>
@ -19,6 +23,10 @@ struct Entry {
return key;
}
void SetKey(unsigned char to) {
key = to;
}
uint64_t GetValue() const {
return value;
}
@ -46,5 +54,49 @@ BOOST_AUTO_TEST_CASE(simple) {
BOOST_CHECK(!table.Find(2, i));
}
struct Entry64 {
uint64_t key;
typedef uint64_t Key;
Entry64() {}
explicit Entry64(uint64_t key_in) {
key = key_in;
}
Key GetKey() const { return key; }
void SetKey(uint64_t to) { key = to; }
};
struct MurmurHashEntry64 {
std::size_t operator()(uint64_t value) const {
return util::MurmurHash64A(&value, 8);
}
};
typedef ProbingHashTable<Entry64, MurmurHashEntry64> Table64;
BOOST_AUTO_TEST_CASE(Double) {
for (std::size_t initial = 19; initial < 30; ++initial) {
size_t size = Table64::Size(initial, 1.2);
scoped_malloc mem(MallocOrThrow(size));
Table64 table(mem.get(), size, std::numeric_limits<uint64_t>::max());
table.Clear();
for (uint64_t i = 0; i < 19; ++i) {
table.Insert(Entry64(i));
}
table.CheckConsistency();
mem.call_realloc(table.DoubleTo());
table.Double(mem.get());
table.CheckConsistency();
for (uint64_t i = 20; i < 40 ; ++i) {
table.Insert(Entry64(i));
}
mem.call_realloc(table.DoubleTo());
table.Double(mem.get());
table.CheckConsistency();
}
}
} // namespace
} // namespace util

View File

@ -16,6 +16,12 @@ void *MallocOrThrow(std::size_t requested) {
return ret;
}
void *CallocOrThrow(std::size_t requested) {
void *ret;
UTIL_THROW_IF_ARG(!(ret = std::calloc(1, requested)), MallocException, (requested), "in calloc");
return ret;
}
scoped_malloc::~scoped_malloc() {
std::free(p_);
}

View File

@ -14,6 +14,7 @@ class MallocException : public ErrnoException {
};
void *MallocOrThrow(std::size_t requested);
void *CallocOrThrow(std::size_t requested);
class scoped_malloc {
public: