mirror of
https://github.com/moses-smt/mosesdecoder.git
synced 2024-12-28 14:32:38 +03:00
output search graph
git-svn-id: https://mosesdecoder.svn.sourceforge.net/svnroot/mosesdecoder/trunk@1588 1f5c12ca-751b-0410-a591-d2e778427230
This commit is contained in:
parent
96c68c26f2
commit
4ee468dc03
@ -58,6 +58,7 @@ IOStream::IOStream(
|
|||||||
,m_inputStream(&std::cin)
|
,m_inputStream(&std::cin)
|
||||||
,m_nBestStream(NULL)
|
,m_nBestStream(NULL)
|
||||||
,m_outputWordGraphStream(NULL)
|
,m_outputWordGraphStream(NULL)
|
||||||
|
,m_outputSearchGraphStream(NULL)
|
||||||
{
|
{
|
||||||
Initialization(inputFactorOrder, outputFactorOrder
|
Initialization(inputFactorOrder, outputFactorOrder
|
||||||
, inputFactorUsed
|
, inputFactorUsed
|
||||||
@ -77,6 +78,7 @@ IOStream::IOStream(const std::vector<FactorType> &inputFactorOrder
|
|||||||
,m_inputFile(new InputFileStream(inputFilePath))
|
,m_inputFile(new InputFileStream(inputFilePath))
|
||||||
,m_nBestStream(NULL)
|
,m_nBestStream(NULL)
|
||||||
,m_outputWordGraphStream(NULL)
|
,m_outputWordGraphStream(NULL)
|
||||||
|
,m_outputSearchGraphStream(NULL)
|
||||||
{
|
{
|
||||||
Initialization(inputFactorOrder, outputFactorOrder
|
Initialization(inputFactorOrder, outputFactorOrder
|
||||||
, inputFactorUsed
|
, inputFactorUsed
|
||||||
@ -97,6 +99,10 @@ IOStream::~IOStream()
|
|||||||
{
|
{
|
||||||
delete m_outputWordGraphStream;
|
delete m_outputWordGraphStream;
|
||||||
}
|
}
|
||||||
|
if (m_outputSearchGraphStream != NULL)
|
||||||
|
{
|
||||||
|
delete m_outputSearchGraphStream;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void IOStream::Initialization(const std::vector<FactorType> &inputFactorOrder
|
void IOStream::Initialization(const std::vector<FactorType> &inputFactorOrder
|
||||||
@ -134,6 +140,15 @@ void IOStream::Initialization(const std::vector<FactorType> &inputFactorOrder
|
|||||||
m_outputWordGraphStream = file;
|
m_outputWordGraphStream = file;
|
||||||
file->open(fileName.c_str());
|
file->open(fileName.c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// search graph output
|
||||||
|
if (staticData.GetOutputSearchGraph())
|
||||||
|
{
|
||||||
|
std::ofstream *file = new std::ofstream;
|
||||||
|
string fileName = staticData.GetParam("output-search-graph")[0];
|
||||||
|
m_outputSearchGraphStream = file;
|
||||||
|
file->open(fileName.c_str());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
InputType*IOStream::GetInput(InputType* inputType)
|
InputType*IOStream::GetInput(InputType* inputType)
|
||||||
|
@ -54,7 +54,7 @@ protected:
|
|||||||
const std::vector<FactorType> &m_outputFactorOrder;
|
const std::vector<FactorType> &m_outputFactorOrder;
|
||||||
const FactorMask &m_inputFactorUsed;
|
const FactorMask &m_inputFactorUsed;
|
||||||
std::ostream *m_nBestStream
|
std::ostream *m_nBestStream
|
||||||
,*m_outputWordGraphStream;
|
,*m_outputWordGraphStream,*m_outputSearchGraphStream;
|
||||||
std::string m_inputFilePath;
|
std::string m_inputFilePath;
|
||||||
std::istream *m_inputStream;
|
std::istream *m_inputStream;
|
||||||
InputFileStream *m_inputFile;
|
InputFileStream *m_inputFile;
|
||||||
@ -78,7 +78,7 @@ public:
|
|||||||
, const FactorMask &inputFactorUsed
|
, const FactorMask &inputFactorUsed
|
||||||
, size_t nBestSize
|
, size_t nBestSize
|
||||||
, const std::string &nBestFilePath
|
, const std::string &nBestFilePath
|
||||||
, const std::string &inputFilePath);
|
, const std::string &infilePath);
|
||||||
~IOStream();
|
~IOStream();
|
||||||
|
|
||||||
InputType* GetInput(InputType *inputType);
|
InputType* GetInput(InputType *inputType);
|
||||||
@ -93,4 +93,8 @@ public:
|
|||||||
{
|
{
|
||||||
return *m_outputWordGraphStream;
|
return *m_outputWordGraphStream;
|
||||||
}
|
}
|
||||||
|
std::ostream &GetOutputSearchGraphStream()
|
||||||
|
{
|
||||||
|
return *m_outputSearchGraphStream;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
@ -140,6 +140,9 @@ int main(int argc, char* argv[])
|
|||||||
if (staticData.GetOutputWordGraph())
|
if (staticData.GetOutputWordGraph())
|
||||||
manager.GetWordGraph(source->GetTranslationId(), ioStream->GetOutputWordGraphStream());
|
manager.GetWordGraph(source->GetTranslationId(), ioStream->GetOutputWordGraphStream());
|
||||||
|
|
||||||
|
if (staticData.GetOutputSearchGraph())
|
||||||
|
manager.GetSearchGraph(source->GetTranslationId(), ioStream->GetOutputSearchGraphStream());
|
||||||
|
|
||||||
// pick best translation (maximum a posteriori decoding)
|
// pick best translation (maximum a posteriori decoding)
|
||||||
if (! staticData.UseMBR()) {
|
if (! staticData.UseMBR()) {
|
||||||
ioStream->OutputBestHypo(manager.GetBestHypothesis(), source->GetTranslationId(),
|
ioStream->OutputBestHypo(manager.GetBestHypothesis(), source->GetTranslationId(),
|
||||||
|
@ -430,7 +430,7 @@ void Hypothesis::CleanupArcList()
|
|||||||
*/
|
*/
|
||||||
const StaticData &staticData = StaticData::Instance();
|
const StaticData &staticData = StaticData::Instance();
|
||||||
size_t nBestSize = staticData.GetNBestSize();
|
size_t nBestSize = staticData.GetNBestSize();
|
||||||
bool distinctNBest = staticData.GetDistinctNBest() || staticData.UseMBR();
|
bool distinctNBest = staticData.GetDistinctNBest() || staticData.UseMBR() || staticData.GetOutputSearchGraph();
|
||||||
|
|
||||||
if (!distinctNBest && m_arcList->size() > nBestSize * 5)
|
if (!distinctNBest && m_arcList->size() > nBestSize * 5)
|
||||||
{ // prune arc list only if there too many arcs
|
{ // prune arc list only if there too many arcs
|
||||||
|
@ -243,6 +243,7 @@ public:
|
|||||||
return m_scoreBreakdown;
|
return m_scoreBreakdown;
|
||||||
}
|
}
|
||||||
float GetTotalScore() const { return m_totalScore; }
|
float GetTotalScore() const { return m_totalScore; }
|
||||||
|
float GetScore() const { return m_totalScore-m_futureScore; }
|
||||||
|
|
||||||
std::vector<std::vector<unsigned int> > *GetLMStats() const
|
std::vector<std::vector<unsigned int> > *GetLMStats() const
|
||||||
{
|
{
|
||||||
|
@ -594,3 +594,168 @@ void Manager::GetWordGraph(long translationId, std::ostream &outputWordGraphStre
|
|||||||
} // for (iterStack
|
} // for (iterStack
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void OutputSearchGraph(long translationId, std::ostream &outputSearchGraphStream, const Hypothesis *hypo, const Hypothesis *recombinationHypo, int forward, double fscore)
|
||||||
|
{
|
||||||
|
outputSearchGraphStream << translationId
|
||||||
|
<< " hyp=" << hypo->GetId()
|
||||||
|
<< " stack=" << hypo->GetWordsBitmap().GetNumWordsCovered();
|
||||||
|
if (hypo->GetId() > 0)
|
||||||
|
{
|
||||||
|
const Hypothesis *prevHypo = hypo->GetPrevHypo();
|
||||||
|
outputSearchGraphStream << " back=" << prevHypo->GetId()
|
||||||
|
<< " score=" << hypo->GetScore()
|
||||||
|
<< " transition=" << (hypo->GetScore() - prevHypo->GetScore());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (recombinationHypo != NULL)
|
||||||
|
{
|
||||||
|
outputSearchGraphStream << " recombined=" << recombinationHypo->GetId();
|
||||||
|
}
|
||||||
|
|
||||||
|
outputSearchGraphStream << " forward=" << forward
|
||||||
|
<< " fscore=" << fscore;
|
||||||
|
|
||||||
|
if (hypo->GetId() > 0)
|
||||||
|
{
|
||||||
|
outputSearchGraphStream << " covered=" << hypo->GetCurrSourceWordsRange().GetStartPos()
|
||||||
|
<< "-" << hypo->GetCurrSourceWordsRange().GetEndPos()
|
||||||
|
<< " out=" << hypo->GetCurrTargetPhrase();
|
||||||
|
}
|
||||||
|
|
||||||
|
outputSearchGraphStream << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Manager::GetSearchGraph(long translationId, std::ostream &outputSearchGraphStream) const
|
||||||
|
{
|
||||||
|
std::map < int, bool > connected;
|
||||||
|
std::map < int, int > forward;
|
||||||
|
std::map < int, double > forwardScore;
|
||||||
|
|
||||||
|
// *** find connected hypotheses ***
|
||||||
|
|
||||||
|
std::vector< const Hypothesis *> connectedList;
|
||||||
|
|
||||||
|
// start with the ones in the final stack
|
||||||
|
const HypothesisStack &finalStack = m_hypoStackColl.back();
|
||||||
|
HypothesisStack::const_iterator iterHypo;
|
||||||
|
for (iterHypo = finalStack.begin() ; iterHypo != finalStack.end() ; ++iterHypo)
|
||||||
|
{
|
||||||
|
const Hypothesis *hypo = *iterHypo;
|
||||||
|
connected[ hypo->GetId() ] = true;
|
||||||
|
connectedList.push_back( hypo );
|
||||||
|
}
|
||||||
|
|
||||||
|
// move back from known connected hypotheses
|
||||||
|
for(size_t i=0; i<connectedList.size(); i++) {
|
||||||
|
const Hypothesis *hypo = connectedList[i];
|
||||||
|
|
||||||
|
// add back pointer
|
||||||
|
const Hypothesis *prevHypo = hypo->GetPrevHypo();
|
||||||
|
if (prevHypo->GetId() > 0 // don't add empty hypothesis
|
||||||
|
&& connected.find( prevHypo->GetId() ) == connected.end()) // don't add already added
|
||||||
|
{
|
||||||
|
connected[ prevHypo->GetId() ] = true;
|
||||||
|
connectedList.push_back( prevHypo );
|
||||||
|
}
|
||||||
|
|
||||||
|
// add arcs
|
||||||
|
const ArcList *arcList = hypo->GetArcList();
|
||||||
|
if (arcList != NULL)
|
||||||
|
{
|
||||||
|
ArcList::const_iterator iterArcList;
|
||||||
|
for (iterArcList = arcList->begin() ; iterArcList != arcList->end() ; ++iterArcList)
|
||||||
|
{
|
||||||
|
const Hypothesis *loserHypo = *iterArcList;
|
||||||
|
if (connected.find( loserHypo->GetId() ) == connected.end()) // don't add already added
|
||||||
|
{
|
||||||
|
connected[ loserHypo->GetId() ] = true;
|
||||||
|
connectedList.push_back( loserHypo );
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ** compute best forward path for each hypothesis *** //
|
||||||
|
|
||||||
|
// forward cost of hypotheses on final stack is 0
|
||||||
|
for (iterHypo = finalStack.begin() ; iterHypo != finalStack.end() ; ++iterHypo)
|
||||||
|
{
|
||||||
|
const Hypothesis *hypo = *iterHypo;
|
||||||
|
forwardScore[ hypo->GetId() ] = 0.0f;
|
||||||
|
forward[ hypo->GetId() ] = -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// compete for best forward score of previous hypothesis
|
||||||
|
std::vector < HypothesisStack >::const_iterator iterStack;
|
||||||
|
for (iterStack = --m_hypoStackColl.end() ; iterStack != m_hypoStackColl.begin() ; --iterStack)
|
||||||
|
{
|
||||||
|
const HypothesisStack &stack = *iterStack;
|
||||||
|
HypothesisStack::const_iterator iterHypo;
|
||||||
|
for (iterHypo = stack.begin() ; iterHypo != stack.end() ; ++iterHypo)
|
||||||
|
{
|
||||||
|
const Hypothesis *hypo = *iterHypo;
|
||||||
|
if (connected.find( hypo->GetId() ) != connected.end())
|
||||||
|
{
|
||||||
|
// make a play for previous hypothesis
|
||||||
|
const Hypothesis *prevHypo = hypo->GetPrevHypo();
|
||||||
|
double fscore = forwardScore[ hypo->GetId() ] +
|
||||||
|
hypo->GetScore() - prevHypo->GetScore();
|
||||||
|
if (forwardScore.find( prevHypo->GetId() ) == forwardScore.end()
|
||||||
|
|| forwardScore.find( prevHypo->GetId() )->second < fscore)
|
||||||
|
{
|
||||||
|
forwardScore[ prevHypo->GetId() ] = fscore;
|
||||||
|
forward[ prevHypo->GetId() ] = hypo->GetId();
|
||||||
|
}
|
||||||
|
// all arcs also make a play
|
||||||
|
const ArcList *arcList = hypo->GetArcList();
|
||||||
|
if (arcList != NULL)
|
||||||
|
{
|
||||||
|
ArcList::const_iterator iterArcList;
|
||||||
|
for (iterArcList = arcList->begin() ; iterArcList != arcList->end() ; ++iterArcList)
|
||||||
|
{
|
||||||
|
const Hypothesis *loserHypo = *iterArcList;
|
||||||
|
// make a play
|
||||||
|
const Hypothesis *loserPrevHypo = loserHypo->GetPrevHypo();
|
||||||
|
double fscore = forwardScore[ hypo->GetId() ] +
|
||||||
|
loserHypo->GetScore() - loserPrevHypo->GetScore();
|
||||||
|
if (forwardScore.find( loserPrevHypo->GetId() ) == forwardScore.end()
|
||||||
|
|| forwardScore.find( loserPrevHypo->GetId() )->second < fscore)
|
||||||
|
{
|
||||||
|
forwardScore[ loserPrevHypo->GetId() ] = fscore;
|
||||||
|
forward[ loserPrevHypo->GetId() ] = loserHypo->GetId();
|
||||||
|
}
|
||||||
|
} // end for arc list
|
||||||
|
} // end if arc list empty
|
||||||
|
} // end if hypo connected
|
||||||
|
} // end for hypo
|
||||||
|
} // end for stack
|
||||||
|
|
||||||
|
// *** output all connected hypotheses *** //
|
||||||
|
|
||||||
|
connected[ 0 ] = true;
|
||||||
|
for (iterStack = m_hypoStackColl.begin() ; iterStack != m_hypoStackColl.end() ; ++iterStack)
|
||||||
|
{
|
||||||
|
const HypothesisStack &stack = *iterStack;
|
||||||
|
HypothesisStack::const_iterator iterHypo;
|
||||||
|
for (iterHypo = stack.begin() ; iterHypo != stack.end() ; ++iterHypo)
|
||||||
|
{
|
||||||
|
const Hypothesis *hypo = *iterHypo;
|
||||||
|
if (connected.find( hypo->GetId() ) != connected.end())
|
||||||
|
{
|
||||||
|
OutputSearchGraph(translationId, outputSearchGraphStream, hypo, NULL, forward[ hypo->GetId() ], forwardScore[ hypo->GetId() ]);
|
||||||
|
|
||||||
|
const ArcList *arcList = hypo->GetArcList();
|
||||||
|
if (arcList != NULL)
|
||||||
|
{
|
||||||
|
ArcList::const_iterator iterArcList;
|
||||||
|
for (iterArcList = arcList->begin() ; iterArcList != arcList->end() ; ++iterArcList)
|
||||||
|
{
|
||||||
|
const Hypothesis *loserHypo = *iterArcList;
|
||||||
|
OutputSearchGraph(translationId, outputSearchGraphStream, loserHypo, hypo, forward[ hypo->GetId() ], forwardScore[ hypo->GetId() ]);
|
||||||
|
}
|
||||||
|
} // end if arcList empty
|
||||||
|
} // end if connected
|
||||||
|
} // end for iterHypo
|
||||||
|
} // end for iterStack
|
||||||
|
}
|
||||||
|
|
||||||
|
@ -98,6 +98,7 @@ public:
|
|||||||
void CalcNBest(size_t count, TrellisPathList &ret,bool onlyDistinct=0) const;
|
void CalcNBest(size_t count, TrellisPathList &ret,bool onlyDistinct=0) const;
|
||||||
|
|
||||||
void GetWordGraph(long translationId, std::ostream &outputWordGraphStream) const;
|
void GetWordGraph(long translationId, std::ostream &outputWordGraphStream) const;
|
||||||
|
void GetSearchGraph(long translationId, std::ostream &outputSearchGraphStream) const;
|
||||||
|
|
||||||
/***
|
/***
|
||||||
* to be called after processing a sentence (which may consist of more than just calling ProcessSentence() )
|
* to be called after processing a sentence (which may consist of more than just calling ProcessSentence() )
|
||||||
|
@ -81,6 +81,7 @@ Parameter::Parameter()
|
|||||||
AddParam("use-persistent-cache", "cache translation options across sentences (default true)");
|
AddParam("use-persistent-cache", "cache translation options across sentences (default true)");
|
||||||
AddParam("recover-input-path", "r", "(conf net/word lattice only) - recover input path corresponding to the best translation");
|
AddParam("recover-input-path", "r", "(conf net/word lattice only) - recover input path corresponding to the best translation");
|
||||||
AddParam("output-word-graph", "owg", "Output stack info as word graph. Takes filename, 0=only hypos in stack, 1=stack + nbest hypos");
|
AddParam("output-word-graph", "owg", "Output stack info as word graph. Takes filename, 0=only hypos in stack, 1=stack + nbest hypos");
|
||||||
|
AddParam("output-search-graph", "osg", "Output connected hypotheses of search into specified filename");
|
||||||
}
|
}
|
||||||
|
|
||||||
Parameter::~Parameter()
|
Parameter::~Parameter()
|
||||||
|
@ -144,6 +144,18 @@ bool StaticData::LoadData(Parameter *parameter)
|
|||||||
else
|
else
|
||||||
m_outputWordGraph = false;
|
m_outputWordGraph = false;
|
||||||
|
|
||||||
|
// search graph
|
||||||
|
if (m_parameter->GetParam("output-search-graph").size() > 0)
|
||||||
|
{
|
||||||
|
if (m_parameter->GetParam("output-search-graph").size() != 1) {
|
||||||
|
UserMessage::Add(string("ERROR: wrong format for switch -output-search-graph file"));
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
m_outputSearchGraph = true;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
m_outputSearchGraph = false;
|
||||||
|
|
||||||
// include feature names in the n-best list
|
// include feature names in the n-best list
|
||||||
SetBooleanParameter( &m_labeledNBestList, "labeled-n-best-list", true );
|
SetBooleanParameter( &m_labeledNBestList, "labeled-n-best-list", true );
|
||||||
|
|
||||||
|
@ -125,6 +125,7 @@ protected:
|
|||||||
//! constructor. only the 1 static variable can be created
|
//! constructor. only the 1 static variable can be created
|
||||||
|
|
||||||
bool m_outputWordGraph; //! whether to output word graph
|
bool m_outputWordGraph; //! whether to output word graph
|
||||||
|
bool m_outputSearchGraph; //! whether to output search graph
|
||||||
|
|
||||||
StaticData();
|
StaticData();
|
||||||
|
|
||||||
@ -329,7 +330,7 @@ public:
|
|||||||
return m_nBestFilePath;
|
return m_nBestFilePath;
|
||||||
}
|
}
|
||||||
bool IsNBestEnabled() const {
|
bool IsNBestEnabled() const {
|
||||||
return (!m_nBestFilePath.empty()) || m_mbr;
|
return (!m_nBestFilePath.empty()) || m_mbr || m_outputSearchGraph;
|
||||||
}
|
}
|
||||||
size_t GetNBestFactor() const
|
size_t GetNBestFactor() const
|
||||||
{
|
{
|
||||||
@ -365,6 +366,7 @@ public:
|
|||||||
size_t UseMBR() const { return m_mbr; }
|
size_t UseMBR() const { return m_mbr; }
|
||||||
size_t GetMBRSize() const { return m_mbrSize; }
|
size_t GetMBRSize() const { return m_mbrSize; }
|
||||||
float GetMBRScale() const { return m_mbrScale; }
|
float GetMBRScale() const { return m_mbrScale; }
|
||||||
|
size_t GetOutputSearchGraph() const { return m_outputSearchGraph; }
|
||||||
|
|
||||||
XmlInputType GetXmlInputType() const { return m_xmlInputType; }
|
XmlInputType GetXmlInputType() const { return m_xmlInputType; }
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user