mirror of
https://github.com/moses-smt/mosesdecoder.git
synced 2024-12-27 05:55:02 +03:00
182 lines
5.0 KiB
C++
182 lines
5.0 KiB
C++
#ifndef moses_Classifier_h
|
|
#define moses_Classifier_h
|
|
|
|
#include <iostream>
|
|
#include <string>
|
|
#include <fstream>
|
|
#include <sstream>
|
|
#include <deque>
|
|
#include <vector>
|
|
#include <boost/shared_ptr.hpp>
|
|
|
|
#include <boost/noncopyable.hpp>
|
|
#include <boost/thread/condition_variable.hpp>
|
|
#include <boost/thread/locks.hpp>
|
|
#include <boost/thread/mutex.hpp>
|
|
#include <boost/iostreams/filtering_stream.hpp>
|
|
#include <boost/iostreams/filter/gzip.hpp>
|
|
#include "../util/string_piece.hh"
|
|
#include "../moses/Util.h"
|
|
|
|
// forward declarations to avoid dependency on VW
|
|
struct vw;
|
|
class ezexample;
|
|
|
|
namespace Discriminative
|
|
{
|
|
|
|
/**
|
|
* Abstract class to be implemented by classifiers.
|
|
*/
|
|
class Classifier
|
|
{
|
|
public:
|
|
/**
|
|
* Add a feature that does not depend on the class (label).
|
|
*/
|
|
virtual void AddLabelIndependentFeature(const StringPiece &name, float value) = 0;
|
|
|
|
/**
|
|
* Add a feature that is specific for the given class.
|
|
*/
|
|
virtual void AddLabelDependentFeature(const StringPiece &name, float value) = 0;
|
|
|
|
/**
|
|
* Train using current example. Use loss to distinguish positive and negative training examples.
|
|
* Throws away current label-dependent features (so that features for another label/class can now be set).
|
|
*/
|
|
virtual void Train(const StringPiece &label, float loss) = 0;
|
|
|
|
/**
|
|
* Predict the loss (inverse of score) of current example.
|
|
* Throws away current label-dependent features (so that features for another label/class can now be set).
|
|
*/
|
|
virtual float Predict(const StringPiece &label) = 0;
|
|
|
|
// helper methods for indicator features
|
|
void AddLabelIndependentFeature(const StringPiece &name) {
|
|
AddLabelIndependentFeature(name, 1.0);
|
|
}
|
|
|
|
void AddLabelDependentFeature(const StringPiece &name) {
|
|
AddLabelDependentFeature(name, 1.0);
|
|
}
|
|
|
|
virtual ~Classifier() {}
|
|
|
|
protected:
|
|
/**
|
|
* Escape special characters in a unified way.
|
|
*/
|
|
static std::string EscapeSpecialChars(const std::string &str) {
|
|
std::string out;
|
|
out = Moses::Replace(str, "\\", "_/_");
|
|
out = Moses::Replace(out, "|", "\\/");
|
|
out = Moses::Replace(out, ":", "\\;");
|
|
out = Moses::Replace(out, " ", "\\_");
|
|
return out;
|
|
}
|
|
|
|
const static bool DEBUG = false;
|
|
|
|
};
|
|
|
|
// some of VW settings are hard-coded because they are always needed in our scenario
|
|
// (e.g. quadratic source X target features)
|
|
const std::string VW_DEFAULT_OPTIONS = " --hash all --noconstant -q st -t --ldf_override s ";
|
|
const std::string VW_DEFAULT_PARSER_OPTIONS = " --quiet --hash all --noconstant -q st -t --csoaa_ldf s ";
|
|
|
|
/**
|
|
* Produce VW training file (does not use the VW library!)
|
|
*/
|
|
class VWTrainer : public Classifier
|
|
{
|
|
public:
|
|
VWTrainer(const std::string &outputFile);
|
|
virtual ~VWTrainer();
|
|
|
|
virtual void AddLabelIndependentFeature(const StringPiece &name, float value);
|
|
virtual void AddLabelDependentFeature(const StringPiece &name, float value);
|
|
virtual void Train(const StringPiece &label, float loss);
|
|
virtual float Predict(const StringPiece &label);
|
|
|
|
protected:
|
|
void AddFeature(const StringPiece &name, float value);
|
|
|
|
bool m_isFirstSource, m_isFirstTarget, m_isFirstExample;
|
|
|
|
private:
|
|
boost::iostreams::filtering_ostream m_bfos;
|
|
std::deque<std::string> m_outputBuffer;
|
|
|
|
void WriteBuffer();
|
|
};
|
|
|
|
/**
|
|
* Predict using VW library.
|
|
*/
|
|
class VWPredictor : public Classifier, private boost::noncopyable
|
|
{
|
|
public:
|
|
VWPredictor(const std::string &modelFile, const std::string &vwOptions);
|
|
virtual ~VWPredictor();
|
|
|
|
virtual void AddLabelIndependentFeature(const StringPiece &name, float value);
|
|
virtual void AddLabelDependentFeature(const StringPiece &name, float value);
|
|
virtual void Train(const StringPiece &label, float loss);
|
|
virtual float Predict(const StringPiece &label);
|
|
|
|
friend class ClassifierFactory;
|
|
|
|
protected:
|
|
void AddFeature(const StringPiece &name, float values);
|
|
|
|
::vw *m_VWInstance, *m_VWParser;
|
|
::ezexample *m_ex;
|
|
// if true, then the VW instance is owned by an external party and should NOT be
|
|
// deleted at end; if false, then we own the VW instance and must clean up after it.
|
|
bool m_sharedVwInstance;
|
|
bool m_isFirstSource, m_isFirstTarget;
|
|
|
|
private:
|
|
// instantiation by classifier factory
|
|
VWPredictor(vw * instance, const std::string &vwOption);
|
|
};
|
|
|
|
/**
|
|
* Provider for classifier instances to be used by individual threads.
|
|
*/
|
|
class ClassifierFactory : private boost::noncopyable
|
|
{
|
|
public:
|
|
typedef boost::shared_ptr<Classifier> ClassifierPtr;
|
|
|
|
/**
|
|
* Creates VWPredictor instances to be used by individual threads.
|
|
*/
|
|
ClassifierFactory(const std::string &modelFile, const std::string &vwOptions);
|
|
|
|
/**
|
|
* Creates VWTrainer instances (which write features to a file).
|
|
*/
|
|
ClassifierFactory(const std::string &modelFilePrefix);
|
|
|
|
// return VWPredictor or VWTrainer instance depending on whether we're in training mode
|
|
ClassifierPtr operator()();
|
|
|
|
~ClassifierFactory();
|
|
|
|
private:
|
|
std::string m_vwOptions;
|
|
::vw *m_VWInstance;
|
|
int m_lastId;
|
|
std::string m_modelFilePrefix;
|
|
bool m_gzip;
|
|
boost::mutex m_mutex;
|
|
const bool m_train;
|
|
};
|
|
|
|
} // namespace Discriminative
|
|
|
|
#endif // moses_Classifier_h
|