mosesdecoder/moses/LM/BilingualLM.h

137 lines
3.9 KiB
C
Raw Permalink Normal View History

2014-08-21 19:10:20 +04:00
#pragma once
#include <string>
#include "moses/FF/StatefulFeatureFunction.h"
#include "moses/FF/FFState.h"
2014-08-21 20:46:48 +04:00
#include <boost/thread/tss.hpp>
2014-08-26 15:31:35 +04:00
#include "moses/Hypothesis.h"
#include "moses/ChartHypothesis.h"
2014-08-26 15:31:35 +04:00
#include "moses/InputPath.h"
#include "moses/Manager.h"
#include "moses/ChartManager.h"
2014-08-29 16:43:31 +04:00
#include "moses/FactorCollection.h"
2014-08-21 20:05:06 +04:00
2014-08-21 19:10:20 +04:00
namespace Moses
{
class BilingualLMState : public FFState
{
2014-09-02 19:07:18 +04:00
size_t m_hash;
std::vector<int> word_alignments; //Carry the word alignments. For hierarchical
std::vector<int> neuralLM_ids; //Carry the neuralLMids of the previous target phrase to avoid calling GetWholePhrase. Hiero only.
2014-08-21 19:10:20 +04:00
public:
2014-09-02 19:07:18 +04:00
BilingualLMState(size_t hash)
2015-01-14 14:07:42 +03:00
:m_hash(hash) {
}
BilingualLMState(size_t hash, std::vector<int>& word_alignments_vec, std::vector<int>& neural_ids)
:m_hash(hash)
, word_alignments(word_alignments_vec)
2015-01-14 14:07:42 +03:00
, neuralLM_ids(neural_ids) {
}
2014-09-02 20:05:46 +04:00
const std::vector<int>& GetWordAlignmentVector() const {
return word_alignments;
2014-09-02 20:05:46 +04:00
}
2014-08-21 19:10:20 +04:00
const std::vector<int>& GetWordIdsVector() const {
return neuralLM_ids;
}
virtual size_t hash() const {
2015-10-16 15:53:33 +03:00
return m_hash;
}
2015-10-16 15:53:33 +03:00
virtual bool operator==(const FFState& other) const {
const BilingualLMState &otherState = static_cast<const BilingualLMState&>(other);
return m_hash == otherState.m_hash;
}
2014-08-21 19:10:20 +04:00
};
2015-01-14 14:07:42 +03:00
class BilingualLM : public StatefulFeatureFunction
{
private:
2014-09-27 02:20:26 +04:00
virtual float Score(std::vector<int>& source_words, std::vector<int>& target_words) const = 0;
2014-08-21 20:05:06 +04:00
2014-09-27 02:20:26 +04:00
virtual int getNeuralLMId(const Word& word, bool is_source_word) const = 0;
virtual void loadModel() = 0;
2014-10-14 21:22:04 +04:00
virtual const Word& getNullWord() const = 0;
size_t selectMiddleAlignment(const std::set<size_t>& alignment_links) const;
void getSourceWords(
2015-01-14 14:07:42 +03:00
const TargetPhrase &targetPhrase,
int targetWordIdx,
const Sentence &source_sent,
2015-10-25 16:37:59 +03:00
const Range &sourceWordRange,
2015-01-14 14:07:42 +03:00
std::vector<int> &words) const;
void appendSourceWordsToVector(const Sentence &source_sent, std::vector<int> &words, int source_word_mid_idx) const;
void getTargetWords(
2015-01-14 14:07:42 +03:00
const Hypothesis &cur_hypo,
const TargetPhrase &targetPhrase,
int current_word_index,
std::vector<int> &words) const;
size_t getState(const Hypothesis &cur_hypo) const;
void requestPrevTargetNgrams(const Hypothesis &cur_hypo, int amount, std::vector<int> &words) const;
2014-08-28 21:09:47 +04:00
//Chart decoder
void getTargetWordsChart(
std::vector<int>& neuralLMids,
int current_word_index,
2014-10-15 20:00:09 +04:00
std::vector<int>& words,
bool sentence_begin) const;
size_t getStateChart(std::vector<int>& neuralLMids) const;
//Get a vector of all target words IDs in the beginning of calculating NeuralLMids for the current phrase.
void getAllTargetIdsChart(const ChartHypothesis& cur_hypo, size_t featureID, std::vector<int>& wordIds) const;
//Get a vector of all alignments (mid_idx word)
void getAllAlignments(const ChartHypothesis& cur_hypo, size_t featureID, std::vector<int>& alignemnts) const;
2014-09-02 20:05:46 +04:00
2014-08-21 20:05:06 +04:00
protected:
// big data (vocab, weights, cache) shared among threads
std::string m_filePath;
2014-08-22 15:41:13 +04:00
int target_ngrams;
int source_ngrams;
//NeuralLM lookup
2014-09-03 02:27:16 +04:00
FactorType word_factortype;
FactorType pos_factortype;
2014-08-29 16:43:31 +04:00
const Factor* BOS_factor;
const Factor* EOS_factor;
2014-10-16 15:36:19 +04:00
mutable Word BOS_word;
mutable Word EOS_word;
2014-08-21 19:10:20 +04:00
public:
BilingualLM(const std::string &line);
bool IsUseable(const FactorMask &mask) const {
return true;
}
virtual const FFState* EmptyHypothesisState(const InputType &input) const {
return new BilingualLMState(0);
}
2015-12-10 06:17:36 +03:00
void Load(AllOptions::ptr const& opts);
2014-08-21 20:05:06 +04:00
2014-08-21 19:10:20 +04:00
FFState* EvaluateWhenApplied(
2015-01-14 14:07:42 +03:00
const Hypothesis& cur_hypo,
const FFState* prev_state,
ScoreComponentCollection* accumulator) const;
2014-09-26 19:25:48 +04:00
2014-08-21 19:10:20 +04:00
FFState* EvaluateWhenApplied(
2015-01-14 14:07:42 +03:00
const ChartHypothesis& cur_hypo ,
int featureID, /* - used to index the state in the previous hypotheses */
ScoreComponentCollection* accumulator) const;
2014-08-21 19:10:20 +04:00
void SetParameter(const std::string& key, const std::string& value);
};
}