This commit is contained in:
Hieu Hoang 2017-02-20 15:45:32 +00:00
commit e309f4f369
8 changed files with 162 additions and 99 deletions

View File

@ -22,6 +22,7 @@ def handle_websocket():
while True:
try:
message = wsock.receive()
#print message
if message is not None:
trans = nmt.translate(message.split('\n'))
wsock.send('\n'.join(trans))

View File

@ -2,13 +2,37 @@
from websocket import create_connection
import time
import sys
with open("testfile.en") as f:
filePath = sys.argv[1]
batchSize = int(sys.argv[2])
#print filePath
#print batchSize
def translate( batch ):
ws = create_connection("ws://localhost:8080/translate")
batch = batch[:-1]
#print batch
ws.send(batch)
result=ws.recv()
result = result[:-1]
print(result)
ws.close()
#time.sleep(5)
with open(filePath) as f:
batchCount = 0
batch = ""
for line in f:
ws = create_connection("ws://localhost:8080/translate")
ws.send(line)
result=ws.recv()
print(result)
ws.close()
#time.sleep(5)
#print line
batchCount = batchCount + 1
batch = batch + line
if batchCount == batchSize:
translate(batch)
batchCount = 0
batch = ""
if batchCount:
translate(batch)

View File

@ -23,47 +23,47 @@ int main(int argc, char* argv[]) {
std::setvbuf(stdin, NULL, _IONBF, 0);
boost::timer::cpu_timer timer;
std::string in;
std::string line;
std::size_t lineNum = 0;
size_t miniSize = god.Get<size_t>("mini-batch");
size_t maxiSize = god.Get<size_t>("maxi-batch");
//std::cerr << "mode=" << god.Get("mode") << std::endl;
size_t cpuThreads = god.Get<size_t>("cpu-threads");
LOG(info) << "Setting CPU thread count to " << cpuThreads;
LOG(info) << "Reading input";
size_t totalThreads = cpuThreads;
#ifdef CUDA
size_t gpuThreads = god.Get<size_t>("gpu-threads");
auto devices = god.Get<std::vector<size_t>>("devices");
LOG(info) << "Setting GPU thread count to " << gpuThreads;
totalThreads += gpuThreads * devices.size();
#endif
SentencesPtr maxiBatch(new Sentences());
LOG(info) << "Total number of threads: " << totalThreads;
amunmt_UTIL_THROW_IF2(totalThreads == 0, "Total number of threads is 0");
while (std::getline(god.GetInputStream(), line)) {
maxiBatch->push_back(SentencePtr(new Sentence(god, lineNum++, line)));
{
ThreadPool pool(totalThreads, totalThreads);
LOG(info) << "Reading input";
if (maxiBatch->size() >= maxiSize) {
SentencesPtr maxiBatch(new Sentences());
while (std::getline(god.GetInputStream(), in)) {
maxiBatch->push_back(SentencePtr(new Sentence(god, lineNum++, in)));
if (maxiBatch->size() >= maxiSize) {
god.Enqueue(*maxiBatch, pool);
maxiBatch.reset(new Sentences());
maxiBatch->SortByLength();
while (maxiBatch->size()) {
SentencesPtr miniBatch = maxiBatch->NextMiniBatch(miniSize);
god.GetThreadPool().enqueue(
[&god,miniBatch]{ return TranslationTask(god, miniBatch); }
);
}
maxiBatch.reset(new Sentences());
}
god.Enqueue(*maxiBatch, pool);
}
// last batch
if (maxiBatch->size()) {
maxiBatch->SortByLength();
while (maxiBatch->size()) {
SentencesPtr miniBatch = maxiBatch->NextMiniBatch(miniSize);
god.GetThreadPool().enqueue(
[&god,miniBatch]{ return TranslationTask(god, miniBatch); }
);
}
}
god.Cleanup();
LOG(info) << "Total time: " << timer.format();
god.CleanUp();
//sleep(10);
return 0;
}

View File

@ -30,6 +30,7 @@ God::God()
God::~God()
{
Cleanup();
}
God& God::Init(const std::string& options) {
@ -86,9 +87,22 @@ God& God::Init(int argc, char** argv) {
LoadPrePostProcessing();
size_t totalThreads = GetTotalThreads();
LOG(info) << "Total number of threads: " << totalThreads;
amunmt_UTIL_THROW_IF2(totalThreads == 0, "Total number of threads is 0");
pool_.reset(new ThreadPool(totalThreads, totalThreads));
return *this;
}
void God::Cleanup()
{
pool_.reset();
cpuLoaders_.clear();
gpuLoaders_.clear();
}
void God::LoadScorers() {
LOG(info) << "Loading scorers...";
#ifdef CUDA
@ -236,15 +250,6 @@ std::vector<std::string> God::Postprocess(const std::vector<std::string>& input)
return processed;
}
void God::CleanUp() {
for (Loaders::value_type& loader : cpuLoaders_) {
loader.second.reset(nullptr);
}
for (Loaders::value_type& loader : gpuLoaders_) {
loader.second.reset(nullptr);
}
}
DeviceInfo God::GetNextDevice() const
{
DeviceInfo ret;
@ -283,18 +288,19 @@ Search &God::GetSearch() const
return obj;
}
void God::Enqueue(Sentences &maxiBatch, ThreadPool &pool)
size_t God::GetTotalThreads() const
{
size_t miniSize = Get<size_t>("mini-batch");
maxiBatch.SortByLength();
while (maxiBatch.size()) {
SentencesPtr miniBatch = maxiBatch.NextMiniBatch(miniSize);
pool.enqueue(
[this,miniBatch]{ return TranslationTask(*this, miniBatch); }
);
}
size_t cpuThreads = Get<size_t>("cpu-threads");
LOG(info) << "Setting CPU thread count to " << cpuThreads;
size_t totalThreads = cpuThreads;
#ifdef CUDA
size_t gpuThreads = Get<size_t>("gpu-threads");
auto devices = Get<std::vector<size_t>>("devices");
LOG(info) << "Setting GPU thread count to " << gpuThreads;
totalThreads += gpuThreads * devices.size();
#endif
return totalThreads;
}
}

View File

@ -35,6 +35,7 @@ class God {
God& Init(const std::string&);
God& Init(int argc, char** argv);
void Cleanup();
bool Has(const std::string& key) const {
return config_.Has(key);
@ -66,14 +67,15 @@ class God {
std::vector<std::string> Preprocess(size_t i, const std::vector<std::string>& input) const;
std::vector<std::string> Postprocess(const std::vector<std::string>& input) const;
void CleanUp();
void LoadWeights(const std::string& path);
DeviceInfo GetNextDevice() const;
Search &GetSearch() const;
void Enqueue(Sentences &maxiBatch, ThreadPool &pool);
size_t GetTotalThreads() const;
ThreadPool &GetThreadPool()
{ return *pool_; }
private:
void LoadScorers();
@ -104,6 +106,8 @@ class God {
mutable size_t threadIncr_;
mutable boost::shared_mutex accessLock_;
std::unique_ptr<ThreadPool> pool_;
};
}

View File

@ -8,21 +8,28 @@ using namespace std;
namespace amunmt {
void TranslationTask(const God &god, std::shared_ptr<Sentences> sentences) {
Search &search = god.GetSearch();
OutputCollector &outputCollector = god.GetOutputCollector();
std::shared_ptr<Histories> histories = TranslationTaskSync(god, sentences);
for (size_t i = 0; i < histories->size(); ++i) {
const History &history = *histories->at(i);
size_t lineNum = history.GetLineNum();
std::stringstream strm;
Printer(god, history, strm);
outputCollector.Write(lineNum, strm.str());
}
}
std::shared_ptr<Histories> TranslationTaskSync(const God &god, std::shared_ptr<Sentences> sentences) {
try {
Search &search = god.GetSearch();
std::shared_ptr<Histories> histories = search.Process(god, *sentences);
OutputCollector &outputCollector = god.GetOutputCollector();
for (size_t i = 0; i < histories->size(); ++i) {
const History &history = *histories->at(i);
size_t lineNum = history.GetLineNum();
std::stringstream strm;
Printer(god, history, strm);
outputCollector.Write(lineNum, strm.str());
}
//cerr << "histories=" << histories->size() << endl;
return histories;
}
#ifdef CUDA
catch(thrust::system_error &e)

View File

@ -7,6 +7,8 @@ namespace amunmt {
class God;
void TranslationTask(const God &god, std::shared_ptr<Sentences> sentences);
std::shared_ptr<Histories> TranslationTaskSync(const God &god, std::shared_ptr<Sentences> sentences);
}

View File

@ -11,21 +11,16 @@
#include "common/search.h"
#include "common/printer.h"
#include "common/sentence.h"
#include "common/sentences.h"
#include "common/exception.h"
#include "common/translation_task.h"
using namespace amunmt;
using namespace std;
God god_;
std::unique_ptr<ThreadPool> pool;
std::shared_ptr<Histories> TranslationTask(const std::string& in, size_t taskCounter) {
Search &search = god_.GetSearch();
std::shared_ptr<Sentences> sentences(new Sentences());
sentences->push_back(SentencePtr(new Sentence(god_, taskCounter, in)));
return search.Process(god_, *sentences);
}
void init(const std::string& options) {
god_.Init(options);
size_t totalThreads = god_.Get<size_t>("gpu-threads") + god_.Get<size_t>("cpu-threads");
@ -33,40 +28,64 @@ void init(const std::string& options) {
}
boost::python::list translate(boost::python::list& in) {
size_t cpuThreads = god_.Get<size_t>("cpu-threads");
LOG(info) << "Setting CPU thread count to " << cpuThreads;
size_t totalThreads = cpuThreads;
#ifdef CUDA
size_t gpuThreads = god_.Get<size_t>("gpu-threads");
auto devices = god_.Get<std::vector<size_t>>("devices");
LOG(info) << "Setting GPU thread count to " << gpuThreads;
totalThreads += gpuThreads * devices.size();
#endif
LOG(info) << "Total number of threads: " << totalThreads;
amunmt_UTIL_THROW_IF2(totalThreads == 0, "Total number of threads is 0");
boost::python::list translate(boost::python::list& in)
{
size_t miniSize = god_.Get<size_t>("mini-batch");
size_t maxiSize = god_.Get<size_t>("maxi-batch");
std::vector<std::future< std::shared_ptr<Histories> >> results;
SentencesPtr maxiBatch(new Sentences());
boost::python::list output;
for(int i = 0; i < boost::python::len(in); ++i) {
std::string s = boost::python::extract<std::string>(boost::python::object(in[i]));
results.emplace_back(
pool->enqueue(
[=]{ return TranslationTask(s, i); }
)
);
for(int lineNum = 0; lineNum < boost::python::len(in); ++lineNum) {
std::string line = boost::python::extract<std::string>(boost::python::object(in[lineNum]));
//cerr << "line=" << line << endl;
maxiBatch->push_back(SentencePtr(new Sentence(god_, lineNum, line)));
if (maxiBatch->size() >= maxiSize) {
maxiBatch->SortByLength();
while (maxiBatch->size()) {
SentencesPtr miniBatch = maxiBatch->NextMiniBatch(miniSize);
results.emplace_back(
god_.GetThreadPool().enqueue(
[&god_,miniBatch]{ return TranslationTaskSync(god_, miniBatch); }
)
);
}
maxiBatch.reset(new Sentences());
}
}
size_t lineCounter = 0;
// last batch
if (maxiBatch->size()) {
maxiBatch->SortByLength();
while (maxiBatch->size()) {
SentencesPtr miniBatch = maxiBatch->NextMiniBatch(miniSize);
results.emplace_back(
god_.GetThreadPool().enqueue(
[&god_,miniBatch]{ return TranslationTaskSync(god_, miniBatch); }
)
);
}
}
// resort batch into line number order
Histories allHistories;
for (auto&& result : results) {
std::stringstream ss;
Printer(god_, *result.get().get(), ss);
output.append(ss.str());
std::shared_ptr<Histories> histories = result.get();
allHistories.Append(*histories);
}
allHistories.SortByLineNum();
// output
std::stringstream ss;
Printer(god_, allHistories, ss);
string str = ss.str();
boost::python::list output;
output.append(str);
return output;
}