add max-unknowns arg to ConstrainedDecoding FF

This commit is contained in:
Hieu Hoang 2013-09-18 15:47:49 +02:00
parent 5625d30a26
commit 24f1760c05
5 changed files with 17 additions and 5 deletions

View File

@ -5,6 +5,7 @@
#include "moses/ChartManager.h"
#include "moses/StaticData.h"
#include "moses/InputFileStream.h"
#include "moses/Util.h"
#include "util/exception.hh"
using namespace std;
@ -95,7 +96,7 @@ FFState* ConstrainedDecoding::Evaluate(
ConstrainedDecodingState *ret = new ConstrainedDecodingState(hypo);
const Phrase &outputPhrase = ret->GetPhrase();
size_t searchPos = ref->Find(outputPhrase);
size_t searchPos = ref->Find(outputPhrase, m_maxUnknowns);
float score;
if (hypo.IsSourceCompleted()) {
@ -125,8 +126,7 @@ FFState* ConstrainedDecoding::EvaluateChart(
ConstrainedDecodingState *ret = new ConstrainedDecodingState(hypo);
const Phrase &outputPhrase = ret->GetPhrase();
size_t searchPos = ref->Find(outputPhrase);
size_t searchPos = ref->Find(outputPhrase, m_maxUnknowns);
float score;
if (hypo.GetCurrSourceRange().GetStartPos() == 0 &&
@ -149,6 +149,9 @@ void ConstrainedDecoding::SetParameter(const std::string& key, const std::string
if (key == "path") {
m_path = value;
}
else if (key == "max-unknowns") {
m_maxUnknowns = Scan<int>(value);
}
else {
StatefulFeatureFunction::SetParameter(key, value);
}

View File

@ -34,6 +34,7 @@ class ConstrainedDecoding : public StatefulFeatureFunction
public:
ConstrainedDecoding(const std::string &line)
:StatefulFeatureFunction("ConstrainedDecoding", 1, line)
,m_maxUnknowns(0)
{
m_tuneable = false;
ReadParameters();
@ -76,6 +77,7 @@ public:
protected:
std::string m_path;
std::map<long,Phrase> m_constraints;
int m_maxUnknowns;
};

View File

@ -373,11 +373,12 @@ void Phrase::InitStartEndWord()
AddWord(endWord);
}
size_t Phrase::Find(const Phrase &sought) const
size_t Phrase::Find(const Phrase &sought, int maxUnknown) const
{
size_t maxStartPos = GetSize() - sought.GetSize();
for (size_t startThisPos = 0; startThisPos <= maxStartPos; ++startThisPos) {
size_t thisPos = startThisPos;
int currUnknowns = 0;
size_t soughtPos;
for (soughtPos = 0; soughtPos < sought.GetSize(); ++soughtPos) {
const Word &soughtWord = sought.GetWord(soughtPos);
@ -386,6 +387,11 @@ size_t Phrase::Find(const Phrase &sought) const
if (soughtWord == thisWord) {
++thisPos;
}
else if (soughtWord.IsOOV() && (maxUnknown < 0 || currUnknowns < maxUnknown)) {
// the output has an OOV word. Allow a certain number of OOVs
++currUnknowns;
++thisPos;
}
else {
break;
}

View File

@ -128,7 +128,7 @@ public:
bool Contains(const std::vector< std::vector<std::string> > &subPhraseVector
, const std::vector<FactorType> &inputFactor) const;
size_t Find(const Phrase &sought) const;
size_t Find(const Phrase &sought, int maxUnknown) const;
//! create an empty word at the end of the phrase
Word &AddWord();

View File

@ -126,6 +126,7 @@ void Word::CreateUnknownWord(const Word &sourceWord)
SetFactor(factorType, factorCollection.AddFactor(Output, factorType, sourceFactor->GetString()));
}
m_isNonTerminal = sourceWord.IsNonTerminal();
m_isOOV = true;
}
void Word::OnlyTheseFactors(const FactorMask &factors)