move decodeGraph to staticData

This commit is contained in:
Hieu Hoang 2012-12-21 15:28:34 +00:00
parent 8799e721ac
commit ea647fc658
6 changed files with 31 additions and 48 deletions

View File

@ -123,7 +123,7 @@ void ChartParserUnknown::Process(const Word &sourceWord, const WordsRange &range
ChartParser::ChartParser(InputType const &source, const TranslationSystem &system, ChartCellCollectionBase &cells) :
m_unknown(system),
m_decodeGraphList(system.GetDecodeGraphs()),
m_decodeGraphList(StaticData::Instance().GetDecodeGraphs()),
m_source(source) {
const StaticData &staticData = StaticData::Instance();

View File

@ -608,13 +608,6 @@ bool StaticData::LoadData(Parameter *parameter)
}
}
for (size_t k = 0; k < m_decodeGraphs.size(); ++k) {
if (!tableIds.size() || tableIds.find(k) != tableIds.end()) {
VERBOSE(2,"Adding decoder graph " << k << " to translation system " << config[0] << endl);
m_translationSystems.find(config[0])->second.AddDecodeGraph(m_decodeGraphs[k],m_decodeGraphBackoff[k]);
}
}
size_t lmid = 0;
for (LMList::const_iterator k = m_languageModel.begin(); k != m_languageModel.end(); ++k, ++lmid) {
if (!tableIds.size() || tableIds.find(lmid) != tableIds.end()) {
@ -630,7 +623,7 @@ bool StaticData::LoadData(Parameter *parameter)
}
//Instigate dictionary loading
m_translationSystems.find(config[0])->second.ConfigDictionaries();
ConfigDictionaries();
for (int i = 0; i < m_phraseDictionary.size(); i++)
cerr << m_phraseDictionary[i] << " ";
@ -2016,6 +2009,25 @@ float StaticData::GetWeightDistortion() const {
return StaticData::Instance().GetWeight(m_distortionScoreProducer);
}
void StaticData::ConfigDictionaries() {
for (vector<DecodeGraph*>::const_iterator i = m_decodeGraphs.begin();
i != m_decodeGraphs.end(); ++i) {
for (DecodeGraph::const_iterator j = (*i)->begin(); j != (*i)->end(); ++j) {
const DecodeStep* step = *j;
PhraseDictionaryFeature* pdict = const_cast<PhraseDictionaryFeature*>(step->GetPhraseDictionaryFeature());
if (pdict) {
StaticData::InstanceNonConst().AddFeatureFunction(pdict);
pdict->InitDictionary(NULL);
}
GenerationDictionary* gdict = const_cast<GenerationDictionary*>(step->GetGenerationDictionaryFeature());
if (gdict) {
StaticData::InstanceNonConst().AddFeatureFunction(gdict);
}
}
}
}
} // namespace

View File

@ -783,6 +783,12 @@ public:
return weights;
}
const std::vector<DecodeGraph*>& GetDecodeGraphs() const {return m_decodeGraphs;}
const std::vector<size_t>& GetDecodeGraphBackoff() const {return m_decodeGraphBackoff;}
//Called after adding the tables in order to set up the dictionaries
void ConfigDictionaries();
};
}

View File

@ -162,7 +162,7 @@ void TranslationOptionCollection::Prune()
void TranslationOptionCollection::ProcessUnknownWord()
{
const vector<DecodeGraph*>& decodeGraphList = m_system->GetDecodeGraphs();
const vector<DecodeGraph*>& decodeGraphList = StaticData::Instance().GetDecodeGraphs();
size_t size = m_source.GetSize();
// try to translation for coverage with no trans by expanding table limit
for (size_t graphInd = 0 ; graphInd < decodeGraphList.size() ; graphInd++) {
@ -367,8 +367,8 @@ void TranslationOptionCollection::CreateTranslationOptions()
// for all phrases
// there may be multiple decoding graphs (factorizations of decoding)
const vector <DecodeGraph*> &decodeGraphList = m_system->GetDecodeGraphs();
const vector <size_t> &decodeGraphBackoff = m_system->GetDecodeGraphBackoff();
const vector <DecodeGraph*> &decodeGraphList = StaticData::Instance().GetDecodeGraphs();
const vector <size_t> &decodeGraphBackoff = StaticData::Instance().GetDecodeGraphBackoff();
// length of the sentence
size_t size = m_source.GetSize();

View File

@ -53,31 +53,6 @@ namespace Moses {
}
}
//Insert core 'big' features
void TranslationSystem::AddDecodeGraph(DecodeGraph* decodeGraph, size_t backoff) {
m_decodeGraphs.push_back(decodeGraph);
m_decodeGraphBackoff.push_back(backoff);
}
void TranslationSystem::ConfigDictionaries() {
for (vector<DecodeGraph*>::const_iterator i = m_decodeGraphs.begin();
i != m_decodeGraphs.end(); ++i) {
for (DecodeGraph::const_iterator j = (*i)->begin(); j != (*i)->end(); ++j) {
const DecodeStep* step = *j;
PhraseDictionaryFeature* pdict = const_cast<PhraseDictionaryFeature*>(step->GetPhraseDictionaryFeature());
if (pdict) {
StaticData::InstanceNonConst().AddFeatureFunction(pdict);
const_cast<PhraseDictionaryFeature*>(pdict)->InitDictionary(this);
}
GenerationDictionary* gdict = const_cast<GenerationDictionary*>(step->GetGenerationDictionaryFeature());
if (gdict) {
StaticData::InstanceNonConst().AddFeatureFunction(gdict);
}
}
}
}
void TranslationSystem::InitializeBeforeSentenceProcessing(const InputType& source) const {
const StaticData &staticData = StaticData::Instance();
const std::vector<PhraseDictionaryFeature*> &phraseDictionaries = staticData.GetPhraseDictionaries();

View File

@ -54,18 +54,11 @@ class TranslationSystem {
const UnknownWordPenaltyProducer* uwpProducer,
const DistortionScoreProducer* distortionProducer);
//Insert core 'big' features
void AddDecodeGraph(DecodeGraph* decodeGraph, size_t backoff);
//Called after adding the tables in order to set up the dictionaries
void ConfigDictionaries();
const std::string& GetId() const {return m_id;}
//Lists of tables relevant to this system.
const std::vector<DecodeGraph*>& GetDecodeGraphs() const {return m_decodeGraphs;}
const std::vector<size_t>& GetDecodeGraphBackoff() const {return m_decodeGraphBackoff;}
//sentence (and thread) specific initialisationn and cleanup
void InitializeBeforeSentenceProcessing(const InputType& source) const;
@ -76,10 +69,7 @@ class TranslationSystem {
private:
std::string m_id;
std::vector<DecodeGraph*> m_decodeGraphs;
std::vector<size_t> m_decodeGraphBackoff;
};