KenLM c34d00

This commit is contained in:
Kenneth Heafield 2012-09-28 15:04:48 +01:00
parent 287836438c
commit 78f295c0a0
31 changed files with 155 additions and 74 deletions

View File

@ -17,4 +17,4 @@ run model_test.cc ../util//kenutil kenlm ..//boost_unit_test_framework : : test.
exe query : ngram_query.cc kenlm ../util//kenutil ;
exe build_binary : build_binary.cc kenlm ../util//kenutil ;
exe kenlm_max_order : max_order.cc : $(max-order) ;
exe kenlm_max_order : max_order.cc : <include>.. $(max-order) ;

View File

@ -50,7 +50,7 @@ std::size_t ArrayCount(uint64_t max_offset, uint64_t max_next, const Config &con
}
} // namespace
std::size_t ArrayBhiksha::Size(uint64_t max_offset, uint64_t max_next, const Config &config) {
uint64_t ArrayBhiksha::Size(uint64_t max_offset, uint64_t max_next, const Config &config) {
return sizeof(uint64_t) * (1 /* header */ + ArrayCount(max_offset, max_next, config)) + 7 /* 8-byte alignment */;
}

View File

@ -33,7 +33,7 @@ class DontBhiksha {
static void UpdateConfigFromBinary(int /*fd*/, Config &/*config*/) {}
static std::size_t Size(uint64_t /*max_offset*/, uint64_t /*max_next*/, const Config &/*config*/) { return 0; }
static uint64_t Size(uint64_t /*max_offset*/, uint64_t /*max_next*/, const Config &/*config*/) { return 0; }
static uint8_t InlineBits(uint64_t /*max_offset*/, uint64_t max_next, const Config &/*config*/) {
return util::RequiredBits(max_next);
@ -67,7 +67,7 @@ class ArrayBhiksha {
static void UpdateConfigFromBinary(int fd, Config &config);
static std::size_t Size(uint64_t max_offset, uint64_t max_next, const Config &config);
static uint64_t Size(uint64_t max_offset, uint64_t max_next, const Config &config);
static uint8_t InlineBits(uint64_t max_offset, uint64_t max_next, const Config &config);

View File

@ -200,10 +200,10 @@ void SeekPastHeader(int fd, const Parameters &params) {
util::SeekOrThrow(fd, TotalHeaderSize(params.counts.size()));
}
uint8_t *SetupBinary(const Config &config, const Parameters &params, std::size_t memory_size, Backing &backing) {
uint8_t *SetupBinary(const Config &config, const Parameters &params, uint64_t memory_size, Backing &backing) {
const uint64_t file_size = util::SizeFile(backing.file.get());
// The header is smaller than a page, so we have to map the whole header as well.
std::size_t total_map = TotalHeaderSize(params.counts.size()) + memory_size;
std::size_t total_map = util::CheckOverflow(TotalHeaderSize(params.counts.size()) + memory_size);
if (file_size != util::kBadSize && static_cast<uint64_t>(file_size) < total_map)
UTIL_THROW(FormatLoadException, "Binary file has size " << file_size << " but the headers say it should be at least " << total_map);

View File

@ -70,7 +70,7 @@ void MatchCheck(ModelType model_type, unsigned int search_version, const Paramet
void SeekPastHeader(int fd, const Parameters &params);
uint8_t *SetupBinary(const Config &config, const Parameters &params, std::size_t memory_size, Backing &backing);
uint8_t *SetupBinary(const Config &config, const Parameters &params, uint64_t memory_size, Backing &backing);
void ComplainAboutARPA(const Config &config, ModelType model_type);
@ -90,7 +90,7 @@ template <class To> void LoadLM(const char *file, const Config &config, To &to)
new_config.probing_multiplier = params.fixed.probing_multiplier;
detail::SeekPastHeader(backing.file.get(), params);
To::UpdateConfigFromBinary(backing.file.get(), params.counts, new_config);
std::size_t memory_size = To::Size(params.counts, new_config);
uint64_t memory_size = To::Size(params.counts, new_config);
uint8_t *start = detail::SetupBinary(new_config, params, memory_size, backing);
to.InitializeFromBinary(start, params, new_config, backing.file.get());
} else {

View File

@ -11,6 +11,8 @@
#ifdef WIN32
#include "util/getopt.hh"
#else
#include <unistd.h>
#endif
namespace lm {
@ -85,16 +87,16 @@ void ShowSizes(const char *file, const lm::ngram::Config &config) {
std::vector<uint64_t> counts;
util::FilePiece f(file);
lm::ReadARPACounts(f, counts);
std::size_t sizes[6];
uint64_t sizes[6];
sizes[0] = ProbingModel::Size(counts, config);
sizes[1] = RestProbingModel::Size(counts, config);
sizes[2] = TrieModel::Size(counts, config);
sizes[3] = QuantTrieModel::Size(counts, config);
sizes[4] = ArrayTrieModel::Size(counts, config);
sizes[5] = QuantArrayTrieModel::Size(counts, config);
std::size_t max_length = *std::max_element(sizes, sizes + sizeof(sizes) / sizeof(size_t));
std::size_t min_length = *std::min_element(sizes, sizes + sizeof(sizes) / sizeof(size_t));
std::size_t divide;
uint64_t max_length = *std::max_element(sizes, sizes + sizeof(sizes) / sizeof(uint64_t));
uint64_t min_length = *std::min_element(sizes, sizes + sizeof(sizes) / sizeof(uint64_t));
uint64_t divide;
char prefix;
if (min_length < (1 << 10) * 10) {
prefix = ' ';

View File

@ -38,6 +38,7 @@
#ifndef LM_LEFT__
#define LM_LEFT__
#include "lm/max_order.hh"
#include "lm/state.hh"
#include "lm/return.hh"

View File

@ -1,3 +1,4 @@
#include "lm/max_order.hh"
#include <iostream>
int main(int argc, char *argv[]) {

12
lm/max_order.hh Normal file
View File

@ -0,0 +1,12 @@
/* IF YOUR BUILD SYSTEM PASSES -DKENLM_MAX_ORDER, THEN CHANGE THE BUILD SYSTEM.
* If not, this is the default maximum order.
* Having this limit means that State can be
* (kMaxOrder - 1) * sizeof(float) bytes instead of
* sizeof(float*) + (kMaxOrder - 1) * sizeof(float) + malloc overhead
*/
#ifndef KENLM_MAX_ORDER
#define KENLM_MAX_ORDER 6
#endif
#ifndef KENLM_ORDER_MESSAGE
#define KENLM_ORDER_MESSAGE "Recompile with e.g. `bjam --kenlm-max-order=6 -a' to change the maximum order."
#endif

View File

@ -5,12 +5,14 @@
#include "lm/search_hashed.hh"
#include "lm/search_trie.hh"
#include "lm/read_arpa.hh"
#include "util/have.hh"
#include "util/murmur_hash.hh"
#include <algorithm>
#include <functional>
#include <numeric>
#include <cmath>
#include <limits>
namespace lm {
namespace ngram {
@ -18,17 +20,18 @@ namespace detail {
template <class Search, class VocabularyT> const ModelType GenericModel<Search, VocabularyT>::kModelType = Search::kModelType;
template <class Search, class VocabularyT> size_t GenericModel<Search, VocabularyT>::Size(const std::vector<uint64_t> &counts, const Config &config) {
template <class Search, class VocabularyT> uint64_t GenericModel<Search, VocabularyT>::Size(const std::vector<uint64_t> &counts, const Config &config) {
return VocabularyT::Size(counts[0], config) + Search::Size(counts, config);
}
template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::SetupMemory(void *base, const std::vector<uint64_t> &counts, const Config &config) {
size_t goal_size = util::CheckOverflow(Size(counts, config));
uint8_t *start = static_cast<uint8_t*>(base);
size_t allocated = VocabularyT::Size(counts[0], config);
vocab_.SetupMemory(start, allocated, counts[0], config);
start += allocated;
start = search_.SetupMemory(start, counts, config);
if (static_cast<std::size_t>(start - static_cast<uint8_t*>(base)) != Size(counts, config)) UTIL_THROW(FormatLoadException, "The data structures took " << (start - static_cast<uint8_t*>(base)) << " but Size says they should take " << Size(counts, config));
if (static_cast<std::size_t>(start - static_cast<uint8_t*>(base)) != goal_size) UTIL_THROW(FormatLoadException, "The data structures took " << (start - static_cast<uint8_t*>(base)) << " but Size says they should take " << goal_size);
}
template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::GenericModel(const char *file, const Config &config) {
@ -47,8 +50,19 @@ template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::Ge
P::Init(begin_sentence, null_context, vocab_, search_.Order());
}
namespace {
void CheckCounts(const std::vector<uint64_t> &counts) {
UTIL_THROW_IF(counts.size() > KENLM_MAX_ORDER, FormatLoadException, "This model has order " << counts.size() << " but KenLM was compiled to support up to " << KENLM_MAX_ORDER << ". " << KENLM_ORDER_MESSAGE);
if (sizeof(uint64_t) > sizeof(std::size_t)) {
for (std::vector<uint64_t>::const_iterator i = counts.begin(); i != counts.end(); ++i) {
UTIL_THROW_IF(*i > static_cast<uint64_t>(std::numeric_limits<size_t>::max()), util::OverflowException, "This model has " << *i << " " << (i - counts.begin() + 1) << "-grams which is too many for 32-bit machines.");
}
}
}
} // namespace
template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromBinary(void *start, const Parameters &params, const Config &config, int fd) {
UTIL_THROW_IF(params.counts.size() > KENLM_MAX_ORDER, FormatLoadException, "This model has order " << params.counts.size() << ". Re-compile (use -a), passing a number at least this large to bjam's --max-kenlm-order flag.");
CheckCounts(params.counts);
SetupMemory(start, params.counts, config);
vocab_.LoadedBinary(params.fixed.has_vocabulary, fd, config.enumerate_vocab);
search_.LoadedBinary();
@ -61,12 +75,11 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
std::vector<uint64_t> counts;
// File counts do not include pruned trigrams that extend to quadgrams etc. These will be fixed by search_.
ReadARPACounts(f, counts);
UTIL_THROW_IF(counts.size() > KENLM_MAX_ORDER, FormatLoadException, "This model has order " << counts.size() << ". Re-compile (use -a), passing a number at least this large to bjam's --max-kenlm-order flag.");
CheckCounts(counts);
if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model.");
if (config.probing_multiplier <= 1.0) UTIL_THROW(ConfigException, "probing multiplier must be > 1.0");
std::size_t vocab_size = VocabularyT::Size(counts[0], config);
std::size_t vocab_size = util::CheckOverflow(VocabularyT::Size(counts[0], config));
// Setup the binary file for writing the vocab lookup table. The search_ is responsible for growing the binary file to its needs.
vocab_.SetupMemory(SetupJustVocab(config, counts.size(), vocab_size, backing_), vocab_size, counts[0], config);

View File

@ -41,7 +41,7 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod
* does not include small non-mapped control structures, such as this class
* itself.
*/
static size_t Size(const std::vector<uint64_t> &counts, const Config &config = Config());
static uint64_t Size(const std::vector<uint64_t> &counts, const Config &config = Config());
/* Load the model from a file. It may be an ARPA or binary file. Binary
* files must have the format expected by this class or you'll get an

View File

@ -3,6 +3,7 @@
#include "lm/blank.hh"
#include "lm/config.hh"
#include "lm/max_order.hh"
#include "lm/model_type.hh"
#include "util/bit_packing.hh"
@ -23,7 +24,7 @@ class DontQuantize {
public:
static const ModelType kModelTypeAdd = static_cast<ModelType>(0);
static void UpdateConfigFromBinary(int, const std::vector<uint64_t> &, Config &) {}
static std::size_t Size(uint8_t /*order*/, const Config &/*config*/) { return 0; }
static uint64_t Size(uint8_t /*order*/, const Config &/*config*/) { return 0; }
static uint8_t MiddleBits(const Config &/*config*/) { return 63; }
static uint8_t LongestBits(const Config &/*config*/) { return 31; }
@ -137,9 +138,9 @@ class SeparatelyQuantize {
static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config);
static std::size_t Size(uint8_t order, const Config &config) {
size_t longest_table = (static_cast<size_t>(1) << static_cast<size_t>(config.prob_bits)) * sizeof(float);
size_t middle_table = (static_cast<size_t>(1) << static_cast<size_t>(config.backoff_bits)) * sizeof(float) + longest_table;
static uint64_t Size(uint8_t order, const Config &config) {
uint64_t longest_table = (static_cast<uint64_t>(1) << static_cast<uint64_t>(config.prob_bits)) * sizeof(float);
uint64_t middle_table = (static_cast<uint64_t>(1) << static_cast<uint64_t>(config.backoff_bits)) * sizeof(float) + longest_table;
// unigrams are currently not quantized so no need for a table.
return (order - 2) * middle_table + longest_table + /* for the bit counts and alignment padding) */ 8;
}

View File

@ -2,12 +2,13 @@
#include "lm/blank.hh"
#include <cmath>
#include <cstdlib>
#include <iostream>
#include <sstream>
#include <vector>
#include <ctype.h>
#include <math.h>
#include <string.h>
#include <stdint.h>
@ -31,6 +32,15 @@ bool IsEntirelyWhiteSpace(const StringPiece &line) {
const char kBinaryMagic[] = "mmap lm http://kheafield.com/code";
// strtoull isn't portable enough :-(
uint64_t ReadCount(const std::string &from) {
std::stringstream stream(from);
uint64_t ret;
stream >> ret;
UTIL_THROW_IF(!stream, FormatLoadException, "Bad count " << from);
return ret;
}
} // namespace
void ReadARPACounts(util::FilePiece &in, std::vector<uint64_t> &number) {
@ -52,15 +62,11 @@ void ReadARPACounts(util::FilePiece &in, std::vector<uint64_t> &number) {
// So strtol doesn't go off the end of line.
std::string remaining(line.data() + 6, line.size() - 6);
char *end_ptr;
unsigned long int length = std::strtol(remaining.c_str(), &end_ptr, 10);
unsigned int length = std::strtol(remaining.c_str(), &end_ptr, 10);
if ((end_ptr == remaining.c_str()) || (length - 1 != number.size())) UTIL_THROW(FormatLoadException, "ngram count lengths should be consecutive starting with 1: " << line);
if (*end_ptr != '=') UTIL_THROW(FormatLoadException, "Expected = immediately following the first number in the count line " << line);
++end_ptr;
const char *start = end_ptr;
long int count = std::strtol(start, &end_ptr, 10);
if (count < 0) UTIL_THROW(FormatLoadException, "Negative n-gram count " << count);
if (start == end_ptr) UTIL_THROW(FormatLoadException, "Couldn't parse n-gram count from " << line);
number.push_back(count);
number.push_back(ReadCount(end_ptr));
}
}
@ -103,7 +109,7 @@ void ReadBackoff(util::FilePiece &in, float &backoff) {
int float_class = _fpclass(backoff);
UTIL_THROW_IF(float_class == _FPCLASS_SNAN || float_class == _FPCLASS_QNAN || float_class == _FPCLASS_NINF || float_class == _FPCLASS_PINF, FormatLoadException, "Bad backoff " << backoff);
#else
int float_class = fpclassify(backoff);
int float_class = std::fpclassify(backoff);
UTIL_THROW_IF(float_class == FP_NAN || float_class == FP_INFINITE, FormatLoadException, "Bad backoff " << backoff);
#endif
}

View File

@ -74,8 +74,8 @@ template <class Value> class HashedSearch {
// TODO: move probing_multiplier here with next binary file format update.
static void UpdateConfigFromBinary(int, const std::vector<uint64_t> &, Config &) {}
static std::size_t Size(const std::vector<uint64_t> &counts, const Config &config) {
std::size_t ret = Unigram::Size(counts[0]);
static uint64_t Size(const std::vector<uint64_t> &counts, const Config &config) {
uint64_t ret = Unigram::Size(counts[0]);
for (unsigned char n = 1; n < counts.size() - 1; ++n) {
ret += Middle::Size(counts[n], config.probing_multiplier);
}
@ -160,7 +160,7 @@ template <class Value> class HashedSearch {
#endif
{}
static std::size_t Size(uint64_t count) {
static uint64_t Size(uint64_t count) {
return (count + 1) * sizeof(ProbBackoff); // +1 for hallucinate <unk>
}

View File

@ -5,6 +5,7 @@
#include "lm/binary_format.hh"
#include "lm/blank.hh"
#include "lm/lm_exception.hh"
#include "lm/max_order.hh"
#include "lm/quantize.hh"
#include "lm/trie.hh"
#include "lm/trie_sort.hh"
@ -88,7 +89,7 @@ class BackoffMessages {
if (!HasExtension(weights.backoff)) {
weights.backoff = kExtensionBackoff;
UTIL_THROW_IF(fseek(unigrams, -sizeof(weights), SEEK_CUR), util::ErrnoException, "Seeking backwards to denote unigram extension failed.");
WriteOrThrow(unigrams, &weights, sizeof(weights));
util::WriteOrThrow(unigrams, &weights, sizeof(weights));
}
const ProbPointer &write_to = *reinterpret_cast<const ProbPointer*>(current_ + sizeof(WordIndex));
base[write_to.array][write_to.index] += weights.backoff;

View File

@ -44,8 +44,8 @@ template <class Quant, class Bhiksha> class TrieSearch {
Bhiksha::UpdateConfigFromBinary(fd, config);
}
static std::size_t Size(const std::vector<uint64_t> &counts, const Config &config) {
std::size_t ret = Quant::Size(counts.size(), config) + Unigram::Size(counts[0]);
static uint64_t Size(const std::vector<uint64_t> &counts, const Config &config) {
uint64_t ret = Quant::Size(counts.size(), config) + Unigram::Size(counts[0]);
for (unsigned char i = 1; i < counts.size() - 1; ++i) {
ret += Middle::Size(Quant::MiddleBits(config), counts[i], counts[0], counts[i+1], config);
}

View File

@ -1,6 +1,7 @@
#ifndef LM_STATE__
#define LM_STATE__
#include "lm/max_order.hh"
#include "lm/word_index.hh"
#include "util/murmur_hash.hh"

View File

@ -36,7 +36,7 @@ bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t key_bits, uint8_
}
} // namespace
std::size_t BitPacked::BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits) {
uint64_t BitPacked::BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits) {
uint8_t total_bits = util::RequiredBits(max_vocab) + remaining_bits;
// Extra entry for next pointer at the end.
// +7 then / 8 to round up bits and convert to bytes
@ -57,7 +57,7 @@ void BitPacked::BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits)
max_vocab_ = max_vocab;
}
template <class Bhiksha> std::size_t BitPackedMiddle<Bhiksha>::Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_ptr, const Config &config) {
template <class Bhiksha> uint64_t BitPackedMiddle<Bhiksha>::Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_ptr, const Config &config) {
return Bhiksha::Size(entries + 1, max_ptr, config) + BaseSize(entries, max_vocab, quant_bits + Bhiksha::InlineBits(entries + 1, max_ptr, config));
}

View File

@ -49,7 +49,7 @@ class Unigram {
unigram_ = static_cast<UnigramValue*>(start);
}
static std::size_t Size(uint64_t count) {
static uint64_t Size(uint64_t count) {
// +1 in case unknown doesn't appear. +1 for the final next.
return (count + 2) * sizeof(UnigramValue);
}
@ -84,7 +84,7 @@ class BitPacked {
}
protected:
static std::size_t BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits);
static uint64_t BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits);
void BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits);
@ -99,7 +99,7 @@ class BitPacked {
template <class Bhiksha> class BitPackedMiddle : public BitPacked {
public:
static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const Config &config);
static uint64_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const Config &config);
// next_source need not be initialized.
BitPackedMiddle(void *base, uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config);
@ -128,7 +128,7 @@ template <class Bhiksha> class BitPackedMiddle : public BitPacked {
class BitPackedLongest : public BitPacked {
public:
static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab) {
static uint64_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab) {
return BaseSize(entries, max_vocab, quant_bits);
}

View File

@ -22,12 +22,6 @@
namespace lm {
namespace ngram {
namespace trie {
void WriteOrThrow(FILE *to, const void *data, size_t size) {
assert(size);
if (1 != std::fwrite(data, size, 1, to)) UTIL_THROW(util::ErrnoException, "Short write; requested size " << size);
}
namespace {
typedef util::SizedIterator NGramIter;
@ -95,12 +89,12 @@ FILE *WriteContextFile(uint8_t *begin, uint8_t *end, const util::TempMaker &make
// Write out to file and uniqueify at the same time. Could have used unique_copy if there was an appropriate OutputIterator.
if (context_begin == context_end) return out.release();
PartialIter i(context_begin);
WriteOrThrow(out.get(), i->Data(), context_size);
util::WriteOrThrow(out.get(), i->Data(), context_size);
const void *previous = i->Data();
++i;
for (; i != context_end; ++i) {
if (memcmp(previous, i->Data(), context_size)) {
WriteOrThrow(out.get(), i->Data(), context_size);
util::WriteOrThrow(out.get(), i->Data(), context_size);
previous = i->Data();
}
}
@ -116,7 +110,7 @@ struct ThrowCombine {
// Useful for context files that just contain records with no value.
struct FirstCombine {
void operator()(std::size_t entry_size, const void *first, const void * /*second*/, FILE *out) const {
WriteOrThrow(out, first, entry_size);
util::WriteOrThrow(out, first, entry_size);
}
};
@ -129,10 +123,10 @@ template <class Combine> FILE *MergeSortedFiles(FILE *first_file, FILE *second_f
EntryCompare less(order);
while (first && second) {
if (less(first.Data(), second.Data())) {
WriteOrThrow(out_file.get(), first.Data(), entry_size);
util::WriteOrThrow(out_file.get(), first.Data(), entry_size);
++first;
} else if (less(second.Data(), first.Data())) {
WriteOrThrow(out_file.get(), second.Data(), entry_size);
util::WriteOrThrow(out_file.get(), second.Data(), entry_size);
++second;
} else {
combine(entry_size, first.Data(), second.Data(), out_file.get());
@ -140,7 +134,7 @@ template <class Combine> FILE *MergeSortedFiles(FILE *first_file, FILE *second_f
}
}
for (RecordReader &remains = (first ? first : second); remains; ++remains) {
WriteOrThrow(out_file.get(), remains.Data(), entry_size);
util::WriteOrThrow(out_file.get(), remains.Data(), entry_size);
}
return out_file.release();
}
@ -164,7 +158,7 @@ void RecordReader::Init(FILE *file, std::size_t entry_size) {
void RecordReader::Overwrite(const void *start, std::size_t amount) {
long internal = (uint8_t*)start - (uint8_t*)data_.get();
UTIL_THROW_IF(fseek(file_, internal - entry_size_, SEEK_CUR), util::ErrnoException, "Couldn't seek backwards for revision");
WriteOrThrow(file_, start, amount);
util::WriteOrThrow(file_, start, amount);
long forward = entry_size_ - internal - amount;
#if !defined(_WIN32) && !defined(_WIN64)
if (forward)

View File

@ -3,6 +3,7 @@
#ifndef LM_TRIE_SORT__
#define LM_TRIE_SORT__
#include "lm/max_order.hh"
#include "lm/word_index.hh"
#include "util/file.hh"
@ -28,8 +29,6 @@ struct Config;
namespace trie {
void WriteOrThrow(FILE *to, const void *data, size_t size);
class EntryCompare : public std::binary_function<const void*, const void*, bool> {
public:
explicit EntryCompare(unsigned char order) : order_(order) {}

View File

@ -87,7 +87,7 @@ void WriteWordsWrapper::Write(int fd) {
SortedVocabulary::SortedVocabulary() : begin_(NULL), end_(NULL), enumerate_(NULL) {}
std::size_t SortedVocabulary::Size(std::size_t entries, const Config &/*config*/) {
uint64_t SortedVocabulary::Size(uint64_t entries, const Config &/*config*/) {
// Lead with the number of entries.
return sizeof(uint64_t) + sizeof(uint64_t) * entries;
}
@ -165,7 +165,7 @@ struct ProbingVocabularyHeader {
ProbingVocabulary::ProbingVocabulary() : enumerate_(NULL) {}
std::size_t ProbingVocabulary::Size(std::size_t entries, const Config &config) {
uint64_t ProbingVocabulary::Size(uint64_t entries, const Config &config) {
return ALIGN8(sizeof(detail::ProbingVocabularyHeader)) + Lookup::Size(entries, config.probing_multiplier);
}

View File

@ -62,7 +62,7 @@ class SortedVocabulary : public base::Vocabulary {
}
// Size for purposes of file writing
static size_t Size(std::size_t entries, const Config &config);
static uint64_t Size(uint64_t entries, const Config &config);
// Vocab words are [0, Bound()) Only valid after FinishedLoading/LoadedBinary.
WordIndex Bound() const { return bound_; }
@ -129,7 +129,7 @@ class ProbingVocabulary : public base::Vocabulary {
return lookup_.Find(detail::HashForVocab(str), i) ? i->value : 0;
}
static size_t Size(std::size_t entries, const Config &config);
static uint64_t Size(uint64_t entries, const Config &config);
// Vocab words are [0, Bound()).
WordIndex Bound() const { return bound_; }

View File

@ -9,16 +9,16 @@ namespace util {
namespace { const unsigned char kWidth = 100; }
ErsatzProgress::ErsatzProgress() : current_(0), next_(std::numeric_limits<std::size_t>::max()), complete_(next_), out_(NULL) {}
ErsatzProgress::ErsatzProgress() : current_(0), next_(std::numeric_limits<uint64_t>::max()), complete_(next_), out_(NULL) {}
ErsatzProgress::~ErsatzProgress() {
if (out_) Finished();
}
ErsatzProgress::ErsatzProgress(std::size_t complete, std::ostream *to, const std::string &message)
ErsatzProgress::ErsatzProgress(uint64_t complete, std::ostream *to, const std::string &message)
: current_(0), next_(complete / kWidth), complete_(complete), stones_written_(0), out_(to) {
if (!out_) {
next_ = std::numeric_limits<std::size_t>::max();
next_ = std::numeric_limits<uint64_t>::max();
return;
}
if (!message.empty()) *out_ << message << '\n';
@ -28,14 +28,14 @@ ErsatzProgress::ErsatzProgress(std::size_t complete, std::ostream *to, const std
void ErsatzProgress::Milestone() {
if (!out_) { current_ = 0; return; }
if (!complete_) return;
unsigned char stone = std::min(static_cast<std::size_t>(kWidth), (current_ * kWidth) / complete_);
unsigned char stone = std::min(static_cast<uint64_t>(kWidth), (current_ * kWidth) / complete_);
for (; stones_written_ < stone; ++stones_written_) {
(*out_) << '*';
}
if (stone == kWidth) {
(*out_) << std::endl;
next_ = std::numeric_limits<std::size_t>::max();
next_ = std::numeric_limits<uint64_t>::max();
out_ = NULL;
} else {
next_ = std::max(next_, (stone * complete_) / kWidth);

View File

@ -4,6 +4,8 @@
#include <iostream>
#include <string>
#include <inttypes.h>
// Ersatz version of boost::progress so core language model doesn't depend on
// boost. Also adds option to print nothing.
@ -14,7 +16,7 @@ class ErsatzProgress {
ErsatzProgress();
// Null means no output. The null value is useful for passing along the ostream pointer from another caller.
explicit ErsatzProgress(std::size_t complete, std::ostream *to = &std::cerr, const std::string &message = "");
explicit ErsatzProgress(uint64_t complete, std::ostream *to = &std::cerr, const std::string &message = "");
~ErsatzProgress();
@ -23,12 +25,12 @@ class ErsatzProgress {
return *this;
}
ErsatzProgress &operator+=(std::size_t amount) {
ErsatzProgress &operator+=(uint64_t amount) {
if ((current_ += amount) >= next_) Milestone();
return *this;
}
void Set(std::size_t to) {
void Set(uint64_t to) {
if ((current_ = to) >= next_) Milestone();
Milestone();
}
@ -40,7 +42,7 @@ class ErsatzProgress {
private:
void Milestone();
std::size_t current_, next_, complete_;
uint64_t current_, next_, complete_;
unsigned char stones_written_;
std::ostream *out_;

View File

@ -84,4 +84,7 @@ EndOfFileException::EndOfFileException() throw() {
}
EndOfFileException::~EndOfFileException() throw() {}
OverflowException::OverflowException() throw() {}
OverflowException::~OverflowException() throw() {}
} // namespace util

View File

@ -2,9 +2,12 @@
#define UTIL_EXCEPTION__
#include <exception>
#include <limits>
#include <sstream>
#include <string>
#include <inttypes.h>
namespace util {
template <class Except, class Data> typename Except::template ExceptionTag<Except&>::Identity operator<<(Except &e, const Data &data);
@ -111,6 +114,25 @@ class EndOfFileException : public Exception {
~EndOfFileException() throw();
};
class OverflowException : public Exception {
public:
OverflowException() throw();
~OverflowException() throw();
};
template <unsigned len> inline std::size_t CheckOverflowInternal(uint64_t value) {
UTIL_THROW_IF(value > static_cast<uint64_t>(std::numeric_limits<std::size_t>::max()), OverflowException, "Integer overflow detected. This model is too big for 32-bit code.");
return value;
}
template <> inline std::size_t CheckOverflowInternal<8>(uint64_t value) {
return value;
}
inline std::size_t CheckOverflow(uint64_t value) {
return CheckOverflowInternal<sizeof(std::size_t)>(value);
}
} // namespace util
#endif // UTIL_EXCEPTION__

View File

@ -6,6 +6,7 @@
#include <cstdio>
#include <iostream>
#include <assert.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <fcntl.h>
@ -111,6 +112,11 @@ void WriteOrThrow(int fd, const void *data_void, std::size_t size) {
}
}
void WriteOrThrow(FILE *to, const void *data, std::size_t size) {
assert(size);
if (1 != std::fwrite(data, size, 1, to)) UTIL_THROW(util::ErrnoException, "Short write; requested size " << size);
}
void FSyncOrThrow(int fd) {
// Apparently windows doesn't have fsync?
#if !defined(_WIN32) && !defined(_WIN64)
@ -119,8 +125,13 @@ void FSyncOrThrow(int fd) {
}
namespace {
void InternalSeek(int fd, off_t off, int whence) {
void InternalSeek(int fd, int64_t off, int whence) {
#if defined(_WIN32) || defined(_WIN64)
UTIL_THROW_IF((__int64)-1 == _lseeki64(fd, off, whence), ErrnoException, "Windows seek failed");
#else
UTIL_THROW_IF((off_t)-1 == lseek(fd, off, whence), ErrnoException, "Seek failed");
#endif
}
} // namespace
@ -143,6 +154,12 @@ std::FILE *FDOpenOrThrow(scoped_fd &file) {
return ret;
}
std::FILE *FOpenOrThrow(const char *path, const char *mode) {
std::FILE *ret;
UTIL_THROW_IF(!(ret = fopen(path, mode)), util::ErrnoException, "Could not fopen " << path << " for " << mode);
return ret;
}
TempMaker::TempMaker(const std::string &prefix) : base_(prefix) {
base_ += "XXXXXX";
}

View File

@ -80,6 +80,7 @@ void ReadOrThrow(int fd, void *to, std::size_t size);
std::size_t ReadOrEOF(int fd, void *to_void, std::size_t amount);
void WriteOrThrow(int fd, const void *data_void, std::size_t size);
void WriteOrThrow(FILE *to, const void *data, std::size_t size);
void FSyncOrThrow(int fd);
@ -90,6 +91,8 @@ void SeekEnd(int fd);
std::FILE *FDOpenOrThrow(scoped_fd &file);
std::FILE *FOpenOrThrow(const char *path, const char *mode);
class TempMaker {
public:
explicit TempMaker(const std::string &prefix);

View File

@ -5,6 +5,8 @@
#include "util/mmap.hh"
#ifdef WIN32
#include <io.h>
#else
#include <unistd.h>
#endif // WIN32
#include <iostream>

View File

@ -8,6 +8,7 @@
#include <functional>
#include <assert.h>
#include <inttypes.h>
namespace util {
@ -42,8 +43,8 @@ template <class EntryT, class HashT, class EqualT = std::equal_to<typename Entry
typedef EqualT Equal;
public:
static std::size_t Size(std::size_t entries, float multiplier) {
std::size_t buckets = std::max(entries + 1, static_cast<std::size_t>(multiplier * static_cast<float>(entries)));
static uint64_t Size(uint64_t entries, float multiplier) {
uint64_t buckets = std::max(entries + 1, static_cast<uint64_t>(multiplier * static_cast<float>(entries)));
return buckets * sizeof(Entry);
}