Merge branch 'weight-new' of github.com:moses-smt/mosesdecoder into weight-new

This commit is contained in:
Hieu Hoang 2013-02-07 09:42:23 +00:00
commit 2e9af629d2
9 changed files with 83 additions and 25 deletions

View File

@ -33,8 +33,8 @@ using namespace std;
namespace Moses {
LanguageModel::LanguageModel(const std::string &line) :
StatefulFeatureFunction("LM", StaticData::Instance().GetLMEnableOOVFeature() ? 2 : 1, line )
LanguageModel::LanguageModel(const std::string& description, const std::string &line) :
StatefulFeatureFunction(description, StaticData::Instance().GetLMEnableOOVFeature() ? 2 : 1, line )
{
m_enableOOVFeature = StaticData::Instance().GetLMEnableOOVFeature();
}

View File

@ -39,7 +39,7 @@ class Phrase;
//! Abstract base class which represent a language model on a contiguous phrase
class LanguageModel : public StatefulFeatureFunction {
protected:
LanguageModel(const std::string &line);
LanguageModel(const std::string& description, const std::string &line);
// This can't be in the constructor for virual function dispatch reasons

View File

@ -121,8 +121,8 @@ public:
class LMRefCount : public LanguageModel {
public:
LMRefCount(LanguageModelImplementation *impl, const std::string &line)
: LanguageModel(line)
LMRefCount(LanguageModelImplementation *impl, const std::string& description, const std::string &line)
: LanguageModel(description, line)
, m_impl(impl) {}
LanguageModel *Duplicate() const {
@ -162,7 +162,7 @@ class LMRefCount : public LanguageModel {
private:
LMRefCount(const LMRefCount &copy_from)
: LanguageModel(copy_from.GetArgLine())
: LanguageModel(copy_from.GetScoreProducerDescription(), copy_from.GetArgLine())
, m_impl(copy_from.m_impl) {}
boost::shared_ptr<LanguageModelImplementation> m_impl;

View File

@ -63,7 +63,7 @@ struct KenLMState : public FFState {
*/
template <class Model> class LanguageModelKen : public LanguageModel {
public:
LanguageModelKen(const std::string &line, const std::string &file, FactorType factorType, bool lazy);
LanguageModelKen(const std::string &description, const std::string &line, const std::string &file, FactorType factorType, bool lazy);
LanguageModel *Duplicate() const;
@ -138,8 +138,8 @@ private:
std::vector<lm::WordIndex> &m_mapping;
};
template <class Model> LanguageModelKen<Model>::LanguageModelKen(const std::string &line, const std::string &file, FactorType factorType, bool lazy)
:LanguageModel(line)
template <class Model> LanguageModelKen<Model>::LanguageModelKen(const std::string &description, const std::string &line, const std::string &file, FactorType factorType, bool lazy)
:LanguageModel(description, line)
,m_factorType(factorType)
{
lm::ngram::Config config;
@ -163,7 +163,7 @@ template <class Model> LanguageModel *LanguageModelKen<Model>::Duplicate() const
}
template <class Model> LanguageModelKen<Model>::LanguageModelKen(const LanguageModelKen<Model> &copy_from)
:LanguageModel(copy_from.GetArgLine()),
:LanguageModel(copy_from.GetScoreProducerDescription(), copy_from.GetArgLine()),
m_ngram(copy_from.m_ngram),
// TODO: don't copy this.
m_lmIdLookup(copy_from.m_lmIdLookup),
@ -336,7 +336,7 @@ template <class Model> FFState *LanguageModelKen<Model>::EvaluateChart(const Cha
} // namespace
LanguageModel *ConstructKenLM(const std::string &line)
LanguageModel *ConstructKenLM(const std::string &description, const std::string &line)
{
cerr << "line=" << line << endl;
FactorType factorType;
@ -367,32 +367,32 @@ LanguageModel *ConstructKenLM(const std::string &line)
}
}
return ConstructKenLM(line, filePath, factorType, lazy);
return ConstructKenLM(description, line, filePath, factorType, lazy);
}
LanguageModel *ConstructKenLM(const std::string &line, const std::string &file, FactorType factorType, bool lazy) {
LanguageModel *ConstructKenLM(const std::string &description, const std::string &line, const std::string &file, FactorType factorType, bool lazy) {
try {
lm::ngram::ModelType model_type;
if (lm::ngram::RecognizeBinary(file.c_str(), model_type)) {
switch(model_type) {
case lm::ngram::PROBING:
return new LanguageModelKen<lm::ngram::ProbingModel>(line, file, factorType, lazy);
return new LanguageModelKen<lm::ngram::ProbingModel>(description, line, file, factorType, lazy);
case lm::ngram::REST_PROBING:
return new LanguageModelKen<lm::ngram::RestProbingModel>(line, file, factorType, lazy);
return new LanguageModelKen<lm::ngram::RestProbingModel>(description, line, file, factorType, lazy);
case lm::ngram::TRIE:
return new LanguageModelKen<lm::ngram::TrieModel>(line, file, factorType, lazy);
return new LanguageModelKen<lm::ngram::TrieModel>(description, line, file, factorType, lazy);
case lm::ngram::QUANT_TRIE:
return new LanguageModelKen<lm::ngram::QuantTrieModel>(line, file, factorType, lazy);
return new LanguageModelKen<lm::ngram::QuantTrieModel>(description, line, file, factorType, lazy);
case lm::ngram::ARRAY_TRIE:
return new LanguageModelKen<lm::ngram::ArrayTrieModel>(line, file, factorType, lazy);
return new LanguageModelKen<lm::ngram::ArrayTrieModel>(description, line, file, factorType, lazy);
case lm::ngram::QUANT_ARRAY_TRIE:
return new LanguageModelKen<lm::ngram::QuantArrayTrieModel>(line, file, factorType, lazy);
return new LanguageModelKen<lm::ngram::QuantArrayTrieModel>(description, line, file, factorType, lazy);
default:
std::cerr << "Unrecognized kenlm model type " << model_type << std::endl;
abort();
}
} else {
return new LanguageModelKen<lm::ngram::ProbingModel>(line, file, factorType, lazy);
return new LanguageModelKen<lm::ngram::ProbingModel>(description, line, file, factorType, lazy);
}
} catch (std::exception &e) {
std::cerr << e.what() << std::endl;

View File

@ -30,10 +30,10 @@ namespace Moses {
class LanguageModel;
LanguageModel *ConstructKenLM(const std::string &line);
LanguageModel *ConstructKenLM(const std::string &description, const std::string &line);
//! This will also load. Returns a templated KenLM class
LanguageModel *ConstructKenLM(const std::string &line, const std::string &file, FactorType factorType, bool lazy);
LanguageModel *ConstructKenLM(const std::string &description, const std::string &line, const std::string &file, FactorType factorType, bool lazy);
} // namespace Moses

View File

@ -324,7 +324,14 @@ bool Parameter::LoadParam(int argc, char* argv[])
std::vector<float> &Parameter::GetWeights(const std::string &name, size_t ind)
{
return m_weights[name + SPrint(ind)];
std::vector<float> &ret = m_weights[name + SPrint(ind)];
cerr << "WEIGHT " << name << ind << "=";
for (size_t i = 0; i < ret.size(); ++i) {
cerr << ret[i] << ",";
}
cerr << endl;
return ret;
}
void Parameter::SetWeight(const std::string &name, size_t ind, float weight)
@ -1097,6 +1104,17 @@ void Parameter::OverwriteParam(const string &paramName, PARAM_VEC values)
}
VERBOSE(2, std::endl);
}
std::set<std::string> Parameter::GetWeightNames() const
{
std::set<std::string> ret;
std::map<std::string, std::vector<float> >::const_iterator iter;
for (iter = m_weights.begin(); iter != m_weights.end(); ++iter) {
const string &key = iter->first;
ret.insert(key);
}
return ret;
}
}

View File

@ -23,6 +23,7 @@ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
#define moses_Parameter_h
#include <string>
#include <set>
#include <map>
#include <vector>
#include "TypeDef.h"
@ -113,6 +114,7 @@ public:
}
std::vector<float> &GetWeights(const std::string &name, size_t ind);
std::set<std::string> GetWeightNames() const;
const PARAM_MAP &GetParams() const
{ return m_setting; }

View File

@ -595,13 +595,13 @@ bool StaticData::LoadData(Parameter *parameter)
SetWeights(model, weights);
}
else if (feature == "KENLM") {
LanguageModel *model = ConstructKenLM(line);
LanguageModel *model = ConstructKenLM(feature, line);
const vector<float> &weights = m_parameter->GetWeights(feature, featureIndex);
SetWeights(model, weights);
}
else if (feature == "IRSTLM") {
LanguageModelIRST *irstlm = new LanguageModelIRST(line);
LanguageModel *model = new LMRefCount(irstlm, line);
LanguageModel *model = new LMRefCount(irstlm, feature, line);
const vector<float> &weights = m_parameter->GetWeights(feature, featureIndex);
SetWeights(model, weights);
}
@ -661,6 +661,10 @@ bool StaticData::LoadData(Parameter *parameter)
if (!LoadPhraseTables()) return false;
if (!LoadDecodeGraphs()) return false;
if (!CheckWeights()) {
return false;
}
// report individual sparse features in n-best list
if (m_parameter->GetParam("report-sparse-features").size() > 0) {
for(size_t i=0; i<m_parameter->GetParam("report-sparse-features").size(); i++) {
@ -702,6 +706,8 @@ bool StaticData::LoadData(Parameter *parameter)
m_allWeights.PlusEquals(extraWeights);
}
cerr << endl << "m_allWeights=" << m_allWeights << endl;
return true;
}
@ -1267,6 +1273,36 @@ void StaticData::CollectFeatureFunctions()
}
bool StaticData::CheckWeights() const
{
set<string> weightNames = m_parameter->GetWeightNames();
const std::vector<FeatureFunction*> &ffs = FeatureFunction::GetFeatureFunctions();
for (size_t i = 0; i < ffs.size(); ++i) {
const FeatureFunction &ff = *ffs[i];
const string &descr = ff.GetScoreProducerDescription();
set<string>::iterator iter = weightNames.find(descr);
if (iter == weightNames.end()) {
cerr << "Can't find weights for feature function " << descr << endl;
}
else {
weightNames.erase(iter);
}
}
if (!weightNames.empty()) {
cerr << "The following weights have no feature function. Maybe incorrectly spelt weights: ";
set<string>::iterator iter;
for (iter = weightNames.begin(); iter != weightNames.end(); ++iter) {
cerr << *iter << ",";
}
return false;
}
return true;
}
} // namespace

View File

@ -711,6 +711,8 @@ public:
void CleanUpAfterSentenceProcessing(const InputType& source) const;
void CollectFeatureFunctions();
bool CheckWeights() const;
};
}