Merge github.com:moses-smt/mosesdecoder into weight-new

This commit is contained in:
Hieu Hoang 2013-02-22 22:39:24 +00:00
commit 30e0d1e0fa
9 changed files with 548 additions and 10 deletions

View File

@ -189,6 +189,33 @@ InputType*IOWrapper::GetInput(InputType* inputType)
}
}
ofstream* IOWrapper::GetOutputSearchGraphSLFStream(size_t sentenceNumber) {
const StaticData &staticData = StaticData::Instance();
stringstream fileName;
fileName << staticData.GetParam("output-search-graph-slf")[0] << "/" << sentenceNumber << ".slf";
std::ofstream *file = new std::ofstream;
file->open(fileName.str().c_str());
return file;
}
ofstream* IOWrapper::GetOutputSearchGraphHypergraphStream(size_t sentenceNumber) {
const StaticData &staticData = StaticData::Instance();
stringstream fileName;
fileName << staticData.GetParam("output-search-graph-hypergraph")[0] << "/" << sentenceNumber;
std::ofstream *file = new std::ofstream;
file->open(fileName.str().c_str());
return file;
}
ofstream* IOWrapper::GetOutputSearchGraphHypergraphWeightsStream() {
const StaticData &staticData = StaticData::Instance();
stringstream fileName;
fileName << staticData.GetParam("output-search-graph-hypergraph")[1];
std::ofstream *file = new std::ofstream;
file->open(fileName.str().c_str());
return file;
}
/***
* print surface factor only for the given phrase
*/

View File

@ -124,6 +124,10 @@ public:
return *m_outputSearchGraphStream;
}
std::ofstream *GetOutputSearchGraphSLFStream(size_t sentenceNumber);
std::ofstream *GetOutputSearchGraphHypergraphStream(size_t sentenceNumber);
std::ofstream *GetOutputSearchGraphHypergraphWeightsStream();
std::ostream &GetDetailedTranslationReportingStream() {
assert (m_detailedTranslationReportingStream);
return *m_detailedTranslationReportingStream;

View File

@ -83,14 +83,18 @@ public:
OutputCollector* wordGraphCollector, OutputCollector* searchGraphCollector,
OutputCollector* detailedTranslationCollector,
OutputCollector* alignmentInfoCollector,
OutputCollector* unknownsCollector) :
OutputCollector* unknownsCollector,
std::ofstream* searchGraphSLFStream,
std::ofstream* searchGraphHypergraphStream) :
m_source(source), m_lineNumber(lineNumber),
m_outputCollector(outputCollector), m_nbestCollector(nbestCollector),
m_latticeSamplesCollector(latticeSamplesCollector),
m_wordGraphCollector(wordGraphCollector), m_searchGraphCollector(searchGraphCollector),
m_detailedTranslationCollector(detailedTranslationCollector),
m_alignmentInfoCollector(alignmentInfoCollector),
m_unknownsCollector(unknownsCollector) {}
m_unknownsCollector(unknownsCollector),
m_searchGraphSLFStream(searchGraphSLFStream),
m_searchGraphHypergraphStream(searchGraphHypergraphStream) {}
/** Translate one sentence
* gets called by main function implemented at end of this source file */
@ -143,6 +147,32 @@ public:
#endif
}
// Output search graph in HTK standard lattice format (SLF)
if (m_searchGraphSLFStream) {
if (m_searchGraphSLFStream->is_open() && m_searchGraphSLFStream->good()) {
ostringstream out;
fix(out,PRECISION);
manager.OutputSearchGraphAsSLF(m_lineNumber, out);
*m_searchGraphSLFStream << out.str();
m_searchGraphSLFStream -> flush();
} else {
TRACE_ERR("Cannot output HTK standard lattice for line " << m_lineNumber << " because the output file is not open or not ready for writing" << std::endl);
}
}
// Output search graph in hypergraph format for Kenneth Heafield's lazy hypergraph decoder
if (m_searchGraphHypergraphStream) {
if (m_searchGraphHypergraphStream->is_open() && m_searchGraphHypergraphStream->good()) {
ostringstream out;
fix(out,PRECISION);
manager.OutputSearchGraphAsHypergraph(m_lineNumber, out);
*m_searchGraphHypergraphStream << out.str();
m_searchGraphHypergraphStream -> flush();
} else {
TRACE_ERR("Cannot output hypergraph for line " << m_lineNumber << " because the output file is not open or not ready for writing" << std::endl);
}
}
// apply decision rule and output best translation(s)
if (m_outputCollector) {
ostringstream out;
@ -297,7 +327,15 @@ public:
}
~TranslationTask() {
if (m_searchGraphSLFStream) {
m_searchGraphSLFStream->close();
}
delete m_searchGraphSLFStream;
delete m_searchGraphHypergraphStream;
delete m_source;
}
private:
@ -311,6 +349,8 @@ private:
OutputCollector* m_detailedTranslationCollector;
OutputCollector* m_alignmentInfoCollector;
OutputCollector* m_unknownsCollector;
std::ofstream *m_searchGraphSLFStream;
std::ofstream *m_searchGraphHypergraphStream;
std::ofstream *m_alignmentStream;
@ -358,6 +398,63 @@ static void ShowWeights()
}
}
size_t OutputFeatureWeightsForHypergraph(size_t index, const FeatureFunction* ff, std::ostream &outputSearchGraphStream)
{
size_t numScoreComps = ff->GetNumScoreComponents();
if (numScoreComps != ScoreProducer::unlimited) {
vector<float> values = StaticData::Instance().GetAllWeights().GetScoresForProducer(ff);
if (numScoreComps > 1) {
for (size_t i = 0; i < numScoreComps; ++i) {
outputSearchGraphStream << ff->GetScoreProducerWeightShortName()
<< i
<< "=" << values[i] << endl;
}
} else {
outputSearchGraphStream << ff->GetScoreProducerWeightShortName()
<< "=" << values[0] << endl;
}
return index+numScoreComps;
} else {
cerr << "Sparse features are not yet supported when outputting hypergraph format" << endl;
assert(false);
return 0;
}
}
void OutputFeatureWeightsForHypergraph(std::ostream &outputSearchGraphStream)
{
outputSearchGraphStream.setf(std::ios::fixed);
outputSearchGraphStream.precision(6);
const StaticData& staticData = StaticData::Instance();
const TranslationSystem& system = staticData.GetTranslationSystem(TranslationSystem::DEFAULT);
const vector<const StatelessFeatureFunction*>& slf =system.GetStatelessFeatureFunctions();
const vector<const StatefulFeatureFunction*>& sff = system.GetStatefulFeatureFunctions();
size_t featureIndex = 1;
for (size_t i = 0; i < sff.size(); ++i) {
featureIndex = OutputFeatureWeightsForHypergraph(featureIndex, sff[i], outputSearchGraphStream);
}
for (size_t i = 0; i < slf.size(); ++i) {
if (slf[i]->GetScoreProducerWeightShortName() != "u" &&
slf[i]->GetScoreProducerWeightShortName() != "tm" &&
slf[i]->GetScoreProducerWeightShortName() != "I" &&
slf[i]->GetScoreProducerWeightShortName() != "g")
{
featureIndex = OutputFeatureWeightsForHypergraph(featureIndex, slf[i], outputSearchGraphStream);
}
}
const vector<PhraseDictionaryFeature*>& pds = system.GetPhraseDictionaries();
for( size_t i=0; i<pds.size(); i++ ) {
featureIndex = OutputFeatureWeightsForHypergraph(featureIndex, pds[i], outputSearchGraphStream);
}
const vector<GenerationDictionary*>& gds = system.GetGenerationDictionaries();
for( size_t i=0; i<gds.size(); i++ ) {
featureIndex = OutputFeatureWeightsForHypergraph(featureIndex, gds[i], outputSearchGraphStream);
}
}
} //namespace
/** main function of the command line version of the decoder **/
@ -421,6 +518,14 @@ int main(int argc, char** argv)
TRACE_ERR(weights);
TRACE_ERR("\n");
}
if (staticData.GetOutputSearchGraphHypergraph() && staticData.GetParam("output-search-graph-hypergraph").size() > 1) {
ofstream* weightsOut = ioWrapper->GetOutputSearchGraphHypergraphWeightsStream();
OutputFeatureWeightsForHypergraph(*weightsOut);
weightsOut->flush();
weightsOut->close();
delete weightsOut;
}
// initialize output streams
// note: we can't just write to STDOUT or files
@ -524,7 +629,11 @@ int main(int argc, char** argv)
searchGraphCollector.get(),
detailedTranslationCollector.get(),
alignmentInfoCollector.get(),
unknownsCollector.get() );
unknownsCollector.get(),
staticData.GetOutputSearchGraphSLF() ?
ioWrapper->GetOutputSearchGraphSLFStream(lineCount) : NULL,
staticData.GetOutputSearchGraphHypergraph() ?
ioWrapper->GetOutputSearchGraphHypergraphStream(lineCount) : NULL);
// execute task
#ifdef WITH_THREADS
pool.Submit(task);

View File

@ -462,7 +462,7 @@ void Hypothesis::CleanupArcList()
*/
const StaticData &staticData = StaticData::Instance();
size_t nBestSize = staticData.GetNBestSize();
bool distinctNBest = staticData.GetDistinctNBest() || staticData.UseMBR() || staticData.GetOutputSearchGraph() || staticData.UseLatticeMBR() ;
bool distinctNBest = staticData.GetDistinctNBest() || staticData.UseMBR() || staticData.GetOutputSearchGraph() || staticData.GetOutputSearchGraphSLF() || staticData.GetOutputSearchGraphHypergraph() || staticData.UseLatticeMBR() ;
if (!distinctNBest && m_arcList->size() > nBestSize * 5) {
// prune arc list only if there too many arcs

View File

@ -26,8 +26,10 @@ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
#endif
#include <algorithm>
#include <limits>
#include <cmath>
#include <limits>
#include <map>
#include <set>
#include "Manager.h"
#include "TypeDef.h"
#include "Util.h"
@ -51,12 +53,12 @@ using namespace std;
namespace Moses
{
Manager::Manager(size_t lineNumber, InputType const& source, SearchAlgorithm searchAlgorithm, const TranslationSystem* system)
:m_lineNumber(lineNumber)
,m_system(system)
:m_system(system)
,m_transOptColl(source.CreateTranslationOptionCollection(system))
,m_search(Search::CreateSearch(*this, source, searchAlgorithm, *m_transOptColl))
,interrupted_flag(0)
,m_hypoId(0)
,m_lineNumber(lineNumber)
,m_source(source)
{
StaticData::Instance().InitializeForInput(source);
@ -637,6 +639,366 @@ void Manager::GetSearchGraph(vector<SearchGraphNode>& searchGraph) const
}
void Manager::OutputFeatureWeightsForSLF(std::ostream &outputSearchGraphStream) const
{
outputSearchGraphStream.setf(std::ios::fixed);
outputSearchGraphStream.precision(6);
const StaticData& staticData = StaticData::Instance();
const TranslationSystem& system = staticData.GetTranslationSystem(TranslationSystem::DEFAULT);
const vector<const StatelessFeatureFunction*>& slf =system.GetStatelessFeatureFunctions();
const vector<const StatefulFeatureFunction*>& sff = system.GetStatefulFeatureFunctions();
size_t featureIndex = 1;
for (size_t i = 0; i < sff.size(); ++i) {
featureIndex = OutputFeatureWeightsForSLF(featureIndex, sff[i], outputSearchGraphStream);
}
for (size_t i = 0; i < slf.size(); ++i) {
if (slf[i]->GetScoreProducerWeightShortName() != "u" &&
slf[i]->GetScoreProducerWeightShortName() != "tm" &&
slf[i]->GetScoreProducerWeightShortName() != "I" &&
slf[i]->GetScoreProducerWeightShortName() != "g")
{
featureIndex = OutputFeatureWeightsForSLF(featureIndex, slf[i], outputSearchGraphStream);
}
}
const vector<PhraseDictionaryFeature*>& pds = system.GetPhraseDictionaries();
for( size_t i=0; i<pds.size(); i++ ) {
featureIndex = OutputFeatureWeightsForSLF(featureIndex, pds[i], outputSearchGraphStream);
}
const vector<GenerationDictionary*>& gds = system.GetGenerationDictionaries();
for( size_t i=0; i<gds.size(); i++ ) {
featureIndex = OutputFeatureWeightsForSLF(featureIndex, gds[i], outputSearchGraphStream);
}
}
void Manager::OutputFeatureValuesForSLF(const Hypothesis* hypo, bool zeros, std::ostream &outputSearchGraphStream) const
{
outputSearchGraphStream.setf(std::ios::fixed);
outputSearchGraphStream.precision(6);
// outputSearchGraphStream << endl;
// outputSearchGraphStream << (*hypo) << endl;
// const ScoreComponentCollection& scoreCollection = hypo->GetScoreBreakdown();
// outputSearchGraphStream << scoreCollection << endl;
const StaticData& staticData = StaticData::Instance();
const TranslationSystem& system = staticData.GetTranslationSystem(TranslationSystem::DEFAULT);
const vector<const StatelessFeatureFunction*>& slf =system.GetStatelessFeatureFunctions();
const vector<const StatefulFeatureFunction*>& sff = system.GetStatefulFeatureFunctions();
size_t featureIndex = 1;
for (size_t i = 0; i < sff.size(); ++i) {
featureIndex = OutputFeatureValuesForSLF(featureIndex, zeros, hypo, sff[i], outputSearchGraphStream);
}
for (size_t i = 0; i < slf.size(); ++i) {
if (slf[i]->GetScoreProducerWeightShortName() != "u" &&
slf[i]->GetScoreProducerWeightShortName() != "tm" &&
slf[i]->GetScoreProducerWeightShortName() != "I" &&
slf[i]->GetScoreProducerWeightShortName() != "g")
{
featureIndex = OutputFeatureValuesForSLF(featureIndex, zeros, hypo, slf[i], outputSearchGraphStream);
}
}
const vector<PhraseDictionaryFeature*>& pds = system.GetPhraseDictionaries();
for( size_t i=0; i<pds.size(); i++ ) {
featureIndex = OutputFeatureValuesForSLF(featureIndex, zeros, hypo, pds[i], outputSearchGraphStream);
}
const vector<GenerationDictionary*>& gds = system.GetGenerationDictionaries();
for( size_t i=0; i<gds.size(); i++ ) {
featureIndex = OutputFeatureValuesForSLF(featureIndex, zeros, hypo, gds[i], outputSearchGraphStream);
}
}
void Manager::OutputFeatureValuesForHypergraph(const Hypothesis* hypo, std::ostream &outputSearchGraphStream) const
{
outputSearchGraphStream.setf(std::ios::fixed);
outputSearchGraphStream.precision(6);
const StaticData& staticData = StaticData::Instance();
const TranslationSystem& system = staticData.GetTranslationSystem(TranslationSystem::DEFAULT);
const vector<const StatelessFeatureFunction*>& slf =system.GetStatelessFeatureFunctions();
const vector<const StatefulFeatureFunction*>& sff = system.GetStatefulFeatureFunctions();
size_t featureIndex = 1;
for (size_t i = 0; i < sff.size(); ++i) {
featureIndex = OutputFeatureValuesForHypergraph(featureIndex, hypo, sff[i], outputSearchGraphStream);
}
for (size_t i = 0; i < slf.size(); ++i) {
if (slf[i]->GetScoreProducerWeightShortName() != "u" &&
slf[i]->GetScoreProducerWeightShortName() != "tm" &&
slf[i]->GetScoreProducerWeightShortName() != "I" &&
slf[i]->GetScoreProducerWeightShortName() != "g")
{
featureIndex = OutputFeatureValuesForHypergraph(featureIndex, hypo, slf[i], outputSearchGraphStream);
}
}
const vector<PhraseDictionaryFeature*>& pds = system.GetPhraseDictionaries();
for( size_t i=0; i<pds.size(); i++ ) {
featureIndex = OutputFeatureValuesForHypergraph(featureIndex, hypo, pds[i], outputSearchGraphStream);
}
const vector<GenerationDictionary*>& gds = system.GetGenerationDictionaries();
for( size_t i=0; i<gds.size(); i++ ) {
featureIndex = OutputFeatureValuesForHypergraph(featureIndex, hypo, gds[i], outputSearchGraphStream);
}
}
size_t Manager::OutputFeatureWeightsForSLF(size_t index, const FeatureFunction* ff, std::ostream &outputSearchGraphStream) const
{
size_t numScoreComps = ff->GetNumScoreComponents();
if (numScoreComps != ScoreProducer::unlimited) {
vector<float> values = StaticData::Instance().GetAllWeights().GetScoresForProducer(ff);
for (size_t i = 0; i < numScoreComps; ++i) {
outputSearchGraphStream << "# " << ff->GetScoreProducerDescription()
<< " " << ff->GetScoreProducerWeightShortName()
<< " " << (i+1) << " of " << numScoreComps << endl
<< "x" << (index+i) << "scale=" << values[i] << endl;
}
return index+numScoreComps;
} else {
cerr << "Sparse features are not supported when outputting HTK standard lattice format" << endl;
assert(false);
return 0;
}
}
size_t Manager::OutputFeatureValuesForSLF(size_t index, bool zeros, const Hypothesis* hypo, const FeatureFunction* ff, std::ostream &outputSearchGraphStream) const
{
// { const FeatureFunction* sp = ff;
// const FVector& m_scores = scoreCollection.GetScoresVector();
// FVector& scores = const_cast<FVector&>(m_scores);
// std::string prefix = sp->GetScoreProducerDescription() + FName::SEP;
// // std::cout << "prefix==" << prefix << endl;
// // cout << "m_scores==" << m_scores << endl;
// // cout << "m_scores.size()==" << m_scores.size() << endl;
// // cout << "m_scores.coreSize()==" << m_scores.coreSize() << endl;
// // cout << "m_scores.cbegin() ?= m_scores.cend()\t" << (m_scores.cbegin() == m_scores.cend()) << endl;
// // for(FVector::FNVmap::const_iterator i = m_scores.cbegin(); i != m_scores.cend(); i++) {
// // std::cout<<prefix << "\t" << (i->first) << "\t" << (i->second) << std::endl;
// // }
// for(int i=0, n=v.size(); i<n; i+=1) {
// // outputSearchGraphStream << prefix << i << "==" << v[i] << std::endl;
// }
// }
// FVector featureValues = scoreCollection.GetVectorForProducer(ff);
// outputSearchGraphStream << featureValues << endl;
const ScoreComponentCollection& scoreCollection = hypo->GetScoreBreakdown();
vector<float> featureValues = scoreCollection.GetScoresForProducer(ff);
size_t numScoreComps = featureValues.size();//featureValues.coreSize();
// if (numScoreComps != ScoreProducer::unlimited) {
// vector<float> values = StaticData::Instance().GetAllWeights().GetScoresForProducer(ff);
for (size_t i = 0; i < numScoreComps; ++i) {
outputSearchGraphStream << "x" << (index+i) << "=" << ((zeros) ? 0.0 : featureValues[i]) << " ";
}
return index+numScoreComps;
// } else {
// cerr << "Sparse features are not supported when outputting HTK standard lattice format" << endl;
// assert(false);
// return 0;
// }
}
size_t Manager::OutputFeatureValuesForHypergraph(size_t index, const Hypothesis* hypo, const FeatureFunction* ff, std::ostream &outputSearchGraphStream) const
{
const ScoreComponentCollection& scoreCollection = hypo->GetScoreBreakdown();
vector<float> featureValues = scoreCollection.GetScoresForProducer(ff);
size_t numScoreComps = featureValues.size();
if (numScoreComps > 1) {
for (size_t i = 0; i < numScoreComps; ++i) {
outputSearchGraphStream << ff->GetScoreProducerWeightShortName() << i << "=" << featureValues[i] << " ";
}
} else {
outputSearchGraphStream << ff->GetScoreProducerWeightShortName() << "=" << featureValues[0] << " ";
}
return index+numScoreComps;
}
/**! Output search graph in hypergraph format of Kenneth Heafield's lazy hypergraph decoder */
void Manager::OutputSearchGraphAsHypergraph(long translationId, std::ostream &outputSearchGraphStream) const
{
vector<SearchGraphNode> searchGraph;
GetSearchGraph(searchGraph);
long numNodes = 0;
map<int,int> nodes;
set<int> terminalNodes;
multimap<int,int> nodeToLines;
for (size_t arcNumber = 0, size=searchGraph.size(); arcNumber < size; ++arcNumber) {
// Record that this arc ends at this node
nodeToLines.insert(pair<int,int>(numNodes,arcNumber));
int hypothesisID = searchGraph[arcNumber].hypo->GetId();
if (nodes.count(hypothesisID) == 0) {
nodes[hypothesisID] = numNodes;
numNodes += 1;
bool terminalNode = (searchGraph[arcNumber].forward == -1);
if (terminalNode) {
// Final arc to end node, representing the end of the sentence </s>
terminalNodes.insert(numNodes);
}
}
}
// Unique end node
nodes[numNodes] = numNodes;
numNodes += 1;
long numArcs = searchGraph.size() + terminalNodes.size();
// Print number of nodes and arcs
outputSearchGraphStream << numNodes << " " << numArcs << endl;
for (int nodeNumber=0; nodeNumber <= numNodes; nodeNumber+=1) {
size_t count = nodeToLines.count(nodeNumber);
if (count > 0) {
outputSearchGraphStream << count << endl;
pair<multimap<int,int>::iterator, multimap<int,int>::iterator> range = nodeToLines.equal_range(nodeNumber);
for (multimap<int,int>::iterator it=range.first; it!=range.second; ++it) {
int lineNumber = (*it).second;
const Hypothesis *thisHypo = searchGraph[lineNumber].hypo;
const Hypothesis *prevHypo = thisHypo->GetPrevHypo();
if (prevHypo==NULL) {
outputSearchGraphStream << "<s> ||| " << endl;
} else {
int startNode = nodes[prevHypo->GetId()];
const TargetPhrase &targetPhrase = thisHypo->GetCurrTargetPhrase();
int targetWordCount = targetPhrase.GetSize();
outputSearchGraphStream << "[" << startNode << "]";
for (int targetWordIndex=0; targetWordIndex<targetWordCount; targetWordIndex+=1) {
outputSearchGraphStream << " " << targetPhrase.GetWord(targetWordIndex);
}
outputSearchGraphStream << " ||| ";
OutputFeatureValuesForHypergraph(thisHypo, outputSearchGraphStream);
outputSearchGraphStream << endl;
}
}
}
}
// Print node and arc(s) for end of sentence </s>
outputSearchGraphStream << terminalNodes.size() << endl;
for (set<int>::iterator it=terminalNodes.begin(); it!=terminalNodes.end(); ++it) {
outputSearchGraphStream << "[" << (*it) << "] </s> ||| " << endl;
}
}
/**! Output search graph in HTK standard lattice format (SLF) */
void Manager::OutputSearchGraphAsSLF(long translationId, std::ostream &outputSearchGraphStream) const
{
vector<SearchGraphNode> searchGraph;
GetSearchGraph(searchGraph);
long numArcs = 0;
long numNodes = 0;
map<int,int> nodes;
set<int> terminalNodes;
// Unique start node
nodes[0] = 0;
for (size_t arcNumber = 0; arcNumber < searchGraph.size(); ++arcNumber) {
int targetWordCount = searchGraph[arcNumber].hypo->GetCurrTargetPhrase().GetSize();
numArcs += targetWordCount;
int hypothesisID = searchGraph[arcNumber].hypo->GetId();
if (nodes.count(hypothesisID) == 0) {
numNodes += targetWordCount;
nodes[hypothesisID] = numNodes;
//numNodes += 1;
bool terminalNode = (searchGraph[arcNumber].forward == -1);
if (terminalNode) {
numArcs += 1;
}
}
}
numNodes += 1;
// Unique end node
nodes[numNodes] = numNodes;
outputSearchGraphStream << "UTTERANCE=Sentence_" << translationId << endl;
outputSearchGraphStream << "VERSION=1.1" << endl;
outputSearchGraphStream << "base=2.71828182845905" << endl;
outputSearchGraphStream << "NODES=" << (numNodes+1) << endl;
outputSearchGraphStream << "LINKS=" << numArcs << endl;
OutputFeatureWeightsForSLF(outputSearchGraphStream);
for (size_t arcNumber = 0, lineNumber = 0; lineNumber < searchGraph.size(); ++lineNumber) {
const Hypothesis *thisHypo = searchGraph[lineNumber].hypo;
const Hypothesis *prevHypo = thisHypo->GetPrevHypo();
if (prevHypo) {
int startNode = nodes[prevHypo->GetId()];
int endNode = nodes[thisHypo->GetId()];
bool terminalNode = (searchGraph[lineNumber].forward == -1);
const TargetPhrase &targetPhrase = thisHypo->GetCurrTargetPhrase();
int targetWordCount = targetPhrase.GetSize();
for (int targetWordIndex=0; targetWordIndex<targetWordCount; targetWordIndex+=1) {
int x = (targetWordCount-targetWordIndex);
outputSearchGraphStream << "J=" << arcNumber;
if (targetWordIndex==0) {
outputSearchGraphStream << " S=" << startNode;
} else {
outputSearchGraphStream << " S=" << endNode - x;
}
outputSearchGraphStream << " E=" << endNode - (x-1)
<< " W=" << targetPhrase.GetWord(targetWordIndex);
OutputFeatureValuesForSLF(thisHypo, (targetWordIndex>0), outputSearchGraphStream);
outputSearchGraphStream << endl;
arcNumber += 1;
}
if (terminalNode && terminalNodes.count(endNode) == 0) {
terminalNodes.insert(endNode);
outputSearchGraphStream << "J=" << arcNumber
<< " S=" << endNode
<< " E=" << numNodes
<< endl;
arcNumber += 1;
}
}
}
}
void OutputSearchNode(long translationId, std::ostream &outputSearchGraphStream,
const SearchGraphNode& searchNode)
{

View File

@ -93,6 +93,19 @@ class Manager
Manager(Manager const&);
void operator=(Manager const&);
const TranslationSystem* m_system;
private:
// Helper functions to output search graph in HTK standard lattice format
void OutputFeatureWeightsForSLF(std::ostream &outputSearchGraphStream) const;
size_t OutputFeatureWeightsForSLF(size_t index, const FeatureFunction* ff, std::ostream &outputSearchGraphStream) const;
void OutputFeatureValuesForSLF(const Hypothesis* hypo, bool zeros, std::ostream &outputSearchGraphStream) const;
size_t OutputFeatureValuesForSLF(size_t index, bool zeros, const Hypothesis* hypo, const FeatureFunction* ff, std::ostream &outputSearchGraphStream) const;
// Helper functions to output search graph in the hypergraph format of Kenneth Heafield's lazy hypergraph decoder
void OutputFeatureValuesForHypergraph(const Hypothesis* hypo, std::ostream &outputSearchGraphStream) const;
size_t OutputFeatureValuesForHypergraph(size_t index, const Hypothesis* hypo, const FeatureFunction* ff, std::ostream &outputSearchGraphStream) const;
protected:
// data
// InputType const& m_source; /**< source sentence to be translated */
@ -103,6 +116,7 @@ protected:
size_t interrupted_flag;
std::auto_ptr<SentenceStats> m_sentenceStats;
int m_hypoId; //used to number the hypos as they are created.
size_t m_lineNumber;
void GetConnectedGraph(
std::map< int, bool >* pConnected,
@ -113,7 +127,6 @@ protected:
public:
size_t m_lineNumber;
InputType const& m_source; /**< source sentence to be translated */
Manager(size_t lineNumber, InputType const& source, SearchAlgorithm searchAlgorithm, const TranslationSystem* system);
~Manager();
@ -137,6 +150,8 @@ public:
#endif
void OutputSearchGraph(long translationId, std::ostream &outputSearchGraphStream) const;
void OutputSearchGraphAsSLF(long translationId, std::ostream &outputSearchGraphStream) const;
void OutputSearchGraphAsHypergraph(long translationId, std::ostream &outputSearchGraphStream) const;
void GetSearchGraph(std::vector<SearchGraphNode>& searchGraph) const;
const InputType& GetSource() const {
return m_source;

View File

@ -112,6 +112,8 @@ Parameter::Parameter()
AddParam("output-search-graph", "osg", "Output connected hypotheses of search into specified filename");
AddParam("output-search-graph-extended", "osgx", "Output connected hypotheses of search into specified filename, in extended format");
AddParam("unpruned-search-graph", "usg", "When outputting chart search graph, do not exclude dead ends. Note: stack pruning may have eliminated some hypotheses");
AddParam("output-search-graph-slf", "slf", "Output connected hypotheses of search into specified directory, one file per sentence, in HTK standard lattice format (SLF)");
AddParam("output-search-graph-hypergraph", "Output connected hypotheses of search into specified directory, one file per sentence, in a hypergraph format (see Kenneth Heafield's lazy hypergraph decoder)");
AddParam("include-lhs-in-search-graph", "lhssg", "When outputting chart search graph, include the label of the LHS of the rule (useful when using syntax)");
#ifdef HAVE_PROTOBUF
AddParam("output-search-graph-pb", "pb", "Write phrase lattice to protocol buffer objects in the specified path.");

View File

@ -239,8 +239,19 @@ bool StaticData::LoadData(Parameter *parameter)
}
m_outputSearchGraph = true;
m_outputSearchGraphExtended = true;
} else
} else {
m_outputSearchGraph = false;
}
if (m_parameter->GetParam("output-search-graph-slf").size() > 0) {
m_outputSearchGraphSLF = true;
} else {
m_outputSearchGraphSLF = false;
}
if (m_parameter->GetParam("output-search-graph-hypergraph").size() > 0) {
m_outputSearchGraphHypergraph = true;
} else {
m_outputSearchGraphHypergraph = false;
}
#ifdef HAVE_PROTOBUF
if (m_parameter->GetParam("output-search-graph-pb").size() > 0) {
if (m_parameter->GetParam("output-search-graph-pb").size() != 1) {

View File

@ -185,6 +185,8 @@ protected:
bool m_outputWordGraph; //! whether to output word graph
bool m_outputSearchGraph; //! whether to output search graph
bool m_outputSearchGraphExtended; //! ... in extended format
bool m_outputSearchGraphSLF; //! whether to output search graph in HTK standard lattice format (SLF)
bool m_outputSearchGraphHypergraph; //! whether to output search graph in hypergraph
#ifdef HAVE_PROTOBUF
bool m_outputSearchGraphPB; //! whether to output search graph as a protobuf
#endif
@ -394,7 +396,7 @@ public:
return m_nBestFilePath;
}
bool IsNBestEnabled() const {
return (!m_nBestFilePath.empty()) || m_mbr || m_useLatticeMBR || m_mira || m_outputSearchGraph || m_useConsensusDecoding || !m_latticeSamplesFilePath.empty()
return (!m_nBestFilePath.empty()) || m_mbr || m_useLatticeMBR || m_mira || m_outputSearchGraph || m_outputSearchGraphSLF || m_outputSearchGraphHypergraph || m_useConsensusDecoding || !m_latticeSamplesFilePath.empty()
#ifdef HAVE_PROTOBUF
|| m_outputSearchGraphPB
#endif
@ -557,6 +559,12 @@ public:
bool GetOutputSearchGraphExtended() const {
return m_outputSearchGraphExtended;
}
bool GetOutputSearchGraphSLF() const {
return m_outputSearchGraphSLF;
}
bool GetOutputSearchGraphHypergraph() const {
return m_outputSearchGraphHypergraph;
}
#ifdef HAVE_PROTOBUF
bool GetOutputSearchGraphPB() const {
return m_outputSearchGraphPB;