KenLM 6f7913cc7ca0f7672c6d899358365f047a742bbb

Mostly fixes from Tetsuo Kiso and Jonathan Graehl
This commit is contained in:
Kenneth Heafield 2013-09-09 10:43:20 +01:00
parent 4c015afb7e
commit 5cca6fafcb
10 changed files with 260 additions and 56 deletions

View File

@ -8,6 +8,7 @@
#include <cstring>
#include <limits>
#include <string>
#include <cstdlib>
#include <stdint.h>
@ -169,21 +170,21 @@ bool IsBinaryFormat(int fd) {
}
Sanity reference_header = Sanity();
reference_header.SetToReference();
if (!memcmp(memory.get(), &reference_header, sizeof(Sanity))) return true;
if (!memcmp(memory.get(), kMagicIncomplete, strlen(kMagicIncomplete))) {
if (!std::memcmp(memory.get(), &reference_header, sizeof(Sanity))) return true;
if (!std::memcmp(memory.get(), kMagicIncomplete, strlen(kMagicIncomplete))) {
UTIL_THROW(FormatLoadException, "This binary file did not finish building");
}
if (!memcmp(memory.get(), kMagicBeforeVersion, strlen(kMagicBeforeVersion))) {
if (!std::memcmp(memory.get(), kMagicBeforeVersion, strlen(kMagicBeforeVersion))) {
char *end_ptr;
const char *begin_version = static_cast<const char*>(memory.get()) + strlen(kMagicBeforeVersion);
long int version = strtol(begin_version, &end_ptr, 10);
long int version = std::strtol(begin_version, &end_ptr, 10);
if ((end_ptr != begin_version) && version != kMagicVersion) {
UTIL_THROW(FormatLoadException, "Binary file has version " << version << " but this implementation expects version " << kMagicVersion << " so you'll have to use the ARPA to rebuild your binary");
}
OldSanity old_sanity = OldSanity();
old_sanity.SetToReference();
UTIL_THROW_IF(!memcmp(memory.get(), &old_sanity, sizeof(OldSanity)), FormatLoadException, "Looks like this is an old 32-bit format. The old 32-bit format has been removed so that 64-bit and 32-bit files are exchangeable.");
UTIL_THROW_IF(!std::memcmp(memory.get(), &old_sanity, sizeof(OldSanity)), FormatLoadException, "Looks like this is an old 32-bit format. The old 32-bit format has been removed so that 64-bit and 32-bit files are exchangeable.");
UTIL_THROW(FormatLoadException, "File looks like it should be loaded with mmap, but the test values don't match. Try rebuilding the binary format LM using the same code revision, compiler, and architecture");
}
return false;

View File

@ -186,6 +186,7 @@ int main(int argc, char *argv[]) {
config.write_mmap = argv[optind + 2];
} else {
Usage(argv[0], default_mem);
return 1;
}
if (!strcmp(model_type, "probing")) {
if (!set_write_method) config.write_method = Config::WRITE_AFTER;

View File

@ -48,21 +48,21 @@ unsigned int ReadMultiple(std::istream &in, Substrings &out) {
return sentence_id + sentence_content;
}
namespace detail { const StringPiece kEndSentence("</s>"); }
namespace {
typedef unsigned int Sentence;
typedef std::vector<Sentence> Sentences;
} // namespace
class Vertex;
namespace detail {
const StringPiece kEndSentence("</s>");
class Arc {
public:
Arc() {}
// For arcs from one vertex to another.
void SetPhrase(Vertex &from, Vertex &to, const Sentences &intersect) {
void SetPhrase(detail::Vertex &from, detail::Vertex &to, const Sentences &intersect) {
Set(to, intersect);
from_ = &from;
}
@ -71,7 +71,7 @@ class Arc {
* aligned). These have no from_ vertex; it implictly matches every
* sentence. This also handles when the n-gram is a substring of a phrase.
*/
void SetRight(Vertex &to, const Sentences &complete) {
void SetRight(detail::Vertex &to, const Sentences &complete) {
Set(to, complete);
from_ = NULL;
}
@ -97,11 +97,11 @@ class Arc {
void LowerBound(const Sentence to);
private:
void Set(Vertex &to, const Sentences &sentences);
void Set(detail::Vertex &to, const Sentences &sentences);
const Sentence *current_;
const Sentence *last_;
Vertex *from_;
detail::Vertex *from_;
};
struct ArcGreater : public std::binary_function<const Arc *, const Arc *, bool> {
@ -183,7 +183,13 @@ void Vertex::LowerBound(const Sentence to) {
}
}
void BuildGraph(const Substrings &phrase, const std::vector<Hash> &hashes, Vertex *const vertices, Arc *free_arc) {
} // namespace detail
namespace {
void BuildGraph(const Substrings &phrase, const std::vector<Hash> &hashes, detail::Vertex *const vertices, detail::Arc *free_arc) {
using detail::Vertex;
using detail::Arc;
assert(!hashes.empty());
const Hash *const first_word = &*hashes.begin();
@ -231,17 +237,29 @@ void BuildGraph(const Substrings &phrase, const std::vector<Hash> &hashes, Verte
namespace detail {
// Here instead of header due to forward declaration.
ConditionCommon::ConditionCommon(const Substrings &substrings) : substrings_(substrings) {}
// Rest of the variables are temporaries anyway
ConditionCommon::ConditionCommon(const ConditionCommon &from) : substrings_(from.substrings_) {}
ConditionCommon::~ConditionCommon() {}
detail::Vertex &ConditionCommon::MakeGraph() {
assert(!hashes_.empty());
vertices_.clear();
vertices_.resize(hashes_.size());
arcs_.clear();
// One for every substring.
arcs_.resize(((hashes_.size() + 1) * hashes_.size()) / 2);
BuildGraph(substrings_, hashes_, &*vertices_.begin(), &*arcs_.begin());
return vertices_[hashes_.size() - 1];
}
} // namespace detail
bool Union::Evaluate() {
assert(!hashes_.empty());
// Usually there are at most 6 words in an n-gram, so stack allocation is reasonable.
Vertex vertices[hashes_.size()];
// One for every substring.
Arc arcs[((hashes_.size() + 1) * hashes_.size()) / 2];
BuildGraph(substrings_, hashes_, vertices, arcs);
Vertex &last_vertex = vertices[hashes_.size() - 1];
detail::Vertex &last_vertex = MakeGraph();
unsigned int lower = 0;
while (true) {
last_vertex.LowerBound(lower);
@ -252,14 +270,7 @@ bool Union::Evaluate() {
}
template <class Output> void Multiple::Evaluate(const StringPiece &line, Output &output) {
assert(!hashes_.empty());
// Usually there are at most 6 words in an n-gram, so stack allocation is reasonable.
Vertex vertices[hashes_.size()];
// One for every substring.
Arc arcs[((hashes_.size() + 1) * hashes_.size()) / 2];
BuildGraph(substrings_, hashes_, vertices, arcs);
Vertex &last_vertex = vertices[hashes_.size() - 1];
detail::Vertex &last_vertex = MakeGraph();
unsigned int lower = 0;
while (true) {
last_vertex.LowerBound(lower);

View File

@ -103,11 +103,33 @@ template <class Iterator> void MakeHashes(Iterator i, const Iterator &end, std::
}
}
class Vertex;
class Arc;
class ConditionCommon {
protected:
ConditionCommon(const Substrings &substrings);
ConditionCommon(const ConditionCommon &from);
~ConditionCommon();
detail::Vertex &MakeGraph();
// Temporaries in PassNGram and Evaluate to avoid reallocation.
std::vector<Hash> hashes_;
private:
std::vector<detail::Vertex> vertices_;
std::vector<detail::Arc> arcs_;
const Substrings &substrings_;
};
} // namespace detail
class Union {
class Union : public detail::ConditionCommon {
public:
explicit Union(const Substrings &substrings) : substrings_(substrings) {}
explicit Union(const Substrings &substrings) : detail::ConditionCommon(substrings) {}
template <class Iterator> bool PassNGram(const Iterator &begin, const Iterator &end) {
detail::MakeHashes(begin, end, hashes_);
@ -116,23 +138,19 @@ class Union {
private:
bool Evaluate();
std::vector<Hash> hashes_;
const Substrings &substrings_;
};
class Multiple {
class Multiple : public detail::ConditionCommon {
public:
explicit Multiple(const Substrings &substrings) : substrings_(substrings) {}
explicit Multiple(const Substrings &substrings) : detail::ConditionCommon(substrings) {}
template <class Iterator, class Output> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line, Output &output) {
detail::MakeHashes(begin, end, hashes_);
if (hashes_.empty()) {
output.AddNGram(line);
return;
} else {
Evaluate(line, output);
}
Evaluate(line, output);
}
template <class Output> void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) {
@ -143,10 +161,6 @@ class Multiple {
private:
template <class Output> void Evaluate(const StringPiece &line, Output &output);
std::vector<Hash> hashes_;
const Substrings &substrings_;
};
} // namespace phrase

View File

@ -0,0 +1,165 @@
#include "util/fake_ofstream.hh"
#include "util/file_piece.hh"
#include "util/murmur_hash.hh"
#include "util/pool.hh"
#include "util/string_piece.hh"
#include "util/string_piece_hash.hh"
#include "util/tokenize_piece.hh"
#include <boost/unordered_map.hpp>
#include <boost/unordered_set.hpp>
#include <cstddef>
#include <vector>
namespace {
struct MutablePiece {
mutable StringPiece behind;
bool operator==(const MutablePiece &other) const {
return behind == other.behind;
}
};
std::size_t hash_value(const MutablePiece &m) {
return hash_value(m.behind);
}
class InternString {
public:
const char *Add(StringPiece str) {
MutablePiece mut;
mut.behind = str;
std::pair<boost::unordered_set<MutablePiece>::iterator, bool> res(strs_.insert(mut));
if (res.second) {
void *mem = backing_.Allocate(str.size() + 1);
memcpy(mem, str.data(), str.size());
static_cast<char*>(mem)[str.size()] = 0;
res.first->behind = StringPiece(static_cast<char*>(mem), str.size());
}
return res.first->behind.data();
}
private:
util::Pool backing_;
boost::unordered_set<MutablePiece> strs_;
};
class TargetWords {
public:
void Introduce(StringPiece source) {
vocab_.resize(vocab_.size() + 1);
std::vector<unsigned int> temp(1, vocab_.size() - 1);
Add(temp, source);
}
void Add(const std::vector<unsigned int> &sentences, StringPiece target) {
if (sentences.empty()) return;
interns_.clear();
for (util::TokenIter<util::SingleCharacter, true> i(target, ' '); i; ++i) {
interns_.push_back(intern_.Add(*i));
}
for (std::vector<unsigned int>::const_iterator i(sentences.begin()); i != sentences.end(); ++i) {
boost::unordered_set<const char *> &vocab = vocab_[*i];
for (std::vector<const char *>::const_iterator j = interns_.begin(); j != interns_.end(); ++j) {
vocab.insert(*j);
}
}
}
void Print() const {
util::FakeOFStream out(1);
for (std::vector<boost::unordered_set<const char *> >::const_iterator i = vocab_.begin(); i != vocab_.end(); ++i) {
for (boost::unordered_set<const char *>::const_iterator j = i->begin(); j != i->end(); ++j) {
out << *j << ' ';
}
out << '\n';
}
}
private:
InternString intern_;
std::vector<boost::unordered_set<const char *> > vocab_;
// Temporary in Add.
std::vector<const char *> interns_;
};
class Input {
public:
explicit Input(std::size_t max_length)
: max_length_(max_length), sentence_id_(0), empty_() {}
void AddSentence(StringPiece sentence, TargetWords &targets) {
canonical_.clear();
starts_.clear();
starts_.push_back(0);
for (util::TokenIter<util::AnyCharacter, true> i(sentence, StringPiece("\0 \t", 3)); i; ++i) {
canonical_.append(i->data(), i->size());
canonical_ += ' ';
starts_.push_back(canonical_.size());
}
targets.Introduce(canonical_);
for (std::size_t i = 0; i < starts_.size() - 1; ++i) {
std::size_t subtract = starts_[i];
const char *start = &canonical_[subtract];
for (std::size_t j = i + 1; j < std::min(starts_.size(), i + max_length_ + 1); ++j) {
map_[util::MurmurHash64A(start, &canonical_[starts_[j]] - start - 1)].push_back(sentence_id_);
}
}
++sentence_id_;
}
// Assumes single space-delimited phrase with no space at the beginning or end.
const std::vector<unsigned int> &Matches(StringPiece phrase) const {
Map::const_iterator i = map_.find(util::MurmurHash64A(phrase.data(), phrase.size()));
return i == map_.end() ? empty_ : i->second;
}
private:
const std::size_t max_length_;
// hash of phrase is the key, array of sentences is the value.
typedef boost::unordered_map<uint64_t, std::vector<unsigned int> > Map;
Map map_;
std::size_t sentence_id_;
// Temporaries in AddSentence.
std::string canonical_;
std::vector<std::size_t> starts_;
const std::vector<unsigned int> empty_;
};
} // namespace
int main(int argc, char *argv[]) {
if (argc != 2) {
std::cerr << "Expected source text on the command line" << std::endl;
return 1;
}
Input input(7);
TargetWords targets;
try {
util::FilePiece inputs(argv[1], &std::cerr);
while (true)
input.AddSentence(inputs.ReadLine(), targets);
} catch (const util::EndOfFileException &e) {}
util::FilePiece table(0, NULL, &std::cerr);
StringPiece line;
const StringPiece pipes("|||");
while (true) {
try {
line = table.ReadLine();
} catch (const util::EndOfFileException &e) { break; }
util::TokenIter<util::MultiCharacter> it(line, pipes);
StringPiece source(*it);
if (!source.empty() && source[source.size() - 1] == ' ')
source.remove_suffix(1);
targets.Add(input.Matches(source), *++it);
}
targets.Print();
}

View File

@ -19,7 +19,7 @@
namespace lm {
// 1 for '\t', '\n', and ' '. This is stricter than isspace.
// 1 for '\t', '\n', and ' '. This is stricter than isspace.
const bool kARPASpaces[256] = {0,0,0,0,0,0,0,0,0,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0};
namespace {
@ -50,7 +50,7 @@ void ReadARPACounts(util::FilePiece &in, std::vector<uint64_t> &number) {
// In general, ARPA files can have arbitrary text before "\data\"
// But in KenLM, we require such lines to start with "#", so that
// we can do stricter error checking
while (IsEntirelyWhiteSpace(line) || line.starts_with("#")) {
while (IsEntirelyWhiteSpace(line) || starts_with(line, "#")) {
line = in.ReadLine();
}
@ -58,7 +58,7 @@ void ReadARPACounts(util::FilePiece &in, std::vector<uint64_t> &number) {
if ((line.size() >= 2) && (line.data()[0] == 0x1f) && (static_cast<unsigned char>(line.data()[1]) == 0x8b)) {
UTIL_THROW(FormatLoadException, "Looks like a gzip file. If this is an ARPA file, pipe " << in.FileName() << " through zcat. If this already in binary format, you need to decompress it because mmap doesn't work on top of gzip.");
}
if (static_cast<size_t>(line.size()) >= strlen(kBinaryMagic) && StringPiece(line.data(), strlen(kBinaryMagic)) == kBinaryMagic)
if (static_cast<size_t>(line.size()) >= strlen(kBinaryMagic) && StringPiece(line.data(), strlen(kBinaryMagic)) == kBinaryMagic)
UTIL_THROW(FormatLoadException, "This looks like a binary file but got sent to the ARPA parser. Did you compress the binary file or pass a binary file where only ARPA files are accepted?");
UTIL_THROW_IF(line.size() >= 4 && StringPiece(line.data(), 4) == "blmt", FormatLoadException, "This looks like an IRSTLM binary file. Did you forget to pass --text yes to compile-lm?");
UTIL_THROW_IF(line == "iARPA", FormatLoadException, "This looks like an IRSTLM iARPA file. You need an ARPA file. Run\n compile-lm --text yes " << in.FileName() << " " << in.FileName() << ".arpa\nfirst.");
@ -66,7 +66,7 @@ void ReadARPACounts(util::FilePiece &in, std::vector<uint64_t> &number) {
}
while (!IsEntirelyWhiteSpace(line = in.ReadLine())) {
if (line.size() < 6 || strncmp(line.data(), "ngram ", 6)) UTIL_THROW(FormatLoadException, "count line \"" << line << "\"doesn't begin with \"ngram \"");
// So strtol doesn't go off the end of line.
// So strtol doesn't go off the end of line.
std::string remaining(line.data() + 6, line.size() - 6);
char *end_ptr;
unsigned int length = std::strtol(remaining.c_str(), &end_ptr, 10);
@ -102,8 +102,8 @@ void ReadBackoff(util::FilePiece &in, Prob &/*weights*/) {
}
void ReadBackoff(util::FilePiece &in, float &backoff) {
// Always make zero negative.
// Negative zero means that no (n+1)-gram has this n-gram as context.
// Always make zero negative.
// Negative zero means that no (n+1)-gram has this n-gram as context.
// Therefore the hypothesis state can be shorter. Of course, many n-grams
// are context for (n+1)-grams. An algorithm in the data structure will go
// back and set the backoff to positive zero in these cases.

View File

@ -11,6 +11,7 @@
#include "util/file_piece.hh"
#include <vector>
#include <cstdlib>
#include <assert.h>
@ -104,12 +105,12 @@ template <class Quant, class Bhiksha> class TrieSearch {
private:
friend void BuildTrie<Quant, Bhiksha>(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing);
// Middles are managed manually so we can delay construction and they don't have to be copyable.
// Middles are managed manually so we can delay construction and they don't have to be copyable.
void FreeMiddles() {
for (const Middle *i = middle_begin_; i != middle_end_; ++i) {
i->~Middle();
}
free(middle_begin_);
std::free(middle_begin_);
}
typedef trie::BitPackedMiddle<Bhiksha> Middle;

View File

@ -9,6 +9,7 @@ namespace ngram {
template <class Model> LowerRestBuild<Model>::LowerRestBuild(const Config &config, unsigned int order, const typename Model::Vocabulary &vocab) {
UTIL_THROW_IF(config.rest_lower_files.size() != order - 1, ConfigException, "This model has order " << order << " so there should be " << (order - 1) << " lower-order models for rest cost purposes.");
Config for_lower = config;
for_lower.write_mmap = NULL;
for_lower.rest_lower_files.clear();
// Unigram models aren't supported, so this is a custom loader.

View File

@ -74,6 +74,12 @@ inline bool operator!=(const StringPiece& x, const StringPiece& y) {
#endif // old version of ICU
U_NAMESPACE_BEGIN
inline bool starts_with(const StringPiece& longer, const StringPiece& prefix) {
int longersize = longer.size(), prefixsize = prefix.size();
return longersize >= prefixsize && std::memcmp(longer.data(), prefix.data(), prefixsize) == 0;
}
#else
#include <algorithm>
@ -212,7 +218,7 @@ class StringPiece {
StringPiece substr(size_type pos, size_type n = npos) const;
static int wordmemcmp(const char* p, const char* p2, size_type N) {
return memcmp(p, p2, N);
return std::memcmp(p, p2, N);
}
};
@ -227,6 +233,10 @@ inline bool operator!=(const StringPiece& x, const StringPiece& y) {
return !(x == y);
}
inline bool starts_with(const StringPiece& longer, const StringPiece& prefix) {
return longer.starts_with(prefix);
}
#endif // HAVE_ICU undefined
inline bool operator<(const StringPiece& x, const StringPiece& y) {

View File

@ -83,7 +83,7 @@ void PrintUsage(std::ostream &out) {
}
struct rusage usage;
if (getrusage(RUSAGE_CHILDREN, &usage)) {
if (getrusage(RUSAGE_SELF, &usage)) {
perror("getrusage");
return;
}
@ -135,7 +135,7 @@ template <class Num> uint64_t ParseNum(const std::string &arg) {
if (after == "%") {
uint64_t mem = GuessPhysicalMemory();
UTIL_THROW_IF_ARG(!mem, SizeParseError, (arg), "because % was specified but the physical memory size could not be determined.");
return static_cast<double>(value) * static_cast<double>(mem) / 100.0;
return static_cast<uint64_t>(static_cast<double>(value) * static_cast<double>(mem) / 100.0);
}
std::string units("bKMGTPEZY");
@ -144,7 +144,7 @@ template <class Num> uint64_t ParseNum(const std::string &arg) {
for (std::string::size_type i = 0; i < index; ++i) {
value *= 1024;
}
return value;
return static_cast<uint64_t>(value);
}
} // namespace