Managers and feature functions now have access to the entire TranslationTask, not just the InputType.

This commit is contained in:
Ulrich Germann 2015-03-26 18:25:54 +00:00
parent 4410e9225a
commit 9dc75bfd8a
24 changed files with 365 additions and 225 deletions

View File

@ -82,12 +82,16 @@ include $(TOP)/jam-files/sanity.jam ;
home = [ os.environ "HOME" ] ;
if [ path.exists $(home)/moses-environment.jam ]
{ include $(home)/moses-environment.jam ; }
{
# for those of use who don't like typing in command line bjam options all day long
include $(home)/moses-environment.jam ;
}
include $(TOP)/jam-files/check-environment.jam ; # get resource locations
# from environment variables
include $(TOP)/jam-files/xmlrpc-c.jam ; # xmlrpc-c stuff for the server
include $(TOP)/jam-files/curlpp.jam ; # curlpp stuff for bias lookup (MMT only)
# exit "done" : 0 ;
max-order = [ option.get "max-kenlm-order" : 6 : 6 ] ;
if ! [ option.get "max-kenlm-order" ]

View File

@ -37,6 +37,7 @@ int main(int argc, char** argv)
#include "moses/Manager.h"
#include "moses/StaticData.h"
#include "moses/ThreadPool.h"
#include "moses/TranslationTask.h"
#include "moses/TranslationModel/PhraseDictionaryDynSuffixArray.h"
#include "moses/TranslationModel/PhraseDictionaryMultiModelCounts.h"
#if PT_UG
@ -232,8 +233,8 @@ public:
/**
* Required so that translations can be sent to a thread pool.
**/
class TranslationTask : public virtual Moses::Task {
public:
class TranslationTask : public virtual Moses::TranslationTask {
protected:
TranslationTask(xmlrpc_c::paramList const& paramList,
boost::condition_variable& cond, boost::mutex& mut)
: m_paramList(paramList),
@ -242,23 +243,33 @@ public:
m_done(false)
{}
public:
static boost::shared_ptr<TranslationTask>
create(xmlrpc_c::paramList const& paramList,
boost::condition_variable& cond, boost::mutex& mut)
{
boost::shared_ptr<TranslationTask> ret(new TranslationTask(paramList, cond, mut));
ret->m_self = ret;
return ret;
}
virtual bool DeleteAfterExecution() {return false;}
bool IsDone() const {return m_done;}
const map<string, xmlrpc_c::value>& GetRetData() { return m_retData;}
virtual void Run() {
virtual void
Run()
{
using namespace xmlrpc_c;
const params_t params = m_paramList.getStruct(0);
m_paramList.verifyEnd(1);
params_t::const_iterator si = params.find("text");
if (si == params.end()) {
throw xmlrpc_c::fault(
"Missing source text",
xmlrpc_c::fault::CODE_PARSE);
throw fault("Missing source text", fault::CODE_PARSE);
}
const string source((xmlrpc_c::value_string(si->second)));
const string source = value_string(si->second);
XVERBOSE(1,"Input: " << source << endl);
si = params.find("align");
@ -272,7 +283,7 @@ public:
si = params.find("report-all-factors");
bool reportAllFactors = (si != params.end());
si = params.find("nbest");
int nbest_size = (si == params.end()) ? 0 : int(xmlrpc_c::value_int(si->second));
int nbest_size = (si == params.end()) ? 0 : int(value_int(si->second));
si = params.find("nbest-distinct");
bool nbest_distinct = (si != params.end());
@ -281,21 +292,25 @@ public:
vector<float> multiModelWeights;
si = params.find("lambda");
if (si != params.end()) {
xmlrpc_c::value_array multiModelArray = xmlrpc_c::value_array(si->second);
vector<xmlrpc_c::value> multiModelValueVector(multiModelArray.vectorValueValue());
for (size_t i=0;i < multiModelValueVector.size();i++) {
multiModelWeights.push_back(xmlrpc_c::value_double(multiModelValueVector[i]));
}
}
if (si != params.end())
{
value_array multiModelArray = value_array(si->second);
vector<value> multiModelValueVector(multiModelArray.vectorValueValue());
for (size_t i=0;i < multiModelValueVector.size();i++)
{
multiModelWeights.push_back(value_double(multiModelValueVector[i]));
}
}
si = params.find("model_name");
if (si != params.end() && multiModelWeights.size() > 0) {
const string model_name = xmlrpc_c::value_string(si->second);
PhraseDictionaryMultiModel* pdmm = (PhraseDictionaryMultiModel*) FindPhraseDictionary(model_name);
if (si != params.end() && multiModelWeights.size() > 0)
{
const string model_name = value_string(si->second);
PhraseDictionaryMultiModel* pdmm
= (PhraseDictionaryMultiModel*) FindPhraseDictionary(model_name);
pdmm->SetTemporaryMultiModelWeightsVector(multiModelWeights);
}
}
const StaticData &staticData = StaticData::Instance();
//Make sure alternative paths are retained, if necessary
@ -306,13 +321,14 @@ public:
stringstream out, graphInfo, transCollOpts;
if (staticData.IsSyntax()) {
TreeInput tinput;
const vector<FactorType>&
inputFactorOrder = staticData.GetInputFactorOrder();
stringstream in(source + "\n");
tinput.Read(in,inputFactorOrder);
ChartManager manager(tinput);
if (staticData.IsSyntax())
{
boost::shared_ptr<TreeInput> tinput(new TreeInput);
const vector<FactorType>& IFO = staticData.GetInputFactorOrder();
istringstream in(source + "\n");
tinput->Read(in,IFO);
ttasksptr task = Moses::TranslationTask::create(tinput);
ChartManager manager(task);
manager.Decode();
const ChartHypothesis *hypo = manager.GetBestHypothesis();
outputChartHypo(out,hypo);
@ -320,57 +336,50 @@ public:
// const size_t translationId = tinput.GetTranslationId();
std::ostringstream sgstream;
manager.OutputSearchGraphMoses(sgstream);
m_retData.insert(pair<string, xmlrpc_c::value>("sg", xmlrpc_c::value_string(sgstream.str())));
m_retData["sg"] = value_string(sgstream.str());
}
} else {
}
else
{
size_t lineNumber = 0; // TODO: Include sentence request number here?
Sentence sentence;
sentence.SetTranslationId(lineNumber);
const vector<FactorType> &
inputFactorOrder = staticData.GetInputFactorOrder();
stringstream in(source + "\n");
sentence.Read(in,inputFactorOrder);
Manager manager(sentence);
manager.Decode();
boost::shared_ptr<Sentence> sentence(new Sentence(0,source));
ttasksptr task = Moses::TranslationTask::create(sentence);
Manager manager(task);
manager.Decode();
const Hypothesis* hypo = manager.GetBestHypothesis();
vector<xmlrpc_c::value> alignInfo;
outputHypo(out,hypo,addAlignInfo,alignInfo,reportAllFactors);
if (addAlignInfo) {
m_retData.insert(pair<string, xmlrpc_c::value>("align", xmlrpc_c::value_array(alignInfo)));
}
if (addWordAlignInfo) {
stringstream wordAlignment;
hypo->OutputAlignment(wordAlignment);
vector<xmlrpc_c::value> alignments;
string alignmentPair;
while (wordAlignment >> alignmentPair) {
if (addAlignInfo) m_retData["align"] = value_array(alignInfo);
if (addWordAlignInfo)
{
stringstream wordAlignment;
hypo->OutputAlignment(wordAlignment);
vector<xmlrpc_c::value> alignments;
string alignmentPair;
while (wordAlignment >> alignmentPair)
{
int pos = alignmentPair.find('-');
map<string, xmlrpc_c::value> wordAlignInfo;
wordAlignInfo["source-word"] = xmlrpc_c::value_int(atoi(alignmentPair.substr(0, pos).c_str()));
wordAlignInfo["target-word"] = xmlrpc_c::value_int(atoi(alignmentPair.substr(pos + 1).c_str()));
alignments.push_back(xmlrpc_c::value_struct(wordAlignInfo));
}
m_retData.insert(pair<string, xmlrpc_c::value_array>("word-align", alignments));
}
if (addGraphInfo) {
insertGraphInfo(manager,m_retData);
}
if (addTopts) {
insertTranslationOptions(manager,m_retData);
}
if (nbest_size>0) {
outputNBest(manager, m_retData, nbest_size, nbest_distinct,
reportAllFactors, addAlignInfo, addScoreBreakdown);
}
wordAlignInfo["source-word"]
= value_int(atoi(alignmentPair.substr(0, pos).c_str()));
wordAlignInfo["target-word"]
= value_int(atoi(alignmentPair.substr(pos + 1).c_str()));
alignments.push_back(value_struct(wordAlignInfo));
}
m_retData["word-align"] = value_array(alignments);
}
if (addGraphInfo) insertGraphInfo(manager,m_retData);
if (addTopts) insertTranslationOptions(manager,m_retData);
if (nbest_size > 0)
{
outputNBest(manager, m_retData, nbest_size, nbest_distinct,
reportAllFactors, addAlignInfo, addScoreBreakdown);
}
(const_cast<StaticData&>(staticData)).SetOutputSearchGraph(false);
}
pair<string, xmlrpc_c::value>
text("text", xmlrpc_c::value_string(out.str()));
m_retData.insert(text);
}
m_retData["text"] = value_string(out.str());
XVERBOSE(1,"Output: " << out.str() << endl);
{
boost::lock_guard<boost::mutex> lock(m_mut);
@ -380,9 +389,12 @@ public:
}
void outputHypo(ostream& out, const Hypothesis* hypo, bool addAlignmentInfo, vector<xmlrpc_c::value>& alignInfo, bool reportAllFactors = false) {
void outputHypo(ostream& out, const Hypothesis* hypo,
bool addAlignmentInfo, vector<xmlrpc_c::value>& alignInfo,
bool reportAllFactors = false) {
if (hypo->GetPrevHypo() != NULL) {
outputHypo(out,hypo->GetPrevHypo(),addAlignmentInfo, alignInfo, reportAllFactors);
outputHypo(out,hypo->GetPrevHypo(),addAlignmentInfo,
alignInfo, reportAllFactors);
Phrase p = hypo->GetCurrTargetPhrase();
if(reportAllFactors) {
out << p << " ";
@ -595,7 +607,7 @@ public:
boost::condition_variable cond;
boost::mutex mut;
typedef ::TranslationTask TTask;
boost::shared_ptr<TTask> task(new TTask(paramList,cond,mut));
boost::shared_ptr<TTask> task = TTask::create(paramList,cond,mut);
m_threadPool.Submit(task);
boost::unique_lock<boost::mutex> lock(mut);
while (!task->IsDone()) {

View File

@ -21,7 +21,7 @@ else
if $(where)
{
option.set "with-curlpp" : $(where) ;
local msg = "setting --with-curlpp=$(where) via environment " ;
local msg = "CURLPP: setting --with-curlpp=$(where) via environment" ;
echo "$(msg) variable CURLPP_ROOT" ;
}
curlpp = [ option.get "with-curlpp" ] ;

View File

@ -53,6 +53,7 @@ POSSIBILITY OF SUCH DAMAGE.
#include "util/exception.hh"
#include <boost/foreach.hpp>
#include "moses/TranslationTask.h"
using namespace std;
using namespace Moses;
@ -175,10 +176,13 @@ int main(int argc, char* argv[])
const vector<float>& prune_grid = grid.getGrid(lmbr_prune);
const vector<float>& scale_grid = grid.getGrid(lmbr_scale);
for (boost::shared_ptr<InputType> source = ioWrapper->ReadInput();
source != NULL; source = ioWrapper->ReadInput())
boost::shared_ptr<InputType> source;
while((source = ioWrapper->ReadInput()) != NULL)
{
Manager manager(*source);
// set up task of translating one sentence
boost::shared_ptr<TranslationTask> ttask;
ttask = TranslationTask::create(source, ioWrapper);
Manager manager(ttask);
manager.Decode();
TrellisPathList nBestList;
manager.CalcNBest(nBestSize, nBestList,true);

View File

@ -153,8 +153,8 @@ int main(int argc, char** argv)
FeatureFunction::CallChangeSource(foo);
// set up task of training one sentence
boost::shared_ptr<TrainingTask>
task(new TrainingTask(source.get(), *ioWrapper));
boost::shared_ptr<TrainingTask> task;
task = TrainingTask::create(source, ioWrapper);
// execute task
#ifdef WITH_THREADS

View File

@ -39,21 +39,21 @@ using namespace std;
namespace Moses
{
extern bool g_mosesDebug;
/* constructor. Initialize everything prior to decoding a particular sentence.
* \param source the sentence to be decoded
* \param system which particular set of models to use.
*/
ChartManager::ChartManager(InputType const& source)
:BaseManager(source)
,m_hypoStackColl(source, *this)
,m_start(clock())
,m_hypothesisId(0)
,m_parser(source, m_hypoStackColl)
,m_translationOptionList(StaticData::Instance().GetRuleLimit(), source)
{
}
ChartManager::ChartManager(ttasksptr const& ttask)
: BaseManager(ttask)
, m_hypoStackColl(m_source, *this)
, m_start(clock())
, m_hypothesisId(0)
, m_parser(ttask, m_hypoStackColl)
, m_translationOptionList(StaticData::Instance().GetRuleLimit(), m_source)
{ }
ChartManager::~ChartManager()
{
@ -67,6 +67,7 @@ ChartManager::~ChartManager()
//! decode the sentence. This contains the main laps. Basically, the CKY++ algorithm
void ChartManager::Decode()
{
VERBOSE(1,"Translating: " << m_source << endl);
ResetSentenceStats(m_source);

View File

@ -33,8 +33,6 @@
#include "BaseManager.h"
#include "moses/Syntax/KBestExtractor.h"
#include <boost/shared_ptr.hpp>
namespace Moses
{
@ -103,7 +101,7 @@ private:
void Backtrack(const ChartHypothesis *hypo) const;
public:
ChartManager(InputType const& source);
ChartManager(ttasksptr const& ttask);
~ChartManager();
void Decode();
void AddXmlChartOptions();

View File

@ -28,6 +28,7 @@
#include "DecodeGraph.h"
#include "moses/FF/UnknownWordPenaltyProducer.h"
#include "moses/TranslationModel/PhraseDictionary.h"
#include "moses/TranslationTask.h"
using namespace std;
using namespace Moses;
@ -35,7 +36,10 @@ using namespace Moses;
namespace Moses
{
ChartParserUnknown::ChartParserUnknown() {}
ChartParserUnknown
::ChartParserUnknown(ttasksptr const& ttask)
: m_ttask(ttask)
{ }
ChartParserUnknown::~ChartParserUnknown()
{
@ -136,13 +140,16 @@ void ChartParserUnknown::Process(const Word &sourceWord, const WordsRange &range
}
}
ChartParser::ChartParser(InputType const &source, ChartCellCollectionBase &cells) :
m_decodeGraphList(StaticData::Instance().GetDecodeGraphs()),
m_source(source)
ChartParser
::ChartParser(ttasksptr const& ttask, ChartCellCollectionBase &cells)
: m_ttask(ttask)
, m_unknown(ttask)
, m_decodeGraphList(StaticData::Instance().GetDecodeGraphs())
, m_source(*(ttask->GetSource().get()))
{
const StaticData &staticData = StaticData::Instance();
staticData.InitializeForInput(source);
staticData.InitializeForInput(ttask);
CreateInputPaths(m_source);
const std::vector<PhraseDictionary*> &dictionaries = PhraseDictionary::GetColl();
@ -161,7 +168,7 @@ ChartParser::ChartParser(InputType const &source, ChartCellCollectionBase &cells
ChartParser::~ChartParser()
{
RemoveAllInColl(m_ruleLookupManagers);
StaticData::Instance().CleanUpAfterSentenceProcessing(m_source);
StaticData::Instance().CleanUpAfterSentenceProcessing(m_ttask.lock());
InputPathMatrix::const_iterator iterOuter;
for (iterOuter = m_inputPathMatrix.begin(); iterOuter != m_inputPathMatrix.end(); ++iterOuter) {

View File

@ -1,3 +1,4 @@
// -*- c++ -*-
// $Id$
// vim:tabstop=2
/***********************************************************************
@ -42,8 +43,9 @@ class DecodeGraph;
class ChartParserUnknown
{
ttaskwptr m_ttask;
public:
ChartParserUnknown();
ChartParserUnknown(ttasksptr const& ttask);
~ChartParserUnknown();
void Process(const Word &sourceWord, const WordsRange &range, ChartParserCallback &to);
@ -59,8 +61,9 @@ private:
class ChartParser
{
ttaskwptr m_ttask;
public:
ChartParser(const InputType &source, ChartCellCollectionBase &cells);
ChartParser(ttasksptr const& ttask, ChartCellCollectionBase &cells);
~ChartParser();
void Create(const WordsRange &range, ChartParserCallback &to);

View File

@ -6,6 +6,7 @@
#include "moses/Hypothesis.h"
#include "moses/Manager.h"
#include "moses/TranslationOption.h"
#include "moses/TranslationTask.h"
#include "moses/Util.h"
#include "moses/FF/DistortionScoreProducer.h"
@ -186,5 +187,15 @@ void FeatureFunction::SetTuneableComponents(const std::string& value)
}
}
void
FeatureFunction
::InitializeForInput(ttasksptr const& ttask)
{ InitializeForInput(*(ttask->GetSource().get())); }
void
FeatureFunction
::CleanUpAfterSentenceProcessing(ttasksptr const& ttask)
{ CleanupAfterSentenceProcessing(*(ttask->GetSource().get())); }
}

View File

@ -114,31 +114,43 @@ public:
virtual std::vector<float> DefaultWeights() const;
protected:
virtual void
InitializeForInput(InputType const& source) { }
virtual void
CleanupAfterSentenceProcessing(InputType const& source) { }
public:
//! Called before search and collecting of translation options
virtual void InitializeForInput(InputType const& source) {
}
virtual void
InitializeForInput(ttasksptr const& ttask);
// clean up temporary memory, called after processing each sentence
virtual void CleanUpAfterSentenceProcessing(const InputType& source) {
}
virtual void
CleanUpAfterSentenceProcessing(ttasksptr const& ttask);
const std::string &GetArgLine() const {
return m_argLine;
}
const std::string &
GetArgLine() const { return m_argLine; }
// given a target phrase containing only factors specified in mask
// return true if the feature function can be evaluated
virtual bool IsUseable(const FactorMask &mask) const = 0;
// used by stateless ff and stateful ff. Calculate initial score estimate during loading of phrase table
// source phrase is the substring that the phrase table uses to look up the target phrase,
// used by stateless ff and stateful ff. Calculate initial score
// estimate during loading of phrase table
//
// source phrase is the substring that the phrase table uses to look
// up the target phrase,
//
// may have more factors than actually need, but not guaranteed.
// For SCFG decoding, the source contains non-terminals, NOT the raw source from the input sentence
virtual void EvaluateInIsolation(const Phrase &source
, const TargetPhrase &targetPhrase
, ScoreComponentCollection &scoreBreakdown
, ScoreComponentCollection &estimatedFutureScore) const = 0;
// For SCFG decoding, the source contains non-terminals, NOT the raw
// source from the input sentence
virtual void
EvaluateInIsolation(const Phrase &source, const TargetPhrase &targetPhrase,
ScoreComponentCollection& scoreBreakdown,
ScoreComponentCollection& estimatedFutureScore) const = 0;
// override this method if you want to change the input before decoding
virtual void ChangeSource(InputType * const&input) const { }

View File

@ -203,15 +203,15 @@ struct ChartCellBaseFactory {
} // namespace
Manager::Manager(const InputType &source) :
BaseManager(source),
cells_(source, ChartCellBaseFactory()),
parser_(source, cells_),
n_best_(search::NBestConfig(StaticData::Instance().GetNBestSize())) {}
Manager::Manager(ttasksptr const& ttask)
: BaseManager(ttask)
, cells_(m_source, ChartCellBaseFactory())
, parser_(ttask, cells_)
, n_best_(search::NBestConfig(StaticData::Instance().GetNBestSize()))
{ }
Manager::~Manager()
{
}
{ }
template <class Model, class Best> search::History Manager::PopulateBest(const Model &model, const std::vector<lm::WordIndex> &words, Best &out)
{

View File

@ -1,3 +1,4 @@
// -*- c++ -*-
#pragma once
#include "lm/word_index.hh"
@ -24,7 +25,7 @@ namespace Incremental
class Manager : public BaseManager
{
public:
Manager(const InputType &source);
Manager(ttasksptr const& ttask);
~Manager();

View File

@ -44,6 +44,7 @@ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
#include "moses/LM/Base.h"
#include "moses/TranslationModel/PhraseDictionary.h"
#include "moses/TranslationAnalysis.h"
#include "moses/TranslationTask.h"
#include "moses/HypergraphOutput.h"
#include "moses/mbr.h"
#include "moses/LatticeMBR.h"
@ -85,7 +86,7 @@ Manager::~Manager()
const InputType&
Manager::GetSource() const
{ return m_source) ; }
{ return m_source ; }
/**
* Main decoder loop that translates a sentence by expanding

View File

@ -17,13 +17,12 @@ License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
***********************************************************************/
#include "MockHypothesis.h"
#include "TranslationOption.h"
#include "TranslationTask.h"
#include <boost/test/unit_test.hpp>
#include "TranslationOption.h"
using namespace Moses;
using namespace std;
@ -31,29 +30,23 @@ using namespace std;
namespace MosesTest
{
MockHypothesisGuard::MockHypothesisGuard(
const string& sourceSentence,
MockHypothesisGuard
::MockHypothesisGuard
( const string& sourceSentence,
const vector<Alignment>& alignments,
const vector<string>& targetSegments)
: m_initialTransOpt(),
m_wp("WordPenalty"),
m_uwp("UnknownWordPenalty"),
m_dist("Distortion"),
m_manager(m_sentence)
: m_initialTransOpt(), m_wp("WordPenalty"),
m_uwp("UnknownWordPenalty"), m_dist("Distortion")
{
BOOST_CHECK_EQUAL(alignments.size(), targetSegments.size());
std::vector<Moses::FactorType> factors;
factors.push_back(0);
stringstream in(sourceSentence + "\n");
m_sentence.Read(in,factors);
std::vector<Moses::FactorType> factors(1,0);
m_sentence.reset(new Sentence(0, sourceSentence, &factors));
m_ttask = TranslationTask::create(m_sentence);
m_manager.reset(new Manager(m_ttask));
//Initial empty hypothesis
m_manager.ResetSentenceStats(m_sentence);
m_hypothesis = Hypothesis::Create(m_manager, m_sentence, m_initialTransOpt);
m_manager->ResetSentenceStats(*m_sentence);
m_hypothesis = Hypothesis::Create(*m_manager, *m_sentence, m_initialTransOpt);
//create the chain
vector<Alignment>::const_iterator ai = alignments.begin();

View File

@ -1,3 +1,4 @@
// -*- c++ -*-
/***********************************************************************
Moses - factored phrase-based language decoder
Copyright (C) 2010 University of Edinburgh
@ -45,10 +46,11 @@ class MockHypothesisGuard
public:
/** Creates a phrase-based hypothesis.
*/
MockHypothesisGuard(
const std::string& sourceSentence,
MockHypothesisGuard
( const std::string& sourceSentence,
const std::vector<Alignment>& alignments,
const std::vector<std::string>& targetSegments);
Moses::Hypothesis* operator*() const {
return m_hypothesis;
}
@ -58,11 +60,12 @@ public:
private:
Moses::TranslationOption m_initialTransOpt;
Moses::Sentence m_sentence;
boost::shared_ptr<Moses::Sentence> m_sentence;
Moses::WordPenaltyProducer m_wp;
Moses::UnknownWordPenaltyProducer m_uwp;
Moses::DistortionScoreProducer m_dist;
Moses::Manager m_manager;
boost::shared_ptr<Moses::Manager> m_manager;
boost::shared_ptr<Moses::TranslationTask> m_ttask;
Moses::Hypothesis* m_hypothesis;
std::vector<Moses::TargetPhrase> m_targetPhrases;
std::vector<Moses::TranslationOption*> m_toptions;

View File

@ -383,10 +383,12 @@ CreateFromString(vector<FactorType> const& FOrder, string const& phraseString)
}
Sentence::
Sentence(size_t const transId, string const& stext) : InputType(transId)
Sentence(size_t const transId, string const& stext,
vector<FactorType> const* IFO)
: InputType(transId)
{
vector<FactorType> const& IFO = StaticData::Instance().GetInputFactorOrder();
init(stext, IFO);
if (IFO) init(stext, *IFO);
else init(stext, StaticData::Instance().GetInputFactorOrder());
}
}

View File

@ -63,7 +63,9 @@ namespace Moses
public:
Sentence();
Sentence(size_t const transId, std::string const& stext);
Sentence(size_t const transId, std::string const& stext,
std::vector<FactorType> const* IFO = NULL);
// Sentence(size_t const transId, std::string const& stext);
~Sentence();
InputTypeEnum GetType() const {

View File

@ -846,27 +846,33 @@ float StaticData::GetWeightWordPenalty() const
return weightWP;
}
void StaticData::InitializeForInput(const InputType& source) const
void
StaticData
::InitializeForInput(ttasksptr const& ttask) const
{
const std::vector<FeatureFunction*> &producers = FeatureFunction::GetFeatureFunctions();
const std::vector<FeatureFunction*> &producers = FeatureFunction::GetFeatureFunctions();
for(size_t i=0; i<producers.size(); ++i) {
FeatureFunction &ff = *producers[i];
if (! IsFeatureFunctionIgnored(ff)) {
Timer iTime;
iTime.start();
ff.InitializeForInput(source);
VERBOSE(3,"InitializeForInput( " << ff.GetScoreProducerDescription() << " ) = " << iTime << endl);
ff.InitializeForInput(ttask);
VERBOSE(3,"InitializeForInput( " << ff.GetScoreProducerDescription() << " )"
<< "= " << iTime << endl);
}
}
}
void StaticData::CleanUpAfterSentenceProcessing(const InputType& source) const
void
StaticData
::CleanUpAfterSentenceProcessing(ttasksptr const& ttask) const
{
const std::vector<FeatureFunction*> &producers = FeatureFunction::GetFeatureFunctions();
const std::vector<FeatureFunction*> &producers
= FeatureFunction::GetFeatureFunctions();
for(size_t i=0; i<producers.size(); ++i) {
FeatureFunction &ff = *producers[i];
if (! IsFeatureFunctionIgnored(ff)) {
ff.CleanUpAfterSentenceProcessing(source);
ff.CleanUpAfterSentenceProcessing(ttask);
}
}
}

View File

@ -446,18 +446,28 @@ public:
SearchAlgorithm GetSearchAlgorithm() const {
return m_searchAlgorithm;
}
bool IsSyntax() const {
return m_searchAlgorithm == CYKPlus ||
m_searchAlgorithm == ChartIncremental ||
m_searchAlgorithm == SyntaxS2T ||
m_searchAlgorithm == SyntaxT2S ||
m_searchAlgorithm == SyntaxT2S_SCFG ||
m_searchAlgorithm == SyntaxF2S;
}
const ScoreComponentCollection& GetAllWeights() const {
return m_allWeights;
// bool IsSyntax() const {
// return m_searchAlgorithm == CYKPlus ||
// m_searchAlgorithm == ChartIncremental ||
// m_searchAlgorithm == SyntaxS2T ||
// m_searchAlgorithm == SyntaxT2S ||
// m_searchAlgorithm == SyntaxT2S_SCFG ||
// m_searchAlgorithm == SyntaxF2S;
// }
bool IsSyntax(SearchAlgorithm algo = DefaultSearchAlgorithm) const
{
if (algo == DefaultSearchAlgorithm)
algo = m_searchAlgorithm;
return (algo == CYKPlus || algo == ChartIncremental ||
algo == SyntaxS2T || algo == SyntaxT2S ||
algo == SyntaxF2S || algo == SyntaxT2S_SCFG);
}
const ScoreComponentCollection&
GetAllWeights() const
{ return m_allWeights; }
void SetAllWeights(const ScoreComponentCollection& weights) {
m_allWeights = weights;
@ -742,8 +752,9 @@ public:
}
//sentence (and thread) specific initialisationn and cleanup
void InitializeForInput(const InputType& source) const;
void CleanUpAfterSentenceProcessing(const InputType& source) const;
// void InitializeForInput(const InputType& source, ttaskptr const& ttask) const;
void InitializeForInput(ttasksptr const& ttask) const;
void CleanUpAfterSentenceProcessing(ttasksptr const& ttask) const;
void LoadFeatureFunctions();
bool CheckWeights() const;

View File

@ -1,9 +1,11 @@
//-*- c++ -*-
#pragma once
#include <boost/smart_ptr/shared_ptr.hpp>
#include "moses/ThreadPool.h"
#include "moses/TranslationOptionCollection.h"
#include "moses/IOWrapper.h"
#include "moses/TranslationTask.h"
namespace Moses
{
@ -11,35 +13,57 @@ class InputType;
class OutputCollector;
class TrainingTask : public Moses::Task
class TrainingTask : public Moses::TranslationTask
{
protected:
TrainingTask(boost::shared_ptr<Moses::InputType> const source,
boost::shared_ptr<Moses::IOWrapper> const ioWrapper)
: TranslationTask(source, ioWrapper)
{ }
public:
TrainingTask(Moses::InputType* source, Moses::IOWrapper &ioWrapper)
: m_source(source)
, m_ioWrapper(ioWrapper) {
// factory function
static boost::shared_ptr<TrainingTask>
create(boost::shared_ptr<InputType> const& source)
{
boost::shared_ptr<IOWrapper> nix;
boost::shared_ptr<TrainingTask> ret(new TrainingTask(source, nix));
ret->m_self = ret;
return ret;
}
~TrainingTask() {
// factory function
static boost::shared_ptr<TrainingTask>
create(boost::shared_ptr<InputType> const& source,
boost::shared_ptr<IOWrapper> const& ioWrapper)
{
boost::shared_ptr<TrainingTask> ret(new TrainingTask(source, ioWrapper));
ret->m_self = ret;
return ret;
}
~TrainingTask()
{ }
void Run() {
StaticData::Instance().InitializeForInput(*m_source);
StaticData::Instance().InitializeForInput(this->self());
std::cerr << *m_source << std::endl;
TranslationOptionCollection *transOptColl = m_source->CreateTranslationOptionCollection();
TranslationOptionCollection *transOptColl
= m_source->CreateTranslationOptionCollection();
transOptColl->CreateTranslationOptions();
delete transOptColl;
StaticData::Instance().CleanUpAfterSentenceProcessing(*m_source);
StaticData::Instance().CleanUpAfterSentenceProcessing(this->self());
}
private:
Moses::InputType* m_source;
Moses::IOWrapper &m_ioWrapper;
// Moses::InputType* m_source;
// Moses::IOWrapper &m_ioWrapper;
};

View File

@ -23,6 +23,16 @@ using namespace std;
namespace Moses
{
boost::shared_ptr<TranslationTask>
TranslationTask
::create(boost::shared_ptr<InputType> const& source)
{
boost::shared_ptr<IOWrapper> nix;
boost::shared_ptr<TranslationTask> ret(new TranslationTask(source, nix));
ret->m_self = ret;
return ret;
}
boost::shared_ptr<TranslationTask>
TranslationTask
::create(boost::shared_ptr<InputType> const& source,
@ -42,6 +52,59 @@ TranslationTask
TranslationTask::~TranslationTask()
{ }
boost::shared_ptr<BaseManager>
TranslationTask
::SetupManager(SearchAlgorithm algo)
{
boost::shared_ptr<BaseManager> manager;
StaticData const& staticData = StaticData::Instance();
if (algo == DefaultSearchAlgorithm) algo = staticData.GetSearchAlgorithm();
if (!staticData.IsSyntax(algo))
manager.reset(new Manager(this->self())); // phrase-based
else if (algo == SyntaxF2S || algo == SyntaxT2S)
{ // STSG-based tree-to-string / forest-to-string decoding (ask Phil Williams)
typedef Syntax::F2S::RuleMatcherCallback Callback;
typedef Syntax::F2S::RuleMatcherHyperTree<Callback> RuleMatcher;
manager.reset(new Syntax::F2S::Manager<RuleMatcher>(this->self()));
}
else if (algo == SyntaxS2T)
{ // new-style string-to-tree decoding (ask Phil Williams)
S2TParsingAlgorithm algorithm = staticData.GetS2TParsingAlgorithm();
if (algorithm == RecursiveCYKPlus)
{
typedef Syntax::S2T::EagerParserCallback Callback;
typedef Syntax::S2T::RecursiveCYKPlusParser<Callback> Parser;
manager.reset(new Syntax::S2T::Manager<Parser>(this->self()));
}
else if (algorithm == Scope3)
{
typedef Syntax::S2T::StandardParserCallback Callback;
typedef Syntax::S2T::Scope3Parser<Callback> Parser;
manager.reset(new Syntax::S2T::Manager<Parser>(this->self()));
}
else UTIL_THROW2("ERROR: unhandled S2T parsing algorithm");
}
else if (algo == SyntaxT2S_SCFG)
{ // SCFG-based tree-to-string decoding (ask Phil Williams)
typedef Syntax::F2S::RuleMatcherCallback Callback;
typedef Syntax::T2S::RuleMatcherSCFG<Callback> RuleMatcher;
manager.reset(new Syntax::T2S::Manager<RuleMatcher>(this->self()));
}
else if (algo == ChartIncremental) // Ken's incremental decoding
manager.reset(new Incremental::Manager(this->self()));
else // original SCFG manager
manager.reset(new ChartManager(this->self()));
return manager;
}
void TranslationTask::Run()
{
UTIL_THROW_IF2(!m_source || !m_ioWrapper,
@ -69,52 +132,22 @@ void TranslationTask::Run()
Timer initTime;
initTime.start();
// which manager
boost::scoped_ptr<BaseManager> manager;
if (!staticData.IsSyntax()) {
// phrase-based
manager.reset(new Manager(*m_source));
} else if (staticData.GetSearchAlgorithm() == SyntaxF2S ||
staticData.GetSearchAlgorithm() == SyntaxT2S) {
// STSG-based tree-to-string / forest-to-string decoding (ask Phil Williams)
typedef Syntax::F2S::RuleMatcherCallback Callback;
typedef Syntax::F2S::RuleMatcherHyperTree<Callback> RuleMatcher;
manager.reset(new Syntax::F2S::Manager<RuleMatcher>(*m_source));
} else if (staticData.GetSearchAlgorithm() == SyntaxS2T) {
// new-style string-to-tree decoding (ask Phil Williams)
S2TParsingAlgorithm algorithm = staticData.GetS2TParsingAlgorithm();
if (algorithm == RecursiveCYKPlus) {
typedef Syntax::S2T::EagerParserCallback Callback;
typedef Syntax::S2T::RecursiveCYKPlusParser<Callback> Parser;
manager.reset(new Syntax::S2T::Manager<Parser>(*m_source));
} else if (algorithm == Scope3) {
typedef Syntax::S2T::StandardParserCallback Callback;
typedef Syntax::S2T::Scope3Parser<Callback> Parser;
manager.reset(new Syntax::S2T::Manager<Parser>(*m_source));
} else {
UTIL_THROW2("ERROR: unhandled S2T parsing algorithm");
}
} else if (staticData.GetSearchAlgorithm() == SyntaxT2S_SCFG) {
// SCFG-based tree-to-string decoding (ask Phil Williams)
typedef Syntax::F2S::RuleMatcherCallback Callback;
typedef Syntax::T2S::RuleMatcherSCFG<Callback> RuleMatcher;
manager.reset(new Syntax::T2S::Manager<RuleMatcher>(*m_source));
} else if (staticData.GetSearchAlgorithm() == ChartIncremental) {
// Ken's incremental decoding
manager.reset(new Incremental::Manager(*m_source));
} else {
// original SCFG manager
manager.reset(new ChartManager(*m_source));
}
boost::shared_ptr<BaseManager> manager = SetupManager();
VERBOSE(1, "Line " << translationId << ": Initialize search took "
<< initTime << " seconds total" << endl);
manager->Decode();
OutputCollector* ocoll;
// new: stop here if m_ioWrapper is NULL. This means that the
// owner of the TranslationTask will take care of the output
// oh, and by the way, all the output should be handled by the
// output wrapper along the lines of *m_iwWrapper << *manager;
// Just sayin' ...
if (m_ioWrapper == NULL) return;
// we are done with search, let's look what we got
OutputCollector* ocoll;
Timer additionalReportingTime;
additionalReportingTime.start();

View File

@ -70,7 +70,12 @@ public:
// creator functions
static boost::shared_ptr<TranslationTask> create();
static boost::shared_ptr<TranslationTask>
static
boost::shared_ptr<TranslationTask>
create(boost::shared_ptr<Moses::InputType> const& source);
static
boost::shared_ptr<TranslationTask>
create(boost::shared_ptr<Moses::InputType> const& source,
boost::shared_ptr<Moses::IOWrapper> const& ioWrapper);
@ -78,8 +83,15 @@ public:
/** Translate one sentence
* gets called by main function implemented at end of this source file */
virtual void Run();
boost::shared_ptr<Moses::InputType>
GetSource() const { return m_source; }
private:
boost::shared_ptr<BaseManager>
SetupManager(SearchAlgorithm algo = DefaultSearchAlgorithm);
protected:
boost::shared_ptr<Moses::InputType> m_source;
boost::shared_ptr<Moses::IOWrapper> m_ioWrapper;

View File

@ -252,7 +252,7 @@ namespace MosesServer
m_reportAllFactors = check(params, "report-all-factors");
m_nbestDistinct = check(params, "nbest-distinct");
m_withScoreBreakdown = check(params, "add-score-breakdown");
m_source.reset(new Sentence(0,m_source_string));
si = params.find("lambda");
if (si != params.end())
{
@ -292,7 +292,7 @@ namespace MosesServer
istringstream buf(m_source_string + "\n");
tinput.Read(buf, StaticData::Instance().GetInputFactorOrder());
Moses::ChartManager manager(tinput);
Moses::ChartManager manager(this->self());
manager.Decode();
const Moses::ChartHypothesis *hypo = manager.GetBestHypothesis();
@ -356,7 +356,7 @@ namespace MosesServer
TranslationRequest::
run_phrase_decoder()
{
Manager manager(Sentence(0, m_source_string));
Manager manager(this->self());
// if (m_bias.size()) manager.SetBias(&m_bias);
manager.Decode();