This commit is contained in:
Hieu Hoang 2017-01-21 23:12:58 +00:00
commit 266f56f312
16 changed files with 161 additions and 72 deletions

View File

@ -41,9 +41,9 @@ The project is a standard Cmake out-of-source build:
cmake ..
make -j
If you want to compile only CPU version on a machine with CUDA, add `-DNOCUDA=ON` flag:
If you want to compile only CPU version on a machine with CUDA, add `-DCUDA=OFF` flag:
cmake -DNOCUDA=ON ..
cmake -DCUDA=OFF ..
## Vocabulary files
Vocabulary files (and all other config files) in AmuNMT are by default YAML files. AmuNMT also reads gzipped yml.gz files.

View File

@ -53,6 +53,11 @@
<link>
<name>plugin</name>
<type>2</type>
<locationURI>virtual:/virtual</locationURI>
</link>
<link>
<name>python</name>
<type>2</type>
<locationURI>PARENT-1-PROJECT_LOC/src/plugin</locationURI>
</link>
<link>
@ -70,5 +75,40 @@
<type>1</type>
<locationURI>PARENT-1-PROJECT_LOC/src/cnpy/cnpy.h</locationURI>
</link>
<link>
<name>plugin/nbest.cu</name>
<type>1</type>
<locationURI>PARENT-1-PROJECT_LOC/src/plugin/nbest.cu</locationURI>
</link>
<link>
<name>plugin/nbest.h</name>
<type>1</type>
<locationURI>PARENT-1-PROJECT_LOC/src/plugin/nbest.h</locationURI>
</link>
<link>
<name>plugin/neural_phrase.h</name>
<type>1</type>
<locationURI>PARENT-1-PROJECT_LOC/src/plugin/neural_phrase.h</locationURI>
</link>
<link>
<name>plugin/nmt.cu</name>
<type>1</type>
<locationURI>PARENT-1-PROJECT_LOC/src/plugin/nmt.cu</locationURI>
</link>
<link>
<name>plugin/nmt.h</name>
<type>1</type>
<locationURI>PARENT-1-PROJECT_LOC/src/plugin/nmt.h</locationURI>
</link>
<link>
<name>python/amunmt.cpp</name>
<type>1</type>
<locationURI>PARENT-1-PROJECT_LOC/src/python/amunmt.cpp</locationURI>
</link>
<link>
<name>python/test.py</name>
<type>1</type>
<locationURI>PARENT-1-PROJECT_LOC/src/python/test.py</locationURI>
</link>
</linkedResources>
</projectDescription>

View File

@ -15,8 +15,8 @@
#include "common/translation_task.h"
int main(int argc, char* argv[]) {
God* god = new God();
god->Init(argc, argv);
God god;
god.Init(argc, argv);
std::setvbuf(stdout, NULL, _IONBF, 0);
std::setvbuf(stdin, NULL, _IONBF, 0);
boost::timer::cpu_timer timer;
@ -25,22 +25,22 @@ int main(int argc, char* argv[]) {
std::size_t lineNum = 0;
std::size_t taskCounter = 0;
size_t bunchSize = god->Get<size_t>("bunch-size");
size_t maxBatchSize = god->Get<size_t>("batch-size");
std::cerr << "mode=" << god->Get("mode") << std::endl;
size_t bunchSize = god.Get<size_t>("bunch-size");
size_t maxBatchSize = god.Get<size_t>("batch-size");
std::cerr << "mode=" << god.Get("mode") << std::endl;
if (god->Get<bool>("wipo") || god->Get<size_t>("cpu-threads")) {
if (god.Get<bool>("wipo") || god.Get<size_t>("cpu-threads")) {
bunchSize = 1;
maxBatchSize = 1;
}
size_t cpuThreads = god->Get<size_t>("cpu-threads");
size_t cpuThreads = god.Get<size_t>("cpu-threads");
LOG(info) << "Setting CPU thread count to " << cpuThreads;
size_t totalThreads = cpuThreads;
#ifdef CUDA
size_t gpuThreads = god->Get<size_t>("gpu-threads");
auto devices = god->Get<std::vector<size_t>>("devices");
size_t gpuThreads = god.Get<size_t>("gpu-threads");
auto devices = god.Get<std::vector<size_t>>("devices");
LOG(info) << "Setting GPU thread count to " << gpuThreads;
totalThreads += gpuThreads * devices.size();
#endif
@ -53,12 +53,12 @@ int main(int argc, char* argv[]) {
std::shared_ptr<Sentences> sentences(new Sentences());
while (std::getline(god->GetInputStream(), in)) {
sentences->push_back(SentencePtr(new Sentence(*god, lineNum++, in)));
while (std::getline(god.GetInputStream(), in)) {
sentences->push_back(SentencePtr(new Sentence(god, lineNum++, in)));
if (sentences->size() >= maxBatchSize * bunchSize) {
pool->enqueue(
[=]{ return TranslationTask(*god, sentences, taskCounter, maxBatchSize); }
[=,&god]{ return TranslationTask(god, sentences, taskCounter, maxBatchSize); }
);
sentences.reset(new Sentences());
@ -69,15 +69,14 @@ int main(int argc, char* argv[]) {
if (sentences->size()) {
pool->enqueue(
[=]{ return TranslationTask(*god, sentences, taskCounter, maxBatchSize); }
[=,&god]{ return TranslationTask(god, sentences, taskCounter, maxBatchSize); }
);
}
delete pool;
LOG(info) << "Total time: " << timer.format();
god->CleanUp();
delete god;
god.CleanUp();
return 0;
}

View File

@ -18,9 +18,7 @@ namespace util {
Exception::Exception() throw() {}
Exception::~Exception() throw() {}
Exception::Exception(const Exception& o) throw() {
what_.str(o.what_.str());
}
Exception::Exception(const Exception& o) throw() : what_(o.what_) {}
void Exception::SetLocation(const char *file, unsigned int line, const char *func, const char *child_name, const char *condition) {
/* The child class might have set some text, but we want this to come first.
@ -28,24 +26,26 @@ void Exception::SetLocation(const char *file, unsigned int line, const char *fun
* then child classes would have to accept constructor arguments and pass
* them down.
*/
std::string old_text = what_.str();
what_.str(std::string());
what_ << file << ':' << line;
if (func) what_ << " in " << func << " threw ";
std::string old_text = what_;
std::swap(old_text, what_);
std::stringstream stream;
stream << file << ':' << line;
if (func) stream << " in " << func << " threw ";
if (child_name) {
what_ << child_name;
stream << child_name;
} else {
#ifdef __GXX_RTTI
what_ << typeid(this).name();
stream << typeid(this).name();
#else
what_ << "an exception";
stream << "an exception";
#endif
}
if (condition) {
what_ << " because `" << condition << '\'';
stream << " because `" << condition << '\'';
}
what_ << ".\n";
what_ << old_text;
stream << ".\n";
stream << old_text;
what_ = stream.str();
}
namespace {

View File

@ -17,7 +17,7 @@ class Exception : public std::exception {
virtual ~Exception() throw();
Exception(const Exception& o) throw();
const char *what() const throw() { return what_.str().c_str(); }
const char *what() const throw() { return what_.c_str(); }
// For use by the UTIL_THROW macros.
void SetLocation(
@ -35,7 +35,22 @@ class Exception : public std::exception {
typedef T Identity;
};
std::stringstream what_;
void Append(const char *data) {
what_ += data;
}
void Append(const std::string &data) {
what_ += data;
}
/* void Append(StringPiece data) {
what_.append(data.data(), data.size());
}*/
template <class Data> void Append(const Data &data) {
std::stringstream crazy_slow;
crazy_slow << data;
what_ += crazy_slow.str();
}
std::string what_;
};
/* This implements the normal operator<< for Exception and all its children.
@ -43,10 +58,11 @@ class Exception : public std::exception {
* boost::enable_if.
*/
template <class Except, class Data> typename Except::template ExceptionTag<Except&>::Identity operator<<(Except &e, const Data &data) {
e.what_ << data;
e.Append(data);
return e;
}
#ifdef __GNUC__
#define UTIL_FUNC_NAME __PRETTY_FUNCTION__
#else

View File

@ -12,12 +12,20 @@
#include "common/filter.h"
#include "common/processor/bpe.h"
#include "common/utils.h"
#include "common/search.h"
#include "scorer.h"
#include "loader_factory.h"
God::God()
:numGPUThreads_(0)
{
}
God::~God() {}
God& God::Init(const std::string& options) {
std::vector<std::string> args = boost::program_options::split_unix(options);
int argc = args.size() + 1;
@ -169,12 +177,12 @@ OutputCollector& God::GetOutputCollector() {
return outputCollector_;
}
std::vector<ScorerPtr> God::GetScorers(size_t threadId) {
std::vector<ScorerPtr> God::GetScorers(DeviceType deviceType, size_t threadId) {
std::vector<ScorerPtr> scorers;
size_t cpuThreads = God::Get<size_t>("cpu-threads");
if (threadId < cpuThreads) {
if (deviceType == CPUDevice) {
for (auto&& loader : cpuLoaders_ | boost::adaptors::map_values)
scorers.emplace_back(loader->NewScorer(*this, threadId));
} else {
@ -232,3 +240,23 @@ void God::CleanUp() {
loader.reset(nullptr);
}
}
Search &God::GetSearch(size_t taskCounter)
{
Search *obj;
obj = search_.get();
if (obj == NULL) {
boost::unique_lock<boost::shared_mutex> lock(m_accessLock);
size_t maxGPUThreads = God::Get<size_t>("gpu-threads");
DeviceType deviceType = (numGPUThreads_ < maxGPUThreads) ? GPUDevice : CPUDevice;
++numGPUThreads_;
obj = new Search(*this, deviceType, taskCounter);
search_.reset(obj);
}
assert(obj);
return *obj;
}

View File

@ -1,6 +1,8 @@
#pragma once
#include <memory>
#include <iostream>
#include <boost/thread/tss.hpp>
#include <boost/thread/shared_mutex.hpp>
#include "common/processor/processor.h"
#include "common/config.h"
@ -21,9 +23,11 @@ class Weights;
class Vocab;
class Filter;
class InputFileStream;
class Search;
class God {
public:
God();
virtual ~God();
God& Init(const std::string&);
@ -53,7 +57,7 @@ class God {
BestHypsBase &GetBestHyps(size_t threadId);
std::vector<ScorerPtr> GetScorers(size_t);
std::vector<ScorerPtr> GetScorers(DeviceType deviceType, size_t);
std::vector<std::string> GetScorerNames();
std::map<std::string, float>& GetScorerWeights();
@ -64,6 +68,8 @@ class God {
void LoadWeights(const std::string& path);
Search &GetSearch(size_t taskCounter);
private:
void LoadScorers();
void LoadFiltering();
@ -89,4 +95,8 @@ class God {
std::unique_ptr<InputFileStream> inputStream_;
OutputCollector outputCollector_;
mutable boost::shared_mutex m_accessLock;
mutable boost::thread_specific_ptr<Search> search_;
size_t numGPUThreads_;
};

View File

@ -2,7 +2,7 @@
std::vector<size_t> GetAlignment(const HypothesisPtr& hypothesis) {
std::vector<SoftAlignment> aligns;
HypothesisPtr last = hypothesis;
HypothesisPtr last = hypothesis->GetPrevHyp();
while (last->GetPrevHyp().get() != nullptr) {
aligns.push_back(*(last->GetAlignment(0)));
last = last->GetPrevHyp();
@ -21,3 +21,13 @@ std::vector<size_t> GetAlignment(const HypothesisPtr& hypothesis) {
return alignment;
}
std::string GetAlignmentString(const std::vector<size_t>& alignment) {
std::stringstream alignString;
alignString << " |||";
for (size_t wordIdx = 0; wordIdx < alignment.size(); ++wordIdx) {
alignString << " " << wordIdx << "-" << alignment[wordIdx];
}
return alignString.str();
}

View File

@ -10,17 +10,16 @@
std::vector<size_t> GetAlignment(const HypothesisPtr& hypothesis);
std::string GetAlignmentString(const std::vector<size_t>& alignment);
template <class OStream>
void Printer(God &god, const History& history, OStream& out) {
auto bestTranslation = history.Top();
std::vector<std::string> bestSentenceWords = god.Postprocess(god.GetTargetVocab()(bestTranslation.first));
std::string best;
std::string best = Join(bestSentenceWords);
if (god.Get<bool>("return-alignment")) {
auto alignment = GetAlignment(bestTranslation.second);
best = Join(bestSentenceWords, alignment);
} else {
best = Join(bestSentenceWords);
best += GetAlignmentString(GetAlignment(bestTranslation.second));
}
LOG(progress) << "Best translation: " << best;
@ -38,12 +37,9 @@ void Printer(God &god, const History& history, OStream& out) {
if(god.Get<bool>("wipo")) {
out << "OUT: ";
}
std::string translation;
std::string translation = Join(god.Postprocess(god.GetTargetVocab()(words)));
if (god.Get<bool>("return-alignment")) {
auto alignment = GetAlignment(bestTranslation.second);
translation = Join(god.Postprocess(god.GetTargetVocab()(words)), alignment);
} else {
translation = Join(god.Postprocess(god.GetTargetVocab()(words)));
translation += GetAlignmentString(GetAlignment(bestTranslation.second));
}
out << history.GetLineNum() << " ||| " << translation << " |||";
@ -57,8 +53,7 @@ void Printer(God &god, const History& history, OStream& out) {
out << " ||| " << hypo->GetCost() << std::endl;
}
}
}
else {
} else {
out << best << std::endl;
}
}

View File

@ -9,8 +9,8 @@
using namespace std;
Search::Search(God &god, size_t threadId)
: scorers_(god.GetScorers(threadId)),
Search::Search(God &god, DeviceType deviceType, size_t threadId)
: scorers_(god.GetScorers(deviceType, threadId)),
bestHyps_(god.GetBestHyps(threadId)) {
}

View File

@ -9,7 +9,7 @@
class Search {
public:
Search(God &god, size_t threadId);
Search(God &god, DeviceType deviceType, size_t threadId);
std::shared_ptr<Histories> Decode(God &god, const Sentences& sentences);
private:

View File

@ -5,11 +5,7 @@
#include "printer.h"
void TranslationTask(God &god, std::shared_ptr<Sentences> sentences, size_t taskCounter, size_t maxBatchSize) {
thread_local std::unique_ptr<Search> search;
if(!search) {
LOG(info) << "Created Search for thread " << std::this_thread::get_id();
search.reset(new Search(god, taskCounter));
}
Search &search = god.GetSearch(taskCounter);
try {
Histories allHistories;
@ -23,7 +19,7 @@ void TranslationTask(God &god, std::shared_ptr<Sentences> sentences, size_t task
if (decodeSentences->size() >= maxBatchSize) {
assert(decodeSentences->size());
std::shared_ptr<Histories> histories = search->Decode(god, *decodeSentences);
std::shared_ptr<Histories> histories = search.Decode(god, *decodeSentences);
allHistories.Append(*histories.get());
decodeSentences.reset(new Sentences(taskCounter, bunchId++));
@ -31,7 +27,7 @@ void TranslationTask(God &god, std::shared_ptr<Sentences> sentences, size_t task
}
if (decodeSentences->size()) {
std::shared_ptr<Histories> histories = search->Decode(god, *decodeSentences);
std::shared_ptr<Histories> histories = search.Decode(god, *decodeSentences);
allHistories.Append(*histories.get());
}

View File

@ -10,3 +10,8 @@ typedef std::vector<Word> Words;
const Word EOS = 0;
const Word UNK = 1;
enum DeviceType
{
CPUDevice,
GPUDevice
};

View File

@ -5,11 +5,6 @@
#include "common/base_matrix.h"
#ifdef __APPLE__
#include <boost/thread/tss.hpp>
#include <boost/pool/object_pool.hpp>
#endif
#include "gpu/types-gpu.h"
namespace GPU {

View File

@ -19,7 +19,7 @@ void MosesPlugin::initGod(const std::string& configPath) {
god_ = new God();
god_->Init(configs);
scorers_ = god_->GetScorers(1);
scorers_ = god_->GetScorers(CPUDevice, 1);
bestHyps_ = &god_->GetBestHyps(1);
}

View File

@ -16,16 +16,11 @@
God *god_;
std::shared_ptr<Histories> TranslationTask(const std::string& in, size_t taskCounter) {
thread_local std::unique_ptr<Search> search;
if(!search) {
LOG(info) << "Created Search for thread " << std::this_thread::get_id();
search.reset(new Search(*god_, taskCounter));
}
Search &search = god_->GetSearch(taskCounter);
std::shared_ptr<Sentences> sentences(new Sentences());
sentences->push_back(SentencePtr(new Sentence(*god_, taskCounter, in)));
return search->Decode(*god_, *sentences);
return search.Decode(*god_, *sentences);
}
void init(const std::string& options) {