From e0d618528a9b81e52bf6cfc245f74f2490e56bdb Mon Sep 17 00:00:00 2001 From: heafield Date: Mon, 23 May 2011 02:23:01 +0000 Subject: [PATCH] Speed improvements mostly. Trie went from 803448 queries/s to 990520 queries/s by knowing what the bounds are in advance. Also, set read ahead for files. git-svn-id: https://mosesdecoder.svn.sourceforge.net/svnroot/mosesdecoder/trunk@3988 1f5c12ca-751b-0410-a591-d2e778427230 --- kenlm/lm/build_binary.cc | 4 ++-- kenlm/lm/config.hh | 2 +- kenlm/lm/lm_exception.hh | 4 ++-- kenlm/lm/read_arpa.cc | 6 +++--- kenlm/lm/trie.cc | 18 ++++++++-------- kenlm/lm/trie.hh | 2 +- kenlm/lm/vocab.cc | 4 ++-- kenlm/lm/vocab.hh | 3 ++- kenlm/util/file_piece.cc | 7 +++++- kenlm/util/sorted_uniform.hh | 41 +++++++++++++++++++++--------------- 10 files changed, 52 insertions(+), 39 deletions(-) diff --git a/kenlm/lm/build_binary.cc b/kenlm/lm/build_binary.cc index 2213c0b49..91ad2fb94 100644 --- a/kenlm/lm/build_binary.cc +++ b/kenlm/lm/build_binary.cc @@ -91,10 +91,10 @@ int main(int argc, char *argv[]) { config.building_memory = ParseUInt(optarg) * 1048576; break; case 's': - config.sentence_marker_missing = lm::NOTHING; + config.sentence_marker_missing = lm::SILENT; break; case 'i': - config.positive_log_probability = lm::NOTHING; + config.positive_log_probability = lm::SILENT; break; default: Usage(argv[0]); diff --git a/kenlm/lm/config.hh b/kenlm/lm/config.hh index 9c8761ee8..6c7fe39b9 100644 --- a/kenlm/lm/config.hh +++ b/kenlm/lm/config.hh @@ -34,7 +34,7 @@ struct Config { // If THROW_UP, the exception will be of type util::SpecialWordMissingException. WarningAction sentence_marker_missing; - // What to do with a positive log probability. For COMPLAIN and NOTHING, map + // What to do with a positive log probability. For COMPLAIN and SILENT, map // to 0. WarningAction positive_log_probability; diff --git a/kenlm/lm/lm_exception.hh b/kenlm/lm/lm_exception.hh index 50df38493..f607ced16 100644 --- a/kenlm/lm/lm_exception.hh +++ b/kenlm/lm/lm_exception.hh @@ -11,7 +11,7 @@ namespace lm { -typedef enum {THROW_UP, COMPLAIN, NOTHING} WarningAction; +typedef enum {THROW_UP, COMPLAIN, SILENT} WarningAction; class ConfigException : public util::Exception { public: @@ -47,4 +47,4 @@ class SpecialWordMissingException : public VocabLoadException { } // namespace lm -#endif // LM_LM_EXCEPTION__ +#endif // LM_LM_EXCEPTION diff --git a/kenlm/lm/read_arpa.cc b/kenlm/lm/read_arpa.cc index e871e0fbd..060a97ea0 100644 --- a/kenlm/lm/read_arpa.cc +++ b/kenlm/lm/read_arpa.cc @@ -119,12 +119,12 @@ void ReadEnd(util::FilePiece &in) { void PositiveProbWarn::Warn(float prob) { switch (action_) { case THROW_UP: - UTIL_THROW(FormatLoadException, "Positive log probability " << prob << " in the model. This is a bug in IRSTLM; you can set config.positive_log_probability = NOTHING or pass -i to build_binary to substitute 0.0 for the log probability. Error"); + UTIL_THROW(FormatLoadException, "Positive log probability " << prob << " in the model. This is a bug in IRSTLM; you can set config.positive_log_probability = SILENT or pass -i to build_binary to substitute 0.0 for the log probability. Error"); case COMPLAIN: std::cerr << "There's a positive log probability " << prob << " in the APRA file, probably because of a bug in IRSTLM. This and subsequent entires will be mapepd to 0 log probability." << std::endl; - action_ = NOTHING; + action_ = SILENT; break; - case NOTHING: + case SILENT: break; } } diff --git a/kenlm/lm/trie.cc b/kenlm/lm/trie.cc index 2c6336137..58682ad49 100644 --- a/kenlm/lm/trie.cc +++ b/kenlm/lm/trie.cc @@ -26,7 +26,7 @@ class JustKeyProxy { private: friend class util::ProxyIterator; - friend bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, WordIndex key, uint64_t &at_index); + friend bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, uint64_t max_vocab, WordIndex key, uint64_t &at_index); JustKeyProxy(const void *base, uint64_t index, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits) : inner_(index), base_(static_cast(base)), key_mask_(key_mask), key_bits_(key_bits), total_bits_(total_bits) {} @@ -47,11 +47,11 @@ class JustKeyProxy { const uint8_t key_bits_, total_bits_; }; -bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, WordIndex key, uint64_t &at_index) { - util::ProxyIterator begin_it(JustKeyProxy(base, begin_index, key_mask, key_bits, total_bits)); - util::ProxyIterator end_it(JustKeyProxy(base, end_index, key_mask, key_bits, total_bits)); +bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, const uint64_t max_vocab, const WordIndex key, uint64_t &at_index) { + util::ProxyIterator before_it(JustKeyProxy(base, begin_index, key_mask, key_bits, total_bits)); + util::ProxyIterator after_it(JustKeyProxy(base, end_index, key_mask, key_bits, total_bits)); util::ProxyIterator out; - if (!util::SortedUniformFind(begin_it, end_it, key, out)) return false; + if (!util::BoundedSortedUniformFind, uint64_t>(before_it - 1, (uint64_t)0, after_it, max_vocab, key, out)) return false; at_index = out.Inner(); return true; } @@ -76,6 +76,7 @@ void BitPacked::BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits) base_ = static_cast(base); insert_index_ = 0; + max_vocab_ = max_vocab; } std::size_t BitPackedMiddle::Size(uint64_t entries, uint64_t max_vocab, uint64_t max_ptr) { @@ -88,7 +89,6 @@ void BitPackedMiddle::Init(void *base, uint64_t max_vocab, uint64_t max_next, co next_bits_ = util::RequiredBits(max_next); if (next_bits_ > 57) UTIL_THROW(util::Exception, "Sorry, this does not support more than " << (1ULL << 57) << " n-grams of a particular order. Edit util/bit_packing.hh and fix the bit packing functions."); next_mask_ = (1ULL << next_bits_) - 1; - BaseInit(base, max_vocab, backoff_bits_ + next_bits_); } @@ -111,7 +111,7 @@ void BitPackedMiddle::Insert(WordIndex word, float prob, float backoff) { bool BitPackedMiddle::Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const { uint64_t at_pointer; - if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, word, at_pointer)) { + if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) { return false; } at_pointer *= total_bits_; @@ -129,7 +129,7 @@ bool BitPackedMiddle::Find(WordIndex word, float &prob, float &backoff, NodeRang bool BitPackedMiddle::FindNoProb(WordIndex word, float &backoff, NodeRange &range) const { uint64_t at_pointer; - if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, word, at_pointer)) return false; + if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) return false; at_pointer *= total_bits_; at_pointer += word_bits_; at_pointer += prob_bits_; @@ -159,7 +159,7 @@ void BitPackedLongest::Insert(WordIndex index, float prob) { bool BitPackedLongest::Find(WordIndex word, float &prob, const NodeRange &range) const { uint64_t at_pointer; - if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, word, at_pointer)) return false; + if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) return false; at_pointer = at_pointer * total_bits_ + word_bits_; prob = util::ReadNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7); return true; diff --git a/kenlm/lm/trie.hh b/kenlm/lm/trie.hh index 6aef050c3..5e061298a 100644 --- a/kenlm/lm/trie.hh +++ b/kenlm/lm/trie.hh @@ -80,7 +80,7 @@ class BitPacked { uint8_t *base_; - uint64_t insert_index_; + uint64_t insert_index_, max_vocab_; }; class BitPackedMiddle : public BitPacked { diff --git a/kenlm/lm/vocab.cc b/kenlm/lm/vocab.cc index 7c75b8f43..515af5dbb 100644 --- a/kenlm/lm/vocab.cc +++ b/kenlm/lm/vocab.cc @@ -189,7 +189,7 @@ void ProbingVocabulary::LoadedBinary(int fd, EnumerateVocab *to) { void MissingUnknown(const Config &config) throw(SpecialWordMissingException) { switch(config.unknown_missing) { - case NOTHING: + case SILENT: return; case COMPLAIN: if (config.messages) *config.messages << "The ARPA file is missing . Substituting log10 probability " << config.unknown_missing_logprob << "." << std::endl; @@ -201,7 +201,7 @@ void MissingUnknown(const Config &config) throw(SpecialWordMissingException) { void MissingSentenceMarker(const Config &config, const char *str) throw(SpecialWordMissingException) { switch (config.sentence_marker_missing) { - case NOTHING: + case SILENT: return; case COMPLAIN: if (config.messages) *config.messages << "Missing special word " << str << "; will treat it as ."; diff --git a/kenlm/lm/vocab.hh b/kenlm/lm/vocab.hh index 546c16499..707fc8a55 100644 --- a/kenlm/lm/vocab.hh +++ b/kenlm/lm/vocab.hh @@ -9,6 +9,7 @@ #include "util/sorted_uniform.hh" #include "util/string_piece.hh" +#include #include #include @@ -59,7 +60,7 @@ class SortedVocabulary : public base::Vocabulary { WordIndex Index(const StringPiece &str) const { const Entry *found; - if (util::SortedUniformFind(begin_, end_, detail::HashForVocab(str), found)) { + if (util::BoundedSortedUniformFind(begin_ - 1, 0, end_, std::numeric_limits::max(), detail::HashForVocab(str), found)) { return found - begin_ + 1; // +1 because is 0 and does not appear in the lookup table. } else { return 0; diff --git a/kenlm/util/file_piece.cc b/kenlm/util/file_piece.cc index 67681f7e4..f447a70c5 100644 --- a/kenlm/util/file_piece.cc +++ b/kenlm/util/file_piece.cc @@ -237,7 +237,12 @@ void FilePiece::MMapShift(off_t desired_begin) throw() { // Forcibly clear the existing mmap first. data_.reset(); - data_.reset(mmap(NULL, mapped_size, PROT_READ, MAP_PRIVATE, *file_, mapped_offset), mapped_size, scoped_memory::MMAP_ALLOCATED); + data_.reset(mmap(NULL, mapped_size, PROT_READ, MAP_SHARED + // Populate where available on linux +#ifdef MAP_POPULATE + | MAP_POPULATE +#endif + , *file_, mapped_offset), mapped_size, scoped_memory::MMAP_ALLOCATED); if (data_.get() == MAP_FAILED) { if (desired_begin) { if (((off_t)-1) == lseek(*file_, desired_begin, SEEK_SET)) UTIL_THROW(ErrnoException, "mmap failed even though it worked before. lseek failed too, so using read isn't an option either."); diff --git a/kenlm/util/sorted_uniform.hh b/kenlm/util/sorted_uniform.hh index 05826b51d..5a3a0f8ee 100644 --- a/kenlm/util/sorted_uniform.hh +++ b/kenlm/util/sorted_uniform.hh @@ -24,6 +24,29 @@ inline std::size_t Pivot(unsigned char off, unsigned char range, std::size_t wid return static_cast(static_cast(off) * width / static_cast(range)); }*/ +// Search the range [before_it + 1, after_it - 1] for key. +// Preconditions: +// before_v <= key <= after_v +// before_v <= all values in the range [before_it + 1, after_it - 1] <= after_v +// range is sorted. +template bool BoundedSortedUniformFind(Iterator before_it, Key before_v, Iterator after_it, Key after_v, const Key key, Iterator &out) { + while (after_it - before_it - 1 > 0) { + Iterator pivot(before_it + (1 + Pivot(key - before_v, after_v - before_v, static_cast(after_it - before_it - 1)))); + Key mid(pivot->GetKey()); + if (mid < key) { + before_it = pivot; + before_v = mid; + } else if (mid > key) { + after_it = pivot; + after_v = mid; + } else { + out = pivot; + return true; + } + } + return false; +} + template bool SortedUniformFind(Iterator begin, Iterator end, const Key key, Iterator &out) { if (begin == end) return false; Key below(begin->GetKey()); @@ -38,23 +61,7 @@ template bool SortedUniformFind(Iterator begin, Iter if (key == above) { out = end; return true; } return false; } - - // Search the range [begin + 1, end - 1] knowing that *begin == below, *end == above. - while (end - begin > 1) { - Iterator pivot(begin + (1 + Pivot(key - below, above - below, static_cast(end - begin - 1)))); - Key mid(pivot->GetKey()); - if (mid < key) { - begin = pivot; - below = mid; - } else if (mid > key) { - end = pivot; - above = mid; - } else { - out = pivot; - return true; - } - } - return false; + return BoundedSortedUniformFind(begin, below, end, above, key, out); } // To use this template, you need to define a Pivot function to match Key.