constrained decoding FF works for both pb and hiero

This commit is contained in:
Hieu Hoang 2013-09-17 15:06:17 +02:00
parent 764684bb6f
commit 5ebb81a17a
15 changed files with 71 additions and 66 deletions

View File

@ -197,7 +197,7 @@ BackwardsEdge::Initialize()
Hypothesis *BackwardsEdge::CreateHypothesis(const Hypothesis &hypothesis, const TranslationOption &transOpt)
{
// create hypothesis and calculate all its scores
Hypothesis *newHypo = hypothesis.CreateNext(transOpt, NULL); // TODO FIXME This is absolutely broken - don't pass null here
Hypothesis *newHypo = hypothesis.CreateNext(transOpt); // TODO FIXME This is absolutely broken - don't pass null here
newHypo->Evaluate(m_futurescore);
return newHypo;

View File

@ -1,4 +1,6 @@
#include "ConstrainedDecoding.h"
#include "moses/Hypothesis.h"
#include "moses/Manager.h"
#include "moses/ChartHypothesis.h"
#include "moses/ChartManager.h"
#include "util/exception.hh"
@ -7,6 +9,11 @@ using namespace std;
namespace Moses
{
ConstrainedDecodingState::ConstrainedDecodingState(const Hypothesis &hypo)
{
hypo.GetOutputPhrase(m_outputPhrase);
}
ConstrainedDecodingState::ConstrainedDecodingState(const ChartHypothesis &hypo)
{
hypo.CreateOutputPhrase(m_outputPhrase);
@ -20,11 +27,33 @@ int ConstrainedDecodingState::Compare(const FFState& other) const
}
FFState* ConstrainedDecoding::Evaluate(
const Hypothesis& cur_hypo,
const Hypothesis& hypo,
const FFState* prev_state,
ScoreComponentCollection* accumulator) const
{
UTIL_THROW(util::Exception, "Not implemented");
const Manager &mgr = hypo.GetManager();
const Phrase *ref = mgr.GetConstraint();
CHECK(ref);
ConstrainedDecodingState *ret = new ConstrainedDecodingState(hypo);
const Phrase &outputPhrase = ret->GetPhrase();
size_t searchPos = ref->Find(outputPhrase);
float score;
if (hypo.IsSourceCompleted()) {
// translated entire sentence.
score = (searchPos == 0) && (ref->GetSize() == outputPhrase.GetSize())
? 0 : - std::numeric_limits<float>::infinity();
}
else {
score = (searchPos != NOT_FOUND) ? 0 : - std::numeric_limits<float>::infinity();
}
accumulator->PlusEquals(this, score);
return ret;
}
FFState* ConstrainedDecoding::EvaluateChart(

View File

@ -14,6 +14,7 @@ public:
ConstrainedDecodingState()
{}
ConstrainedDecodingState(const Hypothesis &hypo);
ConstrainedDecodingState(const ChartHypothesis &hypo);
int Compare(const FFState& other) const;

View File

@ -31,7 +31,7 @@
#include "moses/FF/OSM-Feature/OpSequenceModel.h"
#include "moses/FF/ControlRecombination.h"
#include "moses/FF/ExternalFeature.h"
//#include "moses/FF/ConstrainedDecoding.h"
#include "moses/FF/ConstrainedDecoding.h"
#include "moses/FF/SkeletonStatelessFF.h"
#include "moses/FF/SkeletonStatefulFF.h"
@ -143,7 +143,7 @@ FeatureRegistry::FeatureRegistry()
MOSES_FNAME(PhrasePenalty);
MOSES_FNAME2("UnknownWordPenalty", UnknownWordPenaltyProducer);
MOSES_FNAME(ControlRecombination);
// MOSES_FNAME(ConstrainedDecoding);
MOSES_FNAME(ConstrainedDecoding);
MOSES_FNAME(ExternalFeature);
MOSES_FNAME(SkeletonStatelessFF);

View File

@ -147,64 +147,23 @@ void Hypothesis::AddArc(Hypothesis *loserHypo)
/***
* return the subclass of Hypothesis most appropriate to the given translation option
*/
Hypothesis* Hypothesis::CreateNext(const TranslationOption &transOpt, const Phrase* constraint) const
Hypothesis* Hypothesis::CreateNext(const TranslationOption &transOpt) const
{
return Create(*this, transOpt, constraint);
return Create(*this, transOpt);
}
/***
* return the subclass of Hypothesis most appropriate to the given translation option
*/
Hypothesis* Hypothesis::Create(const Hypothesis &prevHypo, const TranslationOption &transOpt, const Phrase* constrainingPhrase)
Hypothesis* Hypothesis::Create(const Hypothesis &prevHypo, const TranslationOption &transOpt)
{
// This method includes code for constraint decoding
bool createHypothesis = true;
if (constrainingPhrase != NULL) {
size_t constraintSize = constrainingPhrase->GetSize();
size_t start = 1 + prevHypo.GetCurrTargetWordsRange().GetEndPos();
const Phrase &transOptPhrase = transOpt.GetTargetPhrase();
size_t transOptSize = transOptPhrase.GetSize();
size_t endpoint = start + transOptSize - 1;
if (endpoint < constraintSize) {
WordsRange range(start, endpoint);
Phrase relevantConstraint = constrainingPhrase->GetSubString(range);
if ( ! relevantConstraint.IsCompatible(transOptPhrase) ) {
createHypothesis = false;
}
} else {
createHypothesis = false;
}
}
if (createHypothesis) {
#ifdef USE_HYPO_POOL
Hypothesis *ptr = s_objectPool.getPtr();
return new(ptr) Hypothesis(prevHypo, transOpt);
#else
return new Hypothesis(prevHypo, transOpt);
#endif
} else {
// If the previous hypothesis plus the proposed translation option
// fail to match the provided constraint,
// return a null hypothesis.
return NULL;
}
}
/***
* return the subclass of Hypothesis most appropriate to the given target phrase

View File

@ -99,7 +99,7 @@ public:
~Hypothesis();
/** return the subclass of Hypothesis most appropriate to the given translation option */
static Hypothesis* Create(const Hypothesis &prevHypo, const TranslationOption &transOpt, const Phrase* constraint);
static Hypothesis* Create(const Hypothesis &prevHypo, const TranslationOption &transOpt);
static Hypothesis* Create(Manager& manager, const WordsBitmap &initialCoverage);
@ -107,7 +107,7 @@ public:
static Hypothesis* Create(Manager& manager, InputType const& source, const TranslationOption &initialTransOpt);
/** return the subclass of Hypothesis most appropriate to the given translation option */
Hypothesis* CreateNext(const TranslationOption &transOpt, const Phrase* constraint) const;
Hypothesis* CreateNext(const TranslationOption &transOpt) const;
void PrintHypothesis() const;

View File

@ -81,6 +81,13 @@ pair<HypothesisStackCubePruning::iterator, bool> HypothesisStackCubePruning::Add
bool HypothesisStackCubePruning::AddPrune(Hypothesis *hypo)
{
if (hypo->GetTotalScore() == - std::numeric_limits<float>::infinity()) {
m_manager.GetSentenceStats().AddDiscarded();
VERBOSE(3,"discarded, constraint" << std::endl);
FREEHYPO(hypo);
return false;
}
if (hypo->GetTotalScore() < m_worstScore) {
// too bad for stack. don't bother adding hypo into collection
m_manager.GetSentenceStats().AddDiscarded();

View File

@ -88,6 +88,13 @@ pair<HypothesisStackNormal::iterator, bool> HypothesisStackNormal::Add(Hypothesi
bool HypothesisStackNormal::AddPrune(Hypothesis *hypo)
{
if (hypo->GetTotalScore() == - std::numeric_limits<float>::infinity()) {
m_manager.GetSentenceStats().AddDiscarded();
VERBOSE(3,"discarded, constraint" << std::endl);
FREEHYPO(hypo);
return false;
}
// too bad for stack. don't bother adding hypo into collection
if (!StaticData::Instance().GetDisableDiscarding() &&
hypo->GetTotalScore() < m_worstScore

View File

@ -60,7 +60,15 @@ Manager::Manager(size_t lineNumber, InputType const& source, SearchAlgorithm sea
,m_lineNumber(lineNumber)
,m_source(source)
{
StaticData::Instance().InitializeForInput(source);
const StaticData &staticData = StaticData::Instance();
staticData.InitializeForInput(source);
long sentenceID = source.GetTranslationId();
m_constraint = staticData.GetConstrainingPhrase(sentenceID);
if (m_constraint) {
VERBOSE(1, "Search constraint to output: " << *m_constraint<<endl);
}
}
Manager::~Manager()

View File

@ -120,6 +120,7 @@ protected:
std::auto_ptr<SentenceStats> m_sentenceStats;
int m_hypoId; //used to number the hypos as they are created.
size_t m_lineNumber;
const Phrase *m_constraint;
void GetConnectedGraph(
std::map< int, bool >* pConnected,
@ -157,6 +158,9 @@ public:
return m_source;
}
const Phrase *GetConstraint() const
{ return m_constraint; }
/***
* to be called after processing a sentence (which may consist of more than just calling ProcessSentence() )
*/

View File

@ -66,7 +66,7 @@ MockHypothesisGuard::MockHypothesisGuard(
m_targetPhrases.back().CreateFromString(Input, factors, *ti, "|", NULL);
m_toptions.push_back(new TranslationOption
(wordsRange,m_targetPhrases.back()));
m_hypothesis = Hypothesis::Create(*prevHypo,*m_toptions.back(),NULL);
m_hypothesis = Hypothesis::Create(*prevHypo,*m_toptions.back());
}

View File

@ -39,7 +39,6 @@ public:
const TranslationOptionCollection &transOptColl);
protected:
const Phrase *m_constraint;
Manager& m_manager;
InputPath m_inputPath; // for initial hypo
TranslationOption m_initialTransOpt; /**< used to seed 1st hypo */

View File

@ -46,11 +46,6 @@ SearchCubePruning::SearchCubePruning(Manager& manager, const InputType &source,
{
const StaticData &staticData = StaticData::Instance();
/* constraint search not implemented in cube pruning
long sentenceID = source.GetTranslationId();
m_constraint = staticData.GetConstrainingPhrase(sentenceID);
*/
std::vector < HypothesisStackCubePruning >::iterator iterStack;
for (size_t ind = 0 ; ind < m_hypoStackColl.size() ; ++ind) {
HypothesisStackCubePruning *sourceHypoColl = new HypothesisStackCubePruning(m_manager);

View File

@ -25,10 +25,6 @@ SearchNormal::SearchNormal(Manager& manager, const InputType &source, const Tran
// only if constraint decoding (having to match a specified output)
long sentenceID = source.GetTranslationId();
m_constraint = staticData.GetConstrainingPhrase(sentenceID);
if (m_constraint) {
VERBOSE(1, "Search constraint to output: " << *m_constraint<<endl);
}
// initialize the stacks: create data structure and set limits
std::vector < HypothesisStackNormal >::iterator iterStack;
@ -292,7 +288,7 @@ void SearchNormal::ExpandHypothesis(const Hypothesis &hypothesis, const Translat
IFVERBOSE(2) {
t = clock();
}
newHypo = hypothesis.CreateNext(transOpt, m_constraint);
newHypo = hypothesis.CreateNext(transOpt);
IFVERBOSE(2) {
stats.AddTimeBuildHyp( clock()-t );
}
@ -327,7 +323,7 @@ void SearchNormal::ExpandHypothesis(const Hypothesis &hypothesis, const Translat
IFVERBOSE(2) {
t = clock();
}
newHypo = hypothesis.CreateNext(transOpt, m_constraint);
newHypo = hypothesis.CreateNext(transOpt);
if (newHypo==NULL) return;
IFVERBOSE(2) {
stats.AddTimeBuildHyp( clock()-t );

View File

@ -125,7 +125,7 @@ void SearchNormalBatch::ExpandHypothesis(const Hypothesis &hypothesis, const Tra
IFVERBOSE(2) {
t = clock();
}
newHypo = hypothesis.CreateNext(transOpt, m_constraint);
newHypo = hypothesis.CreateNext(transOpt);
IFVERBOSE(2) {
stats.AddTimeBuildHyp( clock()-t );
}