soft matching of target-side nonterminals

This commit is contained in:
Rico Sennrich 2014-01-16 18:34:33 +00:00
parent 4e75911331
commit ed25bb2b99
5 changed files with 262 additions and 0 deletions

View File

@ -34,6 +34,7 @@
#include "moses/FF/ExternalFeature.h"
#include "moses/FF/ConstrainedDecoding.h"
#include "moses/FF/CoveredReferenceFeature.h"
#include "moses/FF/SoftMatchingFeature.h"
#include "moses/FF/SkeletonStatelessFF.h"
#include "moses/FF/SkeletonStatefulFF.h"
@ -171,6 +172,7 @@ FeatureRegistry::FeatureRegistry()
MOSES_FNAME(ConstrainedDecoding);
MOSES_FNAME(CoveredReferenceFeature);
MOSES_FNAME(ExternalFeature);
MOSES_FNAME(SoftMatchingFeature);
MOSES_FNAME(SkeletonStatelessFF);
MOSES_FNAME(SkeletonStatefulFF);

View File

@ -0,0 +1,108 @@
#include "SoftMatchingFeature.h"
#include "moses/AlignmentInfo.h"
#include "moses/TargetPhrase.h"
#include "moses/ChartHypothesis.h"
#include "moses/StaticData.h"
#include "moses/InputFileStream.h"
namespace Moses
{
SoftMatchingFeature::SoftMatchingFeature(const std::string &line)
: StatelessFeatureFunction(0, line)
{
std::cerr << "Initializing SoftMatchingFeature.." << std::endl;
for (size_t i = 0; i < m_args.size(); ++i) {
const std::vector<std::string> &args = m_args[i];
if (args[0] == "path") {
const std::string filePath = args[1];
Load(filePath);
}
} // for
}
bool SoftMatchingFeature::Load(const std::string& filePath)
{
StaticData &staticData = StaticData::InstanceNonConst();
InputFileStream inStream(filePath);
std::string line;
while(getline(inStream, line)) {
std::vector<std::string> tokens = Tokenize(line);
UTIL_THROW_IF2(tokens.size() != 2, "Error: wrong format of SoftMatching file: must have two nonterminals per line");
// no soft matching necessary if LHS and RHS are the same
if (tokens[0] == tokens[1]) {
continue;
}
Word LHS, RHS;
LHS.CreateFromString(Output, staticData.GetOutputFactorOrder(), tokens[0], true);
RHS.CreateFromString(Output, staticData.GetOutputFactorOrder(), tokens[1], true);
m_soft_matches[LHS].insert(RHS);
m_soft_matches_reverse[RHS].insert(LHS);
}
staticData.Set_Soft_Matches(Get_Soft_Matches());
staticData.Set_Soft_Matches_Reverse(Get_Soft_Matches_Reverse());
return true;
}
void SoftMatchingFeature::EvaluateChart(const ChartHypothesis& hypo,
ScoreComponentCollection* accumulator) const
{
const TargetPhrase& target = hypo.GetCurrTargetPhrase();
const AlignmentInfo::NonTermIndexMap &nonTermIndexMap = target.GetAlignNonTerm().GetNonTermIndexMap();
// loop over the rule that is being applied
for (size_t pos = 0; pos < target.GetSize(); ++pos) {
const Word& word = target.GetWord(pos);
// for non-terminals, trigger the feature mapping the LHS of the previous hypo to the RHS of this hypo
if (word.IsNonTerminal()) {
size_t nonTermInd = nonTermIndexMap[pos];
const ChartHypothesis* prevHypo = hypo.GetPrevHypo(nonTermInd);
const Word& prevLHS = prevHypo->GetTargetLHS();
const std::string name = GetFeatureName(prevLHS, word);
accumulator->PlusEquals(this,name,1);
}
}
}
//caching feature names because string conversion is slow
const std::string& SoftMatchingFeature::GetFeatureName(const Word& LHS, const Word& RHS) const
{
const NonTerminalMapKey key(LHS, RHS);
{
#ifdef WITH_THREADS //try read-only lock
boost::shared_lock<boost::shared_mutex> read_lock(m_accessLock);
#endif // WITH_THREADS
NonTerminalSoftMatchingMap::const_iterator i = m_soft_matching_cache.find(key);
if (i != m_soft_matching_cache.end()) return i->second;
}
#ifdef WITH_THREADS //need to update cache; write lock
boost::unique_lock<boost::shared_mutex> lock(m_accessLock);
#endif // WITH_THREADS
const std::vector<FactorType> &outputFactorOrder = StaticData::Instance().GetOutputFactorOrder();
std::string LHS_string = LHS.GetString(outputFactorOrder, false);
std::string RHS_string = RHS.GetString(outputFactorOrder, false);
const std::string name = LHS_string + "->" + RHS_string;
m_soft_matching_cache[key] = name;
return m_soft_matching_cache.find(key)->second;
}
}

View File

@ -0,0 +1,77 @@
#pragma once
#include <stdexcept>
#include "moses/Util.h"
#include "moses/Word.h"
#include "StatelessFeatureFunction.h"
#include "moses/TranslationModel/PhraseDictionaryNodeMemory.h"
#ifdef WITH_THREADS
#include <boost/thread/shared_mutex.hpp>
#endif
namespace Moses
{
class SoftMatchingFeature : public StatelessFeatureFunction
{
public:
SoftMatchingFeature(const std::string &line);
bool IsUseable(const FactorMask &mask) const {
return true;
}
virtual void EvaluateChart(const ChartHypothesis& hypo,
ScoreComponentCollection* accumulator) const;
void Evaluate(const Phrase &source
, const TargetPhrase &targetPhrase
, ScoreComponentCollection &scoreBreakdown
, ScoreComponentCollection &estimatedFutureScore) const {};
void Evaluate(const InputType &input
, const InputPath &inputPath
, const TargetPhrase &targetPhrase
, ScoreComponentCollection &scoreBreakdown
, ScoreComponentCollection *estimatedFutureScore = NULL) const {};
void Evaluate(const Hypothesis& hypo,
ScoreComponentCollection* accumulator) const {};
bool Load(const std::string &filePath);
std::map<Word, std::set<Word> >& Get_Soft_Matches() {
return m_soft_matches;
}
std::map<Word, std::set<Word> >& Get_Soft_Matches_Reverse() {
return m_soft_matches_reverse;
}
const std::string& GetFeatureName(const Word& LHS, const Word& RHS) const;
private:
std::map<Word, std::set<Word> > m_soft_matches; // map LHS of old rule to RHS of new rle
std::map<Word, std::set<Word> > m_soft_matches_reverse; // map RHS of new rule to LHS of old rule
typedef std::pair<Word, Word> NonTerminalMapKey;
#if defined(BOOST_VERSION) && (BOOST_VERSION >= 104200)
typedef boost::unordered_map<NonTerminalMapKey,
std::string,
NonTerminalMapKeyHasher,
NonTerminalMapKeyEqualityPred> NonTerminalSoftMatchingMap;
#else
typedef std::map<NonTerminalMapKey, std::string> NonTerminalSoftMatchingMap;
#endif
mutable NonTerminalSoftMatchingMap m_soft_matching_cache;
#ifdef WITH_THREADS
//reader-writer lock
mutable boost::shared_mutex m_accessLock;
#endif
};
}

View File

@ -214,6 +214,9 @@ protected:
bool m_continuePartialTranslation;
std::string m_binPath;
// soft NT lookup for chart models
std::map<Word, std::set<Word> > m_soft_matches_map;
std::map<Word, std::set<Word> > m_soft_matches_map_reverse;
public:
@ -731,6 +734,22 @@ public:
return m_useLegacyPT;
}
void Set_Soft_Matches(std::map<Word, std::set<Word> >& soft_matches_map) {
m_soft_matches_map = soft_matches_map;
}
const std::map<Word, std::set<Word> >* Get_Soft_Matches() const {
return &m_soft_matches_map;
}
void Set_Soft_Matches_Reverse(std::map<Word, std::set<Word> >& soft_matches_map) {
m_soft_matches_map_reverse = soft_matches_map;
}
const std::map<Word, std::set<Word> >* Get_Soft_Matches_Reverse() const {
return &m_soft_matches_map_reverse;
}
};
}

View File

@ -226,6 +226,11 @@ void ChartRuleLookupManagerMemory::ExtendPartialRuleApplication(
const PhraseDictionaryNodeMemory::NonTerminalMap & nonTermMap =
node.GetNonTerminalMap();
// permissible soft nonterminal matches (target side)
const StaticData &staticData = StaticData::Instance();
const std::map<Word, std::set<Word> >* m_soft_matches_map = staticData.Get_Soft_Matches();
const std::map<Word, std::set<Word> >* m_soft_matches_map_reverse = staticData.Get_Soft_Matches_Reverse();
const size_t numChildren = nonTermMap.size();
if (numChildren == 0) {
return;
@ -255,6 +260,34 @@ void ChartRuleLookupManagerMemory::ExtendPartialRuleApplication(
for (; q != tEnd; ++q) {
const ChartCellLabel &cellLabel = q->second;
//soft matching of NTs
const Word& targetNonTerm = cellLabel.GetLabel();
if (m_soft_matches_map->find(targetNonTerm) != m_soft_matches_map->end()) {
const std::set<Word>& softMatches = m_soft_matches_map->find(targetNonTerm)->second;
for (std::set<Word>::const_iterator softMatch = softMatches.begin(); softMatch != softMatches.end(); ++softMatch) {
// try to match both source and target non-terminal
const PhraseDictionaryNodeMemory * child =
node.GetChild(sourceNonTerm, *softMatch);
// nothing found? then we are done
if (child == NULL) {
continue;
}
// create new rule
#ifdef USE_BOOST_POOL
DottedRuleInMemory *rule = m_dottedRulePool.malloc();
new (rule) DottedRuleInMemory(*child, cellLabel, prevDottedRule);
#else
DottedRuleInMemory *rule = new DottedRuleInMemory(*child, cellLabel,
prevDottedRule);
#endif
dottedRuleColl.Add(stackInd, rule);
}
} // end of soft matching
// try to match both source and target non-terminal
const PhraseDictionaryNodeMemory * child =
node.GetChild(sourceNonTerm, cellLabel.GetLabel());
@ -288,6 +321,29 @@ void ChartRuleLookupManagerMemory::ExtendPartialRuleApplication(
continue;
}
const Word &targetNonTerm = key.second;
//soft matching of NTs
if (m_soft_matches_map_reverse->find(targetNonTerm) != m_soft_matches_map_reverse->end()) {
const std::set<Word>& softMatches = m_soft_matches_map_reverse->find(targetNonTerm)->second;
for (std::set<Word>::const_iterator softMatch = softMatches.begin(); softMatch != softMatches.end(); ++softMatch) {
const ChartCellLabel *cellLabel = targetNonTerms.Find(*softMatch);
if (!cellLabel) {
continue;
}
// create new rule
const PhraseDictionaryNodeMemory &child = p->second;
#ifdef USE_BOOST_POOL
DottedRuleInMemory *rule = m_dottedRulePool.malloc();
new (rule) DottedRuleInMemory(child, *cellLabel, prevDottedRule);
#else
DottedRuleInMemory *rule = new DottedRuleInMemory(child, *cellLabel,
prevDottedRule);
#endif
dottedRuleColl.Add(stackInd, rule);
}
} // end of soft matches lookup
const ChartCellLabel *cellLabel = targetNonTerms.Find(targetNonTerm);
if (!cellLabel) {
continue;