mirror of
https://github.com/moses-smt/mosesdecoder.git
synced 2024-12-26 05:14:36 +03:00
add max-unknowns arg to ConstrainedDecoding FF
This commit is contained in:
parent
5625d30a26
commit
24f1760c05
@ -5,6 +5,7 @@
|
||||
#include "moses/ChartManager.h"
|
||||
#include "moses/StaticData.h"
|
||||
#include "moses/InputFileStream.h"
|
||||
#include "moses/Util.h"
|
||||
#include "util/exception.hh"
|
||||
|
||||
using namespace std;
|
||||
@ -95,7 +96,7 @@ FFState* ConstrainedDecoding::Evaluate(
|
||||
ConstrainedDecodingState *ret = new ConstrainedDecodingState(hypo);
|
||||
const Phrase &outputPhrase = ret->GetPhrase();
|
||||
|
||||
size_t searchPos = ref->Find(outputPhrase);
|
||||
size_t searchPos = ref->Find(outputPhrase, m_maxUnknowns);
|
||||
|
||||
float score;
|
||||
if (hypo.IsSourceCompleted()) {
|
||||
@ -125,8 +126,7 @@ FFState* ConstrainedDecoding::EvaluateChart(
|
||||
|
||||
ConstrainedDecodingState *ret = new ConstrainedDecodingState(hypo);
|
||||
const Phrase &outputPhrase = ret->GetPhrase();
|
||||
|
||||
size_t searchPos = ref->Find(outputPhrase);
|
||||
size_t searchPos = ref->Find(outputPhrase, m_maxUnknowns);
|
||||
|
||||
float score;
|
||||
if (hypo.GetCurrSourceRange().GetStartPos() == 0 &&
|
||||
@ -149,6 +149,9 @@ void ConstrainedDecoding::SetParameter(const std::string& key, const std::string
|
||||
if (key == "path") {
|
||||
m_path = value;
|
||||
}
|
||||
else if (key == "max-unknowns") {
|
||||
m_maxUnknowns = Scan<int>(value);
|
||||
}
|
||||
else {
|
||||
StatefulFeatureFunction::SetParameter(key, value);
|
||||
}
|
||||
|
@ -34,6 +34,7 @@ class ConstrainedDecoding : public StatefulFeatureFunction
|
||||
public:
|
||||
ConstrainedDecoding(const std::string &line)
|
||||
:StatefulFeatureFunction("ConstrainedDecoding", 1, line)
|
||||
,m_maxUnknowns(0)
|
||||
{
|
||||
m_tuneable = false;
|
||||
ReadParameters();
|
||||
@ -76,6 +77,7 @@ public:
|
||||
protected:
|
||||
std::string m_path;
|
||||
std::map<long,Phrase> m_constraints;
|
||||
int m_maxUnknowns;
|
||||
|
||||
};
|
||||
|
||||
|
@ -373,11 +373,12 @@ void Phrase::InitStartEndWord()
|
||||
AddWord(endWord);
|
||||
}
|
||||
|
||||
size_t Phrase::Find(const Phrase &sought) const
|
||||
size_t Phrase::Find(const Phrase &sought, int maxUnknown) const
|
||||
{
|
||||
size_t maxStartPos = GetSize() - sought.GetSize();
|
||||
for (size_t startThisPos = 0; startThisPos <= maxStartPos; ++startThisPos) {
|
||||
size_t thisPos = startThisPos;
|
||||
int currUnknowns = 0;
|
||||
size_t soughtPos;
|
||||
for (soughtPos = 0; soughtPos < sought.GetSize(); ++soughtPos) {
|
||||
const Word &soughtWord = sought.GetWord(soughtPos);
|
||||
@ -386,6 +387,11 @@ size_t Phrase::Find(const Phrase &sought) const
|
||||
if (soughtWord == thisWord) {
|
||||
++thisPos;
|
||||
}
|
||||
else if (soughtWord.IsOOV() && (maxUnknown < 0 || currUnknowns < maxUnknown)) {
|
||||
// the output has an OOV word. Allow a certain number of OOVs
|
||||
++currUnknowns;
|
||||
++thisPos;
|
||||
}
|
||||
else {
|
||||
break;
|
||||
}
|
||||
|
@ -128,7 +128,7 @@ public:
|
||||
bool Contains(const std::vector< std::vector<std::string> > &subPhraseVector
|
||||
, const std::vector<FactorType> &inputFactor) const;
|
||||
|
||||
size_t Find(const Phrase &sought) const;
|
||||
size_t Find(const Phrase &sought, int maxUnknown) const;
|
||||
|
||||
//! create an empty word at the end of the phrase
|
||||
Word &AddWord();
|
||||
|
@ -126,6 +126,7 @@ void Word::CreateUnknownWord(const Word &sourceWord)
|
||||
SetFactor(factorType, factorCollection.AddFactor(Output, factorType, sourceFactor->GetString()));
|
||||
}
|
||||
m_isNonTerminal = sourceWord.IsNonTerminal();
|
||||
m_isOOV = true;
|
||||
}
|
||||
|
||||
void Word::OnlyTheseFactors(const FactorMask &factors)
|
||||
|
Loading…
Reference in New Issue
Block a user