code simplification by removing language-specific, unused hack.

This commit is contained in:
Rico Sennrich 2015-07-17 14:45:38 +01:00
parent c1142741a1
commit e85f353898
4 changed files with 18 additions and 70 deletions

View File

@ -290,8 +290,8 @@ void RDLM::Score(InternalTree* root, const TreePointerMap & back_pointers, boost
}
std::pair<int,int> head_ids;
InternalTree* found = GetHead(root, back_pointers, head_ids);
if (found == NULL) {
bool found = GetHead(root, back_pointers, head_ids);
if (!found) {
head_ids = std::make_pair(static_dummy_head, static_dummy_head);
}
@ -516,7 +516,7 @@ void RDLM::Score(InternalTree* root, const TreePointerMap & back_pointers, boost
ancestor_labels.pop_back();
}
InternalTree* RDLM::GetHead(InternalTree* root, const TreePointerMap & back_pointers, std::pair<int,int> & IDs, InternalTree* head_ptr) const
bool RDLM::GetHead(InternalTree* root, const TreePointerMap & back_pointers, std::pair<int,int> & IDs) const
{
InternalTree *tree;
@ -528,51 +528,27 @@ InternalTree* RDLM::GetHead(InternalTree* root, const TreePointerMap & back_poin
}
if (m_binarized && tree->GetLabel()[0] == '^') {
head_ptr = GetHead(tree, back_pointers, IDs, head_ptr);
if (head_ptr != NULL && !m_isPTKVZ) {
return head_ptr;
bool found = GetHead(tree, back_pointers, IDs);
if (found) {
return true;
}
}
// assumption (only true for dependency parse): each constituent has a preterminal label, and corresponding terminal is head
// if constituent has multiple preterminals, first one is picked; if it has no preterminals, dummy_head is returned
else if (tree->GetLength() == 1 && tree->GetChildren()[0]->IsTerminal() && head_ptr == NULL) {
head_ptr = tree;
if (!m_isPTKVZ) {
GetIDs(head_ptr->GetChildren()[0]->GetLabel(), head_ptr->GetLabel(), IDs);
return head_ptr;
}
}
// add PTKVZ to lemma of verb
else if (m_isPTKVZ && head_ptr && tree->GetLabel() == "avz") {
InternalTree *tree2;
for (std::vector<TreePointer>::const_iterator it2 = tree->GetChildren().begin(); it2 != tree->GetChildren().end(); ++it2) {
if ((*it2)->IsLeafNT()) {
tree2 = back_pointers.find(it2->get())->second.get();
} else {
tree2 = it2->get();
}
if (tree2->GetLabel() == "PTKVZ" && tree2->GetLength() == 1 && tree2->GetChildren()[0]->IsTerminal()) {
std::string verb = tree2->GetChildren()[0]->GetLabel() + head_ptr->GetChildren()[0]->GetLabel();
GetIDs(verb, head_ptr->GetLabel(), IDs);
return head_ptr;
}
}
else if (tree->GetLength() == 1 && tree->GetChildren()[0]->IsTerminal()) {
GetIDs(tree->GetChildren()[0]->GetLabel(), tree->GetLabel(), IDs);
return true;
}
}
if (head_ptr != NULL) {
GetIDs(head_ptr->GetChildren()[0]->GetLabel(), head_ptr->GetLabel(), IDs);
}
return head_ptr;
return false;
}
void RDLM::GetChildHeadsAndLabels(InternalTree *root, const TreePointerMap & back_pointers, int reached_end, const nplm::neuralTM *lm_head, const nplm::neuralTM *lm_label, std::vector<int> & heads, std::vector<int> & labels, std::vector<int> & heads_output, std::vector<int> & labels_output) const
{
std::pair<int,int> child_ids;
InternalTree* found;
size_t j = 0;
// score start label (if enabled) for all nonterminal nodes (but not for terminal or preterminal nodes)
@ -616,8 +592,8 @@ void RDLM::GetChildHeadsAndLabels(InternalTree *root, const TreePointerMap & bac
continue;
}
found = GetHead(child, back_pointers, child_ids);
if (found == NULL) {
bool found = GetHead(child, back_pointers, child_ids);
if (!found) {
child_ids = std::make_pair(static_dummy_head, static_dummy_head);
}
@ -714,8 +690,6 @@ void RDLM::SetParameter(const std::string& key, const std::string& value)
m_path_head_lm = value;
} else if (key == "path_label_lm") {
m_path_label_lm = value;
} else if (key == "ptkvz") {
m_isPTKVZ = Scan<bool>(value);
} else if (key == "backoff") {
m_isPretermBackoff = Scan<bool>(value);
} else if (key == "context_up") {

View File

@ -68,7 +68,6 @@ class RDLM : public StatefulFeatureFunction
std::string m_endTag;
std::string m_path_head_lm;
std::string m_path_label_lm;
bool m_isPTKVZ;
bool m_isPretermBackoff;
size_t m_context_left;
size_t m_context_right;
@ -111,7 +110,6 @@ public:
, m_startSymbol("SSTART")
, m_endSymbol("SEND")
, m_endTag("</s>")
, m_isPTKVZ(false)
, m_isPretermBackoff(true)
, m_context_left(3)
, m_context_right(0)
@ -133,7 +131,7 @@ public:
}
void Score(InternalTree* root, const TreePointerMap & back_pointers, boost::array<float,4> &score, std::vector<int> &ancestor_heads, std::vector<int> &ancestor_labels, size_t &boundary_hash, int num_virtual = 0, int rescoring_levels = 0) const;
InternalTree* GetHead(InternalTree* root, const TreePointerMap & back_pointers, std::pair<int,int> & IDs, InternalTree * head_ptr=NULL) const;
bool GetHead(InternalTree* root, const TreePointerMap & back_pointers, std::pair<int,int> & IDs) const;
void GetChildHeadsAndLabels(InternalTree *root, const TreePointerMap & back_pointers, int reached_end, const nplm::neuralTM *lm_head, const nplm::neuralTM *lm_labels, std::vector<int> & heads, std::vector<int> & labels, std::vector<int> & heads_output, std::vector<int> & labels_output) const;
void GetIDs(const std::string & head, const std::string & preterminal, std::pair<int,int> & IDs) const;
void ScoreFile(std::string &path); //for debugging

View File

@ -89,11 +89,6 @@ def create_parser():
help=(
"Sentence end symbol. Will be skipped during extraction "
"(default: %(default)s)"))
parser.add_argument(
'--ptkvz', action='store_true',
help=(
"Special rule for German dependency trees: "
"concatenate separable verb prefix and verb."))
return parser
@ -107,22 +102,15 @@ def escape_text(s):
return s
def get_head(xml, add_ptkvz):
def get_head(xml):
"""Deterministic heuristic to get head of subtree."""
head = None
preterminal = None
for child in xml:
if not len(child):
if head is not None:
continue
preterminal = child.get('label')
head = escape_text(child.text.strip())
elif add_ptkvz and head and child.get('label') == 'avz':
for grandchild in child:
if grandchild.get('label') == 'PTKVZ':
head = escape_text(grandchild.text.strip()) + head
break
return head, preterminal
return head, preterminal
@ -159,7 +147,7 @@ def get_syntactic_ngrams(xml, options, vocab, output_vocab,
parent_labels = (
[vocab.get('<root_label>', 0)] * options.up_context)
head, preterminal = get_head(xml, options.ptkvz)
head, preterminal = get_head(xml)
if not head:
head = '<dummy_head>'
preterminal = head
@ -222,7 +210,7 @@ def get_syntactic_ngrams(xml, options, vocab, output_vocab,
preterminal_child = head_child
child_label = '<head_label>'
else:
head_child, preterminal_child = get_head(child, options.ptkvz)
head_child, preterminal_child = get_head(child)
child_label = child.get('label')
if head_child is None:

View File

@ -46,11 +46,6 @@ def create_parser():
parser.add_argument(
'--output', '-o', type=str, default='vocab', metavar='PREFIX',
help="Output prefix (default: 'vocab')")
parser.add_argument(
'--ptkvz', action="store_true",
help=(
"Special rule for German dependency trees: attach separable "
"verb prefixes to verb."))
return parser
@ -70,16 +65,9 @@ def get_head(xml, args):
preterminal = None
for child in xml:
if not len(child):
if head is not None:
continue
preterminal = child.get('label')
head = escape_text(child.text.strip())
elif args.ptkvz and head and child.get('label') == 'avz':
for grandchild in child:
if grandchild.get('label') == 'PTKVZ':
head = escape_text(grandchild.text.strip()) + head
break
return head, preterminal
return head, preterminal