Speed improvements mostly. Trie went from 803448 queries/s to 990520 queries/s by knowing what the bounds are in advance. Also, set read ahead for

files.  



git-svn-id: https://mosesdecoder.svn.sourceforge.net/svnroot/mosesdecoder/trunk@3988 1f5c12ca-751b-0410-a591-d2e778427230
This commit is contained in:
heafield 2011-05-23 02:23:01 +00:00
parent 7408636328
commit e0d618528a
10 changed files with 52 additions and 39 deletions

View File

@ -91,10 +91,10 @@ int main(int argc, char *argv[]) {
config.building_memory = ParseUInt(optarg) * 1048576;
break;
case 's':
config.sentence_marker_missing = lm::NOTHING;
config.sentence_marker_missing = lm::SILENT;
break;
case 'i':
config.positive_log_probability = lm::NOTHING;
config.positive_log_probability = lm::SILENT;
break;
default:
Usage(argv[0]);

View File

@ -34,7 +34,7 @@ struct Config {
// If THROW_UP, the exception will be of type util::SpecialWordMissingException.
WarningAction sentence_marker_missing;
// What to do with a positive log probability. For COMPLAIN and NOTHING, map
// What to do with a positive log probability. For COMPLAIN and SILENT, map
// to 0.
WarningAction positive_log_probability;

View File

@ -11,7 +11,7 @@
namespace lm {
typedef enum {THROW_UP, COMPLAIN, NOTHING} WarningAction;
typedef enum {THROW_UP, COMPLAIN, SILENT} WarningAction;
class ConfigException : public util::Exception {
public:
@ -47,4 +47,4 @@ class SpecialWordMissingException : public VocabLoadException {
} // namespace lm
#endif // LM_LM_EXCEPTION__
#endif // LM_LM_EXCEPTION

View File

@ -119,12 +119,12 @@ void ReadEnd(util::FilePiece &in) {
void PositiveProbWarn::Warn(float prob) {
switch (action_) {
case THROW_UP:
UTIL_THROW(FormatLoadException, "Positive log probability " << prob << " in the model. This is a bug in IRSTLM; you can set config.positive_log_probability = NOTHING or pass -i to build_binary to substitute 0.0 for the log probability. Error");
UTIL_THROW(FormatLoadException, "Positive log probability " << prob << " in the model. This is a bug in IRSTLM; you can set config.positive_log_probability = SILENT or pass -i to build_binary to substitute 0.0 for the log probability. Error");
case COMPLAIN:
std::cerr << "There's a positive log probability " << prob << " in the APRA file, probably because of a bug in IRSTLM. This and subsequent entires will be mapepd to 0 log probability." << std::endl;
action_ = NOTHING;
action_ = SILENT;
break;
case NOTHING:
case SILENT:
break;
}
}

View File

@ -26,7 +26,7 @@ class JustKeyProxy {
private:
friend class util::ProxyIterator<JustKeyProxy>;
friend bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, WordIndex key, uint64_t &at_index);
friend bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, uint64_t max_vocab, WordIndex key, uint64_t &at_index);
JustKeyProxy(const void *base, uint64_t index, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits)
: inner_(index), base_(static_cast<const uint8_t*>(base)), key_mask_(key_mask), key_bits_(key_bits), total_bits_(total_bits) {}
@ -47,11 +47,11 @@ class JustKeyProxy {
const uint8_t key_bits_, total_bits_;
};
bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, WordIndex key, uint64_t &at_index) {
util::ProxyIterator<JustKeyProxy> begin_it(JustKeyProxy(base, begin_index, key_mask, key_bits, total_bits));
util::ProxyIterator<JustKeyProxy> end_it(JustKeyProxy(base, end_index, key_mask, key_bits, total_bits));
bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, const uint64_t max_vocab, const WordIndex key, uint64_t &at_index) {
util::ProxyIterator<JustKeyProxy> before_it(JustKeyProxy(base, begin_index, key_mask, key_bits, total_bits));
util::ProxyIterator<JustKeyProxy> after_it(JustKeyProxy(base, end_index, key_mask, key_bits, total_bits));
util::ProxyIterator<JustKeyProxy> out;
if (!util::SortedUniformFind(begin_it, end_it, key, out)) return false;
if (!util::BoundedSortedUniformFind<util::ProxyIterator<JustKeyProxy>, uint64_t>(before_it - 1, (uint64_t)0, after_it, max_vocab, key, out)) return false;
at_index = out.Inner();
return true;
}
@ -76,6 +76,7 @@ void BitPacked::BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits)
base_ = static_cast<uint8_t*>(base);
insert_index_ = 0;
max_vocab_ = max_vocab;
}
std::size_t BitPackedMiddle::Size(uint64_t entries, uint64_t max_vocab, uint64_t max_ptr) {
@ -88,7 +89,6 @@ void BitPackedMiddle::Init(void *base, uint64_t max_vocab, uint64_t max_next, co
next_bits_ = util::RequiredBits(max_next);
if (next_bits_ > 57) UTIL_THROW(util::Exception, "Sorry, this does not support more than " << (1ULL << 57) << " n-grams of a particular order. Edit util/bit_packing.hh and fix the bit packing functions.");
next_mask_ = (1ULL << next_bits_) - 1;
BaseInit(base, max_vocab, backoff_bits_ + next_bits_);
}
@ -111,7 +111,7 @@ void BitPackedMiddle::Insert(WordIndex word, float prob, float backoff) {
bool BitPackedMiddle::Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const {
uint64_t at_pointer;
if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, word, at_pointer)) {
if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) {
return false;
}
at_pointer *= total_bits_;
@ -129,7 +129,7 @@ bool BitPackedMiddle::Find(WordIndex word, float &prob, float &backoff, NodeRang
bool BitPackedMiddle::FindNoProb(WordIndex word, float &backoff, NodeRange &range) const {
uint64_t at_pointer;
if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, word, at_pointer)) return false;
if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) return false;
at_pointer *= total_bits_;
at_pointer += word_bits_;
at_pointer += prob_bits_;
@ -159,7 +159,7 @@ void BitPackedLongest::Insert(WordIndex index, float prob) {
bool BitPackedLongest::Find(WordIndex word, float &prob, const NodeRange &range) const {
uint64_t at_pointer;
if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, word, at_pointer)) return false;
if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) return false;
at_pointer = at_pointer * total_bits_ + word_bits_;
prob = util::ReadNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7);
return true;

View File

@ -80,7 +80,7 @@ class BitPacked {
uint8_t *base_;
uint64_t insert_index_;
uint64_t insert_index_, max_vocab_;
};
class BitPackedMiddle : public BitPacked {

View File

@ -189,7 +189,7 @@ void ProbingVocabulary::LoadedBinary(int fd, EnumerateVocab *to) {
void MissingUnknown(const Config &config) throw(SpecialWordMissingException) {
switch(config.unknown_missing) {
case NOTHING:
case SILENT:
return;
case COMPLAIN:
if (config.messages) *config.messages << "The ARPA file is missing <unk>. Substituting log10 probability " << config.unknown_missing_logprob << "." << std::endl;
@ -201,7 +201,7 @@ void MissingUnknown(const Config &config) throw(SpecialWordMissingException) {
void MissingSentenceMarker(const Config &config, const char *str) throw(SpecialWordMissingException) {
switch (config.sentence_marker_missing) {
case NOTHING:
case SILENT:
return;
case COMPLAIN:
if (config.messages) *config.messages << "Missing special word " << str << "; will treat it as <unk>.";

View File

@ -9,6 +9,7 @@
#include "util/sorted_uniform.hh"
#include "util/string_piece.hh"
#include <limits>
#include <string>
#include <vector>
@ -59,7 +60,7 @@ class SortedVocabulary : public base::Vocabulary {
WordIndex Index(const StringPiece &str) const {
const Entry *found;
if (util::SortedUniformFind<const Entry *, uint64_t>(begin_, end_, detail::HashForVocab(str), found)) {
if (util::BoundedSortedUniformFind<const Entry *, uint64_t>(begin_ - 1, 0, end_, std::numeric_limits<uint64_t>::max(), detail::HashForVocab(str), found)) {
return found - begin_ + 1; // +1 because <unk> is 0 and does not appear in the lookup table.
} else {
return 0;

View File

@ -237,7 +237,12 @@ void FilePiece::MMapShift(off_t desired_begin) throw() {
// Forcibly clear the existing mmap first.
data_.reset();
data_.reset(mmap(NULL, mapped_size, PROT_READ, MAP_PRIVATE, *file_, mapped_offset), mapped_size, scoped_memory::MMAP_ALLOCATED);
data_.reset(mmap(NULL, mapped_size, PROT_READ, MAP_SHARED
// Populate where available on linux
#ifdef MAP_POPULATE
| MAP_POPULATE
#endif
, *file_, mapped_offset), mapped_size, scoped_memory::MMAP_ALLOCATED);
if (data_.get() == MAP_FAILED) {
if (desired_begin) {
if (((off_t)-1) == lseek(*file_, desired_begin, SEEK_SET)) UTIL_THROW(ErrnoException, "mmap failed even though it worked before. lseek failed too, so using read isn't an option either.");

View File

@ -24,6 +24,29 @@ inline std::size_t Pivot(unsigned char off, unsigned char range, std::size_t wid
return static_cast<std::size_t>(static_cast<std::size_t>(off) * width / static_cast<std::size_t>(range));
}*/
// Search the range [before_it + 1, after_it - 1] for key.
// Preconditions:
// before_v <= key <= after_v
// before_v <= all values in the range [before_it + 1, after_it - 1] <= after_v
// range is sorted.
template <class Iterator, class Key> bool BoundedSortedUniformFind(Iterator before_it, Key before_v, Iterator after_it, Key after_v, const Key key, Iterator &out) {
while (after_it - before_it - 1 > 0) {
Iterator pivot(before_it + (1 + Pivot(key - before_v, after_v - before_v, static_cast<std::size_t>(after_it - before_it - 1))));
Key mid(pivot->GetKey());
if (mid < key) {
before_it = pivot;
before_v = mid;
} else if (mid > key) {
after_it = pivot;
after_v = mid;
} else {
out = pivot;
return true;
}
}
return false;
}
template <class Iterator, class Key> bool SortedUniformFind(Iterator begin, Iterator end, const Key key, Iterator &out) {
if (begin == end) return false;
Key below(begin->GetKey());
@ -38,23 +61,7 @@ template <class Iterator, class Key> bool SortedUniformFind(Iterator begin, Iter
if (key == above) { out = end; return true; }
return false;
}
// Search the range [begin + 1, end - 1] knowing that *begin == below, *end == above.
while (end - begin > 1) {
Iterator pivot(begin + (1 + Pivot(key - below, above - below, static_cast<std::size_t>(end - begin - 1))));
Key mid(pivot->GetKey());
if (mid < key) {
begin = pivot;
below = mid;
} else if (mid > key) {
end = pivot;
above = mid;
} else {
out = pivot;
return true;
}
}
return false;
return BoundedSortedUniformFind(begin, below, end, above, key, out);
}
// To use this template, you need to define a Pivot function to match Key.