kenlm update

mmap works; utility to build binary format included.  
Configuration struct (including unknown handling options). 
config option to build a binary format while loading an ARPA.  
Doesn't require Boost or ICU. 
Works on 32 and 64 bit. 
query appends </s>. 
Reduced memory consumption: 12 bytes per 5-gram instead of 16 bytes on 64-bit machines.  
Reduced memory consumption: vocabulary takes 8 bytes/word instead of 12 bytes/word if sorted is 
used. 
Removed some cruft that wasn't needed by this code.  
Compiles on Mac OS X.  
Add script to run tests; these depend on Boost.  
SRI wrapper works again, is slightly faster, no longer depends on Boost, and has a test.
Debugging code only appears with -DDEBUG, so the default is fast.  



git-svn-id: https://mosesdecoder.svn.sourceforge.net/svnroot/mosesdecoder/trunk@3447 1f5c12ca-751b-0410-a591-d2e778427230
This commit is contained in:
heafield 2010-09-14 21:33:11 +00:00
parent 128a885406
commit d00c788760
52 changed files with 2141 additions and 1929 deletions

View File

@ -1,3 +1,10 @@
Most of the code here is licensed under the LGPL. There are exceptions which have their own licenses, listed below. See comments in those files for more details.
util/murmur_hash.cc is under the MIT license.
util/string_piece.hh and util/string_piece.cc are Google code and contains its own license.
For the rest:
Avenue code is free software: you can redistribute it and/or modify
it under the terms of the GNU Lesser General Public License as published
by the Free Software Foundation, either version 3 of the License, or
@ -10,8 +17,3 @@
You should have received a copy of the GNU Lesser General Public License
along with Avenue code. If not, see <http://www.gnu.org/licenses/>.
Most of the code here is licensed under the LGPL. There are exceptions which have their own licenses, listed below. You may not have been provided with some of these directories or files.
util/murmur_hash.cc is under the MIT license.
util/string_piece.hh and util/string_piece.cc are Google code and contains its own license.

View File

@ -1,11 +1,31 @@
This is a language model under active development. However, the API is mostly stable.
Language model inference code by Kenneth Heafield <infer at kheafield.com>
See LICENSE for list of files by other people and their licenses.
Currently, it loads an ARPA file in 2/3 the time SRI takes and uses 6.5 GB when SRI takes 11 GB. I'm working on optimizing this even further.
Compile: ./compile.sh
Run: ./query lm/test.arpa <text
Build binary format: ./build_binary lm/test.arpa test.binary
Use binary format: ./query test.binary <text
Binary format is coming soon now. It's already using mmap; the only change is to pass an fd to this mmap call.
Test (uses Boost): ./test.sh
Currently it depends on Boost (mostly lexical_cast) and ICU (only StringPiece). I am actively working on removing these dependencies. My normal build system is Boost Jam. I've stripped this out and simplified to a shell script ./compile.sh for you.
Currently, it loads an ARPA file in 2/3 the time SRI takes and uses 6.5 GB when SRI takes 11 GB. These are compared to the default SRI build (i.e. without their smaller structures). I'm working on optimizing this even further.
I recommend copying the code and distributing it with your decoder. However, please send improvements to me so that they can be integrated into the core package.
Binary format via mmap is supported. Run ./build_binary to make one then pass the binary file name instead.
Also included is a wrapper to SRI with the same interface.
Currently, it assumes POSIX APIs for errno, sterror_r, open, close, mmap, munmap, ftruncate, fstat, and read. This is tested on Linux and the non-UNIX Mac OS X. I welcome submissions porting (via #ifdef) to other systems (e.g. Windows) but proudly have no machine on which to test it.
A brief note to Mac OS X users: your gcc is too old to recognize the pack pragma. The warning effectively means that, on 64-bit machines, the model will use 16 bytes instead of 12 bytes per n-gram of maximum order.
It does not depend on Boost or ICU. However, if you use Boost and/or ICU in the rest of your code, you should define USE_BOOST and/or USE_ICU in util/string_piece.hh. Defining USE_BOOST will let you hash StringPiece. Defining USE_ICU will use ICU's StringPiece to prevent a conflict with the one provided here. By the way, ICU's StringPiece is buggy and I reported this bug: http://bugs.icu-project.org/trac/ticket/7924 .
The recommend way to use this:
Copy the code and distribute with your decoder.
Set USE_ICU and USE_BOOST at the top of util/string_piece.hh as instructed above.
Look at compile.sh and reimplement using your build system.
Use either the interface in lm/ngram.hh or lm/virtual_interface.hh
Interface documentation is in comments of lm/virtual_interface.hh (including for lm/ngram.hh).
I recommend copying the code and distributing it with your decoder. However, please send improvements to me so that they can be integrated into the package.
Also included:
A wrapper to SRI with the same interface.

View File

@ -1,2 +1,11 @@
#!/bin/bash
g++ -O3 -I. -licui18n lm/arpa_io.cc lm/exception.cc lm/ngram.cc lm/query.cc lm/virtual_interface.cc util/errno_exception.cc util/file_piece.cc util/murmur_hash.cc util/scoped.cc util/string_piece.cc -o query -licutu -licutu -licudata -licuio -licule -liculx -licuuc
#This is just an example compilation. You should integrate these files into your build system. I can provide boost jam if you want.
#If your code uses ICU, edit util/string_piece.hh and uncomment #define USE_ICU
set -e
for i in util/{ersatz_progress,exception,file_piece,murmur_hash,scoped,string_piece} lm/{exception,virtual_interface,ngram}; do
g++ -I. -O3 -c $i.cc -o $i.o
done
g++ -I. -O3 lm/ngram_build_binary.cc {lm,util}/*.o -o build_binary
g++ -I. -O3 lm/ngram_query.cc {lm,util}/*.o -o query

View File

@ -1,180 +0,0 @@
#include "lm/arpa_io.hh"
#include "util/file_piece.hh"
#include <boost/lexical_cast.hpp>
#include <istream>
#include <ostream>
#include <string>
#include <vector>
#include <ctype.h>
#include <errno.h>
#include <string.h>
namespace lm {
ARPAInputException::ARPAInputException(const StringPiece &message) throw() : what_("Error: ") {
what_.append(message.data(), message.size());
}
ARPAInputException::ARPAInputException(const StringPiece &message, const StringPiece &line) throw() {
what_ = "Error: ";
what_.append(message.data(), message.size());
what_ += " in line '";
what_.append(line.data(), line.size());
what_ += "'.";
}
ARPAOutputException::ARPAOutputException(const char *message, const std::string &file_name) throw()
: what_(std::string(message) + " file " + file_name), file_name_(file_name) {
if (errno) {
char buf[1024];
buf[0] = 0;
const char *add = buf;
if (!strerror_r(errno, buf, 1024)) {
what_ += " :";
what_ += add;
}
}
}
// Seeking is the responsibility of the caller.
void WriteCounts(std::ostream &out, const std::vector<size_t> &number) {
out << "\n\\data\\\n";
for (unsigned int i = 0; i < number.size(); ++i) {
out << "ngram " << i+1 << "=" << number[i] << '\n';
}
out << '\n';
}
size_t SizeNeededForCounts(const std::vector<size_t> &number) {
std::ostringstream buf;
WriteCounts(buf, number);
return buf.tellp();
}
bool IsEntirelyWhiteSpace(const StringPiece &line) {
for (size_t i = 0; i < static_cast<size_t>(line.size()); ++i) {
if (!isspace(line.data()[i])) return false;
}
return true;
}
void ReadCounts(std::istream &in, std::vector<size_t> &number) throw (ARPAInputException) {
number.clear();
std::string line;
if (!getline(in, line)) throw ARPAInputException("reading input lm");
if (!IsEntirelyWhiteSpace(line)) throw ARPAInputException("first line was not blank", line);
if (!getline(in, line)) throw ARPAInputException("reading \\data\\");
if (!(line == "\\data\\")) throw ARPAInputException("second line was not \\data\\.", line);
while (getline(in, line)) {
if (IsEntirelyWhiteSpace(line)) {
return;
}
if (strncmp(line.c_str(), "ngram ", 6)) throw ARPAInputException("count line doesn't begin with \"ngram \"", line);
size_t equals = line.find('=');
if (equals == std::string::npos) throw ARPAInputException("expected = inside a count line", line);
unsigned int length = boost::lexical_cast<unsigned int>(line.substr(6, equals - 6));
if (length - 1 != number.size()) throw ARPAInputException("ngram count lengths should be consecutive starting with 1", line);
unsigned int count = boost::lexical_cast<unsigned int>(line.substr(equals + 1));
number.push_back(count);
}
throw ARPAInputException("reading counts from input lm failed");
}
void ReadCounts(util::FilePiece &in, std::vector<size_t> &number) throw (ARPAInputException) {
number.clear();
StringPiece line;
if (!IsEntirelyWhiteSpace(line = in.ReadLine())) throw ARPAInputException("first line was not blank", line);
if ((line = in.ReadLine()) != "\\data\\") throw ARPAInputException("second line was not \\data\\.", line);
while (!IsEntirelyWhiteSpace(line = in.ReadLine())) {
if (line.size() < 6 || strncmp(line.data(), "ngram ", 6)) throw ARPAInputException("count line doesn't begin with \"ngram \"", line);
util::PieceIterator<'='> equals(line);
unsigned int length = boost::lexical_cast<unsigned int>(equals->substr(6));
if (length - 1 != number.size()) throw ARPAInputException("ngram count lengths should be consecutive starting with 1", line);
if (!++equals) throw ARPAInputException("expected = inside a count line", line);
unsigned int count = boost::lexical_cast<unsigned int>(*equals);
number.push_back(count);
}
}
void ReadNGramHeader(std::istream &in, unsigned int length) {
std::string line;
do {
if (!getline(in, line)) throw ARPAInputException(std::string("Reading header for n-gram length ") + boost::lexical_cast<std::string>(length) + " from input lm failed");
} while (IsEntirelyWhiteSpace(line));
if (line != (std::string("\\") + boost::lexical_cast<std::string>(length) + "-grams:")) throw ARPAInputException("wrong ngram header", line);
}
void ReadNGramHeader(util::FilePiece &in, unsigned int length) {
StringPiece line;
while (IsEntirelyWhiteSpace(line = in.ReadLine())) {}
if (line != (std::string("\\") + boost::lexical_cast<std::string>(length) + "-grams:")) throw ARPAInputException("wrong ngram header", line);
}
void ReadEnd(std::istream &in_lm) {
std::string line;
do {
if (!getline(in_lm, line)) throw ARPAInputException("reading end marker failed");
} while (IsEntirelyWhiteSpace(line));
if (line != "\\end\\") throw ARPAInputException("expected ending line \\end\\", line);
}
ARPAOutput::ARPAOutput(const char *name, size_t buffer_size) : file_name_(name), buffer_(new char[buffer_size]) {
try {
file_.exceptions(std::ostream::eofbit | std::ostream::failbit | std::ostream::badbit);
if (!file_.rdbuf()->pubsetbuf(buffer_.get(), buffer_size)) {
std::cerr << "Warning: could not enlarge buffer for " << name << std::endl;
buffer_.reset();
}
file_.open(name, std::ios::out | std::ios::binary);
} catch (const std::ios_base::failure &f) {
throw ARPAOutputException("Opening", file_name_);
}
}
void ARPAOutput::ReserveForCounts(std::streampos reserve) {
try {
for (std::streampos i = 0; i < reserve; i += std::streampos(1)) {
file_ << '\n';
}
} catch (const std::ios_base::failure &f) {
throw ARPAOutputException("Writing blanks to reserve space for counts to ", file_name_);
}
}
void ARPAOutput::BeginLength(unsigned int length) {
fast_counter_ = 0;
try {
file_ << '\\' << length << "-grams:" << '\n';
} catch (const std::ios_base::failure &f) {
throw ARPAOutputException("Writing n-gram header to ", file_name_);
}
}
void ARPAOutput::EndLength(unsigned int length) {
try {
file_ << '\n';
} catch (const std::ios_base::failure &f) {
throw ARPAOutputException("Writing blank at end of count list to ", file_name_);
}
if (length > counts_.size()) {
counts_.resize(length);
}
counts_[length - 1] = fast_counter_;
}
void ARPAOutput::Finish() {
try {
file_ << "\\end\\\n";
file_.seekp(0);
WriteCounts(file_, counts_);
file_ << std::flush;
} catch (const std::ios_base::failure &f) {
throw ARPAOutputException("Finishing including writing counts at beginning to ", file_name_);
}
}
} // namespace lm

View File

@ -1,138 +0,0 @@
#ifndef LM_ARPA_IO_H__
#define LM_ARPA_IO_H__
/* Input and output for ARPA format language model files.
* TODO: throw exceptions instead of using err.
*/
#include "util/string_piece.hh"
#include "util/tokenize_piece.hh"
#include <boost/lexical_cast.hpp>
#include <boost/noncopyable.hpp>
#include <boost/progress.hpp>
#include <boost/scoped_array.hpp>
#include <fstream>
#include <istream>
#include <string>
#include <vector>
#include <err.h>
#include <string.h>
namespace util { class FilePiece; }
namespace lm {
class ARPAInputException : public std::exception {
public:
explicit ARPAInputException(const StringPiece &message) throw();
explicit ARPAInputException(const StringPiece &message, const StringPiece &line) throw();
virtual ~ARPAInputException() throw() {}
const char *what() const throw() { return what_.c_str(); }
private:
std::string what_;
};
class ARPAOutputException : public std::exception {
public:
ARPAOutputException(const char *prefix, const std::string &file_name) throw();
virtual ~ARPAOutputException() throw() {}
const char *what() const throw() { return what_.c_str(); }
const std::string &File() const throw() { return file_name_; }
private:
std::string what_;
const std::string file_name_;
};
// Handling for the counts of n-grams at the beginning of ARPA files.
size_t SizeNeededForCounts(const std::vector<size_t> &number);
// TODO: transition to FilePiece.
void ReadCounts(std::istream &in, std::vector<size_t> &number) throw (ARPAInputException);
void ReadCounts(util::FilePiece &in, std::vector<size_t> &number) throw (ARPAInputException);
// Read and verify the headers like \1-grams:
void ReadNGramHeader(util::FilePiece &in_lm, unsigned int length);
void ReadNGramHeader(std::istream &in_lm, unsigned int length);
// Read and verify end marker.
void ReadEnd(std::istream &in_lm);
/* Writes an ARPA file. This has to be seekable so the counts can be written
* at the end. Hence, I just have it own a std::fstream instead of accepting
* a separately held std::ostream.
*/
class ARPAOutput : boost::noncopyable {
public:
explicit ARPAOutput(const char *name, size_t buffer_size = 65536);
void ReserveForCounts(std::streampos reserve);
void BeginLength(unsigned int length);
void AddNGram(const StringPiece &line) {
try {
file_ << line << '\n';
} catch (const std::ios_base::failure &f) {
throw ARPAOutputException("Writing an n-gram", file_name_);
}
++fast_counter_;
}
template <class Iterator> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line) {
AddNGram(line);
}
void EndLength(unsigned int length);
void Finish();
private:
const std::string file_name_;
boost::scoped_array<char> buffer_;
std::fstream file_;
size_t fast_counter_;
std::vector<size_t> counts_;
};
template <class Output> void ReadNGrams(std::istream &in, unsigned int length, size_t max_length, size_t number, Output &out) {
std::string line;
ReadNGramHeader(in, length);
out.BeginLength(length);
boost::progress_display display(number, std::cerr, std::string("Length ") + boost::lexical_cast<std::string>(length) + "/" + boost::lexical_cast<std::string>(max_length) + ": " + boost::lexical_cast<std::string>(number) + " total\n");
for (unsigned int i = 0; i < number;) {
if (!std::getline(in, line)) throw ARPAInputException("Reading ngram failed. Maybe the counts are wrong?");
util::PieceIterator<'\t'> tabber(line);
if (!tabber) {
std::cerr << "Warning: empty line inside list of " << length << "-grams." << std::endl;
continue;
}
if (!++tabber) throw ARPAInputException("no tab", line);
out.AddNGram(util::PieceIterator<' '>(*tabber), util::PieceIterator<' '>::end(), line);
++i;
++display;
}
out.EndLength(length);
}
template <class Output> void ReadARPA(std::istream &in_lm, Output &out) {
std::vector<size_t> number;
ReadCounts(in_lm, number);
out.ReserveForCounts(SizeNeededForCounts(number));
for (unsigned int i = 0; i < number.size(); ++i) {
ReadNGrams(in_lm, i + 1, number.size(), number[i], out);
}
ReadEnd(in_lm);
out.Finish();
}
} // namespace lm
#endif // LM_ARPA_IO_H__

View File

@ -1,90 +0,0 @@
#ifndef LM_COUNT_IO_H__
#define LM_COUNT_IO_H__
#include <fstream>
#include <string>
#include <err.h>
namespace lm {
class CountOutput : boost::noncopyable {
public:
explicit CountOutput(const char *name) : file_(name, std::ios::out) {}
void AddNGram(const StringPiece &line) {
if (!(file_ << line << '\n')) {
err(3, "Writing counts file failed");
}
}
template <class Iterator> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line) {
AddNGram(line);
}
private:
std::fstream file_;
};
class CountBatch {
public:
explicit CountBatch(std::streamsize initial_read)
: initial_read_(initial_read) {
buffer_.reserve(initial_read);
}
void Read(std::istream &in) {
buffer_.resize(initial_read_);
in.read(&*buffer_.begin(), initial_read_);
buffer_.resize(in.gcount());
char got;
while (in.get(got) && got != '\n')
buffer_.push_back(got);
}
template <class Output> void Send(Output &out) {
for (util::PieceIterator<'\n'> line(StringPiece(&*buffer_.begin(), buffer_.size())); line; ++line) {
util::PieceIterator<'\t'> tabber(*line);
if (!tabber) {
std::cerr << "Warning: empty n-gram count line being removed\n";
continue;
}
util::PieceIterator<' '> words(*tabber);
if (!words) {
std::cerr << "Line has a tab but no words.\n";
continue;
}
out.AddNGram(words, util::PieceIterator<' '>::end(), *line);
}
}
private:
std::streamsize initial_read_;
// This could have been a std::string but that's less happy with raw writes.
std::vector<char> buffer_;
};
template <class Output> void ReadCount(std::istream &in_file, Output &out) {
std::string line;
while (getline(in_file, line)) {
util::PieceIterator<'\t'> tabber(line);
if (!tabber) {
std::cerr << "Warning: empty n-gram count line being removed\n";
continue;
}
util::PieceIterator<' '> words(*tabber);
if (!words) {
std::cerr << "Line has a tab but no words.\n";
continue;
}
out.AddNGram(words, util::PieceIterator<' '>::end(), line);
}
if (!in_file.eof()) {
err(2, "Reading counts file failed");
}
}
} // namespace lm
#endif // LM_COUNT_IO_H__

View File

@ -1,30 +0,0 @@
#ifndef LM_ENCODE_H__
#define LM_ENCODE_H__
#include <inttypes.h>
namespace lm {
namespace encode {
template <class Value> struct StructEncode {
uint64_t GetKey() const { return key; }
Value GetValue() const { return value; }
void SetKey(uint64_t to) { key = to; }
void SetValue(const Value &to) { value = to; }
static size_t Size() { return sizeof(StructEncode<Value>); }
static size_t Bits() { Size() * 8; }
// Nominally private. public to be a POD.
uint64_t key;
Value value;
};
template <class Value> struct AppendEncode {
};
} // namespace encode
} // namespace lm
#endif // LM_ENCODE_H__

View File

@ -1,41 +1,27 @@
#include "lm/exception.hh"
#include<boost/lexical_cast.hpp>
#include<sstream>
#include<errno.h>
#include<stdio.h>
namespace lm {
NotFoundInVocabException::NotFoundInVocabException(const StringPiece &word) throw() : word_(word.data(), word.length()) {
what_ = "Word '";
what_ += word_;
what_ += "' was not found in the vocabulary.";
LoadException::LoadException() throw() {}
LoadException::~LoadException() throw() {}
VocabLoadException::VocabLoadException() throw() {}
VocabLoadException::~VocabLoadException() throw() {}
AllocateMemoryLoadException::AllocateMemoryLoadException(size_t requested) throw() {
*this << "Failed to allocate memory for " << requested << "bytes.";
}
IDDuplicateVocabLoadException::IDDuplicateVocabLoadException(unsigned int id, const StringPiece &first, const StringPiece &second) throw() {
std::ostringstream tmp;
tmp << "Vocabulary id " << id << " is same for " << first << " and " << second;
what_ = tmp.str();
}
AllocateMemoryLoadException::~AllocateMemoryLoadException() throw() {}
WordDuplicateVocabLoadException::WordDuplicateVocabLoadException(const StringPiece &word, unsigned int first, unsigned int second) throw() {
std::ostringstream tmp;
tmp << "Vocabulary word " << word << " has two ids: " << first << " and " << second;
what_ = tmp.str();
}
FormatLoadException::FormatLoadException() throw() {}
FormatLoadException::~FormatLoadException() throw() {}
AllocateMemoryLoadException::AllocateMemoryLoadException(size_t requested) throw()
: ErrnoException(std::string("Failed to allocate language model memory; asked for ") + boost::lexical_cast<std::string>(requested)) {}
FormatLoadException::FormatLoadException(const StringPiece &complaint, const StringPiece &context) throw() {
what_.assign(complaint.data(), complaint.size());
if (!context.empty()) {
what_ += " at ";
what_.append(context.data(), context.size());
}
SpecialWordMissingException::SpecialWordMissingException(StringPiece which) throw() {
*this << "Missing special word " << which;
}
SpecialWordMissingException::~SpecialWordMissingException() throw() {}
} // namespace lm

View File

@ -1,7 +1,7 @@
#ifndef LM_EXCEPTION_HH__
#define LM_EXCEPTION_HH__
#ifndef LM_EXCEPTION__
#define LM_EXCEPTION__
#include "util/errno_exception.hh"
#include "util/exception.hh"
#include "util/string_piece.hh"
#include <exception>
@ -9,150 +9,39 @@
namespace lm {
class NotFoundInVocabException : public std::exception {
public:
explicit NotFoundInVocabException(const StringPiece &word) throw();
~NotFoundInVocabException() throw() {}
const std::string &Word() const throw() { return word_; }
virtual const char *what() const throw() { return what_.c_str(); }
private:
std::string word_;
std::string what_;
};
class LoadException : public std::exception {
class LoadException : public util::Exception {
public:
virtual ~LoadException() throw() {}
virtual ~LoadException() throw();
protected:
LoadException() throw() {}
LoadException() throw();
};
class VocabLoadException : public LoadException {
public:
virtual ~VocabLoadException() throw() {}
protected:
VocabLoadException() throw() {}
};
// Different words, same ids
class IDDuplicateVocabLoadException : public VocabLoadException {
public:
IDDuplicateVocabLoadException(unsigned int id, const StringPiece &existing, const StringPiece &replacement) throw();
~IDDuplicateVocabLoadException() throw() {}
const char *what() const throw() { return what_.c_str(); }
private:
std::string what_;
};
// One word, two ids.
class WordDuplicateVocabLoadException : public VocabLoadException {
public:
WordDuplicateVocabLoadException(const StringPiece &word, unsigned int first, unsigned int second) throw();
~WordDuplicateVocabLoadException() throw() {}
const char *what() const throw() { return what_.c_str(); }
private:
std::string what_;
virtual ~VocabLoadException() throw();
VocabLoadException() throw();
};
class AllocateMemoryLoadException : public util::ErrnoException {
public:
explicit AllocateMemoryLoadException(size_t requested) throw();
~AllocateMemoryLoadException() throw() {}
};
class OpenFileLoadException : public LoadException {
public:
explicit OpenFileLoadException(const char *name) throw() : name_(name) {
what_ = "Error opening file ";
what_ += name;
}
~OpenFileLoadException() throw() {}
const char *what() const throw() { return what_.c_str(); }
private:
std::string name_;
std::string what_;
};
class ReadFileLoadException : public LoadException {
public:
explicit ReadFileLoadException(const char *name) throw() : name_(name) {
what_ = "Error reading file ";
what_ += name;
}
~ReadFileLoadException() throw() {}
const char *what() const throw() { return what_.c_str(); }
private:
std::string name_;
std::string what_;
~AllocateMemoryLoadException() throw();
};
class FormatLoadException : public LoadException {
public:
explicit FormatLoadException(const StringPiece &complaint, const StringPiece &context = StringPiece()) throw();
~FormatLoadException() throw() {}
const char *what() const throw() { return what_.c_str(); }
private:
std::string what_;
FormatLoadException() throw();
~FormatLoadException() throw();
};
class SpecialWordMissingException : public LoadException {
class SpecialWordMissingException : public VocabLoadException {
public:
virtual ~SpecialWordMissingException() throw() {}
protected:
SpecialWordMissingException() throw() {}
explicit SpecialWordMissingException(StringPiece which) throw();
~SpecialWordMissingException() throw();
};
class BeginSentenceMissingException : public SpecialWordMissingException {
public:
BeginSentenceMissingException() throw() {}
~BeginSentenceMissingException() throw() {}
const char *what() const throw() { return "Begin of sentence marker missing from vocabulary"; }
};
class EndSentenceMissingException : public SpecialWordMissingException {
public:
EndSentenceMissingException() throw() {}
~EndSentenceMissingException() throw() {}
const char *what() const throw() { return "End of sentence marker missing from vocabulary"; }
};
class UnknownMissingException : public SpecialWordMissingException {
public:
UnknownMissingException() throw() {}
~UnknownMissingException() throw() {}
const char *what() const throw() { return "Unknown word missing from vocabulary"; }
};
} // namespace lm
#endif

View File

@ -1,54 +0,0 @@
#ifndef LM_MOCK_H__
#define LM_MOCK_H__
#include "lm/facade.hh"
namespace lm {
namespace mock {
class Vocabulary : public base::Vocabulary {
public:
Vocabulary() {}
WordIndex Index(const StringPiece &str) const { return 0; }
const char *Word(WordIndex index) const {
return "Mock";
}
};
struct State {};
size_t hash_value(const State &state) {
return 87483974;
}
bool operator==(const State &left, const State &right) {
return true;
}
class Model : public base::ModelFacade<Model, State, Vocabulary> {
private:
typedef base::ModelFacade<Model, State, Vocabulary> P;
public:
explicit Model() {
Init(State(), State(), vocab_, 0);
}
FullScoreReturn FullScore(
const State &in_state,
const WordIndex word,
State &out_state) const {
FullScoreReturn ret;
ret.prob = 1.0;
ret.ngram_length = 0;
return ret;
}
Vocabulary vocab_;
};
} // namespace mock
} // namespace lm
#endif // LM_MOCK_H__

View File

@ -1,17 +1,15 @@
#include "lm/ngram.hh"
#include "lm/arpa_io.hh"
#include "lm/exception.hh"
#include "util/file_piece.hh"
#include "util/joint_sort.hh"
#include "util/probing_hash_table.hh"
#include "util/scoped.hh"
#include <boost/lexical_cast.hpp>
#include <boost/progress.hpp>
#include <algorithm>
#include <functional>
#include <numeric>
#include <limits>
#include <string>
#include <cmath>
@ -27,38 +25,89 @@ namespace lm {
namespace ngram {
namespace detail {
// Sadly some LMs have <UNK>.
template <class Search> GenericVocabulary<Search>::GenericVocabulary() : hash_unk_(Hash("<unk>")), hash_unk_cap_(Hash("<UNK>")) {}
template <class Search> void GenericVocabulary<Search>::Init(const typename Search::Init &search_init, char *start, std::size_t entries) {
lookup_ = Lookup(search_init, start, entries);
assert(kNotFound == 0);
available_ = kNotFound + 1;
// Later if available_ != expected_available_ then we can throw UnknownMissingException.
expected_available_ = entries;
void Prob::SetBackoff(float to) {
UTIL_THROW(FormatLoadException, "Attempt to set backoff " << to << " for the highest order n-gram");
}
template <class Search> WordIndex GenericVocabulary<Search>::Insert(const StringPiece &str) {
uint64_t hashed = Hash(str);
// Normally static initialization is a bad idea but MurmurHash is pure arithmetic, so this is ok.
const uint64_t kUnknownHash = HashForVocab("<unk>", 5);
// Sadly some LMs have <UNK>.
const uint64_t kUnknownCapHash = HashForVocab("<UNK>", 5);
} // namespace detail
SortedVocabulary::SortedVocabulary() : begin_(NULL), end_(NULL) {}
std::size_t SortedVocabulary::Size(std::size_t entries, float ignored) {
// Lead with the number of entries.
return sizeof(uint64_t) + sizeof(Entry) * entries;
}
void SortedVocabulary::Init(void *start, std::size_t allocated, std::size_t entries) {
assert(allocated >= Size(entries));
// Leave space for number of entries.
begin_ = reinterpret_cast<Entry*>(reinterpret_cast<uint64_t*>(start) + 1);
end_ = begin_;
saw_unk_ = false;
}
WordIndex SortedVocabulary::Insert(const StringPiece &str) {
uint64_t hashed = detail::HashForVocab(str);
if (hashed == detail::kUnknownHash || hashed == detail::kUnknownCapHash) {
saw_unk_ = true;
return 0;
}
end_->key = hashed;
++end_;
// This is 1 + the offset where it was inserted to make room for unk.
return end_ - begin_;
}
bool SortedVocabulary::FinishedLoading(detail::ProbBackoff *reorder_vocab) {
util::JointSort(begin_, end_, reorder_vocab + 1);
SetSpecial(Index("<s>"), Index("</s>"), 0, end_ - begin_ + 1);
// Save size.
*(reinterpret_cast<uint64_t*>(begin_) - 1) = end_ - begin_;
return saw_unk_;
}
void SortedVocabulary::LoadedBinary() {
end_ = begin_ + *(reinterpret_cast<const uint64_t*>(begin_) - 1);
SetSpecial(Index("<s>"), Index("</s>"), 0, end_ - begin_ + 1);
}
namespace detail {
template <class Search> MapVocabulary<Search>::MapVocabulary() {}
template <class Search> void MapVocabulary<Search>::Init(void *start, std::size_t allocated, std::size_t entries) {
lookup_ = Lookup(start, allocated);
available_ = 1;
// Later if available_ != expected_available_ then we can throw UnknownMissingException.
saw_unk_ = false;
}
template <class Search> WordIndex MapVocabulary<Search>::Insert(const StringPiece &str) {
uint64_t hashed = HashForVocab(str);
// Prevent unknown from going into the table.
if (hashed == hash_unk_ || hashed == hash_unk_cap_) {
return kNotFound;
if (hashed == kUnknownHash || hashed == kUnknownCapHash) {
saw_unk_ = true;
return 0;
} else {
lookup_.Insert(hashed, available_);
lookup_.Insert(Lookup::Packing::Make(hashed, available_));
return available_++;
}
}
template <class Search> void GenericVocabulary<Search>::FinishedLoading() {
template <class Search> bool MapVocabulary<Search>::FinishedLoading(ProbBackoff *reorder_vocab) {
lookup_.FinishedInserting();
const WordIndex *begin, *end;
if (!lookup_.Find(Hash("<s>"), begin)) throw BeginSentenceMissingException();
if (!lookup_.Find(Hash("</s>"), end)) throw EndSentenceMissingException();
if (expected_available_ != available_) {
// TODO: command line option for this.
// throw UnknownMissingException();
}
SetSpecial(*begin, *end, kNotFound, available_);
SetSpecial(Index("<s>"), Index("</s>"), 0, available_);
return saw_unk_;
}
template <class Search> void MapVocabulary<Search>::LoadedBinary() {
lookup_.LoadedBinary();
SetSpecial(Index("<s>"), Index("</s>"), 0, available_);
}
/* All of the entropy is in low order bits and boost::hash does poorly with
@ -79,44 +128,78 @@ uint64_t ChainedWordHash(const WordIndex *word, const WordIndex *word_end) {
return current;
}
bool IsEntirelyWhiteSpace(const StringPiece &line) {
for (size_t i = 0; i < static_cast<size_t>(line.size()); ++i) {
if (!isspace(line.data()[i])) return false;
}
return true;
}
void ReadARPACounts(util::FilePiece &in, std::vector<size_t> &number) {
number.clear();
StringPiece line;
if (!IsEntirelyWhiteSpace(line = in.ReadLine())) UTIL_THROW(FormatLoadException, "First line was \"" << line << "\" not blank");
if ((line = in.ReadLine()) != "\\data\\") UTIL_THROW(FormatLoadException, "second line was \"" << line << "\" not \\data\\.");
while (!IsEntirelyWhiteSpace(line = in.ReadLine())) {
if (line.size() < 6 || strncmp(line.data(), "ngram ", 6)) UTIL_THROW(FormatLoadException, "count line \"" << line << "\"doesn't begin with \"ngram \"");
// So strtol doesn't go off the end of line.
std::string remaining(line.data() + 6, line.size() - 6);
char *end_ptr;
unsigned long int length = std::strtol(remaining.c_str(), &end_ptr, 10);
if ((end_ptr == remaining.c_str()) || (length - 1 != number.size())) UTIL_THROW(FormatLoadException, "ngram count lengths should be consecutive starting with 1: " << line);
if (*end_ptr != '=') UTIL_THROW(FormatLoadException, "Expected = immediately following the first number in the count line " << line);
++end_ptr;
const char *start = end_ptr;
long int count = std::strtol(start, &end_ptr, 10);
if (count < 0) UTIL_THROW(FormatLoadException, "Negative n-gram count " << count);
if (start == end_ptr) UTIL_THROW(FormatLoadException, "Couldn't parse n-gram count from " << line);
number.push_back(count);
}
}
void ReadNGramHeader(util::FilePiece &in, unsigned int length) {
StringPiece line;
while (IsEntirelyWhiteSpace(line = in.ReadLine())) {}
std::stringstream expected;
expected << '\\' << length << "-grams:";
if (line != expected.str()) UTIL_THROW(FormatLoadException, "Was expecting n-gram header " << expected.str() << " but got " << line << " instead.");
}
// Special unigram reader because unigram's data structure is different and because we're inserting vocab words.
template <class Voc> void Read1Grams(util::FilePiece &f, const size_t count, Voc &vocab, ProbBackoff *unigrams) {
ReadNGramHeader(f, 1);
boost::progress_display progress(count, std::cerr, "Loading 1-grams\n");
for (size_t i = 0; i < count; ++i, ++progress) {
for (size_t i = 0; i < count; ++i) {
try {
float prob = f.ReadFloat();
if (f.get() != '\t')
throw FormatLoadException("Expected tab after probability");
if (f.get() != '\t') UTIL_THROW(FormatLoadException, "Expected tab after probability");
ProbBackoff &value = unigrams[vocab.Insert(f.ReadDelimited())];
value.prob = prob;
switch (f.get()) {
case '\t':
value.SetBackoff(f.ReadFloat());
if ((f.get() != '\n')) throw FormatLoadException("Expected newline after backoff");
if ((f.get() != '\n')) UTIL_THROW(FormatLoadException, "Expected newline after backoff");
break;
case '\n':
value.ZeroBackoff();
break;
default:
throw FormatLoadException("Expected tab or newline after unigram");
UTIL_THROW(FormatLoadException, "Expected tab or newline after unigram");
}
} catch (const std::exception &f) {
throw FormatLoadException("Error reading the " + boost::lexical_cast<std::string>(i) + "th 1-gram. " + f.what());
} catch(util::Exception &e) {
e << " in the " << i << "th 1-gram at byte " << f.Offset();
throw;
}
}
if (f.ReadLine().size()) throw FormatLoadException("Blank line after ngrams not blank");
vocab.FinishedLoading();
if (f.ReadLine().size()) UTIL_THROW(FormatLoadException, "Expected blank line after unigrams at byte " << f.Offset());
}
template <class Voc, class Store> void ReadNGrams(util::FilePiece &f, const unsigned int n, const size_t count, const Voc &vocab, Store &store) {
ReadNGramHeader(f, n);
boost::progress_display progress(count, std::cerr, std::string("Loading ") + boost::lexical_cast<std::string>(n) + "-grams\n");
// vocab ids of words in reverse order
WordIndex vocab_ids[n];
typename Store::Value value;
for (size_t i = 0; i < count; ++i, ++progress) {
typename Store::Packing::Value value;
for (size_t i = 0; i < count; ++i) {
try {
value.prob = f.ReadFloat();
for (WordIndex *vocab_out = &vocab_ids[n-1]; vocab_out >= vocab_ids; --vocab_out) {
@ -127,71 +210,191 @@ template <class Voc, class Store> void ReadNGrams(util::FilePiece &f, const unsi
switch (f.get()) {
case '\t':
value.SetBackoff(f.ReadFloat());
if ((f.get() != '\n')) UTIL_THROW(FormatLoadException, "Expected newline after backoff");
break;
case '\n':
value.ZeroBackoff();
break;
default:
throw FormatLoadException("Got unexpected delimiter before backoff weight");
UTIL_THROW(FormatLoadException, "Expected tab or newline after unigram");
}
store.Insert(key, value);
} catch (const std::exception &f) {
throw FormatLoadException("Error reading the " + boost::lexical_cast<std::string>(i) + "th " + boost::lexical_cast<std::string>(n) + "-gram." + f.what());
store.Insert(Store::Packing::Make(key, value));
} catch(util::Exception &e) {
e << " in the " << i << "th " << n << "-gram at byte " << f.Offset();
throw;
}
}
if (f.ReadLine().size()) throw FormatLoadException("Blank line after ngrams not blank");
if (f.ReadLine().size()) UTIL_THROW(FormatLoadException, "Expected blank line after " << n << "-grams at byte " << f.Offset());
store.FinishedInserting();
}
void Prob::SetBackoff(float to) {
throw FormatLoadException("Attempt to set backoff " + boost::lexical_cast<std::string>(to) + " for an n-gram with longest order.");
}
template <class Search> size_t GenericModel<Search>::Size(const typename Search::Init &search_init, const std::vector<size_t> &counts) {
if (counts.size() < 2)
throw FormatLoadException("This ngram implementation assumes at least a bigram model.");
size_t memory_size = GenericVocabulary<Search>::Size(search_init, counts[0]);
memory_size += sizeof(ProbBackoff) * counts[0];
template <class Search, class VocabularyT> size_t GenericModel<Search, VocabularyT>::Size(const std::vector<size_t> &counts, const Config &config) {
if (counts.size() > kMaxOrder) UTIL_THROW(FormatLoadException, "This model has order " << counts.size() << ". Edit ngram.hh's kMaxOrder to at least this value and recompile.");
if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model.");
size_t memory_size = VocabularyT::Size(counts[0], config.probing_multiplier);
memory_size += sizeof(ProbBackoff) * (counts[0] + 1); // +1 for hallucinate <unk>
for (unsigned char n = 2; n < counts.size(); ++n) {
memory_size += Middle::Size(search_init, counts[n - 1]);
memory_size += Middle::Size(counts[n - 1], config.probing_multiplier);
}
memory_size += Longest::Size(search_init, counts.back());
memory_size += Longest::Size(counts.back(), config.probing_multiplier);
return memory_size;
}
template <class Search> GenericModel<Search>::GenericModel(const char *file, const typename Search::Init &search_init) {
util::FilePiece f(file);
template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::SetupMemory(char *base, const std::vector<size_t> &counts, const Config &config) {
char *start = base;
size_t allocated = VocabularyT::Size(counts[0], config.probing_multiplier);
vocab_.Init(start, allocated, counts[0]);
start += allocated;
unigram_ = reinterpret_cast<ProbBackoff*>(start);
start += sizeof(ProbBackoff) * (counts[0] + 1);
for (unsigned int n = 2; n < counts.size(); ++n) {
allocated = Middle::Size(counts[n - 1], config.probing_multiplier);
middle_.push_back(Middle(start, allocated));
start += allocated;
}
allocated = Longest::Size(counts.back(), config.probing_multiplier);
longest_ = Longest(start, allocated);
start += allocated;
if (static_cast<std::size_t>(start - base) != Size(counts, config)) UTIL_THROW(FormatLoadException, "The data structures took " << (start - base) << " but Size says they should take " << Size(counts, config));
}
const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 0\n\0";
struct BinaryFileHeader {
char magic[sizeof(kMagicBytes)];
float zero_f, one_f, minus_half_f;
WordIndex one_word_index, max_word_index;
uint64_t one_uint64;
void SetToReference() {
std::memcpy(magic, kMagicBytes, sizeof(magic));
zero_f = 0.0; one_f = 1.0; minus_half_f = -0.5;
one_word_index = 1;
max_word_index = std::numeric_limits<WordIndex>::max();
one_uint64 = 1;
}
};
bool IsBinaryFormat(int fd, off_t size) {
if (size == util::kBadSize || (size <= static_cast<off_t>(sizeof(BinaryFileHeader)))) return false;
// Try reading the header.
util::scoped_mmap memory(mmap(NULL, sizeof(BinaryFileHeader), PROT_READ, MAP_FILE | MAP_PRIVATE, fd, 0), sizeof(BinaryFileHeader));
if (memory.get() == MAP_FAILED) return false;
BinaryFileHeader reference_header = BinaryFileHeader();
reference_header.SetToReference();
if (!memcmp(memory.get(), &reference_header, sizeof(BinaryFileHeader))) return true;
if (!memcmp(memory.get(), "mmap lm ", 8)) UTIL_THROW(FormatLoadException, "File looks like it should be loaded with mmap, but the test values don't match. Was it built on a different machine or with a different compiler?");
return false;
}
std::size_t Align8(std::size_t in) {
std::size_t off = in % 8;
if (!off) return in;
return in + 8 - off;
}
std::size_t TotalHeaderSize(unsigned int order) {
return Align8(sizeof(BinaryFileHeader) + 1 /* order */ + sizeof(uint64_t) * order /* counts */ + sizeof(float) /* probing multiplier */ + 1 /* search_tag */);
}
void ReadBinaryHeader(const void *from, off_t size, std::vector<size_t> &out, float &probing_multiplier, unsigned char &search_tag) {
const char *from_char = reinterpret_cast<const char*>(from);
if (size < static_cast<off_t>(1 + sizeof(BinaryFileHeader))) UTIL_THROW(FormatLoadException, "File too short to have count information.");
// Skip over the BinaryFileHeader which was read by IsBinaryFormat.
from_char += sizeof(BinaryFileHeader);
unsigned char order = *reinterpret_cast<const unsigned char*>(from_char);
if (size < static_cast<off_t>(TotalHeaderSize(order))) UTIL_THROW(FormatLoadException, "File too short to have full header.");
out.resize(static_cast<std::size_t>(order));
const uint64_t *counts = reinterpret_cast<const uint64_t*>(from_char + 1);
for (std::size_t i = 0; i < out.size(); ++i) {
out[i] = static_cast<std::size_t>(counts[i]);
}
const float *probing_ptr = reinterpret_cast<const float*>(counts + out.size());
probing_multiplier = *probing_ptr;
search_tag = *reinterpret_cast<const char*>(probing_ptr + 1);
}
void WriteBinaryHeader(void *to, const std::vector<size_t> &from, float probing_multiplier, char search_tag) {
BinaryFileHeader header = BinaryFileHeader();
header.SetToReference();
memcpy(to, &header, sizeof(BinaryFileHeader));
char *out = reinterpret_cast<char*>(to) + sizeof(BinaryFileHeader);
*reinterpret_cast<unsigned char*>(out) = static_cast<unsigned char>(from.size());
uint64_t *counts = reinterpret_cast<uint64_t*>(out + 1);
for (std::size_t i = 0; i < from.size(); ++i) {
counts[i] = from[i];
}
float *probing_ptr = reinterpret_cast<float*>(counts + from.size());
*probing_ptr = probing_multiplier;
*reinterpret_cast<char*>(probing_ptr + 1) = search_tag;
}
template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::GenericModel(const char *file, Config config) : mapped_file_(util::OpenReadOrThrow(file)) {
const off_t file_size = util::SizeFile(mapped_file_.get());
std::vector<size_t> counts;
ReadCounts(f, counts);
if (counts.size() < 2)
throw FormatLoadException("This ngram implementation assumes at least a bigram model.");
if (counts.size() > kMaxOrder)
throw FormatLoadException(std::string("Edit ngram.hh and change kMaxOrder to at least ") + boost::lexical_cast<std::string>(counts.size()));
unsigned char order = counts.size();
if (IsBinaryFormat(mapped_file_.get(), file_size)) {
memory_.reset(mmap(NULL, file_size, PROT_READ,
#ifdef MAP_POPULATE // Linux specific
(config.prefault ? MAP_POPULATE : 0) |
#endif
MAP_FILE | MAP_PRIVATE, mapped_file_.get(), 0), file_size);
if (MAP_FAILED == memory_.get()) UTIL_THROW(util::ErrnoException, "Couldn't mmap the whole " << file);
const size_t memory_size = Size(search_init, counts);
memory_.reset(mmap(NULL, memory_size, PROT_READ | PROT_WRITE, MAP_ANON | MAP_PRIVATE, -1, 0), memory_size);
if (memory_.get() == MAP_FAILED) throw AllocateMemoryLoadException(memory_size);
unsigned char search_tag;
ReadBinaryHeader(memory_.begin(), file_size, counts, config.probing_multiplier, search_tag);
if (config.probing_multiplier < 1.0) UTIL_THROW(FormatLoadException, "Binary format claims to have a probing multiplier of " << config.probing_multiplier << " which is < 1.0.");
if (search_tag != Search::kBinaryTag) UTIL_THROW(FormatLoadException, "The binary file has a different search strategy than the one requested.");
size_t memory_size = Size(counts, config);
char *start = static_cast<char*>(memory_.get());
vocab_.Init(search_init, start, counts[0]);
start += GenericVocabulary<Search>::Size(search_init, counts[0]);
unigram_ = reinterpret_cast<ProbBackoff*>(start);
start += sizeof(ProbBackoff) * counts[0];
for (unsigned int n = 2; n < order; ++n) {
middle_.push_back(Middle(search_init, start, counts[n - 1]));
start += Middle::Size(search_init, counts[n - 1]);
}
longest_ = Longest(search_init, start, counts[order - 1]);
assert(static_cast<size_t>(start + Longest::Size(search_init, counts[order - 1]) - reinterpret_cast<char*>(memory_.get())) == memory_size);
char *start = reinterpret_cast<char*>(memory_.get()) + TotalHeaderSize(counts.size());
if (memory_size != static_cast<size_t>(memory_.end() - start)) UTIL_THROW(FormatLoadException, "The mmap file " << file << " has size " << file_size << " but " << (memory_size + TotalHeaderSize(counts.size())) << " was expected based on the number of counts and configuration.");
LoadFromARPA(f, counts);
SetupMemory(start, counts, config);
vocab_.LoadedBinary();
for (typename std::vector<Middle>::iterator i = middle_.begin(); i != middle_.end(); ++i) {
i->LoadedBinary();
}
longest_.LoadedBinary();
if (std::fabs(unigram_[GenericVocabulary<Search>::kNotFound].backoff) > 0.0000001) {
throw FormatLoadException(std::string("Backoff for unknown word with index is ") + boost::lexical_cast<std::string>(unigram_[GenericVocabulary<Search>::kNotFound].backoff) + std::string(" not zero"));
} else {
if (config.probing_multiplier <= 1.0) UTIL_THROW(FormatLoadException, "probing multiplier must be > 1.0");
util::FilePiece f(file, mapped_file_.release(), config.messages);
ReadARPACounts(f, counts);
size_t memory_size = Size(counts, config);
char *start;
if (config.write_mmap) {
// Write out an mmap file.
// O_TRUNC insures that the later ftruncate call fills with zeros. The data structures like being initialized with zeros.
mapped_file_.reset(open(config.write_mmap, O_CREAT | O_RDWR | O_TRUNC, S_IRUSR | S_IWUSR | S_IRGRP | S_IROTH));
if (-1 == mapped_file_.get()) UTIL_THROW(util::ErrnoException, "Couldn't create " << config.write_mmap);
size_t total_size = TotalHeaderSize(counts.size()) + memory_size;
if (-1 == ftruncate(mapped_file_.get(), total_size)) UTIL_THROW(util::ErrnoException, "ftruncate on " << config.write_mmap << " to " << total_size << " failed.");
memory_.reset(mmap(NULL, total_size, PROT_READ | PROT_WRITE, MAP_FILE | MAP_SHARED, mapped_file_.get(), 0), total_size);
if (memory_.get() == MAP_FAILED) UTIL_THROW(util::ErrnoException, "Failed to mmap " << config.write_mmap);
WriteBinaryHeader(memory_.get(), counts, config.probing_multiplier, Search::kBinaryTag);
start = reinterpret_cast<char*>(memory_.get()) + TotalHeaderSize(counts.size());
} else {
memory_.reset(mmap(NULL, memory_size, PROT_READ | PROT_WRITE,
#ifdef MAP_ANONYMOUS
MAP_ANONYMOUS // Linux
#else
MAP_ANON // BSD
#endif
| MAP_PRIVATE, -1, 0), memory_size);
if (memory_.get() == MAP_FAILED) throw AllocateMemoryLoadException(memory_size);
start = reinterpret_cast<char*>(memory_.get());
}
SetupMemory(start, counts, config);
try {
LoadFromARPA(f, counts, config);
} catch (FormatLoadException &e) {
e << " in file " << file;
throw;
}
}
// g++ prints warnings unless these are fully initialized.
@ -201,21 +404,38 @@ template <class Search> GenericModel<Search>::GenericModel(const char *file, con
begin_sentence.backoff_[0] = unigram_[begin_sentence.history_[0]].backoff;
State null_context = State();
null_context.valid_length_ = 0;
P::Init(begin_sentence, null_context, vocab_, order);
P::Init(begin_sentence, null_context, vocab_, counts.size());
}
template <class Search> void GenericModel<Search>::LoadFromARPA(util::FilePiece &f, const std::vector<size_t> &counts) {
// Default for <unk> is skip.
unigram_[0].prob = 0.0;
unigram_[0].backoff = 0.0;
template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::LoadFromARPA(util::FilePiece &f, const std::vector<size_t> &counts, const Config &config) {
// Read the unigrams.
Read1Grams(f, counts[0], vocab_, unigram_);
bool saw_unk = vocab_.FinishedLoading(unigram_);
if (!saw_unk) {
switch(config.unknown_missing) {
case Config::THROW_UP:
{
SpecialWordMissingException e("<unk>");
e << " and configuration was set to throw if unknown is missing";
throw e;
}
case Config::COMPLAIN:
if (config.messages) *config.messages << "Language model is missing <unk>. Substituting probability " << config.unknown_missing_prob << "." << std::endl;
// There's no break;. This is by design.
case Config::SILENT:
// Default probabilities for unknown.
unigram_[0].backoff = 0.0;
unigram_[0].prob = config.unknown_missing_prob;
break;
}
}
// Read the n-grams.
for (unsigned int n = 2; n < counts.size(); ++n) {
ReadNGrams(f, n, counts[n-1], vocab_, middle_[n-2]);
}
ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab_, longest_);
if (std::fabs(unigram_[0].backoff) > 0.0000001) UTIL_THROW(FormatLoadException, "Backoff for unknown word should be zero, but was given as " << unigram_[0].backoff);
}
/* Ugly optimized function.
@ -225,7 +445,7 @@ template <class Search> void GenericModel<Search>::LoadFromARPA(util::FilePiece
*
* The search goes in increasing order of ngram length.
*/
template <class Search> FullScoreReturn GenericModel<Search>::FullScore(
template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::FullScore(
const State &in_state,
const WordIndex new_word,
State &out_state) const {
@ -233,7 +453,7 @@ template <class Search> FullScoreReturn GenericModel<Search>::FullScore(
FullScoreReturn ret;
// This is end pointer passed to SumBackoffs.
const ProbBackoff &unigram = unigram_[new_word];
if (new_word == GenericVocabulary<Search>::kNotFound) {
if (new_word == 0) {
ret.ngram_length = out_state.valid_length_ = 0;
// all of backoff.
ret.prob = std::accumulate(
@ -269,7 +489,7 @@ template <class Search> FullScoreReturn GenericModel<Search>::FullScore(
}
lookup_hash = CombineWordHash(lookup_hash, *hist_iter);
if (mid_iter == middle_.end()) break;
const ProbBackoff *found;
typename Middle::ConstIterator found;
if (!mid_iter->Find(lookup_hash, found)) {
// Didn't find an ngram using hist_iter.
// The history used in the found n-gram is [in_state.history_, hist_iter).
@ -282,11 +502,11 @@ template <class Search> FullScoreReturn GenericModel<Search>::FullScore(
ret.prob);
return ret;
}
*backoff_out = found->backoff;
ret.prob = found->prob;
*backoff_out = found->GetValue().backoff;
ret.prob = found->GetValue().prob;
}
const Prob *found;
typename Longest::ConstIterator found;
if (!longest_.Find(lookup_hash, found)) {
// It's an (P::Order()-1)-gram
std::copy(in_state.history_, in_state.history_ + P::Order() - 2, out_state.history_ + 1);
@ -299,13 +519,12 @@ template <class Search> FullScoreReturn GenericModel<Search>::FullScore(
std::copy(in_state.history_, in_state.history_ + P::Order() - 2, out_state.history_ + 1);
out_state.valid_length_ = P::Order() - 1;
ret.ngram_length = P::Order();
ret.prob = found->prob;
ret.prob = found->GetValue().prob;
return ret;
}
// This also instantiates GenericVocabulary.
template class GenericModel<ProbingSearch>;
template class GenericModel<SortedUniformSearch>;
template class GenericModel<ProbingSearch, MapVocabulary<ProbingSearch> >;
template class GenericModel<SortedUniformSearch, SortedVocabulary>;
} // namespace detail
} // namespace ngram
} // namespace lm

View File

@ -2,6 +2,8 @@
#define LM_NGRAM__
#include "lm/facade.hh"
#include "lm/ngram_config.hh"
#include "util/key_value_packing.hh"
#include "util/probing_hash_table.hh"
#include "util/sorted_uniform.hh"
#include "util/string_piece.hh"
@ -38,64 +40,27 @@ class State {
}
// You shouldn't need to touch anything below this line, but the members are public so FullState will qualify as a POD.
unsigned char valid_length_;
float backoff_[kMaxOrder - 1];
// This order minimizes total size of the struct if WordIndex is 64 bit, float is 32 bit, and alignment of 64 bit integers is 64 bit.
WordIndex history_[kMaxOrder - 1];
float backoff_[kMaxOrder - 1];
unsigned char valid_length_;
};
inline size_t hash_value(const State &state) {
// If the histories are equal, so are the backoffs.
return MurmurHash64A(state.history_, sizeof(WordIndex) * state.valid_length_, 0);
return util::MurmurHashNative(state.history_, sizeof(WordIndex) * state.valid_length_);
}
namespace detail {
// std::identity is an SGI extension :-(
struct IdentityHash : public std::unary_function<uint64_t, size_t> {
size_t operator()(uint64_t arg) const { return static_cast<size_t>(arg); }
};
template <class Search> class GenericVocabulary : public base::Vocabulary {
public:
GenericVocabulary();
WordIndex Index(const StringPiece &str) const {
const WordIndex *ret;
return lookup_.Find(Hash(str), ret) ? *ret : kNotFound;
}
static size_t Size(const typename Search::Init &search_init, std::size_t entries) {
return Lookup::Size(search_init, entries);
}
/* This class forces unknown to zero. The constructor starts vocab ids
* after this value. The present hash function maps any string of 0s to 0.
* But that's fine because we never lookup a string of <unk>. In short,
* don't change this.
*/
const static WordIndex kNotFound = 0;
// Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway.
void Init(const typename Search::Init &search_init, char *start, std::size_t entries);
WordIndex Insert(const StringPiece &str);
void FinishedLoading();
private:
static uint64_t Hash(const StringPiece &str) {
// This proved faster than Boost's hash in speed trials: total load time Murmur 67090000, Boost 72210000
return MurmurHash64A(str.data(), str.length(), 0);
}
typedef typename Search::template Table<WordIndex>::T Lookup;
Lookup lookup_;
// Safety check to ensure we were provided with all the expected entries.
std::size_t expected_available_;
// These could be static if I trusted the static initialization fiasco.
const uint64_t hash_unk_, hash_unk_cap_;
};
inline uint64_t HashForVocab(const char *str, std::size_t len) {
// This proved faster than Boost's hash in speed trials: total load time Murmur 67090000, Boost 72210000
// Chose to use 64A instead of native so binary format will be portable across 64 and 32 bit.
return util::MurmurHash64A(str, len, 0);
}
inline uint64_t HashForVocab(const StringPiece &str) {
return HashForVocab(str.data(), str.length());
}
struct Prob {
float prob;
@ -110,27 +75,116 @@ struct ProbBackoff {
void ZeroBackoff() { backoff = 0.0; }
};
// Should return the same results as SRI except ln instead of log10
template <class Search> class GenericModel : public base::ModelFacade<GenericModel<Search>, State, GenericVocabulary<Search> > {
} // namespace detail
// Vocabulary based on sorted uniform find storing only uint64_t values and using their offsets as indices.
class SortedVocabulary : public base::Vocabulary {
private:
typedef base::ModelFacade<GenericModel<Search>, State, GenericVocabulary<Search> > P;
// Sorted uniform requires a GetKey function.
struct Entry {
uint64_t GetKey() const { return key; }
uint64_t key;
bool operator<(const Entry &other) const {
return key < other.key;
}
};
public:
SortedVocabulary();
WordIndex Index(const StringPiece &str) const {
const Entry *found;
if (util::SortedUniformFind<const Entry *, uint64_t>(begin_, end_, detail::HashForVocab(str), found)) {
return found - begin_ + 1; // +1 because <unk> is 0 and does not appear in the lookup table.
} else {
return 0;
}
}
// Ignores second argument for consistency with probing hash which has a float here.
static size_t Size(std::size_t entries, float ignored = 0.0);
// Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway.
void Init(void *start, std::size_t allocated, std::size_t entries);
WordIndex Insert(const StringPiece &str);
// Returns true if unknown was seen. Reorders reorder_vocab so that the IDs are sorted.
bool FinishedLoading(detail::ProbBackoff *reorder_vocab);
void LoadedBinary();
private:
Entry *begin_, *end_;
bool saw_unk_;
};
namespace detail {
// Vocabulary storing a map from uint64_t to WordIndex.
template <class Search> class MapVocabulary : public base::Vocabulary {
public:
MapVocabulary();
WordIndex Index(const StringPiece &str) const {
typename Lookup::ConstIterator i;
return lookup_.Find(HashForVocab(str), i) ? i->GetValue() : 0;
}
static size_t Size(std::size_t entries, float probing_multiplier) {
return Lookup::Size(entries, probing_multiplier);
}
// Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway.
void Init(void *start, std::size_t allocated, std::size_t entries);
WordIndex Insert(const StringPiece &str);
// Returns true if unknown was seen. Does nothing with reorder_vocab.
bool FinishedLoading(ProbBackoff *reorder_vocab);
void LoadedBinary();
private:
typedef typename Search::template Table<WordIndex>::T Lookup;
Lookup lookup_;
bool saw_unk_;
};
// std::identity is an SGI extension :-(
struct IdentityHash : public std::unary_function<uint64_t, size_t> {
size_t operator()(uint64_t arg) const { return static_cast<size_t>(arg); }
};
// Should return the same results as SRI.
// Why VocabularyT instead of just Vocabulary? ModelFacade defines Vocabulary.
template <class Search, class VocabularyT> class GenericModel : public base::ModelFacade<GenericModel<Search, VocabularyT>, State, VocabularyT> {
private:
typedef base::ModelFacade<GenericModel<Search, VocabularyT>, State, VocabularyT> P;
public:
// Get the size of memory that will be mapped given ngram counts. This
// does not include small non-mapped control structures, such as this class
// itself.
static size_t Size(const typename Search::Init &search_init, const std::vector<size_t> &counts);
static size_t Size(const std::vector<size_t> &counts, const Config &config = Config());
GenericModel(const char *file, const typename Search::Init &init);
GenericModel(const char *file, Config config = Config());
FullScoreReturn FullScore(const State &in_state, const WordIndex new_word, State &out_state) const;
private:
void LoadFromARPA(util::FilePiece &f, const std::vector<size_t> &counts);
// Appears after Size in the cc.
void SetupMemory(char *start, const std::vector<size_t> &counts, const Config &config);
void LoadFromARPA(util::FilePiece &f, const std::vector<size_t> &counts, const Config &config);
util::scoped_fd mapped_file_;
// memory_ is the raw block of memory backing vocab_, unigram_, [middle.begin(), middle.end()), and longest_.
util::scoped_mmap memory_;
GenericVocabulary<Search> vocab_;
VocabularyT vocab_;
ProbBackoff *unigram_;
@ -143,26 +197,35 @@ template <class Search> class GenericModel : public base::ModelFacade<GenericMod
struct ProbingSearch {
typedef float Init;
static const unsigned char kBinaryTag = 1;
template <class Value> struct Table {
typedef util::ProbingMap<uint64_t, Value, IdentityHash> T;
typedef util::ByteAlignedPacking<uint64_t, Value> Packing;
typedef util::ProbingHashTable<Packing, IdentityHash> T;
};
};
struct SortedUniformSearch {
typedef util::SortedUniformInit Init;
// This is ignored.
typedef float Init;
static const unsigned char kBinaryTag = 2;
template <class Value> struct Table {
typedef util::SortedUniformMap<uint64_t, Value> T;
typedef util::ByteAlignedPacking<uint64_t, Value> Packing;
typedef util::SortedUniformMap<Packing> T;
};
};
} // namespace detail
// These must also be instantiated in the cc file.
typedef detail::GenericVocabulary<detail::ProbingSearch> Vocabulary;
typedef detail::GenericModel<detail::ProbingSearch> Model;
typedef detail::MapVocabulary<detail::ProbingSearch> Vocabulary;
typedef detail::GenericModel<detail::ProbingSearch, Vocabulary> Model;
typedef detail::GenericVocabulary<detail::SortedUniformSearch> SortedVocabulary;
typedef detail::GenericModel<detail::SortedUniformSearch> SortedModel;
// SortedVocabulary was defined above.
typedef detail::GenericModel<detail::SortedUniformSearch, SortedVocabulary> SortedModel;
} // namespace ngram
} // namespace lm

View File

@ -1,308 +0,0 @@
#include "lm/ngram.hh"
#include "lm/arpa_io.hh"
#include "lm/exception.hh"
#include "util/file_piece.hh"
#include "util/probing_hash_table.hh"
#include "util/scoped.hh"
#include <boost/lexical_cast.hpp>
#include <boost/progress.hpp>
#include <algorithm>
#include <functional>
#include <numeric>
#include <string>
#include <cmath>
#include <fcntl.h>
#include <errno.h>
#include <stdlib.h>
#include <sys/mman.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <unistd.h>
namespace lm {
namespace ngram {
namespace detail {
// Sadly some LMs have <UNK>.
template <class Search> GenericVocabulary<Search>::GenericVocabulary() : hash_unk_(Hash("<unk>")), hash_unk_cap_(Hash("<UNK>")) {}
template <class Search> void GenericVocabulary<Search>::Init(const typename Search::Init &search_init, char *start, std::size_t entries) {
lookup_ = Lookup(search_init, start, entries);
assert(kNotFound == 0);
available_ = kNotFound + 1;
// Later if available_ != expected_available_ then we can throw UnknownMissingException.
expected_available_ = entries;
}
template <class Search> WordIndex GenericVocabulary<Search>::Insert(const StringPiece &str) {
uint64_t hashed = Hash(str);
// Prevent unknown from going into the table.
if (hashed == hash_unk_ || hashed == hash_unk_cap_) {
return kNotFound;
} else {
lookup_.Insert(hashed, available_);
return available_++;
}
}
template <class Search> void GenericVocabulary<Search>::FinishedLoading() {
lookup_.FinishedInserting();
const WordIndex *begin, *end;
if (expected_available_ != available_) {
std::cerr << "HHH";
throw UnknownMissingException();
}
if (!lookup_.Find(Hash("<s>"), begin)) throw BeginSentenceMissingException();
if (!lookup_.Find(Hash("</s>"), end)) throw EndSentenceMissingException();
SetSpecial(*begin, *end, kNotFound, available_);
}
/* All of the entropy is in low order bits and boost::hash does poorly with
* these. Odd numbers near 2^64 chosen by mashing on the keyboard. There is a
* stable point: 0. But 0 is <unk> which won't be queried here anyway.
*/
inline uint64_t CombineWordHash(uint64_t current, const WordIndex next) {
uint64_t ret = (current * 8978948897894561157ULL) ^ (static_cast<uint64_t>(next) * 17894857484156487943ULL);
return ret;
}
uint64_t ChainedWordHash(const WordIndex *word, const WordIndex *word_end) {
if (word == word_end) return 0;
uint64_t current = static_cast<uint64_t>(*word);
for (++word; word != word_end; ++word) {
current = CombineWordHash(current, *word);
}
return current;
}
// Special unigram reader because unigram's data structure is different and because we're inserting vocab words.
template <class Voc> void Read1Grams(util::FilePiece &f, const size_t count, Voc &vocab, ProbBackoff *unigrams) {
ReadNGramHeader(f, 1);
boost::progress_display progress(count, std::cerr, "Loading 1-grams\n");
for (size_t i = 0; i < count; ++i, ++progress) {
try {
float prob = f.ReadFloat();
if (f.get() != '\t')
throw FormatLoadException("Expected tab after probability");
ProbBackoff &value = unigrams[vocab.Insert(f.ReadDelimited())];
value.prob = prob;
switch (f.get()) {
case '\t':
value.SetBackoff(f.ReadFloat());
if ((f.get() != '\n')) throw FormatLoadException("Expected newline after backoff");
break;
case '\n':
value.ZeroBackoff();
break;
default:
throw FormatLoadException("Expected tab or newline after unigram");
}
} catch (const std::exception &f) {
throw FormatLoadException("Error reading the " + boost::lexical_cast<std::string>(i) + "th 1-gram. " + f.what());
}
}
if (f.ReadLine().size()) throw FormatLoadException("Blank line after ngrams not blank");
vocab.FinishedLoading();
}
template <class Voc, class Store> void ReadNGrams(util::FilePiece &f, const unsigned int n, const size_t count, const Voc &vocab, Store &store) {
ReadNGramHeader(f, n);
boost::progress_display progress(count, std::cerr, std::string("Loading ") + boost::lexical_cast<std::string>(n) + "-grams\n");
// vocab ids of words in reverse order
WordIndex vocab_ids[n];
typename Store::Value value;
for (size_t i = 0; i < count; ++i, ++progress) {
try {
value.prob = f.ReadFloat();
for (WordIndex *vocab_out = &vocab_ids[n-1]; vocab_out >= vocab_ids; --vocab_out) {
*vocab_out = vocab.Index(f.ReadDelimited());
}
uint64_t key = ChainedWordHash(vocab_ids, vocab_ids + n);
switch (f.get()) {
case '\t':
value.SetBackoff(f.ReadFloat());
break;
case '\n':
value.ZeroBackoff();
break;
default:
throw FormatLoadException("Got unexpected delimiter before backoff weight");
}
store.Insert(key, value);
} catch (const std::exception &f) {
throw FormatLoadException("Error reading the " + boost::lexical_cast<std::string>(i) + "th " + boost::lexical_cast<std::string>(n) + "-gram." + f.what());
}
}
if (f.ReadLine().size()) throw FormatLoadException("Blank line after ngrams not blank");
store.FinishedInserting();
}
void Prob::SetBackoff(float to) {
throw FormatLoadException("Attempt to set backoff " + boost::lexical_cast<std::string>(to) + " for an n-gram with longest order.");
}
template <class Search> size_t GenericModel<Search>::Size(const typename Search::Init &search_init, const std::vector<size_t> &counts) {
if (counts.size() < 2)
throw FormatLoadException("This ngram implementation assumes at least a bigram model.");
size_t memory_size = GenericVocabulary<Search>::Size(search_init, counts[0]);
memory_size += sizeof(ProbBackoff) * counts[0];
for (unsigned char n = 2; n < counts.size(); ++n) {
memory_size += Middle::Size(search_init, counts[n - 1]);
}
memory_size += Longest::Size(search_init, counts.back());
return memory_size;
}
template <class Search> GenericModel<Search>::GenericModel(const char *file, const typename Search::Init &search_init) {
util::FilePiece f(file);
std::vector<size_t> counts;
ReadCounts(f, counts);
if (counts.size() < 2)
throw FormatLoadException("This ngram implementation assumes at least a bigram model.");
if (counts.size() > kMaxOrder)
throw FormatLoadException(std::string("Edit ngram.hh and change kMaxOrder to at least ") + boost::lexical_cast<std::string>(counts.size()));
unsigned char order = counts.size();
const size_t memory_size = Size(search_init, counts);
memory_.reset(mmap(NULL, memory_size, PROT_READ | PROT_WRITE, MAP_ANON | MAP_PRIVATE, -1, 0), memory_size);
if (memory_.get() == MAP_FAILED) throw AllocateMemoryLoadException(memory_size);
char *start = static_cast<char*>(memory_.get());
vocab_.Init(search_init, start, counts[0]);
start += GenericVocabulary<Search>::Size(search_init, counts[0]);
unigram_ = reinterpret_cast<ProbBackoff*>(start);
start += sizeof(ProbBackoff) * counts[0];
for (unsigned int n = 2; n < order; ++n) {
middle_.push_back(Middle(search_init, start, counts[n - 1]));
start += Middle::Size(search_init, counts[n - 1]);
}
longest_ = Longest(search_init, start, counts[order - 1]);
assert(static_cast<size_t>(start + Longest::Size(search_init, counts[order - 1]) - reinterpret_cast<char*>(memory_.get())) == memory_size);
LoadFromARPA(f, counts);
if (std::fabs(unigram_[GenericVocabulary<Search>::kNotFound].backoff) > 0.0000001) {
throw FormatLoadException(std::string("Backoff for unknown word with index is ") + boost::lexical_cast<std::string>(unigram_[GenericVocabulary<Search>::kNotFound].backoff) + std::string(" not zero"));
}
// g++ prints warnings unless these are fully initialized.
State begin_sentence = State();
begin_sentence.valid_length_ = 1;
begin_sentence.history_[0] = vocab_.BeginSentence();
begin_sentence.backoff_[0] = unigram_[begin_sentence.history_[0]].backoff;
State null_context = State();
null_context.valid_length_ = 0;
P::Init(begin_sentence, null_context, vocab_, order);
}
template <class Search> void GenericModel<Search>::LoadFromARPA(util::FilePiece &f, const std::vector<size_t> &counts) {
// Read the unigrams.
Read1Grams(f, counts[0], vocab_, unigram_);
// Read the n-grams.
for (unsigned int n = 2; n < counts.size(); ++n) {
ReadNGrams(f, n, counts[n-1], vocab_, middle_[n-2]);
}
ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab_, longest_);
}
/* Ugly optimized function.
* in_state contains the previous ngram's length and backoff probabilites to
* be used here. out_state is populated with the found ngram length and
* backoffs that the next call will find useful.
*
* The search goes in increasing order of ngram length.
*/
template <class Search> FullScoreReturn GenericModel<Search>::FullScore(
const State &in_state,
const WordIndex new_word,
State &out_state) const {
FullScoreReturn ret;
// This is end pointer passed to SumBackoffs.
const ProbBackoff &unigram = unigram_[new_word];
if (new_word == GenericVocabulary<Search>::kNotFound) {
ret.ngram_length = out_state.valid_length_ = 0;
// all of backoff.
ret.prob = std::accumulate(
in_state.backoff_,
in_state.backoff_ + in_state.valid_length_,
unigram.prob);
return ret;
}
float *backoff_out(out_state.backoff_);
*backoff_out = unigram.backoff;
ret.prob = unigram.prob;
out_state.history_[0] = new_word;
if (in_state.valid_length_ == 0) {
ret.ngram_length = out_state.valid_length_ = 1;
// No backoff because NGramLength() == 0 and unknown can't have backoff.
return ret;
}
++backoff_out;
// Ok now we now that the bigram contains known words. Start by looking it up.
uint64_t lookup_hash = static_cast<uint64_t>(new_word);
const WordIndex *hist_iter = in_state.history_;
const WordIndex *const hist_end = hist_iter + in_state.valid_length_;
typename std::vector<Middle>::const_iterator mid_iter = middle_.begin();
for (; ; ++mid_iter, ++hist_iter, ++backoff_out) {
if (hist_iter == hist_end) {
// Used history [in_state.history_, hist_end) and ran out. No backoff.
std::copy(in_state.history_, hist_end, out_state.history_ + 1);
ret.ngram_length = out_state.valid_length_ = in_state.valid_length_ + 1;
// ret.prob was already set.
return ret;
}
lookup_hash = CombineWordHash(lookup_hash, *hist_iter);
if (mid_iter == middle_.end()) break;
const ProbBackoff *found;
if (!mid_iter->Find(lookup_hash, found)) {
// Didn't find an ngram using hist_iter.
// The history used in the found n-gram is [in_state.history_, hist_iter).
std::copy(in_state.history_, hist_iter, out_state.history_ + 1);
// Therefore, we found a (hist_iter - in_state.history_ + 1)-gram including the last word.
ret.ngram_length = out_state.valid_length_ = (hist_iter - in_state.history_) + 1;
ret.prob = std::accumulate(
in_state.backoff_ + (mid_iter - middle_.begin()),
in_state.backoff_ + in_state.valid_length_,
ret.prob);
return ret;
}
*backoff_out = found->backoff;
ret.prob = found->prob;
}
const Prob *found;
if (!longest_.Find(lookup_hash, found)) {
// It's an (P::Order()-1)-gram
std::copy(in_state.history_, in_state.history_ + P::Order() - 2, out_state.history_ + 1);
ret.ngram_length = out_state.valid_length_ = P::Order() - 1;
ret.prob += in_state.backoff_[P::Order() - 2];
return ret;
}
// It's an P::Order()-gram
// out_state.valid_length_ is still P::Order() - 1 because the next lookup will only need that much.
std::copy(in_state.history_, in_state.history_ + P::Order() - 2, out_state.history_ + 1);
out_state.valid_length_ = P::Order() - 1;
ret.ngram_length = P::Order();
ret.prob = found->prob;
return ret;
}
// This also instantiates GenericVocabulary.
template class GenericModel<ProbingSearch>;
template class GenericModel<SortedUniformSearch>;
} // namespace detail
} // namespace ngram
} // namespace lm

View File

@ -0,0 +1,13 @@
#include "lm/ngram.hh"
#include <iostream>
int main(int argc, char *argv[]) {
if (argc != 3) {
std::cerr << "Usage: " << argv[0] << " input.arpa output.mmap" << std::endl;
return 1;
}
lm::ngram::Config config;
config.write_mmap = argv[2];
lm::ngram::Model(argv[1], config);
}

58
kenlm/lm/ngram_config.hh Normal file
View File

@ -0,0 +1,58 @@
#ifndef LM_NGRAM_CONFIG__
#define LM_NGRAM_CONFIG__
/* Configuration for ngram model. Separate header to reduce pollution. */
#include <iostream>
namespace lm { namespace ngram {
struct Config {
/* EFFECTIVE FOR BOTH ARPA AND BINARY READS */
// Where to log messages including the progress bar. Set to NULL for
// silence.
std::ostream *messages;
/* ONLY EFFECTIVE WHEN READING ARPA */
// What to do when <unk> isn't in the provided model.
typedef enum {THROW_UP, COMPLAIN, SILENT} UnknownMissing;
UnknownMissing unknown_missing;
// The probability to substitute for <unk> if it's missing from the model.
// No effect if the model has <unk> or unknown_missing == THROW_UP.
float unknown_missing_prob;
// Size multiplier for probing hash table. Must be > 1. Space is linear in
// this. Time is probing_multiplier / (probing_multiplier - 1). No effect
// for sorted variant.
// If you find yourself setting this to a low number, consider using the
// Sorted version instead which has lower memory consumption.
float probing_multiplier;
// While loading an ARPA file, also write out this binary format file. Set
// to NULL to disable.
const char *write_mmap;
/* ONLY EFFECTIVE WHEN READING BINARY */
bool prefault;
// Defaults.
Config() :
messages(&std::cerr),
unknown_missing(COMPLAIN),
unknown_missing_prob(0.0),
probing_multiplier(1.5),
write_mmap(NULL),
prefault(false) {}
};
} /* namespace ngram */ } /* namespace lm */
#endif // LM_NGRAM_CONFIG__

View File

@ -1,12 +1,9 @@
#include "util/tokenize_piece.hh"
#include "lm/ngram.hh"
#include <boost/lexical_cast.hpp>
#include <cstdlib>
#include <fstream>
#include <iostream>
#include <string>
#include <vector>
#include <sys/resource.h>
#include <sys/time.h>
@ -37,19 +34,26 @@ void PrintUsage(const char *message) {
template <class Model> void Query(const Model &model) {
PrintUsage("Loading statistics:\n");
std::string line;
typename Model::State state;
while (std::getline(std::cin, line)) {
typename Model::State state, out;
lm::FullScoreReturn ret;
std::string word;
while (std::cin) {
state = model.BeginSentenceState();
float total = 0.0;
for (util::PieceIterator<' '> it(line); it; ++it) {
LMWordIndex index = model.GetVocabulary().Index(*it);
typename Model::State out;
lm::FullScoreReturn ret = model.FullScore(state, index, out);
bool got = false;
while (std::cin >> word) {
got = true;
ret = model.FullScore(state, model.GetVocabulary().Index(word), out);
total += ret.prob;
std::cout << word << ' ' << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << ' ';
state = out;
std::cout << *it << ' ' << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << ' ';
if (std::cin.get() == '\n') break;
}
if (!got && !std::cin) break;
ret = model.FullScore(state, model.GetVocabulary().EndSentence(), out);
total += ret.prob;
std::cout << "</s> " << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << ' ';
std::cout << "Total: " << total << '\n';
}
PrintUsage("After queries:\n");
@ -57,10 +61,12 @@ template <class Model> void Query(const Model &model) {
int main(int argc, char *argv[]) {
if (argc < 2) {
std::cerr << "Pass language model APRA file." << std::endl;
std::cerr << "Pass language model name." << std::endl;
return 0;
}
lm::ngram::Model ngram(argv[1], 1.5);
Query(ngram);
{
lm::ngram::Model ngram(argv[1]);
Query(ngram);
}
PrintUsage("Total time including destruction:\n");
}

View File

@ -1,7 +1,8 @@
#include "lm/ngram.hh"
#define BOOST_TEST_MODULE NGramTest
#include <stdlib.h>
#define BOOST_TEST_MODULE NGramTest
#include <boost/test/unit_test.hpp>
namespace lm {
@ -11,7 +12,7 @@ namespace {
#define StartTest(word, ngram, score) \
ret = model.FullScore( \
state, \
Lookup(word), \
model.GetVocabulary().Index(word), \
out);\
BOOST_CHECK_CLOSE(score, ret.prob, 0.001); \
BOOST_CHECK_EQUAL(static_cast<unsigned int>(ngram), ret.ngram_length); \
@ -21,19 +22,7 @@ namespace {
StartTest(word, ngram, score) \
state = out;
struct Fixture {
Fixture() : model("test.arpa", 1.5) {}
Model model;
unsigned int Lookup(const char *value) const {
return model.GetVocabulary().Index(StringPiece(value));
}
};
BOOST_FIXTURE_TEST_SUITE(f, Fixture)
BOOST_AUTO_TEST_CASE(starters_probing) {
template <class M> void Starters(M &model) {
FullScoreReturn ret;
Model::State state(model.BeginSentenceState());
Model::State out;
@ -46,7 +35,7 @@ BOOST_AUTO_TEST_CASE(starters_probing) {
StartTest("this_is_not_found", 0, -1.995635 + -0.4149733);
}
BOOST_AUTO_TEST_CASE(continuation_probing) {
template <class M> void Continuation(M &model) {
FullScoreReturn ret;
Model::State state(model.BeginSentenceState());
Model::State out;
@ -68,56 +57,34 @@ BOOST_AUTO_TEST_CASE(continuation_probing) {
AppendTest("loin", 5, -0.0432557);
}
BOOST_AUTO_TEST_SUITE_END()
BOOST_AUTO_TEST_CASE(starters_probing) { Model m("test.arpa"); Starters(m); }
BOOST_AUTO_TEST_CASE(continuation_probing) { Model m("test.arpa"); Continuation(m); }
BOOST_AUTO_TEST_CASE(starters_sorted) { SortedModel m("test.arpa"); Starters(m); }
BOOST_AUTO_TEST_CASE(continuation_sorted) { SortedModel m("test.arpa"); Continuation(m); }
struct SortedFixture {
SortedFixture() : model("test.arpa", detail::SortedUniformSearch::Init()) {}
SortedModel model;
unsigned int Lookup(const char *value) const {
return model.GetVocabulary().Index(StringPiece(value));
BOOST_AUTO_TEST_CASE(write_and_read_probing) {
Config config;
config.write_mmap = "test.binary";
{
Model copy_model("test.arpa", config);
}
};
BOOST_FIXTURE_TEST_SUITE(s, SortedFixture)
BOOST_AUTO_TEST_CASE(starters_sorted) {
FullScoreReturn ret;
Model::State state(model.BeginSentenceState());
Model::State out;
StartTest("looking", 2, -0.4846522);
// , probability plus <s> backoff
StartTest(",", 1, -1.383514 + -0.4149733);
// <unk> probability plus <s> backoff
StartTest("this_is_not_found", 0, -1.995635 + -0.4149733);
Model binary("test.binary");
Starters(binary);
Continuation(binary);
}
BOOST_AUTO_TEST_CASE(continuation_sorted) {
FullScoreReturn ret;
Model::State state(model.BeginSentenceState());
Model::State out;
AppendTest("looking", 2, -0.484652);
AppendTest("on", 3, -0.348837);
AppendTest("a", 4, -0.0155266);
AppendTest("little", 5, -0.00306122);
State preserve = state;
AppendTest("the", 1, -4.04005);
AppendTest("biarritz", 1, -1.9889);
AppendTest("not_found", 0, -2.29666);
AppendTest("more", 1, -1.20632);
AppendTest(".", 2, -0.51363);
AppendTest("</s>", 3, -0.0191651);
state = preserve;
AppendTest("more", 5, -0.00181395);
AppendTest("loin", 5, -0.0432557);
BOOST_AUTO_TEST_CASE(write_and_read_sorted) {
Config config;
config.write_mmap = "test.binary";
config.prefault = true;
{
SortedModel copy_model("test.arpa", config);
}
SortedModel binary("test.binary");
Starters(binary);
Continuation(binary);
}
BOOST_AUTO_TEST_SUITE_END()
} // namespace
} // namespace ngram

View File

@ -4,7 +4,7 @@
#include <Ngram.h>
#include <Vocab.h>
#include <iostream>
#include <errno.h>
namespace lm {
namespace sri {
@ -29,64 +29,85 @@ const char *Vocabulary::Word(WordIndex index) const {
void Vocabulary::FinishedLoading() {
SetSpecial(
sri_->getIndex(Vocab_SentStart),
sri_->getIndex(Vocab_SentEnd),
sri_->getIndex(Vocab_Unknown),
sri_->ssIndex(),
sri_->seIndex(),
sri_->unkIndex(),
sri_->highIndex() + 1);
}
namespace {
Ngram *MakeSRIModel(const char *file_name, unsigned int ngram_length, Vocab &sri_vocab) throw (ReadFileLoadException) {
Ngram *MakeSRIModel(const char *file_name, unsigned int ngram_length, Vocab &sri_vocab) {
sri_vocab.unkIsWord() = true;
std::auto_ptr<Ngram> ret(new Ngram(sri_vocab, ngram_length));
File file(file_name, "r");
errno = 0;
if (!ret->read(file)) {
throw ReadFileLoadException(file_name);
UTIL_THROW(FormatLoadException, "reading file " << file_name << " with SRI failed.");
}
return ret.release();
}
} // namespace
Model::Model(const char *file_name, unsigned int ngram_length) : sri_(MakeSRIModel(file_name, ngram_length, *vocab_.sri_)) {
// TODO: exception this?
if (!sri_->setorder()) {
std::cerr << "Can't have order 0 SRI" << std::endl;
abort();
UTIL_THROW(FormatLoadException, "Can't have an SRI model with order 0.");
}
vocab_.FinishedLoading();
State begin_state = State();
begin_state.valid_length_ = 1;
begin_state.history_[0] = vocab_.BeginSentence();
if (kMaxOrder > 1) {
begin_state.history_[0] = vocab_.BeginSentence();
if (kMaxOrder > 2) begin_state.history_[1] = Vocab_None;
}
State null_state = State();
null_state.valid_length_ = 0;
if (kMaxOrder > 1) null_state.history_[0] = Vocab_None;
Init(begin_state, null_state, vocab_, sri_->setorder());
not_found_ = vocab_.NotFound();
}
Model::~Model() {}
namespace {
/* Argh SRI's wordProb knows the ngram length but doesn't return it. One more
* reason you should use my model. */
// TODO(stolcke): fix SRILM so I don't have to do this.
unsigned int MatchedLength(Ngram &model, const WordIndex new_word, const SRIVocabIndex *const_history) {
unsigned int out_length = 0;
// This gets the length of context used, which is ngram_length - 1 unless new_word is OOV in which case it is 0.
model.contextID(new_word, const_history, out_length);
return out_length + 1;
}
} // namespace
FullScoreReturn Model::FullScore(const State &in_state, const WordIndex new_word, State &out_state) const {
// If you get a compiler in this function, change SRIVocabIndex in sri.hh to match the one found in SRI's Vocab.h.
// TODO: optimize this to use the new state's history.
SRIVocabIndex history[Order()];
std::copy(in_state.history_, in_state.history_ + in_state.valid_length_, history);
history[in_state.valid_length_] = Vocab_None;
const SRIVocabIndex *const_history = history;
const SRIVocabIndex *const_history;
SRIVocabIndex local_history[Order()];
if (in_state.valid_length_ < kMaxOrder - 1) {
const_history = in_state.history_;
} else {
std::copy(in_state.history_, in_state.history_ + in_state.valid_length_, local_history);
local_history[in_state.valid_length_] = Vocab_None;
const_history = local_history;
}
FullScoreReturn ret;
// TODO: avoid double backoff.
if (new_word != not_found_) {
// This gets the length of context used, which is ngram_length - 1 unless new_word is OOV in which case it is 0.
unsigned int out_length = 0;
sri_->contextID(new_word, const_history, out_length);
ret.ngram_length = out_length + 1;
ret.ngram_length = MatchedLength(*sri_, new_word, const_history);
out_state.history_[0] = new_word;
out_state.valid_length_ = std::min<unsigned char>(ret.ngram_length, Order() - 1);
std::copy(history, history + out_state.valid_length_ - 1, out_state.history_ + 1);
std::copy(const_history, const_history + out_state.valid_length_ - 1, out_state.history_ + 1);
if (out_state.valid_length_ < kMaxOrder - 1) {
out_state.history_[out_state.valid_length_] = Vocab_None;
}
} else {
ret.ngram_length = 0;
if (kMaxOrder > 1) out_state.history_[0] = Vocab_None;
out_state.valid_length_ = 0;
}
ret.prob = sri_->wordProb(new_word, const_history);
// SRI uses log10, we use log.
return ret;
}

View File

@ -2,12 +2,11 @@
#define LM_SRI__
#include "lm/facade.hh"
#include <boost/functional/hash/hash.hpp>
#include <boost/scoped_ptr.hpp>
#include "util/murmur_hash.hh"
#include <cmath>
#include <exception>
#include <memory>
class Ngram;
class Vocab;
@ -33,8 +32,10 @@ typedef unsigned int SRIVocabIndex;
class State {
public:
unsigned char valid_length_;
// You shouldn't need to touch these, but they're public so State will be a POD.
// If valid_length_ < kMaxOrder - 1 then history_[valid_length_] == Vocab_None.
SRIVocabIndex history_[kMaxOrder - 1];
unsigned char valid_length_;
};
inline bool operator==(const State &left, const State &right) {
@ -50,7 +51,7 @@ inline bool operator==(const State &left, const State &right) {
}
inline size_t hash_value(const State &state) {
return boost::hash_range(state.history_, state.history_ + state.valid_length_);
return util::MurmurHashNative(&state.history_, sizeof(SRIVocabIndex) * state.valid_length_);
}
class Vocabulary : public base::Vocabulary {
@ -74,7 +75,9 @@ class Vocabulary : public base::Vocabulary {
friend class Model;
void FinishedLoading();
mutable boost::scoped_ptr<Vocab> sri_;
// The parent class isn't copyable so auto_ptr is the same as scoped_ptr
// but without the boost dependence.
mutable std::auto_ptr<Vocab> sri_;
};
class Model : public base::ModelFacade<Model, State, Vocabulary> {
@ -88,7 +91,7 @@ class Model : public base::ModelFacade<Model, State, Vocabulary> {
private:
Vocabulary vocab_;
mutable boost::scoped_ptr<Ngram> sri_;
mutable std::auto_ptr<Ngram> sri_;
WordIndex not_found_;
};

65
kenlm/lm/sri_test.cc Normal file
View File

@ -0,0 +1,65 @@
#include "lm/sri.hh"
#include <stdlib.h>
#define BOOST_TEST_MODULE SRITest
#include <boost/test/unit_test.hpp>
namespace lm {
namespace sri {
namespace {
#define StartTest(word, ngram, score) \
ret = model.FullScore( \
state, \
model.GetVocabulary().Index(word), \
out);\
BOOST_CHECK_CLOSE(score, ret.prob, 0.001); \
BOOST_CHECK_EQUAL(static_cast<unsigned int>(ngram), ret.ngram_length); \
BOOST_CHECK_EQUAL(std::min<unsigned char>(ngram, 5 - 1), out.valid_length_);
#define AppendTest(word, ngram, score) \
StartTest(word, ngram, score) \
state = out;
template <class M> void Starters(M &model) {
FullScoreReturn ret;
Model::State state(model.BeginSentenceState());
Model::State out;
StartTest("looking", 2, -0.4846522);
// , probability plus <s> backoff
StartTest(",", 1, -1.383514 + -0.4149733);
// <unk> probability plus <s> backoff
StartTest("this_is_not_found", 0, -1.995635 + -0.4149733);
}
template <class M> void Continuation(M &model) {
FullScoreReturn ret;
Model::State state(model.BeginSentenceState());
Model::State out;
AppendTest("looking", 2, -0.484652);
AppendTest("on", 3, -0.348837);
AppendTest("a", 4, -0.0155266);
AppendTest("little", 5, -0.00306122);
State preserve = state;
AppendTest("the", 1, -4.04005);
AppendTest("biarritz", 1, -1.9889);
AppendTest("not_found", 0, -2.29666);
AppendTest("more", 1, -1.20632);
AppendTest(".", 2, -0.51363);
AppendTest("</s>", 3, -0.0191651);
state = preserve;
AppendTest("more", 5, -0.00181395);
AppendTest("loin", 5, -0.0432557);
}
BOOST_AUTO_TEST_CASE(starters) { Model m("test.arpa", 5); Starters(m); }
BOOST_AUTO_TEST_CASE(continuation) { Model m("test.arpa", 5); Continuation(m); }
} // namespace
} // namespace sri
} // namespace lm

View File

@ -1,127 +0,0 @@
#ifndef LM_STORE_H__
#define LM_STORE_H__
namespace lm {
template <class Key, class Value, class Encoding, class Ptr, size_t Multiply> class SimpleIterator {
private:
typedef SimpleIterator<Key, Value, Encoding, Ptr, Multiply> S;
public:
SimpleIterator() {}
explicit SimpleIterator(const char *begin) : ptr_(reinterpret_cast<Ptr>(begin)) {}
explicit SimpleIterator(char *begin) : ptr_(reinterpret_cast<Ptr>(begin)) {}
bool operator==(const S &other) { return ptr_ == other.ptr_; }
S &operator++() {
ptr_ += Multiply;
return *this;
}
S &operator+=(size_t amount) {
ptr_ += Multiply * amount;
return *this;
}
S &operator--() {
ptr_ -= Multiply;
return *this;
}
S &operator-=(size_t amount) {
ptr_ -= Multiply * amount;
return *this;
}
size_t operator-(const S &other) {
return (ptr_ - other.ptr_) / Multiply;
}
Key GetKey() { return Entry::GetKey(ptr_); }
Value GetValue() { return Entry::GetValue(ptr_); }
// These shouldn't be a problem due to template magic.
void SetKey(Key to) { Entry::SetKey(ptr_, to); }
void SetValue(const Value &to) { Entry::SetValue(ptr_, to); }
protected:
Ptr ptr_;
};
template <class Key, class Value> class AlignedArray {
public:
static const size_t kBytes = sizeof(Entry);
static const size_t kBits = kBytes * 8;
typedef SimpleIterator<Key, Value, Encoding, const Entry*, 1> ConstIterator;
typedef SimpleIterator<Key, Value, Encoding, Entry*, 1> MutableIterator;
private:
struct Entry {
Key key;
Value value;
};
struct Encoding {
static Key GetKey(const Entry *e) { return e->key; }
static Value GetValue(const Entry *e) { return e->value; }
static void SetKey(Entry *e, Key key) { e->key = key; }
static void SetValue(Entry *e, const Value &value) { e->value = value; }
};
};
template <class Key, class Value> class ByteAlignedArray {
public:
static const size_t kBytes = sizeof(Key) + sizeof(Value);
static const size_t kBits = kBytes * 8;
typedef SimpleIterator<Key, Value, Encoding, const char *, kBytes> ConstIterator;
typedef SimpleIterator<Key, Value, Encoding, char*, kBytes> MutableIterator;
private:
struct Encoding {
static Key GetKey(const char *a) {
return *reinterpret_cast<const Key *>(a);
}
static Value GetValue(const char *a) {
return *reinterpret_cast<const Value*>(a + sizeof(Key));
}
static void SetKey(char *a, Key key) {
*reinterpret_cast<Key *>(a) = key;
}
static void SetValue(char *a, const Value &value) {
*reinterpret_cast<Value*>(a + sizeof(Key)) = value;
}
};
};
template <class Key, class Value> class AlternatingArray {
public:
static const size_t kBytes = sizeof(Entry) / 2;
static const size_t kBits = sizeof(Entry) * 4;
typedef SimpleIterator<Key, Value, Encoding, size_t, 1> ConstIterator;
typedef SimpleIterator<Key, Value, Encoding, size_t, 1> MutableIterator;
private:
// Here's hoping the % operations compile to bit operations.
struct Encoding {
static Key GetKey(const char *a) {
const Entry &ent = *reinterpret_cast<const Entry *>(a);
return (reinterpret_cast<std::size_t>(a) % sizeof(Entry)) ? ent.key0 : ent.key1;
}
static Value GetValue(const char *a) {
const Entry &val = *reinterpret_cast<const Entry *>(a);
return (a % sizeof(Entry)) ? ent.value0 : ent.value1;
}
static void SetKey(char *a, Key key) {
Entry &val = *reinterpret_cast<Entry *>(a);
((a % sizeof(Entry)) ? ent.key0 : ent.key1) = key;
}
static void SetValue(char *a, const Value &value) {
Entry &val = *reinterpret_cast<Entry *>(a);
((a % sizeof(Entry)) ? ent.value0 : ent.value1) = value;
}
};
};
} // namespace lm
#endif // LM_STORE_H__

View File

@ -12,8 +12,8 @@ void Vocabulary::SetSpecial(WordIndex begin_sentence, WordIndex end_sentence, Wo
end_sentence_ = end_sentence;
not_found_ = not_found;
available_ = available;
if (begin_sentence_ == not_found_) throw BeginSentenceMissingException();
if (end_sentence_ == not_found_) throw EndSentenceMissingException();
if (begin_sentence_ == not_found_) throw SpecialWordMissingException("<s>");
if (end_sentence_ == not_found_) throw SpecialWordMissingException("</s>");
}
Model::~Model() {}

View File

@ -4,8 +4,6 @@
#include "lm/word_index.hh"
#include "util/string_piece.hh"
#include <boost/noncopyable.hpp>
#include <string>
namespace lm {
@ -32,7 +30,7 @@ template <class T, class U, class V> class ModelFacade;
* GetVocabulary() for the actual implementation (in which case you'll need the
* actual implementation of the Model too).
*/
class Vocabulary : boost::noncopyable {
class Vocabulary {
public:
virtual ~Vocabulary();
@ -65,6 +63,12 @@ class Vocabulary : boost::noncopyable {
void SetSpecial(WordIndex begin_sentence, WordIndex end_sentence, WordIndex not_found, WordIndex available);
WordIndex begin_sentence_, end_sentence_, not_found_, available_;
private:
// Disable copy constructors. They're private and undefined.
// Ersatz boost::noncopyable.
Vocabulary(const Vocabulary &);
Vocabulary &operator=(const Vocabulary &);
};
/* There are two ways to access a Model.
@ -113,7 +117,7 @@ class Vocabulary : boost::noncopyable {
* All the State objects are POD, so it's ok to use raw memory for storing
* State.
*/
class Model : boost::noncopyable {
class Model {
public:
virtual ~Model();
@ -139,6 +143,11 @@ class Model : boost::noncopyable {
const Vocabulary *base_vocab_;
unsigned char order_;
// Disable copy constructors. They're private and undefined.
// Ersatz boost::noncopyable.
Model(const Model &);
Model &operator=(const Model &);
};
} // mamespace base

View File

@ -1,94 +0,0 @@
#ifndef LM_VOCAB_H__
#define LM_VOCAB_H__
#include "lm/exception.hh"
#include "lm/word_index.hh"
#include <boost/unordered_map.hpp>
#include <boost/ptr_container/ptr_vector.hpp>
#include <memory>
namespace lm {
/* This doesn't inherit from Vocabulary so it can be used where the special
* tags are not applicable.
* TODO: make ngram.* and SALM use this.
*/
class GenericVocabulary {
public:
static const WordIndex kNotFoundIndex;
static const char *const kNotFoundWord;
GenericVocabulary() {
strings_.push_back(new std::string(kNotFoundWord));
ids_[strings_[0]] = kNotFoundIndex;
strings_.push_back(new std::string());
available_ = 1;
}
/* Query API */
WordIndex Index(const StringPiece &str) const {
boost::unordered_map<StringPiece, WordIndex>::const_iterator i(ids_.find(str));
return (__builtin_expect(i == ids_.end(), 0)) ? kNotFoundIndex : i->second;
}
// Note that the literal token <unk> is in the index.
WordIndex IndexOrThrow(const StringPiece &str) const {
boost::unordered_map<StringPiece, WordIndex>::const_iterator i(ids_.find(str));
if (i == ids_.end()) throw NotFoundInVocabException(str);
return i->second;
}
bool Known(const StringPiece &str) const {
return ids_.find(str) != ids_.end();
}
const char *Word(WordIndex index) const {
return strings_[index].c_str();
}
/* Insertion API */
void Reserve(size_t to) {
strings_.reserve(to);
ids_.rehash(to + 1);
}
std::string &Temp() {
return strings_.back();
}
// Take the string returned by Temp() and insert it.
WordIndex InsertOrFind() {
std::pair<boost::unordered_map<StringPiece, WordIndex>::const_iterator, bool> res(ids_.insert(std::make_pair(StringPiece(strings_.back()), available_)));
if (res.second) {
++available_;
strings_.push_back(new std::string());
}
return res.first->second;
}
// Insert a word. Throw up if already found. Take ownership of the word in either case.
WordIndex InsertOrThrow() throw(WordDuplicateVocabLoadException) {
std::pair<boost::unordered_map<StringPiece, WordIndex>::const_iterator, bool> res(ids_.insert(std::make_pair(StringPiece(strings_.back()), available_)));
if (!res.second) {
throw WordDuplicateVocabLoadException(strings_.back(), res.first->second, available_);
}
++available_;
strings_.push_back(new std::string());
return res.first->second;
}
private:
// TODO: optimize memory use here by using one giant buffer, preferably premade by a binary file format.
boost::ptr_vector<std::string> strings_;
boost::unordered_map<StringPiece, WordIndex> ids_;
WordIndex available_;
};
} // namespace lm
#endif // LM_VOCAB_H__

View File

@ -1,6 +1,6 @@
// Separate header because this is used often.
#ifndef LM_WORD_INDEX_HH__
#define LM_WORD_INDEX_HH__
#ifndef LM_WORD_INDEX__
#define LM_WORD_INDEX__
namespace lm {
typedef unsigned int WordIndex;

8
kenlm/test.sh Executable file
View File

@ -0,0 +1,8 @@
#!/bin/bash
#Run tests. Requires Boost.
set -e
./compile.sh
for i in util/{file_piece,joint_sort,key_value_packing,probing_hash_table,sorted_uniform}_test lm/ngram_test; do
g++ -I. -O3 $i.cc {lm,util}/*.o -lboost_test_exec_monitor -o $i
pushd $(dirname $i) && ./$(basename $i); popd
done

View File

@ -1,21 +0,0 @@
#include "util/errno_exception.hh"
#include <boost/lexical_cast.hpp>
#include <errno.h>
#include <stdio.h>
namespace util {
ErrnoException::ErrnoException(const StringPiece &problem) throw() : errno_(errno), what_(problem.data(), problem.size()) {
char buf[200];
buf[0] = 0;
const char *add = buf;
if (!strerror_r(errno, buf, 200)) {
what_ += add;
}
}
ErrnoException::~ErrnoException() throw() {}
} // namespace util

View File

@ -1,28 +0,0 @@
#ifndef UTIL_ERRNO_EXCEPTION__
#define UTIL_ERRNO_EXCEPTION__
#include <exception>
#include <string>
#include "util/string_piece.hh"
namespace util {
class ErrnoException : public std::exception {
public:
explicit ErrnoException(const StringPiece &problem) throw();
virtual ~ErrnoException() throw();
virtual const char *what() const throw() { return what_.c_str(); }
int Error() { return errno_; }
private:
int errno_;
std::string what_;
};
} // namespace util
#endif // UTIL_ERRNO_EXCEPTION__

View File

@ -0,0 +1,47 @@
#include "util/ersatz_progress.hh"
#include <algorithm>
#include <ostream>
#include <limits>
#include <string>
namespace util {
namespace { const unsigned char kWidth = 100; }
ErsatzProgress::ErsatzProgress() : current_(0), next_(std::numeric_limits<std::size_t>::max()), complete_(next_), out_(NULL) {}
ErsatzProgress::~ErsatzProgress() {
if (!out_) return;
for (; stones_written_ < kWidth; ++stones_written_) {
(*out_) << '*';
}
*out_ << '\n';
}
ErsatzProgress::ErsatzProgress(std::ostream *to, const std::string &message, std::size_t complete)
: current_(0), next_(complete / kWidth), complete_(complete), stones_written_(0), out_(to) {
if (!out_) {
next_ = std::numeric_limits<std::size_t>::max();
return;
}
*out_ << message << "\n----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100\n";
}
void ErsatzProgress::Milestone() {
if (!out_) { current_ = 0; return; }
if (!complete_) return;
unsigned char stone = std::min(static_cast<std::size_t>(kWidth), (current_ * kWidth) / complete_);
for (; stones_written_ < stone; ++stones_written_) {
(*out_) << '*';
}
if (current_ >= complete_) {
next_ = std::numeric_limits<std::size_t>::max();
} else {
next_ = std::max(next_, (stone * complete_) / kWidth);
}
}
} // namespace util

View File

@ -0,0 +1,50 @@
#ifndef UTIL_ERSATZ_PROGRESS__
#define UTIL_ERSATZ_PROGRESS__
#include <iosfwd>
#include <string>
// Ersatz version of boost::progress so core language model doesn't depend on
// boost. Also adds option to print nothing.
namespace util {
class ErsatzProgress {
public:
// No output.
ErsatzProgress();
// Null means no output. The null value is useful for passing along the ostream pointer from another caller.
ErsatzProgress(std::ostream *to, const std::string &message, std::size_t complete);
~ErsatzProgress();
ErsatzProgress &operator++() {
if (++current_ == next_) Milestone();
return *this;
}
ErsatzProgress &operator+=(std::size_t amount) {
if ((current_ += amount) >= next_) Milestone();
return *this;
}
void Set(std::size_t to) {
if ((current_ = to) >= next_) Milestone();
Milestone();
}
private:
void Milestone();
std::size_t current_, next_, complete_;
unsigned char stones_written_;
std::ostream *out_;
// noncopyable
ErsatzProgress(const ErsatzProgress &other);
ErsatzProgress &operator=(const ErsatzProgress &other);
};
} // namespace util
#endif // UTIL_ERSATZ_PROGRESS__

38
kenlm/util/exception.cc Normal file
View File

@ -0,0 +1,38 @@
#include "util/exception.hh"
#include <errno.h>
#include <string.h>
namespace util {
Exception::Exception() throw() {}
Exception::~Exception() throw() {}
Exception::Exception(const Exception &other) throw() : stream_(other.stream_.str()) {}
Exception &Exception::operator=(const Exception &other) throw() { stream_.str(other.stream_.str()); return *this; }
const char *Exception::what() const throw() { return stream_.str().c_str(); }
namespace {
// The XOPEN version.
const char *HandleStrerror(int ret, const char *buf) {
if (!ret) return buf;
return NULL;
}
// The GNU version.
const char *HandleStrerror(const char *ret, const char *buf) {
return ret;
}
} // namespace
ErrnoException::ErrnoException() throw() : errno_(errno) {
char buf[200];
buf[0] = 0;
const char *add = HandleStrerror(strerror_r(errno, buf, 200), buf);
if (add) {
*this << add << ' ';
}
}
ErrnoException::~ErrnoException() throw() {}
} // namespace util

54
kenlm/util/exception.hh Normal file
View File

@ -0,0 +1,54 @@
#ifndef UTIL_ERRNO_EXCEPTION__
#define UTIL_ERRNO_EXCEPTION__
#include <exception>
#include <sstream>
namespace util {
class Exception : public std::exception {
public:
Exception() throw();
virtual ~Exception() throw();
Exception(const Exception &other) throw();
Exception &operator=(const Exception &other) throw();
virtual const char *what() const throw();
std::stringstream &str() { return stream_; }
// This helps restrict operator<< defined below.
template <class T> struct ExceptionTag {
typedef T Identity;
};
protected:
std::stringstream stream_;
};
/* This implements the normal operator<< for Exception and all its children.
* SNIFAE means it only applies to Exception. Think of this as an ersatz
* boost::enable_if.
*/
template <class Except, class Data> typename Except::template ExceptionTag<Except&>::Identity operator<<(Except &e, const Data &data) {
e.str() << data;
return e;
}
#define UTIL_THROW(Exception, Modify) { Exception UTIL_e; {UTIL_e << Modify;} throw UTIL_e; }
class ErrnoException : public Exception {
public:
ErrnoException() throw();
virtual ~ErrnoException() throw();
int Error() { return errno_; }
private:
int errno_;
};
} // namespace util
#endif // UTIL_ERRNO_EXCEPTION__

View File

@ -1,9 +1,12 @@
#include "util/file_piece.hh"
#include "util/errno_exception.hh"
#include "util/exception.hh"
#include <iostream>
#include <string>
#include <limits>
#include <assert.h>
#include <cstdlib>
#include <ctype.h>
#include <fcntl.h>
@ -14,33 +17,51 @@
namespace util {
namespace {
int OpenOrThrow(const char *name) {
EndOfFileException::EndOfFileException() throw() {
stream_ << "End of file";
}
EndOfFileException::~EndOfFileException() throw() {}
ParseNumberException::ParseNumberException(StringPiece value) throw() {
stream_ << "Could not parse \"" << value << "\" into a float";
}
int OpenReadOrThrow(const char *name) {
int ret = open(name, O_RDONLY);
if (ret == -1) throw ErrnoException(std::string("open ") + name);
if (ret == -1) UTIL_THROW(ErrnoException, "in open (" << name << ") for reading");
return ret;
}
off_t SizeOrThrow(int fd, const char *name) {
off_t SizeFile(int fd) {
struct stat sb;
if (fstat(fd, &sb) == -1) throw ErrnoException(std::string("stat ") + name);
if (fstat(fd, &sb) == -1 || (!sb.st_size && !S_ISREG(sb.st_mode))) return kBadSize;
return sb.st_size;
}
} // namespace
ParseNumberException::ParseNumberException(StringPiece value) throw() {
what_ = "Could not parse \"";
what_.append(value.data(), value.length());
what_ += "\" into a float.";
FilePiece::FilePiece(const char *name, std::ostream *show_progress, off_t min_buffer) :
file_(OpenReadOrThrow(name)), total_size_(SizeFile(file_.get())), page_(sysconf(_SC_PAGE_SIZE)),
progress_(total_size_ == kBadSize ? NULL : show_progress, std::string("Reading ") + name, total_size_) {
Initialize(name, show_progress, min_buffer);
}
FilePiece::FilePiece(const char *name, off_t min_buffer) :
file_(OpenOrThrow(name)), total_size_(SizeOrThrow(file_.get(), name)), page_(sysconf(_SC_PAGE_SIZE)) {
FilePiece::FilePiece(const char *name, int fd, std::ostream *show_progress, off_t min_buffer) :
file_(fd), total_size_(SizeFile(file_.get())), page_(sysconf(_SC_PAGE_SIZE)),
progress_(total_size_ == kBadSize ? NULL : show_progress, std::string("Reading ") + name, total_size_) {
Initialize(name, show_progress, min_buffer);
}
void FilePiece::Initialize(const char *name, std::ostream *show_progress, off_t min_buffer) {
if (total_size_ == kBadSize) {
fallback_to_read_ = true;
if (show_progress)
*show_progress << "File " << name << " isn't normal. Using slower read() instead of mmap(). No progress bar." << std::endl;
} else {
fallback_to_read_ = false;
}
default_map_size_ = page_ * std::max<off_t>((min_buffer / page_ + 1), 2);
position_ = NULL;
position_end_ = NULL;
mapped_offset_ = data_.begin() - position_end_;
mapped_offset_ = 0;
at_end_ = false;
Shift();
}
@ -49,6 +70,7 @@ float FilePiece::ReadFloat() throw(EndOfFileException, ParseNumberException) {
SkipSpaces();
while (last_space_ < position_) {
if (at_end_) {
// Hallucinate a null off the end of the file.
std::string buffer(position_, position_end_);
char *end;
float ret = std::strtof(buffer.c_str(), &end);
@ -108,30 +130,95 @@ StringPiece FilePiece::ReadLine(char delim) throw (EndOfFileException) {
void FilePiece::Shift() throw(EndOfFileException) {
if (at_end_) throw EndOfFileException();
off_t desired_begin = position_ - data_.begin() + mapped_offset_;
off_t ignore = desired_begin % page_;
// Duplicate request for Shift means give more data.
if (position_ == data_.begin() + ignore) {
default_map_size_ *= 2;
}
mapped_offset_ = desired_begin - ignore;
progress_.Set(desired_begin);
if (!fallback_to_read_) MMapShift(desired_begin);
// Notice an mmap failure might set the fallback.
if (fallback_to_read_) ReadShift(desired_begin);
// The normal operation of this loop is to run once. However, it may run
// multiple times if we can't find an enter character.
off_t mapped_size;
if (default_map_size_ >= total_size_ - mapped_offset_) {
at_end_ = true;
mapped_size = total_size_ - mapped_offset_;
} else {
mapped_size = default_map_size_;
}
data_.reset();
data_.reset(mmap(NULL, mapped_size, PROT_READ, MAP_PRIVATE, *file_, mapped_offset_), mapped_size);
if (data_.get() == MAP_FAILED) throw ErrnoException("mmap language model file for reading");
position_ = data_.begin() + ignore;
position_end_ = data_.begin() + mapped_size;
for (last_space_ = position_end_ - 1; last_space_ >= position_; --last_space_) {
if (isspace(*last_space_)) break;
}
}
void FilePiece::MMapShift(off_t desired_begin) throw() {
// Use mmap.
off_t ignore = desired_begin % page_;
// Duplicate request for Shift means give more data.
if (position_ == data_.begin() + ignore) {
default_map_size_ *= 2;
}
// Local version so that in case of failure it doesn't overwrite the class variable.
off_t mapped_offset = desired_begin - ignore;
off_t mapped_size;
if (default_map_size_ >= static_cast<size_t>(total_size_ - mapped_offset)) {
at_end_ = true;
mapped_size = total_size_ - mapped_offset;
} else {
mapped_size = default_map_size_;
}
// Forcibly clear the existing mmap first.
data_.reset();
data_.reset(mmap(NULL, mapped_size, PROT_READ, MAP_PRIVATE, *file_, mapped_offset), mapped_size, scoped_memory::MMAP_ALLOCATED);
if (data_.get() == MAP_FAILED) {
fallback_to_read_ = true;
if (desired_begin) {
if (((off_t)-1) == lseek(*file_, desired_begin, SEEK_SET)) UTIL_THROW(ErrnoException, "mmap failed even though it worked before. lseek failed too, so using read isn't an option either.");
}
return;
}
mapped_offset_ = mapped_offset;
position_ = data_.begin() + ignore;
position_end_ = data_.begin() + mapped_size;
}
void FilePiece::ReadShift(off_t desired_begin) throw() {
assert(fallback_to_read_);
if (data_.source() != scoped_memory::MALLOC_ALLOCATED) {
// First call.
data_.reset();
data_.reset(malloc(default_map_size_), default_map_size_, scoped_memory::MALLOC_ALLOCATED);
if (!data_.get()) UTIL_THROW(ErrnoException, "malloc failed for " << default_map_size_);
position_ = data_.begin();
position_end_ = position_;
}
// Bytes [data_.begin(), position_) have been consumed.
// Bytes [position_, position_end_) have been read into the buffer.
// Start at the beginning of the buffer if there's nothing useful in it.
if (position_ == position_end_) {
mapped_offset_ += (position_end_ - data_.begin());
position_ = data_.begin();
position_end_ = position_;
}
std::size_t already_read = position_end_ - data_.begin();
if (already_read == default_map_size_) {
if (position_ == data_.begin()) {
// Buffer too small.
std::size_t valid_length = position_end_ - position_;
default_map_size_ *= 2;
data_.call_realloc(default_map_size_);
if (!data_.get()) UTIL_THROW(ErrnoException, "realloc failed for " << default_map_size_);
position_ = data_.begin();
position_end_ = position_ + valid_length;
} else {
size_t moving = position_end_ - position_;
memmove(data_.get(), position_, moving);
position_ = data_.begin();
position_end_ = position_ + moving;
already_read = moving;
}
}
ssize_t read_return = read(file_.get(), static_cast<char*>(data_.get()) + already_read, default_map_size_ - already_read);
if (read_return == -1) UTIL_THROW(ErrnoException, "read failed");
if (read_return == 0) at_end_ = true;
position_end_ += read_return;
}
} // namespace util

View File

@ -1,41 +1,41 @@
#ifndef UTIL_FILE_PIECE__
#define UTIL_FILE_PIECE__
#include "util/ersatz_progress.hh"
#include "util/exception.hh"
#include "util/scoped.hh"
#include "util/string_piece.hh"
#include <exception>
#include <string>
#include <cstddef>
namespace util {
class EndOfFileException : public std::exception {
class EndOfFileException : public Exception {
public:
EndOfFileException() throw() {}
~EndOfFileException() throw() {}
const char *what() const throw() { return "End of file."; }
EndOfFileException() throw();
~EndOfFileException() throw();
};
class ParseNumberException : public std::exception {
class ParseNumberException : public Exception {
public:
explicit ParseNumberException(StringPiece value) throw();
~ParseNumberException() throw() {}
const char *what() const throw() { return what_.c_str(); }
private:
std::string what_;
};
int OpenReadOrThrow(const char *name);
// Return value for SizeFile when it can't size properly.
const off_t kBadSize = -1;
off_t SizeFile(int fd);
class FilePiece {
public:
// 32 MB default.
explicit FilePiece(const char *file, off_t min_buffer = 33554432);
explicit FilePiece(const char *file, std::ostream *show_progress = NULL, off_t min_buffer = 33554432);
// Takes ownership of fd. name is used for messages.
explicit FilePiece(const char *name, int fd, std::ostream *show_progress = NULL, off_t min_buffer = 33554432);
char get() throw(EndOfFileException) {
if (position_ == position_end_) Shift();
@ -55,8 +55,19 @@ class FilePiece {
float ReadFloat() throw(EndOfFileException, ParseNumberException);
void SkipSpaces() throw (EndOfFileException);
off_t Offset() const {
return position_ - data_.begin() + mapped_offset_;
}
// Only for testing.
void ForceFallbackToRead() {
fallback_to_read_ = true;
}
private:
void Initialize(const char *name, std::ostream *show_progress, off_t min_buffer);
StringPiece Consume(const char *to) {
StringPiece ret(position_, to - position_);
position_ = to;
@ -66,6 +77,9 @@ class FilePiece {
const char *FindDelimiterOrEOF() throw(EndOfFileException);
void Shift() throw (EndOfFileException);
// Backends to Shift().
void MMapShift(off_t desired_begin) throw ();
void ReadShift(off_t desired_begin) throw ();
const char *position_, *last_space_, *position_end_;
@ -73,13 +87,16 @@ class FilePiece {
const off_t total_size_;
const off_t page_;
off_t default_map_size_;
size_t default_map_size_;
off_t mapped_offset_;
// Order matters: file_ should always be destroyed after this.
scoped_mmap data_;
scoped_memory data_;
bool at_end_;
bool fallback_to_read_;
ErsatzProgress progress_;
};
} // namespace util

View File

@ -8,16 +8,32 @@
namespace util {
namespace {
BOOST_AUTO_TEST_CASE(ReadLine) {
std::fstream ref("file_piece.hh", std::ios::in);
FilePiece test("file_piece.hh", 1);
/* mmap implementation */
BOOST_AUTO_TEST_CASE(MMapLine) {
std::fstream ref("file_piece.cc", std::ios::in);
FilePiece test("file_piece.cc", NULL, 1);
std::string ref_line;
while (getline(ref, ref_line)) {
StringPiece test_line(test.ReadLine());
if (test_line != ref_line) {
std::cerr << test_line.size() << " " << ref_line.size() << std::endl;
// I submitted a bug report to ICU: http://bugs.icu-project.org/trac/ticket/7924
if (!test_line.empty() || !ref_line.empty()) {
BOOST_CHECK_EQUAL(ref_line, test_line);
}
}
}
/* read() implementation */
BOOST_AUTO_TEST_CASE(ReadLine) {
std::fstream ref("file_piece.cc", std::ios::in);
FilePiece test("file_piece.cc", NULL, 1);
test.ForceFallbackToRead();
std::string ref_line;
while (getline(ref, ref_line)) {
StringPiece test_line(test.ReadLine());
// I submitted a bug report to ICU: http://bugs.icu-project.org/trac/ticket/7924
if (!test_line.empty() || !ref_line.empty()) {
BOOST_CHECK_EQUAL(ref_line, test_line);
}
BOOST_CHECK_EQUAL(ref_line, test_line);
}
}

145
kenlm/util/joint_sort.hh Normal file
View File

@ -0,0 +1,145 @@
#ifndef UTIL_JOINT_SORT__
#define UTIL_JOINT_SORT__
/* A terrifying amount of C++ to coax std::sort into soring one range while
* also permuting another range the same way.
*/
#include "util/proxy_iterator.hh"
#include <algorithm>
#include <functional>
#include <iostream>
namespace util {
namespace detail {
template <class KeyIter, class ValueIter> class JointProxy;
template <class KeyIter, class ValueIter> class JointIter {
public:
JointIter() {}
JointIter(const KeyIter &key_iter, const ValueIter &value_iter) : key_(key_iter), value_(value_iter) {}
bool operator==(const JointIter<KeyIter, ValueIter> &other) const { return key_ == other.key_; }
bool operator<(const JointIter<KeyIter, ValueIter> &other) const { return (key_ < other.key_); }
std::ptrdiff_t operator-(const JointIter<KeyIter, ValueIter> &other) const { return key_ - other.key_; }
JointIter<KeyIter, ValueIter> &operator+=(std::ptrdiff_t amount) {
key_ += amount;
value_ += amount;
return *this;
}
void swap(const JointIter &other) {
std::swap(key_, other.key_);
std::swap(value_, other.value_);
}
private:
friend class JointProxy<KeyIter, ValueIter>;
KeyIter key_;
ValueIter value_;
};
template <class KeyIter, class ValueIter> class JointProxy {
private:
typedef JointIter<KeyIter, ValueIter> InnerIterator;
public:
typedef struct {
typename std::iterator_traits<KeyIter>::value_type key;
typename std::iterator_traits<ValueIter>::value_type value;
const typename std::iterator_traits<KeyIter>::value_type &GetKey() const { return key; }
} value_type;
JointProxy(const KeyIter &key_iter, const ValueIter &value_iter) : inner_(key_iter, value_iter) {}
JointProxy(const JointProxy<KeyIter, ValueIter> &other) : inner_(other.inner_) {}
operator const value_type() const {
value_type ret;
ret.key = *inner_.key_;
ret.value = *inner_.value_;
return ret;
}
JointProxy &operator=(const JointProxy &other) {
*inner_.key_ = *other.inner_.key_;
*inner_.value_ = *other.inner_.value_;
return *this;
}
JointProxy &operator=(const value_type &other) {
*inner_.key_ = other.key;
*inner_.value_ = other.value;
return *this;
}
typename std::iterator_traits<KeyIter>::reference GetKey() const {
return *(inner_.key_);
}
void swap(JointProxy<KeyIter, ValueIter> &other) {
std::swap(*inner_.key_, *other.inner_.key_);
std::swap(*inner_.value_, *other.inner_.value_);
}
private:
friend class ProxyIterator<JointProxy<KeyIter, ValueIter> >;
InnerIterator &Inner() { return inner_; }
const InnerIterator &Inner() const { return inner_; }
InnerIterator inner_;
};
template <class Proxy, class Less> class LessWrapper : public std::binary_function<const typename Proxy::value_type &, const typename Proxy::value_type &, bool> {
public:
explicit LessWrapper(const Less &less) : less_(less) {}
bool operator()(const Proxy &left, const Proxy &right) const {
return less_(left.GetKey(), right.GetKey());
}
bool operator()(const Proxy &left, const typename Proxy::value_type &right) const {
return less_(left.GetKey(), right.GetKey());
}
bool operator()(const typename Proxy::value_type &left, const Proxy &right) const {
return less_(left.GetKey(), right.GetKey());
}
bool operator()(const typename Proxy::value_type &left, const typename Proxy::value_type &right) const {
return less_(left.GetKey(), right.GetKey());
}
private:
const Less less_;
};
} // namespace detail
template <class KeyIter, class ValueIter, class Less> void JointSort(const KeyIter &key_begin, const KeyIter &key_end, const ValueIter &value_begin, const Less &less) {
ProxyIterator<detail::JointProxy<KeyIter, ValueIter> > full_begin(detail::JointProxy<KeyIter, ValueIter>(key_begin, value_begin));
detail::LessWrapper<detail::JointProxy<KeyIter, ValueIter>, Less> less_wrap(less);
std::sort(full_begin, full_begin + (key_end - key_begin), less_wrap);
}
template <class KeyIter, class ValueIter> void JointSort(const KeyIter &key_begin, const KeyIter &key_end, const ValueIter &value_begin) {
JointSort(key_begin, key_end, value_begin, std::less<typename std::iterator_traits<KeyIter>::value_type>());
}
} // namespace util
namespace std {
template <class KeyIter, class ValueIter> void swap(util::detail::JointIter<KeyIter, ValueIter> &left, util::detail::JointIter<KeyIter, ValueIter> &right) {
left.swap(right);
}
template <class KeyIter, class ValueIter> void swap(util::detail::JointProxy<KeyIter, ValueIter> &left, util::detail::JointProxy<KeyIter, ValueIter> &right) {
left.swap(right);
}
} // namespace std
#endif // UTIL_JOINT_SORT__

View File

@ -0,0 +1,50 @@
#include "util/joint_sort.hh"
#define BOOST_TEST_MODULE JointSortTest
#include <boost/test/unit_test.hpp>
namespace util { namespace {
BOOST_AUTO_TEST_CASE(just_flip) {
char keys[2];
int values[2];
keys[0] = 1; values[0] = 327;
keys[1] = 0; values[1] = 87897;
JointSort<char *, int *>(keys + 0, keys + 2, values + 0);
BOOST_CHECK_EQUAL(0, keys[0]);
BOOST_CHECK_EQUAL(87897, values[0]);
BOOST_CHECK_EQUAL(1, keys[1]);
BOOST_CHECK_EQUAL(327, values[1]);
}
BOOST_AUTO_TEST_CASE(three) {
char keys[3];
int values[3];
keys[0] = 1; values[0] = 327;
keys[1] = 2; values[1] = 87897;
keys[2] = 0; values[2] = 10;
JointSort<char *, int *>(keys + 0, keys + 3, values + 0);
BOOST_CHECK_EQUAL(0, keys[0]);
BOOST_CHECK_EQUAL(1, keys[1]);
BOOST_CHECK_EQUAL(2, keys[2]);
}
BOOST_AUTO_TEST_CASE(char_int) {
char keys[4];
int values[4];
keys[0] = 3; values[0] = 327;
keys[1] = 1; values[1] = 87897;
keys[2] = 2; values[2] = 10;
keys[3] = 0; values[3] = 24347;
JointSort<char *, int *>(keys + 0, keys + 4, values + 0);
BOOST_CHECK_EQUAL(0, keys[0]);
BOOST_CHECK_EQUAL(24347, values[0]);
BOOST_CHECK_EQUAL(1, keys[1]);
BOOST_CHECK_EQUAL(87897, values[1]);
BOOST_CHECK_EQUAL(2, keys[2]);
BOOST_CHECK_EQUAL(10, values[2]);
BOOST_CHECK_EQUAL(3, keys[3]);
BOOST_CHECK_EQUAL(327, values[3]);
}
}} // namespace anonymous util

View File

@ -0,0 +1,122 @@
#ifndef UTIL_KEY_VALUE_PACKING__
#define UTIL_KEY_VALUE_PACKING__
/* Why such a general interface? I'm planning on doing bit-level packing. */
#include <algorithm>
#include <cstddef>
#include <cstring>
#include <inttypes.h>
namespace util {
template <class Key, class Value> struct Entry {
Key key;
Value value;
const Key &GetKey() const { return key; }
const Value &GetValue() const { return value; }
void Set(const Key &key_in, const Value &value_in) {
SetKey(key_in);
SetValue(value_in);
}
void SetKey(const Key &key_in) { key = key_in; }
void SetValue(const Value &value_in) { value = value_in; }
bool operator<(const Entry<Key, Value> &other) const { return GetKey() < other.GetKey(); }
};
// And now for a brief interlude to specialize std::swap.
} // namespace util
namespace std {
template <class Key, class Value> void swap(util::Entry<Key, Value> &first, util::Entry<Key, Value> &second) {
swap(first.key, second.key);
swap(first.value, second.value);
}
}// namespace std
namespace util {
template <class KeyT, class ValueT> class AlignedPacking {
public:
typedef KeyT Key;
typedef ValueT Value;
public:
static const std::size_t kBytes = sizeof(Entry<Key, Value>);
static const std::size_t kBits = kBytes * 8;
typedef Entry<Key, Value> * MutableIterator;
typedef const Entry<Key, Value> * ConstIterator;
typedef const Entry<Key, Value> & ConstReference;
static MutableIterator FromVoid(void *start) {
return reinterpret_cast<MutableIterator>(start);
}
static Entry<Key, Value> Make(const Key &key, const Value &value) {
Entry<Key, Value> ret;
ret.Set(key, value);
return ret;
}
};
template <class KeyT, class ValueT> class ByteAlignedPacking {
public:
typedef KeyT Key;
typedef ValueT Value;
private:
#pragma pack(push)
#pragma pack(1)
struct RawEntry {
Key key;
Value value;
const Key &GetKey() const { return key; }
const Value &GetValue() const { return value; }
void Set(const Key &key_in, const Value &value_in) {
SetKey(key_in);
SetValue(value_in);
}
void SetKey(const Key &key_in) { key = key_in; }
void SetValue(const Value &value_in) { value = value_in; }
bool operator<(const RawEntry &other) const { return GetKey() < other.GetKey(); }
};
#pragma pack(pop)
friend void std::swap<>(RawEntry&, RawEntry&);
public:
typedef RawEntry *MutableIterator;
typedef const RawEntry *ConstIterator;
typedef RawEntry &ConstReference;
static const std::size_t kBytes = sizeof(RawEntry);
static const std::size_t kBits = kBytes * 8;
static MutableIterator FromVoid(void *start) {
return MutableIterator(reinterpret_cast<RawEntry*>(start));
}
static RawEntry Make(const Key &key, const Value &value) {
RawEntry ret;
ret.Set(key, value);
return ret;
}
};
} // namespace util
namespace std {
template <class Key, class Value> void swap(
typename util::ByteAlignedPacking<Key, Value>::RawEntry &first,
typename util::ByteAlignedPacking<Key, Value>::RawEntry &second) {
swap(first.key, second.key);
swap(first.value, second.value);
}
}// namespace std
#endif // UTIL_KEY_VALUE_PACKING__

View File

@ -0,0 +1,75 @@
#include "util/key_value_packing.hh"
#include <boost/random/mersenne_twister.hpp>
#include <boost/random/uniform_int.hpp>
#include <boost/random/variate_generator.hpp>
#include <boost/scoped_array.hpp>
#define BOOST_TEST_MODULE KeyValueStoreTest
#include <boost/test/unit_test.hpp>
#include <limits>
#include <stdlib.h>
namespace util {
namespace {
BOOST_AUTO_TEST_CASE(basic_in_out) {
typedef ByteAlignedPacking<uint64_t, unsigned char> Packing;
void *backing = malloc(Packing::kBytes * 2);
Packing::MutableIterator i(Packing::FromVoid(backing));
i->SetKey(10);
BOOST_CHECK_EQUAL(10, i->GetKey());
i->SetValue(3);
BOOST_CHECK_EQUAL(3, i->GetValue());
++i;
i->SetKey(5);
BOOST_CHECK_EQUAL(5, i->GetKey());
i->SetValue(42);
BOOST_CHECK_EQUAL(42, i->GetValue());
Packing::ConstIterator c(i);
BOOST_CHECK_EQUAL(5, c->GetKey());
--c;
BOOST_CHECK_EQUAL(10, c->GetKey());
BOOST_CHECK_EQUAL(42, i->GetValue());
BOOST_CHECK_EQUAL(5, i->GetKey());
free(backing);
}
BOOST_AUTO_TEST_CASE(simple_sort) {
typedef ByteAlignedPacking<uint64_t, unsigned char> Packing;
char foo[Packing::kBytes * 4];
Packing::MutableIterator begin(Packing::FromVoid(foo));
Packing::MutableIterator i = begin;
i->SetKey(0); ++i;
i->SetKey(2); ++i;
i->SetKey(3); ++i;
i->SetKey(1); ++i;
std::sort(begin, i);
BOOST_CHECK_EQUAL(0, begin[0].GetKey());
BOOST_CHECK_EQUAL(1, begin[1].GetKey());
BOOST_CHECK_EQUAL(2, begin[2].GetKey());
BOOST_CHECK_EQUAL(3, begin[3].GetKey());
}
BOOST_AUTO_TEST_CASE(big_sort) {
typedef ByteAlignedPacking<uint64_t, unsigned char> Packing;
boost::scoped_array<char> memory(new char[Packing::kBytes * 1000]);
Packing::MutableIterator begin(Packing::FromVoid(memory.get()));
boost::mt19937 rng;
boost::uniform_int<uint64_t> range(0, std::numeric_limits<uint64_t>::max());
boost::variate_generator<boost::mt19937&, boost::uniform_int<uint64_t> > gen(rng, range);
for (size_t i = 0; i < 1000; ++i) {
(begin + i)->SetKey(gen());
}
std::sort(begin, begin + 1000);
for (size_t i = 0; i < 999; ++i) {
BOOST_CHECK(begin[i] < begin[i+1]);
}
}
} // namespace
} // namespace util

View File

@ -1,9 +1,18 @@
// Downloaded from http://sites.google.com/site/murmurhash/ which says "All
// code is released to the public domain. For business purposes, Murmurhash is
// under the MIT license."
/* Downloaded from http://sites.google.com/site/murmurhash/ which says "All
* code is released to the public domain. For business purposes, Murmurhash is
* under the MIT license."
* This is modified from the original:
* ULL tag on 0xc6a4a7935bd1e995 so this will compile on 32-bit.
* length changed to unsigned int.
* placed in namespace util
* add MurmurHashNative
* default option = 0 for seed
*/
#include "util/murmur_hash.hh"
namespace util {
//-----------------------------------------------------------------------------
// MurmurHash2, 64-bit versions, by Austin Appleby
@ -12,9 +21,9 @@
// 64-bit hash for 64-bit platforms
uint64_t MurmurHash64A ( const void * key, int len, unsigned int seed )
uint64_t MurmurHash64A ( const void * key, std::size_t len, unsigned int seed )
{
const uint64_t m = 0xc6a4a7935bd1e995;
const uint64_t m = 0xc6a4a7935bd1e995ULL;
const int r = 47;
uint64_t h = seed ^ (len * m);
@ -58,7 +67,7 @@ uint64_t MurmurHash64A ( const void * key, int len, unsigned int seed )
// 64-bit hash for 32-bit platforms
uint64_t MurmurHash64B ( const void * key, int len, unsigned int seed )
uint64_t MurmurHash64B ( const void * key, std::size_t len, unsigned int seed )
{
const unsigned int m = 0x5bd1e995;
const int r = 24;
@ -107,4 +116,14 @@ uint64_t MurmurHash64B ( const void * key, int len, unsigned int seed )
h = (h << 32) | h2;
return h;
}
}
uint64_t MurmurHashNative(const void * key, unsigned int len, unsigned int seed) {
if (sizeof(int) == 4) {
return MurmurHash64B(key, len, seed);
} else {
return MurmurHash64A(key, len, seed);
}
}
} // namespace util

View File

@ -1,8 +1,14 @@
#ifndef UTIL_MURMUR_HASH_H__
#define UTIL_MURMUR_HASH_H__
#ifndef UTIL_MURMUR_HASH__
#define UTIL_MURMUR_HASH__
#include <cstddef>
#include <stdint.h>
uint64_t MurmurHash64A (const void * key, int len, unsigned int seed);
uint64_t MurmurHash64B (const void * key, int len, unsigned int seed);
namespace util {
#endif // UTIL_MURMUR_HASH_H__
uint64_t MurmurHash64A(const void * key, std::size_t len, unsigned int seed = 0);
uint64_t MurmurHash64B(const void * key, std::size_t len, unsigned int seed = 0);
uint64_t MurmurHashNative(const void * key, std::size_t len, unsigned int seed = 0);
} // namespace util
#endif // UTIL_MURMUR_HASH__

View File

@ -4,7 +4,8 @@
#include <algorithm>
#include <cstddef>
#include <functional>
#include <utility>
#include <assert.h>
namespace util {
@ -15,148 +16,80 @@ namespace util {
* serialize these to disk and load them quickly.
* Uses linear probing to find value.
* Only insert and lookup operations.
* Generic find operation.
*
*/
template <class ValueT, class HashT, class EqualT = std::equal_to<ValueT>, class PointerT = const ValueT *> class ReadProbingHashTable {
template <class PackingT, class HashT, class EqualT = std::equal_to<typename PackingT::Key> > class ProbingHashTable {
public:
typedef ValueT Value;
typedef PackingT Packing;
typedef typename Packing::Key Key;
typedef typename Packing::MutableIterator MutableIterator;
typedef typename Packing::ConstIterator ConstIterator;
typedef HashT Hash;
typedef EqualT Equal;
ReadProbingHashTable() {}
static std::size_t Size(std::size_t entries, float multiplier) {
return std::max(entries + 1, static_cast<std::size_t>(multiplier * static_cast<float>(entries))) * Packing::kBytes;
}
ReadProbingHashTable(
PointerT start,
std::size_t buckets,
const Value &invalid,
const Hash &hash_func = Hash(),
const Equal &equal_func = Equal())
: start_(start), end_(start + buckets), buckets_(buckets), invalid_(invalid), hash_(hash_func), equal_(equal_func) {}
// Must be assigned to later.
ProbingHashTable()
#ifdef DEBUG
: initialized_(false), entries_(0)
#endif
{}
template <class Key> const Value *Find(const Key &key) const {
const Value *it = start_ + (hash_(key) % buckets_);
while (true) {
if (equal_(*it, invalid_)) return NULL;
if (equal_(*it, key)) return it;
++it;
if (it == end_) it = start_;
ProbingHashTable(void *start, std::size_t allocated, const Key &invalid = Key(), const Hash &hash_func = Hash(), const Equal &equal_func = Equal())
: begin_(Packing::FromVoid(start)),
buckets_(allocated / Packing::kBytes),
end_(begin_ + (allocated / Packing::kBytes)),
invalid_(invalid),
hash_(hash_func),
equal_(equal_func)
#ifdef DEBUG
, initialized_(true),
entries_(0)
#endif
{}
template <class T> void Insert(const T &t) {
#ifdef DEBUG
assert(initialized_);
assert(++entries_ < buckets_);
#endif
for (MutableIterator i(begin_ + (hash_(t.GetKey()) % buckets_));;) {
if (equal_(i->GetKey(), invalid_)) { *i = t; return; }
if (++i == end_) { i = begin_; }
}
}
protected:
PointerT start_, end_;
std::size_t buckets_;
Value invalid_;
Hash hash_;
Equal equal_;
};
template <class ValueT, class HashT, class EqualT = std::equal_to<ValueT> > class ProbingHashTable : public ReadProbingHashTable<ValueT, HashT, EqualT, ValueT *> {
private:
typedef ReadProbingHashTable<ValueT, HashT, EqualT, ValueT *> P;
public:
ProbingHashTable() {}
// Memory should be initialized buckets copies of invalid.
ProbingHashTable(
typename P::Value *start,
std::size_t buckets,
const typename P::Value &invalid,
const typename P::Hash &hash_func = typename P::Hash(),
const typename P::Equal &equal_func = typename P::Equal())
: P(start, buckets, invalid, hash_func, equal_func) {}
std::pair<const typename P::Value *, bool> Insert(const typename P::Value &value) {
typename P::Value *it = P::start_ + (P::hash_(value) % P::buckets_);
while (!P::equal_(*it, P::invalid_)) {
if (P::equal_(*it, value)) return std::pair<const typename P::Value*, bool>(it, false);
++it;
if (it == P::end_) it = P::start_;
}
*it = value;
return std::pair<const typename P::Value*, bool>(it, true);
}
const typename P::Value *InsertAlreadyUnique(const typename P::Value &value) {
typename P::Value *it = P::start_ + (P::hash_(value) % P::buckets_);
while (!P::equal_(*it, P::invalid_)) {
++it;
if (it == P::end_) it = P::start_;
}
*it = value;
return it;
}
};
// Default configuration of the above: table from keys to values.
template <class KeyT, class ValueT, class HashT, class EqualsT = std::equal_to<KeyT> > class ProbingMap {
public:
typedef KeyT Key;
typedef ValueT Value;
typedef HashT Hash;
typedef EqualsT Equals;
static std::size_t Size(float multiplier, std::size_t entries) {
return std::max(entries + 1, static_cast<std::size_t>(multiplier * static_cast<float>(entries))) * sizeof(Entry);
}
ProbingMap() {}
ProbingMap(float multiplier, char *start, std::size_t entries, const Hash &hasher = Hash(), const Equals &equals = Equals())
: table_(
reinterpret_cast<Entry*>(start),
Size(multiplier, entries) / sizeof(Entry),
Entry(),
HashKeyOnly(hasher),
EqualsKeyOnly(equals)) {}
bool Find(const Key &key, const Value *&value) const {
const Entry *e = table_.Find(key);
if (!e) return false;
value = &e->value;
return true;
}
void Insert(const Key &key, const Value &value) {
Entry e;
e.key = key;
e.value = value;
table_.Insert(e);
}
void FinishedInserting() {}
void LoadedBinary() {}
template <class Key> bool Find(const Key key, ConstIterator &out) const {
#ifdef DEBUG
assert(initialized_);
#endif
for (ConstIterator i(begin_ + (hash_(key) % buckets_));;) {
Key got(i->GetKey());
if (equal_(got, key)) { out = i; return true; }
if (equal_(got, invalid_)) { return false; }
if (++i == end_) { i = begin_; }
}
}
private:
struct Entry {
Key key;
Value value;
};
class HashKeyOnly : public std::unary_function<const Entry &, std::size_t> {
public:
HashKeyOnly() {}
explicit HashKeyOnly(const Hash &hasher) : hasher_(hasher) {}
std::size_t operator()(const Entry &e) const { return hasher_(e.key); }
std::size_t operator()(const Key value) const { return hasher_(value); }
private:
Hash hasher_;
};
struct EqualsKeyOnly : public std::binary_function<const Entry &, const Entry &, bool> {
public:
EqualsKeyOnly() {}
explicit EqualsKeyOnly(const Equals &equals) : equals_(equals) {}
bool operator()(const Entry &a, const Entry &b) const { return equals_(a.key, b.key); }
bool operator()(const Entry &a, const Key k) const { return equals_(a.key, k); }
private:
Equals equals_;
};
ProbingHashTable<Entry, HashKeyOnly, EqualsKeyOnly> table_;
MutableIterator begin_;
std::size_t buckets_;
MutableIterator end_;
Key invalid_;
Hash hash_;
Equal equal_;
#ifdef DEBUG
bool initialized_;
std::size_t entries_;
#endif
};
} // namespace util

View File

@ -1,5 +1,7 @@
#include "util/probing_hash_table.hh"
#include "util/key_value_packing.hh"
#define BOOST_TEST_MODULE ProbingHashTableTest
#include <boost/test/unit_test.hpp>
#include <boost/functional/hash.hpp>
@ -7,14 +9,21 @@
namespace util {
namespace {
typedef AlignedPacking<char, uint64_t> Packing;
typedef ProbingHashTable<Packing, boost::hash<char> > Table;
BOOST_AUTO_TEST_CASE(simple) {
char mem[10];
char mem[Table::Size(10, 1.2)];
memset(mem, 0, sizeof(mem));
ProbingHashTable<char, boost::hash<char> > table(mem, 10, 0);
BOOST_CHECK_EQUAL((char*)NULL, table.Find(2));
BOOST_CHECK_EQUAL((char)2, *table.Insert(2).first);
BOOST_REQUIRE(table.Find(2));
BOOST_CHECK_EQUAL((char)2, *table.Find(2));
Table table(mem, sizeof(mem));
Packing::ConstIterator i = Packing::ConstIterator();
BOOST_CHECK(!table.Find(2, i));
table.Insert(Packing::Make(3, 328920));
BOOST_REQUIRE(table.Find(3, i));
BOOST_CHECK_EQUAL(3, i->GetKey());
BOOST_CHECK_EQUAL(static_cast<uint64_t>(328920), i->GetValue());
BOOST_CHECK(!table.Find(2, i));
}
} // namespace

View File

@ -0,0 +1,94 @@
#ifndef UTIL_PROXY_ITERATOR__
#define UTIL_PROXY_ITERATOR__
#include <cstddef>
#include <iterator>
/* This is a RandomAccessIterator that uses a proxy to access the underlying
* data. Useful for packing data at bit offsets but still using STL
* algorithms.
*
* Normally I would use boost::iterator_facade but some people are too lazy to
* install boost and still want to use my language model. It's amazing how
* many operators an iterator has.
*
* The Proxy needs to provide:
* class InnerIterator;
* InnerIterator &Inner();
* const InnerIterator &Inner() const;
*
* InnerIterator has to implement:
* operator==(InnerIterator)
* operator<(InnerIterator)
* operator+=(std::ptrdiff_t)
* operator-(InnerIterator)
* and of course whatever Proxy needs to dereference it.
*
* It's also a good idea to specialize std::swap for Proxy.
*/
namespace util {
template <class Proxy> class ProxyIterator {
private:
// Self.
typedef ProxyIterator<Proxy> S;
typedef typename Proxy::InnerIterator InnerIterator;
public:
typedef std::random_access_iterator_tag iterator_category;
typedef typename Proxy::value_type value_type;
typedef std::ptrdiff_t difference_type;
typedef Proxy reference;
typedef Proxy * pointer;
ProxyIterator() {}
// For cast from non const to const.
template <class AlternateProxy> ProxyIterator(const ProxyIterator<AlternateProxy> &in) : p_(*in) {}
explicit ProxyIterator(const Proxy &p) : p_(p) {}
// p_'s operator= does value copying, but here we want iterator copying.
S &operator=(const S &other) {
I() = other.I();
return *this;
}
bool operator==(const S &other) const { return I() == other.I(); }
bool operator!=(const S &other) const { return !(*this == other); }
bool operator<(const S &other) const { return I() < other.I(); }
bool operator>(const S &other) const { return other < *this; }
bool operator<=(const S &other) const { return !(*this > other); }
bool operator>=(const S &other) const { return !(*this < other); }
S &operator++() { return *this += 1; }
S operator++(int) { S ret(*this); ++*this; return ret; }
S &operator+=(std::ptrdiff_t amount) { I() += amount; return *this; }
S operator+(std::ptrdiff_t amount) const { S ret(*this); ret += amount; return ret; }
S &operator--() { return *this -= 1; }
S operator--(int) { S ret(*this); --*this; return ret; }
S &operator-=(std::ptrdiff_t amount) { I() += (-amount); return *this; }
S operator-(std::ptrdiff_t amount) const { S ret(*this); ret -= amount; return ret; }
std::ptrdiff_t operator-(const S &other) const { return I() - other.I(); }
Proxy operator*() { return p_; }
const Proxy operator*() const { return p_; }
Proxy *operator->() { return &p_; }
const Proxy *operator->() const { return &p_; }
Proxy operator[](std::ptrdiff_t amount) const { return *(*this + amount); }
private:
InnerIterator &I() { return p_.Inner(); }
const InnerIterator &I() const { return p_.Inner(); }
Proxy p_;
};
template <class Proxy> ProxyIterator<Proxy> operator+(std::ptrdiff_t amount, const ProxyIterator<Proxy> &it) {
return it + amount;
}
} // namespace util
#endif // UTIL_PROXY_ITERATOR__

View File

@ -1,7 +1,9 @@
#include "util/scoped.hh"
#include <assert.h>
#include <err.h>
#include <sys/mman.h>
#include <stdlib.h>
#include <unistd.h>
namespace util {
@ -17,4 +19,33 @@ scoped_mmap::~scoped_mmap() {
}
}
void scoped_memory::reset(void *data, std::size_t size, Alloc source) {
switch(source_) {
case MMAP_ALLOCATED:
scoped_mmap(data_, size_);
break;
case ARRAY_ALLOCATED:
delete [] reinterpret_cast<char*>(data_);
break;
case MALLOC_ALLOCATED:
free(data_);
break;
case NONE_ALLOCATED:
break;
}
data_ = data;
size_ = size;
source_ = source;
}
void scoped_memory::call_realloc(std::size_t size) {
assert(source_ == MALLOC_ALLOCATED || source_ == NONE_ALLOCATED);
void *new_data = realloc(data_, size);
if (!new_data) {
reset();
} else {
reset(new_data, size, MALLOC_ALLOCATED);
}
}
} // namespace util

View File

@ -1,13 +1,13 @@
#ifndef UTIL_SCOPED_H__
#define UTIL_SCOPED_H__
#ifndef UTIL_SCOPED__
#define UTIL_SCOPED__
#include <boost/noncopyable.hpp>
/* Other scoped objects in the style of scoped_ptr. */
#include <cstddef>
namespace util {
template <class T, class R, R (*Free)(T*)> class scoped_thing : boost::noncopyable {
template <class T, class R, R (*Free)(T*)> class scoped_thing {
public:
explicit scoped_thing(T *c = static_cast<T*>(0)) : c_(c) {}
@ -26,9 +26,12 @@ template <class T, class R, R (*Free)(T*)> class scoped_thing : boost::noncopyab
private:
T *c_;
scoped_thing(const scoped_thing &);
scoped_thing &operator=(const scoped_thing &);
};
class scoped_fd : boost::noncopyable {
class scoped_fd {
public:
scoped_fd() : fd_(-1) {}
@ -45,12 +48,21 @@ class scoped_fd : boost::noncopyable {
int operator*() const { return fd_; }
int release() {
int ret = fd_;
fd_ = -1;
return ret;
}
private:
int fd_;
scoped_fd(const scoped_fd &);
scoped_fd &operator=(const scoped_fd &);
};
// (void*)-1 is MAP_FAILED; this is done to avoid including the mmap header here.
class scoped_mmap : boost::noncopyable {
class scoped_mmap {
public:
scoped_mmap() : data_((void*)-1), size_(0) {}
scoped_mmap(void *data, std::size_t size) : data_(data), size_(size) {}
@ -75,9 +87,49 @@ class scoped_mmap : boost::noncopyable {
private:
void *data_;
std::size_t size_;
scoped_mmap(const scoped_mmap &);
scoped_mmap &operator=(const scoped_mmap &);
};
/* For when the memory might come from mmap, new char[], or malloc. Uses NULL
* and 0 for blanks even though mmap signals errors with (void*)-1). The reset
* function checks that blank for mmap.
*/
class scoped_memory {
public:
typedef enum {MMAP_ALLOCATED, ARRAY_ALLOCATED, MALLOC_ALLOCATED, NONE_ALLOCATED} Alloc;
scoped_memory() : data_(NULL), size_(0), source_(NONE_ALLOCATED) {}
~scoped_memory() { reset(); }
void *get() const { return data_; }
const char *begin() const { return reinterpret_cast<char*>(data_); }
const char *end() const { return reinterpret_cast<char*>(data_) + size_; }
std::size_t size() const { return size_; }
Alloc source() const { return source_; }
void reset() { reset(NULL, 0, NONE_ALLOCATED); }
void reset(void *data, std::size_t size, Alloc from);
// realloc allows the current data to escape hence the need for this call
// If realloc fails, destroys the original too and get() returns NULL.
void call_realloc(std::size_t to);
private:
void *data_;
std::size_t size_;
Alloc source_;
scoped_memory(const scoped_memory &);
scoped_memory &operator=(const scoped_memory &);
};
} // namespace util
#endif // UTIL_SCOPED_H__
#endif // UTIL_SCOPED__

View File

@ -1,10 +1,10 @@
#ifndef UTIL_SORTED_UNIFORM_H__
#define UTIL_SORTED_UNIFORM_H__
#ifndef UTIL_SORTED_UNIFORM__
#define UTIL_SORTED_UNIFORM__
#include <algorithm>
#include <cstddef>
#include <functional>
#include <assert.h>
#include <inttypes.h>
namespace util {
@ -14,90 +14,126 @@ inline std::size_t Pivot(uint64_t off, uint64_t range, std::size_t width) {
// Cap for floating point rounding
return (ret < width) ? ret : width - 1;
}
inline std::size_t Pivot(uint32_t off, uint32_t range, std::size_t width) {
/*inline std::size_t Pivot(uint32_t off, uint32_t range, std::size_t width) {
return static_cast<std::size_t>(static_cast<uint64_t>(off) * static_cast<uint64_t>(width) / static_cast<uint64_t>(range));
}
inline std::size_t Pivot(uint16_t off, uint16_t range, std::size_t width) {
return static_cast<std::size_t>(static_cast<std::size_t>(off) * width / static_cast<std::size_t>(range));
}
inline std::size_t Pivot(uint8_t off, uint8_t range, std::size_t width) {
inline std::size_t Pivot(unsigned char off, unsigned char range, std::size_t width) {
return static_cast<std::size_t>(static_cast<std::size_t>(off) * width / static_cast<std::size_t>(range));
}*/
template <class Iterator, class Key> bool SortedUniformFind(Iterator begin, Iterator end, const Key key, Iterator &out) {
if (begin == end) return false;
Key below(begin->GetKey());
if (key <= below) {
if (key == below) { out = begin; return true; }
return false;
}
// Make the range [begin, end].
--end;
Key above(end->GetKey());
if (key >= above) {
if (key == above) { out = end; return true; }
return false;
}
// Search the range [begin + 1, end - 1] knowing that *begin == below, *end == above.
while (end - begin > 1) {
Iterator pivot(begin + (1 + Pivot(key - below, above - below, static_cast<std::size_t>(end - begin - 1))));
Key mid(pivot->GetKey());
if (mid < key) {
begin = pivot;
below = mid;
} else if (mid > key) {
end = pivot;
above = mid;
} else {
out = pivot;
return true;
}
}
return false;
}
// For consistent API with ProbingSearch.
struct SortedUniformInit {};
// Define a Pivot function to match Key.
template <class KeyT, class ValueT> class SortedUniformMap {
// To use this template, you need to define a Pivot function to match Key.
template <class PackingT> class SortedUniformMap {
public:
typedef KeyT Key;
typedef ValueT Value;
typedef SortedUniformInit Init;
typedef PackingT Packing;
typedef typename Packing::ConstIterator ConstIterator;
static std::size_t Size(Init ignore, std::size_t entries) {
return entries * sizeof(Entry);
public:
// Offer consistent API with probing hash.
static std::size_t Size(std::size_t entries, float ignore = 0.0) {
return sizeof(uint64_t) + entries * Packing::kBytes;
}
SortedUniformMap() {}
SortedUniformMap()
#ifdef DEBUG
: initialized_(false), loaded_(false)
#endif
{}
SortedUniformMap(Init ignore, char *start, std::size_t entries) : begin_(reinterpret_cast<Entry*>(start)), end_(begin_) {}
SortedUniformMap(void *start, std::size_t allocated) :
begin_(Packing::FromVoid(reinterpret_cast<uint64_t*>(start) + 1)),
end_(begin_), size_ptr_(reinterpret_cast<uint64_t*>(start))
#ifdef DEBUG
, initialized_(true), loaded_(false)
#endif
{}
void LoadedBinary() {
#ifdef DEBUG
assert(initialized_);
assert(!loaded_);
loaded_ = true;
#endif
// Restore the size.
end_ = begin_ + *size_ptr_;
}
// Caller responsible for not exceeding specified size. Do not call after FinishedInserting.
void Insert(const Key &key, const Value &value) {
end_->key = key;
end_->value = value;
template <class T> void Insert(const T &t) {
#ifdef DEBUG
assert(initialized_);
assert(!loaded_);
#endif
*end_ = t;
++end_;
}
void FinishedInserting() {
std::sort(begin_, end_, LessEntry());
#ifdef DEBUG
assert(initialized_);
assert(!loaded_);
loaded_ = true;
#endif
std::sort(begin_, end_);
*size_ptr_ = (end_ - begin_);
}
// Do not call before FinishedInserting.
bool Find(const Key &key, const Value *&value) const {
const Entry *begin = begin_;
const Entry *end = end_;
while (begin != end) {
if (key <= begin->key) {
if (key != begin->key) return false;
value = &begin->value;
return true;
}
if (key >= (end - 1)->key) {
if (key != (end - 1)->key) return false;
value = &(end - 1)->value;
return true;
}
Key off = key - begin->key;
const Entry *pivot = begin + Pivot(off, (end - 1)->key - begin->key, end - begin);
if (pivot->key > key) {
end = pivot;
} else if (pivot->key < key) {
begin = pivot + 1;
} else {
value = &pivot->value;
return true;
}
}
return false;
template <class Key> bool Find(const Key key, ConstIterator &out) const {
#ifdef DEBUG
assert(initialized_);
assert(loaded_);
#endif
return SortedUniformFind<ConstIterator, Key>(ConstIterator(begin_), ConstIterator(end_), key, out);
}
ConstIterator begin() const { return begin_; }
ConstIterator end() const { return end_; }
private:
struct Entry {
Key key;
Value value;
};
struct LessEntry : public std::binary_function<const Entry &, const Entry &, bool> {
bool operator()(const Entry &left, const Entry &right) const {
return left.key < right.key;
}
};
Entry *begin_;
Entry *end_;
typename Packing::MutableIterator begin_, end_;
uint64_t *size_ptr_;
#ifdef DEBUG
bool initialized_;
bool loaded_;
#endif
};
} // namespace util
#endif // UTIL_SORTED_UNIFORM_H__
#endif // UTIL_SORTED_UNIFORM__

View File

@ -1,11 +1,13 @@
#include "util/sorted_uniform.hh"
#include "util/key_value_packing.hh"
#include <boost/random/mersenne_twister.hpp>
#include <boost/random/uniform_int.hpp>
#include <boost/random/variate_generator.hpp>
#include <boost/scoped_array.hpp>
#include <boost/unordered_map.hpp>
#define BOOST_TEST_MODULE SortedUniformBoundTest
#define BOOST_TEST_MODULE SortedUniformTest
#include <boost/test/unit_test.hpp>
#include <algorithm>
@ -15,35 +17,47 @@
namespace util {
namespace {
template <class Key, class Value> void Check(const SortedUniformMap<Key, Value> &map, const boost::unordered_map<Key, Value> &reference, const Key &key) {
template <class Map, class Key, class Value> void Check(const Map &map, const boost::unordered_map<Key, Value> &reference, const Key key) {
typename boost::unordered_map<Key, Value>::const_iterator ref = reference.find(key);
typename Map::ConstIterator i = typename Map::ConstIterator();
if (ref == reference.end()) {
const Value *val;
BOOST_CHECK(!map.Find(key, val));
BOOST_CHECK(!map.Find(key, i));
} else {
// g++ can't tell that require will crash and burn.
const Value *val = NULL;
BOOST_REQUIRE(map.Find(key, val));
BOOST_CHECK_EQUAL(ref->second, *val);
BOOST_REQUIRE(map.Find(key, i));
BOOST_CHECK_EQUAL(ref->second, i->GetValue());
}
}
/*BOOST_AUTO_TEST_CASE(empty) {
uint64_t foo;
Check<uint64_t>(&foo, &foo, 1);
typedef SortedUniformMap<AlignedPacking<uint64_t, uint32_t> > TestMap;
BOOST_AUTO_TEST_CASE(empty) {
char buf[TestMap::Size(0)];
TestMap map(buf, TestMap::Size(0));
map.FinishedInserting();
TestMap::ConstIterator i;
BOOST_CHECK(!map.Find(42, i));
}
BOOST_AUTO_TEST_CASE(one) {
uint64_t array[] = {1};
Check<uint64_t>(&array[0], &array[1], 1);
Check<uint64_t>(&array[0], &array[1], 0);
}*/
char buf[TestMap::Size(1)];
TestMap map(buf, sizeof(buf));
Entry<uint64_t, uint32_t> e;
e.Set(42,2);
map.Insert(e);
map.FinishedInserting();
TestMap::ConstIterator i = TestMap::ConstIterator();
BOOST_REQUIRE(map.Find(42, i));
BOOST_CHECK(i == map.begin());
BOOST_CHECK(!map.Find(43, i));
BOOST_CHECK(!map.Find(41, i));
}
template <class Key> void RandomTest(Key upper, size_t entries, size_t queries) {
typedef unsigned char Value;
typedef SortedUniformMap<Key, unsigned char> Map;
boost::scoped_array<char> buffer(new char[Map::Size(typename Map::Init(), entries)]);
Map map(typename Map::Init(), buffer.get(), entries);
typedef SortedUniformMap<AlignedPacking<Key, unsigned char> > Map;
boost::scoped_array<char> buffer(new char[Map::Size(entries)]);
Map map(buffer.get(), entries);
boost::mt19937 rng;
boost::uniform_int<Key> range_key(0, upper);
boost::uniform_int<Value> range_value(0, 255);
@ -51,11 +65,13 @@ template <class Key> void RandomTest(Key upper, size_t entries, size_t queries)
boost::variate_generator<boost::mt19937&, boost::uniform_int<unsigned char> > gen_value(rng, range_value);
boost::unordered_map<Key, unsigned char> reference;
Entry<Key, unsigned char> ent;
for (size_t i = 0; i < entries; ++i) {
Key key = gen_key();
unsigned char value = gen_value();
if (reference.insert(std::make_pair(key, value)).second) {
map.Insert(key, value);
ent.Set(key, value);
map.Insert(Entry<Key, unsigned char>(ent));
}
}
map.FinishedInserting();
@ -63,17 +79,17 @@ template <class Key> void RandomTest(Key upper, size_t entries, size_t queries)
// Random queries.
for (size_t i = 0; i < queries; ++i) {
const Key key = gen_key();
Check<Key, Value>(map, reference, key);
Check<Map, Key, unsigned char>(map, reference, key);
}
typename boost::unordered_map<Key, unsigned char>::const_iterator it = reference.begin();
for (size_t i = 0; (i < queries) && (it != reference.end()); ++i, ++it) {
Check<Key,Value>(map, reference, it->second);
Check<Map, Key, unsigned char>(map, reference, it->second);
}
}
BOOST_AUTO_TEST_CASE(sparse_random) {
RandomTest<uint64_t>(std::numeric_limits<uint64_t>::max(), 100000, 2000);
BOOST_AUTO_TEST_CASE(basic) {
RandomTest<uint8_t>(11, 10, 200);
}
BOOST_AUTO_TEST_CASE(tiny_dense_random) {
@ -92,5 +108,9 @@ BOOST_AUTO_TEST_CASE(medium_sparse_random) {
RandomTest<uint16_t>(32000, 1000, 2000);
}
BOOST_AUTO_TEST_CASE(sparse_random) {
RandomTest<uint64_t>(std::numeric_limits<uint64_t>::max(), 100000, 2000);
}
} // namespace
} // namespace util

View File

@ -30,20 +30,28 @@
#include "util/string_piece.hh"
#ifdef USE_BOOST
#include <boost/functional/hash/hash.hpp>
#endif
#include <algorithm>
#include <iostream>
#ifdef USE_ICU
U_NAMESPACE_BEGIN
#endif
std::ostream& operator<<(std::ostream& o, const StringPiece& piece) {
o.write(piece.data(), static_cast<std::streamsize>(piece.size()));
return o;
}
#ifdef USE_BOOST
size_t hash_value(const StringPiece &str) {
return boost::hash_range(str.data(), str.data() + str.length());
}
#endif
#ifdef USE_ICU
U_NAMESPACE_END
#endif

View File

@ -1,4 +1,7 @@
/* This supplements the deficient implementation of StringPiece provided by ICU */
/* If you use ICU in your program, then compile with -DUSE_ICU -licui18n. If
* you don't use ICU, then this will use the Google implementation from Chrome.
* This has been modified from the original version to let you choose.
*/
// Copyright 2008, Google Inc.
// All rights reserved.
@ -45,12 +48,166 @@
#ifndef BASE_STRING_PIECE_H__
#define BASE_STRING_PIECE_H__
#include <unicode/stringpiece.h>
//Uncomment this line if you use ICU in your code.
//#define USE_ICU
//Uncomment this line if you want boost hashing for your StringPieces.
//#define USE_BOOST
#include <cstring>
#include <iosfwd>
#ifdef USE_ICU
#include <unicode/stringpiece.h>
U_NAMESPACE_BEGIN
#else
#include <algorithm>
#include <ostream>
#include <string>
#include <string.h>
class StringPiece {
public:
typedef size_t size_type;
private:
const char* ptr_;
size_type length_;
public:
// We provide non-explicit singleton constructors so users can pass
// in a "const char*" or a "string" wherever a "StringPiece" is
// expected.
StringPiece() : ptr_(NULL), length_(0) { }
StringPiece(const char* str)
: ptr_(str), length_((str == NULL) ? 0 : strlen(str)) { }
StringPiece(const std::string& str)
: ptr_(str.data()), length_(str.size()) { }
StringPiece(const char* offset, size_type len)
: ptr_(offset), length_(len) { }
// data() may return a pointer to a buffer with embedded NULs, and the
// returned buffer may or may not be null terminated. Therefore it is
// typically a mistake to pass data() to a routine that expects a NUL
// terminated string.
const char* data() const { return ptr_; }
size_type size() const { return length_; }
size_type length() const { return length_; }
bool empty() const { return length_ == 0; }
void clear() { ptr_ = NULL; length_ = 0; }
void set(const char* data, size_type len) { ptr_ = data; length_ = len; }
void set(const char* str) {
ptr_ = str;
length_ = str ? strlen(str) : 0;
}
void set(const void* data, size_type len) {
ptr_ = reinterpret_cast<const char*>(data);
length_ = len;
}
char operator[](size_type i) const { return ptr_[i]; }
void remove_prefix(size_type n) {
ptr_ += n;
length_ -= n;
}
void remove_suffix(size_type n) {
length_ -= n;
}
int compare(const StringPiece& x) const {
int r = wordmemcmp(ptr_, x.ptr_, std::min(length_, x.length_));
if (r == 0) {
if (length_ < x.length_) r = -1;
else if (length_ > x.length_) r = +1;
}
return r;
}
std::string as_string() const {
// std::string doesn't like to take a NULL pointer even with a 0 size.
return std::string(!empty() ? data() : "", size());
}
void CopyToString(std::string* target) const;
void AppendToString(std::string* target) const;
// Does "this" start with "x"
bool starts_with(const StringPiece& x) const {
return ((length_ >= x.length_) &&
(wordmemcmp(ptr_, x.ptr_, x.length_) == 0));
}
// Does "this" end with "x"
bool ends_with(const StringPiece& x) const {
return ((length_ >= x.length_) &&
(wordmemcmp(ptr_ + (length_-x.length_), x.ptr_, x.length_) == 0));
}
// standard STL container boilerplate
typedef char value_type;
typedef const char* pointer;
typedef const char& reference;
typedef const char& const_reference;
typedef ptrdiff_t difference_type;
static const size_type npos;
typedef const char* const_iterator;
typedef const char* iterator;
typedef std::reverse_iterator<const_iterator> const_reverse_iterator;
typedef std::reverse_iterator<iterator> reverse_iterator;
iterator begin() const { return ptr_; }
iterator end() const { return ptr_ + length_; }
const_reverse_iterator rbegin() const {
return const_reverse_iterator(ptr_ + length_);
}
const_reverse_iterator rend() const {
return const_reverse_iterator(ptr_);
}
size_type max_size() const { return length_; }
size_type capacity() const { return length_; }
size_type copy(char* buf, size_type n, size_type pos = 0) const;
size_type find(const StringPiece& s, size_type pos = 0) const;
size_type find(char c, size_type pos = 0) const;
size_type rfind(const StringPiece& s, size_type pos = npos) const;
size_type rfind(char c, size_type pos = npos) const;
size_type find_first_of(const StringPiece& s, size_type pos = 0) const;
size_type find_first_of(char c, size_type pos = 0) const {
return find(c, pos);
}
size_type find_first_not_of(const StringPiece& s, size_type pos = 0) const;
size_type find_first_not_of(char c, size_type pos = 0) const;
size_type find_last_of(const StringPiece& s, size_type pos = npos) const;
size_type find_last_of(char c, size_type pos = npos) const {
return rfind(c, pos);
}
size_type find_last_not_of(const StringPiece& s, size_type pos = npos) const;
size_type find_last_not_of(char c, size_type pos = npos) const;
StringPiece substr(size_type pos, size_type n = npos) const;
static int wordmemcmp(const char* p, const char* p2, size_type N) {
return memcmp(p, p2, N);
}
};
inline bool operator==(const StringPiece& x, const StringPiece& y) {
if (x.size() != y.size())
return false;
return std::memcmp(x.data(), y.data(), x.size()) == 0;
}
inline bool operator!=(const StringPiece& x, const StringPiece& y) {
return !(x == y);
}
#endif
inline bool operator<(const StringPiece& x, const StringPiece& y) {
const int r = std::memcmp(x.data(), y.data(),
@ -73,6 +230,7 @@ inline bool operator>=(const StringPiece& x, const StringPiece& y) {
// allow StringPiece to be logged (needed for unit testing).
extern std::ostream& operator<<(std::ostream& o, const StringPiece& piece);
#ifdef USE_BOOST
size_t hash_value(const StringPiece &str);
/* Support for lookup of StringPiece in boost::unordered_map<std::string> */
@ -81,6 +239,7 @@ struct StringPieceCompatibleHash : public std::unary_function<const StringPiece
return hash_value(str);
}
};
struct StringPieceCompatibleEquals : public std::binary_function<const StringPiece &, const std::string &, bool> {
bool operator()(const StringPiece &first, const StringPiece &second) const {
return first == second;
@ -92,8 +251,10 @@ template <class T> typename T::const_iterator FindStringPiece(const T &t, const
template <class T> typename T::iterator FindStringPiece(T &t, const StringPiece &key) {
return t.find(key, StringPieceCompatibleHash(), StringPieceCompatibleEquals());
}
#endif
#ifdef USE_ICU
U_NAMESPACE_END
#endif
#endif // BASE_STRING_PIECE_H__

View File

@ -1,69 +0,0 @@
#ifndef UTIL_TOKENIZE_PIECE__
#define UTIL_TOKENIZE_PIECE__
#include "util/string_piece.hh"
#include <boost/iterator/iterator_facade.hpp>
/* Usage:
*
* for (PieceIterator<' '> i(" foo \r\n bar "); i; ++i) {
* std::cout << *i << "\n";
* }
*
*/
namespace util {
// Tokenize a StringPiece using an iterator interface. boost::tokenizer doesn't work with StringPiece.
template <char d> class PieceIterator : public boost::iterator_facade<PieceIterator<d>, const StringPiece, boost::forward_traversal_tag> {
public:
// Default construct is end, which is also accessed by kEndPieceIterator;
PieceIterator() {}
explicit PieceIterator(const StringPiece &str)
: after_(str) {
increment();
}
bool operator!() const {
return after_.data() == 0;
}
operator bool() const {
return after_.data() != 0;
}
static PieceIterator<d> end() {
return PieceIterator<d>();
}
private:
friend class boost::iterator_core_access;
void increment() {
const char *start = after_.data();
for (; (start != after_.data() + after_.size()) && (d == *start); ++start) {}
if (start == after_.data() + after_.size()) {
// End condition.
after_.clear();
return;
}
const char *finish = start;
for (; (finish != after_.data() + after_.size()) && (d != *finish); ++finish) {}
current_ = StringPiece(start, finish - start);
after_ = StringPiece(finish, after_.data() + after_.size() - finish);
}
bool equal(const PieceIterator &other) const {
return after_.data() == other.after_.data();
}
const StringPiece &dereference() const { return current_; }
StringPiece current_;
StringPiece after_;
};
} // namespace util
#endif // UTIL_TOKENIZE_PIECE__

View File

@ -1,57 +0,0 @@
#include "util/tokenize_piece.hh"
#include "util/string_piece.hh"
#define BOOST_TEST_MODULE TokenIteratorTest
#include <boost/test/unit_test.hpp>
namespace util {
namespace {
BOOST_AUTO_TEST_CASE(simple) {
PieceIterator<' '> it("single spaced words.");
BOOST_REQUIRE(it);
BOOST_CHECK_EQUAL(StringPiece("single"), *it);
++it;
BOOST_REQUIRE(it);
BOOST_CHECK_EQUAL(StringPiece("spaced"), *it);
++it;
BOOST_REQUIRE(it);
BOOST_CHECK_EQUAL(StringPiece("words."), *it);
++it;
BOOST_CHECK(!it);
}
BOOST_AUTO_TEST_CASE(null_delimiter) {
const char str[] = "\0first\0\0second\0\0\0third\0fourth\0\0\0";
PieceIterator<'\0'> it(StringPiece(str, sizeof(str) - 1));
BOOST_REQUIRE(it);
BOOST_CHECK_EQUAL(StringPiece("first"), *it);
++it;
BOOST_REQUIRE(it);
BOOST_CHECK_EQUAL(StringPiece("second"), *it);
++it;
BOOST_REQUIRE(it);
BOOST_CHECK_EQUAL(StringPiece("third"), *it);
++it;
BOOST_REQUIRE(it);
BOOST_CHECK_EQUAL(StringPiece("fourth"), *it);
++it;
BOOST_CHECK(!it);
}
BOOST_AUTO_TEST_CASE(null_entries) {
const char str[] = "\0split\0\0 \0me\0 ";
PieceIterator<' '> it(StringPiece(str, sizeof(str) - 1));
BOOST_REQUIRE(it);
const char first[] = "\0split\0\0";
BOOST_CHECK_EQUAL(StringPiece(first, sizeof(first) - 1), *it);
++it;
BOOST_REQUIRE(it);
const char second[] = "\0me\0";
BOOST_CHECK_EQUAL(StringPiece(second, sizeof(second) - 1), *it);
++it;
BOOST_CHECK(!it);
}
} // namespace
} // namespace util