Fixed the multiple models tests.

Multiple models won't work properly though, since the default model
is now hard-coded in various spots.
This commit is contained in:
Barry Haddow 2012-07-17 22:36:10 +01:00
parent 2b4e61d826
commit 080fa0fa4e
8 changed files with 14 additions and 17 deletions

View File

@ -236,7 +236,7 @@ void ChartRuleLookupManagerOnDisk::GetChartRuleCollection(
const OnDiskPt::TargetPhraseCollection *tpcollBerkeleyDb = node->GetTargetPhraseCollection(m_dictionary.GetTableLimit(), m_dbWrapper);
std::vector<float> weightT = staticData.GetTranslationSystem(TranslationSystem::DEFAULT).GetTranslationWeights(m_dictionary.GetDictIndex());
std::vector<float> weightT = staticData.GetWeights(m_dictionary.GetFeature());
targetPhraseCollection
= tpcollBerkeleyDb->ConvertToMoses(m_inputFactorsVec
,m_outputFactorsVec

View File

@ -139,8 +139,9 @@ protected:
return 0;
}
//TODO: Multiple models broken here
const TranslationSystem& system = StaticData::Instance().GetTranslationSystem(TranslationSystem::DEFAULT);
std::vector<float> weights = system.GetTranslationWeights(m_obj->GetDictIndex());
std::vector<float> weights = StaticData::Instance().GetWeights(m_obj->GetFeature());
float weightWP = system.GetWeightWordPenalty();
std::vector<TargetPhrase> tCands;
@ -374,7 +375,7 @@ protected:
stack.push_back(State(i, i, m_dict->GetRoot(), std::vector<float>(m_numInputScores,0.0)));
const TranslationSystem& system = StaticData::Instance().GetTranslationSystem(TranslationSystem::DEFAULT);
std::vector<float> weightT = system.GetTranslationWeights(m_obj->GetDictIndex());
std::vector<float> weightT = StaticData::Instance().GetWeights(m_obj->GetFeature());
float weightWP = system.GetWeightWordPenalty();
while(!stack.empty()) {

View File

@ -87,7 +87,7 @@ PhraseDictionaryFeature::PhraseDictionaryFeature
PhraseDictionary* PhraseDictionaryFeature::LoadPhraseTable(const TranslationSystem* system)
{
const StaticData& staticData = StaticData::Instance();
std::vector<float> weightT = system->GetTranslationWeights(m_dictIndex);
std::vector<float> weightT = staticData.GetWeights(this);
if (m_implementation == Memory) {
// memory phrase table

View File

@ -64,7 +64,7 @@ void PhraseDictionaryALSuffixArray::InitializeForInput(InputType const& source)
std::auto_ptr<RuleTableLoader> loader =
RuleTableLoaderFactory::Create(grammarFile);
std::vector<float> weightT = StaticData::Instance().GetTranslationSystem(TranslationSystem::DEFAULT).GetTranslationWeights(GetDictIndex());
std::vector<float> weightT = StaticData::Instance().GetWeights(GetFeature());
bool ret = loader->Load(*m_input, *m_output, inFile, weightT, m_tableLimit,
*m_languageModels, m_wpProducer, *this);

View File

@ -91,7 +91,6 @@ ChartRuleLookupManager *PhraseDictionaryOnDisk::CreateRuleLookupManager(
const InputType &sentence,
const ChartCellCollection &cellCollection)
{
std::vector<float> weightT = StaticData::Instance().GetTranslationSystem(TranslationSystem::DEFAULT).GetTranslationWeights(GetDictIndex());
return new ChartRuleLookupManagerOnDisk(sentence, cellCollection, *this,
m_dbWrapper, m_languageModels,
m_wpProducer, m_inputFactorsVec,

View File

@ -73,7 +73,13 @@ private:
static IndexPair GetIndexes(const ScoreProducer* sp)
{
ScoreIndexMap::const_iterator indexIter = s_scoreIndexes.find(sp);
CHECK(indexIter != s_scoreIndexes.end());
if (indexIter == s_scoreIndexes.end()) {
std::cerr << "ERROR: ScoreProducer: " << sp->GetScoreProducerDescription() <<
" not registered with ScoreIndexMap" << std::endl;
std::cerr << "You must call ScoreComponentCollection.RegisterScoreProducer() " <<
" for every ScoreProducer" << std::endl;
abort();
}
return indexIter->second;
}

View File

@ -175,12 +175,4 @@ namespace Moses {
return StaticData::Instance().GetWeight(m_distortionScoreProducer);
}
std::vector<float> TranslationSystem::GetTranslationWeights(size_t index) const {
std::vector<float> weights = StaticData::Instance().GetWeights(GetTranslationScoreProducer(index));
//VERBOSE(1, "Read weightT from translation sytem.. ");
for (size_t i = 0; i < weights.size(); ++i)
//VERBOSE(1, weights[i] << " ");
//VERBOSE(1, std::endl);
return weights;
}
};

View File

@ -83,12 +83,11 @@ class TranslationSystem {
const UnknownWordPenaltyProducer *GetUnknownWordPenaltyProducer() const { return m_unknownWpProducer; }
const DistortionScoreProducer* GetDistortionProducer() const {return m_distortionScoreProducer;}
const PhraseDictionaryFeature *GetTranslationScoreProducer(size_t index) const { return GetPhraseDictionaries()[index]; }
const PhraseDictionaryFeature *GetTranslationScoreProducer(size_t index) const { return GetPhraseDictionaries().at(index); }
float GetWeightWordPenalty() const;
float GetWeightUnknownWordPenalty() const;
float GetWeightDistortion() const;
std::vector<float> GetTranslationWeights(size_t index) const;
//sentence (and thread) specific initialisationn and cleanup
void InitializeBeforeSentenceProcessing(const InputType& source) const;