various optimizations to make CYK+ parser several times faster and eat less memory.

speed-up of decoding depends on how much time is spent in parser:
10-50% speed-up for string-to-tree systems observed (more on long sentences and with high max-chart-span).

if you only use hiero or string-to-tree models (but none with source syntax), use compile-option --unlabelled-source for (small) efficiency gains.
This commit is contained in:
Rico Sennrich 2014-03-21 10:53:15 +00:00
parent 1c6061e781
commit 45630a5851
36 changed files with 802 additions and 612 deletions

View File

@ -64,6 +64,8 @@
#
# --max-factors maximum number of factors (default 4)
#
# --unlabelled-source ignore source labels (redundant in hiero or string-to-tree system)
# for better performance
#CONTROLLING THE BUILD
#-a to build from scratch
#-j$NCPUS to compile in parallel
@ -110,6 +112,7 @@ requirements += [ option.get "notrace" : <define>TRACE_ENABLE=1 ] ;
requirements += [ option.get "enable-boost-pool" : : <define>USE_BOOST_POOL ] ;
requirements += [ option.get "with-mm" : : <define>PT_UG ] ;
requirements += [ option.get "with-mm" : : <define>MAX_NUM_FACTORS=4 ] ;
requirements += [ option.get "unlabelled-source" : : <define>UNLABELLED_SOURCE ] ;
if [ option.get "with-cmph" ] {
requirements += <define>HAVE_CMPH ;

View File

@ -109,7 +109,7 @@ void Word::ConvertToMoses(
for (std::vector<Moses::FactorType>::const_iterator t = outputFactorsVec.begin(); t != outputFactorsVec.end(); ++t, ++tok) {
UTIL_THROW_IF2(!tok, "Too few factors in \"" << vocab.GetString(m_vocabId) << "\"; was expecting " << outputFactorsVec.size());
overwrite.SetFactor(*t, factorColl.AddFactor(*tok));
overwrite.SetFactor(*t, factorColl.AddFactor(*tok, m_isNonTerminal));
}
UTIL_THROW_IF2(tok, "Too many factors in \"" << vocab.GetString(m_vocabId) << "\"; was expecting " << outputFactorsVec.size());
}

View File

@ -21,6 +21,7 @@
#include "ChartCellLabel.h"
#include "NonTerminal.h"
#include "moses/FactorCollection.h"
#include <boost/functional/hash.hpp>
#include <boost/unordered_map.hpp>
@ -36,20 +37,23 @@ class ChartHypothesisCollection;
class ChartCellLabelSet
{
private:
#if defined(BOOST_VERSION) && (BOOST_VERSION >= 104200)
typedef boost::unordered_map<Word, ChartCellLabel,
NonTerminalHasher, NonTerminalEqualityPred
> MapType;
#else
typedef std::map<Word, ChartCellLabel> MapType;
#endif
typedef std::vector<ChartCellLabel*> MapType;
public:
typedef MapType::const_iterator const_iterator;
typedef MapType::iterator iterator;
ChartCellLabelSet(const WordsRange &coverage) : m_coverage(coverage) {}
ChartCellLabelSet(const WordsRange &coverage)
: m_coverage(coverage)
, m_map(FactorCollection::Instance().GetNumNonTerminals(), NULL)
, m_size(0) { }
~ChartCellLabelSet() {
RemoveAllInColl(m_map);
}
// TODO: skip empty elements when iterating, or deprecate this
const_iterator begin() const {
return m_map.begin();
}
@ -65,36 +69,72 @@ public:
}
void AddWord(const Word &w) {
m_map.insert(std::make_pair(w, ChartCellLabel(m_coverage, w)));
size_t idx = w[0]->GetId();
if (! ChartCellExists(idx)) {
m_size++;
m_map[idx] = new ChartCellLabel(m_coverage, w);
}
}
// Stack is a HypoList or whatever the search algorithm uses.
void AddConstituent(const Word &w, const HypoList *stack) {
ChartCellLabel::Stack s;
s.cube = stack;
m_map.insert(std::make_pair(w, ChartCellLabel(m_coverage, w, s)));
size_t idx = w[0]->GetId();
if (ChartCellExists(idx)) {
ChartCellLabel::Stack & s = m_map[idx]->MutableStack();
s.cube = stack;
}
else {
ChartCellLabel::Stack s;
s.cube = stack;
m_size++;
m_map[idx] = new ChartCellLabel(m_coverage, w, s);
}
}
// grow vector if necessary
bool ChartCellExists(size_t idx) {
try {
if (m_map.at(idx) != NULL) {
return true;
}
}
catch (const std::out_of_range& oor) {
m_map.resize(FactorCollection::Instance().GetNumNonTerminals(), NULL);
}
return false;
}
bool Empty() const {
return m_map.empty();
return m_size == 0;
}
size_t GetSize() const {
return m_map.size();
return m_size;
}
const ChartCellLabel *Find(const Word &w) const {
MapType::const_iterator p = m_map.find(w);
return p == m_map.end() ? 0 : &(p->second);
size_t idx = w[0]->GetId();
try {
return m_map.at(idx);
}
catch (const std::out_of_range& oor) {
return NULL;
}
}
ChartCellLabel::Stack &FindOrInsert(const Word &w) {
return m_map.insert(std::make_pair(w, ChartCellLabel(m_coverage, w))).first->second.MutableStack();
size_t idx = w[0]->GetId();
if (! ChartCellExists(idx)) {
m_size++;
m_map[idx] = new ChartCellLabel(m_coverage, w);
}
return m_map[idx]->MutableStack();
}
private:
const WordsRange &m_coverage;
MapType m_map;
size_t m_size;
};
}

View File

@ -77,8 +77,8 @@ void ChartManager::ProcessSentence()
// MAIN LOOP
size_t size = m_source.GetSize();
for (size_t width = 1; width <= size; ++width) {
for (size_t startPos = 0; startPos <= size-width; ++startPos) {
for (int startPos = size-1; startPos >= 0; --startPos) {
for (size_t width = 1; width <= size-startPos; ++width) {
size_t endPos = startPos + width - 1;
WordsRange range(startPos, endPos);

View File

@ -181,8 +181,12 @@ void ChartParser::Create(const WordsRange &wordsRange, ChartParserCallback &to)
assert(decodeGraph.GetSize() == 1);
ChartRuleLookupManager &ruleLookupManager = **iterRuleLookupManagers;
size_t maxSpan = decodeGraph.GetMaxChartSpan();
size_t last = m_source.GetSize()-1;
if (maxSpan != 0) {
last = min(last, wordsRange.GetStartPos()+maxSpan);
}
if (maxSpan == 0 || wordsRange.GetNumWordsCovered() <= maxSpan) {
ruleLookupManager.GetChartRuleCollection(wordsRange, to);
ruleLookupManager.GetChartRuleCollection(wordsRange, last, to);
}
}

View File

@ -25,6 +25,9 @@ public:
virtual void AddPhraseOOV(TargetPhrase &phrase, std::list<TargetPhraseCollection*> &waste_memory, const WordsRange &range) = 0;
virtual void Evaluate(const InputType &input, const InputPath &inputPath) = 0;
virtual float CalcEstimateOfBestScore(const TargetPhraseCollection &, const StackVec &) const = 0;
};
} // namespace Moses

View File

@ -66,6 +66,7 @@ public:
*/
virtual void GetChartRuleCollection(
const WordsRange &range,
size_t lastPos, // last position to consider if using lookahead
ChartParserCallback &outColl) = 0;
private:

View File

@ -60,6 +60,10 @@ public:
return m_size == 0;
}
float CalcEstimateOfBestScore(const TargetPhraseCollection & tpc, const StackVec & stackVec) const {
return ChartTranslationOptions::CalcEstimateOfBestScore(tpc, stackVec);
}
void Clear();
void ApplyThreshold();
void Evaluate(const InputType &input, const InputPath &inputPath);

View File

@ -4,12 +4,15 @@
#include "moses/ChartHypothesis.h"
#include "moses/StaticData.h"
#include "moses/InputFileStream.h"
#include "moses/FactorCollection.h"
#include "moses/Util.h"
namespace Moses
{
SoftMatchingFeature::SoftMatchingFeature(const std::string &line)
: StatelessFeatureFunction(0, line)
, m_softMatches(moses_MaxNumNonterminals)
{
ReadParameters();
}
@ -49,12 +52,11 @@ bool SoftMatchingFeature::Load(const std::string& filePath)
LHS.CreateFromString(Output, staticData.GetOutputFactorOrder(), tokens[0], true);
RHS.CreateFromString(Output, staticData.GetOutputFactorOrder(), tokens[1], true);
m_soft_matches[LHS].insert(RHS);
m_soft_matches_reverse[RHS].insert(LHS);
m_softMatches[RHS[0]->GetId()].push_back(LHS);
GetOrSetFeatureName(RHS, LHS);
}
staticData.Set_Soft_Matches(Get_Soft_Matches());
staticData.Set_Soft_Matches_Reverse(Get_Soft_Matches_Reverse());
staticData.SetSoftMatches(m_softMatches);
return true;
}
@ -78,37 +80,50 @@ void SoftMatchingFeature::EvaluateChart(const ChartHypothesis& hypo,
const ChartHypothesis* prevHypo = hypo.GetPrevHypo(nonTermInd);
const Word& prevLHS = prevHypo->GetTargetLHS();
const std::string name = GetFeatureName(prevLHS, word);
const std::string &name = GetOrSetFeatureName(word, prevLHS);
accumulator->PlusEquals(this,name,1);
}
}
}
//caching feature names because string conversion is slow
const std::string& SoftMatchingFeature::GetFeatureName(const Word& LHS, const Word& RHS) const
{
// when loading, or when we notice that non-terminals have been added after loading, we resize vectors
void SoftMatchingFeature::ResizeCache() const {
FactorCollection& fc = FactorCollection::Instance();
size_t numNonTerminals = fc.GetNumNonTerminals();
const NonTerminalMapKey key(LHS, RHS);
{
m_nameCache.resize(numNonTerminals);
for (size_t i = 0; i < numNonTerminals; i++) {
m_nameCache[i].resize(numNonTerminals);
}
}
const std::string& SoftMatchingFeature::GetOrSetFeatureName(const Word& RHS, const Word& LHS) const {
try {
#ifdef WITH_THREADS //try read-only lock
boost::shared_lock<boost::shared_mutex> read_lock(m_accessLock);
#endif // WITH_THREADS
NonTerminalSoftMatchingMap::const_iterator i = m_soft_matching_cache.find(key);
if (i != m_soft_matching_cache.end()) return i->second;
boost::shared_lock<boost::shared_mutex> read_lock(m_accessLock);
#endif
const std::string &name = m_nameCache.at(RHS[0]->GetId()).at(LHS[0]->GetId());
if (!name.empty()) {
return name;
}
}
catch (const std::out_of_range& oor) {
#ifdef WITH_THREADS //need to resize cache; write lock
boost::unique_lock<boost::shared_mutex> lock(m_accessLock);
#endif
ResizeCache();
}
#ifdef WITH_THREADS //need to update cache; write lock
boost::unique_lock<boost::shared_mutex> lock(m_accessLock);
#endif // WITH_THREADS
const std::vector<FactorType> &outputFactorOrder = StaticData::Instance().GetOutputFactorOrder();
std::string LHS_string = LHS.GetString(outputFactorOrder, false);
std::string RHS_string = RHS.GetString(outputFactorOrder, false);
const std::string name = LHS_string + "->" + RHS_string;
m_soft_matching_cache[key] = name;
return m_soft_matching_cache.find(key)->second;
}
boost::unique_lock<boost::shared_mutex> lock(m_accessLock);
#endif
std::string &name = m_nameCache[RHS[0]->GetId()][LHS[0]->GetId()];
const std::vector<FactorType> &outputFactorOrder = StaticData::Instance().GetOutputFactorOrder();
std::string LHS_string = LHS.GetString(outputFactorOrder, false);
std::string RHS_string = RHS.GetString(outputFactorOrder, false);
name = LHS_string + "->" + RHS_string;
return name;
}
}

View File

@ -1,10 +1,7 @@
#pragma once
#include <stdexcept>
#include "moses/Util.h"
#include "moses/Word.h"
#include "StatelessFeatureFunction.h"
#include "moses/TranslationModel/PhraseDictionaryNodeMemory.h"
#ifdef WITH_THREADS
#include <boost/thread/shared_mutex.hpp>
@ -39,33 +36,19 @@ public:
bool Load(const std::string &filePath);
std::map<Word, std::set<Word> >& Get_Soft_Matches() {
return m_soft_matches;
std::vector<std::vector<Word> >& GetSoftMatches() {
return m_softMatches;
}
std::map<Word, std::set<Word> >& Get_Soft_Matches_Reverse() {
return m_soft_matches_reverse;
}
void ResizeCache() const;
const std::string& GetFeatureName(const Word& LHS, const Word& RHS) const;
const std::string& GetOrSetFeatureName(const Word& RHS, const Word& LHS) const;
void SetParameter(const std::string& key, const std::string& value);
private:
std::map<Word, std::set<Word> > m_soft_matches; // map LHS of old rule to RHS of new rle
std::map<Word, std::set<Word> > m_soft_matches_reverse; // map RHS of new rule to LHS of old rule
typedef std::pair<Word, Word> NonTerminalMapKey;
#if defined(BOOST_VERSION) && (BOOST_VERSION >= 104200)
typedef boost::unordered_map<NonTerminalMapKey,
std::string,
NonTerminalMapKeyHasher,
NonTerminalMapKeyEqualityPred> NonTerminalSoftMatchingMap;
#else
typedef std::map<NonTerminalMapKey, std::string> NonTerminalSoftMatchingMap;
#endif
mutable NonTerminalSoftMatchingMap m_soft_matching_cache;
mutable std::vector<std::vector<Word> > m_softMatches; // map RHS of new rule to list of possible LHS of old rule (subtree)
mutable std::vector<std::vector<std::string> > m_nameCache;
#ifdef WITH_THREADS
//reader-writer lock

View File

@ -35,27 +35,34 @@ namespace Moses
{
FactorCollection FactorCollection::s_instance;
const Factor *FactorCollection::AddFactor(const StringPiece &factorString)
const Factor *FactorCollection::AddFactor(const StringPiece &factorString, bool isNonTerminal)
{
FactorFriend to_ins;
to_ins.in.m_string = factorString;
to_ins.in.m_id = m_factorId;
to_ins.in.m_id = (isNonTerminal) ? m_factorIdNonTerminal : m_factorId;
Set & set = (isNonTerminal) ? m_set : m_setNonTerminal;
// If we're threaded, hope a read-only lock is sufficient.
#ifdef WITH_THREADS
{
// read=lock scope
boost::shared_lock<boost::shared_mutex> read_lock(m_accessLock);
Set::const_iterator i = m_set.find(to_ins);
if (i != m_set.end()) return &i->in;
Set::const_iterator i = set.find(to_ins);
if (i != set.end()) return &i->in;
}
boost::unique_lock<boost::shared_mutex> lock(m_accessLock);
#endif // WITH_THREADS
std::pair<Set::iterator, bool> ret(m_set.insert(to_ins));
std::pair<Set::iterator, bool> ret(set.insert(to_ins));
if (ret.second) {
ret.first->in.m_string.set(
memcpy(m_string_backing.Allocate(factorString.size()), factorString.data(), factorString.size()),
factorString.size());
m_factorId++;
if (isNonTerminal) {
m_factorIdNonTerminal++;
UTIL_THROW_IF2(m_factorIdNonTerminal >= moses_MaxNumNonterminals, "Number of non-terminals exceeds maximum size reserved. Adjust parameter moses_MaxNumNonterminals, then recompile");
}
else {
m_factorId++;
}
}
return &ret.first->in;
}

View File

@ -22,6 +22,11 @@ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
#ifndef moses_FactorCollection_h
#define moses_FactorCollection_h
// reserve space for non-terminal symbols (ensuring consecutive numbering, and allowing quick lookup by ID)
#ifndef moses_MaxNumNonterminals
#define moses_MaxNumNonterminals 10000
#endif
#ifdef WITH_THREADS
#include <boost/thread/shared_mutex.hpp>
#endif
@ -74,6 +79,7 @@ class FactorCollection
};
typedef boost::unordered_set<FactorFriend, HashFactor, EqualsFactor> Set;
Set m_set;
Set m_setNonTerminal;
util::Pool m_string_backing;
@ -83,11 +89,13 @@ class FactorCollection
mutable boost::shared_mutex m_accessLock;
#endif
size_t m_factorId; /**< unique, contiguous ids, starting from 0, for each factor */
size_t m_factorIdNonTerminal; /**< unique, contiguous ids, starting from 0, for each non-terminal factor */
size_t m_factorId; /**< unique, contiguous ids, starting from moses_MaxNumNonterminals, for each terminal factor */
//! constructor. only the 1 static variable can be created
FactorCollection()
:m_factorId(0) {
: m_factorIdNonTerminal(0)
, m_factorId(moses_MaxNumNonterminals) {
}
public:
@ -100,11 +108,15 @@ public:
/** returns a factor with the same direction, factorType and factorString.
* If a factor already exist in the collection, return the existing factor, if not create a new 1
*/
const Factor *AddFactor(const StringPiece &factorString);
const Factor *AddFactor(const StringPiece &factorString, bool isNonTerminal = false);
const size_t GetNumNonTerminals() {
return m_factorIdNonTerminal;
}
// TODO: remove calls to this function, replacing them with the simpler AddFactor(factorString)
const Factor *AddFactor(FactorDirection /*direction*/, FactorType /*factorType*/, const StringPiece &factorString) {
return AddFactor(factorString);
const Factor *AddFactor(FactorDirection /*direction*/, FactorType /*factorType*/, const StringPiece &factorString, bool isNonTerminal = false) {
return AddFactor(factorString, isNonTerminal);
}
TO_STRING();

View File

@ -50,7 +50,10 @@ public:
void FinishedSearch() {
for (ChartCellLabelSet::iterator i(out_.mutable_begin()); i != out_.mutable_end(); ++i) {
ChartCellLabel::Stack &stack = i->second.MutableStack();
if ((*i) == NULL) {
continue;
}
ChartCellLabel::Stack &stack = (*i)->MutableStack();
Gen *gen = static_cast<Gen*>(stack.incr_generator);
gen->FinishedSearch();
stack.incr = &gen->Generating();
@ -80,6 +83,8 @@ public:
void AddPhraseOOV(TargetPhrase &phrase, std::list<TargetPhraseCollection*> &waste_memory, const WordsRange &range);
float CalcEstimateOfBestScore(const TargetPhraseCollection & tpc, const StackVec & stackVec) const;
bool Empty() const {
return edges_.Empty();
}
@ -112,7 +117,7 @@ private:
const search::Score oov_weight_;
};
template <class Model> void Fill<Model>::Add(const TargetPhraseCollection &targets, const StackVec &nts, const WordsRange &)
template <class Model> void Fill<Model>::Add(const TargetPhraseCollection &targets, const StackVec &nts, const WordsRange &range)
{
std::vector<search::PartialVertex> vertices;
vertices.reserve(nts.size());
@ -173,6 +178,17 @@ template <class Model> void Fill<Model>::AddPhraseOOV(TargetPhrase &phrase, std:
edges_.AddEdge(edge);
}
// for early pruning
template <class Model> float Fill<Model>::CalcEstimateOfBestScore(const TargetPhraseCollection &targets, const StackVec &nts) const
{
float below_score = 0.0;
for (StackVec::const_iterator i = nts.begin(); i != nts.end(); ++i) {
below_score += (*i)->GetStack().incr->RootAlternate().Bound();
}
const TargetPhrase &targetPhrase = **(targets.begin());
return targetPhrase.GetFutureScore() + below_score;
}
// TODO: factors (but chart doesn't seem to support factors anyway).
template <class Model> lm::WordIndex Fill<Model>::Convert(const Word &word) const
{
@ -209,8 +225,12 @@ template <class Model, class Best> search::History Manager::PopulateBest(const M
size_t size = source_.GetSize();
boost::object_pool<search::Vertex> vertex_pool(std::max<size_t>(size * size / 2, 32));
for (size_t width = 1; width < size; ++width) {
for (size_t startPos = 0; startPos <= size-width; ++startPos) {
for (int startPos = size-1; startPos >= 0; --startPos) {
for (size_t width = 1; width <= size-startPos; ++width) {
// full range uses RootSearch
if (startPos == 0 && startPos + width == size) {
break;
}
WordsRange range(startPos, startPos + width - 1);
Fill<Model> filler(context, words, oov_weight);
parser_.Create(range, filler);

View File

@ -19,8 +19,14 @@ InputPath(const Phrase &phrase, const NonTerminalSet &sourceNonTerms,
,m_range(range)
,m_inputScore(inputScore)
,m_sourceNonTerms(sourceNonTerms)
,m_sourceNonTermArray(FactorCollection::Instance().GetNumNonTerminals(), false)
,m_nextNode(1)
{
for (NonTerminalSet::const_iterator iter = sourceNonTerms.begin(); iter != sourceNonTerms.end(); ++iter) {
size_t idx = (*iter)[0]->GetId();
m_sourceNonTermArray[idx] = true;
}
//cerr << "phrase=" << phrase << " m_inputScore=" << *m_inputScore << endl;
}

View File

@ -6,6 +6,7 @@
#include "Phrase.h"
#include "WordsRange.h"
#include "NonTerminal.h"
#include "moses/FactorCollection.h"
namespace Moses
{
@ -45,6 +46,7 @@ protected:
// for syntax model only
mutable std::vector<std::vector<const Word*> > m_ruleSourceFromInputPath;
const NonTerminalSet m_sourceNonTerms;
std::vector<bool> m_sourceNonTermArray;
public:
@ -65,6 +67,9 @@ public:
const NonTerminalSet &GetNonTerminalSet() const {
return m_sourceNonTerms;
}
const std::vector<bool> &GetNonTerminalArray() const {
return m_sourceNonTermArray;
}
const WordsRange &GetWordsRange() const {
return m_range;
}

View File

@ -616,11 +616,11 @@ void StaticData::LoadNonTerminals()
FactorCollection &factorCollection = FactorCollection::Instance();
m_inputDefaultNonTerminal.SetIsNonTerminal(true);
const Factor *sourceFactor = factorCollection.AddFactor(Input, 0, defaultNonTerminals);
const Factor *sourceFactor = factorCollection.AddFactor(Input, 0, defaultNonTerminals, true);
m_inputDefaultNonTerminal.SetFactor(0, sourceFactor);
m_outputDefaultNonTerminal.SetIsNonTerminal(true);
const Factor *targetFactor = factorCollection.AddFactor(Output, 0, defaultNonTerminals);
const Factor *targetFactor = factorCollection.AddFactor(Output, 0, defaultNonTerminals, true);
m_outputDefaultNonTerminal.SetFactor(0, targetFactor);
// for unknwon words
@ -638,6 +638,7 @@ void StaticData::LoadNonTerminals()
"Incorrect unknown LHS format: " << line);
UnknownLHSEntry entry(tokens[0], Scan<float>(tokens[1]));
m_unknownLHS.push_back(entry);
const Factor *targetFactor = factorCollection.AddFactor(Output, 0, tokens[0], true);
}
}

View File

@ -218,11 +218,13 @@ protected:
std::string m_binPath;
// soft NT lookup for chart models
std::map<Word, std::set<Word> > m_soft_matches_map;
std::map<Word, std::set<Word> > m_soft_matches_map_reverse;
std::vector<std::vector<Word> > m_softMatchesMap;
const StatefulFeatureFunction* m_treeStructure;
// number of nonterminal labels
// size_t m_nonTerminalSize;
public:
bool IsAlwaysCreateDirectTranslationOption() const {
@ -740,21 +742,14 @@ public:
return m_useLegacyPT;
}
void Set_Soft_Matches(std::map<Word, std::set<Word> >& soft_matches_map) {
m_soft_matches_map = soft_matches_map;
void SetSoftMatches(std::vector<std::vector<Word> >& softMatchesMap) {
m_softMatchesMap = softMatchesMap;
}
const std::map<Word, std::set<Word> >* Get_Soft_Matches() const {
return &m_soft_matches_map;
const std::vector< std::vector<Word> >& GetSoftMatches() const {
return m_softMatchesMap;
}
void Set_Soft_Matches_Reverse(std::map<Word, std::set<Word> >& soft_matches_map) {
m_soft_matches_map_reverse = soft_matches_map;
}
const std::map<Word, std::set<Word> >* Get_Soft_Matches_Reverse() const {
return &m_soft_matches_map_reverse;
}
bool AdjacentOnly() const
{ return m_adjacentOnly; }

View File

@ -19,7 +19,6 @@
#include <iostream>
#include "ChartRuleLookupManagerMemory.h"
#include "DotChartInMemory.h"
#include "moses/ChartParser.h"
#include "moses/InputType.h"
@ -40,328 +39,191 @@ ChartRuleLookupManagerMemory::ChartRuleLookupManagerMemory(
const PhraseDictionaryMemory &ruleTable)
: ChartRuleLookupManagerCYKPlus(parser, cellColl)
, m_ruleTable(ruleTable)
, m_softMatchingMap(StaticData::Instance().GetSoftMatches())
{
UTIL_THROW_IF2(m_dottedRuleColls.size() != 0,
"Dotted rule collection not correctly initialized");
size_t sourceSize = parser.GetSize();
m_dottedRuleColls.resize(sourceSize);
const PhraseDictionaryNodeMemory &rootNode = m_ruleTable.GetRootNode();
m_completedRules.resize(sourceSize);
// permissible soft nonterminal matches (target side)
const StaticData &staticData = StaticData::Instance();
m_soft_matches_map = staticData.Get_Soft_Matches();
m_soft_matches_map_reverse = staticData.Get_Soft_Matches_Reverse();
m_soft_matching = !m_soft_matches_map->empty();
for (size_t ind = 0; ind < m_dottedRuleColls.size(); ++ind) {
#ifdef USE_BOOST_POOL
DottedRuleInMemory *initDottedRule = m_dottedRulePool.malloc();
new (initDottedRule) DottedRuleInMemory(rootNode);
#else
DottedRuleInMemory *initDottedRule = new DottedRuleInMemory(rootNode);
#endif
DottedRuleColl *dottedRuleColl = new DottedRuleColl(sourceSize - ind + 1);
dottedRuleColl->Add(0, initDottedRule); // init rule. stores the top node in tree
m_dottedRuleColls[ind] = dottedRuleColl;
}
}
ChartRuleLookupManagerMemory::~ChartRuleLookupManagerMemory()
{
RemoveAllInColl(m_dottedRuleColls);
m_isSoftMatching = !m_softMatchingMap.empty();
}
void ChartRuleLookupManagerMemory::GetChartRuleCollection(
const WordsRange &range,
size_t lastPos,
ChartParserCallback &outColl)
{
size_t relEndPos = range.GetEndPos() - range.GetStartPos();
size_t startPos = range.GetStartPos();
size_t absEndPos = range.GetEndPos();
// MAIN LOOP. create list of nodes of target phrases
m_lastPos = lastPos;
m_stackVec.clear();
m_outColl = &outColl;
m_unaryPos = absEndPos-1; // rules ending in this position are unary and should not be added to collection
// get list of all rules that apply to spans at same starting position
DottedRuleColl &dottedRuleCol = *m_dottedRuleColls[range.GetStartPos()];
const DottedRuleList &expandableDottedRuleList = dottedRuleCol.GetExpandableDottedRuleList();
DottedRuleMap &expandableDottedRuleListTerminalsOnly = dottedRuleCol.GetExpandableDottedRuleListTerminalsOnly();
const PhraseDictionaryNodeMemory &rootNode = m_ruleTable.GetRootNode();
const ChartCellLabel &sourceWordLabel = GetSourceAt(absEndPos);
// size-1 terminal rules
if (startPos == absEndPos) {
const Word &sourceWord = GetSourceAt(absEndPos).GetLabel();
const PhraseDictionaryNodeMemory *child = rootNode.GetChild(sourceWord);
// loop through the rules
// (note that expandableDottedRuleList can be expanded as the loop runs
// through calls to ExtendPartialRuleApplication())
for (size_t ind = 0; ind < expandableDottedRuleList.size(); ++ind) {
// rule we are about to extend
const DottedRuleInMemory &prevDottedRule = *expandableDottedRuleList[ind];
// we will now try to extend it, starting after where it ended
size_t startPos = prevDottedRule.IsRoot()
? range.GetStartPos()
: prevDottedRule.GetWordsRange().GetEndPos() + 1;
// search for terminal symbol
// (if only one more word position needs to be covered)
if (startPos == absEndPos) {
// look up in rule dictionary, if the current rule can be extended
// with the source word in the last position
const Word &sourceWord = sourceWordLabel.GetLabel();
const PhraseDictionaryNodeMemory *node = prevDottedRule.GetLastNode().GetChild(sourceWord);
// if we found a new rule -> create it and add it to the list
if (node != NULL) {
// create the rule
#ifdef USE_BOOST_POOL
DottedRuleInMemory *dottedRule = m_dottedRulePool.malloc();
new (dottedRule) DottedRuleInMemory(*node, sourceWordLabel,
prevDottedRule);
#else
DottedRuleInMemory *dottedRule = new DottedRuleInMemory(*node,
sourceWordLabel,
prevDottedRule);
#endif
dottedRuleCol.Add(relEndPos+1, dottedRule);
}
// if we found a new rule -> directly add it to the out collection
if (child != NULL) {
const TargetPhraseCollection &tpc = child->GetTargetPhraseCollection();
outColl.Add(tpc, m_stackVec, range);
}
// search for non-terminals
size_t endPos, stackInd;
// span is already complete covered? nothing can be done
if (startPos > absEndPos)
continue;
else if (startPos == range.GetStartPos() && range.GetEndPos() > range.GetStartPos()) {
// We're at the root of the prefix tree so won't try to cover the full
// span (i.e. we don't allow non-lexical unary rules). However, we need
// to match non-unary rules that begin with a non-terminal child, so we
// do that in two steps: during this iteration we search for non-terminals
// that cover all but the last source word in the span (there won't
// already be running nodes for these because that would have required a
// non-lexical unary rule match for an earlier span). Any matches will
// result in running nodes being appended to the list and on subsequent
// iterations (for this same span), we'll extend them to cover the final
// word.
endPos = absEndPos - 1;
stackInd = relEndPos;
} else {
endPos = absEndPos;
stackInd = relEndPos + 1;
}
// all rules starting with nonterminal
else if (absEndPos > startPos) {
GetNonTerminalExtension(&rootNode, startPos, absEndPos-1);
// all (non-unary) rules starting with terminal
if (absEndPos == startPos+1) {
GetTerminalExtension(&rootNode, absEndPos-1);
}
ExtendPartialRuleApplication(prevDottedRule, startPos, endPos, stackInd,
dottedRuleCol);
}
// search for terminal symbol
// (if only one more word position needs to be covered)
DottedRuleMap::iterator it = expandableDottedRuleListTerminalsOnly.find(absEndPos);
if (it != expandableDottedRuleListTerminalsOnly.end()) {
for (size_t ind = 0; ind < it->second.size(); ++ind) {
// rule we are about to extend
const DottedRuleInMemory &prevDottedRule = *it->second[ind];
// look up in rule dictionary, if the current rule can be extended
// with the source word in the last position
const Word &sourceWord = sourceWordLabel.GetLabel();
const PhraseDictionaryNodeMemory *node = prevDottedRule.GetLastNode().GetChild(sourceWord);
// if we found a new rule -> create it and add it to the list
if (node != NULL) {
// create the rule
#ifdef USE_BOOST_POOL
DottedRuleInMemory *dottedRule = m_dottedRulePool.malloc();
new (dottedRule) DottedRuleInMemory(*node, sourceWordLabel,
prevDottedRule);
#else
DottedRuleInMemory *dottedRule = new DottedRuleInMemory(*node,
sourceWordLabel,
prevDottedRule);
#endif
dottedRuleCol.Add(relEndPos+1, dottedRule);
}
}
// we only need to check once if a terminal matches the input at a given position.
expandableDottedRuleListTerminalsOnly.erase(it);
// copy temporarily stored rules to out collection
CompletedRuleCollection rules = m_completedRules[absEndPos];
for (vector<CompletedRule*>::const_iterator iter = rules.begin(); iter != rules.end(); ++iter) {
outColl.Add((*iter)->GetTPC(), (*iter)->GetStackVector(), range);
}
// list of rules that that cover the entire span
DottedRuleList &rules = dottedRuleCol.Get(relEndPos + 1);
m_completedRules[absEndPos].Clear();
// look up target sides for the rules
DottedRuleList::const_iterator iterRule;
for (iterRule = rules.begin(); iterRule != rules.end(); ++iterRule) {
const DottedRuleInMemory &dottedRule = **iterRule;
const PhraseDictionaryNodeMemory &node = dottedRule.GetLastNode();
// look up target sides
const TargetPhraseCollection &tpc = node.GetTargetPhraseCollection();
// add the fully expanded rule (with lexical target side)
AddCompletedRule(dottedRule, tpc, range, outColl);
}
dottedRuleCol.Clear(relEndPos+1);
}
// Given a partial rule application ending at startPos-1 and given the sets of
// source and target non-terminals covering the span [startPos, endPos],
// determines the full or partial rule applications that can be produced through
// extending the current rule application by a single non-terminal.
void ChartRuleLookupManagerMemory::ExtendPartialRuleApplication(
const DottedRuleInMemory &prevDottedRule,
size_t startPos,
size_t endPos,
size_t stackInd,
DottedRuleColl & dottedRuleColl)
{
// source non-terminal labels for the remainder
const InputPath &inputPath = GetParser().GetInputPath(startPos, endPos);
const NonTerminalSet &sourceNonTerms = inputPath.GetNonTerminalSet();
// if a (partial) rule matches, add it to list completed rules (if non-unary and non-empty), and try find expansions that have this partial rule as prefix.
void ChartRuleLookupManagerMemory::AddAndExtend(
const PhraseDictionaryNodeMemory *node,
size_t endPos,
const ChartCellLabel *cellLabel) {
// target non-terminal labels for the remainder
const ChartCellLabelSet &targetNonTerms = GetTargetLabelSet(startPos, endPos);
// add backpointer
if (cellLabel != NULL) {
m_stackVec.push_back(cellLabel);
}
// note where it was found in the prefix tree of the rule dictionary
const PhraseDictionaryNodeMemory &node = prevDottedRule.GetLastNode();
const TargetPhraseCollection &tpc = node->GetTargetPhraseCollection();
// add target phrase collection (except if rule is empty or unary)
if (!tpc.IsEmpty() && endPos != m_unaryPos) {
m_completedRules[endPos].Add(tpc, m_stackVec, *m_outColl);
}
const PhraseDictionaryNodeMemory::NonTerminalMap & nonTermMap =
node.GetNonTerminalMap();
const size_t numChildren = nonTermMap.size();
if (numChildren == 0) {
return;
}
const size_t numSourceNonTerms = sourceNonTerms.size();
const size_t numTargetNonTerms = targetNonTerms.GetSize();
const size_t numCombinations = numSourceNonTerms * numTargetNonTerms;
// We can search by either:
// 1. Enumerating all possible source-target NT pairs that are valid for
// the span and then searching for matching children in the node,
// or
// 2. Iterating over all the NT children in the node, searching
// for each source and target NT in the span's sets.
// We'll do whichever minimises the number of lookups:
if (numCombinations <= numChildren*2) {
// loop over possible source non-terminal labels (as found in input tree)
NonTerminalSet::const_iterator p = sourceNonTerms.begin();
NonTerminalSet::const_iterator sEnd = sourceNonTerms.end();
for (; p != sEnd; ++p) {
const Word & sourceNonTerm = *p;
// loop over possible target non-terminal labels (as found in chart)
ChartCellLabelSet::const_iterator q = targetNonTerms.begin();
ChartCellLabelSet::const_iterator tEnd = targetNonTerms.end();
for (; q != tEnd; ++q) {
const ChartCellLabel &cellLabel = q->second;
//soft matching of NTs
const Word& targetNonTerm = cellLabel.GetLabel();
if (m_soft_matching && m_soft_matches_map->find(targetNonTerm) != m_soft_matches_map->end()) {
const std::set<Word>& softMatches = m_soft_matches_map->find(targetNonTerm)->second;
for (std::set<Word>::const_iterator softMatch = softMatches.begin(); softMatch != softMatches.end(); ++softMatch) {
// try to match both source and target non-terminal
const PhraseDictionaryNodeMemory * child =
node.GetChild(sourceNonTerm, *softMatch);
// nothing found? then we are done
if (child == NULL) {
continue;
}
// create new rule
#ifdef USE_BOOST_POOL
DottedRuleInMemory *rule = m_dottedRulePool.malloc();
new (rule) DottedRuleInMemory(*child, cellLabel, prevDottedRule);
#else
DottedRuleInMemory *rule = new DottedRuleInMemory(*child, cellLabel,
prevDottedRule);
#endif
dottedRuleColl.Add(stackInd, rule);
}
} // end of soft matching
// try to match both source and target non-terminal
const PhraseDictionaryNodeMemory * child =
node.GetChild(sourceNonTerm, targetNonTerm);
// nothing found? then we are done
if (child == NULL) {
continue;
// get all further extensions of rule (until reaching end of sentence or max-chart-span)
if (endPos < m_lastPos) {
if (!node->GetTerminalMap().empty()) {
GetTerminalExtension(node, endPos+1);
}
if (!node->GetNonTerminalMap().empty()) {
for (size_t newEndPos = endPos+1; newEndPos <= m_lastPos; newEndPos++) {
GetNonTerminalExtension(node, endPos+1, newEndPos);
}
// create new rule
#ifdef USE_BOOST_POOL
DottedRuleInMemory *rule = m_dottedRulePool.malloc();
new (rule) DottedRuleInMemory(*child, cellLabel, prevDottedRule);
#else
DottedRuleInMemory *rule = new DottedRuleInMemory(*child, cellLabel,
prevDottedRule);
#endif
dottedRuleColl.Add(stackInd, rule);
}
}
} else {
// remove backpointer
if (cellLabel != NULL) {
m_stackVec.pop_back();
}
}
// search all possible terminal extensions of a partial rule (pointed at by node) at a given position
// recursively try to expand partial rules into full rules up to m_lastPos.
void ChartRuleLookupManagerMemory::GetTerminalExtension(
const PhraseDictionaryNodeMemory *node,
size_t pos) {
const Word &sourceWord = GetSourceAt(pos).GetLabel();
const PhraseDictionaryNodeMemory::TerminalMap & terminals = node->GetTerminalMap();
// if node has small number of terminal edges, test word equality for each.
if (terminals.size() < 5) {
for (PhraseDictionaryNodeMemory::TerminalMap::const_iterator iter = terminals.begin(); iter != terminals.end(); ++iter) {
const Word & word = iter->first;
if (word == sourceWord) {
const PhraseDictionaryNodeMemory *child = & iter->second;
AddAndExtend(child, pos, NULL);
}
}
}
// else, do hash lookup
else {
const PhraseDictionaryNodeMemory *child = node->GetChild(sourceWord);
if (child != NULL) {
AddAndExtend(child, pos, NULL);
}
}
}
// search all nonterminal possible nonterminal extensions of a partial rule (pointed at by node) for a given span (StartPos, endPos).
// recursively try to expand partial rules into full rules up to m_lastPos.
void ChartRuleLookupManagerMemory::GetNonTerminalExtension(
const PhraseDictionaryNodeMemory *node,
size_t startPos,
size_t endPos) {
// target non-terminal labels for the span
const ChartCellLabelSet &targetNonTerms = GetTargetLabelSet(startPos, endPos);
if (targetNonTerms.GetSize() == 0) {
return;
}
#if !defined(UNLABELLED_SOURCE)
// source non-terminal labels for the span
const InputPath &inputPath = GetParser().GetInputPath(startPos, endPos);
const std::vector<bool> &sourceNonTermArray = inputPath.GetNonTerminalArray();
// can this ever be true? Moses seems to pad the non-terminal set of the input with [X]
if (inputPath.GetNonTerminalSet.size() == 0) {
return;
}
#endif
// non-terminal labels in phrase dictionary node
const PhraseDictionaryNodeMemory::NonTerminalMap & nonTermMap = node->GetNonTerminalMap();
// loop over possible expansions of the rule
PhraseDictionaryNodeMemory::NonTerminalMap::const_iterator p;
PhraseDictionaryNodeMemory::NonTerminalMap::const_iterator end =
nonTermMap.end();
PhraseDictionaryNodeMemory::NonTerminalMap::const_iterator end = nonTermMap.end();
for (p = nonTermMap.begin(); p != end; ++p) {
// does it match possible source and target non-terminals?
#if defined(UNLABELLED_SOURCE)
const Word &targetNonTerm = p->first;
#else
const PhraseDictionaryNodeMemory::NonTerminalMapKey &key = p->first;
const Word &sourceNonTerm = key.first;
if (sourceNonTerms.find(sourceNonTerm) == sourceNonTerms.end()) {
// check if source label matches
if (! sourceNonTermArray[sourceNonTerm[0]->GetId()]) {
continue;
}
const Word &targetNonTerm = key.second;
#endif
//soft matching of NTs
if (m_soft_matching && m_soft_matches_map_reverse->find(targetNonTerm) != m_soft_matches_map_reverse->end()) {
const std::set<Word>& softMatches = m_soft_matches_map_reverse->find(targetNonTerm)->second;
for (std::set<Word>::const_iterator softMatch = softMatches.begin(); softMatch != softMatches.end(); ++softMatch) {
if (m_isSoftMatching && !m_softMatchingMap[targetNonTerm[0]->GetId()].empty()) {
const std::vector<Word>& softMatches = m_softMatchingMap[targetNonTerm[0]->GetId()];
for (std::vector<Word>::const_iterator softMatch = softMatches.begin(); softMatch != softMatches.end(); ++softMatch) {
const ChartCellLabel *cellLabel = targetNonTerms.Find(*softMatch);
if (!cellLabel) {
if (cellLabel == NULL) {
continue;
}
// create new rule
const PhraseDictionaryNodeMemory &child = p->second;
#ifdef USE_BOOST_POOL
DottedRuleInMemory *rule = m_dottedRulePool.malloc();
new (rule) DottedRuleInMemory(child, *cellLabel, prevDottedRule);
#else
DottedRuleInMemory *rule = new DottedRuleInMemory(child, *cellLabel,
prevDottedRule);
#endif
dottedRuleColl.Add(stackInd, rule);
AddAndExtend(&child, endPos, cellLabel);
}
} // end of soft matches lookup
const ChartCellLabel *cellLabel = targetNonTerms.Find(targetNonTerm);
if (!cellLabel) {
if (cellLabel == NULL) {
continue;
}
// create new rule
const PhraseDictionaryNodeMemory &child = p->second;
#ifdef USE_BOOST_POOL
DottedRuleInMemory *rule = m_dottedRulePool.malloc();
new (rule) DottedRuleInMemory(child, *cellLabel, prevDottedRule);
#else
DottedRuleInMemory *rule = new DottedRuleInMemory(child, *cellLabel,
prevDottedRule);
#endif
dottedRuleColl.Add(stackInd, rule);
AddAndExtend(&child, endPos, cellLabel);
}
}
}
} // namespace Moses

View File

@ -23,12 +23,8 @@
#include <vector>
#ifdef USE_BOOST_POOL
#include <boost/pool/object_pool.hpp>
#endif
#include "ChartRuleLookupManagerCYKPlus.h"
#include "DotChartInMemory.h"
#include "CompletedRuleCollection.h"
#include "moses/NonTerminal.h"
#include "moses/TranslationModel/PhraseDictionaryMemory.h"
#include "moses/TranslationModel/PhraseDictionaryNodeMemory.h"
@ -38,7 +34,6 @@ namespace Moses
{
class ChartParserCallback;
class DottedRuleColl;
class WordsRange;
//! Implementation of ChartRuleLookupManager for in-memory rule tables.
@ -49,33 +44,44 @@ public:
const ChartCellCollectionBase &cellColl,
const PhraseDictionaryMemory &ruleTable);
~ChartRuleLookupManagerMemory();
~ChartRuleLookupManagerMemory() {};
virtual void GetChartRuleCollection(
const WordsRange &range,
size_t lastPos, // last position to consider if using lookahead
ChartParserCallback &outColl);
private:
void ExtendPartialRuleApplication(
const DottedRuleInMemory &prevDottedRule,
size_t startPos,
size_t endPos,
size_t stackInd,
DottedRuleColl &dottedRuleColl);
std::vector<DottedRuleColl*> m_dottedRuleColls;
void GetTerminalExtension(
const PhraseDictionaryNodeMemory *node,
size_t pos);
void GetNonTerminalExtension(
const PhraseDictionaryNodeMemory *node,
size_t startPos,
size_t endPos);
void AddAndExtend(
const PhraseDictionaryNodeMemory *node,
size_t endPos,
const ChartCellLabel *cellLabel);
const PhraseDictionaryMemory &m_ruleTable;
#ifdef USE_BOOST_POOL
// Use an object pool to allocate the dotted rules for this sentence. We
// allocate a lot of them and this has been seen to significantly improve
// performance, especially for multithreaded decoding.
boost::object_pool<DottedRuleInMemory> m_dottedRulePool;
#endif
// permissible soft nonterminal matches (target side)
bool m_soft_matching;
const std::map<Word, std::set<Word> >* m_soft_matches_map;
const std::map<Word, std::set<Word> >* m_soft_matches_map_reverse;
bool m_isSoftMatching;
const std::vector<std::vector<Word> >& m_softMatchingMap;
// temporary storage of completed rules (one collection per end position; all rules collected consecutively start from the same position)
std::vector<CompletedRuleCollection> m_completedRules;
size_t m_lastPos;
size_t m_unaryPos;
StackVec m_stackVec;
ChartParserCallback* m_outColl;
};
} // namespace Moses

View File

@ -17,16 +17,18 @@
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
***********************************************************************/
#include <iostream>
#include "ChartRuleLookupManagerMemoryPerSentence.h"
#include "DotChartInMemory.h"
#include "moses/TranslationModel/RuleTable/PhraseDictionaryFuzzyMatch.h"
#include "moses/ChartParser.h"
#include "moses/InputType.h"
#include "moses/ChartParserCallback.h"
#include "moses/StaticData.h"
#include "moses/NonTerminal.h"
#include "moses/ChartCellCollection.h"
#include "moses/ChartParser.h"
#include "moses/TranslationModel/RuleTable/PhraseDictionaryFuzzyMatch.h"
using namespace std;
namespace Moses
{
@ -37,235 +39,192 @@ ChartRuleLookupManagerMemoryPerSentence::ChartRuleLookupManagerMemoryPerSentence
const PhraseDictionaryFuzzyMatch &ruleTable)
: ChartRuleLookupManagerCYKPlus(parser, cellColl)
, m_ruleTable(ruleTable)
, m_softMatchingMap(StaticData::Instance().GetSoftMatches())
{
UTIL_THROW_IF2(m_dottedRuleColls.size() != 0, "Dotted rule collection not correctly initialized");
size_t sourceSize = parser.GetSize();
m_dottedRuleColls.resize(sourceSize);
const PhraseDictionaryNodeMemory &rootNode = m_ruleTable.GetRootNode(parser.GetTranslationId());
m_completedRules.resize(sourceSize);
for (size_t ind = 0; ind < m_dottedRuleColls.size(); ++ind) {
#ifdef USE_BOOST_POOL
DottedRuleInMemory *initDottedRule = m_dottedRulePool.malloc();
new (initDottedRule) DottedRuleInMemory(rootNode);
#else
DottedRuleInMemory *initDottedRule = new DottedRuleInMemory(rootNode);
#endif
DottedRuleColl *dottedRuleColl = new DottedRuleColl(sourceSize - ind + 1);
dottedRuleColl->Add(0, initDottedRule); // init rule. stores the top node in tree
m_dottedRuleColls[ind] = dottedRuleColl;
}
}
ChartRuleLookupManagerMemoryPerSentence::~ChartRuleLookupManagerMemoryPerSentence()
{
RemoveAllInColl(m_dottedRuleColls);
m_isSoftMatching = !m_softMatchingMap.empty();
}
void ChartRuleLookupManagerMemoryPerSentence::GetChartRuleCollection(
const WordsRange &range,
size_t lastPos,
ChartParserCallback &outColl)
{
size_t relEndPos = range.GetEndPos() - range.GetStartPos();
size_t startPos = range.GetStartPos();
size_t absEndPos = range.GetEndPos();
// MAIN LOOP. create list of nodes of target phrases
m_lastPos = lastPos;
m_stackVec.clear();
m_outColl = &outColl;
m_unaryPos = absEndPos-1; // rules ending in this position are unary and should not be added to collection
// get list of all rules that apply to spans at same starting position
DottedRuleColl &dottedRuleCol = *m_dottedRuleColls[range.GetStartPos()];
const DottedRuleList &expandableDottedRuleList = dottedRuleCol.GetExpandableDottedRuleList();
const PhraseDictionaryNodeMemory &rootNode = m_ruleTable.GetRootNode(GetParser().GetTranslationId());
// loop through the rules
// (note that expandableDottedRuleList can be expanded as the loop runs
// through calls to ExtendPartialRuleApplication())
for (size_t ind = 0; ind < expandableDottedRuleList.size(); ++ind) {
// rule we are about to extend
const DottedRuleInMemory &prevDottedRule = *expandableDottedRuleList[ind];
// we will now try to extend it, starting after where it ended
size_t startPos = prevDottedRule.IsRoot()
? range.GetStartPos()
: prevDottedRule.GetWordsRange().GetEndPos() + 1;
// size-1 terminal rules
if (startPos == absEndPos) {
const Word &sourceWord = GetSourceAt(absEndPos).GetLabel();
const PhraseDictionaryNodeMemory *child = rootNode.GetChild(sourceWord);
// search for terminal symbol
// (if only one more word position needs to be covered)
if (startPos == absEndPos) {
// look up in rule dictionary, if the current rule can be extended
// with the source word in the last position
const ChartCellLabel &sourceWordLabel = GetSourceAt(absEndPos);
const Word &sourceWord = sourceWordLabel.GetLabel();
const PhraseDictionaryNodeMemory *node = prevDottedRule.GetLastNode().GetChild(sourceWord);
// if we found a new rule -> create it and add it to the list
if (node != NULL) {
// create the rule
#ifdef USE_BOOST_POOL
DottedRuleInMemory *dottedRule = m_dottedRulePool.malloc();
new (dottedRule) DottedRuleInMemory(*node, sourceWordLabel,
prevDottedRule);
#else
DottedRuleInMemory *dottedRule = new DottedRuleInMemory(*node,
sourceWordLabel,
prevDottedRule);
#endif
dottedRuleCol.Add(relEndPos+1, dottedRule);
}
// if we found a new rule -> directly add it to the out collection
if (child != NULL) {
const TargetPhraseCollection &tpc = child->GetTargetPhraseCollection();
outColl.Add(tpc, m_stackVec, range);
}
// search for non-terminals
size_t endPos, stackInd;
// span is already complete covered? nothing can be done
if (startPos > absEndPos)
continue;
else if (startPos == range.GetStartPos() && range.GetEndPos() > range.GetStartPos()) {
// We're at the root of the prefix tree so won't try to cover the full
// span (i.e. we don't allow non-lexical unary rules). However, we need
// to match non-unary rules that begin with a non-terminal child, so we
// do that in two steps: during this iteration we search for non-terminals
// that cover all but the last source word in the span (there won't
// already be running nodes for these because that would have required a
// non-lexical unary rule match for an earlier span). Any matches will
// result in running nodes being appended to the list and on subsequent
// iterations (for this same span), we'll extend them to cover the final
// word.
endPos = absEndPos - 1;
stackInd = relEndPos;
} else {
endPos = absEndPos;
stackInd = relEndPos + 1;
}
// all rules starting with nonterminal
else if (absEndPos > startPos) {
GetNonTerminalExtension(&rootNode, startPos, absEndPos-1);
// all (non-unary) rules starting with terminal
if (absEndPos == startPos+1) {
GetTerminalExtension(&rootNode, absEndPos-1);
}
ExtendPartialRuleApplication(prevDottedRule, startPos, endPos, stackInd,
dottedRuleCol);
}
// list of rules that that cover the entire span
DottedRuleList &rules = dottedRuleCol.Get(relEndPos + 1);
// look up target sides for the rules
DottedRuleList::const_iterator iterRule;
for (iterRule = rules.begin(); iterRule != rules.end(); ++iterRule) {
const DottedRuleInMemory &dottedRule = **iterRule;
const PhraseDictionaryNodeMemory &node = dottedRule.GetLastNode();
// look up target sides
const TargetPhraseCollection &tpc = node.GetTargetPhraseCollection();
// add the fully expanded rule (with lexical target side)
AddCompletedRule(dottedRule, tpc, range, outColl);
// copy temporarily stored rules to out collection
CompletedRuleCollection rules = m_completedRules[absEndPos];
for (vector<CompletedRule*>::const_iterator iter = rules.begin(); iter != rules.end(); ++iter) {
outColl.Add((*iter)->GetTPC(), (*iter)->GetStackVector(), range);
}
dottedRuleCol.Clear(relEndPos+1);
m_completedRules[absEndPos].Clear();
}
// Given a partial rule application ending at startPos-1 and given the sets of
// source and target non-terminals covering the span [startPos, endPos],
// determines the full or partial rule applications that can be produced through
// extending the current rule application by a single non-terminal.
void ChartRuleLookupManagerMemoryPerSentence::ExtendPartialRuleApplication(
const DottedRuleInMemory &prevDottedRule,
size_t startPos,
size_t endPos,
size_t stackInd,
DottedRuleColl & dottedRuleColl)
{
// source non-terminal labels for the remainder
const NonTerminalSet &sourceNonTerms = GetParser().GetInputPath(startPos, endPos).GetNonTerminalSet();
// if a (partial) rule matches, add it to list completed rules (if non-unary and non-empty), and try find expansions that have this partial rule as prefix.
void ChartRuleLookupManagerMemoryPerSentence::AddAndExtend(
const PhraseDictionaryNodeMemory *node,
size_t endPos,
const ChartCellLabel *cellLabel) {
// target non-terminal labels for the remainder
const ChartCellLabelSet &targetNonTerms = GetTargetLabelSet(startPos, endPos);
// add backpointer
if (cellLabel != NULL) {
m_stackVec.push_back(cellLabel);
}
// note where it was found in the prefix tree of the rule dictionary
const PhraseDictionaryNodeMemory &node = prevDottedRule.GetLastNode();
const TargetPhraseCollection &tpc = node->GetTargetPhraseCollection();
// add target phrase collection (except if rule is empty or unary)
if (!tpc.IsEmpty() && endPos != m_unaryPos) {
m_completedRules[endPos].Add(tpc, m_stackVec, *m_outColl);
}
const PhraseDictionaryNodeMemory::NonTerminalMap & nonTermMap =
node.GetNonTerminalMap();
const size_t numChildren = nonTermMap.size();
if (numChildren == 0) {
return;
}
const size_t numSourceNonTerms = sourceNonTerms.size();
const size_t numTargetNonTerms = targetNonTerms.GetSize();
const size_t numCombinations = numSourceNonTerms * numTargetNonTerms;
// We can search by either:
// 1. Enumerating all possible source-target NT pairs that are valid for
// the span and then searching for matching children in the node,
// or
// 2. Iterating over all the NT children in the node, searching
// for each source and target NT in the span's sets.
// We'll do whichever minimises the number of lookups:
if (numCombinations <= numChildren*2) {
// loop over possible source non-terminal labels (as found in input tree)
NonTerminalSet::const_iterator p = sourceNonTerms.begin();
NonTerminalSet::const_iterator sEnd = sourceNonTerms.end();
for (; p != sEnd; ++p) {
const Word & sourceNonTerm = *p;
// loop over possible target non-terminal labels (as found in chart)
ChartCellLabelSet::const_iterator q = targetNonTerms.begin();
ChartCellLabelSet::const_iterator tEnd = targetNonTerms.end();
for (; q != tEnd; ++q) {
const ChartCellLabel &cellLabel = q->second;
// try to match both source and target non-terminal
const PhraseDictionaryNodeMemory * child =
node.GetChild(sourceNonTerm, cellLabel.GetLabel());
// nothing found? then we are done
if (child == NULL) {
continue;
// get all further extensions of rule (until reaching end of sentence or max-chart-span)
if (endPos < m_lastPos) {
if (!node->GetTerminalMap().empty()) {
GetTerminalExtension(node, endPos+1);
}
if (!node->GetNonTerminalMap().empty()) {
for (size_t newEndPos = endPos+1; newEndPos <= m_lastPos; newEndPos++) {
GetNonTerminalExtension(node, endPos+1, newEndPos);
}
// create new rule
#ifdef USE_BOOST_POOL
DottedRuleInMemory *rule = m_dottedRulePool.malloc();
new (rule) DottedRuleInMemory(*child, cellLabel, prevDottedRule);
#else
DottedRuleInMemory *rule = new DottedRuleInMemory(*child, cellLabel,
prevDottedRule);
#endif
dottedRuleColl.Add(stackInd, rule);
}
}
} else {
// remove backpointer
if (cellLabel != NULL) {
m_stackVec.pop_back();
}
}
// search all possible terminal extensions of a partial rule (pointed at by node) at a given position
// recursively try to expand partial rules into full rules up to m_lastPos.
void ChartRuleLookupManagerMemoryPerSentence::GetTerminalExtension(
const PhraseDictionaryNodeMemory *node,
size_t pos) {
const Word &sourceWord = GetSourceAt(pos).GetLabel();
const PhraseDictionaryNodeMemory::TerminalMap & terminals = node->GetTerminalMap();
// if node has small number of terminal edges, test word equality for each.
if (terminals.size() < 5) {
for (PhraseDictionaryNodeMemory::TerminalMap::const_iterator iter = terminals.begin(); iter != terminals.end(); ++iter) {
const Word & word = iter->first;
if (word == sourceWord) {
const PhraseDictionaryNodeMemory *child = & iter->second;
AddAndExtend(child, pos, NULL);
}
}
}
// else, do hash lookup
else {
const PhraseDictionaryNodeMemory *child = node->GetChild(sourceWord);
if (child != NULL) {
AddAndExtend(child, pos, NULL);
}
}
}
// search all nonterminal possible nonterminal extensions of a partial rule (pointed at by node) for a given span (StartPos, endPos).
// recursively try to expand partial rules into full rules up to m_lastPos.
void ChartRuleLookupManagerMemoryPerSentence::GetNonTerminalExtension(
const PhraseDictionaryNodeMemory *node,
size_t startPos,
size_t endPos) {
// target non-terminal labels for the span
const ChartCellLabelSet &targetNonTerms = GetTargetLabelSet(startPos, endPos);
if (targetNonTerms.GetSize() == 0) {
return;
}
#if !defined(UNLABELLED_SOURCE)
// source non-terminal labels for the span
const InputPath &inputPath = GetParser().GetInputPath(startPos, endPos);
const std::vector<bool> &sourceNonTermArray = inputPath.GetNonTerminalArray();
// can this ever be true? Moses seems to pad the non-terminal set of the input with [X]
if (inputPath.GetNonTerminalSet.size() == 0) {
return;
}
#endif
// non-terminal labels in phrase dictionary node
const PhraseDictionaryNodeMemory::NonTerminalMap & nonTermMap = node->GetNonTerminalMap();
// loop over possible expansions of the rule
PhraseDictionaryNodeMemory::NonTerminalMap::const_iterator p;
PhraseDictionaryNodeMemory::NonTerminalMap::const_iterator end =
nonTermMap.end();
PhraseDictionaryNodeMemory::NonTerminalMap::const_iterator end = nonTermMap.end();
for (p = nonTermMap.begin(); p != end; ++p) {
// does it match possible source and target non-terminals?
#if defined(UNLABELLED_SOURCE)
const Word &targetNonTerm = p->first;
#else
const PhraseDictionaryNodeMemory::NonTerminalMapKey &key = p->first;
const Word &sourceNonTerm = key.first;
if (sourceNonTerms.find(sourceNonTerm) == sourceNonTerms.end()) {
// check if source label matches
if (! sourceNonTermArray[sourceNonTerm[0]->GetId()]) {
continue;
}
const Word &targetNonTerm = key.second;
#endif
//soft matching of NTs
if (m_isSoftMatching && !m_softMatchingMap[targetNonTerm[0]->GetId()].empty()) {
const std::vector<Word>& softMatches = m_softMatchingMap[targetNonTerm[0]->GetId()];
for (std::vector<Word>::const_iterator softMatch = softMatches.begin(); softMatch != softMatches.end(); ++softMatch) {
const ChartCellLabel *cellLabel = targetNonTerms.Find(*softMatch);
if (cellLabel == NULL) {
continue;
}
// create new rule
const PhraseDictionaryNodeMemory &child = p->second;
AddAndExtend(&child, endPos, cellLabel);
}
} // end of soft matches lookup
const ChartCellLabel *cellLabel = targetNonTerms.Find(targetNonTerm);
if (!cellLabel) {
if (cellLabel == NULL) {
continue;
}
// create new rule
const PhraseDictionaryNodeMemory &child = p->second;
#ifdef USE_BOOST_POOL
DottedRuleInMemory *rule = m_dottedRulePool.malloc();
new (rule) DottedRuleInMemory(child, *cellLabel, prevDottedRule);
#else
DottedRuleInMemory *rule = new DottedRuleInMemory(child, *cellLabel,
prevDottedRule);
#endif
dottedRuleColl.Add(stackInd, rule);
AddAndExtend(&child, endPos, cellLabel);
}
}
}
} // namespace Moses

View File

@ -18,17 +18,13 @@
***********************************************************************/
#pragma once
#ifndef moses_ChartRuleLookupManagerMemory_h
#define moses_ChartRuleLookupManagerMemory_h
#ifndef moses_ChartRuleLookupManagerMemoryPerSentence_h
#define moses_ChartRuleLookupManagerMemoryPerSentence_h
#include <vector>
#ifdef USE_BOOST_POOL
#include <boost/pool/object_pool.hpp>
#endif
#include "ChartRuleLookupManagerCYKPlus.h"
#include "DotChartInMemory.h"
#include "CompletedRuleCollection.h"
#include "moses/NonTerminal.h"
#include "moses/TranslationModel/PhraseDictionaryMemory.h"
#include "moses/TranslationModel/PhraseDictionaryNodeMemory.h"
@ -38,7 +34,6 @@ namespace Moses
{
class ChartParserCallback;
class DottedRuleColl;
class WordsRange;
//! Implementation of ChartRuleLookupManager for in-memory rule tables.
@ -46,31 +41,47 @@ class ChartRuleLookupManagerMemoryPerSentence : public ChartRuleLookupManagerCYK
{
public:
ChartRuleLookupManagerMemoryPerSentence(const ChartParser &parser,
const ChartCellCollectionBase &cellColl,
const PhraseDictionaryFuzzyMatch &ruleTable);
const ChartCellCollectionBase &cellColl,
const PhraseDictionaryFuzzyMatch &ruleTable);
~ChartRuleLookupManagerMemoryPerSentence();
~ChartRuleLookupManagerMemoryPerSentence() {};
virtual void GetChartRuleCollection(
const WordsRange &range,
size_t lastPos, // last position to consider if using lookahead
ChartParserCallback &outColl);
private:
void ExtendPartialRuleApplication(
const DottedRuleInMemory &prevDottedRule,
size_t startPos,
size_t endPos,
size_t stackInd,
DottedRuleColl &dottedRuleColl);
std::vector<DottedRuleColl*> m_dottedRuleColls;
void GetTerminalExtension(
const PhraseDictionaryNodeMemory *node,
size_t pos);
void GetNonTerminalExtension(
const PhraseDictionaryNodeMemory *node,
size_t startPos,
size_t endPos);
void AddAndExtend(
const PhraseDictionaryNodeMemory *node,
size_t endPos,
const ChartCellLabel *cellLabel);
const PhraseDictionaryFuzzyMatch &m_ruleTable;
#ifdef USE_BOOST_POOL
// Use an object pool to allocate the dotted rules for this sentence. We
// allocate a lot of them and this has been seen to significantly improve
// performance, especially for multithreaded decoding.
boost::object_pool<DottedRuleInMemory> m_dottedRulePool;
#endif
// permissible soft nonterminal matches (target side)
bool m_isSoftMatching;
const std::vector<std::vector<Word> >& m_softMatchingMap;
// temporary storage of completed rules (one collection per end position; all rules collected consecutively start from the same position)
std::vector<CompletedRuleCollection> m_completedRules;
size_t m_lastPos;
size_t m_unaryPos;
StackVec m_stackVec;
ChartParserCallback* m_outColl;
};
} // namespace Moses

View File

@ -78,6 +78,7 @@ ChartRuleLookupManagerOnDisk::~ChartRuleLookupManagerOnDisk()
void ChartRuleLookupManagerOnDisk::GetChartRuleCollection(
const WordsRange &range,
size_t lastPos,
ChartParserCallback &outColl)
{
const StaticData &staticData = StaticData::Instance();
@ -168,7 +169,10 @@ void ChartRuleLookupManagerOnDisk::GetChartRuleCollection(
// go through each TARGET lhs
ChartCellLabelSet::const_iterator iterChartNonTerm;
for (iterChartNonTerm = chartNonTermSet.begin(); iterChartNonTerm != chartNonTermSet.end(); ++iterChartNonTerm) {
const ChartCellLabel &cellLabel = iterChartNonTerm->second;
if (*iterChartNonTerm == NULL) {
continue;
}
const ChartCellLabel &cellLabel = **iterChartNonTerm;
//cerr << sourceLHS << " " << defaultSourceNonTerm << " " << chartNonTerm << " " << defaultTargetNonTerm << endl;

View File

@ -47,6 +47,7 @@ public:
~ChartRuleLookupManagerOnDisk();
virtual void GetChartRuleCollection(const WordsRange &range,
size_t last,
ChartParserCallback &outColl);
private:

View File

@ -53,6 +53,7 @@ ChartRuleLookupManagerSkeleton::~ChartRuleLookupManagerSkeleton()
void ChartRuleLookupManagerSkeleton::GetChartRuleCollection(
const WordsRange &range,
size_t last,
ChartParserCallback &outColl)
{
//m_tpColl.push_back(TargetPhraseCollection());

View File

@ -42,6 +42,7 @@ public:
virtual void GetChartRuleCollection(
const WordsRange &range,
size_t last,
ChartParserCallback &outColl);
private:

View File

@ -0,0 +1,75 @@
/***********************************************************************
Moses - factored phrase-based language decoder
Copyright (C) 2014 University of Edinburgh
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
***********************************************************************/
#include <iostream>
#include "CompletedRuleCollection.h"
#include "moses/StaticData.h"
using namespace std;
namespace Moses
{
CompletedRuleCollection::CompletedRuleCollection() : m_ruleLimit(StaticData::Instance().GetRuleLimit())
{
m_scoreThreshold = numeric_limits<float>::infinity();
}
// copies some functionality (pruning) from ChartTranslationOptionList::Add
void CompletedRuleCollection::Add(const TargetPhraseCollection &tpc,
const StackVec &stackVec,
const ChartParserCallback &outColl)
{
if (tpc.IsEmpty()) {
return;
}
const float score = outColl.CalcEstimateOfBestScore(tpc, stackVec);
// If the rule limit has already been reached then don't add the option
// unless it is better than at least one existing option.
if (m_collection.size() > m_ruleLimit && score < m_scoreThreshold) {
return;
}
CompletedRule *completedRule = new CompletedRule(tpc, stackVec, score);
m_collection.push_back(completedRule);
// If the rule limit hasn't been exceeded then update the threshold.
if (m_collection.size() <= m_ruleLimit) {
m_scoreThreshold = (score < m_scoreThreshold) ? score : m_scoreThreshold;
}
// Prune if bursting
if (m_collection.size() == m_ruleLimit * 2) {
NTH_ELEMENT4(m_collection.begin(),
m_collection.begin() + m_ruleLimit - 1,
m_collection.end(),
CompletedRuleOrdered());
m_scoreThreshold = m_collection[m_ruleLimit-1]->GetScoreEstimate();
for (size_t i = 0 + m_ruleLimit; i < m_collection.size(); i++) {
delete m_collection[i];
}
m_collection.resize(m_ruleLimit);
}
}
}

View File

@ -0,0 +1,117 @@
/***********************************************************************
Moses - factored phrase-based language decoder
Copyright (C) 2014 University of Edinburgh
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
***********************************************************************/
#pragma once
#ifndef moses_CompletedRuleCollectionS_h
#define moses_CompletedRuleCollectionS_h
#include <vector>
#include "moses/StackVec.h"
#include "moses/TargetPhraseCollection.h"
#include "moses/ChartTranslationOptions.h"
#include "moses/ChartCellLabel.h"
#include "moses/ChartParserCallback.h"
namespace Moses
{
// temporary storage for a completed rule (because we use lookahead to find rules before ChartManager wants us to)
struct CompletedRule
{
public:
CompletedRule(const TargetPhraseCollection &tpc,
const StackVec &stackVec,
const float score)
: m_stackVec(stackVec)
, m_tpc(tpc)
, m_score(score) {}
const TargetPhraseCollection & GetTPC() const {
return m_tpc;
}
const StackVec & GetStackVector() const {
return m_stackVec;
}
const float GetScoreEstimate() const {
return m_score;
}
private:
const StackVec m_stackVec;
const TargetPhraseCollection &m_tpc;
const float m_score;
};
class CompletedRuleOrdered
{
public:
bool operator()(const CompletedRule* itemA, const CompletedRule* itemB) const {
return itemA->GetScoreEstimate() > itemB->GetScoreEstimate();
}
};
struct CompletedRuleCollection
{
public:
CompletedRuleCollection();
CompletedRuleCollection(const CompletedRuleCollection &old)
: m_collection(old.m_collection)
, m_scoreThreshold(old.m_scoreThreshold)
, m_ruleLimit(old.m_ruleLimit) {}
CompletedRuleCollection & operator=(const CompletedRuleCollection &old) {
m_collection = old.m_collection;
m_scoreThreshold = old.m_scoreThreshold;
m_ruleLimit = old.m_ruleLimit;
return *this;
}
std::vector<CompletedRule*>::const_iterator begin() const {
return m_collection.begin();
}
std::vector<CompletedRule*>::const_iterator end() const {
return m_collection.end();
}
void Clear() {
RemoveAllInColl(m_collection);
}
void Add(const TargetPhraseCollection &tpc,
const StackVec &stackVec,
const ChartParserCallback &outColl);
private:
std::vector<CompletedRule*> m_collection;
float m_scoreThreshold;
size_t m_ruleLimit;
};
} // namespace Moses
#endif

View File

@ -105,8 +105,11 @@ PhraseDictionaryNodeMemory &PhraseDictionaryMemory::GetOrCreateNode(const Phrase
size_t targetNonTermInd = iterAlign->second;
++iterAlign;
const Word &targetNonTerm = target.GetWord(targetNonTermInd);
#if defined(UNLABELLED_SOURCE)
currNode = currNode->GetOrCreateNonTerminalChild(targetNonTerm);
#else
currNode = currNode->GetOrCreateChild(sourceNonTerm, targetNonTerm);
#endif
} else {
currNode = currNode->GetOrCreateChild(word);
}
@ -181,8 +184,13 @@ ostream& operator<<(ostream& out, const PhraseDictionaryMemory& phraseDict)
const PhraseDictionaryNodeMemory &coll = phraseDict.m_collection;
for (NonTermMap::const_iterator p = coll.m_nonTermMap.begin(); p != coll.m_nonTermMap.end(); ++p) {
#if defined(UNLABELLED_SOURCE)
const Word &targetNonTerm = p->first;
out << targetNonTerm;
#else
const Word &sourceNonTerm = p->first.first;
out << sourceNonTerm;
#endif
}
for (TermMap::const_iterator p = coll.m_sourceTermMap.begin(); p != coll.m_sourceTermMap.end(); ++p) {
const Word &sourceTerm = p->first;

View File

@ -61,6 +61,15 @@ PhraseDictionaryNodeMemory *PhraseDictionaryNodeMemory::GetOrCreateChild(const W
return &m_sourceTermMap[sourceTerm];
}
#if defined(UNLABELLED_SOURCE)
PhraseDictionaryNodeMemory *PhraseDictionaryNodeMemory::GetOrCreateNonTerminalChild(const Word &targetNonTerm)
{
UTIL_THROW_IF2(!targetNonTerm.IsNonTerminal(),
"Not a non-terminal: " << targetNonTerm);
return &m_nonTermMap[targetNonTerm];
}
#else
PhraseDictionaryNodeMemory *PhraseDictionaryNodeMemory::GetOrCreateChild(const Word &sourceNonTerm, const Word &targetNonTerm)
{
UTIL_THROW_IF2(!sourceNonTerm.IsNonTerminal(),
@ -71,6 +80,7 @@ PhraseDictionaryNodeMemory *PhraseDictionaryNodeMemory::GetOrCreateChild(const W
NonTerminalMapKey key(sourceNonTerm, targetNonTerm);
return &m_nonTermMap[NonTerminalMapKey(sourceNonTerm, targetNonTerm)];
}
#endif
const PhraseDictionaryNodeMemory *PhraseDictionaryNodeMemory::GetChild(const Word &sourceTerm) const
{
@ -81,6 +91,16 @@ const PhraseDictionaryNodeMemory *PhraseDictionaryNodeMemory::GetChild(const Wor
return (p == m_sourceTermMap.end()) ? NULL : &p->second;
}
#if defined(UNLABELLED_SOURCE)
const PhraseDictionaryNodeMemory *PhraseDictionaryNodeMemory::GetNonTerminalChild(const Word &targetNonTerm) const
{
UTIL_THROW_IF2(!targetNonTerm.IsNonTerminal(),
"Not a non-terminal: " << targetNonTerm);
NonTerminalMap::const_iterator p = m_nonTermMap.find(targetNonTerm);
return (p == m_nonTermMap.end()) ? NULL : &p->second;
}
#else
const PhraseDictionaryNodeMemory *PhraseDictionaryNodeMemory::GetChild(const Word &sourceNonTerm, const Word &targetNonTerm) const
{
UTIL_THROW_IF2(!sourceNonTerm.IsNonTerminal(),
@ -92,6 +112,7 @@ const PhraseDictionaryNodeMemory *PhraseDictionaryNodeMemory::GetChild(const Wor
NonTerminalMap::const_iterator p = m_nonTermMap.find(key);
return (p == m_nonTermMap.end()) ? NULL : &p->second;
}
#endif
void PhraseDictionaryNodeMemory::Remove()
{

View File

@ -29,6 +29,7 @@ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
#include "moses/Word.h"
#include "moses/TargetPhraseCollection.h"
#include "moses/Terminal.h"
#include "moses/NonTerminal.h"
#include <boost/functional/hash.hpp>
#include <boost/unordered_map.hpp>
@ -102,14 +103,25 @@ public:
TerminalHasher,
TerminalEqualityPred> TerminalMap;
#if defined(UNLABELLED_SOURCE)
typedef boost::unordered_map<Word,
PhraseDictionaryNodeMemory,
NonTerminalHasher,
NonTerminalEqualityPred> NonTerminalMap;
#else
typedef boost::unordered_map<NonTerminalMapKey,
PhraseDictionaryNodeMemory,
NonTerminalMapKeyHasher,
NonTerminalMapKeyEqualityPred> NonTerminalMap;
#endif
#else
typedef std::map<Word, PhraseDictionaryNodeMemory> TerminalMap;
#if defined(UNLABELLED_SOURCE)
typedef std::map<Word, PhraseDictionaryNodeMemory> NonTerminalMap;
#else
typedef std::map<NonTerminalMapKey, PhraseDictionaryNodeMemory> NonTerminalMap;
#endif
#endif
private:
friend std::ostream& operator<<(std::ostream&, const PhraseDictionaryMemory&);
@ -131,9 +143,14 @@ public:
void Prune(size_t tableLimit);
void Sort(size_t tableLimit);
PhraseDictionaryNodeMemory *GetOrCreateChild(const Word &sourceTerm);
PhraseDictionaryNodeMemory *GetOrCreateChild(const Word &sourceNonTerm, const Word &targetNonTerm);
const PhraseDictionaryNodeMemory *GetChild(const Word &sourceTerm) const;
#if defined(UNLABELLED_SOURCE)
PhraseDictionaryNodeMemory *GetOrCreateNonTerminalChild(const Word &targetNonTerm);
const PhraseDictionaryNodeMemory *GetNonTerminalChild(const Word &targetNonTerm) const;
#else
PhraseDictionaryNodeMemory *GetOrCreateChild(const Word &sourceNonTerm, const Word &targetNonTerm);
const PhraseDictionaryNodeMemory *GetChild(const Word &sourceNonTerm, const Word &targetNonTerm) const;
#endif
const TargetPhraseCollection &GetTargetPhraseCollection() const {
return m_targetPhraseCollection;

View File

@ -340,7 +340,11 @@ PhraseDictionaryNodeMemory &PhraseDictionaryFuzzyMatch::GetOrCreateNode(PhraseDi
++iterAlign;
const Word &targetNonTerm = target.GetWord(targetNonTermInd);
#if defined(UNLABELLED_SOURCE)
currNode = currNode->GetOrCreateNonTerminalChild(targetNonTerm);
#else
currNode = currNode->GetOrCreateChild(sourceNonTerm, targetNonTerm);
#endif
} else {
currNode = currNode->GetOrCreateChild(word);
}

View File

@ -40,6 +40,7 @@ namespace Moses
void Scope3Parser::GetChartRuleCollection(
const WordsRange &range,
size_t last,
ChartParserCallback &outColl)
{
const size_t start = range.GetStartPos();

View File

@ -59,6 +59,7 @@ public:
void GetChartRuleCollection(
const WordsRange &range,
size_t last,
ChartParserCallback &outColl);
private:

View File

@ -313,7 +313,7 @@ void TreeInput::AddChartLabel(size_t startPos, size_t endPos, const string &labe
, const std::vector<FactorType>& factorOrder)
{
Word word(true);
const Factor *factor = FactorCollection::Instance().AddFactor(Input, factorOrder[0], label); // TODO - no factors
const Factor *factor = FactorCollection::Instance().AddFactor(Input, factorOrder[0], label, true); // TODO - no factors
word.SetFactor(0, factor);
AddChartLabel(startPos, endPos, word, factorOrder);

View File

@ -107,7 +107,7 @@ void Word::CreateFromString(FactorDirection direction
util::TokenIter<util::MultiCharacter> fit(str, StaticData::Instance().GetFactorDelimiter());
for (size_t ind = 0; ind < factorOrder.size() && fit; ++ind, ++fit) {
m_factorArray[factorOrder[ind]] = factorCollection.AddFactor(*fit);
m_factorArray[factorOrder[ind]] = factorCollection.AddFactor(*fit, isNonTerminal);
}
UTIL_THROW_IF(fit, StrayFactorException, "You have configured " << factorOrder.size() << " factors but the word " << str << " contains factor delimiter " << StaticData::Instance().GetFactorDelimiter() << " too many times.");
@ -119,16 +119,18 @@ void Word::CreateUnknownWord(const Word &sourceWord)
{
FactorCollection &factorCollection = FactorCollection::Instance();
m_isNonTerminal = sourceWord.IsNonTerminal();
for (unsigned int currFactor = 0 ; currFactor < MAX_NUM_FACTORS ; currFactor++) {
FactorType factorType = static_cast<FactorType>(currFactor);
const Factor *sourceFactor = sourceWord[currFactor];
if (sourceFactor == NULL)
SetFactor(factorType, factorCollection.AddFactor(Output, factorType, UNKNOWN_FACTOR));
SetFactor(factorType, factorCollection.AddFactor(Output, factorType, UNKNOWN_FACTOR, m_isNonTerminal));
else
SetFactor(factorType, factorCollection.AddFactor(Output, factorType, sourceFactor->GetString()));
SetFactor(factorType, factorCollection.AddFactor(Output, factorType, sourceFactor->GetString(), m_isNonTerminal));
}
m_isNonTerminal = sourceWord.IsNonTerminal();
m_isOOV = true;
}

View File

@ -363,7 +363,7 @@ bool ProcessAndStripXMLTags(string &line, vector<XmlOption*> &res, ReorderingCon
// lhs
const UnknownLHSList &lhsList = staticData.GetUnknownLHS();
if (!lhsList.empty()) {
const Factor *factor = FactorCollection::Instance().AddFactor(lhsList[0].first);
const Factor *factor = FactorCollection::Instance().AddFactor(lhsList[0].first, true);
Word *targetLHS = new Word(true);
targetLHS->SetFactor(0, factor); // TODO - other factors too?
targetPhrase.SetTargetLHS(targetLHS);