support factors in InternalTree

This commit is contained in:
Rico Sennrich 2015-07-17 17:27:31 +01:00
parent e85f353898
commit bec950cf72
10 changed files with 274 additions and 213 deletions

View File

@ -5,10 +5,7 @@
#include <vector>
#include "StatisticsBasedScorer.h"
#include "moses/FF/InternalTree.h"
using Moses::TreePointer;
using Moses::InternalTree;
#include "InternalTree.h"
namespace MosesTuning
{

113
mert/InternalTree.cpp Normal file
View File

@ -0,0 +1,113 @@
#include "InternalTree.h"
namespace MosesTuning
{
InternalTree::InternalTree(const std::string & line, const bool terminal):
m_isTerminal(terminal)
{
size_t found = line.find_first_of("[] ");
if (found == line.npos) {
m_value = line;
}
else {
AddSubTree(line, 0);
}
}
size_t InternalTree::AddSubTree(const std::string & line, size_t pos) {
std::string value;
char token = 0;
while (token != ']' && pos != std::string::npos)
{
size_t oldpos = pos;
pos = line.find_first_of("[] ", pos);
if (pos == std::string::npos) break;
token = line[pos];
value = line.substr(oldpos,pos-oldpos);
if (token == '[') {
if (m_value.size() > 0) {
m_children.push_back(boost::make_shared<InternalTree>(value,false));
pos = m_children.back()->AddSubTree(line, pos+1);
}
else {
if (value.size() > 0) {
m_value = value;
}
pos = AddSubTree(line, pos+1);
}
}
else if (token == ' ' || token == ']') {
if (value.size() > 0 && !(m_value.size() > 0)) {
m_value = value;
}
else if (value.size() > 0) {
m_isTerminal = false;
m_children.push_back(boost::make_shared<InternalTree>(value,true));
}
if (token == ' ') {
pos++;
}
}
if (m_children.size() > 0) {
m_isTerminal = false;
}
}
if (pos == std::string::npos) {
return line.size();
}
return std::min(line.size(),pos+1);
}
std::string InternalTree::GetString(bool start) const {
std::string ret = "";
if (!start) {
ret += " ";
}
if (!m_isTerminal) {
ret += "[";
}
ret += m_value;
for (std::vector<TreePointer>::const_iterator it = m_children.begin(); it != m_children.end(); ++it)
{
ret += (*it)->GetString(false);
}
if (!m_isTerminal) {
ret += "]";
}
return ret;
}
void InternalTree::Combine(const std::vector<TreePointer> &previous) {
std::vector<TreePointer>::iterator it;
bool found = false;
leafNT next_leafNT(this);
for (std::vector<TreePointer>::const_iterator it_prev = previous.begin(); it_prev != previous.end(); ++it_prev) {
found = next_leafNT(it);
if (found) {
*it = *it_prev;
}
else {
std::cerr << "Warning: leaf nonterminal not found in rule; why did this happen?\n";
}
}
}
}

77
mert/InternalTree.h Normal file
View File

@ -0,0 +1,77 @@
#pragma once
#include <iostream>
#include <string>
#include <map>
#include <vector>
#include <boost/shared_ptr.hpp>
#include <boost/make_shared.hpp>
#include "util/generator.hh"
#include "util/exception.hh"
namespace MosesTuning
{
class InternalTree;
typedef boost::shared_ptr<InternalTree> TreePointer;
typedef int NTLabel;
class InternalTree
{
std::string m_value;
std::vector<TreePointer> m_children;
bool m_isTerminal;
public:
InternalTree(const std::string & line, const bool terminal = false);
InternalTree(const InternalTree & tree):
m_value(tree.m_value),
m_isTerminal(tree.m_isTerminal) {
const std::vector<TreePointer> & children = tree.m_children;
for (std::vector<TreePointer>::const_iterator it = children.begin(); it != children.end(); it++) {
m_children.push_back(boost::make_shared<InternalTree>(**it));
}
}
size_t AddSubTree(const std::string & line, size_t start);
std::string GetString(bool start = true) const;
void Combine(const std::vector<TreePointer> &previous);
const std::string & GetLabel() const {
return m_value;
}
size_t GetLength() const {
return m_children.size();
}
std::vector<TreePointer> & GetChildren() {
return m_children;
}
bool IsTerminal() const {
return m_isTerminal;
}
bool IsLeafNT() const {
return (!m_isTerminal && m_children.size() == 0);
}
};
// Python-like generator that yields next nonterminal leaf on every call
$generator(leafNT) {
std::vector<TreePointer>::iterator it;
InternalTree* tree;
leafNT(InternalTree* root = 0): tree(root) {}
$emit(std::vector<TreePointer>::iterator)
for (it = tree->GetChildren().begin(); it !=tree->GetChildren().end(); ++it) {
if (!(*it)->IsTerminal() && (*it)->GetLength() == 0) {
$yield(it);
}
else if ((*it)->GetLength() > 0) {
if ((*it).get()) { // normal pointer to same object that TreePointer points to
$restart(tree = (*it).get());
}
}
}
$stop;
};
}

View File

@ -30,7 +30,7 @@ InterpolatedScorer.cpp
Point.cpp
PerScorer.cpp
HwcmScorer.cpp
../moses/FF/InternalTree.cpp
InternalTree.cpp
Scorer.cpp
ScorerFactory.cpp
Optimizer.cpp

View File

@ -1,27 +1,24 @@
#include "InternalTree.h"
#include "moses/StaticData.h"
namespace Moses
{
InternalTree::InternalTree(const std::string & line, size_t start, size_t len, const bool terminal):
m_value_nt(0),
m_isTerminal(terminal)
InternalTree::InternalTree(const std::string & line, size_t start, size_t len, const bool nonterminal)
{
if (len > 0) {
m_value.assign(line, start, len);
m_value.CreateFromString(Output, StaticData::Instance().GetOutputFactorOrder(), StringPiece(line).substr(start, len), nonterminal);
}
}
InternalTree::InternalTree(const std::string & line, const bool terminal):
m_value_nt(0),
m_isTerminal(terminal)
InternalTree::InternalTree(const std::string & line, const bool nonterminal)
{
size_t found = line.find_first_of("[] ");
if (found == line.npos) {
m_value = line;
m_value.CreateFromString(Output, StaticData::Instance().GetOutputFactorOrder(), line, nonterminal);
} else {
AddSubTree(line, 0);
}
@ -32,6 +29,7 @@ size_t InternalTree::AddSubTree(const std::string & line, size_t pos)
char token = 0;
size_t len = 0;
bool has_value = false;
while (token != ']' && pos != std::string::npos) {
size_t oldpos = pos;
@ -41,30 +39,27 @@ size_t InternalTree::AddSubTree(const std::string & line, size_t pos)
len = pos-oldpos;
if (token == '[') {
if (!m_value.empty()) {
m_children.push_back(boost::make_shared<InternalTree>(line, oldpos, len, false));
if (has_value) {
m_children.push_back(boost::make_shared<InternalTree>(line, oldpos, len, true));
pos = m_children.back()->AddSubTree(line, pos+1);
} else {
if (len > 0) {
m_value.assign(line, oldpos, len);
m_value.CreateFromString(Output, StaticData::Instance().GetOutputFactorOrder(), StringPiece(line).substr(oldpos, len), false);
has_value = true;
}
pos = AddSubTree(line, pos+1);
}
} else if (token == ' ' || token == ']') {
if (len > 0 && m_value.empty()) {
m_value.assign(line, oldpos, len);
if (len > 0 && !has_value) {
m_value.CreateFromString(Output, StaticData::Instance().GetOutputFactorOrder(), StringPiece(line).substr(oldpos, len), true);
has_value = true;
} else if (len > 0) {
m_isTerminal = false;
m_children.push_back(boost::make_shared<InternalTree>(line, oldpos, len, true));
m_children.push_back(boost::make_shared<InternalTree>(line, oldpos, len, false));
}
if (token == ' ') {
pos++;
}
}
if (!m_children.empty()) {
m_isTerminal = false;
}
}
if (pos == std::string::npos) {
@ -82,16 +77,16 @@ std::string InternalTree::GetString(bool start) const
ret += " ";
}
if (!m_isTerminal) {
if (!IsTerminal()) {
ret += "[";
}
ret += m_value;
ret += m_value.GetString(StaticData::Instance().GetOutputFactorOrder(), false);
for (std::vector<TreePointer>::const_iterator it = m_children.begin(); it != m_children.end(); ++it) {
ret += (*it)->GetString(false);
}
if (!m_isTerminal) {
if (!IsTerminal()) {
ret += "]";
}
return ret;
@ -120,13 +115,13 @@ void InternalTree::Unbinarize()
{
// nodes with virtual label cannot be unbinarized
if (m_value.empty() || m_value[0] == '^') {
if (m_value.GetString(0).empty() || m_value.GetString(0).as_string()[0] == '^') {
return;
}
//if node has child that is virtual node, get unbinarized list of children
for (std::vector<TreePointer>::iterator it = m_children.begin(); it != m_children.end(); ++it) {
if (!(*it)->IsTerminal() && (*it)->GetLabel()[0] == '^') {
if (!(*it)->IsTerminal() && (*it)->GetLabel().GetString(0).as_string()[0] == '^') {
std::vector<TreePointer> new_children;
GetUnbinarizedChildren(new_children);
m_children = new_children;
@ -144,8 +139,8 @@ void InternalTree::Unbinarize()
void InternalTree::GetUnbinarizedChildren(std::vector<TreePointer> &ret) const
{
for (std::vector<TreePointer>::const_iterator itx = m_children.begin(); itx != m_children.end(); ++itx) {
const std::string &label = (*itx)->GetLabel();
if (!label.empty() && label[0] == '^') {
const StringPiece label = (*itx)->GetLabel().GetString(0);
if (!label.empty() && label.as_string()[0] == '^') {
(*itx)->GetUnbinarizedChildren(ret);
} else {
ret.push_back(*itx);
@ -153,7 +148,7 @@ void InternalTree::GetUnbinarizedChildren(std::vector<TreePointer> &ret) const
}
}
bool InternalTree::FlatSearch(const std::string & label, std::vector<TreePointer>::const_iterator & it) const
bool InternalTree::FlatSearch(const Word & label, std::vector<TreePointer>::const_iterator & it) const
{
for (it = m_children.begin(); it != m_children.end(); ++it) {
if ((*it)->GetLabel() == label) {
@ -163,7 +158,7 @@ bool InternalTree::FlatSearch(const std::string & label, std::vector<TreePointer
return false;
}
bool InternalTree::RecursiveSearch(const std::string & label, std::vector<TreePointer>::const_iterator & it) const
bool InternalTree::RecursiveSearch(const Word & label, std::vector<TreePointer>::const_iterator & it) const
{
for (it = m_children.begin(); it != m_children.end(); ++it) {
if ((*it)->GetLabel() == label) {
@ -178,7 +173,7 @@ bool InternalTree::RecursiveSearch(const std::string & label, std::vector<TreePo
return false;
}
bool InternalTree::RecursiveSearch(const std::string & label, std::vector<TreePointer>::const_iterator & it, InternalTree const* &parent) const
bool InternalTree::RecursiveSearch(const Word & label, std::vector<TreePointer>::const_iterator & it, InternalTree const* &parent) const
{
for (it = m_children.begin(); it != m_children.end(); ++it) {
if ((*it)->GetLabel() == label) {
@ -194,88 +189,4 @@ bool InternalTree::RecursiveSearch(const std::string & label, std::vector<TreePo
return false;
}
bool InternalTree::FlatSearch(const NTLabel & label, std::vector<TreePointer>::const_iterator & it) const
{
for (it = m_children.begin(); it != m_children.end(); ++it) {
if ((*it)->GetNTLabel() == label) {
return true;
}
}
return false;
}
bool InternalTree::RecursiveSearch(const NTLabel & label, std::vector<TreePointer>::const_iterator & it) const
{
for (it = m_children.begin(); it != m_children.end(); ++it) {
if ((*it)->GetNTLabel() == label) {
return true;
}
std::vector<TreePointer>::const_iterator it2;
if ((*it)->RecursiveSearch(label, it2)) {
it = it2;
return true;
}
}
return false;
}
bool InternalTree::RecursiveSearch(const NTLabel & label, std::vector<TreePointer>::const_iterator & it, InternalTree const* &parent) const
{
for (it = m_children.begin(); it != m_children.end(); ++it) {
if ((*it)->GetNTLabel() == label) {
parent = this;
return true;
}
std::vector<TreePointer>::const_iterator it2;
if ((*it)->RecursiveSearch(label, it2, parent)) {
it = it2;
return true;
}
}
return false;
}
bool InternalTree::FlatSearch(const std::vector<NTLabel> & labels, std::vector<TreePointer>::const_iterator & it) const
{
for (it = m_children.begin(); it != m_children.end(); ++it) {
if (std::binary_search(labels.begin(), labels.end(), (*it)->GetNTLabel())) {
return true;
}
}
return false;
}
bool InternalTree::RecursiveSearch(const std::vector<NTLabel> & labels, std::vector<TreePointer>::const_iterator & it) const
{
for (it = m_children.begin(); it != m_children.end(); ++it) {
if (std::binary_search(labels.begin(), labels.end(), (*it)->GetNTLabel())) {
return true;
}
std::vector<TreePointer>::const_iterator it2;
if ((*it)->RecursiveSearch(labels, it2)) {
it = it2;
return true;
}
}
return false;
}
bool InternalTree::RecursiveSearch(const std::vector<NTLabel> & labels, std::vector<TreePointer>::const_iterator & it, InternalTree const* &parent) const
{
for (it = m_children.begin(); it != m_children.end(); ++it) {
if (std::binary_search(labels.begin(), labels.end(), (*it)->GetNTLabel())) {
parent = this;
return true;
}
std::vector<TreePointer>::const_iterator it2;
if ((*it)->RecursiveSearch(labels, it2, parent)) {
it = it2;
return true;
}
}
return false;
}
}

View File

@ -5,30 +5,28 @@
#include <map>
#include <vector>
#include "FFState.h"
#include "moses/Word.h"
#include <boost/shared_ptr.hpp>
#include <boost/make_shared.hpp>
#include "util/generator.hh"
#include "util/exception.hh"
#include "util/string_piece.hh"
namespace Moses
{
class InternalTree;
typedef boost::shared_ptr<InternalTree> TreePointer;
typedef int NTLabel;
class InternalTree
{
std::string m_value;
NTLabel m_value_nt;
Word m_value;
std::vector<TreePointer> m_children;
bool m_isTerminal;
public:
InternalTree(const std::string & line, size_t start, size_t len, const bool terminal);
InternalTree(const std::string & line, const bool terminal = false);
InternalTree(const std::string & line, const bool nonterminal = true);
InternalTree(const InternalTree & tree):
m_value(tree.m_value),
m_isTerminal(tree.m_isTerminal) {
m_value(tree.m_value) {
const std::vector<TreePointer> & children = tree.m_children;
for (std::vector<TreePointer>::const_iterator it = children.begin(); it != children.end(); it++) {
m_children.push_back(boost::make_shared<InternalTree>(**it));
@ -40,20 +38,10 @@ public:
void Combine(const std::vector<TreePointer> &previous);
void Unbinarize();
void GetUnbinarizedChildren(std::vector<TreePointer> &children) const;
const std::string & GetLabel() const {
const Word & GetLabel() const {
return m_value;
}
// optionally identify label by int instead of string;
// allows abstraction if multiple nonterminal strings should map to same label.
const NTLabel & GetNTLabel() const {
return m_value_nt;
}
void SetNTLabel(NTLabel value) {
m_value_nt = value;
}
size_t GetLength() const {
return m_children.size();
}
@ -62,38 +50,22 @@ public:
}
bool IsTerminal() const {
return m_isTerminal;
return !m_value.IsNonTerminal();
}
bool IsLeafNT() const {
return (!m_isTerminal && m_children.size() == 0);
return (m_value.IsNonTerminal() && m_children.size() == 0);
}
// different methods to search a tree (either just direct children (FlatSearch) or all children (RecursiveSearch)) for constituents.
// can be used for formulating syntax constraints.
// if found, 'it' is iterator to first tree node that matches search string
bool FlatSearch(const std::string & label, std::vector<TreePointer>::const_iterator & it) const;
bool RecursiveSearch(const std::string & label, std::vector<TreePointer>::const_iterator & it) const;
bool FlatSearch(const Word & label, std::vector<TreePointer>::const_iterator & it) const;
bool RecursiveSearch(const Word & label, std::vector<TreePointer>::const_iterator & it) const;
// if found, 'it' is iterator to first tree node that matches search string, and 'parent' to its parent node
bool RecursiveSearch(const std::string & label, std::vector<TreePointer>::const_iterator & it, InternalTree const* &parent) const;
// use NTLabel for search to reduce number of string comparisons / deal with synonymous labels
// if found, 'it' is iterator to first tree node that matches search string
bool FlatSearch(const NTLabel & label, std::vector<TreePointer>::const_iterator & it) const;
bool RecursiveSearch(const NTLabel & label, std::vector<TreePointer>::const_iterator & it) const;
// if found, 'it' is iterator to first tree node that matches search string, and 'parent' to its parent node
bool RecursiveSearch(const NTLabel & label, std::vector<TreePointer>::const_iterator & it, InternalTree const* &parent) const;
// pass vector of possible labels to search
// if found, 'it' is iterator to first tree node that matches search string
bool FlatSearch(const std::vector<NTLabel> & labels, std::vector<TreePointer>::const_iterator & it) const;
bool RecursiveSearch(const std::vector<NTLabel> & labels, std::vector<TreePointer>::const_iterator & it) const;
// if found, 'it' is iterator to first tree node that matches search string, and 'parent' to its parent node
bool RecursiveSearch(const std::vector<NTLabel> & labels, std::vector<TreePointer>::const_iterator & it, InternalTree const* &parent) const;
bool RecursiveSearch(const Word & label, std::vector<TreePointer>::const_iterator & it, InternalTree const* &parent) const;
// Python-like generator that yields next nonterminal leaf on every call
$generator(leafNT) {

View File

@ -13,33 +13,12 @@ void TreeStructureFeature::Load()
// syntactic constraints can be hooked in here.
m_constraints = NULL;
m_labelset = NULL;
StaticData &staticData = StaticData::InstanceNonConst();
staticData.SetTreeStructure(this);
}
// define NT labels (ints) that are mapped from strings for quicker comparison.
void TreeStructureFeature::AddNTLabels(TreePointer root) const
{
std::string label = root->GetLabel();
if (root->IsTerminal()) {
return;
}
std::map<std::string, NTLabel>::const_iterator it = m_labelset->string_to_label.find(label);
if (it != m_labelset->string_to_label.end()) {
root->SetNTLabel(it->second);
}
std::vector<TreePointer> children = root->GetChildren();
for (std::vector<TreePointer>::const_iterator it2 = children.begin(); it2 != children.end(); ++it2) {
AddNTLabels(*it2);
}
}
FFState* TreeStructureFeature::EvaluateWhenApplied(const ChartHypothesis& cur_hypo
, int featureID /* used to index the state in the previous hypotheses */
, ScoreComponentCollection* accumulator) const
@ -48,10 +27,6 @@ FFState* TreeStructureFeature::EvaluateWhenApplied(const ChartHypothesis& cur_hy
const std::string *tree = property->GetValueString();
TreePointer mytree (boost::make_shared<InternalTree>(*tree));
if (m_labelset) {
AddNTLabels(mytree);
}
//get subtrees (in target order)
std::vector<TreePointer> previous_trees;
for (size_t pos = 0; pos < cur_hypo.GetCurrTargetPhrase().GetSize(); ++pos) {
@ -70,7 +45,7 @@ FFState* TreeStructureFeature::EvaluateWhenApplied(const ChartHypothesis& cur_hy
}
mytree->Combine(previous_trees);
bool full_sentence = (mytree->GetChildren().back()->GetLabel() == "</s>" || (mytree->GetChildren().back()->GetLabel() == "SEND" && mytree->GetChildren().back()->GetChildren().back()->GetLabel() == "</s>"));
bool full_sentence = (mytree->GetChildren().back()->GetLabel() == m_send || (mytree->GetChildren().back()->GetLabel() == m_send_nt && mytree->GetChildren().back()->GetChildren().back()->GetLabel() == m_send));
if (m_binarized && full_sentence) {
mytree->Unbinarize();
}

View File

@ -4,6 +4,7 @@
#include <map>
#include "StatefulFeatureFunction.h"
#include "FFState.h"
#include "moses/Word.h"
#include "InternalTree.h"
namespace Moses
@ -35,11 +36,18 @@ class TreeStructureFeature : public StatefulFeatureFunction
SyntaxConstraints* m_constraints;
LabelSet* m_labelset;
bool m_binarized;
Word m_send;
Word m_send_nt;
public:
TreeStructureFeature(const std::string &line)
:StatefulFeatureFunction(0, line)
, m_binarized(false) {
ReadParameters();
std::vector<FactorType> factors;
factors.push_back(0);
m_send.CreateFromString(Output, factors, "</s>", false);
m_send_nt.CreateFromString(Output, factors, "SEND", true);
}
~TreeStructureFeature() {
delete m_constraints;
@ -49,8 +57,6 @@ public:
return new TreeState(TreePointer());
}
void AddNTLabels(TreePointer root) const;
bool IsUseable(const FactorMask &mask) const {
return true;
}

View File

@ -70,7 +70,7 @@ void RDLM::Load()
static_label_null[i] = lm_label_base_instance_->lookup_input_word(numstr);
}
static_dummy_head = lm_head_base_instance_->lookup_input_word(dummy_head);
static_dummy_head = lm_head_base_instance_->lookup_input_word(dummy_head.GetString(0).as_string());
static_start_head = lm_head_base_instance_->lookup_input_word("<start_head>");
static_start_label = lm_head_base_instance_->lookup_input_word("<start_label>");
@ -211,7 +211,7 @@ void RDLM::Score(InternalTree* root, const TreePointerMap & back_pointers, boost
}
// ignore virtual nodes (in binarization; except if it's the root)
if (m_binarized && root->GetLabel()[0] == '^' && !ancestor_heads.empty()) {
if (m_binarized && root->GetLabel().GetString(0).as_string()[0] == '^' && !ancestor_heads.empty()) {
// recursion
if (root->IsLeafNT() && m_context_up > 1 && ancestor_heads.size()) {
root = back_pointers.find(root)->second.get();
@ -241,9 +241,9 @@ void RDLM::Score(InternalTree* root, const TreePointerMap & back_pointers, boost
// root of tree: score without context
if (ancestor_heads.empty() || (ancestor_heads.size() == m_context_up && ancestor_heads.back() == static_root_head)) {
std::vector<int> ngram_head_null (static_head_null);
ngram_head_null.back() = lm_head->lookup_output_word(root->GetChildren()[0]->GetLabel());
ngram_head_null.back() = lm_head->lookup_output_word(root->GetChildren()[0]->GetLabel().GetString(m_factorType).as_string());
if (m_isPretermBackoff && ngram_head_null.back() == 0) {
ngram_head_null.back() = lm_head->lookup_output_word(root->GetLabel());
ngram_head_null.back() = lm_head->lookup_output_word(root->GetLabel().GetString(m_factorType).as_string());
}
if (ancestor_heads.size() == m_context_up && ancestor_heads.back() == static_root_head) {
std::vector<int>::iterator it = ngram_head_null.begin();
@ -296,7 +296,7 @@ void RDLM::Score(InternalTree* root, const TreePointerMap & back_pointers, boost
}
size_t context_up_nonempty = std::min(m_context_up, ancestor_heads.size());
const std::string & head_label = root->GetLabel();
const std::string & head_label = root->GetLabel().GetString(0).as_string();
bool virtual_head = false;
int reached_end = 0;
int label_idx, label_idx_out;
@ -527,7 +527,7 @@ bool RDLM::GetHead(InternalTree* root, const TreePointerMap & back_pointers, std
tree = it->get();
}
if (m_binarized && tree->GetLabel()[0] == '^') {
if (m_binarized && tree->GetLabel().GetString(0).as_string()[0] == '^') {
bool found = GetHead(tree, back_pointers, IDs);
if (found) {
return true;
@ -597,8 +597,8 @@ void RDLM::GetChildHeadsAndLabels(InternalTree *root, const TreePointerMap & bac
child_ids = std::make_pair(static_dummy_head, static_dummy_head);
}
labels[j] = lm_head->lookup_input_word(child->GetLabel());
labels_output[j] = lm_label->lookup_output_word(child->GetLabel());
labels[j] = lm_head->lookup_input_word(child->GetLabel().GetString(0).as_string());
labels_output[j] = lm_label->lookup_output_word(child->GetLabel().GetString(0).as_string());
heads[j] = child_ids.first;
heads_output[j] = child_ids.second;
j++;
@ -613,18 +613,18 @@ void RDLM::GetChildHeadsAndLabels(InternalTree *root, const TreePointerMap & bac
}
void RDLM::GetIDs(const std::string & head, const std::string & preterminal, std::pair<int,int> & IDs) const
void RDLM::GetIDs(const Word & head, const Word & preterminal, std::pair<int,int> & IDs) const
{
IDs.first = lm_head_base_instance_->lookup_input_word(head);
IDs.first = lm_head_base_instance_->lookup_input_word(head.GetString(m_factorType).as_string());
if (m_isPretermBackoff && IDs.first == 0) {
IDs.first = lm_head_base_instance_->lookup_input_word(preterminal);
IDs.first = lm_head_base_instance_->lookup_input_word(preterminal.GetString(0).as_string());
}
if (m_sharedVocab) {
IDs.second = IDs.first;
} else {
IDs.second = lm_head_base_instance_->lookup_output_word(head);
IDs.second = lm_head_base_instance_->lookup_output_word(head.GetString(m_factorType).as_string());
if (m_isPretermBackoff && IDs.second == 0) {
IDs.second = lm_head_base_instance_->lookup_output_word(preterminal);
IDs.second = lm_head_base_instance_->lookup_output_word(preterminal.GetString(0).as_string());
}
}
}
@ -718,7 +718,9 @@ void RDLM::SetParameter(const std::string& key, const std::string& value)
else
UTIL_THROW(util::Exception, "Unknown value for argument " << key << "=" << value);
} else if (key == "glue_symbol") {
m_glueSymbol = value;
m_glueSymbolString = value;
} else if (key == "factor") {
m_factorType = Scan<FactorType>(value);
} else if (key == "cache_size") {
m_cacheSize = Scan<int>(value);
} else {

View File

@ -3,6 +3,7 @@
#include "moses/FF/StatefulFeatureFunction.h"
#include "moses/FF/FFState.h"
#include "moses/FF/InternalTree.h"
#include "moses/Word.h"
#include <boost/thread/tss.hpp>
#include <boost/array.hpp>
@ -61,11 +62,12 @@ class RDLM : public StatefulFeatureFunction
nplm::neuralTM* lm_label_base_instance_;
mutable boost::thread_specific_ptr<nplm::neuralTM> lm_label_backend_;
std::string dummy_head;
std::string m_glueSymbol;
std::string m_startSymbol;
std::string m_endSymbol;
std::string m_endTag;
std::string m_glueSymbolString;
Word dummy_head;
Word m_glueSymbol;
Word m_startSymbol;
Word m_endSymbol;
Word m_endTag;
std::string m_path_head_lm;
std::string m_path_label_lm;
bool m_isPretermBackoff;
@ -102,14 +104,12 @@ class RDLM : public StatefulFeatureFunction
int static_stop_label_output;
int static_start_label_output;
FactorType m_factorType;
public:
RDLM(const std::string &line)
: StatefulFeatureFunction(2, line)
, dummy_head("<dummy_head>")
, m_glueSymbol("Q")
, m_startSymbol("SSTART")
, m_endSymbol("SEND")
, m_endTag("</s>")
, m_glueSymbolString("Q")
, m_isPretermBackoff(true)
, m_context_left(3)
, m_context_right(0)
@ -120,8 +120,16 @@ public:
, m_normalizeLabelLM(false)
, m_sharedVocab(false)
, m_binarized(0)
, m_cacheSize(1000000) {
, m_cacheSize(1000000)
, m_factorType(0) {
ReadParameters();
std::vector<FactorType> factors;
factors.push_back(0);
dummy_head.CreateFromString(Output, factors, "<dummy_head>", false);
m_glueSymbol.CreateFromString(Output, factors, m_glueSymbolString, true);
m_startSymbol.CreateFromString(Output, factors, "SSTART", true);
m_endSymbol.CreateFromString(Output, factors, "SEND", true);
m_endTag.CreateFromString(Output, factors, "</s>", false);
}
~RDLM();
@ -133,7 +141,7 @@ public:
void Score(InternalTree* root, const TreePointerMap & back_pointers, boost::array<float,4> &score, std::vector<int> &ancestor_heads, std::vector<int> &ancestor_labels, size_t &boundary_hash, int num_virtual = 0, int rescoring_levels = 0) const;
bool GetHead(InternalTree* root, const TreePointerMap & back_pointers, std::pair<int,int> & IDs) const;
void GetChildHeadsAndLabels(InternalTree *root, const TreePointerMap & back_pointers, int reached_end, const nplm::neuralTM *lm_head, const nplm::neuralTM *lm_labels, std::vector<int> & heads, std::vector<int> & labels, std::vector<int> & heads_output, std::vector<int> & labels_output) const;
void GetIDs(const std::string & head, const std::string & preterminal, std::pair<int,int> & IDs) const;
void GetIDs(const Word & head, const Word & preterminal, std::pair<int,int> & IDs) const;
void ScoreFile(std::string &path); //for debugging
void PrintInfo(std::vector<int> &ngram, nplm::neuralTM* lm) const; //for debugging
@ -190,7 +198,7 @@ public:
_end = current->GetChildren().end();
iter = current->GetChildren().begin();
// expand virtual node
while (binarized && !(*iter)->GetLabel().empty() && (*iter)->GetLabel()[0] == '^') {
while (binarized && !(*iter)->GetLabel().GetString(0).empty() && (*iter)->GetLabel().GetString(0).data()[0] == '^') {
stack.push_back(std::make_pair(current, iter));
// also go through trees or previous hypotheses to rescore nodes for which more context has become available
if ((*iter)->IsLeafNT()) {
@ -227,7 +235,7 @@ public:
}
}
// expand virtual node
while (binarized && !(*iter)->GetLabel().empty() && (*iter)->GetLabel()[0] == '^') {
while (binarized && !(*iter)->GetLabel().GetString(0).empty() && (*iter)->GetLabel().GetString(0).data()[0] == '^') {
stack.push_back(std::make_pair(current, iter));
// also go through trees or previous hypotheses to rescore nodes for which more context has become available
if ((*iter)->IsLeafNT()) {