mosesdecoder/moses/LM/oxlm/SourceOxLM.cpp

151 lines
4.1 KiB
C++
Raw Permalink Normal View History

#include "moses/LM/oxlm/SourceOxLM.h"
2014-09-27 18:16:28 +04:00
#include <boost/archive/binary_iarchive.hpp>
#include <boost/archive/binary_oarchive.hpp>
#include <boost/filesystem.hpp>
2015-05-22 00:17:43 +03:00
#include "moses/TypeDef.h"
#include "moses/TranslationTask.h"
2014-09-27 18:16:28 +04:00
using namespace std;
using namespace oxlm;
2015-01-14 14:07:42 +03:00
namespace Moses
{
2014-09-27 18:16:28 +04:00
SourceOxLM::SourceOxLM(const string &line)
2015-01-14 14:07:42 +03:00
: BilingualLM(line), posBackOff(false), posFactorType(1),
persistentCache(false), cacheHits(0), totalHits(0)
{
FactorCollection& factorFactory = FactorCollection::Instance(); // To add null word.
const Factor* NULL_factor = factorFactory.AddFactor("<unk>");
NULL_word.SetFactor(0, NULL_factor);
}
SourceOxLM::~SourceOxLM()
{
2014-09-27 18:16:28 +04:00
if (persistentCache) {
double cache_hit_ratio = 100.0 * cacheHits / totalHits;
cerr << "Cache hit ratio: " << cache_hit_ratio << endl;
}
}
float SourceOxLM::Score(
2015-01-14 14:07:42 +03:00
vector<int>& source_words,
vector<int>& target_words) const
{
2014-09-28 19:53:41 +04:00
// OxLM expects the context in the following format:
// [t_{n-1}, t_{n-2}, ..., t_{n-m}, s_{a_n-sm}, s_{a_n-sm+1}, ..., s_{a_n+sm}]
// where n is the index for the current target word, m is the target order,
// a_n is t_n's affiliation and sm is the source order.
2014-09-26 19:25:48 +04:00
vector<int> context = target_words;
2014-09-28 19:53:41 +04:00
int word = context.back();
context.pop_back();
reverse(context.begin(), context.end());
2014-09-26 19:25:48 +04:00
context.insert(context.end(), source_words.begin(), source_words.end());
2014-09-27 18:16:28 +04:00
float score;
if (persistentCache) {
if (!cache.get()) {
cache.reset(new QueryCache());
}
++totalHits;
NGram query(word, context);
pair<double, bool> ret = cache->get(query);
if (ret.second) {
score = ret.first;
++cacheHits;
} else {
2014-10-09 01:53:05 +04:00
score = model.getLogProb(word, context);
2014-09-27 18:16:28 +04:00
cache->put(query, score);
}
} else {
2014-10-09 01:53:05 +04:00
score = model.getLogProb(word, context);
2014-09-27 18:16:28 +04:00
}
// TODO(pauldb): Return OOV count too.
return score;
}
2015-01-14 14:07:42 +03:00
int SourceOxLM::getNeuralLMId(const Word& word, bool is_source_word) const
{
2014-09-29 20:25:26 +04:00
return is_source_word ? mapper->convertSource(word) : mapper->convert(word);
}
2015-01-14 14:07:42 +03:00
const Word& SourceOxLM::getNullWord() const
{
2014-10-14 21:22:04 +04:00
return NULL_word;
}
2015-01-14 14:07:42 +03:00
void SourceOxLM::loadModel()
{
2014-09-26 19:25:48 +04:00
model.load(m_filePath);
boost::shared_ptr<ModelData> config = model.getConfig();
source_ngrams = 2 * config->source_order - 1;
target_ngrams = config->ngram_order - 1;
2014-09-27 02:20:26 +04:00
boost::shared_ptr<Vocabulary> vocab = model.getVocab();
2014-09-29 20:25:26 +04:00
mapper = boost::make_shared<OxLMParallelMapper>(
2015-01-14 14:07:42 +03:00
vocab, posBackOff, posFactorType);
}
2015-01-14 14:07:42 +03:00
void SourceOxLM::SetParameter(const string& key, const string& value)
{
2014-09-26 19:25:48 +04:00
if (key == "persistent-cache") {
persistentCache = Scan<bool>(value);
2014-09-29 20:25:26 +04:00
} else if (key == "pos-back-off") {
posBackOff = Scan<bool>(value);
} else if (key == "pos-factor-type") {
posFactorType = Scan<FactorType>(value);
2014-09-26 19:25:48 +04:00
} else {
BilingualLM::SetParameter(key, value);
}
}
void SourceOxLM::InitializeForInput(ttasksptr const& ttask)
2015-01-14 14:07:42 +03:00
{
2015-05-22 00:17:43 +03:00
const InputType& source = *ttask->GetSource();
BilingualLM::InitializeForInput(ttask);
2014-09-27 18:16:28 +04:00
if (persistentCache) {
if (!cache.get()) {
cache.reset(new QueryCache());
}
2014-09-27 18:16:28 +04:00
int sentence_id = source.GetTranslationId();
string cacheFile = m_filePath + "." + to_string(sentence_id) + ".cache.bin";
if (boost::filesystem::exists(cacheFile)) {
ifstream fin(cacheFile);
boost::archive::binary_iarchive iar(fin);
cerr << "Loading n-gram probability cache from " << cacheFile << endl;
iar >> *cache;
cerr << "Done loading " << cache->size()
<< " n-gram probabilities..." << endl;
} else {
cerr << "Cache file not found!" << endl;
}
}
}
2015-01-14 14:07:42 +03:00
void SourceOxLM::CleanUpAfterSentenceProcessing(const InputType& source)
{
2014-09-27 18:16:28 +04:00
// Thread safe: the model cache is thread specific.
model.clearCache();
if (persistentCache) {
int sentence_id = source.GetTranslationId();
string cacheFile = m_filePath + "." + to_string(sentence_id) + ".cache.bin";
ofstream fout(cacheFile);
boost::archive::binary_oarchive oar(fout);
cerr << "Saving persistent cache to " << cacheFile << endl;
oar << *cache;
cerr << "Done saving " << cache->size()
<< " n-gram probabilities..." << endl;
2014-09-27 18:16:28 +04:00
cache->clear();
}
}
} // namespace Moses