mosesdecoder/moses/LM/RDLM.h
2015-12-11 01:09:22 +00:00

290 lines
9.0 KiB
C++

#include <string>
#include <map>
#include "moses/FF/StatefulFeatureFunction.h"
#include "moses/FF/FFState.h"
#include "moses/FF/InternalTree.h"
#include "moses/Word.h"
#include <boost/thread/tss.hpp>
#include <boost/array.hpp>
#ifdef WITH_THREADS
#include <boost/thread/shared_mutex.hpp>
#endif
// relational dependency language model, described in:
// Sennrich, Rico (2015). Modelling and Optimizing on Syntactic N-Grams for Statistical Machine Translation. Transactions of the Association for Computational Linguistics.
// see 'scripts/training/rdlm' for training scripts
namespace nplm
{
class neuralTM;
}
namespace Moses
{
namespace rdlm
{
// we re-use some short-lived objects to reduce the number of allocations;
// each thread gets its own instance to prevent collision
// [could be replaced with thread_local keyword in C++11]
class ThreadLocal
{
public:
std::vector<int> ancestor_heads;
std::vector<int> ancestor_labels;
std::vector<int> ngram;
std::vector<int> heads;
std::vector<int> labels;
std::vector<int> heads_output;
std::vector<int> labels_output;
std::vector<std::pair<InternalTree*,std::vector<TreePointer>::const_iterator> > stack;
nplm::neuralTM* lm_head;
nplm::neuralTM* lm_label;
ThreadLocal(nplm::neuralTM *lm_head_base_instance_, nplm::neuralTM *lm_label_base_instance_, bool normalizeHeadLM, bool normalizeLabelLM, int cacheSize);
~ThreadLocal();
};
}
class RDLMState : public TreeState
{
float m_approx_head; //score that was approximated due to lack of context
float m_approx_label;
size_t m_hash;
public:
RDLMState(TreePointer tree, float approx_head, float approx_label, size_t hash)
: TreeState(tree)
, m_approx_head(approx_head)
, m_approx_label(approx_label)
, m_hash(hash)
{}
float GetApproximateScoreHead() const {
return m_approx_head;
}
float GetApproximateScoreLabel() const {
return m_approx_label;
}
size_t GetHash() const {
return m_hash;
}
int Compare(const FFState& other) const {
if (m_hash == static_cast<const RDLMState*>(&other)->GetHash()) return 0;
else if (m_hash > static_cast<const RDLMState*>(&other)->GetHash()) return 1;
else return -1;
}
};
class RDLM : public StatefulFeatureFunction
{
typedef std::map<InternalTree*,TreePointer> TreePointerMap;
nplm::neuralTM* lm_head_base_instance_;
nplm::neuralTM* lm_label_base_instance_;
mutable boost::thread_specific_ptr<rdlm::ThreadLocal> thread_objects_backend_;
std::string m_glueSymbolString;
Word dummy_head;
Word m_glueSymbol;
Word m_startSymbol;
Word m_endSymbol;
Word m_endTag;
std::string m_path_head_lm;
std::string m_path_label_lm;
bool m_isPretermBackoff;
size_t m_context_left;
size_t m_context_right;
size_t m_context_up;
bool m_premultiply;
bool m_rerank;
bool m_normalizeHeadLM;
bool m_normalizeLabelLM;
bool m_sharedVocab;
std::string m_debugPath; // score all trees in the provided file, then exit
int m_binarized;
int m_cacheSize;
size_t offset_up_head;
size_t offset_up_label;
size_t size_head;
size_t size_label;
std::vector<int> static_label_null;
std::vector<int> static_head_null;
int static_dummy_head;
int static_start_head;
int static_start_label;
int static_stop_head;
int static_stop_label;
int static_head_head;
int static_head_label;
int static_root_head;
int static_root_label;
int static_head_label_output;
int static_stop_label_output;
int static_start_label_output;
FactorType m_factorType;
static const int LABEL_INPUT = 0;
static const int LABEL_OUTPUT = 1;
static const int HEAD_INPUT = 2;
static const int HEAD_OUTPUT = 3;
mutable std::vector<int> factor2id_label_input;
mutable std::vector<int> factor2id_label_output;
mutable std::vector<int> factor2id_head_input;
mutable std::vector<int> factor2id_head_output;
#ifdef WITH_THREADS
//reader-writer lock
mutable boost::shared_mutex m_accessLock;
#endif
public:
RDLM(const std::string &line)
: StatefulFeatureFunction(2, line)
, m_glueSymbolString("Q")
, m_isPretermBackoff(true)
, m_context_left(3)
, m_context_right(0)
, m_context_up(2)
, m_premultiply(true)
, m_rerank(false)
, m_normalizeHeadLM(false)
, m_normalizeLabelLM(false)
, m_sharedVocab(false)
, m_binarized(0)
, m_cacheSize(1000000)
, m_factorType(0) {
ReadParameters();
std::vector<FactorType> factors;
factors.push_back(0);
dummy_head.CreateFromString(Output, factors, "<dummy_head>", false);
m_glueSymbol.CreateFromString(Output, factors, m_glueSymbolString, true);
m_startSymbol.CreateFromString(Output, factors, "SSTART", true);
m_endSymbol.CreateFromString(Output, factors, "SEND", true);
m_endTag.CreateFromString(Output, factors, "</s>", false);
}
~RDLM();
virtual const FFState* EmptyHypothesisState(const InputType &input) const {
return new RDLMState(TreePointer(), 0, 0, 0);
}
void Score(InternalTree* root, const TreePointerMap & back_pointers, boost::array<float,4> &score, size_t &boundary_hash, rdlm::ThreadLocal &thread_objects, int num_virtual = 0, int rescoring_levels = 0) const;
bool GetHead(InternalTree* root, const TreePointerMap & back_pointers, std::pair<int,int> & IDs) const;
void GetChildHeadsAndLabels(InternalTree *root, const TreePointerMap & back_pointers, int reached_end, rdlm::ThreadLocal &thread_objects) const;
void GetIDs(const Word & head, const Word & preterminal, std::pair<int,int> & IDs) const;
int Factor2ID(const Factor * const factor, int model_type) const;
void ScoreFile(std::string &path); //for debugging
void PrintInfo(std::vector<int> &ngram, nplm::neuralTM* lm) const; //for debugging
TreePointerMap AssociateLeafNTs(InternalTree* root, const std::vector<TreePointer> &previous) const;
bool IsUseable(const FactorMask &mask) const {
return true;
}
void SetParameter(const std::string& key, const std::string& value);
FFState* EvaluateWhenApplied(
const Hypothesis& cur_hypo,
const FFState* prev_state,
ScoreComponentCollection* accumulator) const {
UTIL_THROW(util::Exception, "Not implemented");
};
FFState* EvaluateWhenApplied(
const ChartHypothesis& /* cur_hypo */,
int /* featureID - used to index the state in the previous hypotheses */,
ScoreComponentCollection* accumulator) const;
void Load(AllOptions::ptr const& opts);
// Iterator-class that yields all children of a node; if child is virtual node of binarized tree, its children are yielded instead.
class UnbinarizedChildren
{
private:
std::vector<TreePointer>::const_iterator iter;
std::vector<TreePointer>::const_iterator _begin;
bool _ended;
InternalTree* current;
const TreePointerMap & back_pointers;
bool binarized;
std::vector<std::pair<InternalTree*,std::vector<TreePointer>::const_iterator> > &stack;
public:
UnbinarizedChildren(InternalTree* root, const TreePointerMap & pointers, bool binary, std::vector<std::pair<InternalTree*,std::vector<TreePointer>::const_iterator> > & persistent_stack):
current(root),
back_pointers(pointers),
binarized(binary),
stack(persistent_stack) {
stack.resize(0);
_ended = current->GetChildren().empty();
iter = current->GetChildren().begin();
// expand virtual node
while (binarized && !(*iter)->GetLabel().GetString(0).empty() && (*iter)->GetLabel().GetString(0).data()[0] == '^') {
stack.push_back(std::make_pair(current, iter));
// also go through trees or previous hypotheses to rescore nodes for which more context has become available
if ((*iter)->IsLeafNT()) {
current = back_pointers.find(iter->get())->second.get();
} else {
current = iter->get();
}
iter = current->GetChildren().begin();
}
_begin = iter;
}
std::vector<TreePointer>::const_iterator begin() const {
return _begin;
}
bool ended() const {
return _ended;
}
std::vector<TreePointer>::const_iterator operator++() {
iter++;
if (iter == current->GetChildren().end()) {
while (!stack.empty()) {
std::pair<InternalTree*,std::vector<TreePointer>::const_iterator> & active = stack.back();
current = active.first;
iter = ++active.second;
stack.pop_back();
if (iter != current->GetChildren().end()) {
break;
}
}
if (iter == current->GetChildren().end()) {
_ended = true;
return iter;
}
}
// expand virtual node
while (binarized && !(*iter)->GetLabel().GetString(0).empty() && (*iter)->GetLabel().GetString(0).data()[0] == '^') {
stack.push_back(std::make_pair(current, iter));
// also go through trees or previous hypotheses to rescore nodes for which more context has become available
if ((*iter)->IsLeafNT()) {
current = back_pointers.find(iter->get())->second.get();
} else {
current = iter->get();
}
iter = current->GetChildren().begin();
}
return iter;
}
};
};
}