Work to allow output search graph in HTK standard lattice format

This commit is contained in:
Lane Schwartz 2013-02-15 13:06:54 -05:00
parent 5844fb21a7
commit 774ed64f2e
8 changed files with 126 additions and 5 deletions

View File

@ -189,6 +189,15 @@ InputType*IOWrapper::GetInput(InputType* inputType)
}
}
ofstream* IOWrapper::GetOutputSearchGraphSLFStream(size_t sentenceNumber) {
const StaticData &staticData = StaticData::Instance();
stringstream fileName;
fileName << staticData.GetParam("output-search-graph-slf")[0] << "/" << sentenceNumber << ".slf";
std::ofstream *file = new std::ofstream;
file->open(fileName.str().c_str());
return file;
}
/***
* print surface factor only for the given phrase
*/

View File

@ -117,6 +117,8 @@ public:
return *m_outputSearchGraphStream;
}
std::ofstream *GetOutputSearchGraphSLFStream(size_t sentenceNumber);
std::ostream &GetDetailedTranslationReportingStream() {
assert (m_detailedTranslationReportingStream);
return *m_detailedTranslationReportingStream;

View File

@ -83,14 +83,16 @@ public:
OutputCollector* wordGraphCollector, OutputCollector* searchGraphCollector,
OutputCollector* detailedTranslationCollector,
OutputCollector* alignmentInfoCollector,
OutputCollector* unknownsCollector) :
OutputCollector* unknownsCollector,
std::ofstream* searchGraphSLFStream) :
m_source(source), m_lineNumber(lineNumber),
m_outputCollector(outputCollector), m_nbestCollector(nbestCollector),
m_latticeSamplesCollector(latticeSamplesCollector),
m_wordGraphCollector(wordGraphCollector), m_searchGraphCollector(searchGraphCollector),
m_detailedTranslationCollector(detailedTranslationCollector),
m_alignmentInfoCollector(alignmentInfoCollector),
m_unknownsCollector(unknownsCollector) {}
m_unknownsCollector(unknownsCollector),
m_searchGraphSLFStream(searchGraphSLFStream) {}
/** Translate one sentence
* gets called by main function implemented at end of this source file */
@ -143,6 +145,19 @@ public:
#endif
}
// Output search graph in HTK standard lattice format (SLF)
if (m_searchGraphSLFStream) {
if (m_searchGraphSLFStream->is_open() && m_searchGraphSLFStream->good()) {
ostringstream out;
fix(out,PRECISION);
manager.OutputSearchGraphAsSLF(m_lineNumber, out);
*m_searchGraphSLFStream << out.str();
m_searchGraphSLFStream -> flush();
} else {
TRACE_ERR("Cannot output HTK standard lattice for line " << m_lineNumber << " because the output file is not open or not ready for writing" << std::endl);
}
}
// apply decision rule and output best translation(s)
if (m_outputCollector) {
ostringstream out;
@ -297,7 +312,14 @@ public:
}
~TranslationTask() {
if (m_searchGraphSLFStream) {
m_searchGraphSLFStream->close();
}
delete m_searchGraphSLFStream;
delete m_source;
}
private:
@ -311,6 +333,7 @@ private:
OutputCollector* m_detailedTranslationCollector;
OutputCollector* m_alignmentInfoCollector;
OutputCollector* m_unknownsCollector;
std::ofstream *m_searchGraphSLFStream;
std::ofstream *m_alignmentStream;
@ -533,7 +556,9 @@ int main(int argc, char** argv)
searchGraphCollector.get(),
detailedTranslationCollector.get(),
alignmentInfoCollector.get(),
unknownsCollector.get() );
unknownsCollector.get(),
staticData.GetOutputSearchGraphSLF() ?
ioWrapper->GetOutputSearchGraphSLFStream(lineCount) : NULL);
// execute task
#ifdef WITH_THREADS
pool.Submit(task);

View File

@ -26,8 +26,10 @@ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
#endif
#include <algorithm>
#include <limits>
#include <cmath>
#include <limits>
#include <map>
#include <set>
#include "Manager.h"
#include "TypeDef.h"
#include "Util.h"
@ -628,6 +630,79 @@ void Manager::GetSearchGraph(vector<SearchGraphNode>& searchGraph) const
}
/**! Output search graph in HTK standard lattice format (SLF) */
void Manager::OutputSearchGraphAsSLF(long translationId, std::ostream &outputSearchGraphStream) const
{
vector<SearchGraphNode> searchGraph;
GetSearchGraph(searchGraph);
long numArcs = 0;
long numNodes = 0;
map<int,int> nodes;
set<int> terminalNodes;
// Unique start node
nodes[0] = 0;
numNodes += 1;
for (size_t arcNumber = 0; arcNumber < searchGraph.size(); ++arcNumber) {
numArcs += 1;
int hypothesisID = searchGraph[arcNumber].hypo->GetId();
if (nodes.count(hypothesisID) == 0) {
nodes[hypothesisID] = numNodes;
numNodes += 1;
bool terminalNode = (searchGraph[arcNumber].forward == -1);
if (terminalNode) {
numArcs += 1;
}
}
}
// Unique end node
nodes[numNodes] = numNodes;
outputSearchGraphStream << "UTTERANCE=\"Sentence " << translationId << "\"" << endl;
outputSearchGraphStream << "VERSION=1.1" << endl;
outputSearchGraphStream << "base=e" << endl;
outputSearchGraphStream << "NODES=" << numNodes << endl;
outputSearchGraphStream << "LINKS=" << numArcs << endl;
const vector<FactorType> &outputFactorOrder = StaticData::Instance().GetOutputFactorOrder();
for (size_t arcNumber = 0; arcNumber < searchGraph.size(); ++arcNumber) {
const Hypothesis *thisHypo = searchGraph[arcNumber].hypo;
const Hypothesis *prevHypo = thisHypo->GetPrevHypo();
if (prevHypo) {
int startNode = nodes[prevHypo->GetId()];
int endNode = nodes[thisHypo->GetId()];
bool terminalNode = (searchGraph[arcNumber].forward == -1);
outputSearchGraphStream << "J=" << arcNumber
<< " S=" << startNode
<< " E=" << endNode
<< " W=\"" << thisHypo->GetCurrTargetPhrase().GetStringRep(outputFactorOrder) << "\""
<< endl;
if (terminalNode && terminalNodes.count(endNode) == 0) {
terminalNodes.insert(endNode);
outputSearchGraphStream << "J=" << arcNumber
<< " S=" << endNode
<< " E=" << numNodes
<< endl;
}
}
}
}
void OutputSearchNode(long translationId, std::ostream &outputSearchGraphStream,
const SearchGraphNode& searchNode)
{

View File

@ -137,6 +137,7 @@ public:
#endif
void OutputSearchGraph(long translationId, std::ostream &outputSearchGraphStream) const;
void OutputSearchGraphAsSLF(long translationId, std::ostream &outputSearchGraphStream) const;
void GetSearchGraph(std::vector<SearchGraphNode>& searchGraph) const;
const InputType& GetSource() const {
return m_source;

View File

@ -130,6 +130,7 @@ Parameter::Parameter()
AddParam("output-search-graph", "osg", "Output connected hypotheses of search into specified filename");
AddParam("output-search-graph-extended", "osgx", "Output connected hypotheses of search into specified filename, in extended format");
AddParam("unpruned-search-graph", "usg", "When outputting chart search graph, do not exclude dead ends. Note: stack pruning may have eliminated some hypotheses");
AddParam("output-search-graph-slf", "slf", "Output connected hypotheses of search into specified directory, one file per sentence, in HTK standard lattice format (SLF)");
AddParam("include-lhs-in-search-graph", "lhssg", "When outputting chart search graph, include the label of the LHS of the rule (useful when using syntax)");
#ifdef HAVE_PROTOBUF
AddParam("output-search-graph-pb", "pb", "Write phrase lattice to protocol buffer objects in the specified path.");

View File

@ -235,8 +235,12 @@ bool StaticData::LoadData(Parameter *parameter)
}
m_outputSearchGraph = true;
m_outputSearchGraphExtended = true;
} else
} else {
m_outputSearchGraph = false;
}
if (m_parameter->GetParam("output-search-graph-slf").size() > 0) {
m_outputSearchGraphSLF = true;
}
#ifdef HAVE_PROTOBUF
if (m_parameter->GetParam("output-search-graph-pb").size() > 0) {
if (m_parameter->GetParam("output-search-graph-pb").size() != 1) {

View File

@ -216,6 +216,7 @@ protected:
bool m_outputWordGraph; //! whether to output word graph
bool m_outputSearchGraph; //! whether to output search graph
bool m_outputSearchGraphExtended; //! ... in extended format
bool m_outputSearchGraphSLF; //! whether to output search graph in HTK standard lattice format (SLF)
#ifdef HAVE_PROTOBUF
bool m_outputSearchGraphPB; //! whether to output search graph as a protobuf
#endif
@ -631,6 +632,9 @@ public:
bool GetOutputSearchGraphExtended() const {
return m_outputSearchGraphExtended;
}
bool GetOutputSearchGraphSLF() const {
return m_outputSearchGraphSLF;
}
#ifdef HAVE_PROTOBUF
bool GetOutputSearchGraphPB() const {
return m_outputSearchGraphPB;