KenLM 5a7efd8fe1db88ee0a9f7e9479b24ac3ca348221 with Hieu's patch to exception.hh

This commit is contained in:
Kenneth Heafield 2014-06-02 10:28:02 -07:00
parent 31a583b0bc
commit dd03f9fb69
111 changed files with 2209 additions and 798 deletions

View File

@ -13,7 +13,16 @@ update-if-changed $(ORDER-LOG) $(max-order) ;
max-order += <dependency>$(ORDER-LOG) ; max-order += <dependency>$(ORDER-LOG) ;
fakelib kenlm : [ glob *.cc : *main.cc *test.cc ] ../util//kenutil : <include>.. $(max-order) : : <include>.. $(max-order) ; wrappers = ;
local with-nplm = [ option.get "with-nplm" ] ;
if $(with-nplm) {
lib neuralLM : : <search>$(with-nplm)/src ;
obj nplm.o : wrappers/nplm.cc : <include>.. <include>$(with-nplm)/src <cxxflags>-fopenmp ;
alias nplm : nplm.o neuralLM ..//boost_thread : : : <cxxflags>-fopenmp <linkflags>-fopenmp <define>WITH_NPLM <library>..//boost_thread ;
wrappers += nplm ;
}
fakelib kenlm : $(wrappers) [ glob *.cc : *main.cc *test.cc ] ../util//kenutil : <include>.. $(max-order) : : <include>.. $(max-order) ;
import testing ; import testing ;

View File

@ -10,8 +10,8 @@
* Currently only used for next pointers. * Currently only used for next pointers.
*/ */
#ifndef LM_BHIKSHA__ #ifndef LM_BHIKSHA_H
#define LM_BHIKSHA__ #define LM_BHIKSHA_H
#include <stdint.h> #include <stdint.h>
#include <assert.h> #include <assert.h>
@ -109,4 +109,4 @@ class ArrayBhiksha {
} // namespace ngram } // namespace ngram
} // namespace lm } // namespace lm
#endif // LM_BHIKSHA__ #endif // LM_BHIKSHA_H

View File

@ -149,7 +149,7 @@ void BinaryFormat::InitializeBinary(int fd, ModelType model_type, unsigned int s
void BinaryFormat::ReadForConfig(void *to, std::size_t amount, uint64_t offset_excluding_header) const { void BinaryFormat::ReadForConfig(void *to, std::size_t amount, uint64_t offset_excluding_header) const {
assert(header_size_ != kInvalidSize); assert(header_size_ != kInvalidSize);
util::PReadOrThrow(file_.get(), to, amount, offset_excluding_header + header_size_); util::ErsatzPRead(file_.get(), to, amount, offset_excluding_header + header_size_);
} }
void *BinaryFormat::LoadBinary(std::size_t size) { void *BinaryFormat::LoadBinary(std::size_t size) {

View File

@ -1,5 +1,5 @@
#ifndef LM_BINARY_FORMAT__ #ifndef LM_BINARY_FORMAT_H
#define LM_BINARY_FORMAT__ #define LM_BINARY_FORMAT_H
#include "lm/config.hh" #include "lm/config.hh"
#include "lm/model_type.hh" #include "lm/model_type.hh"
@ -103,4 +103,4 @@ bool IsBinaryFormat(int fd);
} // namespace ngram } // namespace ngram
} // namespace lm } // namespace lm
#endif // LM_BINARY_FORMAT__ #endif // LM_BINARY_FORMAT_H

View File

@ -1,5 +1,5 @@
#ifndef LM_BLANK__ #ifndef LM_BLANK_H
#define LM_BLANK__ #define LM_BLANK_H
#include <limits> #include <limits>
@ -40,4 +40,4 @@ inline bool HasExtension(const float &backoff) {
} // namespace ngram } // namespace ngram
} // namespace lm } // namespace lm
#endif // LM_BLANK__ #endif // LM_BLANK_H

View File

@ -1,8 +1,9 @@
#include "lm/builder/adjust_counts.hh" #include "lm/builder/adjust_counts.hh"
#include "lm/builder/multi_stream.hh" #include "lm/builder/ngram_stream.hh"
#include "util/stream/timer.hh" #include "util/stream/timer.hh"
#include <algorithm> #include <algorithm>
#include <iostream>
namespace lm { namespace builder { namespace lm { namespace builder {
@ -10,19 +11,19 @@ BadDiscountException::BadDiscountException() throw() {}
BadDiscountException::~BadDiscountException() throw() {} BadDiscountException::~BadDiscountException() throw() {}
namespace { namespace {
// Return last word in full that is different. // Return last word in full that is different.
const WordIndex* FindDifference(const NGram &full, const NGram &lower_last) { const WordIndex* FindDifference(const NGram &full, const NGram &lower_last) {
const WordIndex *cur_word = full.end() - 1; const WordIndex *cur_word = full.end() - 1;
const WordIndex *pre_word = lower_last.end() - 1; const WordIndex *pre_word = lower_last.end() - 1;
// Find last difference. // Find last difference.
for (; pre_word >= lower_last.begin() && *pre_word == *cur_word; --cur_word, --pre_word) {} for (; pre_word >= lower_last.begin() && *pre_word == *cur_word; --cur_word, --pre_word) {}
return cur_word; return cur_word;
} }
class StatCollector { class StatCollector {
public: public:
StatCollector(std::size_t order, std::vector<uint64_t> &counts, std::vector<Discount> &discounts) StatCollector(std::size_t order, std::vector<uint64_t> &counts, std::vector<uint64_t> &counts_pruned, std::vector<Discount> &discounts)
: orders_(order), full_(orders_.back()), counts_(counts), discounts_(discounts) { : orders_(order), full_(orders_.back()), counts_(counts), counts_pruned_(counts_pruned), discounts_(discounts) {
memset(&orders_[0], 0, sizeof(OrderStat) * order); memset(&orders_[0], 0, sizeof(OrderStat) * order);
} }
@ -30,10 +31,12 @@ class StatCollector {
void CalculateDiscounts() { void CalculateDiscounts() {
counts_.resize(orders_.size()); counts_.resize(orders_.size());
counts_pruned_.resize(orders_.size());
discounts_.resize(orders_.size()); discounts_.resize(orders_.size());
for (std::size_t i = 0; i < orders_.size(); ++i) { for (std::size_t i = 0; i < orders_.size(); ++i) {
const OrderStat &s = orders_[i]; const OrderStat &s = orders_[i];
counts_[i] = s.count; counts_[i] = s.count;
counts_pruned_[i] = s.count_pruned;
for (unsigned j = 1; j < 4; ++j) { for (unsigned j = 1; j < 4; ++j) {
// TODO: Specialize error message for j == 3, meaning 3+ // TODO: Specialize error message for j == 3, meaning 3+
@ -52,14 +55,18 @@ class StatCollector {
} }
} }
void Add(std::size_t order_minus_1, uint64_t count) { void Add(std::size_t order_minus_1, uint64_t count, bool pruned = false) {
OrderStat &stat = orders_[order_minus_1]; OrderStat &stat = orders_[order_minus_1];
++stat.count; ++stat.count;
if (!pruned)
++stat.count_pruned;
if (count < 5) ++stat.n[count]; if (count < 5) ++stat.n[count];
} }
void AddFull(uint64_t count) { void AddFull(uint64_t count, bool pruned = false) {
++full_.count; ++full_.count;
if (!pruned)
++full_.count_pruned;
if (count < 5) ++full_.n[count]; if (count < 5) ++full_.n[count];
} }
@ -68,24 +75,27 @@ class StatCollector {
// n_1 in equation 26 of Chen and Goodman etc // n_1 in equation 26 of Chen and Goodman etc
uint64_t n[5]; uint64_t n[5];
uint64_t count; uint64_t count;
uint64_t count_pruned;
}; };
std::vector<OrderStat> orders_; std::vector<OrderStat> orders_;
OrderStat &full_; OrderStat &full_;
std::vector<uint64_t> &counts_; std::vector<uint64_t> &counts_;
std::vector<uint64_t> &counts_pruned_;
std::vector<Discount> &discounts_; std::vector<Discount> &discounts_;
}; };
// Reads all entries in order like NGramStream does. // Reads all entries in order like NGramStream does.
// But deletes any entries that have <s> in the 1st (not 0th) position on the // But deletes any entries that have <s> in the 1st (not 0th) position on the
// way out by putting other entries in their place. This disrupts the sort // way out by putting other entries in their place. This disrupts the sort
// order but we don't care because the data is going to be sorted again. // order but we don't care because the data is going to be sorted again.
class CollapseStream { class CollapseStream {
public: public:
CollapseStream(const util::stream::ChainPosition &position) : CollapseStream(const util::stream::ChainPosition &position, uint64_t prune_threshold) :
current_(NULL, NGram::OrderFromSize(position.GetChain().EntrySize())), current_(NULL, NGram::OrderFromSize(position.GetChain().EntrySize())),
block_(position) { prune_threshold_(prune_threshold),
block_(position) {
StartBlock(); StartBlock();
} }
@ -96,10 +106,18 @@ class CollapseStream {
CollapseStream &operator++() { CollapseStream &operator++() {
assert(block_); assert(block_);
if (current_.begin()[1] == kBOS && current_.Base() < copy_from_) { if (current_.begin()[1] == kBOS && current_.Base() < copy_from_) {
memcpy(current_.Base(), copy_from_, current_.TotalSize()); memcpy(current_.Base(), copy_from_, current_.TotalSize());
UpdateCopyFrom(); UpdateCopyFrom();
// Mark highest order n-grams for later pruning
if(current_.Count() <= prune_threshold_) {
current_.Mark();
}
} }
current_.NextInMemory(); current_.NextInMemory();
uint8_t *block_base = static_cast<uint8_t*>(block_->Get()); uint8_t *block_base = static_cast<uint8_t*>(block_->Get());
if (current_.Base() == block_base + block_->ValidSize()) { if (current_.Base() == block_base + block_->ValidSize()) {
@ -107,6 +125,12 @@ class CollapseStream {
++block_; ++block_;
StartBlock(); StartBlock();
} }
// Mark highest order n-grams for later pruning
if(current_.Count() <= prune_threshold_) {
current_.Mark();
}
return *this; return *this;
} }
@ -119,9 +143,15 @@ class CollapseStream {
current_.ReBase(block_->Get()); current_.ReBase(block_->Get());
copy_from_ = static_cast<uint8_t*>(block_->Get()) + block_->ValidSize(); copy_from_ = static_cast<uint8_t*>(block_->Get()) + block_->ValidSize();
UpdateCopyFrom(); UpdateCopyFrom();
// Mark highest order n-grams for later pruning
if(current_.Count() <= prune_threshold_) {
current_.Mark();
}
} }
// Find last without bos. // Find last without bos.
void UpdateCopyFrom() { void UpdateCopyFrom() {
for (copy_from_ -= current_.TotalSize(); copy_from_ >= current_.Base(); copy_from_ -= current_.TotalSize()) { for (copy_from_ -= current_.TotalSize(); copy_from_ >= current_.Base(); copy_from_ -= current_.TotalSize()) {
if (NGram(copy_from_, current_.Order()).begin()[1] != kBOS) break; if (NGram(copy_from_, current_.Order()).begin()[1] != kBOS) break;
@ -132,79 +162,103 @@ class CollapseStream {
// Goes backwards in the block // Goes backwards in the block
uint8_t *copy_from_; uint8_t *copy_from_;
uint64_t prune_threshold_;
util::stream::Link block_; util::stream::Link block_;
}; };
} // namespace } // namespace
void AdjustCounts::Run(const ChainPositions &positions) { void AdjustCounts::Run(const util::stream::ChainPositions &positions) {
UTIL_TIMER("(%w s) Adjusted counts\n"); UTIL_TIMER("(%w s) Adjusted counts\n");
const std::size_t order = positions.size(); const std::size_t order = positions.size();
StatCollector stats(order, counts_, discounts_); StatCollector stats(order, counts_, counts_pruned_, discounts_);
if (order == 1) { if (order == 1) {
// Only unigrams. Just collect stats. // Only unigrams. Just collect stats.
for (NGramStream full(positions[0]); full; ++full) for (NGramStream full(positions[0]); full; ++full)
stats.AddFull(full->Count()); stats.AddFull(full->Count());
stats.CalculateDiscounts(); stats.CalculateDiscounts();
return; return;
} }
NGramStreams streams; NGramStreams streams;
streams.Init(positions, positions.size() - 1); streams.Init(positions, positions.size() - 1);
CollapseStream full(positions[positions.size() - 1]);
CollapseStream full(positions[positions.size() - 1], prune_thresholds_.back());
// Initialization: <unk> has count 0 and so does <s>. // Initialization: <unk> has count 0 and so does <s>.
NGramStream *lower_valid = streams.begin(); NGramStream *lower_valid = streams.begin();
streams[0]->Count() = 0; streams[0]->Count() = 0;
*streams[0]->begin() = kUNK; *streams[0]->begin() = kUNK;
stats.Add(0, 0); stats.Add(0, 0);
(++streams[0])->Count() = 0; (++streams[0])->Count() = 0;
*streams[0]->begin() = kBOS; *streams[0]->begin() = kBOS;
// not in stats because it will get put in later. // not in stats because it will get put in later.
std::vector<uint64_t> lower_counts(positions.size(), 0);
// iterate over full (the stream of the highest order ngrams) // iterate over full (the stream of the highest order ngrams)
for (; full; ++full) { for (; full; ++full) {
const WordIndex *different = FindDifference(*full, **lower_valid); const WordIndex *different = FindDifference(*full, **lower_valid);
std::size_t same = full->end() - 1 - different; std::size_t same = full->end() - 1 - different;
// Increment the adjusted count. // Increment the adjusted count.
if (same) ++streams[same - 1]->Count(); if (same) ++streams[same - 1]->Count();
// Output all the valid ones that changed. // Output all the valid ones that changed.
for (; lower_valid >= &streams[same]; --lower_valid) { for (; lower_valid >= &streams[same]; --lower_valid) {
stats.Add(lower_valid - streams.begin(), (*lower_valid)->Count());
// mjd: review this!
uint64_t order = (*lower_valid)->Order();
uint64_t realCount = lower_counts[order - 1];
if(order > 1 && prune_thresholds_[order - 1] && realCount <= prune_thresholds_[order - 1])
(*lower_valid)->Mark();
stats.Add(lower_valid - streams.begin(), (*lower_valid)->UnmarkedCount(), (*lower_valid)->IsMarked());
++*lower_valid; ++*lower_valid;
} }
// Count the true occurrences of lower-order n-grams
for (std::size_t i = 0; i < lower_counts.size(); ++i) {
if (i >= same) {
lower_counts[i] = 0;
}
lower_counts[i] += full->UnmarkedCount();
}
// This is here because bos is also const WordIndex *, so copy gets // This is here because bos is also const WordIndex *, so copy gets
// consistent argument types. // consistent argument types.
const WordIndex *full_end = full->end(); const WordIndex *full_end = full->end();
// Initialize and mark as valid up to bos. // Initialize and mark as valid up to bos.
const WordIndex *bos; const WordIndex *bos;
for (bos = different; (bos > full->begin()) && (*bos != kBOS); --bos) { for (bos = different; (bos > full->begin()) && (*bos != kBOS); --bos) {
++lower_valid; ++lower_valid;
std::copy(bos, full_end, (*lower_valid)->begin()); std::copy(bos, full_end, (*lower_valid)->begin());
(*lower_valid)->Count() = 1; (*lower_valid)->Count() = 1;
} }
// Now bos indicates where <s> is or is the 0th word of full. // Now bos indicates where <s> is or is the 0th word of full.
if (bos != full->begin()) { if (bos != full->begin()) {
// There is an <s> beyond the 0th word. // There is an <s> beyond the 0th word.
NGramStream &to = *++lower_valid; NGramStream &to = *++lower_valid;
std::copy(bos, full_end, to->begin()); std::copy(bos, full_end, to->begin());
to->Count() = full->Count();
// mjd: what is this doing?
to->Count() = full->UnmarkedCount();
} else { } else {
stats.AddFull(full->Count()); stats.AddFull(full->UnmarkedCount(), full->IsMarked());
} }
assert(lower_valid >= &streams[0]); assert(lower_valid >= &streams[0]);
} }
// Output everything valid. // Output everything valid.
for (NGramStream *s = streams.begin(); s <= lower_valid; ++s) { for (NGramStream *s = streams.begin(); s <= lower_valid; ++s) {
stats.Add(s - streams.begin(), (*s)->Count()); if((*s)->Count() <= prune_thresholds_[(*s)->Order() - 1])
(*s)->Mark();
stats.Add(s - streams.begin(), (*s)->UnmarkedCount(), (*s)->IsMarked());
++*s; ++*s;
} }
// Poison everyone! Except the N-grams which were already poisoned by the input. // Poison everyone! Except the N-grams which were already poisoned by the input.
for (NGramStream *s = streams.begin(); s != streams.end(); ++s) for (NGramStream *s = streams.begin(); s != streams.end(); ++s)
s->Poison(); s->Poison();

View File

@ -1,5 +1,5 @@
#ifndef LM_BUILDER_ADJUST_COUNTS__ #ifndef LM_BUILDER_ADJUST_COUNTS_H
#define LM_BUILDER_ADJUST_COUNTS__ #define LM_BUILDER_ADJUST_COUNTS_H
#include "lm/builder/discount.hh" #include "lm/builder/discount.hh"
#include "util/exception.hh" #include "util/exception.hh"
@ -8,11 +8,11 @@
#include <stdint.h> #include <stdint.h>
namespace util { namespace stream { class ChainPositions; } }
namespace lm { namespace lm {
namespace builder { namespace builder {
class ChainPositions;
class BadDiscountException : public util::Exception { class BadDiscountException : public util::Exception {
public: public:
BadDiscountException() throw(); BadDiscountException() throw();
@ -27,18 +27,21 @@ class BadDiscountException : public util::Exception {
*/ */
class AdjustCounts { class AdjustCounts {
public: public:
AdjustCounts(std::vector<uint64_t> &counts, std::vector<Discount> &discounts) AdjustCounts(std::vector<uint64_t> &counts, std::vector<uint64_t> &counts_pruned, std::vector<Discount> &discounts, std::vector<uint64_t> &prune_thresholds)
: counts_(counts), discounts_(discounts) {} : counts_(counts), counts_pruned_(counts_pruned), discounts_(discounts), prune_thresholds_(prune_thresholds)
{}
void Run(const ChainPositions &positions); void Run(const util::stream::ChainPositions &positions);
private: private:
std::vector<uint64_t> &counts_; std::vector<uint64_t> &counts_;
std::vector<uint64_t> &counts_pruned_;
std::vector<Discount> &discounts_; std::vector<Discount> &discounts_;
std::vector<uint64_t> &prune_thresholds_;
}; };
} // namespace builder } // namespace builder
} // namespace lm } // namespace lm
#endif // LM_BUILDER_ADJUST_COUNTS__ #endif // LM_BUILDER_ADJUST_COUNTS_H

View File

@ -1,6 +1,6 @@
#include "lm/builder/adjust_counts.hh" #include "lm/builder/adjust_counts.hh"
#include "lm/builder/multi_stream.hh" #include "lm/builder/ngram_stream.hh"
#include "util/scoped.hh" #include "util/scoped.hh"
#include <boost/thread/thread.hpp> #include <boost/thread/thread.hpp>
@ -61,19 +61,21 @@ BOOST_AUTO_TEST_CASE(Simple) {
util::stream::ChainConfig config; util::stream::ChainConfig config;
config.total_memory = 100; config.total_memory = 100;
config.block_count = 1; config.block_count = 1;
Chains chains(4); util::stream::Chains chains(4);
for (unsigned i = 0; i < 4; ++i) { for (unsigned i = 0; i < 4; ++i) {
config.entry_size = NGram::TotalSize(i + 1); config.entry_size = NGram::TotalSize(i + 1);
chains.push_back(config); chains.push_back(config);
} }
chains[3] >> WriteInput(); chains[3] >> WriteInput();
ChainPositions for_adjust(chains); util::stream::ChainPositions for_adjust(chains);
for (unsigned i = 0; i < 4; ++i) { for (unsigned i = 0; i < 4; ++i) {
chains[i] >> boost::ref(outputs[i]); chains[i] >> boost::ref(outputs[i]);
} }
chains >> util::stream::kRecycle; chains >> util::stream::kRecycle;
BOOST_CHECK_THROW(AdjustCounts(counts, discount).Run(for_adjust), BadDiscountException); std::vector<uint64_t> counts_pruned(4);
std::vector<uint64_t> prune_thresholds(4);
BOOST_CHECK_THROW(AdjustCounts(counts, counts_pruned, discount, prune_thresholds).Run(for_adjust), BadDiscountException);
} }
BOOST_REQUIRE_EQUAL(4UL, counts.size()); BOOST_REQUIRE_EQUAL(4UL, counts.size());
BOOST_CHECK_EQUAL(4UL, counts[0]); BOOST_CHECK_EQUAL(4UL, counts[0]);

View File

@ -2,6 +2,7 @@
#include "lm/builder/ngram.hh" #include "lm/builder/ngram.hh"
#include "lm/lm_exception.hh" #include "lm/lm_exception.hh"
#include "lm/vocab.hh"
#include "lm/word_index.hh" #include "lm/word_index.hh"
#include "util/fake_ofstream.hh" #include "util/fake_ofstream.hh"
#include "util/file.hh" #include "util/file.hh"
@ -37,60 +38,6 @@ struct VocabEntry {
}; };
#pragma pack(pop) #pragma pack(pop)
const float kProbingMultiplier = 1.5;
class VocabHandout {
public:
static std::size_t MemUsage(WordIndex initial_guess) {
if (initial_guess < 2) initial_guess = 2;
return util::CheckOverflow(Table::Size(initial_guess, kProbingMultiplier));
}
explicit VocabHandout(int fd, WordIndex initial_guess) :
table_backing_(util::CallocOrThrow(MemUsage(initial_guess))),
table_(table_backing_.get(), MemUsage(initial_guess)),
double_cutoff_(std::max<std::size_t>(initial_guess * 1.1, 1)),
word_list_(fd) {
Lookup("<unk>"); // Force 0
Lookup("<s>"); // Force 1
Lookup("</s>"); // Force 2
}
WordIndex Lookup(const StringPiece &word) {
VocabEntry entry;
entry.key = util::MurmurHashNative(word.data(), word.size());
entry.value = table_.SizeNoSerialization();
Table::MutableIterator it;
if (table_.FindOrInsert(entry, it))
return it->value;
word_list_ << word << '\0';
UTIL_THROW_IF(Size() >= std::numeric_limits<lm::WordIndex>::max(), VocabLoadException, "Too many vocabulary words. Change WordIndex to uint64_t in lm/word_index.hh.");
if (Size() >= double_cutoff_) {
table_backing_.call_realloc(table_.DoubleTo());
table_.Double(table_backing_.get());
double_cutoff_ *= 2;
}
return entry.value;
}
WordIndex Size() const {
return table_.SizeNoSerialization();
}
private:
// TODO: factor out a resizable probing hash table.
// TODO: use mremap on linux to get all zeros on resizes.
util::scoped_malloc table_backing_;
typedef util::ProbingHashTable<VocabEntry, util::IdentityHash> Table;
Table table_;
std::size_t double_cutoff_;
util::FakeOFStream word_list_;
};
class DedupeHash : public std::unary_function<const WordIndex *, bool> { class DedupeHash : public std::unary_function<const WordIndex *, bool> {
public: public:
explicit DedupeHash(std::size_t order) : size_(order * sizeof(WordIndex)) {} explicit DedupeHash(std::size_t order) : size_(order * sizeof(WordIndex)) {}
@ -127,6 +74,10 @@ struct DedupeEntry {
} }
}; };
// TODO: don't have this here, should be with probing hash table defaults?
const float kProbingMultiplier = 1.5;
typedef util::ProbingHashTable<DedupeEntry, DedupeHash, DedupeEquals> Dedupe; typedef util::ProbingHashTable<DedupeEntry, DedupeHash, DedupeEquals> Dedupe;
class Writer { class Writer {
@ -220,37 +171,50 @@ float CorpusCount::DedupeMultiplier(std::size_t order) {
} }
std::size_t CorpusCount::VocabUsage(std::size_t vocab_estimate) { std::size_t CorpusCount::VocabUsage(std::size_t vocab_estimate) {
return VocabHandout::MemUsage(vocab_estimate); return ngram::GrowableVocab<ngram::WriteUniqueWords>::MemUsage(vocab_estimate);
} }
CorpusCount::CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block) CorpusCount::CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block, WarningAction disallowed_symbol)
: from_(from), vocab_write_(vocab_write), token_count_(token_count), type_count_(type_count), : from_(from), vocab_write_(vocab_write), token_count_(token_count), type_count_(type_count),
dedupe_mem_size_(Dedupe::Size(entries_per_block, kProbingMultiplier)), dedupe_mem_size_(Dedupe::Size(entries_per_block, kProbingMultiplier)),
dedupe_mem_(util::MallocOrThrow(dedupe_mem_size_)) { dedupe_mem_(util::MallocOrThrow(dedupe_mem_size_)),
disallowed_symbol_action_(disallowed_symbol) {
} }
void CorpusCount::Run(const util::stream::ChainPosition &position) { namespace {
UTIL_TIMER("(%w s) Counted n-grams\n"); void ComplainDisallowed(StringPiece word, WarningAction &action) {
switch (action) {
case SILENT:
return;
case COMPLAIN:
std::cerr << "Warning: " << word << " appears in the input. All instances of <s>, </s>, and <unk> will be interpreted as whitespace." << std::endl;
action = SILENT;
return;
case THROW_UP:
UTIL_THROW(FormatLoadException, "Special word " << word << " is not allowed in the corpus. I plan to support models containing <unk> in the future. Pass --skip_symbols to convert these symbols to whitespace.");
}
}
} // namespace
VocabHandout vocab(vocab_write_, type_count_); void CorpusCount::Run(const util::stream::ChainPosition &position) {
ngram::GrowableVocab<ngram::WriteUniqueWords> vocab(type_count_, vocab_write_);
token_count_ = 0; token_count_ = 0;
type_count_ = 0; type_count_ = 0;
const WordIndex end_sentence = vocab.Lookup("</s>"); const WordIndex end_sentence = vocab.FindOrInsert("</s>");
Writer writer(NGram::OrderFromSize(position.GetChain().EntrySize()), position, dedupe_mem_.get(), dedupe_mem_size_); Writer writer(NGram::OrderFromSize(position.GetChain().EntrySize()), position, dedupe_mem_.get(), dedupe_mem_size_);
uint64_t count = 0; uint64_t count = 0;
bool delimiters[256]; bool delimiters[256];
memset(delimiters, 0, sizeof(delimiters)); util::BoolCharacter::Build("\0\t\n\r ", delimiters);
const char kDelimiterSet[] = "\0\t\n\r ";
for (const char *i = kDelimiterSet; i < kDelimiterSet + sizeof(kDelimiterSet); ++i) {
delimiters[static_cast<unsigned char>(*i)] = true;
}
try { try {
while(true) { while(true) {
StringPiece line(from_.ReadLine()); StringPiece line(from_.ReadLine());
writer.StartSentence(); writer.StartSentence();
for (util::TokenIter<util::BoolCharacter, true> w(line, delimiters); w; ++w) { for (util::TokenIter<util::BoolCharacter, true> w(line, delimiters); w; ++w) {
WordIndex word = vocab.Lookup(*w); WordIndex word = vocab.FindOrInsert(*w);
UTIL_THROW_IF(word <= 2, FormatLoadException, "Special word " << *w << " is not allowed in the corpus. I plan to support models containing <unk> in the future."); if (word <= 2) {
ComplainDisallowed(*w, disallowed_symbol_action_);
continue;
}
writer.Append(word); writer.Append(word);
++count; ++count;
} }

View File

@ -1,6 +1,7 @@
#ifndef LM_BUILDER_CORPUS_COUNT__ #ifndef LM_BUILDER_CORPUS_COUNT_H
#define LM_BUILDER_CORPUS_COUNT__ #define LM_BUILDER_CORPUS_COUNT_H
#include "lm/lm_exception.hh"
#include "lm/word_index.hh" #include "lm/word_index.hh"
#include "util/scoped.hh" #include "util/scoped.hh"
@ -28,7 +29,7 @@ class CorpusCount {
// token_count: out. // token_count: out.
// type_count aka vocabulary size. Initialize to an estimate. It is set to the exact value. // type_count aka vocabulary size. Initialize to an estimate. It is set to the exact value.
CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block); CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block, WarningAction disallowed_symbol);
void Run(const util::stream::ChainPosition &position); void Run(const util::stream::ChainPosition &position);
@ -40,8 +41,10 @@ class CorpusCount {
std::size_t dedupe_mem_size_; std::size_t dedupe_mem_size_;
util::scoped_malloc dedupe_mem_; util::scoped_malloc dedupe_mem_;
WarningAction disallowed_symbol_action_;
}; };
} // namespace builder } // namespace builder
} // namespace lm } // namespace lm
#endif // LM_BUILDER_CORPUS_COUNT__ #endif // LM_BUILDER_CORPUS_COUNT_H

View File

@ -45,7 +45,7 @@ BOOST_AUTO_TEST_CASE(Short) {
NGramStream stream; NGramStream stream;
uint64_t token_count; uint64_t token_count;
WordIndex type_count = 10; WordIndex type_count = 10;
CorpusCount counter(input_piece, vocab.get(), token_count, type_count, chain.BlockSize() / chain.EntrySize()); CorpusCount counter(input_piece, vocab.get(), token_count, type_count, chain.BlockSize() / chain.EntrySize(), SILENT);
chain >> boost::ref(counter) >> stream >> util::stream::kRecycle; chain >> boost::ref(counter) >> stream >> util::stream::kRecycle;
const char *v[] = {"<unk>", "<s>", "</s>", "looking", "on", "a", "little", "more", "loin", "foo", "bar"}; const char *v[] = {"<unk>", "<s>", "</s>", "looking", "on", "a", "little", "more", "loin", "foo", "bar"};

View File

@ -1,5 +1,5 @@
#ifndef BUILDER_DISCOUNT__ #ifndef LM_BUILDER_DISCOUNT_H
#define BUILDER_DISCOUNT__ #define LM_BUILDER_DISCOUNT_H
#include <algorithm> #include <algorithm>
@ -23,4 +23,4 @@ struct Discount {
} // namespace builder } // namespace builder
} // namespace lm } // namespace lm
#endif // BUILDER_DISCOUNT__ #endif // LM_BUILDER_DISCOUNT_H

19
lm/builder/hash_gamma.hh Normal file
View File

@ -0,0 +1,19 @@
#ifndef LM_BUILDER_HASH_GAMMA__
#define LM_BUILDER_HASH_GAMMA__
#include <stdint.h>
namespace lm { namespace builder {
#pragma pack(push)
#pragma pack(4)
struct HashGamma {
uint64_t hash_value;
float gamma;
};
#pragma pack(pop)
}} // namespaces
#endif // LM_BUILDER_HASH_GAMMA__

View File

@ -1,5 +1,5 @@
#ifndef LM_BUILDER_HEADER_INFO__ #ifndef LM_BUILDER_HEADER_INFO_H
#define LM_BUILDER_HEADER_INFO__ #define LM_BUILDER_HEADER_INFO_H
#include <string> #include <string>
#include <stdint.h> #include <stdint.h>

View File

@ -3,6 +3,8 @@
#include "lm/builder/discount.hh" #include "lm/builder/discount.hh"
#include "lm/builder/ngram_stream.hh" #include "lm/builder/ngram_stream.hh"
#include "lm/builder/sort.hh" #include "lm/builder/sort.hh"
#include "lm/builder/hash_gamma.hh"
#include "util/murmur_hash.hh"
#include "util/file.hh" #include "util/file.hh"
#include "util/stream/chain.hh" #include "util/stream/chain.hh"
#include "util/stream/io.hh" #include "util/stream/io.hh"
@ -14,55 +16,179 @@ namespace lm { namespace builder {
namespace { namespace {
struct BufferEntry { struct BufferEntry {
// Gamma from page 20 of Chen and Goodman. // Gamma from page 20 of Chen and Goodman.
float gamma; float gamma;
// \sum_w a(c w) for all w. // \sum_w a(c w) for all w.
float denominator; float denominator;
}; };
// Extract an array of gamma from an array of BufferEntry. struct HashBufferEntry : public BufferEntry {
// Hash value of ngram. Used to join contexts with backoffs.
uint64_t hash_value;
};
// Reads all entries in order like NGramStream does.
// But deletes any entries that have CutoffCount below or equal to pruning
// threshold.
class PruneNGramStream {
public:
PruneNGramStream(const util::stream::ChainPosition &position) :
current_(NULL, NGram::OrderFromSize(position.GetChain().EntrySize())),
dest_(NULL, NGram::OrderFromSize(position.GetChain().EntrySize())),
currentCount_(0),
block_(position)
{
StartBlock();
}
NGram &operator*() { return current_; }
NGram *operator->() { return &current_; }
operator bool() const {
return block_;
}
PruneNGramStream &operator++() {
assert(block_);
if (current_.Order() > 1) {
if(currentCount_ > 0) {
if(dest_.Base() < current_.Base()) {
memcpy(dest_.Base(), current_.Base(), current_.TotalSize());
}
dest_.NextInMemory();
}
} else {
dest_.NextInMemory();
}
current_.NextInMemory();
uint8_t *block_base = static_cast<uint8_t*>(block_->Get());
if (current_.Base() == block_base + block_->ValidSize()) {
block_->SetValidSize(dest_.Base() - block_base);
++block_;
StartBlock();
}
currentCount_ = current_.CutoffCount();
return *this;
}
private:
void StartBlock() {
for (; ; ++block_) {
if (!block_) return;
if (block_->ValidSize()) break;
}
current_.ReBase(block_->Get());
currentCount_ = current_.CutoffCount();
dest_.ReBase(block_->Get());
}
NGram current_; // input iterator
NGram dest_; // output iterator
uint64_t currentCount_;
util::stream::Link block_;
};
// Extract an array of HashedGamma from an array of BufferEntry.
class OnlyGamma { class OnlyGamma {
public: public:
OnlyGamma(bool pruning) : pruning_(pruning) {}
void Run(const util::stream::ChainPosition &position) { void Run(const util::stream::ChainPosition &position) {
for (util::stream::Link block_it(position); block_it; ++block_it) { for (util::stream::Link block_it(position); block_it; ++block_it) {
float *out = static_cast<float*>(block_it->Get()); if(pruning_) {
const float *in = out; const HashBufferEntry *in = static_cast<const HashBufferEntry*>(block_it->Get());
const float *end = static_cast<const float*>(block_it->ValidEnd()); const HashBufferEntry *end = static_cast<const HashBufferEntry*>(block_it->ValidEnd());
for (out += 1, in += 2; in < end; out += 1, in += 2) {
*out = *in; // Just make it point to the beginning of the stream so it can be overwritten
// With HashGamma values. Do not attempt to interpret the values until set below.
HashGamma *out = static_cast<HashGamma*>(block_it->Get());
for (; in < end; out += 1, in += 1) {
// buffering, otherwise might overwrite values too early
float gamma_buf = in->gamma;
uint64_t hash_buf = in->hash_value;
out->gamma = gamma_buf;
out->hash_value = hash_buf;
}
block_it->SetValidSize((block_it->ValidSize() * sizeof(HashGamma)) / sizeof(HashBufferEntry));
}
else {
float *out = static_cast<float*>(block_it->Get());
const float *in = out;
const float *end = static_cast<const float*>(block_it->ValidEnd());
for (out += 1, in += 2; in < end; out += 1, in += 2) {
*out = *in;
}
block_it->SetValidSize(block_it->ValidSize() / 2);
} }
block_it->SetValidSize(block_it->ValidSize() / 2);
} }
} }
private:
bool pruning_;
}; };
class AddRight { class AddRight {
public: public:
AddRight(const Discount &discount, const util::stream::ChainPosition &input) AddRight(const Discount &discount, const util::stream::ChainPosition &input, bool pruning)
: discount_(discount), input_(input) {} : discount_(discount), input_(input), pruning_(pruning) {}
void Run(const util::stream::ChainPosition &output) { void Run(const util::stream::ChainPosition &output) {
NGramStream in(input_); NGramStream in(input_);
util::stream::Stream out(output); util::stream::Stream out(output);
std::vector<WordIndex> previous(in->Order() - 1); std::vector<WordIndex> previous(in->Order() - 1);
// Silly windows requires this workaround to just get an invalid pointer when empty.
void *const previous_raw = previous.empty() ? NULL : static_cast<void*>(&previous[0]);
const std::size_t size = sizeof(WordIndex) * previous.size(); const std::size_t size = sizeof(WordIndex) * previous.size();
for(; in; ++out) { for(; in; ++out) {
memcpy(&previous[0], in->begin(), size); memcpy(previous_raw, in->begin(), size);
uint64_t denominator = 0; uint64_t denominator = 0;
uint64_t normalizer = 0;
uint64_t counts[4]; uint64_t counts[4];
memset(counts, 0, sizeof(counts)); memset(counts, 0, sizeof(counts));
do { do {
denominator += in->Count(); denominator += in->UnmarkedCount();
++counts[std::min(in->Count(), static_cast<uint64_t>(3))];
} while (++in && !memcmp(&previous[0], in->begin(), size)); // Collect unused probability mass from pruning.
// Becomes 0 for unpruned ngrams.
normalizer += in->UnmarkedCount() - in->CutoffCount();
// Chen&Goodman do not mention counting based on cutoffs, but
// backoff becomes larger than 1 otherwise, so probably needs
// to count cutoffs. Counts normally without pruning.
if(in->CutoffCount() > 0)
++counts[std::min(in->CutoffCount(), static_cast<uint64_t>(3))];
} while (++in && !memcmp(previous_raw, in->begin(), size));
BufferEntry &entry = *reinterpret_cast<BufferEntry*>(out.Get()); BufferEntry &entry = *reinterpret_cast<BufferEntry*>(out.Get());
entry.denominator = static_cast<float>(denominator); entry.denominator = static_cast<float>(denominator);
entry.gamma = 0.0; entry.gamma = 0.0;
for (unsigned i = 1; i <= 3; ++i) { for (unsigned i = 1; i <= 3; ++i) {
entry.gamma += discount_.Get(i) * static_cast<float>(counts[i]); entry.gamma += discount_.Get(i) * static_cast<float>(counts[i]);
} }
// Makes model sum to 1 with pruning (I hope).
entry.gamma += normalizer;
entry.gamma /= entry.denominator; entry.gamma /= entry.denominator;
if(pruning_) {
// If pruning is enabled the stream actually contains HashBufferEntry, see InitialProbabilities(...),
// so add a hash value that identifies the current ngram.
static_cast<HashBufferEntry*>(&entry)->hash_value = util::MurmurHashNative(previous_raw, size);
}
} }
out.Poison(); out.Poison();
} }
@ -70,6 +196,7 @@ class AddRight {
private: private:
const Discount &discount_; const Discount &discount_;
const util::stream::ChainPosition input_; const util::stream::ChainPosition input_;
bool pruning_;
}; };
class MergeRight { class MergeRight {
@ -82,7 +209,7 @@ class MergeRight {
void Run(const util::stream::ChainPosition &primary) { void Run(const util::stream::ChainPosition &primary) {
util::stream::Stream summed(from_adder_); util::stream::Stream summed(from_adder_);
NGramStream grams(primary); PruneNGramStream grams(primary);
// Without interpolation, the interpolation weight goes to <unk>. // Without interpolation, the interpolation weight goes to <unk>.
if (grams->Order() == 1 && !interpolate_unigrams_) { if (grams->Order() == 1 && !interpolate_unigrams_) {
@ -97,15 +224,16 @@ class MergeRight {
++summed; ++summed;
return; return;
} }
std::vector<WordIndex> previous(grams->Order() - 1); std::vector<WordIndex> previous(grams->Order() - 1);
const std::size_t size = sizeof(WordIndex) * previous.size(); const std::size_t size = sizeof(WordIndex) * previous.size();
for (; grams; ++summed) { for (; grams; ++summed) {
memcpy(&previous[0], grams->begin(), size); memcpy(&previous[0], grams->begin(), size);
const BufferEntry &sums = *static_cast<const BufferEntry*>(summed.Get()); const BufferEntry &sums = *static_cast<const BufferEntry*>(summed.Get());
do { do {
Payload &pay = grams->Value(); Payload &pay = grams->Value();
pay.uninterp.prob = discount_.Apply(pay.count) / sums.denominator; pay.uninterp.prob = discount_.Apply(grams->UnmarkedCount()) / sums.denominator;
pay.uninterp.gamma = sums.gamma; pay.uninterp.gamma = sums.gamma;
} while (++grams && !memcmp(&previous[0], grams->begin(), size)); } while (++grams && !memcmp(&previous[0], grams->begin(), size));
} }
@ -119,17 +247,29 @@ class MergeRight {
} // namespace } // namespace
void InitialProbabilities(const InitialProbabilitiesConfig &config, const std::vector<Discount> &discounts, Chains &primary, Chains &second_in, Chains &gamma_out) { void InitialProbabilities(
util::stream::ChainConfig gamma_config = config.adder_out; const InitialProbabilitiesConfig &config,
gamma_config.entry_size = sizeof(BufferEntry); const std::vector<Discount> &discounts,
util::stream::Chains &primary,
util::stream::Chains &second_in,
util::stream::Chains &gamma_out,
const std::vector<uint64_t> &prune_thresholds) {
for (size_t i = 0; i < primary.size(); ++i) { for (size_t i = 0; i < primary.size(); ++i) {
util::stream::ChainConfig gamma_config = config.adder_out;
if(prune_thresholds[i] > 0)
gamma_config.entry_size = sizeof(HashBufferEntry);
else
gamma_config.entry_size = sizeof(BufferEntry);
util::stream::ChainPosition second(second_in[i].Add()); util::stream::ChainPosition second(second_in[i].Add());
second_in[i] >> util::stream::kRecycle; second_in[i] >> util::stream::kRecycle;
gamma_out.push_back(gamma_config); gamma_out.push_back(gamma_config);
gamma_out[i] >> AddRight(discounts[i], second); gamma_out[i] >> AddRight(discounts[i], second, prune_thresholds[i] > 0);
primary[i] >> MergeRight(config.interpolate_unigrams, gamma_out[i].Add(), discounts[i]); primary[i] >> MergeRight(config.interpolate_unigrams, gamma_out[i].Add(), discounts[i]);
// Don't bother with the OnlyGamma thread for something to discard.
if (i) gamma_out[i] >> OnlyGamma(); // Don't bother with the OnlyGamma thread for something to discard.
if (i) gamma_out[i] >> OnlyGamma(prune_thresholds[i] > 0);
} }
} }

View File

@ -1,14 +1,15 @@
#ifndef LM_BUILDER_INITIAL_PROBABILITIES__ #ifndef LM_BUILDER_INITIAL_PROBABILITIES_H
#define LM_BUILDER_INITIAL_PROBABILITIES__ #define LM_BUILDER_INITIAL_PROBABILITIES_H
#include "lm/builder/discount.hh" #include "lm/builder/discount.hh"
#include "util/stream/config.hh" #include "util/stream/config.hh"
#include <vector> #include <vector>
namespace util { namespace stream { class Chains; } }
namespace lm { namespace lm {
namespace builder { namespace builder {
class Chains;
struct InitialProbabilitiesConfig { struct InitialProbabilitiesConfig {
// These should be small buffers to keep the adder from getting too far ahead // These should be small buffers to keep the adder from getting too far ahead
@ -26,9 +27,15 @@ struct InitialProbabilitiesConfig {
* The values are bare floats and should be buffered for interpolation to * The values are bare floats and should be buffered for interpolation to
* use. * use.
*/ */
void InitialProbabilities(const InitialProbabilitiesConfig &config, const std::vector<Discount> &discounts, Chains &primary, Chains &second_in, Chains &gamma_out); void InitialProbabilities(
const InitialProbabilitiesConfig &config,
const std::vector<Discount> &discounts,
util::stream::Chains &primary,
util::stream::Chains &second_in,
util::stream::Chains &gamma_out,
const std::vector<uint64_t> &prune_thresholds);
} // namespace builder } // namespace builder
} // namespace lm } // namespace lm
#endif // LM_BUILDER_INITIAL_PROBABILITIES__ #endif // LM_BUILDER_INITIAL_PROBABILITIES_H

View File

@ -1,9 +1,12 @@
#include "lm/builder/interpolate.hh" #include "lm/builder/interpolate.hh"
#include "lm/builder/hash_gamma.hh"
#include "lm/builder/joint_order.hh" #include "lm/builder/joint_order.hh"
#include "lm/builder/multi_stream.hh" #include "lm/builder/ngram_stream.hh"
#include "lm/builder/sort.hh" #include "lm/builder/sort.hh"
#include "lm/lm_exception.hh" #include "lm/lm_exception.hh"
#include "util/fixed_array.hh"
#include "util/murmur_hash.hh"
#include <assert.h> #include <assert.h>
@ -12,7 +15,8 @@ namespace {
class Callback { class Callback {
public: public:
Callback(float uniform_prob, const ChainPositions &backoffs) : backoffs_(backoffs.size()), probs_(backoffs.size() + 2) { Callback(float uniform_prob, const util::stream::ChainPositions &backoffs, const std::vector<uint64_t> &prune_thresholds)
: backoffs_(backoffs.size()), probs_(backoffs.size() + 2), prune_thresholds_(prune_thresholds) {
probs_[0] = uniform_prob; probs_[0] = uniform_prob;
for (std::size_t i = 0; i < backoffs.size(); ++i) { for (std::size_t i = 0; i < backoffs.size(); ++i) {
backoffs_.push_back(backoffs[i]); backoffs_.push_back(backoffs[i]);
@ -33,12 +37,37 @@ class Callback {
pay.complete.prob = pay.uninterp.prob + pay.uninterp.gamma * probs_[order_minus_1]; pay.complete.prob = pay.uninterp.prob + pay.uninterp.gamma * probs_[order_minus_1];
probs_[order_minus_1 + 1] = pay.complete.prob; probs_[order_minus_1 + 1] = pay.complete.prob;
pay.complete.prob = log10(pay.complete.prob); pay.complete.prob = log10(pay.complete.prob);
// TODO: this is a hack to skip n-grams that don't appear as context. Pruning will require some different handling.
if (order_minus_1 < backoffs_.size() && *(gram.end() - 1) != kUNK && *(gram.end() - 1) != kEOS) { if (order_minus_1 < backoffs_.size() && *(gram.end() - 1) != kUNK && *(gram.end() - 1) != kEOS) {
pay.complete.backoff = log10(*static_cast<const float*>(backoffs_[order_minus_1].Get())); // This skips over ngrams if backoffs have been exhausted.
++backoffs_[order_minus_1]; if(!backoffs_[order_minus_1]) {
pay.complete.backoff = 0.0;
return;
}
if(prune_thresholds_[order_minus_1 + 1] > 0) {
//Compute hash value for current context
uint64_t current_hash = util::MurmurHashNative(gram.begin(), gram.Order() * sizeof(WordIndex));
const HashGamma *hashed_backoff = static_cast<const HashGamma*>(backoffs_[order_minus_1].Get());
while(backoffs_[order_minus_1] && current_hash != hashed_backoff->hash_value) {
hashed_backoff = static_cast<const HashGamma*>(backoffs_[order_minus_1].Get());
++backoffs_[order_minus_1];
}
if(current_hash == hashed_backoff->hash_value) {
pay.complete.backoff = log10(hashed_backoff->gamma);
++backoffs_[order_minus_1];
} else {
// Has been pruned away so it is not a context anymore
pay.complete.backoff = 0.0;
}
} else {
pay.complete.backoff = log10(*static_cast<const float*>(backoffs_[order_minus_1].Get()));
++backoffs_[order_minus_1];
}
} else { } else {
// Not a context. // Not a context.
pay.complete.backoff = 0.0; pay.complete.backoff = 0.0;
} }
} }
@ -46,19 +75,22 @@ class Callback {
void Exit(unsigned, const NGram &) const {} void Exit(unsigned, const NGram &) const {}
private: private:
FixedArray<util::stream::Stream> backoffs_; util::FixedArray<util::stream::Stream> backoffs_;
std::vector<float> probs_; std::vector<float> probs_;
const std::vector<uint64_t>& prune_thresholds_;
}; };
} // namespace } // namespace
Interpolate::Interpolate(uint64_t unigram_count, const ChainPositions &backoffs) Interpolate::Interpolate(uint64_t vocab_size, const util::stream::ChainPositions &backoffs, const std::vector<uint64_t>& prune_thresholds)
: uniform_prob_(1.0 / static_cast<float>(unigram_count - 1)), backoffs_(backoffs) {} : uniform_prob_(1.0 / static_cast<float>(vocab_size)), // Includes <unk> but excludes <s>.
backoffs_(backoffs),
prune_thresholds_(prune_thresholds) {}
// perform order-wise interpolation // perform order-wise interpolation
void Interpolate::Run(const ChainPositions &positions) { void Interpolate::Run(const util::stream::ChainPositions &positions) {
assert(positions.size() == backoffs_.size() + 1); assert(positions.size() == backoffs_.size() + 1);
Callback callback(uniform_prob_, backoffs_); Callback callback(uniform_prob_, backoffs_, prune_thresholds_);
JointOrder<Callback, SuffixOrder>(positions, callback); JointOrder<Callback, SuffixOrder>(positions, callback);
} }

View File

@ -1,10 +1,12 @@
#ifndef LM_BUILDER_INTERPOLATE__ #ifndef LM_BUILDER_INTERPOLATE_H
#define LM_BUILDER_INTERPOLATE__ #define LM_BUILDER_INTERPOLATE_H
#include "util/stream/multi_stream.hh"
#include <vector>
#include <stdint.h> #include <stdint.h>
#include "lm/builder/multi_stream.hh"
namespace lm { namespace builder { namespace lm { namespace builder {
/* Interpolate step. /* Interpolate step.
@ -14,14 +16,17 @@ namespace lm { namespace builder {
*/ */
class Interpolate { class Interpolate {
public: public:
explicit Interpolate(uint64_t unigram_count, const ChainPositions &backoffs); // Normally vocab_size is the unigram count-1 (since p(<s>) = 0) but might
// be larger when the user specifies a consistent vocabulary size.
explicit Interpolate(uint64_t vocab_size, const util::stream::ChainPositions &backoffs, const std::vector<uint64_t> &prune_thresholds);
void Run(const ChainPositions &positions); void Run(const util::stream::ChainPositions &positions);
private: private:
float uniform_prob_; float uniform_prob_;
ChainPositions backoffs_; util::stream::ChainPositions backoffs_;
const std::vector<uint64_t> prune_thresholds_;
}; };
}} // namespaces }} // namespaces
#endif // LM_BUILDER_INTERPOLATE__ #endif // LM_BUILDER_INTERPOLATE_H

View File

@ -1,14 +1,14 @@
#ifndef LM_BUILDER_JOINT_ORDER__ #ifndef LM_BUILDER_JOINT_ORDER_H
#define LM_BUILDER_JOINT_ORDER__ #define LM_BUILDER_JOINT_ORDER_H
#include "lm/builder/multi_stream.hh" #include "lm/builder/ngram_stream.hh"
#include "lm/lm_exception.hh" #include "lm/lm_exception.hh"
#include <string.h> #include <string.h>
namespace lm { namespace builder { namespace lm { namespace builder {
template <class Callback, class Compare> void JointOrder(const ChainPositions &positions, Callback &callback) { template <class Callback, class Compare> void JointOrder(const util::stream::ChainPositions &positions, Callback &callback) {
// Allow matching to reference streams[-1]. // Allow matching to reference streams[-1].
NGramStreams streams_with_dummy; NGramStreams streams_with_dummy;
streams_with_dummy.InitWithDummy(positions); streams_with_dummy.InitWithDummy(positions);
@ -40,4 +40,4 @@ template <class Callback, class Compare> void JointOrder(const ChainPositions &p
}} // namespaces }} // namespaces
#endif // LM_BUILDER_JOINT_ORDER__ #endif // LM_BUILDER_JOINT_ORDER_H

View File

@ -1,4 +1,5 @@
#include "lm/builder/pipeline.hh" #include "lm/builder/pipeline.hh"
#include "lm/lm_exception.hh"
#include "util/file.hh" #include "util/file.hh"
#include "util/file_piece.hh" #include "util/file_piece.hh"
#include "util/usage.hh" #include "util/usage.hh"
@ -7,6 +8,7 @@
#include <boost/program_options.hpp> #include <boost/program_options.hpp>
#include <boost/version.hpp> #include <boost/version.hpp>
#include <vector>
namespace { namespace {
class SizeNotify { class SizeNotify {
@ -25,6 +27,46 @@ boost::program_options::typed_value<std::string> *SizeOption(std::size_t &to, co
return boost::program_options::value<std::string>()->notifier(SizeNotify(to))->default_value(default_value); return boost::program_options::value<std::string>()->notifier(SizeNotify(to))->default_value(default_value);
} }
// Parse and validate pruning thresholds then return vector of threshold counts
// for each n-grams order.
std::vector<uint64_t> ParsePruning(const std::vector<std::string> &param, std::size_t order) {
// convert to vector of integers
std::vector<uint64_t> prune_thresholds;
prune_thresholds.reserve(order);
std::cerr << "Pruning ";
for (std::vector<std::string>::const_iterator it(param.begin()); it != param.end(); ++it) {
try {
prune_thresholds.push_back(boost::lexical_cast<uint64_t>(*it));
} catch(const boost::bad_lexical_cast &) {
UTIL_THROW(util::Exception, "Bad pruning threshold " << *it);
}
}
// Fill with zeros by default.
if (prune_thresholds.empty()) {
prune_thresholds.resize(order, 0);
return prune_thresholds;
}
// validate pruning threshold if specified
// throw if each n-gram order has not threshold specified
UTIL_THROW_IF(prune_thresholds.size() > order, util::Exception, "You specified pruning thresholds for orders 1 through " << prune_thresholds.size() << " but the model only has order " << order);
// threshold for unigram can only be 0 (no pruning)
UTIL_THROW_IF(prune_thresholds[0] != 0, util::Exception, "Unigram pruning is not implemented, so the first pruning threshold must be 0.");
// check if threshold are not in decreasing order
uint64_t lower_threshold = 0;
for (std::vector<uint64_t>::iterator it = prune_thresholds.begin(); it != prune_thresholds.end(); ++it) {
UTIL_THROW_IF(lower_threshold > *it, util::Exception, "Pruning thresholds should be in non-decreasing order. Otherwise substrings would be removed, which is bad for query-time data structures.");
lower_threshold = *it;
}
// Pad to all orders using the last value.
prune_thresholds.resize(order, prune_thresholds.back());
return prune_thresholds;
}
} // namespace } // namespace
int main(int argc, char *argv[]) { int main(int argc, char *argv[]) {
@ -34,25 +76,30 @@ int main(int argc, char *argv[]) {
lm::builder::PipelineConfig pipeline; lm::builder::PipelineConfig pipeline;
std::string text, arpa; std::string text, arpa;
std::vector<std::string> pruning;
options.add_options() options.add_options()
("help", po::bool_switch(), "Show this help message") ("help,h", po::bool_switch(), "Show this help message")
("order,o", po::value<std::size_t>(&pipeline.order) ("order,o", po::value<std::size_t>(&pipeline.order)
#if BOOST_VERSION >= 104200 #if BOOST_VERSION >= 104200
->required() ->required()
#endif #endif
, "Order of the model") , "Order of the model")
("interpolate_unigrams", po::bool_switch(&pipeline.initial_probs.interpolate_unigrams), "Interpolate the unigrams (default: emulate SRILM by not interpolating)") ("interpolate_unigrams", po::bool_switch(&pipeline.initial_probs.interpolate_unigrams), "Interpolate the unigrams (default: emulate SRILM by not interpolating)")
("skip_symbols", po::bool_switch(), "Treat <s>, </s>, and <unk> as whitespace instead of throwing an exception")
("temp_prefix,T", po::value<std::string>(&pipeline.sort.temp_prefix)->default_value("/tmp/lm"), "Temporary file prefix") ("temp_prefix,T", po::value<std::string>(&pipeline.sort.temp_prefix)->default_value("/tmp/lm"), "Temporary file prefix")
("memory,S", SizeOption(pipeline.sort.total_memory, util::GuessPhysicalMemory() ? "80%" : "1G"), "Sorting memory") ("memory,S", SizeOption(pipeline.sort.total_memory, util::GuessPhysicalMemory() ? "80%" : "1G"), "Sorting memory")
("minimum_block", SizeOption(pipeline.minimum_block, "8K"), "Minimum block size to allow") ("minimum_block", SizeOption(pipeline.minimum_block, "8K"), "Minimum block size to allow")
("sort_block", SizeOption(pipeline.sort.buffer_size, "64M"), "Size of IO operations for sort (determines arity)") ("sort_block", SizeOption(pipeline.sort.buffer_size, "64M"), "Size of IO operations for sort (determines arity)")
("vocab_estimate", po::value<lm::WordIndex>(&pipeline.vocab_estimate)->default_value(1000000), "Assume this vocabulary size for purposes of calculating memory in step 1 (corpus count) and pre-sizing the hash table")
("block_count", po::value<std::size_t>(&pipeline.block_count)->default_value(2), "Block count (per order)") ("block_count", po::value<std::size_t>(&pipeline.block_count)->default_value(2), "Block count (per order)")
("vocab_file", po::value<std::string>(&pipeline.vocab_file)->default_value(""), "Location to write vocabulary file") ("vocab_estimate", po::value<lm::WordIndex>(&pipeline.vocab_estimate)->default_value(1000000), "Assume this vocabulary size for purposes of calculating memory in step 1 (corpus count) and pre-sizing the hash table")
("vocab_file", po::value<std::string>(&pipeline.vocab_file)->default_value(""), "Location to write a file containing the unique vocabulary strings delimited by null bytes")
("vocab_pad", po::value<uint64_t>(&pipeline.vocab_size_for_unk)->default_value(0), "If the vocabulary is smaller than this value, pad with <unk> to reach this size. Requires --interpolate_unigrams")
("verbose_header", po::bool_switch(&pipeline.verbose_header), "Add a verbose header to the ARPA file that includes information such as token count, smoothing type, etc.") ("verbose_header", po::bool_switch(&pipeline.verbose_header), "Add a verbose header to the ARPA file that includes information such as token count, smoothing type, etc.")
("text", po::value<std::string>(&text), "Read text from a file instead of stdin") ("text", po::value<std::string>(&text), "Read text from a file instead of stdin")
("arpa", po::value<std::string>(&arpa), "Write ARPA to a file instead of stdout"); ("arpa", po::value<std::string>(&arpa), "Write ARPA to a file instead of stdout")
("prune", po::value<std::vector<std::string> >(&pruning)->multitoken(), "Prune n-grams with count less than or equal to the given threshold. Specify one value for each order i.e. 0 0 1 to prune singleton trigrams and above. The sequence of values must be non-decreasing and the last value applies to any remaining orders. Unigram pruning is not implemented, so the first value must be zero. Default is to not prune, which is equivalent to --prune 0.");
po::variables_map vm; po::variables_map vm;
po::store(po::parse_command_line(argc, argv, options), vm); po::store(po::parse_command_line(argc, argv, options), vm);
@ -95,6 +142,20 @@ int main(int argc, char *argv[]) {
} }
#endif #endif
if (pipeline.vocab_size_for_unk && !pipeline.initial_probs.interpolate_unigrams) {
std::cerr << "--vocab_pad requires --interpolate_unigrams" << std::endl;
return 1;
}
if (vm["skip_symbols"].as<bool>()) {
pipeline.disallowed_symbol_action = lm::COMPLAIN;
} else {
pipeline.disallowed_symbol_action = lm::THROW_UP;
}
// parse pruning thresholds. These depend on order, so it is not done as a notifier.
pipeline.prune_thresholds = ParsePruning(pruning, pipeline.order);
util::NormalizeTempPrefix(pipeline.sort.temp_prefix); util::NormalizeTempPrefix(pipeline.sort.temp_prefix);
lm::builder::InitialProbabilitiesConfig &initial = pipeline.initial_probs; lm::builder::InitialProbabilitiesConfig &initial = pipeline.initial_probs;

View File

@ -1,180 +0,0 @@
#ifndef LM_BUILDER_MULTI_STREAM__
#define LM_BUILDER_MULTI_STREAM__
#include "lm/builder/ngram_stream.hh"
#include "util/scoped.hh"
#include "util/stream/chain.hh"
#include <cstddef>
#include <new>
#include <assert.h>
#include <stdlib.h>
namespace lm { namespace builder {
template <class T> class FixedArray {
public:
explicit FixedArray(std::size_t count) {
Init(count);
}
FixedArray() : newed_end_(NULL) {}
void Init(std::size_t count) {
assert(!block_.get());
block_.reset(malloc(sizeof(T) * count));
if (!block_.get()) throw std::bad_alloc();
newed_end_ = begin();
}
FixedArray(const FixedArray &from) {
std::size_t size = from.newed_end_ - static_cast<const T*>(from.block_.get());
Init(size);
for (std::size_t i = 0; i < size; ++i) {
new(end()) T(from[i]);
Constructed();
}
}
~FixedArray() { clear(); }
T *begin() { return static_cast<T*>(block_.get()); }
const T *begin() const { return static_cast<const T*>(block_.get()); }
// Always call Constructed after successful completion of new.
T *end() { return newed_end_; }
const T *end() const { return newed_end_; }
T &back() { return *(end() - 1); }
const T &back() const { return *(end() - 1); }
std::size_t size() const { return end() - begin(); }
bool empty() const { return begin() == end(); }
T &operator[](std::size_t i) { return begin()[i]; }
const T &operator[](std::size_t i) const { return begin()[i]; }
template <class C> void push_back(const C &c) {
new (end()) T(c);
Constructed();
}
void clear() {
for (T *i = begin(); i != end(); ++i)
i->~T();
newed_end_ = begin();
}
protected:
void Constructed() {
++newed_end_;
}
private:
util::scoped_malloc block_;
T *newed_end_;
};
class Chains;
class ChainPositions : public FixedArray<util::stream::ChainPosition> {
public:
ChainPositions() {}
void Init(Chains &chains);
explicit ChainPositions(Chains &chains) {
Init(chains);
}
};
class Chains : public FixedArray<util::stream::Chain> {
private:
template <class T, void (T::*ptr)(const ChainPositions &) = &T::Run> struct CheckForRun {
typedef Chains type;
};
public:
explicit Chains(std::size_t limit) : FixedArray<util::stream::Chain>(limit) {}
template <class Worker> typename CheckForRun<Worker>::type &operator>>(const Worker &worker) {
threads_.push_back(new util::stream::Thread(ChainPositions(*this), worker));
return *this;
}
template <class Worker> typename CheckForRun<Worker>::type &operator>>(const boost::reference_wrapper<Worker> &worker) {
threads_.push_back(new util::stream::Thread(ChainPositions(*this), worker));
return *this;
}
Chains &operator>>(const util::stream::Recycler &recycler) {
for (util::stream::Chain *i = begin(); i != end(); ++i)
*i >> recycler;
return *this;
}
void Wait(bool release_memory = true) {
threads_.clear();
for (util::stream::Chain *i = begin(); i != end(); ++i) {
i->Wait(release_memory);
}
}
private:
boost::ptr_vector<util::stream::Thread> threads_;
Chains(const Chains &);
void operator=(const Chains &);
};
inline void ChainPositions::Init(Chains &chains) {
FixedArray<util::stream::ChainPosition>::Init(chains.size());
for (util::stream::Chain *i = chains.begin(); i != chains.end(); ++i) {
new (end()) util::stream::ChainPosition(i->Add()); Constructed();
}
}
inline Chains &operator>>(Chains &chains, ChainPositions &positions) {
positions.Init(chains);
return chains;
}
class NGramStreams : public FixedArray<NGramStream> {
public:
NGramStreams() {}
// This puts a dummy NGramStream at the beginning (useful to algorithms that need to reference something at the beginning).
void InitWithDummy(const ChainPositions &positions) {
FixedArray<NGramStream>::Init(positions.size() + 1);
new (end()) NGramStream(); Constructed();
for (const util::stream::ChainPosition *i = positions.begin(); i != positions.end(); ++i) {
push_back(*i);
}
}
// Limit restricts to positions[0,limit)
void Init(const ChainPositions &positions, std::size_t limit) {
FixedArray<NGramStream>::Init(limit);
for (const util::stream::ChainPosition *i = positions.begin(); i != positions.begin() + limit; ++i) {
push_back(*i);
}
}
void Init(const ChainPositions &positions) {
Init(positions, positions.size());
}
NGramStreams(const ChainPositions &positions) {
Init(positions);
}
};
inline Chains &operator>>(Chains &chains, NGramStreams &streams) {
ChainPositions positions;
chains >> positions;
streams.Init(positions);
return chains;
}
}} // namespaces
#endif // LM_BUILDER_MULTI_STREAM__

View File

@ -1,5 +1,5 @@
#ifndef LM_BUILDER_NGRAM__ #ifndef LM_BUILDER_NGRAM_H
#define LM_BUILDER_NGRAM__ #define LM_BUILDER_NGRAM_H
#include "lm/weights.hh" #include "lm/weights.hh"
#include "lm/word_index.hh" #include "lm/word_index.hh"
@ -26,7 +26,7 @@ union Payload {
class NGram { class NGram {
public: public:
NGram(void *begin, std::size_t order) NGram(void *begin, std::size_t order)
: begin_(static_cast<WordIndex*>(begin)), end_(begin_ + order) {} : begin_(static_cast<WordIndex*>(begin)), end_(begin_ + order) {}
const uint8_t *Base() const { return reinterpret_cast<const uint8_t*>(begin_); } const uint8_t *Base() const { return reinterpret_cast<const uint8_t*>(begin_); }
@ -38,12 +38,12 @@ class NGram {
end_ = begin_ + difference; end_ = begin_ + difference;
} }
// Would do operator++ but that can get confusing for a stream. // Would do operator++ but that can get confusing for a stream.
void NextInMemory() { void NextInMemory() {
ReBase(&Value() + 1); ReBase(&Value() + 1);
} }
// Lower-case in deference to STL. // Lower-case in deference to STL.
const WordIndex *begin() const { return begin_; } const WordIndex *begin() const { return begin_; }
WordIndex *begin() { return begin_; } WordIndex *begin() { return begin_; }
const WordIndex *end() const { return end_; } const WordIndex *end() const { return end_; }
@ -61,7 +61,7 @@ class NGram {
return order * sizeof(WordIndex) + sizeof(Payload); return order * sizeof(WordIndex) + sizeof(Payload);
} }
std::size_t TotalSize() const { std::size_t TotalSize() const {
// Compiler should optimize this. // Compiler should optimize this.
return TotalSize(Order()); return TotalSize(Order());
} }
static std::size_t OrderFromSize(std::size_t size) { static std::size_t OrderFromSize(std::size_t size) {
@ -69,6 +69,31 @@ class NGram {
assert(size == TotalSize(ret)); assert(size == TotalSize(ret));
return ret; return ret;
} }
// manipulate msb to signal that ngram can be pruned
/*mjd**********************************************************************/
bool IsMarked() const {
return Value().count >> (sizeof(Value().count) * 8 - 1);
}
void Mark() {
Value().count |= (1ul << (sizeof(Value().count) * 8 - 1));
}
void Unmark() {
Value().count &= ~(1ul << (sizeof(Value().count) * 8 - 1));
}
uint64_t UnmarkedCount() const {
return Value().count & ~(1ul << (sizeof(Value().count) * 8 - 1));
}
uint64_t CutoffCount() const {
return IsMarked() ? 0 : UnmarkedCount();
}
/*mjd**********************************************************************/
private: private:
WordIndex *begin_, *end_; WordIndex *begin_, *end_;
@ -81,4 +106,4 @@ const WordIndex kEOS = 2;
} // namespace builder } // namespace builder
} // namespace lm } // namespace lm
#endif // LM_BUILDER_NGRAM__ #endif // LM_BUILDER_NGRAM_H

View File

@ -1,8 +1,9 @@
#ifndef LM_BUILDER_NGRAM_STREAM__ #ifndef LM_BUILDER_NGRAM_STREAM_H
#define LM_BUILDER_NGRAM_STREAM__ #define LM_BUILDER_NGRAM_STREAM_H
#include "lm/builder/ngram.hh" #include "lm/builder/ngram.hh"
#include "util/stream/chain.hh" #include "util/stream/chain.hh"
#include "util/stream/multi_stream.hh"
#include "util/stream/stream.hh" #include "util/stream/stream.hh"
#include <cstddef> #include <cstddef>
@ -51,5 +52,7 @@ inline util::stream::Chain &operator>>(util::stream::Chain &chain, NGramStream &
return chain; return chain;
} }
typedef util::stream::GenericStreams<NGramStream> NGramStreams;
}} // namespaces }} // namespaces
#endif // LM_BUILDER_NGRAM_STREAM__ #endif // LM_BUILDER_NGRAM_STREAM_H

View File

@ -2,6 +2,7 @@
#include "lm/builder/adjust_counts.hh" #include "lm/builder/adjust_counts.hh"
#include "lm/builder/corpus_count.hh" #include "lm/builder/corpus_count.hh"
#include "lm/builder/hash_gamma.hh"
#include "lm/builder/initial_probabilities.hh" #include "lm/builder/initial_probabilities.hh"
#include "lm/builder/interpolate.hh" #include "lm/builder/interpolate.hh"
#include "lm/builder/print.hh" #include "lm/builder/print.hh"
@ -20,10 +21,13 @@
namespace lm { namespace builder { namespace lm { namespace builder {
namespace { namespace {
void PrintStatistics(const std::vector<uint64_t> &counts, const std::vector<Discount> &discounts) { void PrintStatistics(const std::vector<uint64_t> &counts, const std::vector<uint64_t> &counts_pruned, const std::vector<Discount> &discounts) {
std::cerr << "Statistics:\n"; std::cerr << "Statistics:\n";
for (size_t i = 0; i < counts.size(); ++i) { for (size_t i = 0; i < counts.size(); ++i) {
std::cerr << (i + 1) << ' ' << counts[i]; std::cerr << (i + 1) << ' ' << counts_pruned[i];
if(counts[i] != counts_pruned[i])
std::cerr << "/" << counts[i];
for (size_t d = 1; d <= 3; ++d) for (size_t d = 1; d <= 3; ++d)
std::cerr << " D" << d << (d == 3 ? "+=" : "=") << discounts[i].amount[d]; std::cerr << " D" << d << (d == 3 ? "+=" : "=") << discounts[i].amount[d];
std::cerr << '\n'; std::cerr << '\n';
@ -39,7 +43,7 @@ class Master {
const PipelineConfig &Config() const { return config_; } const PipelineConfig &Config() const { return config_; }
Chains &MutableChains() { return chains_; } util::stream::Chains &MutableChains() { return chains_; }
template <class T> Master &operator>>(const T &worker) { template <class T> Master &operator>>(const T &worker) {
chains_ >> worker; chains_ >> worker;
@ -64,7 +68,7 @@ class Master {
} }
// For initial probabilities, but this is generic. // For initial probabilities, but this is generic.
void SortAndReadTwice(const std::vector<uint64_t> &counts, Sorts<ContextOrder> &sorts, Chains &second, util::stream::ChainConfig second_config) { void SortAndReadTwice(const std::vector<uint64_t> &counts, Sorts<ContextOrder> &sorts, util::stream::Chains &second, util::stream::ChainConfig second_config) {
// Do merge first before allocating chain memory. // Do merge first before allocating chain memory.
for (std::size_t i = 1; i < config_.order; ++i) { for (std::size_t i = 1; i < config_.order; ++i) {
sorts[i - 1].Merge(0); sorts[i - 1].Merge(0);
@ -198,9 +202,9 @@ class Master {
PipelineConfig config_; PipelineConfig config_;
Chains chains_; util::stream::Chains chains_;
// Often only unigrams, but sometimes all orders. // Often only unigrams, but sometimes all orders.
FixedArray<util::stream::FileBuffer> files_; util::FixedArray<util::stream::FileBuffer> files_;
}; };
void CountText(int text_file /* input */, int vocab_file /* output */, Master &master, uint64_t &token_count, std::string &text_file_name) { void CountText(int text_file /* input */, int vocab_file /* output */, Master &master, uint64_t &token_count, std::string &text_file_name) {
@ -221,7 +225,7 @@ void CountText(int text_file /* input */, int vocab_file /* output */, Master &m
WordIndex type_count = config.vocab_estimate; WordIndex type_count = config.vocab_estimate;
util::FilePiece text(text_file, NULL, &std::cerr); util::FilePiece text(text_file, NULL, &std::cerr);
text_file_name = text.FileName(); text_file_name = text.FileName();
CorpusCount counter(text, vocab_file, token_count, type_count, chain.BlockSize() / chain.EntrySize()); CorpusCount counter(text, vocab_file, token_count, type_count, chain.BlockSize() / chain.EntrySize(), config.disallowed_symbol_action);
chain >> boost::ref(counter); chain >> boost::ref(counter);
util::stream::Sort<SuffixOrder, AddCombiner> sorter(chain, config.sort, SuffixOrder(config.order), AddCombiner()); util::stream::Sort<SuffixOrder, AddCombiner> sorter(chain, config.sort, SuffixOrder(config.order), AddCombiner());
@ -231,21 +235,22 @@ void CountText(int text_file /* input */, int vocab_file /* output */, Master &m
master.InitForAdjust(sorter, type_count); master.InitForAdjust(sorter, type_count);
} }
void InitialProbabilities(const std::vector<uint64_t> &counts, const std::vector<Discount> &discounts, Master &master, Sorts<SuffixOrder> &primary, FixedArray<util::stream::FileBuffer> &gammas) { void InitialProbabilities(const std::vector<uint64_t> &counts, const std::vector<uint64_t> &counts_pruned, const std::vector<Discount> &discounts, Master &master, Sorts<SuffixOrder> &primary,
util::FixedArray<util::stream::FileBuffer> &gammas, const std::vector<uint64_t> &prune_thresholds) {
const PipelineConfig &config = master.Config(); const PipelineConfig &config = master.Config();
Chains second(config.order); util::stream::Chains second(config.order);
{ {
Sorts<ContextOrder> sorts; Sorts<ContextOrder> sorts;
master.SetupSorts(sorts); master.SetupSorts(sorts);
PrintStatistics(counts, discounts); PrintStatistics(counts, counts_pruned, discounts);
lm::ngram::ShowSizes(counts); lm::ngram::ShowSizes(counts_pruned);
std::cerr << "=== 3/5 Calculating and sorting initial probabilities ===" << std::endl; std::cerr << "=== 3/5 Calculating and sorting initial probabilities ===" << std::endl;
master.SortAndReadTwice(counts, sorts, second, config.initial_probs.adder_in); master.SortAndReadTwice(counts_pruned, sorts, second, config.initial_probs.adder_in);
} }
Chains gamma_chains(config.order); util::stream::Chains gamma_chains(config.order);
InitialProbabilities(config.initial_probs, discounts, master.MutableChains(), second, gamma_chains); InitialProbabilities(config.initial_probs, discounts, master.MutableChains(), second, gamma_chains, prune_thresholds);
// Don't care about gamma for 0. // Don't care about gamma for 0.
gamma_chains[0] >> util::stream::kRecycle; gamma_chains[0] >> util::stream::kRecycle;
gammas.Init(config.order - 1); gammas.Init(config.order - 1);
@ -257,19 +262,25 @@ void InitialProbabilities(const std::vector<uint64_t> &counts, const std::vector
master.SetupSorts(primary); master.SetupSorts(primary);
} }
void InterpolateProbabilities(const std::vector<uint64_t> &counts, Master &master, Sorts<SuffixOrder> &primary, FixedArray<util::stream::FileBuffer> &gammas) { void InterpolateProbabilities(const std::vector<uint64_t> &counts, Master &master, Sorts<SuffixOrder> &primary, util::FixedArray<util::stream::FileBuffer> &gammas) {
std::cerr << "=== 4/5 Calculating and writing order-interpolated probabilities ===" << std::endl; std::cerr << "=== 4/5 Calculating and writing order-interpolated probabilities ===" << std::endl;
const PipelineConfig &config = master.Config(); const PipelineConfig &config = master.Config();
master.MaximumLazyInput(counts, primary); master.MaximumLazyInput(counts, primary);
Chains gamma_chains(config.order - 1); util::stream::Chains gamma_chains(config.order - 1);
util::stream::ChainConfig read_backoffs(config.read_backoffs);
read_backoffs.entry_size = sizeof(float);
for (std::size_t i = 0; i < config.order - 1; ++i) { for (std::size_t i = 0; i < config.order - 1; ++i) {
util::stream::ChainConfig read_backoffs(config.read_backoffs);
// Add 1 because here we are skipping unigrams
if(config.prune_thresholds[i + 1] > 0)
read_backoffs.entry_size = sizeof(HashGamma);
else
read_backoffs.entry_size = sizeof(float);
gamma_chains.push_back(read_backoffs); gamma_chains.push_back(read_backoffs);
gamma_chains.back() >> gammas[i].Source(); gamma_chains.back() >> gammas[i].Source();
} }
master >> Interpolate(counts[0], ChainPositions(gamma_chains)); master >> Interpolate(std::max(master.Config().vocab_size_for_unk, counts[0] - 1 /* <s> is not included */), util::stream::ChainPositions(gamma_chains), config.prune_thresholds);
gamma_chains >> util::stream::kRecycle; gamma_chains >> util::stream::kRecycle;
master.BufferFinal(counts); master.BufferFinal(counts);
} }
@ -301,21 +312,22 @@ void Pipeline(PipelineConfig config, int text_file, int out_arpa) {
CountText(text_file, vocab_file.get(), master, token_count, text_file_name); CountText(text_file, vocab_file.get(), master, token_count, text_file_name);
std::vector<uint64_t> counts; std::vector<uint64_t> counts;
std::vector<uint64_t> counts_pruned;
std::vector<Discount> discounts; std::vector<Discount> discounts;
master >> AdjustCounts(counts, discounts); master >> AdjustCounts(counts, counts_pruned, discounts, config.prune_thresholds);
{ {
FixedArray<util::stream::FileBuffer> gammas; util::FixedArray<util::stream::FileBuffer> gammas;
Sorts<SuffixOrder> primary; Sorts<SuffixOrder> primary;
InitialProbabilities(counts, discounts, master, primary, gammas); InitialProbabilities(counts, counts_pruned, discounts, master, primary, gammas, config.prune_thresholds);
InterpolateProbabilities(counts, master, primary, gammas); InterpolateProbabilities(counts_pruned, master, primary, gammas);
} }
std::cerr << "=== 5/5 Writing ARPA model ===" << std::endl; std::cerr << "=== 5/5 Writing ARPA model ===" << std::endl;
VocabReconstitute vocab(vocab_file.get()); VocabReconstitute vocab(vocab_file.get());
UTIL_THROW_IF(vocab.Size() != counts[0], util::Exception, "Vocab words don't match up. Is there a null byte in the input?"); UTIL_THROW_IF(vocab.Size() != counts[0], util::Exception, "Vocab words don't match up. Is there a null byte in the input?");
HeaderInfo header_info(text_file_name, token_count); HeaderInfo header_info(text_file_name, token_count);
master >> PrintARPA(vocab, counts, (config.verbose_header ? &header_info : NULL), out_arpa) >> util::stream::kRecycle; master >> PrintARPA(vocab, counts_pruned, (config.verbose_header ? &header_info : NULL), out_arpa) >> util::stream::kRecycle;
master.MutableChains().Wait(true); master.MutableChains().Wait(true);
} }

View File

@ -1,8 +1,9 @@
#ifndef LM_BUILDER_PIPELINE__ #ifndef LM_BUILDER_PIPELINE_H
#define LM_BUILDER_PIPELINE__ #define LM_BUILDER_PIPELINE_H
#include "lm/builder/initial_probabilities.hh" #include "lm/builder/initial_probabilities.hh"
#include "lm/builder/header_info.hh" #include "lm/builder/header_info.hh"
#include "lm/lm_exception.hh"
#include "lm/word_index.hh" #include "lm/word_index.hh"
#include "util/stream/config.hh" #include "util/stream/config.hh"
#include "util/file_piece.hh" #include "util/file_piece.hh"
@ -30,6 +31,28 @@ struct PipelineConfig {
// Number of blocks to use. This will be overridden to 1 if everything fits. // Number of blocks to use. This will be overridden to 1 if everything fits.
std::size_t block_count; std::size_t block_count;
// n-gram count thresholds for pruning. 0 values means no pruning for
// corresponding n-gram order
std::vector<uint64_t> prune_thresholds; //mjd
/* Computing the perplexity of LMs with different vocabularies is hard. For
* example, the lowest perplexity is attained by a unigram model that
* predicts p(<unk>) = 1 and has no other vocabulary. Also, linearly
* interpolated models will sum to more than 1 because <unk> is duplicated
* (SRI just pretends p(<unk>) = 0 for these purposes, which makes it sum to
* 1 but comes with its own problems). This option will make the vocabulary
* a particular size by replicating <unk> multiple times for purposes of
* computing vocabulary size. It has no effect if the actual vocabulary is
* larger. This parameter serves the same purpose as IRSTLM's "dub".
*/
uint64_t vocab_size_for_unk;
/* What to do the first time <s>, </s>, or <unk> appears in the input. If
* this is anything but THROW_UP, then the symbol will always be treated as
* whitespace.
*/
WarningAction disallowed_symbol_action;
const std::string &TempPrefix() const { return sort.temp_prefix; } const std::string &TempPrefix() const { return sort.temp_prefix; }
std::size_t TotalMemory() const { return sort.total_memory; } std::size_t TotalMemory() const { return sort.total_memory; }
}; };
@ -38,4 +61,4 @@ struct PipelineConfig {
void Pipeline(PipelineConfig config, int text_file, int out_arpa); void Pipeline(PipelineConfig config, int text_file, int out_arpa);
}} // namespaces }} // namespaces
#endif // LM_BUILDER_PIPELINE__ #endif // LM_BUILDER_PIPELINE_H

View File

@ -42,14 +42,14 @@ PrintARPA::PrintARPA(const VocabReconstitute &vocab, const std::vector<uint64_t>
util::WriteOrThrow(out_fd, as_string.data(), as_string.size()); util::WriteOrThrow(out_fd, as_string.data(), as_string.size());
} }
void PrintARPA::Run(const ChainPositions &positions) { void PrintARPA::Run(const util::stream::ChainPositions &positions) {
util::scoped_fd closer(out_fd_); util::scoped_fd closer(out_fd_);
UTIL_TIMER("(%w s) Wrote ARPA file\n"); UTIL_TIMER("(%w s) Wrote ARPA file\n");
util::FakeOFStream out(out_fd_); util::FakeOFStream out(out_fd_);
for (unsigned order = 1; order <= positions.size(); ++order) { for (unsigned order = 1; order <= positions.size(); ++order) {
out << "\\" << order << "-grams:" << '\n'; out << "\\" << order << "-grams:" << '\n';
for (NGramStream stream(positions[order - 1]); stream; ++stream) { for (NGramStream stream(positions[order - 1]); stream; ++stream) {
// Correcting for numerical precision issues. Take that IRST. // Correcting for numerical precision issues. Take that IRST.
out << std::min(0.0f, stream->Value().complete.prob) << '\t' << vocab_.Lookup(*stream->begin()); out << std::min(0.0f, stream->Value().complete.prob) << '\t' << vocab_.Lookup(*stream->begin());
for (const WordIndex *i = stream->begin() + 1; i != stream->end(); ++i) { for (const WordIndex *i = stream->begin() + 1; i != stream->end(); ++i) {
out << ' ' << vocab_.Lookup(*i); out << ' ' << vocab_.Lookup(*i);
@ -58,6 +58,7 @@ void PrintARPA::Run(const ChainPositions &positions) {
if (backoff != 0.0) if (backoff != 0.0)
out << '\t' << backoff; out << '\t' << backoff;
out << '\n'; out << '\n';
} }
out << '\n'; out << '\n';
} }

View File

@ -1,8 +1,8 @@
#ifndef LM_BUILDER_PRINT__ #ifndef LM_BUILDER_PRINT_H
#define LM_BUILDER_PRINT__ #define LM_BUILDER_PRINT_H
#include "lm/builder/ngram.hh" #include "lm/builder/ngram.hh"
#include "lm/builder/multi_stream.hh" #include "lm/builder/ngram_stream.hh"
#include "lm/builder/header_info.hh" #include "lm/builder/header_info.hh"
#include "util/file.hh" #include "util/file.hh"
#include "util/mmap.hh" #include "util/mmap.hh"
@ -59,7 +59,7 @@ template <class V> class Print {
public: public:
explicit Print(const VocabReconstitute &vocab, std::ostream &to) : vocab_(vocab), to_(to) {} explicit Print(const VocabReconstitute &vocab, std::ostream &to) : vocab_(vocab), to_(to) {}
void Run(const ChainPositions &chains) { void Run(const util::stream::ChainPositions &chains) {
NGramStreams streams(chains); NGramStreams streams(chains);
for (NGramStream *s = streams.begin(); s != streams.end(); ++s) { for (NGramStream *s = streams.begin(); s != streams.end(); ++s) {
DumpStream(*s); DumpStream(*s);
@ -92,7 +92,7 @@ class PrintARPA {
// Takes ownership of out_fd upon Run(). // Takes ownership of out_fd upon Run().
explicit PrintARPA(const VocabReconstitute &vocab, const std::vector<uint64_t> &counts, const HeaderInfo* header_info, int out_fd); explicit PrintARPA(const VocabReconstitute &vocab, const std::vector<uint64_t> &counts, const HeaderInfo* header_info, int out_fd);
void Run(const ChainPositions &positions); void Run(const util::stream::ChainPositions &positions);
private: private:
const VocabReconstitute &vocab_; const VocabReconstitute &vocab_;
@ -100,4 +100,4 @@ class PrintARPA {
}; };
}} // namespaces }} // namespaces
#endif // LM_BUILDER_PRINT__ #endif // LM_BUILDER_PRINT_H

View File

@ -1,7 +1,7 @@
#ifndef LM_BUILDER_SORT__ #ifndef LM_BUILDER_SORT_H
#define LM_BUILDER_SORT__ #define LM_BUILDER_SORT_H
#include "lm/builder/multi_stream.hh" #include "lm/builder/ngram_stream.hh"
#include "lm/builder/ngram.hh" #include "lm/builder/ngram.hh"
#include "lm/word_index.hh" #include "lm/word_index.hh"
#include "util/stream/sort.hh" #include "util/stream/sort.hh"
@ -14,24 +14,71 @@
namespace lm { namespace lm {
namespace builder { namespace builder {
/**
* Abstract parent class for defining custom n-gram comparators.
*/
template <class Child> class Comparator : public std::binary_function<const void *, const void *, bool> { template <class Child> class Comparator : public std::binary_function<const void *, const void *, bool> {
public: public:
/**
* Constructs a comparator capable of comparing two n-grams.
*
* @param order Number of words in each n-gram
*/
explicit Comparator(std::size_t order) : order_(order) {} explicit Comparator(std::size_t order) : order_(order) {}
/**
* Applies the comparator using the Compare method that must be defined in any class that inherits from this class.
*
* @param lhs A pointer to the n-gram on the left-hand side of the comparison
* @param rhs A pointer to the n-gram on the right-hand side of the comparison
*
* @see ContextOrder::Compare
* @see PrefixOrder::Compare
* @see SuffixOrder::Compare
*/
inline bool operator()(const void *lhs, const void *rhs) const { inline bool operator()(const void *lhs, const void *rhs) const {
return static_cast<const Child*>(this)->Compare(static_cast<const WordIndex*>(lhs), static_cast<const WordIndex*>(rhs)); return static_cast<const Child*>(this)->Compare(static_cast<const WordIndex*>(lhs), static_cast<const WordIndex*>(rhs));
} }
/** Gets the n-gram order defined for this comparator. */
std::size_t Order() const { return order_; } std::size_t Order() const { return order_; }
protected: protected:
std::size_t order_; std::size_t order_;
}; };
/**
* N-gram comparator that compares n-grams according to their reverse (suffix) order.
*
* This comparator compares n-grams lexicographically, one word at a time,
* beginning with the last word of each n-gram and ending with the first word of each n-gram.
*
* Some examples of n-gram comparisons as defined by this comparator:
* - a b c == a b c
* - a b c < a b d
* - a b c > a d b
* - a b c > a b b
* - a b c > x a c
* - a b c < x y z
*/
class SuffixOrder : public Comparator<SuffixOrder> { class SuffixOrder : public Comparator<SuffixOrder> {
public: public:
/**
* Constructs a comparator capable of comparing two n-grams.
*
* @param order Number of words in each n-gram
*/
explicit SuffixOrder(std::size_t order) : Comparator<SuffixOrder>(order) {} explicit SuffixOrder(std::size_t order) : Comparator<SuffixOrder>(order) {}
/**
* Compares two n-grams lexicographically, one word at a time,
* beginning with the last word of each n-gram and ending with the first word of each n-gram.
*
* @param lhs A pointer to the n-gram on the left-hand side of the comparison
* @param rhs A pointer to the n-gram on the right-hand side of the comparison
*/
inline bool Compare(const WordIndex *lhs, const WordIndex *rhs) const { inline bool Compare(const WordIndex *lhs, const WordIndex *rhs) const {
for (std::size_t i = order_ - 1; i != 0; --i) { for (std::size_t i = order_ - 1; i != 0; --i) {
if (lhs[i] != rhs[i]) if (lhs[i] != rhs[i])
@ -43,10 +90,40 @@ class SuffixOrder : public Comparator<SuffixOrder> {
static const unsigned kMatchOffset = 1; static const unsigned kMatchOffset = 1;
}; };
/**
* N-gram comparator that compares n-grams according to the reverse (suffix) order of the n-gram context.
*
* This comparator compares n-grams lexicographically, one word at a time,
* beginning with the penultimate word of each n-gram and ending with the first word of each n-gram;
* finally, this comparator compares the last word of each n-gram.
*
* Some examples of n-gram comparisons as defined by this comparator:
* - a b c == a b c
* - a b c < a b d
* - a b c < a d b
* - a b c > a b b
* - a b c > x a c
* - a b c < x y z
*/
class ContextOrder : public Comparator<ContextOrder> { class ContextOrder : public Comparator<ContextOrder> {
public: public:
/**
* Constructs a comparator capable of comparing two n-grams.
*
* @param order Number of words in each n-gram
*/
explicit ContextOrder(std::size_t order) : Comparator<ContextOrder>(order) {} explicit ContextOrder(std::size_t order) : Comparator<ContextOrder>(order) {}
/**
* Compares two n-grams lexicographically, one word at a time,
* beginning with the penultimate word of each n-gram and ending with the first word of each n-gram;
* finally, this comparator compares the last word of each n-gram.
*
* @param lhs A pointer to the n-gram on the left-hand side of the comparison
* @param rhs A pointer to the n-gram on the right-hand side of the comparison
*/
inline bool Compare(const WordIndex *lhs, const WordIndex *rhs) const { inline bool Compare(const WordIndex *lhs, const WordIndex *rhs) const {
for (int i = order_ - 2; i >= 0; --i) { for (int i = order_ - 2; i >= 0; --i) {
if (lhs[i] != rhs[i]) if (lhs[i] != rhs[i])
@ -56,10 +133,37 @@ class ContextOrder : public Comparator<ContextOrder> {
} }
}; };
/**
* N-gram comparator that compares n-grams according to their natural (prefix) order.
*
* This comparator compares n-grams lexicographically, one word at a time,
* beginning with the first word of each n-gram and ending with the last word of each n-gram.
*
* Some examples of n-gram comparisons as defined by this comparator:
* - a b c == a b c
* - a b c < a b d
* - a b c < a d b
* - a b c > a b b
* - a b c < x a c
* - a b c < x y z
*/
class PrefixOrder : public Comparator<PrefixOrder> { class PrefixOrder : public Comparator<PrefixOrder> {
public: public:
/**
* Constructs a comparator capable of comparing two n-grams.
*
* @param order Number of words in each n-gram
*/
explicit PrefixOrder(std::size_t order) : Comparator<PrefixOrder>(order) {} explicit PrefixOrder(std::size_t order) : Comparator<PrefixOrder>(order) {}
/**
* Compares two n-grams lexicographically, one word at a time,
* beginning with the first word of each n-gram and ending with the last word of each n-gram.
*
* @param lhs A pointer to the n-gram on the left-hand side of the comparison
* @param rhs A pointer to the n-gram on the right-hand side of the comparison
*/
inline bool Compare(const WordIndex *lhs, const WordIndex *rhs) const { inline bool Compare(const WordIndex *lhs, const WordIndex *rhs) const {
for (std::size_t i = 0; i < order_; ++i) { for (std::size_t i = 0; i < order_; ++i) {
if (lhs[i] != rhs[i]) if (lhs[i] != rhs[i])
@ -84,15 +188,52 @@ struct AddCombiner {
}; };
// The combiner is only used on a single chain, so I didn't bother to allow // The combiner is only used on a single chain, so I didn't bother to allow
// that template. // that template.
template <class Compare> class Sorts : public FixedArray<util::stream::Sort<Compare> > { /**
* Represents an @ref util::FixedArray "array" capable of storing @ref util::stream::Sort "Sort" objects.
*
* In the anticipated use case, an instance of this class will maintain one @ref util::stream::Sort "Sort" object
* for each n-gram order (ranging from 1 up to the maximum n-gram order being processed).
* Use in this manner would enable the n-grams each n-gram order to be sorted, in parallel.
*
* @tparam Compare An @ref Comparator "ngram comparator" to use during sorting.
*/
template <class Compare> class Sorts : public util::FixedArray<util::stream::Sort<Compare> > {
private: private:
typedef util::stream::Sort<Compare> S; typedef util::stream::Sort<Compare> S;
typedef FixedArray<S> P; typedef util::FixedArray<S> P;
public: public:
/**
* Constructs, but does not initialize.
*
* @ref util::FixedArray::Init() "Init" must be called before use.
*
* @see util::FixedArray::Init()
*/
Sorts() {}
/**
* Constructs an @ref util::FixedArray "array" capable of storing a fixed number of @ref util::stream::Sort "Sort" objects.
*
* @param number The maximum number of @ref util::stream::Sort "sorters" that can be held by this @ref util::FixedArray "array"
* @see util::FixedArray::FixedArray()
*/
explicit Sorts(std::size_t number) : util::FixedArray<util::stream::Sort<Compare> >(number) {}
/**
* Constructs a new @ref util::stream::Sort "Sort" object which is stored in this @ref util::FixedArray "array".
*
* The new @ref util::stream::Sort "Sort" object is constructed using the provided @ref util::stream::SortConfig "SortConfig" and @ref Comparator "ngram comparator";
* once constructed, a new worker @ref util::stream::Thread "thread" (owned by the @ref util::stream::Chain "chain") will sort the n-gram data stored
* in the @ref util::stream::Block "blocks" of the provided @ref util::stream::Chain "chain".
*
* @see util::stream::Sort::Sort()
* @see util::stream::Chain::operator>>()
*/
void push_back(util::stream::Chain &chain, const util::stream::SortConfig &config, const Compare &compare) { void push_back(util::stream::Chain &chain, const util::stream::SortConfig &config, const Compare &compare) {
new (P::end()) S(chain, config, compare); new (P::end()) S(chain, config, compare); // use "placement new" syntax to initalize S in an already-allocated memory location
P::Constructed(); P::Constructed();
} }
}; };
@ -100,4 +241,4 @@ template <class Compare> class Sorts : public FixedArray<util::stream::Sort<Comp
} // namespace builder } // namespace builder
} // namespace lm } // namespace lm
#endif // LM_BUILDER_SORT__ #endif // LM_BUILDER_SORT_H

View File

@ -1,5 +1,5 @@
#ifndef LM_CONFIG__ #ifndef LM_CONFIG_H
#define LM_CONFIG__ #define LM_CONFIG_H
#include "lm/lm_exception.hh" #include "lm/lm_exception.hh"
#include "util/mmap.hh" #include "util/mmap.hh"
@ -120,4 +120,4 @@ struct Config {
} /* namespace ngram */ } /* namespace lm */ } /* namespace ngram */ } /* namespace lm */
#endif // LM_CONFIG__ #endif // LM_CONFIG_H

View File

@ -1,5 +1,5 @@
#ifndef LM_ENUMERATE_VOCAB__ #ifndef LM_ENUMERATE_VOCAB_H
#define LM_ENUMERATE_VOCAB__ #define LM_ENUMERATE_VOCAB_H
#include "lm/word_index.hh" #include "lm/word_index.hh"
#include "util/string_piece.hh" #include "util/string_piece.hh"
@ -24,5 +24,5 @@ class EnumerateVocab {
} // namespace lm } // namespace lm
#endif // LM_ENUMERATE_VOCAB__ #endif // LM_ENUMERATE_VOCAB_H

View File

@ -1,5 +1,5 @@
#ifndef LM_FACADE__ #ifndef LM_FACADE_H
#define LM_FACADE__ #define LM_FACADE_H
#include "lm/virtual_interface.hh" #include "lm/virtual_interface.hh"
#include "util/string_piece.hh" #include "util/string_piece.hh"
@ -70,4 +70,4 @@ template <class Child, class StateT, class VocabularyT> class ModelFacade : publ
} // mamespace base } // mamespace base
} // namespace lm } // namespace lm
#endif // LM_FACADE__ #endif // LM_FACADE_H

View File

@ -1,5 +1,5 @@
#ifndef LM_FILTER_ARPA_IO__ #ifndef LM_FILTER_ARPA_IO_H
#define LM_FILTER_ARPA_IO__ #define LM_FILTER_ARPA_IO_H
/* Input and output for ARPA format language model files. /* Input and output for ARPA format language model files.
*/ */
#include "lm/read_arpa.hh" #include "lm/read_arpa.hh"
@ -111,4 +111,4 @@ template <class Output> void ReadARPA(util::FilePiece &in_lm, Output &out) {
} // namespace lm } // namespace lm
#endif // LM_FILTER_ARPA_IO__ #endif // LM_FILTER_ARPA_IO_H

View File

@ -1,5 +1,5 @@
#ifndef LM_FILTER_COUNT_IO__ #ifndef LM_FILTER_COUNT_IO_H
#define LM_FILTER_COUNT_IO__ #define LM_FILTER_COUNT_IO_H
#include <fstream> #include <fstream>
#include <iostream> #include <iostream>
@ -86,4 +86,4 @@ template <class Output> void ReadCount(util::FilePiece &in_file, Output &out) {
} // namespace lm } // namespace lm
#endif // LM_FILTER_COUNT_IO__ #endif // LM_FILTER_COUNT_IO_H

View File

@ -1,5 +1,5 @@
#ifndef LM_FILTER_FORMAT_H__ #ifndef LM_FILTER_FORMAT_H
#define LM_FILTER_FORMAT_H__ #define LM_FILTER_FORMAT_H
#include "lm/filter/arpa_io.hh" #include "lm/filter/arpa_io.hh"
#include "lm/filter/count_io.hh" #include "lm/filter/count_io.hh"
@ -247,4 +247,4 @@ class MultipleOutputBuffer {
} // namespace lm } // namespace lm
#endif // LM_FILTER_FORMAT_H__ #endif // LM_FILTER_FORMAT_H

View File

@ -1,5 +1,5 @@
#ifndef LM_FILTER_PHRASE_H__ #ifndef LM_FILTER_PHRASE_H
#define LM_FILTER_PHRASE_H__ #define LM_FILTER_PHRASE_H
#include "util/murmur_hash.hh" #include "util/murmur_hash.hh"
#include "util/string_piece.hh" #include "util/string_piece.hh"
@ -165,4 +165,4 @@ class Multiple : public detail::ConditionCommon {
} // namespace phrase } // namespace phrase
} // namespace lm } // namespace lm
#endif // LM_FILTER_PHRASE_H__ #endif // LM_FILTER_PHRASE_H

View File

@ -1,5 +1,5 @@
#ifndef LM_FILTER_THREAD_H__ #ifndef LM_FILTER_THREAD_H
#define LM_FILTER_THREAD_H__ #define LM_FILTER_THREAD_H
#include "util/thread_pool.hh" #include "util/thread_pool.hh"
@ -164,4 +164,4 @@ template <class Filter, class OutputBuffer, class RealOutput> class Controller :
} // namespace lm } // namespace lm
#endif // LM_FILTER_THREAD_H__ #endif // LM_FILTER_THREAD_H

View File

@ -1,5 +1,5 @@
#ifndef LM_FILTER_VOCAB_H__ #ifndef LM_FILTER_VOCAB_H
#define LM_FILTER_VOCAB_H__ #define LM_FILTER_VOCAB_H
// Vocabulary-based filters for language models. // Vocabulary-based filters for language models.
@ -130,4 +130,4 @@ class Multiple {
} // namespace vocab } // namespace vocab
} // namespace lm } // namespace lm
#endif // LM_FILTER_VOCAB_H__ #endif // LM_FILTER_VOCAB_H

View File

@ -1,5 +1,5 @@
#ifndef LM_FILTER_WRAPPER_H__ #ifndef LM_FILTER_WRAPPER_H
#define LM_FILTER_WRAPPER_H__ #define LM_FILTER_WRAPPER_H
#include "util/string_piece.hh" #include "util/string_piece.hh"
@ -53,4 +53,4 @@ template <class FilterT> class ContextFilter {
} // namespace lm } // namespace lm
#endif // LM_FILTER_WRAPPER_H__ #endif // LM_FILTER_WRAPPER_H

View File

@ -35,8 +35,8 @@
* phrase, even if hypotheses are generated left-to-right. * phrase, even if hypotheses are generated left-to-right.
*/ */
#ifndef LM_LEFT__ #ifndef LM_LEFT_H
#define LM_LEFT__ #define LM_LEFT_H
#include "lm/max_order.hh" #include "lm/max_order.hh"
#include "lm/state.hh" #include "lm/state.hh"
@ -213,4 +213,4 @@ template <class M> class RuleScore {
} // namespace ngram } // namespace ngram
} // namespace lm } // namespace lm
#endif // LM_LEFT__ #endif // LM_LEFT_H

View File

@ -1,5 +1,5 @@
#ifndef LM_LM_EXCEPTION__ #ifndef LM_LM_EXCEPTION_H
#define LM_LM_EXCEPTION__ #define LM_LM_EXCEPTION_H
// Named to avoid conflict with util/exception.hh. // Named to avoid conflict with util/exception.hh.

View File

@ -1,9 +1,13 @@
/* IF YOUR BUILD SYSTEM PASSES -DKENLM_MAX_ORDER, THEN CHANGE THE BUILD SYSTEM. #ifndef LM_MAX_ORDER_H
#define LM_MAX_ORDER_H
/* IF YOUR BUILD SYSTEM PASSES -DKENLM_MAX_ORDER_H, THEN CHANGE THE BUILD SYSTEM.
* If not, this is the default maximum order. * If not, this is the default maximum order.
* Having this limit means that State can be * Having this limit means that State can be
* (kMaxOrder - 1) * sizeof(float) bytes instead of * (kMaxOrder - 1) * sizeof(float) bytes instead of
* sizeof(float*) + (kMaxOrder - 1) * sizeof(float) + malloc overhead * sizeof(float*) + (kMaxOrder - 1) * sizeof(float) + malloc overhead
*/ */
#ifndef KENLM_ORDER_MESSAGE #ifndef KENLM_ORDER_MESSAGE
#define KENLM_ORDER_MESSAGE "If your build system supports changing KENLM_MAX_ORDER, change it there and recompile. In the KenLM tarball or Moses, use e.g. `bjam --max-kenlm-order=6 -a'. Otherwise, edit lm/max_order.hh." #define KENLM_ORDER_MESSAGE "If your build system supports changing KENLM_MAX_ORDER_H, change it there and recompile. In the KenLM tarball or Moses, use e.g. `bjam --max-kenlm-order=6 -a'. Otherwise, edit lm/max_order.hh."
#endif #endif
#endif // LM_MAX_ORDER_H

View File

@ -1,5 +1,5 @@
#ifndef LM_MODEL__ #ifndef LM_MODEL_H
#define LM_MODEL__ #define LM_MODEL_H
#include "lm/bhiksha.hh" #include "lm/bhiksha.hh"
#include "lm/binary_format.hh" #include "lm/binary_format.hh"
@ -153,4 +153,4 @@ base::Model *LoadVirtual(const char *file_name, const Config &config = Config(),
} // namespace ngram } // namespace ngram
} // namespace lm } // namespace lm
#endif // LM_MODEL__ #endif // LM_MODEL_H

View File

@ -1,5 +1,5 @@
#ifndef LM_MODEL_TYPE__ #ifndef LM_MODEL_TYPE_H
#define LM_MODEL_TYPE__ #define LM_MODEL_TYPE_H
namespace lm { namespace lm {
namespace ngram { namespace ngram {
@ -20,4 +20,4 @@ const static ModelType kArrayAdd = static_cast<ModelType>(ARRAY_TRIE - TRIE);
} // namespace ngram } // namespace ngram
} // namespace lm } // namespace lm
#endif // LM_MODEL_TYPE__ #endif // LM_MODEL_TYPE_H

View File

@ -1,8 +1,9 @@
#ifndef LM_NGRAM_QUERY__ #ifndef LM_NGRAM_QUERY_H
#define LM_NGRAM_QUERY__ #define LM_NGRAM_QUERY_H
#include "lm/enumerate_vocab.hh" #include "lm/enumerate_vocab.hh"
#include "lm/model.hh" #include "lm/model.hh"
#include "util/file_piece.hh"
#include "util/usage.hh" #include "util/usage.hh"
#include <cstdlib> #include <cstdlib>
@ -16,64 +17,94 @@
namespace lm { namespace lm {
namespace ngram { namespace ngram {
template <class Model> void Query(const Model &model, bool sentence_context, std::istream &in_stream, std::ostream &out_stream) { struct BasicPrint {
void Word(StringPiece, WordIndex, const FullScoreReturn &) const {}
void Line(uint64_t oov, float total) const {
std::cout << "Total: " << total << " OOV: " << oov << '\n';
}
void Summary(double, double, uint64_t, uint64_t) {}
};
struct FullPrint : public BasicPrint {
void Word(StringPiece surface, WordIndex vocab, const FullScoreReturn &ret) const {
std::cout << surface << '=' << vocab << ' ' << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << '\t';
}
void Summary(double ppl_including_oov, double ppl_excluding_oov, uint64_t corpus_oov, uint64_t corpus_tokens) {
std::cout <<
"Perplexity including OOVs:\t" << ppl_including_oov << "\n"
"Perplexity excluding OOVs:\t" << ppl_excluding_oov << "\n"
"OOVs:\t" << corpus_oov << "\n"
"Tokenss:\t" << corpus_tokens << '\n'
;
}
};
template <class Model, class Printer> void Query(const Model &model, bool sentence_context) {
Printer printer;
typename Model::State state, out; typename Model::State state, out;
lm::FullScoreReturn ret; lm::FullScoreReturn ret;
std::string word; StringPiece word;
util::FilePiece in(0);
double corpus_total = 0.0; double corpus_total = 0.0;
double corpus_total_oov_only = 0.0;
uint64_t corpus_oov = 0; uint64_t corpus_oov = 0;
uint64_t corpus_tokens = 0; uint64_t corpus_tokens = 0;
while (in_stream) { while (true) {
state = sentence_context ? model.BeginSentenceState() : model.NullContextState(); state = sentence_context ? model.BeginSentenceState() : model.NullContextState();
float total = 0.0; float total = 0.0;
bool got = false;
uint64_t oov = 0; uint64_t oov = 0;
while (in_stream >> word) {
got = true; while (in.ReadWordSameLine(word)) {
lm::WordIndex vocab = model.GetVocabulary().Index(word); lm::WordIndex vocab = model.GetVocabulary().Index(word);
if (vocab == 0) ++oov;
ret = model.FullScore(state, vocab, out); ret = model.FullScore(state, vocab, out);
if (vocab == model.GetVocabulary().NotFound()) {
++oov;
corpus_total_oov_only += ret.prob;
}
total += ret.prob; total += ret.prob;
out_stream << word << '=' << vocab << ' ' << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << '\t'; printer.Word(word, vocab, ret);
++corpus_tokens; ++corpus_tokens;
state = out; state = out;
char c;
while (true) {
c = in_stream.get();
if (!in_stream) break;
if (c == '\n') break;
if (!isspace(c)) {
in_stream.unget();
break;
}
}
if (c == '\n') break;
} }
if (!got && !in_stream) break; // If people don't have a newline after their last query, this won't add a </s>.
// Sue me.
try {
UTIL_THROW_IF('\n' != in.get(), util::Exception, "FilePiece is confused.");
} catch (const util::EndOfFileException &e) { break; }
if (sentence_context) { if (sentence_context) {
ret = model.FullScore(state, model.GetVocabulary().EndSentence(), out); ret = model.FullScore(state, model.GetVocabulary().EndSentence(), out);
total += ret.prob; total += ret.prob;
++corpus_tokens; ++corpus_tokens;
out_stream << "</s>=" << model.GetVocabulary().EndSentence() << ' ' << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << '\t'; printer.Word("</s>", model.GetVocabulary().EndSentence(), ret);
} }
out_stream << "Total: " << total << " OOV: " << oov << '\n'; printer.Line(oov, total);
corpus_total += total; corpus_total += total;
corpus_oov += oov; corpus_oov += oov;
} }
out_stream << "Perplexity " << pow(10.0, -(corpus_total / static_cast<double>(corpus_tokens))) << std::endl; printer.Summary(
pow(10.0, -(corpus_total / static_cast<double>(corpus_tokens))), // PPL including OOVs
pow(10.0, -((corpus_total - corpus_total_oov_only) / static_cast<double>(corpus_tokens - corpus_oov))), // PPL excluding OOVs
corpus_oov,
corpus_tokens);
} }
template <class M> void Query(const char *file, bool sentence_context, std::istream &in_stream, std::ostream &out_stream) { template <class Model> void Query(const char *file, const Config &config, bool sentence_context, bool show_words) {
Config config; Model model(file, config);
M model(file, config); if (show_words) {
Query(model, sentence_context, in_stream, out_stream); Query<Model, FullPrint>(model, sentence_context);
} else {
Query<Model, BasicPrint>(model, sentence_context);
}
} }
} // namespace ngram } // namespace ngram
} // namespace lm } // namespace lm
#endif // LM_NGRAM_QUERY__ #endif // LM_NGRAM_QUERY_H

View File

@ -1,5 +1,5 @@
#ifndef LM_PARTIAL__ #ifndef LM_PARTIAL_H
#define LM_PARTIAL__ #define LM_PARTIAL_H
#include "lm/return.hh" #include "lm/return.hh"
#include "lm/state.hh" #include "lm/state.hh"
@ -164,4 +164,4 @@ template <class Model> float Subsume(const Model &model, Left &first_left, const
} // namespace ngram } // namespace ngram
} // namespace lm } // namespace lm
#endif // LM_PARTIAL__ #endif // LM_PARTIAL_H

View File

@ -1,5 +1,5 @@
#ifndef LM_QUANTIZE_H__ #ifndef LM_QUANTIZE_H
#define LM_QUANTIZE_H__ #define LM_QUANTIZE_H
#include "lm/blank.hh" #include "lm/blank.hh"
#include "lm/config.hh" #include "lm/config.hh"
@ -230,4 +230,4 @@ class SeparatelyQuantize {
} // namespace ngram } // namespace ngram
} // namespace lm } // namespace lm
#endif // LM_QUANTIZE_H__ #endif // LM_QUANTIZE_H

View File

@ -1,4 +1,5 @@
#include "lm/ngram_query.hh" #include "lm/ngram_query.hh"
#include "util/getopt.hh"
#ifdef WITH_NPLM #ifdef WITH_NPLM
#include "lm/wrappers/nplm.hh" #include "lm/wrappers/nplm.hh"
@ -7,47 +8,76 @@
#include <stdlib.h> #include <stdlib.h>
void Usage(const char *name) { void Usage(const char *name) {
std::cerr << "KenLM was compiled with maximum order " << KENLM_MAX_ORDER << "." << std::endl; std::cerr <<
std::cerr << "Usage: " << name << " [-n] lm_file" << std::endl; "KenLM was compiled with maximum order " << KENLM_MAX_ORDER << ".\n"
std::cerr << "Input is wrapped in <s> and </s> unless -n is passed." << std::endl; "Usage: " << name << " [-n] [-s] lm_file\n"
"-n: Do not wrap the input in <s> and </s>.\n"
"-s: Sentence totals only.\n"
"-l lazy|populate|read|parallel: Load lazily, with populate, or malloc+read\n"
"The default loading method is populate on Linux and read on others.\n";
exit(1); exit(1);
} }
int main(int argc, char *argv[]) { int main(int argc, char *argv[]) {
if (argc == 1 || (argc == 2 && !strcmp(argv[1], "--help")))
Usage(argv[0]);
lm::ngram::Config config;
bool sentence_context = true; bool sentence_context = true;
const char *file = NULL; bool show_words = true;
for (char **arg = argv + 1; arg != argv + argc; ++arg) {
if (!strcmp(*arg, "-n")) { int opt;
sentence_context = false; while ((opt = getopt(argc, argv, "hnsl:")) != -1) {
} else if (!strcmp(*arg, "-h") || !strcmp(*arg, "--help") || file) { switch (opt) {
Usage(argv[0]); case 'n':
} else { sentence_context = false;
file = *arg; break;
case 's':
show_words = false;
break;
case 'l':
if (!strcmp(optarg, "lazy")) {
config.load_method = util::LAZY;
} else if (!strcmp(optarg, "populate")) {
config.load_method = util::POPULATE_OR_READ;
} else if (!strcmp(optarg, "read")) {
config.load_method = util::READ;
} else if (!strcmp(optarg, "parallel")) {
config.load_method = util::PARALLEL_READ;
} else {
Usage(argv[0]);
}
break;
case 'h':
default:
Usage(argv[0]);
} }
} }
if (!file) Usage(argv[0]); if (optind + 1 != argc)
Usage(argv[0]);
const char *file = argv[optind];
try { try {
using namespace lm::ngram; using namespace lm::ngram;
ModelType model_type; ModelType model_type;
if (RecognizeBinary(file, model_type)) { if (RecognizeBinary(file, model_type)) {
switch(model_type) { switch(model_type) {
case PROBING: case PROBING:
Query<lm::ngram::ProbingModel>(file, sentence_context, std::cin, std::cout); Query<lm::ngram::ProbingModel>(file, config, sentence_context, show_words);
break; break;
case REST_PROBING: case REST_PROBING:
Query<lm::ngram::RestProbingModel>(file, sentence_context, std::cin, std::cout); Query<lm::ngram::RestProbingModel>(file, config, sentence_context, show_words);
break; break;
case TRIE: case TRIE:
Query<TrieModel>(file, sentence_context, std::cin, std::cout); Query<TrieModel>(file, config, sentence_context, show_words);
break; break;
case QUANT_TRIE: case QUANT_TRIE:
Query<QuantTrieModel>(file, sentence_context, std::cin, std::cout); Query<QuantTrieModel>(file, config, sentence_context, show_words);
break; break;
case ARRAY_TRIE: case ARRAY_TRIE:
Query<ArrayTrieModel>(file, sentence_context, std::cin, std::cout); Query<ArrayTrieModel>(file, config, sentence_context, show_words);
break; break;
case QUANT_ARRAY_TRIE: case QUANT_ARRAY_TRIE:
Query<QuantArrayTrieModel>(file, sentence_context, std::cin, std::cout); Query<QuantArrayTrieModel>(file, config, sentence_context, show_words);
break; break;
default: default:
std::cerr << "Unrecognized kenlm model type " << model_type << std::endl; std::cerr << "Unrecognized kenlm model type " << model_type << std::endl;
@ -56,12 +86,15 @@ int main(int argc, char *argv[]) {
#ifdef WITH_NPLM #ifdef WITH_NPLM
} else if (lm::np::Model::Recognize(file)) { } else if (lm::np::Model::Recognize(file)) {
lm::np::Model model(file); lm::np::Model model(file);
Query(model, sentence_context, std::cin, std::cout); if (show_words) {
Query<lm::np::Model, lm::ngram::FullPrint>(model, sentence_context);
} else {
Query<lm::np::Model, lm::ngram::BasicPrint>(model, sentence_context);
}
#endif #endif
} else { } else {
Query<ProbingModel>(file, sentence_context, std::cin, std::cout); Query<ProbingModel>(file, config, sentence_context, show_words);
} }
std::cerr << "Total time including destruction:\n";
util::PrintUsage(std::cerr); util::PrintUsage(std::cerr);
} catch (const std::exception &e) { } catch (const std::exception &e) {
std::cerr << e.what() << std::endl; std::cerr << e.what() << std::endl;

View File

@ -1,5 +1,5 @@
#ifndef LM_READ_ARPA__ #ifndef LM_READ_ARPA_H
#define LM_READ_ARPA__ #define LM_READ_ARPA_H
#include "lm/lm_exception.hh" #include "lm/lm_exception.hh"
#include "lm/word_index.hh" #include "lm/word_index.hh"
@ -28,7 +28,7 @@ void ReadEnd(util::FilePiece &in);
extern const bool kARPASpaces[256]; extern const bool kARPASpaces[256];
// Positive log probability warning. // Positive log probability warning.
class PositiveProbWarn { class PositiveProbWarn {
public: public:
PositiveProbWarn() : action_(THROW_UP) {} PositiveProbWarn() : action_(THROW_UP) {}
@ -41,24 +41,29 @@ class PositiveProbWarn {
WarningAction action_; WarningAction action_;
}; };
template <class Voc, class Weights> void Read1Gram(util::FilePiece &f, Voc &vocab, Weights *unigrams, PositiveProbWarn &warn) { template <class Weights> StringPiece Read1Gram(util::FilePiece &f, Weights &weights, PositiveProbWarn &warn) {
try { try {
float prob = f.ReadFloat(); weights.prob = f.ReadFloat();
if (prob > 0.0) { if (weights.prob > 0.0) {
warn.Warn(prob); warn.Warn(weights.prob);
prob = 0.0; weights.prob = 0.0;
} }
if (f.get() != '\t') UTIL_THROW(FormatLoadException, "Expected tab after probability"); UTIL_THROW_IF(f.get() != '\t', FormatLoadException, "Expected tab after probability");
Weights &value = unigrams[vocab.Insert(f.ReadDelimited(kARPASpaces))]; StringPiece ret(f.ReadDelimited(kARPASpaces));
value.prob = prob; ReadBackoff(f, weights);
ReadBackoff(f, value); return ret;
} catch(util::Exception &e) { } catch(util::Exception &e) {
e << " in the 1-gram at byte " << f.Offset(); e << " in the 1-gram at byte " << f.Offset();
throw; throw;
} }
} }
// Return true if a positive log probability came out. template <class Voc, class Weights> void Read1Gram(util::FilePiece &f, Voc &vocab, Weights *unigrams, PositiveProbWarn &warn) {
Weights temp;
WordIndex word = vocab.Insert(Read1Gram(f, temp, warn));
unigrams[word] = temp;
}
template <class Voc, class Weights> void Read1Grams(util::FilePiece &f, std::size_t count, Voc &vocab, Weights *unigrams, PositiveProbWarn &warn) { template <class Voc, class Weights> void Read1Grams(util::FilePiece &f, std::size_t count, Voc &vocab, Weights *unigrams, PositiveProbWarn &warn) {
ReadNGramHeader(f, 1); ReadNGramHeader(f, 1);
for (std::size_t i = 0; i < count; ++i) { for (std::size_t i = 0; i < count; ++i) {
@ -67,16 +72,16 @@ template <class Voc, class Weights> void Read1Grams(util::FilePiece &f, std::siz
vocab.FinishedLoading(unigrams); vocab.FinishedLoading(unigrams);
} }
// Return true if a positive log probability came out. // Read ngram, write vocab ids to indices_out.
template <class Voc, class Weights> void ReadNGram(util::FilePiece &f, const unsigned char n, const Voc &vocab, WordIndex *const reverse_indices, Weights &weights, PositiveProbWarn &warn) { template <class Voc, class Weights, class Iterator> void ReadNGram(util::FilePiece &f, const unsigned char n, const Voc &vocab, Iterator indices_out, Weights &weights, PositiveProbWarn &warn) {
try { try {
weights.prob = f.ReadFloat(); weights.prob = f.ReadFloat();
if (weights.prob > 0.0) { if (weights.prob > 0.0) {
warn.Warn(weights.prob); warn.Warn(weights.prob);
weights.prob = 0.0; weights.prob = 0.0;
} }
for (WordIndex *vocab_out = reverse_indices + n - 1; vocab_out >= reverse_indices; --vocab_out) { for (unsigned char i = 0; i < n; ++i, ++indices_out) {
*vocab_out = vocab.Index(f.ReadDelimited(kARPASpaces)); *indices_out = vocab.Index(f.ReadDelimited(kARPASpaces));
} }
ReadBackoff(f, weights); ReadBackoff(f, weights);
} catch(util::Exception &e) { } catch(util::Exception &e) {
@ -87,4 +92,4 @@ template <class Voc, class Weights> void ReadNGram(util::FilePiece &f, const uns
} // namespace lm } // namespace lm
#endif // LM_READ_ARPA__ #endif // LM_READ_ARPA_H

View File

@ -1,5 +1,5 @@
#ifndef LM_RETURN__ #ifndef LM_RETURN_H
#define LM_RETURN__ #define LM_RETURN_H
#include <stdint.h> #include <stdint.h>
@ -39,4 +39,4 @@ struct FullScoreReturn {
}; };
} // namespace lm } // namespace lm
#endif // LM_RETURN__ #endif // LM_RETURN_H

View File

@ -178,7 +178,7 @@ template <class Build, class Activate, class Store> void ReadNGrams(
typename Store::Entry entry; typename Store::Entry entry;
std::vector<typename Value::Weights *> between; std::vector<typename Value::Weights *> between;
for (size_t i = 0; i < count; ++i) { for (size_t i = 0; i < count; ++i) {
ReadNGram(f, n, vocab, &*vocab_ids.begin(), entry.value, warn); ReadNGram(f, n, vocab, vocab_ids.rbegin(), entry.value, warn);
build.SetRest(&*vocab_ids.begin(), n, entry.value); build.SetRest(&*vocab_ids.begin(), n, entry.value);
keys[0] = detail::CombineWordHash(static_cast<uint64_t>(vocab_ids.front()), vocab_ids[1]); keys[0] = detail::CombineWordHash(static_cast<uint64_t>(vocab_ids.front()), vocab_ids[1]);

View File

@ -1,5 +1,5 @@
#ifndef LM_SEARCH_HASHED__ #ifndef LM_SEARCH_HASHED_H
#define LM_SEARCH_HASHED__ #define LM_SEARCH_HASHED_H
#include "lm/model_type.hh" #include "lm/model_type.hh"
#include "lm/config.hh" #include "lm/config.hh"
@ -189,4 +189,4 @@ template <class Value> class HashedSearch {
} // namespace ngram } // namespace ngram
} // namespace lm } // namespace lm
#endif // LM_SEARCH_HASHED__ #endif // LM_SEARCH_HASHED_H

View File

@ -561,6 +561,7 @@ template <class Quant, class Bhiksha> uint8_t *TrieSearch<Quant, Bhiksha>::Setup
} }
// Crazy backwards thing so we initialize using pointers to ones that have already been initialized // Crazy backwards thing so we initialize using pointers to ones that have already been initialized
for (unsigned char i = counts.size() - 1; i >= 2; --i) { for (unsigned char i = counts.size() - 1; i >= 2; --i) {
// use "placement new" syntax to initalize Middle in an already-allocated memory location
new (middle_begin_ + i - 2) Middle( new (middle_begin_ + i - 2) Middle(
middle_starts[i-2], middle_starts[i-2],
quant_.MiddleBits(config), quant_.MiddleBits(config),

View File

@ -1,5 +1,5 @@
#ifndef LM_SEARCH_TRIE__ #ifndef LM_SEARCH_TRIE_H
#define LM_SEARCH_TRIE__ #define LM_SEARCH_TRIE_H
#include "lm/config.hh" #include "lm/config.hh"
#include "lm/model_type.hh" #include "lm/model_type.hh"
@ -127,4 +127,4 @@ template <class Quant, class Bhiksha> class TrieSearch {
} // namespace ngram } // namespace ngram
} // namespace lm } // namespace lm
#endif // LM_SEARCH_TRIE__ #endif // LM_SEARCH_TRIE_H

View File

@ -1,5 +1,5 @@
#ifndef LM_SIZES__ #ifndef LM_SIZES_H
#define LM_SIZES__ #define LM_SIZES_H
#include <vector> #include <vector>
@ -14,4 +14,4 @@ void ShowSizes(const std::vector<uint64_t> &counts);
void ShowSizes(const char *file, const lm::ngram::Config &config); void ShowSizes(const char *file, const lm::ngram::Config &config);
}} // namespaces }} // namespaces
#endif // LM_SIZES__ #endif // LM_SIZES_H

View File

@ -1,5 +1,5 @@
#ifndef LM_STATE__ #ifndef LM_STATE_H
#define LM_STATE__ #define LM_STATE_H
#include "lm/max_order.hh" #include "lm/max_order.hh"
#include "lm/word_index.hh" #include "lm/word_index.hh"
@ -122,4 +122,4 @@ inline uint64_t hash_value(const ChartState &state) {
} // namespace ngram } // namespace ngram
} // namespace lm } // namespace lm
#endif // LM_STATE__ #endif // LM_STATE_H

View File

@ -1,5 +1,5 @@
#ifndef LM_TRIE__ #ifndef LM_TRIE_H
#define LM_TRIE__ #define LM_TRIE_H
#include "lm/weights.hh" #include "lm/weights.hh"
#include "lm/word_index.hh" #include "lm/word_index.hh"
@ -143,4 +143,4 @@ class BitPackedLongest : public BitPacked {
} // namespace ngram } // namespace ngram
} // namespace lm } // namespace lm
#endif // LM_TRIE__ #endif // LM_TRIE_H

View File

@ -16,6 +16,7 @@
#include <cstdio> #include <cstdio>
#include <cstdlib> #include <cstdlib>
#include <deque> #include <deque>
#include <iterator>
#include <limits> #include <limits>
#include <vector> #include <vector>
@ -248,11 +249,13 @@ void SortedFiles::ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vo
uint8_t *out_end = out + std::min(count - done, batch_size) * entry_size; uint8_t *out_end = out + std::min(count - done, batch_size) * entry_size;
if (order == counts.size()) { if (order == counts.size()) {
for (; out != out_end; out += entry_size) { for (; out != out_end; out += entry_size) {
ReadNGram(f, order, vocab, reinterpret_cast<WordIndex*>(out), *reinterpret_cast<Prob*>(out + words_size), warn); std::reverse_iterator<WordIndex*> it(reinterpret_cast<WordIndex*>(out) + order);
ReadNGram(f, order, vocab, it, *reinterpret_cast<Prob*>(out + words_size), warn);
} }
} else { } else {
for (; out != out_end; out += entry_size) { for (; out != out_end; out += entry_size) {
ReadNGram(f, order, vocab, reinterpret_cast<WordIndex*>(out), *reinterpret_cast<ProbBackoff*>(out + words_size), warn); std::reverse_iterator<WordIndex*> it(reinterpret_cast<WordIndex*>(out) + order);
ReadNGram(f, order, vocab, it, *reinterpret_cast<ProbBackoff*>(out + words_size), warn);
} }
} }
// Sort full records by full n-gram. // Sort full records by full n-gram.

View File

@ -1,7 +1,7 @@
// Step of trie builder: create sorted files. // Step of trie builder: create sorted files.
#ifndef LM_TRIE_SORT__ #ifndef LM_TRIE_SORT_H
#define LM_TRIE_SORT__ #define LM_TRIE_SORT_H
#include "lm/max_order.hh" #include "lm/max_order.hh"
#include "lm/word_index.hh" #include "lm/word_index.hh"
@ -111,4 +111,4 @@ class SortedFiles {
} // namespace ngram } // namespace ngram
} // namespace lm } // namespace lm
#endif // LM_TRIE_SORT__ #endif // LM_TRIE_SORT_H

View File

@ -1,5 +1,5 @@
#ifndef LM_VALUE__ #ifndef LM_VALUE_H
#define LM_VALUE__ #define LM_VALUE_H
#include "lm/model_type.hh" #include "lm/model_type.hh"
#include "lm/value_build.hh" #include "lm/value_build.hh"
@ -154,4 +154,4 @@ struct RestValue {
} // namespace ngram } // namespace ngram
} // namespace lm } // namespace lm
#endif // LM_VALUE__ #endif // LM_VALUE_H

View File

@ -1,5 +1,5 @@
#ifndef LM_VALUE_BUILD__ #ifndef LM_VALUE_BUILD_H
#define LM_VALUE_BUILD__ #define LM_VALUE_BUILD_H
#include "lm/weights.hh" #include "lm/weights.hh"
#include "lm/word_index.hh" #include "lm/word_index.hh"
@ -94,4 +94,4 @@ template <class Model> class LowerRestBuild {
} // namespace ngram } // namespace ngram
} // namespace lm } // namespace lm
#endif // LM_VALUE_BUILD__ #endif // LM_VALUE_BUILD_H

View File

@ -1,5 +1,5 @@
#ifndef LM_VIRTUAL_INTERFACE__ #ifndef LM_VIRTUAL_INTERFACE_H
#define LM_VIRTUAL_INTERFACE__ #define LM_VIRTUAL_INTERFACE_H
#include "lm/return.hh" #include "lm/return.hh"
#include "lm/word_index.hh" #include "lm/word_index.hh"
@ -157,4 +157,4 @@ class Model {
} // mamespace base } // mamespace base
} // namespace lm } // namespace lm
#endif // LM_VIRTUAL_INTERFACE__ #endif // LM_VIRTUAL_INTERFACE_H

View File

@ -170,11 +170,15 @@ struct ProbingVocabularyHeader {
ProbingVocabulary::ProbingVocabulary() : enumerate_(NULL) {} ProbingVocabulary::ProbingVocabulary() : enumerate_(NULL) {}
uint64_t ProbingVocabulary::Size(uint64_t entries, const Config &config) { uint64_t ProbingVocabulary::Size(uint64_t entries, float probing_multiplier) {
return ALIGN8(sizeof(detail::ProbingVocabularyHeader)) + Lookup::Size(entries, config.probing_multiplier); return ALIGN8(sizeof(detail::ProbingVocabularyHeader)) + Lookup::Size(entries, probing_multiplier);
} }
void ProbingVocabulary::SetupMemory(void *start, std::size_t allocated, std::size_t /*entries*/, const Config &/*config*/) { uint64_t ProbingVocabulary::Size(uint64_t entries, const Config &config) {
return Size(entries, config.probing_multiplier);
}
void ProbingVocabulary::SetupMemory(void *start, std::size_t allocated) {
header_ = static_cast<detail::ProbingVocabularyHeader*>(start); header_ = static_cast<detail::ProbingVocabularyHeader*>(start);
lookup_ = Lookup(static_cast<uint8_t*>(start) + ALIGN8(sizeof(detail::ProbingVocabularyHeader)), allocated); lookup_ = Lookup(static_cast<uint8_t*>(start) + ALIGN8(sizeof(detail::ProbingVocabularyHeader)), allocated);
bound_ = 1; bound_ = 1;
@ -201,12 +205,12 @@ WordIndex ProbingVocabulary::Insert(const StringPiece &str) {
return 0; return 0;
} else { } else {
if (enumerate_) enumerate_->Add(bound_, str); if (enumerate_) enumerate_->Add(bound_, str);
lookup_.Insert(ProbingVocabuaryEntry::Make(hashed, bound_)); lookup_.Insert(ProbingVocabularyEntry::Make(hashed, bound_));
return bound_++; return bound_++;
} }
} }
void ProbingVocabulary::InternalFinishedLoading() { void ProbingVocabulary::FinishedLoading() {
lookup_.FinishedInserting(); lookup_.FinishedInserting();
header_->bound = bound_; header_->bound = bound_;
header_->version = kProbingVocabularyVersion; header_->version = kProbingVocabularyVersion;

View File

@ -1,9 +1,11 @@
#ifndef LM_VOCAB__ #ifndef LM_VOCAB_H
#define LM_VOCAB__ #define LM_VOCAB_H
#include "lm/enumerate_vocab.hh" #include "lm/enumerate_vocab.hh"
#include "lm/lm_exception.hh" #include "lm/lm_exception.hh"
#include "lm/virtual_interface.hh" #include "lm/virtual_interface.hh"
#include "util/fake_ofstream.hh"
#include "util/murmur_hash.hh"
#include "util/pool.hh" #include "util/pool.hh"
#include "util/probing_hash_table.hh" #include "util/probing_hash_table.hh"
#include "util/sorted_uniform.hh" #include "util/sorted_uniform.hh"
@ -104,17 +106,16 @@ class SortedVocabulary : public base::Vocabulary {
#pragma pack(push) #pragma pack(push)
#pragma pack(4) #pragma pack(4)
struct ProbingVocabuaryEntry { struct ProbingVocabularyEntry {
uint64_t key; uint64_t key;
WordIndex value; WordIndex value;
typedef uint64_t Key; typedef uint64_t Key;
uint64_t GetKey() const { uint64_t GetKey() const { return key; }
return key; void SetKey(uint64_t to) { key = to; }
}
static ProbingVocabuaryEntry Make(uint64_t key, WordIndex value) { static ProbingVocabularyEntry Make(uint64_t key, WordIndex value) {
ProbingVocabuaryEntry ret; ProbingVocabularyEntry ret;
ret.key = key; ret.key = key;
ret.value = value; ret.value = value;
return ret; return ret;
@ -132,13 +133,18 @@ class ProbingVocabulary : public base::Vocabulary {
return lookup_.Find(detail::HashForVocab(str), i) ? i->value : 0; return lookup_.Find(detail::HashForVocab(str), i) ? i->value : 0;
} }
static uint64_t Size(uint64_t entries, float probing_multiplier);
// This just unwraps Config to get the probing_multiplier.
static uint64_t Size(uint64_t entries, const Config &config); static uint64_t Size(uint64_t entries, const Config &config);
// Vocab words are [0, Bound()). // Vocab words are [0, Bound()).
WordIndex Bound() const { return bound_; } WordIndex Bound() const { return bound_; }
// Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway. // Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway.
void SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config); void SetupMemory(void *start, std::size_t allocated);
void SetupMemory(void *start, std::size_t allocated, std::size_t /*entries*/, const Config &/*config*/) {
SetupMemory(start, allocated);
}
void Relocate(void *new_start); void Relocate(void *new_start);
@ -147,8 +153,9 @@ class ProbingVocabulary : public base::Vocabulary {
WordIndex Insert(const StringPiece &str); WordIndex Insert(const StringPiece &str);
template <class Weights> void FinishedLoading(Weights * /*reorder_vocab*/) { template <class Weights> void FinishedLoading(Weights * /*reorder_vocab*/) {
InternalFinishedLoading(); FinishedLoading();
} }
void FinishedLoading();
std::size_t UnkCountChangePadding() const { return 0; } std::size_t UnkCountChangePadding() const { return 0; }
@ -157,9 +164,7 @@ class ProbingVocabulary : public base::Vocabulary {
void LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset); void LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset);
private: private:
void InternalFinishedLoading(); typedef util::ProbingHashTable<ProbingVocabularyEntry, util::IdentityHash> Lookup;
typedef util::ProbingHashTable<ProbingVocabuaryEntry, util::IdentityHash> Lookup;
Lookup lookup_; Lookup lookup_;
@ -181,7 +186,64 @@ template <class Vocab> void CheckSpecials(const Config &config, const Vocab &voc
if (vocab.EndSentence() == vocab.NotFound()) MissingSentenceMarker(config, "</s>"); if (vocab.EndSentence() == vocab.NotFound()) MissingSentenceMarker(config, "</s>");
} }
class WriteUniqueWords {
public:
explicit WriteUniqueWords(int fd) : word_list_(fd) {}
void operator()(const StringPiece &word) {
word_list_ << word << '\0';
}
private:
util::FakeOFStream word_list_;
};
class NoOpUniqueWords {
public:
NoOpUniqueWords() {}
void operator()(const StringPiece &word) {}
};
template <class NewWordAction = NoOpUniqueWords> class GrowableVocab {
public:
static std::size_t MemUsage(WordIndex content) {
return Lookup::MemUsage(content > 2 ? content : 2);
}
// Does not take ownership of write_wordi
template <class NewWordConstruct> GrowableVocab(WordIndex initial_size, const NewWordConstruct &new_word_construct = NewWordAction())
: lookup_(initial_size), new_word_(new_word_construct) {
FindOrInsert("<unk>"); // Force 0
FindOrInsert("<s>"); // Force 1
FindOrInsert("</s>"); // Force 2
}
WordIndex Index(const StringPiece &str) const {
Lookup::ConstIterator i;
return lookup_.Find(detail::HashForVocab(str), i) ? i->value : 0;
}
WordIndex FindOrInsert(const StringPiece &word) {
ProbingVocabularyEntry entry = ProbingVocabularyEntry::Make(util::MurmurHashNative(word.data(), word.size()), Size());
Lookup::MutableIterator it;
if (!lookup_.FindOrInsert(entry, it)) {
new_word_(word);
UTIL_THROW_IF(Size() >= std::numeric_limits<lm::WordIndex>::max(), VocabLoadException, "Too many vocabulary words. Change WordIndex to uint64_t in lm/word_index.hh");
}
return it->value;
}
WordIndex Size() const { return lookup_.Size(); }
private:
typedef util::AutoProbing<ProbingVocabularyEntry, util::IdentityHash> Lookup;
Lookup lookup_;
NewWordAction new_word_;
};
} // namespace ngram } // namespace ngram
} // namespace lm } // namespace lm
#endif // LM_VOCAB__ #endif // LM_VOCAB_H

View File

@ -1,5 +1,5 @@
#ifndef LM_WEIGHTS__ #ifndef LM_WEIGHTS_H
#define LM_WEIGHTS__ #define LM_WEIGHTS_H
// Weights for n-grams. Probability and possibly a backoff. // Weights for n-grams. Probability and possibly a backoff.
@ -19,4 +19,4 @@ struct RestWeights {
}; };
} // namespace lm } // namespace lm
#endif // LM_WEIGHTS__ #endif // LM_WEIGHTS_H

View File

@ -1,6 +1,6 @@
// Separate header because this is used often. // Separate header because this is used often.
#ifndef LM_WORD_INDEX__ #ifndef LM_WORD_INDEX_H
#define LM_WORD_INDEX__ #define LM_WORD_INDEX_H
#include <limits.h> #include <limits.h>

View File

@ -11,28 +11,26 @@ if [ test_library "lzma" ] && [ test_header "lzma.h" ] {
compressed_deps += lzma ; compressed_deps += lzma ;
} }
local have-clock = [ SHELL "bash -c \"g++ -dM -x c++ -E /dev/null -include time.h 2>/dev/null |grep CLOCK_MONOTONIC\"" : exit-status ] ; #rt is needed for clock_gettime on linux. But it's already included with threading=multi
if $(have-clock[2]) = 0 { lib rt ;
#required for clock_gettime. Threads already have rt.
lib rt : : <runtime-link>static:<link>static <runtime-link>shared:<link>shared ;
} else {
alias rt ;
}
obj read_compressed.o : read_compressed.cc : $(compressed_flags) ; obj read_compressed.o : read_compressed.cc : $(compressed_flags) ;
alias read_compressed : read_compressed.o $(compressed_deps) ; alias read_compressed : read_compressed.o $(compressed_deps) ;
obj read_compressed_test.o : read_compressed_test.cc /top//boost_unit_test_framework : $(compressed_flags) ; obj read_compressed_test.o : read_compressed_test.cc /top//boost_unit_test_framework : $(compressed_flags) ;
obj file_piece_test.o : file_piece_test.cc /top//boost_unit_test_framework : $(compressed_flags) ; obj file_piece_test.o : file_piece_test.cc /top//boost_unit_test_framework : $(compressed_flags) ;
fakelib kenutil : bit_packing.cc ersatz_progress.cc exception.cc file.cc file_piece.cc mmap.cc murmur_hash.cc pool.cc read_compressed scoped.cc string_piece.cc usage.cc double-conversion//double-conversion : <include>.. <os>LINUX,<threading>single:<source>rt : : <include>.. ; fakelib parallel_read : parallel_read.cc : <threading>multi:<source>/top//boost_thread <threading>multi:<define>WITH_THREADS : : <include>.. ;
fakelib kenutil : bit_packing.cc ersatz_progress.cc exception.cc file.cc file_piece.cc mmap.cc murmur_hash.cc parallel_read pool.cc read_compressed scoped.cc string_piece.cc usage.cc double-conversion//double-conversion : <include>.. <os>LINUX,<threading>single:<source>rt : : <include>.. ;
exe cat_compressed : cat_compressed_main.cc kenutil ;
alias programs : cat_compressed ;
import testing ; import testing ;
run file_piece_test.o kenutil /top//boost_unit_test_framework : : file_piece.cc ; run file_piece_test.o kenutil /top//boost_unit_test_framework : : file_piece.cc ;
unit-test pcqueue_test : pcqueue_test.cc kenutil /top//boost_unit_test_framework /top//boost_system : <threading>single:<build>no ; for local t in [ glob *_test.cc : file_piece_test.cc read_compressed_test.cc ] {
for local t in [ glob *_test.cc : file_piece_test.cc read_compressed_test.cc pcqueue_test.cc ] {
local name = [ MATCH "(.*)\.cc" : $(t) ] ; local name = [ MATCH "(.*)\.cc" : $(t) ] ;
unit-test $(name) : $(t) kenutil /top//boost_unit_test_framework /top//boost_system ; unit-test $(name) : $(t) kenutil /top//boost_unit_test_framework /top//boost_system ;
} }

View File

@ -1,5 +1,5 @@
#ifndef UTIL_BIT_PACKING__ #ifndef UTIL_BIT_PACKING_H
#define UTIL_BIT_PACKING__ #define UTIL_BIT_PACKING_H
/* Bit-level packing routines /* Bit-level packing routines
* *
@ -183,4 +183,4 @@ struct BitAddress {
} // namespace util } // namespace util
#endif // UTIL_BIT_PACKING__ #endif // UTIL_BIT_PACKING_H

View File

@ -0,0 +1,47 @@
// Like cat but interprets compressed files.
#include "util/file.hh"
#include "util/read_compressed.hh"
#include <string.h>
#include <iostream>
namespace {
const std::size_t kBufSize = 16384;
void Copy(util::ReadCompressed &from, int to) {
util::scoped_malloc buffer(util::MallocOrThrow(kBufSize));
while (std::size_t amount = from.Read(buffer.get(), kBufSize)) {
util::WriteOrThrow(to, buffer.get(), amount);
}
}
} // namespace
int main(int argc, char *argv[]) {
// Lane Schwartz likes -h and --help
for (int i = 1; i < argc; ++i) {
char *arg = argv[i];
if (!strcmp(arg, "--")) break;
if (!strcmp(arg, "-h") || !strcmp(arg, "--help")) {
std::cerr <<
"A cat implementation that interprets compressed files.\n"
"Usage: " << argv[0] << " [file1] [file2] ...\n"
"If no file is provided, then stdin is read.\n";
return 1;
}
}
try {
if (argc == 1) {
util::ReadCompressed in(0);
Copy(in, 1);
} else {
for (int i = 1; i < argc; ++i) {
util::ReadCompressed in(util::OpenReadOrThrow(argv[i]));
Copy(in, 1);
}
}
} catch (const std::exception &e) {
std::cerr << e.what() << std::endl;
return 2;
}
return 0;
}

View File

@ -1,5 +1,5 @@
#ifndef UTIL_ERSATZ_PROGRESS__ #ifndef UTIL_ERSATZ_PROGRESS_H
#define UTIL_ERSATZ_PROGRESS__ #define UTIL_ERSATZ_PROGRESS_H
#include <iostream> #include <iostream>
#include <string> #include <string>
@ -55,4 +55,4 @@ class ErsatzProgress {
} // namespace util } // namespace util
#endif // UTIL_ERSATZ_PROGRESS__ #endif // UTIL_ERSATZ_PROGRESS_H

View File

@ -2,6 +2,9 @@
* Does not support many data types. Currently, it's targeted at writing ARPA * Does not support many data types. Currently, it's targeted at writing ARPA
* files quickly. * files quickly.
*/ */
#ifndef UTIL_FAKE_OFSTREAM_H
#define UTIL_FAKE_OFSTREAM_H
#include "util/double-conversion/double-conversion.h" #include "util/double-conversion/double-conversion.h"
#include "util/double-conversion/utils.h" #include "util/double-conversion/utils.h"
#include "util/file.hh" #include "util/file.hh"
@ -17,7 +20,8 @@ class FakeOFStream {
static const std::size_t kOutBuf = 1048576; static const std::size_t kOutBuf = 1048576;
// Does not take ownership of out. // Does not take ownership of out.
explicit FakeOFStream(int out) // Allows default constructor, but must call SetFD.
explicit FakeOFStream(int out = -1)
: buf_(util::MallocOrThrow(kOutBuf)), : buf_(util::MallocOrThrow(kOutBuf)),
builder_(static_cast<char*>(buf_.get()), kOutBuf), builder_(static_cast<char*>(buf_.get()), kOutBuf),
// Mostly the default but with inf instead. And no flags. // Mostly the default but with inf instead. And no flags.
@ -28,6 +32,11 @@ class FakeOFStream {
if (buf_.get()) Flush(); if (buf_.get()) Flush();
} }
void SetFD(int to) {
if (builder_.position()) Flush();
fd_ = to;
}
FakeOFStream &operator<<(float value) { FakeOFStream &operator<<(float value) {
// Odd, but this is the largest number found in the comments. // Odd, but this is the largest number found in the comments.
EnsureRemaining(double_conversion::DoubleToStringConverter::kMaxPrecisionDigits + 8); EnsureRemaining(double_conversion::DoubleToStringConverter::kMaxPrecisionDigits + 8);
@ -92,3 +101,5 @@ class FakeOFStream {
}; };
} // namespace } // namespace
#endif

View File

@ -5,28 +5,29 @@
#include "util/exception.hh" #include "util/exception.hh"
#include <algorithm>
#include <cstdlib> #include <cstdlib>
#include <cstdio> #include <cstdio>
#include <sstream>
#include <iostream> #include <iostream>
#include <limits>
#include <sstream>
#include <assert.h> #include <assert.h>
#include <errno.h> #include <errno.h>
#include <limits.h>
#include <sys/types.h> #include <sys/types.h>
#include <sys/stat.h> #include <sys/stat.h>
#include <fcntl.h> #include <fcntl.h>
#include <stdint.h> #include <stdint.h>
#if defined __MINGW32__ #if defined(__MINGW32__)
#include <windows.h> #include <windows.h>
#include <unistd.h> #include <unistd.h>
#warning "The file functions on MinGW have not been tested for file sizes above 2^31 - 1. Please read https://stackoverflow.com/questions/12539488/determine-64-bit-file-size-in-c-on-mingw-32-bit and fix" #warning "The file functions on MinGW have not been tested for file sizes above 2^31 - 1. Please read https://stackoverflow.com/questions/12539488/determine-64-bit-file-size-in-c-on-mingw-32-bit and fix"
#elif defined(_WIN32) || defined(_WIN64) #elif defined(_WIN32) || defined(_WIN64)
#include <windows.h> #include <windows.h>
#include <io.h> #include <io.h>
#include <algorithm>
#include <limits.h>
#include <limits>
#else #else
#include <unistd.h> #include <unistd.h>
#endif #endif
@ -111,7 +112,7 @@ uint64_t SizeOrThrow(int fd) {
void ResizeOrThrow(int fd, uint64_t to) { void ResizeOrThrow(int fd, uint64_t to) {
#if defined __MINGW32__ #if defined __MINGW32__
// Does this handle 64-bit? // Does this handle 64-bit?
int ret = ftruncate int ret = ftruncate
#elif defined(_WIN32) || defined(_WIN64) #elif defined(_WIN32) || defined(_WIN64)
errno_t ret = _chsize_s errno_t ret = _chsize_s
@ -128,8 +129,10 @@ namespace {
std::size_t GuardLarge(std::size_t size) { std::size_t GuardLarge(std::size_t size) {
// The following operating systems have broken read/write/pread/pwrite that // The following operating systems have broken read/write/pread/pwrite that
// only supports up to 2^31. // only supports up to 2^31.
// OS X man pages claim to support 64-bit, but Kareem M. Darwish had problems
// building with larger files, so APPLE is also here.
#if defined(_WIN32) || defined(_WIN64) || defined(__APPLE__) || defined(OS_ANDROID) || defined(__MINGW32__) #if defined(_WIN32) || defined(_WIN64) || defined(__APPLE__) || defined(OS_ANDROID) || defined(__MINGW32__)
return std::min(static_cast<std::size_t>(static_cast<unsigned>(-1)), size); return size < INT_MAX ? size : INT_MAX;
#else #else
return size; return size;
#endif #endif
@ -172,46 +175,6 @@ std::size_t ReadOrEOF(int fd, void *to_void, std::size_t amount) {
return amount; return amount;
} }
void PReadOrThrow(int fd, void *to_void, std::size_t size, uint64_t off) {
uint8_t *to = static_cast<uint8_t*>(to_void);
#if defined(_WIN32) || defined(_WIN64)
UTIL_THROW(Exception, "This pread implementation for windows is broken. Please send me a patch that does not change the file pointer. Atomically. Or send me an implementation of pwrite that is allowed to change the file pointer but can be called concurrently with pread.");
const std::size_t kMaxDWORD = static_cast<std::size_t>(4294967295UL);
#endif
for (;size ;) {
#if defined(_WIN32) || defined(_WIN64)
/* BROKEN: changes file pointer. Even if you save it and change it back, it won't be safe to use concurrently with write() or read() which lmplz does. */
// size_t might be 64-bit. DWORD is always 32.
DWORD reading = static_cast<DWORD>(std::min<std::size_t>(kMaxDWORD, size));
DWORD ret;
OVERLAPPED overlapped;
memset(&overlapped, 0, sizeof(OVERLAPPED));
overlapped.Offset = static_cast<DWORD>(off);
overlapped.OffsetHigh = static_cast<DWORD>(off >> 32);
UTIL_THROW_IF(!ReadFile((HANDLE)_get_osfhandle(fd), to, reading, &ret, &overlapped), Exception, "ReadFile failed for offset " << off);
#else
ssize_t ret;
errno = 0;
do {
ret =
#ifdef OS_ANDROID
pread64
#else
pread
#endif
(fd, to, GuardLarge(size), off);
} while (ret == -1 && errno == EINTR);
if (ret <= 0) {
UTIL_THROW_IF(ret == 0, EndOfFileException, " for reading " << size << " bytes at " << off << " from " << NameFromFD(fd));
UTIL_THROW_ARG(FDException, (fd), "while reading " << size << " bytes at offset " << off);
}
#endif
size -= ret;
off += ret;
to += ret;
}
}
void WriteOrThrow(int fd, const void *data_void, std::size_t size) { void WriteOrThrow(int fd, const void *data_void, std::size_t size) {
const uint8_t *data = static_cast<const uint8_t*>(data_void); const uint8_t *data = static_cast<const uint8_t*>(data_void);
while (size) { while (size) {
@ -241,6 +204,83 @@ void WriteOrThrow(FILE *to, const void *data, std::size_t size) {
UTIL_THROW_IF(1 != std::fwrite(data, size, 1, to), ErrnoException, "Short write; requested size " << size); UTIL_THROW_IF(1 != std::fwrite(data, size, 1, to), ErrnoException, "Short write; requested size " << size);
} }
#if defined(_WIN32) || defined(_WIN64)
namespace {
const std::size_t kMaxDWORD = static_cast<std::size_t>(4294967295UL);
} // namespace
#endif
void ErsatzPRead(int fd, void *to_void, std::size_t size, uint64_t off) {
uint8_t *to = static_cast<uint8_t*>(to_void);
while (size) {
#if defined(_WIN32) || defined(_WIN64)
/* BROKEN: changes file pointer. Even if you save it and change it back, it won't be safe to use concurrently with write() or read() which lmplz does. */
// size_t might be 64-bit. DWORD is always 32.
DWORD reading = static_cast<DWORD>(std::min<std::size_t>(kMaxDWORD, size));
DWORD ret;
OVERLAPPED overlapped;
memset(&overlapped, 0, sizeof(OVERLAPPED));
overlapped.Offset = static_cast<DWORD>(off);
overlapped.OffsetHigh = static_cast<DWORD>(off >> 32);
UTIL_THROW_IF(!ReadFile((HANDLE)_get_osfhandle(fd), to, reading, &ret, &overlapped), Exception, "ReadFile failed for offset " << off);
#else
ssize_t ret;
errno = 0;
ret =
#ifdef OS_ANDROID
pread64
#else
pread
#endif
(fd, to, GuardLarge(size), off);
if (ret <= 0) {
if (ret == -1 && errno == EINTR) continue;
UTIL_THROW_IF(ret == 0, EndOfFileException, " for reading " << size << " bytes at " << off << " from " << NameFromFD(fd));
UTIL_THROW_ARG(FDException, (fd), "while reading " << size << " bytes at offset " << off);
}
#endif
size -= ret;
off += ret;
to += ret;
}
}
void ErsatzPWrite(int fd, const void *from_void, std::size_t size, uint64_t off) {
const uint8_t *from = static_cast<const uint8_t*>(from_void);
while(size) {
#if defined(_WIN32) || defined(_WIN64)
/* Changes file pointer. Even if you save it and change it back, it won't be safe to use concurrently with write() or read() */
// size_t might be 64-bit. DWORD is always 32.
DWORD writing = static_cast<DWORD>(std::min<std::size_t>(kMaxDWORD, size));
DWORD ret;
OVERLAPPED overlapped;
memset(&overlapped, 0, sizeof(OVERLAPPED));
overlapped.Offset = static_cast<DWORD>(off);
overlapped.OffsetHigh = static_cast<DWORD>(off >> 32);
UTIL_THROW_IF(!WriteFile((HANDLE)_get_osfhandle(fd), from, writing, &ret, &overlapped), Exception, "WriteFile failed for offset " << off);
#else
ssize_t ret;
errno = 0;
ret =
#ifdef OS_ANDROID
pwrite64
#else
pwrite
#endif
(fd, from, GuardLarge(size), off);
if (ret <= 0) {
if (ret == -1 && errno == EINTR) continue;
UTIL_THROW_IF(ret == 0, EndOfFileException, " for writing " << size << " bytes at " << off << " from " << NameFromFD(fd));
UTIL_THROW_ARG(FDException, (fd), "while writing " << size << " bytes at offset " << off);
}
#endif
size -= ret;
off += ret;
from += ret;
}
}
void FSyncOrThrow(int fd) { void FSyncOrThrow(int fd) {
// Apparently windows doesn't have fsync? // Apparently windows doesn't have fsync?
#if !defined(_WIN32) && !defined(_WIN64) #if !defined(_WIN32) && !defined(_WIN64)
@ -443,8 +483,8 @@ void NormalizeTempPrefix(std::string &base) {
) base += '/'; ) base += '/';
} }
int MakeTemp(const std::string &base) { int MakeTemp(const StringPiece &base) {
std::string name(base); std::string name(base.data(), base.size());
name += "XXXXXX"; name += "XXXXXX";
name.push_back(0); name.push_back(0);
int ret; int ret;
@ -452,7 +492,7 @@ int MakeTemp(const std::string &base) {
return ret; return ret;
} }
std::FILE *FMakeTemp(const std::string &base) { std::FILE *FMakeTemp(const StringPiece &base) {
util::scoped_fd file(MakeTemp(base)); util::scoped_fd file(MakeTemp(base));
return FDOpenOrThrow(file); return FDOpenOrThrow(file);
} }
@ -478,14 +518,18 @@ bool TryName(int fd, std::string &out) {
if (-1 == lstat(name.c_str(), &sb)) if (-1 == lstat(name.c_str(), &sb))
return false; return false;
out.resize(sb.st_size + 1); out.resize(sb.st_size + 1);
ssize_t ret = readlink(name.c_str(), &out[0], sb.st_size + 1); // lstat gave us a size, but I've seen it grow, possibly due to symlinks on top of symlinks.
if (-1 == ret) while (true) {
return false; ssize_t ret = readlink(name.c_str(), &out[0], out.size());
if (ret > sb.st_size) { if (-1 == ret)
// Increased in size?! return false;
return false; if ((size_t)ret < out.size()) {
out.resize(ret);
break;
}
// Exponential growth.
out.resize(out.size() * 2);
} }
out.resize(ret);
// Don't use the non-file names. // Don't use the non-file names.
if (!out.empty() && out[0] != '/') if (!out.empty() && out[0] != '/')
return false; return false;

View File

@ -1,7 +1,8 @@
#ifndef UTIL_FILE__ #ifndef UTIL_FILE_H
#define UTIL_FILE__ #define UTIL_FILE_H
#include "util/exception.hh" #include "util/exception.hh"
#include "util/string_piece.hh"
#include <cstddef> #include <cstddef>
#include <cstdio> #include <cstdio>
@ -106,12 +107,20 @@ void ResizeOrThrow(int fd, uint64_t to);
std::size_t PartialRead(int fd, void *to, std::size_t size); std::size_t PartialRead(int fd, void *to, std::size_t size);
void ReadOrThrow(int fd, void *to, std::size_t size); void ReadOrThrow(int fd, void *to, std::size_t size);
std::size_t ReadOrEOF(int fd, void *to_void, std::size_t size); std::size_t ReadOrEOF(int fd, void *to_void, std::size_t size);
// Positioned: unix only for now.
void PReadOrThrow(int fd, void *to, std::size_t size, uint64_t off);
void WriteOrThrow(int fd, const void *data_void, std::size_t size); void WriteOrThrow(int fd, const void *data_void, std::size_t size);
void WriteOrThrow(FILE *to, const void *data, std::size_t size); void WriteOrThrow(FILE *to, const void *data, std::size_t size);
/* These call pread/pwrite in a loop. However, on Windows they call ReadFile/
* WriteFile which changes the file pointer. So it's safe to call ErsatzPRead
* and ErsatzPWrite concurrently (or any combination thereof). But it changes
* the file pointer on windows, so it's not safe to call concurrently with
* anything that uses the implicit file pointer e.g. the Read/Write functions
* above.
*/
void ErsatzPRead(int fd, void *to, std::size_t size, uint64_t off);
void ErsatzPWrite(int fd, const void *data_void, std::size_t size, uint64_t off);
void FSyncOrThrow(int fd); void FSyncOrThrow(int fd);
// Seeking // Seeking
@ -125,8 +134,8 @@ std::FILE *FDOpenReadOrThrow(scoped_fd &file);
// Temporary files // Temporary files
// Append a / if base is a directory. // Append a / if base is a directory.
void NormalizeTempPrefix(std::string &base); void NormalizeTempPrefix(std::string &base);
int MakeTemp(const std::string &prefix); int MakeTemp(const StringPiece &prefix);
std::FILE *FMakeTemp(const std::string &prefix); std::FILE *FMakeTemp(const StringPiece &prefix);
// dup an fd. // dup an fd.
int DupOrThrow(int fd); int DupOrThrow(int fd);
@ -139,4 +148,4 @@ std::string NameFromFD(int fd);
} // namespace util } // namespace util
#endif // UTIL_FILE__ #endif // UTIL_FILE_H

View File

@ -84,6 +84,13 @@ StringPiece FilePiece::ReadLine(char delim) {
} }
} }
bool FilePiece::ReadLineOrEOF(StringPiece &to, char delim) {
try {
to = ReadLine(delim);
} catch (const util::EndOfFileException &e) { return false; }
return true;
}
float FilePiece::ReadFloat() { float FilePiece::ReadFloat() {
return ReadNumber<float>(); return ReadNumber<float>();
} }

View File

@ -1,5 +1,5 @@
#ifndef UTIL_FILE_PIECE__ #ifndef UTIL_FILE_PIECE_H
#define UTIL_FILE_PIECE__ #define UTIL_FILE_PIECE_H
#include "util/ersatz_progress.hh" #include "util/ersatz_progress.hh"
#include "util/exception.hh" #include "util/exception.hh"
@ -56,10 +56,33 @@ class FilePiece {
return Consume(FindDelimiterOrEOF(delim)); return Consume(FindDelimiterOrEOF(delim));
} }
// Read word until the line or file ends.
bool ReadWordSameLine(StringPiece &to, const bool *delim = kSpaces) {
assert(delim[static_cast<unsigned char>('\n')]);
// Skip non-enter spaces.
for (; ; ++position_) {
if (position_ == position_end_) {
try {
Shift();
} catch (const util::EndOfFileException &e) { return false; }
// And break out at end of file.
if (position_ == position_end_) return false;
}
if (!delim[static_cast<unsigned char>(*position_)]) break;
if (*position_ == '\n') return false;
}
// We can't be at the end of file because there's at least one character open.
to = Consume(FindDelimiterOrEOF(delim));
return true;
}
// Unlike ReadDelimited, this includes leading spaces and consumes the delimiter. // Unlike ReadDelimited, this includes leading spaces and consumes the delimiter.
// It is similar to getline in that way. // It is similar to getline in that way.
StringPiece ReadLine(char delim = '\n'); StringPiece ReadLine(char delim = '\n');
// Doesn't throw EndOfFileException, just returns false.
bool ReadLineOrEOF(StringPiece &to, char delim = '\n');
float ReadFloat(); float ReadFloat();
double ReadDouble(); double ReadDouble();
long int ReadLong(); long int ReadLong();
@ -132,4 +155,4 @@ class FilePiece {
} // namespace util } // namespace util
#endif // UTIL_FILE_PIECE__ #endif // UTIL_FILE_PIECE_H

153
util/fixed_array.hh Normal file
View File

@ -0,0 +1,153 @@
#ifndef UTIL_FIXED_ARRAY_H
#define UTIL_FIXED_ARRAY_H
#include "util/scoped.hh"
#include <cstddef>
#include <assert.h>
#include <stdlib.h>
namespace util {
/**
* Defines a fixed-size collection.
*
* Ever want an array of things by they don't have a default constructor or are
* non-copyable? FixedArray allows constructing one at a time.
*/
template <class T> class FixedArray {
public:
/** Initialize with a given size bound but do not construct the objects. */
explicit FixedArray(std::size_t limit) {
Init(limit);
}
/**
* Constructs an instance, but does not initialize it.
*
* Any objects constructed in this manner must be subsequently @ref FixedArray::Init() "initialized" prior to use.
*
* @see FixedArray::Init()
*/
FixedArray()
: newed_end_(NULL)
#ifndef NDEBUG
, allocated_end_(NULL)
#endif
{}
/**
* Initialize with a given size bound but do not construct the objects.
*
* This method is responsible for allocating memory.
* Objects stored in this array will be constructed in a location within this allocated memory.
*/
void Init(std::size_t count) {
assert(!block_.get());
block_.reset(malloc(sizeof(T) * count));
if (!block_.get()) throw std::bad_alloc();
newed_end_ = begin();
#ifndef NDEBUG
allocated_end_ = begin() + count;
#endif
}
/**
* Constructs a copy of the provided array.
*
* @param from Array whose elements should be copied into this newly-constructed data structure.
*/
FixedArray(const FixedArray &from) {
std::size_t size = from.newed_end_ - static_cast<const T*>(from.block_.get());
Init(size);
for (std::size_t i = 0; i < size; ++i) {
push_back(from[i]);
}
}
/**
* Frees the memory held by this object.
*/
~FixedArray() { clear(); }
/** Gets a pointer to the first object currently stored in this data structure. */
T *begin() { return static_cast<T*>(block_.get()); }
/** Gets a const pointer to the last object currently stored in this data structure. */
const T *begin() const { return static_cast<const T*>(block_.get()); }
/** Gets a pointer to the last object currently stored in this data structure. */
T *end() { return newed_end_; }
/** Gets a const pointer to the last object currently stored in this data structure. */
const T *end() const { return newed_end_; }
/** Gets a reference to the last object currently stored in this data structure. */
T &back() { return *(end() - 1); }
/** Gets a const reference to the last object currently stored in this data structure. */
const T &back() const { return *(end() - 1); }
/** Gets the number of objects currently stored in this data structure. */
std::size_t size() const { return end() - begin(); }
/** Returns true if there are no objects currently stored in this data structure. */
bool empty() const { return begin() == end(); }
/**
* Gets a reference to the object with index i currently stored in this data structure.
*
* @param i Index of the object to reference
*/
T &operator[](std::size_t i) { return begin()[i]; }
/**
* Gets a const reference to the object with index i currently stored in this data structure.
*
* @param i Index of the object to reference
*/
const T &operator[](std::size_t i) const { return begin()[i]; }
/**
* Constructs a new object using the provided parameter,
* and stores it in this data structure.
*
* The memory backing the constructed object is managed by this data structure.
*/
template <class C> void push_back(const C &c) {
new (end()) T(c); // use "placement new" syntax to initalize T in an already-allocated memory location
Constructed();
}
/**
* Removes all elements from this array.
*/
void clear() {
for (T *i = begin(); i != end(); ++i)
i->~T();
newed_end_ = begin();
}
protected:
// Always call Constructed after successful completion of new.
void Constructed() {
++newed_end_;
#ifndef NDEBUG
assert(newed_end_ <= allocated_end_);
#endif
}
private:
util::scoped_malloc block_;
T *newed_end_;
#ifndef NDEBUG
T *allocated_end_;
#endif
};
} // namespace util
#endif // UTIL_FIXED_ARRAY_H

View File

@ -11,8 +11,8 @@ Code given out at the 1985 UNIFORUM conference in Dallas.
#endif #endif
#ifndef __GNUC__ #ifndef __GNUC__
#ifndef _WINGETOPT_H_ #ifndef UTIL_GETOPT_H
#define _WINGETOPT_H_ #define UTIL_GETOPT_H
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
@ -28,6 +28,6 @@ extern int getopt(int argc, char **argv, char *opts);
} }
#endif #endif
#endif /* _GETOPT_H_ */ #endif /* UTIL_GETOPT_H */
#endif /* __GNUC__ */ #endif /* __GNUC__ */

View File

@ -1,6 +1,6 @@
/* Optional packages. You might want to integrate this with your build system e.g. config.h from ./configure. */ /* Optional packages. You might want to integrate this with your build system e.g. config.h from ./configure. */
#ifndef UTIL_HAVE__ #ifndef UTIL_HAVE_H
#define UTIL_HAVE__ #define UTIL_HAVE_H
#ifdef HAVE_CONFIG_H #ifdef HAVE_CONFIG_H
#include "config.h" #include "config.h"
@ -10,4 +10,4 @@
//#define HAVE_ICU //#define HAVE_ICU
#endif #endif
#endif // UTIL_HAVE__ #endif // UTIL_HAVE_H

View File

@ -1,5 +1,5 @@
#ifndef UTIL_JOINT_SORT__ #ifndef UTIL_JOINT_SORT_H
#define UTIL_JOINT_SORT__ #define UTIL_JOINT_SORT_H
/* A terrifying amount of C++ to coax std::sort into soring one range while /* A terrifying amount of C++ to coax std::sort into soring one range while
* also permuting another range the same way. * also permuting another range the same way.
@ -143,4 +143,4 @@ template <class KeyIter, class ValueIter> void JointSort(const KeyIter &key_begi
} // namespace util } // namespace util
#endif // UTIL_JOINT_SORT__ #endif // UTIL_JOINT_SORT_H

View File

@ -6,6 +6,7 @@
#include "util/exception.hh" #include "util/exception.hh"
#include "util/file.hh" #include "util/file.hh"
#include "util/parallel_read.hh"
#include "util/scoped.hh" #include "util/scoped.hh"
#include <iostream> #include <iostream>
@ -40,7 +41,7 @@ void SyncOrThrow(void *start, size_t length) {
#if defined(_WIN32) || defined(_WIN64) #if defined(_WIN32) || defined(_WIN64)
UTIL_THROW_IF(!::FlushViewOfFile(start, length), ErrnoException, "Failed to sync mmap"); UTIL_THROW_IF(!::FlushViewOfFile(start, length), ErrnoException, "Failed to sync mmap");
#else #else
UTIL_THROW_IF(msync(start, length, MS_SYNC), ErrnoException, "Failed to sync mmap"); UTIL_THROW_IF(length && msync(start, length, MS_SYNC), ErrnoException, "Failed to sync mmap");
#endif #endif
} }
@ -154,6 +155,10 @@ void MapRead(LoadMethod method, int fd, uint64_t offset, std::size_t size, scope
SeekOrThrow(fd, offset); SeekOrThrow(fd, offset);
ReadOrThrow(fd, out.get(), size); ReadOrThrow(fd, out.get(), size);
break; break;
case PARALLEL_READ:
out.reset(MallocOrThrow(size), size, scoped_memory::MALLOC_ALLOCATED);
ParallelRead(fd, out.get(), size, offset);
break;
} }
} }
@ -189,4 +194,66 @@ void *MapZeroedWrite(const char *name, std::size_t size, scoped_fd &file) {
} }
} }
Rolling::Rolling(const Rolling &copy_from, uint64_t increase) {
*this = copy_from;
IncreaseBase(increase);
}
Rolling &Rolling::operator=(const Rolling &copy_from) {
fd_ = copy_from.fd_;
file_begin_ = copy_from.file_begin_;
file_end_ = copy_from.file_end_;
for_write_ = copy_from.for_write_;
block_ = copy_from.block_;
read_bound_ = copy_from.read_bound_;
current_begin_ = 0;
if (copy_from.IsPassthrough()) {
current_end_ = copy_from.current_end_;
ptr_ = copy_from.ptr_;
} else {
// Force call on next mmap.
current_end_ = 0;
ptr_ = NULL;
}
return *this;
}
Rolling::Rolling(int fd, bool for_write, std::size_t block, std::size_t read_bound, uint64_t offset, uint64_t amount) {
current_begin_ = 0;
current_end_ = 0;
fd_ = fd;
file_begin_ = offset;
file_end_ = offset + amount;
for_write_ = for_write;
block_ = block;
read_bound_ = read_bound;
}
void *Rolling::ExtractNonRolling(scoped_memory &out, uint64_t index, std::size_t size) {
out.reset();
if (IsPassthrough()) return static_cast<uint8_t*>(get()) + index;
uint64_t offset = index + file_begin_;
// Round down to multiple of page size.
uint64_t cruft = offset % static_cast<uint64_t>(SizePage());
std::size_t map_size = static_cast<std::size_t>(size + cruft);
out.reset(MapOrThrow(map_size, for_write_, kFileFlags, true, fd_, offset - cruft), map_size, scoped_memory::MMAP_ALLOCATED);
return static_cast<uint8_t*>(out.get()) + static_cast<std::size_t>(cruft);
}
void Rolling::Roll(uint64_t index) {
assert(!IsPassthrough());
std::size_t amount;
if (file_end_ - (index + file_begin_) > static_cast<uint64_t>(block_)) {
amount = block_;
current_end_ = index + amount - read_bound_;
} else {
amount = file_end_ - (index + file_begin_);
current_end_ = index + amount;
}
ptr_ = static_cast<uint8_t*>(ExtractNonRolling(mem_, index, amount)) - index;
current_begin_ = index;
}
} // namespace util } // namespace util

View File

@ -1,8 +1,9 @@
#ifndef UTIL_MMAP__ #ifndef UTIL_MMAP_H
#define UTIL_MMAP__ #define UTIL_MMAP_H
// Utilities for mmaped files. // Utilities for mmaped files.
#include <cstddef> #include <cstddef>
#include <limits>
#include <stdint.h> #include <stdint.h>
#include <sys/types.h> #include <sys/types.h>
@ -52,6 +53,9 @@ class scoped_memory {
public: public:
typedef enum {MMAP_ALLOCATED, ARRAY_ALLOCATED, MALLOC_ALLOCATED, NONE_ALLOCATED} Alloc; typedef enum {MMAP_ALLOCATED, ARRAY_ALLOCATED, MALLOC_ALLOCATED, NONE_ALLOCATED} Alloc;
scoped_memory(void *data, std::size_t size, Alloc source)
: data_(data), size_(size), source_(source) {}
scoped_memory() : data_(NULL), size_(0), source_(NONE_ALLOCATED) {} scoped_memory() : data_(NULL), size_(0), source_(NONE_ALLOCATED) {}
~scoped_memory() { reset(); } ~scoped_memory() { reset(); }
@ -72,7 +76,6 @@ class scoped_memory {
void call_realloc(std::size_t to); void call_realloc(std::size_t to);
private: private:
void *data_; void *data_;
std::size_t size_; std::size_t size_;
@ -90,7 +93,9 @@ typedef enum {
// Populate on Linux. malloc and read on non-Linux. // Populate on Linux. malloc and read on non-Linux.
POPULATE_OR_READ, POPULATE_OR_READ,
// malloc and read. // malloc and read.
READ READ,
// malloc and read in parallel (recommended for Lustre)
PARALLEL_READ,
} LoadMethod; } LoadMethod;
extern const int kFileFlags; extern const int kFileFlags;
@ -109,6 +114,79 @@ void *MapZeroedWrite(const char *name, std::size_t size, scoped_fd &file);
// msync wrapper // msync wrapper
void SyncOrThrow(void *start, size_t length); void SyncOrThrow(void *start, size_t length);
// Forward rolling memory map with no overlap.
class Rolling {
public:
Rolling() {}
explicit Rolling(void *data) { Init(data); }
Rolling(const Rolling &copy_from, uint64_t increase = 0);
Rolling &operator=(const Rolling &copy_from);
// For an actual rolling mmap.
explicit Rolling(int fd, bool for_write, std::size_t block, std::size_t read_bound, uint64_t offset, uint64_t amount);
// For a static mapping
void Init(void *data) {
ptr_ = data;
current_end_ = std::numeric_limits<uint64_t>::max();
current_begin_ = 0;
// Mark as a pass-through.
fd_ = -1;
}
void IncreaseBase(uint64_t by) {
file_begin_ += by;
ptr_ = static_cast<uint8_t*>(ptr_) + by;
if (!IsPassthrough()) current_end_ = 0;
}
void DecreaseBase(uint64_t by) {
file_begin_ -= by;
ptr_ = static_cast<uint8_t*>(ptr_) - by;
if (!IsPassthrough()) current_end_ = 0;
}
void *ExtractNonRolling(scoped_memory &out, uint64_t index, std::size_t size);
// Returns base pointer
void *get() const { return ptr_; }
// Returns base pointer.
void *CheckedBase(uint64_t index) {
if (index >= current_end_ || index < current_begin_) {
Roll(index);
}
return ptr_;
}
// Returns indexed pointer.
void *CheckedIndex(uint64_t index) {
return static_cast<uint8_t*>(CheckedBase(index)) + index;
}
private:
void Roll(uint64_t index);
// True if this is just a thin wrapper on a pointer.
bool IsPassthrough() const { return fd_ == -1; }
void *ptr_;
uint64_t current_begin_;
uint64_t current_end_;
scoped_memory mem_;
int fd_;
uint64_t file_begin_;
uint64_t file_end_;
bool for_write_;
std::size_t block_;
std::size_t read_bound_;
};
} // namespace util } // namespace util
#endif // UTIL_MMAP__ #endif // UTIL_MMAP_H

View File

@ -1,5 +1,5 @@
#ifndef UTIL_MULTI_INTERSECTION__ #ifndef UTIL_MULTI_INTERSECTION_H
#define UTIL_MULTI_INTERSECTION__ #define UTIL_MULTI_INTERSECTION_H
#include <boost/optional.hpp> #include <boost/optional.hpp>
#include <boost/range/iterator_range.hpp> #include <boost/range/iterator_range.hpp>
@ -77,4 +77,4 @@ template <class Iterator, class Output> void AllIntersection(std::vector<boost::
} // namespace util } // namespace util
#endif // UTIL_MULTI_INTERSECTION__ #endif // UTIL_MULTI_INTERSECTION_H

View File

@ -1,5 +1,5 @@
#ifndef UTIL_MURMUR_HASH__ #ifndef UTIL_MURMUR_HASH_H
#define UTIL_MURMUR_HASH__ #define UTIL_MURMUR_HASH_H
#include <cstddef> #include <cstddef>
#include <stdint.h> #include <stdint.h>
@ -15,4 +15,4 @@ uint64_t MurmurHashNative(const void * key, std::size_t len, uint64_t seed = 0);
} // namespace util } // namespace util
#endif // UTIL_MURMUR_HASH__ #endif // UTIL_MURMUR_HASH_H

69
util/parallel_read.cc Normal file
View File

@ -0,0 +1,69 @@
#include "util/parallel_read.hh"
#include "util/file.hh"
#ifdef WITH_THREADS
#include "util/thread_pool.hh"
namespace util {
namespace {
class Reader {
public:
explicit Reader(int fd) : fd_(fd) {}
struct Request {
void *to;
std::size_t size;
uint64_t offset;
bool operator==(const Request &other) const {
return (to == other.to) && (size == other.size) && (offset == other.offset);
}
};
void operator()(const Request &request) {
util::ErsatzPRead(fd_, request.to, request.size, request.offset);
}
private:
int fd_;
};
} // namespace
void ParallelRead(int fd, void *to, std::size_t amount, uint64_t offset) {
Reader::Request poison;
poison.to = NULL;
poison.size = 0;
poison.offset = 0;
unsigned threads = boost::thread::hardware_concurrency();
if (!threads) threads = 2;
ThreadPool<Reader> pool(2 /* don't need much of a queue */, threads, fd, poison);
const std::size_t kBatch = 1ULL << 25; // 32 MB
Reader::Request request;
request.to = to;
request.size = kBatch;
request.offset = offset;
for (; amount > kBatch; amount -= kBatch) {
pool.Produce(request);
request.to = reinterpret_cast<uint8_t*>(request.to) + kBatch;
request.offset += kBatch;
}
request.size = amount;
if (request.size) {
pool.Produce(request);
}
}
} // namespace util
#else // WITH_THREADS
namespace util {
void ParallelRead(int fd, void *to, std::size_t amount, uint64_t offset) {
util::ErsatzPRead(fd, to, amount, offset);
}
} // namespace util
#endif

16
util/parallel_read.hh Normal file
View File

@ -0,0 +1,16 @@
#ifndef UTIL_PARALLEL_READ__
#define UTIL_PARALLEL_READ__
/* Read pieces of a file in parallel. This has a very specific use case:
* reading files from Lustre is CPU bound so multiple threads actually
* increases throughput. Speed matters when an LM takes a terabyte.
*/
#include <cstddef>
#include <stdint.h>
namespace util {
void ParallelRead(int fd, void *to, std::size_t amount, uint64_t offset);
} // namespace util
#endif // UTIL_PARALLEL_READ__

View File

@ -1,5 +1,5 @@
#ifndef UTIL_PCQUEUE__ #ifndef UTIL_PCQUEUE_H
#define UTIL_PCQUEUE__ #define UTIL_PCQUEUE_H
#include "util/exception.hh" #include "util/exception.hh"
@ -72,7 +72,8 @@ inline void WaitSemaphore (Semaphore &on) {
#endif // __APPLE__ #endif // __APPLE__
/* Producer consumer queue safe for multiple producers and multiple consumers. /**
* Producer consumer queue safe for multiple producers and multiple consumers.
* T must be default constructable and have operator=. * T must be default constructable and have operator=.
* The value is copied twice for Consume(T &out) or three times for Consume(), * The value is copied twice for Consume(T &out) or three times for Consume(),
* so larger objects should be passed via pointer. * so larger objects should be passed via pointer.
@ -152,4 +153,4 @@ template <class T> class PCQueue : boost::noncopyable {
} // namespace util } // namespace util
#endif // UTIL_PCQUEUE__ #endif // UTIL_PCQUEUE_H

View File

@ -1,8 +1,8 @@
// Very simple pool. It can only allocate memory. And all of the memory it // Very simple pool. It can only allocate memory. And all of the memory it
// allocates must be freed at the same time. // allocates must be freed at the same time.
#ifndef UTIL_POOL__ #ifndef UTIL_POOL_H
#define UTIL_POOL__ #define UTIL_POOL_H
#include <vector> #include <vector>
@ -42,4 +42,4 @@ class Pool {
} // namespace util } // namespace util
#endif // UTIL_POOL__ #endif // UTIL_POOL_H

View File

@ -1,5 +1,5 @@
#ifndef UTIL_PROBING_HASH_TABLE__ #ifndef UTIL_PROBING_HASH_TABLE_H
#define UTIL_PROBING_HASH_TABLE__ #define UTIL_PROBING_HASH_TABLE_H
#include "util/exception.hh" #include "util/exception.hh"
#include "util/scoped.hh" #include "util/scoped.hh"
@ -258,6 +258,10 @@ template <class EntryT, class HashT, class EqualT = std::equal_to<typename Entry
private: private:
typedef ProbingHashTable<EntryT, HashT, EqualT> Backend; typedef ProbingHashTable<EntryT, HashT, EqualT> Backend;
public: public:
static std::size_t MemUsage(std::size_t size, float multiplier = 1.5) {
return Backend::Size(size, multiplier);
}
typedef EntryT Entry; typedef EntryT Entry;
typedef typename Entry::Key Key; typedef typename Entry::Key Key;
typedef const Entry *ConstIterator; typedef const Entry *ConstIterator;
@ -268,6 +272,7 @@ template <class EntryT, class HashT, class EqualT = std::equal_to<typename Entry
AutoProbing(std::size_t initial_size = 10, const Key &invalid = Key(), const Hash &hash_func = Hash(), const Equal &equal_func = Equal()) : AutoProbing(std::size_t initial_size = 10, const Key &invalid = Key(), const Hash &hash_func = Hash(), const Equal &equal_func = Equal()) :
allocated_(Backend::Size(initial_size, 1.5)), mem_(util::MallocOrThrow(allocated_)), backend_(mem_.get(), allocated_, invalid, hash_func, equal_func) { allocated_(Backend::Size(initial_size, 1.5)), mem_(util::MallocOrThrow(allocated_)), backend_(mem_.get(), allocated_, invalid, hash_func, equal_func) {
threshold_ = initial_size * 1.2; threshold_ = initial_size * 1.2;
Clear();
} }
// Assumes that the key is unique. Multiple insertions won't cause a failure, just inconsistent lookup. // Assumes that the key is unique. Multiple insertions won't cause a failure, just inconsistent lookup.
@ -323,4 +328,4 @@ template <class EntryT, class HashT, class EqualT = std::equal_to<typename Entry
} // namespace util } // namespace util
#endif // UTIL_PROBING_HASH_TABLE__ #endif // UTIL_PROBING_HASH_TABLE_H

View File

@ -1,5 +1,5 @@
#ifndef UTIL_PROXY_ITERATOR__ #ifndef UTIL_PROXY_ITERATOR_H
#define UTIL_PROXY_ITERATOR__ #define UTIL_PROXY_ITERATOR_H
#include <cstddef> #include <cstddef>
#include <iterator> #include <iterator>
@ -98,4 +98,4 @@ template <class Proxy> ProxyIterator<Proxy> operator+(std::ptrdiff_t amount, con
} // namespace util } // namespace util
#endif // UTIL_PROXY_ITERATOR__ #endif // UTIL_PROXY_ITERATOR_H

View File

@ -1,5 +1,5 @@
#ifndef UTIL_READ_COMPRESSED__ #ifndef UTIL_READ_COMPRESSED_H
#define UTIL_READ_COMPRESSED__ #define UTIL_READ_COMPRESSED_H
#include "util/exception.hh" #include "util/exception.hh"
#include "util/scoped.hh" #include "util/scoped.hh"
@ -78,4 +78,4 @@ class ReadCompressed {
} // namespace util } // namespace util
#endif // UTIL_READ_COMPRESSED__ #endif // UTIL_READ_COMPRESSED_H

View File

@ -1,5 +1,5 @@
#ifndef UTIL_SCOPED__ #ifndef UTIL_SCOPED_H
#define UTIL_SCOPED__ #define UTIL_SCOPED_H
/* Other scoped objects in the style of scoped_ptr. */ /* Other scoped objects in the style of scoped_ptr. */
#include "util/exception.hh" #include "util/exception.hh"
@ -101,4 +101,4 @@ template <class T> class scoped_ptr {
} // namespace util } // namespace util
#endif // UTIL_SCOPED__ #endif // UTIL_SCOPED_H

View File

@ -1,5 +1,5 @@
#ifndef UTIL_SIZED_ITERATOR__ #ifndef UTIL_SIZED_ITERATOR_H
#define UTIL_SIZED_ITERATOR__ #define UTIL_SIZED_ITERATOR_H
#include "util/proxy_iterator.hh" #include "util/proxy_iterator.hh"
@ -117,4 +117,4 @@ template <class Delegate, class Proxy = SizedProxy> class SizedCompare : public
}; };
} // namespace util } // namespace util
#endif // UTIL_SIZED_ITERATOR__ #endif // UTIL_SIZED_ITERATOR_H

View File

@ -1,5 +1,5 @@
#ifndef UTIL_SORTED_UNIFORM__ #ifndef UTIL_SORTED_UNIFORM_H
#define UTIL_SORTED_UNIFORM__ #define UTIL_SORTED_UNIFORM_H
#include <algorithm> #include <algorithm>
#include <cstddef> #include <cstddef>
@ -124,4 +124,4 @@ template <class Iterator, class Accessor> Iterator BinaryBelow(
} // namespace util } // namespace util
#endif // UTIL_SORTED_UNIFORM__ #endif // UTIL_SORTED_UNIFORM_H

View File

@ -1,5 +1,5 @@
#ifndef UTIL_STREAM_BLOCK__ #ifndef UTIL_STREAM_BLOCK_H
#define UTIL_STREAM_BLOCK__ #define UTIL_STREAM_BLOCK_H
#include <cstddef> #include <cstddef>
#include <stdint.h> #include <stdint.h>
@ -7,28 +7,77 @@
namespace util { namespace util {
namespace stream { namespace stream {
/**
* Encapsulates a block of memory.
*/
class Block { class Block {
public: public:
/**
* Constructs an empty block.
*/
Block() : mem_(NULL), valid_size_(0) {} Block() : mem_(NULL), valid_size_(0) {}
/**
* Constructs a block that encapsulates a segment of memory.
*
* @param[in] mem The segment of memory to encapsulate
* @param[in] size The size of the memory segment in bytes
*/
Block(void *mem, std::size_t size) : mem_(mem), valid_size_(size) {} Block(void *mem, std::size_t size) : mem_(mem), valid_size_(size) {}
/**
* Set the number of bytes in this block that should be interpreted as valid.
*
* @param[in] to Number of bytes
*/
void SetValidSize(std::size_t to) { valid_size_ = to; } void SetValidSize(std::size_t to) { valid_size_ = to; }
// Read might fill in less than Allocated at EOF.
/**
* Gets the number of bytes in this block that should be interpreted as valid.
* This is important because read might fill in less than Allocated at EOF.
*/
std::size_t ValidSize() const { return valid_size_; } std::size_t ValidSize() const { return valid_size_; }
/** Gets a void pointer to the memory underlying this block. */
void *Get() { return mem_; } void *Get() { return mem_; }
/** Gets a const void pointer to the memory underlying this block. */
const void *Get() const { return mem_; } const void *Get() const { return mem_; }
/**
* Gets a const void pointer to the end of the valid section of memory
* encapsulated by this block.
*/
const void *ValidEnd() const { const void *ValidEnd() const {
return reinterpret_cast<const uint8_t*>(mem_) + valid_size_; return reinterpret_cast<const uint8_t*>(mem_) + valid_size_;
} }
/**
* Returns true if this block encapsulates a valid (non-NULL) block of memory.
*
* This method is a user-defined implicit conversion function to boolean;
* among other things, this method enables bare instances of this class
* to be used as the condition of an if statement.
*/
operator bool() const { return mem_ != NULL; } operator bool() const { return mem_ != NULL; }
/**
* Returns true if this block is empty.
*
* In other words, if Get()==NULL, this method will return true.
*/
bool operator!() const { return mem_ == NULL; } bool operator!() const { return mem_ == NULL; }
private: private:
friend class Link; friend class Link;
/**
* Points this block's memory at NULL.
*
* This class defines poison as a block whose memory pointer is NULL.
*/
void SetToPoison() { void SetToPoison() {
mem_ = NULL; mem_ = NULL;
} }
@ -40,4 +89,4 @@ class Block {
} // namespace stream } // namespace stream
} // namespace util } // namespace util
#endif // UTIL_STREAM_BLOCK__ #endif // UTIL_STREAM_BLOCK_H

View File

@ -59,6 +59,11 @@ Chain &Chain::operator>>(const WriteAndRecycle &writer) {
return *this; return *this;
} }
Chain &Chain::operator>>(const PWriteAndRecycle &writer) {
threads_.push_back(new Thread(Complete(), writer));
return *this;
}
void Chain::Wait(bool release_memory) { void Chain::Wait(bool release_memory) {
if (queues_.empty()) { if (queues_.empty()) {
assert(threads_.empty()); assert(threads_.empty());
@ -126,7 +131,12 @@ Link::~Link() {
// abort(); // abort();
} else { } else {
if (!poisoned_) { if (!poisoned_) {
// Pass the poison! // Poison is a block whose memory pointer is NULL.
//
// Because we're in the else block,
// we know that the memory pointer of current_ is NULL.
//
// Pass the current (poison) block!
out_->Produce(current_); out_->Produce(current_);
} }
} }

View File

@ -1,5 +1,5 @@
#ifndef UTIL_STREAM_CHAIN__ #ifndef UTIL_STREAM_CHAIN_H
#define UTIL_STREAM_CHAIN__ #define UTIL_STREAM_CHAIN_H
#include "util/stream/block.hh" #include "util/stream/block.hh"
#include "util/stream/config.hh" #include "util/stream/config.hh"
@ -24,7 +24,12 @@ class ChainConfigException : public Exception {
}; };
class Chain; class Chain;
// Specifies position in chain for Link constructor.
/**
* Encapsulates a @ref PCQueue "producer queue" and a @ref PCQueue "consumer queue" within a @ref Chain "chain".
*
* Specifies position in chain for Link constructor.
*/
class ChainPosition { class ChainPosition {
public: public:
const Chain &GetChain() const { return *chain_; } const Chain &GetChain() const { return *chain_; }
@ -41,14 +46,32 @@ class ChainPosition {
WorkerProgress progress_; WorkerProgress progress_;
}; };
// Position is usually ChainPosition but if there are multiple streams involved, this can be ChainPositions.
/**
* Encapsulates a worker thread processing data at a given position in the chain.
*
* Each instance of this class owns one boost thread in which the worker is Run().
*/
class Thread { class Thread {
public: public:
/**
* Constructs a new Thread in which the provided Worker is Run().
*
* Position is usually ChainPosition but if there are multiple streams involved, this can be ChainPositions.
*
* After a call to this constructor, the provided worker will be running within a boost thread owned by the newly constructed Thread object.
*/
template <class Position, class Worker> Thread(const Position &position, const Worker &worker) template <class Position, class Worker> Thread(const Position &position, const Worker &worker)
: thread_(boost::ref(*this), position, worker) {} : thread_(boost::ref(*this), position, worker) {}
~Thread(); ~Thread();
/**
* Launches the provided worker in this object's boost thread.
*
* This method is called automatically by this class's @ref Thread() "constructor".
*/
template <class Position, class Worker> void operator()(const Position &position, Worker &worker) { template <class Position, class Worker> void operator()(const Position &position, Worker &worker) {
try { try {
worker.Run(position); worker.Run(position);
@ -63,14 +86,27 @@ class Thread {
boost::thread thread_; boost::thread thread_;
}; };
/**
* This resets blocks to full valid size. Used to close the loop in Chain by recycling blocks.
*/
class Recycler { class Recycler {
public: public:
/**
* Resets the blocks in the chain such that the blocks' respective valid sizes match the chain's block size.
*
* @see Block::SetValidSize()
* @see Chain::BlockSize()
*/
void Run(const ChainPosition &position); void Run(const ChainPosition &position);
}; };
extern const Recycler kRecycle; extern const Recycler kRecycle;
class WriteAndRecycle; class WriteAndRecycle;
class PWriteAndRecycle;
/**
* Represents a sequence of workers, through which @ref Block "blocks" can pass.
*/
class Chain { class Chain {
private: private:
template <class T, void (T::*ptr)(const ChainPosition &) = &T::Run> struct CheckForRun { template <class T, void (T::*ptr)(const ChainPosition &) = &T::Run> struct CheckForRun {
@ -78,8 +114,20 @@ class Chain {
}; };
public: public:
/**
* Constructs a configured Chain.
*
* @param config Specifies how to configure the Chain.
*/
explicit Chain(const ChainConfig &config); explicit Chain(const ChainConfig &config);
/**
* Destructs a Chain.
*
* This method waits for the chain's threads to complete,
* and frees the memory held by this chain.
*/
~Chain(); ~Chain();
void ActivateProgress() { void ActivateProgress() {
@ -91,24 +139,49 @@ class Chain {
progress_.SetTarget(target); progress_.SetTarget(target);
} }
/**
* Gets the number of bytes in each record of a Block.
*
* @see ChainConfig::entry_size
*/
std::size_t EntrySize() const { std::size_t EntrySize() const {
return config_.entry_size; return config_.entry_size;
} }
/**
* Gets the inital @ref Block::ValidSize "valid size" for @ref Block "blocks" in this chain.
*
* @see Block::ValidSize
*/
std::size_t BlockSize() const { std::size_t BlockSize() const {
return block_size_; return block_size_;
} }
// Two ways to add to the chain: Add() or operator>>. /** Two ways to add to the chain: Add() or operator>>. */
ChainPosition Add(); ChainPosition Add();
// This is for adding threaded workers with a Run method. /**
* Adds a new worker to this chain,
* and runs that worker in a new Thread owned by this chain.
*
* The worker must have a Run method that accepts a position argument.
*
* @see Thread::operator()()
*/
template <class Worker> typename CheckForRun<Worker>::type &operator>>(const Worker &worker) { template <class Worker> typename CheckForRun<Worker>::type &operator>>(const Worker &worker) {
assert(!complete_called_); assert(!complete_called_);
threads_.push_back(new Thread(Add(), worker)); threads_.push_back(new Thread(Add(), worker));
return *this; return *this;
} }
// Avoid copying the worker. /**
* Adds a new worker to this chain (but avoids copying that worker),
* and runs that worker in a new Thread owned by this chain.
*
* The worker must have a Run method that accepts a position argument.
*
* @see Thread::operator()()
*/
template <class Worker> typename CheckForRun<Worker>::type &operator>>(const boost::reference_wrapper<Worker> &worker) { template <class Worker> typename CheckForRun<Worker>::type &operator>>(const boost::reference_wrapper<Worker> &worker) {
assert(!complete_called_); assert(!complete_called_);
threads_.push_back(new Thread(Add(), worker)); threads_.push_back(new Thread(Add(), worker));
@ -122,12 +195,21 @@ class Chain {
threads_.push_back(new Thread(Complete(), kRecycle)); threads_.push_back(new Thread(Complete(), kRecycle));
} }
/**
* Adds a Recycler worker to this chain,
* and runs that worker in a new Thread owned by this chain.
*/
Chain &operator>>(const Recycler &) { Chain &operator>>(const Recycler &) {
CompleteLoop(); CompleteLoop();
return *this; return *this;
} }
/**
* Adds a WriteAndRecycle worker to this chain,
* and runs that worker in a new Thread owned by this chain.
*/
Chain &operator>>(const WriteAndRecycle &writer); Chain &operator>>(const WriteAndRecycle &writer);
Chain &operator>>(const PWriteAndRecycle &writer);
// Chains are reusable. Call Wait to wait for everything to finish and free memory. // Chains are reusable. Call Wait to wait for everything to finish and free memory.
void Wait(bool release_memory = true); void Wait(bool release_memory = true);
@ -156,28 +238,87 @@ class Chain {
}; };
// Create the link in the worker thread using the position token. // Create the link in the worker thread using the position token.
/**
* Represents a C++ style iterator over @ref Block "blocks".
*/
class Link { class Link {
public: public:
// Either default construct and Init or just construct all at once. // Either default construct and Init or just construct all at once.
/**
* Constructs an @ref Init "initialized" link.
*
* @see Init
*/
explicit Link(const ChainPosition &position);
/**
* Constructs a link that must subsequently be @ref Init "initialized".
*
* @see Init
*/
Link(); Link();
/**
* Initializes the link with the input @ref PCQueue "consumer queue" and output @ref PCQueue "producer queue" at a given @ref ChainPosition "position" in the @ref Chain "chain".
*
* @see Link()
*/
void Init(const ChainPosition &position); void Init(const ChainPosition &position);
explicit Link(const ChainPosition &position); /**
* Destructs the link object.
*
* If necessary, this method will pass a poison block
* to this link's output @ref PCQueue "producer queue".
*
* @see Block::SetToPoison()
*/
~Link(); ~Link();
/**
* Gets a reference to the @ref Block "block" at this link.
*/
Block &operator*() { return current_; } Block &operator*() { return current_; }
/**
* Gets a const reference to the @ref Block "block" at this link.
*/
const Block &operator*() const { return current_; } const Block &operator*() const { return current_; }
/**
* Gets a pointer to the @ref Block "block" at this link.
*/
Block *operator->() { return &current_; } Block *operator->() { return &current_; }
/**
* Gets a const pointer to the @ref Block "block" at this link.
*/
const Block *operator->() const { return &current_; } const Block *operator->() const { return &current_; }
/**
* Gets the link at the next @ref ChainPosition "position" in the @ref Chain "chain".
*/
Link &operator++(); Link &operator++();
/**
* Returns true if the @ref Block "block" at this link encapsulates a valid (non-NULL) block of memory.
*
* This method is a user-defined implicit conversion function to boolean;
* among other things, this method enables bare instances of this class
* to be used as the condition of an if statement.
*/
operator bool() const { return current_; } operator bool() const { return current_; }
/**
* @ref Block::SetToPoison() "Poisons" the @ref Block "block" at this link,
* and passes this now-poisoned block to this link's output @ref PCQueue "producer queue".
*
* @see Block::SetToPoison()
*/
void Poison(); void Poison();
private: private:
Block current_; Block current_;
PCQueue<Block> *in_, *out_; PCQueue<Block> *in_, *out_;
@ -195,4 +336,4 @@ inline Chain &operator>>(Chain &chain, Link &link) {
} // namespace stream } // namespace stream
} // namespace util } // namespace util
#endif // UTIL_STREAM_CHAIN__ #endif // UTIL_STREAM_CHAIN_H

View File

@ -1,32 +1,63 @@
#ifndef UTIL_STREAM_CONFIG__ #ifndef UTIL_STREAM_CONFIG_H
#define UTIL_STREAM_CONFIG__ #define UTIL_STREAM_CONFIG_H
#include <cstddef> #include <cstddef>
#include <string> #include <string>
namespace util { namespace stream { namespace util { namespace stream {
/**
* Represents how a chain should be configured.
*/
struct ChainConfig { struct ChainConfig {
/** Constructs an configuration with underspecified (or default) parameters. */
ChainConfig() {} ChainConfig() {}
/**
* Constructs a chain configuration object.
*
* @param [in] in_entry_size Number of bytes in each record.
* @param [in] in_block_count Number of blocks in the chain.
* @param [in] in_total_memory Total number of bytes available to the chain.
* This value will be divided amongst the blocks in the chain.
*/
ChainConfig(std::size_t in_entry_size, std::size_t in_block_count, std::size_t in_total_memory) ChainConfig(std::size_t in_entry_size, std::size_t in_block_count, std::size_t in_total_memory)
: entry_size(in_entry_size), block_count(in_block_count), total_memory(in_total_memory) {} : entry_size(in_entry_size), block_count(in_block_count), total_memory(in_total_memory) {}
/**
* Number of bytes in each record.
*/
std::size_t entry_size; std::size_t entry_size;
/**
* Number of blocks in the chain.
*/
std::size_t block_count; std::size_t block_count;
// Chain's constructor will make this a multiple of entry_size.
/**
* Total number of bytes available to the chain.
* This value will be divided amongst the blocks in the chain.
* Chain's constructor will make this a multiple of entry_size.
*/
std::size_t total_memory; std::size_t total_memory;
}; };
/**
* Represents how a sorter should be configured.
*/
struct SortConfig { struct SortConfig {
/** Filename prefix where temporary files should be placed. */
std::string temp_prefix; std::string temp_prefix;
// Size of each input/output buffer. /** Size of each input/output buffer. */
std::size_t buffer_size; std::size_t buffer_size;
// Total memory to use when running alone. /** Total memory to use when running alone. */
std::size_t total_memory; std::size_t total_memory;
}; };
}} // namespaces }} // namespaces
#endif // UTIL_STREAM_CONFIG__ #endif // UTIL_STREAM_CONFIG_H

View File

@ -36,12 +36,12 @@ void PRead::Run(const ChainPosition &position) {
Link link(position); Link link(position);
uint64_t offset = 0; uint64_t offset = 0;
for (; offset + block_size64 < size; offset += block_size64, ++link) { for (; offset + block_size64 < size; offset += block_size64, ++link) {
PReadOrThrow(file_, link->Get(), block_size, offset); ErsatzPRead(file_, link->Get(), block_size, offset);
link->SetValidSize(block_size); link->SetValidSize(block_size);
} }
// size - offset is <= block_size, so it casts to 32-bit fine. // size - offset is <= block_size, so it casts to 32-bit fine.
if (size - offset) { if (size - offset) {
PReadOrThrow(file_, link->Get(), size - offset, offset); ErsatzPRead(file_, link->Get(), size - offset, offset);
link->SetValidSize(size - offset); link->SetValidSize(size - offset);
++link; ++link;
} }
@ -62,5 +62,15 @@ void WriteAndRecycle::Run(const ChainPosition &position) {
} }
} }
void PWriteAndRecycle::Run(const ChainPosition &position) {
const std::size_t block_size = position.GetChain().BlockSize();
uint64_t offset = 0;
for (Link link(position); link; ++link) {
ErsatzPWrite(file_, link->Get(), link->ValidSize(), offset);
offset += link->ValidSize();
link->SetValidSize(block_size);
}
}
} // namespace stream } // namespace stream
} // namespace util } // namespace util

View File

@ -1,5 +1,5 @@
#ifndef UTIL_STREAM_IO__ #ifndef UTIL_STREAM_IO_H
#define UTIL_STREAM_IO__ #define UTIL_STREAM_IO_H
#include "util/exception.hh" #include "util/exception.hh"
#include "util/file.hh" #include "util/file.hh"
@ -41,6 +41,8 @@ class Write {
int file_; int file_;
}; };
// It's a common case that stuff is written and then recycled. So rather than
// spawn another thread to Recycle, this combines the two roles.
class WriteAndRecycle { class WriteAndRecycle {
public: public:
explicit WriteAndRecycle(int fd) : file_(fd) {} explicit WriteAndRecycle(int fd) : file_(fd) {}
@ -49,14 +51,23 @@ class WriteAndRecycle {
int file_; int file_;
}; };
class PWriteAndRecycle {
public:
explicit PWriteAndRecycle(int fd) : file_(fd) {}
void Run(const ChainPosition &position);
private:
int file_;
};
// Reuse the same file over and over again to buffer output. // Reuse the same file over and over again to buffer output.
class FileBuffer { class FileBuffer {
public: public:
explicit FileBuffer(int fd) : file_(fd) {} explicit FileBuffer(int fd) : file_(fd) {}
WriteAndRecycle Sink() const { PWriteAndRecycle Sink() const {
util::SeekOrThrow(file_.get(), 0); util::SeekOrThrow(file_.get(), 0);
return WriteAndRecycle(file_.get()); return PWriteAndRecycle(file_.get());
} }
PRead Source() const { PRead Source() const {
@ -73,4 +84,4 @@ class FileBuffer {
} // namespace stream } // namespace stream
} // namespace util } // namespace util
#endif // UTIL_STREAM_IO__ #endif // UTIL_STREAM_IO_H

View File

@ -1,5 +1,5 @@
#ifndef UTIL_STREAM_LINE_INPUT__ #ifndef UTIL_STREAM_LINE_INPUT_H
#define UTIL_STREAM_LINE_INPUT__ #define UTIL_STREAM_LINE_INPUT_H
namespace util {namespace stream { namespace util {namespace stream {
class ChainPosition; class ChainPosition;
@ -19,4 +19,4 @@ class LineInput {
}; };
}} // namespaces }} // namespaces
#endif // UTIL_STREAM_LINE_INPUT__ #endif // UTIL_STREAM_LINE_INPUT_H

View File

@ -1,6 +1,6 @@
/* Progress bar suitable for chains of workers */ /* Progress bar suitable for chains of workers */
#ifndef UTIL_MULTI_PROGRESS__ #ifndef UTIL_STREAM_MULTI_PROGRESS_H
#define UTIL_MULTI_PROGRESS__ #define UTIL_STREAM_MULTI_PROGRESS_H
#include <boost/thread/mutex.hpp> #include <boost/thread/mutex.hpp>
@ -87,4 +87,4 @@ class WorkerProgress {
}} // namespaces }} // namespaces
#endif // UTIL_MULTI_PROGRESS__ #endif // UTIL_STREAM_MULTI_PROGRESS_H

Some files were not shown because too many files have changed in this diff Show More