Added VWFeatureSource and VWFeatureTarget

This commit is contained in:
Marcin Junczys-Dowmunt 2015-01-07 12:18:01 +01:00
parent 8f898a0a48
commit cb50517f7d
6 changed files with 108 additions and 18 deletions

View File

@ -8,8 +8,6 @@
#include "Classifier.h"
#include "VWFeatureBase.h"
#include "VWFeatureBase.h"
namespace Moses
{
@ -51,7 +49,18 @@ public:
Discriminative::Classifier *classifier = m_train
? m_trainer
: (Discriminative::Classifier *)m_predictorFactory->Acquire();
const std::vector<VWFeatureBase*>& features = VWFeatureBase::GetFeatures(GetScoreProducerDescription());
UTIL_THROW_IF2(translationOptionList.size() == 0, "There are not translation options.");
const std::vector<VWFeatureBase*>& sourceFeatures = VWFeatureBase::GetSourceFeatures(GetScoreProducerDescription());
const WordsRange &sourceRange = translationOptionList.Get(0)->GetSourceWordsRange();
const InputPath &inputPath = translationOptionList.Get(0)->GetInputPath();
for(size_t i = 0; i < sourceFeatures.size(); ++i)
(*sourceFeatures[i])(input, inputPath, sourceRange, classifier);
const std::vector<VWFeatureBase*>& targetFeatures = VWFeatureBase::GetTargetFeatures(GetScoreProducerDescription());
std::vector<float> losses(translationOptionList.size());
@ -61,8 +70,8 @@ public:
iterTransOpt != translationOptionList.end() ; ++iterTransOpt, ++iterLoss) {
TranslationOption &transOpt = **iterTransOpt;
for(size_t i = 0; i < features.size(); ++i)
(*features[i])(input, transOpt.GetInputPath(), transOpt.GetTargetPhrase(), classifier);
for(size_t i = 0; i < targetFeatures.size(); ++i)
(*targetFeatures[i])(input, inputPath, transOpt.GetTargetPhrase(), classifier);
*iterLoss = classifier->Predict("DUMMY"); // VW does not use the label!!
// TODO handle training somehow

View File

@ -1,24 +1,24 @@
#pragma once
#include <string>
#include "VWFeatureBase.h"
#include "VWFeatureSource.h"
namespace Moses
{
class VWFeatureBagOfWords : public VWFeatureBase
class VWFeatureBagOfWords : public VWFeatureSource
{
public:
VWFeatureBagOfWords(const std::string &line)
: VWFeatureBase(line)
: VWFeatureSource(line)
{}
void operator()(const InputType &input
, const InputPath &inputPath
, const TargetPhrase &targetPhrase
, const WordsRange &sourceRange
, Discriminative::Classifier *classifier) const
{
std::cerr << GetScoreProducerDescription() << " got TargetPhrase: " << targetPhrase << std::endl;
std::cerr << GetScoreProducerDescription() << " got Phrase: " << sourceRange << std::endl;
}
};

View File

@ -6,5 +6,7 @@
namespace Moses
{
std::map<std::string, std::vector<VWFeatureBase*> > VWFeatureBase::s_features;
std::map<std::string, std::vector<VWFeatureBase*> > VWFeatureBase::s_sourceFeatures;
std::map<std::string, std::vector<VWFeatureBase*> > VWFeatureBase::s_targetFeatures;
}

View File

@ -12,7 +12,7 @@ namespace Moses
class VWFeatureBase : public StatelessFeatureFunction
{
public:
VWFeatureBase(const std::string &line)
VWFeatureBase(const std::string &line, bool isSource = true)
:StatelessFeatureFunction(0, line)
{
ReadParameters();
@ -21,8 +21,13 @@ class VWFeatureBase : public StatelessFeatureFunction
m_usedBy.push_back("VW0");
for(std::vector<std::string>::const_iterator it = m_usedBy.begin();
it != m_usedBy.end(); it++)
it != m_usedBy.end(); it++) {
s_features[*it].push_back(this);
if(isSource)
s_sourceFeatures[*it].push_back(this);
else
s_targetFeatures[*it].push_back(this);
}
}
bool IsUseable(const FactorMask &mask) const {
@ -62,16 +67,31 @@ class VWFeatureBase : public StatelessFeatureFunction
}
}
virtual void operator()(const InputType &input
, const InputPath &inputPath
, const TargetPhrase &targetPhrase
, Discriminative::Classifier *classifier) const = 0;
static const std::vector<VWFeatureBase*>& GetFeatures(std::string name = "VW0") {
UTIL_THROW_IF2(s_features.count(name) == 0, "No features registered for parent classifier: " + name);
return s_features[name];
}
static const std::vector<VWFeatureBase*>& GetSourceFeatures(std::string name = "VW0") {
UTIL_THROW_IF2(s_sourceFeatures.count(name) == 0, "No source features registered for parent classifier: " + name);
return s_sourceFeatures[name];
}
static const std::vector<VWFeatureBase*>& GetTargetFeatures(std::string name = "VW0") {
UTIL_THROW_IF2(s_targetFeatures.count(name) == 0, "No target features registered for parent classifier: " + name);
return s_targetFeatures[name];
}
virtual void operator()(const InputType &input
, const InputPath &inputPath
, const WordsRange &sourceRange
, Discriminative::Classifier *classifier) const = 0;
virtual void operator()(const InputType &input
, const InputPath &inputPath
, const TargetPhrase &targetPhrase
, Discriminative::Classifier *classifier) const = 0;
protected:
std::vector<FactorType> m_sourceFactors, m_targetFactors;
@ -79,7 +99,7 @@ class VWFeatureBase : public StatelessFeatureFunction
void ParseFactorDefinition(const std::string &list, /* out */ std::vector<FactorType> &out)
{
std::vector<std::string> split = Tokenize(list, ",");
Scan<int>(out, split);
Scan<FactorType>(out, split);
}
void ParseUsedBy(const std::string &usedBy) {
@ -88,6 +108,8 @@ class VWFeatureBase : public StatelessFeatureFunction
std::vector<std::string> m_usedBy;
static std::map<std::string, std::vector<VWFeatureBase*> > s_features;
static std::map<std::string, std::vector<VWFeatureBase*> > s_sourceFeatures;
static std::map<std::string, std::vector<VWFeatureBase*> > s_targetFeatures;
};
}

View File

@ -0,0 +1,28 @@
#pragma once
#include <string>
#include "VWFeatureBase.h"
namespace Moses
{
class VWFeatureSource : public VWFeatureBase
{
public:
VWFeatureSource(const std::string &line)
: VWFeatureBase(line, true)
{}
virtual void operator()(const InputType &input
, const InputPath &inputPath
, const WordsRange &sourceRange
, Discriminative::Classifier *classifier) const = 0;
virtual void operator()(const InputType &input
, const InputPath &inputPath
, const TargetPhrase &targetPhrase
, Discriminative::Classifier *classifier) const
{}
};
}

View File

@ -0,0 +1,29 @@
#pragma once
#include <string>
#include "VWFeatureBase.h"
namespace Moses
{
class VWFeatureTarget : public VWFeatureBase
{
public:
VWFeatureTarget(const std::string &line)
: VWFeatureBase(line, false)
{}
virtual void operator()(const InputType &input
, const InputPath &inputPath
, const TargetPhrase &targetPhrase
, Discriminative::Classifier *classifier) const = 0;
virtual void operator()(const InputType &input
, const InputPath &inputPath
, const Phrase &sourcePhrase
, Discriminative::Classifier *classifier) const
{}
};
}