Managers and feature functions now have access to the entire TranslationTask, not just the InputType.

This commit is contained in:
Ulrich Germann 2015-03-26 18:25:54 +00:00
parent 4410e9225a
commit 9dc75bfd8a
24 changed files with 365 additions and 225 deletions

View File

@ -82,12 +82,16 @@ include $(TOP)/jam-files/sanity.jam ;
home = [ os.environ "HOME" ] ; home = [ os.environ "HOME" ] ;
if [ path.exists $(home)/moses-environment.jam ] if [ path.exists $(home)/moses-environment.jam ]
{ include $(home)/moses-environment.jam ; } {
# for those of use who don't like typing in command line bjam options all day long
include $(home)/moses-environment.jam ;
}
include $(TOP)/jam-files/check-environment.jam ; # get resource locations include $(TOP)/jam-files/check-environment.jam ; # get resource locations
# from environment variables # from environment variables
include $(TOP)/jam-files/xmlrpc-c.jam ; # xmlrpc-c stuff for the server include $(TOP)/jam-files/xmlrpc-c.jam ; # xmlrpc-c stuff for the server
include $(TOP)/jam-files/curlpp.jam ; # curlpp stuff for bias lookup (MMT only) include $(TOP)/jam-files/curlpp.jam ; # curlpp stuff for bias lookup (MMT only)
# exit "done" : 0 ;
max-order = [ option.get "max-kenlm-order" : 6 : 6 ] ; max-order = [ option.get "max-kenlm-order" : 6 : 6 ] ;
if ! [ option.get "max-kenlm-order" ] if ! [ option.get "max-kenlm-order" ]

View File

@ -37,6 +37,7 @@ int main(int argc, char** argv)
#include "moses/Manager.h" #include "moses/Manager.h"
#include "moses/StaticData.h" #include "moses/StaticData.h"
#include "moses/ThreadPool.h" #include "moses/ThreadPool.h"
#include "moses/TranslationTask.h"
#include "moses/TranslationModel/PhraseDictionaryDynSuffixArray.h" #include "moses/TranslationModel/PhraseDictionaryDynSuffixArray.h"
#include "moses/TranslationModel/PhraseDictionaryMultiModelCounts.h" #include "moses/TranslationModel/PhraseDictionaryMultiModelCounts.h"
#if PT_UG #if PT_UG
@ -232,8 +233,8 @@ public:
/** /**
* Required so that translations can be sent to a thread pool. * Required so that translations can be sent to a thread pool.
**/ **/
class TranslationTask : public virtual Moses::Task { class TranslationTask : public virtual Moses::TranslationTask {
public: protected:
TranslationTask(xmlrpc_c::paramList const& paramList, TranslationTask(xmlrpc_c::paramList const& paramList,
boost::condition_variable& cond, boost::mutex& mut) boost::condition_variable& cond, boost::mutex& mut)
: m_paramList(paramList), : m_paramList(paramList),
@ -242,23 +243,33 @@ public:
m_done(false) m_done(false)
{} {}
public:
static boost::shared_ptr<TranslationTask>
create(xmlrpc_c::paramList const& paramList,
boost::condition_variable& cond, boost::mutex& mut)
{
boost::shared_ptr<TranslationTask> ret(new TranslationTask(paramList, cond, mut));
ret->m_self = ret;
return ret;
}
virtual bool DeleteAfterExecution() {return false;} virtual bool DeleteAfterExecution() {return false;}
bool IsDone() const {return m_done;} bool IsDone() const {return m_done;}
const map<string, xmlrpc_c::value>& GetRetData() { return m_retData;} const map<string, xmlrpc_c::value>& GetRetData() { return m_retData;}
virtual void Run() { virtual void
Run()
{
using namespace xmlrpc_c;
const params_t params = m_paramList.getStruct(0); const params_t params = m_paramList.getStruct(0);
m_paramList.verifyEnd(1); m_paramList.verifyEnd(1);
params_t::const_iterator si = params.find("text"); params_t::const_iterator si = params.find("text");
if (si == params.end()) { if (si == params.end()) {
throw xmlrpc_c::fault( throw fault("Missing source text", fault::CODE_PARSE);
"Missing source text",
xmlrpc_c::fault::CODE_PARSE);
} }
const string source((xmlrpc_c::value_string(si->second))); const string source = value_string(si->second);
XVERBOSE(1,"Input: " << source << endl); XVERBOSE(1,"Input: " << source << endl);
si = params.find("align"); si = params.find("align");
@ -272,7 +283,7 @@ public:
si = params.find("report-all-factors"); si = params.find("report-all-factors");
bool reportAllFactors = (si != params.end()); bool reportAllFactors = (si != params.end());
si = params.find("nbest"); si = params.find("nbest");
int nbest_size = (si == params.end()) ? 0 : int(xmlrpc_c::value_int(si->second)); int nbest_size = (si == params.end()) ? 0 : int(value_int(si->second));
si = params.find("nbest-distinct"); si = params.find("nbest-distinct");
bool nbest_distinct = (si != params.end()); bool nbest_distinct = (si != params.end());
@ -281,21 +292,25 @@ public:
vector<float> multiModelWeights; vector<float> multiModelWeights;
si = params.find("lambda"); si = params.find("lambda");
if (si != params.end()) { if (si != params.end())
xmlrpc_c::value_array multiModelArray = xmlrpc_c::value_array(si->second); {
vector<xmlrpc_c::value> multiModelValueVector(multiModelArray.vectorValueValue()); value_array multiModelArray = value_array(si->second);
for (size_t i=0;i < multiModelValueVector.size();i++) { vector<value> multiModelValueVector(multiModelArray.vectorValueValue());
multiModelWeights.push_back(xmlrpc_c::value_double(multiModelValueVector[i])); for (size_t i=0;i < multiModelValueVector.size();i++)
} {
} multiModelWeights.push_back(value_double(multiModelValueVector[i]));
}
}
si = params.find("model_name"); si = params.find("model_name");
if (si != params.end() && multiModelWeights.size() > 0) { if (si != params.end() && multiModelWeights.size() > 0)
const string model_name = xmlrpc_c::value_string(si->second); {
PhraseDictionaryMultiModel* pdmm = (PhraseDictionaryMultiModel*) FindPhraseDictionary(model_name); const string model_name = value_string(si->second);
PhraseDictionaryMultiModel* pdmm
= (PhraseDictionaryMultiModel*) FindPhraseDictionary(model_name);
pdmm->SetTemporaryMultiModelWeightsVector(multiModelWeights); pdmm->SetTemporaryMultiModelWeightsVector(multiModelWeights);
} }
const StaticData &staticData = StaticData::Instance(); const StaticData &staticData = StaticData::Instance();
//Make sure alternative paths are retained, if necessary //Make sure alternative paths are retained, if necessary
@ -306,13 +321,14 @@ public:
stringstream out, graphInfo, transCollOpts; stringstream out, graphInfo, transCollOpts;
if (staticData.IsSyntax()) { if (staticData.IsSyntax())
TreeInput tinput; {
const vector<FactorType>& boost::shared_ptr<TreeInput> tinput(new TreeInput);
inputFactorOrder = staticData.GetInputFactorOrder(); const vector<FactorType>& IFO = staticData.GetInputFactorOrder();
stringstream in(source + "\n"); istringstream in(source + "\n");
tinput.Read(in,inputFactorOrder); tinput->Read(in,IFO);
ChartManager manager(tinput); ttasksptr task = Moses::TranslationTask::create(tinput);
ChartManager manager(task);
manager.Decode(); manager.Decode();
const ChartHypothesis *hypo = manager.GetBestHypothesis(); const ChartHypothesis *hypo = manager.GetBestHypothesis();
outputChartHypo(out,hypo); outputChartHypo(out,hypo);
@ -320,57 +336,50 @@ public:
// const size_t translationId = tinput.GetTranslationId(); // const size_t translationId = tinput.GetTranslationId();
std::ostringstream sgstream; std::ostringstream sgstream;
manager.OutputSearchGraphMoses(sgstream); manager.OutputSearchGraphMoses(sgstream);
m_retData.insert(pair<string, xmlrpc_c::value>("sg", xmlrpc_c::value_string(sgstream.str()))); m_retData["sg"] = value_string(sgstream.str());
} }
} else { }
else
{
size_t lineNumber = 0; // TODO: Include sentence request number here? size_t lineNumber = 0; // TODO: Include sentence request number here?
Sentence sentence; boost::shared_ptr<Sentence> sentence(new Sentence(0,source));
sentence.SetTranslationId(lineNumber); ttasksptr task = Moses::TranslationTask::create(sentence);
Manager manager(task);
const vector<FactorType> & manager.Decode();
inputFactorOrder = staticData.GetInputFactorOrder();
stringstream in(source + "\n");
sentence.Read(in,inputFactorOrder);
Manager manager(sentence);
manager.Decode();
const Hypothesis* hypo = manager.GetBestHypothesis(); const Hypothesis* hypo = manager.GetBestHypothesis();
vector<xmlrpc_c::value> alignInfo; vector<xmlrpc_c::value> alignInfo;
outputHypo(out,hypo,addAlignInfo,alignInfo,reportAllFactors); outputHypo(out,hypo,addAlignInfo,alignInfo,reportAllFactors);
if (addAlignInfo) { if (addAlignInfo) m_retData["align"] = value_array(alignInfo);
m_retData.insert(pair<string, xmlrpc_c::value>("align", xmlrpc_c::value_array(alignInfo))); if (addWordAlignInfo)
} {
if (addWordAlignInfo) { stringstream wordAlignment;
stringstream wordAlignment; hypo->OutputAlignment(wordAlignment);
hypo->OutputAlignment(wordAlignment); vector<xmlrpc_c::value> alignments;
vector<xmlrpc_c::value> alignments; string alignmentPair;
string alignmentPair; while (wordAlignment >> alignmentPair)
while (wordAlignment >> alignmentPair) { {
int pos = alignmentPair.find('-'); int pos = alignmentPair.find('-');
map<string, xmlrpc_c::value> wordAlignInfo; map<string, xmlrpc_c::value> wordAlignInfo;
wordAlignInfo["source-word"] = xmlrpc_c::value_int(atoi(alignmentPair.substr(0, pos).c_str())); wordAlignInfo["source-word"]
wordAlignInfo["target-word"] = xmlrpc_c::value_int(atoi(alignmentPair.substr(pos + 1).c_str())); = value_int(atoi(alignmentPair.substr(0, pos).c_str()));
alignments.push_back(xmlrpc_c::value_struct(wordAlignInfo)); wordAlignInfo["target-word"]
} = value_int(atoi(alignmentPair.substr(pos + 1).c_str()));
m_retData.insert(pair<string, xmlrpc_c::value_array>("word-align", alignments)); alignments.push_back(value_struct(wordAlignInfo));
} }
m_retData["word-align"] = value_array(alignments);
if (addGraphInfo) { }
insertGraphInfo(manager,m_retData);
} if (addGraphInfo) insertGraphInfo(manager,m_retData);
if (addTopts) { if (addTopts) insertTranslationOptions(manager,m_retData);
insertTranslationOptions(manager,m_retData); if (nbest_size > 0)
} {
if (nbest_size>0) { outputNBest(manager, m_retData, nbest_size, nbest_distinct,
outputNBest(manager, m_retData, nbest_size, nbest_distinct, reportAllFactors, addAlignInfo, addScoreBreakdown);
reportAllFactors, addAlignInfo, addScoreBreakdown); }
}
(const_cast<StaticData&>(staticData)).SetOutputSearchGraph(false); (const_cast<StaticData&>(staticData)).SetOutputSearchGraph(false);
}
} m_retData["text"] = value_string(out.str());
pair<string, xmlrpc_c::value>
text("text", xmlrpc_c::value_string(out.str()));
m_retData.insert(text);
XVERBOSE(1,"Output: " << out.str() << endl); XVERBOSE(1,"Output: " << out.str() << endl);
{ {
boost::lock_guard<boost::mutex> lock(m_mut); boost::lock_guard<boost::mutex> lock(m_mut);
@ -380,9 +389,12 @@ public:
} }
void outputHypo(ostream& out, const Hypothesis* hypo, bool addAlignmentInfo, vector<xmlrpc_c::value>& alignInfo, bool reportAllFactors = false) { void outputHypo(ostream& out, const Hypothesis* hypo,
bool addAlignmentInfo, vector<xmlrpc_c::value>& alignInfo,
bool reportAllFactors = false) {
if (hypo->GetPrevHypo() != NULL) { if (hypo->GetPrevHypo() != NULL) {
outputHypo(out,hypo->GetPrevHypo(),addAlignmentInfo, alignInfo, reportAllFactors); outputHypo(out,hypo->GetPrevHypo(),addAlignmentInfo,
alignInfo, reportAllFactors);
Phrase p = hypo->GetCurrTargetPhrase(); Phrase p = hypo->GetCurrTargetPhrase();
if(reportAllFactors) { if(reportAllFactors) {
out << p << " "; out << p << " ";
@ -595,7 +607,7 @@ public:
boost::condition_variable cond; boost::condition_variable cond;
boost::mutex mut; boost::mutex mut;
typedef ::TranslationTask TTask; typedef ::TranslationTask TTask;
boost::shared_ptr<TTask> task(new TTask(paramList,cond,mut)); boost::shared_ptr<TTask> task = TTask::create(paramList,cond,mut);
m_threadPool.Submit(task); m_threadPool.Submit(task);
boost::unique_lock<boost::mutex> lock(mut); boost::unique_lock<boost::mutex> lock(mut);
while (!task->IsDone()) { while (!task->IsDone()) {

View File

@ -21,7 +21,7 @@ else
if $(where) if $(where)
{ {
option.set "with-curlpp" : $(where) ; option.set "with-curlpp" : $(where) ;
local msg = "setting --with-curlpp=$(where) via environment " ; local msg = "CURLPP: setting --with-curlpp=$(where) via environment" ;
echo "$(msg) variable CURLPP_ROOT" ; echo "$(msg) variable CURLPP_ROOT" ;
} }
curlpp = [ option.get "with-curlpp" ] ; curlpp = [ option.get "with-curlpp" ] ;

View File

@ -53,6 +53,7 @@ POSSIBILITY OF SUCH DAMAGE.
#include "util/exception.hh" #include "util/exception.hh"
#include <boost/foreach.hpp> #include <boost/foreach.hpp>
#include "moses/TranslationTask.h"
using namespace std; using namespace std;
using namespace Moses; using namespace Moses;
@ -175,10 +176,13 @@ int main(int argc, char* argv[])
const vector<float>& prune_grid = grid.getGrid(lmbr_prune); const vector<float>& prune_grid = grid.getGrid(lmbr_prune);
const vector<float>& scale_grid = grid.getGrid(lmbr_scale); const vector<float>& scale_grid = grid.getGrid(lmbr_scale);
for (boost::shared_ptr<InputType> source = ioWrapper->ReadInput(); boost::shared_ptr<InputType> source;
source != NULL; source = ioWrapper->ReadInput()) while((source = ioWrapper->ReadInput()) != NULL)
{ {
Manager manager(*source); // set up task of translating one sentence
boost::shared_ptr<TranslationTask> ttask;
ttask = TranslationTask::create(source, ioWrapper);
Manager manager(ttask);
manager.Decode(); manager.Decode();
TrellisPathList nBestList; TrellisPathList nBestList;
manager.CalcNBest(nBestSize, nBestList,true); manager.CalcNBest(nBestSize, nBestList,true);

View File

@ -153,8 +153,8 @@ int main(int argc, char** argv)
FeatureFunction::CallChangeSource(foo); FeatureFunction::CallChangeSource(foo);
// set up task of training one sentence // set up task of training one sentence
boost::shared_ptr<TrainingTask> boost::shared_ptr<TrainingTask> task;
task(new TrainingTask(source.get(), *ioWrapper)); task = TrainingTask::create(source, ioWrapper);
// execute task // execute task
#ifdef WITH_THREADS #ifdef WITH_THREADS

View File

@ -39,21 +39,21 @@ using namespace std;
namespace Moses namespace Moses
{ {
extern bool g_mosesDebug; extern bool g_mosesDebug;
/* constructor. Initialize everything prior to decoding a particular sentence. /* constructor. Initialize everything prior to decoding a particular sentence.
* \param source the sentence to be decoded * \param source the sentence to be decoded
* \param system which particular set of models to use. * \param system which particular set of models to use.
*/ */
ChartManager::ChartManager(InputType const& source) ChartManager::ChartManager(ttasksptr const& ttask)
:BaseManager(source) : BaseManager(ttask)
,m_hypoStackColl(source, *this) , m_hypoStackColl(m_source, *this)
,m_start(clock()) , m_start(clock())
,m_hypothesisId(0) , m_hypothesisId(0)
,m_parser(source, m_hypoStackColl) , m_parser(ttask, m_hypoStackColl)
,m_translationOptionList(StaticData::Instance().GetRuleLimit(), source) , m_translationOptionList(StaticData::Instance().GetRuleLimit(), m_source)
{ { }
}
ChartManager::~ChartManager() ChartManager::~ChartManager()
{ {
@ -67,6 +67,7 @@ ChartManager::~ChartManager()
//! decode the sentence. This contains the main laps. Basically, the CKY++ algorithm //! decode the sentence. This contains the main laps. Basically, the CKY++ algorithm
void ChartManager::Decode() void ChartManager::Decode()
{ {
VERBOSE(1,"Translating: " << m_source << endl); VERBOSE(1,"Translating: " << m_source << endl);
ResetSentenceStats(m_source); ResetSentenceStats(m_source);

View File

@ -33,8 +33,6 @@
#include "BaseManager.h" #include "BaseManager.h"
#include "moses/Syntax/KBestExtractor.h" #include "moses/Syntax/KBestExtractor.h"
#include <boost/shared_ptr.hpp>
namespace Moses namespace Moses
{ {
@ -103,7 +101,7 @@ private:
void Backtrack(const ChartHypothesis *hypo) const; void Backtrack(const ChartHypothesis *hypo) const;
public: public:
ChartManager(InputType const& source); ChartManager(ttasksptr const& ttask);
~ChartManager(); ~ChartManager();
void Decode(); void Decode();
void AddXmlChartOptions(); void AddXmlChartOptions();

View File

@ -28,6 +28,7 @@
#include "DecodeGraph.h" #include "DecodeGraph.h"
#include "moses/FF/UnknownWordPenaltyProducer.h" #include "moses/FF/UnknownWordPenaltyProducer.h"
#include "moses/TranslationModel/PhraseDictionary.h" #include "moses/TranslationModel/PhraseDictionary.h"
#include "moses/TranslationTask.h"
using namespace std; using namespace std;
using namespace Moses; using namespace Moses;
@ -35,7 +36,10 @@ using namespace Moses;
namespace Moses namespace Moses
{ {
ChartParserUnknown::ChartParserUnknown() {} ChartParserUnknown
::ChartParserUnknown(ttasksptr const& ttask)
: m_ttask(ttask)
{ }
ChartParserUnknown::~ChartParserUnknown() ChartParserUnknown::~ChartParserUnknown()
{ {
@ -136,13 +140,16 @@ void ChartParserUnknown::Process(const Word &sourceWord, const WordsRange &range
} }
} }
ChartParser::ChartParser(InputType const &source, ChartCellCollectionBase &cells) : ChartParser
m_decodeGraphList(StaticData::Instance().GetDecodeGraphs()), ::ChartParser(ttasksptr const& ttask, ChartCellCollectionBase &cells)
m_source(source) : m_ttask(ttask)
, m_unknown(ttask)
, m_decodeGraphList(StaticData::Instance().GetDecodeGraphs())
, m_source(*(ttask->GetSource().get()))
{ {
const StaticData &staticData = StaticData::Instance(); const StaticData &staticData = StaticData::Instance();
staticData.InitializeForInput(source); staticData.InitializeForInput(ttask);
CreateInputPaths(m_source); CreateInputPaths(m_source);
const std::vector<PhraseDictionary*> &dictionaries = PhraseDictionary::GetColl(); const std::vector<PhraseDictionary*> &dictionaries = PhraseDictionary::GetColl();
@ -161,7 +168,7 @@ ChartParser::ChartParser(InputType const &source, ChartCellCollectionBase &cells
ChartParser::~ChartParser() ChartParser::~ChartParser()
{ {
RemoveAllInColl(m_ruleLookupManagers); RemoveAllInColl(m_ruleLookupManagers);
StaticData::Instance().CleanUpAfterSentenceProcessing(m_source); StaticData::Instance().CleanUpAfterSentenceProcessing(m_ttask.lock());
InputPathMatrix::const_iterator iterOuter; InputPathMatrix::const_iterator iterOuter;
for (iterOuter = m_inputPathMatrix.begin(); iterOuter != m_inputPathMatrix.end(); ++iterOuter) { for (iterOuter = m_inputPathMatrix.begin(); iterOuter != m_inputPathMatrix.end(); ++iterOuter) {

View File

@ -1,3 +1,4 @@
// -*- c++ -*-
// $Id$ // $Id$
// vim:tabstop=2 // vim:tabstop=2
/*********************************************************************** /***********************************************************************
@ -42,8 +43,9 @@ class DecodeGraph;
class ChartParserUnknown class ChartParserUnknown
{ {
ttaskwptr m_ttask;
public: public:
ChartParserUnknown(); ChartParserUnknown(ttasksptr const& ttask);
~ChartParserUnknown(); ~ChartParserUnknown();
void Process(const Word &sourceWord, const WordsRange &range, ChartParserCallback &to); void Process(const Word &sourceWord, const WordsRange &range, ChartParserCallback &to);
@ -59,8 +61,9 @@ private:
class ChartParser class ChartParser
{ {
ttaskwptr m_ttask;
public: public:
ChartParser(const InputType &source, ChartCellCollectionBase &cells); ChartParser(ttasksptr const& ttask, ChartCellCollectionBase &cells);
~ChartParser(); ~ChartParser();
void Create(const WordsRange &range, ChartParserCallback &to); void Create(const WordsRange &range, ChartParserCallback &to);

View File

@ -6,6 +6,7 @@
#include "moses/Hypothesis.h" #include "moses/Hypothesis.h"
#include "moses/Manager.h" #include "moses/Manager.h"
#include "moses/TranslationOption.h" #include "moses/TranslationOption.h"
#include "moses/TranslationTask.h"
#include "moses/Util.h" #include "moses/Util.h"
#include "moses/FF/DistortionScoreProducer.h" #include "moses/FF/DistortionScoreProducer.h"
@ -186,5 +187,15 @@ void FeatureFunction::SetTuneableComponents(const std::string& value)
} }
} }
void
FeatureFunction
::InitializeForInput(ttasksptr const& ttask)
{ InitializeForInput(*(ttask->GetSource().get())); }
void
FeatureFunction
::CleanUpAfterSentenceProcessing(ttasksptr const& ttask)
{ CleanupAfterSentenceProcessing(*(ttask->GetSource().get())); }
} }

View File

@ -114,31 +114,43 @@ public:
virtual std::vector<float> DefaultWeights() const; virtual std::vector<float> DefaultWeights() const;
protected:
virtual void
InitializeForInput(InputType const& source) { }
virtual void
CleanupAfterSentenceProcessing(InputType const& source) { }
public:
//! Called before search and collecting of translation options //! Called before search and collecting of translation options
virtual void InitializeForInput(InputType const& source) { virtual void
} InitializeForInput(ttasksptr const& ttask);
// clean up temporary memory, called after processing each sentence // clean up temporary memory, called after processing each sentence
virtual void CleanUpAfterSentenceProcessing(const InputType& source) { virtual void
} CleanUpAfterSentenceProcessing(ttasksptr const& ttask);
const std::string &GetArgLine() const { const std::string &
return m_argLine; GetArgLine() const { return m_argLine; }
}
// given a target phrase containing only factors specified in mask // given a target phrase containing only factors specified in mask
// return true if the feature function can be evaluated // return true if the feature function can be evaluated
virtual bool IsUseable(const FactorMask &mask) const = 0; virtual bool IsUseable(const FactorMask &mask) const = 0;
// used by stateless ff and stateful ff. Calculate initial score estimate during loading of phrase table // used by stateless ff and stateful ff. Calculate initial score
// source phrase is the substring that the phrase table uses to look up the target phrase, // estimate during loading of phrase table
//
// source phrase is the substring that the phrase table uses to look
// up the target phrase,
//
// may have more factors than actually need, but not guaranteed. // may have more factors than actually need, but not guaranteed.
// For SCFG decoding, the source contains non-terminals, NOT the raw source from the input sentence // For SCFG decoding, the source contains non-terminals, NOT the raw
virtual void EvaluateInIsolation(const Phrase &source // source from the input sentence
, const TargetPhrase &targetPhrase virtual void
, ScoreComponentCollection &scoreBreakdown EvaluateInIsolation(const Phrase &source, const TargetPhrase &targetPhrase,
, ScoreComponentCollection &estimatedFutureScore) const = 0; ScoreComponentCollection& scoreBreakdown,
ScoreComponentCollection& estimatedFutureScore) const = 0;
// override this method if you want to change the input before decoding // override this method if you want to change the input before decoding
virtual void ChangeSource(InputType * const&input) const { } virtual void ChangeSource(InputType * const&input) const { }

View File

@ -203,15 +203,15 @@ struct ChartCellBaseFactory {
} // namespace } // namespace
Manager::Manager(const InputType &source) : Manager::Manager(ttasksptr const& ttask)
BaseManager(source), : BaseManager(ttask)
cells_(source, ChartCellBaseFactory()), , cells_(m_source, ChartCellBaseFactory())
parser_(source, cells_), , parser_(ttask, cells_)
n_best_(search::NBestConfig(StaticData::Instance().GetNBestSize())) {} , n_best_(search::NBestConfig(StaticData::Instance().GetNBestSize()))
{ }
Manager::~Manager() Manager::~Manager()
{ { }
}
template <class Model, class Best> search::History Manager::PopulateBest(const Model &model, const std::vector<lm::WordIndex> &words, Best &out) template <class Model, class Best> search::History Manager::PopulateBest(const Model &model, const std::vector<lm::WordIndex> &words, Best &out)
{ {

View File

@ -1,3 +1,4 @@
// -*- c++ -*-
#pragma once #pragma once
#include "lm/word_index.hh" #include "lm/word_index.hh"
@ -24,7 +25,7 @@ namespace Incremental
class Manager : public BaseManager class Manager : public BaseManager
{ {
public: public:
Manager(const InputType &source); Manager(ttasksptr const& ttask);
~Manager(); ~Manager();

View File

@ -44,6 +44,7 @@ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
#include "moses/LM/Base.h" #include "moses/LM/Base.h"
#include "moses/TranslationModel/PhraseDictionary.h" #include "moses/TranslationModel/PhraseDictionary.h"
#include "moses/TranslationAnalysis.h" #include "moses/TranslationAnalysis.h"
#include "moses/TranslationTask.h"
#include "moses/HypergraphOutput.h" #include "moses/HypergraphOutput.h"
#include "moses/mbr.h" #include "moses/mbr.h"
#include "moses/LatticeMBR.h" #include "moses/LatticeMBR.h"
@ -85,7 +86,7 @@ Manager::~Manager()
const InputType& const InputType&
Manager::GetSource() const Manager::GetSource() const
{ return m_source) ; } { return m_source ; }
/** /**
* Main decoder loop that translates a sentence by expanding * Main decoder loop that translates a sentence by expanding

View File

@ -17,13 +17,12 @@ License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
***********************************************************************/ ***********************************************************************/
#include "MockHypothesis.h" #include "MockHypothesis.h"
#include "TranslationOption.h"
#include "TranslationTask.h"
#include <boost/test/unit_test.hpp> #include <boost/test/unit_test.hpp>
#include "TranslationOption.h"
using namespace Moses; using namespace Moses;
using namespace std; using namespace std;
@ -31,29 +30,23 @@ using namespace std;
namespace MosesTest namespace MosesTest
{ {
MockHypothesisGuard
MockHypothesisGuard::MockHypothesisGuard( ::MockHypothesisGuard
const string& sourceSentence, ( const string& sourceSentence,
const vector<Alignment>& alignments, const vector<Alignment>& alignments,
const vector<string>& targetSegments) const vector<string>& targetSegments)
: m_initialTransOpt(), : m_initialTransOpt(), m_wp("WordPenalty"),
m_wp("WordPenalty"), m_uwp("UnknownWordPenalty"), m_dist("Distortion")
m_uwp("UnknownWordPenalty"),
m_dist("Distortion"),
m_manager(m_sentence)
{ {
BOOST_CHECK_EQUAL(alignments.size(), targetSegments.size()); BOOST_CHECK_EQUAL(alignments.size(), targetSegments.size());
std::vector<Moses::FactorType> factors(1,0);
std::vector<Moses::FactorType> factors; m_sentence.reset(new Sentence(0, sourceSentence, &factors));
factors.push_back(0); m_ttask = TranslationTask::create(m_sentence);
m_manager.reset(new Manager(m_ttask));
stringstream in(sourceSentence + "\n");
m_sentence.Read(in,factors);
//Initial empty hypothesis //Initial empty hypothesis
m_manager.ResetSentenceStats(m_sentence); m_manager->ResetSentenceStats(*m_sentence);
m_hypothesis = Hypothesis::Create(m_manager, m_sentence, m_initialTransOpt); m_hypothesis = Hypothesis::Create(*m_manager, *m_sentence, m_initialTransOpt);
//create the chain //create the chain
vector<Alignment>::const_iterator ai = alignments.begin(); vector<Alignment>::const_iterator ai = alignments.begin();

View File

@ -1,3 +1,4 @@
// -*- c++ -*-
/*********************************************************************** /***********************************************************************
Moses - factored phrase-based language decoder Moses - factored phrase-based language decoder
Copyright (C) 2010 University of Edinburgh Copyright (C) 2010 University of Edinburgh
@ -45,10 +46,11 @@ class MockHypothesisGuard
public: public:
/** Creates a phrase-based hypothesis. /** Creates a phrase-based hypothesis.
*/ */
MockHypothesisGuard( MockHypothesisGuard
const std::string& sourceSentence, ( const std::string& sourceSentence,
const std::vector<Alignment>& alignments, const std::vector<Alignment>& alignments,
const std::vector<std::string>& targetSegments); const std::vector<std::string>& targetSegments);
Moses::Hypothesis* operator*() const { Moses::Hypothesis* operator*() const {
return m_hypothesis; return m_hypothesis;
} }
@ -58,11 +60,12 @@ public:
private: private:
Moses::TranslationOption m_initialTransOpt; Moses::TranslationOption m_initialTransOpt;
Moses::Sentence m_sentence; boost::shared_ptr<Moses::Sentence> m_sentence;
Moses::WordPenaltyProducer m_wp; Moses::WordPenaltyProducer m_wp;
Moses::UnknownWordPenaltyProducer m_uwp; Moses::UnknownWordPenaltyProducer m_uwp;
Moses::DistortionScoreProducer m_dist; Moses::DistortionScoreProducer m_dist;
Moses::Manager m_manager; boost::shared_ptr<Moses::Manager> m_manager;
boost::shared_ptr<Moses::TranslationTask> m_ttask;
Moses::Hypothesis* m_hypothesis; Moses::Hypothesis* m_hypothesis;
std::vector<Moses::TargetPhrase> m_targetPhrases; std::vector<Moses::TargetPhrase> m_targetPhrases;
std::vector<Moses::TranslationOption*> m_toptions; std::vector<Moses::TranslationOption*> m_toptions;

View File

@ -383,10 +383,12 @@ CreateFromString(vector<FactorType> const& FOrder, string const& phraseString)
} }
Sentence:: Sentence::
Sentence(size_t const transId, string const& stext) : InputType(transId) Sentence(size_t const transId, string const& stext,
vector<FactorType> const* IFO)
: InputType(transId)
{ {
vector<FactorType> const& IFO = StaticData::Instance().GetInputFactorOrder(); if (IFO) init(stext, *IFO);
init(stext, IFO); else init(stext, StaticData::Instance().GetInputFactorOrder());
} }
} }

View File

@ -63,7 +63,9 @@ namespace Moses
public: public:
Sentence(); Sentence();
Sentence(size_t const transId, std::string const& stext); Sentence(size_t const transId, std::string const& stext,
std::vector<FactorType> const* IFO = NULL);
// Sentence(size_t const transId, std::string const& stext);
~Sentence(); ~Sentence();
InputTypeEnum GetType() const { InputTypeEnum GetType() const {

View File

@ -846,27 +846,33 @@ float StaticData::GetWeightWordPenalty() const
return weightWP; return weightWP;
} }
void StaticData::InitializeForInput(const InputType& source) const void
StaticData
::InitializeForInput(ttasksptr const& ttask) const
{ {
const std::vector<FeatureFunction*> &producers = FeatureFunction::GetFeatureFunctions(); const std::vector<FeatureFunction*> &producers = FeatureFunction::GetFeatureFunctions();
for(size_t i=0; i<producers.size(); ++i) { for(size_t i=0; i<producers.size(); ++i) {
FeatureFunction &ff = *producers[i]; FeatureFunction &ff = *producers[i];
if (! IsFeatureFunctionIgnored(ff)) { if (! IsFeatureFunctionIgnored(ff)) {
Timer iTime; Timer iTime;
iTime.start(); iTime.start();
ff.InitializeForInput(source); ff.InitializeForInput(ttask);
VERBOSE(3,"InitializeForInput( " << ff.GetScoreProducerDescription() << " ) = " << iTime << endl); VERBOSE(3,"InitializeForInput( " << ff.GetScoreProducerDescription() << " )"
<< "= " << iTime << endl);
} }
} }
} }
void StaticData::CleanUpAfterSentenceProcessing(const InputType& source) const void
StaticData
::CleanUpAfterSentenceProcessing(ttasksptr const& ttask) const
{ {
const std::vector<FeatureFunction*> &producers = FeatureFunction::GetFeatureFunctions(); const std::vector<FeatureFunction*> &producers
= FeatureFunction::GetFeatureFunctions();
for(size_t i=0; i<producers.size(); ++i) { for(size_t i=0; i<producers.size(); ++i) {
FeatureFunction &ff = *producers[i]; FeatureFunction &ff = *producers[i];
if (! IsFeatureFunctionIgnored(ff)) { if (! IsFeatureFunctionIgnored(ff)) {
ff.CleanUpAfterSentenceProcessing(source); ff.CleanUpAfterSentenceProcessing(ttask);
} }
} }
} }

View File

@ -446,18 +446,28 @@ public:
SearchAlgorithm GetSearchAlgorithm() const { SearchAlgorithm GetSearchAlgorithm() const {
return m_searchAlgorithm; return m_searchAlgorithm;
} }
bool IsSyntax() const {
return m_searchAlgorithm == CYKPlus ||
m_searchAlgorithm == ChartIncremental ||
m_searchAlgorithm == SyntaxS2T ||
m_searchAlgorithm == SyntaxT2S ||
m_searchAlgorithm == SyntaxT2S_SCFG ||
m_searchAlgorithm == SyntaxF2S;
}
const ScoreComponentCollection& GetAllWeights() const { // bool IsSyntax() const {
return m_allWeights; // return m_searchAlgorithm == CYKPlus ||
// m_searchAlgorithm == ChartIncremental ||
// m_searchAlgorithm == SyntaxS2T ||
// m_searchAlgorithm == SyntaxT2S ||
// m_searchAlgorithm == SyntaxT2S_SCFG ||
// m_searchAlgorithm == SyntaxF2S;
// }
bool IsSyntax(SearchAlgorithm algo = DefaultSearchAlgorithm) const
{
if (algo == DefaultSearchAlgorithm)
algo = m_searchAlgorithm;
return (algo == CYKPlus || algo == ChartIncremental ||
algo == SyntaxS2T || algo == SyntaxT2S ||
algo == SyntaxF2S || algo == SyntaxT2S_SCFG);
} }
const ScoreComponentCollection&
GetAllWeights() const
{ return m_allWeights; }
void SetAllWeights(const ScoreComponentCollection& weights) { void SetAllWeights(const ScoreComponentCollection& weights) {
m_allWeights = weights; m_allWeights = weights;
@ -742,8 +752,9 @@ public:
} }
//sentence (and thread) specific initialisationn and cleanup //sentence (and thread) specific initialisationn and cleanup
void InitializeForInput(const InputType& source) const; // void InitializeForInput(const InputType& source, ttaskptr const& ttask) const;
void CleanUpAfterSentenceProcessing(const InputType& source) const; void InitializeForInput(ttasksptr const& ttask) const;
void CleanUpAfterSentenceProcessing(ttasksptr const& ttask) const;
void LoadFeatureFunctions(); void LoadFeatureFunctions();
bool CheckWeights() const; bool CheckWeights() const;

View File

@ -1,9 +1,11 @@
//-*- c++ -*-
#pragma once #pragma once
#include <boost/smart_ptr/shared_ptr.hpp> #include <boost/smart_ptr/shared_ptr.hpp>
#include "moses/ThreadPool.h" #include "moses/ThreadPool.h"
#include "moses/TranslationOptionCollection.h" #include "moses/TranslationOptionCollection.h"
#include "moses/IOWrapper.h" #include "moses/IOWrapper.h"
#include "moses/TranslationTask.h"
namespace Moses namespace Moses
{ {
@ -11,35 +13,57 @@ class InputType;
class OutputCollector; class OutputCollector;
class TrainingTask : public Moses::Task class TrainingTask : public Moses::TranslationTask
{ {
protected:
TrainingTask(boost::shared_ptr<Moses::InputType> const source,
boost::shared_ptr<Moses::IOWrapper> const ioWrapper)
: TranslationTask(source, ioWrapper)
{ }
public: public:
TrainingTask(Moses::InputType* source, Moses::IOWrapper &ioWrapper) // factory function
: m_source(source) static boost::shared_ptr<TrainingTask>
, m_ioWrapper(ioWrapper) { create(boost::shared_ptr<InputType> const& source)
{
boost::shared_ptr<IOWrapper> nix;
boost::shared_ptr<TrainingTask> ret(new TrainingTask(source, nix));
ret->m_self = ret;
return ret;
} }
~TrainingTask() { // factory function
static boost::shared_ptr<TrainingTask>
create(boost::shared_ptr<InputType> const& source,
boost::shared_ptr<IOWrapper> const& ioWrapper)
{
boost::shared_ptr<TrainingTask> ret(new TrainingTask(source, ioWrapper));
ret->m_self = ret;
return ret;
} }
~TrainingTask()
{ }
void Run() { void Run() {
StaticData::Instance().InitializeForInput(*m_source); StaticData::Instance().InitializeForInput(this->self());
std::cerr << *m_source << std::endl; std::cerr << *m_source << std::endl;
TranslationOptionCollection *transOptColl = m_source->CreateTranslationOptionCollection(); TranslationOptionCollection *transOptColl
= m_source->CreateTranslationOptionCollection();
transOptColl->CreateTranslationOptions(); transOptColl->CreateTranslationOptions();
delete transOptColl; delete transOptColl;
StaticData::Instance().CleanUpAfterSentenceProcessing(*m_source); StaticData::Instance().CleanUpAfterSentenceProcessing(this->self());
} }
private: private:
Moses::InputType* m_source; // Moses::InputType* m_source;
Moses::IOWrapper &m_ioWrapper; // Moses::IOWrapper &m_ioWrapper;
}; };

View File

@ -23,6 +23,16 @@ using namespace std;
namespace Moses namespace Moses
{ {
boost::shared_ptr<TranslationTask>
TranslationTask
::create(boost::shared_ptr<InputType> const& source)
{
boost::shared_ptr<IOWrapper> nix;
boost::shared_ptr<TranslationTask> ret(new TranslationTask(source, nix));
ret->m_self = ret;
return ret;
}
boost::shared_ptr<TranslationTask> boost::shared_ptr<TranslationTask>
TranslationTask TranslationTask
::create(boost::shared_ptr<InputType> const& source, ::create(boost::shared_ptr<InputType> const& source,
@ -42,6 +52,59 @@ TranslationTask
TranslationTask::~TranslationTask() TranslationTask::~TranslationTask()
{ } { }
boost::shared_ptr<BaseManager>
TranslationTask
::SetupManager(SearchAlgorithm algo)
{
boost::shared_ptr<BaseManager> manager;
StaticData const& staticData = StaticData::Instance();
if (algo == DefaultSearchAlgorithm) algo = staticData.GetSearchAlgorithm();
if (!staticData.IsSyntax(algo))
manager.reset(new Manager(this->self())); // phrase-based
else if (algo == SyntaxF2S || algo == SyntaxT2S)
{ // STSG-based tree-to-string / forest-to-string decoding (ask Phil Williams)
typedef Syntax::F2S::RuleMatcherCallback Callback;
typedef Syntax::F2S::RuleMatcherHyperTree<Callback> RuleMatcher;
manager.reset(new Syntax::F2S::Manager<RuleMatcher>(this->self()));
}
else if (algo == SyntaxS2T)
{ // new-style string-to-tree decoding (ask Phil Williams)
S2TParsingAlgorithm algorithm = staticData.GetS2TParsingAlgorithm();
if (algorithm == RecursiveCYKPlus)
{
typedef Syntax::S2T::EagerParserCallback Callback;
typedef Syntax::S2T::RecursiveCYKPlusParser<Callback> Parser;
manager.reset(new Syntax::S2T::Manager<Parser>(this->self()));
}
else if (algorithm == Scope3)
{
typedef Syntax::S2T::StandardParserCallback Callback;
typedef Syntax::S2T::Scope3Parser<Callback> Parser;
manager.reset(new Syntax::S2T::Manager<Parser>(this->self()));
}
else UTIL_THROW2("ERROR: unhandled S2T parsing algorithm");
}
else if (algo == SyntaxT2S_SCFG)
{ // SCFG-based tree-to-string decoding (ask Phil Williams)
typedef Syntax::F2S::RuleMatcherCallback Callback;
typedef Syntax::T2S::RuleMatcherSCFG<Callback> RuleMatcher;
manager.reset(new Syntax::T2S::Manager<RuleMatcher>(this->self()));
}
else if (algo == ChartIncremental) // Ken's incremental decoding
manager.reset(new Incremental::Manager(this->self()));
else // original SCFG manager
manager.reset(new ChartManager(this->self()));
return manager;
}
void TranslationTask::Run() void TranslationTask::Run()
{ {
UTIL_THROW_IF2(!m_source || !m_ioWrapper, UTIL_THROW_IF2(!m_source || !m_ioWrapper,
@ -69,52 +132,22 @@ void TranslationTask::Run()
Timer initTime; Timer initTime;
initTime.start(); initTime.start();
// which manager boost::shared_ptr<BaseManager> manager = SetupManager();
boost::scoped_ptr<BaseManager> manager;
if (!staticData.IsSyntax()) {
// phrase-based
manager.reset(new Manager(*m_source));
} else if (staticData.GetSearchAlgorithm() == SyntaxF2S ||
staticData.GetSearchAlgorithm() == SyntaxT2S) {
// STSG-based tree-to-string / forest-to-string decoding (ask Phil Williams)
typedef Syntax::F2S::RuleMatcherCallback Callback;
typedef Syntax::F2S::RuleMatcherHyperTree<Callback> RuleMatcher;
manager.reset(new Syntax::F2S::Manager<RuleMatcher>(*m_source));
} else if (staticData.GetSearchAlgorithm() == SyntaxS2T) {
// new-style string-to-tree decoding (ask Phil Williams)
S2TParsingAlgorithm algorithm = staticData.GetS2TParsingAlgorithm();
if (algorithm == RecursiveCYKPlus) {
typedef Syntax::S2T::EagerParserCallback Callback;
typedef Syntax::S2T::RecursiveCYKPlusParser<Callback> Parser;
manager.reset(new Syntax::S2T::Manager<Parser>(*m_source));
} else if (algorithm == Scope3) {
typedef Syntax::S2T::StandardParserCallback Callback;
typedef Syntax::S2T::Scope3Parser<Callback> Parser;
manager.reset(new Syntax::S2T::Manager<Parser>(*m_source));
} else {
UTIL_THROW2("ERROR: unhandled S2T parsing algorithm");
}
} else if (staticData.GetSearchAlgorithm() == SyntaxT2S_SCFG) {
// SCFG-based tree-to-string decoding (ask Phil Williams)
typedef Syntax::F2S::RuleMatcherCallback Callback;
typedef Syntax::T2S::RuleMatcherSCFG<Callback> RuleMatcher;
manager.reset(new Syntax::T2S::Manager<RuleMatcher>(*m_source));
} else if (staticData.GetSearchAlgorithm() == ChartIncremental) {
// Ken's incremental decoding
manager.reset(new Incremental::Manager(*m_source));
} else {
// original SCFG manager
manager.reset(new ChartManager(*m_source));
}
VERBOSE(1, "Line " << translationId << ": Initialize search took " VERBOSE(1, "Line " << translationId << ": Initialize search took "
<< initTime << " seconds total" << endl); << initTime << " seconds total" << endl);
manager->Decode(); manager->Decode();
OutputCollector* ocoll; // new: stop here if m_ioWrapper is NULL. This means that the
// owner of the TranslationTask will take care of the output
// oh, and by the way, all the output should be handled by the
// output wrapper along the lines of *m_iwWrapper << *manager;
// Just sayin' ...
if (m_ioWrapper == NULL) return;
// we are done with search, let's look what we got // we are done with search, let's look what we got
OutputCollector* ocoll;
Timer additionalReportingTime; Timer additionalReportingTime;
additionalReportingTime.start(); additionalReportingTime.start();

View File

@ -70,7 +70,12 @@ public:
// creator functions // creator functions
static boost::shared_ptr<TranslationTask> create(); static boost::shared_ptr<TranslationTask> create();
static boost::shared_ptr<TranslationTask> static
boost::shared_ptr<TranslationTask>
create(boost::shared_ptr<Moses::InputType> const& source);
static
boost::shared_ptr<TranslationTask>
create(boost::shared_ptr<Moses::InputType> const& source, create(boost::shared_ptr<Moses::InputType> const& source,
boost::shared_ptr<Moses::IOWrapper> const& ioWrapper); boost::shared_ptr<Moses::IOWrapper> const& ioWrapper);
@ -78,8 +83,15 @@ public:
/** Translate one sentence /** Translate one sentence
* gets called by main function implemented at end of this source file */ * gets called by main function implemented at end of this source file */
virtual void Run(); virtual void Run();
boost::shared_ptr<Moses::InputType>
GetSource() const { return m_source; }
private: boost::shared_ptr<BaseManager>
SetupManager(SearchAlgorithm algo = DefaultSearchAlgorithm);
protected:
boost::shared_ptr<Moses::InputType> m_source; boost::shared_ptr<Moses::InputType> m_source;
boost::shared_ptr<Moses::IOWrapper> m_ioWrapper; boost::shared_ptr<Moses::IOWrapper> m_ioWrapper;

View File

@ -252,7 +252,7 @@ namespace MosesServer
m_reportAllFactors = check(params, "report-all-factors"); m_reportAllFactors = check(params, "report-all-factors");
m_nbestDistinct = check(params, "nbest-distinct"); m_nbestDistinct = check(params, "nbest-distinct");
m_withScoreBreakdown = check(params, "add-score-breakdown"); m_withScoreBreakdown = check(params, "add-score-breakdown");
m_source.reset(new Sentence(0,m_source_string));
si = params.find("lambda"); si = params.find("lambda");
if (si != params.end()) if (si != params.end())
{ {
@ -292,7 +292,7 @@ namespace MosesServer
istringstream buf(m_source_string + "\n"); istringstream buf(m_source_string + "\n");
tinput.Read(buf, StaticData::Instance().GetInputFactorOrder()); tinput.Read(buf, StaticData::Instance().GetInputFactorOrder());
Moses::ChartManager manager(tinput); Moses::ChartManager manager(this->self());
manager.Decode(); manager.Decode();
const Moses::ChartHypothesis *hypo = manager.GetBestHypothesis(); const Moses::ChartHypothesis *hypo = manager.GetBestHypothesis();
@ -356,7 +356,7 @@ namespace MosesServer
TranslationRequest:: TranslationRequest::
run_phrase_decoder() run_phrase_decoder()
{ {
Manager manager(Sentence(0, m_source_string)); Manager manager(this->self());
// if (m_bias.size()) manager.SetBias(&m_bias); // if (m_bias.size()) manager.SetBias(&m_bias);
manager.Decode(); manager.Decode();