filter-rule-table: support for "hierarchical" and "s2t" model types

Output should match filter-rule-table.py, but filtering is faster.  Some rough
timings:

             That        This
  System A    0h 13m     0h 04m
  System B   18h 03m     0h 51m

System A is WMT14, en-de, string-to-tree (32M rules, 3,000 test sentences)
System B is WMT14, cs-en, string-to-tree (293M rules, 13,071 test sentences)
This commit is contained in:
Phil Williams 2015-02-10 15:11:10 +00:00
parent 0de206f359
commit 02f5ada680
4 changed files with 532 additions and 4 deletions

View File

@ -0,0 +1,29 @@
#pragma once
#include <istream>
#include <ostream>
#include <string>
#include <vector>
namespace MosesTraining
{
namespace Syntax
{
namespace FilterRuleTable
{
// Base class for StringCfgFilter and TreeCfgFilter, both of which filter rule
// tables where the source-side is CFG.
class CfgFilter {
public:
virtual ~CfgFilter() {}
// Read a rule table from 'in' and filter it according to the test sentences.
virtual void Filter(std::istream &in, std::ostream &out) = 0;
protected:
};
} // namespace FilterRuleTable
} // namespace Syntax
} // namespace MosesTraining

View File

@ -9,6 +9,7 @@
#include <sstream>
#include <vector>
#include <boost/make_shared.hpp>
#include <boost/program_options.hpp>
#include "util/string_piece.hh"
@ -22,6 +23,7 @@
#include "ForestTsgFilter.h"
#include "Options.h"
#include "StringCfgFilter.h"
#include "StringForest.h"
#include "StringForestParser.h"
#include "TreeTsgFilter.h"
@ -77,8 +79,11 @@ int FilterRuleTable::Main(int argc, char *argv[])
// Read the test sentences then set up and run the filter.
if (testSentenceFormat == kString) {
// TODO Implement ReadTestSet for strings and StringCfgFilter
Error("string-based filtering not supported yet");
assert(sourceSideRuleFormat == kCfg);
std::vector<boost::shared_ptr<std::string> > testStrings;
ReadTestSet(testStream, testStrings);
StringCfgFilter filter(testStrings);
filter.Filter(std::cin, std::cout);
} else if (testSentenceFormat == kTree) {
std::vector<boost::shared_ptr<StringTree> > testTrees;
ReadTestSet(testStream, testTrees);
@ -106,8 +111,24 @@ void FilterRuleTable::ReadTestSet(
std::istream &input,
std::vector<boost::shared_ptr<std::string> > &sentences)
{
// TODO
assert(false);
const util::AnyCharacter symbolDelimiter(" \t");
int lineNum = 0;
std::string line;
while (std::getline(input, line)) {
++lineNum;
if (line.empty()) {
std::cerr << "skipping blank test sentence at line " << lineNum
<< std::endl;
continue;
}
std::ostringstream tmp;
tmp << " ";
for (util::TokenIter<util::AnyCharacter, true> p(line, symbolDelimiter);
p; ++p) {
tmp << *p << " ";
}
sentences.push_back(boost::make_shared<std::string>(tmp.str()));
}
}
void FilterRuleTable::ReadTestSet(

View File

@ -0,0 +1,327 @@
#include "StringCfgFilter.h"
#include <algorithm>
#include "util/string_piece_hash.hh"
namespace MosesTraining
{
namespace Syntax
{
namespace FilterRuleTable
{
const std::size_t StringCfgFilter::kMaxNGramLength = 5;
StringCfgFilter::StringCfgFilter(
const std::vector<boost::shared_ptr<std::string> > &sentences)
: m_maxSentenceLength(-1)
{
// Populate m_ngramCoordinateMap (except for the CoordinateTable's
// sentence vectors) and record the sentence lengths.
m_sentenceLengths.reserve(sentences.size());
const util::AnyCharacter delimiter(" \t");
std::vector<Vocabulary::IdType> vocabIds;
for (std::size_t i = 0; i < sentences.size(); ++i) {
vocabIds.clear();
for (util::TokenIter<util::AnyCharacter, true> p(*sentences[i], delimiter);
p; ++p) {
std::string tmp;
p->CopyToString(&tmp);
vocabIds.push_back(m_testVocab.Insert(tmp));
}
AddSentenceNGrams(vocabIds, i);
const int sentenceLength = static_cast<int>(vocabIds.size());
m_sentenceLengths.push_back(sentenceLength);
m_maxSentenceLength = std::max(sentenceLength, m_maxSentenceLength);
}
// Populate the CoordinateTable's sentence vectors.
for (NGramCoordinateMap::iterator p = m_ngramCoordinateMap.begin();
p != m_ngramCoordinateMap.end(); ++p) {
CoordinateTable &ct = p->second;
ct.sentences.reserve(ct.intraSentencePositions.size());
for (boost::unordered_map<int, PositionSeq>::const_iterator
q = ct.intraSentencePositions.begin();
q != ct.intraSentencePositions.end(); ++q) {
ct.sentences.push_back(q->first);
}
std::sort(ct.sentences.begin(), ct.sentences.end());
}
}
void StringCfgFilter::Filter(std::istream &in, std::ostream &out)
{
const util::MultiCharacter fieldDelimiter("|||");
const util::AnyCharacter symbolDelimiter(" \t");
std::string line;
std::string prevLine;
StringPiece source;
std::vector<StringPiece> symbols;
Pattern pattern;
bool keep = true;
int lineNum = 0;
while (std::getline(in, line)) {
++lineNum;
// Read the source-side of the rule.
util::TokenIter<util::MultiCharacter> it(line, fieldDelimiter);
// Check if this rule has the same source-side as the previous rule. If
// it does then we already know whether or not to keep the rule. This
// optimisation is based on the assumption that the rule table is sorted
// (which is the case in the standard Moses training pipeline).
if (*it == source) {
if (keep) {
out << line << std::endl;
}
continue;
}
// The source-side is different from the previous rule's.
source = *it;
// Tokenize the source-side.
symbols.clear();
for (util::TokenIter<util::AnyCharacter, true> p(source, symbolDelimiter);
p; ++p) {
symbols.push_back(*p);
}
keep = GeneratePattern(symbols, pattern) && MatchPattern(pattern);
if (keep) {
out << line << std::endl;
}
// Retain line for the next iteration (in order that the source StringPiece
// remains valid).
prevLine.swap(line);
}
}
void StringCfgFilter::AddSentenceNGrams(
const std::vector<Vocabulary::IdType> &s, std::size_t sentNum)
{
const std::size_t len = s.size();
NGram ngram;
// For each starting position in the sentence:
for (std::size_t i = 0; i < len; ++i) {
// For each n-gram length: 1, 2, 3, ... kMaxNGramLength (or less when
// approaching the end of the sentence):
for (std::size_t n = 1; n <= std::min(kMaxNGramLength, len-i); ++n) {
ngram.clear();
for (std::size_t j = 0; j < n; ++j) {
ngram.push_back(s[i+j]);
}
m_ngramCoordinateMap[ngram].intraSentencePositions[sentNum].push_back(i);
}
}
}
bool StringCfgFilter::GeneratePattern(const std::vector<StringPiece> &symbols,
Pattern &pattern) const
{
pattern.subpatterns.clear();
pattern.minGapWidths.clear();
int gapWidth = 0;
// The first symbol is handled as a special case because there is always a
// leading gap / non-gap.
if (IsNonTerminal(symbols[0])) {
++gapWidth;
} else {
pattern.minGapWidths.push_back(0);
// Add the symbol to the first n-gram.
Vocabulary::IdType vocabId =
m_testVocab.Lookup(symbols[0], StringPieceCompatibleHash(),
StringPieceCompatibleEquals());
if (vocabId == Vocabulary::NullId()) {
return false;
}
pattern.subpatterns.push_back(NGram(1, vocabId));
}
// Process the remaining symbols (except the last which is the RHS).
for (std::size_t i = 1; i < symbols.size()-1; ++i) {
// Is current symbol a non-terminal?
if (IsNonTerminal(symbols[i])) {
++gapWidth;
continue;
}
// Does the current terminal follow a non-terminal?
if (gapWidth > 0) {
pattern.minGapWidths.push_back(gapWidth);
gapWidth = 0;
pattern.subpatterns.resize(pattern.subpatterns.size()+1);
// Is the current n-gram full?
} else if (pattern.subpatterns.back().size() == kMaxNGramLength) {
pattern.minGapWidths.push_back(0);
pattern.subpatterns.resize(pattern.subpatterns.size()+1);
}
// Add the symbol to the current n-gram.
Vocabulary::IdType vocabId =
m_testVocab.Lookup(symbols[i], StringPieceCompatibleHash(),
StringPieceCompatibleEquals());
if (vocabId == Vocabulary::NullId()) {
return false;
}
pattern.subpatterns.back().push_back(vocabId);
}
// Add the final gap width value (0 if the last symbol was a terminal).
pattern.minGapWidths.push_back(gapWidth);
return true;
}
bool StringCfgFilter::IsNonTerminal(const StringPiece &symbol) const
{
return symbol.size() >= 3 && symbol[0] == '[' &&
symbol[symbol.size()-1] == ']';
}
bool StringCfgFilter::MatchPattern(const Pattern &pattern) const
{
// Step 0: If the pattern is just a single gap (i.e. the original rule
// was fully non-lexical) then the pattern matches unless the
// minimum gap width is wider than any sentence.
if (pattern.subpatterns.empty()) {
assert(pattern.minGapWidths.size() == 1);
return pattern.minGapWidths[0] <= m_maxSentenceLength;
}
// Step 1: Look up all of the subpatterns in m_ngramCoordinateMap and record
// pointers to their CoordinateTables.
std::vector<const CoordinateTable *> tables;
for (std::vector<NGram>::const_iterator p = pattern.subpatterns.begin();
p != pattern.subpatterns.end(); ++p) {
NGramCoordinateMap::const_iterator q = m_ngramCoordinateMap.find(*p);
// If a subpattern doesn't appear in m_ngramCoordinateMap then the match
// has already failed.
if (q == m_ngramCoordinateMap.end()) {
return false;
}
tables.push_back(&(q->second));
}
// Step 2: Intersect the CoordinateTables' sentence sets to find the set of
// test set sentences in which all subpatterns occur.
std::vector<int> intersection = tables[0]->sentences;
std::vector<int> tmp(intersection.size());
for (std::size_t i = 1; i < tables.size(); ++i) {
std::vector<int>::iterator p = std::set_intersection(
intersection.begin(), intersection.end(), tables[i]->sentences.begin(),
tables[i]->sentences.end(), tmp.begin());
tmp.resize(p-tmp.begin());
if (tmp.empty()) {
return false;
}
intersection.swap(tmp);
}
// Step 3: For each sentence in the intersection, construct a trellis
// with a column of intra-sentence positions for each subpattern.
// If there is a consistent path of position values through the
// trellis then there is a match ('consistent' here means that the
// subpatterns occur in the right order and are separated by at
// least the minimum widths required by the pattern's gaps).
for (std::vector<int>::const_iterator p = intersection.begin();
p != intersection.end(); ++p) {
const int sentenceId = *p;
const int sentenceLength = m_sentenceLengths[sentenceId];
SentenceTrellis trellis;
// For each subpattern's CoordinateTable:
for (std::vector<const CoordinateTable *>::const_iterator
q = tables.begin(); q != tables.end(); ++q) {
const CoordinateTable &table = **q;
// Add the intra-sentence position sequence as a column of the trellis.
boost::unordered_map<int, PositionSeq>::const_iterator r =
table.intraSentencePositions.find(sentenceId);
assert(r != table.intraSentencePositions.end());
trellis.columns.push_back(&(r->second));
}
// Search the trellis for a consistent sequence of position values.
if (MatchPattern(trellis, sentenceLength, pattern)) {
return true;
}
}
return false;
}
bool StringCfgFilter::MatchPattern(const SentenceTrellis &trellis,
int sentenceLength,
const Pattern &pattern) const
{
// In the for loop below, we need to know the set of start position ranges
// where subpattern i is allowed to occur (rangeSet) and we are generating
// the ranges for subpattern i+1 (nextRangeSet).
// TODO Merge ranges if subpattern i follows a non-zero gap.
std::vector<Range> rangeSet;
std::vector<Range> nextRangeSet;
// Calculate the range for the first subpattern.
int minStart = pattern.minGapWidths[0];
int maxStart = sentenceLength - MinWidth(pattern, 0);
rangeSet.push_back(Range(minStart, maxStart));
// Attempt to match subpatterns.
for (int i = 0; i < pattern.subpatterns.size(); ++i) {
const PositionSeq &col = *trellis.columns[i];
for (PositionSeq::const_iterator p = col.begin(); p != col.end(); ++p) {
bool inRange = false;
for (std::vector<Range>::const_iterator q = rangeSet.begin();
q != rangeSet.end(); ++q) {
if (*p >= q->first && *p <= q->second) {
inRange = true;
break;
}
}
if (!inRange) {
continue;
}
// If this is the last subpattern then we're done.
if (i+1 == pattern.subpatterns.size()) {
return true;
}
nextRangeSet.push_back(CalcNextRange(pattern, i, *p, sentenceLength));
}
if (nextRangeSet.empty()) {
return false;
}
rangeSet.swap(nextRangeSet);
nextRangeSet.clear();
}
return true;
}
StringCfgFilter::Range StringCfgFilter::CalcNextRange(
const Pattern &pattern, int i, int x, int sentenceLength) const
{
assert(i+1 < pattern.subpatterns.size());
Range range;
if (pattern.minGapWidths[i+1] == 0) {
// The next subpattern follows this one without a gap.
range.first = range.second = x + pattern.subpatterns[i].size();
} else {
range.first = x + pattern.subpatterns[i].size() + pattern.minGapWidths[i+1];
// TODO MinWidth should only be computed once per subpattern.
range.second = sentenceLength - MinWidth(pattern, i+1);
}
return range;
}
int StringCfgFilter::MinWidth(const Pattern &pattern, int i) const
{
int minWidth = 0;
for (; i < pattern.subpatterns.size(); ++i) {
minWidth += pattern.subpatterns[i].size();
minWidth += pattern.minGapWidths[i+1];
}
return minWidth;
}
} // namespace FilterRuleTable
} // namespace Syntax
} // namespace MosesTraining

View File

@ -0,0 +1,151 @@
#pragma once
#include <string>
#include <vector>
#include "syntax-common/numbered_set.h"
#include <boost/shared_ptr.hpp>
#include <boost/unordered_map.hpp>
#include "util/string_piece.hh"
#include "util/tokenize_piece.hh"
#include "CfgFilter.h"
namespace MosesTraining
{
namespace Syntax
{
namespace FilterRuleTable
{
// Filters a rule table, discarding rules that cannot be applied to a given
// test set. The rule table must have a CFG source-side and the test sentences
// must be strings.
class StringCfgFilter : public CfgFilter {
public:
// Initialize the filter for a given set of test sentences.
StringCfgFilter(const std::vector<boost::shared_ptr<std::string> > &);
void Filter(std::istream &in, std::ostream &out);
private:
// Filtering works by converting the source LHSs of translation rules to
// patterns containing variable length gaps and then pattern matching
// against the test set.
//
// The algorithm is vaguely similar to Algorithm 1 from Rahman et al. (2006),
// but with a slightly different definition of a pattern and designed for a
// text containing sentence boundaries. Here the text is assumed to be
// short (a few thousand sentences) and the number of patterns is assumed to
// be large (tens of millions of rules).
//
// M. Sohel Rahman, Costas S. Iliopoulos, Inbok Lee, Manal Mohamed, and
// William F. Smyth
// "Finding Patterns with Variable Length Gaps or Don't Cares"
// In proceedings of COCOON, 2006
// Max NGram length.
static const std::size_t kMaxNGramLength;
// Maps symbols (terminals and non-terminals) from strings to integers.
typedef NumberedSet<std::string, std::size_t> Vocabulary;
// A NGram is a sequence of words.
typedef std::vector<Vocabulary::IdType> NGram;
// A pattern is an alternating sequence of gaps and NGram subpatterns,
// starting and ending with a gap. Every gap has a minimum width, which
// can be any integer >= 0 (a gap of width 0 is really a non-gap).
//
// The source LHSs of translation rules are converted to patterns where each
// sequence of m consecutive non-terminals is converted to a gap with minimum
// width m. For example, if a rule has the source LHS:
//
// [NP] and all the king 's men could n't [VB] [NP] together again
//
// and kMaxN is set to 5 then the following pattern is used:
//
// * <and all the king 's> * <men could n't> * <together again> *
//
// where the gaps have minimum widths of 1, 0, 2, and 0.
//
struct Pattern
{
std::vector<NGram> subpatterns;
std::vector<int> minGapWidths;
};
// A sorted (ascending) sequence of start positions.
typedef std::vector<int> PositionSeq;
// A range of start positions.
typedef std::pair<int, int> Range;
// A SentenceTrellis holds the positions at which each of a pattern's
// subpatterns occur in a single sentence.
struct SentenceTrellis
{
std::vector<const PositionSeq *> columns;
};
// A CoordinateTable records the set of sentences in which a single
// n-gram occurs and for each of those sentences, the start positions
struct CoordinateTable {
// Sentences IDs (ascending). This contains the same values as the key set
// from intraSentencePositions but sorted into ascending order.
std::vector<int> sentences;
// Map from sentence ID to set of intra-sentence start positions.
boost::unordered_map<int, PositionSeq> intraSentencePositions;
};
// NGramCoordinateMap is the main search structure. It maps a NGram to
// a CoordinateTable containing the positions that the n-gram occurs at
// in the test set.
typedef boost::unordered_map<NGram, CoordinateTable> NGramCoordinateMap;
// Add all n-grams and coordinates for a single sentence s with index i.
void AddSentenceNGrams(const std::vector<Vocabulary::IdType> &s,
std::size_t i);
// Calculate the range of possible start positions for subpattern i+1
// assuming that subpattern i has position x.
Range CalcNextRange(const Pattern &p, int i, int x, int sentenceLength) const;
// Generate the pattern corresponding to the given source-side of a rule.
// This will fail if the rule's source-side contains any terminals that
// do not occur in the test sentence vocabulary.
bool GeneratePattern(const std::vector<StringPiece> &, Pattern &) const;
// Calculate the minimum width of the pattern suffix starting
// at subpattern i.
int MinWidth(const Pattern &p, int i) const;
bool IsNonTerminal(const StringPiece &symbol) const;
// Try to match the pattern p against any sentence in the test set.
bool MatchPattern(const Pattern &p) const;
// Try to match the pattern p against the SentenceTrellis t of a single
// sentence.
bool MatchPattern(const SentenceTrellis &t, int sentenceLength,
const Pattern &p) const;
// The main search structure constructed from the test set sentences.
NGramCoordinateMap m_ngramCoordinateMap;
// The lengths of the test sentences.
std::vector<int> m_sentenceLengths;
// The maximum length of any test sentence.
int m_maxSentenceLength;
// The symbol vocabulary of the test sentences.
Vocabulary m_testVocab;
};
} // namespace FilterRuleTable
} // namespace Syntax
} // namespace MosesTraining