Merge ../amunmt.hieu.neuralpt into neuralpt

This commit is contained in:
Hieu Hoang 2017-02-07 17:45:08 +00:00
commit 0cedad9ce0
4 changed files with 32 additions and 33 deletions

View File

@ -5,7 +5,14 @@ using namespace std;
namespace amunmt {
std::string AmunOutput::Debug() const
HypoState::HypoState()
{}
HypoState::~HypoState()
{
}
std::string HypoState::Debug() const
{
stringstream strm;
/*

View File

@ -10,25 +10,31 @@
namespace amunmt {
struct AmunOutput
struct HypoState
{
States states;
Beam prevHyps;
float score;
std::shared_ptr<Sentences> sentences;
HypoState();
~HypoState();
std::string Debug() const;
};
typedef std::vector<AmunOutput> AmunOutputs;
typedef std::vector<HypoState> HypoStates;
////////////////////////////////////////////////////////////////
struct AmunInput
struct AmunInput : public HypoState
{
States prevStates;
States nextStates;
Beam prevHyps;
AmunInput(const HypoState &hypoState)
:HypoState(hypoState)
{
}
Words phrase;
};

View File

@ -27,38 +27,25 @@ MosesPlugin::~MosesPlugin()
{
}
size_t MosesPlugin::GetDevices(size_t maxDevices) {
int num_gpus = 0; // number of CUDA GPUs
HANDLE_ERROR( cudaGetDeviceCount(&num_gpus));
std::cerr << "Number of CUDA devices: " << num_gpus << std::endl;
HypoState MosesPlugin::SetSource(const std::vector<size_t>& words) {
HypoState ret;
for (int i = 0; i < num_gpus; i++) {
cudaDeviceProp dprop;
HANDLE_ERROR( cudaGetDeviceProperties(&dprop, i));
std::cerr << i << ": " << dprop.name << std::endl;
}
return (size_t)std::min(num_gpus, (int)maxDevices);
}
AmunOutput MosesPlugin::SetSource(const std::vector<size_t>& words) {
AmunOutput ret;
amunmt::Sentences sentences;
sentences.push_back(SentencePtr(new Sentence(god_, 0, words)));
ret.sentences.reset(new Sentences());
ret.sentences->push_back(SentencePtr(new Sentence(god_, 0, words)));
// Encode
Search &search = god_.GetSearch();
size_t numScorers = search.GetScorers().size();
std::shared_ptr<Histories> histories(new Histories(god_, sentences));
std::shared_ptr<Histories> histories(new Histories(god_, *ret.sentences));
size_t batchSize = sentences.size();
size_t batchSize = ret.sentences->size();
Beam prevHyps(batchSize, HypothesisPtr(new Hypothesis()));
States states = search.NewStates();
search.PreProcess(god_, sentences, histories, prevHyps);
search.Encode(sentences, states);
search.PreProcess(god_, *ret.sentences, histories, prevHyps);
search.Encode(*ret.sentences, states);
// fill return info
ret.states = states;
@ -68,9 +55,9 @@ AmunOutput MosesPlugin::SetSource(const std::vector<size_t>& words) {
return ret;
}
AmunOutputs MosesPlugin::Score(const AmunInputs &inputs)
HypoStates MosesPlugin::Score(const AmunInputs &inputs)
{
AmunOutputs outputs(inputs.size());
HypoStates outputs(inputs.size());
// TODO

View File

@ -25,7 +25,6 @@ class MosesPlugin {
MosesPlugin();
~MosesPlugin();
static size_t GetDevices(size_t = 1);
void SetDevice();
size_t GetDevice();
const amunmt::God &GetGod() const
@ -33,9 +32,9 @@ class MosesPlugin {
void initGod(const std::string& configPath);
AmunOutput SetSource(const std::vector<size_t>& words);
HypoState SetSource(const std::vector<size_t>& words);
AmunOutputs Score(const AmunInputs &inputs);
HypoStates Score(const AmunInputs &inputs);
private:
amunmt::God god_;