added CacheFreqWords() to speed up decoding with suffix array PTs

git-svn-id: https://mosesdecoder.svn.sourceforge.net/svnroot/mosesdecoder/trunk@4064 1f5c12ca-751b-0410-a591-d2e778427230
This commit is contained in:
leven101 2011-07-01 14:36:28 +00:00
parent 126739f3f1
commit e0174b413c
2 changed files with 59 additions and 36 deletions

View File

@ -56,13 +56,16 @@ bool BilingualDynSuffixArray::Load(
m_srcSA = new DynSuffixArray(m_srcCorpus);
if(!m_srcSA) return false;
cerr << "Building Target Suffix Array...\n";
m_trgSA = new DynSuffixArray(m_trgCorpus);
if(!m_trgSA) return false;
//m_trgSA = new DynSuffixArray(m_trgCorpus);
//if(!m_trgSA) return false;
cerr << "\t(Skipped. Not used)\n";
InputFileStream alignStrme(alignments);
cerr << "Loading Alignment File...\n";
LoadRawAlignments(alignStrme);
//LoadAlignments(alignStrme);
cerr << "Building frequent word cache...\n";
CacheFreqWords();
return true;
}
@ -169,7 +172,7 @@ bool BilingualDynSuffixArray::ExtractPhrases(const int& sntIndex, const int& wor
void BilingualDynSuffixArray::CleanUp()
{
m_wordPairCache.clear();
//m_wordPairCache.clear();
}
int BilingualDynSuffixArray::LoadCorpus(InputFileStream& corpus, const FactorList& factors,
@ -218,7 +221,7 @@ bool BilingualDynSuffixArray::GetLocalVocabIDs(const Phrase& src, SAPhrase &outp
pair<float, float> BilingualDynSuffixArray::GetLexicalWeight(const PhrasePair& phrasepair) const
{
//return pair<float, float>(0, 0);
//return pair<float, float>(1, 1);
float srcLexWeight(1.0), trgLexWeight(1.0);
std::map<pair<wordID_t, wordID_t>, float> targetProbs; // collect sum of target probs given source words
//const SentenceAlignment& alignment = m_alignments[phrasepair.m_sntIndex];
@ -229,8 +232,9 @@ pair<float, float> BilingualDynSuffixArray::GetLexicalWeight(const PhrasePair& p
float srcSumPairProbs(0);
wordID_t srcWord = m_srcCorpus->at(srcIdx + m_srcSntBreaks[phrasepair.m_sntIndex]); // localIDs
const std::vector<int>& srcWordAlignments = alignment.alignedList.at(srcIdx);
// for each target word aligned to this source word in this alignment
if(srcWordAlignments.size() == 0) { // get p(NULL|src)
pair<wordID_t, wordID_t> wordpair = std::make_pair(srcWord, m_srcVocab->GetkOOVWordID());
pair<wordID_t, wordID_t> wordpair = make_pair(srcWord, m_srcVocab->GetkOOVWordID());
itrCache = m_wordPairCache.find(wordpair);
if(itrCache == m_wordPairCache.end()) { // if not in cache
CacheWordProbs(srcWord);
@ -245,10 +249,10 @@ pair<float, float> BilingualDynSuffixArray::GetLexicalWeight(const PhrasePair& p
int trgIdx = srcWordAlignments[i];
wordID_t trgWord = m_trgCorpus->at(trgIdx + m_trgSntBreaks[phrasepair.m_sntIndex]);
// get probability of this source->target word pair
pair<wordID_t, wordID_t> wordpair = std::make_pair(srcWord, trgWord);
pair<wordID_t, wordID_t> wordpair = make_pair(srcWord, trgWord);
itrCache = m_wordPairCache.find(wordpair);
if(itrCache == m_wordPairCache.end()) { // if not in cache
CacheWordProbs(srcWord);
CacheWordProbs(srcWord);
itrCache = m_wordPairCache.find(wordpair); // search cache again
}
assert(itrCache != m_wordPairCache.end());
@ -275,14 +279,35 @@ pair<float, float> BilingualDynSuffixArray::GetLexicalWeight(const PhrasePair& p
// TODO::Need to get p(NULL|trg)
return pair<float, float>(srcLexWeight, trgLexWeight);
}
void BilingualDynSuffixArray::CacheFreqWords() const {
std::multimap<int, wordID_t> wordCnts;
// for each source word in vocab
Vocab::Word2Id::const_iterator it;
for(it = m_srcVocab->VocabStart(); it != m_srcVocab->VocabEnd(); ++it) {
// get its frequency
wordID_t srcWord = it->second;
std::vector<wordID_t> sword(1, srcWord), wrdIndices;
m_srcSA->GetCorpusIndex(&sword, &wrdIndices);
if(wrdIndices.size() >= 1000) { // min count
wordCnts.insert(make_pair(wrdIndices.size(), srcWord));
}
}
int numSoFar(0);
std::multimap<int, wordID_t>::reverse_iterator ritr;
for(ritr = wordCnts.rbegin(); ritr != wordCnts.rend(); ++ritr) {
m_freqWordsCached.insert(ritr->second);
CacheWordProbs(ritr->second);
if(++numSoFar == 50) break; // get top counts
}
cerr << "\tCached " << m_freqWordsCached.size() << " source words\n";
}
void BilingualDynSuffixArray::CacheWordProbs(wordID_t srcWord) const
{
std::map<wordID_t, int> counts;
std::vector<wordID_t> vword(1, srcWord), wrdIndices;
bool ret = m_srcSA->GetCorpusIndex(&vword, &wrdIndices);
std::vector<wordID_t> sword(1, srcWord), wrdIndices;
bool ret = m_srcSA->GetCorpusIndex(&sword, &wrdIndices);
assert(ret);
std::vector<int> sntIndexes = GetSntIndexes(wrdIndices, 1);
std::vector<int> sntIndexes = GetSntIndexes(wrdIndices, 1, m_srcSntBreaks);
float denom(0);
// for each occurrence of this word
for(size_t snt = 0; snt < sntIndexes.size(); ++snt) {
@ -290,7 +315,6 @@ void BilingualDynSuffixArray::CacheWordProbs(wordID_t srcWord) const
assert(sntIdx != -1);
int srcWrdSntIdx = wrdIndices.at(snt) - m_srcSntBreaks.at(sntIdx); // get word index in sentence
const std::vector<int> srcAlg = GetSentenceAlignment(sntIdx).alignedList.at(srcWrdSntIdx); // list of target words for this source word
//const std::vector<int>& srcAlg = m_alignments.at(sntIdx).alignedList.at(srcWrdSntIdx); // list of target words for this source word
if(srcAlg.size() == 0) {
++counts[m_srcVocab->GetkOOVWordID()]; // if not alligned then align to NULL word
++denom;
@ -307,7 +331,7 @@ void BilingualDynSuffixArray::CacheWordProbs(wordID_t srcWord) const
// get probs and cache all pairs
for(std::map<wordID_t, int>::const_iterator itrCnt = counts.begin();
itrCnt != counts.end(); ++itrCnt) {
pair<wordID_t, wordID_t> wordPair = std::make_pair(srcWord, itrCnt->first);
pair<wordID_t, wordID_t> wordPair = make_pair(srcWord, itrCnt->first);
float srcTrgPrb = float(itrCnt->second) / float(denom); // gives p(src->trg)
float trgSrcPrb = float(itrCnt->second) / float(counts.size()); // gives p(trg->src)
m_wordPairCache[wordPair] = pair<float, float>(srcTrgPrb, trgSrcPrb);
@ -347,12 +371,14 @@ void BilingualDynSuffixArray::GetTargetPhrasesByLexicalWeight(const Phrase& src,
if(!GetLocalVocabIDs(src, localIDs)) return;
float totalTrgPhrases(0);
std::map<SAPhrase, int> phraseCounts;
//std::map<SAPhrase, PhrasePair> phraseColl; // (one of) the word indexes this phrase was taken from
std::map<SAPhrase, pair<float, float> > lexicalWeights;
std::map<SAPhrase, pair<float, float> >::iterator itrLexW;
std::vector<unsigned> wrdIndices;
// extract sentence IDs from SA and return rightmost index of phrases
if(!m_srcSA->GetCorpusIndex(&(localIDs.words), &wrdIndices)) return;
std::vector<int> sntIndexes = GetSntIndexes(wrdIndices, sourceSize);
SampleSelection(wrdIndices);
std::vector<int> sntIndexes = GetSntIndexes(wrdIndices, sourceSize, m_srcSntBreaks);
// for each sentence with this phrase
for(size_t snt = 0; snt < sntIndexes.size(); ++snt) {
std::vector<PhrasePair*> phrasePairs; // to store all phrases possible from current sentence
@ -365,8 +391,8 @@ void BilingualDynSuffixArray::GetTargetPhrasesByLexicalWeight(const Phrase& src,
for (iterPhrasePair = phrasePairs.begin(); iterPhrasePair != phrasePairs.end(); ++iterPhrasePair) {
SAPhrase phrase = TrgPhraseFromSntIdx(**iterPhrasePair);
phraseCounts[phrase]++; // count each unique phrase
// TODO::Inefficient to extract lexical weight here. Should do it later
// once the top phrases have been chosen by phrase prob p(e|f)
// NOTE::Correct but slow to extract lexical weight here. could do
// it later for only the top phrases chosen by phrase prob p(e|f)
pair<float, float> lexWeight = GetLexicalWeight(**iterPhrasePair); // get lexical weighting for this phrase pair
itrLexW = lexicalWeights.find(phrase); // check if phrase already has lexical weight attached
if((itrLexW != lexicalWeights.end()) && (itrLexW->second.first < lexWeight.first))
@ -388,29 +414,28 @@ void BilingualDynSuffixArray::GetTargetPhrasesByLexicalWeight(const Phrase& src,
scoreVector[0] = trg2SrcMLE;
scoreVector[1] = itrLexW->second.first;
scoreVector[2] = 2.718; // exp(1);
phraseScores.insert(pair<Scores, const SAPhrase*>(scoreVector, &iterPhrases->first));
phraseScores.insert(make_pair(scoreVector, &iterPhrases->first));
}
// return top scoring phrases
std::multimap<Scores, const SAPhrase*, ScoresComp>::reverse_iterator ritr;
for(ritr = phraseScores.rbegin(); ritr != phraseScores.rend(); ++ritr) {
Scores scoreVector = ritr->first;
TargetPhrase *targetPhrase = GetMosesFactorIDs(*ritr->second);
target.push_back( make_pair( scoreVector, targetPhrase));
target.push_back(make_pair( scoreVector, targetPhrase));
if(target.size() == m_maxSampleSize) break;
}
return;
}
std::vector<int> BilingualDynSuffixArray::GetSntIndexes(std::vector<unsigned>& wrdIndices,
const int sourceSize) const
const int sourceSize, const std::vector<unsigned>& sntBreaks) const
{
std::vector<unsigned>::const_iterator vit;
std::vector<int> sntIndexes;
for(size_t i=0; i < wrdIndices.size(); ++i) {
vit = std::upper_bound(m_srcSntBreaks.begin(), m_srcSntBreaks.end(), wrdIndices[i]);
int index = int(vit - m_srcSntBreaks.begin()) - 1;
vit = std::upper_bound(sntBreaks.begin(), sntBreaks.end(), wrdIndices[i]);
int index = int(vit - sntBreaks.begin()) - 1;
// check for phrases that cross sentence boundaries
if(wrdIndices[i] - sourceSize + 1 < m_srcSntBreaks.at(index))
if(wrdIndices[i] - sourceSize + 1 < sntBreaks.at(index))
sntIndexes.push_back(-1); // set bad flag
else
sntIndexes.push_back(index); // store the index of the sentence in the corpus
@ -418,17 +443,13 @@ std::vector<int> BilingualDynSuffixArray::GetSntIndexes(std::vector<unsigned>& w
return sntIndexes;
}
std::vector<unsigned> BilingualDynSuffixArray::SampleSelection(std::vector<unsigned> sample) const
int BilingualDynSuffixArray::SampleSelection(std::vector<unsigned>& sample,
int sampleSize) const
{
//sample.erase(wrdIndices.begin()+m_maxSampleSize, wrdIndices.end());
//return sample;
int size = sample.size();
//if(size < m_maxSampleSize) return sample;
std::vector<unsigned> subSample;
int jump = size / m_maxSampleSize;
for(int i=0; i < size; i+=jump)
subSample.push_back(sample.at(i));
return subSample;
// only use top 'sampleSize' number of samples
if(sample.size() > sampleSize)
sample.erase(sample.begin()+sampleSize, sample.end());
return sample.size();
}
void BilingualDynSuffixArray::addSntPair(string& source, string& target, string& alignment) {
@ -458,7 +479,7 @@ void BilingualDynSuffixArray::addSntPair(string& source, string& target, string&
}
m_trgSntBreaks.push_back(oldTrgCrpSize);
m_srcSA->Insert(&srcFactor, oldSrcCrpSize);
m_trgSA->Insert(&trgFactor, oldTrgCrpSize);
//m_trgSA->Insert(&trgFactor, oldTrgCrpSize);
LoadRawAlignments(alignment);
m_trgVocab->MakeClosed();
}

View File

@ -105,6 +105,7 @@ private:
std::vector<std::vector<short> > m_rawAlignments;
mutable std::map<std::pair<wordID_t, wordID_t>, std::pair<float, float> > m_wordPairCache;
mutable std::set<wordID_t> m_freqWordsCached;
const size_t m_maxPhraseLength, m_maxSampleSize;
int LoadCorpus(InputFileStream&, const std::vector<FactorType>& factors,
@ -116,13 +117,14 @@ private:
bool ExtractPhrases(const int&, const int&, const int&, std::vector<PhrasePair*>&, bool=false) const;
SentenceAlignment GetSentenceAlignment(const int, bool=false) const;
std::vector<unsigned> SampleSelection(std::vector<unsigned>) const;
int SampleSelection(std::vector<unsigned>&, int = 300) const;
std::vector<int> GetSntIndexes(std::vector<unsigned>&, const int) const;
std::vector<int> GetSntIndexes(std::vector<unsigned>&, int, const std::vector<unsigned>&) const;
TargetPhrase* GetMosesFactorIDs(const SAPhrase&) const;
SAPhrase TrgPhraseFromSntIdx(const PhrasePair&) const;
bool GetLocalVocabIDs(const Phrase&, SAPhrase &) const;
void CacheWordProbs(wordID_t) const;
void CacheFreqWords() const;
std::pair<float, float> GetLexicalWeight(const PhrasePair&) const;
int GetSourceSentenceSize(size_t sentenceId) const