mosesdecoder/moses/TranslationOptionCollectionConfusionNet.cpp

319 lines
12 KiB
C++
Raw Permalink Normal View History

// -*- mode: c++; indent-tabs-mode: nil; tab-width:2 -*-
#include <list>
#include <vector>
#include "TranslationOptionCollectionConfusionNet.h"
2013-06-21 04:17:17 +04:00
#include "ConfusionNet.h"
#include "DecodeGraph.h"
#include "DecodeStepTranslation.h"
2013-08-07 15:11:39 +04:00
#include "DecodeStepGeneration.h"
#include "FactorCollection.h"
#include "FF/InputFeature.h"
#include "TranslationModel/PhraseDictionaryTreeAdaptor.h"
#include "util/exception.hh"
#include <boost/foreach.hpp>
#include "TranslationTask.h"
using namespace std;
namespace Moses
{
/** constructor; just initialize the base class */
TranslationOptionCollectionConfusionNet::
2015-12-12 21:04:13 +03:00
TranslationOptionCollectionConfusionNet(ttasksptr const& ttask,
2015-12-11 13:12:54 +03:00
const ConfusionNet &input)
2015-12-12 21:04:13 +03:00
// , size_t maxNoTransOptPerCoverage, float translationOptionThreshold)
2015-12-11 13:12:54 +03:00
: TranslationOptionCollection(ttask,input)//
2015-12-12 21:04:13 +03:00
// , maxNoTransOptPerCoverage, translationOptionThreshold)
{
2015-12-11 13:12:54 +03:00
size_t maxNoTransOptPerCoverage = ttask->options()->search.max_trans_opt_per_cov;
float translationOptionThreshold = ttask->options()->search.trans_opt_threshold;
// Prefix checkers are phrase dictionaries that provide a prefix check
// to indicate that a phrase table entry with a given prefix exists.
// If no entry with the given prefix exists, there is no point in
// expanding it further.
vector<PhraseDictionary*> prefixCheckers;
BOOST_FOREACH(PhraseDictionary* pd, PhraseDictionary::GetColl())
2014-05-19 17:34:27 +04:00
if (pd->ProvidesPrefixCheck()) prefixCheckers.push_back(pd);
const InputFeature *inputFeature = InputFeature::InstancePtr();
UTIL_THROW_IF2(inputFeature == NULL, "Input feature must be specified");
size_t inputSize = input.GetSize();
m_inputPathMatrix.resize(inputSize);
size_t maxSizePhrase = ttask->options()->search.max_phrase_length;
maxSizePhrase = std::min(inputSize, maxSizePhrase);
// 1-word phrases
for (size_t startPos = 0; startPos < inputSize; ++startPos) {
2013-08-02 21:24:36 +04:00
vector<InputPathList> &vec = m_inputPathMatrix[startPos];
vec.push_back(InputPathList());
InputPathList &list = vec.back();
2015-10-25 16:37:59 +03:00
Range range(startPos, startPos);
const NonTerminalSet &labels = input.GetLabelSet(startPos, startPos);
const ConfusionNet::Column &col = input.GetColumn(startPos);
for (size_t i = 0; i < col.size(); ++i) {
const Word &word = col[i].first;
Phrase subphrase;
subphrase.AddWord(word);
2013-09-08 21:22:55 +04:00
const ScorePair &scores = col[i].second;
ScorePair *inputScore = new ScorePair(scores);
2015-12-14 02:07:15 +03:00
InputPath* path = new InputPath(ttask.get(), subphrase, labels,
2015-12-12 19:23:37 +03:00
range, NULL, inputScore);
2013-10-02 19:51:16 +04:00
list.push_back(path);
2015-12-14 02:07:15 +03:00
m_inputPathQueue.push_back(path);
}
}
// subphrases of 2+ words
for (size_t phraseSize = 2; phraseSize <= maxSizePhrase; ++phraseSize) {
for (size_t startPos = 0; startPos < inputSize - phraseSize + 1; ++startPos) {
size_t endPos = startPos + phraseSize -1;
2015-10-25 16:37:59 +03:00
Range range(startPos, endPos);
const NonTerminalSet &labels = input.GetLabelSet(startPos, endPos);
2013-08-02 21:24:36 +04:00
vector<InputPathList> &vec = m_inputPathMatrix[startPos];
vec.push_back(InputPathList());
InputPathList &list = vec.back();
// loop thru every previous path
2013-10-02 19:51:16 +04:00
const InputPathList &prevPaths = GetInputPathList(startPos, endPos - 1);
int prevNodesInd = 0;
InputPathList::const_iterator iterPath;
2013-10-02 19:51:16 +04:00
for (iterPath = prevPaths.begin(); iterPath != prevPaths.end(); ++iterPath) {
//for (size_t pathInd = 0; pathInd < prevPaths.size(); ++pathInd) {
const InputPath &prevPath = **iterPath;
//const InputPath &prevPath = *prevPaths[pathInd];
2013-10-02 19:51:16 +04:00
const Phrase &prevPhrase = prevPath.GetPhrase();
const ScorePair *prevInputScore = prevPath.GetInputScore();
2013-11-23 00:27:46 +04:00
UTIL_THROW_IF2(prevInputScore == NULL,
2014-01-15 19:42:02 +04:00
"No input score for path: " << prevPath);
// loop thru every word at this position
const ConfusionNet::Column &col = input.GetColumn(endPos);
for (size_t i = 0; i < col.size(); ++i) {
const Word &word = col[i].first;
Phrase subphrase(prevPhrase);
subphrase.AddWord(word);
2014-05-19 17:34:27 +04:00
bool OK = prefixCheckers.size() == 0;
for (size_t k = 0; !OK && k < prefixCheckers.size(); ++k)
OK = prefixCheckers[k]->PrefixExists(m_ttask.lock(), subphrase);
2014-05-19 17:34:27 +04:00
if (!OK) continue;
2013-09-08 21:22:55 +04:00
const ScorePair &scores = col[i].second;
2013-09-08 17:57:31 +04:00
ScorePair *inputScore = new ScorePair(*prevInputScore);
inputScore->PlusEquals(scores);
2015-12-14 02:07:15 +03:00
InputPath *path = new InputPath(ttask.get(), subphrase, labels, range,
2015-12-12 19:23:37 +03:00
&prevPath, inputScore);
2013-10-02 19:51:16 +04:00
list.push_back(path);
m_inputPathQueue.push_back(path);
} // for (size_t i = 0; i < col.size(); ++i) {
++prevNodesInd;
2013-10-02 19:51:16 +04:00
} // for (iterPath = prevPaths.begin(); iterPath != prevPaths.end(); ++iterPath) {
}
}
2014-05-19 17:34:27 +04:00
// cerr << "HAVE " << m_inputPathQueue.size()
// << " input paths of max. length "
// << maxSizePhrase << "." << endl;
}
InputPathList &TranslationOptionCollectionConfusionNet::GetInputPathList(size_t startPos, size_t endPos)
{
size_t offset = endPos - startPos;
2013-11-23 00:27:46 +04:00
UTIL_THROW_IF2(offset >= m_inputPathMatrix[startPos].size(),
"Out of bound access: " << offset);
2013-08-02 21:24:36 +04:00
return m_inputPathMatrix[startPos][offset];
}
/* forcibly create translation option for a particular source word.
* call the base class' ProcessOneUnknownWord() for each possible word in the confusion network
* at a particular source position
*/
void TranslationOptionCollectionConfusionNet::ProcessUnknownWord(size_t sourcePos)
{
2015-10-18 14:41:36 +03:00
ConfusionNet const& source=static_cast<ConfusionNet const&>(m_source);
ConfusionNet::Column const& coll=source.GetColumn(sourcePos);
const InputPathList &inputPathList = GetInputPathList(sourcePos, sourcePos);
ConfusionNet::Column::const_iterator iterCol;
InputPathList::const_iterator iterInputPath;
size_t j=0;
for(iterCol = coll.begin(), iterInputPath = inputPathList.begin();
2013-08-16 00:14:04 +04:00
iterCol != coll.end();
++iterCol , ++iterInputPath) {
const InputPath &inputPath = **iterInputPath;
size_t length = source.GetColumnIncrement(sourcePos, j++);
2013-09-08 21:22:55 +04:00
const ScorePair &inputScores = iterCol->second;
ProcessOneUnknownWord(inputPath ,sourcePos, length, &inputScores);
}
}
void
TranslationOptionCollectionConfusionNet
::CreateTranslationOptions()
{
if (!StaticData::Instance().GetUseLegacyPT()) {
GetTargetPhraseCollectionBatch();
}
TranslationOptionCollection::CreateTranslationOptions();
}
/** create translation options that exactly cover a specific input span.
* Called by CreateTranslationOptions() and ProcessUnknownWord()
* \param decodeGraph list of decoding steps
* \param factorCollection input sentence with all factors
* \param startPos first position in input sentence
* \param lastPos last position in input sentence
* \param adhereTableLimit whether phrase & generation table limits are adhered to
2015-02-19 15:27:23 +03:00
* \return true if there is at least one path for the range has matches
* in the source side of the parallel data, i.e., the phrase prefix exists
* (abortion condition for trie-based lookup if false)
*/
bool
TranslationOptionCollectionConfusionNet::
CreateTranslationOptionsForRange(const DecodeGraph &decodeGraph,
2015-02-19 15:27:23 +03:00
size_t startPos, size_t endPos,
bool adhereTableLimit, size_t graphInd)
{
if (StaticData::Instance().GetUseLegacyPT()) {
2015-02-19 15:27:23 +03:00
return CreateTranslationOptionsForRangeLEGACY(decodeGraph, startPos, endPos,
adhereTableLimit, graphInd);
2013-07-11 19:20:15 +04:00
} else {
2015-02-19 15:27:23 +03:00
return CreateTranslationOptionsForRangeNew(decodeGraph, startPos, endPos,
adhereTableLimit, graphInd);
}
}
bool
TranslationOptionCollectionConfusionNet::
CreateTranslationOptionsForRangeNew
( const DecodeGraph &decodeGraph, size_t startPos, size_t endPos,
bool adhereTableLimit, size_t graphInd)
{
InputPathList &inputPathList = GetInputPathList(startPos, endPos);
if (inputPathList.size() == 0) return false; // no input path matches!
InputPathList::iterator iter;
for (iter = inputPathList.begin(); iter != inputPathList.end(); ++iter) {
2013-07-11 19:20:15 +04:00
InputPath &inputPath = **iter;
TranslationOptionCollection::CreateTranslationOptionsForRange
2015-02-19 15:27:23 +03:00
(decodeGraph, startPos, endPos, adhereTableLimit, graphInd, inputPath);
}
return true;
}
bool
TranslationOptionCollectionConfusionNet::
2015-12-09 03:00:35 +03:00
CreateTranslationOptionsForRangeLEGACY(const DecodeGraph &decodeGraph,
size_t startPos, size_t endPos,
2015-12-07 19:07:11 +03:00
bool adhereTableLimit, size_t graphInd)
{
bool retval = true;
2015-12-09 03:00:35 +03:00
size_t const max_phrase_length
2015-12-10 06:17:36 +03:00
= StaticData::Instance().options()->search.max_phrase_length;
XmlInputType intype = m_ttask.lock()->options()->input.xml_policy;
if ((intype != XmlExclusive) || !HasXmlOptionsOverlappingRange(startPos,endPos)) {
2013-08-07 17:18:12 +04:00
InputPathList &inputPathList = GetInputPathList(startPos, endPos);
2015-02-19 15:27:23 +03:00
2013-07-11 19:20:15 +04:00
// partial trans opt stored in here
2015-12-07 19:07:11 +03:00
PartialTranslOptColl* oldPtoc = new PartialTranslOptColl(max_phrase_length);
2013-07-11 19:20:15 +04:00
size_t totalEarlyPruned = 0;
2013-07-11 19:20:15 +04:00
// initial translation step
list <const DecodeStep* >::const_iterator iterStep = decodeGraph.begin();
const DecodeStep &decodeStep = **iterStep;
2015-11-02 03:00:37 +03:00
DecodeStepTranslation const& dstep
= static_cast<const DecodeStepTranslation&>(decodeStep);
dstep.ProcessInitialTransLEGACY(m_source, *oldPtoc, startPos, endPos,
adhereTableLimit, inputPathList);
2013-07-11 19:20:15 +04:00
// do rest of decode steps
int indexStep = 0;
for (++iterStep ; iterStep != decodeGraph.end() ; ++iterStep) {
2013-08-07 15:11:39 +04:00
const DecodeStep *decodeStep = *iterStep;
const DecodeStepTranslation *transStep =dynamic_cast<const DecodeStepTranslation*>(decodeStep);
const DecodeStepGeneration *genStep =dynamic_cast<const DecodeStepGeneration*>(decodeStep);
2015-12-07 19:07:11 +03:00
PartialTranslOptColl* newPtoc = new PartialTranslOptColl(max_phrase_length);
2013-07-11 19:20:15 +04:00
// go thru each intermediate trans opt just created
const vector<TranslationOption*>& partTransOptList = oldPtoc->GetList();
vector<TranslationOption*>::const_iterator iterPartialTranslOpt;
2014-05-19 17:34:27 +04:00
for (iterPartialTranslOpt = partTransOptList.begin();
iterPartialTranslOpt != partTransOptList.end();
++iterPartialTranslOpt) {
2013-07-11 19:20:15 +04:00
TranslationOption &inputPartialTranslOpt = **iterPartialTranslOpt;
2013-08-07 15:11:39 +04:00
if (transStep) {
2013-08-24 00:34:10 +04:00
transStep->ProcessLEGACY(inputPartialTranslOpt
2013-08-07 17:18:12 +04:00
, *decodeStep
, *newPtoc
, this
, adhereTableLimit);
2013-08-07 17:18:12 +04:00
} else {
assert(genStep);
2013-08-07 17:18:12 +04:00
genStep->Process(inputPartialTranslOpt
2013-08-07 15:11:39 +04:00
, *decodeStep
2013-07-11 19:20:15 +04:00
, *newPtoc
, this
, adhereTableLimit);
2013-08-07 15:11:39 +04:00
}
}
2013-07-11 19:20:15 +04:00
// last but 1 partial trans not required anymore
totalEarlyPruned += newPtoc->GetPrunedCount();
delete oldPtoc;
2013-07-11 19:20:15 +04:00
oldPtoc = newPtoc;
indexStep++;
} // for (++iterStep
// add to fully formed translation option list
PartialTranslOptColl &lastPartialTranslOptColl = *oldPtoc;
const vector<TranslationOption*>& partTransOptList = lastPartialTranslOptColl.GetList();
vector<TranslationOption*>::const_iterator iterColl;
for (iterColl = partTransOptList.begin() ; iterColl != partTransOptList.end() ; ++iterColl) {
TranslationOption *transOpt = *iterColl;
Add(transOpt);
}
lastPartialTranslOptColl.DetachAll();
totalEarlyPruned += oldPtoc->GetPrunedCount();
delete oldPtoc;
// TRACE_ERR( "Early translation options pruned: " << totalEarlyPruned << endl);
} // if ((intype != XmlExclusive) || !HasXmlOptionsOverlappingRange(startPos,endPos))
2015-02-19 15:27:23 +03:00
if (graphInd == 0 && intype != XmlPassThrough &&
HasXmlOptionsOverlappingRange(startPos,endPos)) {
CreateXmlOptionsForRange(startPos, endPos);
}
return retval;
}
}