mirror of
https://github.com/moses-smt/mosesdecoder.git
synced 2024-12-28 14:32:38 +03:00
Merge branch 'master' of github.com:moses-smt/mosesdecoder
This commit is contained in:
commit
5d5ee9e6b1
@ -3,6 +3,7 @@ import option ;
|
||||
import os ;
|
||||
import path ;
|
||||
import project ;
|
||||
import build-system ;
|
||||
|
||||
#Shell with trailing line removed http://lists.boost.org/boost-build/2007/08/17051.php
|
||||
rule trim-nl ( str extras * ) {
|
||||
@ -186,7 +187,59 @@ rule install-headers ( name : list * : source-root ? ) {
|
||||
}
|
||||
|
||||
rule build-projects ( projects * ) {
|
||||
for p in $(projects) {
|
||||
for local p in $(projects) {
|
||||
build-project $(p) ;
|
||||
}
|
||||
}
|
||||
|
||||
#Only one post build hook is allowed. Allow multiple.
|
||||
post-hooks = ;
|
||||
rule post-build ( ok ? ) {
|
||||
for local r in $(post-hooks) {
|
||||
$(r) $(ok) ;
|
||||
}
|
||||
}
|
||||
IMPORT $(__name__) : post-build : : $(__name__).post-build ;
|
||||
build-system.set-post-build-hook $(__name__).post-build ;
|
||||
rule add-post-hook ( names * ) {
|
||||
post-hooks += $(names) ;
|
||||
}
|
||||
|
||||
|
||||
#Backend for writing content to files after build completes.
|
||||
post-files = ;
|
||||
post-contents = ;
|
||||
rule save-post-build ( ok ? ) {
|
||||
if $(ok) {
|
||||
while $(post-files) {
|
||||
local ignored = @($(post-files[1]):E=$(post-contents[1])) ;
|
||||
post-files = $(post-files[2-]) ;
|
||||
post-contents = $(post-contents[2-]) ;
|
||||
}
|
||||
}
|
||||
}
|
||||
add-post-hook save-post-build ;
|
||||
|
||||
#Queue content to be written to file when build completes successfully.
|
||||
rule add-post-write ( name content ) {
|
||||
post-files += $(name) ;
|
||||
post-contents += $(content) ;
|
||||
}
|
||||
|
||||
#Compare contents of file with current. If they're different, force the targets to rebuild then overwrite the file.
|
||||
rule always-if-changed ( file current : targets * ) {
|
||||
local previous = inconsistent ;
|
||||
if [ path.exists $(file) ] {
|
||||
previous = [ _shell "cat $(file)" ] ;
|
||||
}
|
||||
if $(current) != $(previous) {
|
||||
#Write inconsistent while the build is running
|
||||
if [ path.exists $(file) ] {
|
||||
local ignored = @($(file):E=inconsistent) ;
|
||||
}
|
||||
add-post-write $(file) $(current) ;
|
||||
for local i in $(targets) {
|
||||
always $(i) ;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -258,6 +258,32 @@ void Hypothesis::ResetScore()
|
||||
m_futureScore = m_totalScore = 0.0f;
|
||||
}
|
||||
|
||||
void Hypothesis::IncorporateTransOptScores() {
|
||||
m_scoreBreakdown.PlusEquals(m_transOpt->GetScoreBreakdown());
|
||||
}
|
||||
|
||||
void Hypothesis::EvaluateWith(StatefulFeatureFunction* sfff,
|
||||
int state_idx) {
|
||||
m_ffStates[state_idx] = sfff->Evaluate(
|
||||
*this,
|
||||
m_prevHypo ? m_prevHypo->m_ffStates[state_idx] : NULL,
|
||||
&m_scoreBreakdown);
|
||||
|
||||
}
|
||||
|
||||
void Hypothesis::EvaluateWith(const StatelessFeatureFunction* slff) {
|
||||
slff->Evaluate(m_targetPhrase, &m_scoreBreakdown);
|
||||
}
|
||||
|
||||
void Hypothesis::CalculateFutureScore(const SquareMatrix& futureScore) {
|
||||
m_futureScore = futureScore.CalcFutureScore( m_sourceCompleted );
|
||||
}
|
||||
|
||||
void Hypothesis::CalculateFinalScore() {
|
||||
m_totalScore = m_scoreBreakdown.InnerProduct(
|
||||
StaticData::Instance().GetAllWeights()) + m_futureScore;
|
||||
}
|
||||
|
||||
/***
|
||||
* calculate the logarithm of our total translation score (sum up components)
|
||||
*/
|
||||
|
@ -236,9 +236,19 @@ public:
|
||||
float GetScore() const {
|
||||
return m_totalScore-m_futureScore;
|
||||
}
|
||||
const FFState* GetFFState(int idx) const {
|
||||
return m_ffStates[idx];
|
||||
}
|
||||
void SetFFState(int idx, FFState* state) {
|
||||
m_ffStates[idx] = state;
|
||||
}
|
||||
|
||||
|
||||
|
||||
// Added by oliver.wilson@ed.ac.uk for async lm stuff.
|
||||
void IncorporateTransOptScores();
|
||||
void EvaluateWith(StatefulFeatureFunction* sfff, int state_idx);
|
||||
void EvaluateWith(const StatelessFeatureFunction* slff);
|
||||
void CalculateFutureScore(const SquareMatrix& futureScore);
|
||||
void CalculateFinalScore();
|
||||
|
||||
//! target span that trans opt would populate if applied to this hypo. Used for alignment check
|
||||
size_t GetNextStartPos(const TranslationOption &transOpt) const;
|
||||
|
@ -84,6 +84,15 @@ public:
|
||||
* \param oovCount number of LM OOVs
|
||||
*/
|
||||
virtual void CalcScore(const Phrase &phrase, float &fullScore, float &ngramScore, std::size_t &oovCount) const = 0;
|
||||
|
||||
virtual void IssueRequestsFor(Hypothesis& hypo,
|
||||
const FFState* input_state) {
|
||||
}
|
||||
virtual void sync() {
|
||||
}
|
||||
virtual void SetFFStateIdx(int state_idx) {
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
}
|
||||
|
@ -5,7 +5,7 @@
|
||||
#Each optional model has a section below. The top level rule is lib LM, which
|
||||
#appears after the optional models.
|
||||
|
||||
import option path build-system ;
|
||||
import option path ;
|
||||
|
||||
local dependencies = ;
|
||||
|
||||
@ -82,33 +82,15 @@ lib LM : Base.cpp Factory.o Implementation.cpp Joint.cpp Ken.cpp MultiFactor.cpp
|
||||
#the install path doesn't encode features). It stores a file lm.log with the
|
||||
#previous options and forces a rebuild if the current options differ.
|
||||
path-constant LM-LOG : bin/lm.log ;
|
||||
#Is there no other way to read a file with bjam?
|
||||
local previous = none ;
|
||||
if [ path.exists $(LM-LOG) ] {
|
||||
previous = [ _shell "cat $(LM-LOG)" ] ;
|
||||
}
|
||||
current = "" ;
|
||||
|
||||
local current = ;
|
||||
for local i in srilm irstlm randlm {
|
||||
local optval = [ option.get "with-$(i)" ] ;
|
||||
if $(optval) {
|
||||
current = "$(current) --with-$(i)=$(optval)" ;
|
||||
current += "--with-$(i)=$(optval)" ;
|
||||
}
|
||||
}
|
||||
current = $(current:J=" ") ;
|
||||
current ?= "" ;
|
||||
|
||||
if $(current) != $(previous) {
|
||||
#Write inconsistent while the build is running
|
||||
if [ path.exists $(LM-LOG) ] {
|
||||
local ignored = @($(LM-LOG):E=inconsistent) ;
|
||||
}
|
||||
#Write $(current) to $(LM-LOG) after the build completes.
|
||||
rule post-build ( ok ? ) {
|
||||
if $(ok) {
|
||||
local ignored = @($(LM-LOG):E=$(current)) ;
|
||||
}
|
||||
}
|
||||
IMPORT $(__name__) : post-build : : $(__name__).post-build ;
|
||||
build-system.set-post-build-hook $(__name__).post-build ;
|
||||
|
||||
always Factory.o ;
|
||||
always LM ;
|
||||
}
|
||||
always-if-changed $(LM-LOG) $(current) : Factory.o LM ;
|
||||
|
@ -7,6 +7,7 @@
|
||||
#include "../FFState.h"
|
||||
#include "../TypeDef.h"
|
||||
#include "../Hypothesis.h"
|
||||
#include "../StaticData.h"
|
||||
|
||||
#include <LDHT/Client.h>
|
||||
#include <LDHT/ClientLocal.h>
|
||||
@ -19,14 +20,43 @@ namespace Moses {
|
||||
|
||||
struct LDHTLMState : public FFState {
|
||||
LDHT::NewNgram gram_fingerprints;
|
||||
bool finalised;
|
||||
std::vector<int> request_tags;
|
||||
|
||||
LDHTLMState(): finalised(false) {
|
||||
}
|
||||
|
||||
void setFinalised() {
|
||||
this->finalised = true;
|
||||
}
|
||||
|
||||
void appendRequestTag(int tag) {
|
||||
this->request_tags.push_back(tag);
|
||||
}
|
||||
|
||||
void clearRequestTags() {
|
||||
this->request_tags.clear();
|
||||
}
|
||||
|
||||
std::vector<int>::iterator requestTagsBegin() {
|
||||
return this->request_tags.begin();
|
||||
}
|
||||
|
||||
std::vector<int>::iterator requestTagsEnd() {
|
||||
return this->request_tags.end();
|
||||
}
|
||||
|
||||
int Compare(const FFState& uncast_other) const {
|
||||
const LDHTLMState &other = static_cast<const LDHTLMState&>(uncast_other);
|
||||
//if (!this->finalised)
|
||||
// return -1;
|
||||
|
||||
return gram_fingerprints.compareMoses(other.gram_fingerprints);
|
||||
}
|
||||
|
||||
void copyFrom(const LDHTLMState& other) {
|
||||
gram_fingerprints.copyFrom(other.gram_fingerprints);
|
||||
finalised = false;
|
||||
}
|
||||
};
|
||||
|
||||
@ -40,7 +70,7 @@ public:
|
||||
LanguageModelLDHT& copyFrom);
|
||||
std::string GetScoreProducerDescription(unsigned) const {
|
||||
std::ostringstream oss;
|
||||
oss << "LM_" << LDHT::NewNgram::k_max_order << "gram";
|
||||
oss << "DLM_" << LDHT::NewNgram::k_max_order << "gram";
|
||||
return oss.str();
|
||||
}
|
||||
LDHT::Client* getClientUnsafe() const;
|
||||
@ -64,10 +94,18 @@ public:
|
||||
int featureID,
|
||||
ScoreComponentCollection* accumulator) const;
|
||||
|
||||
virtual void IssueRequestsFor(Hypothesis& hypo,
|
||||
const FFState* input_state);
|
||||
float calcScoreFromState(LDHTLMState* hypo) const;
|
||||
void sync();
|
||||
void SetFFStateIdx(int state_idx);
|
||||
|
||||
protected:
|
||||
boost::thread_specific_ptr<LDHT::Client> m_client;
|
||||
std::string m_configPath;
|
||||
FactorType m_factorType;
|
||||
int m_state_idx;
|
||||
uint64_t m_start_tick;
|
||||
|
||||
};
|
||||
|
||||
@ -99,7 +137,7 @@ LanguageModelLDHT::LanguageModelLDHT(const std::string& path,
|
||||
|
||||
LanguageModelLDHT::~LanguageModelLDHT() {
|
||||
// TODO(wilson): should cleanup for each individual thread.
|
||||
delete getClientSafe();
|
||||
//delete getClientSafe();
|
||||
}
|
||||
|
||||
LanguageModel* LanguageModelLDHT::Duplicate(
|
||||
@ -131,8 +169,8 @@ LDHT::Client* LanguageModelLDHT::initTSSClient() {
|
||||
LDHT::FactoryCollection::createDefaultFactoryCollection();
|
||||
|
||||
LDHT::Client* client;
|
||||
client = new LDHT::ClientLocal();
|
||||
//client = new LDHT::Client();
|
||||
//client = new LDHT::ClientLocal();
|
||||
client = new LDHT::Client();
|
||||
client->fromXmlFiles(*factory_collection,
|
||||
ldht_config_path,
|
||||
ldhtlm_config_path);
|
||||
@ -141,9 +179,26 @@ LDHT::Client* LanguageModelLDHT::initTSSClient() {
|
||||
|
||||
void LanguageModelLDHT::InitializeBeforeSentenceProcessing() {
|
||||
getClientSafe()->clearCache();
|
||||
m_start_tick = LDHT::Util::rdtsc();
|
||||
}
|
||||
|
||||
void LanguageModelLDHT::CleanUpAfterSentenceProcessing() {
|
||||
LDHT::Client* client = getClientSafe();
|
||||
|
||||
std::cerr << "LDHT sentence stats:" << std::endl;
|
||||
std::cerr << " ngrams submitted: " << client->getNumNgramsSubmitted() << std::endl
|
||||
<< " ngrams requested: " << client->getNumNgramsRequested() << std::endl
|
||||
<< " ngrams not found: " << client->getKeyNotFoundCount() << std::endl
|
||||
<< " cache hits: " << client->getCacheHitCount() << std::endl
|
||||
<< " inferences: " << client->getInferenceCount() << std::endl
|
||||
<< " pcnt latency: " << (float)client->getLatencyTicks() / (float)(LDHT::Util::rdtsc() - m_start_tick) * 100.0 << std::endl;
|
||||
m_start_tick = 0;
|
||||
client->resetLatencyTicks();
|
||||
client->resetNumNgramsSubmitted();
|
||||
client->resetNumNgramsRequested();
|
||||
client->resetInferenceCount();
|
||||
client->resetCacheHitCount();
|
||||
client->resetKeyNotFoundCount();
|
||||
}
|
||||
|
||||
const FFState* LanguageModelLDHT::EmptyHypothesisState(
|
||||
@ -159,6 +214,9 @@ void LanguageModelLDHT::CalcScore(const Phrase& phrase,
|
||||
float& fullScore,
|
||||
float& ngramScore,
|
||||
std::size_t& oovCount) const {
|
||||
// Issue requests for phrase internal ngrams.
|
||||
// Sync if necessary. (or autosync).
|
||||
|
||||
// TODO(wilson): handle nonterminal words.
|
||||
LDHT::Client* client = getClientUnsafe();
|
||||
// Score the first order - 1 words of the phrase.
|
||||
@ -203,10 +261,8 @@ void LanguageModelLDHT::CalcScore(const Phrase& phrase,
|
||||
oovCount = 0;
|
||||
}
|
||||
|
||||
FFState* LanguageModelLDHT::Evaluate(
|
||||
const Hypothesis& hypo,
|
||||
const FFState* input_state,
|
||||
ScoreComponentCollection* score_output) const {
|
||||
void LanguageModelLDHT::IssueRequestsFor(Hypothesis& hypo,
|
||||
const FFState* input_state) {
|
||||
// TODO(wilson): handle nonterminal words.
|
||||
LDHT::Client* client = getClientUnsafe();
|
||||
|
||||
@ -236,11 +292,10 @@ FFState* LanguageModelLDHT::Evaluate(
|
||||
int overlap_end = std::min(phrase_end, phrase_start + order - 1);
|
||||
int word_idx = overlap_start;
|
||||
LDHT::NewNgram& ngram = new_state->gram_fingerprints;
|
||||
std::deque<int> request_tags;
|
||||
for (; word_idx < overlap_end; ++word_idx) {
|
||||
ngram.appendGram(
|
||||
hypo.GetFactor(word_idx, m_factorType)->GetString().c_str());
|
||||
request_tags.push_back(client->requestNgram(ngram));
|
||||
new_state->appendRequestTag(client->requestNgram(ngram));
|
||||
}
|
||||
// No need to score phrase internal ngrams, but keep track of them
|
||||
// in the state (which in this case is the NewNgram containing the
|
||||
@ -253,22 +308,36 @@ FFState* LanguageModelLDHT::Evaluate(
|
||||
// with the end of sentence marker on it.
|
||||
if (hypo.IsSourceCompleted()) {
|
||||
ngram.appendGram(EOS_);
|
||||
request_tags.push_back(client->requestNgram(ngram));
|
||||
//request_tags.push_back(client->requestNgram(ngram));
|
||||
new_state->appendRequestTag(client->requestNgram(ngram));
|
||||
}
|
||||
// Await responses from the server.
|
||||
client->awaitResponses();
|
||||
hypo.SetFFState(m_state_idx, new_state);
|
||||
}
|
||||
|
||||
// Calculate scores given the request tags.
|
||||
float score = 0;
|
||||
while (!request_tags.empty()) {
|
||||
score += client->getNgramScore(request_tags.front());
|
||||
request_tags.pop_front();
|
||||
}
|
||||
void LanguageModelLDHT::sync() {
|
||||
getClientUnsafe()->awaitResponses();
|
||||
}
|
||||
|
||||
void LanguageModelLDHT::SetFFStateIdx(int state_idx) {
|
||||
m_state_idx = state_idx;
|
||||
}
|
||||
|
||||
FFState* LanguageModelLDHT::Evaluate(
|
||||
const Hypothesis& hypo,
|
||||
const FFState* input_state_ignored,
|
||||
ScoreComponentCollection* score_output) const {
|
||||
// Input state is the state from the previous hypothesis, which
|
||||
// we are not interested in. The requests for this hypo should
|
||||
// already have been issued via IssueRequestsFor() and the LM then
|
||||
// synced and all responses processed, and the tags placed in our
|
||||
// FFState of hypo.
|
||||
LDHTLMState* state = const_cast<LDHTLMState*>(static_cast<const LDHTLMState*>(hypo.GetFFState(m_state_idx)));
|
||||
|
||||
float score = calcScoreFromState(state);
|
||||
score = FloorScore(TransformLMScore(score));
|
||||
score_output->PlusEquals(this, score);
|
||||
|
||||
return new_state;
|
||||
return state;
|
||||
}
|
||||
|
||||
FFState* LanguageModelLDHT::EvaluateChart(
|
||||
@ -278,5 +347,19 @@ FFState* LanguageModelLDHT::EvaluateChart(
|
||||
return NULL;
|
||||
}
|
||||
|
||||
float LanguageModelLDHT::calcScoreFromState(LDHTLMState* state) const {
|
||||
float score = 0.0;
|
||||
std::vector<int>::iterator tag_iter;
|
||||
LDHT::Client* client = getClientUnsafe();
|
||||
for (tag_iter = state->requestTagsBegin();
|
||||
tag_iter != state->requestTagsEnd();
|
||||
++tag_iter) {
|
||||
score += client->getNgramScore(*tag_iter);
|
||||
}
|
||||
state->clearRequestTags();
|
||||
state->setFinalised();
|
||||
return score;
|
||||
}
|
||||
|
||||
} // namespace Moses.
|
||||
|
||||
|
@ -42,6 +42,22 @@ public:
|
||||
~LMList();
|
||||
|
||||
void CalcScore(const Phrase &phrase, float &retFullScore, float &retNGramScore, float &retOOVScore, ScoreComponentCollection* breakdown) const;
|
||||
void InitializeBeforeSentenceProcessing() {
|
||||
std::list<LanguageModel*>::iterator lm_iter;
|
||||
for (lm_iter = m_coll.begin();
|
||||
lm_iter != m_coll.end();
|
||||
++lm_iter) {
|
||||
(*lm_iter)->InitializeBeforeSentenceProcessing();
|
||||
}
|
||||
}
|
||||
void CleanUpAfterSentenceProcessing() {
|
||||
std::list<LanguageModel*>::iterator lm_iter;
|
||||
for (lm_iter = m_coll.begin();
|
||||
lm_iter != m_coll.end();
|
||||
++lm_iter) {
|
||||
(*lm_iter)->CleanUpAfterSentenceProcessing();
|
||||
}
|
||||
}
|
||||
|
||||
void Add(LanguageModel *lm);
|
||||
|
||||
|
@ -127,7 +127,7 @@ Parameter::Parameter()
|
||||
AddParam("cube-pruning-diversity", "cbd", "How many hypotheses should be created for each coverage. (default = 0)");
|
||||
AddParam("cube-pruning-lazy-scoring", "cbls", "Don't fully score a hypothesis until it is popped");
|
||||
AddParam("parsing-algorithm", "Which parsing algorithm to use. 0=CYK+, 1=scope-3. (default = 0)");
|
||||
AddParam("search-algorithm", "Which search algorithm to use. 0=normal stack, 1=cube pruning, 2=cube growing. (default = 0)");
|
||||
AddParam("search-algorithm", "Which search algorithm to use. 0=normal stack, 1=cube pruning, 2=cube growing, 4=stack with batched lm requests (default = 0)");
|
||||
AddParam("constraint", "Location of the file with target sentences to produce constraining the search");
|
||||
AddParam("use-alignment-info", "Use word-to-word alignment: actually it is only used to output the word-to-word alignment. Word-to-word alignments are taken from the phrase table if any. Default is false.");
|
||||
AddParam("print-alignment-info", "Output word-to-word alignment into the log file. Word-to-word alignments are takne from the phrase table if any. Default is false");
|
||||
|
@ -65,6 +65,8 @@ bool PhraseDictionaryMemory::Load(const std::vector<FactorType> &input
|
||||
, const LMList &languageModels
|
||||
, float weightWP)
|
||||
{
|
||||
const_cast<LMList&>(languageModels).InitializeBeforeSentenceProcessing();
|
||||
|
||||
const StaticData &staticData = StaticData::Instance();
|
||||
|
||||
m_tableLimit = tableLimit;
|
||||
@ -161,6 +163,8 @@ bool PhraseDictionaryMemory::Load(const std::vector<FactorType> &input
|
||||
// sort each target phrase collection
|
||||
m_collection.Sort(m_tableLimit);
|
||||
|
||||
const_cast<LMList&>(languageModels).CleanUpAfterSentenceProcessing();
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -2,6 +2,7 @@
|
||||
#include "Manager.h"
|
||||
#include "SearchCubePruning.h"
|
||||
#include "SearchNormal.h"
|
||||
#include "SearchNormalBatch.h"
|
||||
#include "UserMessage.h"
|
||||
|
||||
namespace Moses
|
||||
@ -18,6 +19,8 @@ Search *Search::CreateSearch(Manager& manager, const InputType &source,
|
||||
return new SearchCubePruning(manager, source, transOptColl);
|
||||
case CubeGrowing:
|
||||
return NULL;
|
||||
case NormalBatch:
|
||||
return new SearchNormalBatch(manager, source, transOptColl);
|
||||
default:
|
||||
UserMessage::Add("ERROR: search. Aborting\n");
|
||||
abort();
|
||||
|
@ -32,7 +32,7 @@ protected:
|
||||
// functions for creating hypotheses
|
||||
void ProcessOneHypothesis(const Hypothesis &hypothesis);
|
||||
void ExpandAllHypotheses(const Hypothesis &hypothesis, size_t startPos, size_t endPos);
|
||||
void ExpandHypothesis(const Hypothesis &hypothesis,const TranslationOption &transOpt, float expectedScore);
|
||||
virtual void ExpandHypothesis(const Hypothesis &hypothesis,const TranslationOption &transOpt, float expectedScore);
|
||||
|
||||
public:
|
||||
SearchNormal(Manager& manager, const InputType &source, const TranslationOptionCollection &transOptColl);
|
||||
|
218
moses/src/SearchNormalBatch.cpp
Normal file
218
moses/src/SearchNormalBatch.cpp
Normal file
@ -0,0 +1,218 @@
|
||||
#include "SearchNormalBatch.h"
|
||||
#include "LM/Base.h"
|
||||
#include "Manager.h"
|
||||
|
||||
#include <google/profiler.h>
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace Moses
|
||||
{
|
||||
SearchNormalBatch::SearchNormalBatch(Manager& manager, const InputType &source, const TranslationOptionCollection &transOptColl)
|
||||
:SearchNormal(manager, source, transOptColl)
|
||||
,m_batch_size(10000)
|
||||
{
|
||||
|
||||
// Split the feature functions into sets of stateless, stateful
|
||||
// distributed lm, and stateful non-distributed.
|
||||
const vector<const StatefulFeatureFunction*>& ffs =
|
||||
m_manager.GetTranslationSystem()->GetStatefulFeatureFunctions();
|
||||
for (unsigned i = 0; i < ffs.size(); ++i) {
|
||||
if (ffs[i]->GetScoreProducerDescription() == "DLM_5gram") {
|
||||
m_dlm_ffs[i] = const_cast<LanguageModel*>(static_cast<const LanguageModel* const>(ffs[i]));
|
||||
m_dlm_ffs[i]->SetFFStateIdx(i);
|
||||
}
|
||||
else {
|
||||
m_stateful_ffs[i] = const_cast<StatefulFeatureFunction*>(ffs[i]);
|
||||
}
|
||||
}
|
||||
m_stateless_ffs = const_cast< vector<const StatelessFeatureFunction*>& >(m_manager.GetTranslationSystem()->GetStatelessFeatureFunctions());
|
||||
|
||||
}
|
||||
|
||||
SearchNormalBatch::~SearchNormalBatch() {
|
||||
}
|
||||
|
||||
/**
|
||||
* Main decoder loop that translates a sentence by expanding
|
||||
* hypotheses stack by stack, until the end of the sentence.
|
||||
*/
|
||||
void SearchNormalBatch::ProcessSentence()
|
||||
{
|
||||
const StaticData &staticData = StaticData::Instance();
|
||||
SentenceStats &stats = m_manager.GetSentenceStats();
|
||||
clock_t t=0; // used to track time for steps
|
||||
|
||||
// initial seed hypothesis: nothing translated, no words produced
|
||||
Hypothesis *hypo = Hypothesis::Create(m_manager,m_source, m_initialTargetPhrase);
|
||||
m_hypoStackColl[0]->AddPrune(hypo);
|
||||
|
||||
// go through each stack
|
||||
std::vector < HypothesisStack* >::iterator iterStack;
|
||||
for (iterStack = m_hypoStackColl.begin() ; iterStack != m_hypoStackColl.end() ; ++iterStack) {
|
||||
// check if decoding ran out of time
|
||||
double _elapsed_time = GetUserTime();
|
||||
if (_elapsed_time > staticData.GetTimeoutThreshold()) {
|
||||
VERBOSE(1,"Decoding is out of time (" << _elapsed_time << "," << staticData.GetTimeoutThreshold() << ")" << std::endl);
|
||||
interrupted_flag = 1;
|
||||
return;
|
||||
}
|
||||
HypothesisStackNormal &sourceHypoColl = *static_cast<HypothesisStackNormal*>(*iterStack);
|
||||
|
||||
// the stack is pruned before processing (lazy pruning):
|
||||
VERBOSE(3,"processing hypothesis from next stack");
|
||||
IFVERBOSE(2) {
|
||||
t = clock();
|
||||
}
|
||||
sourceHypoColl.PruneToSize(staticData.GetMaxHypoStackSize());
|
||||
VERBOSE(3,std::endl);
|
||||
sourceHypoColl.CleanupArcList();
|
||||
IFVERBOSE(2) {
|
||||
stats.AddTimeStack( clock()-t );
|
||||
}
|
||||
|
||||
// go through each hypothesis on the stack and try to expand it
|
||||
HypothesisStackNormal::const_iterator iterHypo;
|
||||
for (iterHypo = sourceHypoColl.begin() ; iterHypo != sourceHypoColl.end() ; ++iterHypo) {
|
||||
Hypothesis &hypothesis = **iterHypo;
|
||||
ProcessOneHypothesis(hypothesis); // expand the hypothesis
|
||||
}
|
||||
|
||||
// some logging
|
||||
IFVERBOSE(2) {
|
||||
OutputHypoStackSize();
|
||||
}
|
||||
|
||||
// this stack is fully expanded;
|
||||
actual_hypoStack = &sourceHypoColl;
|
||||
}
|
||||
|
||||
// some more logging
|
||||
IFVERBOSE(2) {
|
||||
m_manager.GetSentenceStats().SetTimeTotal( clock()-m_start );
|
||||
}
|
||||
VERBOSE(2, m_manager.GetSentenceStats());
|
||||
}
|
||||
|
||||
/**
|
||||
* Expand one hypothesis with a translation option.
|
||||
* this involves initial creation, scoring and adding it to the proper stack
|
||||
* \param hypothesis hypothesis to be expanded upon
|
||||
* \param transOpt translation option (phrase translation)
|
||||
* that is applied to create the new hypothesis
|
||||
* \param expectedScore base score for early discarding
|
||||
* (base hypothesis score plus future score estimation)
|
||||
*/
|
||||
void SearchNormalBatch::ExpandHypothesis(const Hypothesis &hypothesis, const TranslationOption &transOpt, float expectedScore)
|
||||
{
|
||||
// Check if the number of partial hypotheses exceeds the batch size.
|
||||
if (m_partial_hypos.size() >= m_batch_size) {
|
||||
EvalAndMergePartialHypos();
|
||||
}
|
||||
|
||||
const StaticData &staticData = StaticData::Instance();
|
||||
SentenceStats &stats = m_manager.GetSentenceStats();
|
||||
clock_t t=0; // used to track time for steps
|
||||
|
||||
Hypothesis *newHypo;
|
||||
if (! staticData.UseEarlyDiscarding()) {
|
||||
// simple build, no questions asked
|
||||
IFVERBOSE(2) {
|
||||
t = clock();
|
||||
}
|
||||
newHypo = hypothesis.CreateNext(transOpt, m_constraint);
|
||||
IFVERBOSE(2) {
|
||||
stats.AddTimeBuildHyp( clock()-t );
|
||||
}
|
||||
if (newHypo==NULL) return;
|
||||
//newHypo->CalcScore(m_transOptColl.GetFutureScore());
|
||||
|
||||
// Issue DLM requests for new hypothesis and put into the list of
|
||||
// partial hypotheses.
|
||||
std::map<int, LanguageModel*>::iterator dlm_iter;
|
||||
for (dlm_iter = m_dlm_ffs.begin();
|
||||
dlm_iter != m_dlm_ffs.end();
|
||||
++dlm_iter) {
|
||||
const FFState* input_state = newHypo->GetPrevHypo() ? newHypo->GetPrevHypo()->GetFFState((*dlm_iter).first) : NULL;
|
||||
(*dlm_iter).second->IssueRequestsFor(*newHypo, input_state);
|
||||
}
|
||||
m_partial_hypos.push_back(newHypo);
|
||||
}
|
||||
else {
|
||||
std::cerr << "can't use early discarding with batch decoding!" << std::endl;
|
||||
abort();
|
||||
}
|
||||
}
|
||||
|
||||
void SearchNormalBatch::EvalAndMergePartialHypos() {
|
||||
std::vector<Hypothesis*>::iterator partial_hypo_iter;
|
||||
for (partial_hypo_iter = m_partial_hypos.begin();
|
||||
partial_hypo_iter != m_partial_hypos.end();
|
||||
++partial_hypo_iter) {
|
||||
Hypothesis* hypo = *partial_hypo_iter;
|
||||
|
||||
// Incorporate the translation option scores.
|
||||
hypo->IncorporateTransOptScores();
|
||||
|
||||
// Evaluate with other ffs.
|
||||
std::map<int, StatefulFeatureFunction*>::iterator sfff_iter;
|
||||
for (sfff_iter = m_stateful_ffs.begin();
|
||||
sfff_iter != m_stateful_ffs.end();
|
||||
++sfff_iter) {
|
||||
hypo->EvaluateWith((*sfff_iter).second, (*sfff_iter).first);
|
||||
}
|
||||
std::vector<const StatelessFeatureFunction*>::iterator slff_iter;
|
||||
for (slff_iter = m_stateless_ffs.begin();
|
||||
slff_iter != m_stateless_ffs.end();
|
||||
++slff_iter) {
|
||||
hypo->EvaluateWith(*slff_iter);
|
||||
}
|
||||
|
||||
// Calculate future score.
|
||||
hypo->CalculateFutureScore(m_transOptColl.GetFutureScore());
|
||||
}
|
||||
|
||||
// Wait for all requests from the distributed LM to come back.
|
||||
std::map<int, LanguageModel*>::iterator dlm_iter;
|
||||
for (dlm_iter = m_dlm_ffs.begin();
|
||||
dlm_iter != m_dlm_ffs.end();
|
||||
++dlm_iter) {
|
||||
(*dlm_iter).second->sync();
|
||||
}
|
||||
|
||||
// Incorporate the DLM scores into all hypotheses and put into their
|
||||
// stacks.
|
||||
for (partial_hypo_iter = m_partial_hypos.begin();
|
||||
partial_hypo_iter != m_partial_hypos.end();
|
||||
++partial_hypo_iter) {
|
||||
Hypothesis* hypo = *partial_hypo_iter;
|
||||
|
||||
// Calculate DLM scores.
|
||||
std::map<int, LanguageModel*>::iterator dlm_iter;
|
||||
for (dlm_iter = m_dlm_ffs.begin();
|
||||
dlm_iter != m_dlm_ffs.end();
|
||||
++dlm_iter) {
|
||||
hypo->EvaluateWith((*dlm_iter).second, (*dlm_iter).first);
|
||||
}
|
||||
|
||||
// Calculate the final score.
|
||||
hypo->CalculateFinalScore();
|
||||
|
||||
// Put completed hypothesis onto its stack.
|
||||
size_t wordsTranslated = hypo->GetWordsBitmap().GetNumWordsCovered();
|
||||
m_hypoStackColl[wordsTranslated]->AddPrune(hypo);
|
||||
}
|
||||
m_partial_hypos.clear();
|
||||
|
||||
|
||||
std::vector < HypothesisStack* >::iterator stack_iter;
|
||||
HypothesisStackNormal* stack;
|
||||
for (stack_iter = m_hypoStackColl.begin();
|
||||
stack_iter != m_hypoStackColl.end();
|
||||
++stack_iter) {
|
||||
stack = static_cast<HypothesisStackNormal*>(*stack_iter);
|
||||
stack->PruneToSize(m_max_stack_size);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
40
moses/src/SearchNormalBatch.h
Normal file
40
moses/src/SearchNormalBatch.h
Normal file
@ -0,0 +1,40 @@
|
||||
#ifndef moses_SearchNormalBatch_h
|
||||
#define moses_SearchNormalBatch_h
|
||||
|
||||
#include "SearchNormal.h"
|
||||
|
||||
namespace Moses
|
||||
{
|
||||
|
||||
class Manager;
|
||||
class InputType;
|
||||
class TranslationOptionCollection;
|
||||
|
||||
class SearchNormalBatch: public SearchNormal
|
||||
{
|
||||
protected:
|
||||
|
||||
// Added for asynclm decoding.
|
||||
std::vector<const StatelessFeatureFunction*> m_stateless_ffs;
|
||||
std::map<int, LanguageModel*> m_dlm_ffs;
|
||||
std::map<int, StatefulFeatureFunction*> m_stateful_ffs;
|
||||
std::vector<Hypothesis*> m_partial_hypos;
|
||||
int m_batch_size;
|
||||
int m_max_stack_size;
|
||||
void EvalAndMerge();
|
||||
|
||||
// functions for creating hypotheses
|
||||
void ExpandHypothesis(const Hypothesis &hypothesis,const TranslationOption &transOpt, float expectedScore);
|
||||
void EvalAndMergePartialHypos();
|
||||
|
||||
public:
|
||||
SearchNormalBatch(Manager& manager, const InputType &source, const TranslationOptionCollection &transOptColl);
|
||||
~SearchNormalBatch();
|
||||
|
||||
void ProcessSentence();
|
||||
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
#endif
|
@ -170,6 +170,7 @@ enum SearchAlgorithm {
|
||||
,CubePruning = 1
|
||||
,CubeGrowing = 2
|
||||
,ChartDecoding= 3
|
||||
,NormalBatch = 4
|
||||
};
|
||||
|
||||
enum SourceLabelOverlap {
|
||||
|
Loading…
Reference in New Issue
Block a user