mosesdecoder/vw/VWPredictor.cpp

122 lines
3.8 KiB
C++
Raw Normal View History

2015-01-07 17:34:02 +03:00
#include <iostream>
#include "Classifier.h"
#include "vw.h"
#include "ezexample.h"
2015-01-07 17:34:02 +03:00
#include "../moses/Util.h"
2015-01-14 14:07:42 +03:00
namespace Discriminative
{
using namespace std;
VWPredictor::VWPredictor(const string &modelFile, const string &vwOptions)
{
2015-01-07 15:28:49 +03:00
m_VWInstance = VW::initialize(VW_DEFAULT_OPTIONS + " -i " + modelFile + vwOptions);
m_VWParser = VW::initialize(VW_DEFAULT_PARSER_OPTIONS + vwOptions + " --noop");
m_sharedVwInstance = false;
m_ex = new ::ezexample(m_VWInstance, false, m_VWParser);
m_isFirstSource = m_isFirstTarget = true;
}
VWPredictor::VWPredictor(vw *instance, const string &vwOptions)
{
m_VWInstance = instance;
m_VWParser = VW::initialize(vwOptions + " --noop");
m_sharedVwInstance = true;
m_ex = new ::ezexample(m_VWInstance, false, m_VWParser);
m_isFirstSource = m_isFirstTarget = true;
}
VWPredictor::~VWPredictor()
{
delete m_ex;
VW::finish(*m_VWParser);
if (!m_sharedVwInstance)
VW::finish(*m_VWInstance);
}
FeatureType VWPredictor::AddLabelIndependentFeature(const StringPiece &name, float value)
{
2015-01-06 17:54:15 +03:00
// label-independent features are kept in a different feature namespace ('s' = source)
if (m_isFirstSource) {
2015-01-06 17:54:15 +03:00
// the first feature of a new example => create the source namespace for
// label-independent features to live in
m_isFirstSource = false;
2015-12-09 19:35:28 +03:00
m_ex->finish();
m_ex->addns('s');
2015-01-07 17:34:02 +03:00
if (DEBUG) std::cerr << "VW :: Setting source namespace\n";
}
return AddFeature(name, value); // namespace 's' is set up, add the feature
}
FeatureType VWPredictor::AddLabelDependentFeature(const StringPiece &name, float value)
{
2015-01-06 17:54:15 +03:00
// VW does not use the label directly, instead, we do a Cartesian product between source and target feature
// namespaces, where the source namespace ('s') contains label-independent features and the target
// namespace ('t') contains label-dependent features
if (m_isFirstTarget) {
2015-01-06 17:54:15 +03:00
// the first target-side feature => create namespace 't'
m_isFirstTarget = false;
m_ex->addns('t');
2015-01-07 17:34:02 +03:00
if (DEBUG) std::cerr << "VW :: Setting target namespace\n";
}
return AddFeature(name, value);
}
2016-03-10 15:41:56 +03:00
void VWPredictor::AddLabelIndependentFeatureVector(const FeatureVector &features)
{
if (m_isFirstSource) {
// the first feature of a new example => create the source namespace for
// label-independent features to live in
m_isFirstSource = false;
m_ex->finish();
m_ex->addns('s');
if (DEBUG) std::cerr << "VW :: Setting source namespace\n";
}
// add each feature index using this "low level" call to VW
for (FeatureVector::const_iterator it = features.begin(); it != features.end(); it++)
m_ex->addf(it->first, it->second);
}
void VWPredictor::AddLabelDependentFeatureVector(const FeatureVector &features)
{
if (m_isFirstTarget) {
// the first target-side feature => create namespace 't'
m_isFirstTarget = false;
m_ex->addns('t');
if (DEBUG) std::cerr << "VW :: Setting target namespace\n";
}
// add each feature index using this "low level" call to VW
for (FeatureVector::const_iterator it = features.begin(); it != features.end(); it++)
m_ex->addf(it->first, it->second);
}
void VWPredictor::Train(const StringPiece &label, float loss)
{
throw logic_error("Trying to train during prediction!");
}
float VWPredictor::Predict(const StringPiece &label)
{
m_ex->set_label(label.as_string());
m_isFirstSource = true;
m_isFirstTarget = true;
2015-01-07 18:59:38 +03:00
float loss = m_ex->predict_partial();
2015-01-07 17:34:02 +03:00
if (DEBUG) std::cerr << "VW :: Predicted loss: " << loss << "\n";
m_ex->remns(); // remove target namespace
return loss;
}
FeatureType VWPredictor::AddFeature(const StringPiece &name, float value)
{
2015-01-07 18:59:38 +03:00
if (DEBUG) std::cerr << "VW :: Adding feature: " << EscapeSpecialChars(name.as_string()) << ":" << value << "\n";
return std::make_pair(m_ex->addf(EscapeSpecialChars(name.as_string()), value), value);
}
} // namespace Discriminative