mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-30 21:39:52 +03:00
Merge ../amunmt.hieu.neuralpt into neuralpt
This commit is contained in:
commit
0cedad9ce0
@ -5,7 +5,14 @@ using namespace std;
|
||||
|
||||
namespace amunmt {
|
||||
|
||||
std::string AmunOutput::Debug() const
|
||||
HypoState::HypoState()
|
||||
{}
|
||||
|
||||
HypoState::~HypoState()
|
||||
{
|
||||
}
|
||||
|
||||
std::string HypoState::Debug() const
|
||||
{
|
||||
stringstream strm;
|
||||
/*
|
||||
|
@ -10,25 +10,31 @@
|
||||
|
||||
namespace amunmt {
|
||||
|
||||
struct AmunOutput
|
||||
struct HypoState
|
||||
{
|
||||
States states;
|
||||
Beam prevHyps;
|
||||
|
||||
float score;
|
||||
|
||||
std::shared_ptr<Sentences> sentences;
|
||||
|
||||
HypoState();
|
||||
~HypoState();
|
||||
|
||||
std::string Debug() const;
|
||||
|
||||
};
|
||||
|
||||
typedef std::vector<AmunOutput> AmunOutputs;
|
||||
typedef std::vector<HypoState> HypoStates;
|
||||
|
||||
////////////////////////////////////////////////////////////////
|
||||
struct AmunInput
|
||||
struct AmunInput : public HypoState
|
||||
{
|
||||
States prevStates;
|
||||
States nextStates;
|
||||
Beam prevHyps;
|
||||
AmunInput(const HypoState &hypoState)
|
||||
:HypoState(hypoState)
|
||||
{
|
||||
}
|
||||
|
||||
Words phrase;
|
||||
};
|
||||
|
@ -27,38 +27,25 @@ MosesPlugin::~MosesPlugin()
|
||||
{
|
||||
}
|
||||
|
||||
size_t MosesPlugin::GetDevices(size_t maxDevices) {
|
||||
int num_gpus = 0; // number of CUDA GPUs
|
||||
HANDLE_ERROR( cudaGetDeviceCount(&num_gpus));
|
||||
std::cerr << "Number of CUDA devices: " << num_gpus << std::endl;
|
||||
HypoState MosesPlugin::SetSource(const std::vector<size_t>& words) {
|
||||
HypoState ret;
|
||||
|
||||
for (int i = 0; i < num_gpus; i++) {
|
||||
cudaDeviceProp dprop;
|
||||
HANDLE_ERROR( cudaGetDeviceProperties(&dprop, i));
|
||||
std::cerr << i << ": " << dprop.name << std::endl;
|
||||
}
|
||||
return (size_t)std::min(num_gpus, (int)maxDevices);
|
||||
}
|
||||
|
||||
AmunOutput MosesPlugin::SetSource(const std::vector<size_t>& words) {
|
||||
AmunOutput ret;
|
||||
|
||||
amunmt::Sentences sentences;
|
||||
sentences.push_back(SentencePtr(new Sentence(god_, 0, words)));
|
||||
ret.sentences.reset(new Sentences());
|
||||
ret.sentences->push_back(SentencePtr(new Sentence(god_, 0, words)));
|
||||
|
||||
// Encode
|
||||
Search &search = god_.GetSearch();
|
||||
size_t numScorers = search.GetScorers().size();
|
||||
|
||||
std::shared_ptr<Histories> histories(new Histories(god_, sentences));
|
||||
std::shared_ptr<Histories> histories(new Histories(god_, *ret.sentences));
|
||||
|
||||
size_t batchSize = sentences.size();
|
||||
size_t batchSize = ret.sentences->size();
|
||||
Beam prevHyps(batchSize, HypothesisPtr(new Hypothesis()));
|
||||
|
||||
States states = search.NewStates();
|
||||
|
||||
search.PreProcess(god_, sentences, histories, prevHyps);
|
||||
search.Encode(sentences, states);
|
||||
search.PreProcess(god_, *ret.sentences, histories, prevHyps);
|
||||
search.Encode(*ret.sentences, states);
|
||||
|
||||
// fill return info
|
||||
ret.states = states;
|
||||
@ -68,9 +55,9 @@ AmunOutput MosesPlugin::SetSource(const std::vector<size_t>& words) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
AmunOutputs MosesPlugin::Score(const AmunInputs &inputs)
|
||||
HypoStates MosesPlugin::Score(const AmunInputs &inputs)
|
||||
{
|
||||
AmunOutputs outputs(inputs.size());
|
||||
HypoStates outputs(inputs.size());
|
||||
|
||||
// TODO
|
||||
|
||||
|
@ -25,7 +25,6 @@ class MosesPlugin {
|
||||
MosesPlugin();
|
||||
~MosesPlugin();
|
||||
|
||||
static size_t GetDevices(size_t = 1);
|
||||
void SetDevice();
|
||||
size_t GetDevice();
|
||||
const amunmt::God &GetGod() const
|
||||
@ -33,9 +32,9 @@ class MosesPlugin {
|
||||
|
||||
void initGod(const std::string& configPath);
|
||||
|
||||
AmunOutput SetSource(const std::vector<size_t>& words);
|
||||
HypoState SetSource(const std::vector<size_t>& words);
|
||||
|
||||
AmunOutputs Score(const AmunInputs &inputs);
|
||||
HypoStates Score(const AmunInputs &inputs);
|
||||
|
||||
private:
|
||||
amunmt::God god_;
|
||||
|
Loading…
Reference in New Issue
Block a user