basic support for alternate weight settings

This commit is contained in:
phikoehn 2013-05-31 12:28:57 +01:00
parent 68501f5a36
commit d1650a5aa7
7 changed files with 140 additions and 13 deletions

View File

@ -48,10 +48,12 @@ protected:
long m_translationId; //< contiguous Id
long m_documentId;
long m_topicId;
std::string m_weightSetting;
std::vector<std::string> m_topicIdAndProb;
bool m_useTopicId;
bool m_useTopicIdAndProb;
bool m_hasMetaData;
bool m_specifiesWeightSetting;
long m_segId;
ReorderingConstraint m_reorderingConstraint; /**< limits on reordering specified either by "-mp" switch or xml tags */
std::string m_textType;
@ -109,6 +111,18 @@ public:
std::string GetTextType() const {
return m_textType;
}
void SetSpecifiesWeightSetting(bool specifiesWeightSetting) {
m_specifiesWeightSetting = specifiesWeightSetting;
}
bool GetSpecifiesWeightSetting() const {
return m_specifiesWeightSetting;
}
void SetWeightSetting(std::string settingName) {
m_weightSetting = settingName;
}
std::string GetWeightSetting() const {
return m_weightSetting;
}
void SetTextType(std::string type) {
m_textType = type;
}

View File

@ -80,7 +80,22 @@ void Manager::ProcessSentence()
{
// reset statistics
ResetSentenceStats(m_source);
// check if alternate weight setting is used
// this is not thread safe! it changes StaticData
if (StaticData::Instance().GetHasAlternateWeightSettings()) {
std::cerr << "config defines weight setting\n";
if (m_source.GetSpecifiesWeightSetting()) {
std::cerr << "sentence specifies weight setting\n";
std::cerr << "calling SetWeightSetting( " << m_source.GetWeightSetting() << ")\n";
StaticData::Instance().SetWeightSetting(m_source.GetWeightSetting());
}
else {
StaticData::Instance().SetWeightSetting("default");
}
}
// get translation options
Timer getOptionsTime;
getOptionsTime.start();
m_transOptColl->CreateTranslationOptions();

View File

@ -194,7 +194,7 @@ Parameter::Parameter()
AddParam("feature", "");
AddParam("print-id", "prefix translations with id. Default if false");
AddParam("alternate-weight-setting", "aws", "alternate set of weights to used per xml specification");
}
Parameter::~Parameter()

View File

@ -110,6 +110,13 @@ int Sentence::Read(std::istream& in,const std::vector<FactorType>& factorOrder)
this->SetUseTopicIdAndProb(true);
}
}
if (meta.find("weight-setting") != meta.end()) {
this->SetWeightSetting(meta["weight-setting"]);
this->SetSpecifiesWeightSetting(true);
}
else {
this->SetSpecifiesWeightSetting(false);
}
// parse XML markup in translation line
//const StaticData &staticData = StaticData::Instance();
@ -156,6 +163,7 @@ int Sentence::Read(std::istream& in,const std::vector<FactorType>& factorOrder)
}
// reordering walls and zones
m_reorderingConstraint.InitializeWalls( GetSize() );
// set reordering walls, if "-monotone-at-punction" is set

View File

@ -570,6 +570,7 @@ bool StaticData::LoadData(Parameter *parameter)
vector<string> toks = Tokenize(line);
const string &feature = toks[0];
//int featureIndex = GetFeatureIndex(featureIndexMap, feature);
if (feature == "GlobalLexicalModel") {
GlobalLexicalModel *model = new GlobalLexicalModel(line);
@ -706,7 +707,6 @@ bool StaticData::LoadData(Parameter *parameter)
UserMessage::Add("Unknown feature function:" + feature);
return false;
}
}
CollectFeatureFunctions();
@ -738,6 +738,10 @@ bool StaticData::LoadData(Parameter *parameter)
//cerr << endl << "m_allWeights=" << m_allWeights << endl;
// alternate weight settings
if (m_parameter->GetParam("alternate-weight-setting").size() > 0) {
ProcessAlternateWeightSettings();
}
return true;
}
@ -1181,6 +1185,70 @@ bool StaticData::CheckWeights() const
return true;
}
void StaticData::ProcessAlternateWeightSettings() {
const vector<string> &weightSpecification = m_parameter->GetParam("alternate-weight-setting");
// get mapping from feature names to feature functions
map<string,FeatureFunction*> nameToFF;
const std::vector<FeatureFunction*> &ffs = FeatureFunction::GetFeatureFunctions();
for (size_t i = 0; i < ffs.size(); ++i) {
nameToFF[ ffs[i]->GetScoreProducerDescription() ] = ffs[i];
}
// copy main weight setting as default
m_weightSetting["default"] = new ScoreComponentCollection( m_allWeights );
// go through specification in config file
string currentId = "";
bool hasErrors = false;
for (size_t i=0; i<weightSpecification.size(); ++i) {
// identifier line (with optional additional specifications)
if (weightSpecification[i].find("id=") == 0) {
vector<string> tokens = Tokenize(weightSpecification[i]);
vector<string> args = Tokenize(tokens[0], "=");
currentId = args[1];
cerr << "alternate weight setting " << currentId << endl;
CHECK(m_weightSetting.find(currentId) == m_weightSetting.end());
m_weightSetting[ currentId ] = new ScoreComponentCollection;
// other specifications
for(size_t j=1; j<tokens.size(); j++) {
vector<string> args = Tokenize(tokens[j], "=");
if (args[0] == "weight-file") {
// TODO: support for sparse weights
}
}
}
// weight lines
else {
CHECK(currentId != "");
vector<string> tokens = Tokenize(weightSpecification[i]);
CHECK(tokens.size() >= 2);
// get name and weight values
string name = tokens[0];
name = name.substr(0, name.size() - 1); // remove trailing "="
vector<float> weights(tokens.size() - 1);
for (size_t i = 1; i < tokens.size(); ++i) {
float weight = Scan<float>(tokens[i]);
weights[i - 1] = weight;
}
// check if a valid nane
map<string,FeatureFunction*>::iterator ffLookUp = nameToFF.find(name);
if (ffLookUp == nameToFF.end()) {
cerr << "ERROR: alternate weight setting " << currentId << " specifies weight(s) for " << name << " but there is no such feature function" << endl;
hasErrors = true;
}
else {
m_weightSetting[ currentId ]->Assign( nameToFF[name], weights);
}
}
}
CHECK(!hasErrors);
}
} // namespace

View File

@ -75,7 +75,7 @@ protected:
std::vector<const GenerationDictionary*> m_generationDictionary;
Parameter *m_parameter;
std::vector<FactorType> m_inputFactorOrder, m_outputFactorOrder;
ScoreComponentCollection m_allWeights;
mutable ScoreComponentCollection m_allWeights;
std::vector<DecodeGraph*> m_decodeGraphs;
std::vector<size_t> m_decodeGraphBackoff;
@ -206,6 +206,9 @@ protected:
int m_threadCount;
long m_startTranslationId;
// alternate weight settings
std::map< std::string, ScoreComponentCollection* > m_weightSetting;
StaticData();
@ -658,6 +661,24 @@ public:
return m_nBestIncludesSegmentation;
}
bool GetHasAlternateWeightSettings() const {
return m_weightSetting.size() > 0;
}
void SetWeightSetting(const std::string &settingName) const {
std::cerr << "SetWeightSetting( " << settingName << ")\n";
CHECK(GetHasAlternateWeightSettings());
std::map< std::string, ScoreComponentCollection* >::const_iterator i =
m_weightSetting.find( settingName );
// if not found, resort to default
std::cerr << "using weight setting " << settingName << std::endl;
if (i == m_weightSetting.end()) {
i = m_weightSetting.find( "default" );
std::cerr << "not found, using default weight setting instead\n";
}
m_allWeights = *(i->second);
}
float GetWeightWordPenalty() const;
float GetWeightUnknownWordPenalty() const;
@ -688,6 +709,7 @@ public:
void CollectFeatureFunctions();
bool CheckWeights() const;
void ProcessAlternateWeightSettings();
void SetTemporaryMultiModelWeightsVector(std::vector<float> weights) const {

View File

@ -385,10 +385,12 @@ void TranslationOptionCollection::CreateTranslationOptions()
// ... and that end at endPos
for (size_t endPos = startPos ; endPos < startPos + maxSize ; endPos++) {
if (graphInd > 0 && // only skip subsequent graphs
decodeGraphBackoff[graphInd] != 0 && // use of backoff specified
(endPos-startPos+1 >= decodeGraphBackoff[graphInd] || // size exceeds backoff limit or ...
m_collection[startPos][endPos-startPos].size() > 0)) { // no phrases found so far
VERBOSE(3,"No backoff to graph " << graphInd << " for span [" << startPos << ";" << endPos << "]" << endl);
decodeGraphBackoff[graphInd] != 0 && // limited use of backoff specified
(endPos-startPos+1 > decodeGraphBackoff[graphInd] || // size exceeds backoff limit or ...
m_collection[startPos][endPos-startPos].size() > 0)) { // already covered
VERBOSE(3,"No backoff to graph " << graphInd << " for span [" << startPos << ";" << endPos << "]");
VERBOSE(3,", length limit: " << decodeGraphBackoff[graphInd]);
VERBOSE(3,", found so far: " << m_collection[startPos][endPos-startPos].size() << endl);
// do not create more options
continue;
}
@ -505,11 +507,10 @@ void TranslationOptionCollection::CreateTranslationOptionsForRange(
, startPos, endPos, adhereTableLimit );
// do rest of decode steps
int indexStep = 0;
int indexStep = 1;
for (++iterStep ; iterStep != decodeGraph.end() ; ++iterStep) {
const DecodeStep &decodeStep = **iterStep;
for (++iterStep; iterStep != decodeGraph.end() ; ++iterStep, ++indexStep) {
const DecodeStep &decodeStep = **iterStep;
PartialTranslOptColl* newPtoc = new PartialTranslOptColl;
// go thru each intermediate trans opt just created
@ -531,7 +532,6 @@ void TranslationOptionCollection::CreateTranslationOptionsForRange(
delete oldPtoc;
oldPtoc = newPtoc;
indexStep++;
} // for (++iterStep
// add to fully formed translation option list