rename SyntaxConstraintFeature to TreeStructureFeature

(makes it clearer what it does; build/print the internal tree structure in string-to-tree decoding)

no longer rely on name of FF for printing internal trees
This commit is contained in:
Rico Sennrich 2014-03-03 13:56:19 +00:00
parent 01bc3c111e
commit c27ecbe5c6
8 changed files with 411 additions and 200 deletions

View File

@ -50,7 +50,7 @@ POSSIBILITY OF SUCH DAMAGE.
#include "moses/FeatureVector.h"
#include "moses/FF/StatefulFeatureFunction.h"
#include "moses/FF/StatelessFeatureFunction.h"
#include "moses/FF/SyntaxConstraintFeature.h"
#include "moses/FF/TreeStructureFeature.h"
#include "util/exception.hh"
using namespace std;
@ -395,14 +395,16 @@ void IOWrapper::OutputDetailedTreeFragmentsTranslationReport(
UTIL_THROW_IF2(m_detailTreeFragmentsOutputCollector == NULL,
"No output file for tree fragments specified");
//Tree of full sentence (to stderr)
const vector<const StatefulFeatureFunction*>& sff = StatefulFeatureFunction::GetStatefulFeatureFunctions();
for( size_t i=0; i<sff.size(); i++ ) {
const StatefulFeatureFunction *ff = sff[i];
if (ff->GetScoreProducerDescription() == "SyntaxConstraintFeature0") {
const TreeState* tree = dynamic_cast<const TreeState*>(hypo->GetFFState(i));
out << "Full Tree " << translationId << ": " << tree->GetTree()->GetString() << "\n";
break;
//Tree of full sentence
const StatefulFeatureFunction* treeStructure = StaticData::Instance().GetTreeStructure();
if (treeStructure != NULL) {
const vector<const StatefulFeatureFunction*>& sff = StatefulFeatureFunction::GetStatefulFeatureFunctions();
for( size_t i=0; i<sff.size(); i++ ) {
if (sff[i] == treeStructure) {
const TreeState* tree = dynamic_cast<const TreeState*>(hypo->GetFFState(i));
out << "Full Tree " << translationId << ": " << tree->GetTree()->GetString() << "\n";
break;
}
}
}

View File

@ -97,7 +97,7 @@ void ChartParserUnknown::Process(const Word &sourceWord, const WordsRange &range
targetPhrase->SetTargetLHS(targetLHS);
targetPhrase->SetAlignmentInfo("0-0");
if (staticData.IsDetailedTreeFragmentsTranslationReportingEnabled()) {
if (staticData.IsDetailedTreeFragmentsTranslationReportingEnabled() || staticData.GetTreeStructure() != NULL) {
targetPhrase->SetProperty("Tree","[ " + (*targetLHS)[0]->GetString().as_string() + " "+sourceWord[0]->GetString().as_string()+" ]");
}

View File

@ -34,7 +34,7 @@
#include "moses/FF/ExternalFeature.h"
#include "moses/FF/ConstrainedDecoding.h"
#include "moses/FF/CoveredReferenceFeature.h"
#include "moses/FF/SyntaxConstraintFeature.h"
#include "moses/FF/TreeStructureFeature.h"
#include "moses/FF/SoftMatchingFeature.h"
#include "moses/FF/HyperParameterAsWeight.h"
@ -174,7 +174,7 @@ FeatureRegistry::FeatureRegistry()
MOSES_FNAME(ConstrainedDecoding);
MOSES_FNAME(CoveredReferenceFeature);
MOSES_FNAME(ExternalFeature);
MOSES_FNAME(SyntaxConstraintFeature);
MOSES_FNAME(TreeStructureFeature);
MOSES_FNAME(SoftMatchingFeature);
MOSES_FNAME(HyperParameterAsWeight);

View File

@ -1,186 +0,0 @@
#include "SyntaxConstraintFeature.h"
#include "moses/ScoreComponentCollection.h"
#include "moses/Hypothesis.h"
#include "moses/ChartHypothesis.h"
#include "moses/TargetPhrase.h"
#include <boost/shared_ptr.hpp>
#include <vector>
using namespace std;
namespace Moses
{
InternalTree::InternalTree(const std::string & line, const bool terminal) {
size_t found = line.find_first_of("[] ");
m_isTerminal = terminal;
if (found == line.npos) {
m_value = line;
}
else {
AddSubTree(line, 0);
}
}
size_t InternalTree::AddSubTree(const std::string & line, size_t pos) {
std::string value = "";
char token = 0;
while (token != ']' && pos != std::string::npos)
{
size_t oldpos = pos;
pos = line.find_first_of("[] ", pos);
if (pos == std::string::npos) break;
token = line[pos];
value = line.substr(oldpos,pos-oldpos);
if (token == '[') {
if (m_value.size() > 0) {
TreePointer child(new InternalTree(value, false));
m_children.push_back(child);
pos = child->AddSubTree(line, pos+1);
}
else {
if (value.size() > 0) {
m_value = value;
}
pos = AddSubTree(line, pos+1);
}
}
else if (token == ' ' || token == ']') {
if (value.size() > 0 && ! m_value.size() > 0) {
m_value = value;
}
else if (value.size() > 0) {
m_isTerminal = false;
TreePointer child(new InternalTree(value, true));
m_children.push_back(child);
}
if (token == ' ') {
pos++;
}
}
if (m_children.size() > 0) {
m_isTerminal = false;
}
}
if (pos == std::string::npos) {
return line.size();
}
return min(line.size(),pos+1);
}
std::string InternalTree::GetString() const {
std::string ret = " ";
if (!m_isTerminal) {
ret += "[";
}
ret += m_value;
for (std::vector<TreePointer>::const_iterator it = m_children.begin(); it != m_children.end(); ++it)
{
ret += (*it)->GetString();
}
if (!m_isTerminal) {
ret += "]";
}
return ret;
}
void InternalTree::Combine(const std::vector<TreePointer> &previous) {
std::vector<TreePointer>::iterator it;
bool found = false;
leafNT next_leafNT(this);
for (std::vector<TreePointer>::const_iterator it_prev = previous.begin(); it_prev != previous.end(); ++it_prev) {
found = next_leafNT(it);
if (found) {
*it = *it_prev;
}
else {
std::cerr << "Warning: leaf nonterminal not found in rule; why did this happen?\n";
}
}
}
bool InternalTree::FlatSearch(const std::string & label, std::vector<TreePointer>::const_iterator & it) const {
for (it = m_children.begin(); it != m_children.end(); ++it) {
if ((*it)->GetLabel() == label) {
return true;
}
}
return false;
}
bool InternalTree::RecursiveSearch(const std::string & label, std::vector<TreePointer>::const_iterator & it) const {
for (it = m_children.begin(); it != m_children.end(); ++it) {
if ((*it)->GetLabel() == label) {
return true;
}
std::vector<TreePointer>::const_iterator it2;
if ((*it)->RecursiveSearch(label, it2)) {
it = it2;
return true;
}
}
return false;
}
bool InternalTree::RecursiveSearch(const std::string & label, std::vector<TreePointer>::const_iterator & it, InternalTree const* &parent) const {
for (it = m_children.begin(); it != m_children.end(); ++it) {
if ((*it)->GetLabel() == label) {
parent = this;
return true;
}
std::vector<TreePointer>::const_iterator it2;
if ((*it)->RecursiveSearch(label, it2, parent)) {
it = it2;
return true;
}
}
return false;
}
FFState* SyntaxConstraintFeature::EvaluateChart(const ChartHypothesis& cur_hypo
, int featureID /* used to index the state in the previous hypotheses */
, ScoreComponentCollection* accumulator) const
{
std::string tree;
bool found = 0;
cur_hypo.GetCurrTargetPhrase().GetProperty("Tree", tree, found);
TreePointer mytree (new InternalTree(tree));
//get subtrees (in target order)
std::vector<TreePointer> previous_trees;
for (size_t pos = 0; pos < cur_hypo.GetCurrTargetPhrase().GetSize(); ++pos) {
const Word &word = cur_hypo.GetCurrTargetPhrase().GetWord(pos);
if (word.IsNonTerminal()) {
size_t nonTermInd = cur_hypo.GetCurrTargetPhrase().GetAlignNonTerm().GetNonTermIndexMap()[pos];
const ChartHypothesis *prevHypo = cur_hypo.GetPrevHypo(nonTermInd);
const TreeState* prev = dynamic_cast<const TreeState*>(prevHypo->GetFFState(featureID));
const TreePointer prev_tree = prev->GetTree();
previous_trees.push_back(prev_tree);
}
}
mytree->Combine(previous_trees);
return new TreeState(mytree);
}
}

View File

@ -0,0 +1,315 @@
#include "TreeStructureFeature.h"
#include "moses/StaticData.h"
#include "moses/ScoreComponentCollection.h"
#include "moses/Hypothesis.h"
#include "moses/ChartHypothesis.h"
#include "moses/TargetPhrase.h"
#include <boost/shared_ptr.hpp>
#include <vector>
using namespace std;
namespace Moses
{
InternalTree::InternalTree(const std::string & line, const bool terminal):
m_value_nt(0),
m_isTerminal(terminal)
{
size_t found = line.find_first_of("[] ");
if (found == line.npos) {
m_value = line;
}
else {
AddSubTree(line, 0);
}
}
size_t InternalTree::AddSubTree(const std::string & line, size_t pos) {
std::string value = "";
char token = 0;
while (token != ']' && pos != std::string::npos)
{
size_t oldpos = pos;
pos = line.find_first_of("[] ", pos);
if (pos == std::string::npos) break;
token = line[pos];
value = line.substr(oldpos,pos-oldpos);
if (token == '[') {
if (m_value.size() > 0) {
TreePointer child(new InternalTree(value, false));
m_children.push_back(child);
pos = child->AddSubTree(line, pos+1);
}
else {
if (value.size() > 0) {
m_value = value;
}
pos = AddSubTree(line, pos+1);
}
}
else if (token == ' ' || token == ']') {
if (value.size() > 0 && ! m_value.size() > 0) {
m_value = value;
}
else if (value.size() > 0) {
m_isTerminal = false;
TreePointer child(new InternalTree(value, true));
m_children.push_back(child);
}
if (token == ' ') {
pos++;
}
}
if (m_children.size() > 0) {
m_isTerminal = false;
}
}
if (pos == std::string::npos) {
return line.size();
}
return min(line.size(),pos+1);
}
std::string InternalTree::GetString() const {
std::string ret = " ";
if (!m_isTerminal) {
ret += "[";
}
ret += m_value;
for (std::vector<TreePointer>::const_iterator it = m_children.begin(); it != m_children.end(); ++it)
{
ret += (*it)->GetString();
}
if (!m_isTerminal) {
ret += "]";
}
return ret;
}
void InternalTree::Combine(const std::vector<TreePointer> &previous) {
std::vector<TreePointer>::iterator it;
bool found = false;
leafNT next_leafNT(this);
for (std::vector<TreePointer>::const_iterator it_prev = previous.begin(); it_prev != previous.end(); ++it_prev) {
found = next_leafNT(it);
if (found) {
*it = *it_prev;
}
else {
std::cerr << "Warning: leaf nonterminal not found in rule; why did this happen?\n";
}
}
}
bool InternalTree::FlatSearch(const std::string & label, std::vector<TreePointer>::const_iterator & it) const {
for (it = m_children.begin(); it != m_children.end(); ++it) {
if ((*it)->GetLabel() == label) {
return true;
}
}
return false;
}
bool InternalTree::RecursiveSearch(const std::string & label, std::vector<TreePointer>::const_iterator & it) const {
for (it = m_children.begin(); it != m_children.end(); ++it) {
if ((*it)->GetLabel() == label) {
return true;
}
std::vector<TreePointer>::const_iterator it2;
if ((*it)->RecursiveSearch(label, it2)) {
it = it2;
return true;
}
}
return false;
}
bool InternalTree::RecursiveSearch(const std::string & label, std::vector<TreePointer>::const_iterator & it, InternalTree const* &parent) const {
for (it = m_children.begin(); it != m_children.end(); ++it) {
if ((*it)->GetLabel() == label) {
parent = this;
return true;
}
std::vector<TreePointer>::const_iterator it2;
if ((*it)->RecursiveSearch(label, it2, parent)) {
it = it2;
return true;
}
}
return false;
}
bool InternalTree::FlatSearch(const NTLabel & label, std::vector<TreePointer>::const_iterator & it) const {
for (it = m_children.begin(); it != m_children.end(); ++it) {
if ((*it)->GetNTLabel() == label) {
return true;
}
}
return false;
}
bool InternalTree::RecursiveSearch(const NTLabel & label, std::vector<TreePointer>::const_iterator & it) const {
for (it = m_children.begin(); it != m_children.end(); ++it) {
if ((*it)->GetNTLabel() == label) {
return true;
}
std::vector<TreePointer>::const_iterator it2;
if ((*it)->RecursiveSearch(label, it2)) {
it = it2;
return true;
}
}
return false;
}
bool InternalTree::RecursiveSearch(const NTLabel & label, std::vector<TreePointer>::const_iterator & it, InternalTree const* &parent) const {
for (it = m_children.begin(); it != m_children.end(); ++it) {
if ((*it)->GetNTLabel() == label) {
parent = this;
return true;
}
std::vector<TreePointer>::const_iterator it2;
if ((*it)->RecursiveSearch(label, it2, parent)) {
it = it2;
return true;
}
}
return false;
}
bool InternalTree::FlatSearch(const std::vector<NTLabel> & labels, std::vector<TreePointer>::const_iterator & it) const {
for (it = m_children.begin(); it != m_children.end(); ++it) {
if (std::binary_search(labels.begin(), labels.end(), (*it)->GetNTLabel())) {
return true;
}
}
return false;
}
bool InternalTree::RecursiveSearch(const std::vector<NTLabel> & labels, std::vector<TreePointer>::const_iterator & it) const {
for (it = m_children.begin(); it != m_children.end(); ++it) {
if (std::binary_search(labels.begin(), labels.end(), (*it)->GetNTLabel())) {
return true;
}
std::vector<TreePointer>::const_iterator it2;
if ((*it)->RecursiveSearch(labels, it2)) {
it = it2;
return true;
}
}
return false;
}
bool InternalTree::RecursiveSearch(const std::vector<NTLabel> & labels, std::vector<TreePointer>::const_iterator & it, InternalTree const* &parent) const {
for (it = m_children.begin(); it != m_children.end(); ++it) {
if (std::binary_search(labels.begin(), labels.end(), (*it)->GetNTLabel())) {
parent = this;
return true;
}
std::vector<TreePointer>::const_iterator it2;
if ((*it)->RecursiveSearch(labels, it2, parent)) {
it = it2;
return true;
}
}
return false;
}
void TreeStructureFeature::Load() {
// syntactic constraints can be hooked in here.
m_constraints = NULL;
m_labelset = NULL;
StaticData &staticData = StaticData::InstanceNonConst();
staticData.SetTreeStructure(this);
}
// define NT labels (ints) that are mapped from strings for quicker comparison.
void TreeStructureFeature::AddNTLabels(TreePointer root) const {
std::string label = root->GetLabel();
if (root->IsTerminal()) {
return;
}
std::map<std::string, NTLabel>::const_iterator it = m_labelset->string_to_label.find(label);
if (it != m_labelset->string_to_label.end()) {
root->SetNTLabel(it->second);
}
std::vector<TreePointer> children = root->GetChildren();
for (std::vector<TreePointer>::const_iterator it2 = children.begin(); it2 != children.end(); ++it2) {
AddNTLabels(*it2);
}
}
FFState* TreeStructureFeature::EvaluateChart(const ChartHypothesis& cur_hypo
, int featureID /* used to index the state in the previous hypotheses */
, ScoreComponentCollection* accumulator) const
{
std::string tree;
bool found = 0;
cur_hypo.GetCurrTargetPhrase().GetProperty("Tree", tree, found);
if (found) {
TreePointer mytree (new InternalTree(tree));
if (m_labelset) {
AddNTLabels(mytree);
}
//get subtrees (in target order)
std::vector<TreePointer> previous_trees;
for (size_t pos = 0; pos < cur_hypo.GetCurrTargetPhrase().GetSize(); ++pos) {
const Word &word = cur_hypo.GetCurrTargetPhrase().GetWord(pos);
if (word.IsNonTerminal()) {
size_t nonTermInd = cur_hypo.GetCurrTargetPhrase().GetAlignNonTerm().GetNonTermIndexMap()[pos];
const ChartHypothesis *prevHypo = cur_hypo.GetPrevHypo(nonTermInd);
const TreeState* prev = dynamic_cast<const TreeState*>(prevHypo->GetFFState(featureID));
const TreePointer prev_tree = prev->GetTree();
previous_trees.push_back(prev_tree);
}
}
std::vector<std::string> sparse_features;
if (m_constraints) {
sparse_features = m_constraints->SyntacticRules(mytree, previous_trees);
}
mytree->Combine(previous_trees);
//sparse scores
for (std::vector<std::string>::const_iterator feature=sparse_features.begin(); feature != sparse_features.end(); ++feature) {
accumulator->PlusEquals(this, *feature, 1);
}
return new TreeState(mytree);
}
else {
UTIL_THROW2("Error: TreeStructureFeature active, but no internal tree structure found");
}
}
}

View File

@ -1,6 +1,7 @@
#pragma once
#include <string>
#include <map>
#include "StatefulFeatureFunction.h"
#include "FFState.h"
#include <boost/shared_ptr.hpp>
@ -12,14 +13,25 @@ namespace Moses
class InternalTree;
typedef boost::shared_ptr<InternalTree> TreePointer;
typedef int NTLabel;
class InternalTree
{
std::string m_value;
NTLabel m_value_nt;
std::vector<TreePointer> m_children;
bool m_isTerminal;
public:
InternalTree(const std::string & line, const bool terminal = false);
InternalTree(const InternalTree & tree):
m_value(tree.m_value),
m_isTerminal(tree.m_isTerminal) {
const std::vector<TreePointer> & children = tree.m_children;
for (std::vector<TreePointer>::const_iterator it = children.begin(); it != children.end(); it++) {
TreePointer child (new InternalTree(**it));
m_children.push_back(child);
}
}
size_t AddSubTree(const std::string & line, size_t start);
std::string GetString() const;
@ -27,6 +39,17 @@ public:
const std::string & GetLabel() const {
return m_value;
}
// optionally identify label by int instead of string;
// allows abstraction if multiple nonterminal strings should map to same label.
const NTLabel & GetNTLabel() const {
return m_value_nt;
}
void SetNTLabel(NTLabel value) {
m_value_nt = value;
}
size_t GetLength() const {
return m_children.size();
}
@ -45,6 +68,8 @@ public:
return (!m_isTerminal && m_children.size() == 0);
}
// different methods to search a tree (either just direct children (FlatSearch) or all children (RecursiveSearch)) for constituents.
// can be used for formulating syntax constraints.
// if found, 'it' is iterator to first tree node that matches search string
bool FlatSearch(const std::string & label, std::vector<TreePointer>::const_iterator & it) const;
@ -53,6 +78,41 @@ public:
// if found, 'it' is iterator to first tree node that matches search string, and 'parent' to its parent node
bool RecursiveSearch(const std::string & label, std::vector<TreePointer>::const_iterator & it, InternalTree const* &parent) const;
// use NTLabel for search to reduce number of string comparisons / deal with synonymous labels
// if found, 'it' is iterator to first tree node that matches search string
bool FlatSearch(const NTLabel & label, std::vector<TreePointer>::const_iterator & it) const;
bool RecursiveSearch(const NTLabel & label, std::vector<TreePointer>::const_iterator & it) const;
// if found, 'it' is iterator to first tree node that matches search string, and 'parent' to its parent node
bool RecursiveSearch(const NTLabel & label, std::vector<TreePointer>::const_iterator & it, InternalTree const* &parent) const;
// pass vector of possible labels to search
// if found, 'it' is iterator to first tree node that matches search string
bool FlatSearch(const std::vector<NTLabel> & labels, std::vector<TreePointer>::const_iterator & it) const;
bool RecursiveSearch(const std::vector<NTLabel> & labels, std::vector<TreePointer>::const_iterator & it) const;
// if found, 'it' is iterator to first tree node that matches search string, and 'parent' to its parent node
bool RecursiveSearch(const std::vector<NTLabel> & labels, std::vector<TreePointer>::const_iterator & it, InternalTree const* &parent) const;
};
// mapping from string nonterminal label to int representation.
// allows abstraction if multiple nonterminal strings should map to same label.
struct LabelSet
{
public:
std::map<std::string, NTLabel> string_to_label;
};
// class to implement language-specific syntactic constraints.
// the method SyntacticRules must return a vector of strings (each identifying a constraint violation), which are then made into sparse features.
class SyntaxConstraints
{
public:
virtual std::vector<std::string> SyntacticRules(TreePointer root, const std::vector<TreePointer> &previous) = 0;
virtual ~SyntaxConstraints() {};
};
@ -71,18 +131,23 @@ public:
int Compare(const FFState& other) const {return 0;};
};
class SyntaxConstraintFeature : public StatefulFeatureFunction
class TreeStructureFeature : public StatefulFeatureFunction
{
SyntaxConstraints* m_constraints;
LabelSet* m_labelset;
public:
SyntaxConstraintFeature(const std::string &line)
TreeStructureFeature(const std::string &line)
:StatefulFeatureFunction(0, line) {
ReadParameters();
}
~TreeStructureFeature() {delete m_constraints;};
virtual const FFState* EmptyHypothesisState(const InputType &input) const {
return new TreeState(TreePointer());
}
void AddNTLabels(TreePointer root) const;
bool IsUseable(const FactorMask &mask) const {
return true;
}
@ -105,6 +170,7 @@ public:
int /* featureID - used to index the state in the previous hypotheses */,
ScoreComponentCollection* accumulator) const;
void Load();
};
// Python-like generator that yields next nonterminal leaf on every call

View File

@ -66,6 +66,7 @@ StaticData::StaticData()
,m_lmEnableOOVFeature(false)
,m_isAlwaysCreateDirectTranslationOption(false)
,m_currentWeightSetting("default")
,m_treeStructure(NULL)
{
m_xmlBrackets.first="<";
m_xmlBrackets.second=">";

View File

@ -221,6 +221,8 @@ protected:
std::map<Word, std::set<Word> > m_soft_matches_map;
std::map<Word, std::set<Word> > m_soft_matches_map_reverse;
const StatefulFeatureFunction* m_treeStructure;
public:
bool IsAlwaysCreateDirectTranslationOption() const {
@ -759,6 +761,17 @@ public:
void ResetWeights(const std::string &denseWeights, const std::string &sparseFile);
// need global access for output of tree structure
const StatefulFeatureFunction* GetTreeStructure() const {
return m_treeStructure;
}
void SetTreeStructure(const StatefulFeatureFunction* treeStructure) {
m_treeStructure = treeStructure;
}
};
}