mirror of
https://github.com/moses-smt/mosesdecoder.git
synced 2024-12-27 05:55:02 +03:00
Speed improvements mostly. Trie went from 803448 queries/s to 990520 queries/s by knowing what the bounds are in advance. Also, set read ahead for
files. git-svn-id: https://mosesdecoder.svn.sourceforge.net/svnroot/mosesdecoder/trunk@3988 1f5c12ca-751b-0410-a591-d2e778427230
This commit is contained in:
parent
7408636328
commit
e0d618528a
@ -91,10 +91,10 @@ int main(int argc, char *argv[]) {
|
|||||||
config.building_memory = ParseUInt(optarg) * 1048576;
|
config.building_memory = ParseUInt(optarg) * 1048576;
|
||||||
break;
|
break;
|
||||||
case 's':
|
case 's':
|
||||||
config.sentence_marker_missing = lm::NOTHING;
|
config.sentence_marker_missing = lm::SILENT;
|
||||||
break;
|
break;
|
||||||
case 'i':
|
case 'i':
|
||||||
config.positive_log_probability = lm::NOTHING;
|
config.positive_log_probability = lm::SILENT;
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
Usage(argv[0]);
|
Usage(argv[0]);
|
||||||
|
@ -34,7 +34,7 @@ struct Config {
|
|||||||
// If THROW_UP, the exception will be of type util::SpecialWordMissingException.
|
// If THROW_UP, the exception will be of type util::SpecialWordMissingException.
|
||||||
WarningAction sentence_marker_missing;
|
WarningAction sentence_marker_missing;
|
||||||
|
|
||||||
// What to do with a positive log probability. For COMPLAIN and NOTHING, map
|
// What to do with a positive log probability. For COMPLAIN and SILENT, map
|
||||||
// to 0.
|
// to 0.
|
||||||
WarningAction positive_log_probability;
|
WarningAction positive_log_probability;
|
||||||
|
|
||||||
|
@ -11,7 +11,7 @@
|
|||||||
|
|
||||||
namespace lm {
|
namespace lm {
|
||||||
|
|
||||||
typedef enum {THROW_UP, COMPLAIN, NOTHING} WarningAction;
|
typedef enum {THROW_UP, COMPLAIN, SILENT} WarningAction;
|
||||||
|
|
||||||
class ConfigException : public util::Exception {
|
class ConfigException : public util::Exception {
|
||||||
public:
|
public:
|
||||||
@ -47,4 +47,4 @@ class SpecialWordMissingException : public VocabLoadException {
|
|||||||
|
|
||||||
} // namespace lm
|
} // namespace lm
|
||||||
|
|
||||||
#endif // LM_LM_EXCEPTION__
|
#endif // LM_LM_EXCEPTION
|
||||||
|
@ -119,12 +119,12 @@ void ReadEnd(util::FilePiece &in) {
|
|||||||
void PositiveProbWarn::Warn(float prob) {
|
void PositiveProbWarn::Warn(float prob) {
|
||||||
switch (action_) {
|
switch (action_) {
|
||||||
case THROW_UP:
|
case THROW_UP:
|
||||||
UTIL_THROW(FormatLoadException, "Positive log probability " << prob << " in the model. This is a bug in IRSTLM; you can set config.positive_log_probability = NOTHING or pass -i to build_binary to substitute 0.0 for the log probability. Error");
|
UTIL_THROW(FormatLoadException, "Positive log probability " << prob << " in the model. This is a bug in IRSTLM; you can set config.positive_log_probability = SILENT or pass -i to build_binary to substitute 0.0 for the log probability. Error");
|
||||||
case COMPLAIN:
|
case COMPLAIN:
|
||||||
std::cerr << "There's a positive log probability " << prob << " in the APRA file, probably because of a bug in IRSTLM. This and subsequent entires will be mapepd to 0 log probability." << std::endl;
|
std::cerr << "There's a positive log probability " << prob << " in the APRA file, probably because of a bug in IRSTLM. This and subsequent entires will be mapepd to 0 log probability." << std::endl;
|
||||||
action_ = NOTHING;
|
action_ = SILENT;
|
||||||
break;
|
break;
|
||||||
case NOTHING:
|
case SILENT:
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -26,7 +26,7 @@ class JustKeyProxy {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
friend class util::ProxyIterator<JustKeyProxy>;
|
friend class util::ProxyIterator<JustKeyProxy>;
|
||||||
friend bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, WordIndex key, uint64_t &at_index);
|
friend bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, uint64_t max_vocab, WordIndex key, uint64_t &at_index);
|
||||||
|
|
||||||
JustKeyProxy(const void *base, uint64_t index, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits)
|
JustKeyProxy(const void *base, uint64_t index, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits)
|
||||||
: inner_(index), base_(static_cast<const uint8_t*>(base)), key_mask_(key_mask), key_bits_(key_bits), total_bits_(total_bits) {}
|
: inner_(index), base_(static_cast<const uint8_t*>(base)), key_mask_(key_mask), key_bits_(key_bits), total_bits_(total_bits) {}
|
||||||
@ -47,11 +47,11 @@ class JustKeyProxy {
|
|||||||
const uint8_t key_bits_, total_bits_;
|
const uint8_t key_bits_, total_bits_;
|
||||||
};
|
};
|
||||||
|
|
||||||
bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, WordIndex key, uint64_t &at_index) {
|
bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, const uint64_t max_vocab, const WordIndex key, uint64_t &at_index) {
|
||||||
util::ProxyIterator<JustKeyProxy> begin_it(JustKeyProxy(base, begin_index, key_mask, key_bits, total_bits));
|
util::ProxyIterator<JustKeyProxy> before_it(JustKeyProxy(base, begin_index, key_mask, key_bits, total_bits));
|
||||||
util::ProxyIterator<JustKeyProxy> end_it(JustKeyProxy(base, end_index, key_mask, key_bits, total_bits));
|
util::ProxyIterator<JustKeyProxy> after_it(JustKeyProxy(base, end_index, key_mask, key_bits, total_bits));
|
||||||
util::ProxyIterator<JustKeyProxy> out;
|
util::ProxyIterator<JustKeyProxy> out;
|
||||||
if (!util::SortedUniformFind(begin_it, end_it, key, out)) return false;
|
if (!util::BoundedSortedUniformFind<util::ProxyIterator<JustKeyProxy>, uint64_t>(before_it - 1, (uint64_t)0, after_it, max_vocab, key, out)) return false;
|
||||||
at_index = out.Inner();
|
at_index = out.Inner();
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@ -76,6 +76,7 @@ void BitPacked::BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits)
|
|||||||
|
|
||||||
base_ = static_cast<uint8_t*>(base);
|
base_ = static_cast<uint8_t*>(base);
|
||||||
insert_index_ = 0;
|
insert_index_ = 0;
|
||||||
|
max_vocab_ = max_vocab;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::size_t BitPackedMiddle::Size(uint64_t entries, uint64_t max_vocab, uint64_t max_ptr) {
|
std::size_t BitPackedMiddle::Size(uint64_t entries, uint64_t max_vocab, uint64_t max_ptr) {
|
||||||
@ -88,7 +89,6 @@ void BitPackedMiddle::Init(void *base, uint64_t max_vocab, uint64_t max_next, co
|
|||||||
next_bits_ = util::RequiredBits(max_next);
|
next_bits_ = util::RequiredBits(max_next);
|
||||||
if (next_bits_ > 57) UTIL_THROW(util::Exception, "Sorry, this does not support more than " << (1ULL << 57) << " n-grams of a particular order. Edit util/bit_packing.hh and fix the bit packing functions.");
|
if (next_bits_ > 57) UTIL_THROW(util::Exception, "Sorry, this does not support more than " << (1ULL << 57) << " n-grams of a particular order. Edit util/bit_packing.hh and fix the bit packing functions.");
|
||||||
next_mask_ = (1ULL << next_bits_) - 1;
|
next_mask_ = (1ULL << next_bits_) - 1;
|
||||||
|
|
||||||
BaseInit(base, max_vocab, backoff_bits_ + next_bits_);
|
BaseInit(base, max_vocab, backoff_bits_ + next_bits_);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -111,7 +111,7 @@ void BitPackedMiddle::Insert(WordIndex word, float prob, float backoff) {
|
|||||||
|
|
||||||
bool BitPackedMiddle::Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const {
|
bool BitPackedMiddle::Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const {
|
||||||
uint64_t at_pointer;
|
uint64_t at_pointer;
|
||||||
if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, word, at_pointer)) {
|
if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
at_pointer *= total_bits_;
|
at_pointer *= total_bits_;
|
||||||
@ -129,7 +129,7 @@ bool BitPackedMiddle::Find(WordIndex word, float &prob, float &backoff, NodeRang
|
|||||||
|
|
||||||
bool BitPackedMiddle::FindNoProb(WordIndex word, float &backoff, NodeRange &range) const {
|
bool BitPackedMiddle::FindNoProb(WordIndex word, float &backoff, NodeRange &range) const {
|
||||||
uint64_t at_pointer;
|
uint64_t at_pointer;
|
||||||
if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, word, at_pointer)) return false;
|
if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) return false;
|
||||||
at_pointer *= total_bits_;
|
at_pointer *= total_bits_;
|
||||||
at_pointer += word_bits_;
|
at_pointer += word_bits_;
|
||||||
at_pointer += prob_bits_;
|
at_pointer += prob_bits_;
|
||||||
@ -159,7 +159,7 @@ void BitPackedLongest::Insert(WordIndex index, float prob) {
|
|||||||
|
|
||||||
bool BitPackedLongest::Find(WordIndex word, float &prob, const NodeRange &range) const {
|
bool BitPackedLongest::Find(WordIndex word, float &prob, const NodeRange &range) const {
|
||||||
uint64_t at_pointer;
|
uint64_t at_pointer;
|
||||||
if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, word, at_pointer)) return false;
|
if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) return false;
|
||||||
at_pointer = at_pointer * total_bits_ + word_bits_;
|
at_pointer = at_pointer * total_bits_ + word_bits_;
|
||||||
prob = util::ReadNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7);
|
prob = util::ReadNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7);
|
||||||
return true;
|
return true;
|
||||||
|
@ -80,7 +80,7 @@ class BitPacked {
|
|||||||
|
|
||||||
uint8_t *base_;
|
uint8_t *base_;
|
||||||
|
|
||||||
uint64_t insert_index_;
|
uint64_t insert_index_, max_vocab_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class BitPackedMiddle : public BitPacked {
|
class BitPackedMiddle : public BitPacked {
|
||||||
|
@ -189,7 +189,7 @@ void ProbingVocabulary::LoadedBinary(int fd, EnumerateVocab *to) {
|
|||||||
|
|
||||||
void MissingUnknown(const Config &config) throw(SpecialWordMissingException) {
|
void MissingUnknown(const Config &config) throw(SpecialWordMissingException) {
|
||||||
switch(config.unknown_missing) {
|
switch(config.unknown_missing) {
|
||||||
case NOTHING:
|
case SILENT:
|
||||||
return;
|
return;
|
||||||
case COMPLAIN:
|
case COMPLAIN:
|
||||||
if (config.messages) *config.messages << "The ARPA file is missing <unk>. Substituting log10 probability " << config.unknown_missing_logprob << "." << std::endl;
|
if (config.messages) *config.messages << "The ARPA file is missing <unk>. Substituting log10 probability " << config.unknown_missing_logprob << "." << std::endl;
|
||||||
@ -201,7 +201,7 @@ void MissingUnknown(const Config &config) throw(SpecialWordMissingException) {
|
|||||||
|
|
||||||
void MissingSentenceMarker(const Config &config, const char *str) throw(SpecialWordMissingException) {
|
void MissingSentenceMarker(const Config &config, const char *str) throw(SpecialWordMissingException) {
|
||||||
switch (config.sentence_marker_missing) {
|
switch (config.sentence_marker_missing) {
|
||||||
case NOTHING:
|
case SILENT:
|
||||||
return;
|
return;
|
||||||
case COMPLAIN:
|
case COMPLAIN:
|
||||||
if (config.messages) *config.messages << "Missing special word " << str << "; will treat it as <unk>.";
|
if (config.messages) *config.messages << "Missing special word " << str << "; will treat it as <unk>.";
|
||||||
|
@ -9,6 +9,7 @@
|
|||||||
#include "util/sorted_uniform.hh"
|
#include "util/sorted_uniform.hh"
|
||||||
#include "util/string_piece.hh"
|
#include "util/string_piece.hh"
|
||||||
|
|
||||||
|
#include <limits>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
@ -59,7 +60,7 @@ class SortedVocabulary : public base::Vocabulary {
|
|||||||
|
|
||||||
WordIndex Index(const StringPiece &str) const {
|
WordIndex Index(const StringPiece &str) const {
|
||||||
const Entry *found;
|
const Entry *found;
|
||||||
if (util::SortedUniformFind<const Entry *, uint64_t>(begin_, end_, detail::HashForVocab(str), found)) {
|
if (util::BoundedSortedUniformFind<const Entry *, uint64_t>(begin_ - 1, 0, end_, std::numeric_limits<uint64_t>::max(), detail::HashForVocab(str), found)) {
|
||||||
return found - begin_ + 1; // +1 because <unk> is 0 and does not appear in the lookup table.
|
return found - begin_ + 1; // +1 because <unk> is 0 and does not appear in the lookup table.
|
||||||
} else {
|
} else {
|
||||||
return 0;
|
return 0;
|
||||||
|
@ -237,7 +237,12 @@ void FilePiece::MMapShift(off_t desired_begin) throw() {
|
|||||||
|
|
||||||
// Forcibly clear the existing mmap first.
|
// Forcibly clear the existing mmap first.
|
||||||
data_.reset();
|
data_.reset();
|
||||||
data_.reset(mmap(NULL, mapped_size, PROT_READ, MAP_PRIVATE, *file_, mapped_offset), mapped_size, scoped_memory::MMAP_ALLOCATED);
|
data_.reset(mmap(NULL, mapped_size, PROT_READ, MAP_SHARED
|
||||||
|
// Populate where available on linux
|
||||||
|
#ifdef MAP_POPULATE
|
||||||
|
| MAP_POPULATE
|
||||||
|
#endif
|
||||||
|
, *file_, mapped_offset), mapped_size, scoped_memory::MMAP_ALLOCATED);
|
||||||
if (data_.get() == MAP_FAILED) {
|
if (data_.get() == MAP_FAILED) {
|
||||||
if (desired_begin) {
|
if (desired_begin) {
|
||||||
if (((off_t)-1) == lseek(*file_, desired_begin, SEEK_SET)) UTIL_THROW(ErrnoException, "mmap failed even though it worked before. lseek failed too, so using read isn't an option either.");
|
if (((off_t)-1) == lseek(*file_, desired_begin, SEEK_SET)) UTIL_THROW(ErrnoException, "mmap failed even though it worked before. lseek failed too, so using read isn't an option either.");
|
||||||
|
@ -24,6 +24,29 @@ inline std::size_t Pivot(unsigned char off, unsigned char range, std::size_t wid
|
|||||||
return static_cast<std::size_t>(static_cast<std::size_t>(off) * width / static_cast<std::size_t>(range));
|
return static_cast<std::size_t>(static_cast<std::size_t>(off) * width / static_cast<std::size_t>(range));
|
||||||
}*/
|
}*/
|
||||||
|
|
||||||
|
// Search the range [before_it + 1, after_it - 1] for key.
|
||||||
|
// Preconditions:
|
||||||
|
// before_v <= key <= after_v
|
||||||
|
// before_v <= all values in the range [before_it + 1, after_it - 1] <= after_v
|
||||||
|
// range is sorted.
|
||||||
|
template <class Iterator, class Key> bool BoundedSortedUniformFind(Iterator before_it, Key before_v, Iterator after_it, Key after_v, const Key key, Iterator &out) {
|
||||||
|
while (after_it - before_it - 1 > 0) {
|
||||||
|
Iterator pivot(before_it + (1 + Pivot(key - before_v, after_v - before_v, static_cast<std::size_t>(after_it - before_it - 1))));
|
||||||
|
Key mid(pivot->GetKey());
|
||||||
|
if (mid < key) {
|
||||||
|
before_it = pivot;
|
||||||
|
before_v = mid;
|
||||||
|
} else if (mid > key) {
|
||||||
|
after_it = pivot;
|
||||||
|
after_v = mid;
|
||||||
|
} else {
|
||||||
|
out = pivot;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
template <class Iterator, class Key> bool SortedUniformFind(Iterator begin, Iterator end, const Key key, Iterator &out) {
|
template <class Iterator, class Key> bool SortedUniformFind(Iterator begin, Iterator end, const Key key, Iterator &out) {
|
||||||
if (begin == end) return false;
|
if (begin == end) return false;
|
||||||
Key below(begin->GetKey());
|
Key below(begin->GetKey());
|
||||||
@ -38,23 +61,7 @@ template <class Iterator, class Key> bool SortedUniformFind(Iterator begin, Iter
|
|||||||
if (key == above) { out = end; return true; }
|
if (key == above) { out = end; return true; }
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
return BoundedSortedUniformFind(begin, below, end, above, key, out);
|
||||||
// Search the range [begin + 1, end - 1] knowing that *begin == below, *end == above.
|
|
||||||
while (end - begin > 1) {
|
|
||||||
Iterator pivot(begin + (1 + Pivot(key - below, above - below, static_cast<std::size_t>(end - begin - 1))));
|
|
||||||
Key mid(pivot->GetKey());
|
|
||||||
if (mid < key) {
|
|
||||||
begin = pivot;
|
|
||||||
below = mid;
|
|
||||||
} else if (mid > key) {
|
|
||||||
end = pivot;
|
|
||||||
above = mid;
|
|
||||||
} else {
|
|
||||||
out = pivot;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// To use this template, you need to define a Pivot function to match Key.
|
// To use this template, you need to define a Pivot function to match Key.
|
||||||
|
Loading…
Reference in New Issue
Block a user