mirror of
https://github.com/moses-smt/mosesdecoder.git
synced 2024-10-27 03:49:57 +03:00
merge RunPb and RunChart
This commit is contained in:
parent
0036d8bb4d
commit
cec03c949e
@ -154,15 +154,7 @@ int main(int argc, char** argv)
|
||||
FeatureFunction::CallChangeSource(source);
|
||||
|
||||
// set up task of translating one sentence
|
||||
TranslationTask* task;
|
||||
if (staticData.IsChart()) {
|
||||
// scfg
|
||||
task = new TranslationTask(source, *ioWrapper, 2);
|
||||
}
|
||||
else {
|
||||
// pb
|
||||
task = new TranslationTask(source, *ioWrapper, 1);
|
||||
}
|
||||
TranslationTask* task = new TranslationTask(source, *ioWrapper);
|
||||
|
||||
// execute task
|
||||
#ifdef WITH_THREADS
|
||||
|
@ -46,6 +46,9 @@ protected:
|
||||
}
|
||||
|
||||
public:
|
||||
virtual ~BaseManager()
|
||||
{}
|
||||
|
||||
//! the input sentence being decoded
|
||||
const InputType& GetSource() const {
|
||||
return m_source;
|
||||
@ -53,6 +56,7 @@ public:
|
||||
|
||||
virtual void Decode() = 0;
|
||||
// outputs
|
||||
virtual void OutputBest(OutputCollector *collector) const = 0;
|
||||
virtual void OutputNBest(OutputCollector *collector) const = 0;
|
||||
virtual void OutputLatticeSamples(OutputCollector *collector) const = 0;
|
||||
virtual void OutputAlignment(OutputCollector *collector) const = 0;
|
||||
@ -60,9 +64,14 @@ public:
|
||||
virtual void OutputDetailedTreeFragmentsTranslationReport(OutputCollector *collector) const = 0;
|
||||
virtual void OutputWordGraph(OutputCollector *collector) const = 0;
|
||||
virtual void OutputSearchGraph(OutputCollector *collector) const = 0;
|
||||
virtual void OutputUnknowns(OutputCollector *collector) const = 0;
|
||||
virtual void OutputSearchGraphSLF() const = 0;
|
||||
virtual void OutputSearchGraphHypergraph() const = 0;
|
||||
|
||||
/***
|
||||
* to be called after processing a sentence
|
||||
*/
|
||||
virtual void CalcDecoderStatistics() const = 0;
|
||||
|
||||
};
|
||||
|
||||
|
@ -126,12 +126,8 @@ public:
|
||||
return m_hypoStackColl;
|
||||
}
|
||||
|
||||
/***
|
||||
* to be called after processing a sentence (which may consist of more than just calling ProcessSentence() )
|
||||
* currently an empty function
|
||||
*/
|
||||
void CalcDecoderStatistics() const {
|
||||
}
|
||||
void CalcDecoderStatistics() const
|
||||
{}
|
||||
|
||||
void ResetSentenceStats(const InputType& source) {
|
||||
m_sentenceStats = std::auto_ptr<SentenceStats>(new SentenceStats(source));
|
||||
|
@ -102,6 +102,11 @@ private:
|
||||
void OutputBestHypo(OutputCollector *collector, search::Applied applied, long translationId) const;
|
||||
void OutputBestNone(OutputCollector *collector, long translationId) const;
|
||||
|
||||
void OutputUnknowns(OutputCollector *collector) const
|
||||
{}
|
||||
void CalcDecoderStatistics() const
|
||||
{}
|
||||
|
||||
};
|
||||
|
||||
// Just get the phrase.
|
||||
|
@ -32,6 +32,9 @@ class Manager : public BaseManager
|
||||
void OutputSearchGraphHypergraph() const {}
|
||||
void OutputSearchGraphSLF() const {}
|
||||
void OutputWordGraph(OutputCollector *collector) const {}
|
||||
void OutputDetailedTranslationReport(OutputCollector *collector) const {}
|
||||
|
||||
void CalcDecoderStatistics() const {}
|
||||
|
||||
// Syntax-specific virtual functions that derived classes must implement.
|
||||
virtual void ExtractKBest(
|
||||
|
@ -20,10 +20,9 @@ using namespace std;
|
||||
namespace Moses
|
||||
{
|
||||
|
||||
TranslationTask::TranslationTask(InputType* source, Moses::IOWrapper &ioWrapper, int pbOrChart)
|
||||
TranslationTask::TranslationTask(InputType* source, Moses::IOWrapper &ioWrapper)
|
||||
: m_source(source)
|
||||
, m_ioWrapper(ioWrapper)
|
||||
, m_pbOrChart(pbOrChart)
|
||||
{}
|
||||
|
||||
TranslationTask::~TranslationTask() {
|
||||
@ -31,26 +30,10 @@ TranslationTask::~TranslationTask() {
|
||||
}
|
||||
|
||||
void TranslationTask::Run()
|
||||
{
|
||||
switch (m_pbOrChart)
|
||||
{
|
||||
case 1:
|
||||
RunPb();
|
||||
break;
|
||||
case 2:
|
||||
RunChart();
|
||||
break;
|
||||
default:
|
||||
UTIL_THROW(util::Exception, "Unknown value: " << m_pbOrChart);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void TranslationTask::RunPb()
|
||||
{
|
||||
// shorthand for "global data"
|
||||
const StaticData &staticData = StaticData::Instance();
|
||||
const size_t translationId = m_source->GetTranslationId();
|
||||
const size_t translationId = m_source->GetTranslationId();
|
||||
|
||||
// input sentence
|
||||
Sentence sentence;
|
||||
@ -70,7 +53,39 @@ void TranslationTask::RunPb()
|
||||
// we still need to apply the decision rule (MAP, MBR, ...)
|
||||
Timer initTime;
|
||||
initTime.start();
|
||||
Manager *manager = new Manager(*m_source);
|
||||
|
||||
// which manager
|
||||
BaseManager *manager;
|
||||
|
||||
switch (staticData.IsChart())
|
||||
{
|
||||
case false:
|
||||
manager = new Manager(*m_source);
|
||||
break;
|
||||
case true:
|
||||
if (staticData.UseS2TDecoder()) {
|
||||
S2TParsingAlgorithm algorithm = staticData.GetS2TParsingAlgorithm();
|
||||
if (algorithm == RecursiveCYKPlus) {
|
||||
typedef Syntax::S2T::EagerParserCallback Callback;
|
||||
typedef Syntax::S2T::RecursiveCYKPlusParser<Callback> Parser;
|
||||
manager = new Syntax::S2T::Manager<Parser>(*m_source);
|
||||
} else if (algorithm == Scope3) {
|
||||
typedef Syntax::S2T::StandardParserCallback Callback;
|
||||
typedef Syntax::S2T::Scope3Parser<Callback> Parser;
|
||||
manager = new Syntax::S2T::Manager<Parser>(*m_source);
|
||||
} else {
|
||||
UTIL_THROW2("ERROR: unhandled S2T parsing algorithm");
|
||||
}
|
||||
}
|
||||
else if (staticData.GetSearchAlgorithm() == ChartIncremental) {
|
||||
manager = new Incremental::Manager(*m_source);
|
||||
}
|
||||
else {
|
||||
manager = new ChartManager(*m_source);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
VERBOSE(1, "Line " << translationId << ": Initialize search took " << initTime << " seconds total" << endl);
|
||||
manager->Decode();
|
||||
|
||||
@ -118,70 +133,4 @@ void TranslationTask::RunPb()
|
||||
delete manager;
|
||||
}
|
||||
|
||||
|
||||
void TranslationTask::RunChart()
|
||||
{
|
||||
const StaticData &staticData = StaticData::Instance();
|
||||
const size_t translationId = m_source->GetTranslationId();
|
||||
|
||||
VERBOSE(2,"\nTRANSLATING(" << translationId << "): " << *m_source);
|
||||
|
||||
if (staticData.UseS2TDecoder()) {
|
||||
S2TParsingAlgorithm algorithm = staticData.GetS2TParsingAlgorithm();
|
||||
if (algorithm == RecursiveCYKPlus) {
|
||||
typedef Syntax::S2T::EagerParserCallback Callback;
|
||||
typedef Syntax::S2T::RecursiveCYKPlusParser<Callback> Parser;
|
||||
DecodeS2T<Parser>();
|
||||
} else if (algorithm == Scope3) {
|
||||
typedef Syntax::S2T::StandardParserCallback Callback;
|
||||
typedef Syntax::S2T::Scope3Parser<Callback> Parser;
|
||||
DecodeS2T<Parser>();
|
||||
} else {
|
||||
UTIL_THROW2("ERROR: unhandled S2T parsing algorithm");
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (staticData.GetSearchAlgorithm() == ChartIncremental) {
|
||||
Incremental::Manager manager(*m_source);
|
||||
manager.Decode();
|
||||
manager.OutputBest(m_ioWrapper.GetSingleBestOutputCollector());
|
||||
manager.OutputDetailedTranslationReport(m_ioWrapper.GetDetailedTranslationCollector());
|
||||
manager.OutputDetailedTreeFragmentsTranslationReport(m_ioWrapper.GetDetailTreeFragmentsOutputCollector());
|
||||
manager.OutputNBest(m_ioWrapper.GetNBestOutputCollector());
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
ChartManager manager(*m_source);
|
||||
manager.Decode();
|
||||
|
||||
UTIL_THROW_IF2(staticData.UseMBR(), "Cannot use MBR");
|
||||
|
||||
// Output search graph in hypergraph format for Kenneth Heafield's lazy hypergraph decoder
|
||||
manager.OutputSearchGraphHypergraph();
|
||||
|
||||
// 1-best
|
||||
manager.OutputBest(m_ioWrapper.GetSingleBestOutputCollector());
|
||||
|
||||
IFVERBOSE(2) {
|
||||
PrintUserTime("Best Hypothesis Generation Time:");
|
||||
}
|
||||
|
||||
manager.OutputAlignment(m_ioWrapper.GetAlignmentInfoCollector());
|
||||
manager.OutputDetailedTranslationReport(m_ioWrapper.GetDetailedTranslationCollector());
|
||||
manager.OutputDetailedTreeFragmentsTranslationReport(m_ioWrapper.GetDetailTreeFragmentsOutputCollector());
|
||||
manager.OutputUnknowns(m_ioWrapper.GetUnknownsCollector());
|
||||
|
||||
// n-best
|
||||
manager.OutputNBest(m_ioWrapper.GetNBestOutputCollector());
|
||||
|
||||
manager.OutputSearchGraph(m_ioWrapper.GetSearchGraphOutputCollector());
|
||||
|
||||
IFVERBOSE(2) {
|
||||
PrintUserTime("Sentence Decoding Time:");
|
||||
}
|
||||
manager.CalcDecoderStatistics();
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -26,7 +26,7 @@ class TranslationTask : public Moses::Task
|
||||
|
||||
public:
|
||||
|
||||
TranslationTask(Moses::InputType* source, Moses::IOWrapper &ioWrapper, int pbOrChart);
|
||||
TranslationTask(Moses::InputType* source, Moses::IOWrapper &ioWrapper);
|
||||
|
||||
~TranslationTask();
|
||||
|
||||
@ -36,33 +36,9 @@ public:
|
||||
|
||||
|
||||
private:
|
||||
int m_pbOrChart; // 1=pb. 2=chart
|
||||
Moses::InputType* m_source;
|
||||
Moses::IOWrapper &m_ioWrapper;
|
||||
|
||||
void RunPb();
|
||||
void RunChart();
|
||||
|
||||
|
||||
template<typename Parser>
|
||||
void DecodeS2T() {
|
||||
const StaticData &staticData = StaticData::Instance();
|
||||
const std::size_t translationId = m_source->GetTranslationId();
|
||||
Syntax::S2T::Manager<Parser> manager(*m_source);
|
||||
manager.Decode();
|
||||
// 1-best
|
||||
manager.OutputBest(m_ioWrapper.GetSingleBestOutputCollector());
|
||||
|
||||
// n-best
|
||||
manager.OutputNBest(m_ioWrapper.GetNBestOutputCollector());
|
||||
|
||||
// Write 1-best derivation (-translation-details / -T option).
|
||||
|
||||
manager.OutputDetailedTranslationReport(m_ioWrapper.GetDetailedTranslationCollector());
|
||||
|
||||
manager.OutputUnknowns(m_ioWrapper.GetUnknownsCollector());
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user