Moved caching of lexical reordering scores from class TranslationOption to class TargetPhrase.

This was done so that phrase tables can add this information (if available) as extra annotation
to TargetPhrases, in preparation of providing lexical reordering models with sampling phrase tables.
This commit is contained in:
Ulrich Germann 2015-03-09 00:30:01 +00:00
parent ddea89312e
commit c1d2313a66
7 changed files with 146 additions and 33 deletions

View File

@ -3,10 +3,12 @@
#include <boost/foreach.hpp>
#include "moses/FF/FFState.h"
#include "moses/TranslationOptionList.h"
#include "LexicalReordering.h"
#include "LexicalReorderingState.h"
#include "moses/StaticData.h"
#include "moses/Util.h"
#include "moses/InputPath.h"
using namespace std;
using namespace boost::algorithm;
@ -124,5 +126,25 @@ IsUseable(const FactorMask &mask) const
}
return true;
}
void
LexicalReordering::
SetCache(TranslationOption& to) const
{
Phrase const& sphrase = to.GetInputPath().GetPhrase();
Phrase const& tphrase = to.GetTargetPhrase();
to.CacheLexReorderingScores(*this, this->GetProb(sphrase,tphrase));
}
void
LexicalReordering::
SetCache(TranslationOptionList& tol) const
{
BOOST_FOREACH(TranslationOption* to, tol)
this->SetCache(*to);
}
}

View File

@ -9,6 +9,7 @@
#include "moses/TypeDef.h"
#include "moses/Util.h"
#include "moses/WordsRange.h"
#include "moses/TranslationOption.h"
#include "moses/FF/StatefulFeatureFunction.h"
#include "util/exception.hh"
@ -95,6 +96,14 @@ public:
return m_defaultScores[i];
}
virtual
void
SetCache(TranslationOption& to) const;
virtual
void
SetCache(TranslationOptionList& tol) const;
private:
bool DecodeCondition(std::string s);
bool DecodeDirection(std::string s);

View File

@ -33,6 +33,7 @@
#include "AlignmentInfoCollection.h"
#include "InputPath.h"
#include "moses/TranslationModel/PhraseDictionary.h"
#include <boost/foreach.hpp>
using namespace std;
@ -83,6 +84,7 @@ TargetPhrase::TargetPhrase(const Phrase &phrase, const PhraseDictionary *pt)
TargetPhrase::TargetPhrase(const TargetPhrase &copy)
: Phrase(copy)
, m_cached_scores(copy.m_cached_scores)
, m_fullScore(copy.m_fullScore)
, m_futureScore(copy.m_futureScore)
, m_scoreBreakdown(copy.m_scoreBreakdown)
@ -221,14 +223,66 @@ void TargetPhrase::SetSparseScore(const FeatureFunction* translationScoreProduce
m_scoreBreakdown.Assign(translationScoreProducer, sparseString.as_string());
}
void TargetPhrase::Merge(const TargetPhrase &copy, const std::vector<FactorType>& factorVec)
boost::shared_ptr<Scores>
mergescores(boost::shared_ptr<Scores> const& a,
boost::shared_ptr<Scores> const& b)
{
boost::shared_ptr<Scores> ret;
if (!a) return b ? b : ret;
if (!b) return a;
if (a->size() != b->size()) return ret;
ret.reset(new Scores(*a));
for (size_t i = 0; i < a->size(); ++i)
{
if ((*a)[i] == 0) (*a)[i] = (*b)[i];
else if ((*b)[i])
{
UTIL_THROW_IF2((*a)[i] != (*b)[i], "can't merge feature vectors");
}
}
return ret;
}
void
TargetPhrase::
Merge(const TargetPhrase &copy, const std::vector<FactorType>& factorVec)
{
Phrase::MergeFactors(copy, factorVec);
m_scoreBreakdown.Merge(copy.GetScoreBreakdown());
m_futureScore += copy.m_futureScore;
m_fullScore += copy.m_fullScore;
typedef ScoreCache_t::iterator iter;
typedef ScoreCache_t::value_type item;
BOOST_FOREACH(item const& s, copy.m_cached_scores)
{
pair<iter,bool> foo = m_cached_scores.insert(s);
if (foo.second == false)
foo.first->second = mergescores(foo.first->second, s.second);
}
}
TargetPhrase::ScoreCache_t const&
TargetPhrase::
GetExtraScores() const
{
return m_cached_scores;
}
Scores const*
TargetPhrase::
GetExtraScores(FeatureFunction const* ff) const
{
ScoreCache_t::const_iterator m = m_cached_scores.find(ff);
return m != m_cached_scores.end() ? m->second.get() : NULL;
}
void
TargetPhrase::
SetExtraScores(FeatureFunction const* ff,
boost::shared_ptr<Scores> const& s)
{ m_cached_scores[ff] = s; }
void TargetPhrase::SetProperties(const StringPiece &str)
{
if (str.size() == 0) {
@ -287,6 +341,7 @@ void swap(TargetPhrase &first, TargetPhrase &second)
std::swap(first.m_alignTerm, second.m_alignTerm);
std::swap(first.m_alignNonTerm, second.m_alignNonTerm);
std::swap(first.m_lhsTarget, second.m_lhsTarget);
std::swap(first.m_cached_scores, second.m_cached_scores);
}
TO_STRING_BODY(TargetPhrase);
@ -325,5 +380,7 @@ std::ostream& operator<<(std::ostream& os, const TargetPhrase& tp)
return os;
}
}

View File

@ -1,3 +1,4 @@
// -*- c++ -*-
// $Id$
/***********************************************************************
@ -50,6 +51,17 @@ class PhraseDictionary;
*/
class TargetPhrase: public Phrase
{
public:
typedef std::map<FeatureFunction const*, boost::shared_ptr<Scores> >
ScoreCache_t;
ScoreCache_t const& GetExtraScores() const;
Scores const* GetExtraScores(FeatureFunction const* ff) const;
void SetExtraScores(FeatureFunction const* ff,
boost::shared_ptr<Scores> const& scores);
private:
ScoreCache_t m_cached_scores;
private:
friend std::ostream& operator<<(std::ostream&, const TargetPhrase&);
friend void swap(TargetPhrase &first, TargetPhrase &second);
@ -186,6 +198,8 @@ public:
return found->second;
}
// To be set by the FF that needs it, by default the rule source = NULL
// make a copy of the source side of the rule
void SetRuleSource(const Phrase &ruleSource) const;

View File

@ -25,6 +25,7 @@ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
#include "GenerationDictionary.h"
#include "StaticData.h"
#include "InputType.h"
#include "moses/FF/LexicalReordering/LexicalReordering.h"
using namespace std;
@ -35,8 +36,7 @@ TranslationOption::TranslationOption()
:m_targetPhrase(NULL)
,m_inputPath(NULL)
,m_sourceWordsRange(NOT_FOUND, NOT_FOUND)
{
}
{ }
//TODO this should be a factory function!
TranslationOption::TranslationOption(const WordsRange &wordsRange
@ -66,9 +66,14 @@ bool TranslationOption::Overlap(const Hypothesis &hypothesis) const
return bitmap.Overlap(GetSourceWordsRange());
}
void TranslationOption::CacheLexReorderingScores(const LexicalReordering &producer, const Scores &score)
void
TranslationOption::
CacheLexReorderingScores(const LexicalReordering &producer, const Scores &score)
{
m_lexReorderingScores[&producer] = score;
if (score.empty()) return;
boost::shared_ptr<Scores> stored(new Scores(score));
m_targetPhrase.SetExtraScores(&producer,stored);
// m_lexReorderingScores[&producer] = score;
}
void TranslationOption::EvaluateWithSourceContext(const InputType &input)
@ -104,6 +109,19 @@ ostream& operator<<(ostream& out, const TranslationOption& possibleTranslation)
return out;
}
/** returns cached scores */
const Scores*
TranslationOption::
GetLexReorderingScores(LexicalReordering const* scoreProducer) const
{
return m_targetPhrase.GetExtraScores(scoreProducer);
// _ScoreCacheMap::const_iterator it;
// it = m_lexReorderingScores.find(scoreProducer);
// if(it == m_lexReorderingScores.end())
// return NULL;
// else
// return &(it->second);
}
}

View File

@ -35,7 +35,6 @@ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
#include "TypeDef.h"
#include "ScoreComponentCollection.h"
#include "StaticData.h"
namespace Moses
{
@ -71,8 +70,11 @@ protected:
const WordsRange m_sourceWordsRange; /*< word position in the input that are covered by this translation option */
float m_futureScore; /*< estimate of total cost when using this translation option, includes language model probabilities */
typedef std::map<const LexicalReordering*, Scores> _ScoreCacheMap;
_ScoreCacheMap m_lexReorderingScores;
// typedef std::map<const LexicalReordering*, Scores> _ScoreCacheMap;
// _ScoreCacheMap m_lexReorderingScores;
// m_lexReorderingScores was moved to TargetPhrase.h so that phrase tables
// can add information (such as lexical reordering scores) to target phrases
// during lookup.
public:
struct Better {
@ -154,15 +156,15 @@ public:
}
/** returns cached scores */
inline const Scores *GetLexReorderingScores(const LexicalReordering *scoreProducer) const {
_ScoreCacheMap::const_iterator it = m_lexReorderingScores.find(scoreProducer);
if(it == m_lexReorderingScores.end())
return NULL;
else
return &(it->second);
}
// inline
const Scores*
GetLexReorderingScores(const LexicalReordering *scoreProducer) const;
// {
// return m_targetPhrase.GetExtraScores(scoreProducer);
// }
void CacheLexReorderingScores(const LexicalReordering &scoreProducer, const Scores &score);
void CacheLexReorderingScores(const LexicalReordering &scoreProducer,
const Scores &score);
TO_STRING();

View File

@ -622,25 +622,16 @@ void
TranslationOptionCollection::
CacheLexReordering()
{
typedef StatefulFeatureFunction sfFF;
std::vector<const sfFF*> const& all_sfff
= sfFF::GetStatefulFeatureFunctions();
size_t const stop = m_source.GetSize();
BOOST_FOREACH(sfFF const* ff, all_sfff) {
if (typeid(*ff) != typeid(LexicalReordering)) continue;
LexicalReordering const& lr = static_cast<const LexicalReordering&>(*ff);
for (size_t s = 0 ; s < stop ; s++) {
BOOST_FOREACH(TranslationOptionList const& tol, m_collection[s]) {
BOOST_FOREACH(TranslationOption* to, tol) {
Phrase const& sphrase = to->GetInputPath().GetPhrase();
Phrase const& tphrase = to->GetTargetPhrase();
Scores score = lr.GetProb(sphrase,tphrase);
if (!score.empty()) to->CacheLexReorderingScores(lr, score);
}
}
typedef StatefulFeatureFunction sfFF;
BOOST_FOREACH(sfFF const* ff, sfFF::GetStatefulFeatureFunctions())
{
if (typeid(*ff) != typeid(LexicalReordering)) continue;
LexicalReordering const& lr = static_cast<const LexicalReordering&>(*ff);
for (size_t s = 0 ; s < stop ; s++)
BOOST_FOREACH(TranslationOptionList& tol, m_collection[s])
lr.SetCache(tol);
}
}
}
//! list of trans opt for a particular span