s2t decoder: fix scoring

Restore scoring function that was commented out in commit 465b4756...
This commit is contained in:
Phil Williams 2015-01-12 08:47:33 +00:00
parent 832b725c59
commit 866f746f98
2 changed files with 45 additions and 45 deletions

View File

@ -360,50 +360,50 @@ template <class Model> FFState *LanguageModelKen<Model>::EvaluateWhenApplied(con
return newState;
}
//template <class Model> FFState *LanguageModelKen<Model>::EvaluateWhenApplied(const Syntax::SHyperedge& hyperedge, int featureID, ScoreComponentCollection *accumulator) const
//{
// LanguageModelChartStateKenLM *newState = new LanguageModelChartStateKenLM();
// lm::ngram::RuleScore<Model> ruleScore(*m_ngram, newState->GetChartState());
// const TargetPhrase &target = *hyperedge.translation;
// const AlignmentInfo::NonTermIndexMap &nonTermIndexMap =
// target.GetAlignNonTerm().GetNonTermIndexMap2();
//
// const size_t size = target.GetSize();
// size_t phrasePos = 0;
// // Special cases for first word.
// if (size) {
// const Word &word = target.GetWord(0);
// if (word.GetFactor(m_factorType) == m_beginSentenceFactor) {
// // Begin of sentence
// ruleScore.BeginSentence();
// phrasePos++;
// } else if (word.IsNonTerminal()) {
// // Non-terminal is first so we can copy instead of rescoring.
// const Syntax::SVertex *pred = hyperedge.tail[nonTermIndexMap[phrasePos]];
// const lm::ngram::ChartState &prevState = static_cast<const LanguageModelChartStateKenLM*>(pred->state[featureID])->GetChartState();
// float prob = UntransformLMScore(pred->best->scoreBreakdown.GetScoresForProducer(this)[0]);
// ruleScore.BeginNonTerminal(prevState, prob);
// phrasePos++;
// }
// }
//
// for (; phrasePos < size; phrasePos++) {
// const Word &word = target.GetWord(phrasePos);
// if (word.IsNonTerminal()) {
// const Syntax::SVertex *pred = hyperedge.tail[nonTermIndexMap[phrasePos]];
// const lm::ngram::ChartState &prevState = static_cast<const LanguageModelChartStateKenLM*>(pred->state[featureID])->GetChartState();
// float prob = UntransformLMScore(pred->best->scoreBreakdown.GetScoresForProducer(this)[0]);
// ruleScore.NonTerminal(prevState, prob);
// } else {
// ruleScore.Terminal(TranslateID(word));
// }
// }
//
// float score = ruleScore.Finish();
// score = TransformLMScore(score);
// accumulator->Assign(this, score);
// return newState;
//}
template <class Model> FFState *LanguageModelKen<Model>::EvaluateWhenApplied(const Syntax::SHyperedge& hyperedge, int featureID, ScoreComponentCollection *accumulator) const
{
LanguageModelChartStateKenLM *newState = new LanguageModelChartStateKenLM();
lm::ngram::RuleScore<Model> ruleScore(*m_ngram, newState->GetChartState());
const TargetPhrase &target = *hyperedge.translation;
const AlignmentInfo::NonTermIndexMap &nonTermIndexMap =
target.GetAlignNonTerm().GetNonTermIndexMap2();
const size_t size = target.GetSize();
size_t phrasePos = 0;
// Special cases for first word.
if (size) {
const Word &word = target.GetWord(0);
if (word.GetFactor(m_factorType) == m_beginSentenceFactor) {
// Begin of sentence
ruleScore.BeginSentence();
phrasePos++;
} else if (word.IsNonTerminal()) {
// Non-terminal is first so we can copy instead of rescoring.
const Syntax::SVertex *pred = hyperedge.tail[nonTermIndexMap[phrasePos]];
const lm::ngram::ChartState &prevState = static_cast<const LanguageModelChartStateKenLM*>(pred->state[featureID])->GetChartState();
float prob = UntransformLMScore(pred->best->scoreBreakdown.GetScoresForProducer(this)[0]);
ruleScore.BeginNonTerminal(prevState, prob);
phrasePos++;
}
}
for (; phrasePos < size; phrasePos++) {
const Word &word = target.GetWord(phrasePos);
if (word.IsNonTerminal()) {
const Syntax::SVertex *pred = hyperedge.tail[nonTermIndexMap[phrasePos]];
const lm::ngram::ChartState &prevState = static_cast<const LanguageModelChartStateKenLM*>(pred->state[featureID])->GetChartState();
float prob = UntransformLMScore(pred->best->scoreBreakdown.GetScoresForProducer(this)[0]);
ruleScore.NonTerminal(prevState, prob);
} else {
ruleScore.Terminal(TranslateID(word));
}
}
float score = ruleScore.Finish();
score = TransformLMScore(score);
accumulator->Assign(this, score);
return newState;
}
template <class Model> void LanguageModelKen<Model>::IncrementalCallback(Incremental::Manager &manager) const
{

View File

@ -59,7 +59,7 @@ public:
virtual FFState *EvaluateWhenApplied(const ChartHypothesis& cur_hypo, int featureID, ScoreComponentCollection *accumulator) const;
// virtual FFState *EvaluateWhenApplied(const Syntax::SHyperedge& hyperedge, int featureID, ScoreComponentCollection *accumulator) const;
virtual FFState *EvaluateWhenApplied(const Syntax::SHyperedge& hyperedge, int featureID, ScoreComponentCollection *accumulator) const;
virtual void IncrementalCallback(Incremental::Manager &manager) const;
virtual void ReportHistoryOrder(std::ostream &out,const Phrase &phrase) const;