priority queue

This commit is contained in:
Hieu Hoang 2016-08-23 19:09:09 +01:00
parent 933e4d29ea
commit 7fa940f105
2 changed files with 37 additions and 4 deletions

View File

@ -96,7 +96,7 @@ const NBest &NBest::GetChild(size_t ind) const
void NBest::CreateDeviants(
const SCFG::Manager &mgr,
const NBestColl &nbestColl,
std::priority_queue<NBest*> &contenders)
Contenders &contenders)
{
if (ind + 1 < arcList->size()) {
NBest *next = new NBest(mgr, nbestColl, *arcList, ind + 1);
@ -169,7 +169,8 @@ void NBestColl::Add(const SCFG::Manager &mgr, const ArcList &arcList)
{
NBests &nbests = GetOrCreateNBests(arcList);
priority_queue<NBest*> contenders;
Contenders contenders;
boost::unordered_set<size_t> distinctHypos;
NBest *contender;
@ -185,10 +186,27 @@ void NBestColl::Add(const SCFG::Manager &mgr, const ArcList &arcList)
NBest *best = contenders.top();
contenders.pop();
nbests.push_back(best);
best->CreateDeviants(mgr, *this, contenders);
bool ok = false;
if (mgr.system.options.nbest.only_distinct) {
string tgtPhrase = path->OutputTargetPhrase(system);
//cerr << "tgtPhrase=" << tgtPhrase << endl;
boost::hash<std::string> string_hash;
size_t hash = string_hash(tgtPhrase);
if (distinctHypos.insert(hash).second) {
ok = true;
}
}
else {
ok = true;
}
if (ok) {
nbests.push_back(best);
}
}
}

View File

@ -22,6 +22,11 @@ class Manager;
class Hypothesis;
class NBestColl;
class NBests;
class NBest;
class NBestScoreOrderer;
/////////////////////////////////////////////////////////////
typedef std::priority_queue<NBest*, std::vector<NBest*>, NBestScoreOrderer> Contenders;
/////////////////////////////////////////////////////////////
class NBest
@ -46,7 +51,7 @@ public:
void CreateDeviants(
const SCFG::Manager &mgr,
const NBestColl &nbestColl,
std::priority_queue<NBest*> &contenders);
Contenders &contenders);
const Scores &GetScores() const
{ return *m_scores; }
@ -64,6 +69,16 @@ protected:
const SCFG::Hypothesis &GetHypo() const;
};
/////////////////////////////////////////////////////////////
class NBestScoreOrderer
{
public:
bool operator()(const NBest* a, const NBest* b) const
{
return a->GetScores().GetTotalScore() > b->GetScores().GetTotalScore();
}
};
/////////////////////////////////////////////////////////////
class NBests : public std::vector<NBest*>
{