From bec950cf728993ab6b7f846cc57f11d34d3f41f1 Mon Sep 17 00:00:00 2001 From: Rico Sennrich Date: Fri, 17 Jul 2015 17:27:31 +0100 Subject: [PATCH] support factors in InternalTree --- mert/HwcmScorer.h | 5 +- mert/InternalTree.cpp | 113 ++++++++++++++++++++++++ mert/InternalTree.h | 77 +++++++++++++++++ mert/Jamfile | 2 +- moses/FF/InternalTree.cpp | 137 ++++++------------------------ moses/FF/InternalTree.h | 50 +++-------- moses/FF/TreeStructureFeature.cpp | 27 +----- moses/FF/TreeStructureFeature.h | 10 ++- moses/LM/RDLM.cpp | 30 ++++--- moses/LM/RDLM.h | 36 +++++--- 10 files changed, 274 insertions(+), 213 deletions(-) create mode 100644 mert/InternalTree.cpp create mode 100644 mert/InternalTree.h diff --git a/mert/HwcmScorer.h b/mert/HwcmScorer.h index 16d563424..2e52f0be9 100644 --- a/mert/HwcmScorer.h +++ b/mert/HwcmScorer.h @@ -5,10 +5,7 @@ #include #include "StatisticsBasedScorer.h" -#include "moses/FF/InternalTree.h" - -using Moses::TreePointer; -using Moses::InternalTree; +#include "InternalTree.h" namespace MosesTuning { diff --git a/mert/InternalTree.cpp b/mert/InternalTree.cpp new file mode 100644 index 000000000..d82fbcc72 --- /dev/null +++ b/mert/InternalTree.cpp @@ -0,0 +1,113 @@ +#include "InternalTree.h" + +namespace MosesTuning +{ + +InternalTree::InternalTree(const std::string & line, const bool terminal): + m_isTerminal(terminal) + { + + size_t found = line.find_first_of("[] "); + + if (found == line.npos) { + m_value = line; + } + + else { + AddSubTree(line, 0); + } +} + +size_t InternalTree::AddSubTree(const std::string & line, size_t pos) { + + std::string value; + char token = 0; + + while (token != ']' && pos != std::string::npos) + { + size_t oldpos = pos; + pos = line.find_first_of("[] ", pos); + if (pos == std::string::npos) break; + token = line[pos]; + value = line.substr(oldpos,pos-oldpos); + + if (token == '[') { + if (m_value.size() > 0) { + m_children.push_back(boost::make_shared(value,false)); + pos = m_children.back()->AddSubTree(line, pos+1); + } + else { + if (value.size() > 0) { + m_value = value; + } + pos = AddSubTree(line, pos+1); + } + } + else if (token == ' ' || token == ']') { + if (value.size() > 0 && !(m_value.size() > 0)) { + m_value = value; + } + else if (value.size() > 0) { + m_isTerminal = false; + m_children.push_back(boost::make_shared(value,true)); + } + if (token == ' ') { + pos++; + } + } + + if (m_children.size() > 0) { + m_isTerminal = false; + } + } + + if (pos == std::string::npos) { + return line.size(); + } + return std::min(line.size(),pos+1); + +} + +std::string InternalTree::GetString(bool start) const { + + std::string ret = ""; + if (!start) { + ret += " "; + } + + if (!m_isTerminal) { + ret += "["; + } + + ret += m_value; + for (std::vector::const_iterator it = m_children.begin(); it != m_children.end(); ++it) + { + ret += (*it)->GetString(false); + } + + if (!m_isTerminal) { + ret += "]"; + } + return ret; + +} + + +void InternalTree::Combine(const std::vector &previous) { + + std::vector::iterator it; + bool found = false; + leafNT next_leafNT(this); + for (std::vector::const_iterator it_prev = previous.begin(); it_prev != previous.end(); ++it_prev) { + found = next_leafNT(it); + if (found) { + *it = *it_prev; + } + else { + std::cerr << "Warning: leaf nonterminal not found in rule; why did this happen?\n"; + } + } +} + + +} diff --git a/mert/InternalTree.h b/mert/InternalTree.h new file mode 100644 index 000000000..f8416101c --- /dev/null +++ b/mert/InternalTree.h @@ -0,0 +1,77 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include "util/generator.hh" +#include "util/exception.hh" + +namespace MosesTuning +{ + +class InternalTree; +typedef boost::shared_ptr TreePointer; +typedef int NTLabel; + +class InternalTree +{ +std::string m_value; +std::vector m_children; +bool m_isTerminal; +public: + InternalTree(const std::string & line, const bool terminal = false); + InternalTree(const InternalTree & tree): + m_value(tree.m_value), + m_isTerminal(tree.m_isTerminal) { + const std::vector & children = tree.m_children; + for (std::vector::const_iterator it = children.begin(); it != children.end(); it++) { + m_children.push_back(boost::make_shared(**it)); + } + } + size_t AddSubTree(const std::string & line, size_t start); + + std::string GetString(bool start = true) const; + void Combine(const std::vector &previous); + const std::string & GetLabel() const { + return m_value; + } + + size_t GetLength() const { + return m_children.size(); + } + std::vector & GetChildren() { + return m_children; + } + + bool IsTerminal() const { + return m_isTerminal; + } + + bool IsLeafNT() const { + return (!m_isTerminal && m_children.size() == 0); + } +}; + +// Python-like generator that yields next nonterminal leaf on every call +$generator(leafNT) { + std::vector::iterator it; + InternalTree* tree; + leafNT(InternalTree* root = 0): tree(root) {} + $emit(std::vector::iterator) + for (it = tree->GetChildren().begin(); it !=tree->GetChildren().end(); ++it) { + if (!(*it)->IsTerminal() && (*it)->GetLength() == 0) { + $yield(it); + } + else if ((*it)->GetLength() > 0) { + if ((*it).get()) { // normal pointer to same object that TreePointer points to + $restart(tree = (*it).get()); + } + } + } + $stop; +}; + +} \ No newline at end of file diff --git a/mert/Jamfile b/mert/Jamfile index 7a8d98bae..e5adce76e 100644 --- a/mert/Jamfile +++ b/mert/Jamfile @@ -30,7 +30,7 @@ InterpolatedScorer.cpp Point.cpp PerScorer.cpp HwcmScorer.cpp -../moses/FF/InternalTree.cpp +InternalTree.cpp Scorer.cpp ScorerFactory.cpp Optimizer.cpp diff --git a/moses/FF/InternalTree.cpp b/moses/FF/InternalTree.cpp index 4a01ea1b2..c38fc5747 100644 --- a/moses/FF/InternalTree.cpp +++ b/moses/FF/InternalTree.cpp @@ -1,27 +1,24 @@ #include "InternalTree.h" +#include "moses/StaticData.h" namespace Moses { -InternalTree::InternalTree(const std::string & line, size_t start, size_t len, const bool terminal): - m_value_nt(0), - m_isTerminal(terminal) +InternalTree::InternalTree(const std::string & line, size_t start, size_t len, const bool nonterminal) { if (len > 0) { - m_value.assign(line, start, len); + m_value.CreateFromString(Output, StaticData::Instance().GetOutputFactorOrder(), StringPiece(line).substr(start, len), nonterminal); } } -InternalTree::InternalTree(const std::string & line, const bool terminal): - m_value_nt(0), - m_isTerminal(terminal) +InternalTree::InternalTree(const std::string & line, const bool nonterminal) { size_t found = line.find_first_of("[] "); if (found == line.npos) { - m_value = line; + m_value.CreateFromString(Output, StaticData::Instance().GetOutputFactorOrder(), line, nonterminal); } else { AddSubTree(line, 0); } @@ -32,6 +29,7 @@ size_t InternalTree::AddSubTree(const std::string & line, size_t pos) char token = 0; size_t len = 0; + bool has_value = false; while (token != ']' && pos != std::string::npos) { size_t oldpos = pos; @@ -41,30 +39,27 @@ size_t InternalTree::AddSubTree(const std::string & line, size_t pos) len = pos-oldpos; if (token == '[') { - if (!m_value.empty()) { - m_children.push_back(boost::make_shared(line, oldpos, len, false)); + if (has_value) { + m_children.push_back(boost::make_shared(line, oldpos, len, true)); pos = m_children.back()->AddSubTree(line, pos+1); } else { if (len > 0) { - m_value.assign(line, oldpos, len); + m_value.CreateFromString(Output, StaticData::Instance().GetOutputFactorOrder(), StringPiece(line).substr(oldpos, len), false); + has_value = true; } pos = AddSubTree(line, pos+1); } } else if (token == ' ' || token == ']') { - if (len > 0 && m_value.empty()) { - m_value.assign(line, oldpos, len); + if (len > 0 && !has_value) { + m_value.CreateFromString(Output, StaticData::Instance().GetOutputFactorOrder(), StringPiece(line).substr(oldpos, len), true); + has_value = true; } else if (len > 0) { - m_isTerminal = false; - m_children.push_back(boost::make_shared(line, oldpos, len, true)); + m_children.push_back(boost::make_shared(line, oldpos, len, false)); } if (token == ' ') { pos++; } } - - if (!m_children.empty()) { - m_isTerminal = false; - } } if (pos == std::string::npos) { @@ -82,16 +77,16 @@ std::string InternalTree::GetString(bool start) const ret += " "; } - if (!m_isTerminal) { + if (!IsTerminal()) { ret += "["; } - ret += m_value; + ret += m_value.GetString(StaticData::Instance().GetOutputFactorOrder(), false); for (std::vector::const_iterator it = m_children.begin(); it != m_children.end(); ++it) { ret += (*it)->GetString(false); } - if (!m_isTerminal) { + if (!IsTerminal()) { ret += "]"; } return ret; @@ -120,13 +115,13 @@ void InternalTree::Unbinarize() { // nodes with virtual label cannot be unbinarized - if (m_value.empty() || m_value[0] == '^') { + if (m_value.GetString(0).empty() || m_value.GetString(0).as_string()[0] == '^') { return; } //if node has child that is virtual node, get unbinarized list of children for (std::vector::iterator it = m_children.begin(); it != m_children.end(); ++it) { - if (!(*it)->IsTerminal() && (*it)->GetLabel()[0] == '^') { + if (!(*it)->IsTerminal() && (*it)->GetLabel().GetString(0).as_string()[0] == '^') { std::vector new_children; GetUnbinarizedChildren(new_children); m_children = new_children; @@ -144,8 +139,8 @@ void InternalTree::Unbinarize() void InternalTree::GetUnbinarizedChildren(std::vector &ret) const { for (std::vector::const_iterator itx = m_children.begin(); itx != m_children.end(); ++itx) { - const std::string &label = (*itx)->GetLabel(); - if (!label.empty() && label[0] == '^') { + const StringPiece label = (*itx)->GetLabel().GetString(0); + if (!label.empty() && label.as_string()[0] == '^') { (*itx)->GetUnbinarizedChildren(ret); } else { ret.push_back(*itx); @@ -153,7 +148,7 @@ void InternalTree::GetUnbinarizedChildren(std::vector &ret) const } } -bool InternalTree::FlatSearch(const std::string & label, std::vector::const_iterator & it) const +bool InternalTree::FlatSearch(const Word & label, std::vector::const_iterator & it) const { for (it = m_children.begin(); it != m_children.end(); ++it) { if ((*it)->GetLabel() == label) { @@ -163,7 +158,7 @@ bool InternalTree::FlatSearch(const std::string & label, std::vector::const_iterator & it) const +bool InternalTree::RecursiveSearch(const Word & label, std::vector::const_iterator & it) const { for (it = m_children.begin(); it != m_children.end(); ++it) { if ((*it)->GetLabel() == label) { @@ -178,7 +173,7 @@ bool InternalTree::RecursiveSearch(const std::string & label, std::vector::const_iterator & it, InternalTree const* &parent) const +bool InternalTree::RecursiveSearch(const Word & label, std::vector::const_iterator & it, InternalTree const* &parent) const { for (it = m_children.begin(); it != m_children.end(); ++it) { if ((*it)->GetLabel() == label) { @@ -194,88 +189,4 @@ bool InternalTree::RecursiveSearch(const std::string & label, std::vector::const_iterator & it) const -{ - for (it = m_children.begin(); it != m_children.end(); ++it) { - if ((*it)->GetNTLabel() == label) { - return true; - } - } - return false; -} - -bool InternalTree::RecursiveSearch(const NTLabel & label, std::vector::const_iterator & it) const -{ - for (it = m_children.begin(); it != m_children.end(); ++it) { - if ((*it)->GetNTLabel() == label) { - return true; - } - std::vector::const_iterator it2; - if ((*it)->RecursiveSearch(label, it2)) { - it = it2; - return true; - } - } - return false; -} - -bool InternalTree::RecursiveSearch(const NTLabel & label, std::vector::const_iterator & it, InternalTree const* &parent) const -{ - for (it = m_children.begin(); it != m_children.end(); ++it) { - if ((*it)->GetNTLabel() == label) { - parent = this; - return true; - } - std::vector::const_iterator it2; - if ((*it)->RecursiveSearch(label, it2, parent)) { - it = it2; - return true; - } - } - return false; -} - - -bool InternalTree::FlatSearch(const std::vector & labels, std::vector::const_iterator & it) const -{ - for (it = m_children.begin(); it != m_children.end(); ++it) { - if (std::binary_search(labels.begin(), labels.end(), (*it)->GetNTLabel())) { - return true; - } - } - return false; -} - -bool InternalTree::RecursiveSearch(const std::vector & labels, std::vector::const_iterator & it) const -{ - for (it = m_children.begin(); it != m_children.end(); ++it) { - if (std::binary_search(labels.begin(), labels.end(), (*it)->GetNTLabel())) { - return true; - } - std::vector::const_iterator it2; - if ((*it)->RecursiveSearch(labels, it2)) { - it = it2; - return true; - } - } - return false; -} - -bool InternalTree::RecursiveSearch(const std::vector & labels, std::vector::const_iterator & it, InternalTree const* &parent) const -{ - for (it = m_children.begin(); it != m_children.end(); ++it) { - if (std::binary_search(labels.begin(), labels.end(), (*it)->GetNTLabel())) { - parent = this; - return true; - } - std::vector::const_iterator it2; - if ((*it)->RecursiveSearch(labels, it2, parent)) { - it = it2; - return true; - } - } - return false; -} - } \ No newline at end of file diff --git a/moses/FF/InternalTree.h b/moses/FF/InternalTree.h index 8f982c6aa..a3db3487e 100644 --- a/moses/FF/InternalTree.h +++ b/moses/FF/InternalTree.h @@ -5,30 +5,28 @@ #include #include #include "FFState.h" +#include "moses/Word.h" #include #include #include "util/generator.hh" #include "util/exception.hh" +#include "util/string_piece.hh" namespace Moses { class InternalTree; typedef boost::shared_ptr TreePointer; -typedef int NTLabel; class InternalTree { - std::string m_value; - NTLabel m_value_nt; + Word m_value; std::vector m_children; - bool m_isTerminal; public: InternalTree(const std::string & line, size_t start, size_t len, const bool terminal); - InternalTree(const std::string & line, const bool terminal = false); + InternalTree(const std::string & line, const bool nonterminal = true); InternalTree(const InternalTree & tree): - m_value(tree.m_value), - m_isTerminal(tree.m_isTerminal) { + m_value(tree.m_value) { const std::vector & children = tree.m_children; for (std::vector::const_iterator it = children.begin(); it != children.end(); it++) { m_children.push_back(boost::make_shared(**it)); @@ -40,20 +38,10 @@ public: void Combine(const std::vector &previous); void Unbinarize(); void GetUnbinarizedChildren(std::vector &children) const; - const std::string & GetLabel() const { + const Word & GetLabel() const { return m_value; } - // optionally identify label by int instead of string; - // allows abstraction if multiple nonterminal strings should map to same label. - const NTLabel & GetNTLabel() const { - return m_value_nt; - } - - void SetNTLabel(NTLabel value) { - m_value_nt = value; - } - size_t GetLength() const { return m_children.size(); } @@ -62,38 +50,22 @@ public: } bool IsTerminal() const { - return m_isTerminal; + return !m_value.IsNonTerminal(); } bool IsLeafNT() const { - return (!m_isTerminal && m_children.size() == 0); + return (m_value.IsNonTerminal() && m_children.size() == 0); } // different methods to search a tree (either just direct children (FlatSearch) or all children (RecursiveSearch)) for constituents. // can be used for formulating syntax constraints. // if found, 'it' is iterator to first tree node that matches search string - bool FlatSearch(const std::string & label, std::vector::const_iterator & it) const; - bool RecursiveSearch(const std::string & label, std::vector::const_iterator & it) const; + bool FlatSearch(const Word & label, std::vector::const_iterator & it) const; + bool RecursiveSearch(const Word & label, std::vector::const_iterator & it) const; // if found, 'it' is iterator to first tree node that matches search string, and 'parent' to its parent node - bool RecursiveSearch(const std::string & label, std::vector::const_iterator & it, InternalTree const* &parent) const; - - // use NTLabel for search to reduce number of string comparisons / deal with synonymous labels - // if found, 'it' is iterator to first tree node that matches search string - bool FlatSearch(const NTLabel & label, std::vector::const_iterator & it) const; - bool RecursiveSearch(const NTLabel & label, std::vector::const_iterator & it) const; - - // if found, 'it' is iterator to first tree node that matches search string, and 'parent' to its parent node - bool RecursiveSearch(const NTLabel & label, std::vector::const_iterator & it, InternalTree const* &parent) const; - - // pass vector of possible labels to search - // if found, 'it' is iterator to first tree node that matches search string - bool FlatSearch(const std::vector & labels, std::vector::const_iterator & it) const; - bool RecursiveSearch(const std::vector & labels, std::vector::const_iterator & it) const; - - // if found, 'it' is iterator to first tree node that matches search string, and 'parent' to its parent node - bool RecursiveSearch(const std::vector & labels, std::vector::const_iterator & it, InternalTree const* &parent) const; + bool RecursiveSearch(const Word & label, std::vector::const_iterator & it, InternalTree const* &parent) const; // Python-like generator that yields next nonterminal leaf on every call $generator(leafNT) { diff --git a/moses/FF/TreeStructureFeature.cpp b/moses/FF/TreeStructureFeature.cpp index fc1fcdc5b..108c99143 100644 --- a/moses/FF/TreeStructureFeature.cpp +++ b/moses/FF/TreeStructureFeature.cpp @@ -13,33 +13,12 @@ void TreeStructureFeature::Load() // syntactic constraints can be hooked in here. m_constraints = NULL; - m_labelset = NULL; StaticData &staticData = StaticData::InstanceNonConst(); staticData.SetTreeStructure(this); } -// define NT labels (ints) that are mapped from strings for quicker comparison. -void TreeStructureFeature::AddNTLabels(TreePointer root) const -{ - std::string label = root->GetLabel(); - - if (root->IsTerminal()) { - return; - } - - std::map::const_iterator it = m_labelset->string_to_label.find(label); - if (it != m_labelset->string_to_label.end()) { - root->SetNTLabel(it->second); - } - - std::vector children = root->GetChildren(); - for (std::vector::const_iterator it2 = children.begin(); it2 != children.end(); ++it2) { - AddNTLabels(*it2); - } -} - FFState* TreeStructureFeature::EvaluateWhenApplied(const ChartHypothesis& cur_hypo , int featureID /* used to index the state in the previous hypotheses */ , ScoreComponentCollection* accumulator) const @@ -48,10 +27,6 @@ FFState* TreeStructureFeature::EvaluateWhenApplied(const ChartHypothesis& cur_hy const std::string *tree = property->GetValueString(); TreePointer mytree (boost::make_shared(*tree)); - if (m_labelset) { - AddNTLabels(mytree); - } - //get subtrees (in target order) std::vector previous_trees; for (size_t pos = 0; pos < cur_hypo.GetCurrTargetPhrase().GetSize(); ++pos) { @@ -70,7 +45,7 @@ FFState* TreeStructureFeature::EvaluateWhenApplied(const ChartHypothesis& cur_hy } mytree->Combine(previous_trees); - bool full_sentence = (mytree->GetChildren().back()->GetLabel() == "" || (mytree->GetChildren().back()->GetLabel() == "SEND" && mytree->GetChildren().back()->GetChildren().back()->GetLabel() == "")); + bool full_sentence = (mytree->GetChildren().back()->GetLabel() == m_send || (mytree->GetChildren().back()->GetLabel() == m_send_nt && mytree->GetChildren().back()->GetChildren().back()->GetLabel() == m_send)); if (m_binarized && full_sentence) { mytree->Unbinarize(); } diff --git a/moses/FF/TreeStructureFeature.h b/moses/FF/TreeStructureFeature.h index ecb2ce7cb..cef87e7ee 100644 --- a/moses/FF/TreeStructureFeature.h +++ b/moses/FF/TreeStructureFeature.h @@ -4,6 +4,7 @@ #include #include "StatefulFeatureFunction.h" #include "FFState.h" +#include "moses/Word.h" #include "InternalTree.h" namespace Moses @@ -35,11 +36,18 @@ class TreeStructureFeature : public StatefulFeatureFunction SyntaxConstraints* m_constraints; LabelSet* m_labelset; bool m_binarized; + Word m_send; + Word m_send_nt; + public: TreeStructureFeature(const std::string &line) :StatefulFeatureFunction(0, line) , m_binarized(false) { ReadParameters(); + std::vector factors; + factors.push_back(0); + m_send.CreateFromString(Output, factors, "", false); + m_send_nt.CreateFromString(Output, factors, "SEND", true); } ~TreeStructureFeature() { delete m_constraints; @@ -49,8 +57,6 @@ public: return new TreeState(TreePointer()); } - void AddNTLabels(TreePointer root) const; - bool IsUseable(const FactorMask &mask) const { return true; } diff --git a/moses/LM/RDLM.cpp b/moses/LM/RDLM.cpp index 1e9f2b4d3..33bdc9c55 100644 --- a/moses/LM/RDLM.cpp +++ b/moses/LM/RDLM.cpp @@ -70,7 +70,7 @@ void RDLM::Load() static_label_null[i] = lm_label_base_instance_->lookup_input_word(numstr); } - static_dummy_head = lm_head_base_instance_->lookup_input_word(dummy_head); + static_dummy_head = lm_head_base_instance_->lookup_input_word(dummy_head.GetString(0).as_string()); static_start_head = lm_head_base_instance_->lookup_input_word(""); static_start_label = lm_head_base_instance_->lookup_input_word(""); @@ -211,7 +211,7 @@ void RDLM::Score(InternalTree* root, const TreePointerMap & back_pointers, boost } // ignore virtual nodes (in binarization; except if it's the root) - if (m_binarized && root->GetLabel()[0] == '^' && !ancestor_heads.empty()) { + if (m_binarized && root->GetLabel().GetString(0).as_string()[0] == '^' && !ancestor_heads.empty()) { // recursion if (root->IsLeafNT() && m_context_up > 1 && ancestor_heads.size()) { root = back_pointers.find(root)->second.get(); @@ -241,9 +241,9 @@ void RDLM::Score(InternalTree* root, const TreePointerMap & back_pointers, boost // root of tree: score without context if (ancestor_heads.empty() || (ancestor_heads.size() == m_context_up && ancestor_heads.back() == static_root_head)) { std::vector ngram_head_null (static_head_null); - ngram_head_null.back() = lm_head->lookup_output_word(root->GetChildren()[0]->GetLabel()); + ngram_head_null.back() = lm_head->lookup_output_word(root->GetChildren()[0]->GetLabel().GetString(m_factorType).as_string()); if (m_isPretermBackoff && ngram_head_null.back() == 0) { - ngram_head_null.back() = lm_head->lookup_output_word(root->GetLabel()); + ngram_head_null.back() = lm_head->lookup_output_word(root->GetLabel().GetString(m_factorType).as_string()); } if (ancestor_heads.size() == m_context_up && ancestor_heads.back() == static_root_head) { std::vector::iterator it = ngram_head_null.begin(); @@ -296,7 +296,7 @@ void RDLM::Score(InternalTree* root, const TreePointerMap & back_pointers, boost } size_t context_up_nonempty = std::min(m_context_up, ancestor_heads.size()); - const std::string & head_label = root->GetLabel(); + const std::string & head_label = root->GetLabel().GetString(0).as_string(); bool virtual_head = false; int reached_end = 0; int label_idx, label_idx_out; @@ -527,7 +527,7 @@ bool RDLM::GetHead(InternalTree* root, const TreePointerMap & back_pointers, std tree = it->get(); } - if (m_binarized && tree->GetLabel()[0] == '^') { + if (m_binarized && tree->GetLabel().GetString(0).as_string()[0] == '^') { bool found = GetHead(tree, back_pointers, IDs); if (found) { return true; @@ -597,8 +597,8 @@ void RDLM::GetChildHeadsAndLabels(InternalTree *root, const TreePointerMap & bac child_ids = std::make_pair(static_dummy_head, static_dummy_head); } - labels[j] = lm_head->lookup_input_word(child->GetLabel()); - labels_output[j] = lm_label->lookup_output_word(child->GetLabel()); + labels[j] = lm_head->lookup_input_word(child->GetLabel().GetString(0).as_string()); + labels_output[j] = lm_label->lookup_output_word(child->GetLabel().GetString(0).as_string()); heads[j] = child_ids.first; heads_output[j] = child_ids.second; j++; @@ -613,18 +613,18 @@ void RDLM::GetChildHeadsAndLabels(InternalTree *root, const TreePointerMap & bac } -void RDLM::GetIDs(const std::string & head, const std::string & preterminal, std::pair & IDs) const +void RDLM::GetIDs(const Word & head, const Word & preterminal, std::pair & IDs) const { - IDs.first = lm_head_base_instance_->lookup_input_word(head); + IDs.first = lm_head_base_instance_->lookup_input_word(head.GetString(m_factorType).as_string()); if (m_isPretermBackoff && IDs.first == 0) { - IDs.first = lm_head_base_instance_->lookup_input_word(preterminal); + IDs.first = lm_head_base_instance_->lookup_input_word(preterminal.GetString(0).as_string()); } if (m_sharedVocab) { IDs.second = IDs.first; } else { - IDs.second = lm_head_base_instance_->lookup_output_word(head); + IDs.second = lm_head_base_instance_->lookup_output_word(head.GetString(m_factorType).as_string()); if (m_isPretermBackoff && IDs.second == 0) { - IDs.second = lm_head_base_instance_->lookup_output_word(preterminal); + IDs.second = lm_head_base_instance_->lookup_output_word(preterminal.GetString(0).as_string()); } } } @@ -718,7 +718,9 @@ void RDLM::SetParameter(const std::string& key, const std::string& value) else UTIL_THROW(util::Exception, "Unknown value for argument " << key << "=" << value); } else if (key == "glue_symbol") { - m_glueSymbol = value; + m_glueSymbolString = value; + } else if (key == "factor") { + m_factorType = Scan(value); } else if (key == "cache_size") { m_cacheSize = Scan(value); } else { diff --git a/moses/LM/RDLM.h b/moses/LM/RDLM.h index c5480b6c4..3d8c62f7e 100644 --- a/moses/LM/RDLM.h +++ b/moses/LM/RDLM.h @@ -3,6 +3,7 @@ #include "moses/FF/StatefulFeatureFunction.h" #include "moses/FF/FFState.h" #include "moses/FF/InternalTree.h" +#include "moses/Word.h" #include #include @@ -61,11 +62,12 @@ class RDLM : public StatefulFeatureFunction nplm::neuralTM* lm_label_base_instance_; mutable boost::thread_specific_ptr lm_label_backend_; - std::string dummy_head; - std::string m_glueSymbol; - std::string m_startSymbol; - std::string m_endSymbol; - std::string m_endTag; + std::string m_glueSymbolString; + Word dummy_head; + Word m_glueSymbol; + Word m_startSymbol; + Word m_endSymbol; + Word m_endTag; std::string m_path_head_lm; std::string m_path_label_lm; bool m_isPretermBackoff; @@ -102,14 +104,12 @@ class RDLM : public StatefulFeatureFunction int static_stop_label_output; int static_start_label_output; + FactorType m_factorType; + public: RDLM(const std::string &line) : StatefulFeatureFunction(2, line) - , dummy_head("") - , m_glueSymbol("Q") - , m_startSymbol("SSTART") - , m_endSymbol("SEND") - , m_endTag("") + , m_glueSymbolString("Q") , m_isPretermBackoff(true) , m_context_left(3) , m_context_right(0) @@ -120,8 +120,16 @@ public: , m_normalizeLabelLM(false) , m_sharedVocab(false) , m_binarized(0) - , m_cacheSize(1000000) { + , m_cacheSize(1000000) + , m_factorType(0) { ReadParameters(); + std::vector factors; + factors.push_back(0); + dummy_head.CreateFromString(Output, factors, "", false); + m_glueSymbol.CreateFromString(Output, factors, m_glueSymbolString, true); + m_startSymbol.CreateFromString(Output, factors, "SSTART", true); + m_endSymbol.CreateFromString(Output, factors, "SEND", true); + m_endTag.CreateFromString(Output, factors, "", false); } ~RDLM(); @@ -133,7 +141,7 @@ public: void Score(InternalTree* root, const TreePointerMap & back_pointers, boost::array &score, std::vector &ancestor_heads, std::vector &ancestor_labels, size_t &boundary_hash, int num_virtual = 0, int rescoring_levels = 0) const; bool GetHead(InternalTree* root, const TreePointerMap & back_pointers, std::pair & IDs) const; void GetChildHeadsAndLabels(InternalTree *root, const TreePointerMap & back_pointers, int reached_end, const nplm::neuralTM *lm_head, const nplm::neuralTM *lm_labels, std::vector & heads, std::vector & labels, std::vector & heads_output, std::vector & labels_output) const; - void GetIDs(const std::string & head, const std::string & preterminal, std::pair & IDs) const; + void GetIDs(const Word & head, const Word & preterminal, std::pair & IDs) const; void ScoreFile(std::string &path); //for debugging void PrintInfo(std::vector &ngram, nplm::neuralTM* lm) const; //for debugging @@ -190,7 +198,7 @@ public: _end = current->GetChildren().end(); iter = current->GetChildren().begin(); // expand virtual node - while (binarized && !(*iter)->GetLabel().empty() && (*iter)->GetLabel()[0] == '^') { + while (binarized && !(*iter)->GetLabel().GetString(0).empty() && (*iter)->GetLabel().GetString(0).data()[0] == '^') { stack.push_back(std::make_pair(current, iter)); // also go through trees or previous hypotheses to rescore nodes for which more context has become available if ((*iter)->IsLeafNT()) { @@ -227,7 +235,7 @@ public: } } // expand virtual node - while (binarized && !(*iter)->GetLabel().empty() && (*iter)->GetLabel()[0] == '^') { + while (binarized && !(*iter)->GetLabel().GetString(0).empty() && (*iter)->GetLabel().GetString(0).data()[0] == '^') { stack.push_back(std::make_pair(current, iter)); // also go through trees or previous hypotheses to rescore nodes for which more context has become available if ((*iter)->IsLeafNT()) {