Merge branch 'master' of github.com:moses-smt/mosesdecoder

This commit is contained in:
Hieu Hoang 2014-05-31 19:52:41 +01:00
commit c4d0f7dc93
3 changed files with 95 additions and 55 deletions

View File

@ -25,7 +25,7 @@ ConstrainedDecodingState::ConstrainedDecodingState(const ChartHypothesis &hypo)
int ConstrainedDecodingState::Compare(const FFState& other) const
{
const ConstrainedDecodingState &otherFF = static_cast<const ConstrainedDecodingState&>(other);
int ret = m_outputPhrase.Compare(otherFF.m_outputPhrase);
int ret = m_outputPhrase.Compare(otherFF.m_outputPhrase);
return ret;
}
@ -34,6 +34,7 @@ ConstrainedDecoding::ConstrainedDecoding(const std::string &line)
:StatefulFeatureFunction(1, line)
,m_maxUnknowns(0)
,m_negate(false)
,m_soft(false)
{
m_tuneable = false;
ReadParameters();
@ -44,47 +45,48 @@ void ConstrainedDecoding::Load()
const StaticData &staticData = StaticData::Instance();
bool addBeginEndWord = (staticData.GetSearchAlgorithm() == ChartDecoding) || (staticData.GetSearchAlgorithm() == ChartIncremental);
InputFileStream constraintFile(m_path);
std::string line;
long sentenceID = staticData.GetStartTranslationId() - 1;
while (getline(constraintFile, line)) {
vector<string> vecStr = Tokenize(line, "\t");
Phrase phrase(0);
if (vecStr.size() == 1) {
sentenceID++;
phrase.CreateFromString(Output, staticData.GetOutputFactorOrder(), vecStr[0], staticData.GetFactorDelimiter(), NULL);
} else if (vecStr.size() == 2) {
sentenceID = Scan<long>(vecStr[0]);
phrase.CreateFromString(Output, staticData.GetOutputFactorOrder(), vecStr[1], staticData.GetFactorDelimiter(), NULL);
} else {
UTIL_THROW(util::Exception, "Reference file not loaded");
for(size_t i = 0; i < m_paths.size(); ++i) {
InputFileStream constraintFile(m_paths[i]);
std::string line;
long sentenceID = staticData.GetStartTranslationId() - 1;
while (getline(constraintFile, line)) {
vector<string> vecStr = Tokenize(line, "\t");
Phrase phrase(0);
if (vecStr.size() == 1) {
sentenceID++;
phrase.CreateFromString(Output, staticData.GetOutputFactorOrder(), vecStr[0], staticData.GetFactorDelimiter(), NULL);
} else if (vecStr.size() == 2) {
sentenceID = Scan<long>(vecStr[0]);
phrase.CreateFromString(Output, staticData.GetOutputFactorOrder(), vecStr[1], staticData.GetFactorDelimiter(), NULL);
} else {
UTIL_THROW(util::Exception, "Reference file not loaded");
}
if (addBeginEndWord) {
phrase.InitStartEndWord();
}
m_constraints[sentenceID].push_back(phrase);
}
if (addBeginEndWord) {
phrase.InitStartEndWord();
}
m_constraints.insert(make_pair(sentenceID,phrase));
}
}
std::vector<float> ConstrainedDecoding::DefaultWeights() const
{
UTIL_THROW_IF2(m_numScoreComponents != 1,
"ConstrainedDecoding must only have 1 score");
"ConstrainedDecoding must only have 1 score");
vector<float> ret(1, 1);
return ret;
}
template <class H, class M>
const Phrase *GetConstraint(const std::map<long,Phrase> &constraints, const H &hypo)
const std::vector<Phrase> *GetConstraint(const std::map<long,std::vector<Phrase> > &constraints, const H &hypo)
{
const M &mgr = hypo.GetManager();
const InputType &input = mgr.GetSource();
long id = input.GetTranslationId();
map<long,Phrase>::const_iterator iter;
map<long,std::vector<Phrase> >::const_iterator iter;
iter = constraints.find(id);
if (iter == constraints.end()) {
@ -101,30 +103,37 @@ FFState* ConstrainedDecoding::Evaluate(
const FFState* prev_state,
ScoreComponentCollection* accumulator) const
{
const Phrase *ref = GetConstraint<Hypothesis, Manager>(m_constraints, hypo);
const std::vector<Phrase> *ref = GetConstraint<Hypothesis, Manager>(m_constraints, hypo);
assert(ref);
ConstrainedDecodingState *ret = new ConstrainedDecodingState(hypo);
const Phrase &outputPhrase = ret->GetPhrase();
const Phrase &outputPhrase = ret->GetPhrase();
size_t searchPos = ref->Find(outputPhrase, m_maxUnknowns);
size_t searchPos = NOT_FOUND;
size_t i = 0;
size_t size = 0;
while(searchPos == NOT_FOUND && i < ref->size()) {
searchPos = (*ref)[i].Find(outputPhrase, m_maxUnknowns);
size = (*ref)[i].GetSize();
i++;
}
float score;
if (hypo.IsSourceCompleted()) {
// translated entire sentence.
bool match = (searchPos == 0) && (ref->GetSize() == outputPhrase.GetSize());
if (!m_negate) {
score = match ? 0 : - std::numeric_limits<float>::infinity();
}
else {
score = !match ? 0 : - std::numeric_limits<float>::infinity();
}
bool match = (searchPos == 0) && (size == outputPhrase.GetSize());
if (!m_negate) {
score = match ? 0 : - ( m_soft ? 1 : std::numeric_limits<float>::infinity());
}
else {
score = !match ? 0 : - ( m_soft ? 1 : std::numeric_limits<float>::infinity());
}
} else if (m_negate) {
// keep all derivations
score = 0;
// keep all derivations
score = 0;
}
else {
score = (searchPos != NOT_FOUND) ? 0 : - std::numeric_limits<float>::infinity();
score = (searchPos != NOT_FOUND) ? 0 : - ( m_soft ? 1 : std::numeric_limits<float>::infinity());
}
accumulator->PlusEquals(this, score);
@ -137,7 +146,7 @@ FFState* ConstrainedDecoding::EvaluateChart(
int /* featureID - used to index the state in the previous hypotheses */,
ScoreComponentCollection* accumulator) const
{
const Phrase *ref = GetConstraint<ChartHypothesis, ChartManager>(m_constraints, hypo);
const std::vector<Phrase> *ref = GetConstraint<ChartHypothesis, ChartManager>(m_constraints, hypo);
assert(ref);
const ChartManager &mgr = hypo.GetManager();
@ -145,25 +154,33 @@ FFState* ConstrainedDecoding::EvaluateChart(
ConstrainedDecodingState *ret = new ConstrainedDecodingState(hypo);
const Phrase &outputPhrase = ret->GetPhrase();
size_t searchPos = ref->Find(outputPhrase, m_maxUnknowns);
size_t searchPos = NOT_FOUND;
size_t i = 0;
size_t size = 0;
while(searchPos == NOT_FOUND && i < ref->size()) {
searchPos = (*ref)[i].Find(outputPhrase, m_maxUnknowns);
size = (*ref)[i].GetSize();
i++;
}
float score;
if (hypo.GetCurrSourceRange().GetStartPos() == 0 &&
hypo.GetCurrSourceRange().GetEndPos() == source.GetSize() - 1) {
// translated entire sentence.
bool match = (searchPos == 0) && (ref->GetSize() == outputPhrase.GetSize());
bool match = (searchPos == 0) && (size == outputPhrase.GetSize());
if (!m_negate) {
score = match ? 0 : - std::numeric_limits<float>::infinity();
}
else {
score = !match ? 0 : - std::numeric_limits<float>::infinity();
}
if (!m_negate) {
score = match ? 0 : - ( m_soft ? 1 : std::numeric_limits<float>::infinity());
}
else {
score = !match ? 0 : - ( m_soft ? 1 : std::numeric_limits<float>::infinity());
}
} else if (m_negate) {
// keep all derivations
score = 0;
// keep all derivations
score = 0;
} else {
score = (searchPos != NOT_FOUND) ? 0 : - std::numeric_limits<float>::infinity();
score = (searchPos != NOT_FOUND) ? 0 : - ( m_soft ? 1 : std::numeric_limits<float>::infinity());
}
accumulator->PlusEquals(this, score);
@ -174,11 +191,13 @@ FFState* ConstrainedDecoding::EvaluateChart(
void ConstrainedDecoding::SetParameter(const std::string& key, const std::string& value)
{
if (key == "path") {
m_path = value;
m_paths = Tokenize(value, ",");
} else if (key == "max-unknowns") {
m_maxUnknowns = Scan<int>(value);
} else if (key == "negate") {
m_negate = Scan<bool>(value);
m_negate = Scan<bool>(value);
} else if (key == "soft") {
m_soft = Scan<bool>(value);
} else {
StatefulFeatureFunction::SetParameter(key, value);
}

View File

@ -46,13 +46,15 @@ public:
, ScoreComponentCollection &scoreBreakdown
, ScoreComponentCollection &estimatedFutureScore) const
{}
void Evaluate(const InputType &input
void Evaluate(const InputType &input
, const InputPath &inputPath
, const TargetPhrase &targetPhrase
, const StackVec *stackVec
, ScoreComponentCollection &scoreBreakdown
, ScoreComponentCollection *estimatedFutureScore = NULL) const
{}
FFState* Evaluate(
const Hypothesis& cur_hypo,
const FFState* prev_state,
@ -72,10 +74,11 @@ public:
void SetParameter(const std::string& key, const std::string& value);
protected:
std::string m_path;
std::map<long,Phrase> m_constraints;
std::vector<std::string> m_paths;
std::map<long, std::vector<Phrase> > m_constraints;
int m_maxUnknowns;
bool m_negate; // only keep translations which DON'T match the reference
bool m_soft;
};

View File

@ -0,0 +1,18 @@
#!/usr/bin/perl
use utf8;
binmode(STDIN, ":utf8");
binmode(STDOUT, ":utf8");
binmode(STDERR, ":utf8");
while (my $line = <STDIN>) {
chomp($line);
#$line =~ tr/\040-\176/ /c;
#$line =~ s/[^[:print:]]/ /g;
#$line =~ s/\s+/ /g;
$line =~ s/\p{C}/ /g;
print "$line\n";
}