KenLM 7408730be415db9b650560a8b2bd3e4e3af49ec9.

unistd.hh is dead.
This commit is contained in:
Kenneth Heafield 2015-05-19 15:27:30 -04:00
parent 90309aebfa
commit a70d37e46f
65 changed files with 2421 additions and 617 deletions

View File

@ -170,6 +170,7 @@ void *BinaryFormat::SetupJustVocab(std::size_t memory_size, uint8_t order) {
if (!write_mmap_) {
header_size_ = 0;
util::MapAnonymous(memory_size, memory_vocab_);
util::AdviseHugePages(memory_vocab_.get(), memory_size);
return reinterpret_cast<uint8_t*>(memory_vocab_.get());
}
header_size_ = TotalHeaderSize(order);
@ -189,6 +190,7 @@ void *BinaryFormat::SetupJustVocab(std::size_t memory_size, uint8_t order) {
break;
}
strncpy(reinterpret_cast<char*>(vocab_base), kMagicIncomplete, header_size_);
util::AdviseHugePages(vocab_base, total);
return reinterpret_cast<uint8_t*>(vocab_base) + header_size_;
}
@ -201,6 +203,7 @@ void *BinaryFormat::GrowForSearch(std::size_t memory_size, std::size_t vocab_pad
util::MapAnonymous(memory_size, memory_search_);
assert(header_size_ == 0 || write_mmap_);
vocab_base = reinterpret_cast<uint8_t*>(memory_vocab_.get()) + header_size_;
util::AdviseHugePages(memory_search_.get(), memory_size);
return reinterpret_cast<uint8_t*>(memory_search_.get());
}
@ -214,6 +217,7 @@ void *BinaryFormat::GrowForSearch(std::size_t memory_size, std::size_t vocab_pad
util::ResizeOrThrow(file_.get(), new_size);
void *ret;
MapFile(vocab_base, ret);
util::AdviseHugePages(ret, new_size);
return ret;
}

View File

@ -1,5 +1,5 @@
fakelib builder : [ glob *.cc : *test.cc *main.cc ]
../../util//kenutil ../../util/stream//stream ../../util/double-conversion//double-conversion ..//kenlm
fakelib builder : [ glob *.cc : *test.cc *main.cc ]
../../util//kenutil ../../util/stream//stream ../../util/double-conversion//double-conversion ..//kenlm ../common//common
: : : <library>/top//boost_thread $(timer-link) ;
exe lmplz : lmplz_main.cc builder /top//boost_program_options ;

View File

@ -1,5 +1,6 @@
#include "lm/builder/adjust_counts.hh"
#include "lm/builder/ngram_stream.hh"
#include "lm/common/ngram_stream.hh"
#include "lm/builder/payload.hh"
#include "util/stream/timer.hh"
#include <algorithm>
@ -13,7 +14,7 @@ BadDiscountException::~BadDiscountException() throw() {}
namespace {
// Return last word in full that is different.
const WordIndex* FindDifference(const NGram &full, const NGram &lower_last) {
const WordIndex* FindDifference(const NGram<BuildingPayload> &full, const NGram<BuildingPayload> &lower_last) {
const WordIndex *cur_word = full.end() - 1;
const WordIndex *pre_word = lower_last.end() - 1;
// Find last difference.
@ -111,15 +112,15 @@ class StatCollector {
class CollapseStream {
public:
CollapseStream(const util::stream::ChainPosition &position, uint64_t prune_threshold, const std::vector<bool>& prune_words) :
current_(NULL, NGram::OrderFromSize(position.GetChain().EntrySize())),
current_(NULL, NGram<BuildingPayload>::OrderFromSize(position.GetChain().EntrySize())),
prune_threshold_(prune_threshold),
prune_words_(prune_words),
block_(position) {
StartBlock();
}
const NGram &operator*() const { return current_; }
const NGram *operator->() const { return &current_; }
const NGram<BuildingPayload> &operator*() const { return current_; }
const NGram<BuildingPayload> *operator->() const { return &current_; }
operator bool() const { return block_; }
@ -131,14 +132,14 @@ class CollapseStream {
UpdateCopyFrom();
// Mark highest order n-grams for later pruning
if(current_.Count() <= prune_threshold_) {
current_.Mark();
if(current_.Value().count <= prune_threshold_) {
current_.Value().Mark();
}
if(!prune_words_.empty()) {
for(WordIndex* i = current_.begin(); i != current_.end(); i++) {
if(prune_words_[*i]) {
current_.Mark();
current_.Value().Mark();
break;
}
}
@ -155,14 +156,14 @@ class CollapseStream {
}
// Mark highest order n-grams for later pruning
if(current_.Count() <= prune_threshold_) {
current_.Mark();
if(current_.Value().count <= prune_threshold_) {
current_.Value().Mark();
}
if(!prune_words_.empty()) {
for(WordIndex* i = current_.begin(); i != current_.end(); i++) {
if(prune_words_[*i]) {
current_.Mark();
current_.Value().Mark();
break;
}
}
@ -182,14 +183,14 @@ class CollapseStream {
UpdateCopyFrom();
// Mark highest order n-grams for later pruning
if(current_.Count() <= prune_threshold_) {
current_.Mark();
if(current_.Value().count <= prune_threshold_) {
current_.Value().Mark();
}
if(!prune_words_.empty()) {
for(WordIndex* i = current_.begin(); i != current_.end(); i++) {
if(prune_words_[*i]) {
current_.Mark();
current_.Value().Mark();
break;
}
}
@ -200,11 +201,11 @@ class CollapseStream {
// Find last without bos.
void UpdateCopyFrom() {
for (copy_from_ -= current_.TotalSize(); copy_from_ >= current_.Base(); copy_from_ -= current_.TotalSize()) {
if (NGram(copy_from_, current_.Order()).begin()[1] != kBOS) break;
if (NGram<BuildingPayload>(copy_from_, current_.Order()).begin()[1] != kBOS) break;
}
}
NGram current_;
NGram<BuildingPayload> current_;
// Goes backwards in the block
uint8_t *copy_from_;
@ -223,36 +224,36 @@ void AdjustCounts::Run(const util::stream::ChainPositions &positions) {
if (order == 1) {
// Only unigrams. Just collect stats.
for (NGramStream full(positions[0]); full; ++full) {
for (NGramStream<BuildingPayload> full(positions[0]); full; ++full) {
// Do not prune <s> </s> <unk>
if(*full->begin() > 2) {
if(full->Count() <= prune_thresholds_[0])
full->Mark();
if(full->Value().count <= prune_thresholds_[0])
full->Value().Mark();
if(!prune_words_.empty() && prune_words_[*full->begin()])
full->Mark();
full->Value().Mark();
}
stats.AddFull(full->UnmarkedCount(), full->IsMarked());
stats.AddFull(full->Value().UnmarkedCount(), full->Value().IsMarked());
}
stats.CalculateDiscounts(discount_config_);
return;
}
NGramStreams streams;
NGramStreams<BuildingPayload> streams;
streams.Init(positions, positions.size() - 1);
CollapseStream full(positions[positions.size() - 1], prune_thresholds_.back(), prune_words_);
// Initialization: <unk> has count 0 and so does <s>.
NGramStream *lower_valid = streams.begin();
const NGramStream *const streams_begin = streams.begin();
streams[0]->Count() = 0;
NGramStream<BuildingPayload> *lower_valid = streams.begin();
const NGramStream<BuildingPayload> *const streams_begin = streams.begin();
streams[0]->Value().count = 0;
*streams[0]->begin() = kUNK;
stats.Add(0, 0);
(++streams[0])->Count() = 0;
(++streams[0])->Value().count = 0;
*streams[0]->begin() = kBOS;
// <s> is not in stats yet because it will get put in later.
@ -271,28 +272,28 @@ void AdjustCounts::Run(const util::stream::ChainPositions &positions) {
for (; lower_valid >= &streams[same]; --lower_valid) {
uint64_t order_minus_1 = lower_valid - streams_begin;
if(actual_counts[order_minus_1] <= prune_thresholds_[order_minus_1])
(*lower_valid)->Mark();
(*lower_valid)->Value().Mark();
if(!prune_words_.empty()) {
for(WordIndex* i = (*lower_valid)->begin(); i != (*lower_valid)->end(); i++) {
if(prune_words_[*i]) {
(*lower_valid)->Mark();
(*lower_valid)->Value().Mark();
break;
}
}
}
stats.Add(order_minus_1, (*lower_valid)->UnmarkedCount(), (*lower_valid)->IsMarked());
stats.Add(order_minus_1, (*lower_valid)->Value().UnmarkedCount(), (*lower_valid)->Value().IsMarked());
++*lower_valid;
}
// STEP 2: Update n-grams that still match.
// n-grams that match get count from the full entry.
for (std::size_t i = 0; i < same; ++i) {
actual_counts[i] += full->UnmarkedCount();
actual_counts[i] += full->Value().UnmarkedCount();
}
// Increment the number of unique extensions for the longest match.
if (same) ++streams[same - 1]->Count();
if (same) ++streams[same - 1]->Value().count;
// STEP 3: Initialize new n-grams.
// This is here because bos is also const WordIndex *, so copy gets
@ -301,47 +302,47 @@ void AdjustCounts::Run(const util::stream::ChainPositions &positions) {
// Initialize and mark as valid up to bos.
const WordIndex *bos;
for (bos = different; (bos > full->begin()) && (*bos != kBOS); --bos) {
NGramStream &to = *++lower_valid;
NGramStream<BuildingPayload> &to = *++lower_valid;
std::copy(bos, full_end, to->begin());
to->Count() = 1;
actual_counts[lower_valid - streams_begin] = full->UnmarkedCount();
to->Value().count = 1;
actual_counts[lower_valid - streams_begin] = full->Value().UnmarkedCount();
}
// Now bos indicates where <s> is or is the 0th word of full.
if (bos != full->begin()) {
// There is an <s> beyond the 0th word.
NGramStream &to = *++lower_valid;
NGramStream<BuildingPayload> &to = *++lower_valid;
std::copy(bos, full_end, to->begin());
// Anything that begins with <s> has full non adjusted count.
to->Count() = full->UnmarkedCount();
actual_counts[lower_valid - streams_begin] = full->UnmarkedCount();
to->Value().count = full->Value().UnmarkedCount();
actual_counts[lower_valid - streams_begin] = full->Value().UnmarkedCount();
} else {
stats.AddFull(full->UnmarkedCount(), full->IsMarked());
stats.AddFull(full->Value().UnmarkedCount(), full->Value().IsMarked());
}
assert(lower_valid >= &streams[0]);
}
// The above loop outputs n-grams when it observes changes. This outputs
// the last n-grams.
for (NGramStream *s = streams.begin(); s <= lower_valid; ++s) {
for (NGramStream<BuildingPayload> *s = streams.begin(); s <= lower_valid; ++s) {
uint64_t lower_count = actual_counts[(*s)->Order() - 1];
if(lower_count <= prune_thresholds_[(*s)->Order() - 1])
(*s)->Mark();
(*s)->Value().Mark();
if(!prune_words_.empty()) {
for(WordIndex* i = (*s)->begin(); i != (*s)->end(); i++) {
if(prune_words_[*i]) {
(*s)->Mark();
(*s)->Value().Mark();
break;
}
}
}
stats.Add(s - streams.begin(), lower_count, (*s)->IsMarked());
stats.Add(s - streams.begin(), lower_count, (*s)->Value().IsMarked());
++*s;
}
// Poison everyone! Except the N-grams which were already poisoned by the input.
for (NGramStream *s = streams.begin(); s != streams.end(); ++s)
for (NGramStream<BuildingPayload> *s = streams.begin(); s != streams.end(); ++s)
s->Poison();
stats.CalculateDiscounts(discount_config_);

View File

@ -1,6 +1,7 @@
#include "lm/builder/adjust_counts.hh"
#include "lm/builder/ngram_stream.hh"
#include "lm/common/ngram_stream.hh"
#include "lm/builder/payload.hh"
#include "util/scoped.hh"
#include <boost/thread/thread.hpp>
@ -37,7 +38,7 @@ struct Gram4 {
class WriteInput {
public:
void Run(const util::stream::ChainPosition &position) {
NGramStream input(position);
NGramStream<BuildingPayload> input(position);
Gram4 grams[] = {
{{0,0,0,0},10},
{{0,0,3,0},3},
@ -47,7 +48,7 @@ class WriteInput {
};
for (size_t i = 0; i < sizeof(grams) / sizeof(Gram4); ++i, ++input) {
memcpy(input->begin(), grams[i].ids, sizeof(WordIndex) * 4);
input->Count() = grams[i].count;
input->Value().count = grams[i].count;
}
input.Poison();
}
@ -63,7 +64,7 @@ BOOST_AUTO_TEST_CASE(Simple) {
config.block_count = 1;
util::stream::Chains chains(4);
for (unsigned i = 0; i < 4; ++i) {
config.entry_size = NGram::TotalSize(i + 1);
config.entry_size = NGram<BuildingPayload>::TotalSize(i + 1);
chains.push_back(config);
}
@ -86,25 +87,25 @@ BOOST_AUTO_TEST_CASE(Simple) {
/* BOOST_CHECK_EQUAL(4UL, counts[1]);
BOOST_CHECK_EQUAL(3UL, counts[2]);
BOOST_CHECK_EQUAL(3UL, counts[3]);*/
BOOST_REQUIRE_EQUAL(NGram::TotalSize(1) * 4, outputs[0].Size());
NGram uni(outputs[0].Get(), 1);
BOOST_REQUIRE_EQUAL(NGram<BuildingPayload>::TotalSize(1) * 4, outputs[0].Size());
NGram<BuildingPayload> uni(outputs[0].Get(), 1);
BOOST_CHECK_EQUAL(kUNK, *uni.begin());
BOOST_CHECK_EQUAL(0ULL, uni.Count());
BOOST_CHECK_EQUAL(0ULL, uni.Value().count);
uni.NextInMemory();
BOOST_CHECK_EQUAL(kBOS, *uni.begin());
BOOST_CHECK_EQUAL(0ULL, uni.Count());
BOOST_CHECK_EQUAL(0ULL, uni.Value().count);
uni.NextInMemory();
BOOST_CHECK_EQUAL(0UL, *uni.begin());
BOOST_CHECK_EQUAL(2ULL, uni.Count());
BOOST_CHECK_EQUAL(2ULL, uni.Value().count);
uni.NextInMemory();
BOOST_CHECK_EQUAL(2ULL, uni.Count());
BOOST_CHECK_EQUAL(2ULL, uni.Value().count);
BOOST_CHECK_EQUAL(2UL, *uni.begin());
BOOST_REQUIRE_EQUAL(NGram::TotalSize(2) * 4, outputs[1].Size());
NGram bi(outputs[1].Get(), 2);
BOOST_REQUIRE_EQUAL(NGram<BuildingPayload>::TotalSize(2) * 4, outputs[1].Size());
NGram<BuildingPayload> bi(outputs[1].Get(), 2);
BOOST_CHECK_EQUAL(0UL, *bi.begin());
BOOST_CHECK_EQUAL(0UL, *(bi.begin() + 1));
BOOST_CHECK_EQUAL(1ULL, bi.Count());
BOOST_CHECK_EQUAL(1ULL, bi.Value().count);
bi.NextInMemory();
}

View File

@ -0,0 +1,31 @@
#ifndef LM_BUILDER_COMBINE_COUNTS_H
#define LM_BUILDER_COMBINE_COUNTS_H
#include "lm/builder/payload.hh"
#include "lm/common/ngram.hh"
#include "lm/common/compare.hh"
#include "lm/word_index.hh"
#include "util/stream/sort.hh"
#include <functional>
#include <string>
namespace lm {
namespace builder {
// Sum counts for the same n-gram.
struct CombineCounts {
bool operator()(void *first_void, const void *second_void, const SuffixOrder &compare) const {
NGram<BuildingPayload> first(first_void, compare.Order());
// There isn't a const version of NGram.
NGram<BuildingPayload> second(const_cast<void*>(second_void), compare.Order());
if (memcmp(first.begin(), second.begin(), sizeof(WordIndex) * compare.Order())) return false;
first.Value().count += second.Value().count;
return true;
}
};
} // namespace builder
} // namespace lm
#endif // LM_BUILDER_COMBINE_COUNTS_H

View File

@ -1,6 +1,7 @@
#include "lm/builder/corpus_count.hh"
#include "lm/builder/ngram.hh"
#include "lm/builder/payload.hh"
#include "lm/common/ngram.hh"
#include "lm/lm_exception.hh"
#include "lm/vocab.hh"
#include "lm/word_index.hh"
@ -25,19 +26,6 @@ namespace lm {
namespace builder {
namespace {
#pragma pack(push)
#pragma pack(4)
struct VocabEntry {
typedef uint64_t Key;
uint64_t GetKey() const { return key; }
void SetKey(uint64_t to) { key = to; }
uint64_t key;
lm::WordIndex value;
};
#pragma pack(pop)
class DedupeHash : public std::unary_function<const WordIndex *, bool> {
public:
explicit DedupeHash(std::size_t order) : size_(order * sizeof(WordIndex)) {}
@ -115,17 +103,17 @@ class Writer {
bool found = dedupe_.FindOrInsert(DedupeEntry::Construct(gram_.begin()), at);
if (found) {
// Already present.
NGram already(at->key, gram_.Order());
++(already.Count());
NGram<BuildingPayload> already(at->key, gram_.Order());
++(already.Value().count);
// Shift left by one.
memmove(gram_.begin(), gram_.begin() + 1, sizeof(WordIndex) * (gram_.Order() - 1));
return;
}
// Complete the write.
gram_.Count() = 1;
gram_.Value().count = 1;
// Prepare the next n-gram.
if (reinterpret_cast<uint8_t*>(gram_.begin()) + gram_.TotalSize() != static_cast<uint8_t*>(block_->Get()) + block_size_) {
NGram last(gram_);
NGram<BuildingPayload> last(gram_);
gram_.NextInMemory();
std::copy(last.begin() + 1, last.end(), gram_.begin());
return;
@ -141,7 +129,7 @@ class Writer {
private:
void AddUnigramWord(WordIndex index) {
*gram_.begin() = index;
gram_.Count() = 0;
gram_.Value().count = 0;
gram_.NextInMemory();
if (gram_.Base() == static_cast<uint8_t*>(block_->Get()) + block_size_) {
block_->SetValidSize(block_size_);
@ -151,7 +139,7 @@ class Writer {
util::stream::Link block_;
NGram gram_;
NGram<BuildingPayload> gram_;
// This is the memory behind the invalid value in dedupe_.
std::vector<WordIndex> dedupe_invalid_;
@ -167,7 +155,7 @@ class Writer {
} // namespace
float CorpusCount::DedupeMultiplier(std::size_t order) {
return kProbingMultiplier * static_cast<float>(sizeof(DedupeEntry)) / static_cast<float>(NGram::TotalSize(order));
return kProbingMultiplier * static_cast<float>(sizeof(DedupeEntry)) / static_cast<float>(NGram<BuildingPayload>::TotalSize(order));
}
std::size_t CorpusCount::VocabUsage(std::size_t vocab_estimate) {
@ -202,7 +190,7 @@ void CorpusCount::Run(const util::stream::ChainPosition &position) {
token_count_ = 0;
type_count_ = 0;
const WordIndex end_sentence = vocab.FindOrInsert("</s>");
Writer writer(NGram::OrderFromSize(position.GetChain().EntrySize()), position, dedupe_mem_.get(), dedupe_mem_size_);
Writer writer(NGram<BuildingPayload>::OrderFromSize(position.GetChain().EntrySize()), position, dedupe_mem_.get(), dedupe_mem_size_);
uint64_t count = 0;
bool delimiters[256];
util::BoolCharacter::Build("\0\t\n\r ", delimiters);
@ -233,9 +221,8 @@ void CorpusCount::Run(const util::stream::ChainPosition &position) {
prune_words_.resize(vocab.Size(), true);
try {
while (true) {
StringPiece line(prune_vocab_file.ReadLine());
for (util::TokenIter<util::BoolCharacter, true> w(line, delimiters); w; ++w)
prune_words_[vocab.Index(*w)] = false;
StringPiece word(prune_vocab_file.ReadDelimited(delimiters));
prune_words_[vocab.Index(word)] = false;
}
} catch (const util::EndOfFileException &e) {}

View File

@ -1,7 +1,8 @@
#include "lm/builder/corpus_count.hh"
#include "lm/builder/ngram.hh"
#include "lm/builder/ngram_stream.hh"
#include "lm/builder/payload.hh"
#include "lm/common/ngram_stream.hh"
#include "lm/common/ngram.hh"
#include "util/file.hh"
#include "util/file_piece.hh"
@ -14,13 +15,13 @@
namespace lm { namespace builder { namespace {
#define Check(str, count) { \
#define Check(str, cnt) { \
BOOST_REQUIRE(stream); \
w = stream->begin(); \
for (util::TokenIter<util::AnyCharacter, true> t(str, " "); t; ++t, ++w) { \
BOOST_CHECK_EQUAL(*t, v[*w]); \
} \
BOOST_CHECK_EQUAL((uint64_t)count, stream->Count()); \
BOOST_CHECK_EQUAL((uint64_t)cnt, stream->Value().count); \
++stream; \
}
@ -35,14 +36,14 @@ BOOST_AUTO_TEST_CASE(Short) {
util::FilePiece input_piece(input_file.release(), "temp file");
util::stream::ChainConfig config;
config.entry_size = NGram::TotalSize(3);
config.entry_size = NGram<BuildingPayload>::TotalSize(3);
config.total_memory = config.entry_size * 20;
config.block_count = 2;
util::scoped_fd vocab(util::MakeTemp("corpus_count_test_vocab"));
util::stream::Chain chain(config);
NGramStream stream;
NGramStream<BuildingPayload> stream;
uint64_t token_count;
WordIndex type_count = 10;
std::vector<bool> prune_words;

View File

@ -1,9 +1,10 @@
#include "lm/builder/initial_probabilities.hh"
#include "lm/builder/discount.hh"
#include "lm/builder/ngram_stream.hh"
#include "lm/builder/sort.hh"
#include "lm/builder/special.hh"
#include "lm/builder/hash_gamma.hh"
#include "lm/builder/payload.hh"
#include "lm/common/ngram_stream.hh"
#include "util/murmur_hash.hh"
#include "util/file.hh"
#include "util/stream/chain.hh"
@ -32,17 +33,18 @@ struct HashBufferEntry : public BufferEntry {
// threshold.
class PruneNGramStream {
public:
PruneNGramStream(const util::stream::ChainPosition &position) :
current_(NULL, NGram::OrderFromSize(position.GetChain().EntrySize())),
dest_(NULL, NGram::OrderFromSize(position.GetChain().EntrySize())),
PruneNGramStream(const util::stream::ChainPosition &position, const SpecialVocab &specials) :
current_(NULL, NGram<BuildingPayload>::OrderFromSize(position.GetChain().EntrySize())),
dest_(NULL, NGram<BuildingPayload>::OrderFromSize(position.GetChain().EntrySize())),
currentCount_(0),
block_(position)
block_(position),
specials_(specials)
{
StartBlock();
}
NGram &operator*() { return current_; }
NGram *operator->() { return &current_; }
NGram<BuildingPayload> &operator*() { return current_; }
NGram<BuildingPayload> *operator->() { return &current_; }
operator bool() const {
return block_;
@ -50,8 +52,7 @@ class PruneNGramStream {
PruneNGramStream &operator++() {
assert(block_);
if(current_.Order() == 1 && *current_.begin() <= 2)
if(UTIL_UNLIKELY(current_.Order() == 1 && specials_.IsSpecial(*current_.begin())))
dest_.NextInMemory();
else if(currentCount_ > 0) {
if(dest_.Base() < current_.Base()) {
@ -68,10 +69,10 @@ class PruneNGramStream {
++block_;
StartBlock();
if (block_) {
currentCount_ = current_.CutoffCount();
currentCount_ = current_.Value().CutoffCount();
}
} else {
currentCount_ = current_.CutoffCount();
currentCount_ = current_.Value().CutoffCount();
}
return *this;
@ -84,23 +85,25 @@ class PruneNGramStream {
if (block_->ValidSize()) break;
}
current_.ReBase(block_->Get());
currentCount_ = current_.CutoffCount();
currentCount_ = current_.Value().CutoffCount();
dest_.ReBase(block_->Get());
}
NGram current_; // input iterator
NGram dest_; // output iterator
NGram<BuildingPayload> current_; // input iterator
NGram<BuildingPayload> dest_; // output iterator
uint64_t currentCount_;
util::stream::Link block_;
const SpecialVocab specials_;
};
// Extract an array of HashedGamma from an array of BufferEntry.
class OnlyGamma {
public:
OnlyGamma(bool pruning) : pruning_(pruning) {}
explicit OnlyGamma(bool pruning) : pruning_(pruning) {}
void Run(const util::stream::ChainPosition &position) {
for (util::stream::Link block_it(position); block_it; ++block_it) {
@ -143,7 +146,7 @@ class AddRight {
: discount_(discount), input_(input), pruning_(pruning) {}
void Run(const util::stream::ChainPosition &output) {
NGramStream in(input_);
NGramStream<BuildingPayload> in(input_);
util::stream::Stream out(output);
std::vector<WordIndex> previous(in->Order() - 1);
@ -159,17 +162,17 @@ class AddRight {
uint64_t counts[4];
memset(counts, 0, sizeof(counts));
do {
denominator += in->UnmarkedCount();
denominator += in->Value().UnmarkedCount();
// Collect unused probability mass from pruning.
// Becomes 0 for unpruned ngrams.
normalizer += in->UnmarkedCount() - in->CutoffCount();
normalizer += in->Value().UnmarkedCount() - in->Value().CutoffCount();
// Chen&Goodman do not mention counting based on cutoffs, but
// backoff becomes larger than 1 otherwise, so probably needs
// to count cutoffs. Counts normally without pruning.
if(in->CutoffCount() > 0)
++counts[std::min(in->CutoffCount(), static_cast<uint64_t>(3))];
if(in->Value().CutoffCount() > 0)
++counts[std::min(in->Value().CutoffCount(), static_cast<uint64_t>(3))];
} while (++in && !memcmp(previous_raw, in->begin(), size));
@ -202,15 +205,15 @@ class AddRight {
class MergeRight {
public:
MergeRight(bool interpolate_unigrams, const util::stream::ChainPosition &from_adder, const Discount &discount)
: interpolate_unigrams_(interpolate_unigrams), from_adder_(from_adder), discount_(discount) {}
MergeRight(bool interpolate_unigrams, const util::stream::ChainPosition &from_adder, const Discount &discount, const SpecialVocab &specials)
: interpolate_unigrams_(interpolate_unigrams), from_adder_(from_adder), discount_(discount), specials_(specials) {}
// calculate the initial probability of each n-gram (before order-interpolation)
// Run() gets invoked once for each order
void Run(const util::stream::ChainPosition &primary) {
util::stream::Stream summed(from_adder_);
PruneNGramStream grams(primary);
PruneNGramStream grams(primary, specials_);
// Without interpolation, the interpolation weight goes to <unk>.
if (grams->Order() == 1) {
@ -228,17 +231,21 @@ class MergeRight {
grams->Value().uninterp.prob = sums.gamma;
}
grams->Value().uninterp.gamma = gamma_assign;
++grams;
for (++grams; *grams->begin() != specials_.BOS(); ++grams) {
grams->Value().uninterp.prob = discount_.Apply(grams->Value().count) / sums.denominator;
grams->Value().uninterp.gamma = gamma_assign;
}
// Special case for <s>: probability 1.0. This allows <s> to be
// explicitly scores as part of the sentence without impacting
// explicitly scored as part of the sentence without impacting
// probability and computes q correctly as b(<s>).
assert(*grams->begin() == kBOS);
assert(*grams->begin() == specials_.BOS());
grams->Value().uninterp.prob = 1.0;
grams->Value().uninterp.gamma = 0.0;
while (++grams) {
grams->Value().uninterp.prob = discount_.Apply(grams->Count()) / sums.denominator;
grams->Value().uninterp.prob = discount_.Apply(grams->Value().count) / sums.denominator;
grams->Value().uninterp.gamma = gamma_assign;
}
++summed;
@ -252,8 +259,8 @@ class MergeRight {
const BufferEntry &sums = *static_cast<const BufferEntry*>(summed.Get());
do {
Payload &pay = grams->Value();
pay.uninterp.prob = discount_.Apply(grams->UnmarkedCount()) / sums.denominator;
BuildingPayload &pay = grams->Value();
pay.uninterp.prob = discount_.Apply(grams->Value().UnmarkedCount()) / sums.denominator;
pay.uninterp.gamma = sums.gamma;
} while (++grams && !memcmp(&previous[0], grams->begin(), size));
}
@ -263,6 +270,7 @@ class MergeRight {
bool interpolate_unigrams_;
util::stream::ChainPosition from_adder_;
Discount discount_;
const SpecialVocab specials_;
};
} // namespace
@ -274,7 +282,8 @@ void InitialProbabilities(
util::stream::Chains &second_in,
util::stream::Chains &gamma_out,
const std::vector<uint64_t> &prune_thresholds,
bool prune_vocab) {
bool prune_vocab,
const SpecialVocab &specials) {
for (size_t i = 0; i < primary.size(); ++i) {
util::stream::ChainConfig gamma_config = config.adder_out;
if(prune_vocab || prune_thresholds[i] > 0)
@ -287,7 +296,7 @@ void InitialProbabilities(
gamma_out.push_back(gamma_config);
gamma_out[i] >> AddRight(discounts[i], second, prune_vocab || prune_thresholds[i] > 0);
primary[i] >> MergeRight(config.interpolate_unigrams, gamma_out[i].Add(), discounts[i]);
primary[i] >> MergeRight(config.interpolate_unigrams, gamma_out[i].Add(), discounts[i], specials);
// Don't bother with the OnlyGamma thread for something to discard.
if (i) gamma_out[i] >> OnlyGamma(prune_vocab || prune_thresholds[i] > 0);

View File

@ -2,6 +2,7 @@
#define LM_BUILDER_INITIAL_PROBABILITIES_H
#include "lm/builder/discount.hh"
#include "lm/word_index.hh"
#include "util/stream/config.hh"
#include <vector>
@ -11,6 +12,8 @@ namespace util { namespace stream { class Chains; } }
namespace lm {
namespace builder {
class SpecialVocab;
struct InitialProbabilitiesConfig {
// These should be small buffers to keep the adder from getting too far ahead
util::stream::ChainConfig adder_in;
@ -34,7 +37,8 @@ void InitialProbabilities(
util::stream::Chains &second_in,
util::stream::Chains &gamma_out,
const std::vector<uint64_t> &prune_thresholds,
bool prune_vocab);
bool prune_vocab,
const SpecialVocab &vocab);
} // namespace builder
} // namespace lm

View File

@ -2,8 +2,8 @@
#include "lm/builder/hash_gamma.hh"
#include "lm/builder/joint_order.hh"
#include "lm/builder/ngram_stream.hh"
#include "lm/builder/sort.hh"
#include "lm/common/ngram_stream.hh"
#include "lm/common/compare.hh"
#include "lm/lm_exception.hh"
#include "util/fixed_array.hh"
#include "util/murmur_hash.hh"
@ -65,11 +65,12 @@ class OutputProbBackoff {
template <class Output> class Callback {
public:
Callback(float uniform_prob, const util::stream::ChainPositions &backoffs, const std::vector<uint64_t> &prune_thresholds, bool prune_vocab)
Callback(float uniform_prob, const util::stream::ChainPositions &backoffs, const std::vector<uint64_t> &prune_thresholds, bool prune_vocab, const SpecialVocab &specials)
: backoffs_(backoffs.size()), probs_(backoffs.size() + 2),
prune_thresholds_(prune_thresholds),
prune_vocab_(prune_vocab),
output_(backoffs.size() + 1 /* order */) {
output_(backoffs.size() + 1 /* order */),
specials_(specials) {
probs_[0] = uniform_prob;
for (std::size_t i = 0; i < backoffs.size(); ++i) {
backoffs_.push_back(backoffs[i]);
@ -89,13 +90,13 @@ template <class Output> class Callback {
}
}
void Enter(unsigned order_minus_1, NGram &gram) {
Payload &pay = gram.Value();
void Enter(unsigned order_minus_1, NGram<BuildingPayload> &gram) {
BuildingPayload &pay = gram.Value();
pay.complete.prob = pay.uninterp.prob + pay.uninterp.gamma * probs_[order_minus_1];
probs_[order_minus_1 + 1] = pay.complete.prob;
float out_backoff;
if (order_minus_1 < backoffs_.size() && *(gram.end() - 1) != kUNK && *(gram.end() - 1) != kEOS && backoffs_[order_minus_1]) {
if (order_minus_1 < backoffs_.size() && *(gram.end() - 1) != specials_.UNK() && *(gram.end() - 1) != specials_.EOS() && backoffs_[order_minus_1]) {
if(prune_vocab_ || prune_thresholds_[order_minus_1 + 1] > 0) {
//Compute hash value for current context
uint64_t current_hash = util::MurmurHashNative(gram.begin(), gram.Order() * sizeof(WordIndex));
@ -123,7 +124,7 @@ template <class Output> class Callback {
output_.Gram(order_minus_1, out_backoff, pay.complete);
}
void Exit(unsigned, const NGram &) const {}
void Exit(unsigned, const NGram<BuildingPayload> &) const {}
private:
util::FixedArray<util::stream::Stream> backoffs_;
@ -133,26 +134,28 @@ template <class Output> class Callback {
bool prune_vocab_;
Output output_;
const SpecialVocab specials_;
};
} // namespace
Interpolate::Interpolate(uint64_t vocab_size, const util::stream::ChainPositions &backoffs, const std::vector<uint64_t>& prune_thresholds, bool prune_vocab, bool output_q)
Interpolate::Interpolate(uint64_t vocab_size, const util::stream::ChainPositions &backoffs, const std::vector<uint64_t>& prune_thresholds, bool prune_vocab, bool output_q, const SpecialVocab &specials)
: uniform_prob_(1.0 / static_cast<float>(vocab_size)), // Includes <unk> but excludes <s>.
backoffs_(backoffs),
prune_thresholds_(prune_thresholds),
prune_vocab_(prune_vocab),
output_q_(output_q) {}
output_q_(output_q),
specials_(specials) {}
// perform order-wise interpolation
void Interpolate::Run(const util::stream::ChainPositions &positions) {
assert(positions.size() == backoffs_.size() + 1);
if (output_q_) {
typedef Callback<OutputQ> C;
C callback(uniform_prob_, backoffs_, prune_thresholds_, prune_vocab_);
C callback(uniform_prob_, backoffs_, prune_thresholds_, prune_vocab_, specials_);
JointOrder<C, SuffixOrder>(positions, callback);
} else {
typedef Callback<OutputProbBackoff> C;
C callback(uniform_prob_, backoffs_, prune_thresholds_, prune_vocab_);
C callback(uniform_prob_, backoffs_, prune_thresholds_, prune_vocab_, specials_);
JointOrder<C, SuffixOrder>(positions, callback);
}
}

View File

@ -1,6 +1,8 @@
#ifndef LM_BUILDER_INTERPOLATE_H
#define LM_BUILDER_INTERPOLATE_H
#include "lm/builder/special.hh"
#include "lm/word_index.hh"
#include "util/stream/multi_stream.hh"
#include <vector>
@ -18,7 +20,7 @@ class Interpolate {
public:
// Normally vocab_size is the unigram count-1 (since p(<s>) = 0) but might
// be larger when the user specifies a consistent vocabulary size.
explicit Interpolate(uint64_t vocab_size, const util::stream::ChainPositions &backoffs, const std::vector<uint64_t> &prune_thresholds, bool prune_vocab, bool output_q_);
explicit Interpolate(uint64_t vocab_size, const util::stream::ChainPositions &backoffs, const std::vector<uint64_t> &prune_thresholds, bool prune_vocab, bool output_q, const SpecialVocab &specials);
void Run(const util::stream::ChainPositions &positions);
@ -28,6 +30,7 @@ class Interpolate {
const std::vector<uint64_t> prune_thresholds_;
bool prune_vocab_;
bool output_q_;
const SpecialVocab specials_;
};
}} // namespaces

View File

@ -1,7 +1,8 @@
#ifndef LM_BUILDER_JOINT_ORDER_H
#define LM_BUILDER_JOINT_ORDER_H
#include "lm/builder/ngram_stream.hh"
#include "lm/common/ngram_stream.hh"
#include "lm/builder/payload.hh"
#include "lm/lm_exception.hh"
#ifdef DEBUG
@ -15,9 +16,9 @@ namespace lm { namespace builder {
template <class Callback, class Compare> void JointOrder(const util::stream::ChainPositions &positions, Callback &callback) {
// Allow matching to reference streams[-1].
NGramStreams streams_with_dummy;
NGramStreams<BuildingPayload> streams_with_dummy;
streams_with_dummy.InitWithDummy(positions);
NGramStream *streams = streams_with_dummy.begin() + 1;
NGramStream<BuildingPayload> *streams = streams_with_dummy.begin() + 1;
unsigned int order;
for (order = 0; order < positions.size() && streams[order]; ++order) {}

View File

@ -87,7 +87,7 @@ int main(int argc, char *argv[]) {
po::options_description options("Language model building options");
lm::builder::PipelineConfig pipeline;
std::string text, arpa;
std::string text, intermediate, arpa;
std::vector<std::string> pruning;
std::vector<std::string> discount_fallback;
std::vector<std::string> discount_fallback_default;
@ -116,6 +116,8 @@ int main(int argc, char *argv[]) {
("verbose_header", po::bool_switch(&verbose_header), "Add a verbose header to the ARPA file that includes information such as token count, smoothing type, etc.")
("text", po::value<std::string>(&text), "Read text from a file instead of stdin")
("arpa", po::value<std::string>(&arpa), "Write ARPA to a file instead of stdout")
("intermediate", po::value<std::string>(&intermediate), "Write ngrams to an intermediate file. Turns off ARPA output (which can be reactivated by --arpa file). Forces --renumber on. Implicitly makes --vocab_file be the provided name + .vocab.")
("renumber", po::bool_switch(&pipeline.renumber_vocabulary), "Rrenumber the vocabulary identifiers so that they are monotone with the hash of each string. This is consistent with the ordering used by the trie data structure.")
("collapse_values", po::bool_switch(&pipeline.output_q), "Collapse probability and backoff into a single value, q that yields the same sentence-level probabilities. See http://kheafield.com/professional/edinburgh/rest_paper.pdf for more details, including a proof.")
("prune", po::value<std::vector<std::string> >(&pruning)->multitoken(), "Prune n-grams with count less than or equal to the given threshold. Specify one value for each order i.e. 0 0 1 to prune singleton trigrams and above. The sequence of values must be non-decreasing and the last value applies to any remaining orders. Default is to not prune, which is equivalent to --prune 0.")
("limit_vocab_file", po::value<std::string>(&pipeline.prune_vocab_file)->default_value(""), "Read allowed vocabulary separated by whitespace. N-grams that contain vocabulary items not in this list will be pruned. Can be combined with --prune arg")
@ -212,8 +214,19 @@ int main(int argc, char *argv[]) {
}
try {
lm::builder::Output output;
output.Add(new lm::builder::PrintARPA(out.release(), verbose_header));
bool writing_intermediate = vm.count("intermediate");
if (writing_intermediate) {
pipeline.renumber_vocabulary = true;
if (!pipeline.vocab_file.empty()) {
std::cerr << "--intermediate and --vocab_file are incompatible because --intermediate already makes a vocab file." << std::endl;
return 1;
}
pipeline.vocab_file = intermediate + ".vocab";
}
lm::builder::Output output(writing_intermediate ? intermediate : pipeline.sort.temp_prefix, writing_intermediate);
if (!writing_intermediate || vm.count("arpa")) {
output.Add(new lm::builder::PrintARPA(out.release(), verbose_header));
}
lm::builder::Pipeline(pipeline, in.release(), output);
} catch (const util::MallocException &e) {
std::cerr << e.what() << std::endl;

View File

@ -1,14 +1,41 @@
#include "lm/builder/output.hh"
#include "lm/common/model_buffer.hh"
#include "util/stream/multi_stream.hh"
#include <boost/ref.hpp>
#include <iostream>
namespace lm { namespace builder {
OutputHook::~OutputHook() {}
void OutputHook::Apply(util::stream::Chains &chains) {
chains >> boost::ref(*this);
Output::Output(StringPiece file_base, bool keep_buffer)
: file_base_(file_base.data(), file_base.size()), keep_buffer_(keep_buffer) {}
void Output::SinkProbs(util::stream::Chains &chains, bool output_q) {
Apply(PROB_PARALLEL_HOOK, chains);
if (!keep_buffer_ && !Have(PROB_SEQUENTIAL_HOOK)) {
chains >> util::stream::kRecycle;
chains.Wait(true);
return;
}
lm::common::ModelBuffer buf(file_base_, keep_buffer_, output_q);
buf.Sink(chains);
chains >> util::stream::kRecycle;
chains.Wait(false);
if (Have(PROB_SEQUENTIAL_HOOK)) {
std::cerr << "=== 5/5 Writing ARPA model ===" << std::endl;
buf.Source(chains);
Apply(PROB_SEQUENTIAL_HOOK, chains);
chains >> util::stream::kRecycle;
chains.Wait(true);
}
}
void Output::Apply(HookType hook_type, util::stream::Chains &chains) {
for (boost::ptr_vector<OutputHook>::iterator entry = outputs_[hook_type].begin(); entry != outputs_[hook_type].end(); ++entry) {
entry->Sink(chains);
}
}
}} // namespaces

View File

@ -7,16 +7,14 @@
#include <boost/ptr_container/ptr_vector.hpp>
#include <boost/utility.hpp>
#include <map>
namespace util { namespace stream { class Chains; class ChainPositions; } }
/* Outputs from lmplz: ARPA< sharded files, etc */
/* Outputs from lmplz: ARPA, sharded files, etc */
namespace lm { namespace builder {
// These are different types of hooks. Values should be consecutive to enable a vector lookup.
enum HookType {
COUNT_HOOK, // Raw N-gram counts, highest order only.
// TODO: counts.
PROB_PARALLEL_HOOK, // Probability and backoff (or just q). Output must process the orders in parallel or there will be a deadlock.
PROB_SEQUENTIAL_HOOK, // Probability and backoff (or just q). Output can process orders any way it likes. This requires writing the data to disk then reading. Useful for ARPA files, which put unigrams first etc.
NUMBER_OF_HOOKS // Keep this last so we know how many values there are.
@ -30,9 +28,7 @@ class OutputHook {
virtual ~OutputHook();
virtual void Apply(util::stream::Chains &chains);
virtual void Run(const util::stream::ChainPositions &positions) = 0;
virtual void Sink(util::stream::Chains &chains) = 0;
protected:
const HeaderInfo &GetHeader() const;
@ -46,7 +42,7 @@ class OutputHook {
class Output : boost::noncopyable {
public:
Output() {}
Output(StringPiece file_base, bool keep_buffer);
// Takes ownership.
void Add(OutputHook *hook) {
@ -64,16 +60,20 @@ class Output : boost::noncopyable {
void SetHeader(const HeaderInfo &header) { header_ = header; }
const HeaderInfo &GetHeader() const { return header_; }
void Apply(HookType hook_type, util::stream::Chains &chains) {
for (boost::ptr_vector<OutputHook>::iterator entry = outputs_[hook_type].begin(); entry != outputs_[hook_type].end(); ++entry) {
entry->Apply(chains);
}
}
// This is called by the pipeline.
void SinkProbs(util::stream::Chains &chains, bool output_q);
unsigned int Steps() const { return Have(PROB_SEQUENTIAL_HOOK); }
private:
void Apply(HookType hook_type, util::stream::Chains &chains);
boost::ptr_vector<OutputHook> outputs_[NUMBER_OF_HOOKS];
int vocab_fd_;
HeaderInfo header_;
std::string file_base_;
bool keep_buffer_;
};
inline const HeaderInfo &OutputHook::GetHeader() const {

48
lm/builder/payload.hh Normal file
View File

@ -0,0 +1,48 @@
#ifndef LM_BUILDER_PAYLOAD_H
#define LM_BUILDER_PAYLOAD_H
#include "lm/weights.hh"
#include "lm/word_index.hh"
#include <stdint.h>
namespace lm { namespace builder {
struct Uninterpolated {
float prob; // Uninterpolated probability.
float gamma; // Interpolation weight for lower order.
};
union BuildingPayload {
uint64_t count;
Uninterpolated uninterp;
ProbBackoff complete;
/*mjd**********************************************************************/
bool IsMarked() const {
return count >> (sizeof(count) * 8 - 1);
}
void Mark() {
count |= (1ul << (sizeof(count) * 8 - 1));
}
void Unmark() {
count &= ~(1ul << (sizeof(count) * 8 - 1));
}
uint64_t UnmarkedCount() const {
return count & ~(1ul << (sizeof(count) * 8 - 1));
}
uint64_t CutoffCount() const {
return IsMarked() ? 0 : UnmarkedCount();
}
/*mjd**********************************************************************/
};
const WordIndex kBOS = 1;
const WordIndex kEOS = 2;
}} // namespaces
#endif // LM_BUILDER_PAYLOAD_H

View File

@ -1,14 +1,17 @@
#include "lm/builder/pipeline.hh"
#include "lm/builder/adjust_counts.hh"
#include "lm/builder/combine_counts.hh"
#include "lm/builder/corpus_count.hh"
#include "lm/builder/hash_gamma.hh"
#include "lm/builder/initial_probabilities.hh"
#include "lm/builder/interpolate.hh"
#include "lm/builder/output.hh"
#include "lm/builder/sort.hh"
#include "lm/common/compare.hh"
#include "lm/common/renumber.hh"
#include "lm/sizes.hh"
#include "lm/vocab.hh"
#include "util/exception.hh"
#include "util/file.hh"
@ -21,7 +24,10 @@
namespace lm { namespace builder {
using util::stream::Sorts;
namespace {
void PrintStatistics(const std::vector<uint64_t> &counts, const std::vector<uint64_t> &counts_pruned, const std::vector<Discount> &discounts) {
std::cerr << "Statistics:\n";
for (size_t i = 0; i < counts.size(); ++i) {
@ -37,9 +43,9 @@ void PrintStatistics(const std::vector<uint64_t> &counts, const std::vector<uint
class Master {
public:
explicit Master(PipelineConfig &config)
: config_(config), chains_(config.order), files_(config.order) {
config_.minimum_block = std::max(NGram::TotalSize(config_.order), config_.minimum_block);
explicit Master(PipelineConfig &config, unsigned output_steps)
: config_(config), chains_(config.order), unigrams_(util::MakeTemp(config_.TempPrefix())), steps_(output_steps + 4) {
config_.minimum_block = std::max(NGram<BuildingPayload>::TotalSize(config_.order), config_.minimum_block);
}
const PipelineConfig &Config() const { return config_; }
@ -52,40 +58,42 @@ class Master {
}
// This takes the (partially) sorted ngrams and sets up for adjusted counts.
void InitForAdjust(util::stream::Sort<SuffixOrder, AddCombiner> &ngrams, WordIndex types) {
void InitForAdjust(util::stream::Sort<SuffixOrder, CombineCounts> &ngrams, WordIndex types, std::size_t subtract_for_numbering) {
const std::size_t each_order_min = config_.minimum_block * config_.block_count;
// We know how many unigrams there are. Don't allocate more than needed to them.
const std::size_t min_chains = (config_.order - 1) * each_order_min +
std::min(types * NGram::TotalSize(1), each_order_min);
std::min(types * NGram<BuildingPayload>::TotalSize(1), each_order_min);
// Prevent overflow in subtracting.
const std::size_t total = std::max<std::size_t>(config_.TotalMemory(), min_chains + subtract_for_numbering + config_.minimum_block);
// Do merge sort with calculated laziness.
const std::size_t merge_using = ngrams.Merge(std::min(config_.TotalMemory() - min_chains, ngrams.DefaultLazy()));
const std::size_t merge_using = ngrams.Merge(std::min(total - min_chains - subtract_for_numbering, ngrams.DefaultLazy()));
std::vector<uint64_t> count_bounds(1, types);
CreateChains(config_.TotalMemory() - merge_using, count_bounds);
CreateChains(total - merge_using - subtract_for_numbering, count_bounds);
ngrams.Output(chains_.back(), merge_using);
// Setup unigram file.
files_.push_back(util::MakeTemp(config_.TempPrefix()));
}
// For initial probabilities, but this is generic.
void SortAndReadTwice(const std::vector<uint64_t> &counts, Sorts<ContextOrder> &sorts, util::stream::Chains &second, util::stream::ChainConfig second_config) {
bool unigrams_are_sorted = !config_.renumber_vocabulary;
// Do merge first before allocating chain memory.
for (std::size_t i = 1; i < config_.order; ++i) {
sorts[i - 1].Merge(0);
for (std::size_t i = 0; i < config_.order - unigrams_are_sorted; ++i) {
sorts[i].Merge(0);
}
// There's no lazy merge, so just divide memory amongst the chains.
CreateChains(config_.TotalMemory(), counts);
chains_.back().ActivateProgress();
chains_[0] >> files_[0].Source();
second_config.entry_size = NGram::TotalSize(1);
second.push_back(second_config);
second.back() >> files_[0].Source();
for (std::size_t i = 1; i < config_.order; ++i) {
util::scoped_fd fd(sorts[i - 1].StealCompleted());
if (unigrams_are_sorted) {
chains_[0] >> unigrams_.Source();
second_config.entry_size = NGram<BuildingPayload>::TotalSize(1);
second.push_back(second_config);
second.back() >> unigrams_.Source();
}
for (std::size_t i = unigrams_are_sorted; i < config_.order; ++i) {
util::scoped_fd fd(sorts[i - unigrams_are_sorted].StealCompleted());
chains_[i].SetProgressTarget(util::SizeOrThrow(fd.get()));
chains_[i] >> util::stream::PRead(util::DupOrThrow(fd.get()), true);
second_config.entry_size = NGram::TotalSize(i + 1);
second_config.entry_size = NGram<BuildingPayload>::TotalSize(i + 1);
second.push_back(second_config);
second.back() >> util::stream::PRead(fd.release(), true);
}
@ -96,7 +104,7 @@ class Master {
// Determine the minimum we can use for all the chains.
std::size_t min_chains = 0;
for (std::size_t i = 0; i < config_.order; ++i) {
min_chains += std::min(counts[i] * NGram::TotalSize(i + 1), static_cast<uint64_t>(config_.minimum_block));
min_chains += std::min(counts[i] * NGram<BuildingPayload>::TotalSize(i + 1), static_cast<uint64_t>(config_.minimum_block));
}
std::size_t for_merge = min_chains > config_.TotalMemory() ? 0 : (config_.TotalMemory() - min_chains);
std::vector<std::size_t> laziness;
@ -110,36 +118,24 @@ class Master {
CreateChains(for_merge + min_chains, counts);
chains_.back().ActivateProgress();
chains_[0] >> files_[0].Source();
chains_[0] >> unigrams_.Source();
for (std::size_t i = 1; i < config_.order; ++i) {
sorts[i - 1].Output(chains_[i], laziness[i - 1]);
}
}
void BufferFinal(const std::vector<uint64_t> &counts) {
chains_[0] >> files_[0].Sink();
for (std::size_t i = 1; i < config_.order; ++i) {
files_.push_back(util::MakeTemp(config_.TempPrefix()));
chains_[i] >> files_[i].Sink();
}
chains_.Wait(true);
// Use less memory. Because we can.
CreateChains(std::min(config_.sort.buffer_size * config_.order, config_.TotalMemory()), counts);
for (std::size_t i = 0; i < config_.order; ++i) {
chains_[i] >> files_[i].Source();
}
}
template <class Compare> void SetupSorts(Sorts<Compare> &sorts) {
sorts.Init(config_.order - 1);
template <class Compare> void SetupSorts(Sorts<Compare> &sorts, bool exclude_unigrams) {
sorts.Init(config_.order - exclude_unigrams);
// Unigrams don't get sorted because their order is always the same.
chains_[0] >> files_[0].Sink();
for (std::size_t i = 1; i < config_.order; ++i) {
if (exclude_unigrams) chains_[0] >> unigrams_.Sink();
for (std::size_t i = exclude_unigrams; i < config_.order; ++i) {
sorts.push_back(chains_[i], config_.sort, Compare(i + 1));
}
chains_.Wait(true);
}
unsigned int Steps() const { return steps_; }
private:
// Create chains, allocating memory to them. Totally heuristic. Count
// bounds are upper bounds on the counts or not present.
@ -150,7 +146,7 @@ class Master {
for (std::size_t i = 0; i < count_bounds.size(); ++i) {
assignments.push_back(static_cast<std::size_t>(std::min(
static_cast<uint64_t>(remaining_mem),
count_bounds[i] * static_cast<uint64_t>(NGram::TotalSize(i + 1)))));
count_bounds[i] * static_cast<uint64_t>(NGram<BuildingPayload>::TotalSize(i + 1)))));
}
assignments.resize(config_.order, remaining_mem);
@ -160,7 +156,7 @@ class Master {
// Indices of orders that have yet to be assigned.
std::vector<std::size_t> unassigned;
for (std::size_t i = 0; i < config_.order; ++i) {
portions.push_back(static_cast<float>((i+1) * NGram::TotalSize(i+1)));
portions.push_back(static_cast<float>((i+1) * NGram<BuildingPayload>::TotalSize(i+1)));
unassigned.push_back(i);
}
/*If somebody doesn't eat their full dinner, give it to the rest of the
@ -196,7 +192,7 @@ class Master {
std::cerr << "Chain sizes:";
for (std::size_t i = 0; i < config_.order; ++i) {
std::cerr << ' ' << (i+1) << ":" << assignments[i];
chains_.push_back(util::stream::ChainConfig(NGram::TotalSize(i + 1), block_count[i], assignments[i]));
chains_.push_back(util::stream::ChainConfig(NGram<BuildingPayload>::TotalSize(i + 1), block_count[i], assignments[i]));
}
std::cerr << std::endl;
}
@ -204,13 +200,15 @@ class Master {
PipelineConfig &config_;
util::stream::Chains chains_;
// Often only unigrams, but sometimes all orders.
util::FixedArray<util::stream::FileBuffer> files_;
util::stream::FileBuffer unigrams_;
const unsigned int steps_;
};
void CountText(int text_file /* input */, int vocab_file /* output */, Master &master, uint64_t &token_count, std::string &text_file_name, std::vector<bool> &prune_words) {
util::stream::Sort<SuffixOrder, CombineCounts> *CountText(int text_file /* input */, int vocab_file /* output */, Master &master, uint64_t &token_count, WordIndex &type_count, std::string &text_file_name, std::vector<bool> &prune_words) {
const PipelineConfig &config = master.Config();
std::cerr << "=== 1/5 Counting and sorting n-grams ===" << std::endl;
std::cerr << "=== 1/" << master.Steps() << " Counting and sorting n-grams ===" << std::endl;
const std::size_t vocab_usage = CorpusCount::VocabUsage(config.vocab_estimate);
UTIL_THROW_IF(config.TotalMemory() < vocab_usage, util::Exception, "Vocab hash size estimate " << vocab_usage << " exceeds total memory " << config.TotalMemory());
@ -221,37 +219,34 @@ void CountText(int text_file /* input */, int vocab_file /* output */, Master &m
(static_cast<float>(config.block_count) + CorpusCount::DedupeMultiplier(config.order)) *
// Chain likes memory expressed in terms of total memory.
static_cast<float>(config.block_count);
util::stream::Chain chain(util::stream::ChainConfig(NGram::TotalSize(config.order), config.block_count, memory_for_chain));
util::stream::Chain chain(util::stream::ChainConfig(NGram<BuildingPayload>::TotalSize(config.order), config.block_count, memory_for_chain));
WordIndex type_count = config.vocab_estimate;
type_count = config.vocab_estimate;
util::FilePiece text(text_file, NULL, &std::cerr);
text_file_name = text.FileName();
CorpusCount counter(text, vocab_file, token_count, type_count, prune_words, config.prune_vocab_file, chain.BlockSize() / chain.EntrySize(), config.disallowed_symbol_action);
chain >> boost::ref(counter);
util::stream::Sort<SuffixOrder, AddCombiner> sorter(chain, config.sort, SuffixOrder(config.order), AddCombiner());
util::scoped_ptr<util::stream::Sort<SuffixOrder, CombineCounts> > sorter(new util::stream::Sort<SuffixOrder, CombineCounts>(chain, config.sort, SuffixOrder(config.order), CombineCounts()));
chain.Wait(true);
std::cerr << "Unigram tokens " << token_count << " types " << type_count << std::endl;
std::cerr << "=== 2/5 Calculating and sorting adjusted counts ===" << std::endl;
master.InitForAdjust(sorter, type_count);
return sorter.release();
}
void InitialProbabilities(const std::vector<uint64_t> &counts, const std::vector<uint64_t> &counts_pruned, const std::vector<Discount> &discounts, Master &master, Sorts<SuffixOrder> &primary,
util::FixedArray<util::stream::FileBuffer> &gammas, const std::vector<uint64_t> &prune_thresholds, bool prune_vocab) {
void InitialProbabilities(const std::vector<uint64_t> &counts, const std::vector<uint64_t> &counts_pruned, const std::vector<Discount> &discounts, Master &master, Sorts<SuffixOrder> &primary, util::FixedArray<util::stream::FileBuffer> &gammas, const std::vector<uint64_t> &prune_thresholds, bool prune_vocab, const SpecialVocab &specials) {
const PipelineConfig &config = master.Config();
util::stream::Chains second(config.order);
{
Sorts<ContextOrder> sorts;
master.SetupSorts(sorts);
master.SetupSorts(sorts, !config.renumber_vocabulary);
PrintStatistics(counts, counts_pruned, discounts);
lm::ngram::ShowSizes(counts_pruned);
std::cerr << "=== 3/5 Calculating and sorting initial probabilities ===" << std::endl;
std::cerr << "=== 3/" << master.Steps() << " Calculating and sorting initial probabilities ===" << std::endl;
master.SortAndReadTwice(counts_pruned, sorts, second, config.initial_probs.adder_in);
}
util::stream::Chains gamma_chains(config.order);
InitialProbabilities(config.initial_probs, discounts, master.MutableChains(), second, gamma_chains, prune_thresholds, prune_vocab);
InitialProbabilities(config.initial_probs, discounts, master.MutableChains(), second, gamma_chains, prune_thresholds, prune_vocab, specials);
// Don't care about gamma for 0.
gamma_chains[0] >> util::stream::kRecycle;
gammas.Init(config.order - 1);
@ -260,11 +255,11 @@ void InitialProbabilities(const std::vector<uint64_t> &counts, const std::vector
gamma_chains[i] >> gammas[i - 1].Sink();
}
// Has to be done here due to gamma_chains scope.
master.SetupSorts(primary);
master.SetupSorts(primary, true);
}
void InterpolateProbabilities(const std::vector<uint64_t> &counts, Master &master, Sorts<SuffixOrder> &primary, util::FixedArray<util::stream::FileBuffer> &gammas) {
std::cerr << "=== 4/5 Calculating and writing order-interpolated probabilities ===" << std::endl;
void InterpolateProbabilities(const std::vector<uint64_t> &counts, Master &master, Sorts<SuffixOrder> &primary, util::FixedArray<util::stream::FileBuffer> &gammas, Output &output, const SpecialVocab &specials) {
std::cerr << "=== 4/" << master.Steps() << " Calculating and writing order-interpolated probabilities ===" << std::endl;
const PipelineConfig &config = master.Config();
master.MaximumLazyInput(counts, primary);
@ -278,13 +273,62 @@ void InterpolateProbabilities(const std::vector<uint64_t> &counts, Master &maste
read_backoffs.entry_size = sizeof(float);
gamma_chains.push_back(read_backoffs);
gamma_chains.back() >> gammas[i].Source();
gamma_chains.back() >> gammas[i].Source(true);
}
master >> Interpolate(std::max(master.Config().vocab_size_for_unk, counts[0] - 1 /* <s> is not included */), util::stream::ChainPositions(gamma_chains), config.prune_thresholds, config.prune_vocab, config.output_q);
master >> Interpolate(std::max(master.Config().vocab_size_for_unk, counts[0] - 1 /* <s> is not included */), util::stream::ChainPositions(gamma_chains), config.prune_thresholds, config.prune_vocab, config.output_q, specials);
gamma_chains >> util::stream::kRecycle;
master.BufferFinal(counts);
output.SinkProbs(master.MutableChains(), config.output_q);
}
class VocabNumbering {
public:
VocabNumbering(StringPiece vocab_file, StringPiece temp_prefix, bool renumber)
: vocab_file_(vocab_file.data(), vocab_file.size()),
temp_prefix_(temp_prefix.data(), temp_prefix.size()),
renumber_(renumber),
specials_(kBOS, kEOS) {
InitFile(renumber || vocab_file.empty());
}
int File() const { return null_delimited_.get(); }
// Compute the vocabulary mapping and return the memory used.
std::size_t ComputeMapping(WordIndex type_count) {
if (!renumber_) return 0;
util::scoped_fd previous(null_delimited_.release());
InitFile(vocab_file_.empty());
ngram::SortedVocabulary::ComputeRenumbering(type_count, previous.get(), null_delimited_.get(), vocab_mapping_);
return sizeof(WordIndex) * vocab_mapping_.size();
}
void ApplyRenumber(util::stream::Chains &chains) {
if (!renumber_) return;
for (std::size_t i = 0; i < chains.size(); ++i) {
chains[i] >> Renumber(&*vocab_mapping_.begin(), i + 1);
}
specials_ = SpecialVocab(vocab_mapping_[specials_.BOS()], vocab_mapping_[specials_.EOS()]);
}
const SpecialVocab &Specials() const { return specials_; }
private:
void InitFile(bool temp) {
null_delimited_.reset(temp ?
util::MakeTemp(temp_prefix_) :
util::CreateOrThrow(vocab_file_.c_str()));
}
std::string vocab_file_, temp_prefix_;
util::scoped_fd null_delimited_;
bool renumber_;
std::vector<WordIndex> vocab_mapping_;
SpecialVocab specials_;
};
} // namespace
void Pipeline(PipelineConfig &config, int text_file, Output &output) {
@ -293,48 +337,49 @@ void Pipeline(PipelineConfig &config, int text_file, Output &output) {
config.sort.buffer_size = config.TotalMemory() / 4;
std::cerr << "Warning: changing sort block size to " << config.sort.buffer_size << " bytes due to low total memory." << std::endl;
}
if (config.minimum_block < NGram::TotalSize(config.order)) {
config.minimum_block = NGram::TotalSize(config.order);
if (config.minimum_block < NGram<BuildingPayload>::TotalSize(config.order)) {
config.minimum_block = NGram<BuildingPayload>::TotalSize(config.order);
std::cerr << "Warning: raising minimum block to " << config.minimum_block << " to fit an ngram in every block." << std::endl;
}
UTIL_THROW_IF(config.sort.buffer_size < config.minimum_block, util::Exception, "Sort block size " << config.sort.buffer_size << " is below the minimum block size " << config.minimum_block << ".");
UTIL_THROW_IF(config.TotalMemory() < config.minimum_block * config.order * config.block_count, util::Exception,
"Not enough memory to fit " << (config.order * config.block_count) << " blocks with minimum size " << config.minimum_block << ". Increase memory to " << (config.minimum_block * config.order * config.block_count) << " bytes or decrease the minimum block size.");
UTIL_TIMER("(%w s) Total wall time elapsed\n");
Master master(config);
Master master(config, output.Steps());
// master's destructor will wait for chains. But they might be deadlocked if
// this thread dies because e.g. it ran out of memory.
try {
util::scoped_fd vocab_file(config.vocab_file.empty() ?
util::MakeTemp(config.TempPrefix()) :
util::CreateOrThrow(config.vocab_file.c_str()));
output.SetVocabFD(vocab_file.get());
VocabNumbering numbering(config.vocab_file, config.TempPrefix(), config.renumber_vocabulary);
uint64_t token_count;
WordIndex type_count;
std::string text_file_name;
std::vector<bool> prune_words;
CountText(text_file, vocab_file.get(), master, token_count, text_file_name, prune_words);
util::scoped_ptr<util::stream::Sort<SuffixOrder, CombineCounts> > sorted_counts(
CountText(text_file, numbering.File(), master, token_count, type_count, text_file_name, prune_words));
std::cerr << "Unigram tokens " << token_count << " types " << type_count << std::endl;
// Create vocab mapping, which uses temporary memory, while nothing else is happening.
std::size_t subtract_for_numbering = numbering.ComputeMapping(type_count);
output.SetVocabFD(numbering.File());
std::cerr << "=== 2/" << master.Steps() << " Calculating and sorting adjusted counts ===" << std::endl;
master.InitForAdjust(*sorted_counts, type_count, subtract_for_numbering);
sorted_counts.reset();
std::vector<uint64_t> counts;
std::vector<uint64_t> counts_pruned;
std::vector<Discount> discounts;
master >> AdjustCounts(config.prune_thresholds, counts, counts_pruned, prune_words, config.discount, discounts);
numbering.ApplyRenumber(master.MutableChains());
{
util::FixedArray<util::stream::FileBuffer> gammas;
Sorts<SuffixOrder> primary;
InitialProbabilities(counts, counts_pruned, discounts, master, primary, gammas, config.prune_thresholds, config.prune_vocab);
InterpolateProbabilities(counts_pruned, master, primary, gammas);
InitialProbabilities(counts, counts_pruned, discounts, master, primary, gammas, config.prune_thresholds, config.prune_vocab, numbering.Specials());
output.SetHeader(HeaderInfo(text_file_name, token_count, counts_pruned));
// Also does output.
InterpolateProbabilities(counts_pruned, master, primary, gammas, output, numbering.Specials());
}
std::cerr << "=== 5/5 Writing ARPA model ===" << std::endl;
output.SetHeader(HeaderInfo(text_file_name, token_count, counts_pruned));
output.Apply(PROB_SEQUENTIAL_HOOK, master.MutableChains());
master >> util::stream::kRecycle;
master.MutableChains().Wait(true);
} catch (const util::Exception &e) {
std::cerr << e.what() << std::endl;
abort();

View File

@ -39,6 +39,9 @@ struct PipelineConfig {
bool prune_vocab;
std::string prune_vocab_file;
/* Renumber the vocabulary the way the trie likes it? */
bool renumber_vocabulary;
// What to do with discount failures.
DiscountConfig discount;

View File

@ -23,30 +23,29 @@ VocabReconstitute::VocabReconstitute(int fd) {
map_.push_back(i);
}
void PrintARPA::Sink(util::stream::Chains &chains) {
chains >> boost::ref(*this);
}
void PrintARPA::Run(const util::stream::ChainPositions &positions) {
VocabReconstitute vocab(GetVocabFD());
// Write header. TODO: integers in FakeOFStream.
{
std::stringstream stream;
if (verbose_header_) {
stream << "# Input file: " << GetHeader().input_file << '\n';
stream << "# Token count: " << GetHeader().token_count << '\n';
stream << "# Smoothing: Modified Kneser-Ney" << '\n';
}
stream << "\\data\\\n";
for (size_t i = 0; i < positions.size(); ++i) {
stream << "ngram " << (i+1) << '=' << GetHeader().counts_pruned[i] << '\n';
}
stream << '\n';
std::string as_string(stream.str());
util::WriteOrThrow(out_fd_.get(), as_string.data(), as_string.size());
}
util::FakeOFStream out(out_fd_.get());
// Write header.
if (verbose_header_) {
out << "# Input file: " << GetHeader().input_file << '\n';
out << "# Token count: " << GetHeader().token_count << '\n';
out << "# Smoothing: Modified Kneser-Ney" << '\n';
}
out << "\\data\\\n";
for (size_t i = 0; i < positions.size(); ++i) {
out << "ngram " << (i+1) << '=' << GetHeader().counts_pruned[i] << '\n';
}
out << '\n';
for (unsigned order = 1; order <= positions.size(); ++order) {
out << "\\" << order << "-grams:" << '\n';
for (NGramStream stream(positions[order - 1]); stream; ++stream) {
for (NGramStream<BuildingPayload> stream(positions[order - 1]); stream; ++stream) {
// Correcting for numerical precision issues. Take that IRST.
out << stream->Value().complete.prob << '\t' << vocab.Lookup(*stream->begin());
for (const WordIndex *i = stream->begin() + 1; i != stream->end(); ++i) {

View File

@ -1,14 +1,17 @@
#ifndef LM_BUILDER_PRINT_H
#define LM_BUILDER_PRINT_H
#include "lm/builder/ngram.hh"
#include "lm/builder/ngram_stream.hh"
#include "lm/common/ngram_stream.hh"
#include "lm/builder/output.hh"
#include "lm/builder/payload.hh"
#include "lm/common/ngram.hh"
#include "util/fake_ofstream.hh"
#include "util/file.hh"
#include "util/mmap.hh"
#include "util/string_piece.hh"
#include <boost/lexical_cast.hpp>
#include <ostream>
#include <cassert>
@ -43,15 +46,15 @@ class VocabReconstitute {
};
// Not defined, only specialized.
template <class T> void PrintPayload(util::FakeOFStream &to, const Payload &payload);
template <> inline void PrintPayload<uint64_t>(util::FakeOFStream &to, const Payload &payload) {
template <class T> void PrintPayload(util::FakeOFStream &to, const BuildingPayload &payload);
template <> inline void PrintPayload<uint64_t>(util::FakeOFStream &to, const BuildingPayload &payload) {
// TODO slow
to << boost::lexical_cast<std::string>(payload.count);
to << payload.count;
}
template <> inline void PrintPayload<Uninterpolated>(util::FakeOFStream &to, const Payload &payload) {
template <> inline void PrintPayload<Uninterpolated>(util::FakeOFStream &to, const BuildingPayload &payload) {
to << log10(payload.uninterp.prob) << ' ' << log10(payload.uninterp.gamma);
}
template <> inline void PrintPayload<ProbBackoff>(util::FakeOFStream &to, const Payload &payload) {
template <> inline void PrintPayload<ProbBackoff>(util::FakeOFStream &to, const BuildingPayload &payload) {
to << payload.complete.prob << ' ' << payload.complete.backoff;
}
@ -70,8 +73,8 @@ template <class V> class Print {
void Run(const util::stream::ChainPositions &chains) {
util::scoped_fd fd(to_);
util::FakeOFStream out(to_);
NGramStreams streams(chains);
for (NGramStream *s = streams.begin(); s != streams.end(); ++s) {
NGramStreams<BuildingPayload> streams(chains);
for (NGramStream<BuildingPayload> *s = streams.begin(); s != streams.end(); ++s) {
DumpStream(*s, out);
}
}
@ -79,12 +82,12 @@ template <class V> class Print {
void Run(const util::stream::ChainPosition &position) {
util::scoped_fd fd(to_);
util::FakeOFStream out(to_);
NGramStream stream(position);
NGramStream<BuildingPayload> stream(position);
DumpStream(stream, out);
}
private:
void DumpStream(NGramStream &stream, util::FakeOFStream &to) {
void DumpStream(NGramStream<BuildingPayload> &stream, util::FakeOFStream &to) {
for (; stream; ++stream) {
PrintPayload<V>(to, stream->Value());
for (const WordIndex *w = stream->begin(); w != stream->end(); ++w) {
@ -103,6 +106,8 @@ class PrintARPA : public OutputHook {
explicit PrintARPA(int fd, bool verbose_header)
: OutputHook(PROB_SEQUENTIAL_HOOK), out_fd_(fd), verbose_header_(verbose_header) {}
void Sink(util::stream::Chains &chains);
void Run(const util::stream::ChainPositions &positions);
private:

27
lm/builder/special.hh Normal file
View File

@ -0,0 +1,27 @@
#ifndef LM_BUILDER_SPECIAL_H
#define LM_BUILDER_SPECIAL_H
#include "lm/word_index.hh"
namespace lm { namespace builder {
class SpecialVocab {
public:
SpecialVocab(WordIndex bos, WordIndex eos) : bos_(bos), eos_(eos) {}
bool IsSpecial(WordIndex word) const {
return word == kUNK || word == bos_ || word == eos_;
}
WordIndex UNK() const { return kUNK; }
WordIndex BOS() const { return bos_; }
WordIndex EOS() const { return eos_; }
private:
WordIndex bos_;
WordIndex eos_;
};
}} // namespaces
#endif // LM_BUILDER_SPECIAL_H

2
lm/common/Jamfile Normal file
View File

@ -0,0 +1,2 @@
fakelib common : [ glob *.cc : *test.cc *main.cc ]
../../util//kenutil ../../util/stream//stream ../../util/double-conversion//double-conversion ..//kenlm ;

View File

@ -1,18 +1,12 @@
#ifndef LM_BUILDER_SORT_H
#define LM_BUILDER_SORT_H
#ifndef LM_COMMON_COMPARE_H
#define LM_COMMON_COMPARE_H
#include "lm/builder/ngram_stream.hh"
#include "lm/builder/ngram.hh"
#include "lm/word_index.hh"
#include "util/stream/sort.hh"
#include "util/stream/timer.hh"
#include <functional>
#include <string>
namespace lm {
namespace builder {
/**
* Abstract parent class for defining custom n-gram comparators.
@ -175,70 +169,6 @@ class PrefixOrder : public Comparator<PrefixOrder> {
static const unsigned kMatchOffset = 0;
};
// Sum counts for the same n-gram.
struct AddCombiner {
bool operator()(void *first_void, const void *second_void, const SuffixOrder &compare) const {
NGram first(first_void, compare.Order());
// There isn't a const version of NGram.
NGram second(const_cast<void*>(second_void), compare.Order());
if (memcmp(first.begin(), second.begin(), sizeof(WordIndex) * compare.Order())) return false;
first.Count() += second.Count();
return true;
}
};
// The combiner is only used on a single chain, so I didn't bother to allow
// that template.
/**
* Represents an @ref util::FixedArray "array" capable of storing @ref util::stream::Sort "Sort" objects.
*
* In the anticipated use case, an instance of this class will maintain one @ref util::stream::Sort "Sort" object
* for each n-gram order (ranging from 1 up to the maximum n-gram order being processed).
* Use in this manner would enable the n-grams each n-gram order to be sorted, in parallel.
*
* @tparam Compare An @ref Comparator "ngram comparator" to use during sorting.
*/
template <class Compare> class Sorts : public util::FixedArray<util::stream::Sort<Compare> > {
private:
typedef util::stream::Sort<Compare> S;
typedef util::FixedArray<S> P;
public:
/**
* Constructs, but does not initialize.
*
* @ref util::FixedArray::Init() "Init" must be called before use.
*
* @see util::FixedArray::Init()
*/
Sorts() {}
/**
* Constructs an @ref util::FixedArray "array" capable of storing a fixed number of @ref util::stream::Sort "Sort" objects.
*
* @param number The maximum number of @ref util::stream::Sort "sorters" that can be held by this @ref util::FixedArray "array"
* @see util::FixedArray::FixedArray()
*/
explicit Sorts(std::size_t number) : util::FixedArray<util::stream::Sort<Compare> >(number) {}
/**
* Constructs a new @ref util::stream::Sort "Sort" object which is stored in this @ref util::FixedArray "array".
*
* The new @ref util::stream::Sort "Sort" object is constructed using the provided @ref util::stream::SortConfig "SortConfig" and @ref Comparator "ngram comparator";
* once constructed, a new worker @ref util::stream::Thread "thread" (owned by the @ref util::stream::Chain "chain") will sort the n-gram data stored
* in the @ref util::stream::Block "blocks" of the provided @ref util::stream::Chain "chain".
*
* @see util::stream::Sort::Sort()
* @see util::stream::Chain::operator>>()
*/
void push_back(util::stream::Chain &chain, const util::stream::SortConfig &config, const Compare &compare) {
new (P::end()) S(chain, config, compare); // use "placement new" syntax to initalize S in an already-allocated memory location
P::Constructed();
}
};
} // namespace builder
} // namespace lm
#endif // LM_BUILDER_SORT_H
#endif // LM_COMMON_COMPARE_H

82
lm/common/model_buffer.cc Normal file
View File

@ -0,0 +1,82 @@
#include "lm/common/model_buffer.hh"
#include "util/exception.hh"
#include "util/fake_ofstream.hh"
#include "util/file.hh"
#include "util/file_piece.hh"
#include "util/stream/io.hh"
#include "util/stream/multi_stream.hh"
#include <boost/lexical_cast.hpp>
namespace lm { namespace common {
namespace {
const char kMetadataHeader[] = "KenLM intermediate binary file";
} // namespace
ModelBuffer::ModelBuffer(const std::string &file_base, bool keep_buffer, bool output_q)
: file_base_(file_base), keep_buffer_(keep_buffer), output_q_(output_q) {}
ModelBuffer::ModelBuffer(const std::string &file_base)
: file_base_(file_base), keep_buffer_(false) {
const std::string full_name = file_base_ + ".kenlm_intermediate";
util::FilePiece in(full_name.c_str());
StringPiece token = in.ReadLine();
UTIL_THROW_IF2(token != kMetadataHeader, "File " << full_name << " begins with \"" << token << "\" not " << kMetadataHeader);
token = in.ReadDelimited();
UTIL_THROW_IF2(token != "Order", "Expected Order, got \"" << token << "\" in " << full_name);
unsigned long order = in.ReadULong();
token = in.ReadDelimited();
UTIL_THROW_IF2(token != "Payload", "Expected Payload, got \"" << token << "\" in " << full_name);
token = in.ReadDelimited();
if (token == "q") {
output_q_ = true;
} else if (token == "pb") {
output_q_ = false;
} else {
UTIL_THROW(util::Exception, "Unknown payload " << token);
}
files_.Init(order);
for (unsigned long i = 0; i < order; ++i) {
files_.push_back(util::OpenReadOrThrow((file_base_ + '.' + boost::lexical_cast<std::string>(i + 1)).c_str()));
}
}
// virtual destructor
ModelBuffer::~ModelBuffer() {}
void ModelBuffer::Sink(util::stream::Chains &chains) {
// Open files.
files_.Init(chains.size());
for (std::size_t i = 0; i < chains.size(); ++i) {
if (keep_buffer_) {
files_.push_back(util::CreateOrThrow(
(file_base_ + '.' + boost::lexical_cast<std::string>(i + 1)).c_str()
));
} else {
files_.push_back(util::MakeTemp(file_base_));
}
chains[i] >> util::stream::Write(files_.back().get());
}
if (keep_buffer_) {
util::scoped_fd metadata(util::CreateOrThrow((file_base_ + ".kenlm_intermediate").c_str()));
util::FakeOFStream meta(metadata.get(), 200);
meta << kMetadataHeader << "\nOrder " << chains.size() << "\nPayload " << (output_q_ ? "q" : "pb") << '\n';
}
}
void ModelBuffer::Source(util::stream::Chains &chains) {
assert(chains.size() == files_.size());
for (unsigned int i = 0; i < files_.size(); ++i) {
chains[i] >> util::stream::PRead(files_[i].get());
}
}
std::size_t ModelBuffer::Order() const {
return files_.size();
}
}} // namespaces

45
lm/common/model_buffer.hh Normal file
View File

@ -0,0 +1,45 @@
#ifndef LM_BUILDER_MODEL_BUFFER_H
#define LM_BUILDER_MODEL_BUFFER_H
/* Format with separate files in suffix order. Each file contains
* n-grams of the same order.
*/
#include "util/file.hh"
#include "util/fixed_array.hh"
#include <string>
namespace util { namespace stream { class Chains; } }
namespace lm { namespace common {
class ModelBuffer {
public:
// Construct for writing.
ModelBuffer(const std::string &file_base, bool keep_buffer, bool output_q);
// Load from file.
explicit ModelBuffer(const std::string &file_base);
// explicit for virtual destructor.
~ModelBuffer();
void Sink(util::stream::Chains &chains);
void Source(util::stream::Chains &chains);
// The order of the n-gram model that is associated with the model buffer.
std::size_t Order() const;
private:
const std::string file_base_;
const bool keep_buffer_;
bool output_q_;
util::FixedArray<util::scoped_fd> files_;
};
}} // namespaces
#endif // LM_BUILDER_MODEL_BUFFER_H

View File

@ -1,5 +1,5 @@
#ifndef LM_BUILDER_NGRAM_H
#define LM_BUILDER_NGRAM_H
#ifndef LM_COMMON_NGRAM_H
#define LM_COMMON_NGRAM_H
#include "lm/weights.hh"
#include "lm/word_index.hh"
@ -10,22 +10,10 @@
#include <cstring>
namespace lm {
namespace builder {
struct Uninterpolated {
float prob; // Uninterpolated probability.
float gamma; // Interpolation weight for lower order.
};
union Payload {
uint64_t count;
Uninterpolated uninterp;
ProbBackoff complete;
};
class NGram {
class NGramHeader {
public:
NGram(void *begin, std::size_t order)
NGramHeader(void *begin, std::size_t order)
: begin_(static_cast<WordIndex*>(begin)), end_(begin_ + order) {}
const uint8_t *Base() const { return reinterpret_cast<const uint8_t*>(begin_); }
@ -37,25 +25,30 @@ class NGram {
end_ = begin_ + difference;
}
// Would do operator++ but that can get confusing for a stream.
void NextInMemory() {
ReBase(&Value() + 1);
}
// These are for the vocab index.
// Lower-case in deference to STL.
const WordIndex *begin() const { return begin_; }
WordIndex *begin() { return begin_; }
const WordIndex *end() const { return end_; }
WordIndex *end() { return end_; }
const Payload &Value() const { return *reinterpret_cast<const Payload *>(end_); }
Payload &Value() { return *reinterpret_cast<Payload *>(end_); }
uint64_t &Count() { return Value().count; }
uint64_t Count() const { return Value().count; }
std::size_t Order() const { return end_ - begin_; }
private:
WordIndex *begin_, *end_;
};
template <class PayloadT> class NGram : public NGramHeader {
public:
typedef PayloadT Payload;
NGram(void *begin, std::size_t order) : NGramHeader(begin, order) {}
// Would do operator++ but that can get confusing for a stream.
void NextInMemory() {
ReBase(&Value() + 1);
}
static std::size_t TotalSize(std::size_t order) {
return order * sizeof(WordIndex) + sizeof(Payload);
}
@ -63,46 +56,17 @@ class NGram {
// Compiler should optimize this.
return TotalSize(Order());
}
static std::size_t OrderFromSize(std::size_t size) {
std::size_t ret = (size - sizeof(Payload)) / sizeof(WordIndex);
assert(size == TotalSize(ret));
return ret;
}
// manipulate msb to signal that ngram can be pruned
/*mjd**********************************************************************/
bool IsMarked() const {
return Value().count >> (sizeof(Value().count) * 8 - 1);
}
void Mark() {
Value().count |= (1ul << (sizeof(Value().count) * 8 - 1));
}
void Unmark() {
Value().count &= ~(1ul << (sizeof(Value().count) * 8 - 1));
}
uint64_t UnmarkedCount() const {
return Value().count & ~(1ul << (sizeof(Value().count) * 8 - 1));
}
uint64_t CutoffCount() const {
return IsMarked() ? 0 : UnmarkedCount();
}
/*mjd**********************************************************************/
private:
WordIndex *begin_, *end_;
const Payload &Value() const { return *reinterpret_cast<const Payload *>(end()); }
Payload &Value() { return *reinterpret_cast<Payload *>(end()); }
};
const WordIndex kUNK = 0;
const WordIndex kBOS = 1;
const WordIndex kEOS = 2;
} // namespace builder
} // namespace lm
#endif // LM_BUILDER_NGRAM_H
#endif // LM_COMMON_NGRAM_H

View File

@ -1,16 +1,16 @@
#ifndef LM_BUILDER_NGRAM_STREAM_H
#define LM_BUILDER_NGRAM_STREAM_H
#include "lm/builder/ngram.hh"
#include "lm/common/ngram.hh"
#include "util/stream/chain.hh"
#include "util/stream/multi_stream.hh"
#include "util/stream/stream.hh"
#include <cstddef>
namespace lm { namespace builder {
namespace lm {
class NGramStream {
template <class Payload> class NGramStream {
public:
NGramStream() : gram_(NULL, 0) {}
@ -20,14 +20,14 @@ class NGramStream {
void Init(const util::stream::ChainPosition &position) {
stream_.Init(position);
gram_ = NGram(stream_.Get(), NGram::OrderFromSize(position.GetChain().EntrySize()));
gram_ = NGram<Payload>(stream_.Get(), NGram<Payload>::OrderFromSize(position.GetChain().EntrySize()));
}
NGram &operator*() { return gram_; }
const NGram &operator*() const { return gram_; }
NGram<Payload> &operator*() { return gram_; }
const NGram<Payload> &operator*() const { return gram_; }
NGram *operator->() { return &gram_; }
const NGram *operator->() const { return &gram_; }
NGram<Payload> *operator->() { return &gram_; }
const NGram<Payload> *operator->() const { return &gram_; }
void *Get() { return stream_.Get(); }
const void *Get() const { return stream_.Get(); }
@ -43,16 +43,22 @@ class NGramStream {
}
private:
NGram gram_;
NGram<Payload> gram_;
util::stream::Stream stream_;
};
inline util::stream::Chain &operator>>(util::stream::Chain &chain, NGramStream &str) {
template <class Payload> inline util::stream::Chain &operator>>(util::stream::Chain &chain, NGramStream<Payload> &str) {
str.Init(chain.Add());
return chain;
}
typedef util::stream::GenericStreams<NGramStream> NGramStreams;
template <class Payload> class NGramStreams : public util::stream::GenericStreams<NGramStream<Payload> > {
private:
typedef util::stream::GenericStreams<NGramStream<Payload> > P;
public:
NGramStreams() : P() {}
NGramStreams(const util::stream::ChainPositions &positions) : P(positions) {}
};
}} // namespaces
} // namespace
#endif // LM_BUILDER_NGRAM_STREAM_H

17
lm/common/renumber.cc Normal file
View File

@ -0,0 +1,17 @@
#include "lm/common/renumber.hh"
#include "lm/common/ngram.hh"
#include "util/stream/stream.hh"
namespace lm {
void Renumber::Run(const util::stream::ChainPosition &position) {
for (util::stream::Stream stream(position); stream; ++stream) {
NGramHeader gram(stream.Get(), order_);
for (WordIndex *w = gram.begin(); w != gram.end(); ++w) {
*w = new_numbers_[*w];
}
}
}
} // namespace lm

30
lm/common/renumber.hh Normal file
View File

@ -0,0 +1,30 @@
/* Map vocab ids. This is useful to merge independently collected counts or
* change the vocab ids to the order used by the trie.
*/
#ifndef LM_COMMON_RENUMBER_H
#define LM_COMMON_RENUMBER_H
#include "lm/word_index.hh"
#include <cstddef>
namespace util { namespace stream { class ChainPosition; }}
namespace lm {
class Renumber {
public:
// Assumes the array is large enough to map all words and stays alive while
// the thread is active.
Renumber(const WordIndex *new_numbers, std::size_t order)
: new_numbers_(new_numbers), order_(order) {}
void Run(const util::stream::ChainPosition &position);
private:
const WordIndex *new_numbers_;
std::size_t order_;
};
} // namespace lm
#endif // LM_COMMON_RENUMBER_H

128
lm/kenlm_benchmark_main.cc Normal file
View File

@ -0,0 +1,128 @@
#include "lm/model.hh"
#include "util/fake_ofstream.hh"
#include "util/file.hh"
#include "util/file_piece.hh"
#include "util/usage.hh"
#include <stdint.h>
namespace {
template <class Model, class Width> void ConvertToBytes(const Model &model, int fd_in) {
util::FilePiece in(fd_in);
util::FakeOFStream out(1);
Width width;
StringPiece word;
const Width end_sentence = (Width)model.GetVocabulary().EndSentence();
while (true) {
while (in.ReadWordSameLine(word)) {
width = (Width)model.GetVocabulary().Index(word);
out.write(&width, sizeof(Width));
}
if (!in.ReadLineOrEOF(word)) break;
out.write(&end_sentence, sizeof(Width));
}
}
template <class Model, class Width> void QueryFromBytes(const Model &model, int fd_in) {
lm::ngram::State state[3];
const lm::ngram::State *const begin_state = &model.BeginSentenceState();
const lm::ngram::State *next_state = begin_state;
Width kEOS = model.GetVocabulary().EndSentence();
Width buf[4096];
float sum = 0.0;
while (true) {
std::size_t got = util::ReadOrEOF(fd_in, buf, sizeof(buf));
if (!got) break;
UTIL_THROW_IF2(got % sizeof(Width), "File size not a multiple of vocab id size " << sizeof(Width));
got /= sizeof(Width);
// Do even stuff first.
const Width *even_end = buf + (got & ~1);
// Alternating states
const Width *i;
for (i = buf; i != even_end;) {
sum += model.FullScore(*next_state, *i, state[1]).prob;
next_state = (*i++ == kEOS) ? begin_state : &state[1];
sum += model.FullScore(*next_state, *i, state[0]).prob;
next_state = (*i++ == kEOS) ? begin_state : &state[0];
}
// Odd corner case.
if (got & 1) {
sum += model.FullScore(*next_state, *i, state[2]).prob;
next_state = (*i++ == kEOS) ? begin_state : &state[2];
}
}
std::cout << "Sum is " << sum << std::endl;
}
template <class Model, class Width> void DispatchFunction(const Model &model, bool query) {
if (query) {
QueryFromBytes<Model, Width>(model, 0);
} else {
ConvertToBytes<Model, Width>(model, 0);
}
}
template <class Model> void DispatchWidth(const char *file, bool query) {
Model model(file);
lm::WordIndex bound = model.GetVocabulary().Bound();
if (bound <= 256) {
DispatchFunction<Model, uint8_t>(model, query);
} else if (bound <= 65536) {
DispatchFunction<Model, uint16_t>(model, query);
} else if (bound <= (1ULL << 32)) {
DispatchFunction<Model, uint32_t>(model, query);
} else {
DispatchFunction<Model, uint64_t>(model, query);
}
}
void Dispatch(const char *file, bool query) {
using namespace lm::ngram;
lm::ngram::ModelType model_type;
if (lm::ngram::RecognizeBinary(file, model_type)) {
switch(model_type) {
case PROBING:
DispatchWidth<lm::ngram::ProbingModel>(file, query);
break;
case REST_PROBING:
DispatchWidth<lm::ngram::RestProbingModel>(file, query);
break;
case TRIE:
DispatchWidth<lm::ngram::TrieModel>(file, query);
break;
case QUANT_TRIE:
DispatchWidth<lm::ngram::QuantTrieModel>(file, query);
break;
case ARRAY_TRIE:
DispatchWidth<lm::ngram::ArrayTrieModel>(file, query);
break;
case QUANT_ARRAY_TRIE:
DispatchWidth<lm::ngram::QuantArrayTrieModel>(file, query);
break;
default:
UTIL_THROW(util::Exception, "Unrecognized kenlm model type " << model_type);
}
} else {
UTIL_THROW(util::Exception, "Binarize before running benchmarks.");
}
}
} // namespace
int main(int argc, char *argv[]) {
if (argc != 3 || (strcmp(argv[1], "vocab") && strcmp(argv[1], "query"))) {
std::cerr
<< "Benchmark program for KenLM. Intended usage:\n"
<< "#Convert text to vocabulary ids offline. These ids are tied to a model.\n"
<< argv[0] << " vocab $model <$text >$text.vocab\n"
<< "#Ensure files are in RAM.\n"
<< "cat $text.vocab $model >/dev/null\n"
<< "#Timed query against the model, including loading.\n"
<< "time " << argv[0] << " query $model <$text.vocab\n";
return 1;
}
Dispatch(argv[2], !strcmp(argv[1], "query"));
util::PrintUsage(std::cerr);
return 0;
}

View File

@ -3,45 +3,53 @@
#include "lm/enumerate_vocab.hh"
#include "lm/model.hh"
#include "util/fake_ofstream.hh"
#include "util/file_piece.hh"
#include "util/usage.hh"
#include <cstdlib>
#include <iostream>
#include <ostream>
#include <istream>
#include <string>
#include <cmath>
namespace lm {
namespace ngram {
struct BasicPrint {
void Word(StringPiece, WordIndex, const FullScoreReturn &) const {}
void Line(uint64_t oov, float total) const {
std::cout << "Total: " << total << " OOV: " << oov << '\n';
}
void Summary(double, double, uint64_t, uint64_t) {}
class QueryPrinter {
public:
QueryPrinter(int fd, bool print_word, bool print_line, bool print_summary, bool flush)
: out_(fd), print_word_(print_word), print_line_(print_line), print_summary_(print_summary), flush_(flush) {}
void Word(StringPiece surface, WordIndex vocab, const FullScoreReturn &ret) {
if (!print_word_) return;
out_ << surface << '=' << vocab << ' ' << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << '\t';
if (flush_) out_.flush();
}
void Line(uint64_t oov, float total) {
if (!print_line_) return;
out_ << "Total: " << total << " OOV: " << oov << '\n';
if (flush_) out_.flush();
}
void Summary(double ppl_including_oov, double ppl_excluding_oov, uint64_t corpus_oov, uint64_t corpus_tokens) {
if (!print_summary_) return;
out_ <<
"Perplexity including OOVs:\t" << ppl_including_oov << "\n"
"Perplexity excluding OOVs:\t" << ppl_excluding_oov << "\n"
"OOVs:\t" << corpus_oov << "\n"
"Tokens:\t" << corpus_tokens << '\n';
out_.flush();
}
private:
util::FakeOFStream out_;
bool print_word_;
bool print_line_;
bool print_summary_;
bool flush_;
};
struct FullPrint : public BasicPrint {
void Word(StringPiece surface, WordIndex vocab, const FullScoreReturn &ret) const {
std::cout << surface << '=' << vocab << ' ' << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << '\t';
}
void Summary(double ppl_including_oov, double ppl_excluding_oov, uint64_t corpus_oov, uint64_t corpus_tokens) {
std::cout <<
"Perplexity including OOVs:\t" << ppl_including_oov << "\n"
"Perplexity excluding OOVs:\t" << ppl_excluding_oov << "\n"
"OOVs:\t" << corpus_oov << "\n"
"Tokens:\t" << corpus_tokens << '\n'
;
}
};
template <class Model, class Printer> void Query(const Model &model, bool sentence_context) {
Printer printer;
template <class Model, class Printer> void Query(const Model &model, bool sentence_context, Printer &printer) {
typename Model::State state, out;
lm::FullScoreReturn ret;
StringPiece word;
@ -92,13 +100,9 @@ template <class Model, class Printer> void Query(const Model &model, bool senten
corpus_tokens);
}
template <class Model> void Query(const char *file, const Config &config, bool sentence_context, bool show_words) {
template <class Model> void Query(const char *file, const Config &config, bool sentence_context, QueryPrinter &printer) {
Model model(file, config);
if (show_words) {
Query<Model, FullPrint>(model, sentence_context);
} else {
Query<Model, BasicPrint>(model, sentence_context);
}
Query<Model, QueryPrinter>(model, sentence_context, printer);
}
} // namespace ngram

View File

@ -10,9 +10,10 @@
void Usage(const char *name) {
std::cerr <<
"KenLM was compiled with maximum order " << KENLM_MAX_ORDER << ".\n"
"Usage: " << name << " [-n] [-s] lm_file\n"
"Usage: " << name << " [-b] [-n] [-w] [-s] lm_file\n"
"-b: Do not buffer output.\n"
"-n: Do not wrap the input in <s> and </s>.\n"
"-s: Sentence totals only.\n"
"-v summary|sentence|word: Level of verbosity\n"
"-l lazy|populate|read|parallel: Load lazily, with populate, or malloc+read\n"
"The default loading method is populate on Linux and read on others.\n";
exit(1);
@ -24,16 +25,28 @@ int main(int argc, char *argv[]) {
lm::ngram::Config config;
bool sentence_context = true;
bool show_words = true;
unsigned int verbosity = 2;
bool flush = false;
int opt;
while ((opt = getopt(argc, argv, "hnsl:")) != -1) {
while ((opt = getopt(argc, argv, "bnv:l:")) != -1) {
switch (opt) {
case 'b':
flush = true;
break;
case 'n':
sentence_context = false;
break;
case 's':
show_words = false;
case 'v':
if (!strcmp(optarg, "word") || !strcmp(optarg, "2")) {
verbosity = 2;
} else if (!strcmp(optarg, "sentence") || !strcmp(optarg, "1")) {
verbosity = 1;
} else if (!strcmp(optarg, "summary") || !strcmp(optarg, "0")) {
verbosity = 0;
} else {
Usage(argv[0]);
}
break;
case 'l':
if (!strcmp(optarg, "lazy")) {
@ -55,6 +68,7 @@ int main(int argc, char *argv[]) {
}
if (optind + 1 != argc)
Usage(argv[0]);
lm::ngram::QueryPrinter printer(1, verbosity >= 2, verbosity >= 1, true, flush);
const char *file = argv[optind];
try {
using namespace lm::ngram;
@ -62,22 +76,22 @@ int main(int argc, char *argv[]) {
if (RecognizeBinary(file, model_type)) {
switch(model_type) {
case PROBING:
Query<lm::ngram::ProbingModel>(file, config, sentence_context, show_words);
Query<lm::ngram::ProbingModel>(file, config, sentence_context, printer);
break;
case REST_PROBING:
Query<lm::ngram::RestProbingModel>(file, config, sentence_context, show_words);
Query<lm::ngram::RestProbingModel>(file, config, sentence_context, printer);
break;
case TRIE:
Query<TrieModel>(file, config, sentence_context, show_words);
Query<TrieModel>(file, config, sentence_context, printer);
break;
case QUANT_TRIE:
Query<QuantTrieModel>(file, config, sentence_context, show_words);
Query<QuantTrieModel>(file, config, sentence_context, printer);
break;
case ARRAY_TRIE:
Query<ArrayTrieModel>(file, config, sentence_context, show_words);
Query<ArrayTrieModel>(file, config, sentence_context, printer);
break;
case QUANT_ARRAY_TRIE:
Query<QuantArrayTrieModel>(file, config, sentence_context, show_words);
Query<QuantArrayTrieModel>(file, config, sentence_context, printer);
break;
default:
std::cerr << "Unrecognized kenlm model type " << model_type << std::endl;
@ -86,14 +100,11 @@ int main(int argc, char *argv[]) {
#ifdef WITH_NPLM
} else if (lm::np::Model::Recognize(file)) {
lm::np::Model model(file);
if (show_words) {
Query<lm::np::Model, lm::ngram::FullPrint>(model, sentence_context);
} else {
Query<lm::np::Model, lm::ngram::BasicPrint>(model, sentence_context);
}
Query<lm::np::Model, lm::ngram::QueryPrinter>(model, sentence_context, printer);
Query<lm::np::Model, lm::ngram::QueryPrinter>(model, sentence_context, printer);
#endif
} else {
Query<ProbingModel>(file, config, sentence_context, show_words);
Query<ProbingModel>(file, config, sentence_context, printer);
}
util::PrintUsage(std::cerr);
} catch (const std::exception &e) {

View File

@ -1,6 +1,7 @@
#ifndef LM_VALUE_H
#define LM_VALUE_H
#include "lm/config.hh"
#include "lm/model_type.hh"
#include "lm/value_build.hh"
#include "lm/weights.hh"

View File

@ -6,13 +6,14 @@
#include "lm/config.hh"
#include "lm/weights.hh"
#include "util/exception.hh"
#include "util/fake_ofstream.hh"
#include "util/file.hh"
#include "util/joint_sort.hh"
#include "util/murmur_hash.hh"
#include "util/probing_hash_table.hh"
#include <string>
#include <cstring>
#include <string>
namespace lm {
namespace ngram {
@ -31,6 +32,7 @@ const uint64_t kUnknownHash = detail::HashForVocab("<unk>", 5);
// Sadly some LMs have <UNK>.
const uint64_t kUnknownCapHash = detail::HashForVocab("<UNK>", 5);
// TODO: replace with FilePiece.
void ReadWords(int fd, EnumerateVocab *enumerate, WordIndex expected_count, uint64_t offset) {
util::SeekOrThrow(fd, offset);
// Check that we're at the right place by reading <unk> which is always first.
@ -69,10 +71,17 @@ void ReadWords(int fd, EnumerateVocab *enumerate, WordIndex expected_count, uint
UTIL_THROW_IF(expected_count != index, FormatLoadException, "The binary file has the wrong number of words at the end. This could be caused by a truncated binary file.");
}
// Constructor ordering madness.
int SeekAndReturn(int fd, uint64_t start) {
util::SeekOrThrow(fd, start);
return fd;
}
} // namespace
ImmediateWriteWordsWrapper::ImmediateWriteWordsWrapper(EnumerateVocab *inner, int fd, uint64_t start)
: inner_(inner), stream_(SeekAndReturn(fd, start)) {}
WriteWordsWrapper::WriteWordsWrapper(EnumerateVocab *inner) : inner_(inner) {}
WriteWordsWrapper::~WriteWordsWrapper() {}
void WriteWordsWrapper::Add(WordIndex index, const StringPiece &str) {
if (inner_) inner_->Add(index, str);
@ -80,6 +89,14 @@ void WriteWordsWrapper::Add(WordIndex index, const StringPiece &str) {
buffer_.push_back(0);
}
void WriteWordsWrapper::Write(int fd, uint64_t start) {
util::SeekOrThrow(fd, start);
util::WriteOrThrow(fd, buffer_.data(), buffer_.size());
// Free memory from the string.
std::string for_swap;
std::swap(buffer_, for_swap);
}
SortedVocabulary::SortedVocabulary() : begin_(NULL), end_(NULL), enumerate_(NULL) {}
uint64_t SortedVocabulary::Size(uint64_t entries, const Config &/*config*/) {
@ -126,10 +143,78 @@ WordIndex SortedVocabulary::Insert(const StringPiece &str) {
return end_ - begin_;
}
void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) {
void SortedVocabulary::FinishedLoading(ProbBackoff *reorder) {
GenericFinished(reorder);
}
namespace {
#pragma pack(push)
#pragma pack(4)
struct RenumberEntry {
uint64_t hash;
const char *str;
WordIndex old;
bool operator<(const RenumberEntry &other) const {
return hash < other.hash;
}
};
#pragma pack(pop)
} // namespace
void SortedVocabulary::ComputeRenumbering(WordIndex types, int from_words, int to_words, std::vector<WordIndex> &mapping) {
mapping.clear();
uint64_t file_size = util::SizeOrThrow(from_words);
util::scoped_memory strings;
util::MapRead(util::POPULATE_OR_READ, from_words, 0, file_size, strings);
const char *const start = static_cast<const char*>(strings.get());
UTIL_THROW_IF(memcmp(start, "<unk>", 6), FormatLoadException, "Vocab file does not begin with <unk> followed by null");
std::vector<RenumberEntry> entries;
entries.reserve(types - 1);
RenumberEntry entry;
entry.old = 1;
for (entry.str = start + 6 /* skip <unk>\0 */; entry.str < start + file_size; ++entry.old) {
StringPiece str(entry.str, strlen(entry.str));
entry.hash = detail::HashForVocab(str);
entries.push_back(entry);
entry.str += str.size() + 1;
}
UTIL_THROW_IF2(entries.size() != types - 1, "Wrong number of vocab ids. Got " << (entries.size() + 1) << " expected " << types);
std::sort(entries.begin(), entries.end());
// Write out new vocab file.
{
util::FakeOFStream out(to_words);
out << "<unk>" << '\0';
for (std::vector<RenumberEntry>::const_iterator i = entries.begin(); i != entries.end(); ++i) {
out << i->str << '\0';
}
}
strings.reset();
mapping.resize(types);
mapping[0] = 0; // <unk>
for (std::vector<RenumberEntry>::const_iterator i = entries.begin(); i != entries.end(); ++i) {
mapping[i->old] = i + 1 - entries.begin();
}
}
void SortedVocabulary::Populated() {
saw_unk_ = true;
SetSpecial(Index("<s>"), Index("</s>"), 0);
bound_ = end_ - begin_ + 1;
*(reinterpret_cast<uint64_t*>(begin_) - 1) = end_ - begin_;
}
void SortedVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset) {
end_ = begin_ + *(reinterpret_cast<const uint64_t*>(begin_) - 1);
SetSpecial(Index("<s>"), Index("</s>"), 0);
bound_ = end_ - begin_ + 1;
if (have_words) ReadWords(fd, to, bound_, offset);
}
template <class T> void SortedVocabulary::GenericFinished(T *reorder) {
if (enumerate_) {
if (!strings_to_enumerate_.empty()) {
util::PairedIterator<ProbBackoff*, StringPiece*> values(reorder_vocab + 1, &*strings_to_enumerate_.begin());
util::PairedIterator<T*, StringPiece*> values(reorder + 1, &*strings_to_enumerate_.begin());
util::JointSort(begin_, end_, values);
}
for (WordIndex i = 0; i < static_cast<WordIndex>(end_ - begin_); ++i) {
@ -139,7 +224,7 @@ void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) {
strings_to_enumerate_.clear();
string_backing_.FreeAll();
} else {
util::JointSort(begin_, end_, reorder_vocab + 1);
util::JointSort(begin_, end_, reorder + 1);
}
SetSpecial(Index("<s>"), Index("</s>"), 0);
// Save size. Excludes UNK.
@ -148,13 +233,6 @@ void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) {
bound_ = end_ - begin_ + 1;
}
void SortedVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset) {
end_ = begin_ + *(reinterpret_cast<const uint64_t*>(begin_) - 1);
SetSpecial(Index("<s>"), Index("</s>"), 0);
bound_ = end_ - begin_ + 1;
if (have_words) ReadWords(fd, to, bound_, offset);
}
namespace {
const unsigned int kProbingVocabularyVersion = 0;
} // namespace
@ -209,7 +287,7 @@ WordIndex ProbingVocabulary::Insert(const StringPiece &str) {
}
}
void ProbingVocabulary::FinishedLoading() {
void ProbingVocabulary::InternalFinishedLoading() {
lookup_.FinishedInserting();
header_->bound = bound_;
header_->version = kProbingVocabularyVersion;

View File

@ -30,15 +30,32 @@ inline uint64_t HashForVocab(const StringPiece &str) {
struct ProbingVocabularyHeader;
} // namespace detail
// Writes words immediately to a file instead of buffering, because we know
// where in the file to put them.
class ImmediateWriteWordsWrapper : public EnumerateVocab {
public:
ImmediateWriteWordsWrapper(EnumerateVocab *inner, int fd, uint64_t start);
void Add(WordIndex index, const StringPiece &str) {
stream_ << str << '\0';
if (inner_) inner_->Add(index, str);
}
private:
EnumerateVocab *inner_;
util::FakeOFStream stream_;
};
// When the binary size isn't known yet.
class WriteWordsWrapper : public EnumerateVocab {
public:
WriteWordsWrapper(EnumerateVocab *inner);
~WriteWordsWrapper();
void Add(WordIndex index, const StringPiece &str);
const std::string &Buffer() const { return buffer_; }
void Write(int fd, uint64_t start);
private:
EnumerateVocab *inner_;
@ -67,6 +84,12 @@ class SortedVocabulary : public base::Vocabulary {
// Size for purposes of file writing
static uint64_t Size(uint64_t entries, const Config &config);
/* Read null-delimited words from file from_words, renumber according to
* hash order, write null-delimited words to to_words, and create a mapping
* from old id to new id. The 0th vocab word must be <unk>.
*/
static void ComputeRenumbering(WordIndex types, int from_words, int to_words, std::vector<WordIndex> &mapping);
// Vocab words are [0, Bound()) Only valid after FinishedLoading/LoadedBinary.
WordIndex Bound() const { return bound_; }
@ -77,8 +100,8 @@ class SortedVocabulary : public base::Vocabulary {
void ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries);
// Insert and FinishedLoading go together.
WordIndex Insert(const StringPiece &str);
// Reorders reorder_vocab so that the IDs are sorted.
void FinishedLoading(ProbBackoff *reorder_vocab);
@ -89,7 +112,13 @@ class SortedVocabulary : public base::Vocabulary {
void LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset);
uint64_t *&EndHack() { return end_; }
void Populated();
private:
template <class T> void GenericFinished(T *reorder);
uint64_t *begin_, *end_;
WordIndex bound_;
@ -153,9 +182,8 @@ class ProbingVocabulary : public base::Vocabulary {
WordIndex Insert(const StringPiece &str);
template <class Weights> void FinishedLoading(Weights * /*reorder_vocab*/) {
FinishedLoading();
InternalFinishedLoading();
}
void FinishedLoading();
std::size_t UnkCountChangePadding() const { return 0; }
@ -164,6 +192,8 @@ class ProbingVocabulary : public base::Vocabulary {
void LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset);
private:
void InternalFinishedLoading();
typedef util::ProbingHashTable<ProbingVocabularyEntry, util::IdentityHash> Lookup;
Lookup lookup_;

View File

@ -7,6 +7,7 @@
namespace lm {
typedef unsigned int WordIndex;
const WordIndex kMaxWordIndex = UINT_MAX;
const WordIndex kUNK = 0;
} // namespace lm
typedef lm::WordIndex LMWordIndex;

View File

@ -13,8 +13,6 @@
#include <iostream>
#include <string>
#include "util/unistd.hh"
#if defined(__GLIBCXX__) || defined(__GLIBCPP__)
#include <ext/stdio_filebuf.h>

View File

@ -18,7 +18,6 @@
#include "ScoreStats.h"
#include "Util.h"
#include "util/unistd.hh"
using namespace std;

View File

@ -21,10 +21,13 @@ obj file_piece_test.o : file_piece_test.cc /top//boost_unit_test_framework : $(c
fakelib parallel_read : parallel_read.cc : <threading>multi:<source>/top//boost_thread <threading>multi:<define>WITH_THREADS : : <include>.. ;
fakelib kenutil : bit_packing.cc ersatz_progress.cc exception.cc file.cc file_piece.cc mmap.cc murmur_hash.cc parallel_read pool.cc random.cc read_compressed scoped.cc string_piece.cc usage.cc double-conversion//double-conversion : <include>.. <os>LINUX,<threading>single:<source>rt : : <include>.. ;
fakelib kenutil : [ glob *.cc : parallel_read.cc read_compressed.cc *_main.cc *_test.cc ] read_compressed parallel_read double-conversion//double-conversion : <include>.. <os>LINUX,<threading>single:<source>rt : : <include>.. ;
exe cat_compressed : cat_compressed_main.cc kenutil ;
#Does not install this
exe probing_hash_table_benchmark : probing_hash_table_benchmark_main.cc kenutil ;
alias programs : cat_compressed ;
import testing ;
@ -34,3 +37,5 @@ for local t in [ glob *_test.cc : file_piece_test.cc read_compressed_test.cc ] {
local name = [ MATCH "(.*)\.cc" : $(t) ] ;
unit-test $(name) : $(t) kenutil /top//boost_unit_test_framework /top//boost_filesystem /top//boost_system ;
}
build-project stream ;

View File

@ -91,6 +91,12 @@ template <class Except, class Data> typename Except::template ExceptionTag<Excep
#define UTIL_UNLIKELY(x) (x)
#endif
#if __GNUC__ >= 3
#define UTIL_LIKELY(x) __builtin_expect (!!(x), 1)
#else
#define UTIL_LIKELY(x) (x)
#endif
#define UTIL_THROW_IF_ARG(Condition, Exception, Arg, Modify) do { \
if (UTIL_UNLIKELY(Condition)) { \
UTIL_THROW_BACKEND(#Condition, Exception, Arg, Modify); \

View File

@ -1,111 +1,135 @@
/* Like std::ofstream but without being incredibly slow. Backed by a raw fd.
* Does not support many data types. Currently, it's targeted at writing ARPA
* files quickly.
* Supports most of the built-in types except for void* and long double.
*/
#ifndef UTIL_FAKE_OFSTREAM_H
#define UTIL_FAKE_OFSTREAM_H
#include "util/double-conversion/double-conversion.h"
#include "util/double-conversion/utils.h"
#include "util/file.hh"
#include "util/float_to_string.hh"
#include "util/integer_to_string.hh"
#include "util/scoped.hh"
#include "util/string_piece.hh"
#define BOOST_LEXICAL_CAST_ASSUME_C_LOCALE
#include <boost/lexical_cast.hpp>
#include <cassert>
#include <cstring>
#include <stdint.h>
namespace util {
class FakeOFStream {
public:
// Maximum over all ToString operations.
// static const std::size_t kMinBuf = 20;
// This was causing compile failures in debug, so now 20 is written directly.
//
// Does not take ownership of out.
// Allows default constructor, but must call SetFD.
explicit FakeOFStream(int out = -1, std::size_t buffer_size = 1048576)
: buf_(util::MallocOrThrow(buffer_size)),
builder_(static_cast<char*>(buf_.get()), buffer_size),
// Mostly the default but with inf instead. And no flags.
convert_(double_conversion::DoubleToStringConverter::NO_FLAGS, "inf", "NaN", 'e', -6, 21, 6, 0),
fd_(out),
buffer_size_(buffer_size) {}
: buf_(util::MallocOrThrow(std::max(buffer_size, (size_t)20))),
current_(static_cast<char*>(buf_.get())),
end_(current_ + std::max(buffer_size, (size_t)20)),
fd_(out) {}
~FakeOFStream() {
if (buf_.get()) Flush();
// Could have called Finish already
flush();
}
void SetFD(int to) {
if (builder_.position()) Flush();
flush();
fd_ = to;
}
FakeOFStream &Write(const void *data, std::size_t length) {
// Dominant case
if (static_cast<std::size_t>(builder_.size() - builder_.position()) > length) {
builder_.AddSubstring((const char*)data, length);
FakeOFStream &write(const void *data, std::size_t length) {
if (UTIL_LIKELY(current_ + length <= end_)) {
std::memcpy(current_, data, length);
current_ += length;
return *this;
}
Flush();
if (length > buffer_size_) {
util::WriteOrThrow(fd_, data, length);
flush();
if (current_ + length <= end_) {
std::memcpy(current_, data, length);
current_ += length;
} else {
builder_.AddSubstring((const char*)data, length);
util::WriteOrThrow(fd_, data, length);
}
return *this;
}
// This also covers std::string and char*
FakeOFStream &operator<<(StringPiece str) {
return Write(str.data(), str.size());
return write(str.data(), str.size());
}
FakeOFStream &operator<<(float value) {
// Odd, but this is the largest number found in the comments.
EnsureRemaining(double_conversion::DoubleToStringConverter::kMaxPrecisionDigits + 8);
convert_.ToShortestSingle(value, &builder_);
// For anything with ToStringBuf<T>::kBytes, define operator<< using ToString.
// This includes uint64_t, int64_t, uint32_t, int32_t, uint16_t, int16_t,
// float, double
private:
template <int Arg> struct EnableIfKludge {
typedef FakeOFStream type;
};
public:
template <class T> typename EnableIfKludge<ToStringBuf<T>::kBytes>::type &operator<<(const T value) {
EnsureRemaining(ToStringBuf<T>::kBytes);
current_ = ToString(value, current_);
assert(current_ <= end_);
return *this;
}
FakeOFStream &operator<<(double value) {
EnsureRemaining(double_conversion::DoubleToStringConverter::kMaxPrecisionDigits + 8);
convert_.ToShortest(value, &builder_);
return *this;
}
// Inefficient! TODO: more efficient implementation
FakeOFStream &operator<<(unsigned value) {
return *this << boost::lexical_cast<std::string>(value);
}
FakeOFStream &operator<<(char c) {
EnsureRemaining(1);
builder_.AddCharacter(c);
*current_++ = c;
return *this;
}
FakeOFStream &operator<<(unsigned char c) {
EnsureRemaining(1);
*current_++ = static_cast<char>(c);
return *this;
}
/* clang on OS X appears to consider std::size_t aka unsigned long distinct
* from uint64_t. So this function makes clang work. gcc considers
* uint64_t and std::size_t the same (on 64-bit) so this isn't necessary.
* But it does no harm since gcc sees it as a specialization of the
* EnableIfKludge template.
* Also, delegating to *this << static_cast<uint64_t>(value) would loop
* indefinitely on gcc.
*/
FakeOFStream &operator<<(std::size_t value) {
EnsureRemaining(ToStringBuf<uint64_t>::kBytes);
current_ = ToString(static_cast<uint64_t>(value), current_);
return *this;
}
// Note this does not sync.
void Flush() {
util::WriteOrThrow(fd_, buf_.get(), builder_.position());
builder_.Reset();
void flush() {
if (current_ != buf_.get()) {
util::WriteOrThrow(fd_, buf_.get(), current_ - (char*)buf_.get());
current_ = static_cast<char*>(buf_.get());
}
}
// Not necessary, but does assure the data is cleared.
void Finish() {
Flush();
// It will segfault trying to null terminate otherwise.
builder_.Finalize();
flush();
buf_.reset();
current_ = NULL;
util::FSyncOrThrow(fd_);
}
private:
void EnsureRemaining(std::size_t amount) {
if (static_cast<std::size_t>(builder_.size() - builder_.position()) <= amount) {
Flush();
if (UTIL_UNLIKELY(current_ + amount > end_)) {
flush();
assert(current_ + amount <= end_);
}
}
util::scoped_malloc buf_;
double_conversion::StringBuilder builder_;
double_conversion::DoubleToStringConverter convert_;
char *current_, *end_;
int fd_;
const std::size_t buffer_size_;
};
} // namespace

View File

@ -11,19 +11,22 @@
#include <unistd.h>
#endif
#include <iostream>
#include <string>
#include <limits>
#include <cassert>
#include <fcntl.h>
#include <cerrno>
#include <cmath>
#include <cstdlib>
#include <iostream>
#include <limits>
#include <string>
#include <fcntl.h>
#include <sys/types.h>
#include <sys/stat.h>
namespace util {
ParseNumberException::ParseNumberException(StringPiece value) throw() {
*this << "Could not parse \"" << value << "\" into a number";
*this << "Could not parse \"" << value << "\" into a ";
}
// Sigh this is the only way I could come up with to do a _const_ bool. It has ' ', '\f', '\n', '\r', '\t', and '\v' (same as isspace on C locale).
@ -62,12 +65,17 @@ FilePiece::FilePiece(std::istream &stream, const char *name, std::size_t min_buf
FilePiece::~FilePiece() {}
StringPiece FilePiece::ReadLine(char delim) {
StringPiece FilePiece::ReadLine(char delim, bool strip_cr) {
std::size_t skip = 0;
while (true) {
for (const char *i = position_ + skip; i < position_end_; ++i) {
if (*i == delim) {
StringPiece ret(position_, i - position_);
// End of line.
// Take 1 byte off the end if it's an unwanted carriage return.
const std::size_t subtract_cr = (
(strip_cr && i > position_ && *(i - 1) == '\r') ?
1 : 0);
StringPiece ret(position_, i - position_ - subtract_cr);
position_ = i + 1;
return ret;
}
@ -83,9 +91,9 @@ StringPiece FilePiece::ReadLine(char delim) {
}
}
bool FilePiece::ReadLineOrEOF(StringPiece &to, char delim) {
bool FilePiece::ReadLineOrEOF(StringPiece &to, char delim, bool strip_cr) {
try {
to = ReadLine(delim);
to = ReadLine(delim, strip_cr);
} catch (const util::EndOfFileException &e) { return false; }
return true;
}
@ -145,49 +153,59 @@ static const double_conversion::StringToDoubleConverter kConverter(
"inf",
"NaN");
void ParseNumber(const char *begin, const char *&end, float &out) {
StringPiece FirstToken(StringPiece str) {
const char *i;
for (i = str.data(); i != str.data() + str.size(); ++i) {
if (kSpaces[(unsigned char)*i]) break;
}
return StringPiece(str.data(), i - str.data());
}
const char *ParseNumber(StringPiece str, float &out) {
int count;
out = kConverter.StringToFloat(begin, end - begin, &count);
end = begin + count;
out = kConverter.StringToFloat(str.data(), str.size(), &count);
UTIL_THROW_IF_ARG(isnan(out) && str != "NaN" && str != "nan", ParseNumberException, (FirstToken(str)), "float");
return str.data() + count;
}
void ParseNumber(const char *begin, const char *&end, double &out) {
const char *ParseNumber(StringPiece str, double &out) {
int count;
out = kConverter.StringToDouble(begin, end - begin, &count);
end = begin + count;
out = kConverter.StringToDouble(str.data(), str.size(), &count);
UTIL_THROW_IF_ARG(isnan(out) && str != "NaN" && str != "nan", ParseNumberException, (FirstToken(str)), "double");
return str.data() + count;
}
void ParseNumber(const char *begin, const char *&end, long int &out) {
char *silly_end;
out = strtol(begin, &silly_end, 10);
end = silly_end;
const char *ParseNumber(StringPiece str, long int &out) {
char *end;
errno = 0;
out = strtol(str.data(), &end, 10);
UTIL_THROW_IF_ARG(errno || (end == str.data()), ParseNumberException, (FirstToken(str)), "long int");
return end;
}
void ParseNumber(const char *begin, const char *&end, unsigned long int &out) {
char *silly_end;
out = strtoul(begin, &silly_end, 10);
end = silly_end;
const char *ParseNumber(StringPiece str, unsigned long int &out) {
char *end;
errno = 0;
out = strtoul(str.data(), &end, 10);
UTIL_THROW_IF_ARG(errno || (end == str.data()), ParseNumberException, (FirstToken(str)), "unsigned long int");
return end;
}
} // namespace
template <class T> T FilePiece::ReadNumber() {
SkipSpaces();
while (last_space_ < position_) {
if (at_end_) {
if (UTIL_UNLIKELY(at_end_)) {
// Hallucinate a null off the end of the file.
std::string buffer(position_, position_end_);
const char *buf = buffer.c_str();
const char *end = buf + buffer.size();
T ret;
ParseNumber(buf, end, ret);
if (buf == end) throw ParseNumberException(buffer);
position_ += end - buf;
// Has to be null-terminated.
const char *begin = buffer.c_str();
const char *end = ParseNumber(StringPiece(begin, buffer.size()), ret);
position_ += end - begin;
return ret;
}
Shift();
}
const char *end = last_space_;
T ret;
ParseNumber(position_, end, ret);
if (end == position_) throw ParseNumberException(ReadDelimited());
position_ = end;
position_ = ParseNumber(StringPiece(position_, last_space_ - position_), ret);
return ret;
}

View File

@ -55,7 +55,7 @@ class FilePiece {
return Consume(FindDelimiterOrEOF(delim));
}
// Read word until the line or file ends.
/// Read word until the line or file ends.
bool ReadWordSameLine(StringPiece &to, const bool *delim = kSpaces) {
assert(delim[static_cast<unsigned char>('\n')]);
// Skip non-enter spaces.
@ -75,12 +75,30 @@ class FilePiece {
return true;
}
// Unlike ReadDelimited, this includes leading spaces and consumes the delimiter.
// It is similar to getline in that way.
StringPiece ReadLine(char delim = '\n');
/** Read a line of text from the file.
*
* Unlike ReadDelimited, this includes leading spaces and consumes the
* delimiter. It is similar to getline in that way.
*
* If strip_cr is true, any trailing carriate return (as would be found on
* a file written on Windows) will be left out of the returned line.
*
* Throws EndOfFileException if the end of the file is encountered. If the
* file does not end in a newline, this could mean that the last line is
* never read.
*/
StringPiece ReadLine(char delim = '\n', bool strip_cr = true);
// Doesn't throw EndOfFileException, just returns false.
bool ReadLineOrEOF(StringPiece &to, char delim = '\n');
/** Read a line of text from the file, or return false on EOF.
*
* This is like ReadLine, except it returns false where ReadLine throws
* EndOfFileException. Like ReadLine it may not read the last line in the
* file if the file does not end in a newline.
*
* If strip_cr is true, any trailing carriate return (as would be found on
* a file written on Windows) will be left out of the returned line.
*/
bool ReadLineOrEOF(StringPiece &to, char delim = '\n', bool strip_cr = true);
float ReadFloat();
double ReadDouble();

View File

@ -1,6 +1,7 @@
// Tests might fail if you have creative characters in your path. Sue me.
#include "util/file_piece.hh"
#include "util/fake_ofstream.hh"
#include "util/file.hh"
#include "util/scoped.hh"
@ -133,5 +134,21 @@ BOOST_AUTO_TEST_CASE(StreamZipReadLine) {
#endif // HAVE_ZLIB
BOOST_AUTO_TEST_CASE(Numbers) {
scoped_fd file(MakeTemp(FileLocation()));
const float floating = 3.2;
{
util::FakeOFStream writing(file.get());
writing << "94389483984398493890287 " << floating << " 5";
}
SeekOrThrow(file.get(), 0);
util::FilePiece f(file.release());
BOOST_CHECK_THROW(f.ReadULong(), ParseNumberException);
BOOST_CHECK_EQUAL("94389483984398493890287", f.ReadDelimited());
// Yes, exactly equal. Isn't double-conversion wonderful?
BOOST_CHECK_EQUAL(floating, f.ReadFloat());
BOOST_CHECK_EQUAL(5, f.ReadULong());
}
} // namespace
} // namespace util

23
util/float_to_string.cc Normal file
View File

@ -0,0 +1,23 @@
#include "util/float_to_string.hh"
#include "util/double-conversion/double-conversion.h"
#include "util/double-conversion/utils.h"
namespace util {
namespace {
const double_conversion::DoubleToStringConverter kConverter(double_conversion::DoubleToStringConverter::NO_FLAGS, "inf", "NaN", 'e', -6, 21, 6, 0);
} // namespace
char *ToString(double value, char *to) {
double_conversion::StringBuilder builder(to, ToStringBuf<double>::kBytes);
kConverter.ToShortest(value, &builder);
return &to[builder.position()];
}
char *ToString(float value, char *to) {
double_conversion::StringBuilder builder(to, ToStringBuf<float>::kBytes);
kConverter.ToShortestSingle(value, &builder);
return &to[builder.position()];
}
} // namespace util

25
util/float_to_string.hh Normal file
View File

@ -0,0 +1,25 @@
#ifndef UTIL_FLOAT_TO_STRING_H
#define UTIL_FLOAT_TO_STRING_H
// Just for ToStringBuf
#include "util/integer_to_string.hh"
namespace util {
template <> struct ToStringBuf<double> {
// DoubleToStringConverter::kBase10MaximalLength + 1 for null paranoia.
static const unsigned kBytes = 18;
};
// Single wasn't documented in double conversion, so be conservative and
// say the same as double.
template <> struct ToStringBuf<float> {
static const unsigned kBytes = 18;
};
char *ToString(double value, char *to);
char *ToString(float value, char *to);
} // namespace util
#endif // UTIL_FLOAT_TO_STRING_H

639
util/integer_to_string.cc Normal file
View File

@ -0,0 +1,639 @@
/* Fast integer to string conversion.
Source: https://github.com/miloyip/itoa-benchmark
Local modifications:
1. Return end of buffer instead of null terminating
2. Collapse to single file
3. Namespace
4. Remove test hook
5. Non-x86 support from the branch_lut code
6. Rename functions
Copyright (C) 2014 Milo Yip
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
Which is based on: http://0x80.pl/snippets/asm/sse-utoa.c
SSE: conversion integers to decimal representation
Author: Wojciech MuÅa
e-mail: wojciech_mula@poczta.onet.pl
www: http://0x80.pl/
License: BSD
initial release 2011-10-21
$Id$
*/
#include "util/integer_to_string.hh"
#include <cassert>
#include <stdint.h>
namespace util {
namespace {
const char gDigitsLut[200] = {
'0','0','0','1','0','2','0','3','0','4','0','5','0','6','0','7','0','8','0','9',
'1','0','1','1','1','2','1','3','1','4','1','5','1','6','1','7','1','8','1','9',
'2','0','2','1','2','2','2','3','2','4','2','5','2','6','2','7','2','8','2','9',
'3','0','3','1','3','2','3','3','3','4','3','5','3','6','3','7','3','8','3','9',
'4','0','4','1','4','2','4','3','4','4','4','5','4','6','4','7','4','8','4','9',
'5','0','5','1','5','2','5','3','5','4','5','5','5','6','5','7','5','8','5','9',
'6','0','6','1','6','2','6','3','6','4','6','5','6','6','6','7','6','8','6','9',
'7','0','7','1','7','2','7','3','7','4','7','5','7','6','7','7','7','8','7','9',
'8','0','8','1','8','2','8','3','8','4','8','5','8','6','8','7','8','8','8','9',
'9','0','9','1','9','2','9','3','9','4','9','5','9','6','9','7','9','8','9','9'
};
} // namespace
// SSE2 implementation according to http://0x80.pl/articles/sse-itoa.html
// Modifications: (1) fix incorrect digits (2) accept all ranges (3) write to user provided buffer.
#if defined(i386) || defined(__amd64) || defined(_M_IX86) || defined(_M_X64)
#include <emmintrin.h>
#ifdef _MSC_VER
#include "intrin.h"
#endif
#ifdef _MSC_VER
#define ALIGN_PRE __declspec(align(16))
#define ALIGN_SUF
#else
#define ALIGN_PRE
#define ALIGN_SUF __attribute__ ((aligned(16)))
#endif
namespace {
static const uint32_t kDiv10000 = 0xd1b71759;
ALIGN_PRE static const uint32_t kDiv10000Vector[4] ALIGN_SUF = { kDiv10000, kDiv10000, kDiv10000, kDiv10000 };
ALIGN_PRE static const uint32_t k10000Vector[4] ALIGN_SUF = { 10000, 10000, 10000, 10000 };
ALIGN_PRE static const uint16_t kDivPowersVector[8] ALIGN_SUF = { 8389, 5243, 13108, 32768, 8389, 5243, 13108, 32768 }; // 10^3, 10^2, 10^1, 10^0
ALIGN_PRE static const uint16_t kShiftPowersVector[8] ALIGN_SUF = {
1 << (16 - (23 + 2 - 16)),
1 << (16 - (19 + 2 - 16)),
1 << (16 - 1 - 2),
1 << (15),
1 << (16 - (23 + 2 - 16)),
1 << (16 - (19 + 2 - 16)),
1 << (16 - 1 - 2),
1 << (15)
};
ALIGN_PRE static const uint16_t k10Vector[8] ALIGN_SUF = { 10, 10, 10, 10, 10, 10, 10, 10 };
ALIGN_PRE static const char kAsciiZero[16] ALIGN_SUF = { '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0' };
inline __m128i Convert8DigitsSSE2(uint32_t value) {
assert(value <= 99999999);
// abcd, efgh = abcdefgh divmod 10000
const __m128i abcdefgh = _mm_cvtsi32_si128(value);
const __m128i abcd = _mm_srli_epi64(_mm_mul_epu32(abcdefgh, reinterpret_cast<const __m128i*>(kDiv10000Vector)[0]), 45);
const __m128i efgh = _mm_sub_epi32(abcdefgh, _mm_mul_epu32(abcd, reinterpret_cast<const __m128i*>(k10000Vector)[0]));
// v1 = [ abcd, efgh, 0, 0, 0, 0, 0, 0 ]
const __m128i v1 = _mm_unpacklo_epi16(abcd, efgh);
// v1a = v1 * 4 = [ abcd * 4, efgh * 4, 0, 0, 0, 0, 0, 0 ]
const __m128i v1a = _mm_slli_epi64(v1, 2);
// v2 = [ abcd * 4, abcd * 4, abcd * 4, abcd * 4, efgh * 4, efgh * 4, efgh * 4, efgh * 4 ]
const __m128i v2a = _mm_unpacklo_epi16(v1a, v1a);
const __m128i v2 = _mm_unpacklo_epi32(v2a, v2a);
// v4 = v2 div 10^3, 10^2, 10^1, 10^0 = [ a, ab, abc, abcd, e, ef, efg, efgh ]
const __m128i v3 = _mm_mulhi_epu16(v2, reinterpret_cast<const __m128i*>(kDivPowersVector)[0]);
const __m128i v4 = _mm_mulhi_epu16(v3, reinterpret_cast<const __m128i*>(kShiftPowersVector)[0]);
// v5 = v4 * 10 = [ a0, ab0, abc0, abcd0, e0, ef0, efg0, efgh0 ]
const __m128i v5 = _mm_mullo_epi16(v4, reinterpret_cast<const __m128i*>(k10Vector)[0]);
// v6 = v5 << 16 = [ 0, a0, ab0, abc0, 0, e0, ef0, efg0 ]
const __m128i v6 = _mm_slli_epi64(v5, 16);
// v7 = v4 - v6 = { a, b, c, d, e, f, g, h }
const __m128i v7 = _mm_sub_epi16(v4, v6);
return v7;
}
inline __m128i ShiftDigits_SSE2(__m128i a, unsigned digit) {
assert(digit <= 8);
switch (digit) {
case 0: return a;
case 1: return _mm_srli_si128(a, 1);
case 2: return _mm_srli_si128(a, 2);
case 3: return _mm_srli_si128(a, 3);
case 4: return _mm_srli_si128(a, 4);
case 5: return _mm_srli_si128(a, 5);
case 6: return _mm_srli_si128(a, 6);
case 7: return _mm_srli_si128(a, 7);
case 8: return _mm_srli_si128(a, 8);
}
return a; // should not execute here.
}
} // namespace
// Original name: u32toa_sse2
char *ToString(uint32_t value, char* buffer) {
if (value < 10000) {
const uint32_t d1 = (value / 100) << 1;
const uint32_t d2 = (value % 100) << 1;
if (value >= 1000)
*buffer++ = gDigitsLut[d1];
if (value >= 100)
*buffer++ = gDigitsLut[d1 + 1];
if (value >= 10)
*buffer++ = gDigitsLut[d2];
*buffer++ = gDigitsLut[d2 + 1];
//*buffer++ = '\0';
return buffer;
}
else if (value < 100000000) {
// Experiment shows that this case SSE2 is slower
#if 0
const __m128i a = Convert8DigitsSSE2(value);
// Convert to bytes, add '0'
const __m128i va = _mm_add_epi8(_mm_packus_epi16(a, _mm_setzero_si128()), reinterpret_cast<const __m128i*>(kAsciiZero)[0]);
// Count number of digit
const unsigned mask = _mm_movemask_epi8(_mm_cmpeq_epi8(va, reinterpret_cast<const __m128i*>(kAsciiZero)[0]));
unsigned long digit;
#ifdef _MSC_VER
_BitScanForward(&digit, ~mask | 0x8000);
#else
digit = __builtin_ctz(~mask | 0x8000);
#endif
// Shift digits to the beginning
__m128i result = ShiftDigits_SSE2(va, digit);
//__m128i result = _mm_srl_epi64(va, _mm_cvtsi32_si128(digit * 8));
_mm_storel_epi64(reinterpret_cast<__m128i*>(buffer), result);
buffer[8 - digit] = '\0';
#else
// value = bbbbcccc
const uint32_t b = value / 10000;
const uint32_t c = value % 10000;
const uint32_t d1 = (b / 100) << 1;
const uint32_t d2 = (b % 100) << 1;
const uint32_t d3 = (c / 100) << 1;
const uint32_t d4 = (c % 100) << 1;
if (value >= 10000000)
*buffer++ = gDigitsLut[d1];
if (value >= 1000000)
*buffer++ = gDigitsLut[d1 + 1];
if (value >= 100000)
*buffer++ = gDigitsLut[d2];
*buffer++ = gDigitsLut[d2 + 1];
*buffer++ = gDigitsLut[d3];
*buffer++ = gDigitsLut[d3 + 1];
*buffer++ = gDigitsLut[d4];
*buffer++ = gDigitsLut[d4 + 1];
// *buffer++ = '\0';
return buffer;
#endif
}
else {
// value = aabbbbbbbb in decimal
const uint32_t a = value / 100000000; // 1 to 42
value %= 100000000;
if (a >= 10) {
const unsigned i = a << 1;
*buffer++ = gDigitsLut[i];
*buffer++ = gDigitsLut[i + 1];
}
else
*buffer++ = '0' + static_cast<char>(a);
const __m128i b = Convert8DigitsSSE2(value);
const __m128i ba = _mm_add_epi8(_mm_packus_epi16(_mm_setzero_si128(), b), reinterpret_cast<const __m128i*>(kAsciiZero)[0]);
const __m128i result = _mm_srli_si128(ba, 8);
_mm_storel_epi64(reinterpret_cast<__m128i*>(buffer), result);
// buffer[8] = '\0';
return buffer + 8;
}
}
// Original name: u64toa_sse2
char *ToString(uint64_t value, char* buffer) {
if (value < 100000000) {
uint32_t v = static_cast<uint32_t>(value);
if (v < 10000) {
const uint32_t d1 = (v / 100) << 1;
const uint32_t d2 = (v % 100) << 1;
if (v >= 1000)
*buffer++ = gDigitsLut[d1];
if (v >= 100)
*buffer++ = gDigitsLut[d1 + 1];
if (v >= 10)
*buffer++ = gDigitsLut[d2];
*buffer++ = gDigitsLut[d2 + 1];
//*buffer++ = '\0';
return buffer;
}
else {
// Experiment shows that this case SSE2 is slower
#if 0
const __m128i a = Convert8DigitsSSE2(v);
// Convert to bytes, add '0'
const __m128i va = _mm_add_epi8(_mm_packus_epi16(a, _mm_setzero_si128()), reinterpret_cast<const __m128i*>(kAsciiZero)[0]);
// Count number of digit
const unsigned mask = _mm_movemask_epi8(_mm_cmpeq_epi8(va, reinterpret_cast<const __m128i*>(kAsciiZero)[0]));
unsigned long digit;
#ifdef _MSC_VER
_BitScanForward(&digit, ~mask | 0x8000);
#else
digit = __builtin_ctz(~mask | 0x8000);
#endif
// Shift digits to the beginning
__m128i result = ShiftDigits_SSE2(va, digit);
_mm_storel_epi64(reinterpret_cast<__m128i*>(buffer), result);
buffer[8 - digit] = '\0';
#else
// value = bbbbcccc
const uint32_t b = v / 10000;
const uint32_t c = v % 10000;
const uint32_t d1 = (b / 100) << 1;
const uint32_t d2 = (b % 100) << 1;
const uint32_t d3 = (c / 100) << 1;
const uint32_t d4 = (c % 100) << 1;
if (value >= 10000000)
*buffer++ = gDigitsLut[d1];
if (value >= 1000000)
*buffer++ = gDigitsLut[d1 + 1];
if (value >= 100000)
*buffer++ = gDigitsLut[d2];
*buffer++ = gDigitsLut[d2 + 1];
*buffer++ = gDigitsLut[d3];
*buffer++ = gDigitsLut[d3 + 1];
*buffer++ = gDigitsLut[d4];
*buffer++ = gDigitsLut[d4 + 1];
//*buffer++ = '\0';
return buffer;
#endif
}
}
else if (value < 10000000000000000) {
const uint32_t v0 = static_cast<uint32_t>(value / 100000000);
const uint32_t v1 = static_cast<uint32_t>(value % 100000000);
const __m128i a0 = Convert8DigitsSSE2(v0);
const __m128i a1 = Convert8DigitsSSE2(v1);
// Convert to bytes, add '0'
const __m128i va = _mm_add_epi8(_mm_packus_epi16(a0, a1), reinterpret_cast<const __m128i*>(kAsciiZero)[0]);
// Count number of digit
const unsigned mask = _mm_movemask_epi8(_mm_cmpeq_epi8(va, reinterpret_cast<const __m128i*>(kAsciiZero)[0]));
#ifdef _MSC_VER
unsigned long digit;
_BitScanForward(&digit, ~mask | 0x8000);
#else
unsigned digit = __builtin_ctz(~mask | 0x8000);
#endif
// Shift digits to the beginning
__m128i result = ShiftDigits_SSE2(va, digit);
_mm_storeu_si128(reinterpret_cast<__m128i*>(buffer), result);
// buffer[16 - digit] = '\0';
return &buffer[16 - digit];
}
else {
const uint32_t a = static_cast<uint32_t>(value / 10000000000000000); // 1 to 1844
value %= 10000000000000000;
if (a < 10)
*buffer++ = '0' + static_cast<char>(a);
else if (a < 100) {
const uint32_t i = a << 1;
*buffer++ = gDigitsLut[i];
*buffer++ = gDigitsLut[i + 1];
}
else if (a < 1000) {
*buffer++ = '0' + static_cast<char>(a / 100);
const uint32_t i = (a % 100) << 1;
*buffer++ = gDigitsLut[i];
*buffer++ = gDigitsLut[i + 1];
}
else {
const uint32_t i = (a / 100) << 1;
const uint32_t j = (a % 100) << 1;
*buffer++ = gDigitsLut[i];
*buffer++ = gDigitsLut[i + 1];
*buffer++ = gDigitsLut[j];
*buffer++ = gDigitsLut[j + 1];
}
const uint32_t v0 = static_cast<uint32_t>(value / 100000000);
const uint32_t v1 = static_cast<uint32_t>(value % 100000000);
const __m128i a0 = Convert8DigitsSSE2(v0);
const __m128i a1 = Convert8DigitsSSE2(v1);
// Convert to bytes, add '0'
const __m128i va = _mm_add_epi8(_mm_packus_epi16(a0, a1), reinterpret_cast<const __m128i*>(kAsciiZero)[0]);
_mm_storeu_si128(reinterpret_cast<__m128i*>(buffer), va);
// buffer[16] = '\0';
return &buffer[16];
}
}
#else // Generic Non-x86 case
// Orignal name: u32toa_branchlut
char *ToString(uint32_t value, char* buffer) {
if (value < 10000) {
const uint32_t d1 = (value / 100) << 1;
const uint32_t d2 = (value % 100) << 1;
if (value >= 1000)
*buffer++ = gDigitsLut[d1];
if (value >= 100)
*buffer++ = gDigitsLut[d1 + 1];
if (value >= 10)
*buffer++ = gDigitsLut[d2];
*buffer++ = gDigitsLut[d2 + 1];
}
else if (value < 100000000) {
// value = bbbbcccc
const uint32_t b = value / 10000;
const uint32_t c = value % 10000;
const uint32_t d1 = (b / 100) << 1;
const uint32_t d2 = (b % 100) << 1;
const uint32_t d3 = (c / 100) << 1;
const uint32_t d4 = (c % 100) << 1;
if (value >= 10000000)
*buffer++ = gDigitsLut[d1];
if (value >= 1000000)
*buffer++ = gDigitsLut[d1 + 1];
if (value >= 100000)
*buffer++ = gDigitsLut[d2];
*buffer++ = gDigitsLut[d2 + 1];
*buffer++ = gDigitsLut[d3];
*buffer++ = gDigitsLut[d3 + 1];
*buffer++ = gDigitsLut[d4];
*buffer++ = gDigitsLut[d4 + 1];
}
else {
// value = aabbbbcccc in decimal
const uint32_t a = value / 100000000; // 1 to 42
value %= 100000000;
if (a >= 10) {
const unsigned i = a << 1;
*buffer++ = gDigitsLut[i];
*buffer++ = gDigitsLut[i + 1];
}
else
*buffer++ = '0' + static_cast<char>(a);
const uint32_t b = value / 10000; // 0 to 9999
const uint32_t c = value % 10000; // 0 to 9999
const uint32_t d1 = (b / 100) << 1;
const uint32_t d2 = (b % 100) << 1;
const uint32_t d3 = (c / 100) << 1;
const uint32_t d4 = (c % 100) << 1;
*buffer++ = gDigitsLut[d1];
*buffer++ = gDigitsLut[d1 + 1];
*buffer++ = gDigitsLut[d2];
*buffer++ = gDigitsLut[d2 + 1];
*buffer++ = gDigitsLut[d3];
*buffer++ = gDigitsLut[d3 + 1];
*buffer++ = gDigitsLut[d4];
*buffer++ = gDigitsLut[d4 + 1];
}
return buffer; //*buffer++ = '\0';
}
// Original name: u64toa_branchlut
char *ToString(uint64_t value, char* buffer) {
if (value < 100000000) {
uint32_t v = static_cast<uint32_t>(value);
if (v < 10000) {
const uint32_t d1 = (v / 100) << 1;
const uint32_t d2 = (v % 100) << 1;
if (v >= 1000)
*buffer++ = gDigitsLut[d1];
if (v >= 100)
*buffer++ = gDigitsLut[d1 + 1];
if (v >= 10)
*buffer++ = gDigitsLut[d2];
*buffer++ = gDigitsLut[d2 + 1];
}
else {
// value = bbbbcccc
const uint32_t b = v / 10000;
const uint32_t c = v % 10000;
const uint32_t d1 = (b / 100) << 1;
const uint32_t d2 = (b % 100) << 1;
const uint32_t d3 = (c / 100) << 1;
const uint32_t d4 = (c % 100) << 1;
if (value >= 10000000)
*buffer++ = gDigitsLut[d1];
if (value >= 1000000)
*buffer++ = gDigitsLut[d1 + 1];
if (value >= 100000)
*buffer++ = gDigitsLut[d2];
*buffer++ = gDigitsLut[d2 + 1];
*buffer++ = gDigitsLut[d3];
*buffer++ = gDigitsLut[d3 + 1];
*buffer++ = gDigitsLut[d4];
*buffer++ = gDigitsLut[d4 + 1];
}
}
else if (value < 10000000000000000) {
const uint32_t v0 = static_cast<uint32_t>(value / 100000000);
const uint32_t v1 = static_cast<uint32_t>(value % 100000000);
const uint32_t b0 = v0 / 10000;
const uint32_t c0 = v0 % 10000;
const uint32_t d1 = (b0 / 100) << 1;
const uint32_t d2 = (b0 % 100) << 1;
const uint32_t d3 = (c0 / 100) << 1;
const uint32_t d4 = (c0 % 100) << 1;
const uint32_t b1 = v1 / 10000;
const uint32_t c1 = v1 % 10000;
const uint32_t d5 = (b1 / 100) << 1;
const uint32_t d6 = (b1 % 100) << 1;
const uint32_t d7 = (c1 / 100) << 1;
const uint32_t d8 = (c1 % 100) << 1;
if (value >= 1000000000000000)
*buffer++ = gDigitsLut[d1];
if (value >= 100000000000000)
*buffer++ = gDigitsLut[d1 + 1];
if (value >= 10000000000000)
*buffer++ = gDigitsLut[d2];
if (value >= 1000000000000)
*buffer++ = gDigitsLut[d2 + 1];
if (value >= 100000000000)
*buffer++ = gDigitsLut[d3];
if (value >= 10000000000)
*buffer++ = gDigitsLut[d3 + 1];
if (value >= 1000000000)
*buffer++ = gDigitsLut[d4];
if (value >= 100000000)
*buffer++ = gDigitsLut[d4 + 1];
*buffer++ = gDigitsLut[d5];
*buffer++ = gDigitsLut[d5 + 1];
*buffer++ = gDigitsLut[d6];
*buffer++ = gDigitsLut[d6 + 1];
*buffer++ = gDigitsLut[d7];
*buffer++ = gDigitsLut[d7 + 1];
*buffer++ = gDigitsLut[d8];
*buffer++ = gDigitsLut[d8 + 1];
}
else {
const uint32_t a = static_cast<uint32_t>(value / 10000000000000000); // 1 to 1844
value %= 10000000000000000;
if (a < 10)
*buffer++ = '0' + static_cast<char>(a);
else if (a < 100) {
const uint32_t i = a << 1;
*buffer++ = gDigitsLut[i];
*buffer++ = gDigitsLut[i + 1];
}
else if (a < 1000) {
*buffer++ = '0' + static_cast<char>(a / 100);
const uint32_t i = (a % 100) << 1;
*buffer++ = gDigitsLut[i];
*buffer++ = gDigitsLut[i + 1];
}
else {
const uint32_t i = (a / 100) << 1;
const uint32_t j = (a % 100) << 1;
*buffer++ = gDigitsLut[i];
*buffer++ = gDigitsLut[i + 1];
*buffer++ = gDigitsLut[j];
*buffer++ = gDigitsLut[j + 1];
}
const uint32_t v0 = static_cast<uint32_t>(value / 100000000);
const uint32_t v1 = static_cast<uint32_t>(value % 100000000);
const uint32_t b0 = v0 / 10000;
const uint32_t c0 = v0 % 10000;
const uint32_t d1 = (b0 / 100) << 1;
const uint32_t d2 = (b0 % 100) << 1;
const uint32_t d3 = (c0 / 100) << 1;
const uint32_t d4 = (c0 % 100) << 1;
const uint32_t b1 = v1 / 10000;
const uint32_t c1 = v1 % 10000;
const uint32_t d5 = (b1 / 100) << 1;
const uint32_t d6 = (b1 % 100) << 1;
const uint32_t d7 = (c1 / 100) << 1;
const uint32_t d8 = (c1 % 100) << 1;
*buffer++ = gDigitsLut[d1];
*buffer++ = gDigitsLut[d1 + 1];
*buffer++ = gDigitsLut[d2];
*buffer++ = gDigitsLut[d2 + 1];
*buffer++ = gDigitsLut[d3];
*buffer++ = gDigitsLut[d3 + 1];
*buffer++ = gDigitsLut[d4];
*buffer++ = gDigitsLut[d4 + 1];
*buffer++ = gDigitsLut[d5];
*buffer++ = gDigitsLut[d5 + 1];
*buffer++ = gDigitsLut[d6];
*buffer++ = gDigitsLut[d6 + 1];
*buffer++ = gDigitsLut[d7];
*buffer++ = gDigitsLut[d7 + 1];
*buffer++ = gDigitsLut[d8];
*buffer++ = gDigitsLut[d8 + 1];
}
return buffer;
}
#endif // End of architecture if statement.
// Signed wrappers. The negation is done on the unsigned version because
// doing so has defined behavior for INT_MIN.
char *ToString(int32_t value, char *to) {
uint32_t un = static_cast<uint32_t>(value);
if (value < 0) {
*to++ = '-';
un = -un;
}
return ToString(un, to);
}
char *ToString(int64_t value, char *to) {
uint64_t un = static_cast<uint64_t>(value);
if (value < 0) {
*to++ = '-';
un = -un;
}
return ToString(un, to);
}
// No optimization for this case yet.
char *ToString(int16_t value, char *to) {
return ToString((int32_t)value, to);
}
char *ToString(uint16_t value, char *to) {
return ToString((uint32_t)value, to);
}
} // namespace util

56
util/integer_to_string.hh Normal file
View File

@ -0,0 +1,56 @@
#ifndef UTIL_INTEGER_TO_STRING_H
#define UTIL_INTEGER_TO_STRING_H
#include <cstddef>
#include <stdint.h>
namespace util {
/* These functions convert integers to strings and return the end pointer.
*/
char *ToString(uint32_t value, char *to);
char *ToString(uint64_t value, char *to);
// Implemented as wrappers to above
char *ToString(int32_t value, char *to);
char *ToString(int64_t value, char *to);
// Calls the 32-bit versions for now.
char *ToString(uint16_t value, char *to);
char *ToString(int16_t value, char *to);
inline char *ToString(bool value, char *to) {
*to++ = '0' + value;
return to;
}
// How many bytes to reserve in the buffer for these strings:
// g++ 4.9.1 doesn't work with this:
// static const std::size_t kBytes = 5;
// So use enum.
template <class T> struct ToStringBuf;
template <> struct ToStringBuf<bool> {
enum { kBytes = 1 };
};
template <> struct ToStringBuf<uint16_t> {
enum { kBytes = 5 };
};
template <> struct ToStringBuf<int16_t> {
enum { kBytes = 6 };
};
template <> struct ToStringBuf<uint32_t> {
enum { kBytes = 10 };
};
template <> struct ToStringBuf<int32_t> {
enum { kBytes = 11 };
};
template <> struct ToStringBuf<uint64_t> {
enum { kBytes = 20 };
};
template <> struct ToStringBuf<int64_t> {
// Not a typo. 2^63 has 19 digits.
enum { kBytes = 20 };
};
} // namespace util
#endif // UTIL_INTEGER_TO_STRING_H

View File

@ -0,0 +1,65 @@
#define BOOST_LEXICAL_CAST_ASSUME_C_LOCALE
#include "util/integer_to_string.hh"
#include "util/string_piece.hh"
#define BOOST_TEST_MODULE IntegerToStringTest
#include <boost/test/unit_test.hpp>
#include <boost/lexical_cast.hpp>
#include <limits>
namespace util {
namespace {
template <class T> void TestValue(const T value) {
char buf[ToStringBuf<T>::kBytes];
StringPiece result(buf, ToString(value, buf) - buf);
BOOST_REQUIRE_GE(static_cast<std::size_t>(ToStringBuf<T>::kBytes), result.size());
BOOST_CHECK_EQUAL(boost::lexical_cast<std::string>(value), result);
}
template <class T> void TestCorners() {
TestValue(std::numeric_limits<T>::min());
TestValue(std::numeric_limits<T>::max());
TestValue(static_cast<T>(0));
TestValue(static_cast<T>(-1));
TestValue(static_cast<T>(1));
}
BOOST_AUTO_TEST_CASE(Corners) {
TestCorners<uint16_t>();
TestCorners<uint32_t>();
TestCorners<uint64_t>();
TestCorners<int16_t>();
TestCorners<int32_t>();
TestCorners<int64_t>();
}
template <class T> void TestAll() {
for (T i = std::numeric_limits<T>::min(); i < std::numeric_limits<T>::max(); ++i) {
TestValue(i);
}
TestValue(std::numeric_limits<T>::max());
}
BOOST_AUTO_TEST_CASE(Short) {
TestAll<uint16_t>();
TestAll<int16_t>();
}
template <class T> void Test10s() {
for (T i = 1; i < std::numeric_limits<T>::max() / 10; i *= 10) {
TestValue(i);
TestValue(i - 1);
TestValue(i + 1);
}
}
BOOST_AUTO_TEST_CASE(Tens) {
Test10s<uint64_t>();
Test10s<int64_t>();
Test10s<uint32_t>();
Test10s<int32_t>();
}
}} // namespaces

View File

@ -88,7 +88,7 @@ template <class EntryT, class HashT, class EqualT = std::equal_to<typename Entry
#ifdef DEBUG
assert(initialized_);
#endif
for (MutableIterator i = Ideal(t);;) {
for (MutableIterator i = Ideal(t.GetKey());;) {
Key got(i->GetKey());
if (equal_(got, t.GetKey())) { out = i; return true; }
if (equal_(got, invalid_)) {
@ -108,7 +108,7 @@ template <class EntryT, class HashT, class EqualT = std::equal_to<typename Entry
#ifdef DEBUG
assert(initialized_);
#endif
for (MutableIterator i(begin_ + (hash_(key) % buckets_));;) {
for (MutableIterator i(Ideal(key));;) {
Key got(i->GetKey());
if (equal_(got, key)) { out = i; return true; }
if (equal_(got, invalid_)) return false;
@ -118,7 +118,7 @@ template <class EntryT, class HashT, class EqualT = std::equal_to<typename Entry
// Like UnsafeMutableFind, but the key must be there.
template <class Key> MutableIterator UnsafeMutableMustFind(const Key key) {
for (MutableIterator i(begin_ + (hash_(key) % buckets_));;) {
for (MutableIterator i(Ideal(key));;) {
Key got(i->GetKey());
if (equal_(got, key)) { return i; }
assert(!equal_(got, invalid_));
@ -131,7 +131,7 @@ template <class EntryT, class HashT, class EqualT = std::equal_to<typename Entry
#ifdef DEBUG
assert(initialized_);
#endif
for (ConstIterator i(begin_ + (hash_(key) % buckets_));;) {
for (ConstIterator i(Ideal(key));;) {
Key got(i->GetKey());
if (equal_(got, key)) { out = i; return true; }
if (equal_(got, invalid_)) return false;
@ -141,7 +141,7 @@ template <class EntryT, class HashT, class EqualT = std::equal_to<typename Entry
// Like Find but we're sure it must be there.
template <class Key> ConstIterator MustFind(const Key key) const {
for (ConstIterator i(begin_ + (hash_(key) % buckets_));;) {
for (ConstIterator i(Ideal(key));;) {
Key got(i->GetKey());
if (equal_(got, key)) { return i; }
assert(!equal_(got, invalid_));
@ -213,7 +213,7 @@ template <class EntryT, class HashT, class EqualT = std::equal_to<typename Entry
MutableIterator i;
// Beginning can be wrap-arounds.
for (i = begin_; !equal_(i->GetKey(), invalid_); ++i) {
MutableIterator ideal = Ideal(*i);
MutableIterator ideal = Ideal(i->GetKey());
UTIL_THROW_IF(ideal > i && ideal <= last, Exception, "Inconsistency at position " << (i - begin_) << " should be at " << (ideal - begin_));
}
MutableIterator pre_gap = i;
@ -222,7 +222,7 @@ template <class EntryT, class HashT, class EqualT = std::equal_to<typename Entry
pre_gap = i;
continue;
}
MutableIterator ideal = Ideal(*i);
MutableIterator ideal = Ideal(i->GetKey());
UTIL_THROW_IF(ideal > i || ideal <= pre_gap, Exception, "Inconsistency at position " << (i - begin_) << " with ideal " << (ideal - begin_));
}
}
@ -230,12 +230,15 @@ template <class EntryT, class HashT, class EqualT = std::equal_to<typename Entry
private:
friend class AutoProbing<Entry, Hash, Equal>;
template <class T> MutableIterator Ideal(const T &t) {
return begin_ + (hash_(t.GetKey()) % buckets_);
MutableIterator Ideal(const Key key) {
return begin_ + (hash_(key) % buckets_);
}
ConstIterator Ideal(const Key key) const {
return begin_ + (hash_(key) % buckets_);
}
template <class T> MutableIterator UncheckedInsert(const T &t) {
for (MutableIterator i(Ideal(t));;) {
for (MutableIterator i(Ideal(t.GetKey()));;) {
if (equal_(i->GetKey(), invalid_)) { *i = t; return i; }
if (++i == end_) { i = begin_; }
}
@ -277,6 +280,7 @@ template <class EntryT, class HashT, class EqualT = std::equal_to<typename Entry
// Assumes that the key is unique. Multiple insertions won't cause a failure, just inconsistent lookup.
template <class T> MutableIterator Insert(const T &t) {
++backend_.entries_;
DoubleIfNeeded();
return backend_.UncheckedInsert(t);
}

View File

@ -0,0 +1,49 @@
#include "util/probing_hash_table.hh"
#include "util/scoped.hh"
#include "util/usage.hh"
#include <boost/random/mersenne_twister.hpp>
#include <boost/random/uniform_int_distribution.hpp>
#include <iostream>
namespace util {
namespace {
struct Entry {
typedef uint64_t Key;
Key key;
Key GetKey() const { return key; }
};
typedef util::ProbingHashTable<Entry, util::IdentityHash> Table;
void Test(uint64_t entries, uint64_t lookups, float multiplier = 1.5) {
std::size_t size = Table::Size(entries, multiplier);
scoped_malloc backing(util::CallocOrThrow(size));
Table table(backing.get(), size);
boost::random::mt19937 gen;
boost::random::uniform_int_distribution<> dist(std::numeric_limits<uint64_t>::min(), std::numeric_limits<uint64_t>::max());
double start = UserTime();
for (uint64_t i = 0; i < entries; ++i) {
Entry entry;
entry.key = dist(gen);
table.Insert(entry);
}
double inserted = UserTime();
bool meaningless = true;
for (uint64_t i = 0; i < lookups; ++i) {
Table::ConstIterator it;
meaningless ^= table.Find(dist(gen), it);
}
std::cout << meaningless << ' ' << entries << ' ' << multiplier << ' ' << (inserted - start) << ' ' << (UserTime() - inserted) / static_cast<double>(lookups) << std::endl;
}
} // namespace
} // namespace util
int main() {
for (uint64_t i = 1; i <= 10000000ULL; i *= 10) {
util::Test(i, 10000000);
}
}

View File

@ -7,6 +7,8 @@
namespace util {
// TODO: if we're really under memory pressure, don't allocate memory to
// display the error.
MallocException::MallocException(std::size_t requested) throw() {
*this << "for " << requested << " bytes ";
}
@ -16,10 +18,6 @@ MallocException::~MallocException() throw() {}
namespace {
void *InspectAddr(void *addr, std::size_t requested, const char *func_name) {
UTIL_THROW_IF_ARG(!addr && requested, MallocException, (requested), "in " << func_name);
// These routines are often used for large chunks of memory where huge pages help.
#if MADV_HUGEPAGE
madvise(addr, requested, MADV_HUGEPAGE);
#endif
return addr;
}
} // namespace
@ -36,4 +34,10 @@ void scoped_malloc::call_realloc(std::size_t requested) {
p_ = InspectAddr(std::realloc(p_, requested), requested, "realloc");
}
void AdviseHugePages(const void *addr, std::size_t size) {
#if MADV_HUGEPAGE
madvise((void*)addr, size, MADV_HUGEPAGE);
#endif
}
} // namespace util

View File

@ -104,6 +104,8 @@ template <class T> class scoped_ptr : public scoped<T, scoped_delete_forward> {
explicit scoped_ptr(T *p = NULL) : scoped<T, scoped_delete_forward>(p) {}
};
void AdviseHugePages(const void *addr, std::size_t size);
} // namespace util
#endif // UTIL_SCOPED_H

View File

@ -4,9 +4,10 @@
# timer-link = ;
#}
fakelib stream : chain.cc io.cc line_input.cc multi_progress.cc ..//kenutil /top//boost_thread : : : <library>/top//boost_thread ;
fakelib stream : chain.cc rewindable_stream.cc io.cc line_input.cc multi_progress.cc ..//kenutil /top//boost_thread : : : <library>/top//boost_thread ;
import testing ;
unit-test io_test : io_test.cc stream /top//boost_unit_test_framework ;
unit-test stream_test : stream_test.cc stream /top//boost_unit_test_framework ;
unit-test rewindable_stream_test : rewindable_stream_test.cc stream /top//boost_unit_test_framework ;
unit-test sort_test : sort_test.cc stream /top//boost_unit_test_framework ;

View File

@ -72,6 +72,7 @@ class Block {
private:
friend class Link;
friend class RewindableStream;
/**
* Points this block's memory at NULL.

View File

@ -23,6 +23,7 @@ class ChainConfigException : public Exception {
};
class Chain;
class RewindableStream;
/**
* Encapsulates a @ref PCQueue "producer queue" and a @ref PCQueue "consumer queue" within a @ref Chain "chain".
@ -35,6 +36,7 @@ class ChainPosition {
private:
friend class Chain;
friend class Link;
friend class RewindableStream;
ChainPosition(PCQueue<Block> &in, PCQueue<Block> &out, Chain *chain, MultiProgress &progress)
: in_(&in), out_(&out), chain_(chain), progress_(progress.Add()) {}

View File

@ -70,8 +70,8 @@ class FileBuffer {
return PWriteAndRecycle(file_.get());
}
PRead Source() const {
return PRead(file_.get());
PRead Source(bool discard = false) {
return PRead(discard ? file_.release() : file_.get(), discard);
}
uint64_t Size() const {

View File

@ -0,0 +1,117 @@
#include "util/stream/rewindable_stream.hh"
#include "util/pcqueue.hh"
namespace util {
namespace stream {
RewindableStream::RewindableStream()
: current_(NULL), in_(NULL), out_(NULL), poisoned_(true) {
// nothing
}
void RewindableStream::Init(const ChainPosition &position) {
UTIL_THROW_IF2(in_, "RewindableStream::Init twice");
in_ = position.in_;
out_ = position.out_;
poisoned_ = false;
progress_ = position.progress_;
entry_size_ = position.GetChain().EntrySize();
block_size_ = position.GetChain().BlockSize();
FetchBlock();
current_bl_ = &second_bl_;
current_ = static_cast<uint8_t*>(current_bl_->Get());
end_ = current_ + current_bl_->ValidSize();
}
const void *RewindableStream::Get() const {
return current_;
}
void *RewindableStream::Get() {
return current_;
}
RewindableStream &RewindableStream::operator++() {
assert(*this);
assert(current_ < end_);
current_ += entry_size_;
if (current_ == end_) {
// two cases: either we need to fetch the next block, or we've already
// fetched it before. We can check this by looking at the current_bl_
// pointer: if it's at the second_bl_, we need to flush and fetch a new
// block. Otherwise, we can just move over to the second block.
if (current_bl_ == &second_bl_) {
if (first_bl_) {
out_->Produce(first_bl_);
progress_ += first_bl_.ValidSize();
}
first_bl_ = second_bl_;
FetchBlock();
}
current_bl_ = &second_bl_;
current_ = static_cast<uint8_t *>(second_bl_.Get());
end_ = current_ + second_bl_.ValidSize();
}
if (!*current_bl_)
{
if (current_bl_ == &second_bl_ && first_bl_)
{
out_->Produce(first_bl_);
progress_ += first_bl_.ValidSize();
}
out_->Produce(*current_bl_);
poisoned_ = true;
}
return *this;
}
void RewindableStream::FetchBlock() {
// The loop is needed since it is *feasible* that we're given 0 sized but
// valid blocks
do {
in_->Consume(second_bl_);
} while (second_bl_ && second_bl_.ValidSize() == 0);
}
void RewindableStream::Mark() {
marked_ = current_;
}
void RewindableStream::Rewind() {
if (marked_ >= first_bl_.Get() && marked_ < first_bl_.ValidEnd()) {
current_bl_ = &first_bl_;
current_ = marked_;
} else if (marked_ >= second_bl_.Get() && marked_ < second_bl_.ValidEnd()) {
current_bl_ = &second_bl_;
current_ = marked_;
} else { UTIL_THROW2("RewindableStream rewound too far"); }
}
void RewindableStream::Poison() {
assert(!poisoned_);
// Three things: if we have a buffered first block, we need to produce it
// first. Then, produce the partial "current" block, and then send the
// poison down the chain
// if we still have a buffered first block, produce it first
if (current_bl_ == &second_bl_ && first_bl_) {
out_->Produce(first_bl_);
progress_ += first_bl_.ValidSize();
}
// send our partial block
current_bl_->SetValidSize(current_
- static_cast<uint8_t *>(current_bl_->Get()));
out_->Produce(*current_bl_);
progress_ += current_bl_->ValidSize();
// send down the poison
current_bl_->SetToPoison();
out_->Produce(*current_bl_);
poisoned_ = true;
}
}
}

View File

@ -0,0 +1,108 @@
#ifndef UTIL_STREAM_REWINDABLE_STREAM_H
#define UTIL_STREAM_REWINDABLE_STREAM_H
#include "util/stream/chain.hh"
#include <boost/noncopyable.hpp>
namespace util {
namespace stream {
/**
* A RewindableStream is like a Stream (but one that is only used for
* creating input at the start of a chain) except that it can be rewound to
* be able to re-write a part of the stream before it is sent. Rewinding
* has a limit of 2 * block_size_ - 1 in distance (it does *not* buffer an
* entire stream into memory, only a maximum of 2 * block_size_).
*/
class RewindableStream : boost::noncopyable {
public:
/**
* Creates an uninitialized RewindableStream. You **must** call Init()
* on it later!
*/
RewindableStream();
/**
* Initializes an existing RewindableStream at a specific position in
* a Chain.
*
* @param position The position in the chain to get input from and
* produce output on
*/
void Init(const ChainPosition &position);
/**
* Constructs a RewindableStream at a specific position in a Chain all
* in one step.
*
* Equivalent to RewindableStream a(); a.Init(....);
*/
explicit RewindableStream(const ChainPosition &position);
/**
* Gets the record at the current stream position. Const version.
*/
const void *Get() const;
/**
* Gets the record at the current stream position.
*/
void *Get();
operator bool() const { return current_; }
bool operator!() const { return !(*this); }
/**
* Marks the current position in the stream to be rewound to later.
* Note that you can only rewind back as far as 2 * block_size_ - 1!
*/
void Mark();
/**
* Rewinds the stream back to the marked position. This will throw an
* exception if the marked position is too far away.
*/
void Rewind();
/**
* Moves the stream forward to the next record. This internally may
* buffer a block for the purposes of rewinding.
*/
RewindableStream& operator++();
/**
* Poisons the stream. This sends any buffered blocks down the chain
* and sends a poison block as well (sending at most 2 non-poison and 1
* poison block).
*/
void Poison();
private:
void FetchBlock();
std::size_t entry_size_;
std::size_t block_size_;
uint8_t *marked_, *current_, *end_;
Block first_bl_;
Block second_bl_;
Block* current_bl_;
PCQueue<Block> *in_, *out_;
bool poisoned_;
WorkerProgress progress_;
};
inline Chain &operator>>(Chain &chain, RewindableStream &stream) {
stream.Init(chain.Add());
return chain;
}
}
}
#endif

View File

@ -0,0 +1,41 @@
#include "util/stream/io.hh"
#include "util/stream/rewindable_stream.hh"
#include "util/file.hh"
#define BOOST_TEST_MODULE RewindableStreamTest
#include <boost/test/unit_test.hpp>
namespace util {
namespace stream {
namespace {
BOOST_AUTO_TEST_CASE(RewindableStreamTest) {
scoped_fd in(MakeTemp("io_test_temp"));
for (uint64_t i = 0; i < 100000; ++i) {
WriteOrThrow(in.get(), &i, sizeof(uint64_t));
}
SeekOrThrow(in.get(), 0);
ChainConfig config;
config.entry_size = 8;
config.total_memory = 100;
config.block_count = 6;
RewindableStream s;
Chain chain(config);
chain >> Read(in.get()) >> s >> kRecycle;
uint64_t i = 0;
for (; s; ++s, ++i) {
BOOST_CHECK_EQUAL(i, *static_cast<const uint64_t*>(s.Get()));
if (100000UL - i == 2)
s.Mark();
}
BOOST_CHECK_EQUAL(100000ULL, i);
s.Rewind();
BOOST_CHECK_EQUAL(100000ULL - 2, *static_cast<const uint64_t*>(s.Get()));
}
}
}
}

View File

@ -25,6 +25,7 @@
#include "util/stream/timer.hh"
#include "util/file.hh"
#include "util/fixed_array.hh"
#include "util/scoped.hh"
#include "util/sized_iterator.hh"
@ -544,6 +545,54 @@ template <class Compare, class Combine> uint64_t BlockingSort(Chain &chain, cons
return size;
}
/**
* Represents an @ref util::FixedArray "array" capable of storing @ref util::stream::Sort "Sort" objects.
*
* In the anticipated use case, an instance of this class will maintain one @ref util::stream::Sort "Sort" object
* for each n-gram order (ranging from 1 up to the maximum n-gram order being processed).
* Use in this manner would enable the n-grams each n-gram order to be sorted, in parallel.
*
* @tparam Compare An @ref Comparator "ngram comparator" to use during sorting.
*/
template <class Compare, class Combine = NeverCombine> class Sorts : public FixedArray<Sort<Compare, Combine> > {
private:
typedef Sort<Compare, Combine> S;
typedef FixedArray<S> P;
public:
/**
* Constructs, but does not initialize.
*
* @ref util::FixedArray::Init() "Init" must be called before use.
*
* @see util::FixedArray::Init()
*/
Sorts() {}
/**
* Constructs an @ref util::FixedArray "array" capable of storing a fixed number of @ref util::stream::Sort "Sort" objects.
*
* @param number The maximum number of @ref util::stream::Sort "sorters" that can be held by this @ref util::FixedArray "array"
* @see util::FixedArray::FixedArray()
*/
explicit Sorts(std::size_t number) : FixedArray<Sort<Compare, Combine> >(number) {}
/**
* Constructs a new @ref util::stream::Sort "Sort" object which is stored in this @ref util::FixedArray "array".
*
* The new @ref util::stream::Sort "Sort" object is constructed using the provided @ref util::stream::SortConfig "SortConfig" and @ref Comparator "ngram comparator";
* once constructed, a new worker @ref util::stream::Thread "thread" (owned by the @ref util::stream::Chain "chain") will sort the n-gram data stored
* in the @ref util::stream::Block "blocks" of the provided @ref util::stream::Chain "chain".
*
* @see util::stream::Sort::Sort()
* @see util::stream::Chain::operator>>()
*/
void push_back(util::stream::Chain &chain, const util::stream::SortConfig &config, const Compare &compare = Compare(), const Combine &combine = Combine()) {
new (P::end()) S(chain, config, compare, combine); // use "placement new" syntax to initalize S in an already-allocated memory location
P::Constructed();
}
};
} // namespace stream
} // namespace util

View File

@ -10,13 +10,14 @@
#if defined(_WIN32) || defined(_WIN64)
#include <windows.h>
#else
#include <unistd.h>
#endif
#include <boost/filesystem.hpp>
#include <boost/noncopyable.hpp>
#include "util/exception.hh"
#include "util/unistd.hh"
namespace util
{

View File

@ -1,22 +0,0 @@
#ifndef UTIL_UNISTD_H
#define UTIL_UNISTD_H
#if (defined(_WIN32) || defined(_WIN64)) && !defined(__MINGW32__)
// Windows doesn't define <unistd.h>
//
// So we define what we need here instead:
//
#define STDIN_FILENO=0
#define STDOUT_FILENO=1
#else // Huzzah for POSIX!
#include <unistd.h>
#endif
#endif // UTIL_UNISTD_H

View File

@ -135,6 +135,16 @@ double WallTime() {
return Subtract(GetWall(), kRecordStart.Started());
}
double UserTime() {
#if !defined(_WIN32) && !defined(_WIN64)
struct rusage usage;
if (getrusage(RUSAGE_SELF, &usage))
return 0.0;
return DoubleSec(usage.ru_utime);
#endif
return 0.0;
}
void PrintUsage(std::ostream &out) {
#if !defined(_WIN32) && !defined(_WIN64)
// Linux doesn't set memory usage in getrusage :-(

View File

@ -9,6 +9,8 @@ namespace util {
// Time in seconds since process started. Zero on unsupported platforms.
double WallTime();
double UserTime();
void PrintUsage(std::ostream &to);
// Determine how much physical memory there is. Return 0 on failure.
@ -16,5 +18,6 @@ uint64_t GuessPhysicalMemory();
// Parse a size like unix sort. Sadly, this means the default multiplier is K.
uint64_t ParseSize(const std::string &arg);
} // namespace util
#endif // UTIL_USAGE_H