mirror of
https://github.com/moses-smt/mosesdecoder.git
synced 2024-12-26 05:14:36 +03:00
clean up weights code for confusion networks & lattices. Works, except for multiple phrase-tables or factors
This commit is contained in:
parent
b8d4c64d6d
commit
da9cd0e3aa
@ -118,11 +118,12 @@ bool ConfusionNet::ReadFormat0(std::istream& in,
|
||||
const std::vector<FactorType>& factorOrder)
|
||||
{
|
||||
Clear();
|
||||
std::string line;
|
||||
size_t numLinkParams = StaticData::Instance().GetNumLinkParams();
|
||||
size_t numLinkWeights = StaticData::Instance().GetNumInputScores();
|
||||
bool addRealWordCount = ((numLinkParams + 1) == numLinkWeights);
|
||||
size_t numInputScores = StaticData::Instance().GetNumInputScores();
|
||||
size_t numRealWordCount = StaticData::Instance().GetNumRealWordsInInput();
|
||||
size_t totalCount = numInputScores + numRealWordCount;
|
||||
bool addRealWordCount = (numRealWordCount > 0);
|
||||
|
||||
std::string line;
|
||||
while(getline(in,line)) {
|
||||
std::istringstream is(line);
|
||||
std::string word;
|
||||
@ -131,8 +132,8 @@ bool ConfusionNet::ReadFormat0(std::istream& in,
|
||||
while(is>>word) {
|
||||
Word w;
|
||||
String2Word(word,w,factorOrder);
|
||||
std::vector<float> probs(numLinkWeights,0.0);
|
||||
for(size_t i=0; i<numLinkParams; i++) {
|
||||
std::vector<float> probs(totalCount, 0.0);
|
||||
for(size_t i=0; i < numInputScores; i++) {
|
||||
double prob;
|
||||
if (!(is>>prob)) {
|
||||
TRACE_ERR("ERROR: unable to parse CN input - bad link probability, or wrong number of scores\n");
|
||||
@ -150,7 +151,7 @@ bool ConfusionNet::ReadFormat0(std::istream& in,
|
||||
}
|
||||
//store 'real' word count in last feature if we have one more weight than we do arc scores and not epsilon
|
||||
if (addRealWordCount && word!=EPSILON && word!="")
|
||||
probs[numLinkParams] = -1.0;
|
||||
probs.back() = -1.0;
|
||||
col.push_back(std::make_pair(w,probs));
|
||||
}
|
||||
if(col.size()) {
|
||||
|
@ -121,7 +121,6 @@ Parameter::Parameter()
|
||||
AddParam("cube-pruning-diversity", "cbd", "How many hypotheses should be created for each coverage. (default = 0)");
|
||||
AddParam("search-algorithm", "Which search algorithm to use. 0=normal stack, 1=cube pruning, 2=cube growing. (default = 0)");
|
||||
AddParam("constraint", "Location of the file with target sentences to produce constraining the search");
|
||||
AddParam("link-param-count", "Number of parameters on word links when using confusion networks or lattices (default = 1)");
|
||||
AddParam("description", "Source language, target language, description");
|
||||
AddParam("max-chart-span", "maximum num. of source word chart rules can consume (default 10)");
|
||||
AddParam("non-terminals", "list of non-term symbols, space separated");
|
||||
@ -166,6 +165,9 @@ Parameter::Parameter()
|
||||
|
||||
AddParam("report-segmentation", "t", "report phrase segmentation in the output");
|
||||
|
||||
AddParam("translation-systems", "DEPRECATED. DO NOT USE. specify multiple translation systems, each consisting of an id, followed by a set of models ids, eg '0 T1 R1 L0'");
|
||||
AddParam("link-param-count", "DEPRECATED. DO NOT USE. Number of parameters on word links when using confusion networks or lattices (default = 1)");
|
||||
|
||||
AddParam("weight-slm", "slm", "DEPRECATED. DO NOT USE. weight(s) for syntactic language model");
|
||||
AddParam("weight-bl", "bl", "DEPRECATED. DO NOT USE. weight for bleu score feature");
|
||||
AddParam("weight-d", "d", "DEPRECATED. DO NOT USE. weight(s) for distortion (reordering components)");
|
||||
@ -189,8 +191,7 @@ Parameter::Parameter()
|
||||
AddParam("weight", "weights for ALL models, 1 per line 'WeightName value'. Weight names can be repeated");
|
||||
AddParam("weight-overwrite", "special parameter for mert. All on 1 line. Overrides weights specified in 'weights' argument");
|
||||
|
||||
AddParam("translation-systems", "DEPRECATED. DO NOT USE. specify multiple translation systems, each consisting of an id, followed by a set of models ids, eg '0 T1 R1 L0'");
|
||||
|
||||
AddParam("input-scores", "2 numbers on 2 lines - [1] of scores on each edge of a confusion network or lattice input (default=1). [2] Number of 'real' word scores (0 or 1. default=0)");
|
||||
}
|
||||
|
||||
Parameter::~Parameter()
|
||||
@ -349,22 +350,11 @@ void Parameter::ConvertWeightArgs()
|
||||
m_setting.count("weight-l") || m_setting.count("weight-u") || m_setting.count("weight-lex") ||
|
||||
m_setting.count("weight-generation") || m_setting.count("weight-lr") || m_setting.count("weight-d")
|
||||
))
|
||||
{
|
||||
cerr << "Do not mix old and new format for specify weights";
|
||||
}
|
||||
|
||||
// input scores. if size=1, add an extra for 'real' word count feature. TODO HACK
|
||||
bool addExtraInputWeight = false;
|
||||
if (m_setting["weight-i"].size() == 1)
|
||||
{
|
||||
addExtraInputWeight = true;
|
||||
cerr << "Do not mix old and new format for specify weights";
|
||||
}
|
||||
|
||||
ConvertWeightArgs("weight-i", "PhraseModel");
|
||||
|
||||
if (addExtraInputWeight)
|
||||
m_setting["weight"].push_back("PhraseModel 0.0");
|
||||
|
||||
|
||||
ConvertWeightArgs("weight-t", "PhraseModel");
|
||||
ConvertWeightArgs("weight-w", "WordPenalty");
|
||||
ConvertWeightArgs("weight-l", "LM");
|
||||
|
@ -91,11 +91,9 @@ StaticData::StaticData()
|
||||
,m_phraseLengthFeature(NULL)
|
||||
,m_targetWordInsertionFeature(NULL)
|
||||
,m_sourceWordDeletionFeature(NULL)
|
||||
,m_numLinkParams(1)
|
||||
,m_fLMsLoaded(false)
|
||||
,m_sourceStartPosMattersForRecombination(false)
|
||||
,m_inputType(SentenceInput)
|
||||
,m_numInputScores(0)
|
||||
,m_bleuScoreFeature(NULL)
|
||||
,m_detailedTranslationReportingFilePath()
|
||||
,m_onlyDistinctNBest(false)
|
||||
@ -1213,6 +1211,8 @@ bool StaticData::LoadPhraseTables()
|
||||
if (m_parameter->GetParam("ttable-file").size() > 0) {
|
||||
// weights
|
||||
const vector<float> &weightAll = m_parameter->GetWeights("PhraseModel");
|
||||
for (int i = 0; i < weightAll.size(); ++i)
|
||||
cerr << weightAll[i] << " " << flush;
|
||||
|
||||
const vector<string> &translationVector = m_parameter->GetParam("ttable-file");
|
||||
vector<size_t> maxTargetPhrase = Scan<size_t>(m_parameter->GetParam("ttable-limit"));
|
||||
@ -1272,57 +1272,38 @@ bool StaticData::LoadPhraseTables()
|
||||
// first InputScores (if any), then translation scores
|
||||
vector<float> weight;
|
||||
|
||||
if(currDict==0 && (m_inputType == ConfusionNetworkInput || m_inputType == WordLatticeInput)) {
|
||||
// TODO. find what the assumptions made by confusion network about phrase table output which makes
|
||||
// it only work with binrary file. This is a hack
|
||||
if(m_inputType == ConfusionNetworkInput || m_inputType == WordLatticeInput) {
|
||||
if (currDict==0) { // only the 1st pt. THis is shit
|
||||
// TODO. find what the assumptions made by confusion network about phrase table output which makes
|
||||
// it only work with binary file. This is a hack
|
||||
CHECK(implementation == Binary);
|
||||
|
||||
m_numInputScores=m_parameter->GetParam("weight-i").size();
|
||||
|
||||
if (implementation == Binary)
|
||||
{
|
||||
for(unsigned k=0; k<m_numInputScores; ++k)
|
||||
weight.push_back(Scan<float>(m_parameter->GetParam("weight-i")[k]));
|
||||
}
|
||||
|
||||
if(m_parameter->GetParam("link-param-count").size())
|
||||
m_numLinkParams = Scan<size_t>(m_parameter->GetParam("link-param-count")[0]);
|
||||
|
||||
//print some info about this interaction:
|
||||
if (implementation == Binary) {
|
||||
if (m_numLinkParams == m_numInputScores) {
|
||||
VERBOSE(1,"specified equal numbers of link parameters and insertion weights, not using non-epsilon 'real' word link count.\n");
|
||||
} else if ((m_numLinkParams + 1) == m_numInputScores) {
|
||||
VERBOSE(1,"WARN: "<< m_numInputScores << " insertion weights found and only "<< m_numLinkParams << " link parameters specified, applying non-epsilon 'real' word link count for last feature weight.\n");
|
||||
} else {
|
||||
stringstream strme;
|
||||
strme << "You specified " << m_numInputScores
|
||||
<< " input weights (weight-i), but you specified " << m_numLinkParams << " link parameters (link-param-count)!";
|
||||
UserMessage::Add(strme.str());
|
||||
return false;
|
||||
if (m_parameter->GetParam("input-scores").size()) {
|
||||
m_numInputScores = Scan<size_t>(m_parameter->GetParam("input-scores")[0]);
|
||||
}
|
||||
else {
|
||||
m_numInputScores = 1;
|
||||
}
|
||||
numScoreComponent += m_numInputScores;
|
||||
|
||||
if (m_parameter->GetParam("input-scores").size() > 1) {
|
||||
m_numRealWordsInInput = Scan<size_t>(m_parameter->GetParam("input-scores")[1]);
|
||||
}
|
||||
else {
|
||||
m_numRealWordsInInput = 0;
|
||||
}
|
||||
numScoreComponent += m_numRealWordsInInput;
|
||||
}
|
||||
|
||||
}
|
||||
if (!m_inputType) {
|
||||
m_numInputScores=0;
|
||||
else { // not confusion network or lattice input
|
||||
m_numInputScores = 0;
|
||||
m_numRealWordsInInput = 0;
|
||||
}
|
||||
//this number changes depending on what phrase table we're talking about: only 0 has the weights on it
|
||||
size_t tableInputScores = (currDict == 0 && implementation == Binary) ? m_numInputScores : 0;
|
||||
|
||||
for (size_t currScore = 0 ; currScore < numScoreComponent; currScore++)
|
||||
weight.push_back(weightAll[weightAllOffset + currScore]);
|
||||
|
||||
|
||||
if(weight.size() - tableInputScores != numScoreComponent) {
|
||||
stringstream strme;
|
||||
strme << "Your phrase table has " << numScoreComponent
|
||||
<< " scores, but you specified " << (weight.size() - tableInputScores) << " weights!";
|
||||
UserMessage::Add(strme.str());
|
||||
return false;
|
||||
}
|
||||
|
||||
weightAllOffset += numScoreComponent;
|
||||
numScoreComponent += tableInputScores;
|
||||
|
||||
string targetPath, alignmentsFile;
|
||||
if (implementation == SuffixArray) {
|
||||
|
@ -138,7 +138,7 @@ protected:
|
||||
, m_maxNoTransOptPerCoverage
|
||||
, m_maxNoPartTransOpt
|
||||
, m_maxPhraseLength
|
||||
, m_numLinkParams;
|
||||
, m_numRealWordsInInput;
|
||||
|
||||
std::string
|
||||
m_constraintFileName;
|
||||
@ -443,8 +443,8 @@ public:
|
||||
return m_minlexrMemory;
|
||||
}
|
||||
|
||||
size_t GetNumLinkParams() const {
|
||||
return m_numLinkParams;
|
||||
size_t GetNumRealWordsInInput() const {
|
||||
return m_numRealWordsInInput;
|
||||
}
|
||||
const std::vector<std::string> &GetDescription() const {
|
||||
return m_parameter->GetParam("description");
|
||||
|
@ -33,12 +33,14 @@ void WordLattice::Print(std::ostream& out) const
|
||||
|
||||
int WordLattice::InitializeFromPCNDataType(const PCN::CN& cn, const std::vector<FactorType>& factorOrder, const std::string& debug_line)
|
||||
{
|
||||
size_t numLinkParams = StaticData::Instance().GetNumLinkParams();
|
||||
size_t numLinkWeights = StaticData::Instance().GetNumInputScores();
|
||||
size_t maxSizePhrase = StaticData::Instance().GetMaxPhraseLength();
|
||||
|
||||
size_t numInputScores = StaticData::Instance().GetNumInputScores();
|
||||
size_t numRealWordCount = StaticData::Instance().GetNumRealWordsInInput();
|
||||
size_t totalCount = numInputScores + numRealWordCount;
|
||||
bool addRealWordCount = (numRealWordCount > 0);
|
||||
|
||||
//when we have one more weight than params, we add a word count feature
|
||||
bool addRealWordCount = ((numLinkParams + 1) == numLinkWeights);
|
||||
data.resize(cn.size());
|
||||
next_nodes.resize(cn.size());
|
||||
for(size_t i=0; i<cn.size(); ++i) {
|
||||
@ -51,8 +53,8 @@ int WordLattice::InitializeFromPCNDataType(const PCN::CN& cn, const std::vector<
|
||||
|
||||
|
||||
//check for correct number of link parameters
|
||||
if (alt.first.second.size() != numLinkParams) {
|
||||
TRACE_ERR("ERROR: need " << numLinkParams << " link parameters, found " << alt.first.second.size() << " while reading column " << i << " from " << debug_line << "\n");
|
||||
if (alt.first.second.size() != numInputScores) {
|
||||
TRACE_ERR("ERROR: need " << numInputScores << " link parameters, found " << alt.first.second.size() << " while reading column " << i << " from " << debug_line << "\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user