correct creation of input paths for lattices

This commit is contained in:
Hieu Hoang 2013-10-03 16:58:47 +01:00
parent 8ccc99a947
commit 73513c182d
5 changed files with 26 additions and 76 deletions

View File

@ -2,7 +2,7 @@
#include <map>
#include <iostream>
#include <list>
#include <vector>
#include "Phrase.h"
#include "WordsRange.h"
#include "NonTerminal.h"
@ -17,7 +17,7 @@ class TargetPhrase;
class InputPath;
struct ScorePair;
typedef std::list<InputPath*> InputPathList;
typedef std::vector<InputPath*> InputPathList;
/** Each node contains
1. substring used to searching the phrase table

View File

@ -108,7 +108,7 @@ void PhraseDictionaryOnDisk::GetTargetPhraseCollectionBatch(InputPath &inputPath
const Phrase &phrase = inputPath.GetPhrase();
const InputPath *prevInputPath = inputPath.GetPrevPath();
//cerr << "inputPath=" << inputPath << endl;
cerr << "inputPath=" << inputPath << endl;
const OnDiskPt::PhraseNode *prevPtNode = NULL;

View File

@ -28,13 +28,9 @@ TranslationOptionCollectionLattice::TranslationOptionCollectionLattice(
CHECK(inputFeature);
size_t size = input.GetSize();
m_inputPathMatrix.resize(size);
// 1-word phrases
for (size_t startPos = 0; startPos < size; ++startPos) {
vector<InputPathList> &vec = m_inputPathMatrix[startPos];
vec.push_back(InputPathList());
InputPathList &list = vec.back();
const std::vector<size_t> &nextNodes = input.GetNextNodes(startPos);
@ -55,8 +51,6 @@ TranslationOptionCollectionLattice::TranslationOptionCollectionLattice(
size_t nextNode = nextNodes[i];
path->SetNextNode(nextNode);
list.push_back(path);
m_phraseDictionaryQueue.push_back(path);
}
}
@ -67,23 +61,21 @@ TranslationOptionCollectionLattice::TranslationOptionCollectionLattice(
size_t endPos = startPos + phaseSize -1;
const std::vector<size_t> &nextNodes = input.GetNextNodes(endPos);
WordsRange range(startPos, endPos);
const NonTerminalSet &labels = input.GetLabelSet(startPos, endPos);
vector<InputPathList> &vec = m_inputPathMatrix[startPos];
vec.push_back(InputPathList());
InputPathList &list = vec.back();
// loop thru every previous paths
size_t numPrevPaths = m_phraseDictionaryQueue.size();
// loop thru every previous path
const InputPathList &prevPaths = GetInputPathList(startPos, endPos - 1);
int prevNodesInd = 0;
InputPathList::const_iterator iterPath;
for (iterPath = prevPaths.begin(); iterPath != prevPaths.end(); ++iterPath) {
for (size_t i = 0; i < numPrevPaths; ++i) {
//for (size_t pathInd = 0; pathInd < prevPaths.size(); ++pathInd) {
const InputPath &prevPath = **iterPath;
//const InputPath &prevPath = *prevPaths[pathInd];
const InputPath &prevPath = *m_phraseDictionaryQueue[i];
size_t nextNode = prevPath.GetNextNode();
if (prevPath.GetWordsRange().GetEndPos() + nextNode != endPos) {
continue;
}
WordsRange range(prevPath.GetWordsRange().GetStartPos(), endPos);
const Phrase &prevPhrase = prevPath.GetPhrase();
const ScorePair *prevInputScore = prevPath.GetInputScore();
@ -106,22 +98,21 @@ TranslationOptionCollectionLattice::TranslationOptionCollectionLattice(
size_t nextNode = nextNodes[i];
path->SetNextNode(nextNode);
list.push_back(path);
cerr << *path << endl;
m_phraseDictionaryQueue.push_back(path);
} // for (size_t i = 0; i < col.size(); ++i) {
++prevNodesInd;
} // for (iterPath = prevPaths.begin(); iterPath != prevPaths.end(); ++iterPath) {
} // for (size_t i = 0; i < numPrevPaths; ++i) {
}
}
}
InputPathList &TranslationOptionCollectionLattice::GetInputPathList(size_t startPos, size_t endPos)
{
size_t offset = endPos - startPos;
CHECK(offset < m_inputPathMatrix[startPos].size());
return m_inputPathMatrix[startPos][offset];
// debug
for (size_t i = 0; i < m_phraseDictionaryQueue.size(); ++i) {
const InputPath &prevPath = *m_phraseDictionaryQueue[i];
cerr << prevPath << endl;
}
}
/* forcibly create translation option for a particular source word.
@ -130,22 +121,6 @@ InputPathList &TranslationOptionCollectionLattice::GetInputPathList(size_t start
*/
void TranslationOptionCollectionLattice::ProcessUnknownWord(size_t sourcePos)
{
ConfusionNet const& source=dynamic_cast<ConfusionNet const&>(m_source);
ConfusionNet::Column const& coll=source.GetColumn(sourcePos);
const InputPathList &inputPathList = GetInputPathList(sourcePos, sourcePos);
ConfusionNet::Column::const_iterator iterCol;
InputPathList::const_iterator iterInputPath;
size_t j=0;
for(iterCol = coll.begin(), iterInputPath = inputPathList.begin();
iterCol != coll.end();
++iterCol , ++iterInputPath) {
const InputPath &inputPath = **iterInputPath;
size_t length = source.GetColumnIncrement(sourcePos, j++);
const ScorePair &inputScores = iterCol->second;
ProcessOneUnknownWord(inputPath ,sourcePos, length, &inputScores);
}
}
@ -173,28 +148,7 @@ void TranslationOptionCollectionLattice::CreateTranslationOptionsForRange(
, bool adhereTableLimit
, size_t graphInd)
{
CreateTranslationOptionsForRangeNew(decodeGraph, startPos, endPos, adhereTableLimit, graphInd);
}
void TranslationOptionCollectionLattice::CreateTranslationOptionsForRangeNew(
const DecodeGraph &decodeGraph
, size_t startPos
, size_t endPos
, bool adhereTableLimit
, size_t graphInd)
{
InputPathList &inputPathList = GetInputPathList(startPos, endPos);
InputPathList::iterator iter;
for (iter = inputPathList.begin(); iter != inputPathList.end(); ++iter) {
InputPath &inputPath = **iter;
TranslationOptionCollection::CreateTranslationOptionsForRange(decodeGraph
, startPos
, endPos
, adhereTableLimit
, graphInd
, inputPath);
}
//CreateTranslationOptionsForRangeNew(decodeGraph, startPos, endPos, adhereTableLimit, graphInd);
}
} // namespace

View File

@ -14,14 +14,7 @@ class WordLattice;
*/
class TranslationOptionCollectionLattice : public TranslationOptionCollection
{
public:
typedef std::vector< std::vector<InputPathList> > InputPathMatrix;
protected:
InputPathMatrix m_inputPathMatrix; /*< contains translation options */
InputPathList &GetInputPathList(size_t startPos, size_t endPos);
void CreateTranslationOptionsForRangeNew(const DecodeGraph &decodeStepList
, size_t startPosition
, size_t endPosition

View File

@ -212,12 +212,15 @@ WordLattice::CreateTranslationOptionCollection() const
float translationOptionThreshold = StaticData::Instance().GetTranslationOptionThreshold();
TranslationOptionCollection *rv = NULL;
rv = new TranslationOptionCollectionConfusionNet(*this, maxNoTransOptPerCoverage, translationOptionThreshold);
/*
if (StaticData::Instance().GetUseLegacyPT()) {
rv = new TranslationOptionCollectionConfusionNet(*this, maxNoTransOptPerCoverage, translationOptionThreshold);
}
else {
rv = new TranslationOptionCollectionLattice(*this, maxNoTransOptPerCoverage, translationOptionThreshold);
}
*/
CHECK(rv);
return rv;
}