added LM updates to mosesserver (only for LanguageModelORLM)

git-svn-id: https://mosesdecoder.svn.sourceforge.net/svnroot/mosesdecoder/trunk@4005 1f5c12ca-751b-0410-a591-d2e778427230
This commit is contained in:
leven101 2011-06-09 17:27:48 +00:00
parent 4bf85266d8
commit 894b49a5b2
9 changed files with 95 additions and 16 deletions

View File

@ -341,6 +341,7 @@ TargetPhrase* BilingualDynSuffixArray::GetMosesFactorIDs(const SAPhrase& phrase)
void BilingualDynSuffixArray::GetTargetPhrasesByLexicalWeight(const Phrase& src, std::vector< std::pair<Scores, TargetPhrase*> > & target) const
{
//cerr << "phrase is \"" << src << endl;
size_t sourceSize = src.GetSize();
SAPhrase localIDs(sourceSize);
if(!GetLocalVocabIDs(src, localIDs)) return;
@ -348,11 +349,9 @@ void BilingualDynSuffixArray::GetTargetPhrasesByLexicalWeight(const Phrase& src,
std::map<SAPhrase, int> phraseCounts;
std::map<SAPhrase, pair<float, float> > lexicalWeights;
std::map<SAPhrase, pair<float, float> >::iterator itrLexW;
std::vector<unsigned> wrdIndices(0);
std::vector<unsigned> wrdIndices;
// extract sentence IDs from SA and return rightmost index of phrases
if(!m_srcSA->GetCorpusIndex(&(localIDs.words), &wrdIndices)) return;
//if(wrdIndices.size() > m_maxSampleSize) // select only first samples
//wrdIndices = SampleSelection(wrdIndices);
std::vector<int> sntIndexes = GetSntIndexes(wrdIndices, sourceSize);
// for each sentence with this phrase
for(size_t snt = 0; snt < sntIndexes.size(); ++snt) {

View File

@ -30,7 +30,6 @@ public:
cerr << "Initialzing auxillary bit filters...\n";
bPrefix_ = new BitFilter(this->cells_);
bHit_ = new BitFilter(this->cells_);
//bConsensus_ = new BitFilter(this->cells_);
}
OnlineRLM(FileHandler* fin, count_t order):
PerfectHash<T>(fin), bAdapting_(true), order_(order), corpusSize_(0) {
@ -39,7 +38,6 @@ public:
alpha_ = new float[order_ + 1];
for(count_t i = 0; i <= order_; ++i)
alpha_[i] = i * log10(0.4);
//bConsensus_ = new BitFilter(this->cells_);
}
~OnlineRLM() {
if(alpha_) delete[] alpha_;
@ -48,7 +46,6 @@ public:
if(cache_) delete cache_;
delete bPrefix_;
delete bHit_;
//delete bConsensus_;
}
float getProb(const wordID_t* ngram, int len, const void** state);
//float getProb2(const wordID_t* ngram, int len, const void** state);
@ -88,7 +85,6 @@ private:
Cache<float>* cache_;
BitFilter* bPrefix_;
BitFilter* bHit_;
//BitFilter* bConsensus_;
};
template<typename T>
bool OnlineRLM<T>::insert(const std::vector<string>& ngram, const int value) {
@ -113,6 +109,7 @@ bool OnlineRLM<T>::update(const std::vector<string>& ngram, const int value) {
wordID_t wrdIDs[len];
uint64_t index(this->cells_ + 1);
hpdEntry_t hpdItr;
vocab_->MakeOpen();
for(int i = 0; i < len; ++i)
wrdIDs[i] = vocab_->GetWordID(ngram[i]);
// if updating, minimize false positives by pre-checking if context already in model
@ -120,10 +117,9 @@ bool OnlineRLM<T>::update(const std::vector<string>& ngram, const int value) {
if(value > 1 && len < (int)order_)
bIncluded = markPrefix(wrdIDs, ngram.size(), true); // mark context
if(bIncluded) { // if context found
bIncluded = PerfectHash<T>::update(wrdIDs, len, value, hpdItr, index);
bIncluded = PerfectHash<T>::update2(wrdIDs, len, value, hpdItr, index);
if(index < this->cells_) {
markQueried(index);
//bConsensus_->setBit(index); // update implies 2nd stream
}
else if(hpdItr != this->dict_.end()) markQueried(hpdItr);
}
@ -207,7 +203,6 @@ count_t OnlineRLM<T>::heurDelete(count_t num2del, count_t order) {
last = first + this->bucketRange_;
for(uint64_t row = first; row < last; ++row) { // check each row
if(!(bHit_->testBit(row) || bPrefix_->testBit(row) )) {
//|| bConsensus_->testBit(row))) { // if not hit or prefix or consensus
if(this->filter_->read(row) != 0) {
PerfectHash<T>::remove(row); // remove from filter
++deleted;
@ -377,7 +372,6 @@ void OnlineRLM<T>::clearMarkings() {
value = &itr->second;
*value -= ((*value & this->hitMask_) != 0) ? this->hitMask_ : 0;
}
//bConsensus_->reset();
}
template<typename T>
void OnlineRLM<T>::save(FileHandler* fout) {
@ -412,8 +406,8 @@ void OnlineRLM<T>::removeNonMarked() {
cerr << "deleting all unused events\n";
int deleted(0);
for(uint64_t i = 0; i < this->cells_; ++i) {
if(!(bHit_->testBit(i) || bPrefix_->testBit(i) /*||
bConsensus_->testBit(i)*/) && (this->filter_->read(i) != 0)) {
if(!(bHit_->testBit(i) || bPrefix_->testBit(i))
&& (this->filter_->read(i) != 0)) {
PerfectHash<T>::remove(i);
++deleted;
}

View File

@ -400,7 +400,7 @@ bool PerfectHash<T>::update2(const wordID_t* IDs, const int len,
}
++index;
}
// could add if it gets here.
// add if it gets here.
insert(IDs, len, value);
return false;
}

View File

@ -206,6 +206,8 @@ bool DynSuffixArray::GetCorpusIndex(const vuint_t* phrase, vuint_t* indices)
// bounds holds first and (last + 1) index of phrase[0] in m_SA
size_t lwrBnd = size_t(bounds.first - m_F->begin());
size_t uprBnd = size_t(bounds.second - m_F->begin());
//cerr << "phrasesize = " << phrasesize << "\tuprBnd = " << uprBnd << "\tlwrBnd = " << lwrBnd;
//cerr << "\tcorpus size = " << m_corpus->size() << endl;
if(uprBnd - lwrBnd == 0) return false; // not found
if(phrasesize == 1) {
for(size_t i=lwrBnd; i < uprBnd; ++i) {

View File

@ -125,6 +125,15 @@ public:
const FFState* prev_state,
ScoreComponentCollection* accumulator) const;
#ifdef WITH_THREADS
// if multi-threaded return boost ptr
boost::shared_ptr<LanguageModelImplementation>
#else // return normal LM ptr
LanguageModelImplementation*
#endif
GetLMImplementation() const {
return m_implementation;
}
};
}

View File

@ -71,12 +71,15 @@ LMResult LanguageModelORLM::GetValue(const std::vector<const Word*> &contextFact
State* finalState) const {
FactorType factorType = GetFactorType();
// set up context
//std::vector<long unsigned int> factor(1,0);
//std::vector<string> sngram;
wordID_t ngram[MAX_NGRAM_SIZE];
int count = contextFactor.size();
for (int i = 0; i < count; i++) {
ngram[i] = GetLmID((*contextFactor[i])[factorType]);
//sngram.push_back(contextFactor[i]->GetString(factor, false));
}
//float logprob = FloorScore(TransformLMScore(m_lm->getProb(&ngram[0], count, finalState)));
//float logprob = FloorScore(TransformLMScore(lm_->getProb(sngram, count, finalState)));
LMResult ret;
ret.score = FloorScore(TransformLMScore(m_lm->getProb(&ngram[0], count, finalState)));
ret.unknown = count && (ngram[count - 1] == m_oov_id);
@ -87,4 +90,14 @@ LMResult LanguageModelORLM::GetValue(const std::vector<const Word*> &contextFact
*/
return ret;
}
bool LanguageModelORLM::UpdateORLM(const std::vector<string>& ngram, const int value) {
/*cerr << "Inserting into ORLM: \"";
iterate(ngram, nit)
cerr << *nit << " ";
cerr << "\"\t" << value << endl; */
m_lm->vocab_->MakeOpen();
bool res = m_lm->update(ngram, value);
m_lm->vocab_->MakeClosed();
return res;
}
}

View File

@ -35,6 +35,7 @@ public:
//m_lm->initThreadSpecificData(); // Creates thread specific data iff
// compiled with multithreading.
}
bool UpdateORLM(const std::vector<string>& ngram, const int value);
protected:
OnlineRLM<T>* m_lm;
//MultiOnlineRLM<T>* m_lm;

View File

@ -420,6 +420,9 @@ public:
SearchAlgorithm GetSearchAlgorithm() const {
return m_searchAlgorithm;
}
LMList GetLMList() const {
return m_languageModel;
}
size_t GetNumInputScores() const {
return m_numInputScores;
}

View File

@ -11,6 +11,10 @@
#include "StaticData.h"
#include "PhraseDictionaryDynSuffixArray.h"
#include "TranslationSystem.h"
#include "LMList.h"
#ifdef LM_ORLM
# include "LanguageModelORLM.h"
#endif
using namespace Moses;
using namespace std;
@ -49,6 +53,9 @@ public:
PhraseDictionaryDynSuffixArray* pdsa = (PhraseDictionaryDynSuffixArray*) pdf->GetDictionary();
cerr << "Inserting into address " << pdsa << endl;
pdsa->insertSnt(source_, target_, alignment_);
if(add2ORLM_) {
updateORLM();
}
cerr << "Done inserting\n";
//PhraseDictionary* pdsa = (PhraseDictionary*) pdf->GetDictionary(*dummy);
map<string, xmlrpc_c::value> retData;
@ -58,7 +65,56 @@ public:
*retvalP = xmlrpc_c::value_string("Phrase table updated");
}
string source_, target_, alignment_;
bool bounded_;
bool bounded_, add2ORLM_;
void updateORLM() {
#ifdef LM_ORLM
vector<string> vl;
map<vector<string>, int> ngSet;
LMList lms = StaticData::Instance().GetLMList(); // get LM
LMList::const_iterator lmIter = lms.begin();
const LanguageModel* lm = *lmIter;
/* currently assumes a single LM that is a ORLM */
#ifdef WITH_THREADS
boost::shared_ptr<LanguageModelORLM> orlm;
orlm = boost::dynamic_pointer_cast<LanguageModelORLM>(lm->GetLMImplementation());
#else
LanguageModelORLM* orlm;
orlm = (LanguageModelORLM*)lm->GetLMImplementation();
#endif
if(orlm == 0) {
cerr << "WARNING: Unable to add target sentence to ORLM\n";
return;
}
// break out new ngrams from sentence
const int ngOrder(orlm->GetNGramOrder());
const std::string sBOS = orlm->GetSentenceStart()->GetString();
const std::string sEOS = orlm->GetSentenceEnd()->GetString();
Utils::splitToStr(target_, vl, " ");
// insert BOS and EOS
vl.insert(vl.begin(), sBOS);
vl.insert(vl.end(), sEOS);
for(int j=0; j < vl.size(); ++j) {
int i = (j<ngOrder) ? 0 : j-ngOrder+1;
for(int t=j; t >= i; --t) {
vector<string> ngVec;
for(int s=t; s<=j; ++s) {
ngVec.push_back(vl[s]);
//cerr << vl[s] << " ";
}
ngSet[ngVec]++;
//cerr << endl;
}
}
// insert into LM in order from 1grams up (for LM well-formedness)
cerr << "Inserting " << ngSet.size() << " ngrams into ORLM...\n";
for(int i=1; i <= ngOrder; ++i) {
iterate(ngSet, it) {
if(it->first.size() == i)
orlm->UpdateORLM(it->first, it->second);
}
}
#endif
}
void breakOutParams(const params_t& params) {
params_t::const_iterator si = params.find("source");
if(si == params.end())
@ -77,6 +133,8 @@ public:
cerr << "alignment = " << alignment_ << endl;
si = params.find("bounded");
bounded_ = (si != params.end());
si = params.find("updateORLM");
add2ORLM_ = (si != params.end());
}
};