mosesdecoder/contrib/moses2/SCFG/Manager.cpp

392 lines
10 KiB
C++
Raw Normal View History

2016-02-26 15:26:32 +03:00
/*
* Manager.cpp
*
* Created on: 23 Oct 2015
* Author: hieu
*/
#include <boost/foreach.hpp>
2016-05-06 13:09:52 +03:00
#include <cstdlib>
2016-02-26 15:26:32 +03:00
#include <vector>
#include <sstream>
2016-06-02 14:58:29 +03:00
#include "../System.h"
#include "../TranslationModel/PhraseTable.h"
2016-02-26 15:26:32 +03:00
#include "Manager.h"
2016-03-03 16:04:27 +03:00
#include "InputPath.h"
2016-04-17 09:16:58 +03:00
#include "Hypothesis.h"
2016-04-30 13:47:51 +03:00
#include "TargetPhraseImpl.h"
2016-05-06 13:09:52 +03:00
#include "ActiveChart.h"
2016-04-27 21:59:03 +03:00
#include "Sentence.h"
2016-02-26 15:26:32 +03:00
#include "nbest/KBestExtractor.h"
2016-08-02 18:35:55 +03:00
2016-02-26 15:26:32 +03:00
using namespace std;
namespace Moses2
{
2016-02-26 15:26:32 +03:00
namespace SCFG
{
2016-03-31 23:00:16 +03:00
Manager::Manager(System &sys, const TranslationTask &task,
2016-06-14 23:30:12 +03:00
const std::string &inputStr, long translationId)
:ManagerBase(sys, task, inputStr, translationId)
2016-02-26 15:51:50 +03:00
{
}
2016-02-26 15:26:32 +03:00
2016-02-26 15:51:50 +03:00
Manager::~Manager()
{
2016-02-26 15:26:32 +03:00
}
void Manager::Decode()
{
2016-03-31 23:00:16 +03:00
// init pools etc
2016-04-16 16:56:15 +03:00
//cerr << "START InitPools()" << endl;
2016-03-31 23:00:16 +03:00
InitPools();
2016-06-22 18:43:43 +03:00
//cerr << "START ParseInput()" << endl;
2016-02-26 15:35:24 +03:00
2016-04-27 12:36:15 +03:00
FactorCollection &vocab = system.GetVocab();
m_input = Sentence::CreateFromString(GetPool(), vocab, system, m_inputStr,
m_translationId);
const SCFG::Sentence &sentence = static_cast<const SCFG::Sentence&>(GetInput());
2016-04-27 12:36:15 +03:00
size_t inputSize = sentence.GetSize();
2016-06-22 18:43:43 +03:00
//cerr << "inputSize=" << inputSize << endl;
2016-03-02 00:41:32 +03:00
2016-04-27 12:36:15 +03:00
m_inputPaths.Init(sentence, *this);
2016-06-22 18:43:43 +03:00
//cerr << "CREATED m_inputPaths" << endl;
2016-03-02 00:41:32 +03:00
2016-04-16 20:59:15 +03:00
m_stacks.Init(*this, inputSize);
2016-06-22 18:43:43 +03:00
//cerr << "CREATED m_stacks" << endl;
2016-03-01 02:28:24 +03:00
2016-04-16 20:59:15 +03:00
for (int startPos = inputSize - 1; startPos >= 0; --startPos) {
2016-06-22 18:43:43 +03:00
//cerr << endl << "startPos=" << startPos << endl;
2016-06-16 19:14:18 +03:00
SCFG::InputPath &initPath = *m_inputPaths.GetMatrix().GetValue(startPos, 0);
2016-06-01 16:29:48 +03:00
2016-06-22 18:43:43 +03:00
//cerr << "BEFORE InitActiveChart=" << initPath.Debug(system) << endl;
2016-06-16 19:14:18 +03:00
InitActiveChart(initPath);
2016-06-22 18:43:43 +03:00
//cerr << "AFTER InitActiveChart=" << initPath.Debug(system) << endl;
2016-03-02 00:41:32 +03:00
int maxPhraseSize = inputSize - startPos + 1;
for (int phraseSize = 1; phraseSize < maxPhraseSize; ++phraseSize) {
2016-06-22 18:43:43 +03:00
//cerr << endl << "phraseSize=" << phraseSize << endl;
2016-06-16 19:14:18 +03:00
SCFG::InputPath &path = *m_inputPaths.GetMatrix().GetValue(startPos, phraseSize);
2016-04-29 01:45:23 +03:00
Stack &stack = m_stacks.GetStack(startPos, phraseSize);
2016-06-22 18:43:43 +03:00
//cerr << "BEFORE LOOKUP path=" << path.Debug(system) << endl;
2016-04-29 01:45:23 +03:00
Lookup(path);
2016-06-22 18:43:43 +03:00
//cerr << "AFTER LOOKUP path=" << path.Debug(system) << endl;
2016-04-29 01:45:23 +03:00
Decode(path, stack);
2016-06-22 18:43:43 +03:00
//cerr << "AFTER DECODE path=" << path.Debug(system) << endl;
2016-06-01 17:59:35 +03:00
2016-04-29 01:45:23 +03:00
LookupUnary(path);
2016-06-22 18:43:43 +03:00
//cerr << "AFTER LookupUnary path=" << path.Debug(system) << endl;
2016-04-29 02:09:02 +03:00
2016-04-29 14:23:43 +03:00
//cerr << "#rules=" << path.GetNumRules() << endl;
2016-03-31 23:00:16 +03:00
}
}
2016-04-17 21:33:21 +03:00
2016-08-27 01:08:21 +03:00
/*
const Stack *stack;
stack = &m_stacks.GetStack(0, 5);
cerr << "stack 0,12:" << stack->Debug(system) << endl;
*/
2016-08-27 02:03:12 +03:00
//m_stacks.OutputStacks();
2016-02-26 15:26:32 +03:00
}
2016-06-01 16:29:48 +03:00
void Manager::InitActiveChart(SCFG::InputPath &path)
2016-03-02 00:41:32 +03:00
{
2016-03-31 23:00:16 +03:00
size_t numPt = system.mappings.size();
2016-04-16 16:56:15 +03:00
//cerr << "numPt=" << numPt << endl;
2016-03-31 23:00:16 +03:00
for (size_t i = 0; i < numPt; ++i) {
2016-04-14 17:55:13 +03:00
const PhraseTable &pt = *system.mappings[i];
2016-04-16 16:56:15 +03:00
//cerr << "START InitActiveChart" << endl;
2016-06-18 00:54:32 +03:00
pt.InitActiveChart(GetPool(), *this, path);
2016-04-16 16:56:15 +03:00
//cerr << "FINISHED InitActiveChart" << endl;
2016-03-31 23:00:16 +03:00
}
2016-04-15 15:38:01 +03:00
}
2016-06-02 14:58:29 +03:00
void Manager::Lookup(SCFG::InputPath &path)
2016-04-15 15:38:01 +03:00
{
size_t numPt = system.mappings.size();
2016-04-16 16:56:15 +03:00
//cerr << "numPt=" << numPt << endl;
2016-04-14 17:55:13 +03:00
2016-04-15 15:38:01 +03:00
for (size_t i = 0; i < numPt; ++i) {
const PhraseTable &pt = *system.mappings[i];
2016-05-25 18:22:24 +03:00
size_t maxChartSpan = system.maxChartSpans[i];
pt.Lookup(GetPool(), *this, maxChartSpan, m_stacks, path);
2016-04-15 15:38:01 +03:00
}
2016-04-16 16:56:15 +03:00
/*
2016-04-16 16:56:15 +03:00
size_t tpsNum = path.targetPhrases.GetSize();
if (tpsNum) {
2016-04-17 21:33:21 +03:00
cerr << tpsNum << " " << path << endl;
2016-04-16 16:56:15 +03:00
}
*/
2016-03-02 00:41:32 +03:00
}
2016-06-02 14:58:29 +03:00
void Manager::LookupUnary(SCFG::InputPath &path)
2016-04-29 01:41:09 +03:00
{
size_t numPt = system.mappings.size();
//cerr << "numPt=" << numPt << endl;
for (size_t i = 0; i < numPt; ++i) {
const PhraseTable &pt = *system.mappings[i];
pt.LookupUnary(GetPool(), *this, m_stacks, path);
}
/*
size_t tpsNum = path.targetPhrases.GetSize();
if (tpsNum) {
cerr << tpsNum << " " << path << endl;
}
*/
}
2016-06-02 14:58:29 +03:00
///////////////////////////////////////////////////////////////
// CUBE-PRUNING
///////////////////////////////////////////////////////////////
void Manager::Decode(SCFG::InputPath &path, Stack &stack)
{
2016-06-04 03:14:46 +03:00
// clear cube pruning data
2016-06-06 01:37:13 +03:00
//std::vector<QueueItem*> &container = Container(m_queue);
//container.clear();
Recycler<HypothesisBase*> &hypoRecycler = GetHypoRecycle();
while (!m_queue.empty()) {
QueueItem *item = m_queue.top();
m_queue.pop();
// recycle unused hypos from queue
Hypothesis *hypo = item->hypo;
hypoRecycler.Recycle(hypo);
// recycle queue item
m_queueItemRecycler.push_back(item);
}
2016-06-03 23:21:42 +03:00
2016-06-04 03:14:46 +03:00
m_seenPositions.clear();
// init queue
BOOST_FOREACH(const InputPath::Coll::value_type &valPair, path.targetPhrases) {
2016-06-02 14:58:29 +03:00
const SymbolBind &symbolBind = valPair.first;
const SCFG::TargetPhrases &tps = *valPair.second;
2016-06-02 18:21:15 +03:00
CreateQueue(path, symbolBind, tps);
2016-06-02 14:58:29 +03:00
}
2016-06-04 03:14:46 +03:00
// MAIN LOOP
2016-06-02 14:58:29 +03:00
size_t pops = 0;
while (!m_queue.empty() && pops < system.options.cube.pop_limit) {
//cerr << "pops=" << pops << endl;
2016-06-02 18:21:15 +03:00
QueueItem *item = m_queue.top();
m_queue.pop();
// add hypo to stack
Hypothesis *hypo = item->hypo;
//cerr << "hypo=" << *hypo << " " << endl;
2016-06-02 18:21:15 +03:00
stack.Add(hypo, GetHypoRecycle(), arcLists);
//cerr << "Added " << *hypo << " " << endl;
2016-06-02 18:21:15 +03:00
item->CreateNext(GetSystemPool(), GetPool(), *this, m_queue, m_seenPositions, path);
//cerr << "Created next " << endl;
2016-06-24 22:23:18 +03:00
m_queueItemRecycler.push_back(item);
2016-06-02 14:58:29 +03:00
++pops;
}
2016-06-02 15:57:29 +03:00
2016-06-02 14:58:29 +03:00
}
2016-06-02 18:21:15 +03:00
void Manager::CreateQueue(
const SCFG::InputPath &path,
const SymbolBind &symbolBind,
const SCFG::TargetPhrases &tps)
2016-06-02 14:58:29 +03:00
{
MemPool &pool = GetPool();
SeenPosition *seenItem = new (pool.Allocate<SeenPosition>()) SeenPosition(pool, symbolBind, tps, symbolBind.numNT);
bool unseen = m_seenPositions.Add(seenItem);
assert(unseen);
2016-06-06 01:29:27 +03:00
QueueItem *item = QueueItem::Create(GetPool(), *this);
2016-08-17 16:10:32 +03:00
item->Init(GetPool(), symbolBind, tps, seenItem->hypoIndColl);
2016-06-02 14:58:29 +03:00
for (size_t i = 0; i < symbolBind.coll.size(); ++i) {
const SymbolBindElement &ele = symbolBind.coll[i];
if (ele.hypos) {
const Moses2::Hypotheses *hypos = ele.hypos;
item->AddHypos(*hypos);
2016-06-02 14:58:29 +03:00
}
}
item->CreateHypo(GetSystemPool(), *this, path, symbolBind);
2016-06-27 23:50:03 +03:00
2016-06-28 17:20:50 +03:00
//cerr << "hypo=" << item->hypo->Debug(system) << endl;
2016-06-27 23:50:03 +03:00
2016-06-02 19:47:04 +03:00
m_queue.push(item);
2016-06-02 14:58:29 +03:00
}
///////////////////////////////////////////////////////////////
// NON CUBE-PRUNING
///////////////////////////////////////////////////////////////
/*
void Manager::Decode(SCFG::InputPath &path, Stack &stack)
2016-04-17 09:16:58 +03:00
{
2016-05-31 20:34:26 +03:00
//cerr << "path=" << path << endl;
2016-05-26 16:46:27 +03:00
boost::unordered_map<SCFG::SymbolBind, SCFG::TargetPhrases*>::const_iterator iterOuter;
2016-05-26 13:42:00 +03:00
for (iterOuter = path.targetPhrases->begin(); iterOuter != path.targetPhrases->end(); ++iterOuter) {
2016-04-20 22:22:57 +03:00
const SCFG::SymbolBind &symbolBind = iterOuter->first;
2016-05-26 16:46:27 +03:00
const SCFG::TargetPhrases &tps = *iterOuter->second;
2016-05-25 01:35:43 +03:00
//cerr << "symbolBind=" << symbolBind << " tps=" << tps.GetSize() << endl;
2016-04-20 19:13:05 +03:00
SCFG::TargetPhrases::const_iterator iter;
for (iter = tps.begin(); iter != tps.end(); ++iter) {
const SCFG::TargetPhraseImpl &tp = **iter;
2016-05-25 01:35:43 +03:00
//cerr << "tp=" << tp << endl;
2016-05-05 19:41:50 +03:00
ExpandHypo(path, symbolBind, tp, stack);
2016-04-20 19:13:05 +03:00
}
2016-04-17 09:16:58 +03:00
}
}
2016-06-02 14:58:29 +03:00
*/
2016-04-17 09:16:58 +03:00
2016-05-31 20:34:26 +03:00
void Manager::ExpandHypo(
const SCFG::InputPath &path,
const SCFG::SymbolBind &symbolBind,
const SCFG::TargetPhraseImpl &tp,
Stack &stack)
{
Recycler<HypothesisBase*> &hypoRecycler = GetHypoRecycle();
std::vector<const SymbolBindElement*> ntEles = symbolBind.GetNTElements();
2016-06-03 18:19:22 +03:00
Vector<size_t> prevHyposIndices(GetPool(), symbolBind.numNT);
2016-05-31 20:34:26 +03:00
assert(ntEles.size() == symbolBind.numNT);
//cerr << "ntEles:" << ntEles.size() << endl;
size_t ind = 0;
while (IncrPrevHypoIndices(prevHyposIndices, ind, ntEles)) {
2016-06-05 23:35:38 +03:00
SCFG::Hypothesis *hypo = SCFG::Hypothesis::Create(GetSystemPool(), *this);
2016-05-31 20:34:26 +03:00
hypo->Init(*this, path, symbolBind, tp, prevHyposIndices);
hypo->EvaluateWhenApplied();
2016-06-02 01:44:46 +03:00
stack.Add(hypo, hypoRecycler, arcLists);
2016-05-31 20:34:26 +03:00
++ind;
}
}
2016-05-06 13:09:52 +03:00
bool Manager::IncrPrevHypoIndices(
2016-06-03 18:19:22 +03:00
Vector<size_t> &prevHyposIndices,
2016-05-06 13:09:52 +03:00
size_t ind,
const std::vector<const SymbolBindElement*> ntEles)
{
if (ntEles.size() == 0) {
// no nt. Do the 1st
return ind ? false : true;
}
size_t numHypos = 0;
2016-05-25 01:35:43 +03:00
//cerr << "IncrPrevHypoIndices:" << ind << " " << ntEles.size() << " ";
2016-05-06 13:09:52 +03:00
for (size_t i = 0; i < ntEles.size() - 1; ++i) {
const SymbolBindElement &ele = *ntEles[i];
const Hypotheses &hypos = *ele.hypos;
2016-05-06 13:09:52 +03:00
numHypos = hypos.size();
std::div_t divRet = std::div((int)ind, (int)numHypos);
ind = divRet.quot;
size_t hypoInd = divRet.rem;
prevHyposIndices[i] = hypoInd;
2016-05-25 01:35:43 +03:00
//cerr << "(" << i << "," << ind << "," << numHypos << "," << hypoInd << ")";
2016-05-06 13:09:52 +03:00
}
// last
prevHyposIndices.back() = ind;
// check if last is over limit
const SymbolBindElement &ele = *ntEles.back();
const Hypotheses &hypos = *ele.hypos;
2016-05-06 13:09:52 +03:00
numHypos = hypos.size();
2016-05-25 01:35:43 +03:00
//cerr << "(" << (ntEles.size() - 1) << "," << ind << "," << numHypos << "," << ind << ")";
//cerr << endl;
2016-05-06 13:09:52 +03:00
if (ind >= numHypos) {
return false;
}
else {
return true;
}
}
2016-05-06 17:41:50 +03:00
std::string Manager::OutputBest() const
{
2016-06-29 13:50:29 +03:00
string out;
2016-05-06 17:41:50 +03:00
const Stack &lastStack = m_stacks.GetLastStack();
2016-08-06 10:16:31 +03:00
const SCFG::Hypothesis *bestHypo = lastStack.GetBestHypo();
2016-05-06 17:41:50 +03:00
if (bestHypo) {
2016-08-04 21:34:53 +03:00
//cerr << "BEST TRANSLATION: " << bestHypo << bestHypo->Debug(system) << endl;
2016-06-23 00:23:51 +03:00
//cerr << " " << out.str() << endl;
2016-06-29 13:50:29 +03:00
stringstream outStrm;
2016-08-14 22:56:08 +03:00
Moses2::FixPrecision(outStrm);
2016-06-29 13:50:29 +03:00
bestHypo->OutputToStream(outStrm);
2016-06-29 13:45:33 +03:00
2016-06-29 13:50:29 +03:00
out = outStrm.str();
out = out.substr(4, out.size() - 10);
2016-06-29 13:45:33 +03:00
if (system.options.output.ReportHypoScore) {
2016-06-29 13:50:29 +03:00
out = SPrint(bestHypo->GetScores().GetTotalScore()) + " " + out;
2016-06-29 13:45:33 +03:00
}
2016-05-06 17:41:50 +03:00
}
else {
if (system.options.output.ReportHypoScore) {
2016-06-29 13:50:29 +03:00
out = "0 ";
}
2016-08-04 21:34:53 +03:00
//cerr << "NO TRANSLATION " << GetTranslationId() << endl;
2016-05-06 17:41:50 +03:00
}
2016-06-29 13:50:29 +03:00
return out;
2016-02-26 15:26:32 +03:00
}
2016-05-06 17:41:50 +03:00
2016-08-02 18:35:55 +03:00
std::string Manager::OutputNBest()
{
2016-08-03 11:37:49 +03:00
stringstream out;
2016-08-18 19:41:37 +03:00
//Moses2::FixPrecision(out);
2016-08-02 18:35:55 +03:00
2016-08-03 14:12:39 +03:00
arcLists.Sort();
2016-08-24 00:38:27 +03:00
//cerr << "arcs=" << arcLists.Debug(system) << endl;
2016-08-03 14:12:39 +03:00
2016-08-02 18:35:55 +03:00
KBestExtractor extractor(*this);
2016-08-03 13:03:11 +03:00
extractor.OutputToStream(out);
2016-08-02 18:35:55 +03:00
2016-08-03 11:37:49 +03:00
return out.str();
2016-08-02 18:35:55 +03:00
}
2016-08-15 15:21:29 +03:00
std::string Manager::OutputTransOpt()
{
const Stack &lastStack = m_stacks.GetLastStack();
const SCFG::Hypothesis *bestHypo = lastStack.GetBestHypo();
if (bestHypo) {
stringstream outStrm;
bestHypo->OutputTransOpt(outStrm);
return outStrm.str();
}
else {
return "";
}
}
2016-05-06 17:41:50 +03:00
} // namespace
2016-02-26 15:26:32 +03:00
}