correct creation of input paths for lattices

This commit is contained in:
Hieu Hoang 2013-10-02 18:42:56 +01:00
parent 1b12b0c4a2
commit 6e32bd3e19
5 changed files with 396 additions and 0 deletions

View File

@ -871,6 +871,16 @@
<type>1</type>
<locationURI>PARENT-3-PROJECT_LOC/moses/TranslationOptionCollectionConfusionNet.h</locationURI>
</link>
<link>
<name>TranslationOptionCollectionLattice.cpp</name>
<type>1</type>
<locationURI>PARENT-3-PROJECT_LOC/moses/TranslationOptionCollectionLattice.cpp</locationURI>
</link>
<link>
<name>TranslationOptionCollectionLattice.h</name>
<type>1</type>
<locationURI>PARENT-3-PROJECT_LOC/moses/TranslationOptionCollectionLattice.h</locationURI>
</link>
<link>
<name>TranslationOptionCollectionText.cpp</name>
<type>1</type>

View File

@ -0,0 +1,319 @@
// $Id$
#include <list>
#include "TranslationOptionCollectionLattice.h"
#include "ConfusionNet.h"
#include "WordLattice.h"
#include "DecodeStep.h"
#include "DecodeStepTranslation.h"
#include "DecodeStepGeneration.h"
#include "FactorCollection.h"
#include "FF/InputFeature.h"
#include "TranslationModel/PhraseDictionaryTreeAdaptor.h"
using namespace std;
namespace Moses
{
/** constructor; just initialize the base class */
TranslationOptionCollectionLattice::TranslationOptionCollectionLattice(
const ConfusionNet &input
, size_t maxNoTransOptPerCoverage, float translationOptionThreshold)
: TranslationOptionCollection(input, maxNoTransOptPerCoverage, translationOptionThreshold)
{
const InputFeature *inputFeature = StaticData::Instance().GetInputFeature();
CHECK(inputFeature);
const WordLattice *lattice = dynamic_cast<const WordLattice*>(&input);
if (lattice) {
cerr << *lattice << endl;
}
size_t size = input.GetSize();
m_inputPathMatrix.resize(size);
// 1-word phrases
for (size_t startPos = 0; startPos < size; ++startPos) {
vector<InputPathList> &vec = m_inputPathMatrix[startPos];
vec.push_back(InputPathList());
InputPathList &list = vec.back();
const std::vector<size_t> *nextNodes = NULL;
if (lattice) {
nextNodes = &lattice->GetNextNodes(startPos);
}
WordsRange range(startPos, startPos);
const NonTerminalSet &labels = input.GetLabelSet(startPos, startPos);
const ConfusionNet::Column &col = input.GetColumn(startPos);
for (size_t i = 0; i < col.size(); ++i) {
const Word &word = col[i].first;
Phrase subphrase;
subphrase.AddWord(word);
const ScorePair &scores = col[i].second;
ScorePair *inputScore = new ScorePair(scores);
InputPath *path = new InputPath(subphrase, labels, range, NULL, inputScore);
if (nextNodes) {
size_t nextNode = nextNodes->at(i);
path->SetNextNode(nextNode);
}
list.push_back(path);
m_phraseDictionaryQueue.push_back(path);
}
}
// subphrases of 2+ words
for (size_t phaseSize = 2; phaseSize <= size; ++phaseSize) {
for (size_t startPos = 0; startPos < size - phaseSize + 1; ++startPos) {
size_t endPos = startPos + phaseSize -1;
const std::vector<size_t> *nextNodes = NULL;
if (lattice) {
nextNodes = &lattice->GetNextNodes(endPos);
}
WordsRange range(startPos, endPos);
const NonTerminalSet &labels = input.GetLabelSet(startPos, endPos);
vector<InputPathList> &vec = m_inputPathMatrix[startPos];
vec.push_back(InputPathList());
InputPathList &list = vec.back();
// loop thru every previous path
const InputPathList &prevPaths = GetInputPathList(startPos, endPos - 1);
int prevNodesInd = 0;
InputPathList::const_iterator iterPath;
for (iterPath = prevPaths.begin(); iterPath != prevPaths.end(); ++iterPath) {
//for (size_t pathInd = 0; pathInd < prevPaths.size(); ++pathInd) {
const InputPath &prevPath = **iterPath;
//const InputPath &prevPath = *prevPaths[pathInd];
const Phrase &prevPhrase = prevPath.GetPhrase();
const ScorePair *prevInputScore = prevPath.GetInputScore();
CHECK(prevInputScore);
// loop thru every word at this position
const ConfusionNet::Column &col = input.GetColumn(endPos);
for (size_t i = 0; i < col.size(); ++i) {
const Word &word = col[i].first;
Phrase subphrase(prevPhrase);
subphrase.AddWord(word);
const ScorePair &scores = col[i].second;
ScorePair *inputScore = new ScorePair(*prevInputScore);
inputScore->PlusEquals(scores);
InputPath *path = new InputPath(subphrase, labels, range, &prevPath, inputScore);
if (nextNodes) {
size_t nextNode = nextNodes->at(i);
path->SetNextNode(nextNode);
}
list.push_back(path);
m_phraseDictionaryQueue.push_back(path);
} // for (size_t i = 0; i < col.size(); ++i) {
++prevNodesInd;
} // for (iterPath = prevPaths.begin(); iterPath != prevPaths.end(); ++iterPath) {
}
}
// check whether we should be using the old code to supportbinary phrase-table.
// eventually, we'll stop support the binary phrase-table and delete this legacy code
CheckLEGACY();
}
InputPathList &TranslationOptionCollectionLattice::GetInputPathList(size_t startPos, size_t endPos)
{
size_t offset = endPos - startPos;
CHECK(offset < m_inputPathMatrix[startPos].size());
return m_inputPathMatrix[startPos][offset];
}
/* forcibly create translation option for a particular source word.
* call the base class' ProcessOneUnknownWord() for each possible word in the confusion network
* at a particular source position
*/
void TranslationOptionCollectionLattice::ProcessUnknownWord(size_t sourcePos)
{
ConfusionNet const& source=dynamic_cast<ConfusionNet const&>(m_source);
ConfusionNet::Column const& coll=source.GetColumn(sourcePos);
const InputPathList &inputPathList = GetInputPathList(sourcePos, sourcePos);
ConfusionNet::Column::const_iterator iterCol;
InputPathList::const_iterator iterInputPath;
size_t j=0;
for(iterCol = coll.begin(), iterInputPath = inputPathList.begin();
iterCol != coll.end();
++iterCol , ++iterInputPath) {
const InputPath &inputPath = **iterInputPath;
size_t length = source.GetColumnIncrement(sourcePos, j++);
const ScorePair &inputScores = iterCol->second;
ProcessOneUnknownWord(inputPath ,sourcePos, length, &inputScores);
}
}
void TranslationOptionCollectionLattice::CreateTranslationOptions()
{
if (!m_useLegacy) {
GetTargetPhraseCollectionBatch();
}
TranslationOptionCollection::CreateTranslationOptions();
}
/** create translation options that exactly cover a specific input span.
* Called by CreateTranslationOptions() and ProcessUnknownWord()
* \param decodeGraph list of decoding steps
* \param factorCollection input sentence with all factors
* \param startPos first position in input sentence
* \param lastPos last position in input sentence
* \param adhereTableLimit whether phrase & generation table limits are adhered to
*/
void TranslationOptionCollectionLattice::CreateTranslationOptionsForRange(
const DecodeGraph &decodeGraph
, size_t startPos
, size_t endPos
, bool adhereTableLimit
, size_t graphInd)
{
if (m_useLegacy) {
CreateTranslationOptionsForRangeLEGACY(decodeGraph, startPos, endPos, adhereTableLimit, graphInd);
} else {
CreateTranslationOptionsForRangeNew(decodeGraph, startPos, endPos, adhereTableLimit, graphInd);
}
}
void TranslationOptionCollectionLattice::CreateTranslationOptionsForRangeNew(
const DecodeGraph &decodeGraph
, size_t startPos
, size_t endPos
, bool adhereTableLimit
, size_t graphInd)
{
InputPathList &inputPathList = GetInputPathList(startPos, endPos);
InputPathList::iterator iter;
for (iter = inputPathList.begin(); iter != inputPathList.end(); ++iter) {
InputPath &inputPath = **iter;
TranslationOptionCollection::CreateTranslationOptionsForRange(decodeGraph
, startPos
, endPos
, adhereTableLimit
, graphInd
, inputPath);
}
}
void TranslationOptionCollectionLattice::CreateTranslationOptionsForRangeLEGACY(
const DecodeGraph &decodeGraph
, size_t startPos
, size_t endPos
, bool adhereTableLimit
, size_t graphInd)
{
if ((StaticData::Instance().GetXmlInputType() != XmlExclusive) || !HasXmlOptionsOverlappingRange(startPos,endPos)) {
InputPathList &inputPathList = GetInputPathList(startPos, endPos);
// partial trans opt stored in here
PartialTranslOptColl* oldPtoc = new PartialTranslOptColl;
size_t totalEarlyPruned = 0;
// initial translation step
list <const DecodeStep* >::const_iterator iterStep = decodeGraph.begin();
const DecodeStep &decodeStep = **iterStep;
static_cast<const DecodeStepTranslation&>(decodeStep).ProcessInitialTranslationLEGACY
(m_source, *oldPtoc
, startPos, endPos, adhereTableLimit, inputPathList );
// do rest of decode steps
int indexStep = 0;
for (++iterStep ; iterStep != decodeGraph.end() ; ++iterStep) {
const DecodeStep *decodeStep = *iterStep;
const DecodeStepTranslation *transStep =dynamic_cast<const DecodeStepTranslation*>(decodeStep);
const DecodeStepGeneration *genStep =dynamic_cast<const DecodeStepGeneration*>(decodeStep);
PartialTranslOptColl* newPtoc = new PartialTranslOptColl;
// go thru each intermediate trans opt just created
const vector<TranslationOption*>& partTransOptList = oldPtoc->GetList();
vector<TranslationOption*>::const_iterator iterPartialTranslOpt;
for (iterPartialTranslOpt = partTransOptList.begin() ; iterPartialTranslOpt != partTransOptList.end() ; ++iterPartialTranslOpt) {
TranslationOption &inputPartialTranslOpt = **iterPartialTranslOpt;
if (transStep) {
transStep->ProcessLEGACY(inputPartialTranslOpt
, *decodeStep
, *newPtoc
, this
, adhereTableLimit);
} else {
CHECK(genStep);
genStep->Process(inputPartialTranslOpt
, *decodeStep
, *newPtoc
, this
, adhereTableLimit);
}
}
// last but 1 partial trans not required anymore
totalEarlyPruned += newPtoc->GetPrunedCount();
delete oldPtoc;
oldPtoc = newPtoc;
indexStep++;
} // for (++iterStep
// add to fully formed translation option list
PartialTranslOptColl &lastPartialTranslOptColl = *oldPtoc;
const vector<TranslationOption*>& partTransOptList = lastPartialTranslOptColl.GetList();
vector<TranslationOption*>::const_iterator iterColl;
for (iterColl = partTransOptList.begin() ; iterColl != partTransOptList.end() ; ++iterColl) {
TranslationOption *transOpt = *iterColl;
Add(transOpt);
}
lastPartialTranslOptColl.DetachAll();
totalEarlyPruned += oldPtoc->GetPrunedCount();
delete oldPtoc;
// TRACE_ERR( "Early translation options pruned: " << totalEarlyPruned << endl);
} // if ((StaticData::Instance().GetXmlInputType() != XmlExclusive) || !HasXmlOptionsOverlappingRange(startPos,endPos))
if (graphInd == 0 && StaticData::Instance().GetXmlInputType() != XmlPassThrough && HasXmlOptionsOverlappingRange(startPos,endPos)) {
CreateXmlOptionsForRange(startPos, endPos);
}
}
void TranslationOptionCollectionLattice::CheckLEGACY()
{
const std::vector<PhraseDictionary*> &pts = StaticData::Instance().GetPhraseDictionaries();
for (size_t i = 0; i < pts.size(); ++i) {
const PhraseDictionary *phraseDictionary = pts[i];
if (dynamic_cast<const PhraseDictionaryTreeAdaptor*>(phraseDictionary) != NULL) {
m_useLegacy = true;
return;
}
}
m_useLegacy = false;
}
}

View File

@ -0,0 +1,55 @@
// $Id$
#pragma once
#include "TranslationOptionCollection.h"
#include "InputPath.h"
namespace Moses
{
class ConfusionNet;
/** Holds all translation options, for all spans, of a particular confusion network input
* Inherited from TranslationOptionCollection.
*/
class TranslationOptionCollectionLattice : public TranslationOptionCollection
{
public:
typedef std::vector< std::vector<InputPathList> > InputPathMatrix;
protected:
bool m_useLegacy;
InputPathMatrix m_inputPathMatrix; /*< contains translation options */
InputPathList &GetInputPathList(size_t startPos, size_t endPos);
void CreateTranslationOptionsForRangeNew(const DecodeGraph &decodeStepList
, size_t startPosition
, size_t endPosition
, bool adhereTableLimit
, size_t graphInd);
void CheckLEGACY();
void CreateTranslationOptionsForRangeLEGACY(const DecodeGraph &decodeStepList
, size_t startPosition
, size_t endPosition
, bool adhereTableLimit
, size_t graphInd);
public:
TranslationOptionCollectionLattice(const ConfusionNet &source, size_t maxNoTransOptPerCoverage, float translationOptionThreshold);
void ProcessUnknownWord(size_t sourcePos);
void CreateTranslationOptions();
void CreateTranslationOptionsForRange(const DecodeGraph &decodeStepList
, size_t startPosition
, size_t endPosition
, bool adhereTableLimit
, size_t graphInd);
protected:
};
}

View File

@ -4,6 +4,7 @@
#include "PCNTools.h"
#include "Util.h"
#include "FloydWarshall.h"
#include "TranslationOptionCollectionLattice.h"
#include "moses/FF/InputFeature.h"
#include "util/check.hh"
@ -203,6 +204,15 @@ bool WordLattice::CanIGetFromAToB(size_t start, size_t end) const
return distances[start][end] < 100000;
}
TranslationOptionCollection*
WordLattice::CreateTranslationOptionCollection() const
{
size_t maxNoTransOptPerCoverage = StaticData::Instance().GetMaxNoTransOptPerCoverage();
float translationOptionThreshold = StaticData::Instance().GetTranslationOptionThreshold();
TranslationOptionCollection *rv= new TranslationOptionCollectionLattice(*this, maxNoTransOptPerCoverage, translationOptionThreshold);
CHECK(rv);
return rv;
}
std::ostream& operator<<(std::ostream &out, const WordLattice &obj)
{

View File

@ -45,6 +45,8 @@ public:
const std::vector<size_t> &GetNextNodes(size_t pos) const
{ return next_nodes[pos]; }
TranslationOptionCollection *CreateTranslationOptionCollection() const;
};
}