Merging of moses mains. mosesmt now does single and multi-threeaded

git-svn-id: https://mosesdecoder.svn.sourceforge.net/svnroot/mosesdecoder/trunk@3264 1f5c12ca-751b-0410-a591-d2e778427230
This commit is contained in:
bhaddow 2010-05-19 16:42:18 +00:00
parent c2f35f614b
commit c9baabe2ea
10 changed files with 213 additions and 75 deletions

View File

@ -162,6 +162,7 @@ void IOWrapper::Initialization(const std::vector<FactorType> &/*inputFactorOrder
{
const std::string &path = staticData.GetDetailedTranslationReportingFilePath();
m_detailedTranslationReportingStream = new std::ofstream(path.c_str());
assert(m_detailedTranslationReportingStream->good());
}
}

View File

@ -121,5 +121,6 @@ void OutputLatticeMBRNBest(std::ostream& out, const std::vector<LatticeMBRSoluti
void OutputBestHypo(const std::vector<Moses::Word>& mbrBestHypo, long /*translationId*/,
bool reportSegmentation, bool reportAllFactors, std::ostream& out);
void OutputBestHypo(const Moses::TrellisPath &path, long /*translationId*/,bool reportSegmentation, bool reportAllFactors, std::ostream &out);
void OutputInput(std::ostream& os, const Hypothesis* hypo);
#endif

View File

@ -151,7 +151,7 @@ int main(int argc, char* argv[])
//Print all derivations in search graph
if (staticData.PrintAllDerivations()) {
manager.PrintAllDerivations(source->GetTranslationId());
manager.PrintAllDerivations(source->GetTranslationId(), std::cerr);
}
// pick best translation (maximum a posteriori decoding)

View File

@ -20,31 +20,49 @@ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
***********************************************************************/
/**
* Main for multithreaded moses.
* Moses main, for single-threaded and multi-threaded.
**/
#include <fstream>
#include <sstream>
#include <vector>
#include <boost/thread/mutex.hpp>
#if defined(BOOST_HAS_PTHREADS)
#include <pthread.h>
#ifdef WIN32
// Include Visual Leak Detector
#include <vld.h>
#endif
#ifdef WITH_THREADS
#include <boost/thread/mutex.hpp>
#endif
#ifdef BOOST_HAS_PTHREADS
#include <pthread.h>
#endif
#include "Hypothesis.h"
#include "IOWrapper.h"
#include "LatticeMBR.h"
#include "Manager.h"
#include "StaticData.h"
#include "ThreadPool.h"
#include "Util.h"
#include "mbr.h"
#include "ThreadPool.h"
#include "TranslationAnalysis.h"
#ifdef HAVE_PROTOBUF
#include "hypergraph.pb.h"
#endif
using namespace std;
using namespace Moses;
/** Enforce rounding */
void fix(std::ostream& stream) {
stream.setf(std::ios::fixed);
stream.precision(3);
}
/**
* Makes sure output goes in the correct order.
@ -59,21 +77,23 @@ class OutputCollector {
* Write or cache the output, as appropriate.
**/
void Write(int sourceId,const string& output,const string& debug="") {
#ifdef WITH_THREADS
boost::mutex::scoped_lock lock(m_mutex);
#endif
if (sourceId == m_nextOutput) {
//This is the one we were expecting
*m_outStream << output;
*m_debugStream << debug;
*m_outStream << output << flush;
*m_debugStream << debug << flush;
++m_nextOutput;
//see if there's any more
map<int,string>::iterator iter;
while ((iter = m_outputs.find(m_nextOutput)) != m_outputs.end()) {
*m_outStream << iter->second;
*m_outStream << iter->second << flush;
m_outputs.erase(iter);
++m_nextOutput;
map<int,string>::iterator debugIter = m_debugs.find(iter->first);
if (debugIter != m_debugs.end()) {
*m_debugStream << debugIter->second;
*m_debugStream << debugIter->second << flush;
m_debugs.erase(debugIter);
}
}
@ -90,7 +110,9 @@ class OutputCollector {
int m_nextOutput;
ostream* m_outStream;
ostream* m_debugStream;
#ifdef WITH_THREADS
boost::mutex m_mutex;
#endif
};
/**
@ -101,27 +123,72 @@ class TranslationTask : public Task {
public:
TranslationTask(size_t lineNumber,
InputType* source, OutputCollector* outputCollector, OutputCollector* nbestCollector) :
InputType* source, OutputCollector* outputCollector, OutputCollector* nbestCollector,
OutputCollector* wordGraphCollector, OutputCollector* searchGraphCollector,
OutputCollector* detailedTranslationCollector) :
m_source(source), m_lineNumber(lineNumber),
m_outputCollector(outputCollector), m_nbestCollector(nbestCollector) {}
m_outputCollector(outputCollector), m_nbestCollector(nbestCollector),
m_wordGraphCollector(wordGraphCollector), m_searchGraphCollector(searchGraphCollector),
m_detailedTranslationCollector(detailedTranslationCollector) {}
void Run()
{
#if defined(BOOST_HAS_PTHREADS)
#ifdef BOOST_HAS_PTHREADS
TRACE_ERR("Translating line " << m_lineNumber << " in thread id " << pthread_self() << std::endl);
#endif
const StaticData &staticData = StaticData::Instance();
Sentence sentence(Input);
Manager manager(*m_source,staticData.GetSearchAlgorithm());
manager.ProcessSentence();
//Word Graph
if (m_wordGraphCollector) {
ostringstream out;
fix(out);
manager.GetWordGraph(m_lineNumber, out);
m_wordGraphCollector->Write(m_lineNumber, out.str());
}
//Search Graph
if (m_searchGraphCollector) {
ostringstream out;
fix(out);
manager.OutputSearchGraph(m_lineNumber, out);
m_searchGraphCollector->Write(m_lineNumber, out.str());
#ifdef HAVE_PROTOBUF
if (staticData.GetOutputSearchGraphPB()) {
ostringstream sfn;
sfn << staticData.GetParam("output-search-graph-pb")[0] << '/' << m_lineNumber << ".pb" << ends;
string fn = sfn.str();
VERBOSE(2, "Writing search graph to " << fn << endl);
fstream output(fn.c_str(), ios::trunc | ios::binary | ios::out);
manager.SerializeSearchGraphPB(m_lineNumber, output);
}
#endif
}
if (m_outputCollector) {
ostringstream out;
ostringstream debug;
fix(debug);
//All derivations - send them to debug stream
if (staticData.PrintAllDerivations()) {
manager.PrintAllDerivations(m_lineNumber, debug);
}
//Best hypothesis
const Hypothesis* bestHypo = NULL;
if (!staticData.UseMBR()) {
bestHypo = manager.GetBestHypothesis();
if (bestHypo) {
if (staticData.IsPathRecoveryEnabled()) {
OutputInput(out, bestHypo);
out << "||| ";
}
OutputSurface(
out,
bestHypo,
@ -195,6 +262,18 @@ class TranslationTask : public Task {
OutputNBest(out,nBestList, staticData.GetOutputFactorOrder(), m_lineNumber);
m_nbestCollector->Write(m_lineNumber, out.str());
}
//detailed translation reporting
if (m_detailedTranslationCollector) {
ostringstream out;
fix(out);
TranslationAnalysis::PrintTranslationAnalysis(out, manager.GetBestHypothesis());
m_detailedTranslationCollector->Write(m_lineNumber,out.str());
}
IFVERBOSE(2) { PrintUserTime("Sentence Decoding Time:"); }
manager.CalcDecoderStatistics();
}
~TranslationTask() {delete m_source;}
@ -204,51 +283,81 @@ class TranslationTask : public Task {
size_t m_lineNumber;
OutputCollector* m_outputCollector;
OutputCollector* m_nbestCollector;
OutputCollector* m_wordGraphCollector;
OutputCollector* m_searchGraphCollector;
OutputCollector* m_detailedTranslationCollector;
};
int main(int argc, char** argv) {
//extract pool-size args, send others to moses
char** mosesargv = new char*[argc+2];
int mosesargc = 0;
int threadcount = 10;
for (int i = 0; i < argc; ++i) {
if (!strcmp(argv[i], "-threads")) {
++i;
if (i >= argc) {
cerr << "Error: Missing argument to -threads" << endl;
exit(1);
} else {
threadcount = atoi(argv[i]);
}
} else {
mosesargv[mosesargc] = new char[strlen(argv[i])+1];
strcpy(mosesargv[mosesargc],argv[i]);
++mosesargc;
}
}
if (threadcount <= 0) {
cerr << "Error: Must specify a positive number of threads" << endl;
exit(1);
#ifdef HAVE_PROTOBUF
GOOGLE_PROTOBUF_VERIFY_VERSION;
#endif
IFVERBOSE(1)
{
TRACE_ERR("command: ");
for(int i=0;i<argc;++i) TRACE_ERR(argv[i]<<" ");
TRACE_ERR(endl);
}
fix(cout);
fix(cerr);
Parameter* params = new Parameter();
if (!params->LoadParam(mosesargc,mosesargv)) {
if (!params->LoadParam(argc,argv)) {
params->Explain();
exit(1);
}
//create threadpool, if necessary
int threadcount = (params->GetParam("threads").size() > 0) ?
Scan<size_t>(params->GetParam("threads")[0]) : 1;
#ifdef WITH_THREADS
if (threadcount < 1) {
cerr << "Error: Need to specify a positive number of threads" << endl;
exit(1);
}
ThreadPool pool(threadcount);
#else
if (threadcount > 1) {
cerr << "Error: Thread count of " << threadcount << " but moses not built with thread support" << endl;
exit(1);
}
#endif
if (!StaticData::LoadDataStatic(params)) {
exit(1);
}
const StaticData& staticData = StaticData::Instance();
// set up read/writing class
IOWrapper* ioWrapper = GetIODevice(staticData);
if (!ioWrapper) {
cerr << "Error; Failed to create IO object" << endl;
exit(1);
}
ThreadPool pool(threadcount);
// check on weights
vector<float> weights = staticData.GetAllWeights();
IFVERBOSE(2) {
TRACE_ERR("The score component vector looks like this:\n" << staticData.GetScoreIndexManager());
TRACE_ERR("The global weight vector looks like this:");
for (size_t j=0; j<weights.size(); j++) { TRACE_ERR(" " << weights[j]); }
TRACE_ERR("\n");
}
// every score must have a weight! check that here:
if(weights.size() != staticData.GetScoreIndexManager().GetTotalNumberOfScores()) {
TRACE_ERR("ERROR: " << staticData.GetScoreIndexManager().GetTotalNumberOfScores() << " score components, but " << weights.size() << " weights defined" << std::endl);
exit(1);
}
InputType* source = NULL;
size_t lineCount = 0;
auto_ptr<OutputCollector> outputCollector;//for translations
@ -257,9 +366,8 @@ int main(int argc, char** argv) {
size_t nbestSize = staticData.GetNBestSize();
string nbestFile = staticData.GetNBestFilePath();
if (nbestSize) {
if (nbestFile == "-") {
if (nbestFile == "-" || nbestFile == "/dev/stdout") {
//nbest to stdout, no 1-best
//FIXME: Moses doesn't actually let you pass a '-' on the command line.
nbestCollector.reset(new OutputCollector());
} else {
//nbest to file, 1-best to stdout
@ -272,22 +380,47 @@ int main(int argc, char** argv) {
outputCollector.reset(new OutputCollector());
}
auto_ptr<OutputCollector> wordGraphCollector;
if (staticData.GetOutputWordGraph()) {
wordGraphCollector.reset(new OutputCollector(&(ioWrapper->GetOutputWordGraphStream())));
}
auto_ptr<OutputCollector> searchGraphCollector;
if (staticData.GetOutputSearchGraph()) {
searchGraphCollector.reset(new OutputCollector(&(ioWrapper->GetOutputSearchGraphStream())));
}
auto_ptr<OutputCollector> detailedTranslationCollector;
if (staticData.IsDetailedTranslationReportingEnabled()) {
detailedTranslationCollector.reset(new OutputCollector(&(ioWrapper->GetDetailedTranslationReportingStream())));
}
while(ReadInput(*ioWrapper,staticData.GetInputType(),source)) {
IFVERBOSE(1) {
ResetUserTime();
}
TranslationTask* task =
new TranslationTask(lineCount,source, outputCollector.get(), nbestCollector.get());
new TranslationTask(lineCount,source, outputCollector.get(), nbestCollector.get(), wordGraphCollector.get(),
searchGraphCollector.get(), detailedTranslationCollector.get());
#ifdef WITH_THREADS
pool.Submit(task);
#else
task->Run();
#endif
source = NULL; //make sure it doesn't get deleted
++lineCount;
}
#ifdef WITH_THREADS
pool.Stop(true); //flush remaining jobs
#endif
#ifndef EXIT_RETURN
//This avoids that detructors are called (it can take a long time)
#ifndef EXIT_RETURN
//This avoids that destructors are called (it can take a long time)
exit(EXIT_SUCCESS);
#else
#else
return EXIT_SUCCESS;
#endif
#endif
}

View File

@ -1,8 +1,4 @@
if WITH_THREADS
bin_PROGRAMS = moses mosesmt lmbrgrid
else
bin_PROGRAMS = moses lmbrgrid
endif
bin_PROGRAMS = moses mosesmt lmbrgrid
AM_CPPFLAGS = -W -Wall -ffor-scope -D_FILE_OFFSET_BITS=64 -D_LARGE_FILES -DUSE_HYPO_POOL -I$(top_srcdir)/moses/src $(BOOST_CPPFLAGS)

View File

@ -22,6 +22,8 @@ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
#include "ThreadPool.h"
#ifdef WITH_THREADS
using namespace std;
using namespace Moses;
@ -54,7 +56,7 @@ void Moses::ThreadPool::Execute()
}
m_threadAvailable.notify_all();
} while (!m_stopped);
#if defined(BOOST_HAS_PTHREADS)
#ifdef BOOST_HAS_PTHREADS
TRACE_ERR("Thread " << pthread_self() << " exiting" << endl);
#endif
}
@ -92,6 +94,7 @@ void Moses::ThreadPool::Stop(bool processRemainingJobs)
}
m_threadNeeded.notify_all();
cerr << m_threads.size() << endl;
m_threads.join_all();
}
#endif //WITH_THREADS

View File

@ -26,11 +26,12 @@ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
#include <queue>
#include <vector>
#ifdef WITH_THREADS
#include <boost/bind.hpp>
#include <boost/thread.hpp>
#endif
#if defined(BOOST_HAS_PTHREADS)
#ifdef BOOST_HAS_PTHREADS
#include <pthread.h>
#endif
@ -52,7 +53,9 @@ class Task {
public:
virtual void Run() = 0;
virtual ~Task() {}
};
};
#ifdef WITH_THREADS
class ThreadPool {
public:
@ -98,7 +101,7 @@ class TestTask : public Task {
public:
TestTask(int id) : m_id(id) {}
virtual void Run() {
#if defined(BOOST_HAS_PTHREADS)
#ifdef BOOST_HAS_PTHREADS
pthread_t tid = pthread_self();
#else
pthread_t tid = 0;
@ -112,7 +115,7 @@ class TestTask : public Task {
int m_id;
};
#endif //WITH_THREADS
}
#endif

View File

@ -108,7 +108,7 @@ void Manager::ProcessSentence()
*
*/
void Manager::PrintAllDerivations(long translationId ) const
void Manager::PrintAllDerivations(long translationId, ostream& outputStream) const
{
const std::vector < HypothesisStack* > &hypoStackColl = m_search->GetHypothesisStacks();
@ -126,14 +126,14 @@ void Manager::PrintAllDerivations(long translationId ) const
; iterBestHypo != sortedPureHypo.end()
; ++iterBestHypo)
{
printThisHypothesis(translationId, *iterBestHypo, remainingPhrases, remainingScore);
printDivergentHypothesis(translationId, *iterBestHypo, remainingPhrases, remainingScore);
printThisHypothesis(translationId, *iterBestHypo, remainingPhrases, remainingScore, outputStream);
printDivergentHypothesis(translationId, *iterBestHypo, remainingPhrases, remainingScore, outputStream);
}
}
void Manager::getSntTranslationOptions(std::ostream& outStr) { outStr << *m_transOptColl; }
void Manager::printDivergentHypothesis(long translationId, const Hypothesis* hypo, const vector <const TargetPhrase*> & remainingPhrases, float remainingScore ) const
void Manager::printDivergentHypothesis(long translationId, const Hypothesis* hypo, const vector <const TargetPhrase*> & remainingPhrases, float remainingScore , ostream& outputStream ) const
{
//Backtrack from the predecessor
if (hypo->GetId() > 0) {
@ -141,7 +141,7 @@ void Manager::printDivergentHypothesis(long translationId, const Hypothesis* hyp
followingPhrases.push_back(& (hypo->GetCurrTargetPhrase()));
///((Phrase) hypo->GetPrevHypo()->GetTargetPhrase());
followingPhrases.insert(followingPhrases.end()--, remainingPhrases.begin(), remainingPhrases.end());
printDivergentHypothesis(translationId, hypo->GetPrevHypo(), followingPhrases , remainingScore + hypo->GetScore() - hypo->GetPrevHypo()->GetScore());
printDivergentHypothesis(translationId, hypo->GetPrevHypo(), followingPhrases , remainingScore + hypo->GetScore() - hypo->GetPrevHypo()->GetScore(), outputStream);
}
//Process the arcs
@ -158,33 +158,33 @@ void Manager::printDivergentHypothesis(long translationId, const Hypothesis* hyp
vector <const TargetPhrase* > followingPhrases;
followingPhrases.push_back(&(loserHypo->GetCurrTargetPhrase()));
followingPhrases.insert(followingPhrases.end()--, remainingPhrases.begin(), remainingPhrases.end());
printThisHypothesis(translationId, loserPrevHypo, followingPhrases, remainingScore + arcScore);
printDivergentHypothesis(translationId, loserPrevHypo, followingPhrases, remainingScore + arcScore);
printThisHypothesis(translationId, loserPrevHypo, followingPhrases, remainingScore + arcScore, outputStream);
printDivergentHypothesis(translationId, loserPrevHypo, followingPhrases, remainingScore + arcScore, outputStream);
}
}
}
void Manager::printThisHypothesis(long translationId, const Hypothesis* hypo, const vector <const TargetPhrase*> & remainingPhrases, float remainingScore ) const
void Manager::printThisHypothesis(long translationId, const Hypothesis* hypo, const vector <const TargetPhrase*> & remainingPhrases, float remainingScore, ostream& outputStream) const
{
cerr << translationId << " ||| ";
outputStream << translationId << " ||| ";
//Yield of this hypothesis
hypo->ToStream(cerr);
hypo->ToStream(outputStream);
for (size_t p = 0; p < remainingPhrases.size(); ++p) {
const TargetPhrase * phrase = remainingPhrases[p];
size_t size = phrase->GetSize();
for (size_t pos = 0 ; pos < size ; pos++)
{
const Factor *factor = phrase->GetFactor(pos, 0);
cerr << *factor;
cerr << " ";
outputStream << *factor;
outputStream << " ";
}
}
cerr << "||| " << hypo->GetScore() + remainingScore;
cerr << endl;
outputStream << "||| " << hypo->GetScore() + remainingScore;
outputStream << endl;
}

View File

@ -126,9 +126,9 @@ public:
const Hypothesis *GetBestHypothesis() const;
const Hypothesis *GetActualBestHypothesis() const;
void CalcNBest(size_t count, TrellisPathList &ret,bool onlyDistinct=0) const;
void PrintAllDerivations(long translationId) const;
void printDivergentHypothesis(long translationId, const Hypothesis* hypo, const std::vector <const TargetPhrase*> & remainingPhrases, float remainingScore ) const;
void printThisHypothesis(long translationId, const Hypothesis* hypo, const std::vector <const TargetPhrase* > & remainingPhrases, float remainingScore ) const;
void PrintAllDerivations(long translationId, std::ostream& outputStream) const;
void printDivergentHypothesis(long translationId, const Hypothesis* hypo, const std::vector <const TargetPhrase*> & remainingPhrases, float remainingScore , std::ostream& outputStream) const;
void printThisHypothesis(long translationId, const Hypothesis* hypo, const std::vector <const TargetPhrase* > & remainingPhrases, float remainingScore , std::ostream& outputStream) const;
void GetWordGraph(long translationId, std::ostream &outputWordGraphStream) const;
int GetNextHypoId();
#ifdef HAVE_PROTOBUF

View File

@ -68,6 +68,7 @@ Parameter::Parameter()
AddParam("report-segmentation", "t", "report phrase segmentation in the output");
AddParam("stack", "s", "maximum stack size for histogram pruning");
AddParam("stack-diversity", "sd", "minimum number of hypothesis of each coverage in stack (default 0)");
AddParam("threads","th", "number of threads to use in decoding (defaults to single-threaded)");
AddParam("translation-details", "T", "for each best hypothesis, report translation details to the given file");
AddParam("ttable-file", "location and properties of the translation tables");
AddParam("ttable-limit", "ttl", "maximum number of translation table entries per input phrase");