score deltas in chart decoding

This commit is contained in:
Matthias Huck 2015-01-07 14:25:43 +00:00
parent 0441fd6ab9
commit 465b475664
13 changed files with 291 additions and 166 deletions

View File

@ -229,7 +229,7 @@ vector< vector<const Word*> > MosesDecoder::runChartDecoder(const std::string& s
for (ChartKBestExtractor::KBestVec::const_iterator p = nBestList.begin();
p != nBestList.end(); ++p) {
const ChartKBestExtractor::Derivation &derivation = **p;
featureValues.push_back(derivation.scoreBreakdown);
featureValues.push_back(*ChartKBestExtractor::GetOutputScoreBreakdown(derivation));
float bleuScore, dynBleuScore, realBleuScore;
dynBleuScore = getBleuScore(featureValues.back());
Phrase outputPhrase = ChartKBestExtractor::GetOutputPhrase(derivation);

View File

@ -61,7 +61,8 @@ ChartHypothesis::ChartHypothesis(const ChartTranslationOptions &transOpt,
const std::vector<HypothesisDimension> &childEntries = item.GetHypothesisDimensions();
m_prevHypos.reserve(childEntries.size());
std::vector<HypothesisDimension>::const_iterator iter;
for (iter = childEntries.begin(); iter != childEntries.end(); ++iter) {
for (iter = childEntries.begin(); iter != childEntries.end(); ++iter)
{
m_prevHypos.push_back(iter->GetHypothesis());
}
}
@ -71,7 +72,6 @@ ChartHypothesis::ChartHypothesis(const ChartTranslationOptions &transOpt,
ChartHypothesis::ChartHypothesis(const ChartHypothesis &pred,
const ChartKBestExtractor & /*unused*/)
:m_currSourceWordsRange(pred.m_currSourceWordsRange)
,m_scoreBreakdown(pred.m_scoreBreakdown)
,m_totalScore(pred.m_totalScore)
,m_arcList(NULL)
,m_winningHypo(NULL)
@ -85,14 +85,17 @@ ChartHypothesis::ChartHypothesis(const ChartHypothesis &pred,
ChartHypothesis::~ChartHypothesis()
{
// delete feature function states
for (unsigned i = 0; i < m_ffStates.size(); ++i) {
for (unsigned i = 0; i < m_ffStates.size(); ++i)
{
delete m_ffStates[i];
}
// delete hypotheses that are not in the chart (recombined away)
if (m_arcList) {
if (m_arcList)
{
ChartArcList::iterator iter;
for (iter = m_arcList->begin() ; iter != m_arcList->end() ; ++iter) {
for (iter = m_arcList->begin() ; iter != m_arcList->end() ; ++iter)
{
ChartHypothesis *hypo = *iter;
Delete(hypo);
}
@ -109,19 +112,25 @@ void ChartHypothesis::GetOutputPhrase(Phrase &outPhrase) const
{
FactorType placeholderFactor = StaticData::Instance().GetPlaceholderFactor();
for (size_t pos = 0; pos < GetCurrTargetPhrase().GetSize(); ++pos) {
for (size_t pos = 0; pos < GetCurrTargetPhrase().GetSize(); ++pos)
{
const Word &word = GetCurrTargetPhrase().GetWord(pos);
if (word.IsNonTerminal()) {
if (word.IsNonTerminal())
{
// non-term. fill out with prev hypo
size_t nonTermInd = GetCurrTargetPhrase().GetAlignNonTerm().GetNonTermIndexMap()[pos];
const ChartHypothesis *prevHypo = m_prevHypos[nonTermInd];
prevHypo->GetOutputPhrase(outPhrase);
} else {
}
else
{
outPhrase.AddWord(word);
if (placeholderFactor != NOT_FOUND) {
if (placeholderFactor != NOT_FOUND)
{
std::set<size_t> sourcePosSet = GetCurrTargetPhrase().GetAlignTerm().GetAlignmentsForTarget(pos);
if (sourcePosSet.size() == 1) {
if (sourcePosSet.size() == 1)
{
const std::vector<const Word*> *ruleSourceFromInputPath = GetTranslationOption().GetSourceRuleFromInputPath();
UTIL_THROW_IF2(ruleSourceFromInputPath == NULL,
"No source rule");
@ -131,7 +140,8 @@ void ChartHypothesis::GetOutputPhrase(Phrase &outPhrase) const
UTIL_THROW_IF2(sourceWord == NULL,
"No source word");
const Factor *factor = sourceWord->GetFactor(placeholderFactor);
if (factor) {
if (factor)
{
outPhrase.Back()[0] = factor;
}
}
@ -149,37 +159,45 @@ Phrase ChartHypothesis::GetOutputPhrase() const
return outPhrase;
}
void ChartHypothesis::GetOutputPhrase(int leftRightMost, int numWords, Phrase &outPhrase) const
/** TODO: this method isn't used anywhere. Remove? */
void ChartHypothesis::GetOutputPhrase(size_t leftRightMost, size_t numWords, Phrase &outPhrase) const
{
const TargetPhrase &tp = GetCurrTargetPhrase();
int targetSize = tp.GetSize();
for (int i = 0; i < targetSize; ++i) {
int pos;
if (leftRightMost == 1) {
pos = i;
}
else if (leftRightMost == 2) {
pos = targetSize - i - 1;
}
else {
abort();
}
size_t targetSize = tp.GetSize();
for (size_t i = 0; i < targetSize; ++i)
{
size_t pos;
if (leftRightMost == 1)
{
pos = i;
}
else if (leftRightMost == 2)
{
pos = targetSize - i - 1;
}
else
{
abort();
}
const Word &word = tp.GetWord(pos);
const Word &word = tp.GetWord(pos);
if (word.IsNonTerminal()) {
// non-term. fill out with prev hypo
size_t nonTermInd = tp.GetAlignNonTerm().GetNonTermIndexMap()[pos];
const ChartHypothesis *prevHypo = m_prevHypos[nonTermInd];
prevHypo->GetOutputPhrase(outPhrase);
} else {
outPhrase.AddWord(word);
}
if (word.IsNonTerminal())
{
// non-term. fill out with prev hypo
size_t nonTermInd = tp.GetAlignNonTerm().GetNonTermIndexMap()[pos];
const ChartHypothesis *prevHypo = m_prevHypos[nonTermInd];
prevHypo->GetOutputPhrase(outPhrase);
}
else
{
outPhrase.AddWord(word);
}
if (outPhrase.GetSize() >= numWords) {
return;
}
if (outPhrase.GetSize() >= numWords) {
return;
}
}
}
@ -209,64 +227,71 @@ int ChartHypothesis::RecombineCompare(const ChartHypothesis &compare) const
return 0;
}
/** calculate total score
* @todo this should be in ScoreBreakdown
*/
/** calculate total score */
void ChartHypothesis::EvaluateWhenApplied()
{
const StaticData &staticData = StaticData::Instance();
// total scores from prev hypos
std::vector<const ChartHypothesis*>::iterator iter;
for (iter = m_prevHypos.begin(); iter != m_prevHypos.end(); ++iter) {
const ChartHypothesis &prevHypo = **iter;
const ScoreComponentCollection &scoreBreakdown = prevHypo.GetScoreBreakdown();
m_scoreBreakdown.PlusEquals(scoreBreakdown);
}
// scores from current translation rule. eg. translation models & word penalty
const ScoreComponentCollection &scoreBreakdown = GetTranslationOption().GetScores();
m_scoreBreakdown.PlusEquals(scoreBreakdown);
// compute values of stateless feature functions that were not
// cached in the translation option-- there is no principled distinction
const std::vector<const StatelessFeatureFunction*>& sfs =
StatelessFeatureFunction::GetStatelessFeatureFunctions();
for (unsigned i = 0; i < sfs.size(); ++i) {
if (! staticData.IsFeatureFunctionIgnored( *sfs[i] )) {
sfs[i]->EvaluateWhenApplied(*this,&m_scoreBreakdown);
for (unsigned i = 0; i < sfs.size(); ++i)
{
if (! staticData.IsFeatureFunctionIgnored( *sfs[i] ))
{
sfs[i]->EvaluateWhenApplied(*this,&m_currScoreBreakdown);
}
}
const std::vector<const StatefulFeatureFunction*>& ffs =
StatefulFeatureFunction::GetStatefulFeatureFunctions();
for (unsigned i = 0; i < ffs.size(); ++i) {
if (! staticData.IsFeatureFunctionIgnored( *ffs[i] )) {
m_ffStates[i] = ffs[i]->EvaluateWhenApplied(*this,i,&m_scoreBreakdown);
for (unsigned i = 0; i < ffs.size(); ++i)
{
if (! staticData.IsFeatureFunctionIgnored( *ffs[i] ))
{
m_ffStates[i] = ffs[i]->EvaluateWhenApplied(*this,i,&m_currScoreBreakdown);
}
}
m_totalScore = m_scoreBreakdown.GetWeightedScore();
// total score from current translation rule
m_totalScore = GetTranslationOption().GetScores().GetWeightedScore();
m_totalScore += m_currScoreBreakdown.GetWeightedScore();
// total scores from prev hypos
for (std::vector<const ChartHypothesis*>::const_iterator iter = m_prevHypos.begin(); iter != m_prevHypos.end(); ++iter) {
const ChartHypothesis &prevHypo = **iter;
m_totalScore += prevHypo.GetTotalScore();
}
}
void ChartHypothesis::AddArc(ChartHypothesis *loserHypo)
{
if (!m_arcList) {
if (loserHypo->m_arcList) { // we don't have an arcList, but loser does
if (!m_arcList)
{
if (loserHypo->m_arcList)
{ // we don't have an arcList, but loser does
this->m_arcList = loserHypo->m_arcList; // take ownership, we'll delete
loserHypo->m_arcList = 0; // prevent a double deletion
} else {
}
else
{
this->m_arcList = new ChartArcList();
}
} else {
if (loserHypo->m_arcList) { // both have an arc list: merge. delete loser
}
else
{
if (loserHypo->m_arcList)
{ // both have an arc list: merge. delete loser
size_t my_size = m_arcList->size();
size_t add_size = loserHypo->m_arcList->size();
this->m_arcList->resize(my_size + add_size, 0);
std::memcpy(&(*m_arcList)[0] + my_size, &(*loserHypo->m_arcList)[0], add_size * sizeof(ChartHypothesis *));
delete loserHypo->m_arcList;
loserHypo->m_arcList = 0;
} else { // loserHypo doesn't have any arcs
}
else
{ // loserHypo doesn't have any arcs
// DO NOTHING
}
}
@ -274,8 +299,10 @@ void ChartHypothesis::AddArc(ChartHypothesis *loserHypo)
}
// sorting helper
struct CompareChartChartHypothesisTotalScore {
bool operator()(const ChartHypothesis* hypo1, const ChartHypothesis* hypo2) const {
struct CompareChartHypothesisTotalScore
{
bool operator()(const ChartHypothesis* hypo1, const ChartHypothesis* hypo2) const
{
return hypo1->GetTotalScore() > hypo2->GetTotalScore();
}
};
@ -295,16 +322,18 @@ void ChartHypothesis::CleanupArcList()
size_t nBestSize = staticData.GetNBestSize();
bool distinctNBest = staticData.GetDistinctNBest() || staticData.UseMBR() || staticData.GetOutputSearchGraph() || staticData.GetOutputSearchGraphHypergraph();
if (!distinctNBest && m_arcList->size() > nBestSize) {
if (!distinctNBest && m_arcList->size() > nBestSize)
{
// prune arc list only if there too many arcs
NTH_ELEMENT4(m_arcList->begin()
, m_arcList->begin() + nBestSize - 1
, m_arcList->end()
, CompareChartChartHypothesisTotalScore());
, CompareChartHypothesisTotalScore());
// delete bad ones
ChartArcList::iterator iter;
for (iter = m_arcList->begin() + nBestSize ; iter != m_arcList->end() ; ++iter) {
for (iter = m_arcList->begin() + nBestSize ; iter != m_arcList->end() ; ++iter)
{
ChartHypothesis *arc = *iter;
ChartHypothesis::Delete(arc);
}
@ -314,7 +343,8 @@ void ChartHypothesis::CleanupArcList()
// set all arc's main hypo variable to this hypo
ChartArcList::iterator iter = m_arcList->begin();
for (; iter != m_arcList->end() ; ++iter) {
for (; iter != m_arcList->end() ; ++iter)
{
ChartHypothesis *arc = *iter;
arc->SetWinningHypo(this);
}
@ -337,11 +367,13 @@ std::ostream& operator<<(std::ostream& out, const ChartHypothesis& hypo)
// recombination
if (hypo.GetWinningHypothesis() != NULL &&
hypo.GetWinningHypothesis() != &hypo) {
hypo.GetWinningHypothesis() != &hypo)
{
out << "->" << hypo.GetWinningHypothesis()->GetId();
}
if (StaticData::Instance().GetIncludeLHSInSearchGraph()) {
if (StaticData::Instance().GetIncludeLHSInSearchGraph())
{
out << " " << hypo.GetTargetLHS() << "=>";
}
out << " " << hypo.GetCurrTargetPhrase()
@ -349,7 +381,8 @@ std::ostream& operator<<(std::ostream& out, const ChartHypothesis& hypo)
<< " " << hypo.GetCurrSourceRange();
HypoList::const_iterator iter;
for (iter = hypo.GetPrevHypos().begin(); iter != hypo.GetPrevHypos().end(); ++iter) {
for (iter = hypo.GetPrevHypos().begin(); iter != hypo.GetPrevHypos().end(); ++iter)
{
const ChartHypothesis &prevHypo = **iter;
out << " " << prevHypo.GetId();
}

View File

@ -21,6 +21,7 @@
#pragma once
#include <vector>
#include <boost/scoped_ptr.hpp>
#include "Util.h"
#include "WordsRange.h"
#include "ScoreComponentCollection.h"
@ -45,7 +46,7 @@ typedef std::vector<ChartHypothesis*> ChartArcList;
class ChartHypothesis
{
friend std::ostream& operator<<(std::ostream&, const ChartHypothesis&);
friend class ChartKBestExtractor;
// friend class ChartKBestExtractor;
protected:
#ifdef USE_HYPO_POOL
@ -56,7 +57,10 @@ protected:
WordsRange m_currSourceWordsRange;
std::vector<const FFState*> m_ffStates; /*! stateful feature function states */
ScoreComponentCollection m_scoreBreakdown /*! detailed score break-down by components (for instance language model, word penalty, etc) */
/*! sum of scores of this hypothesis, and previous hypotheses. Lazily initialised. */
mutable boost::scoped_ptr<ScoreComponentCollection> m_scoreBreakdown;
mutable boost::scoped_ptr<ScoreComponentCollection> m_deltaScoreBreakdown;
ScoreComponentCollection m_currScoreBreakdown /*! scores for this hypothesis only */
,m_lmNGram
,m_lmPrefix;
float m_totalScore;
@ -76,23 +80,23 @@ protected:
//! not implemented
ChartHypothesis(const ChartHypothesis &copy);
//! only used by ChartKBestExtractor
ChartHypothesis(const ChartHypothesis &, const ChartKBestExtractor &);
public:
#ifdef USE_HYPO_POOL
void *operator new(size_t /* num_bytes */) {
void *operator new(size_t /* num_bytes */)
{
void *ptr = s_objectPool.getPtr();
return ptr;
}
//! delete \param hypo. Works with object pool too
static void Delete(ChartHypothesis *hypo) {
static void Delete(ChartHypothesis *hypo)
{
s_objectPool.freeObject(hypo);
}
#else
//! delete \param hypo. Works with object pool too
static void Delete(ChartHypothesis *hypo) {
static void Delete(ChartHypothesis *hypo)
{
delete hypo;
}
#endif
@ -100,38 +104,48 @@ public:
ChartHypothesis(const ChartTranslationOptions &, const RuleCubeItem &item,
ChartManager &manager);
//! only used by ChartKBestExtractor
ChartHypothesis(const ChartHypothesis &, const ChartKBestExtractor &);
~ChartHypothesis();
unsigned GetId() const {
unsigned GetId() const
{
return m_id;
}
const ChartTranslationOption &GetTranslationOption()const {
const ChartTranslationOption &GetTranslationOption() const
{
return *m_transOpt;
}
//! Get the rule that created this hypothesis
const TargetPhrase &GetCurrTargetPhrase()const {
const TargetPhrase &GetCurrTargetPhrase() const
{
return m_transOpt->GetPhrase();
}
//! the source range that this hypothesis spans
const WordsRange &GetCurrSourceRange()const {
const WordsRange &GetCurrSourceRange() const
{
return m_currSourceWordsRange;
}
//! the arc list when creating n-best lists
inline const ChartArcList* GetArcList() const {
inline const ChartArcList* GetArcList() const
{
return m_arcList;
}
//! the feature function states for a particular feature \param featureID
inline const FFState* GetFFState( size_t featureID ) const {
inline const FFState* GetFFState( size_t featureID ) const
{
return m_ffStates[ featureID ];
}
//! reference back to the manager
inline const ChartManager& GetManager() const {
inline const ChartManager& GetManager() const
{
return m_manager;
}
@ -140,7 +154,7 @@ public:
// get leftmost/rightmost words only
// leftRightMost: 1=left, 2=right
void GetOutputPhrase(int leftRightMost, int numWords, Phrase &outPhrase) const;
void GetOutputPhrase(size_t leftRightMost, size_t numWords, Phrase &outPhrase) const;
int RecombineCompare(const ChartHypothesis &compare) const;
@ -151,32 +165,74 @@ public:
void SetWinningHypo(const ChartHypothesis *hypo);
//! get the unweighted score for each feature function
const ScoreComponentCollection &GetScoreBreakdown() const {
return m_scoreBreakdown;
const ScoreComponentCollection &GetScoreBreakdown() const
{
// Note: never call this method before m_currScoreBreakdown is fully computed
if (!m_scoreBreakdown.get())
{
m_scoreBreakdown.reset(new ScoreComponentCollection());
// score breakdown from current translation rule
if (m_transOpt)
{
m_scoreBreakdown->PlusEquals(GetTranslationOption().GetScores());
}
m_scoreBreakdown->PlusEquals(m_currScoreBreakdown);
// score breakdowns from prev hypos
for (std::vector<const ChartHypothesis*>::const_iterator iter = m_prevHypos.begin(); iter != m_prevHypos.end(); ++iter)
{
const ChartHypothesis &prevHypo = **iter;
m_scoreBreakdown->PlusEquals(prevHypo.GetScoreBreakdown());
}
}
return *(m_scoreBreakdown.get());
}
//! get the unweighted score delta for each feature function
const ScoreComponentCollection &GetDeltaScoreBreakdown() const
{
// Note: never call this method before m_currScoreBreakdown is fully computed
if (!m_deltaScoreBreakdown.get())
{
m_deltaScoreBreakdown.reset(new ScoreComponentCollection());
// score breakdown from current translation rule
if (m_transOpt)
{
m_deltaScoreBreakdown->PlusEquals(GetTranslationOption().GetScores());
}
m_deltaScoreBreakdown->PlusEquals(m_currScoreBreakdown);
// delta: score breakdowns from prev hypos _not_ added
}
return *(m_deltaScoreBreakdown.get());
}
//! Get the weighted total score
float GetTotalScore() const {
float GetTotalScore() const
{
// scores from current translation rule. eg. translation models & word penalty
return m_totalScore;
}
//! vector of previous hypotheses this hypo is built on
const std::vector<const ChartHypothesis*> &GetPrevHypos() const {
const std::vector<const ChartHypothesis*> &GetPrevHypos() const
{
return m_prevHypos;
}
//! get a particular previous hypos
const ChartHypothesis* GetPrevHypo(size_t pos) const {
const ChartHypothesis* GetPrevHypo(size_t pos) const
{
return m_prevHypos[pos];
}
//! get the constituency label that covers this hypo
const Word &GetTargetLHS() const {
const Word &GetTargetLHS() const
{
return GetCurrTargetPhrase().GetTargetLHS();
}
//! get the best hypo in the arc list when doing n-best list creation. It's either this hypothesis, or the best hypo is this hypo is in the arc list
const ChartHypothesis* GetWinningHypothesis() const {
const ChartHypothesis* GetWinningHypothesis() const
{
return m_winningHypo;
}

View File

@ -124,6 +124,28 @@ Phrase ChartKBestExtractor::GetOutputPhrase(const Derivation &d)
return ret;
}
// Generate the score breakdown of the derivation d.
boost::shared_ptr<ScoreComponentCollection>
ChartKBestExtractor::GetOutputScoreBreakdown(const Derivation &d)
{
const ChartHypothesis &hypo = d.edge.head->hypothesis;
boost::shared_ptr<ScoreComponentCollection> scoreBreakdown(new ScoreComponentCollection());
scoreBreakdown->PlusEquals(hypo.GetDeltaScoreBreakdown());
const TargetPhrase &phrase = hypo.GetCurrTargetPhrase();
const AlignmentInfo::NonTermIndexMap &nonTermIndexMap =
phrase.GetAlignNonTerm().GetNonTermIndexMap();
for (std::size_t pos = 0; pos < phrase.GetSize(); ++pos) {
const Word &word = phrase.GetWord(pos);
if (word.IsNonTerminal()) {
std::size_t nonTermInd = nonTermIndexMap[pos];
const Derivation &subderivation = *d.subderivations[nonTermInd];
scoreBreakdown->PlusEquals(*GetOutputScoreBreakdown(subderivation));
}
}
return scoreBreakdown;
}
// Generate the target tree of the derivation d.
TreePointer ChartKBestExtractor::GetOutputTree(const Derivation &d)
{
@ -286,7 +308,6 @@ ChartKBestExtractor::Derivation::Derivation(const UnweightedHyperarc &e)
boost::shared_ptr<Derivation> sub(pred.kBestList[0]);
subderivations.push_back(sub);
}
scoreBreakdown = edge.head->hypothesis.GetScoreBreakdown();
score = edge.head->hypothesis.GetTotalScore();
}
@ -298,15 +319,14 @@ ChartKBestExtractor::Derivation::Derivation(const Derivation &d, std::size_t i)
backPointers = d.backPointers;
subderivations = d.subderivations;
std::size_t j = ++backPointers[i];
scoreBreakdown = d.scoreBreakdown;
score = d.score;
// Deduct the score of the old subderivation.
scoreBreakdown.MinusEquals(subderivations[i]->scoreBreakdown);
score -= subderivations[i]->score;
// Update the subderivation pointer.
boost::shared_ptr<Derivation> newSub(edge.tail[i]->kBestList[j]);
subderivations[i] = newSub;
// Add the score of the new subderivation.
scoreBreakdown.PlusEquals(subderivations[i]->scoreBreakdown);
score = scoreBreakdown.GetWeightedScore();
score += subderivations[i]->score;
}
} // namespace Moses

View File

@ -26,6 +26,7 @@
#include <boost/unordered_set.hpp>
#include <boost/weak_ptr.hpp>
#include <boost/shared_ptr.hpp>
#include <queue>
#include <vector>
@ -56,7 +57,6 @@ public:
UnweightedHyperarc edge;
std::vector<std::size_t> backPointers;
std::vector<boost::shared_ptr<Derivation> > subderivations;
ScoreComponentCollection scoreBreakdown;
float score;
};
@ -90,6 +90,7 @@ public:
std::size_t k, KBestVec &);
static Phrase GetOutputPhrase(const Derivation &);
static boost::shared_ptr<ScoreComponentCollection> GetOutputScoreBreakdown(const Derivation &);
static TreePointer GetOutputTree(const Derivation &);
private:

View File

@ -365,7 +365,8 @@ void ChartManager::OutputNBestList(OutputCollector *collector,
out << translationId << " ||| ";
OutputSurface(out, outputPhrase, outputFactorOrder, false);
out << " ||| ";
derivation.scoreBreakdown.OutputAllFeatureScores(out);
boost::shared_ptr<ScoreComponentCollection> scoreBreakdown = ChartKBestExtractor::GetOutputScoreBreakdown(derivation);
scoreBreakdown->OutputAllFeatureScores(out);
out << " ||| " << derivation.score;
// optionally, print word alignments

View File

@ -29,6 +29,7 @@ public:
const InputPath *GetInputPath() const {
return m_inputPath;
}
void SetInputPath(const InputPath *inputPath) {
m_inputPath = inputPath;
}

View File

@ -385,6 +385,15 @@ void FVector::sparsePlusEquals(const FVector& rhs)
set(i->first, get(i->first) + i->second);
}
// add only core features
void FVector::corePlusEquals(const FVector& rhs)
{
if (rhs.m_coreFeatures.size() > m_coreFeatures.size())
resize(rhs.m_coreFeatures.size());
for (size_t i = 0; i < rhs.m_coreFeatures.size(); ++i)
m_coreFeatures[i] += rhs.m_coreFeatures[i];
}
// assign only core features
void FVector::coreAssign(const FVector& rhs)
{

View File

@ -235,6 +235,7 @@ public:
void capMin(FValue minValue);
void sparsePlusEquals(const FVector& rhs);
void corePlusEquals(const FVector& rhs);
void coreAssign(const FVector& rhs);
void incrementSparseHopeFeatures();

View File

@ -280,7 +280,7 @@ FFState* LanguageModelImplementation::EvaluateWhenApplied(const ChartHypothesis&
// get prefixScore and finalizedScore
prefixScore = prevState->GetPrefixScore();
finalizedScore = prevHypo->GetScoreBreakdown().GetScoresForProducer(this)[0] - prefixScore;
finalizedScore = -prefixScore;
// get language model state
delete lmState;
@ -308,13 +308,10 @@ FFState* LanguageModelImplementation::EvaluateWhenApplied(const ChartHypothesis&
updateChartScore( &prefixScore, &finalizedScore, GetValueGivenState(contextFactor, *lmState).score, ++wordPos );
}
finalizedScore -= prevState->GetPrefixScore();
// check if we are dealing with a large sub-phrase
if (subPhraseLength > GetNGramOrder() - 1) {
// add its finalized language model score
finalizedScore +=
prevHypo->GetScoreBreakdown().GetScoresForProducer(this)[0] // full score
- prevState->GetPrefixScore(); // - prefix score
// copy language model state
delete lmState;
lmState = NewState( prevState->GetRightContext() );
@ -337,15 +334,16 @@ FFState* LanguageModelImplementation::EvaluateWhenApplied(const ChartHypothesis&
}
}
// assign combined score to score breakdown
// add combined score to score breakdown
if (OOVFeatureEnabled()) {
vector<float> scores(2);
scores[0] = prefixScore + finalizedScore;
scores[1] = out->GetScoresForProducer(this)[1];
out->Assign(this, scores);
scores[0] = prefixScore + finalizedScore - hypo.GetTranslationOption().GetScores().GetScoresForProducer(this)[0];
// scores[1] = out->GetScoresForProducer(this)[1];
scores[1] = 0;
out->PlusEquals(this, scores);
}
else {
out->Assign(this, prefixScore + finalizedScore);
out->PlusEquals(this, prefixScore + finalizedScore - hypo.GetTranslationOption().GetScores().GetScoresForProducer(this)[0]);
}
ret->Set(prefixScore, lmState);

View File

@ -329,8 +329,7 @@ template <class Model> FFState *LanguageModelKen<Model>::EvaluateWhenApplied(con
// Non-terminal is first so we can copy instead of rescoring.
const ChartHypothesis *prevHypo = hypo.GetPrevHypo(nonTermIndexMap[phrasePos]);
const lm::ngram::ChartState &prevState = static_cast<const LanguageModelChartStateKenLM*>(prevHypo->GetFFState(featureID))->GetChartState();
float prob = UntransformLMScore(prevHypo->GetScoreBreakdown().GetScoresForProducer(this)[0]);
ruleScore.BeginNonTerminal(prevState, prob);
ruleScore.BeginNonTerminal(prevState);
phrasePos++;
}
}
@ -340,8 +339,7 @@ template <class Model> FFState *LanguageModelKen<Model>::EvaluateWhenApplied(con
if (word.IsNonTerminal()) {
const ChartHypothesis *prevHypo = hypo.GetPrevHypo(nonTermIndexMap[phrasePos]);
const lm::ngram::ChartState &prevState = static_cast<const LanguageModelChartStateKenLM*>(prevHypo->GetFFState(featureID))->GetChartState();
float prob = UntransformLMScore(prevHypo->GetScoreBreakdown().GetScoresForProducer(this)[0]);
ruleScore.NonTerminal(prevState, prob);
ruleScore.NonTerminal(prevState);
} else {
ruleScore.Terminal(TranslateID(word));
}
@ -349,62 +347,64 @@ template <class Model> FFState *LanguageModelKen<Model>::EvaluateWhenApplied(con
float score = ruleScore.Finish();
score = TransformLMScore(score);
score -= hypo.GetTranslationOption().GetScores().GetScoresForProducer(this)[0];
if (OOVFeatureEnabled()) {
std::vector<float> scores(2);
scores[0] = score;
scores[1] = 0.0;
accumulator->Assign(this, scores);
accumulator->PlusEquals(this, scores);
}
else {
accumulator->Assign(this, score);
accumulator->PlusEquals(this, score);
}
return newState;
}
template <class Model> FFState *LanguageModelKen<Model>::EvaluateWhenApplied(const Syntax::SHyperedge& hyperedge, int featureID, ScoreComponentCollection *accumulator) const
{
LanguageModelChartStateKenLM *newState = new LanguageModelChartStateKenLM();
lm::ngram::RuleScore<Model> ruleScore(*m_ngram, newState->GetChartState());
const TargetPhrase &target = *hyperedge.translation;
const AlignmentInfo::NonTermIndexMap &nonTermIndexMap =
target.GetAlignNonTerm().GetNonTermIndexMap2();
const size_t size = target.GetSize();
size_t phrasePos = 0;
// Special cases for first word.
if (size) {
const Word &word = target.GetWord(0);
if (word.GetFactor(m_factorType) == m_beginSentenceFactor) {
// Begin of sentence
ruleScore.BeginSentence();
phrasePos++;
} else if (word.IsNonTerminal()) {
// Non-terminal is first so we can copy instead of rescoring.
const Syntax::SVertex *pred = hyperedge.tail[nonTermIndexMap[phrasePos]];
const lm::ngram::ChartState &prevState = static_cast<const LanguageModelChartStateKenLM*>(pred->state[featureID])->GetChartState();
float prob = UntransformLMScore(pred->best->scoreBreakdown.GetScoresForProducer(this)[0]);
ruleScore.BeginNonTerminal(prevState, prob);
phrasePos++;
}
}
for (; phrasePos < size; phrasePos++) {
const Word &word = target.GetWord(phrasePos);
if (word.IsNonTerminal()) {
const Syntax::SVertex *pred = hyperedge.tail[nonTermIndexMap[phrasePos]];
const lm::ngram::ChartState &prevState = static_cast<const LanguageModelChartStateKenLM*>(pred->state[featureID])->GetChartState();
float prob = UntransformLMScore(pred->best->scoreBreakdown.GetScoresForProducer(this)[0]);
ruleScore.NonTerminal(prevState, prob);
} else {
ruleScore.Terminal(TranslateID(word));
}
}
float score = ruleScore.Finish();
score = TransformLMScore(score);
accumulator->Assign(this, score);
return newState;
}
//template <class Model> FFState *LanguageModelKen<Model>::EvaluateWhenApplied(const Syntax::SHyperedge& hyperedge, int featureID, ScoreComponentCollection *accumulator) const
//{
// LanguageModelChartStateKenLM *newState = new LanguageModelChartStateKenLM();
// lm::ngram::RuleScore<Model> ruleScore(*m_ngram, newState->GetChartState());
// const TargetPhrase &target = *hyperedge.translation;
// const AlignmentInfo::NonTermIndexMap &nonTermIndexMap =
// target.GetAlignNonTerm().GetNonTermIndexMap2();
//
// const size_t size = target.GetSize();
// size_t phrasePos = 0;
// // Special cases for first word.
// if (size) {
// const Word &word = target.GetWord(0);
// if (word.GetFactor(m_factorType) == m_beginSentenceFactor) {
// // Begin of sentence
// ruleScore.BeginSentence();
// phrasePos++;
// } else if (word.IsNonTerminal()) {
// // Non-terminal is first so we can copy instead of rescoring.
// const Syntax::SVertex *pred = hyperedge.tail[nonTermIndexMap[phrasePos]];
// const lm::ngram::ChartState &prevState = static_cast<const LanguageModelChartStateKenLM*>(pred->state[featureID])->GetChartState();
// float prob = UntransformLMScore(pred->best->scoreBreakdown.GetScoresForProducer(this)[0]);
// ruleScore.BeginNonTerminal(prevState, prob);
// phrasePos++;
// }
// }
//
// for (; phrasePos < size; phrasePos++) {
// const Word &word = target.GetWord(phrasePos);
// if (word.IsNonTerminal()) {
// const Syntax::SVertex *pred = hyperedge.tail[nonTermIndexMap[phrasePos]];
// const lm::ngram::ChartState &prevState = static_cast<const LanguageModelChartStateKenLM*>(pred->state[featureID])->GetChartState();
// float prob = UntransformLMScore(pred->best->scoreBreakdown.GetScoresForProducer(this)[0]);
// ruleScore.NonTerminal(prevState, prob);
// } else {
// ruleScore.Terminal(TranslateID(word));
// }
// }
//
// float score = ruleScore.Finish();
// score = TransformLMScore(score);
// accumulator->Assign(this, score);
// return newState;
//}
template <class Model> void LanguageModelKen<Model>::IncrementalCallback(Incremental::Manager &manager) const
{

View File

@ -59,7 +59,7 @@ public:
virtual FFState *EvaluateWhenApplied(const ChartHypothesis& cur_hypo, int featureID, ScoreComponentCollection *accumulator) const;
virtual FFState *EvaluateWhenApplied(const Syntax::SHyperedge& hyperedge, int featureID, ScoreComponentCollection *accumulator) const;
// virtual FFState *EvaluateWhenApplied(const Syntax::SHyperedge& hyperedge, int featureID, ScoreComponentCollection *accumulator) const;
virtual void IncrementalCallback(Incremental::Manager &manager) const;
virtual void ReportHistoryOrder(std::ostream &out,const Phrase &phrase) const;

View File

@ -200,6 +200,11 @@ public:
m_scores.sparsePlusEquals(rhs.m_scores);
}
// add only core features
void CorePlusEquals(const ScoreComponentCollection& rhs) {
m_scores.corePlusEquals(rhs.m_scores);
}
void PlusEquals(const FVector& scores) {
m_scores += scores;
}