create deviant paths

This commit is contained in:
Hieu Hoang 2016-08-03 18:56:03 +01:00
parent c4bb814cd5
commit 3c644c8be4
3 changed files with 59 additions and 10 deletions

View File

@ -35,11 +35,14 @@ KBestExtractor::KBestExtractor(const SCFG::Manager &mgr)
contenders.Add(path);
}
cerr << "mgr.system.options.nbest.nbest_size=" << mgr.system.options.nbest.nbest_size << endl;
size_t bestInd = 0;
while (bestInd < mgr.system.options.nbest.nbest_size && !contenders.empty()) {
//cerr << "bestInd=" << bestInd << endl;
SCFG::TrellisPath *path = contenders.Get();
m_coll.push_back(path);
path->CreateDeviantPaths(contenders, mgr);
}
}

View File

@ -8,6 +8,7 @@
#include "Hypothesis.h"
#include "Manager.h"
#include "TargetPhraseImpl.h"
#include "../TrellisPaths.h"
using namespace std;
@ -22,13 +23,27 @@ TrellisNode::TrellisNode(const ArcLists &arcLists, const SCFG::Hypothesis &hypo)
UTIL_THROW_IF2(arcList->size() == 0, "Empty arclist");
ind = 0;
CreateTail(arcLists, hypo);
}
TrellisNode::TrellisNode(const ArcLists &arcLists, const ArcList &varcList, size_t vind)
:arcList(&varcList)
,ind(vind)
{
UTIL_THROW_IF2(vind >= arcList->size(), "arclist out of bound" << ind << " >= " << arcList->size());
const SCFG::Hypothesis &hypo = (*arcList)[ind]->Cast<SCFG::Hypothesis>();
CreateTail(arcLists, hypo);
}
void TrellisNode::CreateTail(const ArcLists &arcLists, const SCFG::Hypothesis &hypo)
{
const Vector<const Hypothesis*> &prevHypos = hypo.GetPrevHypos();
m_prevNodes.resize(prevHypos.size(), NULL);
for (size_t i = 0; i < hypo.GetPrevHypos().size(); ++i) {
const SCFG::Hypothesis &prevHypo = *prevHypos[i];
TrellisNode *prevNode = new TrellisNode(arcLists, prevHypo);
m_prevNodes[i] = prevNode;
const SCFG::Hypothesis &prevHypo = *prevHypos[i];
TrellisNode *prevNode = new TrellisNode(arcLists, prevHypo);
m_prevNodes[i] = prevNode;
}
}
@ -64,6 +79,11 @@ void TrellisNode::OutputToStream(std::stringstream &strm) const
}
}
bool TrellisNode::HasMore() const
{
bool ret = arcList->size() > (ind + 1);
return ret;
}
/////////////////////////////////////////////////////////////////////
@ -75,6 +95,14 @@ TrellisPath::TrellisPath(const SCFG::Manager &mgr, const SCFG::Hypothesis &hypo)
m_scores = m_scores = new (pool.Allocate<Scores>())
Scores(mgr.system, pool, mgr.system.featureFunctions.GetNumScores(), hypo.GetScores());
m_node = new TrellisNode(mgr.arcLists, hypo);
m_prevNodeChanged = m_node;
}
TrellisPath::TrellisPath(const SCFG::Manager &mgr, const SCFG::TrellisPath &origPath, const TrellisNode &nodeToChange)
{
if (origPath.m_node == &nodeToChange) {
m_node = new TrellisNode(mgr.arcLists, *nodeToChange.arcList, nodeToChange.ind + 1);
}
}
void TrellisPath::OutputToStream(std::stringstream &strm)
@ -87,7 +115,16 @@ SCORE TrellisPath::GetFutureScore() const
return m_scores->GetTotalScore();
}
//! create a set of next best paths by wiggling 1 of the node at a time.
void TrellisPath::CreateDeviantPaths(TrellisPaths<SCFG::TrellisPath> &paths, const SCFG::Manager &mgr) const
{
if (m_prevNodeChanged->HasMore()) {
SCFG::TrellisPath *deviantPath = new TrellisPath(mgr, *this, *m_prevNodeChanged);
paths.Add(deviantPath);
}
}
} // namespace
}

View File

@ -11,6 +11,10 @@
namespace Moses2
{
class Scores;
class System;
template<typename T>
class TrellisPaths;
namespace SCFG
{
@ -25,25 +29,27 @@ public:
size_t ind;
TrellisNode(const ArcLists &arcLists, const SCFG::Hypothesis &hypo);
TrellisNode(const ArcList &varcList, size_t vind) :
arcList(&varcList), ind(vind)
{
}
TrellisNode(const ArcLists &arcLists, const ArcList &varcList, size_t vind);
const SCFG::Hypothesis &GetHypothesis() const;
bool HasMore() const;
void OutputToStream(std::stringstream &strm) const;
protected:
std::vector<const TrellisNode*> m_prevNodes;
void CreateTail(const ArcLists &arcLists, const SCFG::Hypothesis &hypo);
};
/////////////////////////////////////////////////////////////////////
class TrellisPath
{
public:
TrellisPath(const SCFG::Manager &mgr, const SCFG::Hypothesis &hypo);
TrellisPath(const SCFG::Manager &mgr, const SCFG::Hypothesis &hypo); // create best path
TrellisPath(const SCFG::Manager &mgr, const SCFG::TrellisPath &origPath, const TrellisNode &nodeToChange); // create original path
void OutputToStream(std::stringstream &strm);
const Scores &GetScores() const
@ -53,13 +59,16 @@ public:
SCORE GetFutureScore() const;
//! create a set of next best paths by wiggling 1 of the node at a time.
void CreateDeviantPaths(TrellisPaths<SCFG::TrellisPath> &paths, const SCFG::Manager &mgr) const;
protected:
Scores *m_scores;
TrellisNode *m_node;
TrellisNode *m_prevNodeChanged;
};
}
} // namespace
}