mosesdecoder/moses/LM/LDHT.cpp

421 lines
13 KiB
C++
Raw Normal View History

//
// Oliver Wilson <oliver.wilson@ed.ac.uk>
//
2013-09-27 12:35:24 +04:00
// This file should be compiled only when the LM_RAND flag is enabled.
2013-08-29 20:56:25 +04:00
//
2013-09-27 12:35:24 +04:00
// The following ifdef prevents XCode and other non-bjam build systems
2013-08-29 20:56:25 +04:00
// from attempting to compile this file when LM_RAND is disabled.
//
#ifdef LM_RAND
#include "LM/Base.h"
#include "LM/LDHT.h"
2012-11-14 18:18:53 +04:00
#include "moses/FFState.h"
#include "moses/TypeDef.h"
#include "moses/Hypothesis.h"
#include "moses/StaticData.h"
2014-01-13 18:37:05 +04:00
#include "util/exception.hh"
#include <LDHT/Client.h>
#include <LDHT/ClientLocal.h>
#include <LDHT/NewNgram.h>
#include <LDHT/FactoryCollection.h>
#include <boost/thread/tss.hpp>
2013-05-29 21:16:15 +04:00
namespace Moses
{
struct LDHTLMState : public FFState {
LDHT::NewNgram gram_fingerprints;
2012-07-02 18:57:54 +04:00
bool finalised;
std::vector<int> request_tags;
LDHTLMState(): finalised(false) {
}
void setFinalised() {
this->finalised = true;
}
void appendRequestTag(int tag) {
this->request_tags.push_back(tag);
}
void clearRequestTags() {
this->request_tags.clear();
}
std::vector<int>::iterator requestTagsBegin() {
return this->request_tags.begin();
}
std::vector<int>::iterator requestTagsEnd() {
return this->request_tags.end();
}
int Compare(const FFState& uncast_other) const {
const LDHTLMState &other = static_cast<const LDHTLMState&>(uncast_other);
2012-07-02 18:57:54 +04:00
//if (!this->finalised)
// return -1;
return gram_fingerprints.compareMoses(other.gram_fingerprints);
}
void copyFrom(const LDHTLMState& other) {
gram_fingerprints.copyFrom(other.gram_fingerprints);
2012-07-02 18:57:54 +04:00
finalised = false;
}
};
2013-05-29 21:16:15 +04:00
class LanguageModelLDHT : public LanguageModel
{
public:
2013-05-29 21:16:15 +04:00
LanguageModelLDHT();
LanguageModelLDHT(const std::string& path,
ScoreIndexManager& manager,
FactorType factorType);
LanguageModelLDHT(ScoreIndexManager& manager,
LanguageModelLDHT& copyFrom);
LDHT::Client* getClientUnsafe() const;
LDHT::Client* getClientSafe();
LDHT::Client* initTSSClient();
virtual ~LanguageModelLDHT();
virtual void InitializeForInput(ttasksptr const& ttask);
2013-05-29 21:16:15 +04:00
virtual void CleanUpAfterSentenceProcessing(const InputType &source);
virtual const FFState* EmptyHypothesisState(const InputType& input) const;
virtual void CalcScore(const Phrase& phrase,
float& fullScore,
float& ngramScore,
std::size_t& oovCount) const;
virtual void CalcScoreFromCache(const Phrase& phrase,
float& fullScore,
float& ngramScore,
std::size_t& oovCount) const;
FFState* Evaluate(const Hypothesis& hypo,
const FFState* input_state,
ScoreComponentCollection* score_output) const;
FFState* EvaluateWhenApplied(const ChartHypothesis& hypo,
2015-01-14 14:07:42 +03:00
int featureID,
ScoreComponentCollection* accumulator) const;
2013-05-29 21:16:15 +04:00
virtual void IssueRequestsFor(Hypothesis& hypo,
const FFState* input_state);
float calcScoreFromState(LDHTLMState* hypo) const;
void sync();
void SetFFStateIdx(int state_idx);
2012-07-02 18:57:54 +04:00
protected:
2013-05-29 21:16:15 +04:00
boost::thread_specific_ptr<LDHT::Client> m_client;
std::string m_configPath;
FactorType m_factorType;
int m_state_idx;
int m_calc_score_count;
uint64_t m_start_tick;
};
LanguageModel* ConstructLDHTLM(const std::string& path,
ScoreIndexManager& manager,
2013-05-29 21:16:15 +04:00
FactorType factorType)
{
return new LanguageModelLDHT(path, manager, factorType);
}
2013-05-29 21:16:15 +04:00
LanguageModelLDHT::LanguageModelLDHT() : LanguageModel(), m_client(NULL)
{
m_enableOOVFeature = false;
}
LanguageModelLDHT::LanguageModelLDHT(ScoreIndexManager& manager,
2013-05-29 21:16:15 +04:00
LanguageModelLDHT& copyFrom)
{
m_calc_score_count = 0;
//m_client = copyFrom.m_client;
m_factorType = copyFrom.m_factorType;
m_configPath = copyFrom.m_configPath;
Init(manager);
}
LanguageModelLDHT::LanguageModelLDHT(const std::string& path,
ScoreIndexManager& manager,
FactorType factorType)
2013-05-29 21:16:15 +04:00
: m_factorType(factorType)
{
m_configPath = path;
Init(manager);
}
2013-05-29 21:16:15 +04:00
LanguageModelLDHT::~LanguageModelLDHT()
{
// TODO(wilson): should cleanup for each individual thread.
//delete getClientSafe();
}
// Check that there is a TSS Client instance, and instantiate one if
// there isn't.
2013-05-29 21:16:15 +04:00
LDHT::Client* LanguageModelLDHT::getClientSafe()
{
if (m_client.get() == NULL)
m_client.reset(initTSSClient());
return m_client.get();
}
// Do not check that there is a TSS Client instance.
2013-05-29 21:16:15 +04:00
LDHT::Client* LanguageModelLDHT::getClientUnsafe() const
{
return m_client.get();
}
2013-05-29 21:16:15 +04:00
LDHT::Client* LanguageModelLDHT::initTSSClient()
{
std::ifstream config_file(m_configPath.c_str());
std::string ldht_config_path;
getline(config_file, ldht_config_path);
std::string ldhtlm_config_path;
getline(config_file, ldhtlm_config_path);
LDHT::FactoryCollection* factory_collection =
LDHT::FactoryCollection::createDefaultFactoryCollection();
LDHT::Client* client;
//client = new LDHT::ClientLocal();
client = new LDHT::Client();
client->fromXmlFiles(*factory_collection,
ldht_config_path,
ldhtlm_config_path);
return client;
}
void LanguageModelLDHT::InitializeForInput(ttasksptr const& ttask)
2013-05-29 21:16:15 +04:00
{
getClientSafe()->clearCache();
m_start_tick = LDHT::Util::rdtsc();
}
2013-05-29 21:16:15 +04:00
void LanguageModelLDHT::CleanUpAfterSentenceProcessing(const InputType &source)
{
LDHT::Client* client = getClientSafe();
std::cerr << "LDHT sentence stats:" << std::endl;
std::cerr << " ngrams submitted: " << client->getNumNgramsSubmitted() << std::endl
<< " ngrams requested: " << client->getNumNgramsRequested() << std::endl
<< " ngrams not found: " << client->getKeyNotFoundCount() << std::endl
<< " cache hits: " << client->getCacheHitCount() << std::endl
<< " inferences: " << client->getInferenceCount() << std::endl
<< " pcnt latency: " << (float)client->getLatencyTicks() / (float)(LDHT::Util::rdtsc() - m_start_tick) * 100.0 << std::endl;
m_start_tick = 0;
client->resetLatencyTicks();
client->resetNumNgramsSubmitted();
client->resetNumNgramsRequested();
client->resetInferenceCount();
client->resetCacheHitCount();
client->resetKeyNotFoundCount();
}
const FFState* LanguageModelLDHT::EmptyHypothesisState(
2013-05-29 21:16:15 +04:00
const InputType& input) const
{
return NULL;
}
void LanguageModelLDHT::CalcScore(const Phrase& phrase,
float& fullScore,
float& ngramScore,
2013-05-29 21:16:15 +04:00
std::size_t& oovCount) const
{
const_cast<LanguageModelLDHT*>(this)->m_calc_score_count++;
if (m_calc_score_count > 10000) {
const_cast<LanguageModelLDHT*>(this)->m_calc_score_count = 0;
const_cast<LanguageModelLDHT*>(this)->sync();
}
2013-05-29 21:16:15 +04:00
// TODO(wilson): handle nonterminal words.
LDHT::Client* client = getClientUnsafe();
// Score the first order - 1 words of the phrase.
int order = LDHT::NewNgram::k_max_order;
int prefix_start = 0;
int prefix_end = std::min(phrase.GetSize(), static_cast<size_t>(order - 1));
LDHT::NewNgram ngram;
for (int word_idx = prefix_start; word_idx < prefix_end; ++word_idx) {
ngram.appendGram(phrase.GetWord(word_idx)
.GetFactor(m_factorType)->GetString().c_str());
client->requestNgram(ngram);
}
// Now score all subsequent ngrams to end of phrase.
int internal_start = prefix_end;
int internal_end = phrase.GetSize();
for (int word_idx = internal_start; word_idx < internal_end; ++word_idx) {
ngram.appendGram(phrase.GetWord(word_idx)
.GetFactor(m_factorType)->GetString().c_str());
client->requestNgram(ngram);
}
2013-05-29 21:16:15 +04:00
fullScore = 0;
ngramScore = 0;
oovCount = 0;
}
void LanguageModelLDHT::CalcScoreFromCache(const Phrase& phrase,
2013-05-29 21:16:15 +04:00
float& fullScore,
float& ngramScore,
std::size_t& oovCount) const
{
// Issue requests for phrase internal ngrams.
// Sync if necessary. (or autosync).
const_cast<LanguageModelLDHT*>(this)->sync();
// TODO(wilson): handle nonterminal words.
LDHT::Client* client = getClientUnsafe();
// Score the first order - 1 words of the phrase.
int order = LDHT::NewNgram::k_max_order;
int prefix_start = 0;
int prefix_end = std::min(phrase.GetSize(), static_cast<size_t>(order - 1));
LDHT::NewNgram ngram;
std::deque<int> full_score_tags;
for (int word_idx = prefix_start; word_idx < prefix_end; ++word_idx) {
ngram.appendGram(phrase.GetWord(word_idx)
.GetFactor(m_factorType)->GetString().c_str());
full_score_tags.push_back(client->requestNgram(ngram));
}
// Now score all subsequent ngrams to end of phrase.
int internal_start = prefix_end;
int internal_end = phrase.GetSize();
std::deque<int> internal_score_tags;
for (int word_idx = internal_start; word_idx < internal_end; ++word_idx) {
ngram.appendGram(phrase.GetWord(word_idx)
.GetFactor(m_factorType)->GetString().c_str());
internal_score_tags.push_back(client->requestNgram(ngram));
}
2013-05-29 21:16:15 +04:00
// Wait for resposes from the servers.
//client->awaitResponses();
2013-05-29 21:16:15 +04:00
// Calculate the full phrase score, and the internal score.
fullScore = 0.0;
while (!full_score_tags.empty()) {
fullScore += client->getNgramScore(full_score_tags.front());
full_score_tags.pop_front();
}
ngramScore = 0.0;
while (!internal_score_tags.empty()) {
float score = client->getNgramScore(internal_score_tags.front());
internal_score_tags.pop_front();
fullScore += score;
ngramScore += score;
}
fullScore = TransformLMScore(fullScore);
ngramScore = TransformLMScore(ngramScore);
oovCount = 0;
}
2012-07-02 18:57:54 +04:00
void LanguageModelLDHT::IssueRequestsFor(Hypothesis& hypo,
2013-05-29 21:16:15 +04:00
const FFState* input_state)
{
// TODO(wilson): handle nonterminal words.
LDHT::Client* client = getClientUnsafe();
// Create a new state and copy the contents of the input_state if
// supplied.
LDHTLMState* new_state = new LDHTLMState();
if (input_state == NULL) {
if (hypo.GetCurrTargetWordsRange().GetStartPos() != 0) {
2014-01-13 22:32:22 +04:00
UTIL_THROW2("got a null state but not at start of sentence");
}
2013-05-29 21:16:15 +04:00
new_state->gram_fingerprints.appendGram(BOS_);
} else {
if (hypo.GetCurrTargetWordsRange().GetStartPos() == 0) {
2014-01-13 22:32:22 +04:00
UTIL_THROW2("got a non null state but at start of sentence");
}
2013-05-29 21:16:15 +04:00
new_state->copyFrom(static_cast<const LDHTLMState&>(*input_state));
}
2013-05-29 21:16:15 +04:00
// Score ngrams that overlap with the previous phrase.
int order = LDHT::NewNgram::k_max_order;
int phrase_start = hypo.GetCurrTargetWordsRange().GetStartPos();
int phrase_end = hypo.GetCurrTargetWordsRange().GetEndPos() + 1;
int overlap_start = phrase_start;
int overlap_end = std::min(phrase_end, phrase_start + order - 1);
int word_idx = overlap_start;
LDHT::NewNgram& ngram = new_state->gram_fingerprints;
for (; word_idx < overlap_end; ++word_idx) {
ngram.appendGram(
hypo.GetFactor(word_idx, m_factorType)->GetString().c_str());
new_state->appendRequestTag(client->requestNgram(ngram));
}
// No need to score phrase internal ngrams, but keep track of them
// in the state (which in this case is the NewNgram containing the
// hashes of the individual grams).
for (; word_idx < phrase_end; ++word_idx) {
ngram.appendGram(
hypo.GetFactor(word_idx, m_factorType)->GetString().c_str());
}
// If this is the last phrase in the sentence, score the last ngram
// with the end of sentence marker on it.
if (hypo.IsSourceCompleted()) {
ngram.appendGram(EOS_);
//request_tags.push_back(client->requestNgram(ngram));
new_state->appendRequestTag(client->requestNgram(ngram));
}
hypo.SetFFState(m_state_idx, new_state);
2012-07-02 18:57:54 +04:00
}
2013-05-29 21:16:15 +04:00
void LanguageModelLDHT::sync()
{
m_calc_score_count = 0;
getClientUnsafe()->awaitResponses();
2012-07-02 18:57:54 +04:00
}
2013-05-29 21:16:15 +04:00
void LanguageModelLDHT::SetFFStateIdx(int state_idx)
{
m_state_idx = state_idx;
2012-07-02 18:57:54 +04:00
}
2012-07-02 18:57:54 +04:00
FFState* LanguageModelLDHT::Evaluate(
2013-05-29 21:16:15 +04:00
const Hypothesis& hypo,
const FFState* input_state_ignored,
ScoreComponentCollection* score_output) const
{
// Input state is the state from the previous hypothesis, which
// we are not interested in. The requests for this hypo should
// already have been issued via IssueRequestsFor() and the LM then
// synced and all responses processed, and the tags placed in our
// FFState of hypo.
LDHTLMState* state = const_cast<LDHTLMState*>(static_cast<const LDHTLMState*>(hypo.GetFFState(m_state_idx)));
float score = calcScoreFromState(state);
score = FloorScore(TransformLMScore(score));
score_output->PlusEquals(this, score);
return state;
}
FFState* LanguageModelLDHT::EvaluateWhenApplied(
2013-05-29 21:16:15 +04:00
const ChartHypothesis& hypo,
int featureID,
ScoreComponentCollection* accumulator) const
{
return NULL;
}
2013-05-29 21:16:15 +04:00
float LanguageModelLDHT::calcScoreFromState(LDHTLMState* state) const
{
float score = 0.0;
std::vector<int>::iterator tag_iter;
LDHT::Client* client = getClientUnsafe();
for (tag_iter = state->requestTagsBegin();
tag_iter != state->requestTagsEnd();
++tag_iter) {
score += client->getNgramScore(*tag_iter);
}
state->clearRequestTags();
state->setFinalised();
return score;
2012-07-02 18:57:54 +04:00
}
} // namespace Moses.
2013-08-29 20:56:25 +04:00
#endif