mirror of
https://github.com/moses-smt/mosesdecoder.git
synced 2024-12-27 05:55:02 +03:00
optional soft constraints for constrained decodings, multiple reference files
This commit is contained in:
parent
9a91f423e4
commit
6cdd2b6019
@ -25,7 +25,7 @@ ConstrainedDecodingState::ConstrainedDecodingState(const ChartHypothesis &hypo)
|
||||
int ConstrainedDecodingState::Compare(const FFState& other) const
|
||||
{
|
||||
const ConstrainedDecodingState &otherFF = static_cast<const ConstrainedDecodingState&>(other);
|
||||
int ret = m_outputPhrase.Compare(otherFF.m_outputPhrase);
|
||||
int ret = m_outputPhrase.Compare(otherFF.m_outputPhrase);
|
||||
return ret;
|
||||
}
|
||||
|
||||
@ -44,47 +44,48 @@ void ConstrainedDecoding::Load()
|
||||
const StaticData &staticData = StaticData::Instance();
|
||||
bool addBeginEndWord = (staticData.GetSearchAlgorithm() == ChartDecoding) || (staticData.GetSearchAlgorithm() == ChartIncremental);
|
||||
|
||||
InputFileStream constraintFile(m_path);
|
||||
std::string line;
|
||||
long sentenceID = staticData.GetStartTranslationId() - 1;
|
||||
while (getline(constraintFile, line)) {
|
||||
vector<string> vecStr = Tokenize(line, "\t");
|
||||
|
||||
Phrase phrase(0);
|
||||
if (vecStr.size() == 1) {
|
||||
sentenceID++;
|
||||
phrase.CreateFromString(Output, staticData.GetOutputFactorOrder(), vecStr[0], staticData.GetFactorDelimiter(), NULL);
|
||||
} else if (vecStr.size() == 2) {
|
||||
sentenceID = Scan<long>(vecStr[0]);
|
||||
phrase.CreateFromString(Output, staticData.GetOutputFactorOrder(), vecStr[1], staticData.GetFactorDelimiter(), NULL);
|
||||
} else {
|
||||
UTIL_THROW(util::Exception, "Reference file not loaded");
|
||||
for(size_t i = 0; i < m_paths.size(); ++i) {
|
||||
InputFileStream constraintFile(m_paths[i]);
|
||||
std::string line;
|
||||
long sentenceID = staticData.GetStartTranslationId() - 1;
|
||||
while (getline(constraintFile, line)) {
|
||||
vector<string> vecStr = Tokenize(line, "\t");
|
||||
|
||||
Phrase phrase(0);
|
||||
if (vecStr.size() == 1) {
|
||||
sentenceID++;
|
||||
phrase.CreateFromString(Output, staticData.GetOutputFactorOrder(), vecStr[0], staticData.GetFactorDelimiter(), NULL);
|
||||
} else if (vecStr.size() == 2) {
|
||||
sentenceID = Scan<long>(vecStr[0]);
|
||||
phrase.CreateFromString(Output, staticData.GetOutputFactorOrder(), vecStr[1], staticData.GetFactorDelimiter(), NULL);
|
||||
} else {
|
||||
UTIL_THROW(util::Exception, "Reference file not loaded");
|
||||
}
|
||||
|
||||
if (addBeginEndWord) {
|
||||
phrase.InitStartEndWord();
|
||||
}
|
||||
m_constraints[sentenceID].push_back(phrase);
|
||||
}
|
||||
|
||||
if (addBeginEndWord) {
|
||||
phrase.InitStartEndWord();
|
||||
}
|
||||
m_constraints.insert(make_pair(sentenceID,phrase));
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<float> ConstrainedDecoding::DefaultWeights() const
|
||||
{
|
||||
UTIL_THROW_IF2(m_numScoreComponents != 1,
|
||||
"ConstrainedDecoding must only have 1 score");
|
||||
"ConstrainedDecoding must only have 1 score");
|
||||
vector<float> ret(1, 1);
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <class H, class M>
|
||||
const Phrase *GetConstraint(const std::map<long,Phrase> &constraints, const H &hypo)
|
||||
const std::vector<Phrase> *GetConstraint(const std::map<long,std::vector<Phrase> > &constraints, const H &hypo)
|
||||
{
|
||||
const M &mgr = hypo.GetManager();
|
||||
const InputType &input = mgr.GetSource();
|
||||
long id = input.GetTranslationId();
|
||||
|
||||
map<long,Phrase>::const_iterator iter;
|
||||
map<long,std::vector<Phrase> >::const_iterator iter;
|
||||
iter = constraints.find(id);
|
||||
|
||||
if (iter == constraints.end()) {
|
||||
@ -101,30 +102,37 @@ FFState* ConstrainedDecoding::Evaluate(
|
||||
const FFState* prev_state,
|
||||
ScoreComponentCollection* accumulator) const
|
||||
{
|
||||
const Phrase *ref = GetConstraint<Hypothesis, Manager>(m_constraints, hypo);
|
||||
const std::vector<Phrase> *ref = GetConstraint<Hypothesis, Manager>(m_constraints, hypo);
|
||||
assert(ref);
|
||||
|
||||
ConstrainedDecodingState *ret = new ConstrainedDecodingState(hypo);
|
||||
const Phrase &outputPhrase = ret->GetPhrase();
|
||||
const Phrase &outputPhrase = ret->GetPhrase();
|
||||
|
||||
size_t searchPos = ref->Find(outputPhrase, m_maxUnknowns);
|
||||
size_t searchPos = NOT_FOUND;
|
||||
size_t i = 0;
|
||||
size_t size = 0;
|
||||
while(searchPos == NOT_FOUND && i < ref->size()) {
|
||||
searchPos = (*ref)[i].Find(outputPhrase, m_maxUnknowns);
|
||||
size = (*ref)[i].GetSize();
|
||||
i++;
|
||||
}
|
||||
|
||||
float score;
|
||||
if (hypo.IsSourceCompleted()) {
|
||||
// translated entire sentence.
|
||||
bool match = (searchPos == 0) && (ref->GetSize() == outputPhrase.GetSize());
|
||||
if (!m_negate) {
|
||||
score = match ? 0 : - std::numeric_limits<float>::infinity();
|
||||
}
|
||||
else {
|
||||
score = !match ? 0 : - std::numeric_limits<float>::infinity();
|
||||
}
|
||||
bool match = (searchPos == 0) && (size == outputPhrase.GetSize());
|
||||
if (!m_negate) {
|
||||
score = match ? 0 : - ( m_soft ? 1 : std::numeric_limits<float>::infinity());
|
||||
}
|
||||
else {
|
||||
score = !match ? 0 : - ( m_soft ? 1 : std::numeric_limits<float>::infinity());
|
||||
}
|
||||
} else if (m_negate) {
|
||||
// keep all derivations
|
||||
score = 0;
|
||||
// keep all derivations
|
||||
score = 0;
|
||||
}
|
||||
else {
|
||||
score = (searchPos != NOT_FOUND) ? 0 : - std::numeric_limits<float>::infinity();
|
||||
score = (searchPos != NOT_FOUND) ? 0 : - ( m_soft ? 1 : std::numeric_limits<float>::infinity());
|
||||
}
|
||||
|
||||
accumulator->PlusEquals(this, score);
|
||||
@ -137,7 +145,7 @@ FFState* ConstrainedDecoding::EvaluateChart(
|
||||
int /* featureID - used to index the state in the previous hypotheses */,
|
||||
ScoreComponentCollection* accumulator) const
|
||||
{
|
||||
const Phrase *ref = GetConstraint<ChartHypothesis, ChartManager>(m_constraints, hypo);
|
||||
const std::vector<Phrase> *ref = GetConstraint<ChartHypothesis, ChartManager>(m_constraints, hypo);
|
||||
assert(ref);
|
||||
|
||||
const ChartManager &mgr = hypo.GetManager();
|
||||
@ -145,25 +153,33 @@ FFState* ConstrainedDecoding::EvaluateChart(
|
||||
|
||||
ConstrainedDecodingState *ret = new ConstrainedDecodingState(hypo);
|
||||
const Phrase &outputPhrase = ret->GetPhrase();
|
||||
size_t searchPos = ref->Find(outputPhrase, m_maxUnknowns);
|
||||
|
||||
size_t searchPos = NOT_FOUND;
|
||||
size_t i = 0;
|
||||
size_t size = 0;
|
||||
while(searchPos == NOT_FOUND && i < ref->size()) {
|
||||
searchPos = (*ref)[i].Find(outputPhrase, m_maxUnknowns);
|
||||
size = (*ref)[i].GetSize();
|
||||
i++;
|
||||
}
|
||||
|
||||
float score;
|
||||
if (hypo.GetCurrSourceRange().GetStartPos() == 0 &&
|
||||
hypo.GetCurrSourceRange().GetEndPos() == source.GetSize() - 1) {
|
||||
// translated entire sentence.
|
||||
bool match = (searchPos == 0) && (ref->GetSize() == outputPhrase.GetSize());
|
||||
bool match = (searchPos == 0) && (size == outputPhrase.GetSize());
|
||||
|
||||
if (!m_negate) {
|
||||
score = match ? 0 : - std::numeric_limits<float>::infinity();
|
||||
}
|
||||
else {
|
||||
score = !match ? 0 : - std::numeric_limits<float>::infinity();
|
||||
}
|
||||
if (!m_negate) {
|
||||
score = match ? 0 : - ( m_soft ? 1 : std::numeric_limits<float>::infinity());
|
||||
}
|
||||
else {
|
||||
score = !match ? 0 : - ( m_soft ? 1 : std::numeric_limits<float>::infinity());
|
||||
}
|
||||
} else if (m_negate) {
|
||||
// keep all derivations
|
||||
score = 0;
|
||||
// keep all derivations
|
||||
score = 0;
|
||||
} else {
|
||||
score = (searchPos != NOT_FOUND) ? 0 : - std::numeric_limits<float>::infinity();
|
||||
score = (searchPos != NOT_FOUND) ? 0 : - ( m_soft ? 1 : std::numeric_limits<float>::infinity());
|
||||
}
|
||||
|
||||
accumulator->PlusEquals(this, score);
|
||||
@ -174,11 +190,13 @@ FFState* ConstrainedDecoding::EvaluateChart(
|
||||
void ConstrainedDecoding::SetParameter(const std::string& key, const std::string& value)
|
||||
{
|
||||
if (key == "path") {
|
||||
m_path = value;
|
||||
m_paths = Tokenize(value, ",");
|
||||
} else if (key == "max-unknowns") {
|
||||
m_maxUnknowns = Scan<int>(value);
|
||||
} else if (key == "negate") {
|
||||
m_negate = Scan<bool>(value);
|
||||
m_negate = Scan<bool>(value);
|
||||
} else if (key == "soft") {
|
||||
m_soft = Scan<bool>(value);
|
||||
} else {
|
||||
StatefulFeatureFunction::SetParameter(key, value);
|
||||
}
|
||||
|
@ -46,13 +46,15 @@ public:
|
||||
, ScoreComponentCollection &scoreBreakdown
|
||||
, ScoreComponentCollection &estimatedFutureScore) const
|
||||
{}
|
||||
void Evaluate(const InputType &input
|
||||
|
||||
void Evaluate(const InputType &input
|
||||
, const InputPath &inputPath
|
||||
, const TargetPhrase &targetPhrase
|
||||
, const StackVec *stackVec
|
||||
, ScoreComponentCollection &scoreBreakdown
|
||||
, ScoreComponentCollection *estimatedFutureScore = NULL) const
|
||||
{}
|
||||
|
||||
FFState* Evaluate(
|
||||
const Hypothesis& cur_hypo,
|
||||
const FFState* prev_state,
|
||||
@ -72,10 +74,11 @@ public:
|
||||
void SetParameter(const std::string& key, const std::string& value);
|
||||
|
||||
protected:
|
||||
std::string m_path;
|
||||
std::map<long,Phrase> m_constraints;
|
||||
std::vector<std::string> m_paths;
|
||||
std::map<long, std::vector<Phrase> > m_constraints;
|
||||
int m_maxUnknowns;
|
||||
bool m_negate; // only keep translations which DON'T match the reference
|
||||
bool m_soft;
|
||||
|
||||
};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user