evaluate VW with target context

This commit is contained in:
Ales Tamchyna 2016-03-07 13:17:38 +01:00
parent 48d9d8e0d4
commit fa8df45656

View File

@ -5,6 +5,7 @@
#include <limits>
#include <boost/unordered_map.hpp>
#include <boost/functional/hash.hpp>
#include "moses/FF/StatefulFeatureFunction.h"
#include "moses/PP/CountsPhraseProperty.h"
@ -108,6 +109,7 @@ typedef ThreadLocalByFeatureStorage<VWTargetSentence> TLSTargetSentence;
typedef boost::unordered_map<size_t, float> FloatHashMap;
typedef ThreadLocalByFeatureStorage<FloatHashMap> TLSFloatHashMap;
typedef ThreadLocalByFeatureStorage<boost::unordered_map<size_t, FloatHashMap> > TLSStateExtensions;
class VW : public StatefulFeatureFunction, public TLSTargetSentence
{
@ -124,6 +126,7 @@ public:
m_tlsClassifier = new TLSClassifier(this, *classifierFactory);
m_tlsFutureScores = new TLSFloatHashMap(this);
m_tlsComputedStateExtensions = new TLSStateExtensions(this);
if (! m_normalizer) {
VERBOSE(1, "VW :: No loss function specified, assuming logistic loss.\n");
@ -160,8 +163,8 @@ public:
}
virtual FFState* EvaluateWhenApplied(
const Hypothesis& cur_hypo,
const FFState* prev_state,
const Hypothesis& curHypo,
const FFState* prevState,
ScoreComponentCollection* accumulator) const
{
VERBOSE(2, "VW :: Evaluating translation options\n");
@ -183,10 +186,29 @@ public:
return new DummyState();
}
size_t spanStart = cur_hypo.GetTranslationOption().GetStartPos();
size_t spanEnd = cur_hypo.GetTranslationOption().GetEndPos();
const TranslationOptionList *topts =
cur_hypo.GetManager().getSntTranslationOptions()->GetTranslationOptionList(spanStart, spanEnd);
size_t spanStart = curHypo.GetTranslationOption().GetStartPos();
size_t spanEnd = curHypo.GetTranslationOption().GetEndPos();
// compute our current key
size_t cacheKey = 0;
boost::hash_combine(cacheKey, prevState);
boost::hash_combine(cacheKey, spanStart);
boost::hash_combine(cacheKey, spanEnd);
boost::unordered_map<size_t, FloatHashMap> &computedStateExtensions
= *m_tlsComputedStateExtensions->GetStored();
if (computedStateExtensions.find(cacheKey) == computedStateExtensions.end()) {
// we have not computed this set of translation options yet
const TranslationOptionList *topts =
curHypo.GetManager().getSntTranslationOptions()->GetTranslationOptionList(spanStart, spanEnd);
}
// now our cache is guaranteed to contain the required score, simply look it up
std::vector<float> newScores(m_numScoreComponents);
size_t toptHash = hash_value(curHypo.GetTranslationOption());
newScores[0] = computedStateExtensions[cacheKey][toptHash];
accumulator->PlusEquals(this, newScores);
/*
@ -451,8 +473,11 @@ public:
targetSent.m_sentence = target;
targetSent.m_alignment = alignment;
FloatHashMap &futureScores = *m_tlsFutureScores->GetStored();
futureScores.clear(); // do not keep future cost estimates across sentences!
// do not keep future cost estimates across sentences!
m_tlsFutureScores->GetStored()->clear();
// invalidate our caches after each sentence
m_tlsComputedStateExtensions->GetStored()->clear();
// pre-compute max- and min- aligned points for faster translation option checking
targetSent.SetConstraints(source.GetSize());
@ -609,6 +634,7 @@ private:
TLSClassifier *m_tlsClassifier;
TLSFloatHashMap *m_tlsFutureScores;
TLSStateExtensions *m_tlsComputedStateExtensions;
};
}