output search graph

git-svn-id: https://mosesdecoder.svn.sourceforge.net/svnroot/mosesdecoder/trunk@1588 1f5c12ca-751b-0410-a591-d2e778427230
This commit is contained in:
phkoehn 2008-03-17 21:34:19 +00:00
parent 96c68c26f2
commit 4ee468dc03
10 changed files with 208 additions and 4 deletions

View File

@ -58,6 +58,7 @@ IOStream::IOStream(
,m_inputStream(&std::cin) ,m_inputStream(&std::cin)
,m_nBestStream(NULL) ,m_nBestStream(NULL)
,m_outputWordGraphStream(NULL) ,m_outputWordGraphStream(NULL)
,m_outputSearchGraphStream(NULL)
{ {
Initialization(inputFactorOrder, outputFactorOrder Initialization(inputFactorOrder, outputFactorOrder
, inputFactorUsed , inputFactorUsed
@ -77,6 +78,7 @@ IOStream::IOStream(const std::vector<FactorType> &inputFactorOrder
,m_inputFile(new InputFileStream(inputFilePath)) ,m_inputFile(new InputFileStream(inputFilePath))
,m_nBestStream(NULL) ,m_nBestStream(NULL)
,m_outputWordGraphStream(NULL) ,m_outputWordGraphStream(NULL)
,m_outputSearchGraphStream(NULL)
{ {
Initialization(inputFactorOrder, outputFactorOrder Initialization(inputFactorOrder, outputFactorOrder
, inputFactorUsed , inputFactorUsed
@ -97,6 +99,10 @@ IOStream::~IOStream()
{ {
delete m_outputWordGraphStream; delete m_outputWordGraphStream;
} }
if (m_outputSearchGraphStream != NULL)
{
delete m_outputSearchGraphStream;
}
} }
void IOStream::Initialization(const std::vector<FactorType> &inputFactorOrder void IOStream::Initialization(const std::vector<FactorType> &inputFactorOrder
@ -134,6 +140,15 @@ void IOStream::Initialization(const std::vector<FactorType> &inputFactorOrder
m_outputWordGraphStream = file; m_outputWordGraphStream = file;
file->open(fileName.c_str()); file->open(fileName.c_str());
} }
// search graph output
if (staticData.GetOutputSearchGraph())
{
std::ofstream *file = new std::ofstream;
string fileName = staticData.GetParam("output-search-graph")[0];
m_outputSearchGraphStream = file;
file->open(fileName.c_str());
}
} }
InputType*IOStream::GetInput(InputType* inputType) InputType*IOStream::GetInput(InputType* inputType)

View File

@ -54,7 +54,7 @@ protected:
const std::vector<FactorType> &m_outputFactorOrder; const std::vector<FactorType> &m_outputFactorOrder;
const FactorMask &m_inputFactorUsed; const FactorMask &m_inputFactorUsed;
std::ostream *m_nBestStream std::ostream *m_nBestStream
,*m_outputWordGraphStream; ,*m_outputWordGraphStream,*m_outputSearchGraphStream;
std::string m_inputFilePath; std::string m_inputFilePath;
std::istream *m_inputStream; std::istream *m_inputStream;
InputFileStream *m_inputFile; InputFileStream *m_inputFile;
@ -78,7 +78,7 @@ public:
, const FactorMask &inputFactorUsed , const FactorMask &inputFactorUsed
, size_t nBestSize , size_t nBestSize
, const std::string &nBestFilePath , const std::string &nBestFilePath
, const std::string &inputFilePath); , const std::string &infilePath);
~IOStream(); ~IOStream();
InputType* GetInput(InputType *inputType); InputType* GetInput(InputType *inputType);
@ -93,4 +93,8 @@ public:
{ {
return *m_outputWordGraphStream; return *m_outputWordGraphStream;
} }
std::ostream &GetOutputSearchGraphStream()
{
return *m_outputSearchGraphStream;
}
}; };

View File

@ -140,6 +140,9 @@ int main(int argc, char* argv[])
if (staticData.GetOutputWordGraph()) if (staticData.GetOutputWordGraph())
manager.GetWordGraph(source->GetTranslationId(), ioStream->GetOutputWordGraphStream()); manager.GetWordGraph(source->GetTranslationId(), ioStream->GetOutputWordGraphStream());
if (staticData.GetOutputSearchGraph())
manager.GetSearchGraph(source->GetTranslationId(), ioStream->GetOutputSearchGraphStream());
// pick best translation (maximum a posteriori decoding) // pick best translation (maximum a posteriori decoding)
if (! staticData.UseMBR()) { if (! staticData.UseMBR()) {
ioStream->OutputBestHypo(manager.GetBestHypothesis(), source->GetTranslationId(), ioStream->OutputBestHypo(manager.GetBestHypothesis(), source->GetTranslationId(),

View File

@ -430,7 +430,7 @@ void Hypothesis::CleanupArcList()
*/ */
const StaticData &staticData = StaticData::Instance(); const StaticData &staticData = StaticData::Instance();
size_t nBestSize = staticData.GetNBestSize(); size_t nBestSize = staticData.GetNBestSize();
bool distinctNBest = staticData.GetDistinctNBest() || staticData.UseMBR(); bool distinctNBest = staticData.GetDistinctNBest() || staticData.UseMBR() || staticData.GetOutputSearchGraph();
if (!distinctNBest && m_arcList->size() > nBestSize * 5) if (!distinctNBest && m_arcList->size() > nBestSize * 5)
{ // prune arc list only if there too many arcs { // prune arc list only if there too many arcs

View File

@ -243,6 +243,7 @@ public:
return m_scoreBreakdown; return m_scoreBreakdown;
} }
float GetTotalScore() const { return m_totalScore; } float GetTotalScore() const { return m_totalScore; }
float GetScore() const { return m_totalScore-m_futureScore; }
std::vector<std::vector<unsigned int> > *GetLMStats() const std::vector<std::vector<unsigned int> > *GetLMStats() const
{ {

View File

@ -594,3 +594,168 @@ void Manager::GetWordGraph(long translationId, std::ostream &outputWordGraphStre
} // for (iterStack } // for (iterStack
} }
void OutputSearchGraph(long translationId, std::ostream &outputSearchGraphStream, const Hypothesis *hypo, const Hypothesis *recombinationHypo, int forward, double fscore)
{
outputSearchGraphStream << translationId
<< " hyp=" << hypo->GetId()
<< " stack=" << hypo->GetWordsBitmap().GetNumWordsCovered();
if (hypo->GetId() > 0)
{
const Hypothesis *prevHypo = hypo->GetPrevHypo();
outputSearchGraphStream << " back=" << prevHypo->GetId()
<< " score=" << hypo->GetScore()
<< " transition=" << (hypo->GetScore() - prevHypo->GetScore());
}
if (recombinationHypo != NULL)
{
outputSearchGraphStream << " recombined=" << recombinationHypo->GetId();
}
outputSearchGraphStream << " forward=" << forward
<< " fscore=" << fscore;
if (hypo->GetId() > 0)
{
outputSearchGraphStream << " covered=" << hypo->GetCurrSourceWordsRange().GetStartPos()
<< "-" << hypo->GetCurrSourceWordsRange().GetEndPos()
<< " out=" << hypo->GetCurrTargetPhrase();
}
outputSearchGraphStream << endl;
}
void Manager::GetSearchGraph(long translationId, std::ostream &outputSearchGraphStream) const
{
std::map < int, bool > connected;
std::map < int, int > forward;
std::map < int, double > forwardScore;
// *** find connected hypotheses ***
std::vector< const Hypothesis *> connectedList;
// start with the ones in the final stack
const HypothesisStack &finalStack = m_hypoStackColl.back();
HypothesisStack::const_iterator iterHypo;
for (iterHypo = finalStack.begin() ; iterHypo != finalStack.end() ; ++iterHypo)
{
const Hypothesis *hypo = *iterHypo;
connected[ hypo->GetId() ] = true;
connectedList.push_back( hypo );
}
// move back from known connected hypotheses
for(size_t i=0; i<connectedList.size(); i++) {
const Hypothesis *hypo = connectedList[i];
// add back pointer
const Hypothesis *prevHypo = hypo->GetPrevHypo();
if (prevHypo->GetId() > 0 // don't add empty hypothesis
&& connected.find( prevHypo->GetId() ) == connected.end()) // don't add already added
{
connected[ prevHypo->GetId() ] = true;
connectedList.push_back( prevHypo );
}
// add arcs
const ArcList *arcList = hypo->GetArcList();
if (arcList != NULL)
{
ArcList::const_iterator iterArcList;
for (iterArcList = arcList->begin() ; iterArcList != arcList->end() ; ++iterArcList)
{
const Hypothesis *loserHypo = *iterArcList;
if (connected.find( loserHypo->GetId() ) == connected.end()) // don't add already added
{
connected[ loserHypo->GetId() ] = true;
connectedList.push_back( loserHypo );
}
}
}
}
// ** compute best forward path for each hypothesis *** //
// forward cost of hypotheses on final stack is 0
for (iterHypo = finalStack.begin() ; iterHypo != finalStack.end() ; ++iterHypo)
{
const Hypothesis *hypo = *iterHypo;
forwardScore[ hypo->GetId() ] = 0.0f;
forward[ hypo->GetId() ] = -1;
}
// compete for best forward score of previous hypothesis
std::vector < HypothesisStack >::const_iterator iterStack;
for (iterStack = --m_hypoStackColl.end() ; iterStack != m_hypoStackColl.begin() ; --iterStack)
{
const HypothesisStack &stack = *iterStack;
HypothesisStack::const_iterator iterHypo;
for (iterHypo = stack.begin() ; iterHypo != stack.end() ; ++iterHypo)
{
const Hypothesis *hypo = *iterHypo;
if (connected.find( hypo->GetId() ) != connected.end())
{
// make a play for previous hypothesis
const Hypothesis *prevHypo = hypo->GetPrevHypo();
double fscore = forwardScore[ hypo->GetId() ] +
hypo->GetScore() - prevHypo->GetScore();
if (forwardScore.find( prevHypo->GetId() ) == forwardScore.end()
|| forwardScore.find( prevHypo->GetId() )->second < fscore)
{
forwardScore[ prevHypo->GetId() ] = fscore;
forward[ prevHypo->GetId() ] = hypo->GetId();
}
// all arcs also make a play
const ArcList *arcList = hypo->GetArcList();
if (arcList != NULL)
{
ArcList::const_iterator iterArcList;
for (iterArcList = arcList->begin() ; iterArcList != arcList->end() ; ++iterArcList)
{
const Hypothesis *loserHypo = *iterArcList;
// make a play
const Hypothesis *loserPrevHypo = loserHypo->GetPrevHypo();
double fscore = forwardScore[ hypo->GetId() ] +
loserHypo->GetScore() - loserPrevHypo->GetScore();
if (forwardScore.find( loserPrevHypo->GetId() ) == forwardScore.end()
|| forwardScore.find( loserPrevHypo->GetId() )->second < fscore)
{
forwardScore[ loserPrevHypo->GetId() ] = fscore;
forward[ loserPrevHypo->GetId() ] = loserHypo->GetId();
}
} // end for arc list
} // end if arc list empty
} // end if hypo connected
} // end for hypo
} // end for stack
// *** output all connected hypotheses *** //
connected[ 0 ] = true;
for (iterStack = m_hypoStackColl.begin() ; iterStack != m_hypoStackColl.end() ; ++iterStack)
{
const HypothesisStack &stack = *iterStack;
HypothesisStack::const_iterator iterHypo;
for (iterHypo = stack.begin() ; iterHypo != stack.end() ; ++iterHypo)
{
const Hypothesis *hypo = *iterHypo;
if (connected.find( hypo->GetId() ) != connected.end())
{
OutputSearchGraph(translationId, outputSearchGraphStream, hypo, NULL, forward[ hypo->GetId() ], forwardScore[ hypo->GetId() ]);
const ArcList *arcList = hypo->GetArcList();
if (arcList != NULL)
{
ArcList::const_iterator iterArcList;
for (iterArcList = arcList->begin() ; iterArcList != arcList->end() ; ++iterArcList)
{
const Hypothesis *loserHypo = *iterArcList;
OutputSearchGraph(translationId, outputSearchGraphStream, loserHypo, hypo, forward[ hypo->GetId() ], forwardScore[ hypo->GetId() ]);
}
} // end if arcList empty
} // end if connected
} // end for iterHypo
} // end for iterStack
}

View File

@ -98,6 +98,7 @@ public:
void CalcNBest(size_t count, TrellisPathList &ret,bool onlyDistinct=0) const; void CalcNBest(size_t count, TrellisPathList &ret,bool onlyDistinct=0) const;
void GetWordGraph(long translationId, std::ostream &outputWordGraphStream) const; void GetWordGraph(long translationId, std::ostream &outputWordGraphStream) const;
void GetSearchGraph(long translationId, std::ostream &outputSearchGraphStream) const;
/*** /***
* to be called after processing a sentence (which may consist of more than just calling ProcessSentence() ) * to be called after processing a sentence (which may consist of more than just calling ProcessSentence() )

View File

@ -81,6 +81,7 @@ Parameter::Parameter()
AddParam("use-persistent-cache", "cache translation options across sentences (default true)"); AddParam("use-persistent-cache", "cache translation options across sentences (default true)");
AddParam("recover-input-path", "r", "(conf net/word lattice only) - recover input path corresponding to the best translation"); AddParam("recover-input-path", "r", "(conf net/word lattice only) - recover input path corresponding to the best translation");
AddParam("output-word-graph", "owg", "Output stack info as word graph. Takes filename, 0=only hypos in stack, 1=stack + nbest hypos"); AddParam("output-word-graph", "owg", "Output stack info as word graph. Takes filename, 0=only hypos in stack, 1=stack + nbest hypos");
AddParam("output-search-graph", "osg", "Output connected hypotheses of search into specified filename");
} }
Parameter::~Parameter() Parameter::~Parameter()

View File

@ -144,6 +144,18 @@ bool StaticData::LoadData(Parameter *parameter)
else else
m_outputWordGraph = false; m_outputWordGraph = false;
// search graph
if (m_parameter->GetParam("output-search-graph").size() > 0)
{
if (m_parameter->GetParam("output-search-graph").size() != 1) {
UserMessage::Add(string("ERROR: wrong format for switch -output-search-graph file"));
return false;
}
m_outputSearchGraph = true;
}
else
m_outputSearchGraph = false;
// include feature names in the n-best list // include feature names in the n-best list
SetBooleanParameter( &m_labeledNBestList, "labeled-n-best-list", true ); SetBooleanParameter( &m_labeledNBestList, "labeled-n-best-list", true );

View File

@ -125,6 +125,7 @@ protected:
//! constructor. only the 1 static variable can be created //! constructor. only the 1 static variable can be created
bool m_outputWordGraph; //! whether to output word graph bool m_outputWordGraph; //! whether to output word graph
bool m_outputSearchGraph; //! whether to output search graph
StaticData(); StaticData();
@ -329,7 +330,7 @@ public:
return m_nBestFilePath; return m_nBestFilePath;
} }
bool IsNBestEnabled() const { bool IsNBestEnabled() const {
return (!m_nBestFilePath.empty()) || m_mbr; return (!m_nBestFilePath.empty()) || m_mbr || m_outputSearchGraph;
} }
size_t GetNBestFactor() const size_t GetNBestFactor() const
{ {
@ -365,6 +366,7 @@ public:
size_t UseMBR() const { return m_mbr; } size_t UseMBR() const { return m_mbr; }
size_t GetMBRSize() const { return m_mbrSize; } size_t GetMBRSize() const { return m_mbrSize; }
float GetMBRScale() const { return m_mbrScale; } float GetMBRScale() const { return m_mbrScale; }
size_t GetOutputSearchGraph() const { return m_outputSearchGraph; }
XmlInputType GetXmlInputType() const { return m_xmlInputType; } XmlInputType GetXmlInputType() const { return m_xmlInputType; }