clean up weights code for confusion networks & lattices. Works, except for multiple phrase-tables or factors

This commit is contained in:
Hieu Hoang 2012-12-05 20:21:33 +00:00
parent b8d4c64d6d
commit da9cd0e3aa
5 changed files with 48 additions and 74 deletions

View File

@ -118,11 +118,12 @@ bool ConfusionNet::ReadFormat0(std::istream& in,
const std::vector<FactorType>& factorOrder) const std::vector<FactorType>& factorOrder)
{ {
Clear(); Clear();
std::string line; size_t numInputScores = StaticData::Instance().GetNumInputScores();
size_t numLinkParams = StaticData::Instance().GetNumLinkParams(); size_t numRealWordCount = StaticData::Instance().GetNumRealWordsInInput();
size_t numLinkWeights = StaticData::Instance().GetNumInputScores(); size_t totalCount = numInputScores + numRealWordCount;
bool addRealWordCount = ((numLinkParams + 1) == numLinkWeights); bool addRealWordCount = (numRealWordCount > 0);
std::string line;
while(getline(in,line)) { while(getline(in,line)) {
std::istringstream is(line); std::istringstream is(line);
std::string word; std::string word;
@ -131,8 +132,8 @@ bool ConfusionNet::ReadFormat0(std::istream& in,
while(is>>word) { while(is>>word) {
Word w; Word w;
String2Word(word,w,factorOrder); String2Word(word,w,factorOrder);
std::vector<float> probs(numLinkWeights,0.0); std::vector<float> probs(totalCount, 0.0);
for(size_t i=0; i<numLinkParams; i++) { for(size_t i=0; i < numInputScores; i++) {
double prob; double prob;
if (!(is>>prob)) { if (!(is>>prob)) {
TRACE_ERR("ERROR: unable to parse CN input - bad link probability, or wrong number of scores\n"); TRACE_ERR("ERROR: unable to parse CN input - bad link probability, or wrong number of scores\n");
@ -150,7 +151,7 @@ bool ConfusionNet::ReadFormat0(std::istream& in,
} }
//store 'real' word count in last feature if we have one more weight than we do arc scores and not epsilon //store 'real' word count in last feature if we have one more weight than we do arc scores and not epsilon
if (addRealWordCount && word!=EPSILON && word!="") if (addRealWordCount && word!=EPSILON && word!="")
probs[numLinkParams] = -1.0; probs.back() = -1.0;
col.push_back(std::make_pair(w,probs)); col.push_back(std::make_pair(w,probs));
} }
if(col.size()) { if(col.size()) {

View File

@ -121,7 +121,6 @@ Parameter::Parameter()
AddParam("cube-pruning-diversity", "cbd", "How many hypotheses should be created for each coverage. (default = 0)"); AddParam("cube-pruning-diversity", "cbd", "How many hypotheses should be created for each coverage. (default = 0)");
AddParam("search-algorithm", "Which search algorithm to use. 0=normal stack, 1=cube pruning, 2=cube growing. (default = 0)"); AddParam("search-algorithm", "Which search algorithm to use. 0=normal stack, 1=cube pruning, 2=cube growing. (default = 0)");
AddParam("constraint", "Location of the file with target sentences to produce constraining the search"); AddParam("constraint", "Location of the file with target sentences to produce constraining the search");
AddParam("link-param-count", "Number of parameters on word links when using confusion networks or lattices (default = 1)");
AddParam("description", "Source language, target language, description"); AddParam("description", "Source language, target language, description");
AddParam("max-chart-span", "maximum num. of source word chart rules can consume (default 10)"); AddParam("max-chart-span", "maximum num. of source word chart rules can consume (default 10)");
AddParam("non-terminals", "list of non-term symbols, space separated"); AddParam("non-terminals", "list of non-term symbols, space separated");
@ -166,6 +165,9 @@ Parameter::Parameter()
AddParam("report-segmentation", "t", "report phrase segmentation in the output"); AddParam("report-segmentation", "t", "report phrase segmentation in the output");
AddParam("translation-systems", "DEPRECATED. DO NOT USE. specify multiple translation systems, each consisting of an id, followed by a set of models ids, eg '0 T1 R1 L0'");
AddParam("link-param-count", "DEPRECATED. DO NOT USE. Number of parameters on word links when using confusion networks or lattices (default = 1)");
AddParam("weight-slm", "slm", "DEPRECATED. DO NOT USE. weight(s) for syntactic language model"); AddParam("weight-slm", "slm", "DEPRECATED. DO NOT USE. weight(s) for syntactic language model");
AddParam("weight-bl", "bl", "DEPRECATED. DO NOT USE. weight for bleu score feature"); AddParam("weight-bl", "bl", "DEPRECATED. DO NOT USE. weight for bleu score feature");
AddParam("weight-d", "d", "DEPRECATED. DO NOT USE. weight(s) for distortion (reordering components)"); AddParam("weight-d", "d", "DEPRECATED. DO NOT USE. weight(s) for distortion (reordering components)");
@ -189,8 +191,7 @@ Parameter::Parameter()
AddParam("weight", "weights for ALL models, 1 per line 'WeightName value'. Weight names can be repeated"); AddParam("weight", "weights for ALL models, 1 per line 'WeightName value'. Weight names can be repeated");
AddParam("weight-overwrite", "special parameter for mert. All on 1 line. Overrides weights specified in 'weights' argument"); AddParam("weight-overwrite", "special parameter for mert. All on 1 line. Overrides weights specified in 'weights' argument");
AddParam("translation-systems", "DEPRECATED. DO NOT USE. specify multiple translation systems, each consisting of an id, followed by a set of models ids, eg '0 T1 R1 L0'"); AddParam("input-scores", "2 numbers on 2 lines - [1] of scores on each edge of a confusion network or lattice input (default=1). [2] Number of 'real' word scores (0 or 1. default=0)");
} }
Parameter::~Parameter() Parameter::~Parameter()
@ -353,18 +354,7 @@ void Parameter::ConvertWeightArgs()
cerr << "Do not mix old and new format for specify weights"; cerr << "Do not mix old and new format for specify weights";
} }
// input scores. if size=1, add an extra for 'real' word count feature. TODO HACK
bool addExtraInputWeight = false;
if (m_setting["weight-i"].size() == 1)
{
addExtraInputWeight = true;
}
ConvertWeightArgs("weight-i", "PhraseModel"); ConvertWeightArgs("weight-i", "PhraseModel");
if (addExtraInputWeight)
m_setting["weight"].push_back("PhraseModel 0.0");
ConvertWeightArgs("weight-t", "PhraseModel"); ConvertWeightArgs("weight-t", "PhraseModel");
ConvertWeightArgs("weight-w", "WordPenalty"); ConvertWeightArgs("weight-w", "WordPenalty");
ConvertWeightArgs("weight-l", "LM"); ConvertWeightArgs("weight-l", "LM");

View File

@ -91,11 +91,9 @@ StaticData::StaticData()
,m_phraseLengthFeature(NULL) ,m_phraseLengthFeature(NULL)
,m_targetWordInsertionFeature(NULL) ,m_targetWordInsertionFeature(NULL)
,m_sourceWordDeletionFeature(NULL) ,m_sourceWordDeletionFeature(NULL)
,m_numLinkParams(1)
,m_fLMsLoaded(false) ,m_fLMsLoaded(false)
,m_sourceStartPosMattersForRecombination(false) ,m_sourceStartPosMattersForRecombination(false)
,m_inputType(SentenceInput) ,m_inputType(SentenceInput)
,m_numInputScores(0)
,m_bleuScoreFeature(NULL) ,m_bleuScoreFeature(NULL)
,m_detailedTranslationReportingFilePath() ,m_detailedTranslationReportingFilePath()
,m_onlyDistinctNBest(false) ,m_onlyDistinctNBest(false)
@ -1213,6 +1211,8 @@ bool StaticData::LoadPhraseTables()
if (m_parameter->GetParam("ttable-file").size() > 0) { if (m_parameter->GetParam("ttable-file").size() > 0) {
// weights // weights
const vector<float> &weightAll = m_parameter->GetWeights("PhraseModel"); const vector<float> &weightAll = m_parameter->GetWeights("PhraseModel");
for (int i = 0; i < weightAll.size(); ++i)
cerr << weightAll[i] << " " << flush;
const vector<string> &translationVector = m_parameter->GetParam("ttable-file"); const vector<string> &translationVector = m_parameter->GetParam("ttable-file");
vector<size_t> maxTargetPhrase = Scan<size_t>(m_parameter->GetParam("ttable-limit")); vector<size_t> maxTargetPhrase = Scan<size_t>(m_parameter->GetParam("ttable-limit"));
@ -1272,57 +1272,38 @@ bool StaticData::LoadPhraseTables()
// first InputScores (if any), then translation scores // first InputScores (if any), then translation scores
vector<float> weight; vector<float> weight;
if(currDict==0 && (m_inputType == ConfusionNetworkInput || m_inputType == WordLatticeInput)) { if(m_inputType == ConfusionNetworkInput || m_inputType == WordLatticeInput) {
if (currDict==0) { // only the 1st pt. THis is shit
// TODO. find what the assumptions made by confusion network about phrase table output which makes // TODO. find what the assumptions made by confusion network about phrase table output which makes
// it only work with binrary file. This is a hack // it only work with binary file. This is a hack
CHECK(implementation == Binary);
m_numInputScores=m_parameter->GetParam("weight-i").size(); if (m_parameter->GetParam("input-scores").size()) {
m_numInputScores = Scan<size_t>(m_parameter->GetParam("input-scores")[0]);
if (implementation == Binary)
{
for(unsigned k=0; k<m_numInputScores; ++k)
weight.push_back(Scan<float>(m_parameter->GetParam("weight-i")[k]));
} }
else {
m_numInputScores = 1;
}
numScoreComponent += m_numInputScores;
if(m_parameter->GetParam("link-param-count").size()) if (m_parameter->GetParam("input-scores").size() > 1) {
m_numLinkParams = Scan<size_t>(m_parameter->GetParam("link-param-count")[0]); m_numRealWordsInInput = Scan<size_t>(m_parameter->GetParam("input-scores")[1]);
}
//print some info about this interaction: else {
if (implementation == Binary) { m_numRealWordsInInput = 0;
if (m_numLinkParams == m_numInputScores) { }
VERBOSE(1,"specified equal numbers of link parameters and insertion weights, not using non-epsilon 'real' word link count.\n"); numScoreComponent += m_numRealWordsInInput;
} else if ((m_numLinkParams + 1) == m_numInputScores) {
VERBOSE(1,"WARN: "<< m_numInputScores << " insertion weights found and only "<< m_numLinkParams << " link parameters specified, applying non-epsilon 'real' word link count for last feature weight.\n");
} else {
stringstream strme;
strme << "You specified " << m_numInputScores
<< " input weights (weight-i), but you specified " << m_numLinkParams << " link parameters (link-param-count)!";
UserMessage::Add(strme.str());
return false;
} }
} }
else { // not confusion network or lattice input
}
if (!m_inputType) {
m_numInputScores = 0; m_numInputScores = 0;
m_numRealWordsInInput = 0;
} }
//this number changes depending on what phrase table we're talking about: only 0 has the weights on it
size_t tableInputScores = (currDict == 0 && implementation == Binary) ? m_numInputScores : 0;
for (size_t currScore = 0 ; currScore < numScoreComponent; currScore++) for (size_t currScore = 0 ; currScore < numScoreComponent; currScore++)
weight.push_back(weightAll[weightAllOffset + currScore]); weight.push_back(weightAll[weightAllOffset + currScore]);
if(weight.size() - tableInputScores != numScoreComponent) {
stringstream strme;
strme << "Your phrase table has " << numScoreComponent
<< " scores, but you specified " << (weight.size() - tableInputScores) << " weights!";
UserMessage::Add(strme.str());
return false;
}
weightAllOffset += numScoreComponent; weightAllOffset += numScoreComponent;
numScoreComponent += tableInputScores;
string targetPath, alignmentsFile; string targetPath, alignmentsFile;
if (implementation == SuffixArray) { if (implementation == SuffixArray) {

View File

@ -138,7 +138,7 @@ protected:
, m_maxNoTransOptPerCoverage , m_maxNoTransOptPerCoverage
, m_maxNoPartTransOpt , m_maxNoPartTransOpt
, m_maxPhraseLength , m_maxPhraseLength
, m_numLinkParams; , m_numRealWordsInInput;
std::string std::string
m_constraintFileName; m_constraintFileName;
@ -443,8 +443,8 @@ public:
return m_minlexrMemory; return m_minlexrMemory;
} }
size_t GetNumLinkParams() const { size_t GetNumRealWordsInInput() const {
return m_numLinkParams; return m_numRealWordsInInput;
} }
const std::vector<std::string> &GetDescription() const { const std::vector<std::string> &GetDescription() const {
return m_parameter->GetParam("description"); return m_parameter->GetParam("description");

View File

@ -33,12 +33,14 @@ void WordLattice::Print(std::ostream& out) const
int WordLattice::InitializeFromPCNDataType(const PCN::CN& cn, const std::vector<FactorType>& factorOrder, const std::string& debug_line) int WordLattice::InitializeFromPCNDataType(const PCN::CN& cn, const std::vector<FactorType>& factorOrder, const std::string& debug_line)
{ {
size_t numLinkParams = StaticData::Instance().GetNumLinkParams();
size_t numLinkWeights = StaticData::Instance().GetNumInputScores();
size_t maxSizePhrase = StaticData::Instance().GetMaxPhraseLength(); size_t maxSizePhrase = StaticData::Instance().GetMaxPhraseLength();
size_t numInputScores = StaticData::Instance().GetNumInputScores();
size_t numRealWordCount = StaticData::Instance().GetNumRealWordsInInput();
size_t totalCount = numInputScores + numRealWordCount;
bool addRealWordCount = (numRealWordCount > 0);
//when we have one more weight than params, we add a word count feature //when we have one more weight than params, we add a word count feature
bool addRealWordCount = ((numLinkParams + 1) == numLinkWeights);
data.resize(cn.size()); data.resize(cn.size());
next_nodes.resize(cn.size()); next_nodes.resize(cn.size());
for(size_t i=0; i<cn.size(); ++i) { for(size_t i=0; i<cn.size(); ++i) {
@ -51,8 +53,8 @@ int WordLattice::InitializeFromPCNDataType(const PCN::CN& cn, const std::vector<
//check for correct number of link parameters //check for correct number of link parameters
if (alt.first.second.size() != numLinkParams) { if (alt.first.second.size() != numInputScores) {
TRACE_ERR("ERROR: need " << numLinkParams << " link parameters, found " << alt.first.second.size() << " while reading column " << i << " from " << debug_line << "\n"); TRACE_ERR("ERROR: need " << numInputScores << " link parameters, found " << alt.first.second.size() << " while reading column " << i << " from " << debug_line << "\n");
return false; return false;
} }