Merged latest trunk into lane-syntax branch. This includes overwriting my branch's version of moses-parallel.pl with the one from trunk. I want to merge in the syntactic LM changes, and the moses-parallel changes can wait for another day. To get those changes, look at the lane-syntax branch before this commit.

git-svn-id: https://mosesdecoder.svn.sourceforge.net/svnroot/mosesdecoder/branches/lane-syntax@3948 1f5c12ca-751b-0410-a591-d2e778427230
This commit is contained in:
dowobeha 2011-04-14 17:22:34 +00:00
commit 75156a7486
515 changed files with 39237 additions and 35809 deletions

View File

@ -158,7 +158,14 @@
isa = PBXProject;
buildConfigurationList = 1DEB923508733DC60010E9CD /* Build configuration list for PBXProject "CreateOnDisk" */;
compatibilityVersion = "Xcode 3.1";
developmentRegion = English;
hasScannedForEncodings = 1;
knownRegions = (
English,
Japanese,
French,
German,
);
mainGroup = 08FB7794FE84155DC02AAC07 /* CreateOnDisk */;
projectDirPath = "";
projectReferences = (
@ -231,10 +238,10 @@
GCC_OPTIMIZATION_LEVEL = 0;
INSTALL_PATH = /usr/local/bin;
LIBRARY_SEARCH_PATHS = (
../irstlm/lib/i386,
../irstlm/lib,
../srilm/lib/macosx,
../randlm/lib,
../kenlm/lm,
../kenlm,
);
OTHER_LDFLAGS = (
"-lz",
@ -259,10 +266,10 @@
GCC_MODEL_TUNING = G5;
INSTALL_PATH = /usr/local/bin;
LIBRARY_SEARCH_PATHS = (
../irstlm/lib/i386,
../irstlm/lib,
../srilm/lib/macosx,
../randlm/lib,
../kenlm/lm,
../kenlm,
);
OTHER_LDFLAGS = (
"-lz",
@ -287,6 +294,8 @@
GCC_OPTIMIZATION_LEVEL = 0;
GCC_WARN_ABOUT_RETURN_TYPE = YES;
GCC_WARN_UNUSED_VARIABLE = YES;
HEADER_SEARCH_PATHS = ../../irstlm/include;
LIBRARY_SEARCH_PATHS = "";
ONLY_ACTIVE_ARCH = YES;
PREBINDING = NO;
SDKROOT = macosx10.6;
@ -300,6 +309,8 @@
GCC_C_LANGUAGE_STANDARD = gnu99;
GCC_WARN_ABOUT_RETURN_TYPE = YES;
GCC_WARN_UNUSED_VARIABLE = YES;
HEADER_SEARCH_PATHS = ../../irstlm/include;
LIBRARY_SEARCH_PATHS = "";
ONLY_ACTIVE_ARCH = YES;
PREBINDING = NO;
SDKROOT = macosx10.6;

View File

@ -2,17 +2,17 @@
/***********************************************************************
Moses - factored phrase-based, hierarchical and syntactic language decoder
Copyright (C) 2009 Hieu Hoang
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
@ -39,217 +39,202 @@ using namespace OnDiskPt;
int main (int argc, char * const argv[])
{
// insert code here...
Moses::ResetUserTime();
Moses::PrintUserTime("Starting");
assert(argc == 8);
int numSourceFactors = Moses::Scan<int>(argv[1])
, numTargetFactors = Moses::Scan<int>(argv[2])
, numScores = Moses::Scan<int>(argv[3])
, tableLimit = Moses::Scan<int>(argv[4]);
TargetPhraseCollection::s_sortScoreInd = Moses::Scan<int>(argv[5]);
const string filePath = argv[6]
,destPath = argv[7];
Moses::InputFileStream inStream(filePath);
OnDiskWrapper onDiskWrapper;
bool retDb = onDiskWrapper.BeginSave(destPath, numSourceFactors, numTargetFactors, numScores);
assert(retDb);
PhraseNode &rootNode = onDiskWrapper.GetRootSourceNode();
size_t lineNum = 0;
char line[100000];
// insert code here...
Moses::ResetUserTime();
Moses::PrintUserTime("Starting");
//while(getline(inStream, line))
while(inStream.getline(line, 100000))
{
lineNum++;
assert(argc == 8);
int numSourceFactors = Moses::Scan<int>(argv[1])
, numTargetFactors = Moses::Scan<int>(argv[2])
, numScores = Moses::Scan<int>(argv[3])
, tableLimit = Moses::Scan<int>(argv[4]);
TargetPhraseCollection::s_sortScoreInd = Moses::Scan<int>(argv[5]);
assert(TargetPhraseCollection::s_sortScoreInd < numScores);
const string filePath = argv[6]
,destPath = argv[7];
Moses::InputFileStream inStream(filePath);
OnDiskWrapper onDiskWrapper;
bool retDb = onDiskWrapper.BeginSave(destPath, numSourceFactors, numTargetFactors, numScores);
assert(retDb);
PhraseNode &rootNode = onDiskWrapper.GetRootSourceNode();
size_t lineNum = 0;
char line[100000];
//while(getline(inStream, line))
while(inStream.getline(line, 100000)) {
lineNum++;
if (lineNum%1000 == 0) cerr << "." << flush;
if (lineNum%10000 == 0) cerr << ":" << flush;
if (lineNum%100000 == 0) cerr << lineNum << flush;
//cerr << lineNum << " " << line << endl;
std::vector<float> misc(1);
SourcePhrase sourcePhrase;
TargetPhrase *targetPhrase = new TargetPhrase(numScores);
Tokenize(sourcePhrase, *targetPhrase, line, onDiskWrapper, numScores, misc);
assert(misc.size() == onDiskWrapper.GetNumCounts());
rootNode.AddTargetPhrase(sourcePhrase, targetPhrase, onDiskWrapper, tableLimit, misc);
}
rootNode.Save(onDiskWrapper, 0, tableLimit);
onDiskWrapper.EndSave();
Moses::PrintUserTime("Finished");
//pause();
return 0;
//cerr << lineNum << " " << line << endl;
std::vector<float> misc(1);
SourcePhrase sourcePhrase;
TargetPhrase *targetPhrase = new TargetPhrase(numScores);
Tokenize(sourcePhrase, *targetPhrase, line, onDiskWrapper, numScores, misc);
assert(misc.size() == onDiskWrapper.GetNumCounts());
rootNode.AddTargetPhrase(sourcePhrase, targetPhrase, onDiskWrapper, tableLimit, misc);
}
rootNode.Save(onDiskWrapper, 0, tableLimit);
onDiskWrapper.EndSave();
Moses::PrintUserTime("Finished");
//pause();
return 0;
} // main()
bool Flush(const OnDiskPt::SourcePhrase *prevSourcePhrase, const OnDiskPt::SourcePhrase *currSourcePhrase)
{
if (prevSourcePhrase == NULL)
return false;
assert(currSourcePhrase);
bool ret = (*currSourcePhrase > *prevSourcePhrase);
//cerr << *prevSourcePhrase << endl << *currSourcePhrase << " " << ret << endl << endl;
if (prevSourcePhrase == NULL)
return false;
return ret;
assert(currSourcePhrase);
bool ret = (*currSourcePhrase > *prevSourcePhrase);
//cerr << *prevSourcePhrase << endl << *currSourcePhrase << " " << ret << endl << endl;
return ret;
}
void Tokenize(SourcePhrase &sourcePhrase, TargetPhrase &targetPhrase, char *line, OnDiskWrapper &onDiskWrapper, int numScores, vector<float> &misc)
{
size_t scoreInd = 0;
// MAIN LOOP
size_t stage = 0;
/* 0 = source phrase
1 = target phrase
2 = scores
3 = align
4 = count
*/
char *tok = strtok (line," ");
while (tok != NULL)
{
if (0 == strcmp(tok, "|||"))
{
++stage;
}
else
{
switch (stage)
{
case 0:
{
Tokenize(sourcePhrase, tok, true, true, onDiskWrapper);
break;
}
case 1:
{
Tokenize(targetPhrase, tok, false, true, onDiskWrapper);
break;
}
case 2:
{
float score = Moses::Scan<float>(tok);
targetPhrase.SetScore(score, scoreInd);
++scoreInd;
break;
}
case 3:
{
targetPhrase.Create1AlignFromString(tok);
break;
}
case 4:
++stage;
break;
case 5:
{ // count info. Only store the 2nd one
float val = Moses::Scan<float>(tok);
misc[0] = val;
++stage;
break;
}
default:
assert(false);
break;
}
}
tok = strtok (NULL, " ");
} // while (tok != NULL)
assert(scoreInd == numScores);
targetPhrase.SortAlign();
size_t scoreInd = 0;
// MAIN LOOP
size_t stage = 0;
/* 0 = source phrase
1 = target phrase
2 = scores
3 = align
4 = count
*/
char *tok = strtok (line," ");
while (tok != NULL) {
if (0 == strcmp(tok, "|||")) {
++stage;
} else {
switch (stage) {
case 0: {
Tokenize(sourcePhrase, tok, true, true, onDiskWrapper);
break;
}
case 1: {
Tokenize(targetPhrase, tok, false, true, onDiskWrapper);
break;
}
case 2: {
float score = Moses::Scan<float>(tok);
targetPhrase.SetScore(score, scoreInd);
++scoreInd;
break;
}
case 3: {
targetPhrase.Create1AlignFromString(tok);
break;
}
case 4:
++stage;
break;
case 5: {
// count info. Only store the 2nd one
float val = Moses::Scan<float>(tok);
misc[0] = val;
++stage;
break;
}
default:
assert(false);
break;
}
}
tok = strtok (NULL, " ");
} // while (tok != NULL)
assert(scoreInd == numScores);
targetPhrase.SortAlign();
} // Tokenize()
void Tokenize(OnDiskPt::Phrase &phrase
, const std::string &token, bool addSourceNonTerm, bool addTargetNonTerm
, OnDiskPt::OnDiskWrapper &onDiskWrapper)
, const std::string &token, bool addSourceNonTerm, bool addTargetNonTerm
, OnDiskPt::OnDiskWrapper &onDiskWrapper)
{
bool nonTerm = false;
size_t tokSize = token.size();
int comStr =token.compare(0, 1, "[");
if (comStr == 0)
{
comStr = token.compare(tokSize - 1, 1, "]");
nonTerm = comStr == 0;
}
if (nonTerm)
{ // non-term
size_t splitPos = token.find_first_of("[", 2);
string wordStr = token.substr(0, splitPos);
if (splitPos == string::npos)
{ // lhs - only 1 word
Word *word = new Word();
word->CreateFromString(wordStr, onDiskWrapper.GetVocab());
phrase.AddWord(word);
}
else
{ // source & target non-terms
if (addSourceNonTerm)
{
Word *word = new Word();
word->CreateFromString(wordStr, onDiskWrapper.GetVocab());
phrase.AddWord(word);
}
wordStr = token.substr(splitPos, tokSize - splitPos);
if (addTargetNonTerm)
{
Word *word = new Word();
word->CreateFromString(wordStr, onDiskWrapper.GetVocab());
phrase.AddWord(word);
}
}
}
else
{ // term
Word *word = new Word();
word->CreateFromString(token, onDiskWrapper.GetVocab());
phrase.AddWord(word);
}
bool nonTerm = false;
size_t tokSize = token.size();
int comStr =token.compare(0, 1, "[");
if (comStr == 0) {
comStr = token.compare(tokSize - 1, 1, "]");
nonTerm = comStr == 0;
}
if (nonTerm) {
// non-term
size_t splitPos = token.find_first_of("[", 2);
string wordStr = token.substr(0, splitPos);
if (splitPos == string::npos) {
// lhs - only 1 word
Word *word = new Word();
word->CreateFromString(wordStr, onDiskWrapper.GetVocab());
phrase.AddWord(word);
} else {
// source & target non-terms
if (addSourceNonTerm) {
Word *word = new Word();
word->CreateFromString(wordStr, onDiskWrapper.GetVocab());
phrase.AddWord(word);
}
wordStr = token.substr(splitPos, tokSize - splitPos);
if (addTargetNonTerm) {
Word *word = new Word();
word->CreateFromString(wordStr, onDiskWrapper.GetVocab());
phrase.AddWord(word);
}
}
} else {
// term
Word *word = new Word();
word->CreateFromString(token, onDiskWrapper.GetVocab());
phrase.AddWord(word);
}
}
void InsertTargetNonTerminals(std::vector<std::string> &sourceToks, const std::vector<std::string> &targetToks, const ::AlignType &alignments)
{
for (int ind = alignments.size() - 1; ind >= 0; --ind)
{
const ::AlignPair &alignPair = alignments[ind];
size_t sourcePos = alignPair.first
,targetPos = alignPair.second;
const string &target = targetToks[targetPos];
sourceToks.insert(sourceToks.begin() + sourcePos + 1, target);
}
for (int ind = alignments.size() - 1; ind >= 0; --ind) {
const ::AlignPair &alignPair = alignments[ind];
size_t sourcePos = alignPair.first
,targetPos = alignPair.second;
const string &target = targetToks[targetPos];
sourceToks.insert(sourceToks.begin() + sourcePos + 1, target);
}
}
class AlignOrderer
{
public:
bool operator()(const ::AlignPair &a, const ::AlignPair &b) const
{
return a.first < b.first;
}
};
{
public:
bool operator()(const ::AlignPair &a, const ::AlignPair &b) const {
return a.first < b.first;
}
};
void SortAlign(::AlignType &alignments)
{
std::sort(alignments.begin(), alignments.end(), AlignOrderer());
std::sort(alignments.begin(), alignments.end(), AlignOrderer());
}

View File

@ -3,17 +3,17 @@
/***********************************************************************
Moses - factored phrase-based, hierarchical and syntactic language decoder
Copyright (C) 2009 Hieu Hoang
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
@ -26,12 +26,12 @@ typedef std::pair<size_t, size_t> AlignPair;
typedef std::vector<AlignPair> AlignType;
void Tokenize(OnDiskPt::Phrase &phrase
, const std::string &token, bool addSourceNonTerm, bool addTargetNonTerm
, OnDiskPt::OnDiskWrapper &onDiskWrapper);
, const std::string &token, bool addSourceNonTerm, bool addTargetNonTerm
, OnDiskPt::OnDiskWrapper &onDiskWrapper);
void Tokenize(OnDiskPt::SourcePhrase &sourcePhrase, OnDiskPt::TargetPhrase &targetPhrase
, char *line, OnDiskPt::OnDiskWrapper &onDiskWrapper
, int numScores
, std::vector<float> &misc);
, char *line, OnDiskPt::OnDiskWrapper &onDiskWrapper
, int numScores
, std::vector<float> &misc);
void InsertTargetNonTerminals(std::vector<std::string> &sourceToks, const std::vector<std::string> &targetToks, const AlignType &alignments);
void SortAlign(AlignType &alignments);

View File

@ -11,4 +11,4 @@ endif
if WITH_SERVER
SERVER = server
endif
SUBDIRS = kenlm moses/src moses-chart/src OnDiskPt/src moses-cmd/src misc moses-chart-cmd/src CreateOnDisk/src $(MERT) $(SERVER)
SUBDIRS = kenlm moses/src OnDiskPt/src moses-cmd/src misc moses-chart-cmd/src CreateOnDisk/src $(MERT) $(SERVER)

View File

@ -149,7 +149,14 @@
isa = PBXProject;
buildConfigurationList = 1DEB91EF08733DB70010E9CD /* Build configuration list for PBXProject "OnDiskPt" */;
compatibilityVersion = "Xcode 3.1";
developmentRegion = English;
hasScannedForEncodings = 1;
knownRegions = (
English,
Japanese,
French,
German,
);
mainGroup = 08FB7794FE84155DC02AAC07 /* OnDiskPt */;
projectDirPath = "";
projectRoot = "";

View File

@ -2,17 +2,17 @@
/***********************************************************************
Moses - factored phrase-based, hierarchical and syntactic language decoder
Copyright (C) 2009 Hieu Hoang
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
@ -41,194 +41,191 @@ OnDiskWrapper::~OnDiskWrapper()
bool OnDiskWrapper::BeginLoad(const std::string &filePath)
{
if (!OpenForLoad(filePath))
return false;
if (!m_vocab.Load(*this))
return false;
UINT64 rootFilePos = GetMisc("RootNodeOffset");
m_rootSourceNode = new PhraseNode(rootFilePos, *this);
if (!OpenForLoad(filePath))
return false;
return true;
if (!m_vocab.Load(*this))
return false;
UINT64 rootFilePos = GetMisc("RootNodeOffset");
m_rootSourceNode = new PhraseNode(rootFilePos, *this);
return true;
}
bool OnDiskWrapper::OpenForLoad(const std::string &filePath)
{
m_fileSource.open((filePath + "/Source.dat").c_str(), ios::in | ios::binary);
assert(m_fileSource.is_open());
m_fileTargetInd.open((filePath + "/TargetInd.dat").c_str(), ios::in | ios::binary);
assert(m_fileTargetInd.is_open());
m_fileTargetColl.open((filePath + "/TargetColl.dat").c_str(), ios::in | ios::binary);
assert(m_fileTargetColl.is_open());
m_fileVocab.open((filePath + "/Vocab.dat").c_str(), ios::in);
assert(m_fileVocab.is_open());
m_fileMisc.open((filePath + "/Misc.dat").c_str(), ios::in);
assert(m_fileMisc.is_open());
// set up root node
LoadMisc();
m_numSourceFactors = GetMisc("NumSourceFactors");
m_numTargetFactors = GetMisc("NumTargetFactors");
m_numScores = GetMisc("NumScores");
return true;
m_fileSource.open((filePath + "/Source.dat").c_str(), ios::in | ios::binary);
assert(m_fileSource.is_open());
m_fileTargetInd.open((filePath + "/TargetInd.dat").c_str(), ios::in | ios::binary);
assert(m_fileTargetInd.is_open());
m_fileTargetColl.open((filePath + "/TargetColl.dat").c_str(), ios::in | ios::binary);
assert(m_fileTargetColl.is_open());
m_fileVocab.open((filePath + "/Vocab.dat").c_str(), ios::in);
assert(m_fileVocab.is_open());
m_fileMisc.open((filePath + "/Misc.dat").c_str(), ios::in);
assert(m_fileMisc.is_open());
// set up root node
LoadMisc();
m_numSourceFactors = GetMisc("NumSourceFactors");
m_numTargetFactors = GetMisc("NumTargetFactors");
m_numScores = GetMisc("NumScores");
return true;
}
bool OnDiskWrapper::LoadMisc()
{
char line[100000];
while(m_fileMisc.getline(line, 100000))
{
vector<string> tokens;
Moses::Tokenize(tokens, line);
assert(tokens.size() == 2);
const string &key = tokens[0];
m_miscInfo[key] = Moses::Scan<UINT64>(tokens[1]);
}
return true;
char line[100000];
while(m_fileMisc.getline(line, 100000)) {
vector<string> tokens;
Moses::Tokenize(tokens, line);
assert(tokens.size() == 2);
const string &key = tokens[0];
m_miscInfo[key] = Moses::Scan<UINT64>(tokens[1]);
}
return true;
}
bool OnDiskWrapper::BeginSave(const std::string &filePath
, int numSourceFactors, int numTargetFactors, int numScores)
, int numSourceFactors, int numTargetFactors, int numScores)
{
m_numSourceFactors = numSourceFactors;
m_numTargetFactors = numTargetFactors;
m_numScores = numScores;
m_filePath = filePath;
m_numSourceFactors = numSourceFactors;
m_numTargetFactors = numTargetFactors;
m_numScores = numScores;
m_filePath = filePath;
#ifdef WIN32
mkdir(filePath.c_str());
mkdir(filePath.c_str());
#else
mkdir(filePath.c_str(), 0777);
mkdir(filePath.c_str(), 0777);
#endif
m_fileSource.open((filePath + "/Source.dat").c_str(), ios::out | ios::in | ios::binary | ios::ate | ios::trunc);
assert(m_fileSource.is_open());
m_fileTargetInd.open((filePath + "/TargetInd.dat").c_str(), ios::out | ios::binary | ios::ate | ios::trunc);
assert(m_fileTargetInd.is_open());
m_fileSource.open((filePath + "/Source.dat").c_str(), ios::out | ios::in | ios::binary | ios::ate | ios::trunc);
assert(m_fileSource.is_open());
m_fileTargetColl.open((filePath + "/TargetColl.dat").c_str(), ios::out | ios::binary | ios::ate | ios::trunc);
assert(m_fileTargetColl.is_open());
m_fileTargetInd.open((filePath + "/TargetInd.dat").c_str(), ios::out | ios::binary | ios::ate | ios::trunc);
assert(m_fileTargetInd.is_open());
m_fileVocab.open((filePath + "/Vocab.dat").c_str(), ios::out | ios::ate | ios::trunc);
assert(m_fileVocab.is_open());
m_fileTargetColl.open((filePath + "/TargetColl.dat").c_str(), ios::out | ios::binary | ios::ate | ios::trunc);
assert(m_fileTargetColl.is_open());
m_fileMisc.open((filePath + "/Misc.dat").c_str(), ios::out | ios::ate | ios::trunc);
assert(m_fileMisc.is_open());
m_fileVocab.open((filePath + "/Vocab.dat").c_str(), ios::out | ios::ate | ios::trunc);
assert(m_fileVocab.is_open());
// offset by 1. 0 offset is reserved
char c = 0xff;
m_fileSource.write(&c, 1);
assert(1 == m_fileSource.tellp());
m_fileTargetInd.write(&c, 1);
assert(1 == m_fileTargetInd.tellp());
m_fileMisc.open((filePath + "/Misc.dat").c_str(), ios::out | ios::ate | ios::trunc);
assert(m_fileMisc.is_open());
m_fileTargetColl.write(&c, 1);
assert(1 == m_fileTargetColl.tellp());
// offset by 1. 0 offset is reserved
char c = 0xff;
m_fileSource.write(&c, 1);
assert(1 == m_fileSource.tellp());
// set up root node
assert(GetNumCounts() == 1);
vector<float> counts(GetNumCounts());
counts[0] = DEFAULT_COUNT;
m_rootSourceNode = new PhraseNode();
m_rootSourceNode->AddCounts(counts);
m_fileTargetInd.write(&c, 1);
assert(1 == m_fileTargetInd.tellp());
return true;
m_fileTargetColl.write(&c, 1);
assert(1 == m_fileTargetColl.tellp());
// set up root node
assert(GetNumCounts() == 1);
vector<float> counts(GetNumCounts());
counts[0] = DEFAULT_COUNT;
m_rootSourceNode = new PhraseNode();
m_rootSourceNode->AddCounts(counts);
return true;
}
void OnDiskWrapper::EndSave()
{
bool ret = m_rootSourceNode->Saved();
assert(ret);
bool ret = m_rootSourceNode->Saved();
assert(ret);
GetVocab().Save(*this);
SaveMisc();
GetVocab().Save(*this);
m_fileMisc.close();
m_fileVocab.close();
m_fileSource.close();
m_fileTarget.close();
m_fileTargetInd.close();
m_fileTargetColl.close();
SaveMisc();
m_fileMisc.close();
m_fileVocab.close();
m_fileSource.close();
m_fileTarget.close();
m_fileTargetInd.close();
m_fileTargetColl.close();
}
void OnDiskWrapper::SaveMisc()
{
m_fileMisc << "Version 3" << endl;
m_fileMisc << "NumSourceFactors " << m_numSourceFactors << endl;
m_fileMisc << "NumTargetFactors " << m_numTargetFactors << endl;
m_fileMisc << "NumScores " << m_numScores << endl;
m_fileMisc << "RootNodeOffset " << m_rootSourceNode->GetFilePos() << endl;
m_fileMisc << "Version 3" << endl;
m_fileMisc << "NumSourceFactors " << m_numSourceFactors << endl;
m_fileMisc << "NumTargetFactors " << m_numTargetFactors << endl;
m_fileMisc << "NumScores " << m_numScores << endl;
m_fileMisc << "RootNodeOffset " << m_rootSourceNode->GetFilePos() << endl;
}
size_t OnDiskWrapper::GetSourceWordSize() const
{
return m_numSourceFactors * sizeof(UINT64) + sizeof(char);
return m_numSourceFactors * sizeof(UINT64) + sizeof(char);
}
size_t OnDiskWrapper::GetTargetWordSize() const
{
return m_numTargetFactors * sizeof(UINT64) + sizeof(char);
return m_numTargetFactors * sizeof(UINT64) + sizeof(char);
}
UINT64 OnDiskWrapper::GetMisc(const std::string &key) const
{
std::map<std::string, UINT64>::const_iterator iter;
iter = m_miscInfo.find(key);
assert(iter != m_miscInfo.end());
return iter->second;
std::map<std::string, UINT64>::const_iterator iter;
iter = m_miscInfo.find(key);
assert(iter != m_miscInfo.end());
return iter->second;
}
PhraseNode &OnDiskWrapper::GetRootSourceNode()
{ return *m_rootSourceNode; }
Word *OnDiskWrapper::ConvertFromMoses(Moses::FactorDirection direction
, const std::vector<Moses::FactorType> &factorsVec
, const Moses::Word &origWord) const
{
bool isNonTerminal = origWord.IsNonTerminal();
Word *newWord = new Word(1, isNonTerminal); // TODO - num of factors
for (size_t ind = 0 ; ind < factorsVec.size() ; ++ind)
{
size_t factorType = factorsVec[ind];
const Moses::Factor *factor = origWord.GetFactor(factorType);
assert(factor);
string str = factor->GetString();
if (isNonTerminal)
{
str = "[" + str + "]";
}
bool found;
UINT64 vocabId = m_vocab.GetVocabId(str, found);
if (!found)
{ // factor not in phrase table -> phrse definately not in. exit
delete newWord;
return NULL;
}
else
{
newWord->SetVocabId(ind, vocabId);
}
} // for (size_t factorType
return newWord;
return *m_rootSourceNode;
}
Word *OnDiskWrapper::ConvertFromMoses(Moses::FactorDirection /* direction */
, const std::vector<Moses::FactorType> &factorsVec
, const Moses::Word &origWord) const
{
bool isNonTerminal = origWord.IsNonTerminal();
Word *newWord = new Word(1, isNonTerminal); // TODO - num of factors
for (size_t ind = 0 ; ind < factorsVec.size() ; ++ind) {
size_t factorType = factorsVec[ind];
const Moses::Factor *factor = origWord.GetFactor(factorType);
assert(factor);
string str = factor->GetString();
if (isNonTerminal) {
str = "[" + str + "]";
}
bool found;
UINT64 vocabId = m_vocab.GetVocabId(str, found);
if (!found) {
// factor not in phrase table -> phrse definately not in. exit
delete newWord;
return NULL;
} else {
newWord->SetVocabId(ind, vocabId);
}
} // for (size_t factorType
return newWord;
}
}

View File

@ -3,17 +3,17 @@
/***********************************************************************
Moses - factored phrase-based, hierarchical and syntactic language decoder
Copyright (C) 2009 Hieu Hoang
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
@ -27,66 +27,75 @@
namespace OnDiskPt
{
const float DEFAULT_COUNT = 66666;
class OnDiskWrapper
{
protected:
Vocab m_vocab;
std::string m_filePath;
int m_numSourceFactors, m_numTargetFactors, m_numScores;
std::fstream m_fileMisc, m_fileVocab, m_fileSource, m_fileTarget, m_fileTargetInd, m_fileTargetColl;
Vocab m_vocab;
std::string m_filePath;
int m_numSourceFactors, m_numTargetFactors, m_numScores;
std::fstream m_fileMisc, m_fileVocab, m_fileSource, m_fileTarget, m_fileTargetInd, m_fileTargetColl;
size_t m_defaultNodeSize;
PhraseNode *m_rootSourceNode;
size_t m_defaultNodeSize;
PhraseNode *m_rootSourceNode;
std::map<std::string, UINT64> m_miscInfo;
void SaveMisc();
bool OpenForLoad(const std::string &filePath);
bool LoadMisc();
std::map<std::string, UINT64> m_miscInfo;
void SaveMisc();
bool OpenForLoad(const std::string &filePath);
bool LoadMisc();
public:
OnDiskWrapper();
~OnDiskWrapper();
OnDiskWrapper();
~OnDiskWrapper();
bool BeginLoad(const std::string &filePath);
bool BeginLoad(const std::string &filePath);
bool BeginSave(const std::string &filePath
, int numSourceFactors, int numTargetFactors, int numScores);
void EndSave();
Vocab &GetVocab()
{ return m_vocab; }
size_t GetSourceWordSize() const;
size_t GetTargetWordSize() const;
std::fstream &GetFileSource()
{ return m_fileSource; }
std::fstream &GetFileTargetInd()
{ return m_fileTargetInd; }
std::fstream &GetFileTargetColl()
{ return m_fileTargetColl; }
std::fstream &GetFileVocab()
{ return m_fileVocab; }
size_t GetNumSourceFactors() const
{ return m_numSourceFactors; }
size_t GetNumTargetFactors() const
{ return m_numTargetFactors; }
size_t GetNumScores() const
{ return m_numScores; }
size_t GetNumCounts() const
{ return 1; }
PhraseNode &GetRootSourceNode();
UINT64 GetMisc(const std::string &key) const;
bool BeginSave(const std::string &filePath
, int numSourceFactors, int numTargetFactors, int numScores);
void EndSave();
Word *ConvertFromMoses(Moses::FactorDirection direction
, const std::vector<Moses::FactorType> &factorsVec
, const Moses::Word &origWord) const;
Vocab &GetVocab() {
return m_vocab;
}
size_t GetSourceWordSize() const;
size_t GetTargetWordSize() const;
std::fstream &GetFileSource() {
return m_fileSource;
}
std::fstream &GetFileTargetInd() {
return m_fileTargetInd;
}
std::fstream &GetFileTargetColl() {
return m_fileTargetColl;
}
std::fstream &GetFileVocab() {
return m_fileVocab;
}
size_t GetNumSourceFactors() const {
return m_numSourceFactors;
}
size_t GetNumTargetFactors() const {
return m_numTargetFactors;
}
size_t GetNumScores() const {
return m_numScores;
}
size_t GetNumCounts() const {
return 1;
}
PhraseNode &GetRootSourceNode();
UINT64 GetMisc(const std::string &key) const;
Word *ConvertFromMoses(Moses::FactorDirection direction
, const std::vector<Moses::FactorType> &factorsVec
, const Moses::Word &origWord) const;
};

View File

@ -2,17 +2,17 @@
/***********************************************************************
Moses - factored phrase-based, hierarchical and syntactic language decoder
Copyright (C) 2009 Hieu Hoang
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
@ -28,90 +28,85 @@ namespace OnDiskPt
{
Phrase::Phrase(const Phrase &copy)
:m_words(copy.GetSize())
:m_words(copy.GetSize())
{
for (size_t pos = 0; pos < copy.GetSize(); ++pos)
{
const Word &oldWord = copy.GetWord(pos);
Word *newWord = new Word(oldWord);
m_words[pos] = newWord;
}
for (size_t pos = 0; pos < copy.GetSize(); ++pos) {
const Word &oldWord = copy.GetWord(pos);
Word *newWord = new Word(oldWord);
m_words[pos] = newWord;
}
}
Phrase::~Phrase()
{
Moses::RemoveAllInColl(m_words);
Moses::RemoveAllInColl(m_words);
}
void Phrase::AddWord(Word *word)
{
m_words.push_back(word);
m_words.push_back(word);
}
void Phrase::AddWord(Word *word, size_t pos)
{
assert(pos < m_words.size());
m_words.insert(m_words.begin() + pos + 1, word);
assert(pos < m_words.size());
m_words.insert(m_words.begin() + pos + 1, word);
}
int Phrase::Compare(const Phrase &compare) const
{
int ret = 0;
for (size_t pos = 0; pos < GetSize(); ++pos)
{
if (pos >= compare.GetSize())
{ // we're bigger than the other. Put 1st
ret = -1;
break;
}
const Word &thisWord = GetWord(pos)
,&compareWord = compare.GetWord(pos);
int wordRet = thisWord.Compare(compareWord);
if (wordRet != 0)
{
ret = wordRet;
break;
}
}
int ret = 0;
for (size_t pos = 0; pos < GetSize(); ++pos) {
if (pos >= compare.GetSize()) {
// we're bigger than the other. Put 1st
ret = -1;
break;
}
if (ret == 0)
{
assert(compare.GetSize() >= GetSize());
ret = (compare.GetSize() > GetSize()) ? 1 : 0;
}
return ret;
const Word &thisWord = GetWord(pos)
,&compareWord = compare.GetWord(pos);
int wordRet = thisWord.Compare(compareWord);
if (wordRet != 0) {
ret = wordRet;
break;
}
}
if (ret == 0) {
assert(compare.GetSize() >= GetSize());
ret = (compare.GetSize() > GetSize()) ? 1 : 0;
}
return ret;
}
//! transitive comparison
bool Phrase::operator<(const Phrase &compare) const
{
int ret = Compare(compare);
return ret < 0;
{
int ret = Compare(compare);
return ret < 0;
}
bool Phrase::operator>(const Phrase &compare) const
{
int ret = Compare(compare);
return ret > 0;
{
int ret = Compare(compare);
return ret > 0;
}
bool Phrase::operator==(const Phrase &compare) const
{
int ret = Compare(compare);
return ret == 0;
{
int ret = Compare(compare);
return ret == 0;
}
std::ostream& operator<<(std::ostream &out, const Phrase &phrase)
{
for (size_t pos = 0; pos < phrase.GetSize(); ++pos)
{
const Word &word = phrase.GetWord(pos);
out << word << " ";
}
return out;
}
for (size_t pos = 0; pos < phrase.GetSize(); ++pos) {
const Word &word = phrase.GetWord(pos);
out << word << " ";
}
return out;
}
}

View File

@ -3,17 +3,17 @@
/***********************************************************************
Moses - factored phrase-based, hierarchical and syntactic language decoder
Copyright (C) 2009 Hieu Hoang
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
@ -27,29 +27,31 @@ namespace OnDiskPt
class Phrase
{
friend std::ostream& operator<<(std::ostream&, const Phrase&);
friend std::ostream& operator<<(std::ostream&, const Phrase&);
protected:
std::vector<Word*> m_words;
std::vector<Word*> m_words;
public:
Phrase()
{}
Phrase(const Phrase &copy);
virtual ~Phrase();
void AddWord(Word *word);
void AddWord(Word *word, size_t pos);
Phrase()
{}
Phrase(const Phrase &copy);
virtual ~Phrase();
const Word &GetWord(size_t pos) const
{ return *m_words[pos]; }
size_t GetSize() const
{ return m_words.size(); }
void AddWord(Word *word);
void AddWord(Word *word, size_t pos);
int Compare(const Phrase &compare) const;
bool operator<(const Phrase &compare) const;
bool operator>(const Phrase &compare) const;
bool operator==(const Phrase &compare) const;
const Word &GetWord(size_t pos) const {
return *m_words[pos];
}
size_t GetSize() const {
return m_words.size();
}
int Compare(const Phrase &compare) const;
bool operator<(const Phrase &compare) const;
bool operator>(const Phrase &compare) const;
bool operator==(const Phrase &compare) const;
};
}

View File

@ -2,17 +2,17 @@
/***********************************************************************
Moses - factored phrase-based, hierarchical and syntactic language decoder
Copyright (C) 2009 Hieu Hoang
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
@ -31,245 +31,241 @@ namespace OnDiskPt
size_t PhraseNode::GetNodeSize(size_t numChildren, size_t wordSize, size_t countSize)
{
size_t ret = sizeof(UINT64) * 2 // num children, value
+ (wordSize + sizeof(UINT64)) * numChildren // word + ptr to next source node
+ sizeof(float) * countSize; // count info
return ret;
size_t ret = sizeof(UINT64) * 2 // num children, value
+ (wordSize + sizeof(UINT64)) * numChildren // word + ptr to next source node
+ sizeof(float) * countSize; // count info
return ret;
}
PhraseNode::PhraseNode()
:m_currChild(NULL)
,m_saved(false)
,m_memLoad(NULL)
,m_value(0)
:m_currChild(NULL)
,m_saved(false)
,m_memLoad(NULL)
,m_value(0)
{
}
PhraseNode::PhraseNode(UINT64 filePos, OnDiskWrapper &onDiskWrapper)
:m_counts(onDiskWrapper.GetNumCounts())
{ // load saved node
m_filePos = filePos;
size_t countSize = onDiskWrapper.GetNumCounts();
std::fstream &file = onDiskWrapper.GetFileSource();
file.seekg(filePos);
assert(filePos == file.tellg());
file.read((char*) &m_numChildrenLoad, sizeof(UINT64));
size_t memAlloc = GetNodeSize(m_numChildrenLoad, onDiskWrapper.GetSourceWordSize(), countSize);
m_memLoad = (char*) malloc(memAlloc);
// go to start of node again
file.seekg(filePos);
assert(filePos == file.tellg());
:m_counts(onDiskWrapper.GetNumCounts())
{
// load saved node
m_filePos = filePos;
// read everything into memory
file.read(m_memLoad, memAlloc);
assert(filePos + memAlloc == file.tellg());
// get value
m_value = ((UINT64*)m_memLoad)[1];
// get counts
float *memFloat = (float*) (m_memLoad + sizeof(UINT64) * 2);
size_t countSize = onDiskWrapper.GetNumCounts();
assert(countSize == 1);
m_counts[0] = memFloat[0];
m_memLoadLast = m_memLoad + memAlloc;
std::fstream &file = onDiskWrapper.GetFileSource();
file.seekg(filePos);
assert(filePos == file.tellg());
file.read((char*) &m_numChildrenLoad, sizeof(UINT64));
size_t memAlloc = GetNodeSize(m_numChildrenLoad, onDiskWrapper.GetSourceWordSize(), countSize);
m_memLoad = (char*) malloc(memAlloc);
// go to start of node again
file.seekg(filePos);
assert(filePos == file.tellg());
// read everything into memory
file.read(m_memLoad, memAlloc);
assert(filePos + memAlloc == file.tellg());
// get value
m_value = ((UINT64*)m_memLoad)[1];
// get counts
float *memFloat = (float*) (m_memLoad + sizeof(UINT64) * 2);
assert(countSize == 1);
m_counts[0] = memFloat[0];
m_memLoadLast = m_memLoad + memAlloc;
}
PhraseNode::~PhraseNode()
{
free(m_memLoad);
//assert(m_saved);
free(m_memLoad);
//assert(m_saved);
}
float PhraseNode::GetCount(size_t ind) const
{ return m_counts[ind]; }
{
return m_counts[ind];
}
void PhraseNode::Save(OnDiskWrapper &onDiskWrapper, size_t pos, size_t tableLimit)
{
assert(!m_saved);
assert(!m_saved);
// save this node
m_targetPhraseColl.Sort(tableLimit);
m_targetPhraseColl.Save(onDiskWrapper);
m_value = m_targetPhraseColl.GetFilePos();
size_t numCounts = onDiskWrapper.GetNumCounts();
size_t memAlloc = GetNodeSize(GetSize(), onDiskWrapper.GetSourceWordSize(), numCounts);
char *mem = (char*) malloc(memAlloc);
//memset(mem, 0xfe, memAlloc);
size_t memUsed = 0;
UINT64 *memArray = (UINT64*) mem;
memArray[0] = GetSize(); // num of children
memArray[1] = m_value; // file pos of corresponding target phrases
memUsed += 2 * sizeof(UINT64);
// count info
float *memFloat = (float*) (mem + memUsed);
assert(numCounts == 1);
memFloat[0] = (m_counts.size() == 0) ? DEFAULT_COUNT : m_counts[0]; // if count = 0, put in very large num to make sure its still used. HACK
memUsed += sizeof(float) * numCounts;
// recursively save chm_countsildren
ChildColl::iterator iter;
for (iter = m_children.begin(); iter != m_children.end(); ++iter)
{
const Word &childWord = iter->first;
PhraseNode &childNode = iter->second;
// save this node
m_targetPhraseColl.Sort(tableLimit);
m_targetPhraseColl.Save(onDiskWrapper);
m_value = m_targetPhraseColl.GetFilePos();
// recursive
if (!childNode.Saved())
childNode.Save(onDiskWrapper, pos + 1, tableLimit);
size_t numCounts = onDiskWrapper.GetNumCounts();
char *currMem = mem + memUsed;
size_t wordMemUsed = childWord.WriteToMemory(currMem);
memUsed += wordMemUsed;
size_t memAlloc = GetNodeSize(GetSize(), onDiskWrapper.GetSourceWordSize(), numCounts);
char *mem = (char*) malloc(memAlloc);
//memset(mem, 0xfe, memAlloc);
UINT64 *memArray = (UINT64*) (mem + memUsed);
memArray[0] = childNode.GetFilePos();
memUsed += sizeof(UINT64);
}
size_t memUsed = 0;
UINT64 *memArray = (UINT64*) mem;
memArray[0] = GetSize(); // num of children
memArray[1] = m_value; // file pos of corresponding target phrases
memUsed += 2 * sizeof(UINT64);
// save this node
//Moses::DebugMem(mem, memAlloc);
assert(memUsed == memAlloc);
// count info
float *memFloat = (float*) (mem + memUsed);
assert(numCounts == 1);
memFloat[0] = (m_counts.size() == 0) ? DEFAULT_COUNT : m_counts[0]; // if count = 0, put in very large num to make sure its still used. HACK
memUsed += sizeof(float) * numCounts;
std::fstream &file = onDiskWrapper.GetFileSource();
m_filePos = file.tellp();
file.seekp(0, ios::end);
file.write(mem, memUsed);
// recursively save chm_countsildren
ChildColl::iterator iter;
for (iter = m_children.begin(); iter != m_children.end(); ++iter) {
const Word &childWord = iter->first;
PhraseNode &childNode = iter->second;
UINT64 endPos = file.tellp();
assert(m_filePos + memUsed == endPos);
// recursive
if (!childNode.Saved())
childNode.Save(onDiskWrapper, pos + 1, tableLimit);
free(mem);
char *currMem = mem + memUsed;
size_t wordMemUsed = childWord.WriteToMemory(currMem);
memUsed += wordMemUsed;
m_children.clear();
m_saved = true;
UINT64 *memArray = (UINT64*) (mem + memUsed);
memArray[0] = childNode.GetFilePos();
memUsed += sizeof(UINT64);
}
// save this node
//Moses::DebugMem(mem, memAlloc);
assert(memUsed == memAlloc);
std::fstream &file = onDiskWrapper.GetFileSource();
m_filePos = file.tellp();
file.seekp(0, ios::end);
file.write(mem, memUsed);
UINT64 endPos = file.tellp();
assert(m_filePos + memUsed == endPos);
free(mem);
m_children.clear();
m_saved = true;
}
void PhraseNode::AddTargetPhrase(const SourcePhrase &sourcePhrase, TargetPhrase *targetPhrase
, OnDiskWrapper &onDiskWrapper, size_t tableLimit
, const std::vector<float> &counts)
, OnDiskWrapper &onDiskWrapper, size_t tableLimit
, const std::vector<float> &counts)
{
AddTargetPhrase(0, sourcePhrase, targetPhrase, onDiskWrapper, tableLimit, counts);
AddTargetPhrase(0, sourcePhrase, targetPhrase, onDiskWrapper, tableLimit, counts);
}
void PhraseNode::AddTargetPhrase(size_t pos, const SourcePhrase &sourcePhrase
, TargetPhrase *targetPhrase, OnDiskWrapper &onDiskWrapper
, size_t tableLimit, const std::vector<float> &counts)
, TargetPhrase *targetPhrase, OnDiskWrapper &onDiskWrapper
, size_t tableLimit, const std::vector<float> &counts)
{
size_t phraseSize = sourcePhrase.GetSize();
if (pos < phraseSize)
{
const Word &word = sourcePhrase.GetWord(pos);
PhraseNode &node = m_children[word];
if (m_currChild != &node)
{ // new node
node.SetPos(pos);
if (m_currChild)
{
m_currChild->Save(onDiskWrapper, pos, tableLimit);
}
m_currChild = &node;
}
node.AddTargetPhrase(pos + 1, sourcePhrase, targetPhrase, onDiskWrapper, tableLimit, counts);
}
else
{ // drilled down to the right node
m_counts = counts;
m_targetPhraseColl.AddTargetPhrase(targetPhrase);
}
size_t phraseSize = sourcePhrase.GetSize();
if (pos < phraseSize) {
const Word &word = sourcePhrase.GetWord(pos);
PhraseNode &node = m_children[word];
if (m_currChild != &node) {
// new node
node.SetPos(pos);
if (m_currChild) {
m_currChild->Save(onDiskWrapper, pos, tableLimit);
}
m_currChild = &node;
}
node.AddTargetPhrase(pos + 1, sourcePhrase, targetPhrase, onDiskWrapper, tableLimit, counts);
} else {
// drilled down to the right node
m_counts = counts;
m_targetPhraseColl.AddTargetPhrase(targetPhrase);
}
}
const PhraseNode *PhraseNode::GetChild(const Word &wordSought, OnDiskWrapper &onDiskWrapper) const
{
const PhraseNode *ret = NULL;
int l = 0;
int r = m_numChildrenLoad - 1;
int x;
while (r >= l)
{
x = (l + r) / 2;
int l = 0;
int r = m_numChildrenLoad - 1;
int x;
Word wordFound;
UINT64 childFilePos;
GetChild(wordFound, childFilePos, x, onDiskWrapper);
if (wordSought == wordFound)
{
ret = new PhraseNode(childFilePos, onDiskWrapper);
break;
}
if (wordSought < wordFound)
r = x - 1;
else
l = x + 1;
}
return ret;
while (r >= l) {
x = (l + r) / 2;
Word wordFound;
UINT64 childFilePos;
GetChild(wordFound, childFilePos, x, onDiskWrapper);
if (wordSought == wordFound) {
ret = new PhraseNode(childFilePos, onDiskWrapper);
break;
}
if (wordSought < wordFound)
r = x - 1;
else
l = x + 1;
}
return ret;
}
void PhraseNode::GetChild(Word &wordFound, UINT64 &childFilePos, size_t ind, OnDiskWrapper &onDiskWrapper) const
{
size_t wordSize = onDiskWrapper.GetSourceWordSize();
size_t childSize = wordSize + sizeof(UINT64);
size_t numFactors = onDiskWrapper.GetNumSourceFactors();
size_t wordSize = onDiskWrapper.GetSourceWordSize();
size_t childSize = wordSize + sizeof(UINT64);
size_t numFactors = onDiskWrapper.GetNumSourceFactors();
char *currMem = m_memLoad
+ sizeof(UINT64) * 2 // size & file pos of target phrase coll
+ sizeof(float) * onDiskWrapper.GetNumCounts() // count info
+ childSize * ind;
char *currMem = m_memLoad
+ sizeof(UINT64) * 2 // size & file pos of target phrase coll
+ sizeof(float) * onDiskWrapper.GetNumCounts() // count info
+ childSize * ind;
size_t memRead = ReadChild(wordFound, childFilePos, currMem, numFactors);
assert(memRead == childSize);
size_t memRead = ReadChild(wordFound, childFilePos, currMem, numFactors);
assert(memRead == childSize);
}
size_t PhraseNode::ReadChild(Word &wordFound, UINT64 &childFilePos, const char *mem, size_t numFactors) const
{
size_t memRead = wordFound.ReadFromMemory(mem, numFactors);
const char *currMem = mem + memRead;
UINT64 *memArray = (UINT64*) (currMem);
childFilePos = memArray[0];
memRead += sizeof(UINT64);
return memRead;
size_t memRead = wordFound.ReadFromMemory(mem, numFactors);
const char *currMem = mem + memRead;
UINT64 *memArray = (UINT64*) (currMem);
childFilePos = memArray[0];
memRead += sizeof(UINT64);
return memRead;
}
const TargetPhraseCollection *PhraseNode::GetTargetPhraseCollection(size_t tableLimit, OnDiskWrapper &onDiskWrapper) const
{
TargetPhraseCollection *ret = new TargetPhraseCollection();
if (m_value > 0)
ret->ReadFromFile(tableLimit, m_value, onDiskWrapper);
else
{
TargetPhraseCollection *ret = new TargetPhraseCollection();
}
return ret;
if (m_value > 0)
ret->ReadFromFile(tableLimit, m_value, onDiskWrapper);
else {
}
return ret;
}
std::ostream& operator<<(std::ostream &out, const PhraseNode &node)
{
out << "node (" << node.GetFilePos() << "," << node.GetValue() << "," << node.m_pos << ")";
return out;
out << "node (" << node.GetFilePos() << "," << node.GetValue() << "," << node.m_pos << ")";
return out;
}
}

View File

@ -3,17 +3,17 @@
/***********************************************************************
Moses - factored phrase-based, hierarchical and syntactic language decoder
Copyright (C) 2009 Hieu Hoang
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
@ -29,67 +29,74 @@ namespace OnDiskPt
class OnDiskWrapper;
class SourcePhrase;
class PhraseNode
{
friend std::ostream& operator<<(std::ostream&, const PhraseNode&);
friend std::ostream& operator<<(std::ostream&, const PhraseNode&);
protected:
UINT64 m_filePos, m_value;
UINT64 m_filePos, m_value;
typedef std::map<Word, PhraseNode> ChildColl;
ChildColl m_children;
PhraseNode *m_currChild;
bool m_saved;
size_t m_pos;
std::vector<float> m_counts;
TargetPhraseCollection m_targetPhraseColl;
char *m_memLoad, *m_memLoadLast;
UINT64 m_numChildrenLoad;
typedef std::map<Word, PhraseNode> ChildColl;
ChildColl m_children;
PhraseNode *m_currChild;
bool m_saved;
size_t m_pos;
std::vector<float> m_counts;
void AddTargetPhrase(size_t pos, const SourcePhrase &sourcePhrase
, TargetPhrase *targetPhrase, OnDiskWrapper &onDiskWrapper
, size_t tableLimit, const std::vector<float> &counts);
size_t ReadChild(Word &wordFound, UINT64 &childFilePos, const char *mem, size_t numFactors) const;
void GetChild(Word &wordFound, UINT64 &childFilePos, size_t ind, OnDiskWrapper &onDiskWrapper) const;
TargetPhraseCollection m_targetPhraseColl;
char *m_memLoad, *m_memLoadLast;
UINT64 m_numChildrenLoad;
void AddTargetPhrase(size_t pos, const SourcePhrase &sourcePhrase
, TargetPhrase *targetPhrase, OnDiskWrapper &onDiskWrapper
, size_t tableLimit, const std::vector<float> &counts);
size_t ReadChild(Word &wordFound, UINT64 &childFilePos, const char *mem, size_t numFactors) const;
void GetChild(Word &wordFound, UINT64 &childFilePos, size_t ind, OnDiskWrapper &onDiskWrapper) const;
public:
static size_t GetNodeSize(size_t numChildren, size_t wordSize, size_t countSize);
static size_t GetNodeSize(size_t numChildren, size_t wordSize, size_t countSize);
PhraseNode(); // unsaved node
PhraseNode(UINT64 filePos, OnDiskWrapper &onDiskWrapper); // load saved node
~PhraseNode();
void Add(const Word &word, UINT64 nextFilePos, size_t wordSize);
void Save(OnDiskWrapper &onDiskWrapper, size_t pos, size_t tableLimit);
PhraseNode(); // unsaved node
PhraseNode(UINT64 filePos, OnDiskWrapper &onDiskWrapper); // load saved node
~PhraseNode();
void AddTargetPhrase(const SourcePhrase &sourcePhrase, TargetPhrase *targetPhrase
, OnDiskWrapper &onDiskWrapper, size_t tableLimit
, const std::vector<float> &counts);
void Add(const Word &word, UINT64 nextFilePos, size_t wordSize);
void Save(OnDiskWrapper &onDiskWrapper, size_t pos, size_t tableLimit);
UINT64 GetFilePos() const
{ return m_filePos; }
UINT64 GetValue() const
{ return m_value; }
void SetValue(UINT64 value)
{ m_value = value; }
size_t GetSize() const
{ return m_children.size(); }
void AddTargetPhrase(const SourcePhrase &sourcePhrase, TargetPhrase *targetPhrase
, OnDiskWrapper &onDiskWrapper, size_t tableLimit
, const std::vector<float> &counts);
bool Saved() const
{ return m_saved; }
UINT64 GetFilePos() const {
return m_filePos;
}
UINT64 GetValue() const {
return m_value;
}
void SetValue(UINT64 value) {
m_value = value;
}
size_t GetSize() const {
return m_children.size();
}
void SetPos(size_t pos)
{ m_pos = pos; }
bool Saved() const {
return m_saved;
}
void SetPos(size_t pos) {
m_pos = pos;
}
const PhraseNode *GetChild(const Word &wordSought, OnDiskWrapper &onDiskWrapper) const;
const TargetPhraseCollection *GetTargetPhraseCollection(size_t tableLimit, OnDiskWrapper &onDiskWrapper) const;
void AddCounts(const std::vector<float> &counts) {
m_counts = counts;
}
float GetCount(size_t ind) const;
const PhraseNode *GetChild(const Word &wordSought, OnDiskWrapper &onDiskWrapper) const;
const TargetPhraseCollection *GetTargetPhraseCollection(size_t tableLimit, OnDiskWrapper &onDiskWrapper) const;
void AddCounts(const std::vector<float> &counts)
{ m_counts = counts; }
float GetCount(size_t ind) const;
};
}

View File

@ -2,17 +2,17 @@
/***********************************************************************
Moses - factored phrase-based, hierarchical and syntactic language decoder
Copyright (C) 2009 Hieu Hoang
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
@ -22,7 +22,7 @@
namespace OnDiskPt
{
}

View File

@ -3,17 +3,17 @@
/***********************************************************************
Moses - factored phrase-based, hierarchical and syntactic language decoder
Copyright (C) 2009 Hieu Hoang
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA

View File

@ -2,17 +2,17 @@
/***********************************************************************
Moses - factored phrase-based, hierarchical and syntactic language decoder
Copyright (C) 2009 Hieu Hoang
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
@ -33,15 +33,15 @@ namespace OnDiskPt
{
TargetPhrase::TargetPhrase(size_t numScores)
:m_scores(numScores)
:m_scores(numScores)
{
}
TargetPhrase::TargetPhrase(const TargetPhrase &copy)
:Phrase(copy)
,m_scores(copy.m_scores)
:Phrase(copy)
,m_scores(copy.m_scores)
{
}
TargetPhrase::~TargetPhrase()
@ -50,287 +50,277 @@ TargetPhrase::~TargetPhrase()
void TargetPhrase::SetLHS(Word *lhs)
{
AddWord(lhs);
AddWord(lhs);
}
void TargetPhrase::Create1AlignFromString(const std::string &align1Str)
{
vector<size_t> alignPoints;
Moses::Tokenize<size_t>(alignPoints, align1Str, "-");
assert(alignPoints.size() == 2);
m_align.push_back(pair<size_t, size_t>(alignPoints[0], alignPoints[1]) );
vector<size_t> alignPoints;
Moses::Tokenize<size_t>(alignPoints, align1Str, "-");
assert(alignPoints.size() == 2);
m_align.push_back(pair<size_t, size_t>(alignPoints[0], alignPoints[1]) );
}
void TargetPhrase::SetScore(float score, size_t ind)
{
assert(ind < m_scores.size());
m_scores[ind] = score;
assert(ind < m_scores.size());
m_scores[ind] = score;
}
class AlignOrderer
{
public:
bool operator()(const AlignPair &a, const AlignPair &b) const
{
return a.first < b.first;
}
public:
bool operator()(const AlignPair &a, const AlignPair &b) const {
return a.first < b.first;
}
};
void TargetPhrase::SortAlign()
{
std::sort(m_align.begin(), m_align.end(), AlignOrderer());
std::sort(m_align.begin(), m_align.end(), AlignOrderer());
}
char *TargetPhrase::WriteToMemory(OnDiskWrapper &onDiskWrapper, size_t &memUsed) const
{
size_t phraseSize = GetSize();
size_t targetWordSize = onDiskWrapper.GetTargetWordSize();
size_t memNeeded = sizeof(UINT64) // num of words
+ targetWordSize * phraseSize; // actual words. lhs as last words
memUsed = 0;
UINT64 *mem = (UINT64*) malloc(memNeeded);
// write size
mem[0] = phraseSize;
memUsed += sizeof(UINT64);
size_t phraseSize = GetSize();
size_t targetWordSize = onDiskWrapper.GetTargetWordSize();
// write each word
for (size_t pos = 0; pos < phraseSize; ++pos)
{
const Word &word = GetWord(pos);
char *currPtr = (char*)mem + memUsed;
memUsed += word.WriteToMemory((char*) currPtr);
}
assert(memUsed == memNeeded);
return (char *) mem;
size_t memNeeded = sizeof(UINT64) // num of words
+ targetWordSize * phraseSize; // actual words. lhs as last words
memUsed = 0;
UINT64 *mem = (UINT64*) malloc(memNeeded);
// write size
mem[0] = phraseSize;
memUsed += sizeof(UINT64);
// write each word
for (size_t pos = 0; pos < phraseSize; ++pos) {
const Word &word = GetWord(pos);
char *currPtr = (char*)mem + memUsed;
memUsed += word.WriteToMemory((char*) currPtr);
}
assert(memUsed == memNeeded);
return (char *) mem;
}
void TargetPhrase::Save(OnDiskWrapper &onDiskWrapper)
{
// save in target ind
size_t memUsed;
char *mem = WriteToMemory(onDiskWrapper, memUsed);
// save in target ind
size_t memUsed;
char *mem = WriteToMemory(onDiskWrapper, memUsed);
std::fstream &file = onDiskWrapper.GetFileTargetInd();
UINT64 startPos = file.tellp();
file.seekp(0, ios::end);
file.write(mem, memUsed);
UINT64 endPos = file.tellp();
assert(startPos + memUsed == endPos);
m_filePos = startPos;
free(mem);
std::fstream &file = onDiskWrapper.GetFileTargetInd();
UINT64 startPos = file.tellp();
file.seekp(0, ios::end);
file.write(mem, memUsed);
UINT64 endPos = file.tellp();
assert(startPos + memUsed == endPos);
m_filePos = startPos;
free(mem);
}
char *TargetPhrase::WriteOtherInfoToMemory(OnDiskWrapper &onDiskWrapper, size_t &memUsed) const
{
// allocate mem
size_t numScores = onDiskWrapper.GetNumScores()
,numAlign = GetAlign().size();
size_t memNeeded = sizeof(UINT64); // file pos (phrase id)
memNeeded += sizeof(UINT64) + 2 * sizeof(UINT64) * numAlign; // align
memNeeded += sizeof(float) * numScores; // scores
char *mem = (char*) malloc(memNeeded);
//memset(mem, 0, memNeeded);
memUsed = 0;
// phrase id
memcpy(mem, &m_filePos, sizeof(UINT64));
memUsed += sizeof(UINT64);
// align
memUsed += WriteAlignToMemory(mem + memUsed);
// scores
memUsed += WriteScoresToMemory(mem + memUsed);
//DebugMem(mem, memNeeded);
assert(memNeeded == memUsed);
return mem;
// allocate mem
size_t numScores = onDiskWrapper.GetNumScores()
,numAlign = GetAlign().size();
size_t memNeeded = sizeof(UINT64); // file pos (phrase id)
memNeeded += sizeof(UINT64) + 2 * sizeof(UINT64) * numAlign; // align
memNeeded += sizeof(float) * numScores; // scores
char *mem = (char*) malloc(memNeeded);
//memset(mem, 0, memNeeded);
memUsed = 0;
// phrase id
memcpy(mem, &m_filePos, sizeof(UINT64));
memUsed += sizeof(UINT64);
// align
memUsed += WriteAlignToMemory(mem + memUsed);
// scores
memUsed += WriteScoresToMemory(mem + memUsed);
//DebugMem(mem, memNeeded);
assert(memNeeded == memUsed);
return mem;
}
size_t TargetPhrase::WriteAlignToMemory(char *mem) const
{
size_t memUsed = 0;
// num of alignments
UINT64 numAlign = m_align.size();
memcpy(mem, &numAlign, sizeof(numAlign));
memUsed += sizeof(numAlign);
// actual alignments
AlignType::const_iterator iter;
for (iter = m_align.begin(); iter != m_align.end(); ++iter)
{
const AlignPair &alignPair = *iter;
memcpy(mem + memUsed, &alignPair.first, sizeof(alignPair.first));
memUsed += sizeof(alignPair.first);
memcpy(mem + memUsed, &alignPair.second, sizeof(alignPair.second));
memUsed += sizeof(alignPair.second);
}
return memUsed;
size_t memUsed = 0;
// num of alignments
UINT64 numAlign = m_align.size();
memcpy(mem, &numAlign, sizeof(numAlign));
memUsed += sizeof(numAlign);
// actual alignments
AlignType::const_iterator iter;
for (iter = m_align.begin(); iter != m_align.end(); ++iter) {
const AlignPair &alignPair = *iter;
memcpy(mem + memUsed, &alignPair.first, sizeof(alignPair.first));
memUsed += sizeof(alignPair.first);
memcpy(mem + memUsed, &alignPair.second, sizeof(alignPair.second));
memUsed += sizeof(alignPair.second);
}
return memUsed;
}
size_t TargetPhrase::WriteScoresToMemory(char *mem) const
{
float *scoreMem = (float*) mem;
for (size_t ind = 0; ind < m_scores.size(); ++ind)
scoreMem[ind] = m_scores[ind];
size_t memUsed = sizeof(float) * m_scores.size();
return memUsed;
{
float *scoreMem = (float*) mem;
for (size_t ind = 0; ind < m_scores.size(); ++ind)
scoreMem[ind] = m_scores[ind];
size_t memUsed = sizeof(float) * m_scores.size();
return memUsed;
}
Moses::TargetPhrase *TargetPhrase::ConvertToMoses(const std::vector<Moses::FactorType> &inputFactors
, const std::vector<Moses::FactorType> &outputFactors
, const Vocab &vocab
, const Moses::PhraseDictionary &phraseDict
, const std::vector<float> &weightT
, const Moses::WordPenaltyProducer* wpProducer
, const Moses::LMList &lmList) const
Moses::TargetPhrase *TargetPhrase::ConvertToMoses(const std::vector<Moses::FactorType> & /*inputFactors */
, const std::vector<Moses::FactorType> &outputFactors
, const Vocab &vocab
, const Moses::PhraseDictionary &phraseDict
, const std::vector<float> &weightT
, const Moses::WordPenaltyProducer* wpProducer
, const Moses::LMList &lmList) const
{
Moses::TargetPhrase *ret = new Moses::TargetPhrase(Moses::Output);
// words
size_t phraseSize = GetSize();
assert(phraseSize > 0); // last word is lhs
--phraseSize;
for (size_t pos = 0; pos < phraseSize; ++pos)
{
Moses::Word *mosesWord = GetWord(pos).ConvertToMoses(Moses::Output, outputFactors, vocab);
ret->AddWord(*mosesWord);
delete mosesWord;
}
// scores
ret->SetScoreChart(phraseDict.GetFeature(), m_scores, weightT, lmList, wpProducer);
// alignments
std::list<std::pair<size_t, size_t> > alignmentInfo;
for (size_t ind = 0; ind < m_align.size(); ++ind)
{
const std::pair<size_t, size_t> &entry = m_align[ind];
alignmentInfo.push_back(entry);
}
ret->SetAlignmentInfo(alignmentInfo);
Moses::Word *lhs = GetWord(GetSize() - 1).ConvertToMoses(Moses::Output, outputFactors, vocab);
ret->SetTargetLHS(*lhs);
delete lhs;
return ret;
Moses::TargetPhrase *ret = new Moses::TargetPhrase(Moses::Output);
// words
size_t phraseSize = GetSize();
assert(phraseSize > 0); // last word is lhs
--phraseSize;
for (size_t pos = 0; pos < phraseSize; ++pos) {
Moses::Word *mosesWord = GetWord(pos).ConvertToMoses(Moses::Output, outputFactors, vocab);
ret->AddWord(*mosesWord);
delete mosesWord;
}
// scores
ret->SetScoreChart(phraseDict.GetFeature(), m_scores, weightT, lmList, wpProducer);
// alignments
std::list<std::pair<size_t, size_t> > alignmentInfo;
for (size_t ind = 0; ind < m_align.size(); ++ind) {
const std::pair<size_t, size_t> &entry = m_align[ind];
alignmentInfo.push_back(entry);
}
ret->SetAlignmentInfo(alignmentInfo);
Moses::Word *lhs = GetWord(GetSize() - 1).ConvertToMoses(Moses::Output, outputFactors, vocab);
ret->SetTargetLHS(*lhs);
delete lhs;
return ret;
}
UINT64 TargetPhrase::ReadOtherInfoFromFile(UINT64 filePos, std::fstream &fileTPColl)
{
assert(filePos == fileTPColl.tellg());
UINT64 memUsed = 0;
fileTPColl.read((char*) &m_filePos, sizeof(UINT64));
memUsed += sizeof(UINT64);
assert(m_filePos != 0);
memUsed += ReadAlignFromFile(fileTPColl);
assert((memUsed + filePos) == fileTPColl.tellg());
memUsed += ReadScoresFromFile(fileTPColl);
assert((memUsed + filePos) == fileTPColl.tellg());
assert(filePos == fileTPColl.tellg());
return memUsed;
UINT64 memUsed = 0;
fileTPColl.read((char*) &m_filePos, sizeof(UINT64));
memUsed += sizeof(UINT64);
assert(m_filePos != 0);
memUsed += ReadAlignFromFile(fileTPColl);
assert((memUsed + filePos) == fileTPColl.tellg());
memUsed += ReadScoresFromFile(fileTPColl);
assert((memUsed + filePos) == fileTPColl.tellg());
return memUsed;
}
UINT64 TargetPhrase::ReadFromFile(std::fstream &fileTP, size_t numFactors)
{
UINT64 bytesRead = 0;
UINT64 bytesRead = 0;
fileTP.seekg(m_filePos);
fileTP.seekg(m_filePos);
UINT64 numWords;
fileTP.read((char*) &numWords, sizeof(UINT64));
bytesRead += sizeof(UINT64);
for (size_t ind = 0; ind < numWords; ++ind)
{
Word *word = new Word();
bytesRead += word->ReadFromFile(fileTP, numFactors);
AddWord(word);
}
return bytesRead;
UINT64 numWords;
fileTP.read((char*) &numWords, sizeof(UINT64));
bytesRead += sizeof(UINT64);
for (size_t ind = 0; ind < numWords; ++ind) {
Word *word = new Word();
bytesRead += word->ReadFromFile(fileTP, numFactors);
AddWord(word);
}
return bytesRead;
}
UINT64 TargetPhrase::ReadAlignFromFile(std::fstream &fileTPColl)
{
UINT64 bytesRead = 0;
UINT64 numAlign;
fileTPColl.read((char*) &numAlign, sizeof(UINT64));
bytesRead += sizeof(UINT64);
for (size_t ind = 0; ind < numAlign; ++ind)
{
AlignPair alignPair;
fileTPColl.read((char*) &alignPair.first, sizeof(UINT64));
fileTPColl.read((char*) &alignPair.second, sizeof(UINT64));
m_align.push_back(alignPair);
bytesRead += sizeof(UINT64) * 2;
}
return bytesRead;
UINT64 bytesRead = 0;
UINT64 numAlign;
fileTPColl.read((char*) &numAlign, sizeof(UINT64));
bytesRead += sizeof(UINT64);
for (size_t ind = 0; ind < numAlign; ++ind) {
AlignPair alignPair;
fileTPColl.read((char*) &alignPair.first, sizeof(UINT64));
fileTPColl.read((char*) &alignPair.second, sizeof(UINT64));
m_align.push_back(alignPair);
bytesRead += sizeof(UINT64) * 2;
}
return bytesRead;
}
UINT64 TargetPhrase::ReadScoresFromFile(std::fstream &fileTPColl)
{
assert(m_scores.size() > 0);
UINT64 bytesRead = 0;
for (size_t ind = 0; ind < m_scores.size(); ++ind)
{
fileTPColl.read((char*) &m_scores[ind], sizeof(float));
bytesRead += sizeof(float);
}
std::transform(m_scores.begin(),m_scores.end(),m_scores.begin(), Moses::TransformScore);
std::transform(m_scores.begin(),m_scores.end(),m_scores.begin(), Moses::FloorScore);
assert(m_scores.size() > 0);
return bytesRead;
UINT64 bytesRead = 0;
for (size_t ind = 0; ind < m_scores.size(); ++ind) {
fileTPColl.read((char*) &m_scores[ind], sizeof(float));
bytesRead += sizeof(float);
}
std::transform(m_scores.begin(),m_scores.end(),m_scores.begin(), Moses::TransformScore);
std::transform(m_scores.begin(),m_scores.end(),m_scores.begin(), Moses::FloorScore);
return bytesRead;
}
std::ostream& operator<<(std::ostream &out, const TargetPhrase &phrase)
{
out << (const Phrase&) phrase << ", " ;
for (size_t ind = 0; ind < phrase.m_align.size(); ++ind)
{
const AlignPair &alignPair = phrase.m_align[ind];
out << alignPair.first << "-" << alignPair.second << " ";
}
out << ", ";
for (size_t ind = 0; ind < phrase.m_scores.size(); ++ind)
{
out << phrase.m_scores[ind] << " ";
}
out << (const Phrase&) phrase << ", " ;
return out;
for (size_t ind = 0; ind < phrase.m_align.size(); ++ind) {
const AlignPair &alignPair = phrase.m_align[ind];
out << alignPair.first << "-" << alignPair.second << " ";
}
out << ", ";
for (size_t ind = 0; ind < phrase.m_scores.size(); ++ind) {
out << phrase.m_scores[ind] << " ";
}
return out;
}
} // namespace

View File

@ -3,17 +3,17 @@
/***********************************************************************
Moses - factored phrase-based, hierarchical and syntactic language decoder
Copyright (C) 2009 Hieu Hoang
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
@ -27,11 +27,11 @@
namespace Moses
{
class PhraseDictionary;
class TargetPhrase;
class LMList;
class Phrase;
class WordPenaltyProducer;
class PhraseDictionary;
class TargetPhrase;
class LMList;
class Phrase;
class WordPenaltyProducer;
}
namespace OnDiskPt
@ -42,52 +42,55 @@ typedef std::vector<AlignPair> AlignType;
class TargetPhrase: public Phrase
{
friend std::ostream& operator<<(std::ostream&, const TargetPhrase&);
friend std::ostream& operator<<(std::ostream&, const TargetPhrase&);
protected:
AlignType m_align;
AlignType m_align;
std::vector<float> m_scores;
UINT64 m_filePos;
size_t WriteAlignToMemory(char *mem) const;
size_t WriteScoresToMemory(char *mem) const;
std::vector<float> m_scores;
UINT64 m_filePos;
UINT64 ReadAlignFromFile(std::fstream &fileTPColl);
UINT64 ReadScoresFromFile(std::fstream &fileTPColl);
size_t WriteAlignToMemory(char *mem) const;
size_t WriteScoresToMemory(char *mem) const;
UINT64 ReadAlignFromFile(std::fstream &fileTPColl);
UINT64 ReadScoresFromFile(std::fstream &fileTPColl);
public:
TargetPhrase(size_t numScores);
TargetPhrase(const TargetPhrase &copy);
virtual ~TargetPhrase();
TargetPhrase(size_t numScores);
TargetPhrase(const TargetPhrase &copy);
virtual ~TargetPhrase();
void SetLHS(Word *lhs);
void SetLHS(Word *lhs);
void Create1AlignFromString(const std::string &align1Str);
void SetScore(float score, size_t ind);
void Create1AlignFromString(const std::string &align1Str);
void SetScore(float score, size_t ind);
const AlignType &GetAlign() const
{ return m_align; }
void SortAlign();
const AlignType &GetAlign() const {
return m_align;
}
void SortAlign();
char *WriteToMemory(OnDiskWrapper &onDiskWrapper, size_t &memUsed) const;
char *WriteOtherInfoToMemory(OnDiskWrapper &onDiskWrapper, size_t &memUsed) const;
void Save(OnDiskWrapper &onDiskWrapper);
char *WriteToMemory(OnDiskWrapper &onDiskWrapper, size_t &memUsed) const;
char *WriteOtherInfoToMemory(OnDiskWrapper &onDiskWrapper, size_t &memUsed) const;
void Save(OnDiskWrapper &onDiskWrapper);
UINT64 GetFilePos() const
{ return m_filePos; }
float GetScore(size_t ind) const
{ return m_scores[ind]; }
UINT64 GetFilePos() const {
return m_filePos;
}
float GetScore(size_t ind) const {
return m_scores[ind];
}
Moses::TargetPhrase *ConvertToMoses(const std::vector<Moses::FactorType> &inputFactors
, const std::vector<Moses::FactorType> &outputFactors
, const Vocab &vocab
, const Moses::PhraseDictionary &phraseDict
, const std::vector<float> &weightT
, const Moses::WordPenaltyProducer* wpProducer
, const Moses::LMList &lmList) const;
UINT64 ReadOtherInfoFromFile(UINT64 filePos, std::fstream &fileTPColl);
UINT64 ReadFromFile(std::fstream &fileTP, size_t numFactors);
Moses::TargetPhrase *ConvertToMoses(const std::vector<Moses::FactorType> &inputFactors
, const std::vector<Moses::FactorType> &outputFactors
, const Vocab &vocab
, const Moses::PhraseDictionary &phraseDict
, const std::vector<float> &weightT
, const Moses::WordPenaltyProducer* wpProducer
, const Moses::LMList &lmList) const;
UINT64 ReadOtherInfoFromFile(UINT64 filePos, std::fstream &fileTPColl);
UINT64 ReadFromFile(std::fstream &fileTP, size_t numFactors);
};
}

View File

@ -2,17 +2,17 @@
/***********************************************************************
Moses - factored phrase-based, hierarchical and syntactic language decoder
Copyright (C) 2009 Hieu Hoang
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
@ -35,165 +35,166 @@ namespace OnDiskPt
size_t TargetPhraseCollection::s_sortScoreInd;
TargetPhraseCollection::TargetPhraseCollection()
:m_filePos(777)
:m_filePos(777)
{}
TargetPhraseCollection::TargetPhraseCollection(const TargetPhraseCollection &copy)
:m_filePos(copy.m_filePos)
,m_debugStr(copy.m_debugStr)
:m_filePos(copy.m_filePos)
,m_debugStr(copy.m_debugStr)
{
}
TargetPhraseCollection::~TargetPhraseCollection()
{
Moses::RemoveAllInColl(m_coll);
Moses::RemoveAllInColl(m_coll);
}
void TargetPhraseCollection::AddTargetPhrase(TargetPhrase *targetPhrase)
{
m_coll.push_back(targetPhrase);
m_coll.push_back(targetPhrase);
}
void TargetPhraseCollection::Sort(size_t tableLimit)
{
std::sort(m_coll.begin(), m_coll.end(), TargetPhraseOrderByScore());
if (m_coll.size() > tableLimit)
{
CollType::iterator iter;
for (iter = m_coll.begin() + tableLimit ; iter != m_coll.end(); ++iter)
{
delete *iter;
}
m_coll.resize(tableLimit);
}
std::sort(m_coll.begin(), m_coll.end(), TargetPhraseOrderByScore());
if (m_coll.size() > tableLimit) {
CollType::iterator iter;
for (iter = m_coll.begin() + tableLimit ; iter != m_coll.end(); ++iter) {
delete *iter;
}
m_coll.resize(tableLimit);
}
}
void TargetPhraseCollection::Save(OnDiskWrapper &onDiskWrapper)
{
std::fstream &file = onDiskWrapper.GetFileTargetColl();
std::fstream &file = onDiskWrapper.GetFileTargetColl();
size_t memUsed = sizeof(UINT64);
char *mem = (char*) malloc(memUsed);
size_t memUsed = sizeof(UINT64);
char *mem = (char*) malloc(memUsed);
// size of coll
UINT64 numPhrases = GetSize();
((UINT64*)mem)[0] = numPhrases;
// MAIN LOOP
CollType::iterator iter;
for (iter = m_coll.begin(); iter != m_coll.end(); ++iter)
{
// save phrase
TargetPhrase &targetPhrase = **iter;
targetPhrase.Save(onDiskWrapper);
// save coll
size_t memUsedTPOtherInfo;
char *memTPOtherInfo = targetPhrase.WriteOtherInfoToMemory(onDiskWrapper, memUsedTPOtherInfo);
// size of coll
UINT64 numPhrases = GetSize();
((UINT64*)mem)[0] = numPhrases;
// expand existing mem
mem = (char*) realloc(mem, memUsed + memUsedTPOtherInfo);
memcpy(mem + memUsed, memTPOtherInfo, memUsedTPOtherInfo);
memUsed += memUsedTPOtherInfo;
free(memTPOtherInfo);
}
// MAIN LOOP
CollType::iterator iter;
for (iter = m_coll.begin(); iter != m_coll.end(); ++iter) {
// save phrase
TargetPhrase &targetPhrase = **iter;
targetPhrase.Save(onDiskWrapper);
// total number of bytes
//((UINT64*)mem)[0] = (UINT64) memUsed;
UINT64 startPos = file.tellp();
file.seekp(0, ios::end);
file.write((char*) mem, memUsed);
// save coll
size_t memUsedTPOtherInfo;
char *memTPOtherInfo = targetPhrase.WriteOtherInfoToMemory(onDiskWrapper, memUsedTPOtherInfo);
free(mem);
UINT64 endPos = file.tellp();
assert(startPos + memUsed == endPos);
m_filePos = startPos;
// expand existing mem
mem = (char*) realloc(mem, memUsed + memUsedTPOtherInfo);
memcpy(mem + memUsed, memTPOtherInfo, memUsedTPOtherInfo);
memUsed += memUsedTPOtherInfo;
free(memTPOtherInfo);
}
// total number of bytes
//((UINT64*)mem)[0] = (UINT64) memUsed;
UINT64 startPos = file.tellp();
file.seekp(0, ios::end);
file.write((char*) mem, memUsed);
free(mem);
UINT64 endPos = file.tellp();
assert(startPos + memUsed == endPos);
m_filePos = startPos;
}
Moses::TargetPhraseCollection *TargetPhraseCollection::ConvertToMoses(const std::vector<Moses::FactorType> &inputFactors
, const std::vector<Moses::FactorType> &outputFactors
, const Moses::PhraseDictionary &phraseDict
, const std::vector<float> &weightT
, const Moses::WordPenaltyProducer* wpProducer
, const Moses::LMList &lmList
, const std::string &filePath
, Vocab &vocab) const
, const std::vector<Moses::FactorType> &outputFactors
, const Moses::PhraseDictionary &phraseDict
, const std::vector<float> &weightT
, const Moses::WordPenaltyProducer* wpProducer
, const Moses::LMList &lmList
, const std::string & /* filePath */
, Vocab &vocab) const
{
Moses::TargetPhraseCollection *ret = new Moses::TargetPhraseCollection();
CollType::const_iterator iter;
for (iter = m_coll.begin(); iter != m_coll.end(); ++iter)
{
const TargetPhrase &tp = **iter;
Moses::TargetPhrase *mosesPhrase = tp.ConvertToMoses(inputFactors, outputFactors
, vocab
, phraseDict
, weightT
, wpProducer
, lmList);
/*
// debugging output
stringstream strme;
strme << filePath << " " << *mosesPhrase;
mosesPhrase->SetDebugOutput(strme.str());
*/
ret->Add(mosesPhrase);
}
ret->Prune(true, phraseDict.GetTableLimit());
return ret;
Moses::TargetPhraseCollection *ret = new Moses::TargetPhraseCollection();
CollType::const_iterator iter;
for (iter = m_coll.begin(); iter != m_coll.end(); ++iter) {
const TargetPhrase &tp = **iter;
Moses::TargetPhrase *mosesPhrase = tp.ConvertToMoses(inputFactors, outputFactors
, vocab
, phraseDict
, weightT
, wpProducer
, lmList);
/*
// debugging output
stringstream strme;
strme << filePath << " " << *mosesPhrase;
mosesPhrase->SetDebugOutput(strme.str());
*/
ret->Add(mosesPhrase);
}
ret->Prune(true, phraseDict.GetTableLimit());
return ret;
}
void TargetPhraseCollection::ReadFromFile(size_t tableLimit, UINT64 filePos, OnDiskWrapper &onDiskWrapper)
{
fstream &fileTPColl = onDiskWrapper.GetFileTargetColl();
fstream &fileTP = onDiskWrapper.GetFileTargetInd();
size_t numScores = onDiskWrapper.GetNumScores();
size_t numTargetFactors = onDiskWrapper.GetNumTargetFactors();
UINT64 numPhrases;
fstream &fileTPColl = onDiskWrapper.GetFileTargetColl();
fstream &fileTP = onDiskWrapper.GetFileTargetInd();
UINT64 currFilePos = filePos;
fileTPColl.seekg(filePos);
fileTPColl.read((char*) &numPhrases, sizeof(UINT64));
// table limit
numPhrases = std::min(numPhrases, (UINT64) tableLimit);
currFilePos += sizeof(UINT64);
for (size_t ind = 0; ind < numPhrases; ++ind)
{
TargetPhrase *tp = new TargetPhrase(numScores);
UINT64 sizeOtherInfo = tp->ReadOtherInfoFromFile(currFilePos, fileTPColl);
tp->ReadFromFile(fileTP, numTargetFactors);
currFilePos += sizeOtherInfo;
m_coll.push_back(tp);
}
size_t numScores = onDiskWrapper.GetNumScores();
size_t numTargetFactors = onDiskWrapper.GetNumTargetFactors();
UINT64 numPhrases;
UINT64 currFilePos = filePos;
fileTPColl.seekg(filePos);
fileTPColl.read((char*) &numPhrases, sizeof(UINT64));
// table limit
numPhrases = std::min(numPhrases, (UINT64) tableLimit);
currFilePos += sizeof(UINT64);
for (size_t ind = 0; ind < numPhrases; ++ind) {
TargetPhrase *tp = new TargetPhrase(numScores);
UINT64 sizeOtherInfo = tp->ReadOtherInfoFromFile(currFilePos, fileTPColl);
tp->ReadFromFile(fileTP, numTargetFactors);
currFilePos += sizeOtherInfo;
m_coll.push_back(tp);
}
}
UINT64 TargetPhraseCollection::GetFilePos() const
{ return m_filePos; }
{
return m_filePos;
}
const std::string TargetPhraseCollection::GetDebugStr() const
{ return m_debugStr; }
{
return m_debugStr;
}
void TargetPhraseCollection::SetDebugStr(const std::string &str)
{ m_debugStr = str; }
{
m_debugStr = str;
}
}

View File

@ -2,17 +2,17 @@
/***********************************************************************
Moses - factored phrase-based, hierarchical and syntactic language decoder
Copyright (C) 2009 Hieu Hoang
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
@ -21,13 +21,13 @@
#include "TargetPhrase.h"
#include "Vocab.h"
namespace Moses
{
class TargetPhraseCollection;
class PhraseDictionary;
class LMList;
class WordPenaltyProducer;
class TargetPhraseCollection;
class PhraseDictionary;
class LMList;
class WordPenaltyProducer;
}
namespace OnDiskPt
@ -35,49 +35,49 @@ namespace OnDiskPt
class TargetPhraseCollection
{
class TargetPhraseOrderByScore
{
public:
bool operator()(const TargetPhrase* a, const TargetPhrase *b) const
{
return a->GetScore(s_sortScoreInd) > b->GetScore(s_sortScoreInd);
}
};
class TargetPhraseOrderByScore
{
public:
bool operator()(const TargetPhrase* a, const TargetPhrase *b) const {
return a->GetScore(s_sortScoreInd) > b->GetScore(s_sortScoreInd);
}
};
protected:
typedef std::vector<TargetPhrase*> CollType;
CollType m_coll;
UINT64 m_filePos;
std::string m_debugStr;
typedef std::vector<TargetPhrase*> CollType;
CollType m_coll;
UINT64 m_filePos;
std::string m_debugStr;
public:
static size_t s_sortScoreInd;
static size_t s_sortScoreInd;
TargetPhraseCollection();
TargetPhraseCollection(const TargetPhraseCollection &copy);
~TargetPhraseCollection();
void AddTargetPhrase(TargetPhrase *targetPhrase);
void Sort(size_t tableLimit);
TargetPhraseCollection();
TargetPhraseCollection(const TargetPhraseCollection &copy);
void Save(OnDiskWrapper &onDiskWrapper);
~TargetPhraseCollection();
void AddTargetPhrase(TargetPhrase *targetPhrase);
void Sort(size_t tableLimit);
size_t GetSize() const
{ return m_coll.size(); }
UINT64 GetFilePos() const;
void Save(OnDiskWrapper &onDiskWrapper);
Moses::TargetPhraseCollection *ConvertToMoses(const std::vector<Moses::FactorType> &inputFactors
, const std::vector<Moses::FactorType> &outputFactors
, const Moses::PhraseDictionary &phraseDict
, const std::vector<float> &weightT
, const Moses::WordPenaltyProducer* wpProducer
, const Moses::LMList &lmList
, const std::string &filePath
, Vocab &vocab) const;
void ReadFromFile(size_t tableLimit, UINT64 filePos, OnDiskWrapper &onDiskWrapper);
size_t GetSize() const {
return m_coll.size();
}
UINT64 GetFilePos() const;
const std::string GetDebugStr() const;
void SetDebugStr(const std::string &str);
Moses::TargetPhraseCollection *ConvertToMoses(const std::vector<Moses::FactorType> &inputFactors
, const std::vector<Moses::FactorType> &outputFactors
, const Moses::PhraseDictionary &phraseDict
, const std::vector<float> &weightT
, const Moses::WordPenaltyProducer* wpProducer
, const Moses::LMList &lmList
, const std::string &filePath
, Vocab &vocab) const;
void ReadFromFile(size_t tableLimit, UINT64 filePos, OnDiskWrapper &onDiskWrapper);
const std::string GetDebugStr() const;
void SetDebugStr(const std::string &str);
};

View File

@ -2,17 +2,17 @@
/***********************************************************************
Moses - factored phrase-based, hierarchical and syntactic language decoder
Copyright (C) 2009 Hieu Hoang
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
@ -27,90 +27,83 @@ using namespace std;
namespace OnDiskPt
{
bool Vocab::Load(OnDiskWrapper &onDiskWrapper)
{
fstream &file = onDiskWrapper.GetFileVocab();
string line;
while(getline(file, line))
{
vector<string> tokens;
Moses::Tokenize(tokens, line);
assert(tokens.size() == 2);
const string &key = tokens[0];
m_vocabColl[key] = Moses::Scan<UINT64>(tokens[1]);
}
// create lookup
// assume contiguous vocab id
m_lookup.resize(m_vocabColl.size() + 1);
CollType::const_iterator iter;
for (iter = m_vocabColl.begin(); iter != m_vocabColl.end(); ++iter)
{
UINT32 vocabId = iter->second;
const std::string &word = iter->first;
m_lookup[vocabId] = word;
}
return true;
fstream &file = onDiskWrapper.GetFileVocab();
string line;
while(getline(file, line)) {
vector<string> tokens;
Moses::Tokenize(tokens, line);
assert(tokens.size() == 2);
const string &key = tokens[0];
m_vocabColl[key] = Moses::Scan<UINT64>(tokens[1]);
}
// create lookup
// assume contiguous vocab id
m_lookup.resize(m_vocabColl.size() + 1);
CollType::const_iterator iter;
for (iter = m_vocabColl.begin(); iter != m_vocabColl.end(); ++iter) {
UINT32 vocabId = iter->second;
const std::string &word = iter->first;
m_lookup[vocabId] = word;
}
return true;
}
void Vocab::Save(OnDiskWrapper &onDiskWrapper)
{
fstream &file = onDiskWrapper.GetFileVocab();
CollType::const_iterator iterVocab;
for (iterVocab = m_vocabColl.begin(); iterVocab != m_vocabColl.end(); ++iterVocab)
{
const string &word = iterVocab->first;
UINT32 vocabId = iterVocab->second;
file << word << " " << vocabId << endl;
}
fstream &file = onDiskWrapper.GetFileVocab();
CollType::const_iterator iterVocab;
for (iterVocab = m_vocabColl.begin(); iterVocab != m_vocabColl.end(); ++iterVocab) {
const string &word = iterVocab->first;
UINT32 vocabId = iterVocab->second;
file << word << " " << vocabId << endl;
}
}
UINT64 Vocab::AddVocabId(const std::string &factorString)
{
// find string id
CollType::const_iterator iter = m_vocabColl.find(factorString);
if (iter == m_vocabColl.end())
{ // add new vocab entry
m_vocabColl[factorString] = m_nextId;
return m_nextId++;
}
else
{ // return existing entry
return iter->second;
}
// find string id
CollType::const_iterator iter = m_vocabColl.find(factorString);
if (iter == m_vocabColl.end()) {
// add new vocab entry
m_vocabColl[factorString] = m_nextId;
return m_nextId++;
} else {
// return existing entry
return iter->second;
}
}
UINT64 Vocab::GetVocabId(const std::string &factorString, bool &found) const
{
// find string id
CollType::const_iterator iter = m_vocabColl.find(factorString);
if (iter == m_vocabColl.end())
{
found = false;
return 0; //return whatever
}
else
{ // return existing entry
found = true;
return iter->second;
}
// find string id
CollType::const_iterator iter = m_vocabColl.find(factorString);
if (iter == m_vocabColl.end()) {
found = false;
return 0; //return whatever
} else {
// return existing entry
found = true;
return iter->second;
}
}
const Moses::Factor *Vocab::GetFactor(UINT32 vocabId, Moses::FactorType factorType, Moses::FactorDirection direction, bool isNonTerminal) const
{
string str = GetString(vocabId);
if (isNonTerminal)
{
str = str.substr(1, str.size() - 2);
}
const Moses::Factor *factor = Moses::FactorCollection::Instance().AddFactor(direction, factorType, str);
return factor;
string str = GetString(vocabId);
if (isNonTerminal) {
str = str.substr(1, str.size() - 2);
}
const Moses::Factor *factor = Moses::FactorCollection::Instance().AddFactor(direction, factorType, str);
return factor;
}
}

View File

@ -3,17 +3,17 @@
/***********************************************************************
Moses - factored phrase-based, hierarchical and syntactic language decoder
Copyright (C) 2009 Hieu Hoang
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
@ -24,7 +24,7 @@
namespace Moses
{
class Factor;
class Factor;
}
namespace OnDiskPt
@ -34,26 +34,27 @@ class OnDiskWrapper;
class Vocab
{
protected:
typedef std::map<std::string, UINT64> CollType;
CollType m_vocabColl;
protected:
typedef std::map<std::string, UINT64> CollType;
CollType m_vocabColl;
std::vector<std::string> m_lookup; // opposite of m_vocabColl
UINT64 m_nextId; // starts @ 1
const std::string &GetString(UINT32 vocabId) const
{ return m_lookup[vocabId]; }
std::vector<std::string> m_lookup; // opposite of m_vocabColl
UINT64 m_nextId; // starts @ 1
const std::string &GetString(UINT32 vocabId) const {
return m_lookup[vocabId];
}
public:
Vocab()
:m_nextId(1)
{}
UINT64 AddVocabId(const std::string &factorString);
UINT64 GetVocabId(const std::string &factorString, bool &found) const;
const Moses::Factor *GetFactor(UINT32 vocabId, Moses::FactorType factorType, Moses::FactorDirection direction, bool isNonTerminal) const;
Vocab()
:m_nextId(1)
{}
UINT64 AddVocabId(const std::string &factorString);
UINT64 GetVocabId(const std::string &factorString, bool &found) const;
const Moses::Factor *GetFactor(UINT32 vocabId, Moses::FactorType factorType, Moses::FactorDirection direction, bool isNonTerminal) const;
bool Load(OnDiskWrapper &onDiskWrapper);
void Save(OnDiskWrapper &onDiskWrapper);
bool Load(OnDiskWrapper &onDiskWrapper);
void Save(OnDiskWrapper &onDiskWrapper);
};
}

View File

@ -2,17 +2,17 @@
/***********************************************************************
Moses - factored phrase-based, hierarchical and syntactic language decoder
Copyright (C) 2009 Hieu Hoang
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
@ -28,139 +28,135 @@ namespace OnDiskPt
{
Word::Word(const Word &copy)
:m_isNonTerminal(copy.m_isNonTerminal)
,m_factors(copy.m_factors)
:m_isNonTerminal(copy.m_isNonTerminal)
,m_factors(copy.m_factors)
{}
Word::~Word()
{}
void Word::CreateFromString(const std::string &inString, Vocab &vocab)
{
if (inString.substr(0, 1) == "[" && inString.substr(inString.size() - 1, 1) == "]")
{ // non-term
m_isNonTerminal = true;
}
else
{
m_isNonTerminal = false;
}
if (inString.substr(0, 1) == "[" && inString.substr(inString.size() - 1, 1) == "]") {
// non-term
m_isNonTerminal = true;
} else {
m_isNonTerminal = false;
}
m_factors.resize(1);
m_factors[0] = vocab.AddVocabId(inString);
m_factors.resize(1);
m_factors[0] = vocab.AddVocabId(inString);
}
size_t Word::WriteToMemory(char *mem) const
{
UINT64 *vocabMem = (UINT64*) mem;
// factors
for (size_t ind = 0; ind < m_factors.size(); ind++)
vocabMem[ind] = m_factors[ind];
size_t size = sizeof(UINT64) * m_factors.size();
{
UINT64 *vocabMem = (UINT64*) mem;
// is non-term
char bNonTerm = (char) m_isNonTerminal;
mem[size] = bNonTerm;
++size;
// factors
for (size_t ind = 0; ind < m_factors.size(); ind++)
vocabMem[ind] = m_factors[ind];
return size;
size_t size = sizeof(UINT64) * m_factors.size();
// is non-term
char bNonTerm = (char) m_isNonTerminal;
mem[size] = bNonTerm;
++size;
return size;
}
size_t Word::ReadFromMemory(const char *mem, size_t numFactors)
{
m_factors.resize(numFactors);
UINT64 *vocabMem = (UINT64*) mem;
// factors
for (size_t ind = 0; ind < m_factors.size(); ind++)
m_factors[ind] = vocabMem[ind];
size_t memUsed = sizeof(UINT64) * m_factors.size();
// is non-term
char bNonTerm;
bNonTerm = mem[memUsed];
m_isNonTerminal = (bool) bNonTerm;
++memUsed;
return memUsed;
m_factors.resize(numFactors);
UINT64 *vocabMem = (UINT64*) mem;
// factors
for (size_t ind = 0; ind < m_factors.size(); ind++)
m_factors[ind] = vocabMem[ind];
size_t memUsed = sizeof(UINT64) * m_factors.size();
// is non-term
char bNonTerm;
bNonTerm = mem[memUsed];
m_isNonTerminal = (bool) bNonTerm;
++memUsed;
return memUsed;
}
size_t Word::ReadFromFile(std::fstream &file, size_t numFactors)
{
size_t memAlloc = numFactors * sizeof(UINT64) + sizeof(char);
char *mem = (char*) malloc(memAlloc);
file.read(mem, memAlloc);
size_t memUsed = ReadFromMemory(mem, numFactors);
assert(memAlloc == memUsed);
free(mem);
return memUsed;
size_t memAlloc = numFactors * sizeof(UINT64) + sizeof(char);
char *mem = (char*) malloc(memAlloc);
file.read(mem, memAlloc);
size_t memUsed = ReadFromMemory(mem, numFactors);
assert(memAlloc == memUsed);
free(mem);
return memUsed;
}
Moses::Word *Word::ConvertToMoses(Moses::FactorDirection direction
, const std::vector<Moses::FactorType> &outputFactorsVec
, const Vocab &vocab) const
, const std::vector<Moses::FactorType> &outputFactorsVec
, const Vocab &vocab) const
{
Moses::Word *ret = new Moses::Word(m_isNonTerminal);
for (size_t ind = 0; ind < m_factors.size(); ++ind)
{
Moses::FactorType factorType = outputFactorsVec[ind];
UINT32 vocabId = m_factors[ind];
const Moses::Factor *factor = vocab.GetFactor(vocabId, factorType, direction, m_isNonTerminal);
ret->SetFactor(factorType, factor);
}
return ret;
Moses::Word *ret = new Moses::Word(m_isNonTerminal);
for (size_t ind = 0; ind < m_factors.size(); ++ind) {
Moses::FactorType factorType = outputFactorsVec[ind];
UINT32 vocabId = m_factors[ind];
const Moses::Factor *factor = vocab.GetFactor(vocabId, factorType, direction, m_isNonTerminal);
ret->SetFactor(factorType, factor);
}
return ret;
}
int Word::Compare(const Word &compare) const
{
int ret;
if (m_isNonTerminal != compare.m_isNonTerminal)
return m_isNonTerminal ?-1 : 1;
if (m_factors < compare.m_factors)
ret = -1;
else if (m_factors > compare.m_factors)
ret = 1;
else
ret = 0;
int ret;
return ret;
if (m_isNonTerminal != compare.m_isNonTerminal)
return m_isNonTerminal ?-1 : 1;
if (m_factors < compare.m_factors)
ret = -1;
else if (m_factors > compare.m_factors)
ret = 1;
else
ret = 0;
return ret;
}
bool Word::operator<(const Word &compare) const
{
int ret = Compare(compare);
return ret < 0;
{
int ret = Compare(compare);
return ret < 0;
}
bool Word::operator==(const Word &compare) const
{
int ret = Compare(compare);
return ret == 0;
{
int ret = Compare(compare);
return ret == 0;
}
std::ostream& operator<<(std::ostream &out, const Word &word)
{
out << "[";
std::vector<UINT64>::const_iterator iter;
for (iter = word.m_factors.begin(); iter != word.m_factors.end(); ++iter)
{
out << *iter << "|";
}
out << (word.m_isNonTerminal ? "n" : "t");
out << "]";
return out;
out << "[";
std::vector<UINT64>::const_iterator iter;
for (iter = word.m_factors.begin(); iter != word.m_factors.end(); ++iter) {
out << *iter << "|";
}
out << (word.m_isNonTerminal ? "n" : "t");
out << "]";
return out;
}
}

View File

@ -3,17 +3,17 @@
/***********************************************************************
Moses - factored phrase-based, hierarchical and syntactic language decoder
Copyright (C) 2009 Hieu Hoang
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
@ -26,7 +26,7 @@
namespace Moses
{
class Word;
class Word;
}
namespace OnDiskPt
@ -34,43 +34,45 @@ namespace OnDiskPt
class Word
{
friend std::ostream& operator<<(std::ostream&, const Word&);
friend std::ostream& operator<<(std::ostream&, const Word&);
protected:
bool m_isNonTerminal;
std::vector<UINT64> m_factors;
bool m_isNonTerminal;
std::vector<UINT64> m_factors;
public:
explicit Word()
{}
explicit Word(size_t numFactors, bool isNonTerminal)
:m_factors(numFactors)
,m_isNonTerminal(isNonTerminal)
{}
Word(const Word &copy);
~Word();
explicit Word()
{}
void CreateFromString(const std::string &inString, Vocab &vocab);
bool IsNonTerminal() const
{ return m_isNonTerminal; }
explicit Word(size_t numFactors, bool isNonTerminal)
:m_isNonTerminal(isNonTerminal)
,m_factors(numFactors)
{}
size_t WriteToMemory(char *mem) const;
size_t ReadFromMemory(const char *mem, size_t numFactors);
size_t ReadFromFile(std::fstream &file, size_t numFactors);
Word(const Word &copy);
~Word();
void SetVocabId(size_t ind, UINT32 vocabId)
{ m_factors[ind] = vocabId; }
Moses::Word *ConvertToMoses(Moses::FactorDirection direction
, const std::vector<Moses::FactorType> &outputFactorsVec
, const Vocab &vocab) const;
int Compare(const Word &compare) const;
bool operator<(const Word &compare) const;
bool operator==(const Word &compare) const;
void CreateFromString(const std::string &inString, Vocab &vocab);
bool IsNonTerminal() const {
return m_isNonTerminal;
}
size_t WriteToMemory(char *mem) const;
size_t ReadFromMemory(const char *mem, size_t numFactors);
size_t ReadFromFile(std::fstream &file, size_t numFactors);
void SetVocabId(size_t ind, UINT32 vocabId) {
m_factors[ind] = vocabId;
}
Moses::Word *ConvertToMoses(Moses::FactorDirection direction
, const std::vector<Moses::FactorType> &outputFactorsVec
, const Vocab &vocab) const;
int Compare(const Word &compare) const;
bool operator<(const Word &compare) const;
bool operator==(const Word &compare) const;
};
}

View File

@ -1,10 +1,10 @@
/* config.h.in. Generated from configure.in by autoheader. */
/* define if the Boost library is available */
/* Defined if the requested minimum BOOST version is satisfied */
#undef HAVE_BOOST
/* define if the Boost::Thread library is available */
#undef HAVE_BOOST_THREAD
/* Define to 1 if you have <boost/thread.hpp> */
#undef HAVE_BOOST_THREAD_HPP
/* Define to 1 if you have the <dlfcn.h> header file. */
#undef HAVE_DLFCN_H
@ -24,6 +24,9 @@
/* Define to 1 if you have the `oolm' library (-loolm). */
#undef HAVE_LIBOOLM
/* Define to 1 if you have the `tcmalloc' library (-ltcmalloc). */
#undef HAVE_LIBTCMALLOC
/* Define to 1 if you have the <memory.h> header file. */
#undef HAVE_MEMORY_H
@ -84,5 +87,8 @@
/* Define to 1 if you have the ANSI C header files. */
#undef STDC_HEADERS
/* Flag to enable use of Boost pool */
#undef USE_BOOST_POOL
/* Version number of package */
#undef VERSION

View File

@ -84,20 +84,41 @@ AC_ARG_ENABLE(optimization,
AC_ARG_ENABLE(threads,
[AC_HELP_STRING([--enable-threads], [compile threadsafe library and multi-threaded moses (mosesmt)])],
[with_threads=yes]
[],
[enable_threads=no]
)
AC_ARG_ENABLE(boost,
[AC_HELP_STRING([--enable-boost], [use Boost library])],
[enable_boost=yes]
[],
[enable_boost=no]
)
AC_ARG_WITH(zlib,
[AC_HELP_STRING([--with-zlib=PATH], [(optional) path to zlib])],
boost [AC_HELP_STRING([--with-zlib=PATH], [(optional) path to zlib])],
[with_zlib=$withval],
[with_zlib=no]
)
AC_ARG_WITH(tcmalloc,
[AC_HELP_STRING([--with-tcmalloc], [(optional) link with tcmalloc; default is no])],
[with_tcmalloc=$withval],
[with_tcmalloc=no]
)
require_boost=no
if test "x$enable_threads" != 'xno' || test "x$enable_boost" != 'xno' || test "x$with_synlm" != 'xno'
then
require_boost=yes
fi
AC_ARG_ENABLE(boost-pool,
[AC_HELP_STRING([--enable-boost-pool], [(optional) try to improve speed by selectively using Boost pool allocation (may increase total memory use); default is yes if Boost enabled])],
[],
[enable_boost_pool=$require_boost]
)
AM_CONDITIONAL([INTERNAL_LM], false)
AM_CONDITIONAL([SRI_LM], false)
AM_CONDITIONAL([IRST_LM], false)
@ -117,13 +138,13 @@ else
CPPFLAGS="$CPPFLAGS -DTRACE_ENABLE=1"
fi
if test "x$with_threads" = 'xyes' || test "x$enable_boost" = 'xyes' || test "x$with_synlm"
if test "x$require_boost" = 'xyes'
then
AC_MSG_NOTICE([Using Boost library])
BOOST_REQUIRE([1.36.0])
fi
if test "x$with_threads" = 'xyes' || test "x$with_synlm"
if test "x$enable_threads" = 'xyes' || test "x$with_synlm"
then
AC_MSG_NOTICE([Building threaded moses])
BOOST_THREADS
@ -245,15 +266,25 @@ then
AM_CONDITIONAL([RAND_LM], true)
fi
if test "x$with_tcmalloc" != 'xno'
then
AC_CHECK_LIB([tcmalloc], [malloc], [], [AC_MSG_ERROR([Cannot find tcmalloc])])
fi
if test "x$enable_boost_pool" != 'xno'
then
AC_CHECK_HEADER(boost/pool/object_pool.hpp,
[AC_DEFINE([USE_BOOST_POOL], [], [Flag to enable use of Boost pool])],
[AC_MSG_WARN([Cannot find boost/pool/object_pool.hpp])]
)
fi
if test "x$with_synlm" != 'xno'
then
SAVE_CPPFLAGS="$CPPFLAGS"
CPPFLAGS="$CPPFLAGS -DWITH_THREADS -I${with_synlm}/rvtl/include -I${with_synlm}/wsjparse/include -lm"
AC_CHECK_HEADERS(nl-cpt.h,
[AC_DEFINE([HAVE_SYNLM], [], [flag for Syntactic Parser])])
@ -268,7 +299,7 @@ AC_CHECK_HEADERS([getopt.h],
[AC_MSG_WARN([Cannot find getopt.h - disabling new mert])])
AM_CONDITIONAL([WITH_SERVER],false)
if test "x$have_xmlrpc_c" = "xyes" && test "x$with_threads" = "xyes"; then
if test "x$have_xmlrpc_c" = "xyes" && test "x$enable_threads" = "xyes"; then
AM_CONDITIONAL([WITH_SERVER],true)
else
AC_MSG_NOTICE([Disabling server])
@ -283,6 +314,6 @@ fi
LIBS="$LIBS -lz"
AC_CONFIG_FILES(Makefile OnDiskPt/src/Makefile moses/src/Makefile moses-chart/src/Makefile moses-cmd/src/Makefile moses-chart-cmd/src/Makefile misc/Makefile mert/Makefile server/Makefile CreateOnDisk/src/Makefile kenlm/Makefile)
AC_CONFIG_FILES(Makefile OnDiskPt/src/Makefile moses/src/Makefile moses-cmd/src/Makefile moses-chart-cmd/src/Makefile misc/Makefile mert/Makefile server/Makefile CreateOnDisk/src/Makefile kenlm/Makefile)
AC_OUTPUT()

View File

@ -13,14 +13,12 @@ libkenlm_la_SOURCES = \
lm/read_arpa.cc \
lm/virtual_interface.cc \
lm/vocab.cc \
util/string_piece.cc \
util/scoped.cc \
util/murmur_hash.cc \
util/mmap.cc \
util/file_piece.cc \
util/ersatz_progress.cc \
util/exception.cc \
util/string_piece.cc \
util/bit_packing.cc
query_SOURCES = lm/ngram_query.cc

View File

@ -1,9 +1,12 @@
Language model inference code by Kenneth Heafield <infer at kheafield.com>
The official website is http://kheafield.com/code/mt/infer.html . If you're a decoder developer, please download the latest version from there instead of copying from another decoder.
The official website is http://kheafield.com/code/kenlm/ . If you're a decoder developer, please download the latest version from there instead of copying from Moses.
This documentation is directed at decoder developers.
While the primary means of building kenlm for use in Moses is the Moses build system, you can also compile independently using:
./compile.sh to compile the code
./test.sh to compile and run tests; requires Boost
./clean.sh to clean
Currently, it loads an ARPA file in 2/3 the time SRI takes and uses 6.5 GB when SRI takes 11 GB. These are compared to the default SRI build (i.e. without their smaller structures). I'm working on optimizing this even further.
The rest of the documentation is directed at decoder developers.
Binary format via mmap is supported. Run ./build_binary to make one then pass the binary file name instead.
@ -11,14 +14,14 @@ Currently, it assumes POSIX APIs for errno, sterror_r, open, close, mmap, munmap
A brief note to Mac OS X users: your gcc is too old to recognize the pack pragma. The warning effectively means that, on 64-bit machines, the model will use 16 bytes instead of 12 bytes per n-gram of maximum order (those of lower order are already 16 bytes) in the probing and sorted models. The trie is not impacted by this.
It does not depend on Boost or ICU. However, if you use Boost and/or ICU in the rest of your code, you should define HAVE_BOOST and/or HAVE_ICU in util/string_piece.hh. Defining HAVE_BOOST will let you hash StringPiece. Defining HAVE_ICU will use ICU's StringPiece to prevent a conflict with the one provided here. By the way, ICU's StringPiece is buggy and I reported this bug: http://bugs.icu-project.org/trac/ticket/7924 .
It does not depend on Boost or ICU. However, if you use Boost and/or ICU in the rest of your code, you should define HAVE_BOOST and/or HAVE_ICU in util/have.hh. Defining HAVE_BOOST will let you hash StringPiece. Defining HAVE_ICU will use ICU's StringPiece to prevent a conflict with the one provided here.
The recommend way to use this:
Copy the code and distribute with your decoder.
Set HAVE_ICU and HAVE_BOOST at the top of util/string_piece.hh as instructed above.
Set HAVE_ICU and HAVE_BOOST at the top of util/have.hh as instructed above.
Look at compile.sh and reimplement using your build system.
Use either the interface in lm/ngram.hh or lm/virtual_interface.hh
Interface documentation is in comments of lm/virtual_interface.hh (including for lm/ngram.hh).
Use either the interface in lm/model.hh or lm/virtual_interface.hh
Interface documentation is in comments of lm/virtual_interface.hh (including for lm/model.hh).
I recommend copying the code and distributing it with your decoder. However, please send improvements to me so that they can be integrated into the package.

2
kenlm/clean.sh Executable file
View File

@ -0,0 +1,2 @@
#!/bin/bash
rm -rf */*.o query build_binary */*_test lm/test.binary* lm/test.arpa?????? util/file_piece.cc.gz

14
kenlm/compile.sh Executable file
View File

@ -0,0 +1,14 @@
#!/bin/bash
#This is just an example compilation. You should integrate these files into your build system. I can provide boost jam if you want.
#If your code uses ICU, edit util/string_piece.hh and uncomment #define USE_ICU
#I use zlib by default. If you don't want to depend on zlib, remove #define USE_ZLIB from util/file_piece.hh
#don't need to use if compiling with moses Makefiles already
set -e
for i in util/{bit_packing,ersatz_progress,exception,file_piece,murmur_hash,scoped,mmap} lm/{binary_format,config,lm_exception,model,read_arpa,search_hashed,search_trie,trie,virtual_interface,vocab}; do
g++ -I. -O3 $CXXFLAGS -c $i.cc -o $i.o
done
g++ -I. -O3 $CXXFLAGS lm/build_binary.cc {lm,util}/*.o -lz -o build_binary
g++ -I. -O3 $CXXFLAGS lm/ngram_query.cc {lm,util}/*.o -lz -o query

View File

@ -48,7 +48,6 @@
1EBB16EA126C158600AE6102 /* scoped.hh in Headers */ = {isa = PBXBuildFile; fileRef = 1EBB16D2126C158600AE6102 /* scoped.hh */; };
1EBB16EB126C158600AE6102 /* sorted_uniform_test.cc in Sources */ = {isa = PBXBuildFile; fileRef = 1EBB16D3126C158600AE6102 /* sorted_uniform_test.cc */; };
1EBB16EC126C158600AE6102 /* sorted_uniform.hh in Headers */ = {isa = PBXBuildFile; fileRef = 1EBB16D4126C158600AE6102 /* sorted_uniform.hh */; };
1EBB16ED126C158600AE6102 /* string_piece.cc in Sources */ = {isa = PBXBuildFile; fileRef = 1EBB16D5126C158600AE6102 /* string_piece.cc */; };
1EBB16EE126C158600AE6102 /* string_piece.hh in Headers */ = {isa = PBXBuildFile; fileRef = 1EBB16D6126C158600AE6102 /* string_piece.hh */; };
1EBB1717126C15C500AE6102 /* facade.hh in Headers */ = {isa = PBXBuildFile; fileRef = 1EBB1708126C15C500AE6102 /* facade.hh */; };
1EBB171A126C15C500AE6102 /* ngram_query.cc in Sources */ = {isa = PBXBuildFile; fileRef = 1EBB170B126C15C500AE6102 /* ngram_query.cc */; };
@ -106,7 +105,6 @@
1EBB16D2126C158600AE6102 /* scoped.hh */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; name = scoped.hh; path = util/scoped.hh; sourceTree = "<group>"; };
1EBB16D3126C158600AE6102 /* sorted_uniform_test.cc */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; name = sorted_uniform_test.cc; path = util/sorted_uniform_test.cc; sourceTree = "<group>"; };
1EBB16D4126C158600AE6102 /* sorted_uniform.hh */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; name = sorted_uniform.hh; path = util/sorted_uniform.hh; sourceTree = "<group>"; };
1EBB16D5126C158600AE6102 /* string_piece.cc */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; name = string_piece.cc; path = util/string_piece.cc; sourceTree = "<group>"; };
1EBB16D6126C158600AE6102 /* string_piece.hh */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; name = string_piece.hh; path = util/string_piece.hh; sourceTree = "<group>"; };
1EBB1708126C15C500AE6102 /* facade.hh */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; name = facade.hh; path = lm/facade.hh; sourceTree = "<group>"; };
1EBB170B126C15C500AE6102 /* ngram_query.cc */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; name = ngram_query.cc; path = lm/ngram_query.cc; sourceTree = "<group>"; };
@ -198,7 +196,6 @@
1EBB16D2126C158600AE6102 /* scoped.hh */,
1EBB16D3126C158600AE6102 /* sorted_uniform_test.cc */,
1EBB16D4126C158600AE6102 /* sorted_uniform.hh */,
1EBB16D5126C158600AE6102 /* string_piece.cc */,
1EBB16D6126C158600AE6102 /* string_piece.hh */,
1E2B85C112555DB1000770D6 /* lm_exception.cc */,
1E2B85C212555DB1000770D6 /* lm_exception.hh */,
@ -287,7 +284,14 @@
isa = PBXProject;
buildConfigurationList = 1DEB91EF08733DB70010E9CD /* Build configuration list for PBXProject "kenlm" */;
compatibilityVersion = "Xcode 3.1";
developmentRegion = English;
hasScannedForEncodings = 1;
knownRegions = (
English,
Japanese,
French,
German,
);
mainGroup = 08FB7794FE84155DC02AAC07 /* kenlm */;
projectDirPath = "";
projectRoot = "";
@ -314,7 +318,6 @@
1EBB16E6126C158600AE6102 /* probing_hash_table_test.cc in Sources */,
1EBB16E9126C158600AE6102 /* scoped.cc in Sources */,
1EBB16EB126C158600AE6102 /* sorted_uniform_test.cc in Sources */,
1EBB16ED126C158600AE6102 /* string_piece.cc in Sources */,
1EBB171A126C15C500AE6102 /* ngram_query.cc in Sources */,
1EBB171C126C15C500AE6102 /* read_arpa.cc in Sources */,
1EBB171E126C15C500AE6102 /* sri_test.cc in Sources */,

View File

@ -9,6 +9,7 @@
#include <fcntl.h>
#include <errno.h>
#include <stdlib.h>
#include <string.h>
#include <sys/mman.h>
#include <sys/types.h>
#include <sys/stat.h>
@ -18,8 +19,10 @@ namespace lm {
namespace ngram {
namespace {
const char kMagicBeforeVersion[] = "mmap lm http://kheafield.com/code format version";
const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 1\n\0";
const long int kMagicVersion = 1;
const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 4\n\0";
// This must be shorter than kMagicBytes and indicates an incomplete binary file (i.e. build failed).
const char kMagicIncomplete[] = "mmap lm http://kheafield.com/code incomplete\n";
const long int kMagicVersion = 4;
// Test values.
struct Sanity {
@ -76,6 +79,50 @@ void WriteHeader(void *to, const Parameters &params) {
}
} // namespace
uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing) {
if (config.write_mmap) {
std::size_t total = TotalHeaderSize(order) + memory_size;
backing.vocab.reset(util::MapZeroedWrite(config.write_mmap, total, backing.file), total, util::scoped_memory::MMAP_ALLOCATED);
strncpy(reinterpret_cast<char*>(backing.vocab.get()), kMagicIncomplete, TotalHeaderSize(order));
return reinterpret_cast<uint8_t*>(backing.vocab.get()) + TotalHeaderSize(order);
} else {
backing.vocab.reset(util::MapAnonymous(memory_size), memory_size, util::scoped_memory::MMAP_ALLOCATED);
return reinterpret_cast<uint8_t*>(backing.vocab.get());
}
}
uint8_t *GrowForSearch(const Config &config, std::size_t memory_size, Backing &backing) {
if (config.write_mmap) {
// Grow the file to accomodate the search, using zeros.
if (-1 == ftruncate(backing.file.get(), backing.vocab.size() + memory_size))
UTIL_THROW(util::ErrnoException, "ftruncate on " << config.write_mmap << " to " << (backing.vocab.size() + memory_size) << " failed");
// We're skipping over the header and vocab for the search space mmap. mmap likes page aligned offsets, so some arithmetic to round the offset down.
off_t page_size = sysconf(_SC_PAGE_SIZE);
off_t alignment_cruft = backing.vocab.size() % page_size;
backing.search.reset(util::MapOrThrow(alignment_cruft + memory_size, true, util::kFileFlags, false, backing.file.get(), backing.vocab.size() - alignment_cruft), alignment_cruft + memory_size, util::scoped_memory::MMAP_ALLOCATED);
return reinterpret_cast<uint8_t*>(backing.search.get()) + alignment_cruft;
} else {
backing.search.reset(util::MapAnonymous(memory_size), memory_size, util::scoped_memory::MMAP_ALLOCATED);
return reinterpret_cast<uint8_t*>(backing.search.get());
}
}
void FinishFile(const Config &config, ModelType model_type, const std::vector<uint64_t> &counts, Backing &backing) {
if (config.write_mmap) {
// header and vocab share the same mmap. The header is written here because we know the counts.
Parameters params;
params.counts = counts;
params.fixed.order = counts.size();
params.fixed.probing_multiplier = config.probing_multiplier;
params.fixed.model_type = model_type;
params.fixed.has_vocabulary = config.include_vocab;
WriteHeader(backing.vocab.get(), params);
}
}
namespace detail {
bool IsBinaryFormat(int fd) {
@ -91,14 +138,17 @@ bool IsBinaryFormat(int fd) {
Sanity reference_header = Sanity();
reference_header.SetToReference();
if (!memcmp(memory.get(), &reference_header, sizeof(Sanity))) return true;
if (!memcmp(memory.get(), kMagicIncomplete, strlen(kMagicIncomplete))) {
UTIL_THROW(FormatLoadException, "This binary file did not finish building");
}
if (!memcmp(memory.get(), kMagicBeforeVersion, strlen(kMagicBeforeVersion))) {
char *end_ptr;
const char *begin_version = static_cast<const char*>(memory.get()) + strlen(kMagicBeforeVersion);
long int version = strtol(begin_version, &end_ptr, 10);
if ((end_ptr != begin_version) && version != kMagicVersion) {
UTIL_THROW(FormatLoadException, "Binary file has version " << version << " but this implementation expects version " << kMagicVersion << " so you'll have to rebuild your binary LM from the ARPA. Sorry.");
UTIL_THROW(FormatLoadException, "Binary file has version " << version << " but this implementation expects version " << kMagicVersion << " so you'll have to use the ARPA to rebuild your binary");
}
UTIL_THROW(FormatLoadException, "File looks like it should be loaded with mmap, but the test values don't match. Try rebuilding the binary format LM using the same code revision, compiler, and architecture.");
UTIL_THROW(FormatLoadException, "File looks like it should be loaded with mmap, but the test values don't match. Try rebuilding the binary format LM using the same code revision, compiler, and architecture");
}
return false;
}
@ -128,7 +178,7 @@ uint8_t *SetupBinary(const Config &config, const Parameters &params, std::size_t
if (file_size != util::kBadSize && static_cast<uint64_t>(file_size) < total_map)
UTIL_THROW(FormatLoadException, "Binary file has size " << file_size << " but the headers say it should be at least " << total_map);
util::MapRead(config.load_method, backing.file.get(), 0, total_map, backing.memory);
util::MapRead(config.load_method, backing.file.get(), 0, total_map, backing.search);
if (config.enumerate_vocab && !params.fixed.has_vocabulary)
UTIL_THROW(FormatLoadException, "The decoder requested all the vocabulary strings, but this binary file does not have them. You may need to rebuild the binary file with an updated version of build_binary.");
@ -137,33 +187,7 @@ uint8_t *SetupBinary(const Config &config, const Parameters &params, std::size_t
if ((off_t)-1 == lseek(backing.file.get(), total_map, SEEK_SET))
UTIL_THROW(util::ErrnoException, "Failed to seek in binary file to vocab words");
}
return reinterpret_cast<uint8_t*>(backing.memory.get()) + TotalHeaderSize(params.counts.size());
}
uint8_t *SetupZeroed(const Config &config, ModelType model_type, const std::vector<uint64_t> &counts, std::size_t memory_size, Backing &backing) {
if (config.write_mmap) {
std::size_t total_map = TotalHeaderSize(counts.size()) + memory_size;
// Write out an mmap file.
backing.memory.reset(util::MapZeroedWrite(config.write_mmap, total_map, backing.file), total_map, util::scoped_memory::MMAP_ALLOCATED);
Parameters params;
params.counts = counts;
params.fixed.order = counts.size();
params.fixed.probing_multiplier = config.probing_multiplier;
params.fixed.model_type = model_type;
params.fixed.has_vocabulary = config.include_vocab;
WriteHeader(backing.memory.get(), params);
if (params.fixed.has_vocabulary) {
if ((off_t)-1 == lseek(backing.file.get(), total_map, SEEK_SET))
UTIL_THROW(util::ErrnoException, "Failed to seek in binary file " << config.write_mmap << " to vocab words");
}
return reinterpret_cast<uint8_t*>(backing.memory.get()) + TotalHeaderSize(counts.size());
} else {
backing.memory.reset(util::MapAnonymous(memory_size), memory_size, util::scoped_memory::MMAP_ALLOCATED);
return reinterpret_cast<uint8_t*>(backing.memory.get());
}
return reinterpret_cast<uint8_t*>(backing.search.get()) + TotalHeaderSize(params.counts.size());
}
void ComplainAboutARPA(const Config &config, ModelType model_type) {

View File

@ -35,10 +35,18 @@ struct Parameters {
struct Backing {
// File behind memory, if any.
util::scoped_fd file;
// Vocabulary lookup table. Not to be confused with the vocab words themselves.
util::scoped_memory vocab;
// Raw block of memory backing the language model data structures
util::scoped_memory memory;
util::scoped_memory search;
};
uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing);
// Grow the binary file for the search data structure and set backing.search, returning the memory address where the search data structure should begin.
uint8_t *GrowForSearch(const Config &config, std::size_t memory_size, Backing &backing);
void FinishFile(const Config &config, ModelType model_type, const std::vector<uint64_t> &counts, Backing &backing);
namespace detail {
bool IsBinaryFormat(int fd);
@ -49,8 +57,6 @@ void MatchCheck(ModelType model_type, const Parameters &params);
uint8_t *SetupBinary(const Config &config, const Parameters &params, std::size_t memory_size, Backing &backing);
uint8_t *SetupZeroed(const Config &config, ModelType model_type, const std::vector<uint64_t> &counts, std::size_t memory_size, Backing &backing);
void ComplainAboutARPA(const Config &config, ModelType model_type);
} // namespace detail
@ -61,13 +67,12 @@ template <class To> void LoadLM(const char *file, const Config &config, To &to)
Backing &backing = to.MutableBacking();
backing.file.reset(util::OpenReadOrThrow(file));
Parameters params;
try {
if (detail::IsBinaryFormat(backing.file.get())) {
Parameters params;
detail::ReadHeader(backing.file.get(), params);
detail::MatchCheck(To::kModelType, params);
// Replace the probing_multiplier.
// Replace the run-time configured probing_multiplier with the one in the file.
Config new_config(config);
new_config.probing_multiplier = params.fixed.probing_multiplier;
std::size_t memory_size = To::Size(params.counts, new_config);
@ -75,15 +80,10 @@ template <class To> void LoadLM(const char *file, const Config &config, To &to)
to.InitializeFromBinary(start, params, new_config, backing.file.get());
} else {
detail::ComplainAboutARPA(config, To::kModelType);
util::FilePiece f(backing.file.release(), file, config.messages);
ReadARPACounts(f, params.counts);
std::size_t memory_size = To::Size(params.counts, config);
uint8_t *start = detail::SetupZeroed(config, To::kModelType, params.counts, memory_size, backing);
to.InitializeFromARPA(file, f, start, params, config);
to.InitializeFromARPA(file, config);
}
} catch (util::Exception &e) {
e << " in file " << file;
e << " File: " << file;
throw;
}

53
kenlm/lm/blank.hh Normal file
View File

@ -0,0 +1,53 @@
#ifndef LM_BLANK__
#define LM_BLANK__
#include <limits>
#include <inttypes.h>
#include <math.h>
namespace lm {
namespace ngram {
/* Suppose "foo bar" appears with zero backoff but there is no trigram
* beginning with these words. Then, when scoring "foo bar", the model could
* return out_state containing "bar" or even null context if "bar" also has no
* backoff and is never followed by another word. Then the backoff is set to
* kNoExtensionBackoff. If the n-gram might be extended, then out_state must
* contain the full n-gram, in which case kExtensionBackoff is set. In any
* case, if an n-gram has non-zero backoff, the full state is returned so
* backoff can be properly charged.
* These differ only in sign bit because the backoff is in fact zero in either
* case.
*/
const float kNoExtensionBackoff = -0.0;
const float kExtensionBackoff = 0.0;
inline void SetExtension(float &backoff) {
if (backoff == kNoExtensionBackoff) backoff = kExtensionBackoff;
}
// This compiles down nicely.
inline bool HasExtension(const float &backoff) {
typedef union { float f; uint32_t i; } UnionValue;
UnionValue compare, interpret;
compare.f = kNoExtensionBackoff;
interpret.f = backoff;
return compare.i != interpret.i;
}
/* Suppose "foo bar baz quux" appears in the ARPA but not "bar baz quux" or
* "baz quux" (because they were pruned). 1.2% of n-grams generated by SRI
* with default settings on the benchmark data set are like this. Since search
* proceeds by finding "quux", "baz quux", "bar baz quux", and finally
* "foo bar baz quux" and the trie needs pointer nodes anyway, blanks are
* inserted. The blanks have probability kBlankProb and backoff kBlankBackoff.
* A blank is recognized by kBlankProb in the probability field; kBlankBackoff
* must be 0 so that inference asseses zero backoff from these blanks.
*/
const float kBlankProb = -std::numeric_limits<float>::infinity();
const float kBlankBackoff = kNoExtensionBackoff;
} // namespace ngram
} // namespace lm
#endif // LM_BLANK__

View File

@ -1,6 +1,8 @@
#include "lm/model.hh"
#include "util/file_piece.hh"
#include <cstdlib>
#include <exception>
#include <iostream>
#include <iomanip>
@ -13,18 +15,21 @@ namespace ngram {
namespace {
void Usage(const char *name) {
std::cerr << "Usage: " << name << " [-u unknown_probability] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [type] input.arpa output.mmap\n\n"
"Where type is one of probing, trie, or sorted:\n\n"
std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [type] input.arpa output.mmap\n\n"
"-u sets the default log10 probability for <unk> if the ARPA file does not have\n"
"one.\n"
"-s allows models to be built even if they do not have <s> and </s>.\n\n"
"type is one of probing, trie, or sorted:\n\n"
"probing uses a probing hash table. It is the fastest but uses the most memory.\n"
"-p sets the space multiplier and must be >1.0. The default is 1.5.\n\n"
"trie is a straightforward trie with bit-level packing. It uses the least\n"
"memory and is still faster than SRI or IRST. Building the trie format uses an\n"
"on-disk sort to save memory.\n"
"-t is the temporary directory prefix. Default is the output file name.\n"
"-m is the amount of memory to use, in MB. Default is 1024MB (1GB).\n\n"
"sorted is like probing but uses a sorted uniform map instead of a hash table.\n"
"-m limits memory use for sorting. Measured in MB. Default is 1024MB.\n\n"
/*"sorted is like probing but uses a sorted uniform map instead of a hash table.\n"
"It uses more memory than trie and is also slower, so there's no real reason to\n"
"use it.\n\n"
"use it.\n\n"*/
"See http://kheafield.com/code/kenlm/benchmark/ for data structure benchmarks.\n"
"Passing only an input file will print memory usage of each data structure.\n"
"If the ARPA file does not have <unk>, -u sets <unk>'s probability; default 0.0.\n";
@ -52,13 +57,13 @@ void ShowSizes(const char *file, const lm::ngram::Config &config) {
std::size_t probing_size = ProbingModel::Size(counts, config);
// probing is always largest so use it to determine number of columns.
long int length = std::max<long int>(5, lrint(ceil(log10(probing_size))));
std::cout << "Memory usage:\ntype ";
std::cout << "Memory estimate:\ntype ";
// right align bytes.
for (long int i = 0; i < length - 5; ++i) std::cout << ' ';
std::cout << "bytes\n"
"probing " << std::setw(length) << probing_size << " assuming -p " << config.probing_multiplier << "\n"
"trie " << std::setw(length) << TrieModel::Size(counts, config) << "\n"
"sorted " << std::setw(length) << SortedModel::Size(counts, config) << "\n";
"trie " << std::setw(length) << TrieModel::Size(counts, config) << "\n";
/* "sorted " << std::setw(length) << SortedModel::Size(counts, config) << "\n";*/
}
} // namespace ngram
@ -68,46 +73,55 @@ void ShowSizes(const char *file, const lm::ngram::Config &config) {
int main(int argc, char *argv[]) {
using namespace lm::ngram;
lm::ngram::Config config;
int opt;
while ((opt = getopt(argc, argv, "u:p:t:m:")) != -1) {
switch(opt) {
case 'u':
config.unknown_missing_prob = ParseFloat(optarg);
break;
case 'p':
config.probing_multiplier = ParseFloat(optarg);
break;
case 't':
config.temporary_directory_prefix = optarg;
break;
case 'm':
config.building_memory = ParseUInt(optarg) * 1048576;
break;
default:
Usage(argv[0]);
try {
lm::ngram::Config config;
int opt;
while ((opt = getopt(argc, argv, "su:p:t:m:")) != -1) {
switch(opt) {
case 'u':
config.unknown_missing_logprob = ParseFloat(optarg);
break;
case 'p':
config.probing_multiplier = ParseFloat(optarg);
break;
case 't':
config.temporary_directory_prefix = optarg;
break;
case 'm':
config.building_memory = ParseUInt(optarg) * 1048576;
break;
case 's':
config.sentence_marker_missing = lm::ngram::Config::SILENT;
break;
default:
Usage(argv[0]);
}
}
}
if (optind + 1 == argc) {
ShowSizes(argv[optind], config);
} else if (optind + 2 == argc) {
config.write_mmap = argv[optind + 1];
ProbingModel(argv[optind], config);
} else if (optind + 3 == argc) {
const char *model_type = argv[optind];
const char *from_file = argv[optind + 1];
config.write_mmap = argv[optind + 2];
if (!strcmp(model_type, "probing")) {
ProbingModel(from_file, config);
} else if (!strcmp(model_type, "sorted")) {
SortedModel(from_file, config);
} else if (!strcmp(model_type, "trie")) {
TrieModel(from_file, config);
if (optind + 1 == argc) {
ShowSizes(argv[optind], config);
} else if (optind + 2 == argc) {
config.write_mmap = argv[optind + 1];
ProbingModel(argv[optind], config);
} else if (optind + 3 == argc) {
const char *model_type = argv[optind];
const char *from_file = argv[optind + 1];
config.write_mmap = argv[optind + 2];
if (!strcmp(model_type, "probing")) {
ProbingModel(from_file, config);
} else if (!strcmp(model_type, "sorted")) {
SortedModel(from_file, config);
} else if (!strcmp(model_type, "trie")) {
TrieModel(from_file, config);
} else {
Usage(argv[0]);
}
} else {
Usage(argv[0]);
}
} else {
Usage(argv[0]);
}
catch (std::exception &e) {
std::cerr << e.what() << std::endl;
abort();
}
return 0;
}

View File

@ -9,7 +9,8 @@ Config::Config() :
messages(&std::cerr),
enumerate_vocab(NULL),
unknown_missing(COMPLAIN),
unknown_missing_prob(0.0),
sentence_marker_missing(THROW_UP),
unknown_missing_logprob(-100.0),
probing_multiplier(1.5),
building_memory(1073741824ULL), // 1 GB
temporary_directory_prefix(NULL),

View File

@ -27,19 +27,22 @@ struct Config {
// ONLY EFFECTIVE WHEN READING ARPA
typedef enum {THROW_UP, COMPLAIN, SILENT} WarningAction;
// What to do when <unk> isn't in the provided model.
typedef enum {THROW_UP, COMPLAIN, SILENT} UnknownMissing;
UnknownMissing unknown_missing;
WarningAction unknown_missing;
// What to do when <s> or </s> is missing from the model.
// If THROW_UP, the exception will be of type util::SpecialWordMissingException.
WarningAction sentence_marker_missing;
// The probability to substitute for <unk> if it's missing from the model.
// No effect if the model has <unk> or unknown_missing == THROW_UP.
float unknown_missing_prob;
float unknown_missing_logprob;
// Size multiplier for probing hash table. Must be > 1. Space is linear in
// this. Time is probing_multiplier / (probing_multiplier - 1). No effect
// for sorted variant.
// If you find yourself setting this to a low number, consider using the
// Sorted version instead which has lower memory consumption.
// TrieModel which has lower memory consumption.
float probing_multiplier;
// Amount of memory to use for building. The actual memory usage will be
@ -53,7 +56,7 @@ struct Config {
// defaults to input file name.
const char *temporary_directory_prefix;
// Level of complaining to do when an ARPA instead of a binary format.
// Level of complaining to do when loading from ARPA instead of binary format.
typedef enum {ALL, EXPENSIVE, NONE} ARPALoadComplain;
ARPALoadComplain arpa_complain;

View File

@ -17,9 +17,7 @@ FormatLoadException::~FormatLoadException() throw() {}
VocabLoadException::VocabLoadException() throw() {}
VocabLoadException::~VocabLoadException() throw() {}
SpecialWordMissingException::SpecialWordMissingException(StringPiece which) throw() {
*this << "Missing special word " << which;
}
SpecialWordMissingException::SpecialWordMissingException() throw() {}
SpecialWordMissingException::~SpecialWordMissingException() throw() {}
} // namespace lm

View File

@ -39,7 +39,7 @@ class VocabLoadException : public LoadException {
class SpecialWordMissingException : public VocabLoadException {
public:
explicit SpecialWordMissingException(StringPiece which) throw();
explicit SpecialWordMissingException() throw();
~SpecialWordMissingException() throw();
};

14
kenlm/lm/max_order.hh Normal file
View File

@ -0,0 +1,14 @@
#ifndef LM_MAX_ORDER__
#define LM_MAX_ORDER__
namespace lm {
namespace ngram {
// If you need higher order, change this and recompile.
// Having this limit means that State can be
// (kMaxOrder - 1) * sizeof(float) bytes instead of
// sizeof(float*) + (kMaxOrder - 1) * sizeof(float) + malloc overhead
const unsigned char kMaxOrder = 6;
} // namespace ngram
} // namespace lm
#endif // LM_MAX_ORDER__

View File

@ -1,5 +1,6 @@
#include "lm/model.hh"
#include "lm/blank.hh"
#include "lm/lm_exception.hh"
#include "lm/search_hashed.hh"
#include "lm/search_trie.hh"
@ -21,9 +22,6 @@ size_t hash_value(const State &state) {
namespace detail {
template <class Search, class VocabularyT> size_t GenericModel<Search, VocabularyT>::Size(const std::vector<uint64_t> &counts, const Config &config) {
if (counts.size() > kMaxOrder) UTIL_THROW(FormatLoadException, "This model has order " << counts.size() << ". Edit ngram.hh's kMaxOrder to at least this value and recompile.");
if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model.");
if (config.probing_multiplier <= 1.0) UTIL_THROW(ConfigException, "probing multiplier must be > 1.0");
return VocabularyT::Size(counts[0], config) + Search::Size(counts, config);
}
@ -59,99 +57,105 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
search_.longest.LoadedBinary();
}
template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromARPA(const char *file, util::FilePiece &f, void *start, const Parameters &params, const Config &config) {
SetupMemory(start, params.counts, config);
template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromARPA(const char *file, const Config &config) {
// Backing file is the ARPA. Steal it so we can make the backing file the mmap output if any.
util::FilePiece f(backing_.file.release(), file, config.messages);
std::vector<uint64_t> counts;
// File counts do not include pruned trigrams that extend to quadgrams etc. These will be fixed by search_.
ReadARPACounts(f, counts);
if (counts.size() > kMaxOrder) UTIL_THROW(FormatLoadException, "This model has order " << counts.size() << ". Edit lm/max_order.hh, set kMaxOrder to at least this value, and recompile.");
if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model.");
if (config.probing_multiplier <= 1.0) UTIL_THROW(ConfigException, "probing multiplier must be > 1.0");
std::size_t vocab_size = VocabularyT::Size(counts[0], config);
// Setup the binary file for writing the vocab lookup table. The search_ is responsible for growing the binary file to its needs.
vocab_.SetupMemory(SetupJustVocab(config, counts.size(), vocab_size, backing_), vocab_size, counts[0], config);
if (config.write_mmap) {
WriteWordsWrapper wrap(config.enumerate_vocab, backing_.file.get());
vocab_.ConfigureEnumerate(&wrap, params.counts[0]);
search_.InitializeFromARPA(file, f, params.counts, config, vocab_);
WriteWordsWrapper wrap(config.enumerate_vocab);
vocab_.ConfigureEnumerate(&wrap, counts[0]);
search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_);
wrap.Write(backing_.file.get());
} else {
vocab_.ConfigureEnumerate(config.enumerate_vocab, params.counts[0]);
search_.InitializeFromARPA(file, f, params.counts, config, vocab_);
vocab_.ConfigureEnumerate(config.enumerate_vocab, counts[0]);
search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_);
}
// TODO: fail faster?
if (!vocab_.SawUnk()) {
switch(config.unknown_missing) {
case Config::THROW_UP:
{
SpecialWordMissingException e("<unk>");
e << " and configuration was set to throw if unknown is missing";
throw e;
}
case Config::COMPLAIN:
if (config.messages) *config.messages << "Language model is missing <unk>. Substituting probability " << config.unknown_missing_prob << "." << std::endl;
// There's no break;. This is by design.
case Config::SILENT:
// Default probabilities for unknown.
search_.unigram.Unknown().backoff = 0.0;
search_.unigram.Unknown().prob = config.unknown_missing_prob;
break;
}
assert(config.unknown_missing != Config::THROW_UP);
// Default probabilities for unknown.
search_.unigram.Unknown().backoff = 0.0;
search_.unigram.Unknown().prob = config.unknown_missing_logprob;
}
if (std::fabs(search_.unigram.Unknown().backoff) > 0.0000001) UTIL_THROW(FormatLoadException, "Backoff for unknown word should be zero, but was given as " << search_.unigram.Unknown().backoff);
FinishFile(config, kModelType, counts, backing_);
}
template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::FullScore(const State &in_state, const WordIndex new_word, State &out_state) const {
unsigned char backoff_start;
FullScoreReturn ret = ScoreExceptBackoff(in_state.history_, in_state.history_ + in_state.valid_length_, new_word, backoff_start, out_state);
if (backoff_start - 1 < in_state.valid_length_) {
ret.prob = std::accumulate(in_state.backoff_ + backoff_start - 1, in_state.backoff_ + in_state.valid_length_, ret.prob);
FullScoreReturn ret = ScoreExceptBackoff(in_state.history_, in_state.history_ + in_state.valid_length_, new_word, out_state);
if (ret.ngram_length - 1 < in_state.valid_length_) {
ret.prob = std::accumulate(in_state.backoff_ + ret.ngram_length - 1, in_state.backoff_ + in_state.valid_length_, ret.prob);
}
return ret;
}
template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::FullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const {
unsigned char backoff_start;
context_rend = std::min(context_rend, context_rbegin + P::Order() - 1);
FullScoreReturn ret = ScoreExceptBackoff(context_rbegin, context_rend, new_word, backoff_start, out_state);
ret.prob += SlowBackoffLookup(context_rbegin, context_rend, backoff_start);
FullScoreReturn ret = ScoreExceptBackoff(context_rbegin, context_rend, new_word, out_state);
// Add the backoff weights for n-grams of order start to (context_rend - context_rbegin).
unsigned char start = ret.ngram_length;
if (context_rend - context_rbegin < static_cast<std::ptrdiff_t>(start)) return ret;
if (start <= 1) {
ret.prob += search_.unigram.Lookup(*context_rbegin).backoff;
start = 2;
}
typename Search::Node node;
if (!search_.FastMakeNode(context_rbegin, context_rbegin + start - 1, node)) {
return ret;
}
float backoff;
// i is the order of the backoff we're looking for.
for (const WordIndex *i = context_rbegin + start - 1; i < context_rend; ++i) {
if (!search_.LookupMiddleNoProb(search_.middle[i - context_rbegin - 1], *i, backoff, node)) break;
ret.prob += backoff;
}
return ret;
}
template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::GetState(const WordIndex *context_rbegin, const WordIndex *context_rend, State &out_state) const {
// Generate a state from context.
context_rend = std::min(context_rend, context_rbegin + P::Order() - 1);
if (context_rend == context_rbegin || *context_rbegin == 0) {
if (context_rend == context_rbegin) {
out_state.valid_length_ = 0;
return;
}
float ignored_prob;
typename Search::Node node;
search_.LookupUnigram(*context_rbegin, ignored_prob, out_state.backoff_[0], node);
out_state.valid_length_ = HasExtension(out_state.backoff_[0]) ? 1 : 0;
float *backoff_out = out_state.backoff_ + 1;
const WordIndex *i = context_rbegin + 1;
for (; i < context_rend; ++i, ++backoff_out) {
if (!search_.LookupMiddleNoProb(search_.middle[i - context_rbegin - 1], *i, *backoff_out, node)) {
out_state.valid_length_ = i - context_rbegin;
std::copy(context_rbegin, i, out_state.history_);
const typename Search::Middle *mid = &*search_.middle.begin();
for (const WordIndex *i = context_rbegin + 1; i < context_rend; ++i, ++backoff_out, ++mid) {
if (!search_.LookupMiddleNoProb(*mid, *i, *backoff_out, node)) {
std::copy(context_rbegin, context_rbegin + out_state.valid_length_, out_state.history_);
return;
}
if (HasExtension(*backoff_out)) out_state.valid_length_ = i - context_rbegin + 1;
}
std::copy(context_rbegin, context_rend, out_state.history_);
out_state.valid_length_ = static_cast<unsigned char>(context_rend - context_rbegin);
std::copy(context_rbegin, context_rbegin + out_state.valid_length_, out_state.history_);
}
template <class Search, class VocabularyT> float GenericModel<Search, VocabularyT>::SlowBackoffLookup(
const WordIndex *const context_rbegin, const WordIndex *const context_rend, unsigned char start) const {
// Add the backoff weights for n-grams of order start to (context_rend - context_rbegin).
if (context_rend - context_rbegin < static_cast<std::ptrdiff_t>(start)) return 0.0;
float ret = 0.0;
if (start == 1) {
ret += search_.unigram.Lookup(*context_rbegin).backoff;
start = 2;
}
typename Search::Node node;
if (!search_.FastMakeNode(context_rbegin, context_rbegin + start - 1, node)) {
return 0.0;
}
float backoff;
// i is the order of the backoff we're looking for.
for (const WordIndex *i = context_rbegin + start - 1; i < context_rend; ++i) {
if (!search_.LookupMiddleNoProb(search_.middle[i - context_rbegin - 1], *i, backoff, node)) break;
ret += backoff;
}
return ret;
namespace {
// Do a paraonoid copy of history, assuming new_word has already been copied
// (hence the -1). out_state.valid_length_ could be zero so I avoided using
// std::copy.
void CopyRemainingHistory(const WordIndex *from, State &out_state) {
WordIndex *out = out_state.history_ + 1;
const WordIndex *in_end = from + static_cast<ptrdiff_t>(out_state.valid_length_) - 1;
for (const WordIndex *in = from; in < in_end; ++in, ++out) *out = *in;
}
} // namespace
/* Ugly optimized function. Produce a score excluding backoff.
* The search goes in increasing order of ngram length.
@ -162,72 +166,64 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,
const WordIndex *context_rbegin,
const WordIndex *context_rend,
const WordIndex new_word,
unsigned char &backoff_start,
State &out_state) const {
FullScoreReturn ret;
// ret.ngram_length contains the last known non-blank ngram length.
ret.ngram_length = 1;
typename Search::Node node;
float *backoff_out(out_state.backoff_);
search_.LookupUnigram(new_word, ret.prob, *backoff_out, node);
if (new_word == 0) {
ret.ngram_length = out_state.valid_length_ = 0;
// All of backoff.
backoff_start = 1;
return ret;
}
// This is the length of the context that should be used for continuation.
out_state.valid_length_ = HasExtension(*backoff_out) ? 1 : 0;
// We'll write the word anyway since it will probably be used and does no harm being there.
out_state.history_[0] = new_word;
if (context_rbegin == context_rend) {
ret.ngram_length = out_state.valid_length_ = 1;
// No backoff because we don't have the history for it.
backoff_start = P::Order();
return ret;
}
if (context_rbegin == context_rend) return ret;
++backoff_out;
// Ok now we now that the bigram contains known words. Start by looking it up.
const WordIndex *hist_iter = context_rbegin;
typename std::vector<Middle>::const_iterator mid_iter = search_.middle.begin();
for (; ; ++mid_iter, ++hist_iter, ++backoff_out) {
if (hist_iter == context_rend) {
// Ran out of history. No backoff.
backoff_start = P::Order();
std::copy(context_rbegin, context_rend, out_state.history_ + 1);
ret.ngram_length = out_state.valid_length_ = (context_rend - context_rbegin) + 1;
// Ran out of history. Typically no backoff, but this could be a blank.
CopyRemainingHistory(context_rbegin, out_state);
// ret.prob was already set.
return ret;
}
if (mid_iter == search_.middle.end()) break;
float revert = ret.prob;
if (!search_.LookupMiddle(*mid_iter, *hist_iter, ret.prob, *backoff_out, node)) {
// Didn't find an ngram using hist_iter.
// The history used in the found n-gram is [context_rbegin, hist_iter).
std::copy(context_rbegin, hist_iter, out_state.history_ + 1);
// Therefore, we found a (hist_iter - context_rbegin + 1)-gram including the last word.
ret.ngram_length = out_state.valid_length_ = (hist_iter - context_rbegin) + 1;
backoff_start = mid_iter - search_.middle.begin() + 1;
CopyRemainingHistory(context_rbegin, out_state);
// ret.prob was already set.
return ret;
}
if (ret.prob == kBlankProb) {
// It's a blank. Go back to the old probability.
ret.prob = revert;
} else {
ret.ngram_length = hist_iter - context_rbegin + 2;
if (HasExtension(*backoff_out)) {
out_state.valid_length_ = ret.ngram_length;
}
}
}
// It passed every lookup in search_.middle. That means it's at least a (P::Order() - 1)-gram.
// All that's left is to check search_.longest.
// It passed every lookup in search_.middle. All that's left is to check search_.longest.
if (!search_.LookupLongest(*hist_iter, ret.prob, node)) {
// It's an (P::Order()-1)-gram
std::copy(context_rbegin, context_rbegin + P::Order() - 2, out_state.history_ + 1);
ret.ngram_length = out_state.valid_length_ = P::Order() - 1;
backoff_start = P::Order() - 1;
// Failed to find a longest n-gram. Fall back to the most recent non-blank.
CopyRemainingHistory(context_rbegin, out_state);
// ret.prob was already set.
return ret;
}
// It's an P::Order()-gram
// out_state.valid_length_ is still P::Order() - 1 because the next lookup will only need that much.
std::copy(context_rbegin, context_rbegin + P::Order() - 2, out_state.history_ + 1);
out_state.valid_length_ = P::Order() - 1;
// It's an P::Order()-gram.
CopyRemainingHistory(context_rbegin, out_state);
// There is no blank in longest_.
ret.ngram_length = P::Order();
backoff_start = P::Order();
return ret;
}

View File

@ -4,6 +4,7 @@
#include "lm/binary_format.hh"
#include "lm/config.hh"
#include "lm/facade.hh"
#include "lm/max_order.hh"
#include "lm/search_hashed.hh"
#include "lm/search_trie.hh"
#include "lm/vocab.hh"
@ -19,12 +20,6 @@ namespace util { class FilePiece; }
namespace lm {
namespace ngram {
// If you need higher order, change this and recompile.
// Having this limit means that State can be
// (kMaxOrder - 1) * sizeof(float) bytes instead of
// sizeof(float*) + (kMaxOrder - 1) * sizeof(float) + malloc overhead
const unsigned char kMaxOrder = 6;
// This is a POD but if you want memcmp to return the same as operator==, call
// ZeroRemaining first.
class State {
@ -56,6 +51,8 @@ class State {
}
}
unsigned char ValidLength() const { return valid_length_; }
// You shouldn't need to touch anything below this line, but the members are public so FullState will qualify as a POD.
// This order minimizes total size of the struct if WordIndex is 64 bit, float is 32 bit, and alignment of 64 bit integers is 64 bit.
WordIndex history_[kMaxOrder - 1];
@ -102,14 +99,14 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod
float SlowBackoffLookup(const WordIndex *const context_rbegin, const WordIndex *const context_rend, unsigned char start) const;
FullScoreReturn ScoreExceptBackoff(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, unsigned char &backoff_start, State &out_state) const;
FullScoreReturn ScoreExceptBackoff(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const;
// Appears after Size in the cc file.
void SetupMemory(void *start, const std::vector<uint64_t> &counts, const Config &config);
void InitializeFromBinary(void *start, const Parameters &params, const Config &config, int fd);
void InitializeFromARPA(const char *file, util::FilePiece &f, void *start, const Parameters &params, const Config &config);
void InitializeFromARPA(const char *file, const Config &config);
Backing &MutableBacking() { return backing_; }

View File

@ -8,6 +8,15 @@
namespace lm {
namespace ngram {
std::ostream &operator<<(std::ostream &o, const State &state) {
o << "State length " << static_cast<unsigned int>(state.valid_length_) << ':';
for (const WordIndex *i = state.history_; i < state.history_ + state.valid_length_; ++i) {
o << ' ' << *i;
}
return o;
}
namespace {
#define StartTest(word, ngram, score) \
@ -17,7 +26,15 @@ namespace {
out);\
BOOST_CHECK_CLOSE(score, ret.prob, 0.001); \
BOOST_CHECK_EQUAL(static_cast<unsigned int>(ngram), ret.ngram_length); \
BOOST_CHECK_EQUAL(std::min<unsigned char>(ngram, 5 - 1), out.valid_length_);
BOOST_CHECK_GE(std::min<unsigned char>(ngram, 5 - 1), out.valid_length_); \
{\
WordIndex context[state.valid_length_ + 1]; \
context[0] = model.GetVocabulary().Index(word); \
std::copy(state.history_, state.history_ + state.valid_length_, context + 1); \
State get_state; \
model.GetState(context, context + state.valid_length_ + 1, get_state); \
BOOST_CHECK_EQUAL(out, get_state); \
}
#define AppendTest(word, ngram, score) \
StartTest(word, ngram, score) \
@ -33,7 +50,7 @@ template <class M> void Starters(const M &model) {
// , probability plus <s> backoff
StartTest(",", 1, -1.383514 + -0.4149733);
// <unk> probability plus <s> backoff
StartTest("this_is_not_found", 0, -1.995635 + -0.4149733);
StartTest("this_is_not_found", 1, -1.995635 + -0.4149733);
}
template <class M> void Continuation(const M &model) {
@ -48,14 +65,77 @@ template <class M> void Continuation(const M &model) {
State preserve = state;
AppendTest("the", 1, -4.04005);
AppendTest("biarritz", 1, -1.9889);
AppendTest("not_found", 0, -2.29666);
AppendTest("more", 1, -1.20632);
AppendTest("not_found", 1, -2.29666);
AppendTest("more", 1, -1.20632 - 20.0);
AppendTest(".", 2, -0.51363);
AppendTest("</s>", 3, -0.0191651);
BOOST_CHECK_EQUAL(0, state.valid_length_);
state = preserve;
AppendTest("more", 5, -0.00181395);
BOOST_CHECK_EQUAL(4, state.valid_length_);
AppendTest("loin", 5, -0.0432557);
BOOST_CHECK_EQUAL(1, state.valid_length_);
}
template <class M> void Blanks(const M &model) {
FullScoreReturn ret;
State state(model.NullContextState());
State out;
AppendTest("also", 1, -1.687872);
AppendTest("would", 2, -2);
AppendTest("consider", 3, -3);
State preserve = state;
AppendTest("higher", 4, -4);
AppendTest("looking", 5, -5);
BOOST_CHECK_EQUAL(1, state.valid_length_);
state = preserve;
AppendTest("not_found", 1, -1.995635 - 7.0 - 0.30103);
state = model.NullContextState();
// higher looking is a blank.
AppendTest("higher", 1, -1.509559);
AppendTest("looking", 1, -1.285941 - 0.30103);
AppendTest("not_found", 1, -1.995635 - 0.4771212);
}
template <class M> void Unknowns(const M &model) {
FullScoreReturn ret;
State state(model.NullContextState());
State out;
AppendTest("not_found", 1, -1.995635);
State preserve = state;
AppendTest("not_found2", 2, -15.0);
AppendTest("not_found3", 2, -15.0 - 2.0);
state = preserve;
AppendTest("however", 2, -4);
AppendTest("not_found3", 3, -6);
}
template <class M> void MinimalState(const M &model) {
FullScoreReturn ret;
State state(model.NullContextState());
State out;
AppendTest("baz", 1, -6.535897);
BOOST_CHECK_EQUAL(0, state.valid_length_);
state = model.NullContextState();
AppendTest("foo", 1, -3.141592);
BOOST_CHECK_EQUAL(1, state.valid_length_);
AppendTest("bar", 2, -6.0);
// Has to include the backoff weight.
BOOST_CHECK_EQUAL(1, state.valid_length_);
AppendTest("bar", 1, -2.718281 + 3.0);
BOOST_CHECK_EQUAL(1, state.valid_length_);
state = model.NullContextState();
AppendTest("to", 1, -1.687872);
AppendTest("look", 2, -0.2922095);
BOOST_CHECK_EQUAL(2, state.valid_length_);
AppendTest("good", 3, -7);
}
#define StatelessTest(word, provide, ngram, score) \
@ -103,16 +183,24 @@ template <class M> void Stateless(const M &model) {
// biarritz
StatelessTest(6, 1, 1, -1.9889);
// not found
StatelessTest(7, 1, 0, -2.29666);
StatelessTest(7, 0, 0, -1.995635);
StatelessTest(7, 1, 1, -2.29666);
StatelessTest(7, 0, 1, -1.995635);
WordIndex unk[1];
unk[0] = 0;
model.GetState(unk, unk + 1, state);
BOOST_CHECK_EQUAL(0, state.valid_length_);
BOOST_CHECK_EQUAL(1, state.valid_length_);
BOOST_CHECK_EQUAL(static_cast<WordIndex>(0), state.history_[0]);
}
//const char *kExpectedOrderProbing[] = {"<unk>", ",", ".", "</s>", "<s>", "a", "also", "beyond", "biarritz", "call", "concerns", "consider", "considering", "for", "higher", "however", "i", "immediate", "in", "is", "little", "loin", "look", "looking", "more", "on", "screening", "small", "the", "to", "watch", "watching", "what", "would"};
template <class M> void Everything(const M &m) {
Starters(m);
Continuation(m);
Blanks(m);
Unknowns(m);
MinimalState(m);
Stateless(m);
}
class ExpectEnumerateVocab : public EnumerateVocab {
public:
@ -124,7 +212,7 @@ class ExpectEnumerateVocab : public EnumerateVocab {
}
void Check(const base::Vocabulary &vocab) {
BOOST_CHECK_EQUAL(34ULL, seen.size());
BOOST_CHECK_EQUAL(37ULL, seen.size());
BOOST_REQUIRE(!seen.empty());
BOOST_CHECK_EQUAL("<unk>", seen[0]);
for (WordIndex i = 0; i < seen.size(); ++i) {
@ -148,18 +236,16 @@ template <class ModelT> void LoadingTest() {
config.probing_multiplier = 2.0;
ModelT m("test.arpa", config);
enumerate.Check(m.GetVocabulary());
Starters(m);
Continuation(m);
Stateless(m);
Everything(m);
}
BOOST_AUTO_TEST_CASE(probing) {
LoadingTest<Model>();
}
BOOST_AUTO_TEST_CASE(sorted) {
/*BOOST_AUTO_TEST_CASE(sorted) {
LoadingTest<SortedModel>();
}
}*/
BOOST_AUTO_TEST_CASE(trie) {
LoadingTest<TrieModel>();
}
@ -175,24 +261,23 @@ template <class ModelT> void BinaryTest() {
ModelT copy_model("test.arpa", config);
enumerate.Check(copy_model.GetVocabulary());
enumerate.Clear();
Everything(copy_model);
}
config.write_mmap = NULL;
ModelT binary("test.binary", config);
enumerate.Check(binary.GetVocabulary());
Starters(binary);
Continuation(binary);
Stateless(binary);
Everything(binary);
unlink("test.binary");
}
BOOST_AUTO_TEST_CASE(write_and_read_probing) {
BinaryTest<Model>();
}
BOOST_AUTO_TEST_CASE(write_and_read_sorted) {
/*BOOST_AUTO_TEST_CASE(write_and_read_sorted) {
BinaryTest<SortedModel>();
}
}*/
BOOST_AUTO_TEST_CASE(write_and_read_trie) {
BinaryTest<TrieModel>();
}

View File

@ -1,3 +1,4 @@
#include "lm/enumerate_vocab.hh"
#include "lm/model.hh"
#include <cstdlib>
@ -5,6 +6,8 @@
#include <iostream>
#include <string>
#include <ctype.h>
#include <sys/resource.h>
#include <sys/time.h>
@ -32,41 +35,79 @@ void PrintUsage(const char *message) {
}
}
template <class Model> void Query(const Model &model) {
template <class Model> void Query(const Model &model, bool sentence_context) {
PrintUsage("Loading statistics:\n");
typename Model::State state, out;
lm::FullScoreReturn ret;
std::string word;
while (std::cin) {
state = model.BeginSentenceState();
state = sentence_context ? model.BeginSentenceState() : model.NullContextState();
float total = 0.0;
bool got = false;
unsigned int oov = 0;
while (std::cin >> word) {
got = true;
ret = model.FullScore(state, model.GetVocabulary().Index(word), out);
lm::WordIndex vocab = model.GetVocabulary().Index(word);
if (vocab == 0) ++oov;
ret = model.FullScore(state, vocab, out);
total += ret.prob;
std::cout << word << ' ' << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << ' ';
std::cout << word << '=' << vocab << ' ' << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << '\t';
state = out;
if (std::cin.get() == '\n') break;
char c;
while (true) {
c = std::cin.get();
if (!std::cin) break;
if (c == '\n') break;
if (!isspace(c)) {
std::cin.unget();
break;
}
}
if (c == '\n') break;
}
if (!got && !std::cin) break;
ret = model.FullScore(state, model.GetVocabulary().EndSentence(), out);
total += ret.prob;
std::cout << "</s> " << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << ' ';
std::cout << "Total: " << total << '\n';
if (sentence_context) {
ret = model.FullScore(state, model.GetVocabulary().EndSentence(), out);
total += ret.prob;
std::cout << "</s>=" << model.GetVocabulary().EndSentence() << ' ' << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << '\t';
}
std::cout << "Total: " << total << " OOV: " << oov << '\n';
}
PrintUsage("After queries:\n");
}
int main(int argc, char *argv[]) {
if (argc < 2) {
std::cerr << "Pass language model name." << std::endl;
return 0;
}
{
lm::ngram::Model ngram(argv[1]);
Query(ngram);
}
PrintUsage("Total time including destruction:\n");
template <class Model> void Query(const char *name) {
lm::ngram::Config config;
Model model(name, config);
Query(model);
}
int main(int argc, char *argv[]) {
if (!(argc == 2 || (argc == 3 && !strcmp(argv[2], "null")))) {
std::cerr << "Usage: " << argv[0] << " lm_file [null]" << std::endl;
std::cerr << "Input is wrapped in <s> and </s> unless null is passed." << std::endl;
return 1;
}
bool sentence_context = (argc == 2);
lm::ngram::ModelType model_type;
if (lm::ngram::RecognizeBinary(argv[1], model_type)) {
switch(model_type) {
case lm::ngram::HASH_PROBING:
Query<lm::ngram::ProbingModel>(argv[1], sentence_context);
break;
case lm::ngram::TRIE_SORTED:
Query<lm::ngram::TrieModel>(argv[1], sentence_context);
break;
case lm::ngram::HASH_SORTED:
default:
std::cerr << "Unrecognized kenlm model type " << model_type << std::endl;
abort();
}
} else {
Query<lm::ngram::ProbingModel>(argv[1], sentence_context);
}
PrintUsage("Total time including destruction:\n");
return 0;
}

View File

@ -1,13 +1,19 @@
#include "lm/read_arpa.hh"
#include "lm/blank.hh"
#include <cstdlib>
#include <vector>
#include <ctype.h>
#include <string.h>
#include <inttypes.h>
namespace lm {
// 1 for '\t', '\n', and ' '. This is stricter than isspace.
const bool kARPASpaces[256] = {0,0,0,0,0,0,0,0,0,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0};
namespace {
bool IsEntirelyWhiteSpace(const StringPiece &line) {
@ -17,14 +23,20 @@ bool IsEntirelyWhiteSpace(const StringPiece &line) {
return true;
}
template <class F> void GenericReadARPACounts(F &in, std::vector<uint64_t> &number) {
const char kBinaryMagic[] = "mmap lm http://kheafield.com/code";
} // namespace
void ReadARPACounts(util::FilePiece &in, std::vector<uint64_t> &number) {
number.clear();
StringPiece line;
if (!IsEntirelyWhiteSpace(line = in.ReadLine())) {
if ((line.size() >= 2) && (line.data()[0] == 0x1f) && (static_cast<unsigned char>(line.data()[1]) == 0x8b)) {
UTIL_THROW(FormatLoadException, "Looks like a gzip file. If this is an ARPA file, run\nzcat " << in.FileName() << " |kenlm/build_binary /dev/stdin " << in.FileName() << ".binary\nIf this already in binary format, you need to decompress it because mmap doesn't work on top of gzip.");
UTIL_THROW(FormatLoadException, "Looks like a gzip file. If this is an ARPA file, pipe " << in.FileName() << " through zcat. If this already in binary format, you need to decompress it because mmap doesn't work on top of gzip.");
}
UTIL_THROW(FormatLoadException, "First line was \"" << static_cast<int>(line.data()[1]) << "\" not blank");
if (static_cast<size_t>(line.size()) >= strlen(kBinaryMagic) && StringPiece(line.data(), strlen(kBinaryMagic)) == kBinaryMagic)
UTIL_THROW(FormatLoadException, "This looks like a binary file but got sent to the ARPA parser. Did you compress the binary file or pass a binary file where only ARPA files are accepted?");
UTIL_THROW(FormatLoadException, "First line was \"" << line.data() << "\" not blank");
}
if ((line = in.ReadLine()) != "\\data\\") UTIL_THROW(FormatLoadException, "second line was \"" << line << "\" not \\data\\.");
while (!IsEntirelyWhiteSpace(line = in.ReadLine())) {
@ -44,66 +56,14 @@ template <class F> void GenericReadARPACounts(F &in, std::vector<uint64_t> &numb
}
}
template <class F> void GenericReadNGramHeader(F &in, unsigned int length) {
StringPiece line;
void ReadNGramHeader(util::FilePiece &in, unsigned int length) {
StringPiece line;
while (IsEntirelyWhiteSpace(line = in.ReadLine())) {}
std::stringstream expected;
expected << '\\' << length << "-grams:";
if (line != expected.str()) UTIL_THROW(FormatLoadException, "Was expecting n-gram header " << expected.str() << " but got " << line << " instead");
}
template <class F> void GenericReadEnd(F &in) {
StringPiece line;
do {
line = in.ReadLine();
} while (IsEntirelyWhiteSpace(line));
if (line != "\\end\\") UTIL_THROW(FormatLoadException, "Expected \\end\\ but the ARPA file has " << line);
}
class FakeFilePiece {
public:
explicit FakeFilePiece(std::istream &in) : in_(in) {
in_.exceptions(std::ios::failbit | std::ios::badbit | std::ios::eofbit);
}
StringPiece ReadLine() throw(util::EndOfFileException) {
getline(in_, buffer_);
return StringPiece(buffer_);
}
float ReadFloat() {
float ret;
in_ >> ret;
return ret;
}
const char *FileName() const {
// This only used for error messages and we don't know the file name. . .
return "$file";
}
private:
std::istream &in_;
std::string buffer_;
};
} // namespace
void ReadARPACounts(util::FilePiece &in, std::vector<uint64_t> &number) {
GenericReadARPACounts(in, number);
}
void ReadARPACounts(std::istream &in, std::vector<uint64_t> &number) {
FakeFilePiece fake(in);
GenericReadARPACounts(fake, number);
}
void ReadNGramHeader(util::FilePiece &in, unsigned int length) {
GenericReadNGramHeader(in, length);
}
void ReadNGramHeader(std::istream &in, unsigned int length) {
FakeFilePiece fake(in);
GenericReadNGramHeader(fake, length);
}
void ReadBackoff(util::FilePiece &in, Prob &/*weights*/) {
switch (in.get()) {
case '\t':
@ -116,39 +76,43 @@ void ReadBackoff(util::FilePiece &in, Prob &/*weights*/) {
case '\n':
break;
default:
UTIL_THROW(FormatLoadException, "Expected tab or newline after unigram");
UTIL_THROW(FormatLoadException, "Expected tab or newline for backoff");
}
}
void ReadBackoff(util::FilePiece &in, ProbBackoff &weights) {
// Always make zero negative.
// Negative zero means that no (n+1)-gram has this n-gram as context.
// Therefore the hypothesis state can be shorter. Of course, many n-grams
// are context for (n+1)-grams. An algorithm in the data structure will go
// back and set the backoff to positive zero in these cases.
switch (in.get()) {
case '\t':
weights.backoff = in.ReadFloat();
if (weights.backoff == ngram::kExtensionBackoff) weights.backoff = ngram::kNoExtensionBackoff;
if ((in.get() != '\n')) UTIL_THROW(FormatLoadException, "Expected newline after backoff");
break;
case '\n':
weights.backoff = 0.0;
weights.backoff = ngram::kNoExtensionBackoff;
break;
default:
UTIL_THROW(FormatLoadException, "Expected tab or newline after unigram");
UTIL_THROW(FormatLoadException, "Expected tab or newline for backoff");
}
}
void ReadEnd(util::FilePiece &in) {
GenericReadEnd(in);
StringPiece line;
do {
line = in.ReadLine();
} while (IsEntirelyWhiteSpace(line));
if (line != "\\end\\") UTIL_THROW(FormatLoadException, "Expected \\end\\ but the ARPA file has " << line);
try {
while (true) {
line = in.ReadLine();
if (!IsEntirelyWhiteSpace(line)) UTIL_THROW(FormatLoadException, "Trailing line " << line);
}
} catch (const util::EndOfFileException &e) {
return;
}
}
void ReadEnd(std::istream &in) {
FakeFilePiece fake(in);
GenericReadEnd(fake);
} catch (const util::EndOfFileException &e) {}
}
} // namespace lm

View File

@ -13,22 +13,21 @@
namespace lm {
void ReadARPACounts(util::FilePiece &in, std::vector<uint64_t> &number);
void ReadARPACounts(std::istream &in, std::vector<uint64_t> &number);
void ReadNGramHeader(util::FilePiece &in, unsigned int length);
void ReadNGramHeader(std::istream &in, unsigned int length);
void ReadBackoff(util::FilePiece &in, Prob &weights);
void ReadBackoff(util::FilePiece &in, ProbBackoff &weights);
void ReadEnd(util::FilePiece &in);
void ReadEnd(std::istream &in);
extern const bool kARPASpaces[256];
template <class Voc> void Read1Gram(util::FilePiece &f, Voc &vocab, ProbBackoff *unigrams) {
try {
float prob = f.ReadFloat();
if (prob > 0) UTIL_THROW(FormatLoadException, "Positive probability " << prob);
if (f.get() != '\t') UTIL_THROW(FormatLoadException, "Expected tab after probability");
ProbBackoff &value = unigrams[vocab.Insert(f.ReadDelimited())];
ProbBackoff &value = unigrams[vocab.Insert(f.ReadDelimited(kARPASpaces))];
value.prob = prob;
ReadBackoff(f, value);
} catch(util::Exception &e) {
@ -50,7 +49,7 @@ template <class Voc, class Weights> void ReadNGram(util::FilePiece &f, const uns
weights.prob = f.ReadFloat();
if (weights.prob > 0) UTIL_THROW(FormatLoadException, "Positive probability " << weights.prob);
for (WordIndex *vocab_out = reverse_indices + n - 1; vocab_out >= reverse_indices; --vocab_out) {
*vocab_out = vocab.Index(f.ReadDelimited());
*vocab_out = vocab.Index(f.ReadDelimited(kARPASpaces));
}
ReadBackoff(f, weights);
} catch(util::Exception &e) {

View File

@ -1,5 +1,6 @@
#include "lm/search_hashed.hh"
#include "lm/blank.hh"
#include "lm/lm_exception.hh"
#include "lm/read_arpa.hh"
#include "lm/vocab.hh"
@ -13,34 +14,65 @@ namespace ngram {
namespace {
/* All of the entropy is in low order bits and boost::hash does poorly with
* these. Odd numbers near 2^64 chosen by mashing on the keyboard. There is a
* stable point: 0. But 0 is <unk> which won't be queried here anyway.
*/
inline uint64_t CombineWordHash(uint64_t current, const WordIndex next) {
uint64_t ret = (current * 8978948897894561157ULL) ^ (static_cast<uint64_t>(next) * 17894857484156487943ULL);
return ret;
}
/* These are passed to ReadNGrams so that n-grams with zero backoff that appear as context will still be used in state. */
template <class Middle> class ActivateLowerMiddle {
public:
explicit ActivateLowerMiddle(Middle &middle) : modify_(middle) {}
uint64_t ChainedWordHash(const WordIndex *word, const WordIndex *word_end) {
if (word == word_end) return 0;
uint64_t current = static_cast<uint64_t>(*word);
for (++word; word != word_end; ++word) {
current = CombineWordHash(current, *word);
}
return current;
}
void operator()(const WordIndex *vocab_ids, const unsigned int n) {
uint64_t hash = static_cast<WordIndex>(vocab_ids[1]);
for (const WordIndex *i = vocab_ids + 2; i < vocab_ids + n; ++i) {
hash = detail::CombineWordHash(hash, *i);
}
typename Middle::MutableIterator i;
// TODO: somehow get text of n-gram for this error message.
if (!modify_.UnsafeMutableFind(hash, i))
UTIL_THROW(FormatLoadException, "The context of every " << n << "-gram should appear as a " << (n-1) << "-gram");
SetExtension(i->MutableValue().backoff);
}
template <class Voc, class Store> void ReadNGrams(util::FilePiece &f, const unsigned int n, const size_t count, const Voc &vocab, Store &store) {
private:
Middle &modify_;
};
class ActivateUnigram {
public:
explicit ActivateUnigram(ProbBackoff *unigram) : modify_(unigram) {}
void operator()(const WordIndex *vocab_ids, const unsigned int /*n*/) {
// assert(n == 2);
SetExtension(modify_[vocab_ids[1]].backoff);
}
private:
ProbBackoff *modify_;
};
template <class Voc, class Store, class Middle, class Activate> void ReadNGrams(util::FilePiece &f, const unsigned int n, const size_t count, const Voc &vocab, std::vector<Middle> &middle, Activate activate, Store &store) {
ReadNGramHeader(f, n);
ProbBackoff blank;
blank.prob = kBlankProb;
blank.backoff = kBlankBackoff;
// vocab ids of words in reverse order
WordIndex vocab_ids[n];
uint64_t keys[n - 1];
typename Store::Packing::Value value;
typename Middle::ConstIterator found;
for (size_t i = 0; i < count; ++i) {
ReadNGram(f, n, vocab, vocab_ids, value);
uint64_t key = ChainedWordHash(vocab_ids, vocab_ids + n);
store.Insert(Store::Packing::Make(key, value));
keys[0] = detail::CombineWordHash(static_cast<uint64_t>(*vocab_ids), vocab_ids[1]);
for (unsigned int h = 1; h < n - 1; ++h) {
keys[h] = detail::CombineWordHash(keys[h-1], vocab_ids[h+1]);
}
store.Insert(Store::Packing::Make(keys[n-2], value));
// Go back and insert blanks.
for (int lower = n - 3; lower >= 0; --lower) {
if (middle[lower].Find(keys[lower], found)) break;
middle[lower].Insert(Middle::Packing::Make(keys[lower], blank));
}
activate(vocab_ids, n);
}
store.FinishedInserting();
@ -49,17 +81,37 @@ template <class Voc, class Store> void ReadNGrams(util::FilePiece &f, const unsi
} // namespace
namespace detail {
template <class MiddleT, class LongestT> template <class Voc> void TemplateHashedSearch<MiddleT, LongestT>::InitializeFromARPA(const char * /*file*/, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &/*config*/, Voc &vocab) {
Read1Grams(f, counts[0], vocab, unigram.Raw());
// Read the n-grams.
for (unsigned int n = 2; n < counts.size(); ++n) {
ReadNGrams(f, n, counts[n-1], vocab, middle[n-2]);
template <class MiddleT, class LongestT> template <class Voc> void TemplateHashedSearch<MiddleT, LongestT>::InitializeFromARPA(const char * /*file*/, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, Voc &vocab, Backing &backing) {
// TODO: fix sorted.
SetupMemory(GrowForSearch(config, Size(counts, config), backing), counts, config);
Read1Grams(f, counts[0], vocab, unigram.Raw());
CheckSpecials(config, vocab);
try {
if (counts.size() > 2) {
ReadNGrams(f, 2, counts[1], vocab, middle, ActivateUnigram(unigram.Raw()), middle[0]);
}
for (unsigned int n = 3; n < counts.size(); ++n) {
ReadNGrams(f, n, counts[n-1], vocab, middle, ActivateLowerMiddle<Middle>(middle[n-3]), middle[n-2]);
}
if (counts.size() > 2) {
ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, middle, ActivateLowerMiddle<Middle>(middle.back()), longest);
} else {
ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, middle, ActivateUnigram(unigram.Raw()), longest);
}
} catch (util::ProbingSizeException &e) {
UTIL_THROW(util::ProbingSizeException, "Avoid pruning n-grams like \"bar baz quux\" when \"foo bar baz quux\" is still in the model. KenLM will work when this pruning happens, but the probing model assumes these events are rare enough that using blank space in the probing hash table will cover all of them. Increase probing_multiplier (-p to build_binary) to add more blank spaces.\n");
}
ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, longest);
ReadEnd(f);
}
template void TemplateHashedSearch<ProbingHashedSearch::Middle, ProbingHashedSearch::Longest>::InitializeFromARPA(const char *, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &, ProbingVocabulary &vocab);
template void TemplateHashedSearch<SortedHashedSearch::Middle, SortedHashedSearch::Longest>::InitializeFromARPA(const char *, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &, SortedVocabulary &vocab);
template void TemplateHashedSearch<ProbingHashedSearch::Middle, ProbingHashedSearch::Longest>::InitializeFromARPA(const char *, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &, ProbingVocabulary &vocab, Backing &backing);
template void TemplateHashedSearch<SortedHashedSearch::Middle, SortedHashedSearch::Longest>::InitializeFromARPA(const char *, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &, SortedVocabulary &vocab, Backing &backing);
SortedHashedSearch::SortedHashedSearch() {
UTIL_THROW(util::Exception, "Sorted is broken at the moment, sorry");
}
} // namespace detail
} // namespace ngram

View File

@ -17,10 +17,11 @@ namespace util { class FilePiece; }
namespace lm {
namespace ngram {
struct Backing;
namespace detail {
inline uint64_t CombineWordHash(uint64_t current, const WordIndex next) {
uint64_t ret = (current * 8978948897894561157ULL) ^ (static_cast<uint64_t>(next) * 17894857484156487943ULL);
uint64_t ret = (current * 8978948897894561157ULL) ^ (static_cast<uint64_t>(1 + next) * 17894857484156487943ULL);
return ret;
}
@ -91,7 +92,7 @@ template <class MiddleT, class LongestT> struct TemplateHashedSearch : public Ha
return start;
}
template <class Voc> void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, Voc &vocab);
template <class Voc> void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, Voc &vocab, Backing &backing);
bool LookupMiddle(const Middle &middle, WordIndex word, float &prob, float &backoff, Node &node) const {
node = CombineWordHash(node, word);
@ -145,6 +146,8 @@ struct ProbingHashedSearch : public TemplateHashedSearch<
struct SortedHashedSearch : public TemplateHashedSearch<
util::SortedUniformMap<util::ByteAlignedPacking<uint64_t, ProbBackoff> >,
util::SortedUniformMap<util::ByteAlignedPacking<uint64_t, Prob> > > {
SortedHashedSearch();
static const ModelType kModelType = HASH_SORTED;
};

View File

@ -1,7 +1,9 @@
/* This is where the trie is built. It's on-disk. */
#include "lm/search_trie.hh"
#include "lm/blank.hh"
#include "lm/lm_exception.hh"
#include "lm/max_order.hh"
#include "lm/read_arpa.hh"
#include "lm/trie.hh"
#include "lm/vocab.hh"
@ -9,16 +11,16 @@
#include "lm/word_index.hh"
#include "util/ersatz_progress.hh"
#include "util/file_piece.hh"
#include "util/have.hh"
#include "util/proxy_iterator.hh"
#include "util/scoped.hh"
#include <algorithm>
#include <cmath>
#include <cstring>
#include <cstdio>
#include <deque>
#include <iostream>
#include <limits>
//#include <parallel/algorithm>
#include <vector>
#include <sys/mman.h>
@ -26,6 +28,7 @@
#include <sys/stat.h>
#include <fcntl.h>
#include <stdlib.h>
#include <unistd.h>
namespace lm {
namespace ngram {
@ -97,7 +100,7 @@ class EntryProxy {
}
const WordIndex *Indices() const {
return static_cast<const WordIndex*>(inner_.Data());
return reinterpret_cast<const WordIndex*>(inner_.Data());
}
private:
@ -113,21 +116,61 @@ class EntryProxy {
typedef util::ProxyIterator<EntryProxy> NGramIter;
class CompareRecords : public std::binary_function<const EntryProxy &, const EntryProxy &, bool> {
// Proxy for an entry except there is some extra cruft between the entries. This is used to sort (n-1)-grams using the same memory as the sorted n-grams.
class PartialViewProxy {
public:
PartialViewProxy() : attention_size_(0), inner_() {}
PartialViewProxy(void *ptr, std::size_t block_size, std::size_t attention_size) : attention_size_(attention_size), inner_(ptr, block_size) {}
operator std::string() const {
return std::string(reinterpret_cast<const char*>(inner_.Data()), attention_size_);
}
PartialViewProxy &operator=(const PartialViewProxy &from) {
memcpy(inner_.Data(), from.inner_.Data(), attention_size_);
return *this;
}
PartialViewProxy &operator=(const std::string &from) {
memcpy(inner_.Data(), from.data(), attention_size_);
return *this;
}
const WordIndex *Indices() const {
return reinterpret_cast<const WordIndex*>(inner_.Data());
}
private:
friend class util::ProxyIterator<PartialViewProxy>;
typedef std::string value_type;
const std::size_t attention_size_;
typedef EntryIterator InnerIterator;
InnerIterator &Inner() { return inner_; }
const InnerIterator &Inner() const { return inner_; }
InnerIterator inner_;
};
typedef util::ProxyIterator<PartialViewProxy> PartialIter;
template <class Proxy> class CompareRecords : public std::binary_function<const Proxy &, const Proxy &, bool> {
public:
explicit CompareRecords(unsigned char order) : order_(order) {}
bool operator()(const EntryProxy &first, const EntryProxy &second) const {
bool operator()(const Proxy &first, const Proxy &second) const {
return Compare(first.Indices(), second.Indices());
}
bool operator()(const EntryProxy &first, const std::string &second) const {
bool operator()(const Proxy &first, const std::string &second) const {
return Compare(first.Indices(), reinterpret_cast<const WordIndex*>(second.data()));
}
bool operator()(const std::string &first, const EntryProxy &second) const {
bool operator()(const std::string &first, const Proxy &second) const {
return Compare(reinterpret_cast<const WordIndex*>(first.data()), second.Indices());
}
bool operator()(const std::string &first, const std::string &second) const {
return Compare(reinterpret_cast<const WordIndex*>(first.data()), reinterpret_cast<const WordIndex*>(first.data()));
return Compare(reinterpret_cast<const WordIndex*>(first.data()), reinterpret_cast<const WordIndex*>(second.data()));
}
private:
@ -143,6 +186,12 @@ class CompareRecords : public std::binary_function<const EntryProxy &, const Ent
unsigned char order_;
};
FILE *OpenOrThrow(const char *name, const char *mode) {
FILE *ret = fopen(name, mode);
if (!ret) UTIL_THROW(util::ErrnoException, "Could not open " << name << " for " << mode);
return ret;
}
void WriteOrThrow(FILE *to, const void *data, size_t size) {
assert(size);
if (1 != std::fwrite(data, size, 1, to)) UTIL_THROW(util::ErrnoException, "Short write; requested size " << size);
@ -152,28 +201,42 @@ void ReadOrThrow(FILE *from, void *data, size_t size) {
if (1 != std::fread(data, size, 1, from)) UTIL_THROW(util::ErrnoException, "Short read; requested size " << size);
}
const std::size_t kCopyBufSize = 512;
void CopyOrThrow(FILE *from, FILE *to, size_t size) {
const size_t kBufSize = 512;
char buf[kBufSize];
for (size_t i = 0; i < size; i += kBufSize) {
std::size_t amount = std::min(size - i, kBufSize);
char buf[std::min<size_t>(size, kCopyBufSize)];
for (size_t i = 0; i < size; i += kCopyBufSize) {
std::size_t amount = std::min(size - i, kCopyBufSize);
ReadOrThrow(from, buf, amount);
WriteOrThrow(to, buf, amount);
}
}
void CopyRestOrThrow(FILE *from, FILE *to) {
char buf[kCopyBufSize];
size_t amount;
while ((amount = fread(buf, 1, kCopyBufSize, from))) {
WriteOrThrow(to, buf, amount);
}
if (!feof(from)) UTIL_THROW(util::ErrnoException, "Short read");
}
void RemoveOrThrow(const char *name) {
if (std::remove(name)) UTIL_THROW(util::ErrnoException, "Could not remove " << name);
}
std::string DiskFlush(const void *mem_begin, const void *mem_end, const std::string &file_prefix, std::size_t batch, unsigned char order, std::size_t weights_size) {
const std::size_t entry_size = sizeof(WordIndex) * order + weights_size;
const std::size_t prefix_size = sizeof(WordIndex) * (order - 1);
std::stringstream assembled;
assembled << file_prefix << static_cast<unsigned int>(order) << '_' << batch;
std::string ret(assembled.str());
util::scoped_FILE out(fopen(ret.c_str(), "w"));
if (!out.get()) UTIL_THROW(util::ErrnoException, "Couldn't open " << assembled.str().c_str() << " for writing");
util::scoped_FILE out(OpenOrThrow(ret.c_str(), "w"));
// Compress entries that being with the same (order-1) words.
for (const uint8_t *group_begin = static_cast<const uint8_t*>(mem_begin); group_begin != static_cast<const uint8_t*>(mem_end);) {
const uint8_t *group_end = group_begin;
for (group_end += entry_size; (group_end != static_cast<const uint8_t*>(mem_end)) && !memcmp(group_begin, group_end, prefix_size); group_end += entry_size) {}
const uint8_t *group_end;
for (group_end = group_begin + entry_size;
(group_end != static_cast<const uint8_t*>(mem_end)) && !memcmp(group_begin, group_end, prefix_size);
group_end += entry_size) {}
WriteOrThrow(out.get(), group_begin, prefix_size);
WordIndex group_size = (group_end - group_begin) / entry_size;
WriteOrThrow(out.get(), &group_size, sizeof(group_size));
@ -188,11 +251,10 @@ std::string DiskFlush(const void *mem_begin, const void *mem_end, const std::str
class SortedFileReader {
public:
SortedFileReader() {}
SortedFileReader() : ended_(false) {}
void Init(const std::string &name, unsigned char order) {
file_.reset(fopen(name.c_str(), "r"));
if (!file_.get()) UTIL_THROW(util::ErrnoException, "Opening " << name << " for read");
file_.reset(OpenOrThrow(name.c_str(), "r"));
header_.resize(order - 1);
NextHeader();
}
@ -206,25 +268,39 @@ class SortedFileReader {
std::size_t HeaderBytes() const { return header_.size() * sizeof(WordIndex); }
void NextHeader() {
if (1 != fread(&*header_.begin(), HeaderBytes(), 1, file_.get()) && !Ended()) {
UTIL_THROW(util::ErrnoException, "Short read of counts");
if (1 != fread(&*header_.begin(), HeaderBytes(), 1, file_.get())) {
if (feof(file_.get())) {
ended_ = true;
} else {
UTIL_THROW(util::ErrnoException, "Short read of counts");
}
}
}
void ReadCount(WordIndex &to) {
ReadOrThrow(file_.get(), &to, sizeof(WordIndex));
WordIndex ReadCount() {
WordIndex ret;
ReadOrThrow(file_.get(), &ret, sizeof(WordIndex));
return ret;
}
void ReadWord(WordIndex &to) {
ReadOrThrow(file_.get(), &to, sizeof(WordIndex));
WordIndex ReadWord() {
WordIndex ret;
ReadOrThrow(file_.get(), &ret, sizeof(WordIndex));
return ret;
}
template <class Weights> void ReadWeights(Weights &to) {
ReadOrThrow(file_.get(), &to, sizeof(Weights));
template <class Weights> void ReadWeights(Weights &weights) {
ReadOrThrow(file_.get(), &weights, sizeof(Weights));
}
bool Ended() {
return feof(file_.get());
bool Ended() const {
return ended_;
}
void Rewind() {
rewind(file_.get());
ended_ = false;
NextHeader();
}
FILE *File() { return file_.get(); }
@ -233,23 +309,25 @@ class SortedFileReader {
util::scoped_FILE file_;
std::vector<WordIndex> header_;
bool ended_;
};
void CopyFullRecord(SortedFileReader &from, FILE *to, std::size_t weights_size) {
WriteOrThrow(to, from.Header(), from.HeaderBytes());
WordIndex count;
from.ReadCount(count);
WordIndex count = from.ReadCount();
WriteOrThrow(to, &count, sizeof(WordIndex));
CopyOrThrow(from.File(), to, (weights_size + sizeof(WordIndex)) * count);
}
void MergeSortedFiles(const char *first_name, const char *second_name, const char *out, std::size_t weights_size, unsigned char order) {
void MergeSortedFiles(const std::string &first_name, const std::string &second_name, const std::string &out, std::size_t weights_size, unsigned char order) {
SortedFileReader first, second;
first.Init(first_name, order);
second.Init(second_name, order);
util::scoped_FILE out_file(fopen(out, "w"));
if (!out_file.get()) UTIL_THROW(util::ErrnoException, "Could not open " << out << " for write");
first.Init(first_name.c_str(), order);
RemoveOrThrow(first_name.c_str());
second.Init(second_name.c_str(), order);
RemoveOrThrow(second_name.c_str());
util::scoped_FILE out_file(OpenOrThrow(out.c_str(), "w"));
while (!first.Ended() && !second.Ended()) {
if (first.HeaderVector() < second.HeaderVector()) {
CopyFullRecord(first, out_file.get(), weights_size);
@ -263,25 +341,23 @@ void MergeSortedFiles(const char *first_name, const char *second_name, const cha
}
// Merge at the entry level.
WriteOrThrow(out_file.get(), first.Header(), first.HeaderBytes());
WordIndex first_count, second_count;
first.ReadCount(first_count); second.ReadCount(second_count);
WordIndex first_count = first.ReadCount(), second_count = second.ReadCount();
WordIndex total_count = first_count + second_count;
WriteOrThrow(out_file.get(), &total_count, sizeof(WordIndex));
WordIndex first_word, second_word;
first.ReadWord(first_word); second.ReadWord(second_word);
WordIndex first_word = first.ReadWord(), second_word = second.ReadWord();
WordIndex first_index = 0, second_index = 0;
while (true) {
if (first_word < second_word) {
WriteOrThrow(out_file.get(), &first_word, sizeof(WordIndex));
CopyOrThrow(first.File(), out_file.get(), weights_size);
if (++first_index == first_count) break;
first.ReadWord(first_word);
first_word = first.ReadWord();
} else {
WriteOrThrow(out_file.get(), &second_word, sizeof(WordIndex));
CopyOrThrow(second.File(), out_file.get(), weights_size);
if (++second_index == second_count) break;
second.ReadWord(second_word);
second_word = second.ReadWord();
}
}
if (first_index == first_count) {
@ -300,10 +376,111 @@ void MergeSortedFiles(const char *first_name, const char *second_name, const cha
}
}
void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, util::scoped_memory &mem, const std::string &file_prefix, unsigned char order) {
if (order == 1) return;
ConvertToSorted(f, vocab, counts, mem, file_prefix, order - 1);
const char *kContextSuffix = "_contexts";
void WriteContextFile(uint8_t *begin, uint8_t *end, const std::string &ngram_file_name, std::size_t entry_size, unsigned char order) {
const size_t context_size = sizeof(WordIndex) * (order - 1);
// Sort just the contexts using the same memory.
PartialIter context_begin(PartialViewProxy(begin + sizeof(WordIndex), entry_size, context_size));
PartialIter context_end(PartialViewProxy(end + sizeof(WordIndex), entry_size, context_size));
std::sort(context_begin, context_end, CompareRecords<PartialViewProxy>(order - 1));
std::string name(ngram_file_name + kContextSuffix);
util::scoped_FILE out(OpenOrThrow(name.c_str(), "w"));
// Write out to file and uniqueify at the same time. Could have used unique_copy if there was an appropriate OutputIterator.
if (context_begin == context_end) return;
PartialIter i(context_begin);
WriteOrThrow(out.get(), i->Indices(), context_size);
const WordIndex *previous = i->Indices();
++i;
for (; i != context_end; ++i) {
if (memcmp(previous, i->Indices(), context_size)) {
WriteOrThrow(out.get(), i->Indices(), context_size);
previous = i->Indices();
}
}
}
class ContextReader {
public:
ContextReader() : valid_(false) {}
ContextReader(const char *name, unsigned char order) {
Reset(name, order);
}
void Reset(const char *name, unsigned char order) {
file_.reset(OpenOrThrow(name, "r"));
length_ = sizeof(WordIndex) * static_cast<size_t>(order);
words_.resize(order);
valid_ = true;
++*this;
}
ContextReader &operator++() {
if (1 != fread(&*words_.begin(), length_, 1, file_.get())) {
if (!feof(file_.get()))
UTIL_THROW(util::ErrnoException, "Short read");
valid_ = false;
}
return *this;
}
const WordIndex *operator*() const { return &*words_.begin(); }
operator bool() const { return valid_; }
FILE *GetFile() { return file_.get(); }
private:
util::scoped_FILE file_;
size_t length_;
std::vector<WordIndex> words_;
bool valid_;
};
void MergeContextFiles(const std::string &first_base, const std::string &second_base, const std::string &out_base, unsigned char order) {
const size_t context_size = sizeof(WordIndex) * (order - 1);
std::string first_name(first_base + kContextSuffix);
std::string second_name(second_base + kContextSuffix);
ContextReader first(first_name.c_str(), order - 1), second(second_name.c_str(), order - 1);
RemoveOrThrow(first_name.c_str());
RemoveOrThrow(second_name.c_str());
std::string out_name(out_base + kContextSuffix);
util::scoped_FILE out(OpenOrThrow(out_name.c_str(), "w"));
while (first && second) {
for (const WordIndex *f = *first, *s = *second; ; ++f, ++s) {
if (f == *first + order - 1) {
// Equal.
WriteOrThrow(out.get(), *first, context_size);
++first;
++second;
break;
}
if (*f < *s) {
// First lower
WriteOrThrow(out.get(), *first, context_size);
++first;
break;
} else if (*f > *s) {
WriteOrThrow(out.get(), *second, context_size);
++second;
break;
}
}
}
ContextReader &remaining = first ? first : second;
if (!remaining) return;
WriteOrThrow(out.get(), *remaining, context_size);
CopyRestOrThrow(remaining.GetFile(), out.get());
}
void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, util::scoped_memory &mem, const std::string &file_prefix, unsigned char order) {
ReadNGramHeader(f, order);
const size_t count = counts[order - 1];
// Size of weights. Does it include backoff?
@ -325,11 +502,13 @@ void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const st
ReadNGram(f, order, vocab, reinterpret_cast<WordIndex*>(out), *reinterpret_cast<ProbBackoff*>(out + words_size));
}
}
// TODO: __gnu_parallel::sort here.
// Sort full records by full n-gram.
EntryProxy proxy_begin(begin, entry_size), proxy_end(out_end, entry_size);
std::sort(NGramIter(proxy_begin), NGramIter(proxy_end), CompareRecords(order));
// parallel_sort uses too much RAM
std::sort(NGramIter(proxy_begin), NGramIter(proxy_end), CompareRecords<EntryProxy>(order));
files.push_back(DiskFlush(begin, out_end, file_prefix, batch, order, weights_size));
WriteContextFile(begin, out_end, files.back(), entry_size, order);
done += (out_end - begin) / entry_size;
}
@ -340,10 +519,9 @@ void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const st
std::stringstream assembled;
assembled << file_prefix << static_cast<unsigned int>(order) << "_merge_" << (merge_count++);
files.push_back(assembled.str());
MergeSortedFiles(files[0].c_str(), files[1].c_str(), files.back().c_str(), weights_size, order);
if (std::remove(files[0].c_str())) UTIL_THROW(util::ErrnoException, "Could not remove " << files[0]);
MergeSortedFiles(files[0], files[1], files.back(), weights_size, order);
MergeContextFiles(files[0], files[1], files.back(), order);
files.pop_front();
if (std::remove(files[0].c_str())) UTIL_THROW(util::ErrnoException, "Could not remove " << files[0]);
files.pop_front();
}
if (!files.empty()) {
@ -351,129 +529,351 @@ void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const st
assembled << file_prefix << static_cast<unsigned int>(order) << "_merged";
std::string merged_name(assembled.str());
if (std::rename(files[0].c_str(), merged_name.c_str())) UTIL_THROW(util::ErrnoException, "Could not rename " << files[0].c_str() << " to " << merged_name.c_str());
std::string context_name = files[0] + kContextSuffix;
merged_name += kContextSuffix;
if (std::rename(context_name.c_str(), merged_name.c_str())) UTIL_THROW(util::ErrnoException, "Could not rename " << context_name << " to " << merged_name.c_str());
}
}
void ARPAToSortedFiles(util::FilePiece &f, const std::vector<uint64_t> &counts, std::size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) {
void ARPAToSortedFiles(const Config &config, util::FilePiece &f, std::vector<uint64_t> &counts, size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) {
{
std::string unigram_name = file_prefix + "unigrams";
util::scoped_fd unigram_file;
util::scoped_mmap unigram_mmap;
unigram_mmap.reset(util::MapZeroedWrite(unigram_name.c_str(), counts[0] * sizeof(ProbBackoff), unigram_file), counts[0] * sizeof(ProbBackoff));
// In case <unk> appears.
size_t extra_count = counts[0] + 1;
util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_name.c_str(), extra_count * sizeof(ProbBackoff), unigram_file), extra_count * sizeof(ProbBackoff));
Read1Grams(f, counts[0], vocab, reinterpret_cast<ProbBackoff*>(unigram_mmap.get()));
CheckSpecials(config, vocab);
if (!vocab.SawUnk()) ++counts[0];
}
// Only use as much buffer as we need.
size_t buffer_use = 0;
for (unsigned int order = 2; order < counts.size(); ++order) {
buffer_use = std::max<size_t>(buffer_use, static_cast<size_t>((sizeof(WordIndex) * order + 2 * sizeof(float)) * counts[order - 1]));
}
buffer_use = std::max<size_t>(buffer_use, static_cast<size_t>((sizeof(WordIndex) * counts.size() + sizeof(float)) * counts.back()));
buffer = std::min<size_t>(buffer, buffer_use);
util::scoped_memory mem;
mem.reset(malloc(buffer), buffer, util::scoped_memory::ARRAY_ALLOCATED);
mem.reset(malloc(buffer), buffer, util::scoped_memory::MALLOC_ALLOCATED);
if (!mem.get()) UTIL_THROW(util::ErrnoException, "malloc failed for sort buffer size " << buffer);
ConvertToSorted(f, vocab, counts, mem, file_prefix, counts.size());
for (unsigned char order = 2; order <= counts.size(); ++order) {
ConvertToSorted(f, vocab, counts, mem, file_prefix, order);
}
ReadEnd(f);
}
struct RecursiveInsertParams {
WordIndex *words;
SortedFileReader *files;
unsigned char max_order;
// This is an array of size order - 2.
BitPackedMiddle *middle;
// This has exactly one entry.
BitPackedLongest *longest;
};
uint64_t RecursiveInsert(RecursiveInsertParams &params, unsigned char order) {
SortedFileReader &file = params.files[order - 2];
const uint64_t ret = (order == params.max_order) ? params.longest->InsertIndex() : params.middle[order - 2].InsertIndex();
if (std::memcmp(params.words, file.Header(), sizeof(WordIndex) * (order - 1)))
return ret;
WordIndex count;
file.ReadCount(count);
WordIndex key;
if (order == params.max_order) {
Prob value;
for (WordIndex i = 0; i < count; ++i) {
file.ReadWord(key);
file.ReadWeights(value);
params.longest->Insert(key, value.prob);
bool HeadMatch(const WordIndex *words, const WordIndex *const words_end, const WordIndex *header) {
for (; words != words_end; ++words, ++header) {
if (*words != *header) {
//assert(*words <= *header);
return false;
}
file.NextHeader();
return ret;
}
ProbBackoff value;
for (WordIndex i = 0; i < count; ++i) {
file.ReadWord(params.words[order - 1]);
file.ReadWeights(value);
params.middle[order - 2].Insert(
params.words[order - 1],
value.prob,
value.backoff,
RecursiveInsert(params, order + 1));
}
file.NextHeader();
return ret;
return true;
}
void BuildTrie(const std::string &file_prefix, const std::vector<uint64_t> &counts, std::ostream *messages, TrieSearch &out) {
UnigramValue *unigrams = out.unigram.Raw();
// Load unigrams. Leave the next pointers uninitialized.
{
std::string name(file_prefix + "unigrams");
util::scoped_FILE file(fopen(name.c_str(), "r"));
if (!file.get()) UTIL_THROW(util::ErrnoException, "Opening " << name << " failed");
for (WordIndex i = 0; i < counts[0]; ++i) {
ReadOrThrow(file.get(), &unigrams[i].weights, sizeof(ProbBackoff));
}
unlink(name.c_str());
}
// Phase to count n-grams, including blanks inserted because they were pruned but have extensions
class JustCount {
public:
JustCount(ContextReader * /*contexts*/, UnigramValue * /*unigrams*/, BitPackedMiddle * /*middle*/, BitPackedLongest &/*longest*/, uint64_t *counts, unsigned char order)
: counts_(counts), longest_counts_(counts + order - 1) {}
void Unigrams(WordIndex begin, WordIndex end) {
counts_[0] += end - begin;
}
void MiddleBlank(const unsigned char mid_idx, WordIndex /* idx */) {
++counts_[mid_idx + 1];
}
void Middle(const unsigned char mid_idx, const WordIndex * /*before*/, WordIndex /*key*/, const ProbBackoff &/*weights*/) {
++counts_[mid_idx + 1];
}
void Longest(WordIndex /*key*/, Prob /*prob*/) {
++*longest_counts_;
}
// Unigrams wrote one past.
void Cleanup() {
--counts_[0];
}
private:
uint64_t *const counts_, *const longest_counts_;
};
// Phase to actually write n-grams to the trie.
class WriteEntries {
public:
WriteEntries(ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, const uint64_t * /*counts*/, unsigned char order) :
contexts_(contexts),
unigrams_(unigrams),
middle_(middle),
longest_(longest),
bigram_pack_((order == 2) ? static_cast<BitPacked&>(longest_) : static_cast<BitPacked&>(*middle_)) {}
void Unigrams(WordIndex begin, WordIndex end) {
uint64_t next = bigram_pack_.InsertIndex();
for (UnigramValue *i = unigrams_ + begin; i < unigrams_ + end; ++i) {
i->next = next;
}
}
void MiddleBlank(const unsigned char mid_idx, WordIndex key) {
middle_[mid_idx].Insert(key, kBlankProb, kBlankBackoff);
}
void Middle(const unsigned char mid_idx, const WordIndex *before, WordIndex key, ProbBackoff weights) {
// Order (mid_idx+2).
ContextReader &context = contexts_[mid_idx + 1];
if (context && !memcmp(before, *context, sizeof(WordIndex) * (mid_idx + 1)) && (*context)[mid_idx + 1] == key) {
SetExtension(weights.backoff);
++context;
}
middle_[mid_idx].Insert(key, weights.prob, weights.backoff);
}
void Longest(WordIndex key, Prob prob) {
longest_.Insert(key, prob.prob);
}
void Cleanup() {}
private:
ContextReader *contexts_;
UnigramValue *const unigrams_;
BitPackedMiddle *const middle_;
BitPackedLongest &longest_;
BitPacked &bigram_pack_;
};
template <class Doing> class RecursiveInsert {
public:
RecursiveInsert(SortedFileReader *inputs, ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, uint64_t *counts, unsigned char order) :
doing_(contexts, unigrams, middle, longest, counts, order), inputs_(inputs), inputs_end_(inputs + order - 1), order_minus_2_(order - 2) {
}
// Outer unigram loop.
void Apply(std::ostream *progress_out, const char *message, WordIndex unigram_count) {
util::ErsatzProgress progress(progress_out, message, unigram_count + 1);
for (words_[0] = 0; ; ++words_[0]) {
progress.Set(words_[0]);
WordIndex min_continue = unigram_count;
for (SortedFileReader *other = inputs_; other != inputs_end_; ++other) {
if (other->Ended()) continue;
min_continue = std::min(min_continue, other->Header()[0]);
}
// This will write at unigram_count. This is by design so that the next pointers will make sense.
doing_.Unigrams(words_[0], min_continue + 1);
if (min_continue == unigram_count) break;
words_[0] = min_continue;
Middle(0);
}
doing_.Cleanup();
}
private:
void Middle(const unsigned char mid_idx) {
// (mid_idx + 2)-gram.
if (mid_idx == order_minus_2_) {
Longest();
return;
}
// Orders [2, order)
SortedFileReader &reader = inputs_[mid_idx];
if (reader.Ended() || !HeadMatch(words_, words_ + mid_idx + 1, reader.Header())) {
// This order doesn't have a header match, but longer ones might.
MiddleAllBlank(mid_idx);
return;
}
// There is a header match.
WordIndex count = reader.ReadCount();
WordIndex current = reader.ReadWord();
while (count) {
WordIndex min_continue = std::numeric_limits<WordIndex>::max();
for (SortedFileReader *other = inputs_ + mid_idx + 1; other < inputs_end_; ++other) {
if (!other->Ended() && HeadMatch(words_, words_ + mid_idx + 1, other->Header()))
min_continue = std::min(min_continue, other->Header()[mid_idx + 1]);
}
while (true) {
if (current > min_continue) {
doing_.MiddleBlank(mid_idx, min_continue);
words_[mid_idx + 1] = min_continue;
Middle(mid_idx + 1);
break;
}
ProbBackoff weights;
reader.ReadWeights(weights);
doing_.Middle(mid_idx, words_, current, weights);
--count;
if (current == min_continue) {
words_[mid_idx + 1] = min_continue;
Middle(mid_idx + 1);
if (count) current = reader.ReadWord();
break;
}
if (!count) break;
current = reader.ReadWord();
}
}
// Count is now zero. Finish off remaining blanks.
MiddleAllBlank(mid_idx);
reader.NextHeader();
}
void MiddleAllBlank(const unsigned char mid_idx) {
while (true) {
WordIndex min_continue = std::numeric_limits<WordIndex>::max();
for (SortedFileReader *other = inputs_ + mid_idx + 1; other < inputs_end_; ++other) {
if (!other->Ended() && HeadMatch(words_, words_ + mid_idx + 1, other->Header()))
min_continue = std::min(min_continue, other->Header()[mid_idx + 1]);
}
if (min_continue == std::numeric_limits<WordIndex>::max()) return;
doing_.MiddleBlank(mid_idx, min_continue);
words_[mid_idx + 1] = min_continue;
Middle(mid_idx + 1);
}
}
void Longest() {
SortedFileReader &reader = *(inputs_end_ - 1);
if (reader.Ended() || !HeadMatch(words_, words_ + order_minus_2_ + 1, reader.Header())) return;
WordIndex count = reader.ReadCount();
for (WordIndex i = 0; i < count; ++i) {
WordIndex word = reader.ReadWord();
Prob prob;
reader.ReadWeights(prob);
doing_.Longest(word, prob);
}
reader.NextHeader();
return;
}
Doing doing_;
SortedFileReader *inputs_;
SortedFileReader *inputs_end_;
WordIndex words_[kMaxOrder];
const unsigned char order_minus_2_;
};
void SanityCheckCounts(const std::vector<uint64_t> &initial, const std::vector<uint64_t> &fixed) {
if (fixed[0] != initial[0]) UTIL_THROW(util::Exception, "Unigram count should be constant but initial is " << initial[0] << " and recounted is " << fixed[0]);
if (fixed.back() != initial.back()) UTIL_THROW(util::Exception, "Longest count should be constant but it changed from " << initial.back() << " to " << fixed.back());
for (unsigned char i = 0; i < initial.size(); ++i) {
if (fixed[i] < initial[i]) UTIL_THROW(util::Exception, "Counts came out lower than expected. This shouldn't happen");
}
}
void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch &out, Backing &backing) {
std::vector<SortedFileReader> inputs(counts.size() - 1);
std::vector<ContextReader> contexts(counts.size() - 1);
// inputs[0] is bigrams.
SortedFileReader inputs[counts.size() - 1];
for (unsigned char i = 2; i <= counts.size(); ++i) {
std::stringstream assembled;
assembled << file_prefix << static_cast<unsigned int>(i) << "_merged";
inputs[i-2].Init(assembled.str(), i);
unlink(assembled.str().c_str());
RemoveOrThrow(assembled.str().c_str());
assembled << kContextSuffix;
contexts[i-2].Reset(assembled.str().c_str(), i-1);
RemoveOrThrow(assembled.str().c_str());
}
// words[0] is unigrams.
WordIndex words[counts.size()];
RecursiveInsertParams params;
params.words = words;
params.files = inputs;
params.max_order = static_cast<unsigned char>(counts.size());
params.middle = &*out.middle.begin();
params.longest = &out.longest;
std::vector<uint64_t> fixed_counts(counts.size());
{
util::ErsatzProgress progress(messages, "Building trie", counts[0]);
for (words[0] = 0; words[0] < counts[0]; ++words[0], ++progress) {
unigrams[words[0]].next = RecursiveInsert(params, 2);
RecursiveInsert<JustCount> counter(&*inputs.begin(), &*contexts.begin(), NULL, &*out.middle.begin(), out.longest, &*fixed_counts.begin(), counts.size());
counter.Apply(config.messages, "Counting n-grams that should not have been pruned", counts[0]);
}
for (std::vector<SortedFileReader>::const_iterator i = inputs.begin(); i != inputs.end(); ++i) {
if (!i->Ended()) UTIL_THROW(FormatLoadException, "There's a bug in the trie implementation: the " << (i - inputs.begin() + 2) << "-gram table did not complete reading");
}
SanityCheckCounts(counts, fixed_counts);
counts = fixed_counts;
out.SetupMemory(GrowForSearch(config, TrieSearch::Size(fixed_counts, config), backing), fixed_counts, config);
for (unsigned char i = 2; i <= counts.size(); ++i) {
inputs[i-2].Rewind();
}
UnigramValue *unigrams = out.unigram.Raw();
// Fill entries except unigram probabilities.
{
RecursiveInsert<WriteEntries> inserter(&*inputs.begin(), &*contexts.begin(), unigrams, &*out.middle.begin(), out.longest, &*fixed_counts.begin(), counts.size());
inserter.Apply(config.messages, "Building trie", fixed_counts[0]);
}
// Fill unigram probabilities.
try {
std::string name(file_prefix + "unigrams");
util::scoped_FILE file(OpenOrThrow(name.c_str(), "r"));
for (WordIndex i = 0; i < counts[0]; ++i) {
ReadOrThrow(file.get(), &unigrams[i].weights, sizeof(ProbBackoff));
if (contexts[0] && **contexts[0] == i) {
SetExtension(unigrams[i].weights.backoff);
++contexts[0];
}
}
RemoveOrThrow(name.c_str());
} catch (util::Exception &e) {
e << " while re-reading unigram probabilities";
throw;
}
// Do not disable this error message or else too little state will be returned. Both WriteEntries::Middle and returning state based on found n-grams will need to be fixed to handle this situation.
for (unsigned char order = 2; order <= counts.size(); ++order) {
const ContextReader &context = contexts[order - 2];
if (context) {
FormatLoadException e;
e << "An " << static_cast<unsigned int>(order) << "-gram has the context (i.e. all but the last word):";
for (const WordIndex *i = *context; i != *context + order - 1; ++i) {
e << ' ' << *i;
}
e << " so this context must appear in the model as a " << static_cast<unsigned int>(order - 1) << "-gram but it does not";
throw e;
}
}
/* Set ending offsets so the last entry will be sized properly */
// Last entry for unigrams was already set.
if (!out.middle.empty()) {
unigrams[counts[0]].next = out.middle.front().InsertIndex();
for (size_t i = 0; i < out.middle.size() - 1; ++i) {
out.middle[i].FinishedLoading(out.middle[i+1].InsertIndex());
}
out.middle.back().FinishedLoading(out.longest.InsertIndex());
} else {
unigrams[counts[0]].next = out.longest.InsertIndex();
}
}
}
bool IsDirectory(const char *path) {
struct stat info;
if (0 != stat(path, &info)) return false;
return S_ISDIR(info.st_mode);
}
} // namespace
void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab) {
void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) {
std::string temporary_directory;
if (config.temporary_directory_prefix) {
temporary_directory = config.temporary_directory_prefix;
if (!temporary_directory.empty() && temporary_directory[temporary_directory.size() - 1] != '/' && IsDirectory(temporary_directory.c_str()))
temporary_directory += '/';
} else if (config.write_mmap) {
temporary_directory = config.write_mmap;
} else {
temporary_directory = file;
}
// Null on end is kludge to ensure null termination.
temporary_directory += "-tmp-XXXXXX\0";
temporary_directory += "_trie_tmp_XXXXXX";
temporary_directory += '\0';
if (!mkdtemp(&temporary_directory[0])) {
UTIL_THROW(util::ErrnoException, "Failed to make a temporary directory based on the name " << temporary_directory.c_str());
}
@ -482,10 +882,11 @@ void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, const
// Add directory delimiter. Assumes a real operating system.
temporary_directory += '/';
// At least 1MB sorting memory.
ARPAToSortedFiles(f, counts, std::max<size_t>(config.building_memory, 1048576), temporary_directory.c_str(), vocab);
BuildTrie(temporary_directory.c_str(), counts, config.messages, *this);
if (rmdir(temporary_directory.c_str())) {
std::cerr << "Failed to delete " << temporary_directory << std::endl;
ARPAToSortedFiles(config, f, counts, std::max<size_t>(config.building_memory, 1048576), temporary_directory.c_str(), vocab);
BuildTrie(temporary_directory, counts, config, *this, backing);
if (rmdir(temporary_directory.c_str()) && config.messages) {
*config.messages << "Failed to delete " << temporary_directory << std::endl;
}
}

View File

@ -9,6 +9,7 @@
namespace lm {
namespace ngram {
struct Backing;
class SortedVocabulary;
namespace trie {
@ -39,14 +40,18 @@ struct TrieSearch {
start += Unigram::Size(counts[0]);
middle.resize(counts.size() - 2);
for (unsigned char i = 1; i < counts.size() - 1; ++i) {
middle[i-1].Init(start, counts[0], counts[i+1]);
middle[i-1].Init(
start,
counts[0],
counts[i+1],
(i == counts.size() - 2) ? static_cast<const BitPacked&>(longest) : static_cast<const BitPacked &>(middle[i]));
start += Middle::Size(counts[i], counts[0], counts[i+1]);
}
longest.Init(start, counts[0]);
return start + Longest::Size(counts.back(), counts[0]);
}
void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab);
void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing);
bool LookupUnigram(WordIndex word, float &prob, float &backoff, Node &node) const {
return unigram.Find(word, prob, backoff, node);
@ -65,7 +70,7 @@ struct TrieSearch {
}
bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const {
// TODO: don't decode prob.
// TODO: don't decode backoff.
assert(begin != end);
float ignored_prob, ignored_backoff;
LookupUnigram(*begin, ignored_prob, ignored_backoff, node);

View File

@ -1,17 +1,17 @@
\data\
ngram 1=34
ngram 2=43
ngram 3=8
ngram 4=5
ngram 5=3
ngram 1=37
ngram 2=47
ngram 3=11
ngram 4=6
ngram 5=4
\1-grams:
-1.383514 , -0.30103
-1.139057 . -0.845098
-1.029493 </s>
-99 <s> -0.4149733
-1.995635 <unk>
-1.995635 <unk> -20
-1.285941 a -0.69897
-1.687872 also -0.30103
-1.687872 beyond -0.30103
@ -41,6 +41,9 @@ ngram 5=3
-1.687872 watching -0.30103
-1.687872 what -0.30103
-1.687872 would -0.30103
-3.141592 foo
-2.718281 bar 3.0
-6.535897 baz -0.0
\2-grams:
-0.6925742 , .
@ -86,6 +89,10 @@ ngram 5=3
-0.2922095 watching considering
-0.2922095 what i
-0.2922095 would also
-2 also would -6
-15 <unk> <unk> -2
-4 <unk> however -1
-6 foo bar
\3-grams:
-0.01916512 more . </s>
@ -96,6 +103,9 @@ ngram 5=3
-0.3488368 <s> looking on -0.4771212
-0.1892331 little more loin
-0.04835128 looking on a -0.4771212
-3 also would consider -7
-6 <unk> however <unk> -12
-7 to look good
\4-grams:
-0.009249173 looking on a little -0.4771212
@ -103,10 +113,12 @@ ngram 5=3
-0.005464747 screening a little more
-0.1453306 a little more loin
-0.01552657 <s> looking on a -0.4771212
-4 also would consider higher -8
\5-grams:
-0.003061223 <s> looking on a little
-0.001813953 looking on a little more
-0.0432557 on a little more loin
-5 also would consider higher looking
\end\

View File

@ -82,7 +82,8 @@ std::size_t BitPackedMiddle::Size(uint64_t entries, uint64_t max_vocab, uint64_t
return BaseSize(entries, max_vocab, 32 + util::RequiredBits(max_ptr));
}
void BitPackedMiddle::Init(void *base, uint64_t max_vocab, uint64_t max_next) {
void BitPackedMiddle::Init(void *base, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source) {
next_source_ = &next_source;
backoff_bits_ = 32;
next_bits_ = util::RequiredBits(max_next);
if (next_bits_ > 57) UTIL_THROW(util::Exception, "Sorry, this does not support more than " << (1ULL << 57) << " n-grams of a particular order. Edit util/bit_packing.hh and fix the bit packing functions.");
@ -91,9 +92,8 @@ void BitPackedMiddle::Init(void *base, uint64_t max_vocab, uint64_t max_next) {
BaseInit(base, max_vocab, backoff_bits_ + next_bits_);
}
void BitPackedMiddle::Insert(WordIndex word, float prob, float backoff, uint64_t next) {
void BitPackedMiddle::Insert(WordIndex word, float prob, float backoff) {
assert(word <= word_mask_);
assert(next <= next_mask_);
uint64_t at_pointer = insert_index_ * total_bits_;
util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, word_bits_, word);
@ -102,6 +102,8 @@ void BitPackedMiddle::Insert(WordIndex word, float prob, float backoff, uint64_t
at_pointer += prob_bits_;
util::WriteFloat32(base_ + (at_pointer >> 3), at_pointer & 7, backoff);
at_pointer += backoff_bits_;
uint64_t next = next_source_->InsertIndex();
assert(next <= next_mask_);
util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_bits_, next);
++insert_index_;
@ -109,7 +111,9 @@ void BitPackedMiddle::Insert(WordIndex word, float prob, float backoff, uint64_t
bool BitPackedMiddle::Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const {
uint64_t at_pointer;
if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, word, at_pointer)) return false;
if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, word, at_pointer)) {
return false;
}
at_pointer *= total_bits_;
at_pointer += word_bits_;
prob = util::ReadNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7);
@ -144,7 +148,6 @@ void BitPackedMiddle::FinishedLoading(uint64_t next_end) {
util::WriteInt57(base_ + (last_next_write >> 3), last_next_write & 7, next_bits_, next_end);
}
void BitPackedLongest::Insert(WordIndex index, float prob) {
assert(index <= word_mask_);
uint64_t at_pointer = insert_index_ * total_bits_;

View File

@ -89,9 +89,10 @@ class BitPackedMiddle : public BitPacked {
static std::size_t Size(uint64_t entries, uint64_t max_vocab, uint64_t max_next);
void Init(void *base, uint64_t max_vocab, uint64_t max_next);
// next_source need not be initialized.
void Init(void *base, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source);
void Insert(WordIndex word, float prob, float backoff, uint64_t next);
void Insert(WordIndex word, float prob, float backoff);
bool Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const;
@ -102,6 +103,8 @@ class BitPackedMiddle : public BitPacked {
private:
uint8_t backoff_bits_, next_bits_;
uint64_t next_mask_;
const BitPacked *next_source_;
};

View File

@ -11,8 +11,6 @@ void Vocabulary::SetSpecial(WordIndex begin_sentence, WordIndex end_sentence, Wo
begin_sentence_ = begin_sentence;
end_sentence_ = end_sentence;
not_found_ = not_found;
if (begin_sentence_ == not_found_) throw SpecialWordMissingException("<s>");
if (end_sentence_ == not_found_) throw SpecialWordMissingException("</s>");
}
Model::~Model() {}

View File

@ -68,15 +68,19 @@ void WriteOrThrow(int fd, const void *data_void, std::size_t size) {
} // namespace
WriteWordsWrapper::WriteWordsWrapper(EnumerateVocab *inner, int fd) : inner_(inner), fd_(fd) {}
WriteWordsWrapper::WriteWordsWrapper(EnumerateVocab *inner) : inner_(inner) {}
WriteWordsWrapper::~WriteWordsWrapper() {}
void WriteWordsWrapper::Add(WordIndex index, const StringPiece &str) {
if (inner_) inner_->Add(index, str);
WriteOrThrow(fd_, str.data(), str.size());
char null_byte = 0;
// Inefficient because it's unbuffered. Sue me.
WriteOrThrow(fd_, &null_byte, 1);
buffer_.append(str.data(), str.size());
buffer_.push_back(0);
}
void WriteWordsWrapper::Write(int fd) {
if ((off_t)-1 == lseek(fd, 0, SEEK_END))
UTIL_THROW(util::ErrnoException, "Failed to seek in binary to vocab words");
WriteOrThrow(fd, buffer_.data(), buffer_.size());
}
SortedVocabulary::SortedVocabulary() : begin_(NULL), end_(NULL), enumerate_(NULL) {}
@ -183,5 +187,29 @@ void ProbingVocabulary::LoadedBinary(int fd, EnumerateVocab *to) {
SetSpecial(Index("<s>"), Index("</s>"), 0);
}
void MissingUnknown(const Config &config) throw(SpecialWordMissingException) {
switch(config.unknown_missing) {
case Config::SILENT:
return;
case Config::COMPLAIN:
if (config.messages) *config.messages << "The ARPA file is missing <unk>. Substituting log10 probability " << config.unknown_missing_logprob << "." << std::endl;
break;
case Config::THROW_UP:
UTIL_THROW(SpecialWordMissingException, "The ARPA file is missing <unk> and the model is configured to throw an exception.");
}
}
void MissingSentenceMarker(const Config &config, const char *str) throw(SpecialWordMissingException) {
switch (config.sentence_marker_missing) {
case Config::SILENT:
return;
case Config::COMPLAIN:
if (config.messages) *config.messages << "Missing special word " << str << "; will treat it as <unk>.";
break;
case Config::THROW_UP:
UTIL_THROW(SpecialWordMissingException, "The ARPA file is missing " << str << " and the model is configured to reject these models. Run build_binary -s to disable this check.");
}
}
} // namespace ngram
} // namespace lm

View File

@ -2,6 +2,7 @@
#define LM_VOCAB__
#include "lm/enumerate_vocab.hh"
#include "lm/lm_exception.hh"
#include "lm/virtual_interface.hh"
#include "util/key_value_packing.hh"
#include "util/probing_hash_table.hh"
@ -27,15 +28,18 @@ inline uint64_t HashForVocab(const StringPiece &str) {
class WriteWordsWrapper : public EnumerateVocab {
public:
WriteWordsWrapper(EnumerateVocab *inner, int fd);
WriteWordsWrapper(EnumerateVocab *inner);
~WriteWordsWrapper();
void Add(WordIndex index, const StringPiece &str);
void Write(int fd);
private:
EnumerateVocab *inner_;
int fd_;
std::string buffer_;
};
// Vocabulary based on sorted uniform find storing only uint64_t values and using their offsets as indices.
@ -62,7 +66,6 @@ class SortedVocabulary : public base::Vocabulary {
}
}
// Ignores second argument for consistency with probing hash which has a float here.
static size_t Size(std::size_t entries, const Config &config);
// Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway.
@ -132,6 +135,15 @@ class ProbingVocabulary : public base::Vocabulary {
EnumerateVocab *enumerate_;
};
void MissingUnknown(const Config &config) throw(SpecialWordMissingException);
void MissingSentenceMarker(const Config &config, const char *str) throw(SpecialWordMissingException);
template <class Vocab> void CheckSpecials(const Config &config, const Vocab &vocab) throw(SpecialWordMissingException) {
if (!vocab.SawUnk()) MissingUnknown(config);
if (vocab.BeginSentence() == vocab.NotFound()) MissingSentenceMarker(config, "<s>");
if (vocab.EndSentence() == vocab.NotFound()) MissingSentenceMarker(config, "</s>");
}
} // namespace ngram
} // namespace lm

8
kenlm/test.sh Executable file
View File

@ -0,0 +1,8 @@
#!/bin/bash
#Run tests. Requires Boost.
set -e
./compile.sh
for i in util/{bit_packing,file_piece,joint_sort,key_value_packing,probing_hash_table,sorted_uniform}_test lm/model_test; do
g++ -I. -O3 $CXXFLAGS $i.cc {lm,util}/*.o -lboost_test_exec_monitor -lz -o $i
pushd $(dirname $i) >/dev/null && ./$(basename $i) || echo "$i failed"; popd >/dev/null
done

View File

@ -22,7 +22,7 @@ uint8_t RequiredBits(uint64_t max_value) {
}
void BitPackingSanity() {
const detail::FloatEnc neg1 = { -1.0 }, pos1 = { 1.0 };
const FloatEnc neg1 = { -1.0 }, pos1 = { 1.0 };
if ((neg1.i ^ pos1.i) != 0x80000000) UTIL_THROW(Exception, "Sign bit is not 0x80000000");
char mem[57+8];
memset(mem, 0, sizeof(mem));

View File

@ -28,16 +28,19 @@ namespace util {
* but it may be called multiple times when that's inconvenient.
*/
inline uint8_t BitPackShift(uint8_t bit, uint8_t length) {
// Fun fact: __BYTE_ORDER is wrong on Solaris Sparc, but the version without __ is correct.
#if BYTE_ORDER == LITTLE_ENDIAN
inline uint8_t BitPackShift(uint8_t bit, uint8_t /*length*/) {
return bit;
}
#elif BYTE_ORDER == BIG_ENDIAN
inline uint8_t BitPackShift(uint8_t bit, uint8_t length) {
return 64 - length - bit;
}
#else
#error "Bit packing code isn't written for your byte order."
#endif
}
/* Pack integers up to 57 bits using their least significant digits.
* The length is specified using mask:
@ -53,30 +56,32 @@ inline void WriteInt57(void *base, uint8_t bit, uint8_t length, uint64_t value)
*reinterpret_cast<uint64_t*>(base) |= (value << BitPackShift(bit, length));
}
namespace detail { typedef union { float f; uint32_t i; } FloatEnc; }
typedef union { float f; uint32_t i; } FloatEnc;
inline float ReadFloat32(const void *base, uint8_t bit) {
detail::FloatEnc encoded;
FloatEnc encoded;
encoded.i = *reinterpret_cast<const uint64_t*>(base) >> BitPackShift(bit, 32);
return encoded.f;
}
inline void WriteFloat32(void *base, uint8_t bit, float value) {
detail::FloatEnc encoded;
FloatEnc encoded;
encoded.f = value;
WriteInt57(base, bit, 32, encoded.i);
}
const uint32_t kSignBit = 0x80000000;
inline float ReadNonPositiveFloat31(const void *base, uint8_t bit) {
detail::FloatEnc encoded;
FloatEnc encoded;
encoded.i = *reinterpret_cast<const uint64_t*>(base) >> BitPackShift(bit, 31);
// Sign bit set means negative.
encoded.i |= 0x80000000;
encoded.i |= kSignBit;
return encoded.f;
}
inline void WriteNonPositiveFloat31(void *base, uint8_t bit, float value) {
assert(value <= 0.0);
detail::FloatEnc encoded;
FloatEnc encoded;
encoded.f = value;
encoded.i &= ~0x80000000;
encoded.i &= ~kSignBit;
WriteInt57(base, bit, 31, encoded.i);
}

View File

@ -36,6 +36,7 @@ void ErsatzProgress::Milestone() {
if (stone == kWidth) {
(*out_) << std::endl;
next_ = std::numeric_limits<std::size_t>::max();
out_ = NULL;
} else {
next_ = std::max(next_, (stone * complete_) / kWidth);
}

View File

@ -8,6 +8,20 @@ namespace util {
Exception::Exception() throw() {}
Exception::~Exception() throw() {}
Exception::Exception(const Exception &from) : std::exception() {
stream_ << from.stream_.str();
}
Exception &Exception::operator=(const Exception &from) {
stream_ << from.stream_.str();
return *this;
}
const char *Exception::what() const throw() {
text_ = stream_.str();
return text_.c_str();
}
namespace {
// The XOPEN version.
const char *HandleStrerror(int ret, const char *buf) {
@ -16,7 +30,7 @@ const char *HandleStrerror(int ret, const char *buf) {
}
// The GNU version.
const char *HandleStrerror(const char *ret, const char *buf) {
const char *HandleStrerror(const char *ret, const char * /*buf*/) {
return ret;
}
} // namespace

View File

@ -9,24 +9,29 @@
namespace util {
template <class Except, class Data> typename Except::template ExceptionTag<Except&>::Identity operator<<(Except &e, const Data &data);
class Exception : public std::exception {
public:
Exception() throw();
virtual ~Exception() throw();
const char *what() const throw() { return what_.c_str(); }
Exception(const Exception &from);
Exception &operator=(const Exception &from);
// Not threadsafe, but probably doesn't matter. FWIW, Boost's exception guidance implies that what() isn't threadsafe.
const char *what() const throw();
private:
template <class Except, class Data> friend typename Except::template ExceptionTag<Except&>::Identity operator<<(Except &e, const Data &data);
// This helps restrict operator<< defined below.
template <class T> struct ExceptionTag {
typedef T Identity;
};
std::string &Str() {
return what_;
}
protected:
std::string what_;
std::stringstream stream_;
mutable std::string text_;
};
/* This implements the normal operator<< for Exception and all its children.
@ -34,22 +39,7 @@ class Exception : public std::exception {
* boost::enable_if.
*/
template <class Except, class Data> typename Except::template ExceptionTag<Except&>::Identity operator<<(Except &e, const Data &data) {
// Argh I had a stringstream in the exception, but the only way to get the string is by calling str(). But that's a temporary string, so virtual const char *what() const can't actually return it.
std::stringstream stream;
stream << data;
e.Str() += stream.str();
return e;
}
template <class Except> typename Except::template ExceptionTag<Except&>::Identity operator<<(Except &e, const char *data) {
e.Str() += data;
return e;
}
template <class Except> typename Except::template ExceptionTag<Except&>::Identity operator<<(Except &e, const std::string &data) {
e.Str() += data;
return e;
}
template <class Except> typename Except::template ExceptionTag<Except&>::Identity operator<<(Except &e, const StringPiece &str) {
e.Str().append(str.data(), str.length());
e.stream_ << data;
return e;
}

View File

@ -37,6 +37,9 @@ GZException::GZException(void *file) {
#endif // HAVE_ZLIB
}
// Sigh this is the only way I could come up with to do a _const_ bool. It has ' ', '\f', '\n', '\r', '\t', and '\v' (same as isspace on C locale).
const bool kSpaces[256] = {0,0,0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0};
int OpenReadOrThrow(const char *name) {
int ret = open(name, O_RDONLY);
if (ret == -1) UTIL_THROW(ErrnoException, "in open (" << name << ") for reading");
@ -76,22 +79,22 @@ FilePiece::~FilePiece() {
}
StringPiece FilePiece::ReadLine(char delim) throw (GZException, EndOfFileException) {
const char *start = position_;
do {
for (const char *i = start; i < position_end_; ++i) {
size_t skip = 0;
while (true) {
for (const char *i = position_ + skip; i < position_end_; ++i) {
if (*i == delim) {
StringPiece ret(position_, i - position_);
position_ = i + 1;
return ret;
}
}
size_t skip = position_end_ - position_;
if (at_end_) {
if (position_ == position_end_) Shift();
return Consume(position_end_);
}
skip = position_end_ - position_;
Shift();
start = position_ + skip;
} while (!at_end_);
StringPiece ret(position_, position_end_ - position_);
position_ = position_end_;
return ret;
}
}
float FilePiece::ReadFloat() throw(GZException, EndOfFileException, ParseNumberException) {
@ -107,13 +110,6 @@ unsigned long int FilePiece::ReadULong() throw(GZException, EndOfFileException,
return ReadNumber<unsigned long int>();
}
void FilePiece::SkipSpaces() throw (GZException, EndOfFileException) {
for (; ; ++position_) {
if (position_ == position_end_) Shift();
if (!isspace(*position_)) return;
}
}
void FilePiece::Initialize(const char *name, std::ostream *show_progress, off_t min_buffer) throw (GZException) {
#ifdef HAVE_ZLIB
gz_file_ = NULL;
@ -190,18 +186,19 @@ template <class T> T FilePiece::ReadNumber() throw(GZException, EndOfFileExcepti
return ret;
}
const char *FilePiece::FindDelimiterOrEOF() throw (GZException, EndOfFileException) {
for (const char *i = position_; i <= last_space_; ++i) {
if (isspace(*i)) return i;
}
while (!at_end_) {
size_t skip = position_end_ - position_;
Shift();
for (const char *i = position_ + skip; i <= last_space_; ++i) {
if (isspace(*i)) return i;
const char *FilePiece::FindDelimiterOrEOF(const bool *delim) throw (GZException, EndOfFileException) {
size_t skip = 0;
while (true) {
for (const char *i = position_ + skip; i < position_end_; ++i) {
if (delim[static_cast<unsigned char>(*i)]) return i;
}
if (at_end_) {
if (position_ == position_end_) Shift();
return position_end_;
}
skip = position_end_ - position_;
Shift();
}
return position_end_;
}
void FilePiece::Shift() throw(GZException, EndOfFileException) {

View File

@ -3,6 +3,7 @@
#include "util/ersatz_progress.hh"
#include "util/exception.hh"
#include "util/have.hh"
#include "util/mmap.hh"
#include "util/scoped.hh"
#include "util/string_piece.hh"
@ -11,8 +12,6 @@
#include <cstddef>
#define HAVE_ZLIB
namespace util {
class EndOfFileException : public Exception {
@ -36,10 +35,13 @@ class GZException : public Exception {
int OpenReadOrThrow(const char *name);
extern const bool kSpaces[256];
// Return value for SizeFile when it can't size properly.
const off_t kBadSize = -1;
off_t SizeFile(int fd);
// Memory backing the returned StringPiece may vanish on the next call.
class FilePiece {
public:
// 32 MB default.
@ -57,12 +59,12 @@ class FilePiece {
return *(position_++);
}
// Memory backing the returned StringPiece may vanish on the next call.
// Leaves the delimiter, if any, to be returned by get().
StringPiece ReadDelimited() throw(GZException, EndOfFileException) {
SkipSpaces();
return Consume(FindDelimiterOrEOF());
// Leaves the delimiter, if any, to be returned by get(). Delimiters defined by isspace().
StringPiece ReadDelimited(const bool *delim = kSpaces) throw(GZException, EndOfFileException) {
SkipSpaces(delim);
return Consume(FindDelimiterOrEOF(delim));
}
// Unlike ReadDelimited, this includes leading spaces and consumes the delimiter.
// It is similar to getline in that way.
StringPiece ReadLine(char delim = '\n') throw(GZException, EndOfFileException);
@ -72,7 +74,13 @@ class FilePiece {
long int ReadLong() throw(GZException, EndOfFileException, ParseNumberException);
unsigned long int ReadULong() throw(GZException, EndOfFileException, ParseNumberException);
void SkipSpaces() throw (GZException, EndOfFileException);
// Skip spaces defined by isspace.
void SkipSpaces(const bool *delim = kSpaces) throw (GZException, EndOfFileException) {
for (; ; ++position_) {
if (position_ == position_end_) Shift();
if (!delim[static_cast<unsigned char>(*position_)]) return;
}
}
off_t Offset() const {
return position_ - data_.begin() + mapped_offset_;
@ -91,7 +99,7 @@ class FilePiece {
return ret;
}
const char *FindDelimiterOrEOF() throw(EndOfFileException, GZException);
const char *FindDelimiterOrEOF(const bool *delim = kSpaces) throw (GZException, EndOfFileException);
void Shift() throw (EndOfFileException, GZException);
// Backends to Shift().

9
kenlm/util/have.hh Normal file
View File

@ -0,0 +1,9 @@
/* This ties kenlm's config into Moses's build system. If you are using kenlm
* outside Moses, see http://kheafield.com/code/kenlm/developers/ .
*/
#ifndef UTIL_HAVE__
#define UTIL_HAVE__
#define HAVE_ZLIB
#endif // UTIL_HAVE__

View File

@ -18,6 +18,8 @@ template <class Key, class Value> struct Entry {
const Key &GetKey() const { return key; }
const Value &GetValue() const { return value; }
Value &MutableValue() { return value; }
void Set(const Key &key_in, const Value &value_in) {
SetKey(key_in);
SetValue(value_in);
@ -77,6 +79,8 @@ template <class KeyT, class ValueT> class ByteAlignedPacking {
const Key &GetKey() const { return key; }
const Value &GetValue() const { return value; }
Value &MutableValue() { return value; }
void Set(const Key &key_in, const Value &value_in) {
SetKey(key_in);
SetValue(value_in);

View File

@ -77,6 +77,8 @@ void ReadAll(int fd, void *to_void, std::size_t amount) {
}
}
} // namespace
const int kFileFlags =
#ifdef MAP_FILE
MAP_FILE | MAP_SHARED
@ -85,8 +87,6 @@ const int kFileFlags =
#endif
;
} // namespace
void MapRead(LoadMethod method, int fd, off_t offset, std::size_t size, scoped_memory &out) {
switch (method) {
case LAZY:

View File

@ -91,9 +91,11 @@ typedef enum {
READ
} LoadMethod;
extern const int kFileFlags;
// Wrapper around mmap to check it worked and hide some platform macros.
void *MapOrThrow(std::size_t size, bool for_write, int flags, bool prefault, int fd, off_t offset = 0);
void MapRead(LoadMethod method, int fd, off_t offset, std::size_t size, scoped_memory &out);
void *MapAnonymous(std::size_t size);

View File

@ -1,6 +1,8 @@
#ifndef UTIL_PROBING_HASH_TABLE__
#define UTIL_PROBING_HASH_TABLE__
#include "util/exception.hh"
#include <algorithm>
#include <cstddef>
#include <functional>
@ -9,6 +11,13 @@
namespace util {
/* Thrown when table grows too large */
class ProbingSizeException : public Exception {
public:
ProbingSizeException() throw() {}
~ProbingSizeException() throw() {}
};
/* Non-standard hash table
* Buckets must be set at the beginning and must be greater than maximum number
* of elements, else an infinite loop happens.
@ -33,9 +42,9 @@ template <class PackingT, class HashT, class EqualT = std::equal_to<typename Pac
}
// Must be assigned to later.
ProbingHashTable()
ProbingHashTable() : entries_(0)
#ifdef DEBUG
: initialized_(false), entries_(0)
, initialized_(false)
#endif
{}
@ -45,17 +54,18 @@ template <class PackingT, class HashT, class EqualT = std::equal_to<typename Pac
end_(begin_ + (allocated / Packing::kBytes)),
invalid_(invalid),
hash_(hash_func),
equal_(equal_func)
equal_(equal_func),
entries_(0)
#ifdef DEBUG
, initialized_(true),
entries_(0)
#endif
{}
template <class T> void Insert(const T &t) {
if (++entries_ >= buckets_)
UTIL_THROW(ProbingSizeException, "Hash table with " << buckets_ << " buckets is full.");
#ifdef DEBUG
assert(initialized_);
assert(++entries_ < buckets_);
#endif
for (MutableIterator i(begin_ + (hash_(t.GetKey()) % buckets_));;) {
if (equal_(i->GetKey(), invalid_)) { *i = t; return; }
@ -67,6 +77,16 @@ template <class PackingT, class HashT, class EqualT = std::equal_to<typename Pac
void LoadedBinary() {}
// Don't change anything related to GetKey,
template <class Key> bool UnsafeMutableFind(const Key key, MutableIterator &out) {
for (MutableIterator i(begin_ + (hash_(key) % buckets_));;) {
Key got(i->GetKey());
if (equal_(got, key)) { out = i; return true; }
if (equal_(got, invalid_)) return false;
if (++i == end_) i = begin_;
}
}
template <class Key> bool Find(const Key key, ConstIterator &out) const {
#ifdef DEBUG
assert(initialized_);
@ -74,8 +94,8 @@ template <class PackingT, class HashT, class EqualT = std::equal_to<typename Pac
for (ConstIterator i(begin_ + (hash_(key) % buckets_));;) {
Key got(i->GetKey());
if (equal_(got, key)) { out = i; return true; }
if (equal_(got, invalid_)) { return false; }
if (++i == end_) { i = begin_; }
if (equal_(got, invalid_)) return false;
if (++i == end_) i = begin_;
}
}
@ -86,9 +106,9 @@ template <class PackingT, class HashT, class EqualT = std::equal_to<typename Pac
Key invalid_;
Hash hash_;
Equal equal_;
std::size_t entries_;
#ifdef DEBUG
bool initialized_;
std::size_t entries_;
#endif
};

View File

@ -20,7 +20,9 @@ template <class T, class R, R (*Free)(T*)> class scoped_thing {
}
T &operator*() { return *c_; }
const T&operator*() const { return *c_; }
T &operator->() { return *c_; }
const T&operator->() const { return *c_; }
T *get() { return c_; }
const T *get() const { return c_; }
@ -80,6 +82,34 @@ class scoped_FILE {
std::FILE *file_;
};
// Hat tip to boost.
template <class T> class scoped_array {
public:
explicit scoped_array(T *content = NULL) : c_(content) {}
~scoped_array() { delete [] c_; }
T *get() { return c_; }
const T* get() const { return c_; }
T &operator*() { return *c_; }
const T&operator*() const { return *c_; }
T &operator->() { return *c_; }
const T&operator->() const { return *c_; }
T &operator[](std::size_t idx) { return c_[idx]; }
const T &operator[](std::size_t idx) const { return c_[idx]; }
void reset(T *to = NULL) {
scoped_array<T> other(c_);
c_ = to;
}
private:
T *c_;
};
} // namespace util
#endif // UTIL_SCOPED__

View File

@ -62,6 +62,7 @@ template <class PackingT> class SortedUniformMap {
public:
typedef PackingT Packing;
typedef typename Packing::ConstIterator ConstIterator;
typedef typename Packing::MutableIterator MutableIterator;
public:
// Offer consistent API with probing hash.
@ -113,6 +114,15 @@ template <class PackingT> class SortedUniformMap {
*size_ptr_ = (end_ - begin_);
}
// Don't use this to change the key.
template <class Key> bool UnsafeMutableFind(const Key key, MutableIterator &out) {
#ifdef DEBUG
assert(initialized_);
assert(loaded_);
#endif
return SortedUniformFind<MutableIterator, Key>(begin_, end_, key, out);
}
// Do not call before FinishedInserting.
template <class Key> bool Find(const Key key, ConstIterator &out) const {
#ifdef DEBUG

View File

@ -1,51 +0,0 @@
// Copyright 2008, Google Inc.
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// Copied from strings/stringpiece.cc with modifications
#include "util/string_piece.hh"
#ifdef HAVE_BOOST
#include <boost/functional/hash/hash.hpp>
#endif
#include <algorithm>
#ifdef HAVE_ICU
U_NAMESPACE_BEGIN
#endif
#ifdef HAVE_BOOST
size_t hash_value(const StringPiece &str) {
return boost::hash_range(str.data(), str.data() + str.length());
}
#endif
#ifdef HAVE_ICU
U_NAMESPACE_END
#endif

View File

@ -48,14 +48,14 @@
#ifndef BASE_STRING_PIECE_H__
#define BASE_STRING_PIECE_H__
//Uncomment this line if you use ICU in your code.
//#define HAVE_ICU
//Uncomment this line if you want boost hashing for your StringPieces.
//#define HAVE_BOOST
#include "util/have.hh"
#ifdef HAVE_BOOST
#include <boost/functional/hash/hash.hpp>
#endif // HAVE_BOOST
#include <cstring>
#include <iosfwd>
#include <ostream>
#ifdef HAVE_ICU
@ -64,6 +64,7 @@ U_NAMESPACE_BEGIN
#else
#include <algorithm>
#include <cstddef>
#include <string>
#include <string.h>
@ -234,7 +235,9 @@ inline std::ostream& operator<<(std::ostream& o, const StringPiece& piece) {
}
#ifdef HAVE_BOOST
size_t hash_value(const StringPiece &str);
inline size_t hash_value(const StringPiece &str) {
return boost::hash_range(str.data(), str.data() + str.length());
}
/* Support for lookup of StringPiece in boost::unordered_map<std::string> */
struct StringPieceCompatibleHash : public std::unary_function<const StringPiece &, size_t> {

View File

@ -6,170 +6,174 @@ const int BleuScorer::LENGTH = 4;
/**
* count the ngrams of each type, up to the given length in the input line.
**/
size_t BleuScorer::countNgrams(const string& line, counts_t& counts, unsigned int n) {
vector<int> encoded_tokens;
//cerr << line << endl;
encode(line,encoded_tokens);
//copy(encoded_tokens.begin(), encoded_tokens.end(), ostream_iterator<int>(cerr," "));
//cerr << endl;
for (size_t k = 1; k <= n; ++k) {
//ngram order longer than sentence - no point
if (k > encoded_tokens.size()) {
continue;
}
for (size_t i = 0; i < encoded_tokens.size()-k+1; ++i) {
vector<int> ngram;
for (size_t j = i; j < i+k && j < encoded_tokens.size(); ++j) {
ngram.push_back(encoded_tokens[j]);
}
int count = 1;
counts_it oldcount = counts.find(ngram);
if (oldcount != counts.end()) {
count = (oldcount->second) + 1;
}
//cerr << count << endl;
counts[ngram] = count;
//cerr << endl;
}
}
//cerr << "counted ngrams" << endl;
//dump_counts(counts);
return encoded_tokens.size();
size_t BleuScorer::countNgrams(const string& line, counts_t& counts, unsigned int n)
{
vector<int> encoded_tokens;
//cerr << line << endl;
encode(line,encoded_tokens);
//copy(encoded_tokens.begin(), encoded_tokens.end(), ostream_iterator<int>(cerr," "));
//cerr << endl;
for (size_t k = 1; k <= n; ++k) {
//ngram order longer than sentence - no point
if (k > encoded_tokens.size()) {
continue;
}
for (size_t i = 0; i < encoded_tokens.size()-k+1; ++i) {
vector<int> ngram;
for (size_t j = i; j < i+k && j < encoded_tokens.size(); ++j) {
ngram.push_back(encoded_tokens[j]);
}
int count = 1;
counts_it oldcount = counts.find(ngram);
if (oldcount != counts.end()) {
count = (oldcount->second) + 1;
}
//cerr << count << endl;
counts[ngram] = count;
//cerr << endl;
}
}
//cerr << "counted ngrams" << endl;
//dump_counts(counts);
return encoded_tokens.size();
}
void BleuScorer::setReferenceFiles(const vector<string>& referenceFiles) {
//make sure reference data is clear
_refcounts.clear();
_reflengths.clear();
_encodings.clear();
void BleuScorer::setReferenceFiles(const vector<string>& referenceFiles)
{
//make sure reference data is clear
_refcounts.clear();
_reflengths.clear();
_encodings.clear();
//load reference data
for (size_t i = 0; i < referenceFiles.size(); ++i) {
TRACE_ERR("Loading reference from " << referenceFiles[i] << endl);
ifstream refin(referenceFiles[i].c_str());
if (!refin) {
throw runtime_error("Unable to open: " + referenceFiles[i]);
}
string line;
size_t sid = 0; //sentence counter
while (getline(refin,line)) {
//cerr << line << endl;
if (i == 0) {
counts_t* counts = new counts_t(); //these get leaked
_refcounts.push_back(counts);
vector<size_t> lengths;
_reflengths.push_back(lengths);
}
if (_refcounts.size() <= sid) {
throw runtime_error("File " + referenceFiles[i] + " has too many sentences");
}
counts_t counts;
size_t length = countNgrams(line,counts,LENGTH);
//for any counts larger than those already there, merge them in
for (counts_it ci = counts.begin(); ci != counts.end(); ++ci) {
counts_it oldcount_it = _refcounts[sid]->find(ci->first);
int oldcount = 0;
if (oldcount_it != _refcounts[sid]->end()) {
oldcount = oldcount_it->second;
}
int newcount = ci->second;
if (newcount > oldcount) {
_refcounts[sid]->operator[](ci->first) = newcount;
}
}
//add in the length
_reflengths[sid].push_back(length);
if (sid > 0 && sid % 100 == 0) {
TRACE_ERR(".");
}
++sid;
}
TRACE_ERR(endl);
}
//load reference data
for (size_t i = 0; i < referenceFiles.size(); ++i) {
TRACE_ERR("Loading reference from " << referenceFiles[i] << endl);
ifstream refin(referenceFiles[i].c_str());
if (!refin) {
throw runtime_error("Unable to open: " + referenceFiles[i]);
}
string line;
size_t sid = 0; //sentence counter
while (getline(refin,line)) {
//cerr << line << endl;
if (i == 0) {
counts_t* counts = new counts_t(); //these get leaked
_refcounts.push_back(counts);
vector<size_t> lengths;
_reflengths.push_back(lengths);
}
if (_refcounts.size() <= sid) {
throw runtime_error("File " + referenceFiles[i] + " has too many sentences");
}
counts_t counts;
size_t length = countNgrams(line,counts,LENGTH);
//for any counts larger than those already there, merge them in
for (counts_it ci = counts.begin(); ci != counts.end(); ++ci) {
counts_it oldcount_it = _refcounts[sid]->find(ci->first);
int oldcount = 0;
if (oldcount_it != _refcounts[sid]->end()) {
oldcount = oldcount_it->second;
}
int newcount = ci->second;
if (newcount > oldcount) {
_refcounts[sid]->operator[](ci->first) = newcount;
}
}
//add in the length
_reflengths[sid].push_back(length);
if (sid > 0 && sid % 100 == 0) {
TRACE_ERR(".");
}
++sid;
}
TRACE_ERR(endl);
}
}
void BleuScorer::prepareStats(size_t sid, const string& text, ScoreStats& entry) {
void BleuScorer::prepareStats(size_t sid, const string& text, ScoreStats& entry)
{
// cerr << text << endl;
// cerr << sid << endl;
//dump_counts(*_refcounts[sid]);
if (sid >= _refcounts.size()) {
stringstream msg;
msg << "Sentence id (" << sid << ") not found in reference set";
throw runtime_error(msg.str());
}
counts_t testcounts;
//stats for this line
vector<float> stats(LENGTH*2);;
size_t length = countNgrams(text,testcounts,LENGTH);
//dump_counts(testcounts);
if (_refLengthStrategy == BLEU_SHORTEST) {
//cerr << reflengths.size() << " " << sid << endl;
int shortest = *min_element(_reflengths[sid].begin(),_reflengths[sid].end());
stats.push_back(shortest);
} else if (_refLengthStrategy == BLEU_AVERAGE) {
int total = 0;
for (size_t i = 0; i < _reflengths[sid].size(); ++i) {
total += _reflengths[sid][i];
}
float mean = (float)total/_reflengths[sid].size();
stats.push_back(mean);
} else if (_refLengthStrategy == BLEU_CLOSEST) {
int min_diff = INT_MAX;
int min_idx = 0;
for (size_t i = 0; i < _reflengths[sid].size(); ++i) {
int reflength = _reflengths[sid][i];
if (abs(reflength-(int)length) < abs(min_diff)) { //look for the closest reference
min_diff = reflength-length;
min_idx = i;
}else if (abs(reflength-(int)length) == abs(min_diff)) { // if two references has the same closest length, take the shortest
if (reflength < (int)_reflengths[sid][min_idx]){
min_idx = i;
}
}
}
stats.push_back(_reflengths[sid][min_idx]);
} else {
throw runtime_error("Unsupported reflength strategy");
}
//cerr << "computed length" << endl;
//precision on each ngram type
for (counts_it testcounts_it = testcounts.begin();
testcounts_it != testcounts.end(); ++testcounts_it) {
counts_it refcounts_it = _refcounts[sid]->find(testcounts_it->first);
int correct = 0;
int guess = testcounts_it->second;
if (refcounts_it != _refcounts[sid]->end()) {
correct = min(refcounts_it->second,guess);
}
size_t len = testcounts_it->first.size();
stats[len*2-2] += correct;
stats[len*2-1] += guess;
}
stringstream sout;
copy(stats.begin(),stats.end(),ostream_iterator<float>(sout," "));
//TRACE_ERR(sout.str() << endl);
string stats_str = sout.str();
entry.set(stats_str);
//dump_counts(*_refcounts[sid]);
if (sid >= _refcounts.size()) {
stringstream msg;
msg << "Sentence id (" << sid << ") not found in reference set";
throw runtime_error(msg.str());
}
counts_t testcounts;
//stats for this line
vector<float> stats(LENGTH*2);;
size_t length = countNgrams(text,testcounts,LENGTH);
//dump_counts(testcounts);
if (_refLengthStrategy == BLEU_SHORTEST) {
//cerr << reflengths.size() << " " << sid << endl;
int shortest = *min_element(_reflengths[sid].begin(),_reflengths[sid].end());
stats.push_back(shortest);
} else if (_refLengthStrategy == BLEU_AVERAGE) {
int total = 0;
for (size_t i = 0; i < _reflengths[sid].size(); ++i) {
total += _reflengths[sid][i];
}
float mean = (float)total/_reflengths[sid].size();
stats.push_back(mean);
} else if (_refLengthStrategy == BLEU_CLOSEST) {
int min_diff = INT_MAX;
int min_idx = 0;
for (size_t i = 0; i < _reflengths[sid].size(); ++i) {
int reflength = _reflengths[sid][i];
if (abs(reflength-(int)length) < abs(min_diff)) { //look for the closest reference
min_diff = reflength-length;
min_idx = i;
} else if (abs(reflength-(int)length) == abs(min_diff)) { // if two references has the same closest length, take the shortest
if (reflength < (int)_reflengths[sid][min_idx]) {
min_idx = i;
}
}
}
stats.push_back(_reflengths[sid][min_idx]);
} else {
throw runtime_error("Unsupported reflength strategy");
}
//cerr << "computed length" << endl;
//precision on each ngram type
for (counts_it testcounts_it = testcounts.begin();
testcounts_it != testcounts.end(); ++testcounts_it) {
counts_it refcounts_it = _refcounts[sid]->find(testcounts_it->first);
int correct = 0;
int guess = testcounts_it->second;
if (refcounts_it != _refcounts[sid]->end()) {
correct = min(refcounts_it->second,guess);
}
size_t len = testcounts_it->first.size();
stats[len*2-2] += correct;
stats[len*2-1] += guess;
}
stringstream sout;
copy(stats.begin(),stats.end(),ostream_iterator<float>(sout," "));
//TRACE_ERR(sout.str() << endl);
string stats_str = sout.str();
entry.set(stats_str);
}
float BleuScorer::calculateScore(const vector<int>& comps) {
//cerr << "BLEU: ";
//copy(comps.begin(),comps.end(), ostream_iterator<int>(cerr," "));
float logbleu = 0.0;
for (int i = 0; i < LENGTH; ++i) {
if (comps[2*i] == 0) {
return 0.0;
}
logbleu += log(comps[2*i]) - log(comps[2*i+1]);
}
logbleu /= LENGTH;
float brevity = 1.0 - (float)comps[LENGTH*2]/comps[1];//reflength divided by test length
if (brevity < 0.0) {
logbleu += brevity;
}
//cerr << " " << exp(logbleu) << endl;
return exp(logbleu);
float BleuScorer::calculateScore(const vector<int>& comps)
{
//cerr << "BLEU: ";
//copy(comps.begin(),comps.end(), ostream_iterator<int>(cerr," "));
float logbleu = 0.0;
for (int i = 0; i < LENGTH; ++i) {
if (comps[2*i] == 0) {
return 0.0;
}
logbleu += log(comps[2*i]) - log(comps[2*i+1]);
}
logbleu /= LENGTH;
float brevity = 1.0 - (float)comps[LENGTH*2]/comps[1];//reflength divided by test length
if (brevity < 0.0) {
logbleu += brevity;
}
//cerr << " " << exp(logbleu) << endl;
return exp(logbleu);
}

View File

@ -23,84 +23,88 @@ enum BleuReferenceLengthStrategy { BLEU_AVERAGE, BLEU_SHORTEST, BLEU_CLOSEST };
/**
* Bleu scoring
**/
class BleuScorer: public StatisticsBasedScorer {
public:
BleuScorer(const string& config = "") : StatisticsBasedScorer("BLEU",config),_refLengthStrategy(BLEU_CLOSEST) {
class BleuScorer: public StatisticsBasedScorer
{
public:
BleuScorer(const string& config = "") : StatisticsBasedScorer("BLEU",config),_refLengthStrategy(BLEU_CLOSEST) {
//configure regularisation
static string KEY_REFLEN = "reflen";
static string REFLEN_AVERAGE = "average";
static string REFLEN_SHORTEST = "shortest";
static string REFLEN_CLOSEST = "closest";
string reflen = getConfig(KEY_REFLEN,REFLEN_CLOSEST);
if (reflen == REFLEN_AVERAGE) {
_refLengthStrategy = BLEU_AVERAGE;
_refLengthStrategy = BLEU_AVERAGE;
} else if (reflen == REFLEN_SHORTEST) {
_refLengthStrategy = BLEU_SHORTEST;
_refLengthStrategy = BLEU_SHORTEST;
} else if (reflen == REFLEN_CLOSEST) {
_refLengthStrategy = BLEU_CLOSEST;
_refLengthStrategy = BLEU_CLOSEST;
} else {
throw runtime_error("Unknown reference length strategy: " + reflen);
throw runtime_error("Unknown reference length strategy: " + reflen);
}
cerr << "Using reference length strategy: " << reflen << endl;
}
virtual void setReferenceFiles(const vector<string>& referenceFiles);
virtual void prepareStats(size_t sid, const string& text, ScoreStats& entry);
static const int LENGTH;
size_t NumberOfScores(){ cerr << "BleuScorer: " << (2 * LENGTH + 1) << endl; return (2 * LENGTH + 1); };
protected:
float calculateScore(const vector<int>& comps);
private:
//no copy
BleuScorer(const BleuScorer&);
~BleuScorer(){};
BleuScorer& operator=(const BleuScorer&);
//Used to construct the ngram map
struct CompareNgrams {
int operator() (const vector<int>& a, const vector<int>& b) {
size_t i;
size_t as = a.size();
size_t bs = b.size();
for (i = 0; i < as && i < bs; ++i) {
if (a[i] < b[i]) {
//cerr << "true" << endl;
return true;
}
if (a[i] > b[i]) {
//cerr << "false" << endl;
return false;
}
}
//entries are equal, shortest wins
return as < bs;;
}
};
}
virtual void setReferenceFiles(const vector<string>& referenceFiles);
virtual void prepareStats(size_t sid, const string& text, ScoreStats& entry);
static const int LENGTH;
typedef map<vector<int>,int,CompareNgrams> counts_t;
typedef map<vector<int>,int,CompareNgrams>::iterator counts_it;
size_t NumberOfScores() {
cerr << "BleuScorer: " << (2 * LENGTH + 1) << endl;
return (2 * LENGTH + 1);
};
typedef vector<counts_t*> refcounts_t;
size_t countNgrams(const string& line, counts_t& counts, unsigned int n);
protected:
float calculateScore(const vector<int>& comps);
void dump_counts(counts_t& counts) {
for (counts_it i = counts.begin(); i != counts.end(); ++i) {
cerr << "(";
copy(i->first.begin(), i->first.end(), ostream_iterator<int>(cerr," "));
cerr << ") " << i->second << ", ";
}
cerr << endl;
}
BleuReferenceLengthStrategy _refLengthStrategy;
// data extracted from reference files
refcounts_t _refcounts;
vector<vector<size_t> > _reflengths;
private:
//no copy
BleuScorer(const BleuScorer&);
~BleuScorer() {};
BleuScorer& operator=(const BleuScorer&);
//Used to construct the ngram map
struct CompareNgrams {
int operator() (const vector<int>& a, const vector<int>& b) {
size_t i;
size_t as = a.size();
size_t bs = b.size();
for (i = 0; i < as && i < bs; ++i) {
if (a[i] < b[i]) {
//cerr << "true" << endl;
return true;
}
if (a[i] > b[i]) {
//cerr << "false" << endl;
return false;
}
}
//entries are equal, shortest wins
return as < bs;;
}
};
typedef map<vector<int>,int,CompareNgrams> counts_t;
typedef map<vector<int>,int,CompareNgrams>::iterator counts_it;
typedef vector<counts_t*> refcounts_t;
size_t countNgrams(const string& line, counts_t& counts, unsigned int n);
void dump_counts(counts_t& counts) {
for (counts_it i = counts.begin(); i != counts.end(); ++i) {
cerr << "(";
copy(i->first.begin(), i->first.end(), ostream_iterator<int>(cerr," "));
cerr << ") " << i->second << ", ";
}
cerr << endl;
}
BleuReferenceLengthStrategy _refLengthStrategy;
// data extracted from reference files
refcounts_t _refcounts;
vector<vector<size_t> > _reflengths;
};

View File

@ -13,94 +13,93 @@
Data::Data(Scorer& ptr):
theScorer(&ptr)
theScorer(&ptr)
{
score_type = (*theScorer).getName();
TRACE_ERR("Data::score_type " << score_type << std::endl);
TRACE_ERR("Data::Scorer type from Scorer: " << theScorer->getName() << endl);
score_type = (*theScorer).getName();
TRACE_ERR("Data::score_type " << score_type << std::endl);
TRACE_ERR("Data::Scorer type from Scorer: " << theScorer->getName() << endl);
featdata=new FeatureData;
scoredata=new ScoreData(*theScorer);
};
void Data::loadnbest(const std::string &file)
{
TRACE_ERR("loading nbest from " << file << std::endl);
TRACE_ERR("loading nbest from " << file << std::endl);
FeatureStats featentry;
ScoreStats scoreentry;
std::string sentence_index;
FeatureStats featentry;
ScoreStats scoreentry;
std::string sentence_index;
inputfilestream inp(file); // matches a stream with a file. Opens the file
inputfilestream inp(file); // matches a stream with a file. Opens the file
if (!inp.good())
throw runtime_error("Unable to open: " + file);
if (!inp.good())
throw runtime_error("Unable to open: " + file);
std::string substring, subsubstring, stringBuf;
std::string theSentence;
std::string::size_type loc;
std::string substring, subsubstring, stringBuf;
std::string theSentence;
std::string::size_type loc;
while (getline(inp,stringBuf,'\n')){
if (stringBuf.empty()) continue;
while (getline(inp,stringBuf,'\n')) {
if (stringBuf.empty()) continue;
// TRACE_ERR("stringBuf: " << stringBuf << std::endl);
// TRACE_ERR("stringBuf: " << stringBuf << std::endl);
getNextPound(stringBuf, substring, "|||"); //first field
sentence_index = substring;
getNextPound(stringBuf, substring, "|||"); //first field
sentence_index = substring;
getNextPound(stringBuf, substring, "|||"); //second field
theSentence = substring;
getNextPound(stringBuf, substring, "|||"); //second field
theSentence = substring;
// adding statistics for error measures
featentry.reset();
scoreentry.clear();
featentry.reset();
scoreentry.clear();
theScorer->prepareStats(sentence_index, theSentence, scoreentry);
theScorer->prepareStats(sentence_index, theSentence, scoreentry);
scoredata->add(scoreentry, sentence_index);
scoredata->add(scoreentry, sentence_index);
getNextPound(stringBuf, substring, "|||"); //third field
getNextPound(stringBuf, substring, "|||"); //third field
if (!existsFeatureNames()){
std::string stringsupport=substring;
// adding feature names
std::string features="";
std::string tmpname="";
if (!existsFeatureNames()) {
std::string stringsupport=substring;
// adding feature names
std::string features="";
std::string tmpname="";
size_t tmpidx=0;
while (!stringsupport.empty()) {
// TRACE_ERR("Decompounding: " << substring << std::endl);
getNextPound(stringsupport, subsubstring);
// string ending with ":" are skipped, because they are the names of the features
if ((loc = subsubstring.find(":")) != subsubstring.length()-1) {
features+=tmpname+"_"+stringify(tmpidx)+" ";
tmpidx++;
} else {
tmpidx=0;
tmpname=subsubstring.substr(0,subsubstring.size() - 1);
}
}
featdata->setFeatureMap(features);
}
size_t tmpidx=0;
while (!stringsupport.empty()){
// TRACE_ERR("Decompounding: " << substring << std::endl);
getNextPound(stringsupport, subsubstring);
// string ending with ":" are skipped, because they are the names of the features
if ((loc = subsubstring.find(":")) != subsubstring.length()-1){
features+=tmpname+"_"+stringify(tmpidx)+" ";
tmpidx++;
}
else{
tmpidx=0;
tmpname=subsubstring.substr(0,subsubstring.size() - 1);
}
}
featdata->setFeatureMap(features);
}
// adding features
while (!substring.empty()){
// TRACE_ERR("Decompounding: " << substring << std::endl);
getNextPound(substring, subsubstring);
while (!substring.empty()) {
// TRACE_ERR("Decompounding: " << substring << std::endl);
getNextPound(substring, subsubstring);
// string ending with ":" are skipped, because they are the names of the features
if ((loc = subsubstring.find(":")) != subsubstring.length()-1){
featentry.add(ATOFST(subsubstring.c_str()));
}
}
featdata->add(featentry,sentence_index);
}
inp.close();
if ((loc = subsubstring.find(":")) != subsubstring.length()-1) {
featentry.add(ATOFST(subsubstring.c_str()));
}
}
featdata->add(featentry,sentence_index);
}
inp.close();
}

View File

@ -24,49 +24,70 @@ class Scorer;
class Data
{
protected:
ScoreData* scoredata;
FeatureData* featdata;
ScoreData* scoredata;
FeatureData* featdata;
private:
Scorer* theScorer;
Scorer* theScorer;
std::string score_type;
size_t number_of_scores; //number of scores
size_t number_of_scores; //number of scores
public:
Data(Scorer& sc);
~Data(){};
inline void clear() { scoredata->clear(); featdata->clear(); }
ScoreData* getScoreData() { return scoredata; };
FeatureData* getFeatureData() { return featdata; };
inline size_t NumberOfFeatures() const{ return featdata->NumberOfFeatures(); }
inline void NumberOfFeatures(size_t v){ featdata->NumberOfFeatures(v); }
inline std::string Features() const{ return featdata->Features(); }
inline void Features(const std::string f){ featdata->Features(f); }
Data(Scorer& sc);
void loadnbest(const std::string &file);
~Data() {};
void load(const std::string &featfile,const std::string &scorefile){
featdata->load(featfile);
scoredata->load(scorefile);
inline void clear() {
scoredata->clear();
featdata->clear();
}
void save(const std::string &featfile,const std::string &scorefile, bool bin=false){
if (bin) cerr << "Binary write mode is selected" << endl;
else cerr << "Binary write mode is NOT selected" << endl;
featdata->save(featfile, bin);
scoredata->save(scorefile, bin);
}
inline bool existsFeatureNames(){ return featdata->existsFeatureNames(); };
inline std::string getFeatureName(size_t idx){ return featdata->getFeatureName(idx); };
inline size_t getFeatureIndex(const std::string& name){ return featdata->getFeatureIndex(name); };
ScoreData* getScoreData() {
return scoredata;
};
FeatureData* getFeatureData() {
return featdata;
};
inline size_t NumberOfFeatures() const {
return featdata->NumberOfFeatures();
}
inline void NumberOfFeatures(size_t v) {
featdata->NumberOfFeatures(v);
}
inline std::string Features() const {
return featdata->Features();
}
inline void Features(const std::string f) {
featdata->Features(f);
}
void loadnbest(const std::string &file);
void load(const std::string &featfile,const std::string &scorefile) {
featdata->load(featfile);
scoredata->load(scorefile);
}
void save(const std::string &featfile,const std::string &scorefile, bool bin=false) {
if (bin) cerr << "Binary write mode is selected" << endl;
else cerr << "Binary write mode is NOT selected" << endl;
featdata->save(featfile, bin);
scoredata->save(scorefile, bin);
}
inline bool existsFeatureNames() {
return featdata->existsFeatureNames();
};
inline std::string getFeatureName(size_t idx) {
return featdata->getFeatureName(idx);
};
inline size_t getFeatureIndex(const std::string& name) {
return featdata->getFeatureIndex(name);
};
};

View File

@ -16,137 +16,137 @@ FeatureArray::FeatureArray(): idx("")
void FeatureArray::savetxt(std::ofstream& outFile)
{
outFile << FEATURES_TXT_BEGIN << " " << idx << " " << array_.size()
<< " " << number_of_features << " " << features << std::endl;
for (featarray_t::iterator i = array_.begin(); i !=array_.end(); i++){
i->savetxt(outFile);
outFile << std::endl;
}
outFile << FEATURES_TXT_END << std::endl;
outFile << FEATURES_TXT_BEGIN << " " << idx << " " << array_.size()
<< " " << number_of_features << " " << features << std::endl;
for (featarray_t::iterator i = array_.begin(); i !=array_.end(); i++) {
i->savetxt(outFile);
outFile << std::endl;
}
outFile << FEATURES_TXT_END << std::endl;
}
void FeatureArray::savebin(std::ofstream& outFile)
{
outFile << FEATURES_BIN_BEGIN << " " << idx << " " << array_.size()
<< " " << number_of_features << " " << features << std::endl;
outFile << FEATURES_BIN_BEGIN << " " << idx << " " << array_.size()
<< " " << number_of_features << " " << features << std::endl;
for (featarray_t::iterator i = array_.begin(); i !=array_.end(); i++)
i->savebin(outFile);
i->savebin(outFile);
outFile << FEATURES_BIN_END << std::endl;
outFile << FEATURES_BIN_END << std::endl;
}
void FeatureArray::save(std::ofstream& inFile, bool bin)
{
if (size()>0)
(bin)?savebin(inFile):savetxt(inFile);
if (size()>0)
(bin)?savebin(inFile):savetxt(inFile);
}
void FeatureArray::save(const std::string &file, bool bin)
{
std::ofstream outFile(file.c_str(), std::ios::out); // matches a stream with a file. Opens the file
std::ofstream outFile(file.c_str(), std::ios::out); // matches a stream with a file. Opens the file
save(outFile);
save(outFile);
outFile.close();
outFile.close();
}
void FeatureArray::loadbin(ifstream& inFile, size_t n)
{
FeatureStats entry(number_of_features);
FeatureStats entry(number_of_features);
for (size_t i=0 ; i < n; i++){
entry.loadbin(inFile);
add(entry);
}
for (size_t i=0 ; i < n; i++) {
entry.loadbin(inFile);
add(entry);
}
}
void FeatureArray::loadtxt(ifstream& inFile, size_t n)
{
FeatureStats entry(number_of_features);
for (size_t i=0 ; i < n; i++){
entry.loadtxt(inFile);
add(entry);
}
FeatureStats entry(number_of_features);
for (size_t i=0 ; i < n; i++) {
entry.loadtxt(inFile);
add(entry);
}
}
void FeatureArray::load(ifstream& inFile)
{
size_t number_of_entries=0;
bool binmode=false;
std::string substring, stringBuf;
bool binmode=false;
std::string substring, stringBuf;
std::string::size_type loc;
std::getline(inFile, stringBuf);
if (!inFile.good()){
return;
}
std::getline(inFile, stringBuf);
if (!inFile.good()) {
return;
}
if (!stringBuf.empty()){
if ((loc = stringBuf.find(FEATURES_TXT_BEGIN)) == 0){
binmode=false;
}else if ((loc = stringBuf.find(FEATURES_BIN_BEGIN)) == 0){
binmode=true;
}else{
TRACE_ERR("ERROR: FeatureArray::load(): Wrong header");
return;
}
getNextPound(stringBuf, substring);
getNextPound(stringBuf, substring);
if (!stringBuf.empty()) {
if ((loc = stringBuf.find(FEATURES_TXT_BEGIN)) == 0) {
binmode=false;
} else if ((loc = stringBuf.find(FEATURES_BIN_BEGIN)) == 0) {
binmode=true;
} else {
TRACE_ERR("ERROR: FeatureArray::load(): Wrong header");
return;
}
getNextPound(stringBuf, substring);
getNextPound(stringBuf, substring);
idx = substring;
getNextPound(stringBuf, substring);
getNextPound(stringBuf, substring);
number_of_entries = atoi(substring.c_str());
getNextPound(stringBuf, substring);
getNextPound(stringBuf, substring);
number_of_features = atoi(substring.c_str());
features = stringBuf;
}
features = stringBuf;
}
(binmode)?loadbin(inFile, number_of_entries):loadtxt(inFile, number_of_entries);
(binmode)?loadbin(inFile, number_of_entries):loadtxt(inFile, number_of_entries);
std::getline(inFile, stringBuf);
if (!stringBuf.empty()){
if ((loc = stringBuf.find(FEATURES_TXT_END)) != 0 && (loc = stringBuf.find(FEATURES_BIN_END)) != 0){
TRACE_ERR("ERROR: FeatureArray::load(): Wrong footer");
return;
}
}
std::getline(inFile, stringBuf);
if (!stringBuf.empty()) {
if ((loc = stringBuf.find(FEATURES_TXT_END)) != 0 && (loc = stringBuf.find(FEATURES_BIN_END)) != 0) {
TRACE_ERR("ERROR: FeatureArray::load(): Wrong footer");
return;
}
}
}
void FeatureArray::load(const std::string &file)
{
TRACE_ERR("loading data from " << file << std::endl);
TRACE_ERR("loading data from " << file << std::endl);
inputfilestream inFile(file); // matches a stream with a file. Opens the file
inputfilestream inFile(file); // matches a stream with a file. Opens the file
load((ifstream&) inFile);
load((ifstream&) inFile);
inFile.close();
inFile.close();
}
void FeatureArray::merge(FeatureArray& e)
{
//dummy implementation
for (size_t i=0; i<e.size(); i++)
add(e.get(i));
//dummy implementation
for (size_t i=0; i<e.size(); i++)
add(e.get(i));
}
bool FeatureArray::check_consistency()
{
size_t sz = NumberOfFeatures();
if (sz == 0)
return true;
for (featarray_t::iterator i=array_.begin(); i!=array_.end(); i++)
if (i->size()!=sz)
return false;
return true;
size_t sz = NumberOfFeatures();
if (sz == 0)
return true;
for (featarray_t::iterator i=array_.begin(); i!=array_.end(); i++)
if (i->size()!=sz)
return false;
return true;
}

View File

@ -27,47 +27,71 @@ using namespace std;
class FeatureArray
{
protected:
featarray_t array_;
size_t number_of_features;
std::string features;
featarray_t array_;
size_t number_of_features;
std::string features;
private:
std::string idx; // idx to identify the utterance, it can differ from the index inside the vector
std::string idx; // idx to identify the utterance, it can differ from the index inside the vector
public:
FeatureArray();
~FeatureArray(){};
inline void clear() { array_.clear(); }
inline std::string getIndex(){ return idx; }
inline void setIndex(const std::string & value){ idx=value; }
FeatureArray();
inline FeatureStats& get(size_t i){ return array_.at(i); }
inline const FeatureStats& get(size_t i)const{ return array_.at(i); }
void add(FeatureStats e){ array_.push_back(e); }
~FeatureArray() {};
void merge(FeatureArray& e);
inline void clear() {
array_.clear();
}
inline size_t size(){ return array_.size(); }
inline size_t NumberOfFeatures() const{ return number_of_features; }
inline void NumberOfFeatures(size_t v){ number_of_features = v; }
inline std::string Features() const{ return features; }
inline void Features(const std::string f){ features = f; }
void savetxt(ofstream& outFile);
void savebin(ofstream& outFile);
void save(ofstream& outFile, bool bin=false);
void save(const std::string &file, bool bin=false);
inline void save(bool bin=false){ save("/dev/stdout",bin); }
inline std::string getIndex() {
return idx;
}
inline void setIndex(const std::string & value) {
idx=value;
}
void loadtxt(ifstream& inFile, size_t n);
void loadbin(ifstream& inFile, size_t n);
void load(ifstream& inFile);
void load(const std::string &file);
bool check_consistency();
inline FeatureStats& get(size_t i) {
return array_.at(i);
}
inline const FeatureStats& get(size_t i)const {
return array_.at(i);
}
void add(FeatureStats e) {
array_.push_back(e);
}
void merge(FeatureArray& e);
inline size_t size() {
return array_.size();
}
inline size_t NumberOfFeatures() const {
return number_of_features;
}
inline void NumberOfFeatures(size_t v) {
number_of_features = v;
}
inline std::string Features() const {
return features;
}
inline void Features(const std::string f) {
features = f;
}
void savetxt(ofstream& outFile);
void savebin(ofstream& outFile);
void save(ofstream& outFile, bool bin=false);
void save(const std::string &file, bool bin=false);
inline void save(bool bin=false) {
save("/dev/stdout",bin);
}
void loadtxt(ifstream& inFile, size_t n);
void loadbin(ifstream& inFile, size_t n);
void load(ifstream& inFile);
void load(const std::string &file);
bool check_consistency();
};

View File

@ -18,127 +18,127 @@ FeatureData::FeatureData() {};
void FeatureData::save(std::ofstream& outFile, bool bin)
{
for (featdata_t::iterator i = array_.begin(); i !=array_.end(); i++)
i->save(outFile, bin);
for (featdata_t::iterator i = array_.begin(); i !=array_.end(); i++)
i->save(outFile, bin);
}
void FeatureData::save(const std::string &file, bool bin)
{
if (file.empty()) return;
if (file.empty()) return;
TRACE_ERR("saving the array into " << file << std::endl);
TRACE_ERR("saving the array into " << file << std::endl);
std::ofstream outFile(file.c_str(), std::ios::out); // matches a stream with a file. Opens the file
std::ofstream outFile(file.c_str(), std::ios::out); // matches a stream with a file. Opens the file
save(outFile, bin);
save(outFile, bin);
outFile.close();
outFile.close();
}
void FeatureData::load(ifstream& inFile)
{
FeatureArray entry;
while (!inFile.eof()){
while (!inFile.eof()) {
if (!inFile.good()){
std::cerr << "ERROR FeatureData::load inFile.good()" << std::endl;
}
if (!inFile.good()) {
std::cerr << "ERROR FeatureData::load inFile.good()" << std::endl;
}
entry.clear();
entry.load(inFile);
entry.clear();
entry.load(inFile);
if (entry.size() == 0)
break;
if (entry.size() == 0)
break;
if (size() == 0){
setFeatureMap(entry.Features());
}
add(entry);
}
if (size() == 0) {
setFeatureMap(entry.Features());
}
add(entry);
}
}
void FeatureData::load(const std::string &file)
{
TRACE_ERR("loading feature data from " << file << std::endl);
TRACE_ERR("loading feature data from " << file << std::endl);
inputfilestream inFile(file); // matches a stream with a file. Opens the file
inputfilestream inFile(file); // matches a stream with a file. Opens the file
if (!inFile) {
throw runtime_error("Unable to open feature file: " + file);
}
if (!inFile) {
throw runtime_error("Unable to open feature file: " + file);
}
load((ifstream&) inFile);
load((ifstream&) inFile);
inFile.close();
inFile.close();
}
void FeatureData::add(FeatureArray& e){
if (exists(e.getIndex())){ // array at position e.getIndex() already exists
//enlarge array at position e.getIndex()
size_t pos = getIndex(e.getIndex());
array_.at(pos).merge(e);
}
else{
array_.push_back(e);
setIndex();
}
void FeatureData::add(FeatureArray& e)
{
if (exists(e.getIndex())) { // array at position e.getIndex() already exists
//enlarge array at position e.getIndex()
size_t pos = getIndex(e.getIndex());
array_.at(pos).merge(e);
} else {
array_.push_back(e);
setIndex();
}
}
void FeatureData::add(FeatureStats& e, const std::string & sent_idx){
if (exists(sent_idx)){ // array at position e.getIndex() already exists
//enlarge array at position e.getIndex()
size_t pos = getIndex(sent_idx);
// TRACE_ERR("Inserting " << e << " in array " << sent_idx << std::endl);
array_.at(pos).add(e);
}
else{
// TRACE_ERR("Creating a new entry in the array and inserting " << e << std::endl);
FeatureArray a;
a.NumberOfFeatures(number_of_features);
a.Features(features);
a.setIndex(sent_idx);
a.add(e);
add(a);
}
}
void FeatureData::add(FeatureStats& e, const std::string & sent_idx)
{
if (exists(sent_idx)) { // array at position e.getIndex() already exists
//enlarge array at position e.getIndex()
size_t pos = getIndex(sent_idx);
// TRACE_ERR("Inserting " << e << " in array " << sent_idx << std::endl);
array_.at(pos).add(e);
} else {
// TRACE_ERR("Creating a new entry in the array and inserting " << e << std::endl);
FeatureArray a;
a.NumberOfFeatures(number_of_features);
a.Features(features);
a.setIndex(sent_idx);
a.add(e);
add(a);
}
}
bool FeatureData::check_consistency()
{
if (array_.size() == 0)
return true;
for (featdata_t::iterator i = array_.begin(); i !=array_.end(); i++)
if (!i->check_consistency()) return false;
if (array_.size() == 0)
return true;
return true;
for (featdata_t::iterator i = array_.begin(); i !=array_.end(); i++)
if (!i->check_consistency()) return false;
return true;
}
void FeatureData::setIndex()
{
size_t j=0;
for (featdata_t::iterator i = array_.begin(); i !=array_.end(); i++){
idx2arrayname_[j]=(*i).getIndex();
arrayname2idx_[(*i).getIndex()] = j;
j++;
}
size_t j=0;
for (featdata_t::iterator i = array_.begin(); i !=array_.end(); i++) {
idx2arrayname_[j]=(*i).getIndex();
arrayname2idx_[(*i).getIndex()] = j;
j++;
}
}
void FeatureData::setFeatureMap(const std::string feat)
{
number_of_features = 0;
features=feat;
number_of_features = 0;
features=feat;
std::string substring, stringBuf;
stringBuf=features;
while (!stringBuf.empty()){
getNextPound(stringBuf, substring);
featname2idx_[substring]=idx2featname_.size();
idx2featname_[idx2featname_.size()]=substring;
number_of_features++;
}
std::string substring, stringBuf;
stringBuf=features;
while (!stringBuf.empty()) {
getNextPound(stringBuf, substring);
featname2idx_[substring]=idx2featname_.size();
idx2featname_[idx2featname_.size()]=substring;
number_of_features++;
}
}

View File

@ -20,86 +20,116 @@ using namespace std;
class FeatureData
{
protected:
featdata_t array_;
idx2name idx2arrayname_; //map from index to name of array
name2idx arrayname2idx_; //map from name to index of array
featdata_t array_;
idx2name idx2arrayname_; //map from index to name of array
name2idx arrayname2idx_; //map from name to index of array
private:
size_t number_of_features;
std::string features;
size_t number_of_features;
std::string features;
map<std::string, size_t> featname2idx_; //map from name to index of features
map<size_t, std::string> idx2featname_; //map from index to name of features
map<std::string, size_t> featname2idx_; //map from name to index of features
map<size_t, std::string> idx2featname_; //map from index to name of features
public:
FeatureData();
~FeatureData(){};
inline void clear() { array_.clear(); }
inline FeatureArray get(const std::string& idx){ return array_.at(getIndex(idx)); }
inline FeatureArray& get(size_t idx){ return array_.at(idx); }
inline const FeatureArray& get(size_t idx) const{ return array_.at(idx); }
FeatureData();
inline bool exists(const std::string & sent_idx){ return exists(getIndex(sent_idx)); }
inline bool exists(int sent_idx){ return (sent_idx>-1 && sent_idx<(int) array_.size())?true:false; }
~FeatureData() {};
inline FeatureStats& get(size_t i, size_t j){ return array_.at(i).get(j); }
inline const FeatureStats& get(size_t i, size_t j) const { return array_.at(i).get(j); }
void add(FeatureArray& e);
void add(FeatureStats& e, const std::string& sent_idx);
inline size_t size(){ return array_.size(); }
inline size_t NumberOfFeatures() const{ return number_of_features; }
inline void NumberOfFeatures(size_t v){ number_of_features = v; }
inline std::string Features() const{ return features; }
inline void Features(const std::string f){ features = f; }
void save(const std::string &file, bool bin=false);
void save(ofstream& outFile, bool bin=false);
inline void save(bool bin=false){ save("/dev/stdout", bin); }
void load(ifstream& inFile);
void load(const std::string &file);
bool check_consistency();
void setIndex();
inline int getIndex(const std::string& idx){
name2idx::iterator i = arrayname2idx_.find(idx);
if (i!=arrayname2idx_.end())
return i->second;
else
return -1;
inline void clear() {
array_.clear();
}
inline std::string getIndex(size_t idx){
idx2name::iterator i = idx2arrayname_.find(idx);
if (i!=idx2arrayname_.end())
throw runtime_error("there is no entry at index " + idx);
return i->second;
}
bool existsFeatureNames(){ return (idx2featname_.size() > 0)?true:false; };
std::string getFeatureName(size_t idx){
if (idx >= idx2featname_.size())
throw runtime_error("Error: you required an too big index");
return idx2featname_[idx];
};
size_t getFeatureIndex(const std::string& name){
if (featname2idx_.find(name)==featname2idx_.end())
throw runtime_error("Error: feature " + name +" is unknown");
return featname2idx_[name];
};
inline FeatureArray get(const std::string& idx) {
return array_.at(getIndex(idx));
}
inline FeatureArray& get(size_t idx) {
return array_.at(idx);
}
inline const FeatureArray& get(size_t idx) const {
return array_.at(idx);
}
inline bool exists(const std::string & sent_idx) {
return exists(getIndex(sent_idx));
}
inline bool exists(int sent_idx) {
return (sent_idx>-1 && sent_idx<(int) array_.size())?true:false;
}
inline FeatureStats& get(size_t i, size_t j) {
return array_.at(i).get(j);
}
inline const FeatureStats& get(size_t i, size_t j) const {
return array_.at(i).get(j);
}
void add(FeatureArray& e);
void add(FeatureStats& e, const std::string& sent_idx);
inline size_t size() {
return array_.size();
}
inline size_t NumberOfFeatures() const {
return number_of_features;
}
inline void NumberOfFeatures(size_t v) {
number_of_features = v;
}
inline std::string Features() const {
return features;
}
inline void Features(const std::string f) {
features = f;
}
void save(const std::string &file, bool bin=false);
void save(ofstream& outFile, bool bin=false);
inline void save(bool bin=false) {
save("/dev/stdout", bin);
}
void load(ifstream& inFile);
void load(const std::string &file);
bool check_consistency();
void setIndex();
inline int getIndex(const std::string& idx) {
name2idx::iterator i = arrayname2idx_.find(idx);
if (i!=arrayname2idx_.end())
return i->second;
else
return -1;
}
inline std::string getIndex(size_t idx) {
idx2name::iterator i = idx2arrayname_.find(idx);
if (i!=idx2arrayname_.end())
throw runtime_error("there is no entry at index " + idx);
return i->second;
}
bool existsFeatureNames() {
return (idx2featname_.size() > 0)?true:false;
};
std::string getFeatureName(size_t idx) {
if (idx >= idx2featname_.size())
throw runtime_error("Error: you required an too big index");
return idx2featname_[idx];
};
size_t getFeatureIndex(const std::string& name) {
if (featname2idx_.find(name)==featname2idx_.end())
throw runtime_error("Error: feature " + name +" is unknown");
return featname2idx_[name];
};
void setFeatureMap(const std::string feat);
};

View File

@ -14,123 +14,124 @@
FeatureStats::FeatureStats()
{
available_ = AVAILABLE_;
entries_ = 0;
array_ = new FeatureStatsType[available_];
available_ = AVAILABLE_;
entries_ = 0;
array_ = new FeatureStatsType[available_];
};
FeatureStats::~FeatureStats()
{
delete array_;
delete array_;
};
FeatureStats::FeatureStats(const FeatureStats &stats)
{
available_ = stats.available();
entries_ = stats.size();
array_ = new FeatureStatsType[available_];
memcpy(array_,stats.getArray(),featbytes_);
available_ = stats.available();
entries_ = stats.size();
array_ = new FeatureStatsType[available_];
memcpy(array_,stats.getArray(),featbytes_);
};
FeatureStats::FeatureStats(const size_t size)
{
available_ = size;
entries_ = size;
array_ = new FeatureStatsType[available_];
memset(array_,0,featbytes_);
available_ = size;
entries_ = size;
array_ = new FeatureStatsType[available_];
memset(array_,0,featbytes_);
};
FeatureStats::FeatureStats(std::string &theString)
{
set(theString);
set(theString);
}
void FeatureStats::expand()
{
available_*=2;
featstats_t t_ = new FeatureStatsType[available_];
memcpy(t_,array_,featbytes_);
delete array_;
array_=t_;
available_*=2;
featstats_t t_ = new FeatureStatsType[available_];
memcpy(t_,array_,featbytes_);
delete array_;
array_=t_;
}
void FeatureStats::add(FeatureStatsType v)
{
if (isfull()) expand();
array_[entries_++]=v;
if (isfull()) expand();
array_[entries_++]=v;
}
void FeatureStats::set(std::string &theString)
{
std::string substring, stringBuf;
reset();
while (!theString.empty()){
getNextPound(theString, substring);
add(ATOFST(substring.c_str()));
}
reset();
while (!theString.empty()) {
getNextPound(theString, substring);
add(ATOFST(substring.c_str()));
}
}
void FeatureStats::loadbin(std::ifstream& inFile)
{
inFile.read((char*) array_, featbytes_);
}
inFile.read((char*) array_, featbytes_);
}
void FeatureStats::loadtxt(std::ifstream& inFile)
{
std::string theString;
std::getline(inFile, theString);
set(theString);
std::string theString;
std::getline(inFile, theString);
set(theString);
}
void FeatureStats::loadtxt(const std::string &file)
{
// TRACE_ERR("loading the stats from " << file << std::endl);
// TRACE_ERR("loading the stats from " << file << std::endl);
std::ifstream inFile(file.c_str(), std::ios::in); // matches a stream with a file. Opens the file
std::ifstream inFile(file.c_str(), std::ios::in); // matches a stream with a file. Opens the file
loadtxt(inFile);
loadtxt(inFile);
}
void FeatureStats::savetxt(const std::string &file)
{
// TRACE_ERR("saving the stats into " << file << std::endl);
// TRACE_ERR("saving the stats into " << file << std::endl);
std::ofstream outFile(file.c_str(), std::ios::out); // matches a stream with a file. Opens the file
std::ofstream outFile(file.c_str(), std::ios::out); // matches a stream with a file. Opens the file
savetxt(outFile);
savetxt(outFile);
}
void FeatureStats::savetxt(std::ofstream& outFile)
{
// TRACE_ERR("saving the stats" << std::endl);
outFile << *this;
// TRACE_ERR("saving the stats" << std::endl);
outFile << *this;
}
void FeatureStats::savebin(std::ofstream& outFile)
{
outFile.write((char*) array_, featbytes_);
}
outFile.write((char*) array_, featbytes_);
}
FeatureStats& FeatureStats::operator=(const FeatureStats &stats)
{
delete array_;
available_ = stats.available();
entries_ = stats.size();
array_ = new FeatureStatsType[available_];
memcpy(array_,stats.getArray(),featbytes_);
return *this;
delete array_;
available_ = stats.available();
entries_ = stats.size();
array_ = new FeatureStatsType[available_];
memcpy(array_,stats.getArray(),featbytes_);
return *this;
}
/**write the whole object to a stream*/
ostream& operator<<(ostream& o, const FeatureStats& e){
for (size_t i=0; i< e.size(); i++)
o << e.get(i) << " ";
return o;
ostream& operator<<(ostream& o, const FeatureStats& e)
{
for (size_t i=0; i< e.size(); i++)
o << e.get(i) << " ";
return o;
}

View File

@ -25,46 +25,67 @@ using namespace std;
class FeatureStats
{
private:
featstats_t array_;
size_t entries_;
size_t available_;
featstats_t array_;
size_t entries_;
size_t available_;
public:
FeatureStats();
FeatureStats(const size_t size);
FeatureStats(const FeatureStats &stats);
FeatureStats(std::string &theString);
FeatureStats& operator=(const FeatureStats &stats);
~FeatureStats();
bool isfull(){return (entries_ < available_)?0:1; }
void expand();
void add(FeatureStatsType v);
inline void clear() { memset((void*) array_,0,featbytes_); }
inline FeatureStatsType get(size_t i){ return array_[i]; }
inline FeatureStatsType get(size_t i)const{ return array_[i]; }
inline featstats_t getArray() const { return array_; }
FeatureStats();
FeatureStats(const size_t size);
FeatureStats(const FeatureStats &stats);
FeatureStats(std::string &theString);
FeatureStats& operator=(const FeatureStats &stats);
void set(std::string &theString);
~FeatureStats();
inline size_t bytes() const{ return featbytes_; }
inline size_t size() const{ return entries_; }
inline size_t available() const{ return available_; }
void savetxt(const std::string &file);
void savetxt(ofstream& outFile);
void savebin(ofstream& outFile);
inline void savetxt(){ savetxt("/dev/stdout"); }
void loadtxt(const std::string &file);
void loadtxt(ifstream& inFile);
void loadbin(ifstream& inFile);
bool isfull() {
return (entries_ < available_)?0:1;
}
void expand();
void add(FeatureStatsType v);
inline void clear() {
memset((void*) array_,0,featbytes_);
}
inline FeatureStatsType get(size_t i) {
return array_[i];
}
inline FeatureStatsType get(size_t i)const {
return array_[i];
}
inline featstats_t getArray() const {
return array_;
}
void set(std::string &theString);
inline size_t bytes() const {
return featbytes_;
}
inline size_t size() const {
return entries_;
}
inline size_t available() const {
return available_;
}
void savetxt(const std::string &file);
void savetxt(ofstream& outFile);
void savebin(ofstream& outFile);
inline void savetxt() {
savetxt("/dev/stdout");
}
void loadtxt(const std::string &file);
void loadtxt(ifstream& inFile);
void loadbin(ifstream& inFile);
inline void reset() {
entries_ = 0;
clear();
}
inline void reset(){ entries_ = 0; clear(); }
/**write the whole object to a stream*/
friend ostream& operator<<(ostream& o, const FeatureStats& e);
};

View File

@ -14,31 +14,34 @@ static const float MAX_FLOAT=numeric_limits<float>::max();
void Optimizer::SetScorer(Scorer *S){
void Optimizer::SetScorer(Scorer *S)
{
if(scorer)
delete scorer;
scorer=S;
}
void Optimizer::SetFData(FeatureData *F){
void Optimizer::SetFData(FeatureData *F)
{
if(FData)
delete FData;
FData=F;
};
Optimizer::Optimizer(unsigned Pd,vector<unsigned> i2O,vector<parameter_t> start):scorer(NULL),FData(NULL){
Optimizer::Optimizer(unsigned Pd,vector<unsigned> i2O,vector<parameter_t> start):scorer(NULL),FData(NULL)
{
//warning: the init vector is a full set of parameters, of dimension pdim!
Point::pdim=Pd;
assert(start.size()==Pd);
Point::dim=i2O.size();
Point::optindices=i2O;
if (Point::pdim>Point::dim){
for (unsigned int i=0;i<Point::pdim;i++){
if (Point::pdim>Point::dim) {
for (unsigned int i=0; i<Point::pdim; i++) {
unsigned int j = 0;
while (j<Point::dim && i!=i2O[j])
j++;
j++;
if (j==Point::dim)//the index i wasnt found on optindices, it is a fixed index, we use the value of the start vector
Point::fixedweights[i]=start[i];
@ -46,12 +49,14 @@ Optimizer::Optimizer(unsigned Pd,vector<unsigned> i2O,vector<parameter_t> start)
}
};
Optimizer::~Optimizer(){
Optimizer::~Optimizer()
{
delete scorer;
delete FData;
}
statscore_t Optimizer::GetStatScore(const Point& param)const{
statscore_t Optimizer::GetStatScore(const Point& param)const
{
vector<unsigned> bests;
Get1bests(param,bests);
//copy(bests.begin(),bests.end(),ostream_iterator<unsigned>(cerr," "));
@ -60,23 +65,25 @@ statscore_t Optimizer::GetStatScore(const Point& param)const{
};
/**compute the intersection of 2 lines*/
float intersect (float m1, float b1,float m2,float b2){
float intersect (float m1, float b1,float m2,float b2)
{
float isect = ((b2-b1)/(m1-m2));
if (!isfinite(isect)) {
isect = MAX_FLOAT;
isect = MAX_FLOAT;
}
return isect;
}
map<float,diff_t >::iterator AddThreshold(map<float,diff_t >& thresholdmap,float newt,pair<unsigned,unsigned> newdiff){
map<float,diff_t >::iterator AddThreshold(map<float,diff_t >& thresholdmap,float newt,pair<unsigned,unsigned> newdiff)
{
map<float,diff_t>::iterator it=thresholdmap.find(newt);
if(it!=thresholdmap.end()){
if(it!=thresholdmap.end()) {
//the threshold already exists!! this is very unlikely
if(it->second.back().first==newdiff.first)
it->second.back().second=newdiff.second;//there was already a diff for this sentence, we change the 1 best;
else
it->second.push_back(newdiff);
}else{
} else {
//normal case
pair< map<float,diff_t >::iterator,bool > ins=thresholdmap.insert(threshold(newt,diff_t(1,newdiff)));
assert(ins.second);//we really inserted something
@ -86,244 +93,247 @@ map<float,diff_t >::iterator AddThreshold(map<float,diff_t >& thresholdmap,float
};
statscore_t Optimizer::LineOptimize(const Point& origin,const Point& direction,Point& bestpoint)const{
statscore_t Optimizer::LineOptimize(const Point& origin,const Point& direction,Point& bestpoint)const
{
// we are looking for the best Point on the line y=Origin+x*direction
float min_int=0.0001;
//typedef pair<unsigned,unsigned> diff;//first the sentence that changes, second is the new 1best for this sentence
//list<threshold> thresholdlist;
map<float,diff_t> thresholdmap;
thresholdmap[MIN_FLOAT]=diff_t();
vector<unsigned> first1best;//the vector of nbests for x=-inf
for(unsigned int S=0;S<size();S++){
for(unsigned int S=0; S<size(); S++) {
map<float,diff_t >::iterator previnserted=thresholdmap.begin();
//first we determine the translation with the best feature score for each sentence and each value of x
//cerr << "Sentence " << S << endl;
multimap<float,unsigned> gradient;
vector<float> f0;
f0.resize(FData->get(S).size());
for(unsigned j=0;j<FData->get(S).size();j++){
for(unsigned j=0; j<FData->get(S).size(); j++) {
gradient.insert(pair<float,unsigned>(direction*(FData->get(S,j)),j));//gradient of the feature function for this particular target sentence
f0[j]=origin*FData->get(S,j);//compute the feature function at the origin point
}
//now lets compute the 1best for each value of x
// vector<pair<float,unsigned> > onebest;
multimap<float,unsigned>::iterator gradientit=gradient.begin();
multimap<float,unsigned>::iterator highest_f0=gradient.begin();
float smallest=gradientit->first;//smallest gradient
//several candidates can have the lowest slope (eg for word penalty where the gradient is an integer )
gradientit++;
while(gradientit!=gradient.end()&&gradientit->first==smallest){
while(gradientit!=gradient.end()&&gradientit->first==smallest) {
// cerr<<"ni"<<gradientit->second<<endl;;
//cerr<<"fos"<<f0[gradientit->second]<<" "<<f0[index]<<" "<<index<<endl;
if(f0[gradientit->second]>f0[highest_f0->second])
highest_f0=gradientit;//the highest line is the one with he highest f0
highest_f0=gradientit;//the highest line is the one with he highest f0
gradientit++;
}
gradientit = highest_f0;
first1best.push_back(highest_f0->second);
first1best.push_back(highest_f0->second);
//now we look for the intersections points indicating a change of 1 best
//we use the fact that the function is convex, which means that the gradient can only go up
while(gradientit!=gradient.end()){
//we use the fact that the function is convex, which means that the gradient can only go up
while(gradientit!=gradient.end()) {
map<float,unsigned>::iterator leftmost=gradientit;
float m=gradientit->first;
float b=f0[gradientit->second];
multimap<float,unsigned>::iterator gradientit2=gradientit;
gradientit2++;
float leftmostx=MAX_FLOAT;
for(;gradientit2!=gradient.end();gradientit2++){
//cerr<<"--"<<d++<<' '<<gradientit2->first<<' '<<gradientit2->second<<endl;
//look for all candidate with a gradient bigger than the current one and find the one with the leftmost intersection
float curintersect;
if(m!=gradientit2->first){
curintersect=intersect(m,b,gradientit2->first,f0[gradientit2->second]);
for(; gradientit2!=gradient.end(); gradientit2++) {
//cerr<<"--"<<d++<<' '<<gradientit2->first<<' '<<gradientit2->second<<endl;
//look for all candidate with a gradient bigger than the current one and find the one with the leftmost intersection
float curintersect;
if(m!=gradientit2->first) {
curintersect=intersect(m,b,gradientit2->first,f0[gradientit2->second]);
//cerr << "curintersect: " << curintersect << " leftmostx: " << leftmostx << endl;
if(curintersect<=leftmostx){
//we have found an intersection to the left of the leftmost we had so far.
//we might have curintersect==leftmostx for example is 2 candidates are the same
//in that case its better its better to update leftmost to gradientit2 to avoid some recomputing later
leftmostx=curintersect;
leftmost=gradientit2;//this is the new reference
}
}
if(curintersect<=leftmostx) {
//we have found an intersection to the left of the leftmost we had so far.
//we might have curintersect==leftmostx for example is 2 candidates are the same
//in that case its better its better to update leftmost to gradientit2 to avoid some recomputing later
leftmostx=curintersect;
leftmost=gradientit2;//this is the new reference
}
}
}
if (leftmost == gradientit) {
//we didn't find any more intersections
//the rightmost bestindex is the one with the highest slope.
assert(abs(leftmost->first-gradient.rbegin()->first)<0.0001);//they should be egal but there might be
//a small difference due to rounding error
break;
//we didn't find any more intersections
//the rightmost bestindex is the one with the highest slope.
assert(abs(leftmost->first-gradient.rbegin()->first)<0.0001);//they should be egal but there might be
//a small difference due to rounding error
break;
}
//we have found the next intersection!
pair<unsigned,unsigned> newd(S,leftmost->second);//new onebest for Sentence S is leftmost->second
if(leftmostx-previnserted->first<min_int){
/* Require that the intersection Point be at least min_int
to the right of the previous one(for this sentence). If not, we replace the
previous intersection Point with this one. Yes, it can even
happen that the new intersection Point is slightly to the
left of the old one, because of numerical imprecision.
we do not check that we are to the right of the penultimate point also. it this happen the 1best the inteval will be wrong
we are going to replace previnsert by the new one because we do not want to keep
2 very close threshold: if the minima is there it could be an artifact
*/
map<float,diff_t>::iterator tit=thresholdmap.find(leftmostx);
if(tit==previnserted){
//the threshold is the same as before can happen if 2 candidates are the same for example
assert(previnserted->second.back().first==newd.first);
previnserted->second.back()=newd;//just replace the 1 best fors sentence S
//previnsert doesnt change
}else{
if(leftmostx-previnserted->first<min_int) {
/* Require that the intersection Point be at least min_int
to the right of the previous one(for this sentence). If not, we replace the
previous intersection Point with this one. Yes, it can even
happen that the new intersection Point is slightly to the
left of the old one, because of numerical imprecision.
we do not check that we are to the right of the penultimate point also. it this happen the 1best the inteval will be wrong
we are going to replace previnsert by the new one because we do not want to keep
2 very close threshold: if the minima is there it could be an artifact
*/
map<float,diff_t>::iterator tit=thresholdmap.find(leftmostx);
if(tit==previnserted) {
//the threshold is the same as before can happen if 2 candidates are the same for example
assert(previnserted->second.back().first==newd.first);
previnserted->second.back()=newd;//just replace the 1 best fors sentence S
//previnsert doesnt change
} else {
if(tit==thresholdmap.end()){
thresholdmap[leftmostx]=previnserted->second;//We keep the diffs at previnsert
thresholdmap.erase(previnserted);//erase old previnsert
previnserted=thresholdmap.find(leftmostx);//point previnsert to the new threshold
previnserted->second.back()=newd;//we update the diff for sentence S
}else{//threshold already exists but is not the previous one.
//we append the diffs in previnsert to tit before destroying previnsert
tit->second.insert(tit->second.end(),previnserted->second.begin(),previnserted->second.end());
assert(tit->second.back().first==newd.first);
tit->second.back()=newd;//change diff for sentence S
thresholdmap.erase(previnserted);//erase old previnsert
previnserted=tit;//point previnsert to the new threshold
}
}
if(tit==thresholdmap.end()) {
thresholdmap[leftmostx]=previnserted->second;//We keep the diffs at previnsert
thresholdmap.erase(previnserted);//erase old previnsert
previnserted=thresholdmap.find(leftmostx);//point previnsert to the new threshold
previnserted->second.back()=newd;//we update the diff for sentence S
} else { //threshold already exists but is not the previous one.
//we append the diffs in previnsert to tit before destroying previnsert
tit->second.insert(tit->second.end(),previnserted->second.begin(),previnserted->second.end());
assert(tit->second.back().first==newd.first);
tit->second.back()=newd;//change diff for sentence S
thresholdmap.erase(previnserted);//erase old previnsert
previnserted=tit;//point previnsert to the new threshold
}
}
assert(previnserted != thresholdmap.end());
}else{//normal insertion process
previnserted=AddThreshold(thresholdmap,leftmostx,newd);
assert(previnserted != thresholdmap.end());
} else { //normal insertion process
previnserted=AddThreshold(thresholdmap,leftmostx,newd);
}
gradientit=leftmost;
} //while(gradientit!=gradient.end()){
} //loop on S
//now the thresholdlist is up to date:
//now the thresholdlist is up to date:
//it contains a list of all the parameter_ts where the function changed its value, along with the nbest list for the interval after each threshold
map<float,diff_t >::iterator thrit;
if(verboselevel()>6){
if(verboselevel()>6) {
cerr << "Thresholds:(" <<thresholdmap.size()<<")"<< endl;
for (thrit = thresholdmap.begin();thrit!=thresholdmap.end();thrit++){
for (thrit = thresholdmap.begin(); thrit!=thresholdmap.end(); thrit++) {
cerr << "x: " << thrit->first << " diffs";
for (size_t j = 0; j < thrit->second.size(); ++j) {
cerr << " " <<thrit->second[j].first << "," << thrit->second[j].second;
cerr << " " <<thrit->second[j].first << "," << thrit->second[j].second;
}
cerr << endl;
}
}
//last thing to do is compute the Stat score (ie BLEU) and find the minimum
thrit=thresholdmap.begin();
++thrit;//first diff corrrespond to MIN_FLOAT and first1best
diffs_t diffs;
for(;thrit!=thresholdmap.end();thrit++)
for(; thrit!=thresholdmap.end(); thrit++)
diffs.push_back(thrit->second);
vector<statscore_t> scores=GetIncStatScore(first1best,diffs);
thrit=thresholdmap.begin();
statscore_t bestscore=MIN_FLOAT;
float bestx=MIN_FLOAT;
assert(scores.size()==thresholdmap.size());//we skipped the first el of thresholdlist but GetIncStatScore return 1 more for first1best
for(unsigned int sc=0;sc!=scores.size();sc++){
for(unsigned int sc=0; sc!=scores.size(); sc++) {
//cerr << "x=" << thrit->first << " => " << scores[sc] << endl;
if (scores[sc] > bestscore) {
//This is the score for the interval [lit2->first, (lit2+1)->first]
//unless we're at the last score, when it's the score
//for the interval [lit2->first,+inf]
bestscore = scores[sc];
//This is the score for the interval [lit2->first, (lit2+1)->first]
//unless we're at the last score, when it's the score
//for the interval [lit2->first,+inf]
bestscore = scores[sc];
//if we're not in [-inf,x1] or [xn,+inf] then just take the value
//if x which splits the interval in half. For the rightmost interval,
//take x to be the last interval boundary + 0.1, and for the leftmost
//interval, take x to be the first interval boundary - 1000.
//These values are taken from cmert.
float leftx = thrit->first;
if (thrit == thresholdmap.begin()) {
leftx = MIN_FLOAT;
}
++thrit;
float rightx = MAX_FLOAT;
if (thrit != thresholdmap.end()) {
rightx = thrit->first;
}
--thrit;
//cerr << "leftx: " << leftx << " rightx: " << rightx << endl;
if (leftx == MIN_FLOAT) {
bestx = rightx-1000;
} else if (rightx == MAX_FLOAT) {
bestx = leftx+0.1;
} else {
bestx = 0.5 * (rightx + leftx);
}
//cerr << "x = " << "set new bestx to: " << bestx << endl;
//if we're not in [-inf,x1] or [xn,+inf] then just take the value
//if x which splits the interval in half. For the rightmost interval,
//take x to be the last interval boundary + 0.1, and for the leftmost
//interval, take x to be the first interval boundary - 1000.
//These values are taken from cmert.
float leftx = thrit->first;
if (thrit == thresholdmap.begin()) {
leftx = MIN_FLOAT;
}
++thrit;
float rightx = MAX_FLOAT;
if (thrit != thresholdmap.end()) {
rightx = thrit->first;
}
--thrit;
//cerr << "leftx: " << leftx << " rightx: " << rightx << endl;
if (leftx == MIN_FLOAT) {
bestx = rightx-1000;
} else if (rightx == MAX_FLOAT) {
bestx = leftx+0.1;
} else {
bestx = 0.5 * (rightx + leftx);
}
//cerr << "x = " << "set new bestx to: " << bestx << endl;
}
++thrit;
}
if(abs(bestx)<0.00015){
if(abs(bestx)<0.00015) {
bestx=0.0;//the origin of the line is the best point!we put it back at 0 so we do not propagate rounding erros
//finally! we manage to extract the best score;
//now we convert bestx (position on the line) to a point!
//finally! we manage to extract the best score;
//now we convert bestx (position on the line) to a point!
if(verboselevel()>4)
cerr<<"best point on line at origin"<<endl;
}
if(verboselevel()>3){
if(verboselevel()>3) {
// cerr<<"end Lineopt, bestx="<<bestx<<endl;
}
bestpoint=direction*bestx+origin;
}
bestpoint=direction*bestx+origin;
bestpoint.score=bestscore;
return bestscore;
return bestscore;
};
void Optimizer::Get1bests(const Point& P,vector<unsigned>& bests)const{
void Optimizer::Get1bests(const Point& P,vector<unsigned>& bests)const
{
assert(FData);
bests.clear();
bests.resize(size());
for(unsigned i=0;i<size();i++){
for(unsigned i=0; i<size(); i++) {
float bestfs=MIN_FLOAT;
unsigned idx=0;
unsigned j;
for(j=0;j<FData->get(i).size();j++){
for(j=0; j<FData->get(i).size(); j++) {
float curfs=P*FData->get(i,j);
if(curfs>bestfs){
bestfs=curfs;
idx=j;
if(curfs>bestfs) {
bestfs=curfs;
idx=j;
}
}
bests[i]=idx;
}
}
statscore_t Optimizer::Run(Point& P)const{
if(!FData){
statscore_t Optimizer::Run(Point& P)const
{
if(!FData) {
cerr<<"error trying to optimize without Features loaded"<<endl;
exit(2);
}
if(!scorer){
if(!scorer) {
cerr<<"error trying to optimize without a Scorer loaded"<<endl;
exit(2);
}
if (scorer->getReferenceSize()!=FData->size()){
if (scorer->getReferenceSize()!=FData->size()) {
cerr<<"error length mismatch between feature file and score file"<<endl;
exit(2);
}
statscore_t score=GetStatScore(P);
P.score=score;
if(verboselevel()>2)
statscore_t score=GetStatScore(P);
P.score=score;
if(verboselevel()>2)
cerr<<"Starting point: "<< P << " => "<< P.score << endl;
statscore_t s=TrueRun(P);
P.score=s;//just in case its not done in TrueRun
@ -331,9 +341,10 @@ statscore_t Optimizer::Run(Point& P)const{
cerr<<"Ending point: "<< P <<" => "<< s << endl;
return s;
}
vector<statscore_t> Optimizer::GetIncStatScore(vector<unsigned> thefirst,vector<vector <pair<unsigned,unsigned> > > thediffs)const{
vector<statscore_t> Optimizer::GetIncStatScore(vector<unsigned> thefirst,vector<vector <pair<unsigned,unsigned> > > thediffs)const
{
assert(scorer);
vector<statscore_t> theres;
@ -347,61 +358,62 @@ vector<statscore_t> Optimizer::GetIncStatScore(vector<unsigned> thefirst,vector<
//---------------- code for the powell optimizer
float SimpleOptimizer::eps=0.0001;
statscore_t SimpleOptimizer::TrueRun(Point& P)const{
statscore_t SimpleOptimizer::TrueRun(Point& P)const
{
statscore_t prevscore=0;
statscore_t bestscore=MIN_FLOAT;
Point best;
//If P is already defined and provides a score
//If P is already defined and provides a score
//we must improve over this score
if(P.score>bestscore){
bestscore=P.score;
best=P;
}
if(P.score>bestscore) {
bestscore=P.score;
best=P;
}
int nrun=0;
do{
++nrun;
do {
++nrun;
if(verboselevel()>2&&nrun>1)
cerr<<"last diff="<<bestscore-prevscore<<" nrun "<<nrun<<endl;
prevscore=bestscore;
Point linebest;
for(unsigned int d=0;d<Point::getdim();d++){
if(verboselevel()>4){
// cerr<<"minimizing along direction "<<d<<endl;
cerr<<"starting point: " << P << " => " << prevscore << endl;
for(unsigned int d=0; d<Point::getdim(); d++) {
if(verboselevel()>4) {
// cerr<<"minimizing along direction "<<d<<endl;
cerr<<"starting point: " << P << " => " << prevscore << endl;
}
Point direction;
for(unsigned int i=0;i<Point::getdim();i++)
direction[i];
for(unsigned int i=0; i<Point::getdim(); i++)
direction[i];
direction[d]=1.0;
statscore_t curscore=LineOptimize(P,direction,linebest);//find the minimum on the line
if(verboselevel()>5){
cerr<<"direction: "<< d << " => " << curscore << endl;
cerr<<"\tending point: "<< linebest << " => " << curscore << endl;
}
if(curscore>bestscore){
bestscore=curscore;
best=linebest;
if(verboselevel()>3){
cerr<<"new best dir:"<<d<<" ("<<nrun<<")"<<endl;
cerr<<"new best Point "<<best<< " => " <<curscore<<endl;
}
}
if(verboselevel()>5) {
cerr<<"direction: "<< d << " => " << curscore << endl;
cerr<<"\tending point: "<< linebest << " => " << curscore << endl;
}
if(curscore>bestscore) {
bestscore=curscore;
best=linebest;
if(verboselevel()>3) {
cerr<<"new best dir:"<<d<<" ("<<nrun<<")"<<endl;
cerr<<"new best Point "<<best<< " => " <<curscore<<endl;
}
}
}
P=best;//update the current vector with the best point on all line tested
if(verboselevel()>3)
cerr<<nrun<<"\t"<<P<<endl;
}while(bestscore-prevscore>eps);
if(verboselevel()>2){
if(verboselevel()>3)
cerr<<nrun<<"\t"<<P<<endl;
} while(bestscore-prevscore>eps);
if(verboselevel()>2) {
cerr<<"end Powell Algo, nrun="<<nrun<<endl;
cerr<<"last diff="<<bestscore-prevscore<<endl;
cerr<<"\t"<<P<<endl;
}
}
return bestscore;
}
@ -409,58 +421,63 @@ statscore_t SimpleOptimizer::TrueRun(Point& P)const{
/**RandomOptimizer to use as beaseline and test.\n
Just return a random point*/
statscore_t RandomOptimizer::TrueRun(Point& P)const{
statscore_t RandomOptimizer::TrueRun(Point& P)const
{
vector<parameter_t> min(Point::getdim());
vector<parameter_t> max(Point::getdim());
for(unsigned int d=0;d<Point::getdim();d++){
for(unsigned int d=0; d<Point::getdim(); d++) {
min[d]=0.0;
max[d]=1.0;
}
P.Randomize(min,max);
statscore_t score=GetStatScore(P);
P.score=score;
return score;
P.Randomize(min,max);
statscore_t score=GetStatScore(P);
P.score=score;
return score;
}
//--------------------------------------
vector<string> OptimizerFactory::typenames;
void OptimizerFactory::SetTypeNames(){
if(typenames.empty()){
void OptimizerFactory::SetTypeNames()
{
if(typenames.empty()) {
typenames.resize(NOPTIMIZER);
typenames[POWELL]="powell";
typenames[RANDOM]="random";
//add new type there
}
}
}
vector<string> OptimizerFactory::GetTypeNames(){
vector<string> OptimizerFactory::GetTypeNames()
{
if(typenames.empty())
SetTypeNames();
return typenames;
}
OptimizerFactory::OptType OptimizerFactory::GetOType(string type){
OptimizerFactory::OptType OptimizerFactory::GetOType(string type)
{
unsigned int thetype;
if(typenames.empty())
SetTypeNames();
for(thetype=0;thetype<typenames.size();thetype++)
for(thetype=0; thetype<typenames.size(); thetype++)
if(typenames[thetype]==type)
break;
return((OptType)thetype);
};
Optimizer* OptimizerFactory::BuildOptimizer(unsigned dim,vector<unsigned> i2o,vector<parameter_t> start,string type){
Optimizer* OptimizerFactory::BuildOptimizer(unsigned dim,vector<unsigned> i2o,vector<parameter_t> start,string type)
{
OptType T=GetOType(type);
if(T==NOPTIMIZER){
if(T==NOPTIMIZER) {
cerr<<"Error: unknown Optimizer type "<<type<<endl;
cerr<<"Known Algorithm are:"<<endl;
unsigned int thetype;
for(thetype=0;thetype<typenames.size();thetype++)
for(thetype=0; thetype<typenames.size(); thetype++)
cerr<<typenames[thetype]<<endl;
throw ("unknown Optimizer Type");
}
switch((OptType)T){
switch((OptType)T) {
case POWELL:
return new SimpleOptimizer(dim,i2o,start);
break;
@ -469,6 +486,6 @@ Optimizer* OptimizerFactory::BuildOptimizer(unsigned dim,vector<unsigned> i2o,ve
break;
default:
cerr<<"Error: unknown optimizer"<<type<<endl;
return NULL;
}
return NULL;
}
}

View File

@ -15,61 +15,69 @@ typedef float featurescore;
using namespace std;
/**abstract virtual class*/
class Optimizer{
protected:
Scorer * scorer; //no accessor for them only child can use them
FeatureData * FData;//no accessor for them only child can use them
public:
class Optimizer
{
protected:
Scorer * scorer; //no accessor for them only child can use them
FeatureData * FData;//no accessor for them only child can use them
public:
Optimizer(unsigned Pd,vector<unsigned> i2O,vector<parameter_t> start);
void SetScorer(Scorer *S);
void SetFData(FeatureData *F);
virtual ~Optimizer();
unsigned size()const{return (FData?FData->size():0);}
unsigned size()const {
return (FData?FData->size():0);
}
/**Generic wrapper around TrueRun to check a few things. Non virtual*/
statscore_t Run(Point&)const;
/**main function that perform an optimization*/
/**main function that perform an optimization*/
virtual statscore_t TrueRun(Point&)const=0;
/**given a set of lambdas, get the nbest for each sentence*/
void Get1bests(const Point& param,vector<unsigned>& bests)const;
/**given a set of nbests, get the Statistical score*/
statscore_t GetStatScore(const vector<unsigned>& nbests)const{return scorer->score(nbests);};
statscore_t GetStatScore(const vector<unsigned>& nbests)const {
return scorer->score(nbests);
};
/**given a set of lambdas, get the total statistical score*/
statscore_t GetStatScore(const Point& param)const;
statscore_t GetStatScore(const Point& param)const;
vector<statscore_t > GetIncStatScore(vector<unsigned> ref,vector<vector <pair<unsigned,unsigned> > >)const;
statscore_t LineOptimize(const Point& start,const Point& direction,Point& best)const;//Get the optimal Lambda and the best score in a particular direction from a given Point
};
/**default basic optimizer*/
class SimpleOptimizer: public Optimizer{
class SimpleOptimizer: public Optimizer
{
private:
static float eps;
static float eps;
public:
SimpleOptimizer(unsigned dim,vector<unsigned> i2O,vector<parameter_t> start):Optimizer(dim,i2O,start){};
SimpleOptimizer(unsigned dim,vector<unsigned> i2O,vector<parameter_t> start):Optimizer(dim,i2O,start) {};
virtual statscore_t TrueRun(Point&)const;
};
class RandomOptimizer: public Optimizer{
class RandomOptimizer: public Optimizer
{
public:
RandomOptimizer(unsigned dim,vector<unsigned> i2O,vector<parameter_t> start):Optimizer(dim,i2O,start){};
RandomOptimizer(unsigned dim,vector<unsigned> i2O,vector<parameter_t> start):Optimizer(dim,i2O,start) {};
virtual statscore_t TrueRun(Point&)const;
};
class OptimizerFactory{
public:
class OptimizerFactory
{
public:
// unsigned dim;
//Point Start;
static vector<string> GetTypeNames();
static Optimizer* BuildOptimizer(unsigned dim,vector<unsigned>tooptimize,vector<parameter_t> start,string type);
private:
enum OptType{POWELL=0,RANDOM,NOPTIMIZER};//Add new optimizer here BEFORE NOPTIMZER
private:
enum OptType {POWELL=0,RANDOM,NOPTIMIZER}; //Add new optimizer here BEFORE NOPTIMZER
static OptType GetOType(string);
static vector<string> typenames;
static void SetTypeNames();
};

View File

@ -1,69 +1,72 @@
#include "PerScorer.h"
void PerScorer::setReferenceFiles(const vector<string>& referenceFiles) {
// for each line in the reference file, create a multiset of the
// word ids
if (referenceFiles.size() != 1) {
throw runtime_error("PER only supports a single reference");
void PerScorer::setReferenceFiles(const vector<string>& referenceFiles)
{
// for each line in the reference file, create a multiset of the
// word ids
if (referenceFiles.size() != 1) {
throw runtime_error("PER only supports a single reference");
}
_reftokens.clear();
_reflengths.clear();
ifstream in(referenceFiles[0].c_str());
if (!in) {
throw runtime_error("Unable to open " + referenceFiles[0]);
}
string line;
int sid = 0;
while (getline(in,line)) {
vector<int> tokens;
encode(line,tokens);
_reftokens.push_back(multiset<int>());
for (size_t i = 0; i < tokens.size(); ++i) {
_reftokens.back().insert(tokens[i]);
}
_reftokens.clear();
_reflengths.clear();
ifstream in(referenceFiles[0].c_str());
if (!in) {
throw runtime_error("Unable to open " + referenceFiles[0]);
_reflengths.push_back(tokens.size());
if (sid > 0 && sid % 100 == 0) {
TRACE_ERR(".");
}
string line;
int sid = 0;
while (getline(in,line)) {
vector<int> tokens;
encode(line,tokens);
_reftokens.push_back(multiset<int>());
for (size_t i = 0; i < tokens.size(); ++i) {
_reftokens.back().insert(tokens[i]);
}
_reflengths.push_back(tokens.size());
if (sid > 0 && sid % 100 == 0) {
TRACE_ERR(".");
}
++sid;
}
TRACE_ERR(endl);
++sid;
}
TRACE_ERR(endl);
}
void PerScorer::prepareStats(size_t sid, const string& text, ScoreStats& entry) {
if (sid >= _reflengths.size()) {
stringstream msg;
msg << "Sentence id (" << sid << ") not found in reference set";
throw runtime_error(msg.str());
}
//calculate correct, output_length and ref_length for
//the line and store it in entry
vector<int> testtokens;
encode(text,testtokens);
multiset<int> testtokens_all(testtokens.begin(),testtokens.end());
set<int> testtokens_unique(testtokens.begin(),testtokens.end());
int correct = 0;
for (set<int>::iterator i = testtokens_unique.begin();
i != testtokens_unique.end(); ++i) {
int token = *i;
correct += min(_reftokens[sid].count(token), testtokens_all.count(token));
}
ostringstream stats;
stats << correct << " " << testtokens.size() << " " << _reflengths[sid] << " " ;
string stats_str = stats.str();
entry.set(stats_str);
void PerScorer::prepareStats(size_t sid, const string& text, ScoreStats& entry)
{
if (sid >= _reflengths.size()) {
stringstream msg;
msg << "Sentence id (" << sid << ") not found in reference set";
throw runtime_error(msg.str());
}
//calculate correct, output_length and ref_length for
//the line and store it in entry
vector<int> testtokens;
encode(text,testtokens);
multiset<int> testtokens_all(testtokens.begin(),testtokens.end());
set<int> testtokens_unique(testtokens.begin(),testtokens.end());
int correct = 0;
for (set<int>::iterator i = testtokens_unique.begin();
i != testtokens_unique.end(); ++i) {
int token = *i;
correct += min(_reftokens[sid].count(token), testtokens_all.count(token));
}
ostringstream stats;
stats << correct << " " << testtokens.size() << " " << _reflengths[sid] << " " ;
string stats_str = stats.str();
entry.set(stats_str);
}
float PerScorer::calculateScore(const vector<int>& comps) {
float denom = comps[2];
float num = comps[0] - max(0,comps[1]-comps[2]);
if (denom == 0) {
//shouldn't happen!
return 0.0;
} else {
return num/denom;
}
float PerScorer::calculateScore(const vector<int>& comps)
{
float denom = comps[2];
float num = comps[0] - max(0,comps[1]-comps[2]);
if (denom == 0) {
//shouldn't happen!
return 0.0;
} else {
return num/denom;
}
}

View File

@ -22,32 +22,36 @@ using namespace std;
* as 1 - (correct - max(0,output_length - ref_length)) / ref_length
* In fact, we ignore the " 1 - " so that it can be maximised.
**/
class PerScorer: public StatisticsBasedScorer {
public:
PerScorer(const string& config = "") : StatisticsBasedScorer("PER",config) {}
virtual void setReferenceFiles(const vector<string>& referenceFiles);
virtual void prepareStats(size_t sid, const string& text, ScoreStats& entry);
virtual void whoami() {
cerr << "I AM PerScorer" << std::endl;
}
size_t NumberOfScores(){ cerr << "PerScorer: 3" << endl; return 3; };
protected:
virtual float calculateScore(const vector<int>& comps) ;
private:
//no copy
PerScorer(const PerScorer&);
~PerScorer(){};
PerScorer& operator=(const PerScorer&);
// data extracted from reference files
vector<size_t> _reflengths;
vector<multiset<int> > _reftokens;
class PerScorer: public StatisticsBasedScorer
{
public:
PerScorer(const string& config = "") : StatisticsBasedScorer("PER",config) {}
virtual void setReferenceFiles(const vector<string>& referenceFiles);
virtual void prepareStats(size_t sid, const string& text, ScoreStats& entry);
virtual void whoami() {
cerr << "I AM PerScorer" << std::endl;
}
size_t NumberOfScores() {
cerr << "PerScorer: 3" << endl;
return 3;
};
protected:
virtual float calculateScore(const vector<int>& comps) ;
private:
//no copy
PerScorer(const PerScorer&);
~PerScorer() {};
PerScorer& operator=(const PerScorer&);
// data extracted from reference files
vector<size_t> _reflengths;
vector<multiset<int> > _reftokens;
};
#endif //__PERSCORER_H

View File

@ -10,22 +10,24 @@ vector<unsigned> Point::optindices;
unsigned Point::dim=0;
map<unsigned,statscore_t> Point::fixedweights;
unsigned Point::pdim=0;
unsigned Point::ncall=0;
void Point::Randomize(const vector<parameter_t>& min,const vector<parameter_t>& max){
void Point::Randomize(const vector<parameter_t>& min,const vector<parameter_t>& max)
{
assert(min.size()==Point::dim);
assert(max.size()==Point::dim);
for (unsigned int i=0; i<size(); i++)
operator[](i)= min[i] + (float)random()/(float)RAND_MAX * (float)(max[i]-min[i]);
}
void Point::NormalizeL2(){
void Point::NormalizeL2()
{
parameter_t norm=0.0;
for (unsigned int i=0; i<size(); i++)
norm+= operator[](i)*operator[](i);
if(norm!=0.0){
if(norm!=0.0) {
norm=sqrt(norm);
for (unsigned int i=0; i<size(); i++)
operator[](i)/=norm;
@ -33,22 +35,24 @@ void Point::NormalizeL2(){
}
void Point::NormalizeL1(){
void Point::NormalizeL1()
{
parameter_t norm=0.0;
for (unsigned int i=0; i<size(); i++)
norm+= abs(operator[](i));
if(norm!=0.0){
for (unsigned int i=0; i<size(); i++)
operator[](i)/=norm;
}
if(norm!=0.0) {
for (unsigned int i=0; i<size(); i++)
operator[](i)/=norm;
}
}
//Can initialize from a vector of dim or pdim
Point::Point(const vector<parameter_t>& init):vector<parameter_t>(Point::dim){
if(init.size()==dim){
Point::Point(const vector<parameter_t>& init):vector<parameter_t>(Point::dim)
{
if(init.size()==dim) {
for (unsigned int i=0; i<Point::dim; i++)
operator[](i)=init[i];
}else{
} else {
assert(init.size()==pdim);
for (unsigned int i=0; i<Point::dim; i++)
operator[](i)=init[optindices[i]];
@ -56,59 +60,64 @@ Point::Point(const vector<parameter_t>& init):vector<parameter_t>(Point::dim){
};
double Point::operator*(const FeatureStats& F)const{
double Point::operator*(const FeatureStats& F)const
{
ncall++;//to track performance
double prod=0.0;
if(OptimizeAll())
for (unsigned i=0; i<size(); i++)
prod+= operator[](i)*F.get(i);
else{
else {
for (unsigned i=0; i<size(); i++)
prod+= operator[](i)*F.get(optindices[i]);
for(map<unsigned,float >::iterator it=fixedweights.begin();it!=fixedweights.end();it++)
for(map<unsigned,float >::iterator it=fixedweights.begin(); it!=fixedweights.end(); it++)
prod+=it->second*F.get(it->first);
}
return prod;
}
Point Point::operator+(const Point& p2)const{
Point Point::operator+(const Point& p2)const
{
assert(p2.size()==size());
Point Res(*this);
for(unsigned i=0;i<size();i++)
for(unsigned i=0; i<size(); i++)
Res[i]+=p2[i];
Res.score=numeric_limits<statscore_t>::max();
return Res;
};
Point Point::operator*(float l)const{
Point Point::operator*(float l)const
{
Point Res(*this);
for(unsigned i=0;i<size();i++)
for(unsigned i=0; i<size(); i++)
Res[i]*=l;
Res.score=numeric_limits<statscore_t>::max();
return Res;
};
ostream& operator<<(ostream& o,const Point& P){
vector<parameter_t> w=P.GetAllWeights();
ostream& operator<<(ostream& o,const Point& P)
{
vector<parameter_t> w=P.GetAllWeights();
// o << "[" << Point::pdim << "] ";
for(unsigned int i=0;i<Point::pdim;i++)
o << w[i] << " ";
for(unsigned int i=0; i<Point::pdim; i++)
o << w[i] << " ";
// o << "=> " << P.GetScore();
return o;
return o;
};
vector<parameter_t> Point::GetAllWeights()const{
vector<parameter_t> Point::GetAllWeights()const
{
vector<parameter_t> w;
if(OptimizeAll()){
if(OptimizeAll()) {
w=*this;
}else{
} else {
w.resize(pdim);
for (unsigned int i=0; i<size(); i++)
w[optindices[i]]=operator[](i);
for(map<unsigned,float >::iterator it=fixedweights.begin();it!=fixedweights.end();it++)
w[it->first]=it->second;
for(map<unsigned,float >::iterator it=fixedweights.begin(); it!=fixedweights.end(); it++)
w[it->first]=it->second;
}
return w;
};

View File

@ -10,9 +10,10 @@ class Optimizer;
/**class that handle the subset of the Feature weight on which we run the optimization*/
class Point:public vector<parameter_t>{
class Point:public vector<parameter_t>
{
friend class Optimizer;
private:
private:
/**The indices over which we optimize*/
static vector<unsigned int> optindices;
/**dimension of optindices and of the parent vector*/
@ -22,12 +23,18 @@ class Point:public vector<parameter_t>{
/**total size of the parameter space; we have pdim=FixedWeight.size()+optinidices.size()*/
static unsigned int pdim;
static unsigned int ncall;
public:
static unsigned int getdim(){return dim;}
static unsigned int getpdim(){return pdim;}
static bool OptimizeAll(){return fixedweights.empty();};
public:
static unsigned int getdim() {
return dim;
}
static unsigned int getpdim() {
return pdim;
}
static bool OptimizeAll() {
return fixedweights.empty();
};
statscore_t score;
Point():vector<parameter_t>(dim){};
Point():vector<parameter_t>(dim) {};
Point(const vector<parameter_t>& init);
void Randomize(const vector<parameter_t>& min,const vector<parameter_t>& max);
@ -36,12 +43,16 @@ class Point:public vector<parameter_t>{
Point operator*(float)const;
/**write the Whole featureweight to a stream (ie pdim float)*/
friend ostream& operator<<(ostream& o,const Point& P);
void Normalize(){ NormalizeL2(); };
void Normalize() {
NormalizeL2();
};
void NormalizeL2();
void NormalizeL1();
/**return a vector of size pdim where all weights have been put(including fixed ones)*/
vector<parameter_t> GetAllWeights()const;
statscore_t GetScore()const { return score; };
statscore_t GetScore()const {
return score;
};
};
#endif

View File

@ -15,134 +15,134 @@ ScoreArray::ScoreArray(): idx("")
void ScoreArray::savetxt(std::ofstream& outFile, const std::string& sctype)
{
outFile << SCORES_TXT_BEGIN << " " << idx << " " << array_.size()
<< " " << number_of_scores << " " << sctype << std::endl;
for (scorearray_t::iterator i = array_.begin(); i !=array_.end(); i++){
i->savetxt(outFile);
outFile << std::endl;
}
outFile << SCORES_TXT_END << std::endl;
outFile << SCORES_TXT_BEGIN << " " << idx << " " << array_.size()
<< " " << number_of_scores << " " << sctype << std::endl;
for (scorearray_t::iterator i = array_.begin(); i !=array_.end(); i++) {
i->savetxt(outFile);
outFile << std::endl;
}
outFile << SCORES_TXT_END << std::endl;
}
void ScoreArray::savebin(std::ofstream& outFile, const std::string& sctype)
{
outFile << SCORES_BIN_BEGIN << " " << idx << " " << array_.size()
<< " " << number_of_scores << " " << sctype << std::endl;
for (scorearray_t::iterator i = array_.begin(); i !=array_.end(); i++)
i->savebin(outFile);
outFile << SCORES_BIN_END << std::endl;
outFile << SCORES_BIN_BEGIN << " " << idx << " " << array_.size()
<< " " << number_of_scores << " " << sctype << std::endl;
for (scorearray_t::iterator i = array_.begin(); i !=array_.end(); i++)
i->savebin(outFile);
outFile << SCORES_BIN_END << std::endl;
}
void ScoreArray::save(std::ofstream& inFile, const std::string& sctype, bool bin)
{
if (size()>0)
(bin)?savebin(inFile, sctype):savetxt(inFile, sctype);
if (size()>0)
(bin)?savebin(inFile, sctype):savetxt(inFile, sctype);
}
void ScoreArray::save(const std::string &file, const std::string& sctype, bool bin)
{
std::ofstream outFile(file.c_str(), std::ios::out); // matches a stream with a file. Opens the file
std::ofstream outFile(file.c_str(), std::ios::out); // matches a stream with a file. Opens the file
save(outFile, sctype, bin);
save(outFile, sctype, bin);
outFile.close();
outFile.close();
}
void ScoreArray::loadbin(ifstream& inFile, size_t n)
{
ScoreStats entry(number_of_scores);
for (size_t i=0 ; i < n; i++){
entry.loadbin(inFile);
add(entry);
}
ScoreStats entry(number_of_scores);
for (size_t i=0 ; i < n; i++) {
entry.loadbin(inFile);
add(entry);
}
}
void ScoreArray::loadtxt(ifstream& inFile, size_t n)
{
ScoreStats entry(number_of_scores);
for (size_t i=0 ; i < n; i++){
entry.loadtxt(inFile);
add(entry);
}
ScoreStats entry(number_of_scores);
for (size_t i=0 ; i < n; i++) {
entry.loadtxt(inFile);
add(entry);
}
}
void ScoreArray::load(ifstream& inFile)
{
size_t number_of_entries=0;
bool binmode=false;
std::string substring, stringBuf;
bool binmode=false;
std::string substring, stringBuf;
std::string::size_type loc;
std::getline(inFile, stringBuf);
if (!inFile.good()){
return;
}
if (!stringBuf.empty()){
if ((loc = stringBuf.find(SCORES_TXT_BEGIN)) == 0){
binmode=false;
}else if ((loc = stringBuf.find(SCORES_BIN_BEGIN)) == 0){
binmode=true;
}else{
TRACE_ERR("ERROR: ScoreArray::load(): Wrong header");
return;
}
getNextPound(stringBuf, substring);
getNextPound(stringBuf, substring);
idx = substring;
getNextPound(stringBuf, substring);
std::getline(inFile, stringBuf);
if (!inFile.good()) {
return;
}
if (!stringBuf.empty()) {
if ((loc = stringBuf.find(SCORES_TXT_BEGIN)) == 0) {
binmode=false;
} else if ((loc = stringBuf.find(SCORES_BIN_BEGIN)) == 0) {
binmode=true;
} else {
TRACE_ERR("ERROR: ScoreArray::load(): Wrong header");
return;
}
getNextPound(stringBuf, substring);
getNextPound(stringBuf, substring);
idx = substring;
getNextPound(stringBuf, substring);
number_of_entries = atoi(substring.c_str());
getNextPound(stringBuf, substring);
getNextPound(stringBuf, substring);
number_of_scores = atoi(substring.c_str());
getNextPound(stringBuf, substring);
score_type = substring;
}
(binmode)?loadbin(inFile, number_of_entries):loadtxt(inFile, number_of_entries);
std::getline(inFile, stringBuf);
if (!stringBuf.empty()){
if ((loc = stringBuf.find(SCORES_TXT_END)) != 0 && (loc = stringBuf.find(SCORES_BIN_END)) != 0){
TRACE_ERR("ERROR: ScoreArray::load(): Wrong footer");
return;
}
}
getNextPound(stringBuf, substring);
score_type = substring;
}
(binmode)?loadbin(inFile, number_of_entries):loadtxt(inFile, number_of_entries);
std::getline(inFile, stringBuf);
if (!stringBuf.empty()) {
if ((loc = stringBuf.find(SCORES_TXT_END)) != 0 && (loc = stringBuf.find(SCORES_BIN_END)) != 0) {
TRACE_ERR("ERROR: ScoreArray::load(): Wrong footer");
return;
}
}
}
void ScoreArray::load(const std::string &file)
{
TRACE_ERR("loading data from " << file << std::endl);
TRACE_ERR("loading data from " << file << std::endl);
inputfilestream inFile(file); // matches a stream with a file. Opens the file
inputfilestream inFile(file); // matches a stream with a file. Opens the file
load((ifstream&) inFile);
load((ifstream&) inFile);
inFile.close();
inFile.close();
}
void ScoreArray::merge(ScoreArray& e)
{
//dummy implementation
for (size_t i=0; i<e.size(); i++)
add(e.get(i));
//dummy implementation
for (size_t i=0; i<e.size(); i++)
add(e.get(i));
}
bool ScoreArray::check_consistency()
{
size_t sz = NumberOfScores();
if (sz == 0)
return true;
for (scorearray_t::iterator i=array_.begin(); i!=array_.end(); i++)
if (i->size()!=sz)
return false;
return true;
size_t sz = NumberOfScores();
if (sz == 0)
return true;
for (scorearray_t::iterator i=array_.begin(); i!=array_.end(); i++)
if (i->size()!=sz)
return false;
return true;
}

View File

@ -27,52 +27,76 @@ using namespace std;
class ScoreArray
{
protected:
scorearray_t array_;
std::string score_type;
size_t number_of_scores;
private:
std::string idx; // idx to identify the utterance, it can differ from the index inside the vector
scorearray_t array_;
std::string score_type;
size_t number_of_scores;
private:
std::string idx; // idx to identify the utterance, it can differ from the index inside the vector
public:
ScoreArray();
~ScoreArray(){};
inline void clear() { array_.clear(); }
inline std::string getIndex(){ return idx; }
inline void setIndex(const std::string& value){ idx=value; }
ScoreArray();
~ScoreArray() {};
inline void clear() {
array_.clear();
}
inline std::string getIndex() {
return idx;
}
inline void setIndex(const std::string& value) {
idx=value;
}
// inline ScoreStats get(size_t i){ return array_.at(i); }
inline ScoreStats& get(size_t i){ return array_.at(i); }
inline const ScoreStats& get(size_t i)const{ return array_.at(i); }
void add(const ScoreStats& e){ array_.push_back(e); }
inline ScoreStats& get(size_t i) {
return array_.at(i);
}
inline const ScoreStats& get(size_t i)const {
return array_.at(i);
}
void merge(ScoreArray& e);
void add(const ScoreStats& e) {
array_.push_back(e);
}
inline std::string name() const{ return score_type; };
inline void name(std::string &sctype){ score_type = sctype; };
void merge(ScoreArray& e);
inline size_t size(){ return array_.size(); }
inline size_t NumberOfScores() const{ return number_of_scores; }
inline void NumberOfScores(size_t v){ number_of_scores = v; }
void savetxt(ofstream& outFile, const std::string& sctype);
void savebin(ofstream& outFile, const std::string& sctype);
void save(ofstream& outFile, const std::string& sctype, bool bin=false);
void save(const std::string &file, const std::string& sctype, bool bin=false);
inline void save(const std::string& sctype, bool bin=false){ save("/dev/stdout", sctype, bin); }
void loadtxt(ifstream& inFile, size_t n);
void loadbin(ifstream& inFile, size_t n);
void load(ifstream& inFile);
void load(const std::string &file);
bool check_consistency();
inline std::string name() const {
return score_type;
};
inline void name(std::string &sctype) {
score_type = sctype;
};
inline size_t size() {
return array_.size();
}
inline size_t NumberOfScores() const {
return number_of_scores;
}
inline void NumberOfScores(size_t v) {
number_of_scores = v;
}
void savetxt(ofstream& outFile, const std::string& sctype);
void savebin(ofstream& outFile, const std::string& sctype);
void save(ofstream& outFile, const std::string& sctype, bool bin=false);
void save(const std::string &file, const std::string& sctype, bool bin=false);
inline void save(const std::string& sctype, bool bin=false) {
save("/dev/stdout", sctype, bin);
}
void loadtxt(ifstream& inFile, size_t n);
void loadbin(ifstream& inFile, size_t n);
void load(ifstream& inFile);
void load(const std::string &file);
bool check_consistency();
};

View File

@ -13,121 +13,121 @@
ScoreData::ScoreData(Scorer& ptr):
theScorer(&ptr)
theScorer(&ptr)
{
score_type = theScorer->getName();
theScorer->setScoreData(this);//this is not dangerous: we dont use the this pointer in SetScoreData
number_of_scores = theScorer->NumberOfScores();
TRACE_ERR("ScoreData: number_of_scores: " << number_of_scores << std::endl);
score_type = theScorer->getName();
theScorer->setScoreData(this);//this is not dangerous: we dont use the this pointer in SetScoreData
number_of_scores = theScorer->NumberOfScores();
TRACE_ERR("ScoreData: number_of_scores: " << number_of_scores << std::endl);
};
void ScoreData::save(std::ofstream& outFile, bool bin)
{
for (scoredata_t::iterator i = array_.begin(); i !=array_.end(); i++){
i->save(outFile, score_type, bin);
}
for (scoredata_t::iterator i = array_.begin(); i !=array_.end(); i++) {
i->save(outFile, score_type, bin);
}
}
void ScoreData::save(const std::string &file, bool bin)
{
if (file.empty()) return;
TRACE_ERR("saving the array into " << file << std::endl);
if (file.empty()) return;
TRACE_ERR("saving the array into " << file << std::endl);
std::ofstream outFile(file.c_str(), std::ios::out); // matches a stream with a file. Opens the file
std::ofstream outFile(file.c_str(), std::ios::out); // matches a stream with a file. Opens the file
ScoreStats entry;
save(outFile, bin);
save(outFile, bin);
outFile.close();
outFile.close();
}
void ScoreData::load(ifstream& inFile)
{
ScoreArray entry;
while (!inFile.eof()){
if (!inFile.good()){
std::cerr << "ERROR ScoreData::load inFile.good()" << std::endl;
}
entry.clear();
entry.load(inFile);
while (!inFile.eof()) {
if (entry.size() == 0){
break;
}
add(entry);
}
if (!inFile.good()) {
std::cerr << "ERROR ScoreData::load inFile.good()" << std::endl;
}
entry.clear();
entry.load(inFile);
if (entry.size() == 0) {
break;
}
add(entry);
}
}
void ScoreData::load(const std::string &file)
{
TRACE_ERR("loading score data from " << file << std::endl);
TRACE_ERR("loading score data from " << file << std::endl);
inputfilestream inFile(file); // matches a stream with a file. Opens the file
inputfilestream inFile(file); // matches a stream with a file. Opens the file
if (!inFile) {
throw runtime_error("Unable to open score file: " + file);
}
if (!inFile) {
throw runtime_error("Unable to open score file: " + file);
}
load((ifstream&) inFile);
load((ifstream&) inFile);
inFile.close();
inFile.close();
}
void ScoreData::add(ScoreArray& e){
if (exists(e.getIndex())){ // array at position e.getIndex() already exists
//enlarge array at position e.getIndex()
size_t pos = getIndex(e.getIndex());
array_.at(pos).merge(e);
}
else{
array_.push_back(e);
setIndex();
}
void ScoreData::add(ScoreArray& e)
{
if (exists(e.getIndex())) { // array at position e.getIndex() already exists
//enlarge array at position e.getIndex()
size_t pos = getIndex(e.getIndex());
array_.at(pos).merge(e);
} else {
array_.push_back(e);
setIndex();
}
}
void ScoreData::add(const ScoreStats& e, const std::string& sent_idx){
if (exists(sent_idx)){ // array at position e.getIndex() already exists
//enlarge array at position e.getIndex()
size_t pos = getIndex(sent_idx);
// TRACE_ERR("Inserting in array " << sent_idx << std::endl);
array_.at(pos).add(e);
// TRACE_ERR("size: " << size() << " -> " << a.size() << std::endl);
}
else{
// TRACE_ERR("Creating a new entry in the array" << std::endl);
ScoreArray a;
a.NumberOfScores(number_of_scores);
a.add(e);
a.setIndex(sent_idx);
add(a);
// TRACE_ERR("size: " << size() << " -> " << a.size() << std::endl);
}
}
void ScoreData::add(const ScoreStats& e, const std::string& sent_idx)
{
if (exists(sent_idx)) { // array at position e.getIndex() already exists
//enlarge array at position e.getIndex()
size_t pos = getIndex(sent_idx);
// TRACE_ERR("Inserting in array " << sent_idx << std::endl);
array_.at(pos).add(e);
// TRACE_ERR("size: " << size() << " -> " << a.size() << std::endl);
} else {
// TRACE_ERR("Creating a new entry in the array" << std::endl);
ScoreArray a;
a.NumberOfScores(number_of_scores);
a.add(e);
a.setIndex(sent_idx);
add(a);
// TRACE_ERR("size: " << size() << " -> " << a.size() << std::endl);
}
}
bool ScoreData::check_consistency()
{
if (array_.size() == 0)
return true;
for (scoredata_t::iterator i = array_.begin(); i !=array_.end(); i++)
if (!i->check_consistency()) return false;
return true;
if (array_.size() == 0)
return true;
for (scoredata_t::iterator i = array_.begin(); i !=array_.end(); i++)
if (!i->check_consistency()) return false;
return true;
}
void ScoreData::setIndex()
{
size_t j=0;
for (scoredata_t::iterator i = array_.begin(); i !=array_.end(); i++){
idx2arrayname_[j]=i->getIndex();
arrayname2idx_[i->getIndex()]=j;
j++;
}
size_t j=0;
for (scoredata_t::iterator i = array_.begin(); i !=array_.end(); i++) {
idx2arrayname_[j]=i->getIndex();
arrayname2idx_[i->getIndex()]=j;
j++;
}
}

View File

@ -23,64 +23,90 @@ class Scorer;
class ScoreData
{
protected:
scoredata_t array_;
idx2name idx2arrayname_; //map from index to name of array
name2idx arrayname2idx_; //map from name to index of array
scoredata_t array_;
idx2name idx2arrayname_; //map from index to name of array
name2idx arrayname2idx_; //map from name to index of array
private:
Scorer* theScorer;
std::string score_type;
size_t number_of_scores;
Scorer* theScorer;
std::string score_type;
size_t number_of_scores;
public:
ScoreData(Scorer& sc);
~ScoreData(){};
inline void clear() { array_.clear(); }
inline ScoreArray get(const std::string& idx){ return array_.at(getIndex(idx)); }
inline ScoreArray& get(size_t idx){ return array_.at(idx); }
inline const ScoreArray& get(size_t idx) const { return array_.at(idx); }
inline bool exists(const std::string & sent_idx){ return exists(getIndex(sent_idx)); }
inline bool exists(int sent_idx){ return (sent_idx>-1 && sent_idx<(int)array_.size())?true:false; }
inline ScoreStats& get(size_t i, size_t j){ return array_.at(i).get(j); }
inline const ScoreStats& get(size_t i, size_t j) const { return array_.at(i).get(j); }
inline std::string name(){ return score_type; };
inline std::string name(std::string &sctype){ return score_type = sctype; };
ScoreData(Scorer& sc);
void add(ScoreArray& e);
void add(const ScoreStats& e, const std::string& sent_idx);
inline size_t NumberOfScores(){ return number_of_scores; }
inline size_t size(){ return array_.size(); }
void save(const std::string &file, bool bin=false);
void save(ofstream& outFile, bool bin=false);
inline void save(bool bin=false){ save("/dev/stdout", bin); }
~ScoreData() {};
void load(ifstream& inFile);
void load(const std::string &file);
bool check_consistency();
void setIndex();
inline int getIndex(const std::string& idx){
name2idx::iterator i = arrayname2idx_.find(idx);
if (i!=arrayname2idx_.end())
return i->second;
else
return -1;
inline void clear() {
array_.clear();
}
inline ScoreArray get(const std::string& idx) {
return array_.at(getIndex(idx));
}
inline ScoreArray& get(size_t idx) {
return array_.at(idx);
}
inline const ScoreArray& get(size_t idx) const {
return array_.at(idx);
}
inline bool exists(const std::string & sent_idx) {
return exists(getIndex(sent_idx));
}
inline bool exists(int sent_idx) {
return (sent_idx>-1 && sent_idx<(int)array_.size())?true:false;
}
inline ScoreStats& get(size_t i, size_t j) {
return array_.at(i).get(j);
}
inline const ScoreStats& get(size_t i, size_t j) const {
return array_.at(i).get(j);
}
inline std::string name() {
return score_type;
};
inline std::string name(std::string &sctype) {
return score_type = sctype;
};
void add(ScoreArray& e);
void add(const ScoreStats& e, const std::string& sent_idx);
inline size_t NumberOfScores() {
return number_of_scores;
}
inline size_t size() {
return array_.size();
}
void save(const std::string &file, bool bin=false);
void save(ofstream& outFile, bool bin=false);
inline void save(bool bin=false) {
save("/dev/stdout", bin);
}
void load(ifstream& inFile);
void load(const std::string &file);
bool check_consistency();
void setIndex();
inline int getIndex(const std::string& idx) {
name2idx::iterator i = arrayname2idx_.find(idx);
if (i!=arrayname2idx_.end())
return i->second;
else
return -1;
}
inline std::string getIndex(size_t idx) {
idx2name::iterator i = idx2arrayname_.find(idx);
if (i!=idx2arrayname_.end())
throw runtime_error("there is no entry at index " + idx);
return i->second;
}
inline std::string getIndex(size_t idx){
idx2name::iterator i = idx2arrayname_.find(idx);
if (i!=idx2arrayname_.end())
throw runtime_error("there is no entry at index " + idx);
return i->second;
}
};

View File

@ -14,123 +14,124 @@
ScoreStats::ScoreStats()
{
available_ = AVAILABLE_;
entries_ = 0;
array_ = new ScoreStatsType[available_];
available_ = AVAILABLE_;
entries_ = 0;
array_ = new ScoreStatsType[available_];
};
ScoreStats::~ScoreStats()
{
delete array_;
delete array_;
};
ScoreStats::ScoreStats(const ScoreStats &stats)
ScoreStats::ScoreStats(const ScoreStats &stats)
{
available_ = stats.available();
entries_ = stats.size();
array_ = new ScoreStatsType[available_];
memcpy(array_,stats.getArray(),scorebytes_);
available_ = stats.available();
entries_ = stats.size();
array_ = new ScoreStatsType[available_];
memcpy(array_,stats.getArray(),scorebytes_);
};
ScoreStats::ScoreStats(const size_t size)
{
available_ = size;
entries_ = size;
array_ = new ScoreStatsType[available_];
memset(array_,0,scorebytes_);
available_ = size;
entries_ = size;
array_ = new ScoreStatsType[available_];
memset(array_,0,scorebytes_);
};
ScoreStats::ScoreStats(std::string &theString)
{
set(theString);
set(theString);
}
void ScoreStats::expand()
{
available_*=2;
scorestats_t t_ = new ScoreStatsType[available_];
memcpy(t_,array_,scorebytes_);
delete array_;
array_=t_;
available_*=2;
scorestats_t t_ = new ScoreStatsType[available_];
memcpy(t_,array_,scorebytes_);
delete array_;
array_=t_;
}
void ScoreStats::add(ScoreStatsType v)
{
if (isfull()) expand();
array_[entries_++]=v;
if (isfull()) expand();
array_[entries_++]=v;
}
void ScoreStats::set(std::string &theString)
{
std::string substring, stringBuf;
reset();
while (!theString.empty()){
getNextPound(theString, substring);
add(ATOSST(substring.c_str()));
}
reset();
while (!theString.empty()) {
getNextPound(theString, substring);
add(ATOSST(substring.c_str()));
}
}
void ScoreStats::loadbin(std::ifstream& inFile)
{
inFile.read((char*) array_, scorebytes_);
}
inFile.read((char*) array_, scorebytes_);
}
void ScoreStats::loadtxt(std::ifstream& inFile)
{
std::string theString;
std::getline(inFile, theString);
set(theString);
std::getline(inFile, theString);
set(theString);
}
void ScoreStats::loadtxt(const std::string &file)
{
// TRACE_ERR("loading the stats from " << file << std::endl);
// TRACE_ERR("loading the stats from " << file << std::endl);
std::ifstream inFile(file.c_str(), std::ios::in); // matches a stream with a file. Opens the file
std::ifstream inFile(file.c_str(), std::ios::in); // matches a stream with a file. Opens the file
loadtxt(inFile);
loadtxt(inFile);
}
void ScoreStats::savetxt(const std::string &file)
{
// TRACE_ERR("saving the stats into " << file << std::endl);
// TRACE_ERR("saving the stats into " << file << std::endl);
std::ofstream outFile(file.c_str(), std::ios::out); // matches a stream with a file. Opens the file
std::ofstream outFile(file.c_str(), std::ios::out); // matches a stream with a file. Opens the file
savetxt(outFile);
savetxt(outFile);
}
void ScoreStats::savetxt(std::ofstream& outFile)
{
outFile << *this;
outFile << *this;
}
void ScoreStats::savebin(std::ofstream& outFile)
{
outFile.write((char*) array_, scorebytes_);
}
outFile.write((char*) array_, scorebytes_);
}
ScoreStats& ScoreStats::operator=(const ScoreStats &stats)
{
delete array_;
available_ = stats.available();
entries_ = stats.size();
array_ = new ScoreStatsType[available_];
memcpy(array_,stats.getArray(),scorebytes_);
return *this;
delete array_;
available_ = stats.available();
entries_ = stats.size();
array_ = new ScoreStatsType[available_];
memcpy(array_,stats.getArray(),scorebytes_);
return *this;
}
/**write the whole object to a stream*/
ostream& operator<<(ostream& o, const ScoreStats& e){
for (size_t i=0; i< e.size(); i++)
o << e.get(i) << " ";
return o;
ostream& operator<<(ostream& o, const ScoreStats& e)
{
for (size_t i=0; i< e.size(); i++)
o << e.get(i) << " ";
return o;
}

View File

@ -26,51 +26,72 @@ using namespace std;
class ScoreStats
{
private:
scorestats_t array_;
size_t entries_;
size_t available_;
scorestats_t array_;
size_t entries_;
size_t available_;
public:
ScoreStats();
ScoreStats();
ScoreStats(const size_t size);
ScoreStats(const ScoreStats &stats);
ScoreStats(std::string &theString);
ScoreStats& operator=(const ScoreStats &stats);
~ScoreStats();
bool isfull(){return (entries_ < available_)?0:1; }
void expand();
void add(ScoreStatsType v);
inline void clear() { memset((void*) array_,0,scorebytes_); }
inline ScoreStatsType get(size_t i){ return array_[i]; }
inline ScoreStatsType get(size_t i)const{ return array_[i]; }
inline scorestats_t getArray() const { return array_; }
void set(std::string &theString);
ScoreStats(const ScoreStats &stats);
ScoreStats(std::string &theString);
ScoreStats& operator=(const ScoreStats &stats);
inline size_t bytes() const{ return scorebytes_; }
inline size_t size() const{ return entries_; }
inline size_t available() const{ return available_; }
void savetxt(const std::string &file);
void savetxt(ofstream& outFile);
void savebin(ofstream& outFile);
inline void savetxt(){ savetxt("/dev/stdout"); }
~ScoreStats();
void loadtxt(const std::string &file);
void loadtxt(ifstream& inFile);
void loadbin(ifstream& inFile);
inline void reset(){ entries_ = 0; clear(); }
bool isfull() {
return (entries_ < available_)?0:1;
}
void expand();
void add(ScoreStatsType v);
/**write the whole object to a stream*/
friend ostream& operator<<(ostream& o, const ScoreStats& e);
inline void clear() {
memset((void*) array_,0,scorebytes_);
}
inline ScoreStatsType get(size_t i) {
return array_[i];
}
inline ScoreStatsType get(size_t i)const {
return array_[i];
}
inline scorestats_t getArray() const {
return array_;
}
void set(std::string &theString);
inline size_t bytes() const {
return scorebytes_;
}
inline size_t size() const {
return entries_;
}
inline size_t available() const {
return available_;
}
void savetxt(const std::string &file);
void savetxt(ofstream& outFile);
void savebin(ofstream& outFile);
inline void savetxt() {
savetxt("/dev/stdout");
}
void loadtxt(const std::string &file);
void loadtxt(ifstream& inFile);
void loadbin(ifstream& inFile);
inline void reset() {
entries_ = 0;
clear();
}
/**write the whole object to a stream*/
friend ostream& operator<<(ostream& o, const ScoreStats& e);
};

View File

@ -1,96 +1,99 @@
#include "Scorer.h"
//regularisation strategies
static float score_min(const statscores_t& scores, size_t start, size_t end) {
float min = numeric_limits<float>::max();
for (size_t i = start; i < end; ++i) {
if (scores[i] < min) {
min = scores[i];
}
}
return min;
static float score_min(const statscores_t& scores, size_t start, size_t end)
{
float min = numeric_limits<float>::max();
for (size_t i = start; i < end; ++i) {
if (scores[i] < min) {
min = scores[i];
}
}
return min;
}
static float score_average(const statscores_t& scores, size_t start, size_t end) {
if ((end - start) < 1) {
//shouldn't happen
return 0;
}
float total = 0;
for (size_t j = start; j < end; ++j) {
total += scores[j];
}
static float score_average(const statscores_t& scores, size_t start, size_t end)
{
if ((end - start) < 1) {
//shouldn't happen
return 0;
}
float total = 0;
for (size_t j = start; j < end; ++j) {
total += scores[j];
}
return total / (end - start);
return total / (end - start);
}
void StatisticsBasedScorer::score(const candidates_t& candidates, const diffs_t& diffs,
statscores_t& scores) {
if (!_scoreData) {
throw runtime_error("Score data not loaded");
}
//calculate the score for the candidates
if (_scoreData->size() == 0) {
throw runtime_error("Score data is empty");
statscores_t& scores)
{
if (!_scoreData) {
throw runtime_error("Score data not loaded");
}
//calculate the score for the candidates
if (_scoreData->size() == 0) {
throw runtime_error("Score data is empty");
}
if (candidates.size() == 0) {
throw runtime_error("No candidates supplied");
}
int numCounts = _scoreData->get(0,candidates[0]).size();
vector<int> totals(numCounts);
for (size_t i = 0; i < candidates.size(); ++i) {
ScoreStats stats = _scoreData->get(i,candidates[i]);
if (stats.size() != totals.size()) {
stringstream msg;
msg << "Statistics for (" << "," << candidates[i] << ") have incorrect "
<< "number of fields. Found: " << stats.size() << " Expected: "
<< totals.size();
throw runtime_error(msg.str());
}
if (candidates.size() == 0) {
throw runtime_error("No candidates supplied");
for (size_t k = 0; k < totals.size(); ++k) {
totals[k] += stats.get(k);
}
}
scores.push_back(calculateScore(totals));
candidates_t last_candidates(candidates);
//apply each of the diffs, and get new scores
for (size_t i = 0; i < diffs.size(); ++i) {
for (size_t j = 0; j < diffs[i].size(); ++j) {
size_t sid = diffs[i][j].first;
size_t nid = diffs[i][j].second;
size_t last_nid = last_candidates[sid];
for (size_t k = 0; k < totals.size(); ++k) {
int diff = _scoreData->get(sid,nid).get(k)
- _scoreData->get(sid,last_nid).get(k);
totals[k] += diff;
}
last_candidates[sid] = nid;
}
int numCounts = _scoreData->get(0,candidates[0]).size();
vector<int> totals(numCounts);
for (size_t i = 0; i < candidates.size(); ++i) {
ScoreStats stats = _scoreData->get(i,candidates[i]);
if (stats.size() != totals.size()) {
stringstream msg;
msg << "Statistics for (" << "," << candidates[i] << ") have incorrect "
<< "number of fields. Found: " << stats.size() << " Expected: "
<< totals.size();
throw runtime_error(msg.str());
}
for (size_t k = 0; k < totals.size(); ++k) {
totals[k] += stats.get(k);
}
}
scores.push_back(calculateScore(totals));
}
candidates_t last_candidates(candidates);
//apply each of the diffs, and get new scores
for (size_t i = 0; i < diffs.size(); ++i) {
for (size_t j = 0; j < diffs[i].size(); ++j) {
size_t sid = diffs[i][j].first;
size_t nid = diffs[i][j].second;
size_t last_nid = last_candidates[sid];
for (size_t k = 0; k < totals.size(); ++k) {
int diff = _scoreData->get(sid,nid).get(k)
- _scoreData->get(sid,last_nid).get(k);
totals[k] += diff;
}
last_candidates[sid] = nid;
}
scores.push_back(calculateScore(totals));
}
//regularisation. This can either be none, or the min or average as described in
//Cer, Jurafsky and Manning at WMT08
if (_regularisationStrategy == REG_NONE || _regularisationWindow <= 0) {
//no regularisation
return;
}
//regularisation. This can either be none, or the min or average as described in
//Cer, Jurafsky and Manning at WMT08
if (_regularisationStrategy == REG_NONE || _regularisationWindow <= 0) {
//no regularisation
return;
//window size specifies the +/- in each direction
statscores_t raw_scores(scores);//copy scores
for (size_t i = 0; i < scores.size(); ++i) {
size_t start = 0;
if (i >= _regularisationWindow) {
start = i - _regularisationWindow;
}
//window size specifies the +/- in each direction
statscores_t raw_scores(scores);//copy scores
for (size_t i = 0; i < scores.size(); ++i) {
size_t start = 0;
if (i >= _regularisationWindow) {
start = i - _regularisationWindow;
}
size_t end = min(scores.size(), i + _regularisationWindow+1);
if (_regularisationStrategy == REG_AVERAGE) {
scores[i] = score_average(raw_scores,start,end);
} else {
scores[i] = score_min(raw_scores,start,end);
}
size_t end = min(scores.size(), i + _regularisationWindow+1);
if (_regularisationStrategy == REG_AVERAGE) {
scores[i] = score_average(raw_scores,start,end);
} else {
scores[i] = score_min(raw_scores,start,end);
}
}
}

View File

@ -23,163 +23,168 @@ class ScoreStats;
/**
* Superclass of all scorers and dummy implementation. In order to add a new
* scorer it should be sufficient to override prepareStats(), setReferenceFiles()
* and score() (or calculateScore()).
* and score() (or calculateScore()).
**/
class Scorer {
private:
string _name;
public:
Scorer(const string& name, const string& config): _name(name), _scoreData(0), _preserveCase(true){
cerr << "Scorer config string: " << config << endl;
size_t start = 0;
while (start < config.size()) {
size_t end = config.find(",",start);
if (end == string::npos) {
end = config.size();
}
string nv = config.substr(start,end-start);
size_t split = nv.find(":");
if (split == string::npos) {
throw runtime_error("Missing colon when processing scorer config: " + config);
}
string name = nv.substr(0,split);
string value = nv.substr(split+1,nv.size()-split-1);
cerr << "name: " << name << " value: " << value << endl;
_config[name] = value;
start = end+1;
}
class Scorer
{
private:
string _name;
};
virtual ~Scorer(){};
public:
/**
* returns the number of statistics needed for the computation of the score
**/
virtual size_t NumberOfScores(){ cerr << "Scorer: 0" << endl; return 0; };
/**
* set the reference files. This must be called before prepareStats.
**/
virtual void setReferenceFiles(const vector<string>& referenceFiles) {
//do nothing
}
/**
* Process the given guessed text, corresponding to the given reference sindex
* and add the appropriate statistics to the entry.
**/
virtual void prepareStats(size_t sindex, const string& text, ScoreStats& entry)
{}
virtual void prepareStats(const string& sindex, const string& text, ScoreStats& entry)
{
// cerr << sindex << endl;
this->prepareStats((size_t) atoi(sindex.c_str()), text, entry);
//cerr << text << std::endl;
}
/**
* Score using each of the candidate index, then go through the diffs
* applying each in turn, and calculating a new score each time.
**/
virtual void score(const candidates_t& candidates, const diffs_t& diffs,
statscores_t& scores) {
//dummy impl
if (!_scoreData) {
throw runtime_error("score data not loaded");
}
scores.push_back(0);
for (size_t i = 0; i < diffs.size(); ++i) {
scores.push_back(0);
}
}
/**
* Calculate the score of the sentences corresponding to the list of candidate
* indices. Each index indicates the 1-best choice from the n-best list.
**/
float score(const candidates_t& candidates) {
diffs_t diffs;
statscores_t scores;
score(candidates, diffs, scores);
return scores[0];
}
const string& getName() const {return _name;}
size_t getReferenceSize() {
if (_scoreData) {
return _scoreData->size();
}
return 0;
}
/**
* Set the score data, prior to scoring.
**/
void setScoreData(ScoreData* data) {
_scoreData = data;
}
protected:
typedef map<string,int> encodings_t;
typedef map<string,int>::iterator encodings_it;
ScoreData* _scoreData;
encodings_t _encodings;
bool _preserveCase;
/**
* Value of config variable. If not provided, return default.
**/
string getConfig(const string& key, const string& def="") {
map<string,string>::iterator i = _config.find(key);
if (i == _config.end()) {
return def;
} else {
return i->second;
}
Scorer(const string& name, const string& config): _name(name), _scoreData(0), _preserveCase(true) {
cerr << "Scorer config string: " << config << endl;
size_t start = 0;
while (start < config.size()) {
size_t end = config.find(",",start);
if (end == string::npos) {
end = config.size();
}
/**
* Tokenise line and encode.
* Note: We assume that all tokens are separated by single spaces
**/
void encode(const string& line, vector<int>& encoded) {
//cerr << line << endl;
istringstream in (line);
string token;
while (in >> token) {
if (!_preserveCase) {
for (string::iterator i = token.begin(); i != token.end(); ++i) {
*i = tolower(*i);
}
}
encodings_it encoding = _encodings.find(token);
int encoded_token;
if (encoding == _encodings.end()) {
encoded_token = (int)_encodings.size();
_encodings[token] = encoded_token;
//cerr << encoded_token << "(n) ";
} else {
encoded_token = encoding->second;
//cerr << encoded_token << " ";
}
encoded.push_back(encoded_token);
}
//cerr << endl;
string nv = config.substr(start,end-start);
size_t split = nv.find(":");
if (split == string::npos) {
throw runtime_error("Missing colon when processing scorer config: " + config);
}
string name = nv.substr(0,split);
string value = nv.substr(split+1,nv.size()-split-1);
cerr << "name: " << name << " value: " << value << endl;
_config[name] = value;
start = end+1;
}
private:
map<string,string> _config;
};
virtual ~Scorer() {};
/**
* returns the number of statistics needed for the computation of the score
**/
virtual size_t NumberOfScores() {
cerr << "Scorer: 0" << endl;
return 0;
};
/**
* set the reference files. This must be called before prepareStats.
**/
virtual void setReferenceFiles(const vector<string>& referenceFiles) {
//do nothing
}
/**
* Process the given guessed text, corresponding to the given reference sindex
* and add the appropriate statistics to the entry.
**/
virtual void prepareStats(size_t sindex, const string& text, ScoreStats& entry)
{}
virtual void prepareStats(const string& sindex, const string& text, ScoreStats& entry) {
// cerr << sindex << endl;
this->prepareStats((size_t) atoi(sindex.c_str()), text, entry);
//cerr << text << std::endl;
}
/**
* Score using each of the candidate index, then go through the diffs
* applying each in turn, and calculating a new score each time.
**/
virtual void score(const candidates_t& candidates, const diffs_t& diffs,
statscores_t& scores) {
//dummy impl
if (!_scoreData) {
throw runtime_error("score data not loaded");
}
scores.push_back(0);
for (size_t i = 0; i < diffs.size(); ++i) {
scores.push_back(0);
}
}
/**
* Calculate the score of the sentences corresponding to the list of candidate
* indices. Each index indicates the 1-best choice from the n-best list.
**/
float score(const candidates_t& candidates) {
diffs_t diffs;
statscores_t scores;
score(candidates, diffs, scores);
return scores[0];
}
const string& getName() const {
return _name;
}
size_t getReferenceSize() {
if (_scoreData) {
return _scoreData->size();
}
return 0;
}
/**
* Set the score data, prior to scoring.
**/
void setScoreData(ScoreData* data) {
_scoreData = data;
}
protected:
typedef map<string,int> encodings_t;
typedef map<string,int>::iterator encodings_it;
ScoreData* _scoreData;
encodings_t _encodings;
bool _preserveCase;
/**
* Value of config variable. If not provided, return default.
**/
string getConfig(const string& key, const string& def="") {
map<string,string>::iterator i = _config.find(key);
if (i == _config.end()) {
return def;
} else {
return i->second;
}
}
/**
* Tokenise line and encode.
* Note: We assume that all tokens are separated by single spaces
**/
void encode(const string& line, vector<int>& encoded) {
//cerr << line << endl;
istringstream in (line);
string token;
while (in >> token) {
if (!_preserveCase) {
for (string::iterator i = token.begin(); i != token.end(); ++i) {
*i = tolower(*i);
}
}
encodings_it encoding = _encodings.find(token);
int encoded_token;
if (encoding == _encodings.end()) {
encoded_token = (int)_encodings.size();
_encodings[token] = encoded_token;
//cerr << encoded_token << "(n) ";
} else {
encoded_token = encoding->second;
//cerr << encoded_token << " ";
}
encoded.push_back(encoded_token);
}
//cerr << endl;
}
private:
map<string,string> _config;
};
@ -187,11 +192,12 @@ class Scorer {
/**
* Abstract base class for scorers that work by adding statistics across all
* Abstract base class for scorers that work by adding statistics across all
* outout sentences, then apply some formula, e.g. bleu, per. **/
class StatisticsBasedScorer : public Scorer {
class StatisticsBasedScorer : public Scorer
{
public:
public:
StatisticsBasedScorer(const string& name, const string& config): Scorer(name,config) {
//configure regularisation
static string KEY_TYPE = "regtype";
@ -202,45 +208,45 @@ class StatisticsBasedScorer : public Scorer {
static string TYPE_MINIMUM = "min";
static string TRUE = "true";
static string FALSE = "false";
string type = getConfig(KEY_TYPE,TYPE_NONE);
if (type == TYPE_NONE) {
_regularisationStrategy = REG_NONE;
_regularisationStrategy = REG_NONE;
} else if (type == TYPE_AVERAGE) {
_regularisationStrategy = REG_AVERAGE;
_regularisationStrategy = REG_AVERAGE;
} else if (type == TYPE_MINIMUM) {
_regularisationStrategy = REG_MINIMUM;
_regularisationStrategy = REG_MINIMUM;
} else {
throw runtime_error("Unknown scorer regularisation strategy: " + type);
throw runtime_error("Unknown scorer regularisation strategy: " + type);
}
cerr << "Using scorer regularisation strategy: " << type << endl;
string window = getConfig(KEY_WINDOW,"0");
_regularisationWindow = atoi(window.c_str());
cerr << "Using scorer regularisation window: " << _regularisationWindow << endl;
string preservecase = getConfig(KEY_CASE,TRUE);
if (preservecase == TRUE) {
_preserveCase = true;
}else if (preservecase == FALSE) {
_preserveCase = false;
_preserveCase = true;
} else if (preservecase == FALSE) {
_preserveCase = false;
}
cerr << "Using case preservation: " << _preserveCase << endl;
}
~StatisticsBasedScorer(){};
virtual void score(const candidates_t& candidates, const diffs_t& diffs,
statscores_t& scores);
~StatisticsBasedScorer() {};
virtual void score(const candidates_t& candidates, const diffs_t& diffs,
statscores_t& scores);
protected:
//calculate the actual score
virtual statscore_t calculateScore(const vector<int>& totals) = 0;
protected:
//calculate the actual score
virtual statscore_t calculateScore(const vector<int>& totals) = 0;
//regularisation
ScorerRegularisationStrategy _regularisationStrategy;
size_t _regularisationWindow;
//regularisation
ScorerRegularisationStrategy _regularisationStrategy;
size_t _regularisationWindow;
};

View File

@ -17,25 +17,26 @@
using namespace std;
class ScorerFactory {
class ScorerFactory
{
public:
vector<string> getTypes() {
vector<string> types;
types.push_back(string("BLEU"));
types.push_back(string("PER"));
return types;
}
public:
vector<string> getTypes() {
vector<string> types;
types.push_back(string("BLEU"));
types.push_back(string("PER"));
return types;
}
Scorer* getScorer(const string& type, const string& config = "") {
if (type == "BLEU") {
return (BleuScorer*) new BleuScorer(config);
} else if (type == "PER") {
return (PerScorer*) new PerScorer(config);
} else {
throw runtime_error("Unknown scorer type: " + type);
}
}
Scorer* getScorer(const string& type, const string& config = "") {
if (type == "BLEU") {
return (BleuScorer*) new BleuScorer(config);
} else if (type == "PER") {
return (PerScorer*) new PerScorer(config);
} else {
throw runtime_error("Unknown scorer type: " + type);
}
}
};
#endif //__SCORER_FACTORY_H

View File

@ -12,8 +12,8 @@
*/
double Timer::elapsed_time()
{
time_t now;
time(&now);
time_t now;
time(&now);
return difftime(now, start_time);
}
@ -36,7 +36,7 @@ double Timer::get_elapsed_time()
void Timer::start(const char* msg)
{
// Print an optional message, something like "Starting timer t";
if (msg) TRACE_ERR( msg << std::endl);
if (msg) TRACE_ERR( msg << std::endl);
// Return immediately if the timer is already running
if (running) return;

View File

@ -8,16 +8,16 @@
class Timer
{
friend std::ostream& operator<<(std::ostream& os, Timer& t);
friend std::ostream& operator<<(std::ostream& os, Timer& t);
private:
private:
bool running;
time_t start_time;
//TODO in seconds?
//TODO in seconds?
double elapsed_time();
public:
public:
/***
* 'running' is initially false. A timer needs to be explicitly started
* using 'start' or 'restart'

View File

@ -1,7 +1,7 @@
/*
* Util.cpp
* met - Minimum Error Training
*
*
* Created by Nicola Bertoldi on 13/05/08.
*
*/
@ -18,47 +18,47 @@ Timer g_timer;
int verbose=0;
int verboselevel(){
int verboselevel()
{
return verbose;
}
int setverboselevel(int v){
int setverboselevel(int v)
{
verbose=v;
return verbose;
}
int getNextPound(std::string &theString, std::string &substring, const std::string delimiter)
{
unsigned int pos = 0;
//skip all occurrences of delimiter
while ( pos == 0 )
{
if ((pos = theString.find(delimiter)) != std::string::npos){
substring.assign(theString, 0, pos);
theString.erase(0,pos + delimiter.size());
}
else{
substring.assign(theString);
theString.assign("");
}
}
return (pos);
unsigned int pos = 0;
//skip all occurrences of delimiter
while ( pos == 0 ) {
if ((pos = theString.find(delimiter)) != std::string::npos) {
substring.assign(theString, 0, pos);
theString.erase(0,pos + delimiter.size());
} else {
substring.assign(theString);
theString.assign("");
}
}
return (pos);
};
inputfilestream::inputfilestream(const std::string &filePath)
: std::istream(0),
m_streambuf(0)
: std::istream(0),
m_streambuf(0)
{
//check if file is readable
std::filebuf* fb = new std::filebuf();
_good=(fb->open(filePath.c_str(), std::ios::in)!=NULL);
if (filePath.size() > 3 &&
filePath.substr(filePath.size() - 3, 3) == ".gz")
{
fb->close(); delete fb;
m_streambuf = new gzfilebuf(filePath.c_str());
filePath.substr(filePath.size() - 3, 3) == ".gz") {
fb->close();
delete fb;
m_streambuf = new gzfilebuf(filePath.c_str());
} else {
m_streambuf = fb;
}
@ -67,7 +67,8 @@ m_streambuf(0)
inputfilestream::~inputfilestream()
{
delete m_streambuf; m_streambuf = 0;
delete m_streambuf;
m_streambuf = 0;
}
void inputfilestream::close()
@ -75,16 +76,15 @@ void inputfilestream::close()
}
outputfilestream::outputfilestream(const std::string &filePath)
: std::ostream(0),
m_streambuf(0)
: std::ostream(0),
m_streambuf(0)
{
//check if file is readable
std::filebuf* fb = new std::filebuf();
_good=(fb->open(filePath.c_str(), std::ios::out)!=NULL);
if (filePath.size() > 3 && filePath.substr(filePath.size() - 3, 3) == ".gz")
{
throw runtime_error("Output to a zipped file not supported!");
_good=(fb->open(filePath.c_str(), std::ios::out)!=NULL);
if (filePath.size() > 3 && filePath.substr(filePath.size() - 3, 3) == ".gz") {
throw runtime_error("Output to a zipped file not supported!");
} else {
m_streambuf = fb;
}
@ -93,7 +93,8 @@ m_streambuf(0)
outputfilestream::~outputfilestream()
{
delete m_streambuf; m_streambuf = 0;
delete m_streambuf;
m_streambuf = 0;
}
void outputfilestream::close()
@ -103,10 +104,14 @@ void outputfilestream::close()
int swapbytes(char *p, int sz, int n)
{
char c, *l, *h;
if((n<1) || (sz<2)) return 0;
for(; n--; p+=sz) for(h=(l=p)+sz; --h>l; l++) { c=*h; *h=*l; *l=c; }
return 0;
for(; n--; p+=sz) for(h=(l=p)+sz; --h>l; l++) {
c=*h;
*h=*l;
*l=c;
}
return 0;
};
@ -116,12 +121,12 @@ void ResetUserTime()
};
void PrintUserTime(const std::string &message)
{
g_timer.check(message.c_str());
{
g_timer.check(message.c_str());
}
double GetUserTime()
{
return g_timer.get_elapsed_time();
return g_timer.get_elapsed_time();
}

View File

@ -51,45 +51,49 @@ int getNextPound(std::string &theString, std::string &substring, const std::stri
template<typename T>
inline T Scan(const std::string &input)
{
std::stringstream stream(input);
T ret;
stream >> ret;
return ret;
std::stringstream stream(input);
T ret;
stream >> ret;
return ret;
};
class inputfilestream : public std::istream
{
protected:
std::streambuf *m_streambuf;
bool _good;
std::streambuf *m_streambuf;
bool _good;
public:
inputfilestream(const std::string &filePath);
~inputfilestream();
bool good(){return _good;}
void close();
inputfilestream(const std::string &filePath);
~inputfilestream();
bool good() {
return _good;
}
void close();
};
class outputfilestream : public std::ostream
{
protected:
std::streambuf *m_streambuf;
bool _good;
std::streambuf *m_streambuf;
bool _good;
public:
outputfilestream(const std::string &filePath);
~outputfilestream();
bool good(){return _good;}
void close();
outputfilestream(const std::string &filePath);
~outputfilestream();
bool good() {
return _good;
}
void close();
};
template<typename T>
inline std::string stringify(T x)
{
std::ostringstream o;
if (!(o << x))
throw std::runtime_error("stringify(template<typename T>)");
return o.str();
std::ostringstream o;
if (!(o << x))
throw std::runtime_error("stringify(template<typename T>)");
return o.str();
}
// Utilities to measure decoding time

View File

@ -18,7 +18,8 @@
using namespace std;
void usage() {
void usage()
{
cerr<<"usage: extractor [options])"<<endl;
cerr<<"[--sctype|-s] the scorer type (default BLEU)"<<endl;
cerr<<"[--scconfig|-c] configuration string passed to scorer"<<endl;
@ -28,7 +29,7 @@ void usage() {
cerr<<"[--nbest|-n] the nbest file"<<endl;
cerr<<"[--scfile|-S] the scorer data output file"<<endl;
cerr<<"[--ffile|-F] the feature data output file"<<endl;
cerr<<"[--prev-ffile|-E] comma separated list of previous feature data" <<endl;
cerr<<"[--prev-ffile|-E] comma separated list of previous feature data" <<endl;
cerr<<"[--prev-scfile|-R] comma separated list of previous scorer data"<<endl;
cerr<<"[-v] verbose level"<<endl;
cerr<<"[--help|-h] print this message and exit"<<endl;
@ -36,184 +37,184 @@ cerr<<"[--prev-ffile|-E] comma separated list of previous feature data" <<endl;
}
static struct option long_options[] =
{
{"sctype",required_argument,0,'s'},
{"scconfig",required_argument,0,'c'},
{"reference",required_argument,0,'r'},
{"binary",no_argument,0,'b'},
{"nbest",required_argument,0,'n'},
{"scfile",required_argument,0,'S'},
{"ffile",required_argument,0,'F'},
{"prev-scfile",required_argument,0,'R'},
{"prev-ffile",required_argument,0,'E'},
{"verbose",required_argument,0,'v'},
{"help",no_argument,0,'h'},
{0, 0, 0, 0}
};
static struct option long_options[] = {
{"sctype",required_argument,0,'s'},
{"scconfig",required_argument,0,'c'},
{"reference",required_argument,0,'r'},
{"binary",no_argument,0,'b'},
{"nbest",required_argument,0,'n'},
{"scfile",required_argument,0,'S'},
{"ffile",required_argument,0,'F'},
{"prev-scfile",required_argument,0,'R'},
{"prev-ffile",required_argument,0,'E'},
{"verbose",required_argument,0,'v'},
{"help",no_argument,0,'h'},
{0, 0, 0, 0}
};
int option_index;
int main(int argc, char** argv) {
ResetUserTime();
/*
Timer timer;
timer.start("Starting...");
*/
//defaults
string scorerType("BLEU");
string scorerConfig("");
string referenceFile("");
string nbestFile("");
string scoreDataFile("statscore.data");
string featureDataFile("features.data");
string prevScoreDataFile("");
string prevFeatureDataFile("");
bool binmode = false;
int verbosity = 0;
int c;
while ((c=getopt_long (argc,argv, "s:r:n:S:F:R:E:v:hb", long_options, &option_index)) != -1) {
switch(c) {
case 's':
scorerType = string(optarg);
break;
case 'c':
scorerConfig = string(optarg);
break;
case 'r':
referenceFile = string(optarg);
break;
case 'b':
binmode = true;
break;
case 'n':
nbestFile = string(optarg);
break;
case 'S':
scoreDataFile = string(optarg);
break;
case 'F':
featureDataFile = string(optarg);
break;
case 'E':
prevFeatureDataFile = string(optarg);
break;
case 'R':
prevScoreDataFile = string(optarg);
break;
case 'v':
verbosity = atoi(optarg);
break;
default:
usage();
}
int main(int argc, char** argv)
{
ResetUserTime();
/*
Timer timer;
timer.start("Starting...");
*/
//defaults
string scorerType("BLEU");
string scorerConfig("");
string referenceFile("");
string nbestFile("");
string scoreDataFile("statscore.data");
string featureDataFile("features.data");
string prevScoreDataFile("");
string prevFeatureDataFile("");
bool binmode = false;
int verbosity = 0;
int c;
while ((c=getopt_long (argc,argv, "s:r:n:S:F:R:E:v:hb", long_options, &option_index)) != -1) {
switch(c) {
case 's':
scorerType = string(optarg);
break;
case 'c':
scorerConfig = string(optarg);
break;
case 'r':
referenceFile = string(optarg);
break;
case 'b':
binmode = true;
break;
case 'n':
nbestFile = string(optarg);
break;
case 'S':
scoreDataFile = string(optarg);
break;
case 'F':
featureDataFile = string(optarg);
break;
case 'E':
prevFeatureDataFile = string(optarg);
break;
case 'R':
prevScoreDataFile = string(optarg);
break;
case 'v':
verbosity = atoi(optarg);
break;
default:
usage();
}
try {
}
try {
//check whether score statistics file is specified
if (scoreDataFile.length() == 0){
throw runtime_error("Error: output score statistics file is not specified");
if (scoreDataFile.length() == 0) {
throw runtime_error("Error: output score statistics file is not specified");
}
//check wheter feature file is specified
if (featureDataFile.length() == 0){
throw runtime_error("Error: output feature file is not specified");
if (featureDataFile.length() == 0) {
throw runtime_error("Error: output feature file is not specified");
}
//check whether reference file is specified when nbest is specified
if ((nbestFile.length() > 0 && referenceFile.length() == 0)){
throw runtime_error("Error: reference file is not specified; you can not score the nbest");
if ((nbestFile.length() > 0 && referenceFile.length() == 0)) {
throw runtime_error("Error: reference file is not specified; you can not score the nbest");
}
vector<string> nbestFiles;
if (nbestFile.length() > 0){
std::string substring;
while (!nbestFile.empty()){
getNextPound(nbestFile, substring, ",");
nbestFiles.push_back(substring);
}
if (nbestFile.length() > 0) {
std::string substring;
while (!nbestFile.empty()) {
getNextPound(nbestFile, substring, ",");
nbestFiles.push_back(substring);
}
}
vector<string> referenceFiles;
if (referenceFile.length() > 0){
std::string substring;
while (!referenceFile.empty()){
getNextPound(referenceFile, substring, ",");
referenceFiles.push_back(substring);
}
if (referenceFile.length() > 0) {
std::string substring;
while (!referenceFile.empty()) {
getNextPound(referenceFile, substring, ",");
referenceFiles.push_back(substring);
}
}
vector<string> prevScoreDataFiles;
if (prevScoreDataFile.length() > 0){
std::string substring;
while (!prevScoreDataFile.empty()){
getNextPound(prevScoreDataFile, substring, ",");
prevScoreDataFiles.push_back(substring);
}
if (prevScoreDataFile.length() > 0) {
std::string substring;
while (!prevScoreDataFile.empty()) {
getNextPound(prevScoreDataFile, substring, ",");
prevScoreDataFiles.push_back(substring);
}
}
vector<string> prevFeatureDataFiles;
if (prevFeatureDataFile.length() > 0){
std::string substring;
while (!prevFeatureDataFile.empty()){
getNextPound(prevFeatureDataFile, substring, ",");
prevFeatureDataFiles.push_back(substring);
}
if (prevFeatureDataFile.length() > 0) {
std::string substring;
while (!prevFeatureDataFile.empty()) {
getNextPound(prevFeatureDataFile, substring, ",");
prevFeatureDataFiles.push_back(substring);
}
}
if (prevScoreDataFiles.size() != prevFeatureDataFiles.size()){
throw runtime_error("Error: there is a different number of previous score and feature files");
if (prevScoreDataFiles.size() != prevFeatureDataFiles.size()) {
throw runtime_error("Error: there is a different number of previous score and feature files");
}
if (binmode) cerr << "Binary write mode is selected" << endl;
else cerr << "Binary write mode is NOT selected" << endl;
TRACE_ERR("Scorer type: " << scorerType << endl);
ScorerFactory sfactory;
Scorer* scorer = sfactory.getScorer(scorerType,scorerConfig);
//load references
if (referenceFiles.size() > 0)
scorer->setReferenceFiles(referenceFiles);
PrintUserTime("References loaded");
Data data(*scorer);
//load old data
for (size_t i=0;i < prevScoreDataFiles.size(); i++){
data.load(prevFeatureDataFiles.at(i), prevScoreDataFiles.at(i));
}
PrintUserTime("Previous data loaded");
//computing score statistics of each nbest file
for (size_t i=0;i < nbestFiles.size(); i++){
data.loadnbest(nbestFiles.at(i));
}
if (binmode) cerr << "Binary write mode is selected" << endl;
else cerr << "Binary write mode is NOT selected" << endl;
PrintUserTime("Nbest entries loaded and scored");
if (binmode)
cerr << "Binary write mode is selected" << endl;
else
cerr << "Binary write mode is NOT selected" << endl;
data.save(featureDataFile, scoreDataFile, binmode);
PrintUserTime("Stopping...");
/*
timer.stop("Stopping...");
*/
return EXIT_SUCCESS;
} catch (const exception& e) {
cerr << "Exception: " << e.what() << endl;
return EXIT_FAILURE;
TRACE_ERR("Scorer type: " << scorerType << endl);
ScorerFactory sfactory;
Scorer* scorer = sfactory.getScorer(scorerType,scorerConfig);
//load references
if (referenceFiles.size() > 0)
scorer->setReferenceFiles(referenceFiles);
PrintUserTime("References loaded");
Data data(*scorer);
//load old data
for (size_t i=0; i < prevScoreDataFiles.size(); i++) {
data.load(prevFeatureDataFiles.at(i), prevScoreDataFiles.at(i));
}
PrintUserTime("Previous data loaded");
//computing score statistics of each nbest file
for (size_t i=0; i < nbestFiles.size(); i++) {
data.loadnbest(nbestFiles.at(i));
}
PrintUserTime("Nbest entries loaded and scored");
if (binmode)
cerr << "Binary write mode is selected" << endl;
else
cerr << "Binary write mode is NOT selected" << endl;
data.save(featureDataFile, scoreDataFile, binmode);
PrintUserTime("Stopping...");
/*
timer.stop("Stopping...");
*/
return EXIT_SUCCESS;
} catch (const exception& e) {
cerr << "Exception: " << e.what() << endl;
return EXIT_FAILURE;
}
}

Some files were not shown because too many files have changed in this diff Show More