Add python bindings

This commit is contained in:
Tomasz Dwojak 2016-10-04 11:45:21 +00:00
parent 3a548b3aa0
commit f4c508c969
7 changed files with 202 additions and 61 deletions

View File

@ -25,7 +25,7 @@ else(CUDA_FOUND)
endif(CUDA_FOUND)
endif(NOCUDA)
find_package(Boost COMPONENTS system filesystem program_options timer iostreams thread python)
find_package(Boost COMPONENTS system filesystem program_options timer iostreams python thread)
if(Boost_FOUND)
include_directories(${Boost_INCLUDE_DIRS})
set(EXT_LIBS ${EXT_LIBS} ${Boost_LIBRARIES})

49
scripts/amunmt_server.py Executable file
View File

@ -0,0 +1,49 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import sys
import os
import argparse
sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../build/src')
import libamunmt as nmt
from bottle import request, Bottle, abort
app = Bottle()
@app.route('/translate')
def handle_websocket():
wsock = request.environ.get('wsgi.websocket')
if not wsock:
abort(400, 'Expected WebSocket request.')
while True:
try:
message = wsock.receive()
if message is not None:
trans = nmt.translate(message.split('\n'))
wsock.send('\n'.join(trans))
except WebSocketError:
break
def parse_args():
""" parse command arguments """
parser = argparse.ArgumentParser()
parser.add_argument("-c", dest="config")
parser.add_argument('-p', dest="port", default=8080, type=int)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
nmt.init("-c {}".format(args.config))
from gevent.pywsgi import WSGIServer
from geventwebsocket import WebSocketError
from geventwebsocket.handler import WebSocketHandler
server = WSGIServer(("0.0.0.0", args.port), app,
handler_class=WebSocketHandler)
server.serve_forever()

View File

@ -3,6 +3,6 @@ include_directories(.)
add_subdirectory(yaml-cpp)
add_library(libcommon OBJECT
add_library(libcnpy OBJECT
cnpy/cnpy.cpp
)

View File

@ -2,33 +2,7 @@
include_directories(.)
include_directories(3rd_party)
add_subdirectory(3rd_party)
#add_library(librescorer OBJECT
# rescorer/nbest.cpp
#)
if(CUDA_FOUND)
cuda_add_executable(
amun
common/decoder_main.cpp
common/config.cpp
common/exception.cpp
common/loader_factory.cpp
common/logging.cpp
common/vocab.cpp
common/utils.cpp
common/god.cpp
common/history.cpp
common/loader.cpp
common/printer.cpp
common/scorer.cpp
common/search.cpp
common/sentence.cpp
common/processor/bpe.cpp
add_library(cpumode OBJECT
cpu/mblas/matrix.cpp
cpu/mblas/phoenix_functions.cpp
cpu/dl4mt/decoder.cpp
@ -36,6 +10,45 @@ cuda_add_executable(
cpu/dl4mt/gru.cpp
cpu/dl4mt/model.cpp
cpu/decoder/encoder_decoder.cpp
)
add_library(libcommon OBJECT
common/config.cpp
common/exception.cpp
common/god.cpp
common/history.cpp
common/loader.cpp
common/loader_factory.cpp
common/logging.cpp
common/printer.cpp
common/scorer.cpp
common/search.cpp
common/sentence.cpp
common/processor/bpe.cpp
common/utils.cpp
common/vocab.cpp
)
if(CUDA_FOUND)
cuda_add_executable(
amun
common/decoder_main.cpp
common/loader_factory.cu
gpu/decoder/ape_penalty.cu
gpu/decoder/encoder_decoder.cu
gpu/dl4mt/encoder.cu
gpu/dl4mt/gru.cu
gpu/mblas/matrix.cu
gpu/npz_converter.cu
$<TARGET_OBJECTS:libcommon>
$<TARGET_OBJECTS:cpumode>
$<TARGET_OBJECTS:libyaml-cpp>
$<TARGET_OBJECTS:libcnpy>
)
cuda_add_library(amunmt SHARED
python/amunmt.cpp
common/loader_factory.cu
gpu/decoder/ape_penalty.cu
gpu/decoder/encoder_decoder.cu
@ -44,50 +57,32 @@ cuda_add_executable(
gpu/dl4mt/gru.cu
gpu/npz_converter.cu
$<TARGET_OBJECTS:libcommon>
$<TARGET_OBJECTS:libcnpy>
$<TARGET_OBJECTS:cpumode>
$<TARGET_OBJECTS:libyaml-cpp>
)
else(CUDA_FOUND)
add_executable(
amun
common/decoder_main.cpp
common/config.cpp
common/exception.cpp
common/loader_factory.cpp
common/logging.cpp
common/vocab.cpp
common/utils.cpp
common/god.cpp
common/history.cpp
common/loader.cpp
common/printer.cpp
common/scorer.cpp
common/search.cpp
common/sentence.cpp
common/processor/bpe.cpp
cpu/mblas/matrix.cpp
cpu/mblas/phoenix_functions.cpp
cpu/dl4mt/decoder.cpp
cpu/dl4mt/encoder.cpp
cpu/dl4mt/gru.cpp
cpu/dl4mt/model.cpp
cpu/decoder/encoder_decoder.cpp
$<TARGET_OBJECTS:libcnpy>
$<TARGET_OBJECTS:cpumode>
$<TARGET_OBJECTS:libcommon>
$<TARGET_OBJECTS:libyaml-cpp>
)
add_library(amunmt SHARED
python/amunmt.cpp
$<TARGET_OBJECTS:libcnpy>
$<TARGET_OBJECTS:cpumode>
$<TARGET_OBJECTS:libcommon>
$<TARGET_OBJECTS:libyaml-cpp>
)
endif(CUDA_FOUND)
#cuda_add_executable(
# rescorer
# rescorer/rescorer_main.cu
# mblas/matrix.cu
# dl4mt/gru.cu
# $<TARGET_OBJECTS:librescorer>
# $<TARGET_OBJECTS:libcommon>
# $<TARGET_OBJECTS:libyaml-cpp>
#)
foreach(exec amun)
foreach(exec amun amunmt)
if(CUDA_FOUND)
target_link_libraries(${exec} ${EXT_LIBS} cuda)
cuda_add_cublas_to_target(${exec})
@ -98,3 +93,4 @@ foreach(exec amun)
endforeach(exec)
add_subdirectory(bpe)
add_subdirectory(3rd_party)

View File

@ -24,6 +24,16 @@ God::~God()
}
}
God& God::Init(const std::string& options) {
std::vector<std::string> args = boost::program_options::split_unix(options);
int argc = args.size() + 1;
char* argv[argc];
argv[0] = const_cast<char*>("bogus");
for(int i = 1; i < argc; i++)
argv[i] = const_cast<char*>(args[i-1].c_str());
return Init(argc, argv);
}
God& God::Init(int argc, char** argv) {
return Summon().NonStaticInit(argc, argv);
}

73
src/python/amunmt.cpp Normal file
View File

@ -0,0 +1,73 @@
#include <cstdlib>
#include <iostream>
#include <string>
#include <boost/timer/timer.hpp>
#include <boost/thread/tss.hpp>
#include <boost/python.hpp>
#include "common/god.h"
#include "common/logging.h"
#include "common/threadpool.h"
#include "common/search.h"
#include "common/printer.h"
#include "common/sentence.h"
History TranslationTask(const std::string& in, size_t taskCounter) {
#ifdef __APPLE__
static boost::thread_specific_ptr<Search> s_search;
Search *search = s_search.get();
if(search == NULL) {
LOG(info) << "Created Search for thread " << std::this_thread::get_id();
search = new Search(taskCounter);
s_search.reset(search);
}
#else
thread_local std::unique_ptr<Search> search;
if(!search) {
LOG(info) << "Created Search for thread " << std::this_thread::get_id();
search.reset(new Search(taskCounter));
}
#endif
return search->Decode(Sentence(taskCounter, in));
}
void init(const std::string& options) {
God::Init(options);
}
boost::python::list translate(boost::python::list& in) {
size_t threadCount = God::Get<size_t>("threads");
LOG(info) << "Setting number of threads to " << threadCount;
ThreadPool pool(threadCount);
std::vector<std::future<History>> results;
boost::python::list output;
for(int i = 0; i < boost::python::len(in); ++i) {
std::string s = boost::python::extract<std::string>(boost::python::object(in[i]));
results.emplace_back(
pool.enqueue(
[=]{ return TranslationTask(s, i); }
)
);
}
size_t lineCounter = 0;
for (auto&& result : results) {
std::stringstream ss;
Printer(result.get(), lineCounter++, ss);
output.append(ss.str());
}
return output;
}
BOOST_PYTHON_MODULE(libamunmt)
{
boost::python::def("init", init);
boost::python::def("translate", translate);
}

13
src/python/test.py Normal file
View File

@ -0,0 +1,13 @@
import libamunmt as nmt
import sys
nmt.init(sys.argv[1])
sentences = []
for line in sys.stdin:
sentences.append(line.rstrip())
output = nmt.translate(sentences)
for line in output:
sys.stdout.write(line)