new format for [mapping] section

This commit is contained in:
Hieu Hoang 2015-01-07 08:55:31 +05:30
parent 1e0a2835bf
commit bfac071742
4 changed files with 88 additions and 85 deletions

View File

@ -280,7 +280,7 @@ void MosesDecoder::initialize(StaticData& staticData, const std::string& source,
// set weight of BleuScoreFeature
//cerr << "Reload Bleu feature weight: " << bleuObjectiveWeight*bleuScoreWeight << " (" << bleuObjectiveWeight << "*" << bleuScoreWeight << ")" << endl;
staticData.ReLoadBleuScoreFeatureParameter(bleuObjectiveWeight*bleuScoreWeight);
//staticData.ReLoadBleuScoreFeatureParameter(bleuObjectiveWeight*bleuScoreWeight);
m_bleuScoreFeature->SetCurrSourceLength((*m_sentence).GetSize());
if (chartDecoding)

View File

@ -502,7 +502,7 @@ bool StaticData::LoadData(Parameter *parameter)
LoadFeatureFunctions();
}
if (!LoadDecodeGraphs()) return false;
LoadDecodeGraphs();
if (!CheckWeights()) {
@ -597,7 +597,7 @@ void StaticData::LoadChartDecodingParameters()
}
bool StaticData::LoadDecodeGraphs()
void StaticData::LoadDecodeGraphs()
{
vector<string> mappingVector;
vector<size_t> maxChartSpans;
@ -614,6 +614,28 @@ bool StaticData::LoadDecodeGraphs()
maxChartSpans = Scan<size_t>(*params);
}
vector<string> toks = Tokenize(mappingVector[0]);
if (toks.size() == 3) {
// eg 0 T 0
LoadDecodeGraphsOld(mappingVector, maxChartSpans);
}
else if (toks.size() == 2) {
if (toks[0] == "T" || toks[0] == "G") {
// eg. T 0
LoadDecodeGraphsOld(mappingVector, maxChartSpans);
}
else {
// eg. 0 TM1
LoadDecodeGraphsNew(mappingVector, maxChartSpans);
}
}
else {
UTIL_THROW(util::Exception, "Malformed mapping");
}
}
void StaticData::LoadDecodeGraphsOld(const vector<string> &mappingVector, const vector<size_t> &maxChartSpans)
{
const vector<PhraseDictionary*>& pts = PhraseDictionary::GetColl();
const vector<GenerationDictionary*>& gens = GenerationDictionary::GetColl();
@ -627,10 +649,12 @@ bool StaticData::LoadDecodeGraphs()
DecodeType decodeType;
size_t index;
if (token.size() == 2) {
// eg. T 0
decodeGraphInd = 0;
decodeType = token[0] == "T" ? Translate : Generate;
index = Scan<size_t>(token[1]);
} else if (token.size() == 3) {
// eg. 0 T 0
// For specifying multiple translation model
decodeGraphInd = Scan<size_t>(token[0]);
//the vectorList index can only increment by one
@ -670,8 +694,8 @@ bool StaticData::LoadDecodeGraphs()
}
decodeStep = new DecodeStepGeneration(gens[index], prev, *featuresRemaining);
break;
case InsertNullFertilityWord:
UTIL_THROW(util::Exception, "Please implement NullFertilityInsertion.");
default:
UTIL_THROW(util::Exception, "Unknown decode step");
break;
}
@ -707,90 +731,77 @@ bool StaticData::LoadDecodeGraphs()
decodeGraph.SetBackoff(Scan<size_t>(backoffVector->at(i)));
}
}
return true;
}
void StaticData::ReLoadParameter()
void StaticData::LoadDecodeGraphsNew(const std::vector<std::string> &mappingVector, const std::vector<size_t> &maxChartSpans)
{
UTIL_THROW(util::Exception, "completely redo. Too many hardcoded ff"); // TODO completely redo. Too many hardcoded ff
/*
m_verboseLevel = 1;
if (m_parameter->GetParam("verbose").size() == 1) {
m_verboseLevel = Scan<size_t>( m_parameter->GetParam("verbose")[0]);
}
const std::vector<FeatureFunction*> *featuresRemaining = &FeatureFunction::GetFeatureFunctions();
DecodeStep *prev = 0;
size_t prevDecodeGraphInd = 0;
// check whether "weight-u" is already set
if (m_parameter->isParamShortNameSpecified("u")) {
if (m_parameter->GetParamShortName("u").size() < 1 ) {
PARAM_VEC w(1,"1.0");
m_parameter->OverwriteParamShortName("u", w);
}
}
for(size_t i=0; i<mappingVector.size(); i++) {
vector<string> token = Tokenize(mappingVector[i]);
size_t decodeGraphInd;
size_t index;
//loop over all ScoreProducer to update weights
decodeGraphInd = Scan<size_t>(token[0]);
//the vectorList index can only increment by one
UTIL_THROW_IF2(decodeGraphInd != prevDecodeGraphInd && decodeGraphInd != prevDecodeGraphInd + 1,
"Malformed mapping");
if (decodeGraphInd > prevDecodeGraphInd) {
prev = NULL;
}
std::vector<const ScoreProducer*>::const_iterator iterSP;
for (iterSP = transSystem.GetFeatureFunctions().begin() ; iterSP != transSystem.GetFeatureFunctions().end() ; ++iterSP) {
std::string paramShortName = (*iterSP)->GetScoreProducerWeightShortName();
vector<float> Weights = Scan<float>(m_parameter->GetParamShortName(paramShortName));
if (prevDecodeGraphInd < decodeGraphInd) {
featuresRemaining = &FeatureFunction::GetFeatureFunctions();
}
if (paramShortName == "d") { //basic distortion model takes the first weight
if ((*iterSP)->GetScoreProducerDescription() == "Distortion") {
Weights.resize(1); //take only the first element
} else { //lexicalized reordering model takes the other
Weights.erase(Weights.begin()); //remove the first element
}
std::cerr << "this is the Distortion Score Producer -> " << (*iterSP)->GetScoreProducerDescription() << std::cerr;
std::cerr << "this is the Distortion Score Producer; it has " << (*iterSP)->GetNumScoreComponents() << " weights"<< std::cerr;
std::cerr << Weights << std::endl;
} else if (paramShortName == "tm") {
continue;
}
SetWeights(*iterSP, Weights);
}
FeatureFunction &ff = FeatureFunction::FindFeatureFunction(token[1]);
// std::cerr << "There are " << m_phraseDictionary.size() << " m_phraseDictionaryfeatures" << std::endl;
DecodeStep* decodeStep = NULL;
if (typeid(ff) == typeid(PhraseDictionary)) {
decodeStep = new DecodeStepTranslation(&static_cast<PhraseDictionary&>(ff), prev, *featuresRemaining);
}
else if (typeid(ff) == typeid(GenerationDictionary)) {
decodeStep = new DecodeStepGeneration(&static_cast<GenerationDictionary&>(ff), prev, *featuresRemaining);
}
else {
UTIL_THROW(util::Exception, "Unknown decode step");
}
const vector<float> WeightsTM = Scan<float>(m_parameter->GetParamShortName("tm"));
// std::cerr << "WeightsTM: " << WeightsTM << std::endl;
featuresRemaining = &decodeStep->GetFeaturesRemaining();
const vector<float> WeightsLM = Scan<float>(m_parameter->GetParamShortName("lm"));
// std::cerr << "WeightsLM: " << WeightsLM << std::endl;
UTIL_THROW_IF2(decodeStep == NULL, "Null decode step");
if (m_decodeGraphs.size() < decodeGraphInd + 1) {
DecodeGraph *decodeGraph;
if (IsChart()) {
size_t maxChartSpan = (decodeGraphInd < maxChartSpans.size()) ? maxChartSpans[decodeGraphInd] : DEFAULT_MAX_CHART_SPAN;
VERBOSE(1,"max-chart-span: " << maxChartSpans[decodeGraphInd] << endl);
decodeGraph = new DecodeGraph(m_decodeGraphs.size(), maxChartSpan);
} else {
decodeGraph = new DecodeGraph(m_decodeGraphs.size());
}
size_t index_WeightTM = 0;
for(size_t i=0; i<transSystem.GetPhraseDictionaries().size(); ++i) {
PhraseDictionaryFeature &phraseDictionaryFeature = *m_phraseDictionary[i];
m_decodeGraphs.push_back(decodeGraph); // TODO max chart span
}
// std::cerr << "phraseDictionaryFeature.GetNumScoreComponents():" << phraseDictionaryFeature.GetNumScoreComponents() << std::endl;
// std::cerr << "phraseDictionaryFeature.GetNumInputScores():" << phraseDictionaryFeature.GetNumInputScores() << std::endl;
m_decodeGraphs[decodeGraphInd]->Add(decodeStep);
prev = decodeStep;
prevDecodeGraphInd = decodeGraphInd;
}
vector<float> tmp_weights;
for(size_t j=0; j<phraseDictionaryFeature.GetNumScoreComponents(); ++j)
tmp_weights.push_back(WeightsTM[index_WeightTM++]);
// set maximum n-gram size for backoff approach to decoding paths
// default is always use subsequent paths (value = 0)
// if specified, record maxmimum unseen n-gram size
const vector<string> *backoffVector = m_parameter->GetParam("decoding-graph-backoff");
for(size_t i=0; i<m_decodeGraphs.size() && backoffVector && i<backoffVector->size(); i++) {
DecodeGraph &decodeGraph = *m_decodeGraphs[i];
// std::cerr << tmp_weights << std::endl;
if (i < backoffVector->size()) {
decodeGraph.SetBackoff(Scan<size_t>(backoffVector->at(i)));
}
}
SetWeights(&phraseDictionaryFeature, tmp_weights);
}
*/
}
void StaticData::ReLoadBleuScoreFeatureParameter(float weight)
{
assert(false);
/*
//loop over ScoreProducers to update weights of BleuScoreFeature
std::vector<const ScoreProducer*>::const_iterator iterSP;
for (iterSP = transSystem.GetFeatureFunctions().begin() ; iterSP != transSystem.GetFeatureFunctions().end() ; ++iterSP) {
std::string paramShortName = (*iterSP)->GetScoreProducerWeightShortName();
if (paramShortName == "bl") {
SetWeight(*iterSP, weight);
break;
}
}
*/
}
// ScoreComponentCollection StaticData::GetAllWeightsScoreComponentCollection() const {}
@ -829,11 +840,6 @@ float StaticData::GetWeightWordPenalty() const
return weightWP;
}
float StaticData::GetWeightUnknownWordPenalty() const
{
return GetWeight(&UnknownWordPenaltyProducer::Instance());
}
void StaticData::InitializeForInput(const InputType& source) const
{
const std::vector<FeatureFunction*> &producers = FeatureFunction::GetFeatureFunctions();

View File

@ -220,7 +220,9 @@ protected:
void LoadNonTerminals();
//! load decoding steps
bool LoadDecodeGraphs();
void LoadDecodeGraphs();
void LoadDecodeGraphsOld(const std::vector<std::string> &mappingVector, const std::vector<size_t> &maxChartSpans);
void LoadDecodeGraphsNew(const std::vector<std::string> &mappingVector, const std::vector<size_t> &maxChartSpans);
void NoCache();
@ -612,9 +614,6 @@ public:
return m_continuePartialTranslation;
}
void ReLoadParameter();
void ReLoadBleuScoreFeatureParameter(float weight);
Parameter* GetParameter() {
return m_parameter;
}
@ -723,7 +722,6 @@ public:
}
float GetWeightWordPenalty() const;
float GetWeightUnknownWordPenalty() const;
const std::vector<DecodeGraph*>& GetDecodeGraphs() const {
return m_decodeGraphs;

View File

@ -93,7 +93,6 @@ enum FactorDirection {
enum DecodeType {
Translate
,Generate
,InsertNullFertilityWord //! an optional step that attempts to insert a few closed-class words to improve LM scores
};
namespace LexReorderType