get weights from param::allWeights. Need to be able to override weights when tuning

This commit is contained in:
Hieu Hoang 2016-08-18 22:23:21 +01:00
parent 2e5b61af86
commit 5c145f7414
6 changed files with 41 additions and 42 deletions

View File

@ -23,7 +23,7 @@ FeatureFunction::FeatureFunction(size_t startInd, const std::string &line)
,m_tuneable(true)
{
ParseLine(line);
cerr << GetName() << " " << m_startInd << "-" << (m_startInd + m_numScores - 1) << endl;
//cerr << GetName() << " " << m_startInd << "-" << (m_startInd + m_numScores - 1) << endl;
}
FeatureFunction::~FeatureFunction()

View File

@ -68,7 +68,7 @@ void FeatureFunctions::Create()
UTIL_THROW_IF2(ffParams == NULL, "Must have [feature] section");
BOOST_FOREACH(const std::string &line, *ffParams){
cerr << "line=" << line << endl;
//cerr << "line=" << line << endl;
FeatureFunction *ff = Create(line);
m_featureFunctions.push_back(ff);
@ -134,11 +134,11 @@ const FeatureFunction *FeatureFunctions::FindFeatureFunction(
const std::string &name) const
{
BOOST_FOREACH(const FeatureFunction *ff, m_featureFunctions){
if (ff->GetName() == name) {
return ff;
}
}
return NULL;
if (ff->GetName() == name) {
return ff;
}
}
return NULL;
}
const PhraseTable *FeatureFunctions::GetPhraseTableExcludeUnknownWordPenalty(size_t ptInd)

View File

@ -84,14 +84,23 @@ System::~System()
void System::LoadWeights()
{
const PARAM_VEC *vec = params.GetParam("weight");
UTIL_THROW_IF2(vec == NULL, "Must have [weight] section");
weights.Init(featureFunctions);
BOOST_FOREACH(const std::string &line, *vec){
cerr << "line=" << line << endl;
weights.CreateFromString(featureFunctions, line);
}
//cerr << "Weights:" << endl;
typedef std::map<std::string, std::vector<float> > WeightMap;
const WeightMap &allWeights = params.GetAllWeights();
BOOST_FOREACH(const WeightMap::value_type &valPair, allWeights) {
const string &ffName = valPair.first;
const std::vector<float> &ffWeights = valPair.second;
/*
cerr << ffName << "=";
for (size_t i = 0; i < ffWeights.size(); ++i) {
cerr << ffWeights[i] << " ";
}
cerr << endl;
*/
weights.SetWeights(featureFunctions, ffName, ffWeights);
}
}
void System::LoadMappings()

View File

@ -46,36 +46,26 @@ std::ostream &Weights::Debug(std::ostream &out, const System &system) const
}
void Weights::CreateFromString(const FeatureFunctions &ffs,
const std::string &line)
{
std::vector<std::string> toks = Tokenize(line);
assert(toks.size());
string ffName = toks[0];
assert(ffName[ffName.size() - 1] == '=');
ffName = ffName.substr(0, ffName.size() - 1);
//cerr << "ffName=" << ffName << endl;
const FeatureFunction *ff = ffs.FindFeatureFunction(ffName);
assert(ff);
size_t startInd = ff->GetStartInd();
size_t numScores = ff->GetNumScores();
assert(numScores == toks.size() - 1);
for (size_t i = 0; i < numScores; ++i) {
SCORE score = Scan<SCORE>(toks[i + 1]);
m_weights[i + startInd] = score;
}
}
std::vector<SCORE> Weights::GetWeights(const FeatureFunction &ff) const
{
std::vector<SCORE> ret(m_weights.begin() + ff.GetStartInd(), m_weights.begin() + ff.GetStartInd() + ff.GetNumScores());
return ret;
}
void Weights::SetWeights(const FeatureFunctions &ffs, const std::string &ffName, const std::vector<float> &weights)
{
const FeatureFunction *ff = ffs.FindFeatureFunction(ffName);
UTIL_THROW_IF2(ff == NULL, "Feature function not found:" << ffName);
size_t startInd = ff->GetStartInd();
size_t numScores = ff->GetNumScores();
UTIL_THROW_IF2(weights.size() != numScores, "Wrong number of weights. " << weights.size() << "!=" << numScores);
for (size_t i = 0; i < numScores; ++i) {
SCORE weight = weights[i];
m_weights[startInd + i] = weight;
}
}
}

View File

@ -29,10 +29,10 @@ public:
std::ostream &Debug(std::ostream &out, const System &system) const;
void CreateFromString(const FeatureFunctions &ffs, const std::string &line);
std::vector<SCORE> GetWeights(const FeatureFunction &ff) const;
void SetWeights(const FeatureFunctions &ffs, const std::string &ffName, const std::vector<float> &weights);
protected:
std::vector<SCORE> m_weights;
};

View File

@ -127,7 +127,7 @@ public:
void OverwriteParam(const std::string &paramName, PARAM_VEC values);
std::vector<float> GetWeights(const std::string &name);
std::map<std::string, std::vector<float> > GetAllWeights() const
const std::map<std::string, std::vector<float> > &GetAllWeights() const
{
return m_weights;
}