refactoring VW state class

This commit is contained in:
Ales Tamchyna 2016-03-23 16:45:53 +01:00
parent 01301ac816
commit c7a1d21abd
4 changed files with 98 additions and 31 deletions

View File

@ -77,6 +77,8 @@ FFState* VW::EvaluateWhenApplied(
{
VERBOSE(2, "VW :: Evaluating translation options\n");
const VWState& prevVWState = *static_cast<const VWState *>(prevState);
const std::vector<VWFeatureBase*>& contextFeatures =
VWFeatureBase::GetTargetContextFeatures(GetScoreProducerDescription());
@ -106,12 +108,11 @@ FFState* VW::EvaluateWhenApplied(
Discriminative::Classifier &classifier = *m_tlsClassifier->GetStored();
// extract target context features
const Phrase &targetContext = static_cast<const VWState *>(prevState)->m_phrase;
size_t contextHash = hash_value(targetContext);
const Phrase &targetContext = prevVWState.GetPhrase();
FeatureVectorMap &contextFeaturesCache = *m_tlsTargetContextFeatures->GetStored();
FeatureVectorMap::const_iterator contextIt = contextFeaturesCache.find(contextHash);
FeatureVectorMap::const_iterator contextIt = contextFeaturesCache.find(cacheKey);
if (contextIt == contextFeaturesCache.end()) {
// we have not extracted features for this context yet
@ -120,7 +121,7 @@ FFState* VW::EvaluateWhenApplied(
for(size_t i = 0; i < contextFeatures.size(); ++i)
(*contextFeatures[i])(input, targetContext, alignInfo, classifier, contextVector);
contextFeaturesCache[contextHash] = contextVector;
contextFeaturesCache[cacheKey] = contextVector;
VERBOSE(3, "VW :: context cache miss\n");
} else {
// context already in cache, simply put feature IDs in the classifier object
@ -172,16 +173,16 @@ FFState* VW::EvaluateWhenApplied(
VERBOSE(3, "VW :: adding score: " << newScores[0] << "\n");
accumulator->PlusEquals(this, newScores);
return VWState::UpdateState(prevState, curHypo);
return new VWState(prevVWState, curHypo);
}
const FFState* VW::EmptyHypothesisState(const InputType &input) const {
size_t maxContextSize = VWFeatureBase::GetMaximumContextSize(GetScoreProducerDescription());
VWState *initial = new VWState();
Phrase initialPhrase;
for (size_t i = 0; i < maxContextSize; i++)
initial->m_phrase.AddWord(m_sentenceStartWord);
initialPhrase.AddWord(m_sentenceStartWord);
return initial;
return new VWState(initialPhrase, 0, 0);
}
void VW::EvaluateTranslationOptionListWithSourceContext(const InputType &input
@ -270,13 +271,13 @@ void VW::EvaluateTranslationOptionListWithSourceContext(const InputType &input
targetContext.AddWord(m_sentenceStartWord);
const Phrase *targetSent = GetStored()->m_sentence;
const AlignmentInfo *alignInfo = GetStored()->m_alignment;
if (currentStart > 0)
targetContext.Append(targetSent->GetSubString(Range(0, currentStart - 1)));
// extract target-context features
AlignmentInfo alignInfo("");
for(size_t i = 0; i < contextFeatures.size(); ++i)
(*contextFeatures[i])(input, targetContext, alignInfo, classifier, dummyVector);
(*contextFeatures[i])(input, targetContext, *alignInfo, classifier, dummyVector);
// go over topts, extract target side features and train the classifier
for (size_t toptIdx = 0; toptIdx < translationOptionList.size(); toptIdx++) {

View File

@ -1,9 +1,11 @@
#pragma once
#include <string>
#include <boost/foreach.hpp>
#include "VWFeatureBase.h"
#include "moses/InputType.h"
#include "moses/TypeDef.h"
#include "moses/Word.h"
namespace Moses
{
@ -56,6 +58,26 @@ protected:
return phrase.GetWord(phrase.GetSize() - posFromEnd - 1).GetString(m_targetFactors, false);
}
// some target-context feature functions also look at the source
inline std::string GetSourceWord(const InputType &input, size_t pos) const {
return input.GetWord(pos).GetString(m_sourceFactors, false);
}
// get source words aligned to a particular context word
std::vector<std::string> GetAlignedSourceWords(const Phrase &contextPhrase
, const InputType &input
, const AlignmentInfo &alignInfo
, size_t posFromEnd) const {
size_t idx = contextPhrase.GetSize() - posFromEnd - 1;
std::set<size_t> alignedToTarget = alignInfo.GetAlignmentsForTarget(idx);
std::vector<std::string> out;
out.reserve(alignedToTarget.size());
BOOST_FOREACH(size_t srcIdx, alignedToTarget) {
out.push_back(GetSourceWord(input, srcIdx));
}
return out;
}
// required context size
size_t m_contextSize;
};

View File

@ -6,25 +6,25 @@
#include "moses/Util.h"
#include "moses/TypeDef.h"
#include "moses/StaticData.h"
#include "moses/TranslationOption.h"
#include <boost/functional/hash.hpp>
namespace Moses {
size_t VWState::hash() const {
return hash_value(m_phrase);
VWState::VWState() : m_spanStart(0), m_spanEnd(0) {
ComputeHash();
}
bool VWState::operator==(const FFState& o) const {
const VWState &other = static_cast<const VWState &>(o);
return m_phrase == other.m_phrase;
VWState::VWState(const Phrase &phrase, size_t spanStart, size_t spanEnd)
: m_phrase(phrase), m_spanStart(spanStart), m_spanEnd(spanEnd) {
ComputeHash();
}
VWState *VWState::UpdateState(const FFState *prevState, const Hypothesis &curHypo) {
const VWState *prevVWState = static_cast<const VWState *>(prevState);
VERBOSE(3, "VW :: updating state\n>> previous state: " << *prevVWState << "\n");
VWState::VWState(const VWState &prevState, const Hypothesis &curHypo) {
VERBOSE(3, "VW :: updating state\n>> previous state: " << prevState << "\n");
// copy phrase from previous state
Phrase phrase = prevVWState->m_phrase;
Phrase phrase = prevState.GetPhrase();
size_t contextSize = phrase.GetSize(); // identical to VWFeatureBase::GetMaximumContextSize()
// add words from current hypothesis
@ -34,19 +34,36 @@ VWState *VWState::UpdateState(const FFState *prevState, const Hypothesis &curHyp
// get a slice of appropriate length
Range range(phrase.GetSize() - contextSize, phrase.GetSize() - 1);
phrase = phrase.GetSubString(range);
m_phrase = phrase.GetSubString(range);
// build the new state
VWState *out = new VWState();
out->m_phrase = phrase;
// set current span start/end
m_spanStart = curHypo.GetTranslationOption().GetStartPos();
m_spanEnd = curHypo.GetTranslationOption().GetEndPos();
VERBOSE(3, ">> updated state: " << *out << "\n");
// compute our hash
ComputeHash();
return out;
VERBOSE(3, ">> updated state: " << *this << "\n");
}
bool VWState::operator==(const FFState& o) const {
const VWState &other = static_cast<const VWState &>(o);
return m_phrase == other.GetPhrase()
&& m_spanStart == other.GetSpanStart()
&& m_spanEnd == other.GetSpanEnd();
}
void VWState::ComputeHash() {
m_hash = 0;
boost::hash_combine(m_hash, m_phrase);
boost::hash_combine(m_hash, m_spanStart);
boost::hash_combine(m_hash, m_spanEnd);
}
std::ostream &operator<<(std::ostream &out, const VWState &state) {
out << state.m_phrase;
out << state.GetPhrase() << "::" << state.GetSpanStart() << "-" << state.GetSpanEnd();
return out;
}

View File

@ -11,14 +11,41 @@ namespace Moses {
/**
* VW state, used in decoding (when target context is enabled).
*/
struct VWState : public FFState {
virtual size_t hash() const;
class VWState : public FFState {
public:
// empty state, used only when VWState is ignored
VWState();
// used for construction of the initial VW state
VWState(const Phrase &phrase, size_t spanStart, size_t spanEnd);
// continue from previous VW state with a new hypothesis
VWState(const VWState &prevState, const Hypothesis &curHypo);
virtual bool operator==(const FFState& o) const;
// shift words in our state, add words from current hypothesis
static VWState *UpdateState(const FFState *prevState, const Hypothesis &curHypo);
inline virtual size_t hash() const {
return m_hash;
}
inline const Phrase &GetPhrase() const {
return m_phrase;
}
inline size_t GetSpanStart() const {
return m_spanStart;
}
inline size_t GetSpanEnd() const {
return m_spanEnd;
}
private:
void ComputeHash();
Phrase m_phrase;
size_t m_spanStart, m_spanEnd;
size_t m_hash;
};
// how to print a VW state