mirror of
https://github.com/moses-smt/mosesdecoder.git
synced 2024-12-25 12:52:29 +03:00
soft matching of target-side nonterminals
This commit is contained in:
parent
4e75911331
commit
ed25bb2b99
@ -34,6 +34,7 @@
|
||||
#include "moses/FF/ExternalFeature.h"
|
||||
#include "moses/FF/ConstrainedDecoding.h"
|
||||
#include "moses/FF/CoveredReferenceFeature.h"
|
||||
#include "moses/FF/SoftMatchingFeature.h"
|
||||
|
||||
#include "moses/FF/SkeletonStatelessFF.h"
|
||||
#include "moses/FF/SkeletonStatefulFF.h"
|
||||
@ -171,6 +172,7 @@ FeatureRegistry::FeatureRegistry()
|
||||
MOSES_FNAME(ConstrainedDecoding);
|
||||
MOSES_FNAME(CoveredReferenceFeature);
|
||||
MOSES_FNAME(ExternalFeature);
|
||||
MOSES_FNAME(SoftMatchingFeature);
|
||||
|
||||
MOSES_FNAME(SkeletonStatelessFF);
|
||||
MOSES_FNAME(SkeletonStatefulFF);
|
||||
|
108
moses/FF/SoftMatchingFeature.cpp
Normal file
108
moses/FF/SoftMatchingFeature.cpp
Normal file
@ -0,0 +1,108 @@
|
||||
#include "SoftMatchingFeature.h"
|
||||
#include "moses/AlignmentInfo.h"
|
||||
#include "moses/TargetPhrase.h"
|
||||
#include "moses/ChartHypothesis.h"
|
||||
#include "moses/StaticData.h"
|
||||
#include "moses/InputFileStream.h"
|
||||
|
||||
namespace Moses
|
||||
{
|
||||
|
||||
SoftMatchingFeature::SoftMatchingFeature(const std::string &line)
|
||||
: StatelessFeatureFunction(0, line)
|
||||
{
|
||||
std::cerr << "Initializing SoftMatchingFeature.." << std::endl;
|
||||
|
||||
for (size_t i = 0; i < m_args.size(); ++i) {
|
||||
const std::vector<std::string> &args = m_args[i];
|
||||
if (args[0] == "path") {
|
||||
const std::string filePath = args[1];
|
||||
Load(filePath);
|
||||
}
|
||||
} // for
|
||||
}
|
||||
|
||||
bool SoftMatchingFeature::Load(const std::string& filePath)
|
||||
{
|
||||
|
||||
StaticData &staticData = StaticData::InstanceNonConst();
|
||||
|
||||
InputFileStream inStream(filePath);
|
||||
std::string line;
|
||||
while(getline(inStream, line)) {
|
||||
std::vector<std::string> tokens = Tokenize(line);
|
||||
UTIL_THROW_IF2(tokens.size() != 2, "Error: wrong format of SoftMatching file: must have two nonterminals per line");
|
||||
|
||||
// no soft matching necessary if LHS and RHS are the same
|
||||
if (tokens[0] == tokens[1]) {
|
||||
continue;
|
||||
}
|
||||
|
||||
Word LHS, RHS;
|
||||
LHS.CreateFromString(Output, staticData.GetOutputFactorOrder(), tokens[0], true);
|
||||
RHS.CreateFromString(Output, staticData.GetOutputFactorOrder(), tokens[1], true);
|
||||
|
||||
m_soft_matches[LHS].insert(RHS);
|
||||
m_soft_matches_reverse[RHS].insert(LHS);
|
||||
}
|
||||
|
||||
staticData.Set_Soft_Matches(Get_Soft_Matches());
|
||||
staticData.Set_Soft_Matches_Reverse(Get_Soft_Matches_Reverse());
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void SoftMatchingFeature::EvaluateChart(const ChartHypothesis& hypo,
|
||||
ScoreComponentCollection* accumulator) const
|
||||
{
|
||||
|
||||
const TargetPhrase& target = hypo.GetCurrTargetPhrase();
|
||||
|
||||
const AlignmentInfo::NonTermIndexMap &nonTermIndexMap = target.GetAlignNonTerm().GetNonTermIndexMap();
|
||||
|
||||
// loop over the rule that is being applied
|
||||
for (size_t pos = 0; pos < target.GetSize(); ++pos) {
|
||||
const Word& word = target.GetWord(pos);
|
||||
|
||||
// for non-terminals, trigger the feature mapping the LHS of the previous hypo to the RHS of this hypo
|
||||
if (word.IsNonTerminal()) {
|
||||
size_t nonTermInd = nonTermIndexMap[pos];
|
||||
|
||||
const ChartHypothesis* prevHypo = hypo.GetPrevHypo(nonTermInd);
|
||||
const Word& prevLHS = prevHypo->GetTargetLHS();
|
||||
|
||||
const std::string name = GetFeatureName(prevLHS, word);
|
||||
accumulator->PlusEquals(this,name,1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//caching feature names because string conversion is slow
|
||||
const std::string& SoftMatchingFeature::GetFeatureName(const Word& LHS, const Word& RHS) const
|
||||
{
|
||||
|
||||
const NonTerminalMapKey key(LHS, RHS);
|
||||
{
|
||||
#ifdef WITH_THREADS //try read-only lock
|
||||
boost::shared_lock<boost::shared_mutex> read_lock(m_accessLock);
|
||||
#endif // WITH_THREADS
|
||||
NonTerminalSoftMatchingMap::const_iterator i = m_soft_matching_cache.find(key);
|
||||
if (i != m_soft_matching_cache.end()) return i->second;
|
||||
}
|
||||
#ifdef WITH_THREADS //need to update cache; write lock
|
||||
boost::unique_lock<boost::shared_mutex> lock(m_accessLock);
|
||||
#endif // WITH_THREADS
|
||||
const std::vector<FactorType> &outputFactorOrder = StaticData::Instance().GetOutputFactorOrder();
|
||||
|
||||
std::string LHS_string = LHS.GetString(outputFactorOrder, false);
|
||||
std::string RHS_string = RHS.GetString(outputFactorOrder, false);
|
||||
|
||||
const std::string name = LHS_string + "->" + RHS_string;
|
||||
|
||||
m_soft_matching_cache[key] = name;
|
||||
return m_soft_matching_cache.find(key)->second;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
77
moses/FF/SoftMatchingFeature.h
Normal file
77
moses/FF/SoftMatchingFeature.h
Normal file
@ -0,0 +1,77 @@
|
||||
#pragma once
|
||||
|
||||
#include <stdexcept>
|
||||
#include "moses/Util.h"
|
||||
#include "moses/Word.h"
|
||||
#include "StatelessFeatureFunction.h"
|
||||
#include "moses/TranslationModel/PhraseDictionaryNodeMemory.h"
|
||||
|
||||
#ifdef WITH_THREADS
|
||||
#include <boost/thread/shared_mutex.hpp>
|
||||
#endif
|
||||
|
||||
namespace Moses
|
||||
{
|
||||
|
||||
class SoftMatchingFeature : public StatelessFeatureFunction
|
||||
{
|
||||
public:
|
||||
SoftMatchingFeature(const std::string &line);
|
||||
|
||||
bool IsUseable(const FactorMask &mask) const {
|
||||
return true;
|
||||
}
|
||||
|
||||
virtual void EvaluateChart(const ChartHypothesis& hypo,
|
||||
ScoreComponentCollection* accumulator) const;
|
||||
|
||||
void Evaluate(const Phrase &source
|
||||
, const TargetPhrase &targetPhrase
|
||||
, ScoreComponentCollection &scoreBreakdown
|
||||
, ScoreComponentCollection &estimatedFutureScore) const {};
|
||||
void Evaluate(const InputType &input
|
||||
, const InputPath &inputPath
|
||||
, const TargetPhrase &targetPhrase
|
||||
, ScoreComponentCollection &scoreBreakdown
|
||||
, ScoreComponentCollection *estimatedFutureScore = NULL) const {};
|
||||
void Evaluate(const Hypothesis& hypo,
|
||||
ScoreComponentCollection* accumulator) const {};
|
||||
|
||||
bool Load(const std::string &filePath);
|
||||
|
||||
std::map<Word, std::set<Word> >& Get_Soft_Matches() {
|
||||
return m_soft_matches;
|
||||
}
|
||||
|
||||
std::map<Word, std::set<Word> >& Get_Soft_Matches_Reverse() {
|
||||
return m_soft_matches_reverse;
|
||||
}
|
||||
|
||||
const std::string& GetFeatureName(const Word& LHS, const Word& RHS) const;
|
||||
|
||||
private:
|
||||
std::map<Word, std::set<Word> > m_soft_matches; // map LHS of old rule to RHS of new rle
|
||||
std::map<Word, std::set<Word> > m_soft_matches_reverse; // map RHS of new rule to LHS of old rule
|
||||
|
||||
typedef std::pair<Word, Word> NonTerminalMapKey;
|
||||
|
||||
#if defined(BOOST_VERSION) && (BOOST_VERSION >= 104200)
|
||||
typedef boost::unordered_map<NonTerminalMapKey,
|
||||
std::string,
|
||||
NonTerminalMapKeyHasher,
|
||||
NonTerminalMapKeyEqualityPred> NonTerminalSoftMatchingMap;
|
||||
#else
|
||||
typedef std::map<NonTerminalMapKey, std::string> NonTerminalSoftMatchingMap;
|
||||
#endif
|
||||
|
||||
mutable NonTerminalSoftMatchingMap m_soft_matching_cache;
|
||||
|
||||
#ifdef WITH_THREADS
|
||||
//reader-writer lock
|
||||
mutable boost::shared_mutex m_accessLock;
|
||||
#endif
|
||||
|
||||
};
|
||||
|
||||
}
|
||||
|
@ -214,6 +214,9 @@ protected:
|
||||
bool m_continuePartialTranslation;
|
||||
std::string m_binPath;
|
||||
|
||||
// soft NT lookup for chart models
|
||||
std::map<Word, std::set<Word> > m_soft_matches_map;
|
||||
std::map<Word, std::set<Word> > m_soft_matches_map_reverse;
|
||||
|
||||
public:
|
||||
|
||||
@ -731,6 +734,22 @@ public:
|
||||
return m_useLegacyPT;
|
||||
}
|
||||
|
||||
void Set_Soft_Matches(std::map<Word, std::set<Word> >& soft_matches_map) {
|
||||
m_soft_matches_map = soft_matches_map;
|
||||
}
|
||||
|
||||
const std::map<Word, std::set<Word> >* Get_Soft_Matches() const {
|
||||
return &m_soft_matches_map;
|
||||
}
|
||||
|
||||
void Set_Soft_Matches_Reverse(std::map<Word, std::set<Word> >& soft_matches_map) {
|
||||
m_soft_matches_map_reverse = soft_matches_map;
|
||||
}
|
||||
|
||||
const std::map<Word, std::set<Word> >* Get_Soft_Matches_Reverse() const {
|
||||
return &m_soft_matches_map_reverse;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
}
|
||||
|
@ -226,6 +226,11 @@ void ChartRuleLookupManagerMemory::ExtendPartialRuleApplication(
|
||||
const PhraseDictionaryNodeMemory::NonTerminalMap & nonTermMap =
|
||||
node.GetNonTerminalMap();
|
||||
|
||||
// permissible soft nonterminal matches (target side)
|
||||
const StaticData &staticData = StaticData::Instance();
|
||||
const std::map<Word, std::set<Word> >* m_soft_matches_map = staticData.Get_Soft_Matches();
|
||||
const std::map<Word, std::set<Word> >* m_soft_matches_map_reverse = staticData.Get_Soft_Matches_Reverse();
|
||||
|
||||
const size_t numChildren = nonTermMap.size();
|
||||
if (numChildren == 0) {
|
||||
return;
|
||||
@ -255,6 +260,34 @@ void ChartRuleLookupManagerMemory::ExtendPartialRuleApplication(
|
||||
for (; q != tEnd; ++q) {
|
||||
const ChartCellLabel &cellLabel = q->second;
|
||||
|
||||
//soft matching of NTs
|
||||
const Word& targetNonTerm = cellLabel.GetLabel();
|
||||
if (m_soft_matches_map->find(targetNonTerm) != m_soft_matches_map->end()) {
|
||||
const std::set<Word>& softMatches = m_soft_matches_map->find(targetNonTerm)->second;
|
||||
|
||||
for (std::set<Word>::const_iterator softMatch = softMatches.begin(); softMatch != softMatches.end(); ++softMatch) {
|
||||
|
||||
// try to match both source and target non-terminal
|
||||
const PhraseDictionaryNodeMemory * child =
|
||||
node.GetChild(sourceNonTerm, *softMatch);
|
||||
|
||||
// nothing found? then we are done
|
||||
if (child == NULL) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// create new rule
|
||||
#ifdef USE_BOOST_POOL
|
||||
DottedRuleInMemory *rule = m_dottedRulePool.malloc();
|
||||
new (rule) DottedRuleInMemory(*child, cellLabel, prevDottedRule);
|
||||
#else
|
||||
DottedRuleInMemory *rule = new DottedRuleInMemory(*child, cellLabel,
|
||||
prevDottedRule);
|
||||
#endif
|
||||
dottedRuleColl.Add(stackInd, rule);
|
||||
}
|
||||
} // end of soft matching
|
||||
|
||||
// try to match both source and target non-terminal
|
||||
const PhraseDictionaryNodeMemory * child =
|
||||
node.GetChild(sourceNonTerm, cellLabel.GetLabel());
|
||||
@ -288,6 +321,29 @@ void ChartRuleLookupManagerMemory::ExtendPartialRuleApplication(
|
||||
continue;
|
||||
}
|
||||
const Word &targetNonTerm = key.second;
|
||||
|
||||
//soft matching of NTs
|
||||
if (m_soft_matches_map_reverse->find(targetNonTerm) != m_soft_matches_map_reverse->end()) {
|
||||
const std::set<Word>& softMatches = m_soft_matches_map_reverse->find(targetNonTerm)->second;
|
||||
for (std::set<Word>::const_iterator softMatch = softMatches.begin(); softMatch != softMatches.end(); ++softMatch) {
|
||||
const ChartCellLabel *cellLabel = targetNonTerms.Find(*softMatch);
|
||||
if (!cellLabel) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// create new rule
|
||||
const PhraseDictionaryNodeMemory &child = p->second;
|
||||
#ifdef USE_BOOST_POOL
|
||||
DottedRuleInMemory *rule = m_dottedRulePool.malloc();
|
||||
new (rule) DottedRuleInMemory(child, *cellLabel, prevDottedRule);
|
||||
#else
|
||||
DottedRuleInMemory *rule = new DottedRuleInMemory(child, *cellLabel,
|
||||
prevDottedRule);
|
||||
#endif
|
||||
dottedRuleColl.Add(stackInd, rule);
|
||||
}
|
||||
} // end of soft matches lookup
|
||||
|
||||
const ChartCellLabel *cellLabel = targetNonTerms.Find(targetNonTerm);
|
||||
if (!cellLabel) {
|
||||
continue;
|
||||
|
Loading…
Reference in New Issue
Block a user