dynamic updating

git-svn-id: https://mosesdecoder.svn.sourceforge.net/svnroot/mosesdecoder/branches/mira-mtm5@3431 1f5c12ca-751b-0410-a591-d2e778427230
This commit is contained in:
hieuhoang1972 2010-09-14 09:42:37 +00:00
parent b487e4853f
commit ff7ce559b7
8 changed files with 340 additions and 0 deletions

View File

@ -140,6 +140,8 @@ int main(int argc, char* argv[])
line += "\n";
ReadInput(*ioWrapper,staticData.GetInputType(),source, line);
StaticData &ss = StaticData::InstanceNonConst();
//ss.ChangeWeights( "lm", 8989);
// note: source is only valid within this while loop!
IFVERBOSE(1)
ResetUserTime();

View File

@ -1025,6 +1025,7 @@
HEADER_SEARCH_PATHS = (
../irstlm/include,
../srilm/include,
/usr/local/include,
);
ONLY_ACTIVE_ARCH = YES;
PREBINDING = NO;
@ -1048,6 +1049,7 @@
HEADER_SEARCH_PATHS = (
../irstlm/include,
../srilm/include,
/usr/local/include,
);
ONLY_ACTIVE_ARCH = YES;
PREBINDING = NO;

142
moses/src/OnlineCommand.cpp Normal file
View File

@ -0,0 +1,142 @@
// $Id: OnlineCommand.cpp 3428 2010-09-13 17:55:23Z nicolabertoldi $
// vim:tabstop=2
/***********************************************************************
Moses - factored phrase-based language decoder
Copyright (C) 2006 University of Edinburgh
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
***********************************************************************/
#include <stdexcept>
#include "StaticData.h" // needed for debugging purpose only
#include "OnlineCommand.h"
#include "Util.h"
using namespace std;
#define COMMAND_KEYWORD "@CMD@"
namespace Moses
{
OnlineCommand::OnlineCommand()
{
VERBOSE(3,"OnlineCommand::OnlineCommand()" << std::endl);
VERBOSE(3,"COMMAND_KEYWORD: " << COMMAND_KEYWORD << std::endl);
command_type = '\0';
command_value = '\0';
accepted_commands.push_back("-weight-l"); // weight(s) for language models
accepted_commands.push_back("-weight-t"); // weight(s) for translation model components
accepted_commands.push_back("-weight-d"); // weight(s) for distortion (reordering components)
accepted_commands.push_back("-weight-w"); // weight for word penalty
accepted_commands.push_back("-weight-u"); // weight for unknown words penalty
accepted_commands.push_back("-weight-g"); // weight(s) for global lexical model components
accepted_commands.push_back("-verbose"); // weights for translation model components
}
bool OnlineCommand::Parse(std::string& line)
{
VERBOSE(3,"OnlineCommand::Parse(std::string& line)" << std::endl);
int next_string_pos = 0;
string firststring = GetFirstString(line, next_string_pos);
bool flag = false;
if (firststring.compare(COMMAND_KEYWORD) == 0){
command_type = GetFirstString(line, next_string_pos);
for (vector<string>::const_iterator iterParam = accepted_commands.begin(); iterParam!=accepted_commands.end(); ++iterParam) {
if (command_type.compare(*iterParam) == 0){ //requested command is found
command_value = line.substr(next_string_pos);
flag = true;
}
}
if (!flag){
VERBOSE(3,"OnlineCommand::Parse: This command |" << command_type << "| is unknown." << std::endl);
}
return true;
}else{
return false;
}
}
void OnlineCommand::Execute() const
{
std::cerr << "void OnlineCommand::Execute() const" << std::endl;
VERBOSE(3,"OnlineCommand::Execute() const" << std::endl);
StaticData &staticData = (StaticData&) StaticData::Instance();
VERBOSE(1,"Handling online command: " << COMMAND_KEYWORD << " " << command_type << " " << command_value << std::endl);
// weights
vector<float> actual_weights;
vector<float> weights;
PARAM_VEC values;
bool flag = false;
for(vector<std::string>::const_iterator iterParam = accepted_commands.begin(); iterParam != accepted_commands.end(); iterParam++)
{
std::string paramName = *iterParam;
if (command_type.compare(paramName) == 0){ //requested command is paramName
Tokenize(values, command_value);
//remove initial "-" character
paramName.erase(0,1);
staticData.GetParameter()->OverwriteParam(paramName, values);
staticData.ReLoadParameter();
// check on weights
vector<float> weights = staticData.GetAllWeights();
IFVERBOSE(2) {
TRACE_ERR("The score component vector looks like this:\n" << staticData.GetScoreIndexManager());
TRACE_ERR("The global weight vector looks like this:");
for (size_t j=0; j<weights.size(); j++) { TRACE_ERR(" " << weights[j]); }
TRACE_ERR("\n");
}
flag = true;
}
}
if (!flag){
TRACE_ERR("ERROR: The command |" << command_type << "| is unknown." << std::endl);
}
}
void OnlineCommand::Print(std::ostream& out) const
{
VERBOSE(3,"OnlineCommand::Print(std::ostream& out) const" << std::endl);
out << command_type << " -> " << command_value << "\n";
}
void OnlineCommand::Clean()
{
VERBOSE(3,"OnlineCommand::Clean() const" << std::endl);
command_type = '\0';
command_value = '\0';
}
}

59
moses/src/OnlineCommand.h Normal file
View File

@ -0,0 +1,59 @@
// $Id: OnlineCommand.h 3428 2010-09-13 17:55:23Z nicolabertoldi $
/***********************************************************************
Moses - factored phrase-based language decoder
Copyright (C) 2006 University of Edinburgh
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
***********************************************************************/
#ifndef moses_OnlineCommand_h
#define moses_OnlineCommand_h
#include <string>
#include "InputType.h"
#include "StaticData.h"
namespace Moses
{
/***
* A class used specifically to read online commnds to modify on the fly system' parameters
*/
class OnlineCommand
{
private:
std::string command_type;
std::string command_value;
PARAM_VEC accepted_commands;
public:
OnlineCommand();
bool Parse(std::string& str);
void Print(std::ostream& out = std::cerr) const;
void Execute() const;
void Clean();
inline std::string GetType(){ return command_type; };
inline std::string GetValue(){ return command_value; };
};
}
#endif

View File

@ -29,6 +29,7 @@ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
#include "Util.h"
#include "InputFileStream.h"
#include "UserMessage.h"
#include "StaticData.h"
#if HAVE_CONFIG_H
#include "config.h"
#endif
@ -152,6 +153,7 @@ void Parameter::AddParam(const string &paramName, const string &abbrevName, cons
m_valid[paramName] = true;
m_valid[abbrevName] = true;
m_abbreviation[paramName] = abbrevName;
m_fullname[abbrevName] = paramName;
m_description[paramName] = description;
}
@ -597,6 +599,30 @@ void Parameter::PrintCredit()
cerr << endl << endl;
}
/** update parameter settings with command line switches
* \param paramName full name of parameter
* \param values inew values for paramName */
void Parameter::OverwriteParam(const string &paramName, PARAM_VEC values)
{
VERBOSE(2,"Overwriting parameter " << paramName);
m_setting[paramName]; // defines the parameter, important for boolean switches
if (m_setting[paramName].size() > 1){
VERBOSE(2," (the parameter had " << m_setting[paramName].size() << " previous values)");
assert(m_setting[paramName].size() == values.size());
}else{
VERBOSE(2," (the parameter does not have previous values)");
m_setting[paramName].resize(values.size());
}
VERBOSE(2," with the following values:");
int i=0;
for (PARAM_VEC::iterator iter = values.begin(); iter != values.end() ; iter++, i++){
m_setting[paramName][i] = *iter;
VERBOSE(2, " " << *iter);
}
VERBOSE(2, std::endl);
}
}

View File

@ -45,6 +45,7 @@ protected:
PARAM_BOOL m_valid;
PARAM_STRING m_abbreviation;
PARAM_STRING m_description;
PARAM_STRING m_fullname;
std::string FindParam(const std::string &paramSwitch, int argc, char* argv[]);
void OverwriteParam(const std::string &paramSwitch, const std::string &paramName, int argc, char* argv[]);
@ -76,6 +77,31 @@ public:
return m_setting.find( paramName ) != m_setting.end();
}
bool isParamShortNameSpecified(const std::string &paramName)
{
return m_setting.find( GetFullName(paramName) ) != m_setting.end();
}
const std::string GetFullName(std::string abbr)
{
return m_fullname[abbr];
}
const std::string GetAbbreviation(std::string full)
{
return m_abbreviation[full];
}
const PARAM_VEC &GetParamShortName(const std::string &paramName)
{
return GetParam(GetFullName(paramName));
}
void OverwriteParam(const std::string &paramName, PARAM_VEC values);
void OverwriteParamShortName(const std::string &paramShortName, PARAM_VEC values){
OverwriteParam(GetFullName(paramShortName),values);
}
};
}

View File

@ -1242,6 +1242,86 @@ void StaticData::AddTransOptListToCache(const DecodeGraph &decodeGraph, const Ph
ReduceTransOptCache();
}
void StaticData::ReLoadParameter()
{
m_verboseLevel = 1;
if (m_parameter->GetParam("verbose").size() == 1)
{
m_verboseLevel = Scan<size_t>( m_parameter->GetParam("verbose")[0]);
}
// check whether "weight-u" is already set
if (m_parameter->isParamShortNameSpecified("u"))
{
if (m_parameter->GetParamShortName("u").size() < 1 ){
PARAM_VEC w(1,"1.0");
m_parameter->OverwriteParamShortName("u", w);
}
}
//loop over all ScoreProducer to update weights
std::vector<const ScoreProducer*>::const_iterator iterSP;
for (iterSP = GetScoreIndexManager().GetFeatureFunctions().begin() ; iterSP != GetScoreIndexManager().GetFeatureFunctions().end() ; ++iterSP)
{
std::string paramShortName = (*iterSP)->GetScoreProducerWeightShortName();
vector<float> Weights = Scan<float>(m_parameter->GetParamShortName(paramShortName));
if (paramShortName == "d"){ //basic distortion model takes the first weight
if ((*iterSP)->GetScoreProducerDescription() == "Distortion"){
Weights.resize(1); //take only the first element
}else{ //lexicalized reordering model takes the other
Weights.erase(Weights.begin()); //remove the first element
}
// std::cerr << "this is the Distortion Score Producer -> " << (*iterSP)->GetScoreProducerDescription() << std::cerr;
// std::cerr << "this is the Distortion Score Producer; it has " << (*iterSP)->GetNumScoreComponents() << " weights"<< std::cerr;
}
SetWeightsForScoreProducer(*iterSP, Weights);
// std::cerr << Weights << std::endl;
}
m_weightWordPenalty = Scan<float>(m_parameter->GetParamShortName("w")[0]);
m_weightUnknownWord = Scan<float>(m_parameter->GetParamShortName("u")[0]);
m_weightDistortion = Scan<float>(m_parameter->GetParamShortName("d")[0]);
// std::cerr << "There are " << m_phraseDictionary.size() << " m_phraseDictionaryfeatures" << std::endl;
const float weightWP = Scan<float>(m_parameter->GetParamShortName("w"))[0];
const vector<float> WeightsTM = Scan<float>(m_parameter->GetParamShortName("tm"));
// std::cerr << "WeightsTM: " << WeightsTM << std::endl;
const vector<float> WeightsLM = Scan<float>(m_parameter->GetParamShortName("lm"));
// std::cerr << "WeightsLM: " << WeightsLM << std::endl;
size_t index_WeightTM = 0;
for(size_t i=0;i<m_phraseDictionary.size();++i)
{
PhraseDictionaryFeature &phraseDictionaryFeature = *m_phraseDictionary[i];
PhraseDictionary &phraseDictionary = *phraseDictionaryFeature.GetDictionary();
// std::cerr << "phraseDictionaryFeature.GetNumScoreComponents():" << phraseDictionaryFeature.GetNumScoreComponents() << std::endl;
// std::cerr << "phraseDictionaryFeature.GetNumInputScores():" << phraseDictionaryFeature.GetNumInputScores() << std::endl;
vector<float> tmp_weights;
for(size_t j=0;j<phraseDictionaryFeature.GetNumScoreComponents();++j)
tmp_weights.push_back(WeightsTM[index_WeightTM++]);
// std::cerr << tmp_weights << std::endl;
phraseDictionary.SetWeightTransModel(tmp_weights);
}
const LMList &languageModels = GetAllLM();
LMList::const_iterator lmIter;
size_t index_WeightLM = 0;
for (lmIter = languageModels.begin(); lmIter != languageModels.end(); ++lmIter)
{
LanguageModel &lm = **lmIter;
lm.SetWeight(WeightsLM[index_WeightLM++]);
}
}
}

View File

@ -539,6 +539,9 @@ public:
bool ContinuePartialTranslation() const { return m_continuePartialTranslation; }
void ReLoadParameter();
};
}