create deviant paths

This commit is contained in:
Hieu Hoang 2016-03-20 09:36:42 +00:00
parent 0b3179d4d7
commit 60a8e87147
4 changed files with 66 additions and 14 deletions

View File

@ -209,7 +209,7 @@ void Manager::OutputNBest()
out << "\n";
// create next paths
path->CreateDeviantPaths(contenders);
path->CreateDeviantPaths(contenders, arcLists, GetPool(), system);
++bestInd;
}

View File

@ -8,6 +8,7 @@
#include "TrellisPath.h"
#include "TrellisPaths.h"
#include "Hypothesis.h"
#include "../System.h"
using namespace std;
@ -27,9 +28,38 @@ TrellisPath::TrellisPath(const Hypothesis *hypo, const ArcLists &arcLists)
m_scores = &hypo->GetScores();
}
TrellisPath::TrellisPath(const TrellisPath &origPath, size_t edgeIndex, const Hypothesis *arc)
TrellisPath::TrellisPath(const TrellisPath &origPath,
size_t edgeIndex,
const TrellishNode &newNode,
const ArcLists &arcLists,
MemPool &pool,
const System &system)
:prevEdgeChanged(edgeIndex)
{
nodes.reserve(origPath.nodes.size());
for (size_t currEdge = 0 ; currEdge < edgeIndex ; currEdge++) {
// copy path from parent
nodes.push_back(origPath.nodes[currEdge]);
}
// 1 deviation
nodes.push_back(newNode);
// rest of path comes from following best path backwards
const ArcList &arcList = *newNode.arcList;
const Hypothesis *arc = static_cast<const Hypothesis*>(arcList[newNode.ind]);
const Hypothesis *prevHypo = arc->GetPrevHypo();
while (prevHypo != NULL) {
const ArcList *arcList = arcLists.GetArcList(prevHypo);
assert(arcList);
TrellishNode node(*arcList, 0);
nodes.push_back(node);
prevHypo = prevHypo->GetPrevHypo();
}
CalcScores(origPath.GetScores(), pool, system);
}
TrellisPath::~TrellisPath() {
@ -73,22 +103,35 @@ void TrellisPath::OutputToStream(std::ostream &out, const System &system) const
GetScores().OutputToStream(out, system);
}
void TrellisPath::CreateDeviantPaths(TrellisPaths &paths) const
void TrellisPath::CreateDeviantPaths(TrellisPaths &paths,
const ArcLists &arcLists,
MemPool &pool,
const System &system) const
{
const size_t sizePath = nodes.size();
cerr << "prevEdgeChanged=" << prevEdgeChanged << endl;
for (size_t currEdge = prevEdgeChanged + 1 ; currEdge < sizePath ; currEdge++) {
const TrellishNode &node = nodes[currEdge];
assert(node.ind == 0);
const ArcList &arcList = *node.arcList;
TrellishNode newNode = nodes[currEdge];
assert(newNode.ind == 0);
const ArcList &arcList = *newNode.arcList;
cerr << "arcList=" << arcList.size() << endl;
for (size_t i = 1; i < arcList.size(); ++i) {
const Hypothesis *arcReplace = static_cast<const Hypothesis *>(arcList[i]);
cerr << "i=" << i << endl;
newNode.ind = i;
TrellisPath *deviantPath = new TrellisPath(*this, currEdge, arcReplace);
paths.Add(deviantPath);
TrellisPath *deviantPath = new TrellisPath(*this, currEdge, newNode, arcLists, pool, system);
cerr << "deviantPath=" << deviantPath << endl;
paths.Add(deviantPath);
}
}
}
void TrellisPath::CalcScores(const Scores &origScores, MemPool &pool, const System &system)
{
Scores *scores = new (pool.Allocate<Scores>()) Scores(system, pool, system.featureFunctions.GetNumScores());
m_scores = scores;
}
} /* namespace Moses2 */

View File

@ -12,6 +12,7 @@
namespace Moses2 {
class Scores;
class MemPool;
class Hypothesis;
class System;
class TrellisPaths;
@ -43,7 +44,12 @@ public:
/** create path from another path, deviate at edgeIndex by using arc instead,
* which may change other hypo back from there
*/
TrellisPath(const TrellisPath &origPath, size_t edgeIndex, const Hypothesis *arc);
TrellisPath(const TrellisPath &origPath,
size_t edgeIndex,
const TrellishNode &newNode,
const ArcLists &arcLists,
MemPool &pool,
const System &system);
virtual ~TrellisPath();
@ -54,12 +60,16 @@ public:
void OutputToStream(std::ostream &out, const System &system) const;
//! create a set of next best paths by wiggling 1 of the node at a time.
void CreateDeviantPaths(TrellisPaths &paths) const;
void CreateDeviantPaths(TrellisPaths &paths,
const ArcLists &arcLists,
MemPool &pool,
const System &system) const;
protected:
const Scores *m_scores;
const Scores *m_scores;
void AddNodes(const Hypothesis *hypo, const ArcLists &arcLists);
void AddNodes(const Hypothesis *hypo, const ArcLists &arcLists);
void CalcScores(const Scores &origScores, MemPool &pool, const System &system);
};
} /* namespace Moses2 */

View File

@ -42,7 +42,6 @@ public:
protected:
typedef std::multiset<TrellisPath*, CompareTrellisPathCollection> CollectionType;
CollectionType m_collection;
};
} /* namespace Moses2 */