mosesdecoder/contrib/moses2/HypothesisColl.cpp

313 lines
8.0 KiB
C++
Raw Normal View History

2016-02-29 16:10:55 +03:00
/*
* HypothesisColl.cpp
*
* Created on: 26 Feb 2016
* Author: hieu
*/
2016-06-03 21:57:51 +03:00
#include <iostream>
2016-06-20 16:59:31 +03:00
#include <sstream>
2016-02-29 16:10:55 +03:00
#include <algorithm>
#include <boost/foreach.hpp>
#include "HypothesisColl.h"
#include "ManagerBase.h"
#include "System.h"
2016-10-26 18:50:20 +03:00
#include "MemPoolAllocator.h"
2016-02-29 16:10:55 +03:00
2016-06-20 16:59:31 +03:00
using namespace std;
2016-03-31 23:00:16 +03:00
namespace Moses2
{
2016-02-29 16:10:55 +03:00
HypothesisColl::HypothesisColl(const ManagerBase &mgr)
:m_coll(MemPoolAllocator<const HypothesisBase*>(mgr.GetPool()))
,m_sortedHypos(NULL)
2016-03-31 23:00:16 +03:00
{
//m_bestScore = -std::numeric_limits<float>::infinity();
//m_minBeamScore = -std::numeric_limits<float>::infinity();
m_worseScore = std::numeric_limits<float>::infinity();
2016-03-31 23:00:16 +03:00
}
2016-02-29 16:10:55 +03:00
const HypothesisBase *HypothesisColl::GetBestHypo() const
{
if (GetSize() == 0) {
return NULL;
}
if (m_sortedHypos) {
return (*m_sortedHypos)[0];
}
SCORE bestScore = -std::numeric_limits<SCORE>::infinity();
const HypothesisBase *bestHypo;
BOOST_FOREACH(const HypothesisBase *hypo, m_coll) {
if (hypo->GetFutureScore() > bestScore) {
bestScore = hypo->GetFutureScore();
bestHypo = hypo;
}
}
return bestHypo;
}
2016-08-05 00:41:29 +03:00
void HypothesisColl::Add(
2016-12-05 21:04:26 +03:00
const ManagerBase &mgr,
HypothesisBase *hypo,
Recycler<HypothesisBase*> &hypoRecycle,
ArcLists &arcLists)
2016-08-05 00:41:29 +03:00
{
2016-12-05 21:04:26 +03:00
size_t maxStackSize = mgr.system.options.search.stack_size;
2016-12-06 01:47:37 +03:00
if (GetSize() > maxStackSize * 2) {
//cerr << "maxStackSize=" << maxStackSize << " " << GetSize() << endl;
2016-12-05 21:04:26 +03:00
PruneHypos(mgr, mgr.arcLists);
}
2016-12-05 17:34:24 +03:00
2016-12-01 15:55:20 +03:00
SCORE futureScore = hypo->GetFutureScore();
/*
cerr << "scores:"
<< futureScore << " "
<< m_bestScore << " "
2016-12-01 18:31:22 +03:00
<< m_minBeamScore << " "
<< GetSize() << " "
<< endl;
*/
2016-12-05 17:34:24 +03:00
if (GetSize() >= maxStackSize && futureScore < m_worseScore) {
// beam threshold or really bad hypo that won't make the pruning cut
// as more hypos are added, the m_worseScore stat gets out of date and isn't the optimum cut-off point
//cerr << "Discard, really bad score:" << hypo->Debug(system) << endl;
2016-12-01 15:55:20 +03:00
hypoRecycle.Recycle(hypo);
return;
}
/*
if (futureScore < m_minBeamScore) {
// beam threshold or really bad hypo that won't make the pruning cut
// as more hypos are added, the m_worseScore stat gets out of date and isn't the optimum cut-off point
//cerr << "Discard, below beam:" << hypo->Debug(system) << endl;
hypoRecycle.Recycle(hypo);
return;
}
2016-12-01 15:55:20 +03:00
if (futureScore > m_bestScore) {
m_bestScore = hypo->GetFutureScore();
// this may also affect the worst score
SCORE beamWidth = system.options.search.beam_width;
//cerr << "beamWidth=" << beamWidth << endl;
2016-12-01 18:31:22 +03:00
if ( m_bestScore + beamWidth > m_minBeamScore ) {
m_minBeamScore = m_bestScore + beamWidth;
2016-12-01 15:55:20 +03:00
}
}
//cerr << "OK:" << hypo->Debug(system) << endl;
*/
2016-12-01 15:55:20 +03:00
StackAdd added = Add(hypo);
2016-12-05 21:04:26 +03:00
size_t nbestSize = mgr.system.options.nbest.nbest_size;
if (nbestSize) {
arcLists.AddArc(added.added, hypo, added.other);
2016-08-05 00:41:29 +03:00
}
else {
if (!added.added) {
hypoRecycle.Recycle(hypo);
}
2016-12-05 17:34:24 +03:00
else {
if (added.other) {
hypoRecycle.Recycle(added.other);
}
if (GetSize() <= maxStackSize && hypo->GetFutureScore() < m_worseScore) {
m_worseScore = futureScore;
}
}
2016-08-05 00:41:29 +03:00
}
}
2016-02-29 16:10:55 +03:00
StackAdd HypothesisColl::Add(const HypothesisBase *hypo)
{
std::pair<_HCType::iterator, bool> addRet = m_coll.insert(hypo);
// CHECK RECOMBINATION
if (addRet.second) {
// equiv hypo doesn't exists
return StackAdd(true, NULL);
}
else {
HypothesisBase *hypoExisting = const_cast<HypothesisBase*>(*addRet.first);
if (hypo->GetFutureScore() > hypoExisting->GetFutureScore()) {
// incoming hypo is better than the one we have
const HypothesisBase * const &hypoExisting1 = *addRet.first;
const HypothesisBase *&hypoExisting2 =
const_cast<const HypothesisBase *&>(hypoExisting1);
hypoExisting2 = hypo;
return StackAdd(true, hypoExisting);
}
else {
// already storing the best hypo. discard incoming hypo
return StackAdd(false, hypoExisting);
}
}
//assert(false);
2016-02-29 16:10:55 +03:00
}
const Hypotheses &HypothesisColl::GetSortedAndPruneHypos(
const ManagerBase &mgr,
ArcLists &arcLists) const
2016-02-29 16:10:55 +03:00
{
if (m_sortedHypos == NULL) {
// create sortedHypos first
MemPool &pool = mgr.GetPool();
m_sortedHypos = new (pool.Allocate<Hypotheses>()) Hypotheses(pool,
m_coll.size());
size_t ind = 0;
BOOST_FOREACH(const HypothesisBase *hypo, m_coll){
(*m_sortedHypos)[ind] = hypo;
++ind;
}
SortAndPruneHypos(mgr, arcLists);
}
return *m_sortedHypos;
2016-02-29 16:10:55 +03:00
}
const Hypotheses &HypothesisColl::GetSortedAndPrunedHypos() const
{
UTIL_THROW_IF2(m_sortedHypos == NULL, "m_sortedHypos must be sorted beforehand");
return *m_sortedHypos;
}
2016-03-31 23:00:16 +03:00
void HypothesisColl::SortAndPruneHypos(const ManagerBase &mgr,
ArcLists &arcLists) const
2016-02-29 16:10:55 +03:00
{
size_t stackSize = mgr.system.options.search.stack_size;
Recycler<HypothesisBase*> &recycler = mgr.GetHypoRecycle();
/*
cerr << "UNSORTED hypos: ";
BOOST_FOREACH(const HypothesisBase *hypo, m_coll) {
cerr << hypo << "(" << hypo->GetFutureScore() << ")" << " ";
2016-03-31 23:00:16 +03:00
}
cerr << endl;
*/
Hypotheses::iterator iterMiddle;
iterMiddle =
(stackSize == 0 || m_sortedHypos->size() < stackSize) ?
m_sortedHypos->end() : m_sortedHypos->begin() + stackSize;
std::partial_sort(m_sortedHypos->begin(), iterMiddle, m_sortedHypos->end(),
HypothesisFutureScoreOrderer());
// prune
if (stackSize && m_sortedHypos->size() > stackSize) {
for (size_t i = stackSize; i < m_sortedHypos->size(); ++i) {
HypothesisBase *hypo = const_cast<HypothesisBase*>((*m_sortedHypos)[i]);
recycler.Recycle(hypo);
// delete from arclist
if (mgr.system.options.nbest.nbest_size) {
arcLists.Delete(hypo);
}
}
m_sortedHypos->resize(stackSize);
}
/*
cerr << "sorted hypos: ";
for (size_t i = 0; i < m_sortedHypos->size(); ++i) {
const HypothesisBase *hypo = (*m_sortedHypos)[i];
cerr << hypo << " ";
2016-03-31 23:00:16 +03:00
}
cerr << endl;
*/
2016-02-29 16:10:55 +03:00
}
2016-12-06 01:47:37 +03:00
void HypothesisColl::PruneHypos(const ManagerBase &mgr, ArcLists &arcLists)
2016-12-05 21:04:26 +03:00
{
2016-12-06 01:47:37 +03:00
size_t maxStackSize = mgr.system.options.search.stack_size;
2016-12-05 21:04:26 +03:00
Recycler<HypothesisBase*> &recycler = mgr.GetHypoRecycle();
/*
cerr << "UNSORTED hypos: ";
BOOST_FOREACH(const HypothesisBase *hypo, m_coll) {
cerr << hypo << "(" << hypo->GetFutureScore() << ")" << " ";
}
cerr << endl;
*/
2016-12-06 01:47:37 +03:00
vector<const HypothesisBase*> sortedHypos(GetSize());
2016-12-05 21:04:26 +03:00
size_t ind = 0;
BOOST_FOREACH(const HypothesisBase *hypo, m_coll){
sortedHypos[ind] = hypo;
++ind;
}
vector<const HypothesisBase*>::iterator iterMiddle;
iterMiddle =
2016-12-06 01:47:37 +03:00
(maxStackSize == 0 || sortedHypos.size() < maxStackSize) ?
sortedHypos.end() : sortedHypos.begin() + maxStackSize;
2016-12-05 21:04:26 +03:00
std::partial_sort(sortedHypos.begin(), iterMiddle, sortedHypos.end(),
HypothesisFutureScoreOrderer());
// prune
2016-12-06 01:47:37 +03:00
if (maxStackSize && sortedHypos.size() > maxStackSize) {
for (size_t i = maxStackSize; i < sortedHypos.size(); ++i) {
2016-12-05 21:04:26 +03:00
HypothesisBase *hypo = const_cast<HypothesisBase*>((sortedHypos)[i]);
// delete from arclist
if (mgr.system.options.nbest.nbest_size) {
arcLists.Delete(hypo);
}
// delete from collection
2016-12-06 01:47:37 +03:00
Delete(hypo);
2016-12-06 03:20:41 +03:00
recycler.Recycle(hypo);
2016-12-05 21:04:26 +03:00
}
2016-12-06 01:47:37 +03:00
2016-12-05 21:04:26 +03:00
}
/*
cerr << "sorted hypos: ";
for (size_t i = 0; i < sortedHypos.size(); ++i) {
const HypothesisBase *hypo = sortedHypos[i];
cerr << hypo << " ";
}
cerr << endl;
*/
}
2016-12-06 01:47:37 +03:00
void HypothesisColl::Delete(const HypothesisBase *hypo)
{
2016-12-06 03:20:41 +03:00
//cerr << "hypo=" << hypo << " " << m_coll.size() << endl;
2016-12-06 01:47:37 +03:00
_HCType::const_iterator iter = m_coll.find(hypo);
UTIL_THROW_IF2(iter == m_coll.end(), "Can't find hypo");
m_coll.erase(iter);
}
2016-02-29 16:10:55 +03:00
void HypothesisColl::Clear()
{
m_sortedHypos = NULL;
m_coll.clear();
//m_bestScore = -std::numeric_limits<float>::infinity();
//m_minBeamScore = -std::numeric_limits<float>::infinity();
m_worseScore = std::numeric_limits<float>::infinity();
2016-02-29 16:10:55 +03:00
}
2016-06-20 16:59:31 +03:00
std::string HypothesisColl::Debug(const System &system) const
2016-06-03 21:57:51 +03:00
{
stringstream out;
BOOST_FOREACH (const HypothesisBase *hypo, m_coll) {
out << hypo->Debug(system);
out << std::endl << std::endl;
}
2016-06-18 01:06:02 +03:00
return out.str();
2016-06-03 21:57:51 +03:00
}
2016-02-29 16:10:55 +03:00
} /* namespace Moses2 */