mirror of
https://github.com/moses-smt/mosesdecoder.git
synced 2024-12-27 22:14:57 +03:00
output search graph
git-svn-id: https://mosesdecoder.svn.sourceforge.net/svnroot/mosesdecoder/trunk@1588 1f5c12ca-751b-0410-a591-d2e778427230
This commit is contained in:
parent
96c68c26f2
commit
4ee468dc03
@ -58,6 +58,7 @@ IOStream::IOStream(
|
||||
,m_inputStream(&std::cin)
|
||||
,m_nBestStream(NULL)
|
||||
,m_outputWordGraphStream(NULL)
|
||||
,m_outputSearchGraphStream(NULL)
|
||||
{
|
||||
Initialization(inputFactorOrder, outputFactorOrder
|
||||
, inputFactorUsed
|
||||
@ -77,6 +78,7 @@ IOStream::IOStream(const std::vector<FactorType> &inputFactorOrder
|
||||
,m_inputFile(new InputFileStream(inputFilePath))
|
||||
,m_nBestStream(NULL)
|
||||
,m_outputWordGraphStream(NULL)
|
||||
,m_outputSearchGraphStream(NULL)
|
||||
{
|
||||
Initialization(inputFactorOrder, outputFactorOrder
|
||||
, inputFactorUsed
|
||||
@ -97,6 +99,10 @@ IOStream::~IOStream()
|
||||
{
|
||||
delete m_outputWordGraphStream;
|
||||
}
|
||||
if (m_outputSearchGraphStream != NULL)
|
||||
{
|
||||
delete m_outputSearchGraphStream;
|
||||
}
|
||||
}
|
||||
|
||||
void IOStream::Initialization(const std::vector<FactorType> &inputFactorOrder
|
||||
@ -134,6 +140,15 @@ void IOStream::Initialization(const std::vector<FactorType> &inputFactorOrder
|
||||
m_outputWordGraphStream = file;
|
||||
file->open(fileName.c_str());
|
||||
}
|
||||
|
||||
// search graph output
|
||||
if (staticData.GetOutputSearchGraph())
|
||||
{
|
||||
std::ofstream *file = new std::ofstream;
|
||||
string fileName = staticData.GetParam("output-search-graph")[0];
|
||||
m_outputSearchGraphStream = file;
|
||||
file->open(fileName.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
InputType*IOStream::GetInput(InputType* inputType)
|
||||
|
@ -54,7 +54,7 @@ protected:
|
||||
const std::vector<FactorType> &m_outputFactorOrder;
|
||||
const FactorMask &m_inputFactorUsed;
|
||||
std::ostream *m_nBestStream
|
||||
,*m_outputWordGraphStream;
|
||||
,*m_outputWordGraphStream,*m_outputSearchGraphStream;
|
||||
std::string m_inputFilePath;
|
||||
std::istream *m_inputStream;
|
||||
InputFileStream *m_inputFile;
|
||||
@ -78,7 +78,7 @@ public:
|
||||
, const FactorMask &inputFactorUsed
|
||||
, size_t nBestSize
|
||||
, const std::string &nBestFilePath
|
||||
, const std::string &inputFilePath);
|
||||
, const std::string &infilePath);
|
||||
~IOStream();
|
||||
|
||||
InputType* GetInput(InputType *inputType);
|
||||
@ -93,4 +93,8 @@ public:
|
||||
{
|
||||
return *m_outputWordGraphStream;
|
||||
}
|
||||
std::ostream &GetOutputSearchGraphStream()
|
||||
{
|
||||
return *m_outputSearchGraphStream;
|
||||
}
|
||||
};
|
||||
|
@ -140,6 +140,9 @@ int main(int argc, char* argv[])
|
||||
if (staticData.GetOutputWordGraph())
|
||||
manager.GetWordGraph(source->GetTranslationId(), ioStream->GetOutputWordGraphStream());
|
||||
|
||||
if (staticData.GetOutputSearchGraph())
|
||||
manager.GetSearchGraph(source->GetTranslationId(), ioStream->GetOutputSearchGraphStream());
|
||||
|
||||
// pick best translation (maximum a posteriori decoding)
|
||||
if (! staticData.UseMBR()) {
|
||||
ioStream->OutputBestHypo(manager.GetBestHypothesis(), source->GetTranslationId(),
|
||||
|
@ -430,7 +430,7 @@ void Hypothesis::CleanupArcList()
|
||||
*/
|
||||
const StaticData &staticData = StaticData::Instance();
|
||||
size_t nBestSize = staticData.GetNBestSize();
|
||||
bool distinctNBest = staticData.GetDistinctNBest() || staticData.UseMBR();
|
||||
bool distinctNBest = staticData.GetDistinctNBest() || staticData.UseMBR() || staticData.GetOutputSearchGraph();
|
||||
|
||||
if (!distinctNBest && m_arcList->size() > nBestSize * 5)
|
||||
{ // prune arc list only if there too many arcs
|
||||
|
@ -243,6 +243,7 @@ public:
|
||||
return m_scoreBreakdown;
|
||||
}
|
||||
float GetTotalScore() const { return m_totalScore; }
|
||||
float GetScore() const { return m_totalScore-m_futureScore; }
|
||||
|
||||
std::vector<std::vector<unsigned int> > *GetLMStats() const
|
||||
{
|
||||
|
@ -594,3 +594,168 @@ void Manager::GetWordGraph(long translationId, std::ostream &outputWordGraphStre
|
||||
} // for (iterStack
|
||||
}
|
||||
|
||||
void OutputSearchGraph(long translationId, std::ostream &outputSearchGraphStream, const Hypothesis *hypo, const Hypothesis *recombinationHypo, int forward, double fscore)
|
||||
{
|
||||
outputSearchGraphStream << translationId
|
||||
<< " hyp=" << hypo->GetId()
|
||||
<< " stack=" << hypo->GetWordsBitmap().GetNumWordsCovered();
|
||||
if (hypo->GetId() > 0)
|
||||
{
|
||||
const Hypothesis *prevHypo = hypo->GetPrevHypo();
|
||||
outputSearchGraphStream << " back=" << prevHypo->GetId()
|
||||
<< " score=" << hypo->GetScore()
|
||||
<< " transition=" << (hypo->GetScore() - prevHypo->GetScore());
|
||||
}
|
||||
|
||||
if (recombinationHypo != NULL)
|
||||
{
|
||||
outputSearchGraphStream << " recombined=" << recombinationHypo->GetId();
|
||||
}
|
||||
|
||||
outputSearchGraphStream << " forward=" << forward
|
||||
<< " fscore=" << fscore;
|
||||
|
||||
if (hypo->GetId() > 0)
|
||||
{
|
||||
outputSearchGraphStream << " covered=" << hypo->GetCurrSourceWordsRange().GetStartPos()
|
||||
<< "-" << hypo->GetCurrSourceWordsRange().GetEndPos()
|
||||
<< " out=" << hypo->GetCurrTargetPhrase();
|
||||
}
|
||||
|
||||
outputSearchGraphStream << endl;
|
||||
}
|
||||
|
||||
void Manager::GetSearchGraph(long translationId, std::ostream &outputSearchGraphStream) const
|
||||
{
|
||||
std::map < int, bool > connected;
|
||||
std::map < int, int > forward;
|
||||
std::map < int, double > forwardScore;
|
||||
|
||||
// *** find connected hypotheses ***
|
||||
|
||||
std::vector< const Hypothesis *> connectedList;
|
||||
|
||||
// start with the ones in the final stack
|
||||
const HypothesisStack &finalStack = m_hypoStackColl.back();
|
||||
HypothesisStack::const_iterator iterHypo;
|
||||
for (iterHypo = finalStack.begin() ; iterHypo != finalStack.end() ; ++iterHypo)
|
||||
{
|
||||
const Hypothesis *hypo = *iterHypo;
|
||||
connected[ hypo->GetId() ] = true;
|
||||
connectedList.push_back( hypo );
|
||||
}
|
||||
|
||||
// move back from known connected hypotheses
|
||||
for(size_t i=0; i<connectedList.size(); i++) {
|
||||
const Hypothesis *hypo = connectedList[i];
|
||||
|
||||
// add back pointer
|
||||
const Hypothesis *prevHypo = hypo->GetPrevHypo();
|
||||
if (prevHypo->GetId() > 0 // don't add empty hypothesis
|
||||
&& connected.find( prevHypo->GetId() ) == connected.end()) // don't add already added
|
||||
{
|
||||
connected[ prevHypo->GetId() ] = true;
|
||||
connectedList.push_back( prevHypo );
|
||||
}
|
||||
|
||||
// add arcs
|
||||
const ArcList *arcList = hypo->GetArcList();
|
||||
if (arcList != NULL)
|
||||
{
|
||||
ArcList::const_iterator iterArcList;
|
||||
for (iterArcList = arcList->begin() ; iterArcList != arcList->end() ; ++iterArcList)
|
||||
{
|
||||
const Hypothesis *loserHypo = *iterArcList;
|
||||
if (connected.find( loserHypo->GetId() ) == connected.end()) // don't add already added
|
||||
{
|
||||
connected[ loserHypo->GetId() ] = true;
|
||||
connectedList.push_back( loserHypo );
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ** compute best forward path for each hypothesis *** //
|
||||
|
||||
// forward cost of hypotheses on final stack is 0
|
||||
for (iterHypo = finalStack.begin() ; iterHypo != finalStack.end() ; ++iterHypo)
|
||||
{
|
||||
const Hypothesis *hypo = *iterHypo;
|
||||
forwardScore[ hypo->GetId() ] = 0.0f;
|
||||
forward[ hypo->GetId() ] = -1;
|
||||
}
|
||||
|
||||
// compete for best forward score of previous hypothesis
|
||||
std::vector < HypothesisStack >::const_iterator iterStack;
|
||||
for (iterStack = --m_hypoStackColl.end() ; iterStack != m_hypoStackColl.begin() ; --iterStack)
|
||||
{
|
||||
const HypothesisStack &stack = *iterStack;
|
||||
HypothesisStack::const_iterator iterHypo;
|
||||
for (iterHypo = stack.begin() ; iterHypo != stack.end() ; ++iterHypo)
|
||||
{
|
||||
const Hypothesis *hypo = *iterHypo;
|
||||
if (connected.find( hypo->GetId() ) != connected.end())
|
||||
{
|
||||
// make a play for previous hypothesis
|
||||
const Hypothesis *prevHypo = hypo->GetPrevHypo();
|
||||
double fscore = forwardScore[ hypo->GetId() ] +
|
||||
hypo->GetScore() - prevHypo->GetScore();
|
||||
if (forwardScore.find( prevHypo->GetId() ) == forwardScore.end()
|
||||
|| forwardScore.find( prevHypo->GetId() )->second < fscore)
|
||||
{
|
||||
forwardScore[ prevHypo->GetId() ] = fscore;
|
||||
forward[ prevHypo->GetId() ] = hypo->GetId();
|
||||
}
|
||||
// all arcs also make a play
|
||||
const ArcList *arcList = hypo->GetArcList();
|
||||
if (arcList != NULL)
|
||||
{
|
||||
ArcList::const_iterator iterArcList;
|
||||
for (iterArcList = arcList->begin() ; iterArcList != arcList->end() ; ++iterArcList)
|
||||
{
|
||||
const Hypothesis *loserHypo = *iterArcList;
|
||||
// make a play
|
||||
const Hypothesis *loserPrevHypo = loserHypo->GetPrevHypo();
|
||||
double fscore = forwardScore[ hypo->GetId() ] +
|
||||
loserHypo->GetScore() - loserPrevHypo->GetScore();
|
||||
if (forwardScore.find( loserPrevHypo->GetId() ) == forwardScore.end()
|
||||
|| forwardScore.find( loserPrevHypo->GetId() )->second < fscore)
|
||||
{
|
||||
forwardScore[ loserPrevHypo->GetId() ] = fscore;
|
||||
forward[ loserPrevHypo->GetId() ] = loserHypo->GetId();
|
||||
}
|
||||
} // end for arc list
|
||||
} // end if arc list empty
|
||||
} // end if hypo connected
|
||||
} // end for hypo
|
||||
} // end for stack
|
||||
|
||||
// *** output all connected hypotheses *** //
|
||||
|
||||
connected[ 0 ] = true;
|
||||
for (iterStack = m_hypoStackColl.begin() ; iterStack != m_hypoStackColl.end() ; ++iterStack)
|
||||
{
|
||||
const HypothesisStack &stack = *iterStack;
|
||||
HypothesisStack::const_iterator iterHypo;
|
||||
for (iterHypo = stack.begin() ; iterHypo != stack.end() ; ++iterHypo)
|
||||
{
|
||||
const Hypothesis *hypo = *iterHypo;
|
||||
if (connected.find( hypo->GetId() ) != connected.end())
|
||||
{
|
||||
OutputSearchGraph(translationId, outputSearchGraphStream, hypo, NULL, forward[ hypo->GetId() ], forwardScore[ hypo->GetId() ]);
|
||||
|
||||
const ArcList *arcList = hypo->GetArcList();
|
||||
if (arcList != NULL)
|
||||
{
|
||||
ArcList::const_iterator iterArcList;
|
||||
for (iterArcList = arcList->begin() ; iterArcList != arcList->end() ; ++iterArcList)
|
||||
{
|
||||
const Hypothesis *loserHypo = *iterArcList;
|
||||
OutputSearchGraph(translationId, outputSearchGraphStream, loserHypo, hypo, forward[ hypo->GetId() ], forwardScore[ hypo->GetId() ]);
|
||||
}
|
||||
} // end if arcList empty
|
||||
} // end if connected
|
||||
} // end for iterHypo
|
||||
} // end for iterStack
|
||||
}
|
||||
|
||||
|
@ -98,6 +98,7 @@ public:
|
||||
void CalcNBest(size_t count, TrellisPathList &ret,bool onlyDistinct=0) const;
|
||||
|
||||
void GetWordGraph(long translationId, std::ostream &outputWordGraphStream) const;
|
||||
void GetSearchGraph(long translationId, std::ostream &outputSearchGraphStream) const;
|
||||
|
||||
/***
|
||||
* to be called after processing a sentence (which may consist of more than just calling ProcessSentence() )
|
||||
|
@ -81,6 +81,7 @@ Parameter::Parameter()
|
||||
AddParam("use-persistent-cache", "cache translation options across sentences (default true)");
|
||||
AddParam("recover-input-path", "r", "(conf net/word lattice only) - recover input path corresponding to the best translation");
|
||||
AddParam("output-word-graph", "owg", "Output stack info as word graph. Takes filename, 0=only hypos in stack, 1=stack + nbest hypos");
|
||||
AddParam("output-search-graph", "osg", "Output connected hypotheses of search into specified filename");
|
||||
}
|
||||
|
||||
Parameter::~Parameter()
|
||||
|
@ -144,6 +144,18 @@ bool StaticData::LoadData(Parameter *parameter)
|
||||
else
|
||||
m_outputWordGraph = false;
|
||||
|
||||
// search graph
|
||||
if (m_parameter->GetParam("output-search-graph").size() > 0)
|
||||
{
|
||||
if (m_parameter->GetParam("output-search-graph").size() != 1) {
|
||||
UserMessage::Add(string("ERROR: wrong format for switch -output-search-graph file"));
|
||||
return false;
|
||||
}
|
||||
m_outputSearchGraph = true;
|
||||
}
|
||||
else
|
||||
m_outputSearchGraph = false;
|
||||
|
||||
// include feature names in the n-best list
|
||||
SetBooleanParameter( &m_labeledNBestList, "labeled-n-best-list", true );
|
||||
|
||||
|
@ -125,6 +125,7 @@ protected:
|
||||
//! constructor. only the 1 static variable can be created
|
||||
|
||||
bool m_outputWordGraph; //! whether to output word graph
|
||||
bool m_outputSearchGraph; //! whether to output search graph
|
||||
|
||||
StaticData();
|
||||
|
||||
@ -329,7 +330,7 @@ public:
|
||||
return m_nBestFilePath;
|
||||
}
|
||||
bool IsNBestEnabled() const {
|
||||
return (!m_nBestFilePath.empty()) || m_mbr;
|
||||
return (!m_nBestFilePath.empty()) || m_mbr || m_outputSearchGraph;
|
||||
}
|
||||
size_t GetNBestFactor() const
|
||||
{
|
||||
@ -365,6 +366,7 @@ public:
|
||||
size_t UseMBR() const { return m_mbr; }
|
||||
size_t GetMBRSize() const { return m_mbrSize; }
|
||||
float GetMBRScale() const { return m_mbrScale; }
|
||||
size_t GetOutputSearchGraph() const { return m_outputSearchGraph; }
|
||||
|
||||
XmlInputType GetXmlInputType() const { return m_xmlInputType; }
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user