mosesdecoder/vw/ClassifierFactory.cpp

49 lines
1.3 KiB
C++
Raw Normal View History

#include "Classifier.h"
#include "vw.h"
#include "../moses/Util.h"
#include <iostream>
#include <boost/algorithm/string/predicate.hpp>
using namespace boost::algorithm;
namespace Discriminative
{
2015-01-14 14:07:42 +03:00
ClassifierFactory::ClassifierFactory(const std::string &modelFile, const std::string &vwOptions)
: m_vwOptions(vwOptions), m_train(false)
{
m_VWInstance = VW::initialize(VW_DEFAULT_OPTIONS + " -i " + modelFile + vwOptions);
}
2015-01-14 14:07:42 +03:00
ClassifierFactory::ClassifierFactory(const std::string &modelFilePrefix)
: m_lastId(0), m_train(true)
2015-01-14 14:07:42 +03:00
{
if (ends_with(modelFilePrefix, ".gz")) {
m_modelFilePrefix = modelFilePrefix.substr(0, modelFilePrefix.size() - 3);
m_gzip = true;
} else {
m_modelFilePrefix = modelFilePrefix;
m_gzip = false;
}
}
2015-01-14 14:07:42 +03:00
ClassifierFactory::~ClassifierFactory()
2015-01-09 14:02:39 +03:00
{
if (! m_train)
VW::finish(*m_VWInstance);
}
2015-01-14 14:07:42 +03:00
ClassifierFactory::ClassifierPtr ClassifierFactory::operator()()
{
if (m_train) {
boost::unique_lock<boost::mutex> lock(m_mutex); // avoid possible race for m_lastId
return ClassifierFactory::ClassifierPtr(
2015-01-14 14:07:42 +03:00
new VWTrainer(m_modelFilePrefix + "." + Moses::SPrint(m_lastId++) + (m_gzip ? ".gz" : "")));
} else {
return ClassifierFactory::ClassifierPtr(
2015-01-14 14:07:42 +03:00
new VWPredictor(m_VWInstance, VW_DEFAULT_PARSER_OPTIONS + m_vwOptions));
}
}
}