Preparation for allowing context-aware decoding.

This commit is contained in:
Ulrich Germann 2015-05-19 02:35:39 +01:00
parent b085ca542d
commit dcb8e5d3e0
41 changed files with 86 additions and 61 deletions

View File

@ -105,10 +105,10 @@ bool PhraseDictionaryInterpolated::Load(
return true;
}
void PhraseDictionaryInterpolated::InitializeForInput(InputType const& source)
void PhraseDictionaryInterpolated::InitializeForInput(ttasksptr const& ttask)
{
for (size_t i = 0; i < m_dictionaries.size(); ++i) {
m_dictionaries[i]->InitializeForInput(source);
m_dictionaries[i]->InitializeForInput(ttask);
}
}

View File

@ -53,7 +53,7 @@ public:
, float weightWP);
virtual const TargetPhraseCollection *GetTargetPhraseCollection(const Phrase& src) const;
virtual void InitializeForInput(InputType const& source);
virtual void InitializeForInput(ttasksptr const& ttask);
virtual ChartRuleLookupManager *CreateRuleLookupManager(
const InputType &,
const ChartCellCollectionBase &) {

View File

@ -190,12 +190,12 @@ void FeatureFunction::SetTuneableComponents(const std::string& value)
}
}
void
FeatureFunction
::InitializeForInput(ttasksptr const& ttask)
{
InitializeForInput(*(ttask->GetSource().get()));
}
// void
// FeatureFunction
// ::InitializeForInput(ttasksptr const& ttask)
// {
// InitializeForInput(*(ttask->GetSource().get()));
// }
void
FeatureFunction

View File

@ -121,15 +121,13 @@ public:
size_t SetIndex(size_t const idx);
protected:
virtual void
InitializeForInput(InputType const& source) { }
virtual void
CleanUpAfterSentenceProcessing(InputType const& source) { }
public:
//! Called before search and collecting of translation options
virtual void
InitializeForInput(ttasksptr const& ttask);
InitializeForInput(ttasksptr const& ttask) { };
// clean up temporary memory, called after processing each sentence
virtual void

View File

@ -3,6 +3,7 @@
#include "moses/StaticData.h"
#include "moses/InputFileStream.h"
#include "moses/TranslationOption.h"
#include "moses/TranslationTask.h"
#include "moses/FactorCollection.h"
#include "util/exception.hh"
@ -108,10 +109,13 @@ void GlobalLexicalModel::Load()
}
}
void GlobalLexicalModel::InitializeForInput( Sentence const& in )
void GlobalLexicalModel::InitializeForInput(ttasksptr const& ttask)
{
UTIL_THROW_IF2(ttask->GetSource()->GetType() != SentenceInput,
"GlobalLexicalModel works only with sentence input.");
Sentence const* s = reinterpret_cast<Sentence const*>(ttask->GetSource().get());
m_local.reset(new ThreadLocalStorage);
m_local->input = &in;
m_local->input = s;
}
float GlobalLexicalModel::ScorePhrase( const TargetPhrase& targetPhrase ) const

View File

@ -66,7 +66,7 @@ public:
void SetParameter(const std::string& key, const std::string& value);
void InitializeForInput( Sentence const& in );
void InitializeForInput(ttasksptr const& ttask);
bool IsUseable(const FactorMask &mask) const;

View File

@ -3,6 +3,7 @@
#include "moses/StaticData.h"
#include "moses/InputFileStream.h"
#include "moses/Hypothesis.h"
#include "moses/TranslationTask.h"
#include "util/string_piece_hash.hh"
using namespace std;
@ -104,10 +105,13 @@ bool GlobalLexicalModelUnlimited::Load(const std::string &filePathSource,
return true;
}
void GlobalLexicalModelUnlimited::InitializeForInput( Sentence const& in )
void GlobalLexicalModelUnlimited::InitializeForInput(ttasksptr const& ttask)
{
UTIL_THROW_IF2(ttask->GetSource()->GetType() != SentenceInput,
"GlobalLexicalModel works only with sentence input.");
Sentence const* s = reinterpret_cast<Sentence const*>(ttask->GetSource().get());
m_local.reset(new ThreadLocalStorage);
m_local->input = &in;
m_local->input = s;
}
void GlobalLexicalModelUnlimited::EvaluateWhenApplied(const Hypothesis& cur_hypo, ScoreComponentCollection* accumulator) const

View File

@ -42,6 +42,7 @@ class GlobalLexicalModelUnlimited : public StatelessFeatureFunction
typedef std::map< std::string, short > StringHash;
struct ThreadLocalStorage {
// const Sentence *input;
const Sentence *input;
};
@ -73,7 +74,7 @@ public:
bool Load(const std::string &filePathSource, const std::string &filePathTarget);
void InitializeForInput( Sentence const& in );
void InitializeForInput(ttasksptr const& ttask);
const FFState* EmptyHypothesisState(const InputType &) const {
return new DummyState();

View File

@ -44,8 +44,8 @@ public:
EmptyHypothesisState(const InputType &input) const;
void
InitializeForInput(const InputType& i) {
if (m_table) m_table->InitializeForInput(i);
InitializeForInput(ttasksptr const& ttask) {
if (m_table) m_table->InitializeForInput(ttask);
}
Scores

View File

@ -7,6 +7,7 @@
#include "moses/GenerationDictionary.h"
#include "moses/TargetPhrase.h"
#include "moses/TargetPhraseCollection.h"
#include "moses/TranslationTask.h"
#if !defined WIN32 || defined __MINGW32__ || defined HAVE_CMPH
#include "moses/TranslationModel/CompactPT/LexicalReorderingTableCompact.h"
@ -290,8 +291,9 @@ auxFindScoreForContext(const Candidates& cands, const Phrase& context)
void
LexicalReorderingTableTree::
InitializeForInput(const InputType& input)
InitializeForInput(ttasksptr const& ttask)
{
const InputType& input = *ttask->GetSource();
ClearCache();
if(ConfusionNet const* cn = dynamic_cast<ConfusionNet const*>(&input)) {
Cache(*cn);

View File

@ -54,7 +54,7 @@ public:
virtual
void
InitializeForInput(const InputType&) {
InitializeForInput(ttasksptr const& ttask) {
/* override for on-demand loading */
};
@ -177,7 +177,7 @@ public:
virtual
void
InitializeForInput(const InputType& input);
InitializeForInput(ttasksptr const& ttask);
virtual
void

View File

@ -306,7 +306,8 @@ public:
}
}
virtual void InitializeForInput(InputType const& source) {
virtual void InitializeForInput(ttasksptr const& ttask) {
InputType const& source = ttask->GetSource();
// tabbed sentence is assumed only in training
if (! m_train)
return;

View File

@ -39,7 +39,8 @@ public:
VWFeatureSource::SetParameter(key, value);
}
virtual void InitializeForInput(InputType const& source) {
virtual void InitializeForInput(ttasksptr const& ttask) {
InputType const& source = ttask->GetSource();
UTIL_THROW_IF2(source.GetType() != TabbedSentenceInput,
"This feature function requires the TabbedSentence input type");

View File

@ -402,7 +402,7 @@ bool LMCacheCleanup(const int sentences_done, const size_t m_lmcache_cleanup_thr
return false;
}
void LanguageModelIRST::InitializeForInput(InputType const& source)
void LanguageModelIRST::InitializeForInput(ttasksptr const& ttask)
{
//nothing to do
#ifdef TRACE_CACHE

View File

@ -104,7 +104,7 @@ public:
*/
void InitializeForInput(InputType const& source);
void InitializeForInput(ttasksptr const& ttask);
void CleanUpAfterSentenceProcessing(const InputType& source);
void set_dictionary_upperbound(int dub) {

View File

@ -83,7 +83,7 @@ public:
LDHT::Client* getClientSafe();
LDHT::Client* initTSSClient();
virtual ~LanguageModelLDHT();
virtual void InitializeForInput(InputType const& source);
virtual void InitializeForInput(ttasksptr const& ttask);
virtual void CleanUpAfterSentenceProcessing(const InputType &source);
virtual const FFState* EmptyHypothesisState(const InputType& input) const;
virtual void CalcScore(const Phrase& phrase,
@ -189,7 +189,7 @@ LDHT::Client* LanguageModelLDHT::initTSSClient()
return client;
}
void LanguageModelLDHT::InitializeForInput(InputType const& source)
void LanguageModelLDHT::InitializeForInput(ttasksptr const& ttask)
{
getClientSafe()->clearCache();
m_start_tick = LDHT::Util::rdtsc();

View File

@ -134,7 +134,7 @@ LMResult LanguageModelRandLM::GetValue(const vector<const Word*> &contextFactor,
return ret;
}
void LanguageModelRandLM::InitializeForInput(InputType const& source)
void LanguageModelRandLM::InitializeForInput(ttasksptr const& ttask)
{
m_lm->initThreadSpecificData(); // Creates thread specific data iff // compiled with multithreading.
}

View File

@ -41,7 +41,7 @@ public:
void Load();
virtual LMResult GetValue(const std::vector<const Word*> &contextFactor, State* finalState = NULL) const;
void InitializeForInput(InputType const& source);
void InitializeForInput(ttasksptr const& ttask);
void CleanUpAfterSentenceProcessing(const InputType& source);
protected:

View File

@ -176,9 +176,10 @@ void OxLM<Model>::savePersistentCache(const string& cache_file) const
}
template<class Model>
void OxLM<Model>::InitializeForInput(const InputType& source)
void OxLM<Model>::InitializeForInput(ttasksptr const& ttask)
{
LanguageModelSingleFactor::InitializeForInput(source);
const InputType& source = ttask->GetSource();
LanguageModelSingleFactor::InitializeForInput(ttask);
if (persistentCache) {
if (!cache.get()) {

View File

@ -30,7 +30,7 @@ public:
const std::vector<const Word*> &contextFactor,
State* finalState = 0) const;
virtual void InitializeForInput(const InputType& source);
virtual void InitializeForInput(ttasksptr const& ttask);
virtual void CleanUpAfterSentenceProcessing(const InputType& source);

View File

@ -101,9 +101,10 @@ void SourceOxLM::SetParameter(const string& key, const string& value)
}
}
void SourceOxLM::InitializeForInput(const InputType& source)
void SourceOxLM::InitializeForInput(ttasksptr const& ttask)
{
BilingualLM::InitializeForInput(source);
const InputType& source = ttasksptr->GetSource();
BilingualLM::InitializeForInput(ttask, source);
if (persistentCache) {
if (!cache.get()) {

View File

@ -31,7 +31,7 @@ private:
void SetParameter(const std::string& key, const std::string& value);
void InitializeForInput(const InputType& source);
void InitializeForInput(ttasksptr const& ttask);
void CleanUpAfterSentenceProcessing(const InputType& source);

View File

@ -117,7 +117,7 @@ public:
virtual
TargetPhraseCollection const *
GetTargetPhraseCollectionLEGACY(ttasksptr const& ttask, const Phrase& src) {
GetTargetPhraseCollectionLEGACY(ttasksptr const& ttask, const Phrase& src) const {
return GetTargetPhraseCollectionLEGACY(src);
}
@ -133,7 +133,7 @@ public:
}
//! Create entry for translation of source to targetPhrase
virtual void InitializeForInput(InputType const& source) {
virtual void InitializeForInput(ttasksptr const& ttask) {
}
// clean up temporary memory, called after processing each sentence
virtual void CleanUpAfterSentenceProcessing(const InputType& source) {

View File

@ -145,7 +145,7 @@ void PhraseDictionaryDynamicCacheBased::SetParameter(const std::string& key, con
}
}
void PhraseDictionaryDynamicCacheBased::InitializeForInput(InputType const& source)
void PhraseDictionaryDynamicCacheBased::InitializeForInput(ttasksptr const& ttask)
{
ReduceCache();
}

View File

@ -123,7 +123,7 @@ public:
void SetParameter(const std::string& key, const std::string& value);
void InitializeForInput(InputType const& source);
void InitializeForInput(ttasksptr const& ttask);
// virtual void InitializeForInput(InputType const&) {
// /* Don't do anything source specific here as this object is shared between threads.*/

View File

@ -85,7 +85,7 @@ public:
#endif
// functions below required by base class
virtual const TargetPhraseCollection* GetTargetPhraseCollectionLEGACY(const Phrase& src) const;
virtual void InitializeForInput(InputType const&) {
virtual void InitializeForInput(ttasksptr const& ttask) {
/* Don't do anything source specific here as this object is shared between threads.*/
}
ChartRuleLookupManager *CreateRuleLookupManager(const ChartParser &, const ChartCellCollectionBase&, std::size_t);

View File

@ -96,7 +96,7 @@ public:
std::vector<float> MinimizePerplexity(std::vector<std::pair<std::string, std::string> > &phrase_pair_vector);
#endif
// functions below required by base class
virtual void InitializeForInput(InputType const&) {
virtual void InitializeForInput(ttasksptr const& ttask) {
/* Don't do anything source specific here as this object is shared between threads.*/
}

View File

@ -13,6 +13,7 @@
#include "moses/StaticData.h"
#include "moses/UniqueObject.h"
#include "moses/PDTAimp.h"
#include "moses/TranslationTask.h"
#include "util/exception.hh"
using namespace std;
@ -40,8 +41,9 @@ void PhraseDictionaryTreeAdaptor::Load()
SetFeaturesToApply();
}
void PhraseDictionaryTreeAdaptor::InitializeForInput(InputType const& source)
void PhraseDictionaryTreeAdaptor::InitializeForInput(ttasksptr const& ttask)
{
InputType const& source = *ttask->GetSource();
const StaticData &staticData = StaticData::Instance();
ReduceCache();

View File

@ -61,7 +61,7 @@ public:
// returns null pointer if nothing found
TargetPhraseCollection const* GetTargetPhraseCollectionNonCacheLEGACY(Phrase const &src) const;
void InitializeForInput(InputType const& source);
void InitializeForInput(ttasksptr const& ttask);
void CleanUpAfterSentenceProcessing(InputType const& source);
virtual ChartRuleLookupManager *CreateRuleLookupManager(

View File

@ -61,7 +61,7 @@ void ProbingPT::Load()
}
}
void ProbingPT::InitializeForInput(InputType const& source)
void ProbingPT::InitializeForInput(ttasksptr const& ttask)
{
ReduceCache();
}

View File

@ -23,7 +23,7 @@ public:
void Load();
void InitializeForInput(InputType const& source);
void InitializeForInput(ttasksptr const& ttask);
// for phrase-based model
void GetTargetPhraseCollectionBatch(const InputPathList &inputPathQueue) const;

View File

@ -11,6 +11,7 @@
#include "moses/InputType.h"
#include "moses/InputFileStream.h"
#include "moses/TypeDef.h"
#include "moses/TranslationTask.h"
#include "moses/StaticData.h"
#include "Loader.h"
#include "LoaderFactory.h"
@ -36,8 +37,9 @@ void PhraseDictionaryALSuffixArray::Load()
SetFeaturesToApply();
}
void PhraseDictionaryALSuffixArray::InitializeForInput(InputType const& source)
void PhraseDictionaryALSuffixArray::InitializeForInput(ttasksptr const& ttask)
{
InputType const& source = *ttask->GetSource();
// populate with rules for this sentence
long translationId = source.GetTranslationId();

View File

@ -24,7 +24,7 @@ class PhraseDictionaryALSuffixArray : public PhraseDictionaryMemory
public:
PhraseDictionaryALSuffixArray(const std::string &line);
void Load();
void InitializeForInput(InputType const& source);
void InitializeForInput(ttasksptr const& ttask);
void CleanUpAfterSentenceProcessing(const InputType& source);
protected:

View File

@ -43,6 +43,7 @@
#include "moses/TranslationModel/CYKPlusParser/ChartRuleLookupManagerMemoryPerSentence.h"
#include "moses/TranslationModel/fuzzy-match/FuzzyMatchWrapper.h"
#include "moses/TranslationModel/fuzzy-match/SentenceAlignment.h"
#include "moses/TranslationTask.h"
#include "util/file.hh"
#include "util/exception.hh"
#include "util/random.hh"
@ -172,8 +173,9 @@ int removedirectoryrecursively(const char *dirname)
return 1;
}
void PhraseDictionaryFuzzyMatch::InitializeForInput(InputType const& inputSentence)
void PhraseDictionaryFuzzyMatch::InitializeForInput(ttasksptr const& ttask)
{
InputType const& inputSentence = *ttask->GetSource();
#if defined __MINGW32__
char dirName[] = "moses.XXXXXX";
#else

View File

@ -51,7 +51,7 @@ public:
const ChartParser &parser,
const ChartCellCollectionBase &,
std::size_t);
void InitializeForInput(InputType const& inputSentence);
void InitializeForInput(ttasksptr const& ttask);
void CleanUpAfterSentenceProcessing(const InputType& source);
void SetParameter(const std::string& key, const std::string& value);

View File

@ -25,6 +25,7 @@
#include "moses/InputPath.h"
#include "moses/TranslationModel/CYKPlusParser/DotChartOnDisk.h"
#include "moses/TranslationModel/CYKPlusParser/ChartRuleLookupManagerOnDisk.h"
#include "moses/TranslationTask.h"
#include "OnDiskPt/OnDiskWrapper.h"
#include "OnDiskPt/Word.h"
@ -78,8 +79,9 @@ const OnDiskPt::OnDiskWrapper &PhraseDictionaryOnDisk::GetImplementation() const
return *dict;
}
void PhraseDictionaryOnDisk::InitializeForInput(InputType const& source)
void PhraseDictionaryOnDisk::InitializeForInput(ttasksptr const& ttask)
{
InputType const& source = *ttask->GetSource();
ReduceCache();
OnDiskPt::OnDiskWrapper *obj = new OnDiskPt::OnDiskWrapper();

View File

@ -75,7 +75,7 @@ public:
const ChartCellCollectionBase &,
std::size_t);
virtual void InitializeForInput(InputType const& source);
virtual void InitializeForInput(ttasksptr const& ttask);
void GetTargetPhraseCollectionBatch(const InputPathList &inputPathQueue) const;
const TargetPhraseCollection *GetTargetPhraseCollection(const OnDiskPt::PhraseNode *ptNode) const;

View File

@ -17,7 +17,7 @@ void SkeletonPT::Load()
SetFeaturesToApply();
}
void SkeletonPT::InitializeForInput(InputType const& source)
void SkeletonPT::InitializeForInput(ttasksptr const& ttask)
{
ReduceCache();
}

View File

@ -18,7 +18,7 @@ public:
void Load();
void InitializeForInput(InputType const& source);
void InitializeForInput(ttasksptr const& ttask);
// for phrase-based model
void GetTargetPhraseCollectionBatch(const InputPathList &inputPathQueue) const;

View File

@ -593,7 +593,7 @@ namespace Moses
inputPath.SetTargetPhrases(*this, targetPhrases, NULL);
}
}
TargetPhraseCollection const*
Mmsapt::
GetTargetPhraseCollectionLEGACY(const Phrase& src) const
@ -645,6 +645,7 @@ namespace Moses
// get context-specific cache of items previously looked up
sptr<ContextScope> const& scope = ttask->GetScope();
sptr<TPCollCache> cache = scope->get<TPCollCache>(cache_key);
if (!cache) cache = m_cache;
TPCollWrapper* ret = cache->get(phrasekey, dyn->revision());
// TO DO: we should revise the revision mechanism: we take the length
// of the dynamic bitext (in sentences) at the time the PT entry

View File

@ -1,5 +1,6 @@
#include "mmsapt.h"
#include "moses/TranslationModel/PhraseDictionaryTreeAdaptor.h"
#include "moses/TranslationTask.h"
#include <boost/foreach.hpp>
#include <boost/format.hpp>
#include <boost/tokenizer.hpp>
@ -67,17 +68,19 @@ int main(int argc, char* argv[])
string line;
while (true)
{
Sentence phrase;
if (!phrase.Read(cin,ifo)) break;
boost::shared_ptr<Sentence> phrase(new Sentence);
if (!phrase->Read(cin,ifo)) break;
boost::shared_ptr<TranslationTask> ttask;
ttask = TranslationTask::create(phrase);
if (pdta)
{
pdta->InitializeForInput(phrase);
pdta->InitializeForInput(ttask);
// do we also need to call CleanupAfterSentenceProcessing at the end?
}
Phrase& p = phrase;
Phrase& p = *phrase;
cout << p << endl;
TargetPhraseCollection const* trg = PT->GetTargetPhraseCollectionLEGACY(p);
TargetPhraseCollection const* trg = PT->GetTargetPhraseCollectionLEGACY(ttask,p);
if (!trg) continue;
vector<size_t> order(trg->GetSize());
for (size_t i = 0; i < order.size(); ++i) order[i] = i;