mirror of
https://github.com/marian-nmt/marian.git
synced 2024-12-02 16:39:38 +03:00
merge
This commit is contained in:
commit
e309f4f369
@ -22,6 +22,7 @@ def handle_websocket():
|
||||
while True:
|
||||
try:
|
||||
message = wsock.receive()
|
||||
#print message
|
||||
if message is not None:
|
||||
trans = nmt.translate(message.split('\n'))
|
||||
wsock.send('\n'.join(trans))
|
||||
|
@ -2,13 +2,37 @@
|
||||
|
||||
from websocket import create_connection
|
||||
import time
|
||||
import sys
|
||||
|
||||
with open("testfile.en") as f:
|
||||
filePath = sys.argv[1]
|
||||
batchSize = int(sys.argv[2])
|
||||
#print filePath
|
||||
#print batchSize
|
||||
|
||||
def translate( batch ):
|
||||
ws = create_connection("ws://localhost:8080/translate")
|
||||
|
||||
batch = batch[:-1]
|
||||
#print batch
|
||||
ws.send(batch)
|
||||
result=ws.recv()
|
||||
result = result[:-1]
|
||||
print(result)
|
||||
ws.close()
|
||||
#time.sleep(5)
|
||||
|
||||
with open(filePath) as f:
|
||||
batchCount = 0
|
||||
batch = ""
|
||||
for line in f:
|
||||
ws = create_connection("ws://localhost:8080/translate")
|
||||
ws.send(line)
|
||||
result=ws.recv()
|
||||
print(result)
|
||||
ws.close()
|
||||
#time.sleep(5)
|
||||
#print line
|
||||
batchCount = batchCount + 1
|
||||
batch = batch + line
|
||||
if batchCount == batchSize:
|
||||
translate(batch)
|
||||
|
||||
batchCount = 0
|
||||
batch = ""
|
||||
|
||||
if batchCount:
|
||||
translate(batch)
|
||||
|
@ -23,47 +23,47 @@ int main(int argc, char* argv[]) {
|
||||
std::setvbuf(stdin, NULL, _IONBF, 0);
|
||||
boost::timer::cpu_timer timer;
|
||||
|
||||
std::string in;
|
||||
std::string line;
|
||||
std::size_t lineNum = 0;
|
||||
|
||||
size_t miniSize = god.Get<size_t>("mini-batch");
|
||||
size_t maxiSize = god.Get<size_t>("maxi-batch");
|
||||
//std::cerr << "mode=" << god.Get("mode") << std::endl;
|
||||
|
||||
size_t cpuThreads = god.Get<size_t>("cpu-threads");
|
||||
LOG(info) << "Setting CPU thread count to " << cpuThreads;
|
||||
LOG(info) << "Reading input";
|
||||
|
||||
size_t totalThreads = cpuThreads;
|
||||
#ifdef CUDA
|
||||
size_t gpuThreads = god.Get<size_t>("gpu-threads");
|
||||
auto devices = god.Get<std::vector<size_t>>("devices");
|
||||
LOG(info) << "Setting GPU thread count to " << gpuThreads;
|
||||
totalThreads += gpuThreads * devices.size();
|
||||
#endif
|
||||
SentencesPtr maxiBatch(new Sentences());
|
||||
|
||||
LOG(info) << "Total number of threads: " << totalThreads;
|
||||
amunmt_UTIL_THROW_IF2(totalThreads == 0, "Total number of threads is 0");
|
||||
while (std::getline(god.GetInputStream(), line)) {
|
||||
maxiBatch->push_back(SentencePtr(new Sentence(god, lineNum++, line)));
|
||||
|
||||
{
|
||||
ThreadPool pool(totalThreads, totalThreads);
|
||||
LOG(info) << "Reading input";
|
||||
if (maxiBatch->size() >= maxiSize) {
|
||||
|
||||
SentencesPtr maxiBatch(new Sentences());
|
||||
|
||||
while (std::getline(god.GetInputStream(), in)) {
|
||||
maxiBatch->push_back(SentencePtr(new Sentence(god, lineNum++, in)));
|
||||
|
||||
if (maxiBatch->size() >= maxiSize) {
|
||||
god.Enqueue(*maxiBatch, pool);
|
||||
maxiBatch.reset(new Sentences());
|
||||
maxiBatch->SortByLength();
|
||||
while (maxiBatch->size()) {
|
||||
SentencesPtr miniBatch = maxiBatch->NextMiniBatch(miniSize);
|
||||
god.GetThreadPool().enqueue(
|
||||
[&god,miniBatch]{ return TranslationTask(god, miniBatch); }
|
||||
);
|
||||
}
|
||||
|
||||
maxiBatch.reset(new Sentences());
|
||||
}
|
||||
|
||||
god.Enqueue(*maxiBatch, pool);
|
||||
}
|
||||
|
||||
// last batch
|
||||
if (maxiBatch->size()) {
|
||||
maxiBatch->SortByLength();
|
||||
while (maxiBatch->size()) {
|
||||
SentencesPtr miniBatch = maxiBatch->NextMiniBatch(miniSize);
|
||||
god.GetThreadPool().enqueue(
|
||||
[&god,miniBatch]{ return TranslationTask(god, miniBatch); }
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
god.Cleanup();
|
||||
LOG(info) << "Total time: " << timer.format();
|
||||
god.CleanUp();
|
||||
//sleep(10);
|
||||
return 0;
|
||||
}
|
||||
|
@ -30,6 +30,7 @@ God::God()
|
||||
|
||||
God::~God()
|
||||
{
|
||||
Cleanup();
|
||||
}
|
||||
|
||||
God& God::Init(const std::string& options) {
|
||||
@ -86,9 +87,22 @@ God& God::Init(int argc, char** argv) {
|
||||
|
||||
LoadPrePostProcessing();
|
||||
|
||||
size_t totalThreads = GetTotalThreads();
|
||||
LOG(info) << "Total number of threads: " << totalThreads;
|
||||
amunmt_UTIL_THROW_IF2(totalThreads == 0, "Total number of threads is 0");
|
||||
|
||||
pool_.reset(new ThreadPool(totalThreads, totalThreads));
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
void God::Cleanup()
|
||||
{
|
||||
pool_.reset();
|
||||
cpuLoaders_.clear();
|
||||
gpuLoaders_.clear();
|
||||
}
|
||||
|
||||
void God::LoadScorers() {
|
||||
LOG(info) << "Loading scorers...";
|
||||
#ifdef CUDA
|
||||
@ -236,15 +250,6 @@ std::vector<std::string> God::Postprocess(const std::vector<std::string>& input)
|
||||
return processed;
|
||||
}
|
||||
|
||||
void God::CleanUp() {
|
||||
for (Loaders::value_type& loader : cpuLoaders_) {
|
||||
loader.second.reset(nullptr);
|
||||
}
|
||||
for (Loaders::value_type& loader : gpuLoaders_) {
|
||||
loader.second.reset(nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
DeviceInfo God::GetNextDevice() const
|
||||
{
|
||||
DeviceInfo ret;
|
||||
@ -283,18 +288,19 @@ Search &God::GetSearch() const
|
||||
return obj;
|
||||
}
|
||||
|
||||
void God::Enqueue(Sentences &maxiBatch, ThreadPool &pool)
|
||||
size_t God::GetTotalThreads() const
|
||||
{
|
||||
size_t miniSize = Get<size_t>("mini-batch");
|
||||
|
||||
maxiBatch.SortByLength();
|
||||
while (maxiBatch.size()) {
|
||||
SentencesPtr miniBatch = maxiBatch.NextMiniBatch(miniSize);
|
||||
pool.enqueue(
|
||||
[this,miniBatch]{ return TranslationTask(*this, miniBatch); }
|
||||
);
|
||||
}
|
||||
size_t cpuThreads = Get<size_t>("cpu-threads");
|
||||
LOG(info) << "Setting CPU thread count to " << cpuThreads;
|
||||
|
||||
size_t totalThreads = cpuThreads;
|
||||
#ifdef CUDA
|
||||
size_t gpuThreads = Get<size_t>("gpu-threads");
|
||||
auto devices = Get<std::vector<size_t>>("devices");
|
||||
LOG(info) << "Setting GPU thread count to " << gpuThreads;
|
||||
totalThreads += gpuThreads * devices.size();
|
||||
#endif
|
||||
return totalThreads;
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -35,6 +35,7 @@ class God {
|
||||
God& Init(const std::string&);
|
||||
God& Init(int argc, char** argv);
|
||||
|
||||
void Cleanup();
|
||||
|
||||
bool Has(const std::string& key) const {
|
||||
return config_.Has(key);
|
||||
@ -66,14 +67,15 @@ class God {
|
||||
std::vector<std::string> Preprocess(size_t i, const std::vector<std::string>& input) const;
|
||||
std::vector<std::string> Postprocess(const std::vector<std::string>& input) const;
|
||||
|
||||
void CleanUp();
|
||||
|
||||
void LoadWeights(const std::string& path);
|
||||
|
||||
DeviceInfo GetNextDevice() const;
|
||||
Search &GetSearch() const;
|
||||
|
||||
void Enqueue(Sentences &maxiBatch, ThreadPool &pool);
|
||||
size_t GetTotalThreads() const;
|
||||
ThreadPool &GetThreadPool()
|
||||
{ return *pool_; }
|
||||
|
||||
private:
|
||||
void LoadScorers();
|
||||
@ -104,6 +106,8 @@ class God {
|
||||
|
||||
mutable size_t threadIncr_;
|
||||
mutable boost::shared_mutex accessLock_;
|
||||
|
||||
std::unique_ptr<ThreadPool> pool_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@ -8,21 +8,28 @@ using namespace std;
|
||||
namespace amunmt {
|
||||
|
||||
void TranslationTask(const God &god, std::shared_ptr<Sentences> sentences) {
|
||||
Search &search = god.GetSearch();
|
||||
OutputCollector &outputCollector = god.GetOutputCollector();
|
||||
|
||||
std::shared_ptr<Histories> histories = TranslationTaskSync(god, sentences);
|
||||
|
||||
for (size_t i = 0; i < histories->size(); ++i) {
|
||||
const History &history = *histories->at(i);
|
||||
size_t lineNum = history.GetLineNum();
|
||||
|
||||
std::stringstream strm;
|
||||
Printer(god, history, strm);
|
||||
|
||||
outputCollector.Write(lineNum, strm.str());
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<Histories> TranslationTaskSync(const God &god, std::shared_ptr<Sentences> sentences) {
|
||||
try {
|
||||
Search &search = god.GetSearch();
|
||||
std::shared_ptr<Histories> histories = search.Process(god, *sentences);
|
||||
|
||||
OutputCollector &outputCollector = god.GetOutputCollector();
|
||||
for (size_t i = 0; i < histories->size(); ++i) {
|
||||
const History &history = *histories->at(i);
|
||||
size_t lineNum = history.GetLineNum();
|
||||
|
||||
std::stringstream strm;
|
||||
Printer(god, history, strm);
|
||||
|
||||
outputCollector.Write(lineNum, strm.str());
|
||||
}
|
||||
//cerr << "histories=" << histories->size() << endl;
|
||||
return histories;
|
||||
}
|
||||
#ifdef CUDA
|
||||
catch(thrust::system_error &e)
|
||||
|
@ -7,6 +7,8 @@ namespace amunmt {
|
||||
class God;
|
||||
|
||||
void TranslationTask(const God &god, std::shared_ptr<Sentences> sentences);
|
||||
std::shared_ptr<Histories> TranslationTaskSync(const God &god, std::shared_ptr<Sentences> sentences);
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
@ -11,21 +11,16 @@
|
||||
#include "common/search.h"
|
||||
#include "common/printer.h"
|
||||
#include "common/sentence.h"
|
||||
#include "common/sentences.h"
|
||||
#include "common/exception.h"
|
||||
#include "common/translation_task.h"
|
||||
|
||||
using namespace amunmt;
|
||||
using namespace std;
|
||||
|
||||
God god_;
|
||||
std::unique_ptr<ThreadPool> pool;
|
||||
|
||||
std::shared_ptr<Histories> TranslationTask(const std::string& in, size_t taskCounter) {
|
||||
Search &search = god_.GetSearch();
|
||||
|
||||
std::shared_ptr<Sentences> sentences(new Sentences());
|
||||
sentences->push_back(SentencePtr(new Sentence(god_, taskCounter, in)));
|
||||
return search.Process(god_, *sentences);
|
||||
}
|
||||
|
||||
void init(const std::string& options) {
|
||||
god_.Init(options);
|
||||
size_t totalThreads = god_.Get<size_t>("gpu-threads") + god_.Get<size_t>("cpu-threads");
|
||||
@ -33,40 +28,64 @@ void init(const std::string& options) {
|
||||
}
|
||||
|
||||
|
||||
boost::python::list translate(boost::python::list& in) {
|
||||
size_t cpuThreads = god_.Get<size_t>("cpu-threads");
|
||||
LOG(info) << "Setting CPU thread count to " << cpuThreads;
|
||||
|
||||
size_t totalThreads = cpuThreads;
|
||||
#ifdef CUDA
|
||||
size_t gpuThreads = god_.Get<size_t>("gpu-threads");
|
||||
auto devices = god_.Get<std::vector<size_t>>("devices");
|
||||
LOG(info) << "Setting GPU thread count to " << gpuThreads;
|
||||
totalThreads += gpuThreads * devices.size();
|
||||
#endif
|
||||
|
||||
LOG(info) << "Total number of threads: " << totalThreads;
|
||||
amunmt_UTIL_THROW_IF2(totalThreads == 0, "Total number of threads is 0");
|
||||
boost::python::list translate(boost::python::list& in)
|
||||
{
|
||||
size_t miniSize = god_.Get<size_t>("mini-batch");
|
||||
size_t maxiSize = god_.Get<size_t>("maxi-batch");
|
||||
|
||||
std::vector<std::future< std::shared_ptr<Histories> >> results;
|
||||
SentencesPtr maxiBatch(new Sentences());
|
||||
|
||||
boost::python::list output;
|
||||
for(int i = 0; i < boost::python::len(in); ++i) {
|
||||
std::string s = boost::python::extract<std::string>(boost::python::object(in[i]));
|
||||
results.emplace_back(
|
||||
pool->enqueue(
|
||||
[=]{ return TranslationTask(s, i); }
|
||||
)
|
||||
);
|
||||
for(int lineNum = 0; lineNum < boost::python::len(in); ++lineNum) {
|
||||
std::string line = boost::python::extract<std::string>(boost::python::object(in[lineNum]));
|
||||
//cerr << "line=" << line << endl;
|
||||
|
||||
maxiBatch->push_back(SentencePtr(new Sentence(god_, lineNum, line)));
|
||||
|
||||
if (maxiBatch->size() >= maxiSize) {
|
||||
|
||||
maxiBatch->SortByLength();
|
||||
while (maxiBatch->size()) {
|
||||
SentencesPtr miniBatch = maxiBatch->NextMiniBatch(miniSize);
|
||||
|
||||
results.emplace_back(
|
||||
god_.GetThreadPool().enqueue(
|
||||
[&god_,miniBatch]{ return TranslationTaskSync(god_, miniBatch); }
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
maxiBatch.reset(new Sentences());
|
||||
}
|
||||
}
|
||||
|
||||
size_t lineCounter = 0;
|
||||
// last batch
|
||||
if (maxiBatch->size()) {
|
||||
maxiBatch->SortByLength();
|
||||
while (maxiBatch->size()) {
|
||||
SentencesPtr miniBatch = maxiBatch->NextMiniBatch(miniSize);
|
||||
results.emplace_back(
|
||||
god_.GetThreadPool().enqueue(
|
||||
[&god_,miniBatch]{ return TranslationTaskSync(god_, miniBatch); }
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// resort batch into line number order
|
||||
Histories allHistories;
|
||||
for (auto&& result : results) {
|
||||
std::stringstream ss;
|
||||
Printer(god_, *result.get().get(), ss);
|
||||
output.append(ss.str());
|
||||
std::shared_ptr<Histories> histories = result.get();
|
||||
allHistories.Append(*histories);
|
||||
}
|
||||
allHistories.SortByLineNum();
|
||||
|
||||
// output
|
||||
std::stringstream ss;
|
||||
Printer(god_, allHistories, ss);
|
||||
string str = ss.str();
|
||||
boost::python::list output;
|
||||
output.append(str);
|
||||
|
||||
return output;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user