2013-11-05 18:37:56 +04:00
|
|
|
|
2013-11-14 22:31:46 +04:00
|
|
|
#include <boost/functional/hash.hpp>
|
2013-12-16 18:17:56 +04:00
|
|
|
#include "moses/FF/FFState.h"
|
2013-11-21 23:19:34 +04:00
|
|
|
#include "DALMWrapper.h"
|
2013-11-11 22:27:15 +04:00
|
|
|
#include "logger.h"
|
2013-11-17 09:42:13 +04:00
|
|
|
#include "dalm.h"
|
2013-11-11 23:49:00 +04:00
|
|
|
#include "vocabulary.h"
|
2013-12-16 18:17:56 +04:00
|
|
|
#include "moses/FactorTypeSet.h"
|
2013-11-14 22:31:46 +04:00
|
|
|
#include "moses/FactorCollection.h"
|
2013-11-18 17:54:40 +04:00
|
|
|
#include "moses/InputFileStream.h"
|
|
|
|
#include "util/exception.hh"
|
2014-02-17 19:00:28 +04:00
|
|
|
#include "moses/ChartHypothesis.h"
|
|
|
|
#include "moses/ChartManager.h"
|
2013-11-05 18:37:56 +04:00
|
|
|
|
|
|
|
using namespace std;
|
|
|
|
|
2013-11-11 18:39:53 +04:00
|
|
|
/////////////////////////
|
2013-11-17 09:42:13 +04:00
|
|
|
void read_ini(const char *inifile, string &model, string &words, string &wordstxt){
|
2013-11-11 18:39:53 +04:00
|
|
|
ifstream ifs(inifile);
|
|
|
|
string line;
|
|
|
|
|
|
|
|
getline(ifs, line);
|
|
|
|
while(ifs){
|
|
|
|
unsigned int pos = line.find("=");
|
|
|
|
string key = line.substr(0, pos);
|
|
|
|
string value = line.substr(pos+1, line.size()-pos);
|
|
|
|
if(key=="MODEL"){
|
|
|
|
model = value;
|
|
|
|
}else if(key=="WORDS"){
|
|
|
|
words = value;
|
2013-11-17 09:42:13 +04:00
|
|
|
}else if(key=="WORDSTXT"){
|
|
|
|
wordstxt = value;
|
2013-11-11 18:39:53 +04:00
|
|
|
}
|
|
|
|
getline(ifs, line);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/////////////////////////
|
|
|
|
|
2013-11-11 23:49:00 +04:00
|
|
|
namespace Moses
|
|
|
|
{
|
2013-12-16 18:17:56 +04:00
|
|
|
|
|
|
|
class DALMState : public FFState
|
|
|
|
{
|
|
|
|
private:
|
|
|
|
DALM::State *state;
|
|
|
|
|
|
|
|
public:
|
|
|
|
DALMState(unsigned short order){
|
|
|
|
state = new DALM::State(order);
|
|
|
|
}
|
|
|
|
|
|
|
|
DALMState(const DALMState &from){
|
|
|
|
state = new DALM::State(*from.state);
|
|
|
|
}
|
|
|
|
|
|
|
|
virtual ~DALMState(){
|
|
|
|
delete state;
|
|
|
|
}
|
|
|
|
|
2014-02-17 19:00:28 +04:00
|
|
|
void reset(const DALMState &from){
|
|
|
|
delete state;
|
|
|
|
state = new DALM::State(*from.state);
|
|
|
|
}
|
|
|
|
|
2014-02-28 08:38:35 +04:00
|
|
|
void reset(DALM::State *s){
|
|
|
|
delete state;
|
|
|
|
state = s;
|
|
|
|
}
|
|
|
|
|
2013-12-16 18:17:56 +04:00
|
|
|
virtual int Compare(const FFState& other) const{
|
|
|
|
const DALMState &o = static_cast<const DALMState &>(other);
|
|
|
|
if(state->get_count() < o.state->get_count()) return -1;
|
|
|
|
else if(state->get_count() > o.state->get_count()) return 1;
|
2013-12-17 09:36:19 +04:00
|
|
|
else return state->compare(o.state);
|
2013-12-16 18:17:56 +04:00
|
|
|
}
|
|
|
|
|
|
|
|
DALM::State *get_state() const{
|
|
|
|
return state;
|
|
|
|
}
|
|
|
|
|
|
|
|
void refresh(){
|
|
|
|
state->refresh();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2014-02-17 19:00:28 +04:00
|
|
|
class DALMChartState : public FFState
|
|
|
|
{
|
|
|
|
private:
|
2014-02-28 08:38:35 +04:00
|
|
|
const ChartHypothesis &hypo;
|
|
|
|
DALM::Fragment *prefixFragments;
|
|
|
|
unsigned short prefixLength;
|
2014-02-17 19:00:28 +04:00
|
|
|
float prefixScore;
|
|
|
|
DALMState *rightContext;
|
|
|
|
bool isLarge;
|
|
|
|
|
|
|
|
public:
|
2014-02-19 12:49:42 +04:00
|
|
|
DALMChartState(
|
|
|
|
const ChartHypothesis &hypo,
|
2014-02-28 08:38:35 +04:00
|
|
|
DALM::Fragment *prefixFragments,
|
|
|
|
unsigned short prefixLength,
|
2014-02-19 12:49:42 +04:00
|
|
|
float prefixScore,
|
|
|
|
DALMState *rightContext,
|
|
|
|
bool isLarge)
|
2014-02-28 08:38:35 +04:00
|
|
|
: hypo(hypo),
|
|
|
|
prefixFragments(prefixFragments),
|
2014-02-19 12:49:42 +04:00
|
|
|
prefixLength(prefixLength),
|
|
|
|
prefixScore(prefixScore),
|
|
|
|
rightContext(rightContext),
|
|
|
|
isLarge(isLarge)
|
2014-02-17 19:00:28 +04:00
|
|
|
{}
|
|
|
|
|
|
|
|
virtual ~DALMChartState(){
|
2014-02-28 08:38:35 +04:00
|
|
|
delete [] prefixFragments;
|
|
|
|
delete rightContext;
|
2014-02-17 19:00:28 +04:00
|
|
|
}
|
|
|
|
|
2014-02-28 08:38:35 +04:00
|
|
|
unsigned short GetPrefixLength() const{
|
2014-02-17 19:00:28 +04:00
|
|
|
return prefixLength;
|
|
|
|
}
|
|
|
|
|
2014-02-28 08:38:35 +04:00
|
|
|
const DALM::Fragment *GetPrefixFragments() const{
|
|
|
|
return prefixFragments;
|
2014-02-17 19:00:28 +04:00
|
|
|
}
|
|
|
|
|
|
|
|
float GetPrefixScore() const{
|
|
|
|
return prefixScore;
|
|
|
|
}
|
|
|
|
|
|
|
|
const DALMState *GetRightContext() const{
|
|
|
|
return rightContext;
|
|
|
|
}
|
|
|
|
|
|
|
|
bool LargeEnough() const{
|
|
|
|
return isLarge;
|
|
|
|
}
|
|
|
|
|
|
|
|
virtual int Compare(const FFState& other) const{
|
|
|
|
const DALMChartState &o = static_cast<const DALMChartState &>(other);
|
|
|
|
// prefix
|
2014-02-28 08:38:35 +04:00
|
|
|
if (hypo.GetCurrSourceRange().GetStartPos() > 0) { // not for "<s> ..."
|
2014-02-19 12:49:42 +04:00
|
|
|
if (prefixLength != o.prefixLength){
|
2014-02-17 19:00:28 +04:00
|
|
|
return (prefixLength < o.prefixLength)?-1:1;
|
2014-02-19 12:49:42 +04:00
|
|
|
} else {
|
2014-02-28 08:38:35 +04:00
|
|
|
if(prefixLength > 0){
|
|
|
|
DALM::Fragment &f = prefixFragments[prefixLength-1];
|
|
|
|
DALM::Fragment &of = o.prefixFragments[prefixLength-1];
|
|
|
|
int ret = DALM::compare_fragments(f, of);
|
|
|
|
if(ret != 0) return ret;
|
|
|
|
}
|
2014-02-17 19:00:28 +04:00
|
|
|
}
|
|
|
|
}
|
2014-02-19 12:49:42 +04:00
|
|
|
|
2014-02-17 19:00:28 +04:00
|
|
|
// suffix
|
2014-02-28 08:38:35 +04:00
|
|
|
size_t inputSize = hypo.GetManager().GetSource().GetSize();
|
|
|
|
if (hypo.GetCurrSourceRange().GetEndPos() < inputSize - 1) { // not for "... </s>"
|
2014-02-19 12:49:42 +04:00
|
|
|
int ret = o.rightContext->Compare(*rightContext);
|
2014-02-17 19:00:28 +04:00
|
|
|
if (ret != 0) return ret;
|
|
|
|
}
|
|
|
|
return 0;
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2013-11-05 18:37:56 +04:00
|
|
|
LanguageModelDALM::LanguageModelDALM(const std::string &line)
|
2013-12-16 18:17:56 +04:00
|
|
|
:LanguageModel(line)
|
2013-11-05 18:37:56 +04:00
|
|
|
{
|
|
|
|
ReadParameters();
|
|
|
|
|
|
|
|
if (m_factorType == NOT_FOUND) {
|
|
|
|
m_factorType = 0;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
LanguageModelDALM::~LanguageModelDALM()
|
|
|
|
{
|
2013-11-11 22:27:15 +04:00
|
|
|
delete m_logger;
|
|
|
|
delete m_vocab;
|
|
|
|
delete m_lm;
|
2013-11-05 18:37:56 +04:00
|
|
|
}
|
|
|
|
|
2013-11-11 18:39:53 +04:00
|
|
|
void LanguageModelDALM::Load()
|
|
|
|
{
|
|
|
|
/////////////////////
|
|
|
|
// READING INIFILE //
|
|
|
|
/////////////////////
|
2014-02-17 19:00:28 +04:00
|
|
|
string inifile= m_filePath + "/dalm.ini";
|
2013-12-31 20:52:47 +04:00
|
|
|
|
2013-11-11 18:39:53 +04:00
|
|
|
string model; // Path to the double-array file.
|
|
|
|
string words; // Path to the vocabulary file.
|
2013-11-17 09:42:13 +04:00
|
|
|
string wordstxt; //Path to the vocabulary file in text format.
|
2013-12-31 20:52:47 +04:00
|
|
|
read_ini(inifile.c_str(), model, words, wordstxt);
|
|
|
|
|
|
|
|
model = m_filePath + "/" + model;
|
2014-02-17 19:00:28 +04:00
|
|
|
words = m_filePath + "/" + words;
|
|
|
|
wordstxt = m_filePath + "/" + wordstxt;
|
2013-11-11 18:39:53 +04:00
|
|
|
|
2013-11-18 17:54:40 +04:00
|
|
|
UTIL_THROW_IF(model.empty() || words.empty() || wordstxt.empty(),
|
|
|
|
util::FileOpenException,
|
|
|
|
"Failed to read DALM ini file " << m_filePath << ". Probably doesn't exist");
|
|
|
|
|
2013-11-11 18:39:53 +04:00
|
|
|
////////////////
|
|
|
|
// LOADING LM //
|
|
|
|
////////////////
|
|
|
|
|
|
|
|
// Preparing a logger object.
|
2013-11-11 21:19:44 +04:00
|
|
|
m_logger = new DALM::Logger(stderr);
|
|
|
|
m_logger->setLevel(DALM::LOGGER_INFO);
|
|
|
|
|
|
|
|
// Load the vocabulary file.
|
|
|
|
m_vocab = new DALM::Vocabulary(words, *m_logger);
|
|
|
|
|
|
|
|
// Load the language model.
|
|
|
|
m_lm = new DALM::LM(model, *m_vocab, *m_logger);
|
2013-12-16 18:17:56 +04:00
|
|
|
|
2013-11-11 23:49:00 +04:00
|
|
|
wid_start = m_vocab->lookup(BOS_);
|
|
|
|
wid_end = m_vocab->lookup(EOS_);
|
2013-11-18 17:54:40 +04:00
|
|
|
|
|
|
|
// vocab mapping
|
|
|
|
CreateVocabMapping(wordstxt);
|
2013-12-16 18:17:56 +04:00
|
|
|
|
|
|
|
FactorCollection &collection = FactorCollection::Instance();
|
|
|
|
m_beginSentenceFactor = collection.AddFactor(BOS_);
|
2013-11-11 18:39:53 +04:00
|
|
|
}
|
|
|
|
|
2013-12-16 18:17:56 +04:00
|
|
|
const FFState *LanguageModelDALM::EmptyHypothesisState(const InputType &/*input*/) const{
|
|
|
|
DALMState *s = new DALMState(m_nGramOrder);
|
|
|
|
m_lm->init_state(*s->get_state());
|
|
|
|
return s;
|
|
|
|
}
|
2013-11-11 23:49:00 +04:00
|
|
|
|
2013-12-16 18:17:56 +04:00
|
|
|
void LanguageModelDALM::CalcScore(const Phrase &phrase, float &fullScore, float &ngramScore, size_t &oovCount) const{
|
|
|
|
fullScore = 0;
|
|
|
|
ngramScore = 0;
|
|
|
|
|
|
|
|
oovCount = 0;
|
|
|
|
|
|
|
|
size_t phraseSize = phrase.GetSize();
|
|
|
|
if (!phraseSize) return;
|
|
|
|
|
|
|
|
size_t currPos = 0;
|
|
|
|
size_t hist_count = 0;
|
2014-02-14 15:22:53 +04:00
|
|
|
DALMState *dalm_state = new DALMState(m_nGramOrder);
|
|
|
|
DALM::State *state = dalm_state->get_state();
|
|
|
|
|
|
|
|
if(phrase.GetWord(0).GetFactor(m_factorType) == m_beginSentenceFactor){
|
|
|
|
m_lm->init_state(*state);
|
|
|
|
currPos++;
|
|
|
|
hist_count++;
|
|
|
|
}
|
2013-12-16 18:17:56 +04:00
|
|
|
|
|
|
|
while (currPos < phraseSize) {
|
|
|
|
const Word &word = phrase.GetWord(currPos);
|
|
|
|
hist_count++;
|
|
|
|
|
|
|
|
if (word.IsNonTerminal()) {
|
2014-02-14 15:22:53 +04:00
|
|
|
state->refresh();
|
2013-12-16 18:17:56 +04:00
|
|
|
hist_count = 0;
|
|
|
|
} else {
|
2014-02-14 15:22:53 +04:00
|
|
|
DALM::VocabId wid = GetVocabId(word.GetFactor(m_factorType));
|
|
|
|
float score = m_lm->query(wid, *state);
|
|
|
|
fullScore += score;
|
|
|
|
if (hist_count >= m_nGramOrder) ngramScore += score;
|
|
|
|
if (wid==m_vocab->unk()) ++oovCount;
|
2013-12-16 18:17:56 +04:00
|
|
|
}
|
|
|
|
|
|
|
|
currPos++;
|
2013-11-11 23:49:00 +04:00
|
|
|
}
|
|
|
|
|
2014-02-14 15:22:53 +04:00
|
|
|
fullScore = TransformLMScore(fullScore);
|
|
|
|
ngramScore = TransformLMScore(ngramScore);
|
|
|
|
delete dalm_state;
|
2013-12-16 18:17:56 +04:00
|
|
|
}
|
|
|
|
|
|
|
|
FFState *LanguageModelDALM::Evaluate(const Hypothesis &hypo, const FFState *ps, ScoreComponentCollection *out) const{
|
|
|
|
// In this function, we only compute the LM scores of n-grams that overlap a
|
|
|
|
// phrase boundary. Phrase-internal scores are taken directly from the
|
|
|
|
// translation option.
|
|
|
|
|
|
|
|
const DALMState *dalm_ps = static_cast<const DALMState *>(ps);
|
|
|
|
|
|
|
|
// Empty phrase added? nothing to be done
|
|
|
|
if (hypo.GetCurrTargetLength() == 0){
|
|
|
|
return dalm_ps ? new DALMState(*dalm_ps) : NULL;
|
|
|
|
}
|
|
|
|
|
|
|
|
const std::size_t begin = hypo.GetCurrTargetWordsRange().GetStartPos();
|
|
|
|
//[begin, end) in STL-like fashion.
|
|
|
|
const std::size_t end = hypo.GetCurrTargetWordsRange().GetEndPos() + 1;
|
|
|
|
const std::size_t adjust_end = std::min(end, begin + m_nGramOrder - 1);
|
|
|
|
|
|
|
|
DALMState *dalm_state = new DALMState(*dalm_ps);
|
2014-02-14 15:22:53 +04:00
|
|
|
DALM::State *state = dalm_state->get_state();
|
2013-12-16 18:17:56 +04:00
|
|
|
|
|
|
|
float score = 0.0;
|
2014-02-14 15:22:53 +04:00
|
|
|
for(std::size_t position=begin; position < adjust_end; position++){
|
|
|
|
score += m_lm->query(GetVocabId(hypo.GetWord(position).GetFactor(m_factorType)), *state);
|
2013-12-16 18:17:56 +04:00
|
|
|
}
|
|
|
|
|
|
|
|
if (hypo.IsSourceCompleted()) {
|
|
|
|
// Score end of sentence.
|
|
|
|
std::vector<DALM::VocabId> indices(m_nGramOrder-1);
|
|
|
|
const DALM::VocabId *last = LastIDs(hypo, &indices.front());
|
2014-02-14 15:22:53 +04:00
|
|
|
m_lm->set_state(&indices.front(), (last-&indices.front()), *state);
|
2013-12-16 18:17:56 +04:00
|
|
|
|
2014-02-14 15:22:53 +04:00
|
|
|
score += m_lm->query(wid_end, *state);
|
2013-12-16 18:17:56 +04:00
|
|
|
} else if (adjust_end < end) {
|
|
|
|
// Get state after adding a long phrase.
|
|
|
|
std::vector<DALM::VocabId> indices(m_nGramOrder-1);
|
|
|
|
const DALM::VocabId *last = LastIDs(hypo, &indices.front());
|
2014-02-14 15:22:53 +04:00
|
|
|
m_lm->set_state(&indices.front(), (last-&indices.front()), *state);
|
2013-12-16 18:17:56 +04:00
|
|
|
}
|
|
|
|
|
2014-02-14 15:22:53 +04:00
|
|
|
score = TransformLMScore(score);
|
2013-12-16 18:17:56 +04:00
|
|
|
if (OOVFeatureEnabled()) {
|
|
|
|
std::vector<float> scores(2);
|
|
|
|
scores[0] = score;
|
|
|
|
scores[1] = 0.0;
|
|
|
|
out->PlusEquals(this, scores);
|
|
|
|
} else {
|
|
|
|
out->PlusEquals(this, score);
|
|
|
|
}
|
|
|
|
|
|
|
|
return dalm_state;
|
|
|
|
}
|
|
|
|
|
|
|
|
FFState *LanguageModelDALM::EvaluateChart(const ChartHypothesis& hypo, int featureID, ScoreComponentCollection *out) const{
|
|
|
|
// initialize language model context state
|
|
|
|
DALMState *dalm_state = new DALMState(m_nGramOrder);
|
2014-02-14 15:22:53 +04:00
|
|
|
DALM::State *state = dalm_state->get_state();
|
2013-12-16 18:17:56 +04:00
|
|
|
|
2014-02-17 19:00:28 +04:00
|
|
|
size_t contextSize = m_nGramOrder-1;
|
2014-02-28 08:38:35 +04:00
|
|
|
DALM::Fragment *prefixFragments = new DALM::Fragment[contextSize];
|
|
|
|
unsigned short prefixLength = 0;
|
2014-02-17 19:00:28 +04:00
|
|
|
bool isLarge = false;
|
|
|
|
|
2013-12-16 18:17:56 +04:00
|
|
|
// initial language model scores
|
|
|
|
float prefixScore = 0.0; // not yet final for initial words (lack context)
|
2014-02-19 12:49:42 +04:00
|
|
|
float hypoScore = 0.0; // total hypothesis score.
|
2014-02-17 19:00:28 +04:00
|
|
|
|
|
|
|
const TargetPhrase &targetPhrase = hypo.GetCurrTargetPhrase();
|
|
|
|
size_t hypoSize = targetPhrase.GetSize();
|
2013-12-16 18:17:56 +04:00
|
|
|
|
|
|
|
// get index map for underlying hypotheses
|
|
|
|
const AlignmentInfo::NonTermIndexMap &nonTermIndexMap =
|
2014-02-17 19:00:28 +04:00
|
|
|
targetPhrase.GetAlignNonTerm().GetNonTermIndexMap();
|
|
|
|
|
|
|
|
size_t phrasePos = 0;
|
|
|
|
|
|
|
|
// begginig of sentence.
|
|
|
|
if(hypoSize > 0){
|
|
|
|
const Word &word = targetPhrase.GetWord(0);
|
|
|
|
if(!word.IsNonTerminal()){
|
|
|
|
DALM::VocabId wid = GetVocabId(word.GetFactor(m_factorType));
|
|
|
|
if(word.GetFactor(m_factorType) == m_beginSentenceFactor){
|
|
|
|
m_lm->init_state(*state);
|
2014-02-19 12:49:42 +04:00
|
|
|
// state is finalized.
|
|
|
|
isLarge = true;
|
2014-02-17 19:00:28 +04:00
|
|
|
}else{
|
2014-02-28 08:38:35 +04:00
|
|
|
if(isLarge){
|
|
|
|
float score = m_lm->query(wid, *state);
|
|
|
|
hypoScore += score;
|
|
|
|
}else{
|
|
|
|
float score = m_lm->query(wid, *state, prefixFragments[prefixLength]);
|
|
|
|
|
2014-02-17 19:00:28 +04:00
|
|
|
prefixScore += score;
|
2014-02-28 08:38:35 +04:00
|
|
|
hypoScore += score;
|
2014-02-17 19:00:28 +04:00
|
|
|
prefixLength++;
|
2014-02-19 12:49:42 +04:00
|
|
|
if(prefixLength >= contextSize) isLarge = true;
|
|
|
|
}
|
2014-02-17 19:00:28 +04:00
|
|
|
}
|
|
|
|
}else{
|
|
|
|
// special case: rule starts with non-terminal -> copy everything
|
|
|
|
size_t nonTermIndex = nonTermIndexMap[0];
|
|
|
|
const ChartHypothesis *prevHypo = hypo.GetPrevHypo(nonTermIndex);
|
|
|
|
|
|
|
|
const DALMChartState* prevState =
|
|
|
|
static_cast<const DALMChartState*>(prevHypo->GetFFState(featureID));
|
|
|
|
|
2014-02-19 12:49:42 +04:00
|
|
|
// get prefixScore and hypoScore
|
2014-02-17 19:00:28 +04:00
|
|
|
prefixScore = prevState->GetPrefixScore();
|
2014-02-19 12:49:42 +04:00
|
|
|
hypoScore = UntransformLMScore(prevHypo->GetScoreBreakdown().GetScoresForProducer(this)[0]);
|
2014-02-17 19:00:28 +04:00
|
|
|
|
|
|
|
// get language model state
|
|
|
|
dalm_state->reset(*prevState->GetRightContext());
|
|
|
|
state = dalm_state->get_state();
|
2014-02-28 08:38:35 +04:00
|
|
|
|
2014-02-17 19:00:28 +04:00
|
|
|
prefixLength = prevState->GetPrefixLength();
|
2014-02-28 08:38:35 +04:00
|
|
|
const DALM::Fragment *prevPrefixFragments = prevState->GetPrefixFragments();
|
|
|
|
std::memcpy(prefixFragments, prevPrefixFragments, sizeof(DALM::Fragment)*prefixLength);
|
2014-02-19 12:49:42 +04:00
|
|
|
isLarge = prevState->LargeEnough();
|
2014-02-17 19:00:28 +04:00
|
|
|
}
|
|
|
|
phrasePos++;
|
|
|
|
}
|
2013-12-16 18:17:56 +04:00
|
|
|
|
|
|
|
// loop over rule
|
2014-02-17 19:00:28 +04:00
|
|
|
for (; phrasePos < hypoSize; phrasePos++) {
|
2013-12-16 18:17:56 +04:00
|
|
|
// consult rule for either word or non-terminal
|
2014-02-17 19:00:28 +04:00
|
|
|
const Word &word = targetPhrase.GetWord(phrasePos);
|
2013-12-16 18:17:56 +04:00
|
|
|
|
|
|
|
// regular word
|
|
|
|
if (!word.IsNonTerminal()) {
|
2014-02-17 19:00:28 +04:00
|
|
|
DALM::VocabId wid = GetVocabId(word.GetFactor(m_factorType));
|
2014-02-28 08:38:35 +04:00
|
|
|
if (isLarge) {
|
|
|
|
hypoScore += m_lm->query(wid, *state);
|
|
|
|
}else{
|
|
|
|
float score = m_lm->query(wid, *state, prefixFragments[prefixLength]);
|
2014-02-17 19:00:28 +04:00
|
|
|
prefixScore += score;
|
2014-02-28 08:38:35 +04:00
|
|
|
hypoScore += score;
|
2014-02-17 19:00:28 +04:00
|
|
|
prefixLength++;
|
2014-02-19 12:49:42 +04:00
|
|
|
if(prefixLength >= contextSize) isLarge = true;
|
2014-02-17 19:00:28 +04:00
|
|
|
}
|
2013-12-16 18:17:56 +04:00
|
|
|
}
|
|
|
|
|
|
|
|
// non-terminal, add phrase from underlying hypothesis
|
2014-02-17 19:00:28 +04:00
|
|
|
// internal non-terminal
|
2013-12-16 18:17:56 +04:00
|
|
|
else {
|
|
|
|
// look up underlying hypothesis
|
|
|
|
size_t nonTermIndex = nonTermIndexMap[phrasePos];
|
|
|
|
const ChartHypothesis *prevHypo = hypo.GetPrevHypo(nonTermIndex);
|
|
|
|
|
2014-02-17 19:00:28 +04:00
|
|
|
const DALMChartState* prevState =
|
|
|
|
static_cast<const DALMChartState*>(prevHypo->GetFFState(featureID));
|
2013-12-16 18:17:56 +04:00
|
|
|
|
2014-02-17 19:00:28 +04:00
|
|
|
size_t prevPrefixLength = prevState->GetPrefixLength();
|
2014-02-28 08:38:35 +04:00
|
|
|
const DALM::Fragment *prevPrefixFragments = prevState->GetPrefixFragments();
|
|
|
|
DALM::Gap gap(*state);
|
2014-02-17 19:00:28 +04:00
|
|
|
// score its prefix
|
|
|
|
for(size_t prefixPos = 0; prefixPos < prevPrefixLength; prefixPos++) {
|
2014-02-28 08:38:35 +04:00
|
|
|
const DALM::Fragment &f = prevPrefixFragments[prefixPos];
|
|
|
|
|
|
|
|
if (isLarge) {
|
|
|
|
hypoScore += m_lm->query(f, *state, gap);
|
|
|
|
} else {
|
|
|
|
float score = m_lm->query(f, *state, gap, prefixFragments[prefixLength]);
|
2014-02-17 19:00:28 +04:00
|
|
|
prefixScore += score;
|
2014-02-28 08:38:35 +04:00
|
|
|
hypoScore += score;
|
2014-02-17 19:00:28 +04:00
|
|
|
prefixLength++;
|
2014-02-19 12:49:42 +04:00
|
|
|
if(prefixLength >= contextSize) isLarge = true;
|
2014-02-17 19:00:28 +04:00
|
|
|
}
|
2014-02-28 08:38:35 +04:00
|
|
|
gap.succ();
|
2014-02-17 19:00:28 +04:00
|
|
|
}
|
2013-12-16 18:17:56 +04:00
|
|
|
|
2014-02-17 19:00:28 +04:00
|
|
|
// check if we are dealing with a large sub-phrase
|
|
|
|
if (prevState->LargeEnough()) {
|
2014-02-19 12:49:42 +04:00
|
|
|
// add its language model score
|
|
|
|
hypoScore += UntransformLMScore(prevHypo->GetScoreBreakdown().GetScoresForProducer(this)[0]);
|
|
|
|
hypoScore -= prevState->GetPrefixScore(); // remove overwrapped score.
|
2014-02-28 08:38:35 +04:00
|
|
|
// copy language model state
|
2014-02-17 19:00:28 +04:00
|
|
|
dalm_state->reset(*prevState->GetRightContext());
|
2014-02-14 15:22:53 +04:00
|
|
|
state = dalm_state->get_state();
|
2014-02-28 08:38:35 +04:00
|
|
|
} else {
|
|
|
|
DALM::State *state_new = new DALM::State(*prevState->GetRightContext()->get_state());
|
|
|
|
m_lm->set_state(*state_new, *state, gap);
|
|
|
|
dalm_state->reset(state_new);
|
|
|
|
state = dalm_state->get_state();
|
|
|
|
}
|
2013-12-16 18:17:56 +04:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// assign combined score to score breakdown
|
2014-02-19 12:49:42 +04:00
|
|
|
out->Assign(this, TransformLMScore(hypoScore));
|
2013-11-05 18:37:56 +04:00
|
|
|
|
2014-02-28 08:38:35 +04:00
|
|
|
return new DALMChartState(hypo, prefixFragments, prefixLength, prefixScore, dalm_state, isLarge);
|
2013-12-16 18:17:56 +04:00
|
|
|
}
|
|
|
|
|
|
|
|
bool LanguageModelDALM::IsUseable(const FactorMask &mask) const
|
|
|
|
{
|
2014-02-17 19:00:28 +04:00
|
|
|
return mask[m_factorType];
|
2013-11-05 18:37:56 +04:00
|
|
|
}
|
|
|
|
|
2013-11-18 17:54:40 +04:00
|
|
|
void LanguageModelDALM::CreateVocabMapping(const std::string &wordstxt)
|
|
|
|
{
|
|
|
|
InputFileStream vocabStrm(wordstxt);
|
|
|
|
|
2014-02-14 15:22:53 +04:00
|
|
|
std::vector< std::pair<std::size_t, DALM::VocabId> > vlist;
|
2013-11-18 17:54:40 +04:00
|
|
|
string line;
|
2014-02-14 15:22:53 +04:00
|
|
|
std::size_t max_fid = 0;
|
2013-11-18 17:54:40 +04:00
|
|
|
while(getline(vocabStrm, line)) {
|
|
|
|
const Factor *factor = FactorCollection::Instance().AddFactor(line);
|
2014-02-14 15:22:53 +04:00
|
|
|
std::size_t fid = factor->GetId();
|
2013-11-18 17:54:40 +04:00
|
|
|
DALM::VocabId wid = m_vocab->lookup(line.c_str());
|
|
|
|
|
2014-02-14 15:22:53 +04:00
|
|
|
vlist.push_back(std::pair<std::size_t, DALM::VocabId>(fid, wid));
|
|
|
|
if(max_fid < fid) max_fid = fid;
|
2013-11-18 17:54:40 +04:00
|
|
|
}
|
|
|
|
|
2014-02-14 15:22:53 +04:00
|
|
|
for(std::size_t i = 0; i < m_vocabMap.size(); i++){
|
|
|
|
m_vocabMap[i] = m_vocab->unk();
|
|
|
|
}
|
|
|
|
|
|
|
|
m_vocabMap.resize(max_fid+1, m_vocab->unk());
|
|
|
|
std::vector< std::pair<std::size_t, DALM::VocabId> >::iterator it = vlist.begin();
|
|
|
|
while(it != vlist.end()){
|
|
|
|
std::pair<std::size_t, DALM::VocabId> &entry = *it;
|
|
|
|
m_vocabMap[entry.first] = entry.second;
|
|
|
|
|
|
|
|
++it;
|
|
|
|
}
|
2013-11-18 17:54:40 +04:00
|
|
|
}
|
|
|
|
|
2013-11-11 23:49:00 +04:00
|
|
|
DALM::VocabId LanguageModelDALM::GetVocabId(const Factor *factor) const
|
|
|
|
{
|
2014-02-14 15:22:53 +04:00
|
|
|
std::size_t fid = factor->GetId();
|
|
|
|
return (m_vocabMap.size() > fid)? m_vocabMap[fid] : m_vocab->unk();
|
2013-11-11 23:49:00 +04:00
|
|
|
}
|
|
|
|
|
2013-12-16 18:17:56 +04:00
|
|
|
void LanguageModelDALM::SetParameter(const std::string& key, const std::string& value)
|
|
|
|
{
|
|
|
|
if (key == "factor") {
|
|
|
|
m_factorType = Scan<FactorType>(value);
|
|
|
|
} else if (key == "order") {
|
|
|
|
m_nGramOrder = Scan<size_t>(value);
|
|
|
|
} else if (key == "path") {
|
|
|
|
m_filePath = value;
|
|
|
|
} else {
|
|
|
|
LanguageModel::SetParameter(key, value);
|
|
|
|
}
|
2013-11-05 18:37:56 +04:00
|
|
|
}
|
|
|
|
|
2013-12-16 18:17:56 +04:00
|
|
|
}
|