diff --git a/moses/FF/Factory.cpp b/moses/FF/Factory.cpp index a633e4076..165683cbb 100644 --- a/moses/FF/Factory.cpp +++ b/moses/FF/Factory.cpp @@ -34,6 +34,7 @@ #include "moses/FF/ExternalFeature.h" #include "moses/FF/ConstrainedDecoding.h" #include "moses/FF/CoveredReferenceFeature.h" +#include "moses/FF/SoftMatchingFeature.h" #include "moses/FF/SkeletonStatelessFF.h" #include "moses/FF/SkeletonStatefulFF.h" @@ -171,6 +172,7 @@ FeatureRegistry::FeatureRegistry() MOSES_FNAME(ConstrainedDecoding); MOSES_FNAME(CoveredReferenceFeature); MOSES_FNAME(ExternalFeature); + MOSES_FNAME(SoftMatchingFeature); MOSES_FNAME(SkeletonStatelessFF); MOSES_FNAME(SkeletonStatefulFF); diff --git a/moses/FF/SoftMatchingFeature.cpp b/moses/FF/SoftMatchingFeature.cpp new file mode 100644 index 000000000..a21b848ba --- /dev/null +++ b/moses/FF/SoftMatchingFeature.cpp @@ -0,0 +1,108 @@ +#include "SoftMatchingFeature.h" +#include "moses/AlignmentInfo.h" +#include "moses/TargetPhrase.h" +#include "moses/ChartHypothesis.h" +#include "moses/StaticData.h" +#include "moses/InputFileStream.h" + +namespace Moses +{ + +SoftMatchingFeature::SoftMatchingFeature(const std::string &line) + : StatelessFeatureFunction(0, line) +{ + std::cerr << "Initializing SoftMatchingFeature.." << std::endl; + + for (size_t i = 0; i < m_args.size(); ++i) { + const std::vector &args = m_args[i]; + if (args[0] == "path") { + const std::string filePath = args[1]; + Load(filePath); + } + } // for +} + +bool SoftMatchingFeature::Load(const std::string& filePath) +{ + + StaticData &staticData = StaticData::InstanceNonConst(); + + InputFileStream inStream(filePath); + std::string line; + while(getline(inStream, line)) { + std::vector tokens = Tokenize(line); + UTIL_THROW_IF2(tokens.size() != 2, "Error: wrong format of SoftMatching file: must have two nonterminals per line"); + + // no soft matching necessary if LHS and RHS are the same + if (tokens[0] == tokens[1]) { + continue; + } + + Word LHS, RHS; + LHS.CreateFromString(Output, staticData.GetOutputFactorOrder(), tokens[0], true); + RHS.CreateFromString(Output, staticData.GetOutputFactorOrder(), tokens[1], true); + + m_soft_matches[LHS].insert(RHS); + m_soft_matches_reverse[RHS].insert(LHS); + } + + staticData.Set_Soft_Matches(Get_Soft_Matches()); + staticData.Set_Soft_Matches_Reverse(Get_Soft_Matches_Reverse()); + + return true; +} + +void SoftMatchingFeature::EvaluateChart(const ChartHypothesis& hypo, + ScoreComponentCollection* accumulator) const +{ + + const TargetPhrase& target = hypo.GetCurrTargetPhrase(); + + const AlignmentInfo::NonTermIndexMap &nonTermIndexMap = target.GetAlignNonTerm().GetNonTermIndexMap(); + + // loop over the rule that is being applied + for (size_t pos = 0; pos < target.GetSize(); ++pos) { + const Word& word = target.GetWord(pos); + + // for non-terminals, trigger the feature mapping the LHS of the previous hypo to the RHS of this hypo + if (word.IsNonTerminal()) { + size_t nonTermInd = nonTermIndexMap[pos]; + + const ChartHypothesis* prevHypo = hypo.GetPrevHypo(nonTermInd); + const Word& prevLHS = prevHypo->GetTargetLHS(); + + const std::string name = GetFeatureName(prevLHS, word); + accumulator->PlusEquals(this,name,1); + } + } +} + +//caching feature names because string conversion is slow +const std::string& SoftMatchingFeature::GetFeatureName(const Word& LHS, const Word& RHS) const +{ + + const NonTerminalMapKey key(LHS, RHS); + { +#ifdef WITH_THREADS //try read-only lock + boost::shared_lock read_lock(m_accessLock); +#endif // WITH_THREADS + NonTerminalSoftMatchingMap::const_iterator i = m_soft_matching_cache.find(key); + if (i != m_soft_matching_cache.end()) return i->second; + } +#ifdef WITH_THREADS //need to update cache; write lock + boost::unique_lock lock(m_accessLock); +#endif // WITH_THREADS + const std::vector &outputFactorOrder = StaticData::Instance().GetOutputFactorOrder(); + + std::string LHS_string = LHS.GetString(outputFactorOrder, false); + std::string RHS_string = RHS.GetString(outputFactorOrder, false); + + const std::string name = LHS_string + "->" + RHS_string; + + m_soft_matching_cache[key] = name; + return m_soft_matching_cache.find(key)->second; +} + + +} + diff --git a/moses/FF/SoftMatchingFeature.h b/moses/FF/SoftMatchingFeature.h new file mode 100644 index 000000000..351ef3a93 --- /dev/null +++ b/moses/FF/SoftMatchingFeature.h @@ -0,0 +1,77 @@ +#pragma once + +#include +#include "moses/Util.h" +#include "moses/Word.h" +#include "StatelessFeatureFunction.h" +#include "moses/TranslationModel/PhraseDictionaryNodeMemory.h" + +#ifdef WITH_THREADS +#include +#endif + +namespace Moses +{ + +class SoftMatchingFeature : public StatelessFeatureFunction +{ +public: + SoftMatchingFeature(const std::string &line); + + bool IsUseable(const FactorMask &mask) const { + return true; + } + + virtual void EvaluateChart(const ChartHypothesis& hypo, + ScoreComponentCollection* accumulator) const; + + void Evaluate(const Phrase &source + , const TargetPhrase &targetPhrase + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection &estimatedFutureScore) const {}; + void Evaluate(const InputType &input + , const InputPath &inputPath + , const TargetPhrase &targetPhrase + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection *estimatedFutureScore = NULL) const {}; + void Evaluate(const Hypothesis& hypo, + ScoreComponentCollection* accumulator) const {}; + + bool Load(const std::string &filePath); + + std::map >& Get_Soft_Matches() { + return m_soft_matches; + } + + std::map >& Get_Soft_Matches_Reverse() { + return m_soft_matches_reverse; + } + + const std::string& GetFeatureName(const Word& LHS, const Word& RHS) const; + +private: + std::map > m_soft_matches; // map LHS of old rule to RHS of new rle + std::map > m_soft_matches_reverse; // map RHS of new rule to LHS of old rule + + typedef std::pair NonTerminalMapKey; + +#if defined(BOOST_VERSION) && (BOOST_VERSION >= 104200) + typedef boost::unordered_map NonTerminalSoftMatchingMap; +#else + typedef std::map NonTerminalSoftMatchingMap; +#endif + + mutable NonTerminalSoftMatchingMap m_soft_matching_cache; + +#ifdef WITH_THREADS + //reader-writer lock + mutable boost::shared_mutex m_accessLock; +#endif + +}; + +} + diff --git a/moses/StaticData.h b/moses/StaticData.h index 55ecebe6b..ce6debfd8 100644 --- a/moses/StaticData.h +++ b/moses/StaticData.h @@ -214,6 +214,9 @@ protected: bool m_continuePartialTranslation; std::string m_binPath; + // soft NT lookup for chart models + std::map > m_soft_matches_map; + std::map > m_soft_matches_map_reverse; public: @@ -731,6 +734,22 @@ public: return m_useLegacyPT; } + void Set_Soft_Matches(std::map >& soft_matches_map) { + m_soft_matches_map = soft_matches_map; + } + + const std::map >* Get_Soft_Matches() const { + return &m_soft_matches_map; + } + + void Set_Soft_Matches_Reverse(std::map >& soft_matches_map) { + m_soft_matches_map_reverse = soft_matches_map; + } + + const std::map >* Get_Soft_Matches_Reverse() const { + return &m_soft_matches_map_reverse; + } + }; } diff --git a/moses/TranslationModel/CYKPlusParser/ChartRuleLookupManagerMemory.cpp b/moses/TranslationModel/CYKPlusParser/ChartRuleLookupManagerMemory.cpp index 64d3d87b4..53cdbe541 100644 --- a/moses/TranslationModel/CYKPlusParser/ChartRuleLookupManagerMemory.cpp +++ b/moses/TranslationModel/CYKPlusParser/ChartRuleLookupManagerMemory.cpp @@ -226,6 +226,11 @@ void ChartRuleLookupManagerMemory::ExtendPartialRuleApplication( const PhraseDictionaryNodeMemory::NonTerminalMap & nonTermMap = node.GetNonTerminalMap(); + // permissible soft nonterminal matches (target side) + const StaticData &staticData = StaticData::Instance(); + const std::map >* m_soft_matches_map = staticData.Get_Soft_Matches(); + const std::map >* m_soft_matches_map_reverse = staticData.Get_Soft_Matches_Reverse(); + const size_t numChildren = nonTermMap.size(); if (numChildren == 0) { return; @@ -255,6 +260,34 @@ void ChartRuleLookupManagerMemory::ExtendPartialRuleApplication( for (; q != tEnd; ++q) { const ChartCellLabel &cellLabel = q->second; + //soft matching of NTs + const Word& targetNonTerm = cellLabel.GetLabel(); + if (m_soft_matches_map->find(targetNonTerm) != m_soft_matches_map->end()) { + const std::set& softMatches = m_soft_matches_map->find(targetNonTerm)->second; + + for (std::set::const_iterator softMatch = softMatches.begin(); softMatch != softMatches.end(); ++softMatch) { + + // try to match both source and target non-terminal + const PhraseDictionaryNodeMemory * child = + node.GetChild(sourceNonTerm, *softMatch); + + // nothing found? then we are done + if (child == NULL) { + continue; + } + + // create new rule +#ifdef USE_BOOST_POOL + DottedRuleInMemory *rule = m_dottedRulePool.malloc(); + new (rule) DottedRuleInMemory(*child, cellLabel, prevDottedRule); +#else + DottedRuleInMemory *rule = new DottedRuleInMemory(*child, cellLabel, + prevDottedRule); +#endif + dottedRuleColl.Add(stackInd, rule); + } + } // end of soft matching + // try to match both source and target non-terminal const PhraseDictionaryNodeMemory * child = node.GetChild(sourceNonTerm, cellLabel.GetLabel()); @@ -288,6 +321,29 @@ void ChartRuleLookupManagerMemory::ExtendPartialRuleApplication( continue; } const Word &targetNonTerm = key.second; + + //soft matching of NTs + if (m_soft_matches_map_reverse->find(targetNonTerm) != m_soft_matches_map_reverse->end()) { + const std::set& softMatches = m_soft_matches_map_reverse->find(targetNonTerm)->second; + for (std::set::const_iterator softMatch = softMatches.begin(); softMatch != softMatches.end(); ++softMatch) { + const ChartCellLabel *cellLabel = targetNonTerms.Find(*softMatch); + if (!cellLabel) { + continue; + } + + // create new rule + const PhraseDictionaryNodeMemory &child = p->second; +#ifdef USE_BOOST_POOL + DottedRuleInMemory *rule = m_dottedRulePool.malloc(); + new (rule) DottedRuleInMemory(child, *cellLabel, prevDottedRule); +#else + DottedRuleInMemory *rule = new DottedRuleInMemory(child, *cellLabel, + prevDottedRule); +#endif + dottedRuleColl.Add(stackInd, rule); + } + } // end of soft matches lookup + const ChartCellLabel *cellLabel = targetNonTerms.Find(targetNonTerm); if (!cellLabel) { continue;