mosesdecoder/moses/TranslationOptionCollectionLattice.cpp

205 lines
6.1 KiB
C++
Raw Normal View History

// $Id$
#include <list>
#include "TranslationOptionCollectionLattice.h"
#include "ConfusionNet.h"
#include "WordLattice.h"
#include "DecodeGraph.h"
#include "DecodeStepTranslation.h"
#include "DecodeStepGeneration.h"
#include "FactorCollection.h"
#include "FF/InputFeature.h"
#include "TranslationModel/PhraseDictionaryTreeAdaptor.h"
#include "util/exception.hh"
2015-12-07 19:07:11 +03:00
#include "TranslationTask.h"
using namespace std;
namespace Moses
{
/** constructor; just initialize the base class */
TranslationOptionCollectionLattice
::TranslationOptionCollectionLattice
2015-12-11 13:12:54 +03:00
( ttasksptr const& ttask, const WordLattice &input)
2015-12-12 21:04:13 +03:00
// , size_t maxNoTransOptPerCoverage, float translationOptionThreshold)
2015-12-11 13:12:54 +03:00
: TranslationOptionCollection(ttask, input)//
2015-12-12 21:04:13 +03:00
// , maxNoTransOptPerCoverage, translationOptionThreshold)
{
UTIL_THROW_IF2(StaticData::Instance().GetUseLegacyPT(),
2015-01-14 14:07:42 +03:00
"Not for models using the legqacy binary phrase table");
2015-12-11 13:12:54 +03:00
size_t maxNoTransOptPerCoverage = ttask->options()->search.max_trans_opt_per_cov;
float translationOptionThreshold = ttask->options()->search.trans_opt_threshold;
const InputFeature *inputFeature = InputFeature::InstancePtr();
UTIL_THROW_IF2(inputFeature == NULL, "Input feature must be specified");
size_t maxPhraseLength = ttask->options()->search.max_phrase_length; //StaticData::Instance().GetMaxPhraseLength();
size_t size = input.GetSize();
// 1-word phrases
for (size_t startPos = 0; startPos < size; ++startPos) {
const std::vector<size_t> &nextNodes = input.GetNextNodes(startPos);
const ConfusionNet::Column &col = input.GetColumn(startPos);
for (size_t i = 0; i < col.size(); ++i) {
const Word &word = col[i].first;
UTIL_THROW_IF2(word.IsEpsilon(), "Epsilon not supported");
2013-10-04 17:18:11 +04:00
size_t nextNode = nextNodes[i];
size_t endPos = startPos + nextNode - 1;
2015-10-25 16:37:59 +03:00
Range range(startPos, endPos);
if (range.GetNumWordsCovered() > maxPhraseLength) {
2015-01-14 14:07:42 +03:00
continue;
}
const NonTerminalSet &labels = input.GetLabelSet(startPos, endPos);
Phrase subphrase;
subphrase.AddWord(word);
const ScorePair &scores = col[i].second;
ScorePair *inputScore = new ScorePair(scores);
2015-10-19 02:00:40 +03:00
InputPath *path
2015-12-14 02:07:15 +03:00
= new InputPath(ttask.get(), subphrase, labels, range, NULL, inputScore);
path->SetNextNode(nextNode);
m_inputPathQueue.push_back(path);
// recursive
Extend(*path, input, ttask->options()->search.max_phrase_length);
2015-12-09 03:00:35 +03:00
}
}
}
2015-12-09 03:00:35 +03:00
void
2015-12-07 19:07:11 +03:00
TranslationOptionCollectionLattice::
2015-12-09 03:00:35 +03:00
Extend(const InputPath &prevPath, const WordLattice &input,
2015-12-07 19:07:11 +03:00
size_t const maxPhraseLength)
{
2015-01-14 14:07:42 +03:00
size_t nextPos = prevPath.GetWordsRange().GetEndPos() + 1;
if (nextPos >= input.GetSize()) {
return;
}
2015-01-14 14:07:42 +03:00
size_t startPos = prevPath.GetWordsRange().GetStartPos();
const Phrase &prevPhrase = prevPath.GetPhrase();
const ScorePair *prevInputScore = prevPath.GetInputScore();
UTIL_THROW_IF2(prevInputScore == NULL,
"Null previous score");
2015-01-14 14:07:42 +03:00
const std::vector<size_t> &nextNodes = input.GetNextNodes(nextPos);
2015-01-14 14:07:42 +03:00
const ConfusionNet::Column &col = input.GetColumn(nextPos);
for (size_t i = 0; i < col.size(); ++i) {
const Word &word = col[i].first;
UTIL_THROW_IF2(word.IsEpsilon(), "Epsilon not supported");
2015-01-14 14:07:42 +03:00
size_t nextNode = nextNodes[i];
size_t endPos = nextPos + nextNode - 1;
2015-10-25 16:37:59 +03:00
Range range(startPos, endPos);
2015-12-07 19:07:11 +03:00
// size_t maxPhraseLength = StaticData::Instance().GetMaxPhraseLength();
2015-01-14 14:07:42 +03:00
if (range.GetNumWordsCovered() > maxPhraseLength) {
continue;
}
2015-01-14 14:07:42 +03:00
const NonTerminalSet &labels = input.GetLabelSet(startPos, endPos);
2015-01-14 14:07:42 +03:00
Phrase subphrase(prevPhrase);
subphrase.AddWord(word);
2015-01-14 14:07:42 +03:00
const ScorePair &scores = col[i].second;
ScorePair *inputScore = new ScorePair(*prevInputScore);
inputScore->PlusEquals(scores);
2015-10-19 02:00:40 +03:00
InputPath *path = new InputPath(prevPath.ttask, subphrase, labels,
range, &prevPath, inputScore);
2015-01-14 14:07:42 +03:00
path->SetNextNode(nextNode);
m_inputPathQueue.push_back(path);
2015-01-14 14:07:42 +03:00
// recursive
2015-12-07 19:07:11 +03:00
Extend(*path, input, maxPhraseLength);
2015-01-14 14:07:42 +03:00
}
}
void TranslationOptionCollectionLattice::CreateTranslationOptions()
{
GetTargetPhraseCollectionBatch();
VERBOSE(2,"Translation Option Collection\n " << *this << endl);
const vector <DecodeGraph*> &decodeGraphs = StaticData::Instance().GetDecodeGraphs();
UTIL_THROW_IF2(decodeGraphs.size() != 1, "Multiple decoder graphs not supported yet");
const DecodeGraph &decodeGraph = *decodeGraphs[0];
UTIL_THROW_IF2(decodeGraph.GetSize() != 1, "Factored decomposition not supported yet");
const DecodeStep &decodeStep = **decodeGraph.begin();
const PhraseDictionary &phraseDictionary = *decodeStep.GetPhraseDictionaryFeature();
for (size_t i = 0; i < m_inputPathQueue.size(); ++i) {
const InputPath &path = *m_inputPathQueue[i];
2015-10-19 02:00:40 +03:00
TargetPhraseCollection::shared_ptr tpColl
= path.GetTargetPhrases(phraseDictionary);
2015-10-25 16:37:59 +03:00
const Range &range = path.GetWordsRange();
if (tpColl && tpColl->GetSize()) {
2015-01-14 14:07:42 +03:00
TargetPhraseCollection::const_iterator iter;
for (iter = tpColl->begin(); iter != tpColl->end(); ++iter) {
const TargetPhrase &tp = **iter;
TranslationOption *transOpt = new TranslationOption(range, tp);
transOpt->SetInputPath(path);
transOpt->EvaluateWithSourceContext(m_source);
Add(transOpt);
}
} else if (path.GetPhrase().GetSize() == 1) {
// unknown word processing
ProcessOneUnknownWord(path, path.GetWordsRange().GetStartPos(), path.GetWordsRange().GetNumWordsCovered() , path.GetInputScore());
}
}
// Prune
Prune();
Sort();
// future score matrix
CalcEstimatedScore();
// Cached lex reodering costs
CacheLexReordering();
}
2015-02-19 15:27:23 +03:00
void
TranslationOptionCollectionLattice::
ProcessUnknownWord(size_t sourcePos)
{
UTIL_THROW(util::Exception, "ProcessUnknownWord() not implemented for lattice");
// why??? UG
}
2015-02-19 15:27:23 +03:00
bool
TranslationOptionCollectionLattice::
CreateTranslationOptionsForRange
2015-02-19 15:27:23 +03:00
(const DecodeGraph &decodeStepList, size_t startPosition, size_t endPosition,
bool adhereTableLimit, size_t graphInd)
{
2015-02-19 15:27:23 +03:00
UTIL_THROW(util::Exception,
"CreateTranslationOptionsForRange() not implemented for lattice");
}
} // namespace