mosesdecoder/vw/Classifier.h
2015-12-10 13:36:22 +01:00

182 lines
5.0 KiB
C++

#ifndef moses_Classifier_h
#define moses_Classifier_h
#include <iostream>
#include <string>
#include <fstream>
#include <sstream>
#include <deque>
#include <vector>
#include <boost/shared_ptr.hpp>
#include <boost/noncopyable.hpp>
#include <boost/thread/condition_variable.hpp>
#include <boost/thread/locks.hpp>
#include <boost/thread/mutex.hpp>
#include <boost/iostreams/filtering_stream.hpp>
#include <boost/iostreams/filter/gzip.hpp>
#include "../util/string_piece.hh"
#include "../moses/Util.h"
// forward declarations to avoid dependency on VW
struct vw;
class ezexample;
namespace Discriminative
{
/**
* Abstract class to be implemented by classifiers.
*/
class Classifier
{
public:
/**
* Add a feature that does not depend on the class (label).
*/
virtual void AddLabelIndependentFeature(const StringPiece &name, float value) = 0;
/**
* Add a feature that is specific for the given class.
*/
virtual void AddLabelDependentFeature(const StringPiece &name, float value) = 0;
/**
* Train using current example. Use loss to distinguish positive and negative training examples.
* Throws away current label-dependent features (so that features for another label/class can now be set).
*/
virtual void Train(const StringPiece &label, float loss) = 0;
/**
* Predict the loss (inverse of score) of current example.
* Throws away current label-dependent features (so that features for another label/class can now be set).
*/
virtual float Predict(const StringPiece &label) = 0;
// helper methods for indicator features
void AddLabelIndependentFeature(const StringPiece &name) {
AddLabelIndependentFeature(name, 1.0);
}
void AddLabelDependentFeature(const StringPiece &name) {
AddLabelDependentFeature(name, 1.0);
}
virtual ~Classifier() {}
protected:
/**
* Escape special characters in a unified way.
*/
static std::string EscapeSpecialChars(const std::string &str) {
std::string out;
out = Moses::Replace(str, "\\", "_/_");
out = Moses::Replace(out, "|", "\\/");
out = Moses::Replace(out, ":", "\\;");
out = Moses::Replace(out, " ", "\\_");
return out;
}
const static bool DEBUG = false;
};
// some of VW settings are hard-coded because they are always needed in our scenario
// (e.g. quadratic source X target features)
const std::string VW_DEFAULT_OPTIONS = " --hash all --noconstant -q st -t --ldf_override sc ";
const std::string VW_DEFAULT_PARSER_OPTIONS = " --quiet --hash all --noconstant -q st -t --csoaa_ldf sc ";
/**
* Produce VW training file (does not use the VW library!)
*/
class VWTrainer : public Classifier
{
public:
VWTrainer(const std::string &outputFile);
virtual ~VWTrainer();
virtual void AddLabelIndependentFeature(const StringPiece &name, float value);
virtual void AddLabelDependentFeature(const StringPiece &name, float value);
virtual void Train(const StringPiece &label, float loss);
virtual float Predict(const StringPiece &label);
protected:
void AddFeature(const StringPiece &name, float value);
bool m_isFirstSource, m_isFirstTarget, m_isFirstExample;
private:
boost::iostreams::filtering_ostream m_bfos;
std::deque<std::string> m_outputBuffer;
void WriteBuffer();
};
/**
* Predict using VW library.
*/
class VWPredictor : public Classifier, private boost::noncopyable
{
public:
VWPredictor(const std::string &modelFile, const std::string &vwOptions);
virtual ~VWPredictor();
virtual void AddLabelIndependentFeature(const StringPiece &name, float value);
virtual void AddLabelDependentFeature(const StringPiece &name, float value);
virtual void Train(const StringPiece &label, float loss);
virtual float Predict(const StringPiece &label);
friend class ClassifierFactory;
protected:
void AddFeature(const StringPiece &name, float values);
::vw *m_VWInstance, *m_VWParser;
::ezexample *m_ex;
// if true, then the VW instance is owned by an external party and should NOT be
// deleted at end; if false, then we own the VW instance and must clean up after it.
bool m_sharedVwInstance;
bool m_isFirstSource, m_isFirstTarget;
private:
// instantiation by classifier factory
VWPredictor(vw * instance, const std::string &vwOption);
};
/**
* Provider for classifier instances to be used by individual threads.
*/
class ClassifierFactory : private boost::noncopyable
{
public:
typedef boost::shared_ptr<Classifier> ClassifierPtr;
/**
* Creates VWPredictor instances to be used by individual threads.
*/
ClassifierFactory(const std::string &modelFile, const std::string &vwOptions);
/**
* Creates VWTrainer instances (which write features to a file).
*/
ClassifierFactory(const std::string &modelFilePrefix);
// return VWPredictor or VWTrainer instance depending on whether we're in training mode
ClassifierPtr operator()();
~ClassifierFactory();
private:
std::string m_vwOptions;
::vw *m_VWInstance;
int m_lastId;
std::string m_modelFilePrefix;
bool m_gzip;
boost::mutex m_mutex;
const bool m_train;
};
} // namespace Discriminative
#endif // moses_Classifier_h