mirror of
https://github.com/moses-smt/mosesdecoder.git
synced 2024-07-14 23:00:29 +03:00
support factors in InternalTree
This commit is contained in:
parent
e85f353898
commit
bec950cf72
@ -5,10 +5,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include "StatisticsBasedScorer.h"
|
||||
#include "moses/FF/InternalTree.h"
|
||||
|
||||
using Moses::TreePointer;
|
||||
using Moses::InternalTree;
|
||||
#include "InternalTree.h"
|
||||
|
||||
namespace MosesTuning
|
||||
{
|
||||
|
113
mert/InternalTree.cpp
Normal file
113
mert/InternalTree.cpp
Normal file
@ -0,0 +1,113 @@
|
||||
#include "InternalTree.h"
|
||||
|
||||
namespace MosesTuning
|
||||
{
|
||||
|
||||
InternalTree::InternalTree(const std::string & line, const bool terminal):
|
||||
m_isTerminal(terminal)
|
||||
{
|
||||
|
||||
size_t found = line.find_first_of("[] ");
|
||||
|
||||
if (found == line.npos) {
|
||||
m_value = line;
|
||||
}
|
||||
|
||||
else {
|
||||
AddSubTree(line, 0);
|
||||
}
|
||||
}
|
||||
|
||||
size_t InternalTree::AddSubTree(const std::string & line, size_t pos) {
|
||||
|
||||
std::string value;
|
||||
char token = 0;
|
||||
|
||||
while (token != ']' && pos != std::string::npos)
|
||||
{
|
||||
size_t oldpos = pos;
|
||||
pos = line.find_first_of("[] ", pos);
|
||||
if (pos == std::string::npos) break;
|
||||
token = line[pos];
|
||||
value = line.substr(oldpos,pos-oldpos);
|
||||
|
||||
if (token == '[') {
|
||||
if (m_value.size() > 0) {
|
||||
m_children.push_back(boost::make_shared<InternalTree>(value,false));
|
||||
pos = m_children.back()->AddSubTree(line, pos+1);
|
||||
}
|
||||
else {
|
||||
if (value.size() > 0) {
|
||||
m_value = value;
|
||||
}
|
||||
pos = AddSubTree(line, pos+1);
|
||||
}
|
||||
}
|
||||
else if (token == ' ' || token == ']') {
|
||||
if (value.size() > 0 && !(m_value.size() > 0)) {
|
||||
m_value = value;
|
||||
}
|
||||
else if (value.size() > 0) {
|
||||
m_isTerminal = false;
|
||||
m_children.push_back(boost::make_shared<InternalTree>(value,true));
|
||||
}
|
||||
if (token == ' ') {
|
||||
pos++;
|
||||
}
|
||||
}
|
||||
|
||||
if (m_children.size() > 0) {
|
||||
m_isTerminal = false;
|
||||
}
|
||||
}
|
||||
|
||||
if (pos == std::string::npos) {
|
||||
return line.size();
|
||||
}
|
||||
return std::min(line.size(),pos+1);
|
||||
|
||||
}
|
||||
|
||||
std::string InternalTree::GetString(bool start) const {
|
||||
|
||||
std::string ret = "";
|
||||
if (!start) {
|
||||
ret += " ";
|
||||
}
|
||||
|
||||
if (!m_isTerminal) {
|
||||
ret += "[";
|
||||
}
|
||||
|
||||
ret += m_value;
|
||||
for (std::vector<TreePointer>::const_iterator it = m_children.begin(); it != m_children.end(); ++it)
|
||||
{
|
||||
ret += (*it)->GetString(false);
|
||||
}
|
||||
|
||||
if (!m_isTerminal) {
|
||||
ret += "]";
|
||||
}
|
||||
return ret;
|
||||
|
||||
}
|
||||
|
||||
|
||||
void InternalTree::Combine(const std::vector<TreePointer> &previous) {
|
||||
|
||||
std::vector<TreePointer>::iterator it;
|
||||
bool found = false;
|
||||
leafNT next_leafNT(this);
|
||||
for (std::vector<TreePointer>::const_iterator it_prev = previous.begin(); it_prev != previous.end(); ++it_prev) {
|
||||
found = next_leafNT(it);
|
||||
if (found) {
|
||||
*it = *it_prev;
|
||||
}
|
||||
else {
|
||||
std::cerr << "Warning: leaf nonterminal not found in rule; why did this happen?\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
77
mert/InternalTree.h
Normal file
77
mert/InternalTree.h
Normal file
@ -0,0 +1,77 @@
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <boost/shared_ptr.hpp>
|
||||
#include <boost/make_shared.hpp>
|
||||
#include "util/generator.hh"
|
||||
#include "util/exception.hh"
|
||||
|
||||
namespace MosesTuning
|
||||
{
|
||||
|
||||
class InternalTree;
|
||||
typedef boost::shared_ptr<InternalTree> TreePointer;
|
||||
typedef int NTLabel;
|
||||
|
||||
class InternalTree
|
||||
{
|
||||
std::string m_value;
|
||||
std::vector<TreePointer> m_children;
|
||||
bool m_isTerminal;
|
||||
public:
|
||||
InternalTree(const std::string & line, const bool terminal = false);
|
||||
InternalTree(const InternalTree & tree):
|
||||
m_value(tree.m_value),
|
||||
m_isTerminal(tree.m_isTerminal) {
|
||||
const std::vector<TreePointer> & children = tree.m_children;
|
||||
for (std::vector<TreePointer>::const_iterator it = children.begin(); it != children.end(); it++) {
|
||||
m_children.push_back(boost::make_shared<InternalTree>(**it));
|
||||
}
|
||||
}
|
||||
size_t AddSubTree(const std::string & line, size_t start);
|
||||
|
||||
std::string GetString(bool start = true) const;
|
||||
void Combine(const std::vector<TreePointer> &previous);
|
||||
const std::string & GetLabel() const {
|
||||
return m_value;
|
||||
}
|
||||
|
||||
size_t GetLength() const {
|
||||
return m_children.size();
|
||||
}
|
||||
std::vector<TreePointer> & GetChildren() {
|
||||
return m_children;
|
||||
}
|
||||
|
||||
bool IsTerminal() const {
|
||||
return m_isTerminal;
|
||||
}
|
||||
|
||||
bool IsLeafNT() const {
|
||||
return (!m_isTerminal && m_children.size() == 0);
|
||||
}
|
||||
};
|
||||
|
||||
// Python-like generator that yields next nonterminal leaf on every call
|
||||
$generator(leafNT) {
|
||||
std::vector<TreePointer>::iterator it;
|
||||
InternalTree* tree;
|
||||
leafNT(InternalTree* root = 0): tree(root) {}
|
||||
$emit(std::vector<TreePointer>::iterator)
|
||||
for (it = tree->GetChildren().begin(); it !=tree->GetChildren().end(); ++it) {
|
||||
if (!(*it)->IsTerminal() && (*it)->GetLength() == 0) {
|
||||
$yield(it);
|
||||
}
|
||||
else if ((*it)->GetLength() > 0) {
|
||||
if ((*it).get()) { // normal pointer to same object that TreePointer points to
|
||||
$restart(tree = (*it).get());
|
||||
}
|
||||
}
|
||||
}
|
||||
$stop;
|
||||
};
|
||||
|
||||
}
|
@ -30,7 +30,7 @@ InterpolatedScorer.cpp
|
||||
Point.cpp
|
||||
PerScorer.cpp
|
||||
HwcmScorer.cpp
|
||||
../moses/FF/InternalTree.cpp
|
||||
InternalTree.cpp
|
||||
Scorer.cpp
|
||||
ScorerFactory.cpp
|
||||
Optimizer.cpp
|
||||
|
@ -1,27 +1,24 @@
|
||||
#include "InternalTree.h"
|
||||
#include "moses/StaticData.h"
|
||||
|
||||
namespace Moses
|
||||
{
|
||||
|
||||
InternalTree::InternalTree(const std::string & line, size_t start, size_t len, const bool terminal):
|
||||
m_value_nt(0),
|
||||
m_isTerminal(terminal)
|
||||
InternalTree::InternalTree(const std::string & line, size_t start, size_t len, const bool nonterminal)
|
||||
{
|
||||
|
||||
if (len > 0) {
|
||||
m_value.assign(line, start, len);
|
||||
m_value.CreateFromString(Output, StaticData::Instance().GetOutputFactorOrder(), StringPiece(line).substr(start, len), nonterminal);
|
||||
}
|
||||
}
|
||||
|
||||
InternalTree::InternalTree(const std::string & line, const bool terminal):
|
||||
m_value_nt(0),
|
||||
m_isTerminal(terminal)
|
||||
InternalTree::InternalTree(const std::string & line, const bool nonterminal)
|
||||
{
|
||||
|
||||
size_t found = line.find_first_of("[] ");
|
||||
|
||||
if (found == line.npos) {
|
||||
m_value = line;
|
||||
m_value.CreateFromString(Output, StaticData::Instance().GetOutputFactorOrder(), line, nonterminal);
|
||||
} else {
|
||||
AddSubTree(line, 0);
|
||||
}
|
||||
@ -32,6 +29,7 @@ size_t InternalTree::AddSubTree(const std::string & line, size_t pos)
|
||||
|
||||
char token = 0;
|
||||
size_t len = 0;
|
||||
bool has_value = false;
|
||||
|
||||
while (token != ']' && pos != std::string::npos) {
|
||||
size_t oldpos = pos;
|
||||
@ -41,30 +39,27 @@ size_t InternalTree::AddSubTree(const std::string & line, size_t pos)
|
||||
len = pos-oldpos;
|
||||
|
||||
if (token == '[') {
|
||||
if (!m_value.empty()) {
|
||||
m_children.push_back(boost::make_shared<InternalTree>(line, oldpos, len, false));
|
||||
if (has_value) {
|
||||
m_children.push_back(boost::make_shared<InternalTree>(line, oldpos, len, true));
|
||||
pos = m_children.back()->AddSubTree(line, pos+1);
|
||||
} else {
|
||||
if (len > 0) {
|
||||
m_value.assign(line, oldpos, len);
|
||||
m_value.CreateFromString(Output, StaticData::Instance().GetOutputFactorOrder(), StringPiece(line).substr(oldpos, len), false);
|
||||
has_value = true;
|
||||
}
|
||||
pos = AddSubTree(line, pos+1);
|
||||
}
|
||||
} else if (token == ' ' || token == ']') {
|
||||
if (len > 0 && m_value.empty()) {
|
||||
m_value.assign(line, oldpos, len);
|
||||
if (len > 0 && !has_value) {
|
||||
m_value.CreateFromString(Output, StaticData::Instance().GetOutputFactorOrder(), StringPiece(line).substr(oldpos, len), true);
|
||||
has_value = true;
|
||||
} else if (len > 0) {
|
||||
m_isTerminal = false;
|
||||
m_children.push_back(boost::make_shared<InternalTree>(line, oldpos, len, true));
|
||||
m_children.push_back(boost::make_shared<InternalTree>(line, oldpos, len, false));
|
||||
}
|
||||
if (token == ' ') {
|
||||
pos++;
|
||||
}
|
||||
}
|
||||
|
||||
if (!m_children.empty()) {
|
||||
m_isTerminal = false;
|
||||
}
|
||||
}
|
||||
|
||||
if (pos == std::string::npos) {
|
||||
@ -82,16 +77,16 @@ std::string InternalTree::GetString(bool start) const
|
||||
ret += " ";
|
||||
}
|
||||
|
||||
if (!m_isTerminal) {
|
||||
if (!IsTerminal()) {
|
||||
ret += "[";
|
||||
}
|
||||
|
||||
ret += m_value;
|
||||
ret += m_value.GetString(StaticData::Instance().GetOutputFactorOrder(), false);
|
||||
for (std::vector<TreePointer>::const_iterator it = m_children.begin(); it != m_children.end(); ++it) {
|
||||
ret += (*it)->GetString(false);
|
||||
}
|
||||
|
||||
if (!m_isTerminal) {
|
||||
if (!IsTerminal()) {
|
||||
ret += "]";
|
||||
}
|
||||
return ret;
|
||||
@ -120,13 +115,13 @@ void InternalTree::Unbinarize()
|
||||
{
|
||||
|
||||
// nodes with virtual label cannot be unbinarized
|
||||
if (m_value.empty() || m_value[0] == '^') {
|
||||
if (m_value.GetString(0).empty() || m_value.GetString(0).as_string()[0] == '^') {
|
||||
return;
|
||||
}
|
||||
|
||||
//if node has child that is virtual node, get unbinarized list of children
|
||||
for (std::vector<TreePointer>::iterator it = m_children.begin(); it != m_children.end(); ++it) {
|
||||
if (!(*it)->IsTerminal() && (*it)->GetLabel()[0] == '^') {
|
||||
if (!(*it)->IsTerminal() && (*it)->GetLabel().GetString(0).as_string()[0] == '^') {
|
||||
std::vector<TreePointer> new_children;
|
||||
GetUnbinarizedChildren(new_children);
|
||||
m_children = new_children;
|
||||
@ -144,8 +139,8 @@ void InternalTree::Unbinarize()
|
||||
void InternalTree::GetUnbinarizedChildren(std::vector<TreePointer> &ret) const
|
||||
{
|
||||
for (std::vector<TreePointer>::const_iterator itx = m_children.begin(); itx != m_children.end(); ++itx) {
|
||||
const std::string &label = (*itx)->GetLabel();
|
||||
if (!label.empty() && label[0] == '^') {
|
||||
const StringPiece label = (*itx)->GetLabel().GetString(0);
|
||||
if (!label.empty() && label.as_string()[0] == '^') {
|
||||
(*itx)->GetUnbinarizedChildren(ret);
|
||||
} else {
|
||||
ret.push_back(*itx);
|
||||
@ -153,7 +148,7 @@ void InternalTree::GetUnbinarizedChildren(std::vector<TreePointer> &ret) const
|
||||
}
|
||||
}
|
||||
|
||||
bool InternalTree::FlatSearch(const std::string & label, std::vector<TreePointer>::const_iterator & it) const
|
||||
bool InternalTree::FlatSearch(const Word & label, std::vector<TreePointer>::const_iterator & it) const
|
||||
{
|
||||
for (it = m_children.begin(); it != m_children.end(); ++it) {
|
||||
if ((*it)->GetLabel() == label) {
|
||||
@ -163,7 +158,7 @@ bool InternalTree::FlatSearch(const std::string & label, std::vector<TreePointer
|
||||
return false;
|
||||
}
|
||||
|
||||
bool InternalTree::RecursiveSearch(const std::string & label, std::vector<TreePointer>::const_iterator & it) const
|
||||
bool InternalTree::RecursiveSearch(const Word & label, std::vector<TreePointer>::const_iterator & it) const
|
||||
{
|
||||
for (it = m_children.begin(); it != m_children.end(); ++it) {
|
||||
if ((*it)->GetLabel() == label) {
|
||||
@ -178,7 +173,7 @@ bool InternalTree::RecursiveSearch(const std::string & label, std::vector<TreePo
|
||||
return false;
|
||||
}
|
||||
|
||||
bool InternalTree::RecursiveSearch(const std::string & label, std::vector<TreePointer>::const_iterator & it, InternalTree const* &parent) const
|
||||
bool InternalTree::RecursiveSearch(const Word & label, std::vector<TreePointer>::const_iterator & it, InternalTree const* &parent) const
|
||||
{
|
||||
for (it = m_children.begin(); it != m_children.end(); ++it) {
|
||||
if ((*it)->GetLabel() == label) {
|
||||
@ -194,88 +189,4 @@ bool InternalTree::RecursiveSearch(const std::string & label, std::vector<TreePo
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
bool InternalTree::FlatSearch(const NTLabel & label, std::vector<TreePointer>::const_iterator & it) const
|
||||
{
|
||||
for (it = m_children.begin(); it != m_children.end(); ++it) {
|
||||
if ((*it)->GetNTLabel() == label) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool InternalTree::RecursiveSearch(const NTLabel & label, std::vector<TreePointer>::const_iterator & it) const
|
||||
{
|
||||
for (it = m_children.begin(); it != m_children.end(); ++it) {
|
||||
if ((*it)->GetNTLabel() == label) {
|
||||
return true;
|
||||
}
|
||||
std::vector<TreePointer>::const_iterator it2;
|
||||
if ((*it)->RecursiveSearch(label, it2)) {
|
||||
it = it2;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool InternalTree::RecursiveSearch(const NTLabel & label, std::vector<TreePointer>::const_iterator & it, InternalTree const* &parent) const
|
||||
{
|
||||
for (it = m_children.begin(); it != m_children.end(); ++it) {
|
||||
if ((*it)->GetNTLabel() == label) {
|
||||
parent = this;
|
||||
return true;
|
||||
}
|
||||
std::vector<TreePointer>::const_iterator it2;
|
||||
if ((*it)->RecursiveSearch(label, it2, parent)) {
|
||||
it = it2;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
bool InternalTree::FlatSearch(const std::vector<NTLabel> & labels, std::vector<TreePointer>::const_iterator & it) const
|
||||
{
|
||||
for (it = m_children.begin(); it != m_children.end(); ++it) {
|
||||
if (std::binary_search(labels.begin(), labels.end(), (*it)->GetNTLabel())) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool InternalTree::RecursiveSearch(const std::vector<NTLabel> & labels, std::vector<TreePointer>::const_iterator & it) const
|
||||
{
|
||||
for (it = m_children.begin(); it != m_children.end(); ++it) {
|
||||
if (std::binary_search(labels.begin(), labels.end(), (*it)->GetNTLabel())) {
|
||||
return true;
|
||||
}
|
||||
std::vector<TreePointer>::const_iterator it2;
|
||||
if ((*it)->RecursiveSearch(labels, it2)) {
|
||||
it = it2;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool InternalTree::RecursiveSearch(const std::vector<NTLabel> & labels, std::vector<TreePointer>::const_iterator & it, InternalTree const* &parent) const
|
||||
{
|
||||
for (it = m_children.begin(); it != m_children.end(); ++it) {
|
||||
if (std::binary_search(labels.begin(), labels.end(), (*it)->GetNTLabel())) {
|
||||
parent = this;
|
||||
return true;
|
||||
}
|
||||
std::vector<TreePointer>::const_iterator it2;
|
||||
if ((*it)->RecursiveSearch(labels, it2, parent)) {
|
||||
it = it2;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
}
|
@ -5,30 +5,28 @@
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "FFState.h"
|
||||
#include "moses/Word.h"
|
||||
#include <boost/shared_ptr.hpp>
|
||||
#include <boost/make_shared.hpp>
|
||||
#include "util/generator.hh"
|
||||
#include "util/exception.hh"
|
||||
#include "util/string_piece.hh"
|
||||
|
||||
namespace Moses
|
||||
{
|
||||
|
||||
class InternalTree;
|
||||
typedef boost::shared_ptr<InternalTree> TreePointer;
|
||||
typedef int NTLabel;
|
||||
|
||||
class InternalTree
|
||||
{
|
||||
std::string m_value;
|
||||
NTLabel m_value_nt;
|
||||
Word m_value;
|
||||
std::vector<TreePointer> m_children;
|
||||
bool m_isTerminal;
|
||||
public:
|
||||
InternalTree(const std::string & line, size_t start, size_t len, const bool terminal);
|
||||
InternalTree(const std::string & line, const bool terminal = false);
|
||||
InternalTree(const std::string & line, const bool nonterminal = true);
|
||||
InternalTree(const InternalTree & tree):
|
||||
m_value(tree.m_value),
|
||||
m_isTerminal(tree.m_isTerminal) {
|
||||
m_value(tree.m_value) {
|
||||
const std::vector<TreePointer> & children = tree.m_children;
|
||||
for (std::vector<TreePointer>::const_iterator it = children.begin(); it != children.end(); it++) {
|
||||
m_children.push_back(boost::make_shared<InternalTree>(**it));
|
||||
@ -40,20 +38,10 @@ public:
|
||||
void Combine(const std::vector<TreePointer> &previous);
|
||||
void Unbinarize();
|
||||
void GetUnbinarizedChildren(std::vector<TreePointer> &children) const;
|
||||
const std::string & GetLabel() const {
|
||||
const Word & GetLabel() const {
|
||||
return m_value;
|
||||
}
|
||||
|
||||
// optionally identify label by int instead of string;
|
||||
// allows abstraction if multiple nonterminal strings should map to same label.
|
||||
const NTLabel & GetNTLabel() const {
|
||||
return m_value_nt;
|
||||
}
|
||||
|
||||
void SetNTLabel(NTLabel value) {
|
||||
m_value_nt = value;
|
||||
}
|
||||
|
||||
size_t GetLength() const {
|
||||
return m_children.size();
|
||||
}
|
||||
@ -62,38 +50,22 @@ public:
|
||||
}
|
||||
|
||||
bool IsTerminal() const {
|
||||
return m_isTerminal;
|
||||
return !m_value.IsNonTerminal();
|
||||
}
|
||||
|
||||
bool IsLeafNT() const {
|
||||
return (!m_isTerminal && m_children.size() == 0);
|
||||
return (m_value.IsNonTerminal() && m_children.size() == 0);
|
||||
}
|
||||
|
||||
// different methods to search a tree (either just direct children (FlatSearch) or all children (RecursiveSearch)) for constituents.
|
||||
// can be used for formulating syntax constraints.
|
||||
|
||||
// if found, 'it' is iterator to first tree node that matches search string
|
||||
bool FlatSearch(const std::string & label, std::vector<TreePointer>::const_iterator & it) const;
|
||||
bool RecursiveSearch(const std::string & label, std::vector<TreePointer>::const_iterator & it) const;
|
||||
bool FlatSearch(const Word & label, std::vector<TreePointer>::const_iterator & it) const;
|
||||
bool RecursiveSearch(const Word & label, std::vector<TreePointer>::const_iterator & it) const;
|
||||
|
||||
// if found, 'it' is iterator to first tree node that matches search string, and 'parent' to its parent node
|
||||
bool RecursiveSearch(const std::string & label, std::vector<TreePointer>::const_iterator & it, InternalTree const* &parent) const;
|
||||
|
||||
// use NTLabel for search to reduce number of string comparisons / deal with synonymous labels
|
||||
// if found, 'it' is iterator to first tree node that matches search string
|
||||
bool FlatSearch(const NTLabel & label, std::vector<TreePointer>::const_iterator & it) const;
|
||||
bool RecursiveSearch(const NTLabel & label, std::vector<TreePointer>::const_iterator & it) const;
|
||||
|
||||
// if found, 'it' is iterator to first tree node that matches search string, and 'parent' to its parent node
|
||||
bool RecursiveSearch(const NTLabel & label, std::vector<TreePointer>::const_iterator & it, InternalTree const* &parent) const;
|
||||
|
||||
// pass vector of possible labels to search
|
||||
// if found, 'it' is iterator to first tree node that matches search string
|
||||
bool FlatSearch(const std::vector<NTLabel> & labels, std::vector<TreePointer>::const_iterator & it) const;
|
||||
bool RecursiveSearch(const std::vector<NTLabel> & labels, std::vector<TreePointer>::const_iterator & it) const;
|
||||
|
||||
// if found, 'it' is iterator to first tree node that matches search string, and 'parent' to its parent node
|
||||
bool RecursiveSearch(const std::vector<NTLabel> & labels, std::vector<TreePointer>::const_iterator & it, InternalTree const* &parent) const;
|
||||
bool RecursiveSearch(const Word & label, std::vector<TreePointer>::const_iterator & it, InternalTree const* &parent) const;
|
||||
|
||||
// Python-like generator that yields next nonterminal leaf on every call
|
||||
$generator(leafNT) {
|
||||
|
@ -13,33 +13,12 @@ void TreeStructureFeature::Load()
|
||||
|
||||
// syntactic constraints can be hooked in here.
|
||||
m_constraints = NULL;
|
||||
m_labelset = NULL;
|
||||
|
||||
StaticData &staticData = StaticData::InstanceNonConst();
|
||||
staticData.SetTreeStructure(this);
|
||||
}
|
||||
|
||||
|
||||
// define NT labels (ints) that are mapped from strings for quicker comparison.
|
||||
void TreeStructureFeature::AddNTLabels(TreePointer root) const
|
||||
{
|
||||
std::string label = root->GetLabel();
|
||||
|
||||
if (root->IsTerminal()) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::map<std::string, NTLabel>::const_iterator it = m_labelset->string_to_label.find(label);
|
||||
if (it != m_labelset->string_to_label.end()) {
|
||||
root->SetNTLabel(it->second);
|
||||
}
|
||||
|
||||
std::vector<TreePointer> children = root->GetChildren();
|
||||
for (std::vector<TreePointer>::const_iterator it2 = children.begin(); it2 != children.end(); ++it2) {
|
||||
AddNTLabels(*it2);
|
||||
}
|
||||
}
|
||||
|
||||
FFState* TreeStructureFeature::EvaluateWhenApplied(const ChartHypothesis& cur_hypo
|
||||
, int featureID /* used to index the state in the previous hypotheses */
|
||||
, ScoreComponentCollection* accumulator) const
|
||||
@ -48,10 +27,6 @@ FFState* TreeStructureFeature::EvaluateWhenApplied(const ChartHypothesis& cur_hy
|
||||
const std::string *tree = property->GetValueString();
|
||||
TreePointer mytree (boost::make_shared<InternalTree>(*tree));
|
||||
|
||||
if (m_labelset) {
|
||||
AddNTLabels(mytree);
|
||||
}
|
||||
|
||||
//get subtrees (in target order)
|
||||
std::vector<TreePointer> previous_trees;
|
||||
for (size_t pos = 0; pos < cur_hypo.GetCurrTargetPhrase().GetSize(); ++pos) {
|
||||
@ -70,7 +45,7 @@ FFState* TreeStructureFeature::EvaluateWhenApplied(const ChartHypothesis& cur_hy
|
||||
}
|
||||
mytree->Combine(previous_trees);
|
||||
|
||||
bool full_sentence = (mytree->GetChildren().back()->GetLabel() == "</s>" || (mytree->GetChildren().back()->GetLabel() == "SEND" && mytree->GetChildren().back()->GetChildren().back()->GetLabel() == "</s>"));
|
||||
bool full_sentence = (mytree->GetChildren().back()->GetLabel() == m_send || (mytree->GetChildren().back()->GetLabel() == m_send_nt && mytree->GetChildren().back()->GetChildren().back()->GetLabel() == m_send));
|
||||
if (m_binarized && full_sentence) {
|
||||
mytree->Unbinarize();
|
||||
}
|
||||
|
@ -4,6 +4,7 @@
|
||||
#include <map>
|
||||
#include "StatefulFeatureFunction.h"
|
||||
#include "FFState.h"
|
||||
#include "moses/Word.h"
|
||||
#include "InternalTree.h"
|
||||
|
||||
namespace Moses
|
||||
@ -35,11 +36,18 @@ class TreeStructureFeature : public StatefulFeatureFunction
|
||||
SyntaxConstraints* m_constraints;
|
||||
LabelSet* m_labelset;
|
||||
bool m_binarized;
|
||||
Word m_send;
|
||||
Word m_send_nt;
|
||||
|
||||
public:
|
||||
TreeStructureFeature(const std::string &line)
|
||||
:StatefulFeatureFunction(0, line)
|
||||
, m_binarized(false) {
|
||||
ReadParameters();
|
||||
std::vector<FactorType> factors;
|
||||
factors.push_back(0);
|
||||
m_send.CreateFromString(Output, factors, "</s>", false);
|
||||
m_send_nt.CreateFromString(Output, factors, "SEND", true);
|
||||
}
|
||||
~TreeStructureFeature() {
|
||||
delete m_constraints;
|
||||
@ -49,8 +57,6 @@ public:
|
||||
return new TreeState(TreePointer());
|
||||
}
|
||||
|
||||
void AddNTLabels(TreePointer root) const;
|
||||
|
||||
bool IsUseable(const FactorMask &mask) const {
|
||||
return true;
|
||||
}
|
||||
|
@ -70,7 +70,7 @@ void RDLM::Load()
|
||||
static_label_null[i] = lm_label_base_instance_->lookup_input_word(numstr);
|
||||
}
|
||||
|
||||
static_dummy_head = lm_head_base_instance_->lookup_input_word(dummy_head);
|
||||
static_dummy_head = lm_head_base_instance_->lookup_input_word(dummy_head.GetString(0).as_string());
|
||||
|
||||
static_start_head = lm_head_base_instance_->lookup_input_word("<start_head>");
|
||||
static_start_label = lm_head_base_instance_->lookup_input_word("<start_label>");
|
||||
@ -211,7 +211,7 @@ void RDLM::Score(InternalTree* root, const TreePointerMap & back_pointers, boost
|
||||
}
|
||||
|
||||
// ignore virtual nodes (in binarization; except if it's the root)
|
||||
if (m_binarized && root->GetLabel()[0] == '^' && !ancestor_heads.empty()) {
|
||||
if (m_binarized && root->GetLabel().GetString(0).as_string()[0] == '^' && !ancestor_heads.empty()) {
|
||||
// recursion
|
||||
if (root->IsLeafNT() && m_context_up > 1 && ancestor_heads.size()) {
|
||||
root = back_pointers.find(root)->second.get();
|
||||
@ -241,9 +241,9 @@ void RDLM::Score(InternalTree* root, const TreePointerMap & back_pointers, boost
|
||||
// root of tree: score without context
|
||||
if (ancestor_heads.empty() || (ancestor_heads.size() == m_context_up && ancestor_heads.back() == static_root_head)) {
|
||||
std::vector<int> ngram_head_null (static_head_null);
|
||||
ngram_head_null.back() = lm_head->lookup_output_word(root->GetChildren()[0]->GetLabel());
|
||||
ngram_head_null.back() = lm_head->lookup_output_word(root->GetChildren()[0]->GetLabel().GetString(m_factorType).as_string());
|
||||
if (m_isPretermBackoff && ngram_head_null.back() == 0) {
|
||||
ngram_head_null.back() = lm_head->lookup_output_word(root->GetLabel());
|
||||
ngram_head_null.back() = lm_head->lookup_output_word(root->GetLabel().GetString(m_factorType).as_string());
|
||||
}
|
||||
if (ancestor_heads.size() == m_context_up && ancestor_heads.back() == static_root_head) {
|
||||
std::vector<int>::iterator it = ngram_head_null.begin();
|
||||
@ -296,7 +296,7 @@ void RDLM::Score(InternalTree* root, const TreePointerMap & back_pointers, boost
|
||||
}
|
||||
|
||||
size_t context_up_nonempty = std::min(m_context_up, ancestor_heads.size());
|
||||
const std::string & head_label = root->GetLabel();
|
||||
const std::string & head_label = root->GetLabel().GetString(0).as_string();
|
||||
bool virtual_head = false;
|
||||
int reached_end = 0;
|
||||
int label_idx, label_idx_out;
|
||||
@ -527,7 +527,7 @@ bool RDLM::GetHead(InternalTree* root, const TreePointerMap & back_pointers, std
|
||||
tree = it->get();
|
||||
}
|
||||
|
||||
if (m_binarized && tree->GetLabel()[0] == '^') {
|
||||
if (m_binarized && tree->GetLabel().GetString(0).as_string()[0] == '^') {
|
||||
bool found = GetHead(tree, back_pointers, IDs);
|
||||
if (found) {
|
||||
return true;
|
||||
@ -597,8 +597,8 @@ void RDLM::GetChildHeadsAndLabels(InternalTree *root, const TreePointerMap & bac
|
||||
child_ids = std::make_pair(static_dummy_head, static_dummy_head);
|
||||
}
|
||||
|
||||
labels[j] = lm_head->lookup_input_word(child->GetLabel());
|
||||
labels_output[j] = lm_label->lookup_output_word(child->GetLabel());
|
||||
labels[j] = lm_head->lookup_input_word(child->GetLabel().GetString(0).as_string());
|
||||
labels_output[j] = lm_label->lookup_output_word(child->GetLabel().GetString(0).as_string());
|
||||
heads[j] = child_ids.first;
|
||||
heads_output[j] = child_ids.second;
|
||||
j++;
|
||||
@ -613,18 +613,18 @@ void RDLM::GetChildHeadsAndLabels(InternalTree *root, const TreePointerMap & bac
|
||||
}
|
||||
|
||||
|
||||
void RDLM::GetIDs(const std::string & head, const std::string & preterminal, std::pair<int,int> & IDs) const
|
||||
void RDLM::GetIDs(const Word & head, const Word & preterminal, std::pair<int,int> & IDs) const
|
||||
{
|
||||
IDs.first = lm_head_base_instance_->lookup_input_word(head);
|
||||
IDs.first = lm_head_base_instance_->lookup_input_word(head.GetString(m_factorType).as_string());
|
||||
if (m_isPretermBackoff && IDs.first == 0) {
|
||||
IDs.first = lm_head_base_instance_->lookup_input_word(preterminal);
|
||||
IDs.first = lm_head_base_instance_->lookup_input_word(preterminal.GetString(0).as_string());
|
||||
}
|
||||
if (m_sharedVocab) {
|
||||
IDs.second = IDs.first;
|
||||
} else {
|
||||
IDs.second = lm_head_base_instance_->lookup_output_word(head);
|
||||
IDs.second = lm_head_base_instance_->lookup_output_word(head.GetString(m_factorType).as_string());
|
||||
if (m_isPretermBackoff && IDs.second == 0) {
|
||||
IDs.second = lm_head_base_instance_->lookup_output_word(preterminal);
|
||||
IDs.second = lm_head_base_instance_->lookup_output_word(preterminal.GetString(0).as_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -718,7 +718,9 @@ void RDLM::SetParameter(const std::string& key, const std::string& value)
|
||||
else
|
||||
UTIL_THROW(util::Exception, "Unknown value for argument " << key << "=" << value);
|
||||
} else if (key == "glue_symbol") {
|
||||
m_glueSymbol = value;
|
||||
m_glueSymbolString = value;
|
||||
} else if (key == "factor") {
|
||||
m_factorType = Scan<FactorType>(value);
|
||||
} else if (key == "cache_size") {
|
||||
m_cacheSize = Scan<int>(value);
|
||||
} else {
|
||||
|
@ -3,6 +3,7 @@
|
||||
#include "moses/FF/StatefulFeatureFunction.h"
|
||||
#include "moses/FF/FFState.h"
|
||||
#include "moses/FF/InternalTree.h"
|
||||
#include "moses/Word.h"
|
||||
|
||||
#include <boost/thread/tss.hpp>
|
||||
#include <boost/array.hpp>
|
||||
@ -61,11 +62,12 @@ class RDLM : public StatefulFeatureFunction
|
||||
nplm::neuralTM* lm_label_base_instance_;
|
||||
mutable boost::thread_specific_ptr<nplm::neuralTM> lm_label_backend_;
|
||||
|
||||
std::string dummy_head;
|
||||
std::string m_glueSymbol;
|
||||
std::string m_startSymbol;
|
||||
std::string m_endSymbol;
|
||||
std::string m_endTag;
|
||||
std::string m_glueSymbolString;
|
||||
Word dummy_head;
|
||||
Word m_glueSymbol;
|
||||
Word m_startSymbol;
|
||||
Word m_endSymbol;
|
||||
Word m_endTag;
|
||||
std::string m_path_head_lm;
|
||||
std::string m_path_label_lm;
|
||||
bool m_isPretermBackoff;
|
||||
@ -102,14 +104,12 @@ class RDLM : public StatefulFeatureFunction
|
||||
int static_stop_label_output;
|
||||
int static_start_label_output;
|
||||
|
||||
FactorType m_factorType;
|
||||
|
||||
public:
|
||||
RDLM(const std::string &line)
|
||||
: StatefulFeatureFunction(2, line)
|
||||
, dummy_head("<dummy_head>")
|
||||
, m_glueSymbol("Q")
|
||||
, m_startSymbol("SSTART")
|
||||
, m_endSymbol("SEND")
|
||||
, m_endTag("</s>")
|
||||
, m_glueSymbolString("Q")
|
||||
, m_isPretermBackoff(true)
|
||||
, m_context_left(3)
|
||||
, m_context_right(0)
|
||||
@ -120,8 +120,16 @@ public:
|
||||
, m_normalizeLabelLM(false)
|
||||
, m_sharedVocab(false)
|
||||
, m_binarized(0)
|
||||
, m_cacheSize(1000000) {
|
||||
, m_cacheSize(1000000)
|
||||
, m_factorType(0) {
|
||||
ReadParameters();
|
||||
std::vector<FactorType> factors;
|
||||
factors.push_back(0);
|
||||
dummy_head.CreateFromString(Output, factors, "<dummy_head>", false);
|
||||
m_glueSymbol.CreateFromString(Output, factors, m_glueSymbolString, true);
|
||||
m_startSymbol.CreateFromString(Output, factors, "SSTART", true);
|
||||
m_endSymbol.CreateFromString(Output, factors, "SEND", true);
|
||||
m_endTag.CreateFromString(Output, factors, "</s>", false);
|
||||
}
|
||||
|
||||
~RDLM();
|
||||
@ -133,7 +141,7 @@ public:
|
||||
void Score(InternalTree* root, const TreePointerMap & back_pointers, boost::array<float,4> &score, std::vector<int> &ancestor_heads, std::vector<int> &ancestor_labels, size_t &boundary_hash, int num_virtual = 0, int rescoring_levels = 0) const;
|
||||
bool GetHead(InternalTree* root, const TreePointerMap & back_pointers, std::pair<int,int> & IDs) const;
|
||||
void GetChildHeadsAndLabels(InternalTree *root, const TreePointerMap & back_pointers, int reached_end, const nplm::neuralTM *lm_head, const nplm::neuralTM *lm_labels, std::vector<int> & heads, std::vector<int> & labels, std::vector<int> & heads_output, std::vector<int> & labels_output) const;
|
||||
void GetIDs(const std::string & head, const std::string & preterminal, std::pair<int,int> & IDs) const;
|
||||
void GetIDs(const Word & head, const Word & preterminal, std::pair<int,int> & IDs) const;
|
||||
void ScoreFile(std::string &path); //for debugging
|
||||
void PrintInfo(std::vector<int> &ngram, nplm::neuralTM* lm) const; //for debugging
|
||||
|
||||
@ -190,7 +198,7 @@ public:
|
||||
_end = current->GetChildren().end();
|
||||
iter = current->GetChildren().begin();
|
||||
// expand virtual node
|
||||
while (binarized && !(*iter)->GetLabel().empty() && (*iter)->GetLabel()[0] == '^') {
|
||||
while (binarized && !(*iter)->GetLabel().GetString(0).empty() && (*iter)->GetLabel().GetString(0).data()[0] == '^') {
|
||||
stack.push_back(std::make_pair(current, iter));
|
||||
// also go through trees or previous hypotheses to rescore nodes for which more context has become available
|
||||
if ((*iter)->IsLeafNT()) {
|
||||
@ -227,7 +235,7 @@ public:
|
||||
}
|
||||
}
|
||||
// expand virtual node
|
||||
while (binarized && !(*iter)->GetLabel().empty() && (*iter)->GetLabel()[0] == '^') {
|
||||
while (binarized && !(*iter)->GetLabel().GetString(0).empty() && (*iter)->GetLabel().GetString(0).data()[0] == '^') {
|
||||
stack.push_back(std::make_pair(current, iter));
|
||||
// also go through trees or previous hypotheses to rescore nodes for which more context has become available
|
||||
if ((*iter)->IsLeafNT()) {
|
||||
|
Loading…
Reference in New Issue
Block a user