Feature function overhaul. Each feature function is computed in one of three ways:

1) Stateless feature functions from the phrase table/generation table: these are computed when
   the TranslationOption is created.  They become part of the ScoreBreakdown object contained in
   the TranslationOption and are added to the feature value vector when a hypothesis is extended.
2) Stateless feature functions that are computed during state exploration. Currently, only
   WordPenalty falls into this category, but these functions implement a method Evaluate which
   do does not receive a Hypothesis or any contextual information.
3) Stateful feature functions: these features receive the arc information (translation option),
   compute some value and then return some context information.  The context information created
   by a particular feature function is passed back to it as the previous context when a hypothesis
   originating at the node where the previous edge terminates is created.  States in the search
   space may be recombined if the context information is identical.  The context information must
   be stored in an object implementing the FFState interface.

TODO:
1) the command line interface / MERT interface needs to go to named parameters that are otherwise opaque
2) StatefulFeatureFunction's Evaluate method should just take a TranslationOption and a context object.  It is not good that it takes a hypothesis, because then people may be tempted to access information about the "previous" hypothesis without "declaring" this dependency.
3) Future cost estimates should be handled using feature functions.  All stateful feature functions need some kind of future cost estimate.
4) Philipp's poor-man's cube pruning is broken.



git-svn-id: https://mosesdecoder.svn.sourceforge.net/svnroot/mosesdecoder/trunk@2087 1f5c12ca-751b-0410-a591-d2e778427230
This commit is contained in:
redpony 2009-02-06 15:43:06 +00:00
parent cc95706045
commit 63effe85b5
24 changed files with 428 additions and 214 deletions

View File

@ -12,6 +12,9 @@
/* flag for protobuf */
#undef HAVE_PROTOBUF
/* flag for RandLM */
#undef HAVE_RANDLM
/* flag for SRILM */
#undef HAVE_SRILM

View File

@ -1,12 +1,36 @@
// $Id$
#include <cassert>
#include "FFState.h"
#include "StaticData.h"
#include "DummyScoreProducers.h"
#include "WordsRange.h"
#include "TranslationOption.h"
namespace Moses
{
struct DistortionState_traditional : public FFState {
WordsRange range;
int first_gap;
DistortionState_traditional(const WordsRange& wr, int fg) : range(wr), first_gap(fg) {}
int Compare(const FFState& other) const {
const DistortionState_traditional& o =
static_cast<const DistortionState_traditional&>(other);
if (range.GetEndPos() < o.range.GetEndPos()) return -1;
if (range.GetEndPos() > o.range.GetEndPos()) return 1;
return 0;
}
};
struct DistortionState_MQ2007 : public FFState {
//TODO
};
const FFState* DistortionScoreProducer::EmptyHypothesisState() const {
return new DistortionState_traditional(WordsRange(NOT_FOUND,NOT_FOUND), NOT_FOUND);
}
DistortionScoreProducer::DistortionScoreProducer(ScoreIndexManager &scoreIndexManager)
{
scoreIndexManager.AddScoreProducer(this);
@ -22,11 +46,6 @@ std::string DistortionScoreProducer::GetScoreProducerDescription() const
return "Distortion";
}
//float DistortionScoreProducer::CalculateDistortionScoreOUTDATED(const WordsRange &prev, const WordsRange &curr) const
//{
// return - (float) StaticData::Instance().GetInput()->ComputeDistortionDistance(prev, curr);
//}
float DistortionScoreProducer::CalculateDistortionScore(const WordsRange &prev, const WordsRange &curr, const int FirstGap) const
{
const int USE_OLD = 1;
@ -53,6 +72,23 @@ float DistortionScoreProducer::CalculateDistortionScore(const WordsRange &prev,
return (float) -2*(curr.GetNumWordsBetween(prev) + curr.GetNumWordsCovered());
}
size_t DistortionScoreProducer::GetNumInputScores() const { return 0;}
FFState* DistortionScoreProducer::Evaluate(
const Hypothesis& hypo,
const FFState* prev_state,
ScoreComponentCollection* out) const {
const DistortionState_traditional* prev = static_cast<const DistortionState_traditional*>(prev_state);
const float distortionScore = CalculateDistortionScore(
prev->range,
hypo.GetCurrSourceWordsRange(),
prev->first_gap);
out->PlusEquals(this, distortionScore);
DistortionState_traditional* res = new DistortionState_traditional(
hypo.GetCurrSourceWordsRange(),
hypo.GetPrevHypo()->GetWordsBitmap().GetFirstGapPos());
return res;
}
WordPenaltyProducer::WordPenaltyProducer(ScoreIndexManager &scoreIndexManager)
@ -70,6 +106,13 @@ std::string WordPenaltyProducer::GetScoreProducerDescription() const
return "WordPenalty";
}
size_t WordPenaltyProducer::GetNumInputScores() const { return 0;}
void WordPenaltyProducer::Evaluate(const TargetPhrase& tp, ScoreComponentCollection* out) const
{
out->PlusEquals(this, -static_cast<float>(tp.GetSize()));
}
UnknownWordPenaltyProducer::UnknownWordPenaltyProducer(ScoreIndexManager &scoreIndexManager)
{
scoreIndexManager.AddScoreProducer(this);
@ -85,6 +128,10 @@ std::string UnknownWordPenaltyProducer::GetScoreProducerDescription() const
return "!UnknownWordPenalty";
}
size_t UnknownWordPenaltyProducer::GetNumInputScores() const { return 0;}
bool UnknownWordPenaltyProducer::ComputeValueInTranslationOption() const {
return true;
}
}

View File

@ -3,7 +3,7 @@
#ifndef _DUMMY_SCORE_PRODUCERS_H_
#define _DUMMY_SCORE_PRODUCERS_H_
#include "ScoreProducer.h"
#include "FeatureFunction.h"
namespace Moses
{
@ -12,7 +12,7 @@ class WordsRange;
/** Calculates Distortion scores
*/
class DistortionScoreProducer : public ScoreProducer {
class DistortionScoreProducer : public StatefulFeatureFunction {
public:
DistortionScoreProducer(ScoreIndexManager &scoreIndexManager);
@ -20,26 +20,45 @@ public:
size_t GetNumScoreComponents() const;
std::string GetScoreProducerDescription() const;
size_t GetNumInputScores() const;
virtual const FFState* EmptyHypothesisState() const;
virtual FFState* Evaluate(
const Hypothesis& cur_hypo,
const FFState* prev_state,
ScoreComponentCollection* accumulator) const;
};
/** Doesn't do anything but provide a key into the global
* score array to store the word penalty in.
*/
class WordPenaltyProducer : public ScoreProducer {
class WordPenaltyProducer : public StatelessFeatureFunction {
public:
WordPenaltyProducer(ScoreIndexManager &scoreIndexManager);
size_t GetNumScoreComponents() const;
std::string GetScoreProducerDescription() const;
size_t GetNumInputScores() const;
virtual void Evaluate(
const TargetPhrase& phrase,
ScoreComponentCollection* out) const;
};
/** unknown word penalty */
class UnknownWordPenaltyProducer : public ScoreProducer {
class UnknownWordPenaltyProducer : public StatelessFeatureFunction {
public:
UnknownWordPenaltyProducer(ScoreIndexManager &scoreIndexManager);
size_t GetNumScoreComponents() const;
std::string GetScoreProducerDescription() const;
size_t GetNumInputScores() const;
virtual bool ComputeValueInTranslationOption() const;
};
}

8
moses/src/FFState.cpp Normal file
View File

@ -0,0 +1,8 @@
#include "FFState.h"
namespace Moses {
FFState::~FFState() {}
}

9
moses/src/FFState.h Normal file
View File

@ -0,0 +1,9 @@
namespace Moses {
class FFState {
public:
virtual ~FFState();
virtual int Compare(const FFState& other) const = 0;
};
}

View File

@ -0,0 +1,22 @@
#include "FeatureFunction.h"
#include <cassert>
namespace Moses {
FeatureFunction::~FeatureFunction() {}
bool StatelessFeatureFunction::IsStateless() const { return true; }
bool StatelessFeatureFunction::ComputeValueInTranslationOption() const {
return false;
}
void StatelessFeatureFunction::Evaluate(
const TargetPhrase& cur_hypo,
ScoreComponentCollection* accumulator) const {
assert(!"Please implement Evaluate or set ComputeValueInTranslationOption to true");
}
bool StatefulFeatureFunction::IsStateless() const { return false; }
}

View File

@ -0,0 +1,64 @@
#ifndef _FEATURE_FUNCTION_H_
#define _FEATURE_FUNCTION_H_
#include <vector>
#include "ScoreProducer.h"
namespace Moses {
class TargetPhrase;
class Hypothesis;
class FFState;
class ScoreComponentCollection;
class FeatureFunction: public ScoreProducer {
public:
virtual bool IsStateless() const = 0;
virtual ~FeatureFunction();
};
class StatelessFeatureFunction: public FeatureFunction {
public:
//! Evaluate for stateless feature functions. Implement this.
virtual void Evaluate(
const TargetPhrase& cur_hypo,
ScoreComponentCollection* accumulator) const;
// If true, this value is expected to be included in the
// ScoreBreakdown in the TranslationOption once it has been
// constructed.
// Default: true
virtual bool ComputeValueInTranslationOption() const;
bool IsStateless() const;
};
class StatefulFeatureFunction: public FeatureFunction {
public:
/**
* \brief This interface should be implemented.
* Notes: When evaluating the value of this feature function, you should avoid
* calling hypo.GetPrevHypo(). If you need something from the "previous"
* hypothesis, you should store it in an FFState object which will be passed
* in as prev_state. If you don't do this, you will get in trouble.
*/
virtual FFState* Evaluate(
const Hypothesis& cur_hypo,
const FFState* prev_state,
ScoreComponentCollection* accumulator) const = 0;
//! return the state associated with the empty hypothesis
virtual const FFState* EmptyHypothesisState() const = 0;
bool IsStateless() const;
};
}
#endif

View File

@ -155,6 +155,10 @@ const OutputWordCollection *GenerationDictionary::FindWord(const Word &word) con
return ret;
}
bool GenerationDictionary::ComputeValueInTranslationOption() const {
return true;
}
}

View File

@ -28,6 +28,7 @@ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
#include "Phrase.h"
#include "TypeDef.h"
#include "Dictionary.h"
#include "FeatureFunction.h"
namespace Moses
{
@ -40,7 +41,7 @@ typedef std::map < Word , ScoreComponentCollection > OutputWordCollection;
/** Implementation of a generation table in a trie.
*/
class GenerationDictionary : public Dictionary, public ScoreProducer
class GenerationDictionary : public Dictionary, public StatelessFeatureFunction
{
typedef std::map<const Word* , OutputWordCollection, WordComparer> Collection;
protected:
@ -82,6 +83,7 @@ public:
* Or NULL if the input word isn't found. The search function used is the WordComparer functor
*/
const OutputWordCollection *FindWord(const Word &word) const;
virtual bool ComputeValueInTranslationOption() const;
};

View File

@ -24,6 +24,8 @@ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
#include <limits>
#include <vector>
#include <algorithm>
#include "FFState.h"
#include "TranslationOption.h"
#include "TranslationOptionCollection.h"
#include "DummyScoreProducers.h"
@ -56,16 +58,18 @@ Hypothesis::Hypothesis(InputType const& source, const TargetPhrase &emptyTarget)
, m_currSourceWordsRange(NOT_FOUND, NOT_FOUND)
, m_currTargetWordsRange(NOT_FOUND, NOT_FOUND)
, m_wordDeleted(false)
, m_languageModelStates(StaticData::Instance().GetLMSize(), LanguageModelSingleFactor::UnknownState)
, m_ffStates(StaticData::Instance().GetScoreIndexManager().GetStatefulFeatureFunctions().size())
, m_arcList(NULL)
, m_id(0)
, m_lmstats(NULL)
, m_alignPair(source.GetSize())
{ // used for initial seeding of trans process
// initialize scores
//_hash_computed = false;
s_HypothesesCreated = 1;
ResetScore();
const vector<const StatefulFeatureFunction*>& ffs = StaticData::Instance().GetScoreIndexManager().GetStatefulFeatureFunctions();
for (unsigned i = 0; i < ffs.size(); ++i)
m_ffStates[i] = ffs[i]->EmptyHypothesisState();
}
/***
@ -73,8 +77,8 @@ Hypothesis::Hypothesis(InputType const& source, const TargetPhrase &emptyTarget)
*/
Hypothesis::Hypothesis(const Hypothesis &prevHypo, const TranslationOption &transOpt)
: m_prevHypo(&prevHypo)
, m_transOpt(&transOpt)
, m_targetPhrase(transOpt.GetTargetPhrase())
, m_transOpt(&transOpt)
, m_sourcePhrase(transOpt.GetSourcePhrase())
, m_sourceCompleted (prevHypo.m_sourceCompleted )
, m_sourceInput (prevHypo.m_sourceInput)
@ -84,11 +88,10 @@ Hypothesis::Hypothesis(const Hypothesis &prevHypo, const TranslationOption &tran
, m_wordDeleted(false)
, m_totalScore(0.0f)
, m_futureScore(0.0f)
, m_ffStates(prevHypo.m_ffStates.size())
, m_scoreBreakdown (prevHypo.m_scoreBreakdown)
, m_languageModelStates(prevHypo.m_languageModelStates)
, m_arcList(NULL)
, m_id(s_HypothesesCreated++)
, m_lmstats(NULL)
, m_alignPair(prevHypo.m_alignPair)
{
// assert that we are not extending our hypothesis by retranslating something
@ -98,11 +101,13 @@ Hypothesis::Hypothesis(const Hypothesis &prevHypo, const TranslationOption &tran
//_hash_computed = false;
m_sourceCompleted.SetValue(m_currSourceWordsRange.GetStartPos(), m_currSourceWordsRange.GetEndPos(), true);
m_wordDeleted = transOpt.IsDeletionOption();
m_scoreBreakdown.PlusEquals(transOpt.GetScoreBreakdown());
}
Hypothesis::~Hypothesis()
{
for (unsigned i = 0; i < m_ffStates.size(); ++i)
delete m_ffStates[i];
if (m_arcList)
{
ArcList::iterator iter;
@ -114,8 +119,6 @@ Hypothesis::~Hypothesis()
delete m_arcList;
m_arcList = NULL;
delete m_lmstats; m_lmstats = NULL;
}
}
@ -233,140 +236,24 @@ Hypothesis* Hypothesis::Create(InputType const& m_source, const TargetPhrase &em
keep an ordered list of hypotheses. This makes recombination
much quicker.
*/
int Hypothesis::NGramCompare(const Hypothesis &compare) const
int Hypothesis::RecombineCompare(const Hypothesis &compare) const
{ // -1 = this < compare
// +1 = this > compare
// 0 = this ==compare
if (m_languageModelStates < compare.m_languageModelStates) return -1;
if (m_languageModelStates > compare.m_languageModelStates) return 1;
int comp = m_sourceCompleted.Compare(compare.m_sourceCompleted);
if (comp != 0)
return comp;
int compareBitmap = m_sourceCompleted.Compare(compare.m_sourceCompleted);
if (compareBitmap != 0)
return compareBitmap;
if (m_currSourceWordsRange.GetEndPos() < compare.m_currSourceWordsRange.GetEndPos()) return -1;
if (m_currSourceWordsRange.GetEndPos() > compare.m_currSourceWordsRange.GetEndPos()) return 1;
if (! StaticData::Instance().GetSourceStartPosMattersForRecombination()) return 0;
if (m_currSourceWordsRange.GetStartPos() < compare.m_currSourceWordsRange.GetStartPos()) return -1;
if (m_currSourceWordsRange.GetStartPos() > compare.m_currSourceWordsRange.GetStartPos()) return 1;
return 0;
}
/** Calculates the overall language model score by combining the scores
* of language models generated for each of the factors. Because the factors
* represent a variety of tag sets, and because factors with smaller tag sets
* (such as POS instead of words) allow us to calculate richer statistics, we
* allow a different length of n-gram to be specified for each factor.
* /param lmListInitial todo - describe this parameter
* /param lmListEnd todo - describe this parameter
*/
void Hypothesis::CalcLMScore(const LMList &languageModels)
{
LMList::const_iterator iterLM;
clock_t t=0;
IFVERBOSE(2) { t = clock(); } // track time
const size_t startPos = m_currTargetWordsRange.GetStartPos();
// will be null if LM stats collection is disabled
if (StaticData::Instance().IsComputeLMBackoffStats()) {
m_lmstats = new vector<vector<unsigned int> >(languageModels.size(), vector<unsigned int>(0));
}
size_t lmIdx = 0;
// already have LM scores from previous and trigram score of poss trans.
// just need n-gram score of the words of the start of current phrase
for (iterLM = languageModels.begin() ; iterLM != languageModels.end() ; ++iterLM,++lmIdx)
{
const LanguageModel &languageModel = **iterLM;
size_t nGramOrder = languageModel.GetNGramOrder();
size_t currEndPos = m_currTargetWordsRange.GetEndPos();
float lmScore;
size_t nLmCallCount = 0;
if(m_currTargetWordsRange.GetNumWordsCovered() == 0) {
lmScore = 0; //the score associated with dropping source words is not part of the language model
} else { //non-empty target phrase
if (m_lmstats)
(*m_lmstats)[lmIdx].resize(m_currTargetWordsRange.GetNumWordsCovered(), 0);
// 1st n-gram
vector<const Word*> contextFactor(nGramOrder);
size_t index = 0;
for (int currPos = (int) startPos - (int) nGramOrder + 1 ; currPos <= (int) startPos ; currPos++)
{
if (currPos >= 0)
contextFactor[index++] = &GetWord(currPos);
else
contextFactor[index++] = &languageModel.GetSentenceStartArray();
}
lmScore = languageModel.GetValue(contextFactor);
if (m_lmstats) { languageModel.GetState(contextFactor, &(*m_lmstats)[lmIdx][nLmCallCount++]); }
//cout<<"context factor: "<<languageModel.GetValue(contextFactor)<<endl;
// main loop
size_t endPos = std::min(startPos + nGramOrder - 2
, currEndPos);
for (size_t currPos = startPos + 1 ; currPos <= endPos ; currPos++)
{
// shift all args down 1 place
for (size_t i = 0 ; i < nGramOrder - 1 ; i++)
contextFactor[i] = contextFactor[i + 1];
// add last factor
contextFactor.back() = &GetWord(currPos);
lmScore += languageModel.GetValue(contextFactor);
if (m_lmstats)
languageModel.GetState(contextFactor, &(*m_lmstats)[lmIdx][nLmCallCount++]);
//cout<<"context factor: "<<languageModel.GetValue(contextFactor)<<endl;
}
// end of sentence
if (m_sourceCompleted.IsComplete())
{
const size_t size = GetSize();
contextFactor.back() = &languageModel.GetSentenceEndArray();
for (size_t i = 0 ; i < nGramOrder - 1 ; i ++)
{
int currPos = (int)(size - nGramOrder + i + 1);
if (currPos < 0)
contextFactor[i] = &languageModel.GetSentenceStartArray();
else
contextFactor[i] = &GetWord((size_t)currPos);
}
if (m_lmstats) {
(*m_lmstats)[lmIdx].resize((*m_lmstats)[lmIdx].size() + 1); // extra space for the last call
lmScore += languageModel.GetValue(contextFactor, &m_languageModelStates[lmIdx], &(*m_lmstats)[lmIdx][nLmCallCount++]);
} else
lmScore += languageModel.GetValue(contextFactor, &m_languageModelStates[lmIdx]);
for (unsigned i = 0; i < m_ffStates.size(); ++i) {
if (m_ffStates[i] == NULL || compare.m_ffStates[i] == NULL) {
comp = m_ffStates[i] - compare.m_ffStates[i];
} else {
for (size_t currPos = endPos+1; currPos <= currEndPos; currPos++) {
for (size_t i = 0 ; i < nGramOrder - 1 ; i++)
contextFactor[i] = contextFactor[i + 1];
contextFactor.back() = &GetWord(currPos);
if (m_lmstats)
languageModel.GetState(contextFactor, &(*m_lmstats)[lmIdx][nLmCallCount++]);
}
m_languageModelStates[lmIdx]=languageModel.GetState(contextFactor);
comp = m_ffStates[i]->Compare(*compare.m_ffStates[i]);
}
if (comp != 0) return comp;
}
m_scoreBreakdown.PlusEquals(&languageModel, lmScore);
}
IFVERBOSE(2) { StaticData::Instance().GetSentenceStats().AddTimeCalcLM( clock()-t ); }
}
void Hypothesis::CalcDistortionScore()
{
const DistortionScoreProducer *dsp = StaticData::Instance().GetDistortionScoreProducer();
float distortionScore = dsp->CalculateDistortionScore(
m_prevHypo->GetCurrSourceWordsRange(),
this->GetCurrSourceWordsRange(),
m_prevHypo->GetWordsBitmap().GetFirstGapPos()
);
m_scoreBreakdown.PlusEquals(dsp, distortionScore);
return 0;
}
void Hypothesis::ResetScore()
@ -380,30 +267,35 @@ void Hypothesis::ResetScore()
*/
void Hypothesis::CalcScore(const SquareMatrix &futureScore)
{
// some stateless score producers cache their values in the translation
// option: add these here
m_scoreBreakdown.PlusEquals(m_transOpt->GetScoreBreakdown());
const StaticData &staticData = StaticData::Instance();
clock_t t=0; // used to track time
// LANGUAGE MODEL COST
CalcLMScore(staticData.GetAllLM());
// compute values of stateless feature functions that were not
// cached in the translation option-- there is no principled distinction
const vector<const StatelessFeatureFunction*>& sfs =
staticData.GetScoreIndexManager().GetStatelessFeatureFunctions();
for (unsigned i = 0; i < sfs.size(); ++i) {
sfs[i]->Evaluate(m_targetPhrase, &m_scoreBreakdown);
}
const vector<const StatefulFeatureFunction*>& ffs =
staticData.GetScoreIndexManager().GetStatefulFeatureFunctions();
for (unsigned i = 0; i < ffs.size(); ++i) {
m_ffStates[i] = ffs[i]->Evaluate(
*this,
m_prevHypo ? m_prevHypo->m_ffStates[i] : NULL,
&m_scoreBreakdown);
}
IFVERBOSE(2) { t = clock(); } // track time excluding LM
// DISTORTION COST
CalcDistortionScore();
// WORD PENALTY
m_scoreBreakdown.PlusEquals(staticData.GetWordPenaltyProducer(), - (float) m_currTargetWordsRange.GetNumWordsCovered());
// FUTURE COST
m_futureScore = futureScore.CalcFutureScore( m_sourceCompleted );
//LEXICAL REORDERING COST
const std::vector<LexicalReordering*> &reorderModels = staticData.GetReorderModels();
for(unsigned int i = 0; i < reorderModels.size(); i++)
{
m_scoreBreakdown.PlusEquals(reorderModels[i], reorderModels[i]->CalcScore(this));
}
// TOTAL
m_totalScore = m_scoreBreakdown.InnerProduct(staticData.GetAllWeights()) + m_futureScore;
@ -421,8 +313,8 @@ float Hypothesis::CalcExpectedScore( const SquareMatrix &futureScore ) {
clock_t t=0;
IFVERBOSE(2) { t = clock(); } // track time excluding LM
// DISTORTION COST
CalcDistortionScore();
assert(!"Need to add code to get the distortion scores");
//CalcDistortionScore();
// LANGUAGE MODEL ESTIMATE (includes word penalty cost)
float estimatedLMScore = m_transOpt->GetFutureScore() - m_transOpt->GetScoreBreakdown().InnerProduct(staticData.GetAllWeights());
@ -450,7 +342,8 @@ void Hypothesis::CalcRemainingScore()
clock_t t=0; // used to track time
// LANGUAGE MODEL COST
CalcLMScore(staticData.GetAllLM());
assert(!"Need to add code to get the LM score(s)");
//CalcLMScore(staticData.GetAllLM());
IFVERBOSE(2) { t = clock(); } // track time excluding LM

View File

@ -46,6 +46,7 @@ class StaticData;
class TranslationOption;
class WordsRange;
class Hypothesis;
class FFState;
typedef std::vector<Hypothesis*> ArcList;
@ -78,18 +79,14 @@ protected:
float m_totalScore; /*! score so far */
float m_futureScore; /*! estimated future cost to translate rest of sentence */
ScoreComponentCollection m_scoreBreakdown; /*! detailed score break-down by components (for instance language model, word penalty, etc) */
std::vector<LanguageModelSingleFactor::State> m_languageModelStates; /*! relevant history for language model scoring -- used for recombination */
std::vector<const FFState*> m_ffStates;
const Hypothesis *m_winningHypo;
ArcList *m_arcList; /*! all arcs that end at the same trellis point as this hypothesis */
AlignmentPair m_alignPair;
const TranslationOption *m_transOpt;
int m_id; /*! numeric ID of this hypothesis, used for logging */
std::vector<std::vector<unsigned int> >* m_lmstats; /*! Statistics: (see IsComputeLMBackoffStats() in StaticData.h */
static unsigned int s_HypothesesCreated; // Statistics: how many hypotheses were created in total
void CalcLMScore(const LMList &languageModels);
void CalcDistortionScore();
//TODO: add appropriate arguments to score calculator
/*! used by initial seeding of the translation process */
Hypothesis(InputType const& source, const TargetPhrase &emptyTarget);
@ -138,7 +135,7 @@ public:
}
/** output length of the translation option used to create this hypothesis */
size_t GetCurrTargetLength() const
inline size_t GetCurrTargetLength() const
{
return m_currTargetWordsRange.GetNumWordsCovered();
}
@ -209,15 +206,11 @@ public:
return m_sourceCompleted;
}
int NGramCompare(const Hypothesis &compare) const;
// inline size_t hash() const
// {
// if (_hash_computed) return _hash;
// GenerateNGramCompareHash();
// return _hash;
// }
inline bool IsSourceCompleted() const {
return m_sourceCompleted.IsComplete();
}
int RecombineCompare(const Hypothesis &compare) const;
void ToStream(std::ostream& out) const
{
@ -298,10 +291,7 @@ public:
//! target span that trans opt would populate if applied to this hypo. Used for alignment check
size_t GetNextStartPos(const TranslationOption &transOpt) const;
std::vector<std::vector<unsigned int> > *GetLMStats() const
{
return m_lmstats;
}
std::vector<std::vector<unsigned int> > *GetLMStats() const { return NULL; }
static unsigned int GetHypothesesCreated()
{
@ -348,22 +338,7 @@ class HypothesisRecombinationOrderer
public:
bool operator()(const Hypothesis* hypoA, const Hypothesis* hypoB) const
{
// Are the last (n-1) words the same on the target side (n for n-gram LM)?
int ret = hypoA->NGramCompare(*hypoB);
// int ret = hypoA->FastNGramCompare(*hypoB, m_NGramMaxOrder - 1);
if (ret != 0)
{
return (ret < 0);
}
//TODO: is this check redundant? NGramCompare already calls wordbitmap.comare.
// same last n-grams. compare source words translated
const WordsBitmap &bitmapA = hypoA->GetWordsBitmap()
, &bitmapB = hypoB->GetWordsBitmap();
ret = bitmapA.Compare(bitmapB);
return (ret < 0);
return hypoA->RecombineCompare(*hypoB) < 0;
}
};

View File

@ -24,6 +24,7 @@ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
#include <iostream>
#include <sstream>
#include "FFState.h"
#include "LanguageModel.h"
#include "TypeDef.h"
#include "Util.h"
@ -95,6 +96,88 @@ LanguageModel::State LanguageModel::GetState(const std::vector<const Word*> &con
return state;
}
struct LMState : public FFState {
const void* lmstate;
LMState(const void* lms) { lmstate = lms; }
virtual int Compare(const FFState& o) const {
const LMState& other = static_cast<const LMState&>(o);
if (other.lmstate > lmstate) return 1;
else if (other.lmstate < lmstate) return -1;
return 0;
}
};
const FFState* LanguageModel::EmptyHypothesisState() const {
return new LMState(NULL);
}
FFState* LanguageModel::Evaluate(
const Hypothesis& hypo,
const FFState* ps,
ScoreComponentCollection* out) const {
clock_t t=0;
IFVERBOSE(2) { t = clock(); } // track time
const void* prevlm = ps ? (static_cast<const LMState *>(ps)->lmstate) : NULL;
LMState* res = new LMState(prevlm);
if (hypo.GetCurrTargetLength() == 0)
return res;
const size_t currEndPos = hypo.GetCurrTargetWordsRange().GetEndPos();
const size_t startPos = hypo.GetCurrTargetWordsRange().GetStartPos();
// 1st n-gram
vector<const Word*> contextFactor(m_nGramOrder);
size_t index = 0;
for (int currPos = (int) startPos - (int) m_nGramOrder + 1 ; currPos <= (int) startPos ; currPos++)
{
if (currPos >= 0)
contextFactor[index++] = &hypo.GetWord(currPos);
else
contextFactor[index++] = &GetSentenceStartArray();
}
float lmScore = GetValue(contextFactor);
//cout<<"context factor: "<<GetValue(contextFactor)<<endl;
// main loop
size_t endPos = std::min(startPos + m_nGramOrder - 2
, currEndPos);
for (size_t currPos = startPos + 1 ; currPos <= endPos ; currPos++)
{
// shift all args down 1 place
for (size_t i = 0 ; i < m_nGramOrder - 1 ; i++)
contextFactor[i] = contextFactor[i + 1];
// add last factor
contextFactor.back() = &hypo.GetWord(currPos);
lmScore += GetValue(contextFactor);
}
// end of sentence
if (hypo.IsSourceCompleted())
{
const size_t size = hypo.GetSize();
contextFactor.back() = &GetSentenceEndArray();
for (size_t i = 0 ; i < m_nGramOrder - 1 ; i ++)
{
int currPos = (int)(size - m_nGramOrder + i + 1);
if (currPos < 0)
contextFactor[i] = &GetSentenceStartArray();
else
contextFactor[i] = &hypo.GetWord((size_t)currPos);
}
lmScore += GetValue(contextFactor, &res->lmstate);
} else {
for (size_t currPos = endPos+1; currPos <= currEndPos; currPos++) {
for (size_t i = 0 ; i < m_nGramOrder - 1 ; i++)
contextFactor[i] = contextFactor[i + 1];
contextFactor.back() = &hypo.GetWord(currPos);
}
res->lmstate = GetState(contextFactor);
}
out->PlusEquals(this, lmScore);
IFVERBOSE(2) { StaticData::Instance().GetSentenceStats().AddTimeCalcLM( clock()-t ); }
return res;
}
}

View File

@ -26,7 +26,7 @@ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
#include "Factor.h"
#include "TypeDef.h"
#include "Util.h"
#include "ScoreProducer.h"
#include "FeatureFunction.h"
#include "Word.h"
namespace Moses
@ -37,7 +37,7 @@ class Factor;
class Phrase;
//! Abstract base class which represent a language model on a contiguous phrase
class LanguageModel : public ScoreProducer
class LanguageModel : public StatefulFeatureFunction
{
protected:
float m_weight; //! scoring weight. Shouldn't this now be superceded by ScoreProducer???
@ -125,6 +125,14 @@ public:
//! overrideable funtions for IRST LM to cleanup. Maybe something to do with on demand/cache loading/unloading
virtual void InitializeBeforeSentenceProcessing(){};
virtual void CleanUpAfterSentenceProcessing() {};
virtual const FFState* EmptyHypothesisState() const;
virtual FFState* Evaluate(
const Hypothesis& cur_hypo,
const FFState* prev_state,
ScoreComponentCollection* accumulator) const;
};
}

View File

@ -250,5 +250,20 @@ Score LexicalReordering::GetProb(const Phrase& f, const Phrase& e) const
return m_Table->GetScore(f, e, Phrase(Output));
}
FFState* LexicalReordering::Evaluate(
const Hypothesis& hypo,
const FFState* prev_state,
ScoreComponentCollection* out) const {
out->PlusEquals(this, CalcScore(const_cast<Hypothesis*>(&hypo)));
//TODO need to return proper state, calc score should not use previous
//hypothesis, it should use the state.
return NULL;
}
const FFState* LexicalReordering::EmptyHypothesisState() const {
return NULL;
}
}

View File

@ -11,6 +11,7 @@
#include "Util.h"
#include "WordsRange.h"
#include "ScoreProducer.h"
#include "FeatureFunction.h"
#include "LexicalReorderingTable.h"
@ -24,7 +25,7 @@ class InputType;
using namespace std;
class LexicalReordering : public ScoreProducer {
class LexicalReordering : public StatefulFeatureFunction {
public: //types & consts
typedef int OrientationType;
enum Direction {Forward, Backward, Bidirectional, Unidirectional = Backward};
@ -43,6 +44,13 @@ class LexicalReordering : public ScoreProducer {
return m_NumScoreComponents;
};
virtual FFState* Evaluate(
const Hypothesis& cur_hypo,
const FFState* prev_state,
ScoreComponentCollection* accumulator) const;
const FFState* EmptyHypothesisState() const;
virtual std::string GetScoreProducerDescription() const {
return "Generic Lexical Reordering Model... overwrite in subclass.";
};

View File

@ -15,6 +15,8 @@ libmoses_a_SOURCES = \
Factor.cpp \
FactorCollection.cpp \
FactorTypeSet.cpp \
FeatureFunction.cpp \
FFState.cpp \
FloydWarshall.cpp \
GenerationDictionary.cpp \
hash.cpp \

View File

@ -23,6 +23,7 @@ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
#include "PhraseDictionary.h"
#include "StaticData.h"
#include "InputType.h"
#include "TranslationOption.h"
namespace Moses
{
@ -50,5 +51,11 @@ size_t PhraseDictionary::GetNumScoreComponents() const
return m_numScoreComponent;
}
size_t PhraseDictionary::GetNumInputScores() const { return 0;}
bool PhraseDictionary::ComputeValueInTranslationOption() const {
return true;
}
}

View File

@ -30,6 +30,7 @@ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
#include "TargetPhrase.h"
#include "Dictionary.h"
#include "TargetPhraseCollection.h"
#include "FeatureFunction.h"
namespace Moses
{
@ -40,7 +41,7 @@ class WordsRange;
/** abstract base class for phrase table classes
*/
class PhraseDictionary : public Dictionary, public ScoreProducer
class PhraseDictionary : public Dictionary, public StatelessFeatureFunction
{
protected:
size_t m_tableLimit;
@ -59,6 +60,12 @@ class PhraseDictionary : public Dictionary, public ScoreProducer
std::string GetScoreProducerDescription() const;
size_t GetNumScoreComponents() const;
size_t GetNumInputScores() const;
virtual bool ComputeValueInTranslationOption() const;
/** set/change translation weights and recalc weighted score for each translation.
* TODO This may be redundant now we use ScoreCollection
*/

View File

@ -106,6 +106,19 @@ public:
}
}
//! Add scores from a single ScoreProducer only
//! The length of scores must be equal to the number of score components
//! produced by sp
void PlusEquals(const ScoreProducer* sp, const ScoreComponentCollection& scores)
{
size_t i = m_sim->GetBeginIndex(sp->GetScoreBookkeepingID());
const size_t end = m_sim->GetEndIndex(sp->GetScoreBookkeepingID());
for (; i < end; ++i)
{
m_scores[i] += scores.m_scores[i];
}
}
//! Special version PlusEquals(ScoreProducer, vector<float>)
//! to add the score from a single ScoreProducer that produces
//! a single value

View File

@ -20,6 +20,14 @@ void ScoreIndexManager::AddScoreProducer(const ScoreProducer* sp)
assert(m_begins.size() == (sp->GetScoreBookkeepingID()));
m_producers.push_back(sp);
if (sp->IsStateless()) {
const StatelessFeatureFunction* ff = static_cast<const StatelessFeatureFunction*>(sp);
if (!ff->ComputeValueInTranslationOption())
m_stateless.push_back(ff);
} else {
m_stateful.push_back(static_cast<const StatefulFeatureFunction*>(sp));
}
m_begins.push_back(m_last);
size_t numScoreCompsProduced = sp->GetNumScoreComponents();
assert(numScoreCompsProduced > 0);
@ -50,8 +58,7 @@ void ScoreIndexManager::InitFeatureNames() {
bool add_idx = (m_producers[cur_scoreType]->GetNumInputScores() > 1);
while (nis_idx < m_producers[cur_scoreType]->GetNumInputScores()){
ostringstream os;
//distinguish these from other scores with _InputScore
os << m_producers[cur_scoreType]->GetScoreProducerDescription() << "_InputScore";
os << m_producers[cur_scoreType]->GetScoreProducerDescription();
if (add_idx)
os << '_' << (nis_idx+1);
m_featureNames.push_back(os.str());
@ -116,6 +123,7 @@ std::ostream& operator<<(std::ostream& os, const ScoreIndexManager& sim)
for (size_t i = 0; i < sim.m_featureNames.size(); ++i) {
os << sim.m_featureNames[i] << endl;
}
os << "Stateless: " << sim.m_stateless.size() << "\tStateful: " << sim.m_stateful.size() << endl;
return os;
}

View File

@ -17,6 +17,8 @@ namespace Moses
class ScoreProducer;
class ScoreComponentCollection; // debugging only
class StatefulFeatureFunction;
class StatelessFeatureFunction;
/** Keep track of scores and score producers. Each score producer is reserved contiguous set of slots
* to put their score components. All the score components are arranged in a vector with no gaps.
@ -46,12 +48,17 @@ public:
void SerializeFeatureNamesToPB(hgmert::Hypergraph* hg) const;
#endif
void InitWeightVectorFromFile(const std::string& fnam, std::vector<float>* m_allWeights) const;
const std::vector<const ScoreProducer*>& GetFeatureFunctions() const { return m_producers; }
const std::vector<const StatefulFeatureFunction*>& GetStatefulFeatureFunctions() const { return m_stateful; }
const std::vector<const StatelessFeatureFunction*>& GetStatelessFeatureFunctions() const { return m_stateless; }
private:
ScoreIndexManager(const ScoreIndexManager&); // don't implement
std::vector<size_t> m_begins;
std::vector<size_t> m_ends;
std::vector<const ScoreProducer*> m_producers; /**< all the score producers in this run */
std::vector<const StatefulFeatureFunction*> m_stateful; /**< all the score producers in this run */
std::vector<const StatelessFeatureFunction*> m_stateless; /**< all the score producers in this run */
std::vector<std::string> m_featureNames;
size_t m_last;
};

View File

@ -9,7 +9,10 @@
namespace Moses
{
class Hypothesis;
class ScoreComponentCollection;
class ScoreIndexManager;
class FFState;
/** to keep track of the various things that can produce a score,
* we use this evil implementation-inheritance to give them each
@ -47,6 +50,9 @@ public:
//! returns the number of scores gathered from the input (0 by default)
virtual size_t GetNumInputScores() const { return 0; };
virtual bool IsStateless() const = 0;
};

View File

@ -23,17 +23,21 @@ TRANSLATION_1_NBEST_4=1 3 2 ||| d: -2 lm: -10.7767 I: 0 tm: -6.90776 -6.90776 w:
TRANSLATION_1_NBEST_5=3 1 2 ||| d: -3 lm: -10.9714 I: 0 tm: -6.90776 -6.90776 w: -3
TRANSLATION_1_NBEST_6=2 1 3 ||| d: -1 lm: -13.645 I: 0 tm: -6.90776 -6.90776 w: -3
TRANSLATION_1_NBEST_7=3 12 ||| d: -3 lm: -10.0318 I: 0 tm: -4.60517 -4.60517 w: -2
TRANSLATION_1_NBEST_8=3 2 1 ||| d: -4 lm: -12.57 I: 0 tm: -6.90776 -6.90776 w: -3
TRANSLATION_2_NBEST_1=1 2 3 ||| d: 0 lm: -6.08823 I: 0 tm: -6.90776 -6.90776 w: -3
TRANSLATION_2_NBEST_2=2 3 1 ||| d: -2 lm: -9.26988 I: 0 tm: -6.90776 -6.90776 w: -3
TRANSLATION_2_NBEST_3=1 3 2 ||| d: -2 lm: -10.7767 I: 0 tm: -6.90776 -6.90776 w: -3
TRANSLATION_2_NBEST_4=3 1 2 ||| d: -3 lm: -10.9714 I: 0 tm: -6.90776 -6.90776 w: -3
TRANSLATION_2_NBEST_5=2 1 3 ||| d: -1 lm: -13.645 I: 0 tm: -6.90776 -6.90776 w: -3
TRANSLATION_2_NBEST_6=3 2 1 ||| d: -4 lm: -12.57 I: 0 tm: -6.90776 -6.90776 w: -3
TRANSLATION_3_NBEST_1=1 2 3 ||| d: 0 lm: -6.08823 I: 0 tm: -6.90776 -6.90776 w: -3
TRANSLATION_3_NBEST_2=1 3 2 ||| d: -1 lm: -10.7767 I: 0 tm: -6.90776 -6.90776 w: -3
TRANSLATION_3_NBEST_3=1 12 ||| d: 0 lm: -9.28642 I: 0 tm: -4.60517 -4.60517 w: -2
TRANSLATION_3_NBEST_4=2 3 1 ||| d: -3 lm: -9.26988 I: 0 tm: -6.90776 -6.90776 w: -3
TRANSLATION_3_NBEST_5=3 1 2 ||| d: -3 lm: -10.9714 I: 0 tm: -6.90776 -6.90776 w: -3
TRANSLATION_3_NBEST_6=12 1 ||| d: -3 lm: -8.59503 I: 0 tm: -4.60517 -4.60517 w: -2
TRANSLATION_3_NBEST_7=3 2 1 ||| d: -4 lm: -12.57 I: 0 tm: -6.90776 -6.90776 w: -3
TRANSLATION_3_NBEST_8=2 1 3 ||| d: -3 lm: -13.645 I: 0 tm: -6.90776 -6.90776 w: -3
TRANSLATION_4_NBEST_1=1234 3 4 5 ||| d: -3 lm: -8.02249 I: 0 tm: -9.21034 -9.21034 w: -4
TRANSLATION_4_NBEST_2=4 1234 5 ||| d: -2 lm: -7.07413 I: 0 tm: -6.90776 -6.90776 w: -3
TRANSLATION_4_NBEST_3=5 4 1234 ||| d: -1 lm: -8.17776 I: 0 tm: -6.90776 -6.90776 w: -3