diff --git a/lm/binary_format.cc b/lm/binary_format.cc index 4ad893d44..2b34a778a 100644 --- a/lm/binary_format.cc +++ b/lm/binary_format.cc @@ -170,6 +170,7 @@ void *BinaryFormat::SetupJustVocab(std::size_t memory_size, uint8_t order) { if (!write_mmap_) { header_size_ = 0; util::MapAnonymous(memory_size, memory_vocab_); + util::AdviseHugePages(memory_vocab_.get(), memory_size); return reinterpret_cast(memory_vocab_.get()); } header_size_ = TotalHeaderSize(order); @@ -189,6 +190,7 @@ void *BinaryFormat::SetupJustVocab(std::size_t memory_size, uint8_t order) { break; } strncpy(reinterpret_cast(vocab_base), kMagicIncomplete, header_size_); + util::AdviseHugePages(vocab_base, total); return reinterpret_cast(vocab_base) + header_size_; } @@ -201,6 +203,7 @@ void *BinaryFormat::GrowForSearch(std::size_t memory_size, std::size_t vocab_pad util::MapAnonymous(memory_size, memory_search_); assert(header_size_ == 0 || write_mmap_); vocab_base = reinterpret_cast(memory_vocab_.get()) + header_size_; + util::AdviseHugePages(memory_search_.get(), memory_size); return reinterpret_cast(memory_search_.get()); } @@ -214,6 +217,7 @@ void *BinaryFormat::GrowForSearch(std::size_t memory_size, std::size_t vocab_pad util::ResizeOrThrow(file_.get(), new_size); void *ret; MapFile(vocab_base, ret); + util::AdviseHugePages(ret, new_size); return ret; } diff --git a/lm/builder/Jamfile b/lm/builder/Jamfile index 1e0e18b5f..329a8e076 100644 --- a/lm/builder/Jamfile +++ b/lm/builder/Jamfile @@ -1,5 +1,5 @@ -fakelib builder : [ glob *.cc : *test.cc *main.cc ] - ../../util//kenutil ../../util/stream//stream ../../util/double-conversion//double-conversion ..//kenlm +fakelib builder : [ glob *.cc : *test.cc *main.cc ] + ../../util//kenutil ../../util/stream//stream ../../util/double-conversion//double-conversion ..//kenlm ../common//common : : : /top//boost_thread $(timer-link) ; exe lmplz : lmplz_main.cc builder /top//boost_program_options ; diff --git a/lm/builder/adjust_counts.cc b/lm/builder/adjust_counts.cc index bcaa71998..3ac3e8d20 100644 --- a/lm/builder/adjust_counts.cc +++ b/lm/builder/adjust_counts.cc @@ -1,5 +1,6 @@ #include "lm/builder/adjust_counts.hh" -#include "lm/builder/ngram_stream.hh" +#include "lm/common/ngram_stream.hh" +#include "lm/builder/payload.hh" #include "util/stream/timer.hh" #include @@ -13,7 +14,7 @@ BadDiscountException::~BadDiscountException() throw() {} namespace { // Return last word in full that is different. -const WordIndex* FindDifference(const NGram &full, const NGram &lower_last) { +const WordIndex* FindDifference(const NGram &full, const NGram &lower_last) { const WordIndex *cur_word = full.end() - 1; const WordIndex *pre_word = lower_last.end() - 1; // Find last difference. @@ -111,15 +112,15 @@ class StatCollector { class CollapseStream { public: CollapseStream(const util::stream::ChainPosition &position, uint64_t prune_threshold, const std::vector& prune_words) : - current_(NULL, NGram::OrderFromSize(position.GetChain().EntrySize())), + current_(NULL, NGram::OrderFromSize(position.GetChain().EntrySize())), prune_threshold_(prune_threshold), prune_words_(prune_words), block_(position) { StartBlock(); } - const NGram &operator*() const { return current_; } - const NGram *operator->() const { return ¤t_; } + const NGram &operator*() const { return current_; } + const NGram *operator->() const { return ¤t_; } operator bool() const { return block_; } @@ -131,14 +132,14 @@ class CollapseStream { UpdateCopyFrom(); // Mark highest order n-grams for later pruning - if(current_.Count() <= prune_threshold_) { - current_.Mark(); + if(current_.Value().count <= prune_threshold_) { + current_.Value().Mark(); } if(!prune_words_.empty()) { for(WordIndex* i = current_.begin(); i != current_.end(); i++) { if(prune_words_[*i]) { - current_.Mark(); + current_.Value().Mark(); break; } } @@ -155,14 +156,14 @@ class CollapseStream { } // Mark highest order n-grams for later pruning - if(current_.Count() <= prune_threshold_) { - current_.Mark(); + if(current_.Value().count <= prune_threshold_) { + current_.Value().Mark(); } if(!prune_words_.empty()) { for(WordIndex* i = current_.begin(); i != current_.end(); i++) { if(prune_words_[*i]) { - current_.Mark(); + current_.Value().Mark(); break; } } @@ -182,14 +183,14 @@ class CollapseStream { UpdateCopyFrom(); // Mark highest order n-grams for later pruning - if(current_.Count() <= prune_threshold_) { - current_.Mark(); + if(current_.Value().count <= prune_threshold_) { + current_.Value().Mark(); } if(!prune_words_.empty()) { for(WordIndex* i = current_.begin(); i != current_.end(); i++) { if(prune_words_[*i]) { - current_.Mark(); + current_.Value().Mark(); break; } } @@ -200,11 +201,11 @@ class CollapseStream { // Find last without bos. void UpdateCopyFrom() { for (copy_from_ -= current_.TotalSize(); copy_from_ >= current_.Base(); copy_from_ -= current_.TotalSize()) { - if (NGram(copy_from_, current_.Order()).begin()[1] != kBOS) break; + if (NGram(copy_from_, current_.Order()).begin()[1] != kBOS) break; } } - NGram current_; + NGram current_; // Goes backwards in the block uint8_t *copy_from_; @@ -223,36 +224,36 @@ void AdjustCounts::Run(const util::stream::ChainPositions &positions) { if (order == 1) { // Only unigrams. Just collect stats. - for (NGramStream full(positions[0]); full; ++full) { + for (NGramStream full(positions[0]); full; ++full) { // Do not prune if(*full->begin() > 2) { - if(full->Count() <= prune_thresholds_[0]) - full->Mark(); + if(full->Value().count <= prune_thresholds_[0]) + full->Value().Mark(); if(!prune_words_.empty() && prune_words_[*full->begin()]) - full->Mark(); + full->Value().Mark(); } - stats.AddFull(full->UnmarkedCount(), full->IsMarked()); + stats.AddFull(full->Value().UnmarkedCount(), full->Value().IsMarked()); } stats.CalculateDiscounts(discount_config_); return; } - NGramStreams streams; + NGramStreams streams; streams.Init(positions, positions.size() - 1); CollapseStream full(positions[positions.size() - 1], prune_thresholds_.back(), prune_words_); // Initialization: has count 0 and so does . - NGramStream *lower_valid = streams.begin(); - const NGramStream *const streams_begin = streams.begin(); - streams[0]->Count() = 0; + NGramStream *lower_valid = streams.begin(); + const NGramStream *const streams_begin = streams.begin(); + streams[0]->Value().count = 0; *streams[0]->begin() = kUNK; stats.Add(0, 0); - (++streams[0])->Count() = 0; + (++streams[0])->Value().count = 0; *streams[0]->begin() = kBOS; // is not in stats yet because it will get put in later. @@ -271,28 +272,28 @@ void AdjustCounts::Run(const util::stream::ChainPositions &positions) { for (; lower_valid >= &streams[same]; --lower_valid) { uint64_t order_minus_1 = lower_valid - streams_begin; if(actual_counts[order_minus_1] <= prune_thresholds_[order_minus_1]) - (*lower_valid)->Mark(); + (*lower_valid)->Value().Mark(); if(!prune_words_.empty()) { for(WordIndex* i = (*lower_valid)->begin(); i != (*lower_valid)->end(); i++) { if(prune_words_[*i]) { - (*lower_valid)->Mark(); + (*lower_valid)->Value().Mark(); break; } } } - stats.Add(order_minus_1, (*lower_valid)->UnmarkedCount(), (*lower_valid)->IsMarked()); + stats.Add(order_minus_1, (*lower_valid)->Value().UnmarkedCount(), (*lower_valid)->Value().IsMarked()); ++*lower_valid; } // STEP 2: Update n-grams that still match. // n-grams that match get count from the full entry. for (std::size_t i = 0; i < same; ++i) { - actual_counts[i] += full->UnmarkedCount(); + actual_counts[i] += full->Value().UnmarkedCount(); } // Increment the number of unique extensions for the longest match. - if (same) ++streams[same - 1]->Count(); + if (same) ++streams[same - 1]->Value().count; // STEP 3: Initialize new n-grams. // This is here because bos is also const WordIndex *, so copy gets @@ -301,47 +302,47 @@ void AdjustCounts::Run(const util::stream::ChainPositions &positions) { // Initialize and mark as valid up to bos. const WordIndex *bos; for (bos = different; (bos > full->begin()) && (*bos != kBOS); --bos) { - NGramStream &to = *++lower_valid; + NGramStream &to = *++lower_valid; std::copy(bos, full_end, to->begin()); - to->Count() = 1; - actual_counts[lower_valid - streams_begin] = full->UnmarkedCount(); + to->Value().count = 1; + actual_counts[lower_valid - streams_begin] = full->Value().UnmarkedCount(); } // Now bos indicates where is or is the 0th word of full. if (bos != full->begin()) { // There is an beyond the 0th word. - NGramStream &to = *++lower_valid; + NGramStream &to = *++lower_valid; std::copy(bos, full_end, to->begin()); // Anything that begins with has full non adjusted count. - to->Count() = full->UnmarkedCount(); - actual_counts[lower_valid - streams_begin] = full->UnmarkedCount(); + to->Value().count = full->Value().UnmarkedCount(); + actual_counts[lower_valid - streams_begin] = full->Value().UnmarkedCount(); } else { - stats.AddFull(full->UnmarkedCount(), full->IsMarked()); + stats.AddFull(full->Value().UnmarkedCount(), full->Value().IsMarked()); } assert(lower_valid >= &streams[0]); } // The above loop outputs n-grams when it observes changes. This outputs // the last n-grams. - for (NGramStream *s = streams.begin(); s <= lower_valid; ++s) { + for (NGramStream *s = streams.begin(); s <= lower_valid; ++s) { uint64_t lower_count = actual_counts[(*s)->Order() - 1]; if(lower_count <= prune_thresholds_[(*s)->Order() - 1]) - (*s)->Mark(); + (*s)->Value().Mark(); if(!prune_words_.empty()) { for(WordIndex* i = (*s)->begin(); i != (*s)->end(); i++) { if(prune_words_[*i]) { - (*s)->Mark(); + (*s)->Value().Mark(); break; } } } - stats.Add(s - streams.begin(), lower_count, (*s)->IsMarked()); + stats.Add(s - streams.begin(), lower_count, (*s)->Value().IsMarked()); ++*s; } // Poison everyone! Except the N-grams which were already poisoned by the input. - for (NGramStream *s = streams.begin(); s != streams.end(); ++s) + for (NGramStream *s = streams.begin(); s != streams.end(); ++s) s->Poison(); stats.CalculateDiscounts(discount_config_); diff --git a/lm/builder/adjust_counts_test.cc b/lm/builder/adjust_counts_test.cc index 2a9d78ae0..fff551f7c 100644 --- a/lm/builder/adjust_counts_test.cc +++ b/lm/builder/adjust_counts_test.cc @@ -1,6 +1,7 @@ #include "lm/builder/adjust_counts.hh" -#include "lm/builder/ngram_stream.hh" +#include "lm/common/ngram_stream.hh" +#include "lm/builder/payload.hh" #include "util/scoped.hh" #include @@ -37,7 +38,7 @@ struct Gram4 { class WriteInput { public: void Run(const util::stream::ChainPosition &position) { - NGramStream input(position); + NGramStream input(position); Gram4 grams[] = { {{0,0,0,0},10}, {{0,0,3,0},3}, @@ -47,7 +48,7 @@ class WriteInput { }; for (size_t i = 0; i < sizeof(grams) / sizeof(Gram4); ++i, ++input) { memcpy(input->begin(), grams[i].ids, sizeof(WordIndex) * 4); - input->Count() = grams[i].count; + input->Value().count = grams[i].count; } input.Poison(); } @@ -63,7 +64,7 @@ BOOST_AUTO_TEST_CASE(Simple) { config.block_count = 1; util::stream::Chains chains(4); for (unsigned i = 0; i < 4; ++i) { - config.entry_size = NGram::TotalSize(i + 1); + config.entry_size = NGram::TotalSize(i + 1); chains.push_back(config); } @@ -86,25 +87,25 @@ BOOST_AUTO_TEST_CASE(Simple) { /* BOOST_CHECK_EQUAL(4UL, counts[1]); BOOST_CHECK_EQUAL(3UL, counts[2]); BOOST_CHECK_EQUAL(3UL, counts[3]);*/ - BOOST_REQUIRE_EQUAL(NGram::TotalSize(1) * 4, outputs[0].Size()); - NGram uni(outputs[0].Get(), 1); + BOOST_REQUIRE_EQUAL(NGram::TotalSize(1) * 4, outputs[0].Size()); + NGram uni(outputs[0].Get(), 1); BOOST_CHECK_EQUAL(kUNK, *uni.begin()); - BOOST_CHECK_EQUAL(0ULL, uni.Count()); + BOOST_CHECK_EQUAL(0ULL, uni.Value().count); uni.NextInMemory(); BOOST_CHECK_EQUAL(kBOS, *uni.begin()); - BOOST_CHECK_EQUAL(0ULL, uni.Count()); + BOOST_CHECK_EQUAL(0ULL, uni.Value().count); uni.NextInMemory(); BOOST_CHECK_EQUAL(0UL, *uni.begin()); - BOOST_CHECK_EQUAL(2ULL, uni.Count()); + BOOST_CHECK_EQUAL(2ULL, uni.Value().count); uni.NextInMemory(); - BOOST_CHECK_EQUAL(2ULL, uni.Count()); + BOOST_CHECK_EQUAL(2ULL, uni.Value().count); BOOST_CHECK_EQUAL(2UL, *uni.begin()); - BOOST_REQUIRE_EQUAL(NGram::TotalSize(2) * 4, outputs[1].Size()); - NGram bi(outputs[1].Get(), 2); + BOOST_REQUIRE_EQUAL(NGram::TotalSize(2) * 4, outputs[1].Size()); + NGram bi(outputs[1].Get(), 2); BOOST_CHECK_EQUAL(0UL, *bi.begin()); BOOST_CHECK_EQUAL(0UL, *(bi.begin() + 1)); - BOOST_CHECK_EQUAL(1ULL, bi.Count()); + BOOST_CHECK_EQUAL(1ULL, bi.Value().count); bi.NextInMemory(); } diff --git a/lm/builder/combine_counts.hh b/lm/builder/combine_counts.hh new file mode 100644 index 000000000..2eda51704 --- /dev/null +++ b/lm/builder/combine_counts.hh @@ -0,0 +1,31 @@ +#ifndef LM_BUILDER_COMBINE_COUNTS_H +#define LM_BUILDER_COMBINE_COUNTS_H + +#include "lm/builder/payload.hh" +#include "lm/common/ngram.hh" +#include "lm/common/compare.hh" +#include "lm/word_index.hh" +#include "util/stream/sort.hh" + +#include +#include + +namespace lm { +namespace builder { + +// Sum counts for the same n-gram. +struct CombineCounts { + bool operator()(void *first_void, const void *second_void, const SuffixOrder &compare) const { + NGram first(first_void, compare.Order()); + // There isn't a const version of NGram. + NGram second(const_cast(second_void), compare.Order()); + if (memcmp(first.begin(), second.begin(), sizeof(WordIndex) * compare.Order())) return false; + first.Value().count += second.Value().count; + return true; + } +}; + +} // namespace builder +} // namespace lm + +#endif // LM_BUILDER_COMBINE_COUNTS_H diff --git a/lm/builder/corpus_count.cc b/lm/builder/corpus_count.cc index 889eeb7a9..9f23b28a8 100644 --- a/lm/builder/corpus_count.cc +++ b/lm/builder/corpus_count.cc @@ -1,6 +1,7 @@ #include "lm/builder/corpus_count.hh" -#include "lm/builder/ngram.hh" +#include "lm/builder/payload.hh" +#include "lm/common/ngram.hh" #include "lm/lm_exception.hh" #include "lm/vocab.hh" #include "lm/word_index.hh" @@ -25,19 +26,6 @@ namespace lm { namespace builder { namespace { -#pragma pack(push) -#pragma pack(4) -struct VocabEntry { - typedef uint64_t Key; - - uint64_t GetKey() const { return key; } - void SetKey(uint64_t to) { key = to; } - - uint64_t key; - lm::WordIndex value; -}; -#pragma pack(pop) - class DedupeHash : public std::unary_function { public: explicit DedupeHash(std::size_t order) : size_(order * sizeof(WordIndex)) {} @@ -115,17 +103,17 @@ class Writer { bool found = dedupe_.FindOrInsert(DedupeEntry::Construct(gram_.begin()), at); if (found) { // Already present. - NGram already(at->key, gram_.Order()); - ++(already.Count()); + NGram already(at->key, gram_.Order()); + ++(already.Value().count); // Shift left by one. memmove(gram_.begin(), gram_.begin() + 1, sizeof(WordIndex) * (gram_.Order() - 1)); return; } // Complete the write. - gram_.Count() = 1; + gram_.Value().count = 1; // Prepare the next n-gram. if (reinterpret_cast(gram_.begin()) + gram_.TotalSize() != static_cast(block_->Get()) + block_size_) { - NGram last(gram_); + NGram last(gram_); gram_.NextInMemory(); std::copy(last.begin() + 1, last.end(), gram_.begin()); return; @@ -141,7 +129,7 @@ class Writer { private: void AddUnigramWord(WordIndex index) { *gram_.begin() = index; - gram_.Count() = 0; + gram_.Value().count = 0; gram_.NextInMemory(); if (gram_.Base() == static_cast(block_->Get()) + block_size_) { block_->SetValidSize(block_size_); @@ -151,7 +139,7 @@ class Writer { util::stream::Link block_; - NGram gram_; + NGram gram_; // This is the memory behind the invalid value in dedupe_. std::vector dedupe_invalid_; @@ -167,7 +155,7 @@ class Writer { } // namespace float CorpusCount::DedupeMultiplier(std::size_t order) { - return kProbingMultiplier * static_cast(sizeof(DedupeEntry)) / static_cast(NGram::TotalSize(order)); + return kProbingMultiplier * static_cast(sizeof(DedupeEntry)) / static_cast(NGram::TotalSize(order)); } std::size_t CorpusCount::VocabUsage(std::size_t vocab_estimate) { @@ -202,7 +190,7 @@ void CorpusCount::Run(const util::stream::ChainPosition &position) { token_count_ = 0; type_count_ = 0; const WordIndex end_sentence = vocab.FindOrInsert(""); - Writer writer(NGram::OrderFromSize(position.GetChain().EntrySize()), position, dedupe_mem_.get(), dedupe_mem_size_); + Writer writer(NGram::OrderFromSize(position.GetChain().EntrySize()), position, dedupe_mem_.get(), dedupe_mem_size_); uint64_t count = 0; bool delimiters[256]; util::BoolCharacter::Build("\0\t\n\r ", delimiters); @@ -233,9 +221,8 @@ void CorpusCount::Run(const util::stream::ChainPosition &position) { prune_words_.resize(vocab.Size(), true); try { while (true) { - StringPiece line(prune_vocab_file.ReadLine()); - for (util::TokenIter w(line, delimiters); w; ++w) - prune_words_[vocab.Index(*w)] = false; + StringPiece word(prune_vocab_file.ReadDelimited(delimiters)); + prune_words_[vocab.Index(word)] = false; } } catch (const util::EndOfFileException &e) {} diff --git a/lm/builder/corpus_count_test.cc b/lm/builder/corpus_count_test.cc index 18301656f..82f859690 100644 --- a/lm/builder/corpus_count_test.cc +++ b/lm/builder/corpus_count_test.cc @@ -1,7 +1,8 @@ #include "lm/builder/corpus_count.hh" -#include "lm/builder/ngram.hh" -#include "lm/builder/ngram_stream.hh" +#include "lm/builder/payload.hh" +#include "lm/common/ngram_stream.hh" +#include "lm/common/ngram.hh" #include "util/file.hh" #include "util/file_piece.hh" @@ -14,13 +15,13 @@ namespace lm { namespace builder { namespace { -#define Check(str, count) { \ +#define Check(str, cnt) { \ BOOST_REQUIRE(stream); \ w = stream->begin(); \ for (util::TokenIter t(str, " "); t; ++t, ++w) { \ BOOST_CHECK_EQUAL(*t, v[*w]); \ } \ - BOOST_CHECK_EQUAL((uint64_t)count, stream->Count()); \ + BOOST_CHECK_EQUAL((uint64_t)cnt, stream->Value().count); \ ++stream; \ } @@ -35,14 +36,14 @@ BOOST_AUTO_TEST_CASE(Short) { util::FilePiece input_piece(input_file.release(), "temp file"); util::stream::ChainConfig config; - config.entry_size = NGram::TotalSize(3); + config.entry_size = NGram::TotalSize(3); config.total_memory = config.entry_size * 20; config.block_count = 2; util::scoped_fd vocab(util::MakeTemp("corpus_count_test_vocab")); util::stream::Chain chain(config); - NGramStream stream; + NGramStream stream; uint64_t token_count; WordIndex type_count = 10; std::vector prune_words; diff --git a/lm/builder/initial_probabilities.cc b/lm/builder/initial_probabilities.cc index 80063eb2e..ef8a8ecfd 100644 --- a/lm/builder/initial_probabilities.cc +++ b/lm/builder/initial_probabilities.cc @@ -1,9 +1,10 @@ #include "lm/builder/initial_probabilities.hh" #include "lm/builder/discount.hh" -#include "lm/builder/ngram_stream.hh" -#include "lm/builder/sort.hh" +#include "lm/builder/special.hh" #include "lm/builder/hash_gamma.hh" +#include "lm/builder/payload.hh" +#include "lm/common/ngram_stream.hh" #include "util/murmur_hash.hh" #include "util/file.hh" #include "util/stream/chain.hh" @@ -32,17 +33,18 @@ struct HashBufferEntry : public BufferEntry { // threshold. class PruneNGramStream { public: - PruneNGramStream(const util::stream::ChainPosition &position) : - current_(NULL, NGram::OrderFromSize(position.GetChain().EntrySize())), - dest_(NULL, NGram::OrderFromSize(position.GetChain().EntrySize())), + PruneNGramStream(const util::stream::ChainPosition &position, const SpecialVocab &specials) : + current_(NULL, NGram::OrderFromSize(position.GetChain().EntrySize())), + dest_(NULL, NGram::OrderFromSize(position.GetChain().EntrySize())), currentCount_(0), - block_(position) + block_(position), + specials_(specials) { StartBlock(); } - NGram &operator*() { return current_; } - NGram *operator->() { return ¤t_; } + NGram &operator*() { return current_; } + NGram *operator->() { return ¤t_; } operator bool() const { return block_; @@ -50,8 +52,7 @@ class PruneNGramStream { PruneNGramStream &operator++() { assert(block_); - - if(current_.Order() == 1 && *current_.begin() <= 2) + if(UTIL_UNLIKELY(current_.Order() == 1 && specials_.IsSpecial(*current_.begin()))) dest_.NextInMemory(); else if(currentCount_ > 0) { if(dest_.Base() < current_.Base()) { @@ -68,10 +69,10 @@ class PruneNGramStream { ++block_; StartBlock(); if (block_) { - currentCount_ = current_.CutoffCount(); + currentCount_ = current_.Value().CutoffCount(); } } else { - currentCount_ = current_.CutoffCount(); + currentCount_ = current_.Value().CutoffCount(); } return *this; @@ -84,23 +85,25 @@ class PruneNGramStream { if (block_->ValidSize()) break; } current_.ReBase(block_->Get()); - currentCount_ = current_.CutoffCount(); + currentCount_ = current_.Value().CutoffCount(); dest_.ReBase(block_->Get()); } - NGram current_; // input iterator - NGram dest_; // output iterator + NGram current_; // input iterator + NGram dest_; // output iterator uint64_t currentCount_; util::stream::Link block_; + + const SpecialVocab specials_; }; // Extract an array of HashedGamma from an array of BufferEntry. class OnlyGamma { public: - OnlyGamma(bool pruning) : pruning_(pruning) {} + explicit OnlyGamma(bool pruning) : pruning_(pruning) {} void Run(const util::stream::ChainPosition &position) { for (util::stream::Link block_it(position); block_it; ++block_it) { @@ -143,7 +146,7 @@ class AddRight { : discount_(discount), input_(input), pruning_(pruning) {} void Run(const util::stream::ChainPosition &output) { - NGramStream in(input_); + NGramStream in(input_); util::stream::Stream out(output); std::vector previous(in->Order() - 1); @@ -159,17 +162,17 @@ class AddRight { uint64_t counts[4]; memset(counts, 0, sizeof(counts)); do { - denominator += in->UnmarkedCount(); + denominator += in->Value().UnmarkedCount(); // Collect unused probability mass from pruning. // Becomes 0 for unpruned ngrams. - normalizer += in->UnmarkedCount() - in->CutoffCount(); + normalizer += in->Value().UnmarkedCount() - in->Value().CutoffCount(); // Chen&Goodman do not mention counting based on cutoffs, but // backoff becomes larger than 1 otherwise, so probably needs // to count cutoffs. Counts normally without pruning. - if(in->CutoffCount() > 0) - ++counts[std::min(in->CutoffCount(), static_cast(3))]; + if(in->Value().CutoffCount() > 0) + ++counts[std::min(in->Value().CutoffCount(), static_cast(3))]; } while (++in && !memcmp(previous_raw, in->begin(), size)); @@ -202,15 +205,15 @@ class AddRight { class MergeRight { public: - MergeRight(bool interpolate_unigrams, const util::stream::ChainPosition &from_adder, const Discount &discount) - : interpolate_unigrams_(interpolate_unigrams), from_adder_(from_adder), discount_(discount) {} + MergeRight(bool interpolate_unigrams, const util::stream::ChainPosition &from_adder, const Discount &discount, const SpecialVocab &specials) + : interpolate_unigrams_(interpolate_unigrams), from_adder_(from_adder), discount_(discount), specials_(specials) {} // calculate the initial probability of each n-gram (before order-interpolation) // Run() gets invoked once for each order void Run(const util::stream::ChainPosition &primary) { util::stream::Stream summed(from_adder_); - PruneNGramStream grams(primary); + PruneNGramStream grams(primary, specials_); // Without interpolation, the interpolation weight goes to . if (grams->Order() == 1) { @@ -228,17 +231,21 @@ class MergeRight { grams->Value().uninterp.prob = sums.gamma; } grams->Value().uninterp.gamma = gamma_assign; - ++grams; + + for (++grams; *grams->begin() != specials_.BOS(); ++grams) { + grams->Value().uninterp.prob = discount_.Apply(grams->Value().count) / sums.denominator; + grams->Value().uninterp.gamma = gamma_assign; + } // Special case for : probability 1.0. This allows to be - // explicitly scores as part of the sentence without impacting + // explicitly scored as part of the sentence without impacting // probability and computes q correctly as b(). - assert(*grams->begin() == kBOS); + assert(*grams->begin() == specials_.BOS()); grams->Value().uninterp.prob = 1.0; grams->Value().uninterp.gamma = 0.0; while (++grams) { - grams->Value().uninterp.prob = discount_.Apply(grams->Count()) / sums.denominator; + grams->Value().uninterp.prob = discount_.Apply(grams->Value().count) / sums.denominator; grams->Value().uninterp.gamma = gamma_assign; } ++summed; @@ -252,8 +259,8 @@ class MergeRight { const BufferEntry &sums = *static_cast(summed.Get()); do { - Payload &pay = grams->Value(); - pay.uninterp.prob = discount_.Apply(grams->UnmarkedCount()) / sums.denominator; + BuildingPayload &pay = grams->Value(); + pay.uninterp.prob = discount_.Apply(grams->Value().UnmarkedCount()) / sums.denominator; pay.uninterp.gamma = sums.gamma; } while (++grams && !memcmp(&previous[0], grams->begin(), size)); } @@ -263,6 +270,7 @@ class MergeRight { bool interpolate_unigrams_; util::stream::ChainPosition from_adder_; Discount discount_; + const SpecialVocab specials_; }; } // namespace @@ -274,7 +282,8 @@ void InitialProbabilities( util::stream::Chains &second_in, util::stream::Chains &gamma_out, const std::vector &prune_thresholds, - bool prune_vocab) { + bool prune_vocab, + const SpecialVocab &specials) { for (size_t i = 0; i < primary.size(); ++i) { util::stream::ChainConfig gamma_config = config.adder_out; if(prune_vocab || prune_thresholds[i] > 0) @@ -287,7 +296,7 @@ void InitialProbabilities( gamma_out.push_back(gamma_config); gamma_out[i] >> AddRight(discounts[i], second, prune_vocab || prune_thresholds[i] > 0); - primary[i] >> MergeRight(config.interpolate_unigrams, gamma_out[i].Add(), discounts[i]); + primary[i] >> MergeRight(config.interpolate_unigrams, gamma_out[i].Add(), discounts[i], specials); // Don't bother with the OnlyGamma thread for something to discard. if (i) gamma_out[i] >> OnlyGamma(prune_vocab || prune_thresholds[i] > 0); diff --git a/lm/builder/initial_probabilities.hh b/lm/builder/initial_probabilities.hh index a8ecf4dc2..dddbbb913 100644 --- a/lm/builder/initial_probabilities.hh +++ b/lm/builder/initial_probabilities.hh @@ -2,6 +2,7 @@ #define LM_BUILDER_INITIAL_PROBABILITIES_H #include "lm/builder/discount.hh" +#include "lm/word_index.hh" #include "util/stream/config.hh" #include @@ -11,6 +12,8 @@ namespace util { namespace stream { class Chains; } } namespace lm { namespace builder { +class SpecialVocab; + struct InitialProbabilitiesConfig { // These should be small buffers to keep the adder from getting too far ahead util::stream::ChainConfig adder_in; @@ -34,7 +37,8 @@ void InitialProbabilities( util::stream::Chains &second_in, util::stream::Chains &gamma_out, const std::vector &prune_thresholds, - bool prune_vocab); + bool prune_vocab, + const SpecialVocab &vocab); } // namespace builder } // namespace lm diff --git a/lm/builder/interpolate.cc b/lm/builder/interpolate.cc index 5b04cb3ff..84672e068 100644 --- a/lm/builder/interpolate.cc +++ b/lm/builder/interpolate.cc @@ -2,8 +2,8 @@ #include "lm/builder/hash_gamma.hh" #include "lm/builder/joint_order.hh" -#include "lm/builder/ngram_stream.hh" -#include "lm/builder/sort.hh" +#include "lm/common/ngram_stream.hh" +#include "lm/common/compare.hh" #include "lm/lm_exception.hh" #include "util/fixed_array.hh" #include "util/murmur_hash.hh" @@ -65,11 +65,12 @@ class OutputProbBackoff { template class Callback { public: - Callback(float uniform_prob, const util::stream::ChainPositions &backoffs, const std::vector &prune_thresholds, bool prune_vocab) + Callback(float uniform_prob, const util::stream::ChainPositions &backoffs, const std::vector &prune_thresholds, bool prune_vocab, const SpecialVocab &specials) : backoffs_(backoffs.size()), probs_(backoffs.size() + 2), prune_thresholds_(prune_thresholds), prune_vocab_(prune_vocab), - output_(backoffs.size() + 1 /* order */) { + output_(backoffs.size() + 1 /* order */), + specials_(specials) { probs_[0] = uniform_prob; for (std::size_t i = 0; i < backoffs.size(); ++i) { backoffs_.push_back(backoffs[i]); @@ -89,13 +90,13 @@ template class Callback { } } - void Enter(unsigned order_minus_1, NGram &gram) { - Payload &pay = gram.Value(); + void Enter(unsigned order_minus_1, NGram &gram) { + BuildingPayload &pay = gram.Value(); pay.complete.prob = pay.uninterp.prob + pay.uninterp.gamma * probs_[order_minus_1]; probs_[order_minus_1 + 1] = pay.complete.prob; float out_backoff; - if (order_minus_1 < backoffs_.size() && *(gram.end() - 1) != kUNK && *(gram.end() - 1) != kEOS && backoffs_[order_minus_1]) { + if (order_minus_1 < backoffs_.size() && *(gram.end() - 1) != specials_.UNK() && *(gram.end() - 1) != specials_.EOS() && backoffs_[order_minus_1]) { if(prune_vocab_ || prune_thresholds_[order_minus_1 + 1] > 0) { //Compute hash value for current context uint64_t current_hash = util::MurmurHashNative(gram.begin(), gram.Order() * sizeof(WordIndex)); @@ -123,7 +124,7 @@ template class Callback { output_.Gram(order_minus_1, out_backoff, pay.complete); } - void Exit(unsigned, const NGram &) const {} + void Exit(unsigned, const NGram &) const {} private: util::FixedArray backoffs_; @@ -133,26 +134,28 @@ template class Callback { bool prune_vocab_; Output output_; + const SpecialVocab specials_; }; } // namespace -Interpolate::Interpolate(uint64_t vocab_size, const util::stream::ChainPositions &backoffs, const std::vector& prune_thresholds, bool prune_vocab, bool output_q) +Interpolate::Interpolate(uint64_t vocab_size, const util::stream::ChainPositions &backoffs, const std::vector& prune_thresholds, bool prune_vocab, bool output_q, const SpecialVocab &specials) : uniform_prob_(1.0 / static_cast(vocab_size)), // Includes but excludes . backoffs_(backoffs), prune_thresholds_(prune_thresholds), prune_vocab_(prune_vocab), - output_q_(output_q) {} + output_q_(output_q), + specials_(specials) {} // perform order-wise interpolation void Interpolate::Run(const util::stream::ChainPositions &positions) { assert(positions.size() == backoffs_.size() + 1); if (output_q_) { typedef Callback C; - C callback(uniform_prob_, backoffs_, prune_thresholds_, prune_vocab_); + C callback(uniform_prob_, backoffs_, prune_thresholds_, prune_vocab_, specials_); JointOrder(positions, callback); } else { typedef Callback C; - C callback(uniform_prob_, backoffs_, prune_thresholds_, prune_vocab_); + C callback(uniform_prob_, backoffs_, prune_thresholds_, prune_vocab_, specials_); JointOrder(positions, callback); } } diff --git a/lm/builder/interpolate.hh b/lm/builder/interpolate.hh index 207a16dfd..dcee75adb 100644 --- a/lm/builder/interpolate.hh +++ b/lm/builder/interpolate.hh @@ -1,6 +1,8 @@ #ifndef LM_BUILDER_INTERPOLATE_H #define LM_BUILDER_INTERPOLATE_H +#include "lm/builder/special.hh" +#include "lm/word_index.hh" #include "util/stream/multi_stream.hh" #include @@ -18,7 +20,7 @@ class Interpolate { public: // Normally vocab_size is the unigram count-1 (since p() = 0) but might // be larger when the user specifies a consistent vocabulary size. - explicit Interpolate(uint64_t vocab_size, const util::stream::ChainPositions &backoffs, const std::vector &prune_thresholds, bool prune_vocab, bool output_q_); + explicit Interpolate(uint64_t vocab_size, const util::stream::ChainPositions &backoffs, const std::vector &prune_thresholds, bool prune_vocab, bool output_q, const SpecialVocab &specials); void Run(const util::stream::ChainPositions &positions); @@ -28,6 +30,7 @@ class Interpolate { const std::vector prune_thresholds_; bool prune_vocab_; bool output_q_; + const SpecialVocab specials_; }; }} // namespaces diff --git a/lm/builder/joint_order.hh b/lm/builder/joint_order.hh index b05ef67fd..5f62a4578 100644 --- a/lm/builder/joint_order.hh +++ b/lm/builder/joint_order.hh @@ -1,7 +1,8 @@ #ifndef LM_BUILDER_JOINT_ORDER_H #define LM_BUILDER_JOINT_ORDER_H -#include "lm/builder/ngram_stream.hh" +#include "lm/common/ngram_stream.hh" +#include "lm/builder/payload.hh" #include "lm/lm_exception.hh" #ifdef DEBUG @@ -15,9 +16,9 @@ namespace lm { namespace builder { template void JointOrder(const util::stream::ChainPositions &positions, Callback &callback) { // Allow matching to reference streams[-1]. - NGramStreams streams_with_dummy; + NGramStreams streams_with_dummy; streams_with_dummy.InitWithDummy(positions); - NGramStream *streams = streams_with_dummy.begin() + 1; + NGramStream *streams = streams_with_dummy.begin() + 1; unsigned int order; for (order = 0; order < positions.size() && streams[order]; ++order) {} diff --git a/lm/builder/lmplz_main.cc b/lm/builder/lmplz_main.cc index 5c9d86deb..c27490665 100644 --- a/lm/builder/lmplz_main.cc +++ b/lm/builder/lmplz_main.cc @@ -87,7 +87,7 @@ int main(int argc, char *argv[]) { po::options_description options("Language model building options"); lm::builder::PipelineConfig pipeline; - std::string text, arpa; + std::string text, intermediate, arpa; std::vector pruning; std::vector discount_fallback; std::vector discount_fallback_default; @@ -116,6 +116,8 @@ int main(int argc, char *argv[]) { ("verbose_header", po::bool_switch(&verbose_header), "Add a verbose header to the ARPA file that includes information such as token count, smoothing type, etc.") ("text", po::value(&text), "Read text from a file instead of stdin") ("arpa", po::value(&arpa), "Write ARPA to a file instead of stdout") + ("intermediate", po::value(&intermediate), "Write ngrams to an intermediate file. Turns off ARPA output (which can be reactivated by --arpa file). Forces --renumber on. Implicitly makes --vocab_file be the provided name + .vocab.") + ("renumber", po::bool_switch(&pipeline.renumber_vocabulary), "Rrenumber the vocabulary identifiers so that they are monotone with the hash of each string. This is consistent with the ordering used by the trie data structure.") ("collapse_values", po::bool_switch(&pipeline.output_q), "Collapse probability and backoff into a single value, q that yields the same sentence-level probabilities. See http://kheafield.com/professional/edinburgh/rest_paper.pdf for more details, including a proof.") ("prune", po::value >(&pruning)->multitoken(), "Prune n-grams with count less than or equal to the given threshold. Specify one value for each order i.e. 0 0 1 to prune singleton trigrams and above. The sequence of values must be non-decreasing and the last value applies to any remaining orders. Default is to not prune, which is equivalent to --prune 0.") ("limit_vocab_file", po::value(&pipeline.prune_vocab_file)->default_value(""), "Read allowed vocabulary separated by whitespace. N-grams that contain vocabulary items not in this list will be pruned. Can be combined with --prune arg") @@ -212,8 +214,19 @@ int main(int argc, char *argv[]) { } try { - lm::builder::Output output; - output.Add(new lm::builder::PrintARPA(out.release(), verbose_header)); + bool writing_intermediate = vm.count("intermediate"); + if (writing_intermediate) { + pipeline.renumber_vocabulary = true; + if (!pipeline.vocab_file.empty()) { + std::cerr << "--intermediate and --vocab_file are incompatible because --intermediate already makes a vocab file." << std::endl; + return 1; + } + pipeline.vocab_file = intermediate + ".vocab"; + } + lm::builder::Output output(writing_intermediate ? intermediate : pipeline.sort.temp_prefix, writing_intermediate); + if (!writing_intermediate || vm.count("arpa")) { + output.Add(new lm::builder::PrintARPA(out.release(), verbose_header)); + } lm::builder::Pipeline(pipeline, in.release(), output); } catch (const util::MallocException &e) { std::cerr << e.what() << std::endl; diff --git a/lm/builder/output.cc b/lm/builder/output.cc index 0fc0197c4..76478ad06 100644 --- a/lm/builder/output.cc +++ b/lm/builder/output.cc @@ -1,14 +1,41 @@ #include "lm/builder/output.hh" + +#include "lm/common/model_buffer.hh" #include "util/stream/multi_stream.hh" -#include +#include namespace lm { namespace builder { OutputHook::~OutputHook() {} -void OutputHook::Apply(util::stream::Chains &chains) { - chains >> boost::ref(*this); +Output::Output(StringPiece file_base, bool keep_buffer) + : file_base_(file_base.data(), file_base.size()), keep_buffer_(keep_buffer) {} + +void Output::SinkProbs(util::stream::Chains &chains, bool output_q) { + Apply(PROB_PARALLEL_HOOK, chains); + if (!keep_buffer_ && !Have(PROB_SEQUENTIAL_HOOK)) { + chains >> util::stream::kRecycle; + chains.Wait(true); + return; + } + lm::common::ModelBuffer buf(file_base_, keep_buffer_, output_q); + buf.Sink(chains); + chains >> util::stream::kRecycle; + chains.Wait(false); + if (Have(PROB_SEQUENTIAL_HOOK)) { + std::cerr << "=== 5/5 Writing ARPA model ===" << std::endl; + buf.Source(chains); + Apply(PROB_SEQUENTIAL_HOOK, chains); + chains >> util::stream::kRecycle; + chains.Wait(true); + } +} + +void Output::Apply(HookType hook_type, util::stream::Chains &chains) { + for (boost::ptr_vector::iterator entry = outputs_[hook_type].begin(); entry != outputs_[hook_type].end(); ++entry) { + entry->Sink(chains); + } } }} // namespaces diff --git a/lm/builder/output.hh b/lm/builder/output.hh index 0ef769ae2..c1e0d1469 100644 --- a/lm/builder/output.hh +++ b/lm/builder/output.hh @@ -7,16 +7,14 @@ #include #include -#include - namespace util { namespace stream { class Chains; class ChainPositions; } } -/* Outputs from lmplz: ARPA< sharded files, etc */ +/* Outputs from lmplz: ARPA, sharded files, etc */ namespace lm { namespace builder { // These are different types of hooks. Values should be consecutive to enable a vector lookup. enum HookType { - COUNT_HOOK, // Raw N-gram counts, highest order only. + // TODO: counts. PROB_PARALLEL_HOOK, // Probability and backoff (or just q). Output must process the orders in parallel or there will be a deadlock. PROB_SEQUENTIAL_HOOK, // Probability and backoff (or just q). Output can process orders any way it likes. This requires writing the data to disk then reading. Useful for ARPA files, which put unigrams first etc. NUMBER_OF_HOOKS // Keep this last so we know how many values there are. @@ -30,9 +28,7 @@ class OutputHook { virtual ~OutputHook(); - virtual void Apply(util::stream::Chains &chains); - - virtual void Run(const util::stream::ChainPositions &positions) = 0; + virtual void Sink(util::stream::Chains &chains) = 0; protected: const HeaderInfo &GetHeader() const; @@ -46,7 +42,7 @@ class OutputHook { class Output : boost::noncopyable { public: - Output() {} + Output(StringPiece file_base, bool keep_buffer); // Takes ownership. void Add(OutputHook *hook) { @@ -64,16 +60,20 @@ class Output : boost::noncopyable { void SetHeader(const HeaderInfo &header) { header_ = header; } const HeaderInfo &GetHeader() const { return header_; } - void Apply(HookType hook_type, util::stream::Chains &chains) { - for (boost::ptr_vector::iterator entry = outputs_[hook_type].begin(); entry != outputs_[hook_type].end(); ++entry) { - entry->Apply(chains); - } - } + // This is called by the pipeline. + void SinkProbs(util::stream::Chains &chains, bool output_q); + + unsigned int Steps() const { return Have(PROB_SEQUENTIAL_HOOK); } private: + void Apply(HookType hook_type, util::stream::Chains &chains); + boost::ptr_vector outputs_[NUMBER_OF_HOOKS]; int vocab_fd_; HeaderInfo header_; + + std::string file_base_; + bool keep_buffer_; }; inline const HeaderInfo &OutputHook::GetHeader() const { diff --git a/lm/builder/payload.hh b/lm/builder/payload.hh new file mode 100644 index 000000000..ba12725a4 --- /dev/null +++ b/lm/builder/payload.hh @@ -0,0 +1,48 @@ +#ifndef LM_BUILDER_PAYLOAD_H +#define LM_BUILDER_PAYLOAD_H + +#include "lm/weights.hh" +#include "lm/word_index.hh" +#include + +namespace lm { namespace builder { + +struct Uninterpolated { + float prob; // Uninterpolated probability. + float gamma; // Interpolation weight for lower order. +}; + +union BuildingPayload { + uint64_t count; + Uninterpolated uninterp; + ProbBackoff complete; + + /*mjd**********************************************************************/ + bool IsMarked() const { + return count >> (sizeof(count) * 8 - 1); + } + + void Mark() { + count |= (1ul << (sizeof(count) * 8 - 1)); + } + + void Unmark() { + count &= ~(1ul << (sizeof(count) * 8 - 1)); + } + + uint64_t UnmarkedCount() const { + return count & ~(1ul << (sizeof(count) * 8 - 1)); + } + + uint64_t CutoffCount() const { + return IsMarked() ? 0 : UnmarkedCount(); + } + /*mjd**********************************************************************/ +}; + +const WordIndex kBOS = 1; +const WordIndex kEOS = 2; + +}} // namespaces + +#endif // LM_BUILDER_PAYLOAD_H diff --git a/lm/builder/pipeline.cc b/lm/builder/pipeline.cc index 1ca2e26f5..d588beedf 100644 --- a/lm/builder/pipeline.cc +++ b/lm/builder/pipeline.cc @@ -1,14 +1,17 @@ #include "lm/builder/pipeline.hh" #include "lm/builder/adjust_counts.hh" +#include "lm/builder/combine_counts.hh" #include "lm/builder/corpus_count.hh" #include "lm/builder/hash_gamma.hh" #include "lm/builder/initial_probabilities.hh" #include "lm/builder/interpolate.hh" #include "lm/builder/output.hh" -#include "lm/builder/sort.hh" +#include "lm/common/compare.hh" +#include "lm/common/renumber.hh" #include "lm/sizes.hh" +#include "lm/vocab.hh" #include "util/exception.hh" #include "util/file.hh" @@ -21,7 +24,10 @@ namespace lm { namespace builder { +using util::stream::Sorts; + namespace { + void PrintStatistics(const std::vector &counts, const std::vector &counts_pruned, const std::vector &discounts) { std::cerr << "Statistics:\n"; for (size_t i = 0; i < counts.size(); ++i) { @@ -37,9 +43,9 @@ void PrintStatistics(const std::vector &counts, const std::vector::TotalSize(config_.order), config_.minimum_block); } const PipelineConfig &Config() const { return config_; } @@ -52,40 +58,42 @@ class Master { } // This takes the (partially) sorted ngrams and sets up for adjusted counts. - void InitForAdjust(util::stream::Sort &ngrams, WordIndex types) { + void InitForAdjust(util::stream::Sort &ngrams, WordIndex types, std::size_t subtract_for_numbering) { const std::size_t each_order_min = config_.minimum_block * config_.block_count; // We know how many unigrams there are. Don't allocate more than needed to them. const std::size_t min_chains = (config_.order - 1) * each_order_min + - std::min(types * NGram::TotalSize(1), each_order_min); + std::min(types * NGram::TotalSize(1), each_order_min); + // Prevent overflow in subtracting. + const std::size_t total = std::max(config_.TotalMemory(), min_chains + subtract_for_numbering + config_.minimum_block); // Do merge sort with calculated laziness. - const std::size_t merge_using = ngrams.Merge(std::min(config_.TotalMemory() - min_chains, ngrams.DefaultLazy())); + const std::size_t merge_using = ngrams.Merge(std::min(total - min_chains - subtract_for_numbering, ngrams.DefaultLazy())); std::vector count_bounds(1, types); - CreateChains(config_.TotalMemory() - merge_using, count_bounds); + CreateChains(total - merge_using - subtract_for_numbering, count_bounds); ngrams.Output(chains_.back(), merge_using); - - // Setup unigram file. - files_.push_back(util::MakeTemp(config_.TempPrefix())); } // For initial probabilities, but this is generic. void SortAndReadTwice(const std::vector &counts, Sorts &sorts, util::stream::Chains &second, util::stream::ChainConfig second_config) { + bool unigrams_are_sorted = !config_.renumber_vocabulary; // Do merge first before allocating chain memory. - for (std::size_t i = 1; i < config_.order; ++i) { - sorts[i - 1].Merge(0); + for (std::size_t i = 0; i < config_.order - unigrams_are_sorted; ++i) { + sorts[i].Merge(0); } // There's no lazy merge, so just divide memory amongst the chains. CreateChains(config_.TotalMemory(), counts); chains_.back().ActivateProgress(); - chains_[0] >> files_[0].Source(); - second_config.entry_size = NGram::TotalSize(1); - second.push_back(second_config); - second.back() >> files_[0].Source(); - for (std::size_t i = 1; i < config_.order; ++i) { - util::scoped_fd fd(sorts[i - 1].StealCompleted()); + if (unigrams_are_sorted) { + chains_[0] >> unigrams_.Source(); + second_config.entry_size = NGram::TotalSize(1); + second.push_back(second_config); + second.back() >> unigrams_.Source(); + } + for (std::size_t i = unigrams_are_sorted; i < config_.order; ++i) { + util::scoped_fd fd(sorts[i - unigrams_are_sorted].StealCompleted()); chains_[i].SetProgressTarget(util::SizeOrThrow(fd.get())); chains_[i] >> util::stream::PRead(util::DupOrThrow(fd.get()), true); - second_config.entry_size = NGram::TotalSize(i + 1); + second_config.entry_size = NGram::TotalSize(i + 1); second.push_back(second_config); second.back() >> util::stream::PRead(fd.release(), true); } @@ -96,7 +104,7 @@ class Master { // Determine the minimum we can use for all the chains. std::size_t min_chains = 0; for (std::size_t i = 0; i < config_.order; ++i) { - min_chains += std::min(counts[i] * NGram::TotalSize(i + 1), static_cast(config_.minimum_block)); + min_chains += std::min(counts[i] * NGram::TotalSize(i + 1), static_cast(config_.minimum_block)); } std::size_t for_merge = min_chains > config_.TotalMemory() ? 0 : (config_.TotalMemory() - min_chains); std::vector laziness; @@ -110,36 +118,24 @@ class Master { CreateChains(for_merge + min_chains, counts); chains_.back().ActivateProgress(); - chains_[0] >> files_[0].Source(); + chains_[0] >> unigrams_.Source(); for (std::size_t i = 1; i < config_.order; ++i) { sorts[i - 1].Output(chains_[i], laziness[i - 1]); } } - void BufferFinal(const std::vector &counts) { - chains_[0] >> files_[0].Sink(); - for (std::size_t i = 1; i < config_.order; ++i) { - files_.push_back(util::MakeTemp(config_.TempPrefix())); - chains_[i] >> files_[i].Sink(); - } - chains_.Wait(true); - // Use less memory. Because we can. - CreateChains(std::min(config_.sort.buffer_size * config_.order, config_.TotalMemory()), counts); - for (std::size_t i = 0; i < config_.order; ++i) { - chains_[i] >> files_[i].Source(); - } - } - - template void SetupSorts(Sorts &sorts) { - sorts.Init(config_.order - 1); + template void SetupSorts(Sorts &sorts, bool exclude_unigrams) { + sorts.Init(config_.order - exclude_unigrams); // Unigrams don't get sorted because their order is always the same. - chains_[0] >> files_[0].Sink(); - for (std::size_t i = 1; i < config_.order; ++i) { + if (exclude_unigrams) chains_[0] >> unigrams_.Sink(); + for (std::size_t i = exclude_unigrams; i < config_.order; ++i) { sorts.push_back(chains_[i], config_.sort, Compare(i + 1)); } chains_.Wait(true); } + unsigned int Steps() const { return steps_; } + private: // Create chains, allocating memory to them. Totally heuristic. Count // bounds are upper bounds on the counts or not present. @@ -150,7 +146,7 @@ class Master { for (std::size_t i = 0; i < count_bounds.size(); ++i) { assignments.push_back(static_cast(std::min( static_cast(remaining_mem), - count_bounds[i] * static_cast(NGram::TotalSize(i + 1))))); + count_bounds[i] * static_cast(NGram::TotalSize(i + 1))))); } assignments.resize(config_.order, remaining_mem); @@ -160,7 +156,7 @@ class Master { // Indices of orders that have yet to be assigned. std::vector unassigned; for (std::size_t i = 0; i < config_.order; ++i) { - portions.push_back(static_cast((i+1) * NGram::TotalSize(i+1))); + portions.push_back(static_cast((i+1) * NGram::TotalSize(i+1))); unassigned.push_back(i); } /*If somebody doesn't eat their full dinner, give it to the rest of the @@ -196,7 +192,7 @@ class Master { std::cerr << "Chain sizes:"; for (std::size_t i = 0; i < config_.order; ++i) { std::cerr << ' ' << (i+1) << ":" << assignments[i]; - chains_.push_back(util::stream::ChainConfig(NGram::TotalSize(i + 1), block_count[i], assignments[i])); + chains_.push_back(util::stream::ChainConfig(NGram::TotalSize(i + 1), block_count[i], assignments[i])); } std::cerr << std::endl; } @@ -204,13 +200,15 @@ class Master { PipelineConfig &config_; util::stream::Chains chains_; - // Often only unigrams, but sometimes all orders. - util::FixedArray files_; + + util::stream::FileBuffer unigrams_; + + const unsigned int steps_; }; -void CountText(int text_file /* input */, int vocab_file /* output */, Master &master, uint64_t &token_count, std::string &text_file_name, std::vector &prune_words) { +util::stream::Sort *CountText(int text_file /* input */, int vocab_file /* output */, Master &master, uint64_t &token_count, WordIndex &type_count, std::string &text_file_name, std::vector &prune_words) { const PipelineConfig &config = master.Config(); - std::cerr << "=== 1/5 Counting and sorting n-grams ===" << std::endl; + std::cerr << "=== 1/" << master.Steps() << " Counting and sorting n-grams ===" << std::endl; const std::size_t vocab_usage = CorpusCount::VocabUsage(config.vocab_estimate); UTIL_THROW_IF(config.TotalMemory() < vocab_usage, util::Exception, "Vocab hash size estimate " << vocab_usage << " exceeds total memory " << config.TotalMemory()); @@ -221,37 +219,34 @@ void CountText(int text_file /* input */, int vocab_file /* output */, Master &m (static_cast(config.block_count) + CorpusCount::DedupeMultiplier(config.order)) * // Chain likes memory expressed in terms of total memory. static_cast(config.block_count); - util::stream::Chain chain(util::stream::ChainConfig(NGram::TotalSize(config.order), config.block_count, memory_for_chain)); + util::stream::Chain chain(util::stream::ChainConfig(NGram::TotalSize(config.order), config.block_count, memory_for_chain)); - WordIndex type_count = config.vocab_estimate; + type_count = config.vocab_estimate; util::FilePiece text(text_file, NULL, &std::cerr); text_file_name = text.FileName(); CorpusCount counter(text, vocab_file, token_count, type_count, prune_words, config.prune_vocab_file, chain.BlockSize() / chain.EntrySize(), config.disallowed_symbol_action); chain >> boost::ref(counter); - util::stream::Sort sorter(chain, config.sort, SuffixOrder(config.order), AddCombiner()); + util::scoped_ptr > sorter(new util::stream::Sort(chain, config.sort, SuffixOrder(config.order), CombineCounts())); chain.Wait(true); - std::cerr << "Unigram tokens " << token_count << " types " << type_count << std::endl; - std::cerr << "=== 2/5 Calculating and sorting adjusted counts ===" << std::endl; - master.InitForAdjust(sorter, type_count); + return sorter.release(); } -void InitialProbabilities(const std::vector &counts, const std::vector &counts_pruned, const std::vector &discounts, Master &master, Sorts &primary, - util::FixedArray &gammas, const std::vector &prune_thresholds, bool prune_vocab) { +void InitialProbabilities(const std::vector &counts, const std::vector &counts_pruned, const std::vector &discounts, Master &master, Sorts &primary, util::FixedArray &gammas, const std::vector &prune_thresholds, bool prune_vocab, const SpecialVocab &specials) { const PipelineConfig &config = master.Config(); util::stream::Chains second(config.order); { Sorts sorts; - master.SetupSorts(sorts); + master.SetupSorts(sorts, !config.renumber_vocabulary); PrintStatistics(counts, counts_pruned, discounts); lm::ngram::ShowSizes(counts_pruned); - std::cerr << "=== 3/5 Calculating and sorting initial probabilities ===" << std::endl; + std::cerr << "=== 3/" << master.Steps() << " Calculating and sorting initial probabilities ===" << std::endl; master.SortAndReadTwice(counts_pruned, sorts, second, config.initial_probs.adder_in); } util::stream::Chains gamma_chains(config.order); - InitialProbabilities(config.initial_probs, discounts, master.MutableChains(), second, gamma_chains, prune_thresholds, prune_vocab); + InitialProbabilities(config.initial_probs, discounts, master.MutableChains(), second, gamma_chains, prune_thresholds, prune_vocab, specials); // Don't care about gamma for 0. gamma_chains[0] >> util::stream::kRecycle; gammas.Init(config.order - 1); @@ -260,11 +255,11 @@ void InitialProbabilities(const std::vector &counts, const std::vector gamma_chains[i] >> gammas[i - 1].Sink(); } // Has to be done here due to gamma_chains scope. - master.SetupSorts(primary); + master.SetupSorts(primary, true); } -void InterpolateProbabilities(const std::vector &counts, Master &master, Sorts &primary, util::FixedArray &gammas) { - std::cerr << "=== 4/5 Calculating and writing order-interpolated probabilities ===" << std::endl; +void InterpolateProbabilities(const std::vector &counts, Master &master, Sorts &primary, util::FixedArray &gammas, Output &output, const SpecialVocab &specials) { + std::cerr << "=== 4/" << master.Steps() << " Calculating and writing order-interpolated probabilities ===" << std::endl; const PipelineConfig &config = master.Config(); master.MaximumLazyInput(counts, primary); @@ -278,13 +273,62 @@ void InterpolateProbabilities(const std::vector &counts, Master &maste read_backoffs.entry_size = sizeof(float); gamma_chains.push_back(read_backoffs); - gamma_chains.back() >> gammas[i].Source(); + gamma_chains.back() >> gammas[i].Source(true); } - master >> Interpolate(std::max(master.Config().vocab_size_for_unk, counts[0] - 1 /* is not included */), util::stream::ChainPositions(gamma_chains), config.prune_thresholds, config.prune_vocab, config.output_q); + master >> Interpolate(std::max(master.Config().vocab_size_for_unk, counts[0] - 1 /* is not included */), util::stream::ChainPositions(gamma_chains), config.prune_thresholds, config.prune_vocab, config.output_q, specials); gamma_chains >> util::stream::kRecycle; - master.BufferFinal(counts); + output.SinkProbs(master.MutableChains(), config.output_q); } +class VocabNumbering { + public: + VocabNumbering(StringPiece vocab_file, StringPiece temp_prefix, bool renumber) + : vocab_file_(vocab_file.data(), vocab_file.size()), + temp_prefix_(temp_prefix.data(), temp_prefix.size()), + renumber_(renumber), + specials_(kBOS, kEOS) { + InitFile(renumber || vocab_file.empty()); + } + + int File() const { return null_delimited_.get(); } + + // Compute the vocabulary mapping and return the memory used. + std::size_t ComputeMapping(WordIndex type_count) { + if (!renumber_) return 0; + util::scoped_fd previous(null_delimited_.release()); + InitFile(vocab_file_.empty()); + ngram::SortedVocabulary::ComputeRenumbering(type_count, previous.get(), null_delimited_.get(), vocab_mapping_); + return sizeof(WordIndex) * vocab_mapping_.size(); + } + + void ApplyRenumber(util::stream::Chains &chains) { + if (!renumber_) return; + for (std::size_t i = 0; i < chains.size(); ++i) { + chains[i] >> Renumber(&*vocab_mapping_.begin(), i + 1); + } + specials_ = SpecialVocab(vocab_mapping_[specials_.BOS()], vocab_mapping_[specials_.EOS()]); + } + + const SpecialVocab &Specials() const { return specials_; } + + private: + void InitFile(bool temp) { + null_delimited_.reset(temp ? + util::MakeTemp(temp_prefix_) : + util::CreateOrThrow(vocab_file_.c_str())); + } + + std::string vocab_file_, temp_prefix_; + + util::scoped_fd null_delimited_; + + bool renumber_; + + std::vector vocab_mapping_; + + SpecialVocab specials_; +}; + } // namespace void Pipeline(PipelineConfig &config, int text_file, Output &output) { @@ -293,48 +337,49 @@ void Pipeline(PipelineConfig &config, int text_file, Output &output) { config.sort.buffer_size = config.TotalMemory() / 4; std::cerr << "Warning: changing sort block size to " << config.sort.buffer_size << " bytes due to low total memory." << std::endl; } - if (config.minimum_block < NGram::TotalSize(config.order)) { - config.minimum_block = NGram::TotalSize(config.order); + if (config.minimum_block < NGram::TotalSize(config.order)) { + config.minimum_block = NGram::TotalSize(config.order); std::cerr << "Warning: raising minimum block to " << config.minimum_block << " to fit an ngram in every block." << std::endl; } UTIL_THROW_IF(config.sort.buffer_size < config.minimum_block, util::Exception, "Sort block size " << config.sort.buffer_size << " is below the minimum block size " << config.minimum_block << "."); UTIL_THROW_IF(config.TotalMemory() < config.minimum_block * config.order * config.block_count, util::Exception, "Not enough memory to fit " << (config.order * config.block_count) << " blocks with minimum size " << config.minimum_block << ". Increase memory to " << (config.minimum_block * config.order * config.block_count) << " bytes or decrease the minimum block size."); - UTIL_TIMER("(%w s) Total wall time elapsed\n"); - - Master master(config); + Master master(config, output.Steps()); // master's destructor will wait for chains. But they might be deadlocked if // this thread dies because e.g. it ran out of memory. try { - util::scoped_fd vocab_file(config.vocab_file.empty() ? - util::MakeTemp(config.TempPrefix()) : - util::CreateOrThrow(config.vocab_file.c_str())); - output.SetVocabFD(vocab_file.get()); + VocabNumbering numbering(config.vocab_file, config.TempPrefix(), config.renumber_vocabulary); uint64_t token_count; + WordIndex type_count; std::string text_file_name; - std::vector prune_words; - CountText(text_file, vocab_file.get(), master, token_count, text_file_name, prune_words); + util::scoped_ptr > sorted_counts( + CountText(text_file, numbering.File(), master, token_count, type_count, text_file_name, prune_words)); + std::cerr << "Unigram tokens " << token_count << " types " << type_count << std::endl; + + // Create vocab mapping, which uses temporary memory, while nothing else is happening. + std::size_t subtract_for_numbering = numbering.ComputeMapping(type_count); + output.SetVocabFD(numbering.File()); + + std::cerr << "=== 2/" << master.Steps() << " Calculating and sorting adjusted counts ===" << std::endl; + master.InitForAdjust(*sorted_counts, type_count, subtract_for_numbering); + sorted_counts.reset(); std::vector counts; std::vector counts_pruned; std::vector discounts; master >> AdjustCounts(config.prune_thresholds, counts, counts_pruned, prune_words, config.discount, discounts); + numbering.ApplyRenumber(master.MutableChains()); { util::FixedArray gammas; Sorts primary; - InitialProbabilities(counts, counts_pruned, discounts, master, primary, gammas, config.prune_thresholds, config.prune_vocab); - InterpolateProbabilities(counts_pruned, master, primary, gammas); + InitialProbabilities(counts, counts_pruned, discounts, master, primary, gammas, config.prune_thresholds, config.prune_vocab, numbering.Specials()); + output.SetHeader(HeaderInfo(text_file_name, token_count, counts_pruned)); + // Also does output. + InterpolateProbabilities(counts_pruned, master, primary, gammas, output, numbering.Specials()); } - - std::cerr << "=== 5/5 Writing ARPA model ===" << std::endl; - - output.SetHeader(HeaderInfo(text_file_name, token_count, counts_pruned)); - output.Apply(PROB_SEQUENTIAL_HOOK, master.MutableChains()); - master >> util::stream::kRecycle; - master.MutableChains().Wait(true); } catch (const util::Exception &e) { std::cerr << e.what() << std::endl; abort(); diff --git a/lm/builder/pipeline.hh b/lm/builder/pipeline.hh index 1987daff1..695ecf7bd 100644 --- a/lm/builder/pipeline.hh +++ b/lm/builder/pipeline.hh @@ -39,6 +39,9 @@ struct PipelineConfig { bool prune_vocab; std::string prune_vocab_file; + /* Renumber the vocabulary the way the trie likes it? */ + bool renumber_vocabulary; + // What to do with discount failures. DiscountConfig discount; diff --git a/lm/builder/print.cc b/lm/builder/print.cc index 56a3134d8..178e54a21 100644 --- a/lm/builder/print.cc +++ b/lm/builder/print.cc @@ -23,30 +23,29 @@ VocabReconstitute::VocabReconstitute(int fd) { map_.push_back(i); } +void PrintARPA::Sink(util::stream::Chains &chains) { + chains >> boost::ref(*this); +} + void PrintARPA::Run(const util::stream::ChainPositions &positions) { VocabReconstitute vocab(GetVocabFD()); - - // Write header. TODO: integers in FakeOFStream. - { - std::stringstream stream; - if (verbose_header_) { - stream << "# Input file: " << GetHeader().input_file << '\n'; - stream << "# Token count: " << GetHeader().token_count << '\n'; - stream << "# Smoothing: Modified Kneser-Ney" << '\n'; - } - stream << "\\data\\\n"; - for (size_t i = 0; i < positions.size(); ++i) { - stream << "ngram " << (i+1) << '=' << GetHeader().counts_pruned[i] << '\n'; - } - stream << '\n'; - std::string as_string(stream.str()); - util::WriteOrThrow(out_fd_.get(), as_string.data(), as_string.size()); - } - util::FakeOFStream out(out_fd_.get()); + + // Write header. + if (verbose_header_) { + out << "# Input file: " << GetHeader().input_file << '\n'; + out << "# Token count: " << GetHeader().token_count << '\n'; + out << "# Smoothing: Modified Kneser-Ney" << '\n'; + } + out << "\\data\\\n"; + for (size_t i = 0; i < positions.size(); ++i) { + out << "ngram " << (i+1) << '=' << GetHeader().counts_pruned[i] << '\n'; + } + out << '\n'; + for (unsigned order = 1; order <= positions.size(); ++order) { out << "\\" << order << "-grams:" << '\n'; - for (NGramStream stream(positions[order - 1]); stream; ++stream) { + for (NGramStream stream(positions[order - 1]); stream; ++stream) { // Correcting for numerical precision issues. Take that IRST. out << stream->Value().complete.prob << '\t' << vocab.Lookup(*stream->begin()); for (const WordIndex *i = stream->begin() + 1; i != stream->end(); ++i) { diff --git a/lm/builder/print.hh b/lm/builder/print.hh index 093a35697..5f293de85 100644 --- a/lm/builder/print.hh +++ b/lm/builder/print.hh @@ -1,14 +1,17 @@ #ifndef LM_BUILDER_PRINT_H #define LM_BUILDER_PRINT_H -#include "lm/builder/ngram.hh" -#include "lm/builder/ngram_stream.hh" +#include "lm/common/ngram_stream.hh" #include "lm/builder/output.hh" +#include "lm/builder/payload.hh" +#include "lm/common/ngram.hh" #include "util/fake_ofstream.hh" #include "util/file.hh" #include "util/mmap.hh" #include "util/string_piece.hh" +#include + #include #include @@ -43,15 +46,15 @@ class VocabReconstitute { }; // Not defined, only specialized. -template void PrintPayload(util::FakeOFStream &to, const Payload &payload); -template <> inline void PrintPayload(util::FakeOFStream &to, const Payload &payload) { +template void PrintPayload(util::FakeOFStream &to, const BuildingPayload &payload); +template <> inline void PrintPayload(util::FakeOFStream &to, const BuildingPayload &payload) { // TODO slow - to << boost::lexical_cast(payload.count); + to << payload.count; } -template <> inline void PrintPayload(util::FakeOFStream &to, const Payload &payload) { +template <> inline void PrintPayload(util::FakeOFStream &to, const BuildingPayload &payload) { to << log10(payload.uninterp.prob) << ' ' << log10(payload.uninterp.gamma); } -template <> inline void PrintPayload(util::FakeOFStream &to, const Payload &payload) { +template <> inline void PrintPayload(util::FakeOFStream &to, const BuildingPayload &payload) { to << payload.complete.prob << ' ' << payload.complete.backoff; } @@ -70,8 +73,8 @@ template class Print { void Run(const util::stream::ChainPositions &chains) { util::scoped_fd fd(to_); util::FakeOFStream out(to_); - NGramStreams streams(chains); - for (NGramStream *s = streams.begin(); s != streams.end(); ++s) { + NGramStreams streams(chains); + for (NGramStream *s = streams.begin(); s != streams.end(); ++s) { DumpStream(*s, out); } } @@ -79,12 +82,12 @@ template class Print { void Run(const util::stream::ChainPosition &position) { util::scoped_fd fd(to_); util::FakeOFStream out(to_); - NGramStream stream(position); + NGramStream stream(position); DumpStream(stream, out); } private: - void DumpStream(NGramStream &stream, util::FakeOFStream &to) { + void DumpStream(NGramStream &stream, util::FakeOFStream &to) { for (; stream; ++stream) { PrintPayload(to, stream->Value()); for (const WordIndex *w = stream->begin(); w != stream->end(); ++w) { @@ -103,6 +106,8 @@ class PrintARPA : public OutputHook { explicit PrintARPA(int fd, bool verbose_header) : OutputHook(PROB_SEQUENTIAL_HOOK), out_fd_(fd), verbose_header_(verbose_header) {} + void Sink(util::stream::Chains &chains); + void Run(const util::stream::ChainPositions &positions); private: diff --git a/lm/builder/special.hh b/lm/builder/special.hh new file mode 100644 index 000000000..c70865ce1 --- /dev/null +++ b/lm/builder/special.hh @@ -0,0 +1,27 @@ +#ifndef LM_BUILDER_SPECIAL_H +#define LM_BUILDER_SPECIAL_H + +#include "lm/word_index.hh" + +namespace lm { namespace builder { + +class SpecialVocab { + public: + SpecialVocab(WordIndex bos, WordIndex eos) : bos_(bos), eos_(eos) {} + + bool IsSpecial(WordIndex word) const { + return word == kUNK || word == bos_ || word == eos_; + } + + WordIndex UNK() const { return kUNK; } + WordIndex BOS() const { return bos_; } + WordIndex EOS() const { return eos_; } + + private: + WordIndex bos_; + WordIndex eos_; +}; + +}} // namespaces + +#endif // LM_BUILDER_SPECIAL_H diff --git a/lm/common/Jamfile b/lm/common/Jamfile new file mode 100644 index 000000000..1c9c37210 --- /dev/null +++ b/lm/common/Jamfile @@ -0,0 +1,2 @@ +fakelib common : [ glob *.cc : *test.cc *main.cc ] + ../../util//kenutil ../../util/stream//stream ../../util/double-conversion//double-conversion ..//kenlm ; diff --git a/lm/builder/sort.hh b/lm/common/compare.hh similarity index 64% rename from lm/builder/sort.hh rename to lm/common/compare.hh index ed20b4b79..1c7cd2499 100644 --- a/lm/builder/sort.hh +++ b/lm/common/compare.hh @@ -1,18 +1,12 @@ -#ifndef LM_BUILDER_SORT_H -#define LM_BUILDER_SORT_H +#ifndef LM_COMMON_COMPARE_H +#define LM_COMMON_COMPARE_H -#include "lm/builder/ngram_stream.hh" -#include "lm/builder/ngram.hh" #include "lm/word_index.hh" -#include "util/stream/sort.hh" - -#include "util/stream/timer.hh" #include #include namespace lm { -namespace builder { /** * Abstract parent class for defining custom n-gram comparators. @@ -175,70 +169,6 @@ class PrefixOrder : public Comparator { static const unsigned kMatchOffset = 0; }; -// Sum counts for the same n-gram. -struct AddCombiner { - bool operator()(void *first_void, const void *second_void, const SuffixOrder &compare) const { - NGram first(first_void, compare.Order()); - // There isn't a const version of NGram. - NGram second(const_cast(second_void), compare.Order()); - if (memcmp(first.begin(), second.begin(), sizeof(WordIndex) * compare.Order())) return false; - first.Count() += second.Count(); - return true; - } -}; - -// The combiner is only used on a single chain, so I didn't bother to allow -// that template. -/** - * Represents an @ref util::FixedArray "array" capable of storing @ref util::stream::Sort "Sort" objects. - * - * In the anticipated use case, an instance of this class will maintain one @ref util::stream::Sort "Sort" object - * for each n-gram order (ranging from 1 up to the maximum n-gram order being processed). - * Use in this manner would enable the n-grams each n-gram order to be sorted, in parallel. - * - * @tparam Compare An @ref Comparator "ngram comparator" to use during sorting. - */ -template class Sorts : public util::FixedArray > { - private: - typedef util::stream::Sort S; - typedef util::FixedArray P; - - public: - - /** - * Constructs, but does not initialize. - * - * @ref util::FixedArray::Init() "Init" must be called before use. - * - * @see util::FixedArray::Init() - */ - Sorts() {} - - /** - * Constructs an @ref util::FixedArray "array" capable of storing a fixed number of @ref util::stream::Sort "Sort" objects. - * - * @param number The maximum number of @ref util::stream::Sort "sorters" that can be held by this @ref util::FixedArray "array" - * @see util::FixedArray::FixedArray() - */ - explicit Sorts(std::size_t number) : util::FixedArray >(number) {} - - /** - * Constructs a new @ref util::stream::Sort "Sort" object which is stored in this @ref util::FixedArray "array". - * - * The new @ref util::stream::Sort "Sort" object is constructed using the provided @ref util::stream::SortConfig "SortConfig" and @ref Comparator "ngram comparator"; - * once constructed, a new worker @ref util::stream::Thread "thread" (owned by the @ref util::stream::Chain "chain") will sort the n-gram data stored - * in the @ref util::stream::Block "blocks" of the provided @ref util::stream::Chain "chain". - * - * @see util::stream::Sort::Sort() - * @see util::stream::Chain::operator>>() - */ - void push_back(util::stream::Chain &chain, const util::stream::SortConfig &config, const Compare &compare) { - new (P::end()) S(chain, config, compare); // use "placement new" syntax to initalize S in an already-allocated memory location - P::Constructed(); - } -}; - -} // namespace builder } // namespace lm -#endif // LM_BUILDER_SORT_H +#endif // LM_COMMON_COMPARE_H diff --git a/lm/common/model_buffer.cc b/lm/common/model_buffer.cc new file mode 100644 index 000000000..d4635da51 --- /dev/null +++ b/lm/common/model_buffer.cc @@ -0,0 +1,82 @@ +#include "lm/common/model_buffer.hh" +#include "util/exception.hh" +#include "util/fake_ofstream.hh" +#include "util/file.hh" +#include "util/file_piece.hh" +#include "util/stream/io.hh" +#include "util/stream/multi_stream.hh" + +#include + +namespace lm { namespace common { + +namespace { +const char kMetadataHeader[] = "KenLM intermediate binary file"; +} // namespace + +ModelBuffer::ModelBuffer(const std::string &file_base, bool keep_buffer, bool output_q) + : file_base_(file_base), keep_buffer_(keep_buffer), output_q_(output_q) {} + +ModelBuffer::ModelBuffer(const std::string &file_base) + : file_base_(file_base), keep_buffer_(false) { + const std::string full_name = file_base_ + ".kenlm_intermediate"; + util::FilePiece in(full_name.c_str()); + StringPiece token = in.ReadLine(); + UTIL_THROW_IF2(token != kMetadataHeader, "File " << full_name << " begins with \"" << token << "\" not " << kMetadataHeader); + + token = in.ReadDelimited(); + UTIL_THROW_IF2(token != "Order", "Expected Order, got \"" << token << "\" in " << full_name); + unsigned long order = in.ReadULong(); + + token = in.ReadDelimited(); + UTIL_THROW_IF2(token != "Payload", "Expected Payload, got \"" << token << "\" in " << full_name); + token = in.ReadDelimited(); + if (token == "q") { + output_q_ = true; + } else if (token == "pb") { + output_q_ = false; + } else { + UTIL_THROW(util::Exception, "Unknown payload " << token); + } + + files_.Init(order); + for (unsigned long i = 0; i < order; ++i) { + files_.push_back(util::OpenReadOrThrow((file_base_ + '.' + boost::lexical_cast(i + 1)).c_str())); + } +} + +// virtual destructor +ModelBuffer::~ModelBuffer() {} + +void ModelBuffer::Sink(util::stream::Chains &chains) { + // Open files. + files_.Init(chains.size()); + for (std::size_t i = 0; i < chains.size(); ++i) { + if (keep_buffer_) { + files_.push_back(util::CreateOrThrow( + (file_base_ + '.' + boost::lexical_cast(i + 1)).c_str() + )); + } else { + files_.push_back(util::MakeTemp(file_base_)); + } + chains[i] >> util::stream::Write(files_.back().get()); + } + if (keep_buffer_) { + util::scoped_fd metadata(util::CreateOrThrow((file_base_ + ".kenlm_intermediate").c_str())); + util::FakeOFStream meta(metadata.get(), 200); + meta << kMetadataHeader << "\nOrder " << chains.size() << "\nPayload " << (output_q_ ? "q" : "pb") << '\n'; + } +} + +void ModelBuffer::Source(util::stream::Chains &chains) { + assert(chains.size() == files_.size()); + for (unsigned int i = 0; i < files_.size(); ++i) { + chains[i] >> util::stream::PRead(files_[i].get()); + } +} + +std::size_t ModelBuffer::Order() const { + return files_.size(); +} + +}} // namespaces diff --git a/lm/common/model_buffer.hh b/lm/common/model_buffer.hh new file mode 100644 index 000000000..6a5c7bf49 --- /dev/null +++ b/lm/common/model_buffer.hh @@ -0,0 +1,45 @@ +#ifndef LM_BUILDER_MODEL_BUFFER_H +#define LM_BUILDER_MODEL_BUFFER_H + +/* Format with separate files in suffix order. Each file contains + * n-grams of the same order. + */ + +#include "util/file.hh" +#include "util/fixed_array.hh" + +#include + +namespace util { namespace stream { class Chains; } } + +namespace lm { namespace common { + +class ModelBuffer { + public: + // Construct for writing. + ModelBuffer(const std::string &file_base, bool keep_buffer, bool output_q); + + // Load from file. + explicit ModelBuffer(const std::string &file_base); + + // explicit for virtual destructor. + ~ModelBuffer(); + + void Sink(util::stream::Chains &chains); + + void Source(util::stream::Chains &chains); + + // The order of the n-gram model that is associated with the model buffer. + std::size_t Order() const; + + private: + const std::string file_base_; + const bool keep_buffer_; + bool output_q_; + + util::FixedArray files_; +}; + +}} // namespaces + +#endif // LM_BUILDER_MODEL_BUFFER_H diff --git a/lm/builder/ngram.hh b/lm/common/ngram.hh similarity index 52% rename from lm/builder/ngram.hh rename to lm/common/ngram.hh index d0033206c..813017640 100644 --- a/lm/builder/ngram.hh +++ b/lm/common/ngram.hh @@ -1,5 +1,5 @@ -#ifndef LM_BUILDER_NGRAM_H -#define LM_BUILDER_NGRAM_H +#ifndef LM_COMMON_NGRAM_H +#define LM_COMMON_NGRAM_H #include "lm/weights.hh" #include "lm/word_index.hh" @@ -10,22 +10,10 @@ #include namespace lm { -namespace builder { -struct Uninterpolated { - float prob; // Uninterpolated probability. - float gamma; // Interpolation weight for lower order. -}; - -union Payload { - uint64_t count; - Uninterpolated uninterp; - ProbBackoff complete; -}; - -class NGram { +class NGramHeader { public: - NGram(void *begin, std::size_t order) + NGramHeader(void *begin, std::size_t order) : begin_(static_cast(begin)), end_(begin_ + order) {} const uint8_t *Base() const { return reinterpret_cast(begin_); } @@ -37,25 +25,30 @@ class NGram { end_ = begin_ + difference; } - // Would do operator++ but that can get confusing for a stream. - void NextInMemory() { - ReBase(&Value() + 1); - } - + // These are for the vocab index. // Lower-case in deference to STL. const WordIndex *begin() const { return begin_; } WordIndex *begin() { return begin_; } const WordIndex *end() const { return end_; } WordIndex *end() { return end_; } - const Payload &Value() const { return *reinterpret_cast(end_); } - Payload &Value() { return *reinterpret_cast(end_); } - - uint64_t &Count() { return Value().count; } - uint64_t Count() const { return Value().count; } - std::size_t Order() const { return end_ - begin_; } + private: + WordIndex *begin_, *end_; +}; + +template class NGram : public NGramHeader { + public: + typedef PayloadT Payload; + + NGram(void *begin, std::size_t order) : NGramHeader(begin, order) {} + + // Would do operator++ but that can get confusing for a stream. + void NextInMemory() { + ReBase(&Value() + 1); + } + static std::size_t TotalSize(std::size_t order) { return order * sizeof(WordIndex) + sizeof(Payload); } @@ -63,46 +56,17 @@ class NGram { // Compiler should optimize this. return TotalSize(Order()); } + static std::size_t OrderFromSize(std::size_t size) { std::size_t ret = (size - sizeof(Payload)) / sizeof(WordIndex); assert(size == TotalSize(ret)); return ret; } - // manipulate msb to signal that ngram can be pruned - /*mjd**********************************************************************/ - - bool IsMarked() const { - return Value().count >> (sizeof(Value().count) * 8 - 1); - } - - void Mark() { - Value().count |= (1ul << (sizeof(Value().count) * 8 - 1)); - } - - void Unmark() { - Value().count &= ~(1ul << (sizeof(Value().count) * 8 - 1)); - } - - uint64_t UnmarkedCount() const { - return Value().count & ~(1ul << (sizeof(Value().count) * 8 - 1)); - } - - uint64_t CutoffCount() const { - return IsMarked() ? 0 : UnmarkedCount(); - } - - /*mjd**********************************************************************/ - - private: - WordIndex *begin_, *end_; + const Payload &Value() const { return *reinterpret_cast(end()); } + Payload &Value() { return *reinterpret_cast(end()); } }; -const WordIndex kUNK = 0; -const WordIndex kBOS = 1; -const WordIndex kEOS = 2; - -} // namespace builder } // namespace lm -#endif // LM_BUILDER_NGRAM_H +#endif // LM_COMMON_NGRAM_H diff --git a/lm/builder/ngram_stream.hh b/lm/common/ngram_stream.hh similarity index 50% rename from lm/builder/ngram_stream.hh rename to lm/common/ngram_stream.hh index ab42734c4..53c4ffcb8 100644 --- a/lm/builder/ngram_stream.hh +++ b/lm/common/ngram_stream.hh @@ -1,16 +1,16 @@ #ifndef LM_BUILDER_NGRAM_STREAM_H #define LM_BUILDER_NGRAM_STREAM_H -#include "lm/builder/ngram.hh" +#include "lm/common/ngram.hh" #include "util/stream/chain.hh" #include "util/stream/multi_stream.hh" #include "util/stream/stream.hh" #include -namespace lm { namespace builder { +namespace lm { -class NGramStream { +template class NGramStream { public: NGramStream() : gram_(NULL, 0) {} @@ -20,14 +20,14 @@ class NGramStream { void Init(const util::stream::ChainPosition &position) { stream_.Init(position); - gram_ = NGram(stream_.Get(), NGram::OrderFromSize(position.GetChain().EntrySize())); + gram_ = NGram(stream_.Get(), NGram::OrderFromSize(position.GetChain().EntrySize())); } - NGram &operator*() { return gram_; } - const NGram &operator*() const { return gram_; } + NGram &operator*() { return gram_; } + const NGram &operator*() const { return gram_; } - NGram *operator->() { return &gram_; } - const NGram *operator->() const { return &gram_; } + NGram *operator->() { return &gram_; } + const NGram *operator->() const { return &gram_; } void *Get() { return stream_.Get(); } const void *Get() const { return stream_.Get(); } @@ -43,16 +43,22 @@ class NGramStream { } private: - NGram gram_; + NGram gram_; util::stream::Stream stream_; }; -inline util::stream::Chain &operator>>(util::stream::Chain &chain, NGramStream &str) { +template inline util::stream::Chain &operator>>(util::stream::Chain &chain, NGramStream &str) { str.Init(chain.Add()); return chain; } -typedef util::stream::GenericStreams NGramStreams; +template class NGramStreams : public util::stream::GenericStreams > { + private: + typedef util::stream::GenericStreams > P; + public: + NGramStreams() : P() {} + NGramStreams(const util::stream::ChainPositions &positions) : P(positions) {} +}; -}} // namespaces +} // namespace #endif // LM_BUILDER_NGRAM_STREAM_H diff --git a/lm/common/renumber.cc b/lm/common/renumber.cc new file mode 100644 index 000000000..0632a149b --- /dev/null +++ b/lm/common/renumber.cc @@ -0,0 +1,17 @@ +#include "lm/common/renumber.hh" +#include "lm/common/ngram.hh" + +#include "util/stream/stream.hh" + +namespace lm { + +void Renumber::Run(const util::stream::ChainPosition &position) { + for (util::stream::Stream stream(position); stream; ++stream) { + NGramHeader gram(stream.Get(), order_); + for (WordIndex *w = gram.begin(); w != gram.end(); ++w) { + *w = new_numbers_[*w]; + } + } +} + +} // namespace lm diff --git a/lm/common/renumber.hh b/lm/common/renumber.hh new file mode 100644 index 000000000..ca25c4dc6 --- /dev/null +++ b/lm/common/renumber.hh @@ -0,0 +1,30 @@ +/* Map vocab ids. This is useful to merge independently collected counts or + * change the vocab ids to the order used by the trie. + */ +#ifndef LM_COMMON_RENUMBER_H +#define LM_COMMON_RENUMBER_H + +#include "lm/word_index.hh" + +#include + +namespace util { namespace stream { class ChainPosition; }} + +namespace lm { + +class Renumber { + public: + // Assumes the array is large enough to map all words and stays alive while + // the thread is active. + Renumber(const WordIndex *new_numbers, std::size_t order) + : new_numbers_(new_numbers), order_(order) {} + + void Run(const util::stream::ChainPosition &position); + + private: + const WordIndex *new_numbers_; + std::size_t order_; +}; + +} // namespace lm +#endif // LM_COMMON_RENUMBER_H diff --git a/lm/kenlm_benchmark_main.cc b/lm/kenlm_benchmark_main.cc new file mode 100644 index 000000000..d8b659139 --- /dev/null +++ b/lm/kenlm_benchmark_main.cc @@ -0,0 +1,128 @@ +#include "lm/model.hh" +#include "util/fake_ofstream.hh" +#include "util/file.hh" +#include "util/file_piece.hh" +#include "util/usage.hh" + +#include + +namespace { + +template void ConvertToBytes(const Model &model, int fd_in) { + util::FilePiece in(fd_in); + util::FakeOFStream out(1); + Width width; + StringPiece word; + const Width end_sentence = (Width)model.GetVocabulary().EndSentence(); + while (true) { + while (in.ReadWordSameLine(word)) { + width = (Width)model.GetVocabulary().Index(word); + out.write(&width, sizeof(Width)); + } + if (!in.ReadLineOrEOF(word)) break; + out.write(&end_sentence, sizeof(Width)); + } +} + +template void QueryFromBytes(const Model &model, int fd_in) { + lm::ngram::State state[3]; + const lm::ngram::State *const begin_state = &model.BeginSentenceState(); + const lm::ngram::State *next_state = begin_state; + Width kEOS = model.GetVocabulary().EndSentence(); + Width buf[4096]; + float sum = 0.0; + while (true) { + std::size_t got = util::ReadOrEOF(fd_in, buf, sizeof(buf)); + if (!got) break; + UTIL_THROW_IF2(got % sizeof(Width), "File size not a multiple of vocab id size " << sizeof(Width)); + got /= sizeof(Width); + // Do even stuff first. + const Width *even_end = buf + (got & ~1); + // Alternating states + const Width *i; + for (i = buf; i != even_end;) { + sum += model.FullScore(*next_state, *i, state[1]).prob; + next_state = (*i++ == kEOS) ? begin_state : &state[1]; + sum += model.FullScore(*next_state, *i, state[0]).prob; + next_state = (*i++ == kEOS) ? begin_state : &state[0]; + } + // Odd corner case. + if (got & 1) { + sum += model.FullScore(*next_state, *i, state[2]).prob; + next_state = (*i++ == kEOS) ? begin_state : &state[2]; + } + } + std::cout << "Sum is " << sum << std::endl; +} + +template void DispatchFunction(const Model &model, bool query) { + if (query) { + QueryFromBytes(model, 0); + } else { + ConvertToBytes(model, 0); + } +} + +template void DispatchWidth(const char *file, bool query) { + Model model(file); + lm::WordIndex bound = model.GetVocabulary().Bound(); + if (bound <= 256) { + DispatchFunction(model, query); + } else if (bound <= 65536) { + DispatchFunction(model, query); + } else if (bound <= (1ULL << 32)) { + DispatchFunction(model, query); + } else { + DispatchFunction(model, query); + } +} + +void Dispatch(const char *file, bool query) { + using namespace lm::ngram; + lm::ngram::ModelType model_type; + if (lm::ngram::RecognizeBinary(file, model_type)) { + switch(model_type) { + case PROBING: + DispatchWidth(file, query); + break; + case REST_PROBING: + DispatchWidth(file, query); + break; + case TRIE: + DispatchWidth(file, query); + break; + case QUANT_TRIE: + DispatchWidth(file, query); + break; + case ARRAY_TRIE: + DispatchWidth(file, query); + break; + case QUANT_ARRAY_TRIE: + DispatchWidth(file, query); + break; + default: + UTIL_THROW(util::Exception, "Unrecognized kenlm model type " << model_type); + } + } else { + UTIL_THROW(util::Exception, "Binarize before running benchmarks."); + } +} + +} // namespace + +int main(int argc, char *argv[]) { + if (argc != 3 || (strcmp(argv[1], "vocab") && strcmp(argv[1], "query"))) { + std::cerr + << "Benchmark program for KenLM. Intended usage:\n" + << "#Convert text to vocabulary ids offline. These ids are tied to a model.\n" + << argv[0] << " vocab $model <$text >$text.vocab\n" + << "#Ensure files are in RAM.\n" + << "cat $text.vocab $model >/dev/null\n" + << "#Timed query against the model, including loading.\n" + << "time " << argv[0] << " query $model <$text.vocab\n"; + return 1; + } + Dispatch(argv[2], !strcmp(argv[1], "query")); + util::PrintUsage(std::cerr); + return 0; +} diff --git a/lm/ngram_query.hh b/lm/ngram_query.hh index 937fe2421..b19c5aa4f 100644 --- a/lm/ngram_query.hh +++ b/lm/ngram_query.hh @@ -3,45 +3,53 @@ #include "lm/enumerate_vocab.hh" #include "lm/model.hh" +#include "util/fake_ofstream.hh" #include "util/file_piece.hh" #include "util/usage.hh" #include -#include -#include -#include #include #include namespace lm { namespace ngram { -struct BasicPrint { - void Word(StringPiece, WordIndex, const FullScoreReturn &) const {} - void Line(uint64_t oov, float total) const { - std::cout << "Total: " << total << " OOV: " << oov << '\n'; - } - void Summary(double, double, uint64_t, uint64_t) {} +class QueryPrinter { + public: + QueryPrinter(int fd, bool print_word, bool print_line, bool print_summary, bool flush) + : out_(fd), print_word_(print_word), print_line_(print_line), print_summary_(print_summary), flush_(flush) {} + void Word(StringPiece surface, WordIndex vocab, const FullScoreReturn &ret) { + if (!print_word_) return; + out_ << surface << '=' << vocab << ' ' << static_cast(ret.ngram_length) << ' ' << ret.prob << '\t'; + if (flush_) out_.flush(); + } + + void Line(uint64_t oov, float total) { + if (!print_line_) return; + out_ << "Total: " << total << " OOV: " << oov << '\n'; + if (flush_) out_.flush(); + } + + void Summary(double ppl_including_oov, double ppl_excluding_oov, uint64_t corpus_oov, uint64_t corpus_tokens) { + if (!print_summary_) return; + out_ << + "Perplexity including OOVs:\t" << ppl_including_oov << "\n" + "Perplexity excluding OOVs:\t" << ppl_excluding_oov << "\n" + "OOVs:\t" << corpus_oov << "\n" + "Tokens:\t" << corpus_tokens << '\n'; + out_.flush(); + } + + private: + util::FakeOFStream out_; + bool print_word_; + bool print_line_; + bool print_summary_; + bool flush_; }; -struct FullPrint : public BasicPrint { - void Word(StringPiece surface, WordIndex vocab, const FullScoreReturn &ret) const { - std::cout << surface << '=' << vocab << ' ' << static_cast(ret.ngram_length) << ' ' << ret.prob << '\t'; - } - - void Summary(double ppl_including_oov, double ppl_excluding_oov, uint64_t corpus_oov, uint64_t corpus_tokens) { - std::cout << - "Perplexity including OOVs:\t" << ppl_including_oov << "\n" - "Perplexity excluding OOVs:\t" << ppl_excluding_oov << "\n" - "OOVs:\t" << corpus_oov << "\n" - "Tokens:\t" << corpus_tokens << '\n' - ; - } -}; - -template void Query(const Model &model, bool sentence_context) { - Printer printer; +template void Query(const Model &model, bool sentence_context, Printer &printer) { typename Model::State state, out; lm::FullScoreReturn ret; StringPiece word; @@ -92,13 +100,9 @@ template void Query(const Model &model, bool senten corpus_tokens); } -template void Query(const char *file, const Config &config, bool sentence_context, bool show_words) { +template void Query(const char *file, const Config &config, bool sentence_context, QueryPrinter &printer) { Model model(file, config); - if (show_words) { - Query(model, sentence_context); - } else { - Query(model, sentence_context); - } + Query(model, sentence_context, printer); } } // namespace ngram diff --git a/lm/query_main.cc b/lm/query_main.cc index 3013ff21e..0bd28f7a9 100644 --- a/lm/query_main.cc +++ b/lm/query_main.cc @@ -10,9 +10,10 @@ void Usage(const char *name) { std::cerr << "KenLM was compiled with maximum order " << KENLM_MAX_ORDER << ".\n" - "Usage: " << name << " [-n] [-s] lm_file\n" + "Usage: " << name << " [-b] [-n] [-w] [-s] lm_file\n" + "-b: Do not buffer output.\n" "-n: Do not wrap the input in and .\n" - "-s: Sentence totals only.\n" + "-v summary|sentence|word: Level of verbosity\n" "-l lazy|populate|read|parallel: Load lazily, with populate, or malloc+read\n" "The default loading method is populate on Linux and read on others.\n"; exit(1); @@ -24,16 +25,28 @@ int main(int argc, char *argv[]) { lm::ngram::Config config; bool sentence_context = true; - bool show_words = true; + unsigned int verbosity = 2; + bool flush = false; int opt; - while ((opt = getopt(argc, argv, "hnsl:")) != -1) { + while ((opt = getopt(argc, argv, "bnv:l:")) != -1) { switch (opt) { + case 'b': + flush = true; + break; case 'n': sentence_context = false; break; - case 's': - show_words = false; + case 'v': + if (!strcmp(optarg, "word") || !strcmp(optarg, "2")) { + verbosity = 2; + } else if (!strcmp(optarg, "sentence") || !strcmp(optarg, "1")) { + verbosity = 1; + } else if (!strcmp(optarg, "summary") || !strcmp(optarg, "0")) { + verbosity = 0; + } else { + Usage(argv[0]); + } break; case 'l': if (!strcmp(optarg, "lazy")) { @@ -55,6 +68,7 @@ int main(int argc, char *argv[]) { } if (optind + 1 != argc) Usage(argv[0]); + lm::ngram::QueryPrinter printer(1, verbosity >= 2, verbosity >= 1, true, flush); const char *file = argv[optind]; try { using namespace lm::ngram; @@ -62,22 +76,22 @@ int main(int argc, char *argv[]) { if (RecognizeBinary(file, model_type)) { switch(model_type) { case PROBING: - Query(file, config, sentence_context, show_words); + Query(file, config, sentence_context, printer); break; case REST_PROBING: - Query(file, config, sentence_context, show_words); + Query(file, config, sentence_context, printer); break; case TRIE: - Query(file, config, sentence_context, show_words); + Query(file, config, sentence_context, printer); break; case QUANT_TRIE: - Query(file, config, sentence_context, show_words); + Query(file, config, sentence_context, printer); break; case ARRAY_TRIE: - Query(file, config, sentence_context, show_words); + Query(file, config, sentence_context, printer); break; case QUANT_ARRAY_TRIE: - Query(file, config, sentence_context, show_words); + Query(file, config, sentence_context, printer); break; default: std::cerr << "Unrecognized kenlm model type " << model_type << std::endl; @@ -86,14 +100,11 @@ int main(int argc, char *argv[]) { #ifdef WITH_NPLM } else if (lm::np::Model::Recognize(file)) { lm::np::Model model(file); - if (show_words) { - Query(model, sentence_context); - } else { - Query(model, sentence_context); - } + Query(model, sentence_context, printer); + Query(model, sentence_context, printer); #endif } else { - Query(file, config, sentence_context, show_words); + Query(file, config, sentence_context, printer); } util::PrintUsage(std::cerr); } catch (const std::exception &e) { diff --git a/lm/value.hh b/lm/value.hh index d017d59fc..d2425cc13 100644 --- a/lm/value.hh +++ b/lm/value.hh @@ -1,6 +1,7 @@ #ifndef LM_VALUE_H #define LM_VALUE_H +#include "lm/config.hh" #include "lm/model_type.hh" #include "lm/value_build.hh" #include "lm/weights.hh" diff --git a/lm/vocab.cc b/lm/vocab.cc index f6d834323..5696e60b3 100644 --- a/lm/vocab.cc +++ b/lm/vocab.cc @@ -6,13 +6,14 @@ #include "lm/config.hh" #include "lm/weights.hh" #include "util/exception.hh" +#include "util/fake_ofstream.hh" #include "util/file.hh" #include "util/joint_sort.hh" #include "util/murmur_hash.hh" #include "util/probing_hash_table.hh" -#include #include +#include namespace lm { namespace ngram { @@ -31,6 +32,7 @@ const uint64_t kUnknownHash = detail::HashForVocab("", 5); // Sadly some LMs have . const uint64_t kUnknownCapHash = detail::HashForVocab("", 5); +// TODO: replace with FilePiece. void ReadWords(int fd, EnumerateVocab *enumerate, WordIndex expected_count, uint64_t offset) { util::SeekOrThrow(fd, offset); // Check that we're at the right place by reading which is always first. @@ -69,10 +71,17 @@ void ReadWords(int fd, EnumerateVocab *enumerate, WordIndex expected_count, uint UTIL_THROW_IF(expected_count != index, FormatLoadException, "The binary file has the wrong number of words at the end. This could be caused by a truncated binary file."); } +// Constructor ordering madness. +int SeekAndReturn(int fd, uint64_t start) { + util::SeekOrThrow(fd, start); + return fd; +} } // namespace +ImmediateWriteWordsWrapper::ImmediateWriteWordsWrapper(EnumerateVocab *inner, int fd, uint64_t start) + : inner_(inner), stream_(SeekAndReturn(fd, start)) {} + WriteWordsWrapper::WriteWordsWrapper(EnumerateVocab *inner) : inner_(inner) {} -WriteWordsWrapper::~WriteWordsWrapper() {} void WriteWordsWrapper::Add(WordIndex index, const StringPiece &str) { if (inner_) inner_->Add(index, str); @@ -80,6 +89,14 @@ void WriteWordsWrapper::Add(WordIndex index, const StringPiece &str) { buffer_.push_back(0); } +void WriteWordsWrapper::Write(int fd, uint64_t start) { + util::SeekOrThrow(fd, start); + util::WriteOrThrow(fd, buffer_.data(), buffer_.size()); + // Free memory from the string. + std::string for_swap; + std::swap(buffer_, for_swap); +} + SortedVocabulary::SortedVocabulary() : begin_(NULL), end_(NULL), enumerate_(NULL) {} uint64_t SortedVocabulary::Size(uint64_t entries, const Config &/*config*/) { @@ -126,10 +143,78 @@ WordIndex SortedVocabulary::Insert(const StringPiece &str) { return end_ - begin_; } -void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) { +void SortedVocabulary::FinishedLoading(ProbBackoff *reorder) { + GenericFinished(reorder); +} + +namespace { +#pragma pack(push) +#pragma pack(4) +struct RenumberEntry { + uint64_t hash; + const char *str; + WordIndex old; + bool operator<(const RenumberEntry &other) const { + return hash < other.hash; + } +}; +#pragma pack(pop) +} // namespace + +void SortedVocabulary::ComputeRenumbering(WordIndex types, int from_words, int to_words, std::vector &mapping) { + mapping.clear(); + uint64_t file_size = util::SizeOrThrow(from_words); + util::scoped_memory strings; + util::MapRead(util::POPULATE_OR_READ, from_words, 0, file_size, strings); + const char *const start = static_cast(strings.get()); + UTIL_THROW_IF(memcmp(start, "", 6), FormatLoadException, "Vocab file does not begin with followed by null"); + std::vector entries; + entries.reserve(types - 1); + RenumberEntry entry; + entry.old = 1; + for (entry.str = start + 6 /* skip \0 */; entry.str < start + file_size; ++entry.old) { + StringPiece str(entry.str, strlen(entry.str)); + entry.hash = detail::HashForVocab(str); + entries.push_back(entry); + entry.str += str.size() + 1; + } + UTIL_THROW_IF2(entries.size() != types - 1, "Wrong number of vocab ids. Got " << (entries.size() + 1) << " expected " << types); + std::sort(entries.begin(), entries.end()); + // Write out new vocab file. + { + util::FakeOFStream out(to_words); + out << "" << '\0'; + for (std::vector::const_iterator i = entries.begin(); i != entries.end(); ++i) { + out << i->str << '\0'; + } + } + strings.reset(); + + mapping.resize(types); + mapping[0] = 0; // + for (std::vector::const_iterator i = entries.begin(); i != entries.end(); ++i) { + mapping[i->old] = i + 1 - entries.begin(); + } +} + +void SortedVocabulary::Populated() { + saw_unk_ = true; + SetSpecial(Index(""), Index(""), 0); + bound_ = end_ - begin_ + 1; + *(reinterpret_cast(begin_) - 1) = end_ - begin_; +} + +void SortedVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset) { + end_ = begin_ + *(reinterpret_cast(begin_) - 1); + SetSpecial(Index(""), Index(""), 0); + bound_ = end_ - begin_ + 1; + if (have_words) ReadWords(fd, to, bound_, offset); +} + +template void SortedVocabulary::GenericFinished(T *reorder) { if (enumerate_) { if (!strings_to_enumerate_.empty()) { - util::PairedIterator values(reorder_vocab + 1, &*strings_to_enumerate_.begin()); + util::PairedIterator values(reorder + 1, &*strings_to_enumerate_.begin()); util::JointSort(begin_, end_, values); } for (WordIndex i = 0; i < static_cast(end_ - begin_); ++i) { @@ -139,7 +224,7 @@ void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) { strings_to_enumerate_.clear(); string_backing_.FreeAll(); } else { - util::JointSort(begin_, end_, reorder_vocab + 1); + util::JointSort(begin_, end_, reorder + 1); } SetSpecial(Index(""), Index(""), 0); // Save size. Excludes UNK. @@ -148,13 +233,6 @@ void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) { bound_ = end_ - begin_ + 1; } -void SortedVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset) { - end_ = begin_ + *(reinterpret_cast(begin_) - 1); - SetSpecial(Index(""), Index(""), 0); - bound_ = end_ - begin_ + 1; - if (have_words) ReadWords(fd, to, bound_, offset); -} - namespace { const unsigned int kProbingVocabularyVersion = 0; } // namespace @@ -209,7 +287,7 @@ WordIndex ProbingVocabulary::Insert(const StringPiece &str) { } } -void ProbingVocabulary::FinishedLoading() { +void ProbingVocabulary::InternalFinishedLoading() { lookup_.FinishedInserting(); header_->bound = bound_; header_->version = kProbingVocabularyVersion; diff --git a/lm/vocab.hh b/lm/vocab.hh index 2659b9ba8..b42566f23 100644 --- a/lm/vocab.hh +++ b/lm/vocab.hh @@ -30,15 +30,32 @@ inline uint64_t HashForVocab(const StringPiece &str) { struct ProbingVocabularyHeader; } // namespace detail +// Writes words immediately to a file instead of buffering, because we know +// where in the file to put them. +class ImmediateWriteWordsWrapper : public EnumerateVocab { + public: + ImmediateWriteWordsWrapper(EnumerateVocab *inner, int fd, uint64_t start); + + void Add(WordIndex index, const StringPiece &str) { + stream_ << str << '\0'; + if (inner_) inner_->Add(index, str); + } + + private: + EnumerateVocab *inner_; + + util::FakeOFStream stream_; +}; + +// When the binary size isn't known yet. class WriteWordsWrapper : public EnumerateVocab { public: WriteWordsWrapper(EnumerateVocab *inner); - ~WriteWordsWrapper(); - void Add(WordIndex index, const StringPiece &str); const std::string &Buffer() const { return buffer_; } + void Write(int fd, uint64_t start); private: EnumerateVocab *inner_; @@ -67,6 +84,12 @@ class SortedVocabulary : public base::Vocabulary { // Size for purposes of file writing static uint64_t Size(uint64_t entries, const Config &config); + /* Read null-delimited words from file from_words, renumber according to + * hash order, write null-delimited words to to_words, and create a mapping + * from old id to new id. The 0th vocab word must be . + */ + static void ComputeRenumbering(WordIndex types, int from_words, int to_words, std::vector &mapping); + // Vocab words are [0, Bound()) Only valid after FinishedLoading/LoadedBinary. WordIndex Bound() const { return bound_; } @@ -77,8 +100,8 @@ class SortedVocabulary : public base::Vocabulary { void ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries); + // Insert and FinishedLoading go together. WordIndex Insert(const StringPiece &str); - // Reorders reorder_vocab so that the IDs are sorted. void FinishedLoading(ProbBackoff *reorder_vocab); @@ -89,7 +112,13 @@ class SortedVocabulary : public base::Vocabulary { void LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset); + uint64_t *&EndHack() { return end_; } + + void Populated(); + private: + template void GenericFinished(T *reorder); + uint64_t *begin_, *end_; WordIndex bound_; @@ -153,9 +182,8 @@ class ProbingVocabulary : public base::Vocabulary { WordIndex Insert(const StringPiece &str); template void FinishedLoading(Weights * /*reorder_vocab*/) { - FinishedLoading(); + InternalFinishedLoading(); } - void FinishedLoading(); std::size_t UnkCountChangePadding() const { return 0; } @@ -164,6 +192,8 @@ class ProbingVocabulary : public base::Vocabulary { void LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset); private: + void InternalFinishedLoading(); + typedef util::ProbingHashTable Lookup; Lookup lookup_; diff --git a/lm/word_index.hh b/lm/word_index.hh index ad59a7c2f..59b24d7d2 100644 --- a/lm/word_index.hh +++ b/lm/word_index.hh @@ -7,6 +7,7 @@ namespace lm { typedef unsigned int WordIndex; const WordIndex kMaxWordIndex = UINT_MAX; +const WordIndex kUNK = 0; } // namespace lm typedef lm::WordIndex LMWordIndex; diff --git a/mert/Fdstream.h b/mert/Fdstream.h index 23eecc466..2258ef4a5 100644 --- a/mert/Fdstream.h +++ b/mert/Fdstream.h @@ -13,8 +13,6 @@ #include #include -#include "util/unistd.hh" - #if defined(__GLIBCXX__) || defined(__GLIBCPP__) #include diff --git a/mert/MeteorScorer.cpp b/mert/MeteorScorer.cpp index f4c7997ee..d030e52bd 100644 --- a/mert/MeteorScorer.cpp +++ b/mert/MeteorScorer.cpp @@ -18,7 +18,6 @@ #include "ScoreStats.h" #include "Util.h" -#include "util/unistd.hh" using namespace std; diff --git a/util/Jamfile b/util/Jamfile index 2d3cede01..7538f7d17 100644 --- a/util/Jamfile +++ b/util/Jamfile @@ -21,10 +21,13 @@ obj file_piece_test.o : file_piece_test.cc /top//boost_unit_test_framework : $(c fakelib parallel_read : parallel_read.cc : multi:/top//boost_thread multi:WITH_THREADS : : .. ; -fakelib kenutil : bit_packing.cc ersatz_progress.cc exception.cc file.cc file_piece.cc mmap.cc murmur_hash.cc parallel_read pool.cc random.cc read_compressed scoped.cc string_piece.cc usage.cc double-conversion//double-conversion : .. LINUX,single:rt : : .. ; +fakelib kenutil : [ glob *.cc : parallel_read.cc read_compressed.cc *_main.cc *_test.cc ] read_compressed parallel_read double-conversion//double-conversion : .. LINUX,single:rt : : .. ; exe cat_compressed : cat_compressed_main.cc kenutil ; +#Does not install this +exe probing_hash_table_benchmark : probing_hash_table_benchmark_main.cc kenutil ; + alias programs : cat_compressed ; import testing ; @@ -34,3 +37,5 @@ for local t in [ glob *_test.cc : file_piece_test.cc read_compressed_test.cc ] { local name = [ MATCH "(.*)\.cc" : $(t) ] ; unit-test $(name) : $(t) kenutil /top//boost_unit_test_framework /top//boost_filesystem /top//boost_system ; } + +build-project stream ; diff --git a/util/exception.hh b/util/exception.hh index 7a0e7c44a..d67a6f9fb 100644 --- a/util/exception.hh +++ b/util/exception.hh @@ -91,6 +91,12 @@ template typename Except::template ExceptionTag= 3 +#define UTIL_LIKELY(x) __builtin_expect (!!(x), 1) +#else +#define UTIL_LIKELY(x) (x) +#endif + #define UTIL_THROW_IF_ARG(Condition, Exception, Arg, Modify) do { \ if (UTIL_UNLIKELY(Condition)) { \ UTIL_THROW_BACKEND(#Condition, Exception, Arg, Modify); \ diff --git a/util/fake_ofstream.hh b/util/fake_ofstream.hh index 8299ba9ac..d35bf0d83 100644 --- a/util/fake_ofstream.hh +++ b/util/fake_ofstream.hh @@ -1,111 +1,135 @@ /* Like std::ofstream but without being incredibly slow. Backed by a raw fd. - * Does not support many data types. Currently, it's targeted at writing ARPA - * files quickly. + * Supports most of the built-in types except for void* and long double. */ #ifndef UTIL_FAKE_OFSTREAM_H #define UTIL_FAKE_OFSTREAM_H -#include "util/double-conversion/double-conversion.h" -#include "util/double-conversion/utils.h" #include "util/file.hh" +#include "util/float_to_string.hh" +#include "util/integer_to_string.hh" #include "util/scoped.hh" #include "util/string_piece.hh" -#define BOOST_LEXICAL_CAST_ASSUME_C_LOCALE -#include +#include +#include + +#include namespace util { class FakeOFStream { public: + // Maximum over all ToString operations. + // static const std::size_t kMinBuf = 20; + // This was causing compile failures in debug, so now 20 is written directly. + // // Does not take ownership of out. // Allows default constructor, but must call SetFD. explicit FakeOFStream(int out = -1, std::size_t buffer_size = 1048576) - : buf_(util::MallocOrThrow(buffer_size)), - builder_(static_cast(buf_.get()), buffer_size), - // Mostly the default but with inf instead. And no flags. - convert_(double_conversion::DoubleToStringConverter::NO_FLAGS, "inf", "NaN", 'e', -6, 21, 6, 0), - fd_(out), - buffer_size_(buffer_size) {} + : buf_(util::MallocOrThrow(std::max(buffer_size, (size_t)20))), + current_(static_cast(buf_.get())), + end_(current_ + std::max(buffer_size, (size_t)20)), + fd_(out) {} ~FakeOFStream() { - if (buf_.get()) Flush(); + // Could have called Finish already + flush(); } void SetFD(int to) { - if (builder_.position()) Flush(); + flush(); fd_ = to; } - FakeOFStream &Write(const void *data, std::size_t length) { - // Dominant case - if (static_cast(builder_.size() - builder_.position()) > length) { - builder_.AddSubstring((const char*)data, length); + FakeOFStream &write(const void *data, std::size_t length) { + if (UTIL_LIKELY(current_ + length <= end_)) { + std::memcpy(current_, data, length); + current_ += length; return *this; } - Flush(); - if (length > buffer_size_) { - util::WriteOrThrow(fd_, data, length); + flush(); + if (current_ + length <= end_) { + std::memcpy(current_, data, length); + current_ += length; } else { - builder_.AddSubstring((const char*)data, length); + util::WriteOrThrow(fd_, data, length); } return *this; } + // This also covers std::string and char* FakeOFStream &operator<<(StringPiece str) { - return Write(str.data(), str.size()); + return write(str.data(), str.size()); } - FakeOFStream &operator<<(float value) { - // Odd, but this is the largest number found in the comments. - EnsureRemaining(double_conversion::DoubleToStringConverter::kMaxPrecisionDigits + 8); - convert_.ToShortestSingle(value, &builder_); + // For anything with ToStringBuf::kBytes, define operator<< using ToString. + // This includes uint64_t, int64_t, uint32_t, int32_t, uint16_t, int16_t, + // float, double + private: + template struct EnableIfKludge { + typedef FakeOFStream type; + }; + public: + template typename EnableIfKludge::kBytes>::type &operator<<(const T value) { + EnsureRemaining(ToStringBuf::kBytes); + current_ = ToString(value, current_); + assert(current_ <= end_); return *this; } - FakeOFStream &operator<<(double value) { - EnsureRemaining(double_conversion::DoubleToStringConverter::kMaxPrecisionDigits + 8); - convert_.ToShortest(value, &builder_); - return *this; - } - - // Inefficient! TODO: more efficient implementation - FakeOFStream &operator<<(unsigned value) { - return *this << boost::lexical_cast(value); - } - FakeOFStream &operator<<(char c) { EnsureRemaining(1); - builder_.AddCharacter(c); + *current_++ = c; + return *this; + } + + FakeOFStream &operator<<(unsigned char c) { + EnsureRemaining(1); + *current_++ = static_cast(c); + return *this; + } + + /* clang on OS X appears to consider std::size_t aka unsigned long distinct + * from uint64_t. So this function makes clang work. gcc considers + * uint64_t and std::size_t the same (on 64-bit) so this isn't necessary. + * But it does no harm since gcc sees it as a specialization of the + * EnableIfKludge template. + * Also, delegating to *this << static_cast(value) would loop + * indefinitely on gcc. + */ + FakeOFStream &operator<<(std::size_t value) { + EnsureRemaining(ToStringBuf::kBytes); + current_ = ToString(static_cast(value), current_); return *this; } // Note this does not sync. - void Flush() { - util::WriteOrThrow(fd_, buf_.get(), builder_.position()); - builder_.Reset(); + void flush() { + if (current_ != buf_.get()) { + util::WriteOrThrow(fd_, buf_.get(), current_ - (char*)buf_.get()); + current_ = static_cast(buf_.get()); + } } // Not necessary, but does assure the data is cleared. void Finish() { - Flush(); - // It will segfault trying to null terminate otherwise. - builder_.Finalize(); + flush(); buf_.reset(); + current_ = NULL; util::FSyncOrThrow(fd_); } private: void EnsureRemaining(std::size_t amount) { - if (static_cast(builder_.size() - builder_.position()) <= amount) { - Flush(); + if (UTIL_UNLIKELY(current_ + amount > end_)) { + flush(); + assert(current_ + amount <= end_); } } util::scoped_malloc buf_; - double_conversion::StringBuilder builder_; - double_conversion::DoubleToStringConverter convert_; + char *current_, *end_; + int fd_; - const std::size_t buffer_size_; }; } // namespace diff --git a/util/file_piece.cc b/util/file_piece.cc index c808e7d90..92edef27d 100644 --- a/util/file_piece.cc +++ b/util/file_piece.cc @@ -11,19 +11,22 @@ #include #endif -#include -#include -#include #include -#include +#include +#include #include +#include +#include +#include + +#include #include #include namespace util { ParseNumberException::ParseNumberException(StringPiece value) throw() { - *this << "Could not parse \"" << value << "\" into a number"; + *this << "Could not parse \"" << value << "\" into a "; } // Sigh this is the only way I could come up with to do a _const_ bool. It has ' ', '\f', '\n', '\r', '\t', and '\v' (same as isspace on C locale). @@ -62,12 +65,17 @@ FilePiece::FilePiece(std::istream &stream, const char *name, std::size_t min_buf FilePiece::~FilePiece() {} -StringPiece FilePiece::ReadLine(char delim) { +StringPiece FilePiece::ReadLine(char delim, bool strip_cr) { std::size_t skip = 0; while (true) { for (const char *i = position_ + skip; i < position_end_; ++i) { if (*i == delim) { - StringPiece ret(position_, i - position_); + // End of line. + // Take 1 byte off the end if it's an unwanted carriage return. + const std::size_t subtract_cr = ( + (strip_cr && i > position_ && *(i - 1) == '\r') ? + 1 : 0); + StringPiece ret(position_, i - position_ - subtract_cr); position_ = i + 1; return ret; } @@ -83,9 +91,9 @@ StringPiece FilePiece::ReadLine(char delim) { } } -bool FilePiece::ReadLineOrEOF(StringPiece &to, char delim) { +bool FilePiece::ReadLineOrEOF(StringPiece &to, char delim, bool strip_cr) { try { - to = ReadLine(delim); + to = ReadLine(delim, strip_cr); } catch (const util::EndOfFileException &e) { return false; } return true; } @@ -145,49 +153,59 @@ static const double_conversion::StringToDoubleConverter kConverter( "inf", "NaN"); -void ParseNumber(const char *begin, const char *&end, float &out) { +StringPiece FirstToken(StringPiece str) { + const char *i; + for (i = str.data(); i != str.data() + str.size(); ++i) { + if (kSpaces[(unsigned char)*i]) break; + } + return StringPiece(str.data(), i - str.data()); +} + +const char *ParseNumber(StringPiece str, float &out) { int count; - out = kConverter.StringToFloat(begin, end - begin, &count); - end = begin + count; + out = kConverter.StringToFloat(str.data(), str.size(), &count); + UTIL_THROW_IF_ARG(isnan(out) && str != "NaN" && str != "nan", ParseNumberException, (FirstToken(str)), "float"); + return str.data() + count; } -void ParseNumber(const char *begin, const char *&end, double &out) { +const char *ParseNumber(StringPiece str, double &out) { int count; - out = kConverter.StringToDouble(begin, end - begin, &count); - end = begin + count; + out = kConverter.StringToDouble(str.data(), str.size(), &count); + UTIL_THROW_IF_ARG(isnan(out) && str != "NaN" && str != "nan", ParseNumberException, (FirstToken(str)), "double"); + return str.data() + count; } -void ParseNumber(const char *begin, const char *&end, long int &out) { - char *silly_end; - out = strtol(begin, &silly_end, 10); - end = silly_end; +const char *ParseNumber(StringPiece str, long int &out) { + char *end; + errno = 0; + out = strtol(str.data(), &end, 10); + UTIL_THROW_IF_ARG(errno || (end == str.data()), ParseNumberException, (FirstToken(str)), "long int"); + return end; } -void ParseNumber(const char *begin, const char *&end, unsigned long int &out) { - char *silly_end; - out = strtoul(begin, &silly_end, 10); - end = silly_end; +const char *ParseNumber(StringPiece str, unsigned long int &out) { + char *end; + errno = 0; + out = strtoul(str.data(), &end, 10); + UTIL_THROW_IF_ARG(errno || (end == str.data()), ParseNumberException, (FirstToken(str)), "unsigned long int"); + return end; } } // namespace template T FilePiece::ReadNumber() { SkipSpaces(); while (last_space_ < position_) { - if (at_end_) { + if (UTIL_UNLIKELY(at_end_)) { // Hallucinate a null off the end of the file. std::string buffer(position_, position_end_); - const char *buf = buffer.c_str(); - const char *end = buf + buffer.size(); T ret; - ParseNumber(buf, end, ret); - if (buf == end) throw ParseNumberException(buffer); - position_ += end - buf; + // Has to be null-terminated. + const char *begin = buffer.c_str(); + const char *end = ParseNumber(StringPiece(begin, buffer.size()), ret); + position_ += end - begin; return ret; } Shift(); } - const char *end = last_space_; T ret; - ParseNumber(position_, end, ret); - if (end == position_) throw ParseNumberException(ReadDelimited()); - position_ = end; + position_ = ParseNumber(StringPiece(position_, last_space_ - position_), ret); return ret; } diff --git a/util/file_piece.hh b/util/file_piece.hh index 35f5eb648..d3d83054d 100644 --- a/util/file_piece.hh +++ b/util/file_piece.hh @@ -55,7 +55,7 @@ class FilePiece { return Consume(FindDelimiterOrEOF(delim)); } - // Read word until the line or file ends. + /// Read word until the line or file ends. bool ReadWordSameLine(StringPiece &to, const bool *delim = kSpaces) { assert(delim[static_cast('\n')]); // Skip non-enter spaces. @@ -75,12 +75,30 @@ class FilePiece { return true; } - // Unlike ReadDelimited, this includes leading spaces and consumes the delimiter. - // It is similar to getline in that way. - StringPiece ReadLine(char delim = '\n'); + /** Read a line of text from the file. + * + * Unlike ReadDelimited, this includes leading spaces and consumes the + * delimiter. It is similar to getline in that way. + * + * If strip_cr is true, any trailing carriate return (as would be found on + * a file written on Windows) will be left out of the returned line. + * + * Throws EndOfFileException if the end of the file is encountered. If the + * file does not end in a newline, this could mean that the last line is + * never read. + */ + StringPiece ReadLine(char delim = '\n', bool strip_cr = true); - // Doesn't throw EndOfFileException, just returns false. - bool ReadLineOrEOF(StringPiece &to, char delim = '\n'); + /** Read a line of text from the file, or return false on EOF. + * + * This is like ReadLine, except it returns false where ReadLine throws + * EndOfFileException. Like ReadLine it may not read the last line in the + * file if the file does not end in a newline. + * + * If strip_cr is true, any trailing carriate return (as would be found on + * a file written on Windows) will be left out of the returned line. + */ + bool ReadLineOrEOF(StringPiece &to, char delim = '\n', bool strip_cr = true); float ReadFloat(); double ReadDouble(); diff --git a/util/file_piece_test.cc b/util/file_piece_test.cc index 4a2a521ba..11e2ab3aa 100644 --- a/util/file_piece_test.cc +++ b/util/file_piece_test.cc @@ -1,6 +1,7 @@ // Tests might fail if you have creative characters in your path. Sue me. #include "util/file_piece.hh" +#include "util/fake_ofstream.hh" #include "util/file.hh" #include "util/scoped.hh" @@ -133,5 +134,21 @@ BOOST_AUTO_TEST_CASE(StreamZipReadLine) { #endif // HAVE_ZLIB +BOOST_AUTO_TEST_CASE(Numbers) { + scoped_fd file(MakeTemp(FileLocation())); + const float floating = 3.2; + { + util::FakeOFStream writing(file.get()); + writing << "94389483984398493890287 " << floating << " 5"; + } + SeekOrThrow(file.get(), 0); + util::FilePiece f(file.release()); + BOOST_CHECK_THROW(f.ReadULong(), ParseNumberException); + BOOST_CHECK_EQUAL("94389483984398493890287", f.ReadDelimited()); + // Yes, exactly equal. Isn't double-conversion wonderful? + BOOST_CHECK_EQUAL(floating, f.ReadFloat()); + BOOST_CHECK_EQUAL(5, f.ReadULong()); +} + } // namespace } // namespace util diff --git a/util/float_to_string.cc b/util/float_to_string.cc new file mode 100644 index 000000000..1e16d6f99 --- /dev/null +++ b/util/float_to_string.cc @@ -0,0 +1,23 @@ +#include "util/float_to_string.hh" + +#include "util/double-conversion/double-conversion.h" +#include "util/double-conversion/utils.h" + +namespace util { +namespace { +const double_conversion::DoubleToStringConverter kConverter(double_conversion::DoubleToStringConverter::NO_FLAGS, "inf", "NaN", 'e', -6, 21, 6, 0); +} // namespace + +char *ToString(double value, char *to) { + double_conversion::StringBuilder builder(to, ToStringBuf::kBytes); + kConverter.ToShortest(value, &builder); + return &to[builder.position()]; +} + +char *ToString(float value, char *to) { + double_conversion::StringBuilder builder(to, ToStringBuf::kBytes); + kConverter.ToShortestSingle(value, &builder); + return &to[builder.position()]; +} + +} // namespace util diff --git a/util/float_to_string.hh b/util/float_to_string.hh new file mode 100644 index 000000000..d1104e790 --- /dev/null +++ b/util/float_to_string.hh @@ -0,0 +1,25 @@ +#ifndef UTIL_FLOAT_TO_STRING_H +#define UTIL_FLOAT_TO_STRING_H + +// Just for ToStringBuf +#include "util/integer_to_string.hh" + +namespace util { + +template <> struct ToStringBuf { + // DoubleToStringConverter::kBase10MaximalLength + 1 for null paranoia. + static const unsigned kBytes = 18; +}; + +// Single wasn't documented in double conversion, so be conservative and +// say the same as double. +template <> struct ToStringBuf { + static const unsigned kBytes = 18; +}; + +char *ToString(double value, char *to); +char *ToString(float value, char *to); + +} // namespace util + +#endif // UTIL_FLOAT_TO_STRING_H diff --git a/util/integer_to_string.cc b/util/integer_to_string.cc new file mode 100644 index 000000000..32047291d --- /dev/null +++ b/util/integer_to_string.cc @@ -0,0 +1,639 @@ +/* Fast integer to string conversion. +Source: https://github.com/miloyip/itoa-benchmark +Local modifications: +1. Return end of buffer instead of null terminating +2. Collapse to single file +3. Namespace +4. Remove test hook +5. Non-x86 support from the branch_lut code +6. Rename functions + +Copyright (C) 2014 Milo Yip + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. + +Which is based on: http://0x80.pl/snippets/asm/sse-utoa.c + + SSE: conversion integers to decimal representation + + Author: Wojciech MuÅ‚a + e-mail: wojciech_mula@poczta.onet.pl + www: http://0x80.pl/ + + License: BSD + + initial release 2011-10-21 + $Id$ +*/ + +#include "util/integer_to_string.hh" +#include +#include + +namespace util { + +namespace { +const char gDigitsLut[200] = { + '0','0','0','1','0','2','0','3','0','4','0','5','0','6','0','7','0','8','0','9', + '1','0','1','1','1','2','1','3','1','4','1','5','1','6','1','7','1','8','1','9', + '2','0','2','1','2','2','2','3','2','4','2','5','2','6','2','7','2','8','2','9', + '3','0','3','1','3','2','3','3','3','4','3','5','3','6','3','7','3','8','3','9', + '4','0','4','1','4','2','4','3','4','4','4','5','4','6','4','7','4','8','4','9', + '5','0','5','1','5','2','5','3','5','4','5','5','5','6','5','7','5','8','5','9', + '6','0','6','1','6','2','6','3','6','4','6','5','6','6','6','7','6','8','6','9', + '7','0','7','1','7','2','7','3','7','4','7','5','7','6','7','7','7','8','7','9', + '8','0','8','1','8','2','8','3','8','4','8','5','8','6','8','7','8','8','8','9', + '9','0','9','1','9','2','9','3','9','4','9','5','9','6','9','7','9','8','9','9' +}; +} // namespace + +// SSE2 implementation according to http://0x80.pl/articles/sse-itoa.html +// Modifications: (1) fix incorrect digits (2) accept all ranges (3) write to user provided buffer. + +#if defined(i386) || defined(__amd64) || defined(_M_IX86) || defined(_M_X64) + +#include + +#ifdef _MSC_VER +#include "intrin.h" +#endif + +#ifdef _MSC_VER +#define ALIGN_PRE __declspec(align(16)) +#define ALIGN_SUF +#else +#define ALIGN_PRE +#define ALIGN_SUF __attribute__ ((aligned(16))) +#endif + +namespace { + +static const uint32_t kDiv10000 = 0xd1b71759; +ALIGN_PRE static const uint32_t kDiv10000Vector[4] ALIGN_SUF = { kDiv10000, kDiv10000, kDiv10000, kDiv10000 }; +ALIGN_PRE static const uint32_t k10000Vector[4] ALIGN_SUF = { 10000, 10000, 10000, 10000 }; +ALIGN_PRE static const uint16_t kDivPowersVector[8] ALIGN_SUF = { 8389, 5243, 13108, 32768, 8389, 5243, 13108, 32768 }; // 10^3, 10^2, 10^1, 10^0 +ALIGN_PRE static const uint16_t kShiftPowersVector[8] ALIGN_SUF = { + 1 << (16 - (23 + 2 - 16)), + 1 << (16 - (19 + 2 - 16)), + 1 << (16 - 1 - 2), + 1 << (15), + 1 << (16 - (23 + 2 - 16)), + 1 << (16 - (19 + 2 - 16)), + 1 << (16 - 1 - 2), + 1 << (15) +}; +ALIGN_PRE static const uint16_t k10Vector[8] ALIGN_SUF = { 10, 10, 10, 10, 10, 10, 10, 10 }; +ALIGN_PRE static const char kAsciiZero[16] ALIGN_SUF = { '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0' }; + +inline __m128i Convert8DigitsSSE2(uint32_t value) { + assert(value <= 99999999); + + // abcd, efgh = abcdefgh divmod 10000 + const __m128i abcdefgh = _mm_cvtsi32_si128(value); + const __m128i abcd = _mm_srli_epi64(_mm_mul_epu32(abcdefgh, reinterpret_cast(kDiv10000Vector)[0]), 45); + const __m128i efgh = _mm_sub_epi32(abcdefgh, _mm_mul_epu32(abcd, reinterpret_cast(k10000Vector)[0])); + + // v1 = [ abcd, efgh, 0, 0, 0, 0, 0, 0 ] + const __m128i v1 = _mm_unpacklo_epi16(abcd, efgh); + + // v1a = v1 * 4 = [ abcd * 4, efgh * 4, 0, 0, 0, 0, 0, 0 ] + const __m128i v1a = _mm_slli_epi64(v1, 2); + + // v2 = [ abcd * 4, abcd * 4, abcd * 4, abcd * 4, efgh * 4, efgh * 4, efgh * 4, efgh * 4 ] + const __m128i v2a = _mm_unpacklo_epi16(v1a, v1a); + const __m128i v2 = _mm_unpacklo_epi32(v2a, v2a); + + // v4 = v2 div 10^3, 10^2, 10^1, 10^0 = [ a, ab, abc, abcd, e, ef, efg, efgh ] + const __m128i v3 = _mm_mulhi_epu16(v2, reinterpret_cast(kDivPowersVector)[0]); + const __m128i v4 = _mm_mulhi_epu16(v3, reinterpret_cast(kShiftPowersVector)[0]); + + // v5 = v4 * 10 = [ a0, ab0, abc0, abcd0, e0, ef0, efg0, efgh0 ] + const __m128i v5 = _mm_mullo_epi16(v4, reinterpret_cast(k10Vector)[0]); + + // v6 = v5 << 16 = [ 0, a0, ab0, abc0, 0, e0, ef0, efg0 ] + const __m128i v6 = _mm_slli_epi64(v5, 16); + + // v7 = v4 - v6 = { a, b, c, d, e, f, g, h } + const __m128i v7 = _mm_sub_epi16(v4, v6); + + return v7; +} + +inline __m128i ShiftDigits_SSE2(__m128i a, unsigned digit) { + assert(digit <= 8); + switch (digit) { + case 0: return a; + case 1: return _mm_srli_si128(a, 1); + case 2: return _mm_srli_si128(a, 2); + case 3: return _mm_srli_si128(a, 3); + case 4: return _mm_srli_si128(a, 4); + case 5: return _mm_srli_si128(a, 5); + case 6: return _mm_srli_si128(a, 6); + case 7: return _mm_srli_si128(a, 7); + case 8: return _mm_srli_si128(a, 8); + } + return a; // should not execute here. +} + +} // namespace + +// Original name: u32toa_sse2 +char *ToString(uint32_t value, char* buffer) { + if (value < 10000) { + const uint32_t d1 = (value / 100) << 1; + const uint32_t d2 = (value % 100) << 1; + + if (value >= 1000) + *buffer++ = gDigitsLut[d1]; + if (value >= 100) + *buffer++ = gDigitsLut[d1 + 1]; + if (value >= 10) + *buffer++ = gDigitsLut[d2]; + *buffer++ = gDigitsLut[d2 + 1]; + //*buffer++ = '\0'; + return buffer; + } + else if (value < 100000000) { + // Experiment shows that this case SSE2 is slower +#if 0 + const __m128i a = Convert8DigitsSSE2(value); + + // Convert to bytes, add '0' + const __m128i va = _mm_add_epi8(_mm_packus_epi16(a, _mm_setzero_si128()), reinterpret_cast(kAsciiZero)[0]); + + // Count number of digit + const unsigned mask = _mm_movemask_epi8(_mm_cmpeq_epi8(va, reinterpret_cast(kAsciiZero)[0])); + unsigned long digit; +#ifdef _MSC_VER + _BitScanForward(&digit, ~mask | 0x8000); +#else + digit = __builtin_ctz(~mask | 0x8000); +#endif + + // Shift digits to the beginning + __m128i result = ShiftDigits_SSE2(va, digit); + //__m128i result = _mm_srl_epi64(va, _mm_cvtsi32_si128(digit * 8)); + _mm_storel_epi64(reinterpret_cast<__m128i*>(buffer), result); + buffer[8 - digit] = '\0'; +#else + // value = bbbbcccc + const uint32_t b = value / 10000; + const uint32_t c = value % 10000; + + const uint32_t d1 = (b / 100) << 1; + const uint32_t d2 = (b % 100) << 1; + + const uint32_t d3 = (c / 100) << 1; + const uint32_t d4 = (c % 100) << 1; + + if (value >= 10000000) + *buffer++ = gDigitsLut[d1]; + if (value >= 1000000) + *buffer++ = gDigitsLut[d1 + 1]; + if (value >= 100000) + *buffer++ = gDigitsLut[d2]; + *buffer++ = gDigitsLut[d2 + 1]; + + *buffer++ = gDigitsLut[d3]; + *buffer++ = gDigitsLut[d3 + 1]; + *buffer++ = gDigitsLut[d4]; + *buffer++ = gDigitsLut[d4 + 1]; +// *buffer++ = '\0'; + return buffer; +#endif + } + else { + // value = aabbbbbbbb in decimal + + const uint32_t a = value / 100000000; // 1 to 42 + value %= 100000000; + + if (a >= 10) { + const unsigned i = a << 1; + *buffer++ = gDigitsLut[i]; + *buffer++ = gDigitsLut[i + 1]; + } + else + *buffer++ = '0' + static_cast(a); + + const __m128i b = Convert8DigitsSSE2(value); + const __m128i ba = _mm_add_epi8(_mm_packus_epi16(_mm_setzero_si128(), b), reinterpret_cast(kAsciiZero)[0]); + const __m128i result = _mm_srli_si128(ba, 8); + _mm_storel_epi64(reinterpret_cast<__m128i*>(buffer), result); +// buffer[8] = '\0'; + return buffer + 8; + } +} + +// Original name: u64toa_sse2 +char *ToString(uint64_t value, char* buffer) { + if (value < 100000000) { + uint32_t v = static_cast(value); + if (v < 10000) { + const uint32_t d1 = (v / 100) << 1; + const uint32_t d2 = (v % 100) << 1; + + if (v >= 1000) + *buffer++ = gDigitsLut[d1]; + if (v >= 100) + *buffer++ = gDigitsLut[d1 + 1]; + if (v >= 10) + *buffer++ = gDigitsLut[d2]; + *buffer++ = gDigitsLut[d2 + 1]; + //*buffer++ = '\0'; + return buffer; + } + else { + // Experiment shows that this case SSE2 is slower +#if 0 + const __m128i a = Convert8DigitsSSE2(v); + + // Convert to bytes, add '0' + const __m128i va = _mm_add_epi8(_mm_packus_epi16(a, _mm_setzero_si128()), reinterpret_cast(kAsciiZero)[0]); + + // Count number of digit + const unsigned mask = _mm_movemask_epi8(_mm_cmpeq_epi8(va, reinterpret_cast(kAsciiZero)[0])); + unsigned long digit; +#ifdef _MSC_VER + _BitScanForward(&digit, ~mask | 0x8000); +#else + digit = __builtin_ctz(~mask | 0x8000); +#endif + + // Shift digits to the beginning + __m128i result = ShiftDigits_SSE2(va, digit); + _mm_storel_epi64(reinterpret_cast<__m128i*>(buffer), result); + buffer[8 - digit] = '\0'; +#else + // value = bbbbcccc + const uint32_t b = v / 10000; + const uint32_t c = v % 10000; + + const uint32_t d1 = (b / 100) << 1; + const uint32_t d2 = (b % 100) << 1; + + const uint32_t d3 = (c / 100) << 1; + const uint32_t d4 = (c % 100) << 1; + + if (value >= 10000000) + *buffer++ = gDigitsLut[d1]; + if (value >= 1000000) + *buffer++ = gDigitsLut[d1 + 1]; + if (value >= 100000) + *buffer++ = gDigitsLut[d2]; + *buffer++ = gDigitsLut[d2 + 1]; + + *buffer++ = gDigitsLut[d3]; + *buffer++ = gDigitsLut[d3 + 1]; + *buffer++ = gDigitsLut[d4]; + *buffer++ = gDigitsLut[d4 + 1]; + //*buffer++ = '\0'; + return buffer; +#endif + } + } + else if (value < 10000000000000000) { + const uint32_t v0 = static_cast(value / 100000000); + const uint32_t v1 = static_cast(value % 100000000); + + const __m128i a0 = Convert8DigitsSSE2(v0); + const __m128i a1 = Convert8DigitsSSE2(v1); + + // Convert to bytes, add '0' + const __m128i va = _mm_add_epi8(_mm_packus_epi16(a0, a1), reinterpret_cast(kAsciiZero)[0]); + + // Count number of digit + const unsigned mask = _mm_movemask_epi8(_mm_cmpeq_epi8(va, reinterpret_cast(kAsciiZero)[0])); +#ifdef _MSC_VER + unsigned long digit; + _BitScanForward(&digit, ~mask | 0x8000); +#else + unsigned digit = __builtin_ctz(~mask | 0x8000); +#endif + + // Shift digits to the beginning + __m128i result = ShiftDigits_SSE2(va, digit); + _mm_storeu_si128(reinterpret_cast<__m128i*>(buffer), result); +// buffer[16 - digit] = '\0'; + return &buffer[16 - digit]; + } + else { + const uint32_t a = static_cast(value / 10000000000000000); // 1 to 1844 + value %= 10000000000000000; + + if (a < 10) + *buffer++ = '0' + static_cast(a); + else if (a < 100) { + const uint32_t i = a << 1; + *buffer++ = gDigitsLut[i]; + *buffer++ = gDigitsLut[i + 1]; + } + else if (a < 1000) { + *buffer++ = '0' + static_cast(a / 100); + + const uint32_t i = (a % 100) << 1; + *buffer++ = gDigitsLut[i]; + *buffer++ = gDigitsLut[i + 1]; + } + else { + const uint32_t i = (a / 100) << 1; + const uint32_t j = (a % 100) << 1; + *buffer++ = gDigitsLut[i]; + *buffer++ = gDigitsLut[i + 1]; + *buffer++ = gDigitsLut[j]; + *buffer++ = gDigitsLut[j + 1]; + } + + const uint32_t v0 = static_cast(value / 100000000); + const uint32_t v1 = static_cast(value % 100000000); + + const __m128i a0 = Convert8DigitsSSE2(v0); + const __m128i a1 = Convert8DigitsSSE2(v1); + + // Convert to bytes, add '0' + const __m128i va = _mm_add_epi8(_mm_packus_epi16(a0, a1), reinterpret_cast(kAsciiZero)[0]); + _mm_storeu_si128(reinterpret_cast<__m128i*>(buffer), va); +// buffer[16] = '\0'; + return &buffer[16]; + } +} + +#else // Generic Non-x86 case + +// Orignal name: u32toa_branchlut +char *ToString(uint32_t value, char* buffer) { + if (value < 10000) { + const uint32_t d1 = (value / 100) << 1; + const uint32_t d2 = (value % 100) << 1; + + if (value >= 1000) + *buffer++ = gDigitsLut[d1]; + if (value >= 100) + *buffer++ = gDigitsLut[d1 + 1]; + if (value >= 10) + *buffer++ = gDigitsLut[d2]; + *buffer++ = gDigitsLut[d2 + 1]; + } + else if (value < 100000000) { + // value = bbbbcccc + const uint32_t b = value / 10000; + const uint32_t c = value % 10000; + + const uint32_t d1 = (b / 100) << 1; + const uint32_t d2 = (b % 100) << 1; + + const uint32_t d3 = (c / 100) << 1; + const uint32_t d4 = (c % 100) << 1; + + if (value >= 10000000) + *buffer++ = gDigitsLut[d1]; + if (value >= 1000000) + *buffer++ = gDigitsLut[d1 + 1]; + if (value >= 100000) + *buffer++ = gDigitsLut[d2]; + *buffer++ = gDigitsLut[d2 + 1]; + + *buffer++ = gDigitsLut[d3]; + *buffer++ = gDigitsLut[d3 + 1]; + *buffer++ = gDigitsLut[d4]; + *buffer++ = gDigitsLut[d4 + 1]; + } + else { + // value = aabbbbcccc in decimal + + const uint32_t a = value / 100000000; // 1 to 42 + value %= 100000000; + + if (a >= 10) { + const unsigned i = a << 1; + *buffer++ = gDigitsLut[i]; + *buffer++ = gDigitsLut[i + 1]; + } + else + *buffer++ = '0' + static_cast(a); + + const uint32_t b = value / 10000; // 0 to 9999 + const uint32_t c = value % 10000; // 0 to 9999 + + const uint32_t d1 = (b / 100) << 1; + const uint32_t d2 = (b % 100) << 1; + + const uint32_t d3 = (c / 100) << 1; + const uint32_t d4 = (c % 100) << 1; + + *buffer++ = gDigitsLut[d1]; + *buffer++ = gDigitsLut[d1 + 1]; + *buffer++ = gDigitsLut[d2]; + *buffer++ = gDigitsLut[d2 + 1]; + *buffer++ = gDigitsLut[d3]; + *buffer++ = gDigitsLut[d3 + 1]; + *buffer++ = gDigitsLut[d4]; + *buffer++ = gDigitsLut[d4 + 1]; + } + return buffer; //*buffer++ = '\0'; +} + +// Original name: u64toa_branchlut +char *ToString(uint64_t value, char* buffer) { + if (value < 100000000) { + uint32_t v = static_cast(value); + if (v < 10000) { + const uint32_t d1 = (v / 100) << 1; + const uint32_t d2 = (v % 100) << 1; + + if (v >= 1000) + *buffer++ = gDigitsLut[d1]; + if (v >= 100) + *buffer++ = gDigitsLut[d1 + 1]; + if (v >= 10) + *buffer++ = gDigitsLut[d2]; + *buffer++ = gDigitsLut[d2 + 1]; + } + else { + // value = bbbbcccc + const uint32_t b = v / 10000; + const uint32_t c = v % 10000; + + const uint32_t d1 = (b / 100) << 1; + const uint32_t d2 = (b % 100) << 1; + + const uint32_t d3 = (c / 100) << 1; + const uint32_t d4 = (c % 100) << 1; + + if (value >= 10000000) + *buffer++ = gDigitsLut[d1]; + if (value >= 1000000) + *buffer++ = gDigitsLut[d1 + 1]; + if (value >= 100000) + *buffer++ = gDigitsLut[d2]; + *buffer++ = gDigitsLut[d2 + 1]; + + *buffer++ = gDigitsLut[d3]; + *buffer++ = gDigitsLut[d3 + 1]; + *buffer++ = gDigitsLut[d4]; + *buffer++ = gDigitsLut[d4 + 1]; + } + } + else if (value < 10000000000000000) { + const uint32_t v0 = static_cast(value / 100000000); + const uint32_t v1 = static_cast(value % 100000000); + + const uint32_t b0 = v0 / 10000; + const uint32_t c0 = v0 % 10000; + + const uint32_t d1 = (b0 / 100) << 1; + const uint32_t d2 = (b0 % 100) << 1; + + const uint32_t d3 = (c0 / 100) << 1; + const uint32_t d4 = (c0 % 100) << 1; + + const uint32_t b1 = v1 / 10000; + const uint32_t c1 = v1 % 10000; + + const uint32_t d5 = (b1 / 100) << 1; + const uint32_t d6 = (b1 % 100) << 1; + + const uint32_t d7 = (c1 / 100) << 1; + const uint32_t d8 = (c1 % 100) << 1; + + if (value >= 1000000000000000) + *buffer++ = gDigitsLut[d1]; + if (value >= 100000000000000) + *buffer++ = gDigitsLut[d1 + 1]; + if (value >= 10000000000000) + *buffer++ = gDigitsLut[d2]; + if (value >= 1000000000000) + *buffer++ = gDigitsLut[d2 + 1]; + if (value >= 100000000000) + *buffer++ = gDigitsLut[d3]; + if (value >= 10000000000) + *buffer++ = gDigitsLut[d3 + 1]; + if (value >= 1000000000) + *buffer++ = gDigitsLut[d4]; + if (value >= 100000000) + *buffer++ = gDigitsLut[d4 + 1]; + + *buffer++ = gDigitsLut[d5]; + *buffer++ = gDigitsLut[d5 + 1]; + *buffer++ = gDigitsLut[d6]; + *buffer++ = gDigitsLut[d6 + 1]; + *buffer++ = gDigitsLut[d7]; + *buffer++ = gDigitsLut[d7 + 1]; + *buffer++ = gDigitsLut[d8]; + *buffer++ = gDigitsLut[d8 + 1]; + } + else { + const uint32_t a = static_cast(value / 10000000000000000); // 1 to 1844 + value %= 10000000000000000; + + if (a < 10) + *buffer++ = '0' + static_cast(a); + else if (a < 100) { + const uint32_t i = a << 1; + *buffer++ = gDigitsLut[i]; + *buffer++ = gDigitsLut[i + 1]; + } + else if (a < 1000) { + *buffer++ = '0' + static_cast(a / 100); + + const uint32_t i = (a % 100) << 1; + *buffer++ = gDigitsLut[i]; + *buffer++ = gDigitsLut[i + 1]; + } + else { + const uint32_t i = (a / 100) << 1; + const uint32_t j = (a % 100) << 1; + *buffer++ = gDigitsLut[i]; + *buffer++ = gDigitsLut[i + 1]; + *buffer++ = gDigitsLut[j]; + *buffer++ = gDigitsLut[j + 1]; + } + + const uint32_t v0 = static_cast(value / 100000000); + const uint32_t v1 = static_cast(value % 100000000); + + const uint32_t b0 = v0 / 10000; + const uint32_t c0 = v0 % 10000; + + const uint32_t d1 = (b0 / 100) << 1; + const uint32_t d2 = (b0 % 100) << 1; + + const uint32_t d3 = (c0 / 100) << 1; + const uint32_t d4 = (c0 % 100) << 1; + + const uint32_t b1 = v1 / 10000; + const uint32_t c1 = v1 % 10000; + + const uint32_t d5 = (b1 / 100) << 1; + const uint32_t d6 = (b1 % 100) << 1; + + const uint32_t d7 = (c1 / 100) << 1; + const uint32_t d8 = (c1 % 100) << 1; + + *buffer++ = gDigitsLut[d1]; + *buffer++ = gDigitsLut[d1 + 1]; + *buffer++ = gDigitsLut[d2]; + *buffer++ = gDigitsLut[d2 + 1]; + *buffer++ = gDigitsLut[d3]; + *buffer++ = gDigitsLut[d3 + 1]; + *buffer++ = gDigitsLut[d4]; + *buffer++ = gDigitsLut[d4 + 1]; + *buffer++ = gDigitsLut[d5]; + *buffer++ = gDigitsLut[d5 + 1]; + *buffer++ = gDigitsLut[d6]; + *buffer++ = gDigitsLut[d6 + 1]; + *buffer++ = gDigitsLut[d7]; + *buffer++ = gDigitsLut[d7 + 1]; + *buffer++ = gDigitsLut[d8]; + *buffer++ = gDigitsLut[d8 + 1]; + } + return buffer; +} + +#endif // End of architecture if statement. + +// Signed wrappers. The negation is done on the unsigned version because +// doing so has defined behavior for INT_MIN. +char *ToString(int32_t value, char *to) { + uint32_t un = static_cast(value); + if (value < 0) { + *to++ = '-'; + un = -un; + } + return ToString(un, to); +} + +char *ToString(int64_t value, char *to) { + uint64_t un = static_cast(value); + if (value < 0) { + *to++ = '-'; + un = -un; + } + return ToString(un, to); +} + +// No optimization for this case yet. +char *ToString(int16_t value, char *to) { + return ToString((int32_t)value, to); +} +char *ToString(uint16_t value, char *to) { + return ToString((uint32_t)value, to); +} + +} // namespace util diff --git a/util/integer_to_string.hh b/util/integer_to_string.hh new file mode 100644 index 000000000..0d975b14e --- /dev/null +++ b/util/integer_to_string.hh @@ -0,0 +1,56 @@ +#ifndef UTIL_INTEGER_TO_STRING_H +#define UTIL_INTEGER_TO_STRING_H +#include +#include + +namespace util { + +/* These functions convert integers to strings and return the end pointer. + */ +char *ToString(uint32_t value, char *to); +char *ToString(uint64_t value, char *to); + +// Implemented as wrappers to above +char *ToString(int32_t value, char *to); +char *ToString(int64_t value, char *to); + +// Calls the 32-bit versions for now. +char *ToString(uint16_t value, char *to); +char *ToString(int16_t value, char *to); + +inline char *ToString(bool value, char *to) { + *to++ = '0' + value; + return to; +} + +// How many bytes to reserve in the buffer for these strings: +// g++ 4.9.1 doesn't work with this: +// static const std::size_t kBytes = 5; +// So use enum. +template struct ToStringBuf; +template <> struct ToStringBuf { + enum { kBytes = 1 }; +}; +template <> struct ToStringBuf { + enum { kBytes = 5 }; +}; +template <> struct ToStringBuf { + enum { kBytes = 6 }; +}; +template <> struct ToStringBuf { + enum { kBytes = 10 }; +}; +template <> struct ToStringBuf { + enum { kBytes = 11 }; +}; +template <> struct ToStringBuf { + enum { kBytes = 20 }; +}; +template <> struct ToStringBuf { + // Not a typo. 2^63 has 19 digits. + enum { kBytes = 20 }; +}; + +} // namespace util + +#endif // UTIL_INTEGER_TO_STRING_H diff --git a/util/integer_to_string_test.cc b/util/integer_to_string_test.cc new file mode 100644 index 000000000..ded1ecec7 --- /dev/null +++ b/util/integer_to_string_test.cc @@ -0,0 +1,65 @@ +#define BOOST_LEXICAL_CAST_ASSUME_C_LOCALE +#include "util/integer_to_string.hh" +#include "util/string_piece.hh" + +#define BOOST_TEST_MODULE IntegerToStringTest +#include +#include + +#include + +namespace util { +namespace { + +template void TestValue(const T value) { + char buf[ToStringBuf::kBytes]; + StringPiece result(buf, ToString(value, buf) - buf); + BOOST_REQUIRE_GE(static_cast(ToStringBuf::kBytes), result.size()); + BOOST_CHECK_EQUAL(boost::lexical_cast(value), result); +} + +template void TestCorners() { + TestValue(std::numeric_limits::min()); + TestValue(std::numeric_limits::max()); + TestValue(static_cast(0)); + TestValue(static_cast(-1)); + TestValue(static_cast(1)); +} + +BOOST_AUTO_TEST_CASE(Corners) { + TestCorners(); + TestCorners(); + TestCorners(); + TestCorners(); + TestCorners(); + TestCorners(); +} + +template void TestAll() { + for (T i = std::numeric_limits::min(); i < std::numeric_limits::max(); ++i) { + TestValue(i); + } + TestValue(std::numeric_limits::max()); +} + +BOOST_AUTO_TEST_CASE(Short) { + TestAll(); + TestAll(); +} + +template void Test10s() { + for (T i = 1; i < std::numeric_limits::max() / 10; i *= 10) { + TestValue(i); + TestValue(i - 1); + TestValue(i + 1); + } +} + +BOOST_AUTO_TEST_CASE(Tens) { + Test10s(); + Test10s(); + Test10s(); + Test10s(); +} + +}} // namespaces diff --git a/util/probing_hash_table.hh b/util/probing_hash_table.hh index 245340ddb..f32b64ea3 100644 --- a/util/probing_hash_table.hh +++ b/util/probing_hash_table.hh @@ -88,7 +88,7 @@ template GetKey()); if (equal_(got, t.GetKey())) { out = i; return true; } if (equal_(got, invalid_)) { @@ -108,7 +108,7 @@ template GetKey()); if (equal_(got, key)) { out = i; return true; } if (equal_(got, invalid_)) return false; @@ -118,7 +118,7 @@ template MutableIterator UnsafeMutableMustFind(const Key key) { - for (MutableIterator i(begin_ + (hash_(key) % buckets_));;) { + for (MutableIterator i(Ideal(key));;) { Key got(i->GetKey()); if (equal_(got, key)) { return i; } assert(!equal_(got, invalid_)); @@ -131,7 +131,7 @@ template GetKey()); if (equal_(got, key)) { out = i; return true; } if (equal_(got, invalid_)) return false; @@ -141,7 +141,7 @@ template ConstIterator MustFind(const Key key) const { - for (ConstIterator i(begin_ + (hash_(key) % buckets_));;) { + for (ConstIterator i(Ideal(key));;) { Key got(i->GetKey()); if (equal_(got, key)) { return i; } assert(!equal_(got, invalid_)); @@ -213,7 +213,7 @@ template GetKey(), invalid_); ++i) { - MutableIterator ideal = Ideal(*i); + MutableIterator ideal = Ideal(i->GetKey()); UTIL_THROW_IF(ideal > i && ideal <= last, Exception, "Inconsistency at position " << (i - begin_) << " should be at " << (ideal - begin_)); } MutableIterator pre_gap = i; @@ -222,7 +222,7 @@ template GetKey()); UTIL_THROW_IF(ideal > i || ideal <= pre_gap, Exception, "Inconsistency at position " << (i - begin_) << " with ideal " << (ideal - begin_)); } } @@ -230,12 +230,15 @@ template ; - template MutableIterator Ideal(const T &t) { - return begin_ + (hash_(t.GetKey()) % buckets_); + MutableIterator Ideal(const Key key) { + return begin_ + (hash_(key) % buckets_); + } + ConstIterator Ideal(const Key key) const { + return begin_ + (hash_(key) % buckets_); } template MutableIterator UncheckedInsert(const T &t) { - for (MutableIterator i(Ideal(t));;) { + for (MutableIterator i(Ideal(t.GetKey()));;) { if (equal_(i->GetKey(), invalid_)) { *i = t; return i; } if (++i == end_) { i = begin_; } } @@ -277,6 +280,7 @@ template MutableIterator Insert(const T &t) { + ++backend_.entries_; DoubleIfNeeded(); return backend_.UncheckedInsert(t); } diff --git a/util/probing_hash_table_benchmark_main.cc b/util/probing_hash_table_benchmark_main.cc new file mode 100644 index 000000000..3e12290cf --- /dev/null +++ b/util/probing_hash_table_benchmark_main.cc @@ -0,0 +1,49 @@ +#include "util/probing_hash_table.hh" +#include "util/scoped.hh" +#include "util/usage.hh" + +#include +#include + +#include + +namespace util { +namespace { + +struct Entry { + typedef uint64_t Key; + Key key; + Key GetKey() const { return key; } +}; + +typedef util::ProbingHashTable Table; + +void Test(uint64_t entries, uint64_t lookups, float multiplier = 1.5) { + std::size_t size = Table::Size(entries, multiplier); + scoped_malloc backing(util::CallocOrThrow(size)); + Table table(backing.get(), size); + boost::random::mt19937 gen; + boost::random::uniform_int_distribution<> dist(std::numeric_limits::min(), std::numeric_limits::max()); + double start = UserTime(); + for (uint64_t i = 0; i < entries; ++i) { + Entry entry; + entry.key = dist(gen); + table.Insert(entry); + } + double inserted = UserTime(); + bool meaningless = true; + for (uint64_t i = 0; i < lookups; ++i) { + Table::ConstIterator it; + meaningless ^= table.Find(dist(gen), it); + } + std::cout << meaningless << ' ' << entries << ' ' << multiplier << ' ' << (inserted - start) << ' ' << (UserTime() - inserted) / static_cast(lookups) << std::endl; +} + +} // namespace +} // namespace util + +int main() { + for (uint64_t i = 1; i <= 10000000ULL; i *= 10) { + util::Test(i, 10000000); + } +} diff --git a/util/scoped.cc b/util/scoped.cc index de1d9e940..84f4344b7 100644 --- a/util/scoped.cc +++ b/util/scoped.cc @@ -7,6 +7,8 @@ namespace util { +// TODO: if we're really under memory pressure, don't allocate memory to +// display the error. MallocException::MallocException(std::size_t requested) throw() { *this << "for " << requested << " bytes "; } @@ -16,10 +18,6 @@ MallocException::~MallocException() throw() {} namespace { void *InspectAddr(void *addr, std::size_t requested, const char *func_name) { UTIL_THROW_IF_ARG(!addr && requested, MallocException, (requested), "in " << func_name); - // These routines are often used for large chunks of memory where huge pages help. -#if MADV_HUGEPAGE - madvise(addr, requested, MADV_HUGEPAGE); -#endif return addr; } } // namespace @@ -36,4 +34,10 @@ void scoped_malloc::call_realloc(std::size_t requested) { p_ = InspectAddr(std::realloc(p_, requested), requested, "realloc"); } +void AdviseHugePages(const void *addr, std::size_t size) { +#if MADV_HUGEPAGE + madvise((void*)addr, size, MADV_HUGEPAGE); +#endif +} + } // namespace util diff --git a/util/scoped.hh b/util/scoped.hh index c347a43cc..21e9a7566 100644 --- a/util/scoped.hh +++ b/util/scoped.hh @@ -104,6 +104,8 @@ template class scoped_ptr : public scoped { explicit scoped_ptr(T *p = NULL) : scoped(p) {} }; +void AdviseHugePages(const void *addr, std::size_t size); + } // namespace util #endif // UTIL_SCOPED_H diff --git a/util/stream/Jamfile b/util/stream/Jamfile index 2e99979f5..cde0247e7 100644 --- a/util/stream/Jamfile +++ b/util/stream/Jamfile @@ -4,9 +4,10 @@ # timer-link = ; #} -fakelib stream : chain.cc io.cc line_input.cc multi_progress.cc ..//kenutil /top//boost_thread : : : /top//boost_thread ; +fakelib stream : chain.cc rewindable_stream.cc io.cc line_input.cc multi_progress.cc ..//kenutil /top//boost_thread : : : /top//boost_thread ; import testing ; unit-test io_test : io_test.cc stream /top//boost_unit_test_framework ; unit-test stream_test : stream_test.cc stream /top//boost_unit_test_framework ; +unit-test rewindable_stream_test : rewindable_stream_test.cc stream /top//boost_unit_test_framework ; unit-test sort_test : sort_test.cc stream /top//boost_unit_test_framework ; diff --git a/util/stream/block.hh b/util/stream/block.hh index 6a70dba3e..42df13f32 100644 --- a/util/stream/block.hh +++ b/util/stream/block.hh @@ -72,6 +72,7 @@ class Block { private: friend class Link; + friend class RewindableStream; /** * Points this block's memory at NULL. diff --git a/util/stream/chain.hh b/util/stream/chain.hh index 0cd8c2aae..8caa1afcb 100644 --- a/util/stream/chain.hh +++ b/util/stream/chain.hh @@ -23,6 +23,7 @@ class ChainConfigException : public Exception { }; class Chain; +class RewindableStream; /** * Encapsulates a @ref PCQueue "producer queue" and a @ref PCQueue "consumer queue" within a @ref Chain "chain". @@ -35,6 +36,7 @@ class ChainPosition { private: friend class Chain; friend class Link; + friend class RewindableStream; ChainPosition(PCQueue &in, PCQueue &out, Chain *chain, MultiProgress &progress) : in_(&in), out_(&out), chain_(chain), progress_(progress.Add()) {} diff --git a/util/stream/io.hh b/util/stream/io.hh index c3b53bbfe..4605a8a79 100644 --- a/util/stream/io.hh +++ b/util/stream/io.hh @@ -70,8 +70,8 @@ class FileBuffer { return PWriteAndRecycle(file_.get()); } - PRead Source() const { - return PRead(file_.get()); + PRead Source(bool discard = false) { + return PRead(discard ? file_.release() : file_.get(), discard); } uint64_t Size() const { diff --git a/util/stream/rewindable_stream.cc b/util/stream/rewindable_stream.cc new file mode 100644 index 000000000..c7e39231b --- /dev/null +++ b/util/stream/rewindable_stream.cc @@ -0,0 +1,117 @@ +#include "util/stream/rewindable_stream.hh" +#include "util/pcqueue.hh" + +namespace util { +namespace stream { + +RewindableStream::RewindableStream() + : current_(NULL), in_(NULL), out_(NULL), poisoned_(true) { + // nothing +} + +void RewindableStream::Init(const ChainPosition &position) { + UTIL_THROW_IF2(in_, "RewindableStream::Init twice"); + in_ = position.in_; + out_ = position.out_; + poisoned_ = false; + progress_ = position.progress_; + entry_size_ = position.GetChain().EntrySize(); + block_size_ = position.GetChain().BlockSize(); + FetchBlock(); + current_bl_ = &second_bl_; + current_ = static_cast(current_bl_->Get()); + end_ = current_ + current_bl_->ValidSize(); +} + +const void *RewindableStream::Get() const { + return current_; +} + +void *RewindableStream::Get() { + return current_; +} + +RewindableStream &RewindableStream::operator++() { + assert(*this); + assert(current_ < end_); + current_ += entry_size_; + if (current_ == end_) { + // two cases: either we need to fetch the next block, or we've already + // fetched it before. We can check this by looking at the current_bl_ + // pointer: if it's at the second_bl_, we need to flush and fetch a new + // block. Otherwise, we can just move over to the second block. + if (current_bl_ == &second_bl_) { + if (first_bl_) { + out_->Produce(first_bl_); + progress_ += first_bl_.ValidSize(); + } + first_bl_ = second_bl_; + FetchBlock(); + } + current_bl_ = &second_bl_; + current_ = static_cast(second_bl_.Get()); + end_ = current_ + second_bl_.ValidSize(); + } + + if (!*current_bl_) + { + if (current_bl_ == &second_bl_ && first_bl_) + { + out_->Produce(first_bl_); + progress_ += first_bl_.ValidSize(); + } + out_->Produce(*current_bl_); + poisoned_ = true; + } + + return *this; +} + +void RewindableStream::FetchBlock() { + // The loop is needed since it is *feasible* that we're given 0 sized but + // valid blocks + do { + in_->Consume(second_bl_); + } while (second_bl_ && second_bl_.ValidSize() == 0); +} + +void RewindableStream::Mark() { + marked_ = current_; +} + +void RewindableStream::Rewind() { + if (marked_ >= first_bl_.Get() && marked_ < first_bl_.ValidEnd()) { + current_bl_ = &first_bl_; + current_ = marked_; + } else if (marked_ >= second_bl_.Get() && marked_ < second_bl_.ValidEnd()) { + current_bl_ = &second_bl_; + current_ = marked_; + } else { UTIL_THROW2("RewindableStream rewound too far"); } +} + +void RewindableStream::Poison() { + assert(!poisoned_); + + // Three things: if we have a buffered first block, we need to produce it + // first. Then, produce the partial "current" block, and then send the + // poison down the chain + + // if we still have a buffered first block, produce it first + if (current_bl_ == &second_bl_ && first_bl_) { + out_->Produce(first_bl_); + progress_ += first_bl_.ValidSize(); + } + + // send our partial block + current_bl_->SetValidSize(current_ + - static_cast(current_bl_->Get())); + out_->Produce(*current_bl_); + progress_ += current_bl_->ValidSize(); + + // send down the poison + current_bl_->SetToPoison(); + out_->Produce(*current_bl_); + poisoned_ = true; +} +} +} diff --git a/util/stream/rewindable_stream.hh b/util/stream/rewindable_stream.hh new file mode 100644 index 000000000..9ee637c99 --- /dev/null +++ b/util/stream/rewindable_stream.hh @@ -0,0 +1,108 @@ +#ifndef UTIL_STREAM_REWINDABLE_STREAM_H +#define UTIL_STREAM_REWINDABLE_STREAM_H + +#include "util/stream/chain.hh" + +#include + +namespace util { +namespace stream { + +/** + * A RewindableStream is like a Stream (but one that is only used for + * creating input at the start of a chain) except that it can be rewound to + * be able to re-write a part of the stream before it is sent. Rewinding + * has a limit of 2 * block_size_ - 1 in distance (it does *not* buffer an + * entire stream into memory, only a maximum of 2 * block_size_). + */ +class RewindableStream : boost::noncopyable { + public: + /** + * Creates an uninitialized RewindableStream. You **must** call Init() + * on it later! + */ + RewindableStream(); + + /** + * Initializes an existing RewindableStream at a specific position in + * a Chain. + * + * @param position The position in the chain to get input from and + * produce output on + */ + void Init(const ChainPosition &position); + + /** + * Constructs a RewindableStream at a specific position in a Chain all + * in one step. + * + * Equivalent to RewindableStream a(); a.Init(....); + */ + explicit RewindableStream(const ChainPosition &position); + + /** + * Gets the record at the current stream position. Const version. + */ + const void *Get() const; + + /** + * Gets the record at the current stream position. + */ + void *Get(); + + operator bool() const { return current_; } + + bool operator!() const { return !(*this); } + + /** + * Marks the current position in the stream to be rewound to later. + * Note that you can only rewind back as far as 2 * block_size_ - 1! + */ + void Mark(); + + /** + * Rewinds the stream back to the marked position. This will throw an + * exception if the marked position is too far away. + */ + void Rewind(); + + /** + * Moves the stream forward to the next record. This internally may + * buffer a block for the purposes of rewinding. + */ + RewindableStream& operator++(); + + /** + * Poisons the stream. This sends any buffered blocks down the chain + * and sends a poison block as well (sending at most 2 non-poison and 1 + * poison block). + */ + void Poison(); + + private: + void FetchBlock(); + + std::size_t entry_size_; + std::size_t block_size_; + + uint8_t *marked_, *current_, *end_; + + Block first_bl_; + Block second_bl_; + Block* current_bl_; + + PCQueue *in_, *out_; + + bool poisoned_; + + WorkerProgress progress_; +}; + +inline Chain &operator>>(Chain &chain, RewindableStream &stream) { + stream.Init(chain.Add()); + return chain; +} + +} +} +#endif diff --git a/util/stream/rewindable_stream_test.cc b/util/stream/rewindable_stream_test.cc new file mode 100644 index 000000000..3ed87f372 --- /dev/null +++ b/util/stream/rewindable_stream_test.cc @@ -0,0 +1,41 @@ +#include "util/stream/io.hh" + +#include "util/stream/rewindable_stream.hh" +#include "util/file.hh" + +#define BOOST_TEST_MODULE RewindableStreamTest +#include + +namespace util { +namespace stream { +namespace { + +BOOST_AUTO_TEST_CASE(RewindableStreamTest) { + scoped_fd in(MakeTemp("io_test_temp")); + for (uint64_t i = 0; i < 100000; ++i) { + WriteOrThrow(in.get(), &i, sizeof(uint64_t)); + } + SeekOrThrow(in.get(), 0); + + ChainConfig config; + config.entry_size = 8; + config.total_memory = 100; + config.block_count = 6; + + RewindableStream s; + Chain chain(config); + chain >> Read(in.get()) >> s >> kRecycle; + uint64_t i = 0; + for (; s; ++s, ++i) { + BOOST_CHECK_EQUAL(i, *static_cast(s.Get())); + if (100000UL - i == 2) + s.Mark(); + } + BOOST_CHECK_EQUAL(100000ULL, i); + s.Rewind(); + BOOST_CHECK_EQUAL(100000ULL - 2, *static_cast(s.Get())); +} + +} +} +} diff --git a/util/stream/sort.hh b/util/stream/sort.hh index a1e0a8539..1b4801ad6 100644 --- a/util/stream/sort.hh +++ b/util/stream/sort.hh @@ -25,6 +25,7 @@ #include "util/stream/timer.hh" #include "util/file.hh" +#include "util/fixed_array.hh" #include "util/scoped.hh" #include "util/sized_iterator.hh" @@ -544,6 +545,54 @@ template uint64_t BlockingSort(Chain &chain, cons return size; } +/** + * Represents an @ref util::FixedArray "array" capable of storing @ref util::stream::Sort "Sort" objects. + * + * In the anticipated use case, an instance of this class will maintain one @ref util::stream::Sort "Sort" object + * for each n-gram order (ranging from 1 up to the maximum n-gram order being processed). + * Use in this manner would enable the n-grams each n-gram order to be sorted, in parallel. + * + * @tparam Compare An @ref Comparator "ngram comparator" to use during sorting. + */ +template class Sorts : public FixedArray > { + private: + typedef Sort S; + typedef FixedArray P; + + public: + /** + * Constructs, but does not initialize. + * + * @ref util::FixedArray::Init() "Init" must be called before use. + * + * @see util::FixedArray::Init() + */ + Sorts() {} + + /** + * Constructs an @ref util::FixedArray "array" capable of storing a fixed number of @ref util::stream::Sort "Sort" objects. + * + * @param number The maximum number of @ref util::stream::Sort "sorters" that can be held by this @ref util::FixedArray "array" + * @see util::FixedArray::FixedArray() + */ + explicit Sorts(std::size_t number) : FixedArray >(number) {} + + /** + * Constructs a new @ref util::stream::Sort "Sort" object which is stored in this @ref util::FixedArray "array". + * + * The new @ref util::stream::Sort "Sort" object is constructed using the provided @ref util::stream::SortConfig "SortConfig" and @ref Comparator "ngram comparator"; + * once constructed, a new worker @ref util::stream::Thread "thread" (owned by the @ref util::stream::Chain "chain") will sort the n-gram data stored + * in the @ref util::stream::Block "blocks" of the provided @ref util::stream::Chain "chain". + * + * @see util::stream::Sort::Sort() + * @see util::stream::Chain::operator>>() + */ + void push_back(util::stream::Chain &chain, const util::stream::SortConfig &config, const Compare &compare = Compare(), const Combine &combine = Combine()) { + new (P::end()) S(chain, config, compare, combine); // use "placement new" syntax to initalize S in an already-allocated memory location + P::Constructed(); + } +}; + } // namespace stream } // namespace util diff --git a/util/tempfile.hh b/util/tempfile.hh index 9b872a27e..9c28346fc 100644 --- a/util/tempfile.hh +++ b/util/tempfile.hh @@ -10,13 +10,14 @@ #if defined(_WIN32) || defined(_WIN64) #include +#else +#include #endif #include #include #include "util/exception.hh" -#include "util/unistd.hh" namespace util { diff --git a/util/unistd.hh b/util/unistd.hh deleted file mode 100644 index f99be592a..000000000 --- a/util/unistd.hh +++ /dev/null @@ -1,22 +0,0 @@ -#ifndef UTIL_UNISTD_H -#define UTIL_UNISTD_H - -#if (defined(_WIN32) || defined(_WIN64)) && !defined(__MINGW32__) - -// Windows doesn't define -// -// So we define what we need here instead: -// -#define STDIN_FILENO=0 -#define STDOUT_FILENO=1 - - -#else // Huzzah for POSIX! - -#include - -#endif - - - -#endif // UTIL_UNISTD_H diff --git a/util/usage.cc b/util/usage.cc index f2b661014..5f66b17d2 100644 --- a/util/usage.cc +++ b/util/usage.cc @@ -135,6 +135,16 @@ double WallTime() { return Subtract(GetWall(), kRecordStart.Started()); } +double UserTime() { +#if !defined(_WIN32) && !defined(_WIN64) + struct rusage usage; + if (getrusage(RUSAGE_SELF, &usage)) + return 0.0; + return DoubleSec(usage.ru_utime); +#endif + return 0.0; +} + void PrintUsage(std::ostream &out) { #if !defined(_WIN32) && !defined(_WIN64) // Linux doesn't set memory usage in getrusage :-( diff --git a/util/usage.hh b/util/usage.hh index 85bd2119a..dff81b59d 100644 --- a/util/usage.hh +++ b/util/usage.hh @@ -9,6 +9,8 @@ namespace util { // Time in seconds since process started. Zero on unsupported platforms. double WallTime(); +double UserTime(); + void PrintUsage(std::ostream &to); // Determine how much physical memory there is. Return 0 on failure. @@ -16,5 +18,6 @@ uint64_t GuessPhysicalMemory(); // Parse a size like unix sort. Sadly, this means the default multiplier is K. uint64_t ParseSize(const std::string &arg); + } // namespace util #endif // UTIL_USAGE_H