merge RunPb and RunChart

This commit is contained in:
Hieu Hoang 2015-01-04 15:32:44 +05:30
parent 0036d8bb4d
commit cec03c949e
7 changed files with 56 additions and 126 deletions

View File

@ -154,15 +154,7 @@ int main(int argc, char** argv)
FeatureFunction::CallChangeSource(source);
// set up task of translating one sentence
TranslationTask* task;
if (staticData.IsChart()) {
// scfg
task = new TranslationTask(source, *ioWrapper, 2);
}
else {
// pb
task = new TranslationTask(source, *ioWrapper, 1);
}
TranslationTask* task = new TranslationTask(source, *ioWrapper);
// execute task
#ifdef WITH_THREADS

View File

@ -46,6 +46,9 @@ protected:
}
public:
virtual ~BaseManager()
{}
//! the input sentence being decoded
const InputType& GetSource() const {
return m_source;
@ -53,6 +56,7 @@ public:
virtual void Decode() = 0;
// outputs
virtual void OutputBest(OutputCollector *collector) const = 0;
virtual void OutputNBest(OutputCollector *collector) const = 0;
virtual void OutputLatticeSamples(OutputCollector *collector) const = 0;
virtual void OutputAlignment(OutputCollector *collector) const = 0;
@ -60,9 +64,14 @@ public:
virtual void OutputDetailedTreeFragmentsTranslationReport(OutputCollector *collector) const = 0;
virtual void OutputWordGraph(OutputCollector *collector) const = 0;
virtual void OutputSearchGraph(OutputCollector *collector) const = 0;
virtual void OutputUnknowns(OutputCollector *collector) const = 0;
virtual void OutputSearchGraphSLF() const = 0;
virtual void OutputSearchGraphHypergraph() const = 0;
/***
* to be called after processing a sentence
*/
virtual void CalcDecoderStatistics() const = 0;
};

View File

@ -126,12 +126,8 @@ public:
return m_hypoStackColl;
}
/***
* to be called after processing a sentence (which may consist of more than just calling ProcessSentence() )
* currently an empty function
*/
void CalcDecoderStatistics() const {
}
void CalcDecoderStatistics() const
{}
void ResetSentenceStats(const InputType& source) {
m_sentenceStats = std::auto_ptr<SentenceStats>(new SentenceStats(source));

View File

@ -102,6 +102,11 @@ private:
void OutputBestHypo(OutputCollector *collector, search::Applied applied, long translationId) const;
void OutputBestNone(OutputCollector *collector, long translationId) const;
void OutputUnknowns(OutputCollector *collector) const
{}
void CalcDecoderStatistics() const
{}
};
// Just get the phrase.

View File

@ -32,6 +32,9 @@ class Manager : public BaseManager
void OutputSearchGraphHypergraph() const {}
void OutputSearchGraphSLF() const {}
void OutputWordGraph(OutputCollector *collector) const {}
void OutputDetailedTranslationReport(OutputCollector *collector) const {}
void CalcDecoderStatistics() const {}
// Syntax-specific virtual functions that derived classes must implement.
virtual void ExtractKBest(

View File

@ -20,10 +20,9 @@ using namespace std;
namespace Moses
{
TranslationTask::TranslationTask(InputType* source, Moses::IOWrapper &ioWrapper, int pbOrChart)
TranslationTask::TranslationTask(InputType* source, Moses::IOWrapper &ioWrapper)
: m_source(source)
, m_ioWrapper(ioWrapper)
, m_pbOrChart(pbOrChart)
{}
TranslationTask::~TranslationTask() {
@ -31,26 +30,10 @@ TranslationTask::~TranslationTask() {
}
void TranslationTask::Run()
{
switch (m_pbOrChart)
{
case 1:
RunPb();
break;
case 2:
RunChart();
break;
default:
UTIL_THROW(util::Exception, "Unknown value: " << m_pbOrChart);
}
}
void TranslationTask::RunPb()
{
// shorthand for "global data"
const StaticData &staticData = StaticData::Instance();
const size_t translationId = m_source->GetTranslationId();
const size_t translationId = m_source->GetTranslationId();
// input sentence
Sentence sentence;
@ -70,7 +53,39 @@ void TranslationTask::RunPb()
// we still need to apply the decision rule (MAP, MBR, ...)
Timer initTime;
initTime.start();
Manager *manager = new Manager(*m_source);
// which manager
BaseManager *manager;
switch (staticData.IsChart())
{
case false:
manager = new Manager(*m_source);
break;
case true:
if (staticData.UseS2TDecoder()) {
S2TParsingAlgorithm algorithm = staticData.GetS2TParsingAlgorithm();
if (algorithm == RecursiveCYKPlus) {
typedef Syntax::S2T::EagerParserCallback Callback;
typedef Syntax::S2T::RecursiveCYKPlusParser<Callback> Parser;
manager = new Syntax::S2T::Manager<Parser>(*m_source);
} else if (algorithm == Scope3) {
typedef Syntax::S2T::StandardParserCallback Callback;
typedef Syntax::S2T::Scope3Parser<Callback> Parser;
manager = new Syntax::S2T::Manager<Parser>(*m_source);
} else {
UTIL_THROW2("ERROR: unhandled S2T parsing algorithm");
}
}
else if (staticData.GetSearchAlgorithm() == ChartIncremental) {
manager = new Incremental::Manager(*m_source);
}
else {
manager = new ChartManager(*m_source);
}
break;
}
VERBOSE(1, "Line " << translationId << ": Initialize search took " << initTime << " seconds total" << endl);
manager->Decode();
@ -118,70 +133,4 @@ void TranslationTask::RunPb()
delete manager;
}
void TranslationTask::RunChart()
{
const StaticData &staticData = StaticData::Instance();
const size_t translationId = m_source->GetTranslationId();
VERBOSE(2,"\nTRANSLATING(" << translationId << "): " << *m_source);
if (staticData.UseS2TDecoder()) {
S2TParsingAlgorithm algorithm = staticData.GetS2TParsingAlgorithm();
if (algorithm == RecursiveCYKPlus) {
typedef Syntax::S2T::EagerParserCallback Callback;
typedef Syntax::S2T::RecursiveCYKPlusParser<Callback> Parser;
DecodeS2T<Parser>();
} else if (algorithm == Scope3) {
typedef Syntax::S2T::StandardParserCallback Callback;
typedef Syntax::S2T::Scope3Parser<Callback> Parser;
DecodeS2T<Parser>();
} else {
UTIL_THROW2("ERROR: unhandled S2T parsing algorithm");
}
return;
}
if (staticData.GetSearchAlgorithm() == ChartIncremental) {
Incremental::Manager manager(*m_source);
manager.Decode();
manager.OutputBest(m_ioWrapper.GetSingleBestOutputCollector());
manager.OutputDetailedTranslationReport(m_ioWrapper.GetDetailedTranslationCollector());
manager.OutputDetailedTreeFragmentsTranslationReport(m_ioWrapper.GetDetailTreeFragmentsOutputCollector());
manager.OutputNBest(m_ioWrapper.GetNBestOutputCollector());
return;
}
ChartManager manager(*m_source);
manager.Decode();
UTIL_THROW_IF2(staticData.UseMBR(), "Cannot use MBR");
// Output search graph in hypergraph format for Kenneth Heafield's lazy hypergraph decoder
manager.OutputSearchGraphHypergraph();
// 1-best
manager.OutputBest(m_ioWrapper.GetSingleBestOutputCollector());
IFVERBOSE(2) {
PrintUserTime("Best Hypothesis Generation Time:");
}
manager.OutputAlignment(m_ioWrapper.GetAlignmentInfoCollector());
manager.OutputDetailedTranslationReport(m_ioWrapper.GetDetailedTranslationCollector());
manager.OutputDetailedTreeFragmentsTranslationReport(m_ioWrapper.GetDetailTreeFragmentsOutputCollector());
manager.OutputUnknowns(m_ioWrapper.GetUnknownsCollector());
// n-best
manager.OutputNBest(m_ioWrapper.GetNBestOutputCollector());
manager.OutputSearchGraph(m_ioWrapper.GetSearchGraphOutputCollector());
IFVERBOSE(2) {
PrintUserTime("Sentence Decoding Time:");
}
manager.CalcDecoderStatistics();
}
}

View File

@ -26,7 +26,7 @@ class TranslationTask : public Moses::Task
public:
TranslationTask(Moses::InputType* source, Moses::IOWrapper &ioWrapper, int pbOrChart);
TranslationTask(Moses::InputType* source, Moses::IOWrapper &ioWrapper);
~TranslationTask();
@ -36,33 +36,9 @@ public:
private:
int m_pbOrChart; // 1=pb. 2=chart
Moses::InputType* m_source;
Moses::IOWrapper &m_ioWrapper;
void RunPb();
void RunChart();
template<typename Parser>
void DecodeS2T() {
const StaticData &staticData = StaticData::Instance();
const std::size_t translationId = m_source->GetTranslationId();
Syntax::S2T::Manager<Parser> manager(*m_source);
manager.Decode();
// 1-best
manager.OutputBest(m_ioWrapper.GetSingleBestOutputCollector());
// n-best
manager.OutputNBest(m_ioWrapper.GetNBestOutputCollector());
// Write 1-best derivation (-translation-details / -T option).
manager.OutputDetailedTranslationReport(m_ioWrapper.GetDetailedTranslationCollector());
manager.OutputUnknowns(m_ioWrapper.GetUnknownsCollector());
}
};