mosesdecoder/contrib/other-builds/moses2/TranslationModel/ProbingPT.cpp

198 lines
5.3 KiB
C++
Raw Normal View History

2015-11-03 18:04:26 +03:00
/*
* ProbingPT.cpp
*
* Created on: 3 Nov 2015
* Author: hieu
*/
#include "ProbingPT.h"
2015-11-04 16:09:53 +03:00
#include "../System.h"
#include "../Scores.h"
#include "../FF/FeatureFunctions.h"
#include "../Search/Manager.h"
2015-11-03 18:04:26 +03:00
#include "moses/FactorCollection.h"
#include "moses/TranslationModel/ProbingPT/quering.hh"
using namespace std;
ProbingPT::ProbingPT(size_t startInd, const std::string &line)
:PhraseTable(startInd, line)
{
ReadParameters();
}
2015-11-04 03:37:35 +03:00
ProbingPT::~ProbingPT()
{
delete m_engine;
2015-11-03 18:04:26 +03:00
}
void ProbingPT::Load(System &system)
{
m_engine = new QueryEngine(m_path.c_str());
m_unkId = 456456546456;
2015-11-13 01:51:13 +03:00
FactorCollection &vocab = system.vocab;
2015-11-03 18:04:26 +03:00
// source vocab
const std::map<uint64_t, std::string> &sourceVocab = m_engine->getSourceVocab();
std::map<uint64_t, std::string>::const_iterator iterSource;
for (iterSource = sourceVocab.begin(); iterSource != sourceVocab.end(); ++iterSource) {
const string &wordStr = iterSource->second;
2015-11-13 01:51:13 +03:00
const Factor *factor = vocab.AddFactor(wordStr);
2015-11-03 18:04:26 +03:00
uint64_t probingId = iterSource->first;
SourceVocabMap::value_type entry(factor, probingId);
m_sourceVocabMap.insert(entry);
}
// target vocab
const std::map<unsigned int, std::string> &probingVocab = m_engine->getVocab();
std::map<unsigned int, std::string>::const_iterator iter;
for (iter = probingVocab.begin(); iter != probingVocab.end(); ++iter) {
const string &wordStr = iter->second;
2015-11-13 01:51:13 +03:00
const Factor *factor = vocab.AddFactor(wordStr);
2015-11-03 18:04:26 +03:00
unsigned int probingId = iter->first;
TargetVocabMap::value_type entry(factor, probingId);
m_vocabMap.insert(entry);
}
}
2015-11-13 01:51:13 +03:00
const Factor *ProbingPT::GetTargetFactor(uint64_t probingId) const
2015-11-03 18:04:26 +03:00
{
TargetVocabMap::right_map::const_iterator iter;
iter = m_vocabMap.right.find(probingId);
if (iter != m_vocabMap.right.end()) {
return iter->second;
} else {
// not in mapping. Must be UNK
return NULL;
}
}
2015-11-13 01:51:13 +03:00
uint64_t ProbingPT::GetSourceProbingId(const Factor *factor) const
2015-11-03 18:04:26 +03:00
{
SourceVocabMap::left_map::const_iterator iter;
iter = m_sourceVocabMap.left.find(factor);
if (iter != m_sourceVocabMap.left.end()) {
return iter->second;
} else {
// not in mapping. Must be UNK
return m_unkId;
}
}
2015-11-03 19:09:49 +03:00
TargetPhrases::shared_const_ptr ProbingPT::Lookup(const Manager &mgr, InputPath &inputPath) const
2015-11-03 18:04:26 +03:00
{
2015-11-06 22:06:41 +03:00
const Phrase &sourcePhrase = inputPath.subPhrase;
2015-11-09 03:10:07 +03:00
TargetPhrases::shared_const_ptr ret = CreateTargetPhrase(mgr.GetPool(), mgr.system, sourcePhrase);
2015-11-04 03:37:35 +03:00
return ret;
2015-11-03 19:09:49 +03:00
}
TargetPhrases::shared_ptr ProbingPT::CreateTargetPhrase(MemPool &pool, const System &system, const Phrase &sourcePhrase) const
{
2015-11-03 18:04:26 +03:00
// create a target phrase from the 1st word of the source, prefix with 'ProbingPT:'
assert(sourcePhrase.GetSize());
2015-11-06 13:55:04 +03:00
TargetPhrases::shared_ptr tpSharedPtr;
2015-11-03 18:04:26 +03:00
bool ok;
vector<uint64_t> probingSource = ConvertToProbingSourcePhrase(sourcePhrase, ok);
if (!ok) {
// source phrase contains a word unknown in the pt.
// We know immediately there's no translation for it
2015-11-06 13:55:04 +03:00
return tpSharedPtr;
2015-11-03 18:04:26 +03:00
}
std::pair<bool, std::vector<target_text> > query_result;
//Actual lookup
query_result = m_engine->query(probingSource);
if (query_result.first) {
//m_engine->printTargetInfo(query_result.second);
2015-11-06 13:55:04 +03:00
tpSharedPtr.reset(new TargetPhrases());
2015-11-03 18:04:26 +03:00
const std::vector<target_text> &probingTargetPhrases = query_result.second;
for (size_t i = 0; i < probingTargetPhrases.size(); ++i) {
const target_text &probingTargetPhrase = probingTargetPhrases[i];
2015-11-03 19:09:49 +03:00
TargetPhrase *tp = CreateTargetPhrase(pool, system, sourcePhrase, probingTargetPhrase);
2015-11-03 18:04:26 +03:00
2015-11-06 13:55:04 +03:00
tpSharedPtr->AddTargetPhrase(*tp);
2015-11-03 18:04:26 +03:00
}
2015-11-06 13:55:04 +03:00
tpSharedPtr->SortAndPrune(m_tableLimit);
2015-11-03 18:04:26 +03:00
}
2015-11-06 13:55:04 +03:00
return tpSharedPtr;
2015-11-03 19:09:49 +03:00
2015-11-03 18:04:26 +03:00
}
2015-11-03 19:09:49 +03:00
TargetPhrase *ProbingPT::CreateTargetPhrase(MemPool &pool, const System &system, const Phrase &sourcePhrase, const target_text &probingTargetPhrase) const
2015-11-03 18:04:26 +03:00
{
2015-11-03 19:09:49 +03:00
const std::vector<unsigned int> &probingPhrase = probingTargetPhrase.target_phrase;
size_t size = probingPhrase.size();
TargetPhrase *tp = new (pool.Allocate<TargetPhrase>()) TargetPhrase(pool, system, size);
// words
for (size_t i = 0; i < size; ++i) {
uint64_t probingId = probingPhrase[i];
2015-11-13 01:51:13 +03:00
const Factor *factor = GetTargetFactor(probingId);
2015-11-03 19:09:49 +03:00
assert(factor);
Word &word = (*tp)[i];
word[0] = factor;
}
// score for this phrase table
vector<SCORE> scores = probingTargetPhrase.prob;
std::transform(scores.begin(), scores.end(), scores.begin(), Moses::TransformScore);
tp->GetScores().PlusEquals(system, *this, scores);
2015-11-03 18:04:26 +03:00
// // alignment
// /*
// const std::vector<unsigned char> &alignments = probingTargetPhrase.word_all1;
//
// AlignmentInfo &aligns = tp->GetAlignTerm();
// for (size_t i = 0; i < alignS.size(); i += 2 ) {
// aligns.Add((size_t) alignments[i], (size_t) alignments[i+1]);
// }
// */
2015-11-03 19:09:49 +03:00
// score of all other ff when this rule is being loaded
2015-11-05 14:19:37 +03:00
const FeatureFunctions &ffs = system.featureFunctions;
2015-11-03 19:09:49 +03:00
ffs.EvaluateInIsolation(pool, system, sourcePhrase, *tp);
return tp;
2015-11-03 18:04:26 +03:00
}
std::vector<uint64_t> ProbingPT::ConvertToProbingSourcePhrase(const Phrase &sourcePhrase, bool &ok) const
{
size_t size = sourcePhrase.GetSize();
std::vector<uint64_t> ret(size);
for (size_t i = 0; i < size; ++i) {
2015-11-13 01:51:13 +03:00
const Factor *factor = sourcePhrase[i][0];
2015-11-03 18:04:26 +03:00
uint64_t probingId = GetSourceProbingId(factor);
if (probingId == m_unkId) {
ok = false;
return ret;
} else {
ret[i] = probingId;
}
}
ok = true;
return ret;
}