#include "ConstrainedDecoding.h" #include "moses/Hypothesis.h" #include "moses/Manager.h" #include "moses/ChartHypothesis.h" #include "moses/ChartManager.h" #include "moses/StaticData.h" #include "moses/InputFileStream.h" #include "moses/Util.h" #include "util/exception.hh" using namespace std; namespace Moses { ConstrainedDecodingState::ConstrainedDecodingState(const Hypothesis &hypo) { hypo.GetOutputPhrase(m_outputPhrase); } ConstrainedDecodingState::ConstrainedDecodingState(const ChartHypothesis &hypo) { hypo.GetOutputPhrase(m_outputPhrase); } int ConstrainedDecodingState::Compare(const FFState& other) const { const ConstrainedDecodingState &otherFF = static_cast(other); bool ret = m_outputPhrase.Compare(otherFF.m_outputPhrase); return ret; } ////////////////////////////////////////////////////////////////// void ConstrainedDecoding::Load() { const StaticData &staticData = StaticData::Instance(); bool addBeginEndWord = (staticData.GetSearchAlgorithm() == ChartDecoding) || (staticData.GetSearchAlgorithm() == ChartIncremental); InputFileStream constraintFile(m_path); std::string line; long sentenceID = staticData.GetStartTranslationId() - 1; while (getline(constraintFile, line)) { vector vecStr = Tokenize(line, "\t"); Phrase phrase(0); if (vecStr.size() == 1) { sentenceID++; phrase.CreateFromString(Output, staticData.GetOutputFactorOrder(), vecStr[0], staticData.GetFactorDelimiter(), NULL); } else if (vecStr.size() == 2) { sentenceID = Scan(vecStr[0]); phrase.CreateFromString(Output, staticData.GetOutputFactorOrder(), vecStr[1], staticData.GetFactorDelimiter(), NULL); } else { CHECK(false); } if (addBeginEndWord) { phrase.InitStartEndWord(); } m_constraints.insert(make_pair(sentenceID,phrase)); } } std::vector ConstrainedDecoding::DefaultWeights() const { CHECK(m_numScoreComponents == 1); vector ret(1, 1); return ret; } template const Phrase *GetConstraint(const std::map &constraints, const H &hypo) { const M &mgr = hypo.GetManager(); const InputType &input = mgr.GetSource(); long id = input.GetTranslationId(); map::const_iterator iter; iter = constraints.find(id); if (iter == constraints.end()) { return NULL; } else { return &iter->second; } } FFState* ConstrainedDecoding::Evaluate( const Hypothesis& hypo, const FFState* prev_state, ScoreComponentCollection* accumulator) const { const Phrase *ref = GetConstraint(m_constraints, hypo); CHECK(ref); ConstrainedDecodingState *ret = new ConstrainedDecodingState(hypo); const Phrase &outputPhrase = ret->GetPhrase(); size_t searchPos = ref->Find(outputPhrase, m_maxUnknowns); float score; if (hypo.IsSourceCompleted()) { // translated entire sentence. score = (searchPos == 0) && (ref->GetSize() == outputPhrase.GetSize()) ? 0 : - std::numeric_limits::infinity(); } else { score = (searchPos != NOT_FOUND) ? 0 : - std::numeric_limits::infinity(); } accumulator->PlusEquals(this, score); return ret; } FFState* ConstrainedDecoding::EvaluateChart( const ChartHypothesis &hypo, int /* featureID - used to index the state in the previous hypotheses */, ScoreComponentCollection* accumulator) const { const Phrase *ref = GetConstraint(m_constraints, hypo); CHECK(ref); const ChartManager &mgr = hypo.GetManager(); const Sentence &source = static_cast(mgr.GetSource()); ConstrainedDecodingState *ret = new ConstrainedDecodingState(hypo); const Phrase &outputPhrase = ret->GetPhrase(); size_t searchPos = ref->Find(outputPhrase, m_maxUnknowns); float score; if (hypo.GetCurrSourceRange().GetStartPos() == 0 && hypo.GetCurrSourceRange().GetEndPos() == source.GetSize() - 1) { // translated entire sentence. score = (searchPos == 0) && (ref->GetSize() == outputPhrase.GetSize()) ? 0 : - std::numeric_limits::infinity(); } else { score = (searchPos != NOT_FOUND) ? 0 : - std::numeric_limits::infinity(); } accumulator->PlusEquals(this, score); return ret; } void ConstrainedDecoding::SetParameter(const std::string& key, const std::string& value) { if (key == "path") { m_path = value; } else if (key == "max-unknowns") { m_maxUnknowns = Scan(value); } else { StatefulFeatureFunction::SetParameter(key, value); } } }