mirror of
https://github.com/marian-nmt/marian.git
synced 2025-01-07 17:10:15 +03:00
Add python bindings
This commit is contained in:
parent
3a548b3aa0
commit
f4c508c969
@ -25,7 +25,7 @@ else(CUDA_FOUND)
|
||||
endif(CUDA_FOUND)
|
||||
endif(NOCUDA)
|
||||
|
||||
find_package(Boost COMPONENTS system filesystem program_options timer iostreams thread python)
|
||||
find_package(Boost COMPONENTS system filesystem program_options timer iostreams python thread)
|
||||
if(Boost_FOUND)
|
||||
include_directories(${Boost_INCLUDE_DIRS})
|
||||
set(EXT_LIBS ${EXT_LIBS} ${Boost_LIBRARIES})
|
||||
|
49
scripts/amunmt_server.py
Executable file
49
scripts/amunmt_server.py
Executable file
@ -0,0 +1,49 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import sys
|
||||
import os
|
||||
import argparse
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../build/src')
|
||||
import libamunmt as nmt
|
||||
|
||||
from bottle import request, Bottle, abort
|
||||
|
||||
app = Bottle()
|
||||
|
||||
|
||||
@app.route('/translate')
|
||||
def handle_websocket():
|
||||
wsock = request.environ.get('wsgi.websocket')
|
||||
if not wsock:
|
||||
abort(400, 'Expected WebSocket request.')
|
||||
|
||||
while True:
|
||||
try:
|
||||
message = wsock.receive()
|
||||
if message is not None:
|
||||
trans = nmt.translate(message.split('\n'))
|
||||
wsock.send('\n'.join(trans))
|
||||
except WebSocketError:
|
||||
break
|
||||
|
||||
|
||||
def parse_args():
|
||||
""" parse command arguments """
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-c", dest="config")
|
||||
parser.add_argument('-p', dest="port", default=8080, type=int)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
nmt.init("-c {}".format(args.config))
|
||||
|
||||
from gevent.pywsgi import WSGIServer
|
||||
from geventwebsocket import WebSocketError
|
||||
from geventwebsocket.handler import WebSocketHandler
|
||||
server = WSGIServer(("0.0.0.0", args.port), app,
|
||||
handler_class=WebSocketHandler)
|
||||
server.serve_forever()
|
2
src/3rd_party/CMakeLists.txt
vendored
2
src/3rd_party/CMakeLists.txt
vendored
@ -3,6 +3,6 @@ include_directories(.)
|
||||
|
||||
add_subdirectory(yaml-cpp)
|
||||
|
||||
add_library(libcommon OBJECT
|
||||
add_library(libcnpy OBJECT
|
||||
cnpy/cnpy.cpp
|
||||
)
|
||||
|
@ -2,33 +2,7 @@
|
||||
include_directories(.)
|
||||
include_directories(3rd_party)
|
||||
|
||||
add_subdirectory(3rd_party)
|
||||
|
||||
|
||||
|
||||
#add_library(librescorer OBJECT
|
||||
# rescorer/nbest.cpp
|
||||
#)
|
||||
|
||||
|
||||
if(CUDA_FOUND)
|
||||
cuda_add_executable(
|
||||
amun
|
||||
common/decoder_main.cpp
|
||||
common/config.cpp
|
||||
common/exception.cpp
|
||||
common/loader_factory.cpp
|
||||
common/logging.cpp
|
||||
common/vocab.cpp
|
||||
common/utils.cpp
|
||||
common/god.cpp
|
||||
common/history.cpp
|
||||
common/loader.cpp
|
||||
common/printer.cpp
|
||||
common/scorer.cpp
|
||||
common/search.cpp
|
||||
common/sentence.cpp
|
||||
common/processor/bpe.cpp
|
||||
add_library(cpumode OBJECT
|
||||
cpu/mblas/matrix.cpp
|
||||
cpu/mblas/phoenix_functions.cpp
|
||||
cpu/dl4mt/decoder.cpp
|
||||
@ -36,6 +10,45 @@ cuda_add_executable(
|
||||
cpu/dl4mt/gru.cpp
|
||||
cpu/dl4mt/model.cpp
|
||||
cpu/decoder/encoder_decoder.cpp
|
||||
)
|
||||
|
||||
add_library(libcommon OBJECT
|
||||
common/config.cpp
|
||||
common/exception.cpp
|
||||
common/god.cpp
|
||||
common/history.cpp
|
||||
common/loader.cpp
|
||||
common/loader_factory.cpp
|
||||
common/logging.cpp
|
||||
common/printer.cpp
|
||||
common/scorer.cpp
|
||||
common/search.cpp
|
||||
common/sentence.cpp
|
||||
common/processor/bpe.cpp
|
||||
common/utils.cpp
|
||||
common/vocab.cpp
|
||||
)
|
||||
|
||||
if(CUDA_FOUND)
|
||||
|
||||
cuda_add_executable(
|
||||
amun
|
||||
common/decoder_main.cpp
|
||||
common/loader_factory.cu
|
||||
gpu/decoder/ape_penalty.cu
|
||||
gpu/decoder/encoder_decoder.cu
|
||||
gpu/dl4mt/encoder.cu
|
||||
gpu/dl4mt/gru.cu
|
||||
gpu/mblas/matrix.cu
|
||||
gpu/npz_converter.cu
|
||||
$<TARGET_OBJECTS:libcommon>
|
||||
$<TARGET_OBJECTS:cpumode>
|
||||
$<TARGET_OBJECTS:libyaml-cpp>
|
||||
$<TARGET_OBJECTS:libcnpy>
|
||||
)
|
||||
|
||||
cuda_add_library(amunmt SHARED
|
||||
python/amunmt.cpp
|
||||
common/loader_factory.cu
|
||||
gpu/decoder/ape_penalty.cu
|
||||
gpu/decoder/encoder_decoder.cu
|
||||
@ -44,50 +57,32 @@ cuda_add_executable(
|
||||
gpu/dl4mt/gru.cu
|
||||
gpu/npz_converter.cu
|
||||
$<TARGET_OBJECTS:libcommon>
|
||||
$<TARGET_OBJECTS:libcnpy>
|
||||
$<TARGET_OBJECTS:cpumode>
|
||||
$<TARGET_OBJECTS:libyaml-cpp>
|
||||
)
|
||||
|
||||
else(CUDA_FOUND)
|
||||
|
||||
add_executable(
|
||||
amun
|
||||
common/decoder_main.cpp
|
||||
common/config.cpp
|
||||
common/exception.cpp
|
||||
common/loader_factory.cpp
|
||||
common/logging.cpp
|
||||
common/vocab.cpp
|
||||
common/utils.cpp
|
||||
common/god.cpp
|
||||
common/history.cpp
|
||||
common/loader.cpp
|
||||
common/printer.cpp
|
||||
common/scorer.cpp
|
||||
common/search.cpp
|
||||
common/sentence.cpp
|
||||
common/processor/bpe.cpp
|
||||
cpu/mblas/matrix.cpp
|
||||
cpu/mblas/phoenix_functions.cpp
|
||||
cpu/dl4mt/decoder.cpp
|
||||
cpu/dl4mt/encoder.cpp
|
||||
cpu/dl4mt/gru.cpp
|
||||
cpu/dl4mt/model.cpp
|
||||
cpu/decoder/encoder_decoder.cpp
|
||||
$<TARGET_OBJECTS:libcnpy>
|
||||
$<TARGET_OBJECTS:cpumode>
|
||||
$<TARGET_OBJECTS:libcommon>
|
||||
$<TARGET_OBJECTS:libyaml-cpp>
|
||||
)
|
||||
add_library(amunmt SHARED
|
||||
python/amunmt.cpp
|
||||
$<TARGET_OBJECTS:libcnpy>
|
||||
$<TARGET_OBJECTS:cpumode>
|
||||
$<TARGET_OBJECTS:libcommon>
|
||||
$<TARGET_OBJECTS:libyaml-cpp>
|
||||
)
|
||||
endif(CUDA_FOUND)
|
||||
|
||||
|
||||
#cuda_add_executable(
|
||||
# rescorer
|
||||
# rescorer/rescorer_main.cu
|
||||
# mblas/matrix.cu
|
||||
# dl4mt/gru.cu
|
||||
# $<TARGET_OBJECTS:librescorer>
|
||||
# $<TARGET_OBJECTS:libcommon>
|
||||
# $<TARGET_OBJECTS:libyaml-cpp>
|
||||
#)
|
||||
|
||||
foreach(exec amun)
|
||||
foreach(exec amun amunmt)
|
||||
if(CUDA_FOUND)
|
||||
target_link_libraries(${exec} ${EXT_LIBS} cuda)
|
||||
cuda_add_cublas_to_target(${exec})
|
||||
@ -98,3 +93,4 @@ foreach(exec amun)
|
||||
endforeach(exec)
|
||||
|
||||
add_subdirectory(bpe)
|
||||
add_subdirectory(3rd_party)
|
||||
|
@ -24,6 +24,16 @@ God::~God()
|
||||
}
|
||||
}
|
||||
|
||||
God& God::Init(const std::string& options) {
|
||||
std::vector<std::string> args = boost::program_options::split_unix(options);
|
||||
int argc = args.size() + 1;
|
||||
char* argv[argc];
|
||||
argv[0] = const_cast<char*>("bogus");
|
||||
for(int i = 1; i < argc; i++)
|
||||
argv[i] = const_cast<char*>(args[i-1].c_str());
|
||||
return Init(argc, argv);
|
||||
}
|
||||
|
||||
God& God::Init(int argc, char** argv) {
|
||||
return Summon().NonStaticInit(argc, argv);
|
||||
}
|
||||
|
73
src/python/amunmt.cpp
Normal file
73
src/python/amunmt.cpp
Normal file
@ -0,0 +1,73 @@
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <boost/timer/timer.hpp>
|
||||
#include <boost/thread/tss.hpp>
|
||||
#include <boost/python.hpp>
|
||||
|
||||
#include "common/god.h"
|
||||
#include "common/logging.h"
|
||||
#include "common/threadpool.h"
|
||||
#include "common/search.h"
|
||||
#include "common/printer.h"
|
||||
#include "common/sentence.h"
|
||||
|
||||
History TranslationTask(const std::string& in, size_t taskCounter) {
|
||||
#ifdef __APPLE__
|
||||
static boost::thread_specific_ptr<Search> s_search;
|
||||
Search *search = s_search.get();
|
||||
|
||||
if(search == NULL) {
|
||||
LOG(info) << "Created Search for thread " << std::this_thread::get_id();
|
||||
search = new Search(taskCounter);
|
||||
s_search.reset(search);
|
||||
}
|
||||
#else
|
||||
thread_local std::unique_ptr<Search> search;
|
||||
|
||||
if(!search) {
|
||||
LOG(info) << "Created Search for thread " << std::this_thread::get_id();
|
||||
search.reset(new Search(taskCounter));
|
||||
}
|
||||
#endif
|
||||
|
||||
return search->Decode(Sentence(taskCounter, in));
|
||||
}
|
||||
|
||||
void init(const std::string& options) {
|
||||
God::Init(options);
|
||||
}
|
||||
|
||||
boost::python::list translate(boost::python::list& in) {
|
||||
size_t threadCount = God::Get<size_t>("threads");
|
||||
LOG(info) << "Setting number of threads to " << threadCount;
|
||||
|
||||
ThreadPool pool(threadCount);
|
||||
std::vector<std::future<History>> results;
|
||||
|
||||
boost::python::list output;
|
||||
for(int i = 0; i < boost::python::len(in); ++i) {
|
||||
std::string s = boost::python::extract<std::string>(boost::python::object(in[i]));
|
||||
results.emplace_back(
|
||||
pool.enqueue(
|
||||
[=]{ return TranslationTask(s, i); }
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
size_t lineCounter = 0;
|
||||
|
||||
for (auto&& result : results) {
|
||||
std::stringstream ss;
|
||||
Printer(result.get(), lineCounter++, ss);
|
||||
output.append(ss.str());
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
BOOST_PYTHON_MODULE(libamunmt)
|
||||
{
|
||||
boost::python::def("init", init);
|
||||
boost::python::def("translate", translate);
|
||||
}
|
13
src/python/test.py
Normal file
13
src/python/test.py
Normal file
@ -0,0 +1,13 @@
|
||||
import libamunmt as nmt
|
||||
import sys
|
||||
|
||||
nmt.init(sys.argv[1])
|
||||
|
||||
sentences = []
|
||||
for line in sys.stdin:
|
||||
sentences.append(line.rstrip())
|
||||
|
||||
output = nmt.translate(sentences)
|
||||
|
||||
for line in output:
|
||||
sys.stdout.write(line)
|
Loading…
Reference in New Issue
Block a user