mosesdecoder/contrib/other-builds/moses2/TranslationModel/ProbingPT.cpp

310 lines
7.7 KiB
C++
Raw Normal View History

2015-11-03 18:04:26 +03:00
/*
* ProbingPT.cpp
*
* Created on: 3 Nov 2015
* Author: hieu
*/
2015-12-07 21:30:17 +03:00
#include <boost/foreach.hpp>
2015-11-03 18:04:26 +03:00
#include "ProbingPT.h"
2015-11-04 16:09:53 +03:00
#include "../System.h"
#include "../Scores.h"
2016-01-15 21:18:33 +03:00
#include "../Phrase.h"
#include "../legacy/InputFileStream.h"
#include "../legacy/ProbingPT/probing_hash_utils.hh"
2015-11-04 16:09:53 +03:00
#include "../FF/FeatureFunctions.h"
#include "../Search/Manager.h"
2015-11-13 03:05:54 +03:00
#include "../legacy/FactorCollection.h"
#include "../legacy/ProbingPT/quering.hh"
2015-11-13 13:40:55 +03:00
#include "../legacy/Util2.h"
2015-11-03 18:04:26 +03:00
using namespace std;
2015-12-10 23:49:30 +03:00
namespace Moses2
{
2015-11-03 18:04:26 +03:00
ProbingPT::ProbingPT(size_t startInd, const std::string &line)
:PhraseTable(startInd, line)
{
ReadParameters();
}
2015-11-04 03:37:35 +03:00
ProbingPT::~ProbingPT()
{
delete m_engine;
2015-11-03 18:04:26 +03:00
}
void ProbingPT::Load(System &system)
{
m_engine = new QueryEngine(m_path.c_str());
m_unkId = 456456546456;
2015-11-03 18:04:26 +03:00
2015-11-18 18:33:42 +03:00
FactorCollection &vocab = system.GetVocab();
2015-11-03 18:04:26 +03:00
// source vocab
const std::map<uint64_t, std::string> &sourceVocab = m_engine->getSourceVocab();
std::map<uint64_t, std::string>::const_iterator iterSource;
for (iterSource = sourceVocab.begin(); iterSource != sourceVocab.end(); ++iterSource) {
const string &wordStr = iterSource->second;
2015-11-18 16:07:16 +03:00
const Factor *factor = vocab.AddFactor(wordStr, system);
2015-11-03 18:04:26 +03:00
uint64_t probingId = iterSource->first;
size_t factorId = factor->GetId();
2015-11-03 18:04:26 +03:00
if (factorId >= m_sourceVocab.size()) {
m_sourceVocab.resize(factorId + 1, m_unkId);
}
m_sourceVocab[factorId] = probingId;
2015-11-03 18:04:26 +03:00
}
// target vocab
InputFileStream targetVocabStrme(m_path + "/TargetVocab.dat");
string line;
while (getline(targetVocabStrme, line)) {
vector<string> toks = Tokenize(line, "\t");
assert(toks.size());
const Factor *factor = vocab.AddFactor(toks[0], system);
uint32_t probingId = Scan<uint32_t>(toks[1]);
if (probingId >= m_targetVocab.size()) {
m_targetVocab.resize(probingId + 1, NULL);
}
m_targetVocab[probingId] = factor;
}
2015-11-03 18:04:26 +03:00
// memory mapped file to tps
string filePath = m_path + "/TargetColl.dat";
file.open(filePath.c_str());
if (!file.is_open()) {
throw "Couldn't open file ";
2015-11-03 18:04:26 +03:00
}
2016-01-15 21:18:33 +03:00
data = file.data();
size_t size = file.size();
2016-01-15 21:18:33 +03:00
// cache
CreateCache(system);
2015-11-03 18:04:26 +03:00
}
2015-12-07 21:30:17 +03:00
void ProbingPT::Lookup(const Manager &mgr, InputPaths &inputPaths) const
{
2016-01-13 20:18:40 +03:00
BOOST_FOREACH(InputPath *path, inputPaths) {
TargetPhrases *tpsPtr;
2016-01-21 14:43:12 +03:00
tpsPtr = Lookup(mgr, mgr.GetPool(), *path);
2016-01-13 20:18:40 +03:00
path->AddTargetPhrases(*this, tpsPtr);
2015-12-07 21:30:17 +03:00
}
}
2016-01-01 19:20:37 +03:00
TargetPhrases* ProbingPT::Lookup(const Manager &mgr,
MemPool &pool,
2016-01-21 14:43:12 +03:00
InputPath &inputPath) const
2015-11-03 18:04:26 +03:00
{
2016-01-16 01:20:16 +03:00
/*
2016-01-13 19:46:23 +03:00
if (inputPath.prefixPath && inputPath.prefixPath->GetTargetPhrases(*this) == NULL) {
// assume all paths have prefixes, except rules with 1 word source
2016-01-15 22:56:54 +03:00
return NULL;
2016-01-13 19:46:23 +03:00
}
else {
const Phrase &sourcePhrase = inputPath.subPhrase;
2016-01-21 14:43:12 +03:00
std::pair<TargetPhrases*, uint64_t> tpsAndKey = CreateTargetPhrase(pool, mgr.system, sourcePhrase);
2016-01-15 22:56:54 +03:00
return tpsAndKey.first;
2016-01-13 19:46:23 +03:00
}
2016-01-16 01:20:16 +03:00
*/
const Phrase &sourcePhrase = inputPath.subPhrase;
2016-01-16 02:58:28 +03:00
// get hash for source phrase
std::pair<bool, uint64_t> keyStruct = GetSourceProbingId(sourcePhrase);
if (!keyStruct.first) {
return NULL;
}
// check in cache
Cache::const_iterator iter = m_cache.find(keyStruct.second);
if (iter != m_cache.end()) {
TargetPhrases *tps = iter->second;
return tps;
}
// query pt
2016-01-21 14:43:12 +03:00
TargetPhrases *tps = CreateTargetPhrase(pool, mgr.system, sourcePhrase, keyStruct.second);
2016-01-16 02:30:52 +03:00
return tps;
2015-11-03 19:09:49 +03:00
}
2016-01-16 02:58:28 +03:00
std::pair<bool, uint64_t> ProbingPT::GetSourceProbingId(const Phrase &sourcePhrase) const
2015-11-03 19:09:49 +03:00
{
2016-01-16 02:58:28 +03:00
std::pair<bool, uint64_t> ret;
2015-11-03 18:04:26 +03:00
// create a target phrase from the 1st word of the source, prefix with 'ProbingPT:'
2015-12-11 01:01:43 +03:00
size_t sourceSize = sourcePhrase.GetSize();
2016-01-13 19:46:23 +03:00
assert(sourceSize);
2015-11-03 18:04:26 +03:00
2015-12-11 01:01:43 +03:00
uint64_t probingSource[sourceSize];
2016-01-16 02:58:28 +03:00
ConvertToProbingSourcePhrase(sourcePhrase, ret.first, probingSource);
if (!ret.first) {
2016-01-16 02:30:52 +03:00
// source phrase contains a word unknown in the pt.
// We know immediately there's no translation for it
2015-11-03 18:04:26 +03:00
}
2016-01-16 02:58:28 +03:00
else {
ret.second = m_engine->getKey(probingSource, sourceSize);
}
return ret;
2015-11-03 18:04:26 +03:00
2016-01-16 02:30:52 +03:00
}
TargetPhrases *ProbingPT::CreateTargetPhrase(
MemPool &pool,
const System &system,
const Phrase &sourcePhrase,
2016-01-21 14:43:12 +03:00
uint64_t key) const
2016-01-16 02:30:52 +03:00
{
TargetPhrases *tps = NULL;
2016-01-15 15:21:42 +03:00
2016-01-16 02:30:52 +03:00
//Actual lookup
std::pair<bool, uint64_t> query_result; // 1st=found, 2nd=target file offset
query_result = m_engine->query(key);
2015-11-03 18:04:26 +03:00
if (query_result.first) {
const char *offset = data + query_result.second;
uint64_t *numTP = (uint64_t*) offset;
2015-11-03 18:04:26 +03:00
tps = new (pool.Allocate<TargetPhrases>()) TargetPhrases(pool, *numTP);
offset += sizeof(uint64_t);
for (size_t i = 0; i < *numTP; ++i) {
TargetPhrase *tp = CreateTargetPhrase(pool, system, offset);
assert(tp);
const FeatureFunctions &ffs = system.featureFunctions;
ffs.EvaluateInIsolation(pool, system, sourcePhrase, *tp);
tps->AddTargetPhrase(*tp);
}
2016-01-21 14:22:55 +03:00
tps->SortAndPrune(m_tableLimit);
system.featureFunctions.EvaluateAfterTablePruning(pool, *tps, sourcePhrase);
//cerr << *tps << endl;
}
2015-11-03 18:04:26 +03:00
2016-01-16 02:30:52 +03:00
return tps;
2015-11-03 18:04:26 +03:00
}
TargetPhrase *ProbingPT::CreateTargetPhrase(
MemPool &pool,
const System &system,
const char *&offset) const
2015-11-03 18:04:26 +03:00
{
TargetPhraseInfo *tpInfo = (TargetPhraseInfo*) offset;
2016-01-21 14:22:55 +03:00
TargetPhrase *tp = new (pool.Allocate<TargetPhrase>()) TargetPhrase(pool, *this, system, tpInfo->numWords);
2015-11-03 19:09:49 +03:00
offset += sizeof(TargetPhraseInfo);
2015-11-03 19:09:49 +03:00
// scores
SCORE *scores = (SCORE*) offset;
2015-11-03 19:09:49 +03:00
size_t totalNumScores = m_engine->num_scores + m_engine->num_lex_scores;
2016-01-07 15:06:49 +03:00
2016-01-21 14:51:51 +03:00
if (m_engine->logProb) {
// set pt score for rule
tp->GetScores().PlusEquals(system, *this, scores);
// save scores for other FF, eg. lex RO. Just give the offset
if (m_engine->num_lex_scores) {
tp->scoreProperties = scores + m_engine->num_scores;
}
2016-01-07 15:06:49 +03:00
}
else {
// log score 1st
SCORE logScores[totalNumScores];
for (size_t i = 0; i < totalNumScores; ++i) {
logScores[i] = FloorScore(TransformScore(scores[i]));
2015-12-28 23:07:47 +03:00
}
// set pt score for rule
tp->GetScores().PlusEquals(system, *this, logScores);
2015-12-28 23:07:47 +03:00
// save scores for other FF, eg. lex RO.
tp->scoreProperties = pool.Allocate<SCORE>(m_engine->num_lex_scores);
for (size_t i = 0; i < m_engine->num_lex_scores; ++i) {
tp->scoreProperties[i] = logScores[i + m_engine->num_scores];
2015-12-28 23:07:47 +03:00
}
}
offset += sizeof(SCORE) * totalNumScores;
2015-11-03 18:04:26 +03:00
// words
for (size_t i = 0; i < tpInfo->numWords; ++i) {
uint32_t *probingId = (uint32_t*) offset;
const Factor *factor = GetTargetFactor(*probingId);
assert(factor);
Word &word = (*tp)[i];
word[0] = factor;
offset += sizeof(uint32_t);
}
// properties TODO
return tp;
2015-11-03 18:04:26 +03:00
}
2015-12-11 01:01:43 +03:00
void ProbingPT::ConvertToProbingSourcePhrase(const Phrase &sourcePhrase, bool &ok, uint64_t probingSource[]) const
2015-11-03 18:04:26 +03:00
{
size_t size = sourcePhrase.GetSize();
for (size_t i = 0; i < size; ++i) {
2015-11-13 01:51:13 +03:00
const Factor *factor = sourcePhrase[i][0];
2015-11-03 18:04:26 +03:00
uint64_t probingId = GetSourceProbingId(factor);
if (probingId == m_unkId) {
ok = false;
2015-12-11 01:01:43 +03:00
return;
2015-11-03 18:04:26 +03:00
} else {
2015-12-11 01:01:43 +03:00
probingSource[i] = probingId;
2015-11-03 18:04:26 +03:00
}
}
ok = true;
}
2016-01-15 21:18:33 +03:00
void ProbingPT::CreateCache(System &system)
{
if (m_maxCacheSize == 0) {
return;
}
string filePath = m_path + "/cache";
InputFileStream strme(filePath);
string line;
getline(strme, line);
//float totalCount = Scan<float>(line);
MemPool &pool = system.GetSystemPool();
FactorCollection &vocab = system.GetVocab();
2015-12-28 23:07:47 +03:00
2016-01-15 21:18:33 +03:00
size_t lineCount = 0;
while (getline(strme, line) && lineCount < m_maxCacheSize) {
vector<string> toks = Tokenize(line, "\t");
assert(toks.size() == 2);
PhraseImpl *sourcePhrase = PhraseImpl::CreateFromString(pool, vocab, system, toks[1]);
2016-01-16 02:58:28 +03:00
std::pair<bool, uint64_t> retStruct = GetSourceProbingId(*sourcePhrase);
if (!retStruct.first) {
return;
}
2016-01-21 14:43:12 +03:00
TargetPhrases *tps = CreateTargetPhrase(pool, system, *sourcePhrase, retStruct.second);
2016-01-16 02:30:52 +03:00
assert(tps);
2016-01-15 21:18:33 +03:00
2016-01-16 02:58:28 +03:00
m_cache[retStruct.second] = tps;
2016-01-15 21:18:33 +03:00
}
}
2015-12-28 23:07:47 +03:00
2015-12-10 23:49:30 +03:00
}