improved interface towards IRSTLM

This commit is contained in:
Nicola Bertoldi 2015-01-04 22:44:13 +01:00
parent 3d5e642156
commit 906fc7bdf4
2 changed files with 174 additions and 15 deletions

View File

@ -29,6 +29,7 @@ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
using namespace irstlm;
#include "IRST.h"
#include "moses/LM/PointerState.h"
#include "moses/TypeDef.h"
#include "moses/Util.h"
#include "moses/FactorCollection.h"
@ -42,9 +43,28 @@ using namespace std;
namespace Moses
{
class IRSTLMState : public PointerState
{
public:
IRSTLMState():PointerState(NULL) {}
IRSTLMState(const void* lms):PointerState(lms) {}
IRSTLMState(const IRSTLMState& copy_from):PointerState(copy_from.lmstate) {}
IRSTLMState& operator=( const IRSTLMState& rhs )
{
lmstate = rhs.lmstate;
return *this;
}
const void* GetState() const
{
return lmstate;
}
};
LanguageModelIRST::LanguageModelIRST(const std::string &line)
:LanguageModelSingleFactor(line)
,m_lmtb_dub(0)
,m_lmtb_dub(0), m_lmtb_size(0)
{
const StaticData &staticData = StaticData::Instance();
int threadCount = staticData.ThreadCount();
@ -54,6 +74,10 @@ LanguageModelIRST::LanguageModelIRST(const std::string &line)
ReadParameters();
VERBOSE(4, GetScoreProducerDescription() << " LanguageModelIRST::LanguageModelIRST() m_lmtb_dub:|" << m_lmtb_dub << "|" << std::endl);
VERBOSE(4, GetScoreProducerDescription() << " LanguageModelIRST::LanguageModelIRST() m_filePath:|" << m_filePath << "|" << std::endl);
VERBOSE(4, GetScoreProducerDescription() << " LanguageModelIRST::LanguageModelIRST() m_factorType:|" << m_factorType << "|" << std::endl);
VERBOSE(4, GetScoreProducerDescription() << " LanguageModelIRST::LanguageModelIRST() m_lmtb_size:|" << m_lmtb_size << "|" << std::endl);
}
LanguageModelIRST::~LanguageModelIRST()
@ -70,17 +94,15 @@ LanguageModelIRST::~LanguageModelIRST()
void LanguageModelIRST::Load()
{
cerr << "In LanguageModelIRST::Load: nGramOrder = " << m_nGramOrder << "\n";
FactorCollection &factorCollection = FactorCollection::Instance();
m_lmtb = m_lmtb->CreateLanguageModel(m_filePath);
m_lmtb->setMaxLoadedLevel(1000);
if (m_lmtb_size > 0) m_lmtb->setMaxLoadedLevel(m_lmtb_size);
m_lmtb->load(m_filePath);
d=m_lmtb->getDict();
d->incflag(1);
m_lmtb_size=m_lmtb->maxlevel();
m_nGramOrder = m_lmtb_size = m_lmtb->maxlevel();
// LM can be ok, just outputs warnings
// Mauro: in the original, the following two instructions are wrongly switched:
@ -89,7 +111,7 @@ void LanguageModelIRST::Load()
CreateFactors(factorCollection);
VERBOSE(1, "IRST: m_unknownId=" << m_unknownId << std::endl);
VERBOSE(1, GetScoreProducerDescription() << " LanguageModelIRST::Load() m_unknownId=" << m_unknownId << std::endl);
//install caches to save time (only if PS_CACHE_ENABLE is defined through compilation flags)
m_lmtb->init_caches(m_lmtb_size>2?m_lmtb_size-1:2);
@ -117,6 +139,8 @@ void LanguageModelIRST::CreateFactors(FactorCollection &factorCollection)
m_sentenceStart = factorCollection.AddFactor(Output, m_factorType, BOS_);
factorId = m_sentenceStart->GetId();
const std::string bs = BOS_;
const std::string es = EOS_;
m_lmtb_sentenceStart=lmIdMap[factorId] = GetLmID(BOS_);
maxFactorId = (factorId > maxFactorId) ? factorId : maxFactorId;
m_sentenceStartWord[m_factorType] = m_sentenceStart;
@ -142,6 +166,11 @@ int LanguageModelIRST::GetLmID( const std::string &str ) const
return d->encode( str.c_str() ); // at the level of micro tags
}
int LanguageModelIRST::GetLmID( const Word &word ) const
{
return GetLmID( word.GetFactor(m_factorType) );
}
int LanguageModelIRST::GetLmID( const Factor *factor ) const
{
size_t factorId = factor->GetId();
@ -157,7 +186,7 @@ int LanguageModelIRST::GetLmID( const Factor *factor ) const
///di cui non sia stato ancora calcolato il suo codice target abbia
///comunque un factorID noto (e quindi minore di m_lmIdLookup.size())
///E' necessario dunque identificare questi casi di indeterminatezza
///del codice target. Attualamente, questo controllo e' stato implementato
///del codice target. Attualmente, questo controllo e' stato implementato
///impostando a m_empty tutti i termini che non hanno ancora
//ricevuto un codice target effettivo
///////////
@ -198,6 +227,102 @@ int LanguageModelIRST::GetLmID( const Factor *factor ) const
}
}
void LanguageModelIRST::CalcScore(const Phrase &phrase, float &fullScore, float &ngramScore, size_t &oovCount) const
{
fullScore = 0;
ngramScore = 0;
oovCount = 0;
if ( !phrase.GetSize() ) return;
if ( m_lmtb_size > (int) phrase.GetSize()) return;
int codes[m_lmtb_size];
int idx = 0;
int position = 0;
for (; position < m_lmtb_size; ++position)
{
codes[idx] = GetLmID(phrase.GetWord(position));
if (codes[idx] == m_unknownId) ++oovCount;
++idx;
}
char* msp = NULL;
ngramScore = m_lmtb->clprob(codes,idx,NULL,NULL,&msp);
int end_loop = (int) phrase.GetSize();
for (; position < end_loop; ++position) {
for (idx = 1; idx < m_lmtb_size; ++idx)
{
codes[idx-1] = codes[idx];
}
codes[idx-1] = GetLmID(phrase.GetWord(position));
if (codes[idx-1] == m_unknownId) ++oovCount;
ngramScore += m_lmtb->clprob(codes,idx,NULL,NULL,&msp);
}
ngramScore = TransformLMScore(ngramScore);
fullScore = ngramScore;
}
FFState* LanguageModelIRST::EvaluateWhenApplied(const Hypothesis &hypo, const FFState *ps, ScoreComponentCollection *out) const
{
if (!hypo.GetCurrTargetLength()) {
std::auto_ptr<IRSTLMState> ret(new IRSTLMState(ps));
return ret.release();
}
const std::size_t begin = hypo.GetCurrTargetWordsRange().GetStartPos();
//[begin, end) in STL-like fashion.
const std::size_t end = hypo.GetCurrTargetWordsRange().GetEndPos() + 1;
const std::size_t adjust_end = std::min(end, begin + m_lmtb_size - 1);
// set up context
int codes[m_lmtb_size];
int idx=m_lmtb_size-1;
int position = adjust_end-1;
//fill the farthest positions with at most ONE sentenceEnd symbol and at most ONE sentenceEnd symbol, if "empty" positions are available
//so that the vector looks like = "</s> <s> context_word context_word" for a two-word context and a LM of order 5
while (position >= (const int) begin) {
codes[idx] = GetLmID(hypo.GetWord(position));
--idx;
--position;
}
if (idx == 1){
codes[1] = m_lmtb_sentenceStart;
codes[0] = m_lmtb_sentenceStart;
}
else if (idx == 0)
{
codes[0] = m_lmtb_sentenceEnd;
}
char* msp = NULL;
float score = m_lmtb->clprob(codes,m_lmtb_size,NULL,NULL,&msp);
score = TransformLMScore(score);
out->PlusEquals(this, score);
if (adjust_end < end)
{
idx=m_lmtb_size-1;
position = end-1;
while (idx>=0)
{
codes[idx] = GetLmID(hypo.GetWord(position));
--idx;
--position;
}
msp = (char *) m_lmtb->cmaxsuffptr(codes,m_lmtb_size);
}
std::auto_ptr<IRSTLMState> ret(new IRSTLMState(msp));
return ret.release();
}
LMResult LanguageModelIRST::GetValue(const vector<const Word*> &contextFactor, State* finalState) const
{
FactorType factorType = GetFactorType();
@ -219,14 +344,16 @@ LMResult LanguageModelIRST::GetValue(const vector<const Word*> &contextFactor, S
if (count < (size_t) m_lmtb_size) codes[idx++] = m_lmtb_sentenceStart;
for (size_t i = 0 ; i < count ; i++) {
codes[idx++] = GetLmID((*contextFactor[i])[factorType]);
//codes[idx] = GetLmID((*contextFactor[i])[factorType]);
codes[idx] = GetLmID(*contextFactor[i]);
++idx;
}
LMResult result;
result.unknown = (codes[idx - 1] == m_unknownId);
char* msp = NULL;
unsigned int ilen;
result.score = m_lmtb->clprob(codes,idx,NULL,NULL,&msp,&ilen);
result.score = m_lmtb->clprob(codes,idx,NULL,NULL,&msp);
if (finalState) *finalState=(State *) msp;
@ -235,7 +362,7 @@ LMResult LanguageModelIRST::GetValue(const vector<const Word*> &contextFactor, S
}
bool LMCacheCleanup(size_t sentences_done, size_t m_lmcache_cleanup_threshold)
bool LMCacheCleanup(const int sentences_done, const size_t m_lmcache_cleanup_threshold)
{
if (sentences_done==-1) return true;
if (m_lmcache_cleanup_threshold)
@ -266,5 +393,15 @@ void LanguageModelIRST::CleanUpAfterSentenceProcessing(const InputType& source)
}
}
void LanguageModelIRST::SetParameter(const std::string& key, const std::string& value)
{
if (key == "dub") {
m_lmtb_dub = Scan<unsigned int>(value);
} else {
LanguageModelSingleFactor::SetParameter(key, value);
}
m_lmtb_size = m_nGramOrder;
}
}

View File

@ -24,10 +24,13 @@ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
#include <string>
#include <vector>
#include "moses/Factor.h"
#include "moses/LM/SingleFactor.h"
#include "moses/Hypothesis.h"
#include "moses/TypeDef.h"
#include "moses/Util.h"
#include "SingleFactor.h"
//this is required because:
//- IRSTLM package uses the namespace irstlm
@ -44,6 +47,9 @@ class dictionary;
namespace Moses
{
//class LanguageModel;
class FFState;
class Phrase;
/** Implementation of single factor LM using IRST's code.
@ -59,31 +65,47 @@ protected:
int m_empty; //code of an empty position
int m_lmtb_sentenceStart; //lmtb symbols to initialize ngram with
int m_lmtb_sentenceEnd; //lmt symbol to initialize ngram with
int m_lmtb_size; //max ngram stored in the table
int m_lmtb_dub; //dictionary upperboud
int m_lmtb_size; //max ngram stored in the table
dictionary* d;
std::string m_mapFilePath;
void CreateFactors(FactorCollection &factorCollection);
int GetLmID( const Word &word ) const;
int GetLmID( const std::string &str ) const;
int GetLmID( const Factor *factor ) const;
dictionary* d;
public:
LanguageModelIRST(const std::string &line);
~LanguageModelIRST();
void SetParameter(const std::string& key, const std::string& value);
void Load();
virtual LMResult GetValue(const std::vector<const Word*> &contextFactor, State* finalState = NULL) const;
virtual FFState *EvaluateWhenApplied(const Hypothesis &hypo, const FFState *ps, ScoreComponentCollection *out) const;
virtual void CalcScore(const Phrase &phrase, float &fullScore, float &ngramScore, size_t &oovCount) const;
/*
virtual FFState *EvaluateWhenApplied(const ChartHypothesis& cur_hypo, int featureID, ScoreComponentCollection *accumulator) const;
virtual FFState *EvaluateWhenApplied(const Syntax::SHyperedge& hyperedge, int featureID, ScoreComponentCollection *accumulator) const;
*/
void InitializeForInput(InputType const& source);
void CleanUpAfterSentenceProcessing(const InputType& source);
void set_dictionary_upperbound(int dub) {
m_lmtb_size=dub ;
//m_lmtb->set_dictionary_upperbound(dub);
};
};