replace classifier factory with Marcin's TSS

This commit is contained in:
Ales Tamchyna 2015-01-09 11:03:27 +01:00
parent 7d5cf002ee
commit 9694448e6a
7 changed files with 109 additions and 168 deletions

View File

@ -1,6 +0,0 @@
#include "VW.h"
namespace Moses
{
boost::thread_specific_ptr<VWTargetSentenceMap> VW::m_targetSentenceMap;
}

View File

@ -2,7 +2,6 @@
#include <string>
#include <map>
#include <boost/thread/tss.hpp>
#include "moses/FF/StatelessFeatureFunction.h"
#include "moses/TranslationOptionList.h"
@ -17,35 +16,52 @@
#include "Classifier.h"
#include "VWFeatureBase.h"
#include "TabbedSentence.h"
#include "ThreadLocalByFeatureStorage.h"
namespace Moses
{
const std::string VW_DUMMY_LABEL = "1111"; // VW does not use the actual label, other classifiers might
/**
* VW thread-specific data about target sentence.
*/
struct VWTargetSentence {
VWTargetSentence(const Phrase &sentence, const AlignmentInfo &alignment)
: m_sentence(sentence), m_alignment(alignment)
{}
VWTargetSentence() : m_sentence(NULL), m_alignment(NULL) {}
Phrase m_sentence;
AlignmentInfo m_alignment;
void Clear()
{
if (m_sentence) delete m_sentence;
if (m_alignment) delete m_alignment;
}
~VWTargetSentence()
{
Clear();
}
Phrase *m_sentence;
AlignmentInfo *m_alignment;
};
typedef std::map<std::string, VWTargetSentence> VWTargetSentenceMap;
typedef ThreadLocalByFeatureStorage<Discriminative::Classifier, Discriminative::ClassifierFactory &> TLSClassifier;
typedef ThreadLocalByFeatureStorage<VWTargetSentence> TLSTargetSentence;
class VW : public StatelessFeatureFunction
, public TLSTargetSentence
{
public:
VW(const std::string &line)
:StatelessFeatureFunction(1, line), m_train(false)
: StatelessFeatureFunction(1, line)
, TLSTargetSentence(this)
, m_train(false)
{
ReadParameters();
if (m_train) {
m_trainer = new Discriminative::VWTrainer(m_modelPath);
} else {
m_predictorFactory = new Discriminative::VWPredictorFactory(m_modelPath, m_vwOptions);
}
Discriminative::ClassifierFactory *classifierFactory = m_train
? new Discriminative::ClassifierFactory(m_modelPath)
: new Discriminative::ClassifierFactory(m_modelPath, m_vwOptions);
m_tlsClassifier = new TLSClassifier(this, *classifierFactory);
if (! m_normalizer) {
VERBOSE(1, "VW :: No loss function specified, assuming logistic loss.\n");
@ -74,9 +90,7 @@ public:
void EvaluateTranslationOptionListWithSourceContext(const InputType &input
, const TranslationOptionList &translationOptionList) const
{
Discriminative::Classifier *classifier = m_train
? m_trainer
: (Discriminative::Classifier *)m_predictorFactory->Acquire();
Discriminative::Classifier *classifier = m_tlsClassifier->GetStored();
if (translationOptionList.size() == 0)
return; // nothing to do
@ -112,9 +126,6 @@ public:
}
}
if (!m_train)
m_predictorFactory->Release(static_cast<Discriminative::VWPredictor *>(classifier));
(*m_normalizer)(losses);
for(iterTransOpt = translationOptionList.begin(), iterLoss = losses.begin() ;
@ -167,23 +178,23 @@ public:
const TabbedSentence& tabbedSentence = static_cast<const TabbedSentence&>(source);
UTIL_THROW_IF2(tabbedSentence.GetColumns().size() < 2, "TabbedSentence must contain target<tab>alignment");
if (! m_targetSentenceMap.get())
m_targetSentenceMap.reset(new VWTargetSentenceMap());
// target sentence represented as a phrase
Phrase target;
target.CreateFromString(
Output,
StaticData::Instance().GetOutputFactorOrder(),
tabbedSentence.GetColumns()[0],
NULL);
Phrase *target = new Phrase();
target->CreateFromString(
Output
, StaticData::Instance().GetOutputFactorOrder()
, tabbedSentence.GetColumns()[0]
, NULL);
// word alignment between source and target sentence
// we don't store alignment info in AlignmentInfoCollection because we keep alignments of whole
// sentences, not phrases
AlignmentInfo alignment(tabbedSentence.GetColumns()[1]);
AlignmentInfo *alignment = new AlignmentInfo(tabbedSentence.GetColumns()[1]);
(*m_targetSentenceMap).insert(std::make_pair(GetScoreProducerDescription(), VWTargetSentence(target, alignment)));
VWTargetSentence &targetSent = *GetStored();
targetSent.Clear();
targetSent.m_sentence = target;
targetSent.m_alignment = alignment;
}
@ -195,11 +206,11 @@ private:
bool IsCorrectTranslationOption(const TranslationOption &topt) const {
size_t sourceStart = topt.GetSourceWordsRange().GetStartPos();
const VWTargetSentence &targetSentence = m_targetSentenceMap->find(GetScoreProducerDescription())->second;
const VWTargetSentence &targetSentence = *GetStored();
// get the left-most alignment point withitn sourceRange
std::set<size_t> aligned;
while ((aligned = targetSentence.m_alignment.GetAlignmentsForSource(sourceStart)).empty())
while ((aligned = targetSentence.m_alignment->GetAlignmentsForSource(sourceStart)).empty())
sourceStart++;
size_t targetSentOffset = *aligned.begin(); // index of first aligned target word covered in source span
@ -216,7 +227,7 @@ private:
size_t startAt = targetSentOffset - toptOffset;
bool matches = true;
for (size_t i = 0; i < tphrase.GetSize(); i++) {
if (tphrase.GetWord(i) != targetSentence.m_sentence.GetWord(startAt + i)) {
if (tphrase.GetWord(i) != targetSentence.m_sentence->GetWord(startAt + i)) {
matches = false;
break;
}
@ -229,10 +240,7 @@ private:
std::string m_modelPath;
std::string m_vwOptions;
Discriminative::Normalizer *m_normalizer = NULL;
Discriminative::Classifier *m_trainer = NULL;
Discriminative::VWPredictorFactory *m_predictorFactory = NULL;
static boost::thread_specific_ptr<VWTargetSentenceMap> m_targetSentenceMap;
TLSClassifier *m_tlsClassifier;
};
}

View File

@ -126,7 +126,7 @@ public:
virtual void Train(const StringPiece &label, float loss);
virtual float Predict(const StringPiece &label);
friend class VWPredictorFactory;
friend class ClassifierFactory;
protected:
void AddFeature(const StringPiece &name, float values);
@ -137,45 +137,42 @@ protected:
// if true, then the VW instance is owned by an external party and should NOT be
// deleted at end; if false, then we own the VW instance and must clean up after it.
bool m_sharedVwInstance;
int m_index;
bool m_isFirstSource, m_isFirstTarget;
~VWPredictor();
private:
// instantiation by VWPredictorFactory
VWPredictor(vw * instance, int index, const std::string &vwOption);
// instantiation by classifier factory
VWPredictor(vw * instance, const std::string &vwOption);
};
/**
* Object pool of VWPredictors.
*/
class VWPredictorFactory : private boost::noncopyable
* Provider for classifier instances to be used by individual threads.
*/
class ClassifierFactory : private boost::noncopyable
{
public:
VWPredictorFactory(const std::string &modelFile, const std::string &vwOptions, const int poolSize = DEFAULT_POOL_SIZE);
/**
* Creates VWPredictor instances to be used by individual threads.
*/
ClassifierFactory(const std::string &modelFile, const std::string &vwOptions);
/**
* Get an instance of VWPredictor from the pool.
*/
VWPredictor * Acquire();
* Creates VWTrainer instances (which write features to a file).
*/
ClassifierFactory(const std::string &modelFilePrefix);
/**
* Release a VWPredictor instance.
*/
void Release(VWPredictor *vwpred);
~VWPredictorFactory();
// return VWPredictor or VWTrainer instance depending on whether we're in training mode
Classifier *operator()();
private:
std::string m_vwOptions;
::vw *m_VWInstance;
int m_firstFree;
std::vector<int> m_nextFree;
std::vector<VWPredictor *> m_predictors;
int m_lastId;
std::string m_modelFilePrefix;
bool m_gzip;
boost::mutex m_mutex;
boost::condition_variable m_cond;
const static int DEFAULT_POOL_SIZE = 128;
const bool m_train;
};
} // namespace Discriminative

37
vw/ClassifierFactory.cpp Normal file
View File

@ -0,0 +1,37 @@
#include "Classifier.h"
#include "vw.h"
#include "../moses/Util.h"
#include <iostream>
namespace Discriminative
{
ClassifierFactory::ClassifierFactory(const std::string &modelFile, const std::string &vwOptions)
: m_vwOptions(vwOptions), m_train(false)
{
m_VWInstance = VW::initialize(VW_DEFAULT_OPTIONS + " -i " + modelFile + vwOptions);
}
ClassifierFactory::ClassifierFactory(const std::string &modelFilePrefix)
: m_lastId(0), m_train(true)
{
if (modelFilePrefix.size() > 3 && modelFilePrefix.substr(modelFilePrefix.size() - 3, 3) == ".gz") {
m_modelFilePrefix = modelFilePrefix.substr(0, modelFilePrefix.size() - 3);
m_gzip = true;
} else {
m_modelFilePrefix = modelFilePrefix;
m_gzip = false;
}
}
Classifier *ClassifierFactory::operator()()
{
if (m_train) {
boost::unique_lock<boost::mutex> lock(m_mutex); // avoid possible race for m_lastId
return new VWTrainer(m_modelFilePrefix + "." + Moses::SPrint(m_lastId++) + (m_gzip ? ".gz" : ""));
} else {
return new VWPredictor(m_VWInstance, VW_DEFAULT_PARSER_OPTIONS + m_vwOptions);
}
}
}

View File

@ -8,10 +8,12 @@ local with-vw = [ option.get "with-vw" ] ;
if $(with-vw) {
lib vw : : <search>$(with-vw)/lib ;
lib allreduce : : <search>$(with-vw)/lib ;
obj VWPredictorFactory.o : VWPredictorFactory.cpp headers : <include>$(with-vw)/include/vowpalwabbit ;
obj ClassifierFactory.o : ClassifierFactory.cpp headers : <include>$(with-vw)/include/vowpalwabbit ;
obj VWPredictor.o : VWPredictor.cpp headers : <include>$(with-vw)/include/vowpalwabbit ;
alias vw_objects : VWPredictor.o VWPredictorFactory.o vw allreduce : : : <library>boost_program_options ;
lib classifier : [ glob *.cpp : VWPredictor.cpp VWPredictorFactory.cpp ] vw_objects headers ;
alias vw_objects : VWPredictor.o ClassifierFactory.o vw allreduce : : : <library>boost_program_options ;
lib classifier : [ glob *.cpp : VWPredictor.cpp ClassifierFactory.cpp ] vw_objects headers ;
exe vwtrainer : MainVW deps ;
echo "Linking with Vowpal Wabbit" ;

View File

@ -18,13 +18,12 @@ VWPredictor::VWPredictor(const string &modelFile, const string &vwOptions)
m_isFirstSource = m_isFirstTarget = true;
}
VWPredictor::VWPredictor(vw *instance, int index, const string &vwOptions)
VWPredictor::VWPredictor(vw *instance, const string &vwOptions)
{
m_VWInstance = instance;
m_VWParser = VW::initialize(vwOptions + " --noop");
m_sharedVwInstance = true;
m_ex = new ::ezexample(m_VWInstance, false, m_VWParser);
m_index = index;
m_isFirstSource = m_isFirstTarget = true;
}

View File

@ -1,96 +0,0 @@
#include "Classifier.h"
#include "vw.h"
#include <iostream>
using namespace std;
namespace Discriminative
{
const int EMPTY_LIST = -1;
const int BAD_LIST_POINTER = -2;
VWPredictorFactory::VWPredictorFactory(
const string &modelFile,
const string &vwOptions,
const int poolSize)
{
m_VWInstance = VW::initialize(VW_DEFAULT_OPTIONS + " -i " + modelFile + vwOptions);
if (poolSize < 1)
throw runtime_error("VWPredictorFactory pool size must be greater than zero!");
int lastFree = EMPTY_LIST;
if (VWPredictor::DEBUG) std::cerr << "VW :: filling VWPredictor pool: ";
for (int i = 0; i < poolSize; ++i)
{
m_predictors.push_back(new VWPredictor(m_VWInstance, i, VW_DEFAULT_PARSER_OPTIONS + vwOptions));
m_nextFree.push_back(lastFree);
lastFree = i;
if (VWPredictor::DEBUG) std::cerr << ".";
}
if (VWPredictor::DEBUG) std::cerr << "done.\n";
m_firstFree = lastFree;
}
VWPredictorFactory::~VWPredictorFactory()
{
boost::unique_lock<boost::mutex> lock(m_mutex);
size_t count = 0;
int prev = EMPTY_LIST;
for (int cur = m_firstFree; cur != EMPTY_LIST; cur = m_nextFree[cur])
{
if (cur == BAD_LIST_POINTER)
throw std::runtime_error("VWPredictorFactory::~VWPredictorFactory -- bad free list!");
++count;
if (prev == EMPTY_LIST)
m_firstFree = BAD_LIST_POINTER;
else
m_nextFree[prev] = BAD_LIST_POINTER;
prev = cur;
}
if (prev != EMPTY_LIST)
m_nextFree[prev] = BAD_LIST_POINTER;
if (count != m_nextFree.size())
throw std::runtime_error("VWPredictorFactory::~VWPredictorFactory -- not all consumers were returned to pool at destruction time!");
for (size_t s = 0; s < m_predictors.size(); ++s)
{
delete m_predictors[s];
m_predictors[s] = NULL;
}
m_predictors.clear();
VW::finish(*m_VWInstance);
}
VWPredictor *VWPredictorFactory::Acquire()
{
boost::unique_lock<boost::mutex> lock(m_mutex);
while (m_firstFree == EMPTY_LIST)
m_cond.wait(lock);
int free = m_firstFree;
m_firstFree = m_nextFree[free];
return m_predictors[free];
}
void VWPredictorFactory::Release(VWPredictor *vwpred)
{
// use scope block to handle the lock
{
boost::unique_lock<boost::mutex> lock(m_mutex);
int index = vwpred->m_index;
if (index < 0 || index >= (int)m_predictors.size())
throw std::runtime_error("bad index at VWPredictorFactory::Release");
if (vwpred != m_predictors[index])
throw std::runtime_error("mismatched pointer at VWPredictorFactory::Release");
m_nextFree[index] = m_firstFree;
m_firstFree = index;
}
// release the semaphore *AFTER* the lock goes out of scope
m_cond.notify_one();
}
} // namespace Discriminative