KenLM 0e5d259 including read_compressed fix

This commit is contained in:
Kenneth Heafield 2013-01-04 21:02:47 +00:00
parent 3203f7c92d
commit f9ee7ae4b3
14 changed files with 111 additions and 98 deletions

View File

@ -233,7 +233,8 @@ void ComplainAboutARPA(const Config &config, ModelType model_type) {
if (config.write_mmap || !config.messages) return; if (config.write_mmap || !config.messages) return;
if (config.arpa_complain == Config::ALL) { if (config.arpa_complain == Config::ALL) {
*config.messages << "Loading the LM will be faster if you build a binary file." << std::endl; *config.messages << "Loading the LM will be faster if you build a binary file." << std::endl;
} else if (config.arpa_complain == Config::EXPENSIVE && model_type == TRIE_SORTED) { } else if (config.arpa_complain == Config::EXPENSIVE &&
(model_type == TRIE || model_type == QUANT_TRIE || model_type == ARRAY_TRIE || model_type == QUANT_ARRAY_TRIE)) {
*config.messages << "Building " << kModelNames[model_type] << " from ARPA is expensive. Save time by building a binary format." << std::endl; *config.messages << "Building " << kModelNames[model_type] << " from ARPA is expensive. Save time by building a binary format." << std::endl;
} }
} }

View File

@ -6,6 +6,7 @@ namespace lm {
namespace ngram { namespace ngram {
Config::Config() : Config::Config() :
show_progress(true),
messages(&std::cerr), messages(&std::cerr),
enumerate_vocab(NULL), enumerate_vocab(NULL),
unknown_missing(COMPLAIN), unknown_missing(COMPLAIN),

View File

@ -19,17 +19,23 @@ namespace ngram {
struct Config { struct Config {
// EFFECTIVE FOR BOTH ARPA AND BINARY READS // EFFECTIVE FOR BOTH ARPA AND BINARY READS
// (default true) print progress bar to messages
bool show_progress;
// Where to log messages including the progress bar. Set to NULL for // Where to log messages including the progress bar. Set to NULL for
// silence. // silence.
std::ostream *messages; std::ostream *messages;
std::ostream *ProgressMessages() const {
return show_progress ? messages : 0;
}
// This will be called with every string in the vocabulary. See // This will be called with every string in the vocabulary. See
// enumerate_vocab.hh for more detail. Config does not take ownership; you // enumerate_vocab.hh for more detail. Config does not take ownership; you
// are still responsible for deleting it (or stack allocating). // are still responsible for deleting it (or stack allocating).
EnumerateVocab *enumerate_vocab; EnumerateVocab *enumerate_vocab;
// ONLY EFFECTIVE WHEN READING ARPA // ONLY EFFECTIVE WHEN READING ARPA
// What to do when <unk> isn't in the provided model. // What to do when <unk> isn't in the provided model.
@ -92,7 +98,6 @@ struct Config {
std::vector<std::string> rest_lower_files; std::vector<std::string> rest_lower_files;
// Quantization options. Only effective for QuantTrieModel. One value is // Quantization options. Only effective for QuantTrieModel. One value is
// reserved for each of prob and backoff, so 2^bits - 1 buckets will be used // reserved for each of prob and backoff, so 2^bits - 1 buckets will be used
// to quantize (and one of the remaining backoffs will be 0). // to quantize (and one of the remaining backoffs will be 0).
@ -102,7 +107,6 @@ struct Config {
uint8_t pointer_bhiksha_bits; uint8_t pointer_bhiksha_bits;
// ONLY EFFECTIVE WHEN READING BINARY // ONLY EFFECTIVE WHEN READING BINARY
// How to get the giant array into memory: lazy mmap, populate, read etc. // How to get the giant array into memory: lazy mmap, populate, read etc.
@ -110,7 +114,6 @@ struct Config {
util::LoadMethod load_method; util::LoadMethod load_method;
// Set defaults. // Set defaults.
Config(); Config();
}; };

View File

@ -70,7 +70,7 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromARPA(const char *file, const Config &config) { template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromARPA(const char *file, const Config &config) {
// Backing file is the ARPA. Steal it so we can make the backing file the mmap output if any. // Backing file is the ARPA. Steal it so we can make the backing file the mmap output if any.
util::FilePiece f(backing_.file.release(), file, config.messages); util::FilePiece f(backing_.file.release(), file, config.ProgressMessages());
try { try {
std::vector<uint64_t> counts; std::vector<uint64_t> counts;
// File counts do not include pruned trigrams that extend to quadgrams etc. These will be fixed by search_. // File counts do not include pruned trigrams that extend to quadgrams etc. These will be fixed by search_.

View File

@ -486,7 +486,7 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve
util::scoped_memory unigrams; util::scoped_memory unigrams;
MapRead(util::POPULATE_OR_READ, unigram_fd.get(), 0, counts[0] * sizeof(ProbBackoff), unigrams); MapRead(util::POPULATE_OR_READ, unigram_fd.get(), 0, counts[0] * sizeof(ProbBackoff), unigrams);
FindBlanks finder(counts.size(), reinterpret_cast<const ProbBackoff*>(unigrams.get()), sri); FindBlanks finder(counts.size(), reinterpret_cast<const ProbBackoff*>(unigrams.get()), sri);
RecursiveInsert(counts.size(), counts[0], inputs, config.messages, "Identifying n-grams omitted by SRI", finder); RecursiveInsert(counts.size(), counts[0], inputs, config.ProgressMessages(), "Identifying n-grams omitted by SRI", finder);
fixed_counts = finder.Counts(); fixed_counts = finder.Counts();
} }
unigram_file.reset(util::FDOpenOrThrow(unigram_fd)); unigram_file.reset(util::FDOpenOrThrow(unigram_fd));
@ -504,7 +504,8 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve
inputs[i-2].Rewind(); inputs[i-2].Rewind();
} }
if (Quant::kTrain) { if (Quant::kTrain) {
util::ErsatzProgress progress(std::accumulate(counts.begin() + 1, counts.end(), 0), config.messages, "Quantizing"); util::ErsatzProgress progress(std::accumulate(counts.begin() + 1, counts.end(), 0),
config.ProgressMessages(), "Quantizing");
for (unsigned char i = 2; i < counts.size(); ++i) { for (unsigned char i = 2; i < counts.size(); ++i) {
TrainQuantizer(i, counts[i-1], sri.Values(i), inputs[i-2], progress, quant); TrainQuantizer(i, counts[i-1], sri.Values(i), inputs[i-2], progress, quant);
} }
@ -522,7 +523,7 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve
// Fill entries except unigram probabilities. // Fill entries except unigram probabilities.
{ {
WriteEntries<Quant, Bhiksha> writer(contexts, quant, unigrams, out.middle_begin_, out.longest_, counts.size(), sri); WriteEntries<Quant, Bhiksha> writer(contexts, quant, unigrams, out.middle_begin_, out.longest_, counts.size(), sri);
RecursiveInsert(counts.size(), counts[0], inputs, config.messages, "Writing trie", writer); RecursiveInsert(counts.size(), counts[0], inputs, config.ProgressMessages(), "Writing trie", writer);
} }
// Do not disable this error message or else too little state will be returned. Both WriteEntries::Middle and returning state based on found n-grams will need to be fixed to handle this situation. // Do not disable this error message or else too little state will be returned. Both WriteEntries::Middle and returning state based on found n-grams will need to be fixed to handle this situation.

View File

@ -38,7 +38,7 @@ void ErsatzProgress::Milestone() {
next_ = std::numeric_limits<uint64_t>::max(); next_ = std::numeric_limits<uint64_t>::max();
out_ = NULL; out_ = NULL;
} else { } else {
next_ = std::max(next_, (stone * complete_) / kWidth); next_ = std::max(next_, ((stone + 1) * complete_ + kWidth - 1) / kWidth);
} }
} }

View File

@ -32,7 +32,6 @@ class ErsatzProgress {
void Set(uint64_t to) { void Set(uint64_t to) {
if ((current_ = to) >= next_) Milestone(); if ((current_ = to) >= next_) Milestone();
Milestone();
} }
void Finished() { void Finished() {

View File

@ -1,3 +1,5 @@
#define _LARGEFILE64_SOURCE
#include "util/file.hh" #include "util/file.hh"
#include "util/exception.hh" #include "util/exception.hh"
@ -91,7 +93,7 @@ void ReadOrThrow(int fd, void *to_void, std::size_t amount) {
uint8_t *to = static_cast<uint8_t*>(to_void); uint8_t *to = static_cast<uint8_t*>(to_void);
while (amount) { while (amount) {
std::size_t ret = PartialRead(fd, to, amount); std::size_t ret = PartialRead(fd, to, amount);
UTIL_THROW_IF(ret == 0, EndOfFileException, "Hit EOF in fd " << fd << " but there should be " << amount << " more bytes to read."); UTIL_THROW_IF(ret == 0, EndOfFileException, " in fd " << fd << " but there should be " << amount << " more bytes to read.");
amount -= ret; amount -= ret;
to += ret; to += ret;
} }
@ -141,7 +143,7 @@ void InternalSeek(int fd, int64_t off, int whence) {
UTIL_THROW_IF((__int64)-1 == _lseeki64(fd, off, whence), ErrnoException, "Windows seek failed"); UTIL_THROW_IF((__int64)-1 == _lseeki64(fd, off, whence), ErrnoException, "Windows seek failed");
#else #else
UTIL_THROW_IF((off_t)-1 == lseek(fd, off, whence), ErrnoException, "Seek failed"); UTIL_THROW_IF((off_t)-1 == lseek64(fd, off, whence), ErrnoException, "Seek failed");
#endif #endif
} }
} // namespace } // namespace

View File

@ -32,8 +32,6 @@ class scoped_fd {
return ret; return ret;
} }
operator bool() { return fd_ != -1; }
private: private:
int fd_; int fd_;

View File

@ -6,8 +6,8 @@
//#define HAVE_ICU //#define HAVE_ICU
#endif #endif
#ifndef HAVE_THREADS #ifndef HAVE_BOOST
//#define HAVE_THREADS #define HAVE_BOOST
#endif #endif
#endif // UTIL_HAVE__ #endif // UTIL_HAVE__

View File

@ -60,7 +60,7 @@ template <class KeyIter, class ValueIter> class JointProxy {
JointProxy(const KeyIter &key_iter, const ValueIter &value_iter) : inner_(key_iter, value_iter) {} JointProxy(const KeyIter &key_iter, const ValueIter &value_iter) : inner_(key_iter, value_iter) {}
JointProxy(const JointProxy<KeyIter, ValueIter> &other) : inner_(other.inner_) {} JointProxy(const JointProxy<KeyIter, ValueIter> &other) : inner_(other.inner_) {}
operator const value_type() const { operator value_type() const {
value_type ret; value_type ret;
ret.key = *inner_.key_; ret.key = *inner_.key_;
ret.value = *inner_.value_; ret.value = *inner_.value_;

View File

@ -370,7 +370,7 @@ ReadBase *ReadFactory(int fd, uint64_t &raw_amount) {
break; break;
} }
try { try {
AdvanceOrThrow(fd, -ReadCompressed::kMagicSize); SeekOrThrow(fd, 0);
} catch (const util::ErrnoException &e) { } catch (const util::ErrnoException &e) {
return new UncompressedWithHeader(hold.release(), header, ReadCompressed::kMagicSize); return new UncompressedWithHeader(hold.release(), header, ReadCompressed::kMagicSize);
} }

View File

@ -49,7 +49,11 @@
#define BASE_STRING_PIECE_H__ #define BASE_STRING_PIECE_H__
#include "util/have.hh" #include "util/have.hh"
#ifdef HAVE_BOOST
#include <boost/functional/hash/hash.hpp> #include <boost/functional/hash/hash.hpp>
#endif // HAVE_BOOST
#include <cstring> #include <cstring>
#include <iosfwd> #include <iosfwd>
#include <ostream> #include <ostream>
@ -252,6 +256,7 @@ inline std::ostream& operator<<(std::ostream& o, const StringPiece& piece) {
return o.write(piece.data(), static_cast<std::streamsize>(piece.size())); return o.write(piece.data(), static_cast<std::streamsize>(piece.size()));
} }
#ifdef HAVE_BOOST
inline size_t hash_value(const StringPiece &str) { inline size_t hash_value(const StringPiece &str) {
return boost::hash_range(str.data(), str.data() + str.length()); return boost::hash_range(str.data(), str.data() + str.length());
} }
@ -285,9 +290,12 @@ template <class T> typename T::iterator FindStringPiece(T &t, const StringPiece
return t.find(key, StringPieceCompatibleHash(), StringPieceCompatibleEquals()); return t.find(key, StringPieceCompatibleHash(), StringPieceCompatibleEquals());
#endif #endif
} }
#endif
#ifdef HAVE_ICU #ifdef HAVE_ICU
U_NAMESPACE_END U_NAMESPACE_END
using U_NAMESPACE_QUALIFIER StringPiece;
#endif #endif
#endif // BASE_STRING_PIECE_H__ #endif // BASE_STRING_PIECE_H__