mirror of
https://github.com/moses-smt/mosesdecoder.git
synced 2024-10-05 15:58:03 +03:00
score deltas in chart decoding
This commit is contained in:
parent
0441fd6ab9
commit
465b475664
@ -229,7 +229,7 @@ vector< vector<const Word*> > MosesDecoder::runChartDecoder(const std::string& s
|
||||
for (ChartKBestExtractor::KBestVec::const_iterator p = nBestList.begin();
|
||||
p != nBestList.end(); ++p) {
|
||||
const ChartKBestExtractor::Derivation &derivation = **p;
|
||||
featureValues.push_back(derivation.scoreBreakdown);
|
||||
featureValues.push_back(*ChartKBestExtractor::GetOutputScoreBreakdown(derivation));
|
||||
float bleuScore, dynBleuScore, realBleuScore;
|
||||
dynBleuScore = getBleuScore(featureValues.back());
|
||||
Phrase outputPhrase = ChartKBestExtractor::GetOutputPhrase(derivation);
|
||||
|
@ -61,7 +61,8 @@ ChartHypothesis::ChartHypothesis(const ChartTranslationOptions &transOpt,
|
||||
const std::vector<HypothesisDimension> &childEntries = item.GetHypothesisDimensions();
|
||||
m_prevHypos.reserve(childEntries.size());
|
||||
std::vector<HypothesisDimension>::const_iterator iter;
|
||||
for (iter = childEntries.begin(); iter != childEntries.end(); ++iter) {
|
||||
for (iter = childEntries.begin(); iter != childEntries.end(); ++iter)
|
||||
{
|
||||
m_prevHypos.push_back(iter->GetHypothesis());
|
||||
}
|
||||
}
|
||||
@ -71,7 +72,6 @@ ChartHypothesis::ChartHypothesis(const ChartTranslationOptions &transOpt,
|
||||
ChartHypothesis::ChartHypothesis(const ChartHypothesis &pred,
|
||||
const ChartKBestExtractor & /*unused*/)
|
||||
:m_currSourceWordsRange(pred.m_currSourceWordsRange)
|
||||
,m_scoreBreakdown(pred.m_scoreBreakdown)
|
||||
,m_totalScore(pred.m_totalScore)
|
||||
,m_arcList(NULL)
|
||||
,m_winningHypo(NULL)
|
||||
@ -85,14 +85,17 @@ ChartHypothesis::ChartHypothesis(const ChartHypothesis &pred,
|
||||
ChartHypothesis::~ChartHypothesis()
|
||||
{
|
||||
// delete feature function states
|
||||
for (unsigned i = 0; i < m_ffStates.size(); ++i) {
|
||||
for (unsigned i = 0; i < m_ffStates.size(); ++i)
|
||||
{
|
||||
delete m_ffStates[i];
|
||||
}
|
||||
|
||||
// delete hypotheses that are not in the chart (recombined away)
|
||||
if (m_arcList) {
|
||||
if (m_arcList)
|
||||
{
|
||||
ChartArcList::iterator iter;
|
||||
for (iter = m_arcList->begin() ; iter != m_arcList->end() ; ++iter) {
|
||||
for (iter = m_arcList->begin() ; iter != m_arcList->end() ; ++iter)
|
||||
{
|
||||
ChartHypothesis *hypo = *iter;
|
||||
Delete(hypo);
|
||||
}
|
||||
@ -109,19 +112,25 @@ void ChartHypothesis::GetOutputPhrase(Phrase &outPhrase) const
|
||||
{
|
||||
FactorType placeholderFactor = StaticData::Instance().GetPlaceholderFactor();
|
||||
|
||||
for (size_t pos = 0; pos < GetCurrTargetPhrase().GetSize(); ++pos) {
|
||||
for (size_t pos = 0; pos < GetCurrTargetPhrase().GetSize(); ++pos)
|
||||
{
|
||||
const Word &word = GetCurrTargetPhrase().GetWord(pos);
|
||||
if (word.IsNonTerminal()) {
|
||||
if (word.IsNonTerminal())
|
||||
{
|
||||
// non-term. fill out with prev hypo
|
||||
size_t nonTermInd = GetCurrTargetPhrase().GetAlignNonTerm().GetNonTermIndexMap()[pos];
|
||||
const ChartHypothesis *prevHypo = m_prevHypos[nonTermInd];
|
||||
prevHypo->GetOutputPhrase(outPhrase);
|
||||
} else {
|
||||
}
|
||||
else
|
||||
{
|
||||
outPhrase.AddWord(word);
|
||||
|
||||
if (placeholderFactor != NOT_FOUND) {
|
||||
if (placeholderFactor != NOT_FOUND)
|
||||
{
|
||||
std::set<size_t> sourcePosSet = GetCurrTargetPhrase().GetAlignTerm().GetAlignmentsForTarget(pos);
|
||||
if (sourcePosSet.size() == 1) {
|
||||
if (sourcePosSet.size() == 1)
|
||||
{
|
||||
const std::vector<const Word*> *ruleSourceFromInputPath = GetTranslationOption().GetSourceRuleFromInputPath();
|
||||
UTIL_THROW_IF2(ruleSourceFromInputPath == NULL,
|
||||
"No source rule");
|
||||
@ -131,7 +140,8 @@ void ChartHypothesis::GetOutputPhrase(Phrase &outPhrase) const
|
||||
UTIL_THROW_IF2(sourceWord == NULL,
|
||||
"No source word");
|
||||
const Factor *factor = sourceWord->GetFactor(placeholderFactor);
|
||||
if (factor) {
|
||||
if (factor)
|
||||
{
|
||||
outPhrase.Back()[0] = factor;
|
||||
}
|
||||
}
|
||||
@ -149,37 +159,45 @@ Phrase ChartHypothesis::GetOutputPhrase() const
|
||||
return outPhrase;
|
||||
}
|
||||
|
||||
void ChartHypothesis::GetOutputPhrase(int leftRightMost, int numWords, Phrase &outPhrase) const
|
||||
/** TODO: this method isn't used anywhere. Remove? */
|
||||
void ChartHypothesis::GetOutputPhrase(size_t leftRightMost, size_t numWords, Phrase &outPhrase) const
|
||||
{
|
||||
const TargetPhrase &tp = GetCurrTargetPhrase();
|
||||
|
||||
int targetSize = tp.GetSize();
|
||||
for (int i = 0; i < targetSize; ++i) {
|
||||
int pos;
|
||||
if (leftRightMost == 1) {
|
||||
pos = i;
|
||||
}
|
||||
else if (leftRightMost == 2) {
|
||||
pos = targetSize - i - 1;
|
||||
}
|
||||
else {
|
||||
abort();
|
||||
}
|
||||
size_t targetSize = tp.GetSize();
|
||||
for (size_t i = 0; i < targetSize; ++i)
|
||||
{
|
||||
size_t pos;
|
||||
if (leftRightMost == 1)
|
||||
{
|
||||
pos = i;
|
||||
}
|
||||
else if (leftRightMost == 2)
|
||||
{
|
||||
pos = targetSize - i - 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
abort();
|
||||
}
|
||||
|
||||
const Word &word = tp.GetWord(pos);
|
||||
const Word &word = tp.GetWord(pos);
|
||||
|
||||
if (word.IsNonTerminal()) {
|
||||
// non-term. fill out with prev hypo
|
||||
size_t nonTermInd = tp.GetAlignNonTerm().GetNonTermIndexMap()[pos];
|
||||
const ChartHypothesis *prevHypo = m_prevHypos[nonTermInd];
|
||||
prevHypo->GetOutputPhrase(outPhrase);
|
||||
} else {
|
||||
outPhrase.AddWord(word);
|
||||
}
|
||||
if (word.IsNonTerminal())
|
||||
{
|
||||
// non-term. fill out with prev hypo
|
||||
size_t nonTermInd = tp.GetAlignNonTerm().GetNonTermIndexMap()[pos];
|
||||
const ChartHypothesis *prevHypo = m_prevHypos[nonTermInd];
|
||||
prevHypo->GetOutputPhrase(outPhrase);
|
||||
}
|
||||
else
|
||||
{
|
||||
outPhrase.AddWord(word);
|
||||
}
|
||||
|
||||
if (outPhrase.GetSize() >= numWords) {
|
||||
return;
|
||||
}
|
||||
if (outPhrase.GetSize() >= numWords) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -209,64 +227,71 @@ int ChartHypothesis::RecombineCompare(const ChartHypothesis &compare) const
|
||||
return 0;
|
||||
}
|
||||
|
||||
/** calculate total score
|
||||
* @todo this should be in ScoreBreakdown
|
||||
*/
|
||||
/** calculate total score */
|
||||
void ChartHypothesis::EvaluateWhenApplied()
|
||||
{
|
||||
const StaticData &staticData = StaticData::Instance();
|
||||
// total scores from prev hypos
|
||||
std::vector<const ChartHypothesis*>::iterator iter;
|
||||
for (iter = m_prevHypos.begin(); iter != m_prevHypos.end(); ++iter) {
|
||||
const ChartHypothesis &prevHypo = **iter;
|
||||
const ScoreComponentCollection &scoreBreakdown = prevHypo.GetScoreBreakdown();
|
||||
|
||||
m_scoreBreakdown.PlusEquals(scoreBreakdown);
|
||||
}
|
||||
|
||||
// scores from current translation rule. eg. translation models & word penalty
|
||||
const ScoreComponentCollection &scoreBreakdown = GetTranslationOption().GetScores();
|
||||
m_scoreBreakdown.PlusEquals(scoreBreakdown);
|
||||
|
||||
// compute values of stateless feature functions that were not
|
||||
// cached in the translation option-- there is no principled distinction
|
||||
const std::vector<const StatelessFeatureFunction*>& sfs =
|
||||
StatelessFeatureFunction::GetStatelessFeatureFunctions();
|
||||
for (unsigned i = 0; i < sfs.size(); ++i) {
|
||||
if (! staticData.IsFeatureFunctionIgnored( *sfs[i] )) {
|
||||
sfs[i]->EvaluateWhenApplied(*this,&m_scoreBreakdown);
|
||||
for (unsigned i = 0; i < sfs.size(); ++i)
|
||||
{
|
||||
if (! staticData.IsFeatureFunctionIgnored( *sfs[i] ))
|
||||
{
|
||||
sfs[i]->EvaluateWhenApplied(*this,&m_currScoreBreakdown);
|
||||
}
|
||||
}
|
||||
|
||||
const std::vector<const StatefulFeatureFunction*>& ffs =
|
||||
StatefulFeatureFunction::GetStatefulFeatureFunctions();
|
||||
for (unsigned i = 0; i < ffs.size(); ++i) {
|
||||
if (! staticData.IsFeatureFunctionIgnored( *ffs[i] )) {
|
||||
m_ffStates[i] = ffs[i]->EvaluateWhenApplied(*this,i,&m_scoreBreakdown);
|
||||
for (unsigned i = 0; i < ffs.size(); ++i)
|
||||
{
|
||||
if (! staticData.IsFeatureFunctionIgnored( *ffs[i] ))
|
||||
{
|
||||
m_ffStates[i] = ffs[i]->EvaluateWhenApplied(*this,i,&m_currScoreBreakdown);
|
||||
}
|
||||
}
|
||||
|
||||
m_totalScore = m_scoreBreakdown.GetWeightedScore();
|
||||
// total score from current translation rule
|
||||
m_totalScore = GetTranslationOption().GetScores().GetWeightedScore();
|
||||
m_totalScore += m_currScoreBreakdown.GetWeightedScore();
|
||||
|
||||
// total scores from prev hypos
|
||||
for (std::vector<const ChartHypothesis*>::const_iterator iter = m_prevHypos.begin(); iter != m_prevHypos.end(); ++iter) {
|
||||
const ChartHypothesis &prevHypo = **iter;
|
||||
m_totalScore += prevHypo.GetTotalScore();
|
||||
}
|
||||
}
|
||||
|
||||
void ChartHypothesis::AddArc(ChartHypothesis *loserHypo)
|
||||
{
|
||||
if (!m_arcList) {
|
||||
if (loserHypo->m_arcList) { // we don't have an arcList, but loser does
|
||||
if (!m_arcList)
|
||||
{
|
||||
if (loserHypo->m_arcList)
|
||||
{ // we don't have an arcList, but loser does
|
||||
this->m_arcList = loserHypo->m_arcList; // take ownership, we'll delete
|
||||
loserHypo->m_arcList = 0; // prevent a double deletion
|
||||
} else {
|
||||
}
|
||||
else
|
||||
{
|
||||
this->m_arcList = new ChartArcList();
|
||||
}
|
||||
} else {
|
||||
if (loserHypo->m_arcList) { // both have an arc list: merge. delete loser
|
||||
}
|
||||
else
|
||||
{
|
||||
if (loserHypo->m_arcList)
|
||||
{ // both have an arc list: merge. delete loser
|
||||
size_t my_size = m_arcList->size();
|
||||
size_t add_size = loserHypo->m_arcList->size();
|
||||
this->m_arcList->resize(my_size + add_size, 0);
|
||||
std::memcpy(&(*m_arcList)[0] + my_size, &(*loserHypo->m_arcList)[0], add_size * sizeof(ChartHypothesis *));
|
||||
delete loserHypo->m_arcList;
|
||||
loserHypo->m_arcList = 0;
|
||||
} else { // loserHypo doesn't have any arcs
|
||||
}
|
||||
else
|
||||
{ // loserHypo doesn't have any arcs
|
||||
// DO NOTHING
|
||||
}
|
||||
}
|
||||
@ -274,8 +299,10 @@ void ChartHypothesis::AddArc(ChartHypothesis *loserHypo)
|
||||
}
|
||||
|
||||
// sorting helper
|
||||
struct CompareChartChartHypothesisTotalScore {
|
||||
bool operator()(const ChartHypothesis* hypo1, const ChartHypothesis* hypo2) const {
|
||||
struct CompareChartHypothesisTotalScore
|
||||
{
|
||||
bool operator()(const ChartHypothesis* hypo1, const ChartHypothesis* hypo2) const
|
||||
{
|
||||
return hypo1->GetTotalScore() > hypo2->GetTotalScore();
|
||||
}
|
||||
};
|
||||
@ -295,16 +322,18 @@ void ChartHypothesis::CleanupArcList()
|
||||
size_t nBestSize = staticData.GetNBestSize();
|
||||
bool distinctNBest = staticData.GetDistinctNBest() || staticData.UseMBR() || staticData.GetOutputSearchGraph() || staticData.GetOutputSearchGraphHypergraph();
|
||||
|
||||
if (!distinctNBest && m_arcList->size() > nBestSize) {
|
||||
if (!distinctNBest && m_arcList->size() > nBestSize)
|
||||
{
|
||||
// prune arc list only if there too many arcs
|
||||
NTH_ELEMENT4(m_arcList->begin()
|
||||
, m_arcList->begin() + nBestSize - 1
|
||||
, m_arcList->end()
|
||||
, CompareChartChartHypothesisTotalScore());
|
||||
, CompareChartHypothesisTotalScore());
|
||||
|
||||
// delete bad ones
|
||||
ChartArcList::iterator iter;
|
||||
for (iter = m_arcList->begin() + nBestSize ; iter != m_arcList->end() ; ++iter) {
|
||||
for (iter = m_arcList->begin() + nBestSize ; iter != m_arcList->end() ; ++iter)
|
||||
{
|
||||
ChartHypothesis *arc = *iter;
|
||||
ChartHypothesis::Delete(arc);
|
||||
}
|
||||
@ -314,7 +343,8 @@ void ChartHypothesis::CleanupArcList()
|
||||
|
||||
// set all arc's main hypo variable to this hypo
|
||||
ChartArcList::iterator iter = m_arcList->begin();
|
||||
for (; iter != m_arcList->end() ; ++iter) {
|
||||
for (; iter != m_arcList->end() ; ++iter)
|
||||
{
|
||||
ChartHypothesis *arc = *iter;
|
||||
arc->SetWinningHypo(this);
|
||||
}
|
||||
@ -337,11 +367,13 @@ std::ostream& operator<<(std::ostream& out, const ChartHypothesis& hypo)
|
||||
|
||||
// recombination
|
||||
if (hypo.GetWinningHypothesis() != NULL &&
|
||||
hypo.GetWinningHypothesis() != &hypo) {
|
||||
hypo.GetWinningHypothesis() != &hypo)
|
||||
{
|
||||
out << "->" << hypo.GetWinningHypothesis()->GetId();
|
||||
}
|
||||
|
||||
if (StaticData::Instance().GetIncludeLHSInSearchGraph()) {
|
||||
if (StaticData::Instance().GetIncludeLHSInSearchGraph())
|
||||
{
|
||||
out << " " << hypo.GetTargetLHS() << "=>";
|
||||
}
|
||||
out << " " << hypo.GetCurrTargetPhrase()
|
||||
@ -349,7 +381,8 @@ std::ostream& operator<<(std::ostream& out, const ChartHypothesis& hypo)
|
||||
<< " " << hypo.GetCurrSourceRange();
|
||||
|
||||
HypoList::const_iterator iter;
|
||||
for (iter = hypo.GetPrevHypos().begin(); iter != hypo.GetPrevHypos().end(); ++iter) {
|
||||
for (iter = hypo.GetPrevHypos().begin(); iter != hypo.GetPrevHypos().end(); ++iter)
|
||||
{
|
||||
const ChartHypothesis &prevHypo = **iter;
|
||||
out << " " << prevHypo.GetId();
|
||||
}
|
||||
|
@ -21,6 +21,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <boost/scoped_ptr.hpp>
|
||||
#include "Util.h"
|
||||
#include "WordsRange.h"
|
||||
#include "ScoreComponentCollection.h"
|
||||
@ -45,7 +46,7 @@ typedef std::vector<ChartHypothesis*> ChartArcList;
|
||||
class ChartHypothesis
|
||||
{
|
||||
friend std::ostream& operator<<(std::ostream&, const ChartHypothesis&);
|
||||
friend class ChartKBestExtractor;
|
||||
// friend class ChartKBestExtractor;
|
||||
|
||||
protected:
|
||||
#ifdef USE_HYPO_POOL
|
||||
@ -56,7 +57,10 @@ protected:
|
||||
|
||||
WordsRange m_currSourceWordsRange;
|
||||
std::vector<const FFState*> m_ffStates; /*! stateful feature function states */
|
||||
ScoreComponentCollection m_scoreBreakdown /*! detailed score break-down by components (for instance language model, word penalty, etc) */
|
||||
/*! sum of scores of this hypothesis, and previous hypotheses. Lazily initialised. */
|
||||
mutable boost::scoped_ptr<ScoreComponentCollection> m_scoreBreakdown;
|
||||
mutable boost::scoped_ptr<ScoreComponentCollection> m_deltaScoreBreakdown;
|
||||
ScoreComponentCollection m_currScoreBreakdown /*! scores for this hypothesis only */
|
||||
,m_lmNGram
|
||||
,m_lmPrefix;
|
||||
float m_totalScore;
|
||||
@ -76,23 +80,23 @@ protected:
|
||||
//! not implemented
|
||||
ChartHypothesis(const ChartHypothesis ©);
|
||||
|
||||
//! only used by ChartKBestExtractor
|
||||
ChartHypothesis(const ChartHypothesis &, const ChartKBestExtractor &);
|
||||
|
||||
public:
|
||||
#ifdef USE_HYPO_POOL
|
||||
void *operator new(size_t /* num_bytes */) {
|
||||
void *operator new(size_t /* num_bytes */)
|
||||
{
|
||||
void *ptr = s_objectPool.getPtr();
|
||||
return ptr;
|
||||
}
|
||||
|
||||
//! delete \param hypo. Works with object pool too
|
||||
static void Delete(ChartHypothesis *hypo) {
|
||||
static void Delete(ChartHypothesis *hypo)
|
||||
{
|
||||
s_objectPool.freeObject(hypo);
|
||||
}
|
||||
#else
|
||||
//! delete \param hypo. Works with object pool too
|
||||
static void Delete(ChartHypothesis *hypo) {
|
||||
static void Delete(ChartHypothesis *hypo)
|
||||
{
|
||||
delete hypo;
|
||||
}
|
||||
#endif
|
||||
@ -100,38 +104,48 @@ public:
|
||||
ChartHypothesis(const ChartTranslationOptions &, const RuleCubeItem &item,
|
||||
ChartManager &manager);
|
||||
|
||||
//! only used by ChartKBestExtractor
|
||||
ChartHypothesis(const ChartHypothesis &, const ChartKBestExtractor &);
|
||||
|
||||
~ChartHypothesis();
|
||||
|
||||
unsigned GetId() const {
|
||||
unsigned GetId() const
|
||||
{
|
||||
return m_id;
|
||||
}
|
||||
|
||||
const ChartTranslationOption &GetTranslationOption()const {
|
||||
const ChartTranslationOption &GetTranslationOption() const
|
||||
{
|
||||
return *m_transOpt;
|
||||
}
|
||||
|
||||
//! Get the rule that created this hypothesis
|
||||
const TargetPhrase &GetCurrTargetPhrase()const {
|
||||
const TargetPhrase &GetCurrTargetPhrase() const
|
||||
{
|
||||
return m_transOpt->GetPhrase();
|
||||
}
|
||||
|
||||
//! the source range that this hypothesis spans
|
||||
const WordsRange &GetCurrSourceRange()const {
|
||||
const WordsRange &GetCurrSourceRange() const
|
||||
{
|
||||
return m_currSourceWordsRange;
|
||||
}
|
||||
|
||||
//! the arc list when creating n-best lists
|
||||
inline const ChartArcList* GetArcList() const {
|
||||
inline const ChartArcList* GetArcList() const
|
||||
{
|
||||
return m_arcList;
|
||||
}
|
||||
|
||||
//! the feature function states for a particular feature \param featureID
|
||||
inline const FFState* GetFFState( size_t featureID ) const {
|
||||
inline const FFState* GetFFState( size_t featureID ) const
|
||||
{
|
||||
return m_ffStates[ featureID ];
|
||||
}
|
||||
|
||||
//! reference back to the manager
|
||||
inline const ChartManager& GetManager() const {
|
||||
inline const ChartManager& GetManager() const
|
||||
{
|
||||
return m_manager;
|
||||
}
|
||||
|
||||
@ -140,7 +154,7 @@ public:
|
||||
|
||||
// get leftmost/rightmost words only
|
||||
// leftRightMost: 1=left, 2=right
|
||||
void GetOutputPhrase(int leftRightMost, int numWords, Phrase &outPhrase) const;
|
||||
void GetOutputPhrase(size_t leftRightMost, size_t numWords, Phrase &outPhrase) const;
|
||||
|
||||
int RecombineCompare(const ChartHypothesis &compare) const;
|
||||
|
||||
@ -151,32 +165,74 @@ public:
|
||||
void SetWinningHypo(const ChartHypothesis *hypo);
|
||||
|
||||
//! get the unweighted score for each feature function
|
||||
const ScoreComponentCollection &GetScoreBreakdown() const {
|
||||
return m_scoreBreakdown;
|
||||
const ScoreComponentCollection &GetScoreBreakdown() const
|
||||
{
|
||||
// Note: never call this method before m_currScoreBreakdown is fully computed
|
||||
if (!m_scoreBreakdown.get())
|
||||
{
|
||||
m_scoreBreakdown.reset(new ScoreComponentCollection());
|
||||
// score breakdown from current translation rule
|
||||
if (m_transOpt)
|
||||
{
|
||||
m_scoreBreakdown->PlusEquals(GetTranslationOption().GetScores());
|
||||
}
|
||||
m_scoreBreakdown->PlusEquals(m_currScoreBreakdown);
|
||||
// score breakdowns from prev hypos
|
||||
for (std::vector<const ChartHypothesis*>::const_iterator iter = m_prevHypos.begin(); iter != m_prevHypos.end(); ++iter)
|
||||
{
|
||||
const ChartHypothesis &prevHypo = **iter;
|
||||
m_scoreBreakdown->PlusEquals(prevHypo.GetScoreBreakdown());
|
||||
}
|
||||
}
|
||||
return *(m_scoreBreakdown.get());
|
||||
}
|
||||
|
||||
//! get the unweighted score delta for each feature function
|
||||
const ScoreComponentCollection &GetDeltaScoreBreakdown() const
|
||||
{
|
||||
// Note: never call this method before m_currScoreBreakdown is fully computed
|
||||
if (!m_deltaScoreBreakdown.get())
|
||||
{
|
||||
m_deltaScoreBreakdown.reset(new ScoreComponentCollection());
|
||||
// score breakdown from current translation rule
|
||||
if (m_transOpt)
|
||||
{
|
||||
m_deltaScoreBreakdown->PlusEquals(GetTranslationOption().GetScores());
|
||||
}
|
||||
m_deltaScoreBreakdown->PlusEquals(m_currScoreBreakdown);
|
||||
// delta: score breakdowns from prev hypos _not_ added
|
||||
}
|
||||
return *(m_deltaScoreBreakdown.get());
|
||||
}
|
||||
|
||||
//! Get the weighted total score
|
||||
float GetTotalScore() const {
|
||||
float GetTotalScore() const
|
||||
{
|
||||
// scores from current translation rule. eg. translation models & word penalty
|
||||
return m_totalScore;
|
||||
}
|
||||
|
||||
//! vector of previous hypotheses this hypo is built on
|
||||
const std::vector<const ChartHypothesis*> &GetPrevHypos() const {
|
||||
const std::vector<const ChartHypothesis*> &GetPrevHypos() const
|
||||
{
|
||||
return m_prevHypos;
|
||||
}
|
||||
|
||||
//! get a particular previous hypos
|
||||
const ChartHypothesis* GetPrevHypo(size_t pos) const {
|
||||
const ChartHypothesis* GetPrevHypo(size_t pos) const
|
||||
{
|
||||
return m_prevHypos[pos];
|
||||
}
|
||||
|
||||
//! get the constituency label that covers this hypo
|
||||
const Word &GetTargetLHS() const {
|
||||
const Word &GetTargetLHS() const
|
||||
{
|
||||
return GetCurrTargetPhrase().GetTargetLHS();
|
||||
}
|
||||
|
||||
//! get the best hypo in the arc list when doing n-best list creation. It's either this hypothesis, or the best hypo is this hypo is in the arc list
|
||||
const ChartHypothesis* GetWinningHypothesis() const {
|
||||
const ChartHypothesis* GetWinningHypothesis() const
|
||||
{
|
||||
return m_winningHypo;
|
||||
}
|
||||
|
||||
|
@ -124,6 +124,28 @@ Phrase ChartKBestExtractor::GetOutputPhrase(const Derivation &d)
|
||||
return ret;
|
||||
}
|
||||
|
||||
// Generate the score breakdown of the derivation d.
|
||||
boost::shared_ptr<ScoreComponentCollection>
|
||||
ChartKBestExtractor::GetOutputScoreBreakdown(const Derivation &d)
|
||||
{
|
||||
const ChartHypothesis &hypo = d.edge.head->hypothesis;
|
||||
boost::shared_ptr<ScoreComponentCollection> scoreBreakdown(new ScoreComponentCollection());
|
||||
scoreBreakdown->PlusEquals(hypo.GetDeltaScoreBreakdown());
|
||||
const TargetPhrase &phrase = hypo.GetCurrTargetPhrase();
|
||||
const AlignmentInfo::NonTermIndexMap &nonTermIndexMap =
|
||||
phrase.GetAlignNonTerm().GetNonTermIndexMap();
|
||||
for (std::size_t pos = 0; pos < phrase.GetSize(); ++pos) {
|
||||
const Word &word = phrase.GetWord(pos);
|
||||
if (word.IsNonTerminal()) {
|
||||
std::size_t nonTermInd = nonTermIndexMap[pos];
|
||||
const Derivation &subderivation = *d.subderivations[nonTermInd];
|
||||
scoreBreakdown->PlusEquals(*GetOutputScoreBreakdown(subderivation));
|
||||
}
|
||||
}
|
||||
|
||||
return scoreBreakdown;
|
||||
}
|
||||
|
||||
// Generate the target tree of the derivation d.
|
||||
TreePointer ChartKBestExtractor::GetOutputTree(const Derivation &d)
|
||||
{
|
||||
@ -286,7 +308,6 @@ ChartKBestExtractor::Derivation::Derivation(const UnweightedHyperarc &e)
|
||||
boost::shared_ptr<Derivation> sub(pred.kBestList[0]);
|
||||
subderivations.push_back(sub);
|
||||
}
|
||||
scoreBreakdown = edge.head->hypothesis.GetScoreBreakdown();
|
||||
score = edge.head->hypothesis.GetTotalScore();
|
||||
}
|
||||
|
||||
@ -298,15 +319,14 @@ ChartKBestExtractor::Derivation::Derivation(const Derivation &d, std::size_t i)
|
||||
backPointers = d.backPointers;
|
||||
subderivations = d.subderivations;
|
||||
std::size_t j = ++backPointers[i];
|
||||
scoreBreakdown = d.scoreBreakdown;
|
||||
score = d.score;
|
||||
// Deduct the score of the old subderivation.
|
||||
scoreBreakdown.MinusEquals(subderivations[i]->scoreBreakdown);
|
||||
score -= subderivations[i]->score;
|
||||
// Update the subderivation pointer.
|
||||
boost::shared_ptr<Derivation> newSub(edge.tail[i]->kBestList[j]);
|
||||
subderivations[i] = newSub;
|
||||
// Add the score of the new subderivation.
|
||||
scoreBreakdown.PlusEquals(subderivations[i]->scoreBreakdown);
|
||||
score = scoreBreakdown.GetWeightedScore();
|
||||
score += subderivations[i]->score;
|
||||
}
|
||||
|
||||
} // namespace Moses
|
||||
|
@ -26,6 +26,7 @@
|
||||
|
||||
#include <boost/unordered_set.hpp>
|
||||
#include <boost/weak_ptr.hpp>
|
||||
#include <boost/shared_ptr.hpp>
|
||||
|
||||
#include <queue>
|
||||
#include <vector>
|
||||
@ -56,7 +57,6 @@ public:
|
||||
UnweightedHyperarc edge;
|
||||
std::vector<std::size_t> backPointers;
|
||||
std::vector<boost::shared_ptr<Derivation> > subderivations;
|
||||
ScoreComponentCollection scoreBreakdown;
|
||||
float score;
|
||||
};
|
||||
|
||||
@ -90,6 +90,7 @@ public:
|
||||
std::size_t k, KBestVec &);
|
||||
|
||||
static Phrase GetOutputPhrase(const Derivation &);
|
||||
static boost::shared_ptr<ScoreComponentCollection> GetOutputScoreBreakdown(const Derivation &);
|
||||
static TreePointer GetOutputTree(const Derivation &);
|
||||
|
||||
private:
|
||||
|
@ -365,7 +365,8 @@ void ChartManager::OutputNBestList(OutputCollector *collector,
|
||||
out << translationId << " ||| ";
|
||||
OutputSurface(out, outputPhrase, outputFactorOrder, false);
|
||||
out << " ||| ";
|
||||
derivation.scoreBreakdown.OutputAllFeatureScores(out);
|
||||
boost::shared_ptr<ScoreComponentCollection> scoreBreakdown = ChartKBestExtractor::GetOutputScoreBreakdown(derivation);
|
||||
scoreBreakdown->OutputAllFeatureScores(out);
|
||||
out << " ||| " << derivation.score;
|
||||
|
||||
// optionally, print word alignments
|
||||
|
@ -29,6 +29,7 @@ public:
|
||||
const InputPath *GetInputPath() const {
|
||||
return m_inputPath;
|
||||
}
|
||||
|
||||
void SetInputPath(const InputPath *inputPath) {
|
||||
m_inputPath = inputPath;
|
||||
}
|
||||
|
@ -385,6 +385,15 @@ void FVector::sparsePlusEquals(const FVector& rhs)
|
||||
set(i->first, get(i->first) + i->second);
|
||||
}
|
||||
|
||||
// add only core features
|
||||
void FVector::corePlusEquals(const FVector& rhs)
|
||||
{
|
||||
if (rhs.m_coreFeatures.size() > m_coreFeatures.size())
|
||||
resize(rhs.m_coreFeatures.size());
|
||||
for (size_t i = 0; i < rhs.m_coreFeatures.size(); ++i)
|
||||
m_coreFeatures[i] += rhs.m_coreFeatures[i];
|
||||
}
|
||||
|
||||
// assign only core features
|
||||
void FVector::coreAssign(const FVector& rhs)
|
||||
{
|
||||
|
@ -235,6 +235,7 @@ public:
|
||||
void capMin(FValue minValue);
|
||||
|
||||
void sparsePlusEquals(const FVector& rhs);
|
||||
void corePlusEquals(const FVector& rhs);
|
||||
void coreAssign(const FVector& rhs);
|
||||
|
||||
void incrementSparseHopeFeatures();
|
||||
|
@ -280,7 +280,7 @@ FFState* LanguageModelImplementation::EvaluateWhenApplied(const ChartHypothesis&
|
||||
|
||||
// get prefixScore and finalizedScore
|
||||
prefixScore = prevState->GetPrefixScore();
|
||||
finalizedScore = prevHypo->GetScoreBreakdown().GetScoresForProducer(this)[0] - prefixScore;
|
||||
finalizedScore = -prefixScore;
|
||||
|
||||
// get language model state
|
||||
delete lmState;
|
||||
@ -308,13 +308,10 @@ FFState* LanguageModelImplementation::EvaluateWhenApplied(const ChartHypothesis&
|
||||
updateChartScore( &prefixScore, &finalizedScore, GetValueGivenState(contextFactor, *lmState).score, ++wordPos );
|
||||
}
|
||||
|
||||
finalizedScore -= prevState->GetPrefixScore();
|
||||
|
||||
// check if we are dealing with a large sub-phrase
|
||||
if (subPhraseLength > GetNGramOrder() - 1) {
|
||||
// add its finalized language model score
|
||||
finalizedScore +=
|
||||
prevHypo->GetScoreBreakdown().GetScoresForProducer(this)[0] // full score
|
||||
- prevState->GetPrefixScore(); // - prefix score
|
||||
|
||||
// copy language model state
|
||||
delete lmState;
|
||||
lmState = NewState( prevState->GetRightContext() );
|
||||
@ -337,15 +334,16 @@ FFState* LanguageModelImplementation::EvaluateWhenApplied(const ChartHypothesis&
|
||||
}
|
||||
}
|
||||
|
||||
// assign combined score to score breakdown
|
||||
// add combined score to score breakdown
|
||||
if (OOVFeatureEnabled()) {
|
||||
vector<float> scores(2);
|
||||
scores[0] = prefixScore + finalizedScore;
|
||||
scores[1] = out->GetScoresForProducer(this)[1];
|
||||
out->Assign(this, scores);
|
||||
scores[0] = prefixScore + finalizedScore - hypo.GetTranslationOption().GetScores().GetScoresForProducer(this)[0];
|
||||
// scores[1] = out->GetScoresForProducer(this)[1];
|
||||
scores[1] = 0;
|
||||
out->PlusEquals(this, scores);
|
||||
}
|
||||
else {
|
||||
out->Assign(this, prefixScore + finalizedScore);
|
||||
out->PlusEquals(this, prefixScore + finalizedScore - hypo.GetTranslationOption().GetScores().GetScoresForProducer(this)[0]);
|
||||
}
|
||||
|
||||
ret->Set(prefixScore, lmState);
|
||||
|
100
moses/LM/Ken.cpp
100
moses/LM/Ken.cpp
@ -329,8 +329,7 @@ template <class Model> FFState *LanguageModelKen<Model>::EvaluateWhenApplied(con
|
||||
// Non-terminal is first so we can copy instead of rescoring.
|
||||
const ChartHypothesis *prevHypo = hypo.GetPrevHypo(nonTermIndexMap[phrasePos]);
|
||||
const lm::ngram::ChartState &prevState = static_cast<const LanguageModelChartStateKenLM*>(prevHypo->GetFFState(featureID))->GetChartState();
|
||||
float prob = UntransformLMScore(prevHypo->GetScoreBreakdown().GetScoresForProducer(this)[0]);
|
||||
ruleScore.BeginNonTerminal(prevState, prob);
|
||||
ruleScore.BeginNonTerminal(prevState);
|
||||
phrasePos++;
|
||||
}
|
||||
}
|
||||
@ -340,8 +339,7 @@ template <class Model> FFState *LanguageModelKen<Model>::EvaluateWhenApplied(con
|
||||
if (word.IsNonTerminal()) {
|
||||
const ChartHypothesis *prevHypo = hypo.GetPrevHypo(nonTermIndexMap[phrasePos]);
|
||||
const lm::ngram::ChartState &prevState = static_cast<const LanguageModelChartStateKenLM*>(prevHypo->GetFFState(featureID))->GetChartState();
|
||||
float prob = UntransformLMScore(prevHypo->GetScoreBreakdown().GetScoresForProducer(this)[0]);
|
||||
ruleScore.NonTerminal(prevState, prob);
|
||||
ruleScore.NonTerminal(prevState);
|
||||
} else {
|
||||
ruleScore.Terminal(TranslateID(word));
|
||||
}
|
||||
@ -349,62 +347,64 @@ template <class Model> FFState *LanguageModelKen<Model>::EvaluateWhenApplied(con
|
||||
|
||||
float score = ruleScore.Finish();
|
||||
score = TransformLMScore(score);
|
||||
score -= hypo.GetTranslationOption().GetScores().GetScoresForProducer(this)[0];
|
||||
|
||||
if (OOVFeatureEnabled()) {
|
||||
std::vector<float> scores(2);
|
||||
scores[0] = score;
|
||||
scores[1] = 0.0;
|
||||
accumulator->Assign(this, scores);
|
||||
accumulator->PlusEquals(this, scores);
|
||||
}
|
||||
else {
|
||||
accumulator->Assign(this, score);
|
||||
accumulator->PlusEquals(this, score);
|
||||
}
|
||||
return newState;
|
||||
}
|
||||
|
||||
template <class Model> FFState *LanguageModelKen<Model>::EvaluateWhenApplied(const Syntax::SHyperedge& hyperedge, int featureID, ScoreComponentCollection *accumulator) const
|
||||
{
|
||||
LanguageModelChartStateKenLM *newState = new LanguageModelChartStateKenLM();
|
||||
lm::ngram::RuleScore<Model> ruleScore(*m_ngram, newState->GetChartState());
|
||||
const TargetPhrase &target = *hyperedge.translation;
|
||||
const AlignmentInfo::NonTermIndexMap &nonTermIndexMap =
|
||||
target.GetAlignNonTerm().GetNonTermIndexMap2();
|
||||
|
||||
const size_t size = target.GetSize();
|
||||
size_t phrasePos = 0;
|
||||
// Special cases for first word.
|
||||
if (size) {
|
||||
const Word &word = target.GetWord(0);
|
||||
if (word.GetFactor(m_factorType) == m_beginSentenceFactor) {
|
||||
// Begin of sentence
|
||||
ruleScore.BeginSentence();
|
||||
phrasePos++;
|
||||
} else if (word.IsNonTerminal()) {
|
||||
// Non-terminal is first so we can copy instead of rescoring.
|
||||
const Syntax::SVertex *pred = hyperedge.tail[nonTermIndexMap[phrasePos]];
|
||||
const lm::ngram::ChartState &prevState = static_cast<const LanguageModelChartStateKenLM*>(pred->state[featureID])->GetChartState();
|
||||
float prob = UntransformLMScore(pred->best->scoreBreakdown.GetScoresForProducer(this)[0]);
|
||||
ruleScore.BeginNonTerminal(prevState, prob);
|
||||
phrasePos++;
|
||||
}
|
||||
}
|
||||
|
||||
for (; phrasePos < size; phrasePos++) {
|
||||
const Word &word = target.GetWord(phrasePos);
|
||||
if (word.IsNonTerminal()) {
|
||||
const Syntax::SVertex *pred = hyperedge.tail[nonTermIndexMap[phrasePos]];
|
||||
const lm::ngram::ChartState &prevState = static_cast<const LanguageModelChartStateKenLM*>(pred->state[featureID])->GetChartState();
|
||||
float prob = UntransformLMScore(pred->best->scoreBreakdown.GetScoresForProducer(this)[0]);
|
||||
ruleScore.NonTerminal(prevState, prob);
|
||||
} else {
|
||||
ruleScore.Terminal(TranslateID(word));
|
||||
}
|
||||
}
|
||||
|
||||
float score = ruleScore.Finish();
|
||||
score = TransformLMScore(score);
|
||||
accumulator->Assign(this, score);
|
||||
return newState;
|
||||
}
|
||||
//template <class Model> FFState *LanguageModelKen<Model>::EvaluateWhenApplied(const Syntax::SHyperedge& hyperedge, int featureID, ScoreComponentCollection *accumulator) const
|
||||
//{
|
||||
// LanguageModelChartStateKenLM *newState = new LanguageModelChartStateKenLM();
|
||||
// lm::ngram::RuleScore<Model> ruleScore(*m_ngram, newState->GetChartState());
|
||||
// const TargetPhrase &target = *hyperedge.translation;
|
||||
// const AlignmentInfo::NonTermIndexMap &nonTermIndexMap =
|
||||
// target.GetAlignNonTerm().GetNonTermIndexMap2();
|
||||
//
|
||||
// const size_t size = target.GetSize();
|
||||
// size_t phrasePos = 0;
|
||||
// // Special cases for first word.
|
||||
// if (size) {
|
||||
// const Word &word = target.GetWord(0);
|
||||
// if (word.GetFactor(m_factorType) == m_beginSentenceFactor) {
|
||||
// // Begin of sentence
|
||||
// ruleScore.BeginSentence();
|
||||
// phrasePos++;
|
||||
// } else if (word.IsNonTerminal()) {
|
||||
// // Non-terminal is first so we can copy instead of rescoring.
|
||||
// const Syntax::SVertex *pred = hyperedge.tail[nonTermIndexMap[phrasePos]];
|
||||
// const lm::ngram::ChartState &prevState = static_cast<const LanguageModelChartStateKenLM*>(pred->state[featureID])->GetChartState();
|
||||
// float prob = UntransformLMScore(pred->best->scoreBreakdown.GetScoresForProducer(this)[0]);
|
||||
// ruleScore.BeginNonTerminal(prevState, prob);
|
||||
// phrasePos++;
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// for (; phrasePos < size; phrasePos++) {
|
||||
// const Word &word = target.GetWord(phrasePos);
|
||||
// if (word.IsNonTerminal()) {
|
||||
// const Syntax::SVertex *pred = hyperedge.tail[nonTermIndexMap[phrasePos]];
|
||||
// const lm::ngram::ChartState &prevState = static_cast<const LanguageModelChartStateKenLM*>(pred->state[featureID])->GetChartState();
|
||||
// float prob = UntransformLMScore(pred->best->scoreBreakdown.GetScoresForProducer(this)[0]);
|
||||
// ruleScore.NonTerminal(prevState, prob);
|
||||
// } else {
|
||||
// ruleScore.Terminal(TranslateID(word));
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// float score = ruleScore.Finish();
|
||||
// score = TransformLMScore(score);
|
||||
// accumulator->Assign(this, score);
|
||||
// return newState;
|
||||
//}
|
||||
|
||||
template <class Model> void LanguageModelKen<Model>::IncrementalCallback(Incremental::Manager &manager) const
|
||||
{
|
||||
|
@ -59,7 +59,7 @@ public:
|
||||
|
||||
virtual FFState *EvaluateWhenApplied(const ChartHypothesis& cur_hypo, int featureID, ScoreComponentCollection *accumulator) const;
|
||||
|
||||
virtual FFState *EvaluateWhenApplied(const Syntax::SHyperedge& hyperedge, int featureID, ScoreComponentCollection *accumulator) const;
|
||||
// virtual FFState *EvaluateWhenApplied(const Syntax::SHyperedge& hyperedge, int featureID, ScoreComponentCollection *accumulator) const;
|
||||
|
||||
virtual void IncrementalCallback(Incremental::Manager &manager) const;
|
||||
virtual void ReportHistoryOrder(std::ostream &out,const Phrase &phrase) const;
|
||||
|
@ -200,6 +200,11 @@ public:
|
||||
m_scores.sparsePlusEquals(rhs.m_scores);
|
||||
}
|
||||
|
||||
// add only core features
|
||||
void CorePlusEquals(const ScoreComponentCollection& rhs) {
|
||||
m_scores.corePlusEquals(rhs.m_scores);
|
||||
}
|
||||
|
||||
void PlusEquals(const FVector& scores) {
|
||||
m_scores += scores;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user