mirror of
https://github.com/moses-smt/mosesdecoder.git
synced 2024-12-27 05:55:02 +03:00
clean up weights code for confusion networks & lattices. Works, except for multiple phrase-tables or factors
This commit is contained in:
parent
b8d4c64d6d
commit
da9cd0e3aa
@ -118,11 +118,12 @@ bool ConfusionNet::ReadFormat0(std::istream& in,
|
|||||||
const std::vector<FactorType>& factorOrder)
|
const std::vector<FactorType>& factorOrder)
|
||||||
{
|
{
|
||||||
Clear();
|
Clear();
|
||||||
std::string line;
|
size_t numInputScores = StaticData::Instance().GetNumInputScores();
|
||||||
size_t numLinkParams = StaticData::Instance().GetNumLinkParams();
|
size_t numRealWordCount = StaticData::Instance().GetNumRealWordsInInput();
|
||||||
size_t numLinkWeights = StaticData::Instance().GetNumInputScores();
|
size_t totalCount = numInputScores + numRealWordCount;
|
||||||
bool addRealWordCount = ((numLinkParams + 1) == numLinkWeights);
|
bool addRealWordCount = (numRealWordCount > 0);
|
||||||
|
|
||||||
|
std::string line;
|
||||||
while(getline(in,line)) {
|
while(getline(in,line)) {
|
||||||
std::istringstream is(line);
|
std::istringstream is(line);
|
||||||
std::string word;
|
std::string word;
|
||||||
@ -131,8 +132,8 @@ bool ConfusionNet::ReadFormat0(std::istream& in,
|
|||||||
while(is>>word) {
|
while(is>>word) {
|
||||||
Word w;
|
Word w;
|
||||||
String2Word(word,w,factorOrder);
|
String2Word(word,w,factorOrder);
|
||||||
std::vector<float> probs(numLinkWeights,0.0);
|
std::vector<float> probs(totalCount, 0.0);
|
||||||
for(size_t i=0; i<numLinkParams; i++) {
|
for(size_t i=0; i < numInputScores; i++) {
|
||||||
double prob;
|
double prob;
|
||||||
if (!(is>>prob)) {
|
if (!(is>>prob)) {
|
||||||
TRACE_ERR("ERROR: unable to parse CN input - bad link probability, or wrong number of scores\n");
|
TRACE_ERR("ERROR: unable to parse CN input - bad link probability, or wrong number of scores\n");
|
||||||
@ -150,7 +151,7 @@ bool ConfusionNet::ReadFormat0(std::istream& in,
|
|||||||
}
|
}
|
||||||
//store 'real' word count in last feature if we have one more weight than we do arc scores and not epsilon
|
//store 'real' word count in last feature if we have one more weight than we do arc scores and not epsilon
|
||||||
if (addRealWordCount && word!=EPSILON && word!="")
|
if (addRealWordCount && word!=EPSILON && word!="")
|
||||||
probs[numLinkParams] = -1.0;
|
probs.back() = -1.0;
|
||||||
col.push_back(std::make_pair(w,probs));
|
col.push_back(std::make_pair(w,probs));
|
||||||
}
|
}
|
||||||
if(col.size()) {
|
if(col.size()) {
|
||||||
|
@ -121,7 +121,6 @@ Parameter::Parameter()
|
|||||||
AddParam("cube-pruning-diversity", "cbd", "How many hypotheses should be created for each coverage. (default = 0)");
|
AddParam("cube-pruning-diversity", "cbd", "How many hypotheses should be created for each coverage. (default = 0)");
|
||||||
AddParam("search-algorithm", "Which search algorithm to use. 0=normal stack, 1=cube pruning, 2=cube growing. (default = 0)");
|
AddParam("search-algorithm", "Which search algorithm to use. 0=normal stack, 1=cube pruning, 2=cube growing. (default = 0)");
|
||||||
AddParam("constraint", "Location of the file with target sentences to produce constraining the search");
|
AddParam("constraint", "Location of the file with target sentences to produce constraining the search");
|
||||||
AddParam("link-param-count", "Number of parameters on word links when using confusion networks or lattices (default = 1)");
|
|
||||||
AddParam("description", "Source language, target language, description");
|
AddParam("description", "Source language, target language, description");
|
||||||
AddParam("max-chart-span", "maximum num. of source word chart rules can consume (default 10)");
|
AddParam("max-chart-span", "maximum num. of source word chart rules can consume (default 10)");
|
||||||
AddParam("non-terminals", "list of non-term symbols, space separated");
|
AddParam("non-terminals", "list of non-term symbols, space separated");
|
||||||
@ -166,6 +165,9 @@ Parameter::Parameter()
|
|||||||
|
|
||||||
AddParam("report-segmentation", "t", "report phrase segmentation in the output");
|
AddParam("report-segmentation", "t", "report phrase segmentation in the output");
|
||||||
|
|
||||||
|
AddParam("translation-systems", "DEPRECATED. DO NOT USE. specify multiple translation systems, each consisting of an id, followed by a set of models ids, eg '0 T1 R1 L0'");
|
||||||
|
AddParam("link-param-count", "DEPRECATED. DO NOT USE. Number of parameters on word links when using confusion networks or lattices (default = 1)");
|
||||||
|
|
||||||
AddParam("weight-slm", "slm", "DEPRECATED. DO NOT USE. weight(s) for syntactic language model");
|
AddParam("weight-slm", "slm", "DEPRECATED. DO NOT USE. weight(s) for syntactic language model");
|
||||||
AddParam("weight-bl", "bl", "DEPRECATED. DO NOT USE. weight for bleu score feature");
|
AddParam("weight-bl", "bl", "DEPRECATED. DO NOT USE. weight for bleu score feature");
|
||||||
AddParam("weight-d", "d", "DEPRECATED. DO NOT USE. weight(s) for distortion (reordering components)");
|
AddParam("weight-d", "d", "DEPRECATED. DO NOT USE. weight(s) for distortion (reordering components)");
|
||||||
@ -189,8 +191,7 @@ Parameter::Parameter()
|
|||||||
AddParam("weight", "weights for ALL models, 1 per line 'WeightName value'. Weight names can be repeated");
|
AddParam("weight", "weights for ALL models, 1 per line 'WeightName value'. Weight names can be repeated");
|
||||||
AddParam("weight-overwrite", "special parameter for mert. All on 1 line. Overrides weights specified in 'weights' argument");
|
AddParam("weight-overwrite", "special parameter for mert. All on 1 line. Overrides weights specified in 'weights' argument");
|
||||||
|
|
||||||
AddParam("translation-systems", "DEPRECATED. DO NOT USE. specify multiple translation systems, each consisting of an id, followed by a set of models ids, eg '0 T1 R1 L0'");
|
AddParam("input-scores", "2 numbers on 2 lines - [1] of scores on each edge of a confusion network or lattice input (default=1). [2] Number of 'real' word scores (0 or 1. default=0)");
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Parameter::~Parameter()
|
Parameter::~Parameter()
|
||||||
@ -353,18 +354,7 @@ void Parameter::ConvertWeightArgs()
|
|||||||
cerr << "Do not mix old and new format for specify weights";
|
cerr << "Do not mix old and new format for specify weights";
|
||||||
}
|
}
|
||||||
|
|
||||||
// input scores. if size=1, add an extra for 'real' word count feature. TODO HACK
|
|
||||||
bool addExtraInputWeight = false;
|
|
||||||
if (m_setting["weight-i"].size() == 1)
|
|
||||||
{
|
|
||||||
addExtraInputWeight = true;
|
|
||||||
}
|
|
||||||
ConvertWeightArgs("weight-i", "PhraseModel");
|
ConvertWeightArgs("weight-i", "PhraseModel");
|
||||||
|
|
||||||
if (addExtraInputWeight)
|
|
||||||
m_setting["weight"].push_back("PhraseModel 0.0");
|
|
||||||
|
|
||||||
|
|
||||||
ConvertWeightArgs("weight-t", "PhraseModel");
|
ConvertWeightArgs("weight-t", "PhraseModel");
|
||||||
ConvertWeightArgs("weight-w", "WordPenalty");
|
ConvertWeightArgs("weight-w", "WordPenalty");
|
||||||
ConvertWeightArgs("weight-l", "LM");
|
ConvertWeightArgs("weight-l", "LM");
|
||||||
|
@ -91,11 +91,9 @@ StaticData::StaticData()
|
|||||||
,m_phraseLengthFeature(NULL)
|
,m_phraseLengthFeature(NULL)
|
||||||
,m_targetWordInsertionFeature(NULL)
|
,m_targetWordInsertionFeature(NULL)
|
||||||
,m_sourceWordDeletionFeature(NULL)
|
,m_sourceWordDeletionFeature(NULL)
|
||||||
,m_numLinkParams(1)
|
|
||||||
,m_fLMsLoaded(false)
|
,m_fLMsLoaded(false)
|
||||||
,m_sourceStartPosMattersForRecombination(false)
|
,m_sourceStartPosMattersForRecombination(false)
|
||||||
,m_inputType(SentenceInput)
|
,m_inputType(SentenceInput)
|
||||||
,m_numInputScores(0)
|
|
||||||
,m_bleuScoreFeature(NULL)
|
,m_bleuScoreFeature(NULL)
|
||||||
,m_detailedTranslationReportingFilePath()
|
,m_detailedTranslationReportingFilePath()
|
||||||
,m_onlyDistinctNBest(false)
|
,m_onlyDistinctNBest(false)
|
||||||
@ -1213,6 +1211,8 @@ bool StaticData::LoadPhraseTables()
|
|||||||
if (m_parameter->GetParam("ttable-file").size() > 0) {
|
if (m_parameter->GetParam("ttable-file").size() > 0) {
|
||||||
// weights
|
// weights
|
||||||
const vector<float> &weightAll = m_parameter->GetWeights("PhraseModel");
|
const vector<float> &weightAll = m_parameter->GetWeights("PhraseModel");
|
||||||
|
for (int i = 0; i < weightAll.size(); ++i)
|
||||||
|
cerr << weightAll[i] << " " << flush;
|
||||||
|
|
||||||
const vector<string> &translationVector = m_parameter->GetParam("ttable-file");
|
const vector<string> &translationVector = m_parameter->GetParam("ttable-file");
|
||||||
vector<size_t> maxTargetPhrase = Scan<size_t>(m_parameter->GetParam("ttable-limit"));
|
vector<size_t> maxTargetPhrase = Scan<size_t>(m_parameter->GetParam("ttable-limit"));
|
||||||
@ -1272,57 +1272,38 @@ bool StaticData::LoadPhraseTables()
|
|||||||
// first InputScores (if any), then translation scores
|
// first InputScores (if any), then translation scores
|
||||||
vector<float> weight;
|
vector<float> weight;
|
||||||
|
|
||||||
if(currDict==0 && (m_inputType == ConfusionNetworkInput || m_inputType == WordLatticeInput)) {
|
if(m_inputType == ConfusionNetworkInput || m_inputType == WordLatticeInput) {
|
||||||
|
if (currDict==0) { // only the 1st pt. THis is shit
|
||||||
// TODO. find what the assumptions made by confusion network about phrase table output which makes
|
// TODO. find what the assumptions made by confusion network about phrase table output which makes
|
||||||
// it only work with binrary file. This is a hack
|
// it only work with binary file. This is a hack
|
||||||
|
CHECK(implementation == Binary);
|
||||||
|
|
||||||
m_numInputScores=m_parameter->GetParam("weight-i").size();
|
if (m_parameter->GetParam("input-scores").size()) {
|
||||||
|
m_numInputScores = Scan<size_t>(m_parameter->GetParam("input-scores")[0]);
|
||||||
if (implementation == Binary)
|
|
||||||
{
|
|
||||||
for(unsigned k=0; k<m_numInputScores; ++k)
|
|
||||||
weight.push_back(Scan<float>(m_parameter->GetParam("weight-i")[k]));
|
|
||||||
}
|
}
|
||||||
|
else {
|
||||||
|
m_numInputScores = 1;
|
||||||
|
}
|
||||||
|
numScoreComponent += m_numInputScores;
|
||||||
|
|
||||||
if(m_parameter->GetParam("link-param-count").size())
|
if (m_parameter->GetParam("input-scores").size() > 1) {
|
||||||
m_numLinkParams = Scan<size_t>(m_parameter->GetParam("link-param-count")[0]);
|
m_numRealWordsInInput = Scan<size_t>(m_parameter->GetParam("input-scores")[1]);
|
||||||
|
}
|
||||||
//print some info about this interaction:
|
else {
|
||||||
if (implementation == Binary) {
|
m_numRealWordsInInput = 0;
|
||||||
if (m_numLinkParams == m_numInputScores) {
|
}
|
||||||
VERBOSE(1,"specified equal numbers of link parameters and insertion weights, not using non-epsilon 'real' word link count.\n");
|
numScoreComponent += m_numRealWordsInInput;
|
||||||
} else if ((m_numLinkParams + 1) == m_numInputScores) {
|
|
||||||
VERBOSE(1,"WARN: "<< m_numInputScores << " insertion weights found and only "<< m_numLinkParams << " link parameters specified, applying non-epsilon 'real' word link count for last feature weight.\n");
|
|
||||||
} else {
|
|
||||||
stringstream strme;
|
|
||||||
strme << "You specified " << m_numInputScores
|
|
||||||
<< " input weights (weight-i), but you specified " << m_numLinkParams << " link parameters (link-param-count)!";
|
|
||||||
UserMessage::Add(strme.str());
|
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
else { // not confusion network or lattice input
|
||||||
}
|
|
||||||
if (!m_inputType) {
|
|
||||||
m_numInputScores = 0;
|
m_numInputScores = 0;
|
||||||
|
m_numRealWordsInInput = 0;
|
||||||
}
|
}
|
||||||
//this number changes depending on what phrase table we're talking about: only 0 has the weights on it
|
|
||||||
size_t tableInputScores = (currDict == 0 && implementation == Binary) ? m_numInputScores : 0;
|
|
||||||
|
|
||||||
for (size_t currScore = 0 ; currScore < numScoreComponent; currScore++)
|
for (size_t currScore = 0 ; currScore < numScoreComponent; currScore++)
|
||||||
weight.push_back(weightAll[weightAllOffset + currScore]);
|
weight.push_back(weightAll[weightAllOffset + currScore]);
|
||||||
|
|
||||||
|
|
||||||
if(weight.size() - tableInputScores != numScoreComponent) {
|
|
||||||
stringstream strme;
|
|
||||||
strme << "Your phrase table has " << numScoreComponent
|
|
||||||
<< " scores, but you specified " << (weight.size() - tableInputScores) << " weights!";
|
|
||||||
UserMessage::Add(strme.str());
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
weightAllOffset += numScoreComponent;
|
weightAllOffset += numScoreComponent;
|
||||||
numScoreComponent += tableInputScores;
|
|
||||||
|
|
||||||
string targetPath, alignmentsFile;
|
string targetPath, alignmentsFile;
|
||||||
if (implementation == SuffixArray) {
|
if (implementation == SuffixArray) {
|
||||||
|
@ -138,7 +138,7 @@ protected:
|
|||||||
, m_maxNoTransOptPerCoverage
|
, m_maxNoTransOptPerCoverage
|
||||||
, m_maxNoPartTransOpt
|
, m_maxNoPartTransOpt
|
||||||
, m_maxPhraseLength
|
, m_maxPhraseLength
|
||||||
, m_numLinkParams;
|
, m_numRealWordsInInput;
|
||||||
|
|
||||||
std::string
|
std::string
|
||||||
m_constraintFileName;
|
m_constraintFileName;
|
||||||
@ -443,8 +443,8 @@ public:
|
|||||||
return m_minlexrMemory;
|
return m_minlexrMemory;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t GetNumLinkParams() const {
|
size_t GetNumRealWordsInInput() const {
|
||||||
return m_numLinkParams;
|
return m_numRealWordsInInput;
|
||||||
}
|
}
|
||||||
const std::vector<std::string> &GetDescription() const {
|
const std::vector<std::string> &GetDescription() const {
|
||||||
return m_parameter->GetParam("description");
|
return m_parameter->GetParam("description");
|
||||||
|
@ -33,12 +33,14 @@ void WordLattice::Print(std::ostream& out) const
|
|||||||
|
|
||||||
int WordLattice::InitializeFromPCNDataType(const PCN::CN& cn, const std::vector<FactorType>& factorOrder, const std::string& debug_line)
|
int WordLattice::InitializeFromPCNDataType(const PCN::CN& cn, const std::vector<FactorType>& factorOrder, const std::string& debug_line)
|
||||||
{
|
{
|
||||||
size_t numLinkParams = StaticData::Instance().GetNumLinkParams();
|
|
||||||
size_t numLinkWeights = StaticData::Instance().GetNumInputScores();
|
|
||||||
size_t maxSizePhrase = StaticData::Instance().GetMaxPhraseLength();
|
size_t maxSizePhrase = StaticData::Instance().GetMaxPhraseLength();
|
||||||
|
|
||||||
|
size_t numInputScores = StaticData::Instance().GetNumInputScores();
|
||||||
|
size_t numRealWordCount = StaticData::Instance().GetNumRealWordsInInput();
|
||||||
|
size_t totalCount = numInputScores + numRealWordCount;
|
||||||
|
bool addRealWordCount = (numRealWordCount > 0);
|
||||||
|
|
||||||
//when we have one more weight than params, we add a word count feature
|
//when we have one more weight than params, we add a word count feature
|
||||||
bool addRealWordCount = ((numLinkParams + 1) == numLinkWeights);
|
|
||||||
data.resize(cn.size());
|
data.resize(cn.size());
|
||||||
next_nodes.resize(cn.size());
|
next_nodes.resize(cn.size());
|
||||||
for(size_t i=0; i<cn.size(); ++i) {
|
for(size_t i=0; i<cn.size(); ++i) {
|
||||||
@ -51,8 +53,8 @@ int WordLattice::InitializeFromPCNDataType(const PCN::CN& cn, const std::vector<
|
|||||||
|
|
||||||
|
|
||||||
//check for correct number of link parameters
|
//check for correct number of link parameters
|
||||||
if (alt.first.second.size() != numLinkParams) {
|
if (alt.first.second.size() != numInputScores) {
|
||||||
TRACE_ERR("ERROR: need " << numLinkParams << " link parameters, found " << alt.first.second.size() << " while reading column " << i << " from " << debug_line << "\n");
|
TRACE_ERR("ERROR: need " << numInputScores << " link parameters, found " << alt.first.second.size() << " while reading column " << i << " from " << debug_line << "\n");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user