mirror of
https://github.com/marian-nmt/marian.git
synced 2024-12-11 09:54:22 +03:00
clean-up, get rid of GH models
This commit is contained in:
parent
1da629da20
commit
4e4ee95633
@ -1,5 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "bahdanau/model.h"
|
||||
#include "bahdanau/encoder.h"
|
||||
#include "bahdanau/decoder.h"
|
@ -1,256 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "mblas/matrix.h"
|
||||
#include "bahdanau/model.h"
|
||||
|
||||
class Decoder {
|
||||
private:
|
||||
template <class Weights>
|
||||
class Embeddings {
|
||||
public:
|
||||
Embeddings(const Weights& model)
|
||||
: w_(model)
|
||||
{}
|
||||
|
||||
void Lookup(mblas::Matrix& Rows, const std::vector<size_t>& ids) {
|
||||
using namespace mblas;
|
||||
Assemble(Rows, w_.E_, ids);
|
||||
Broadcast(_1 + _2, Rows, w_.EB_);
|
||||
}
|
||||
|
||||
private:
|
||||
const Weights& w_;
|
||||
};
|
||||
|
||||
template <class Weights>
|
||||
class RNN {
|
||||
public:
|
||||
RNN(const Weights& model)
|
||||
: w_(model) {}
|
||||
|
||||
void InitializeState(mblas::Matrix& State,
|
||||
const mblas::Matrix& SourceContext,
|
||||
const size_t batchSize = 1) {
|
||||
using namespace mblas;
|
||||
CopyRow(Temp1_, SourceContext, 0, 1000);
|
||||
Temp2_.Clear();
|
||||
Temp2_.Resize(batchSize, 1000, 0.0);
|
||||
Broadcast(_1 + _2, Temp2_, Temp1_);
|
||||
Prod(State, Temp2_, w_.Ws_);
|
||||
Broadcast(Tanh(_1 + _2), State, w_.WsB_);
|
||||
}
|
||||
|
||||
mblas::Matrix& GetNextState(mblas::Matrix& State,
|
||||
const mblas::Matrix& Embd,
|
||||
const mblas::Matrix& PrevState,
|
||||
const mblas::Matrix& Context) {
|
||||
using namespace mblas;
|
||||
|
||||
Prod(Z_, Embd, w_.Wz_);
|
||||
Prod(Temp1_, PrevState, w_.Uz_);
|
||||
Prod(Temp2_, Context, w_.Cz_);
|
||||
Element(Logit(_1 + _2 + _3),
|
||||
Z_, Temp1_, Temp2_);
|
||||
|
||||
Prod(R_, Embd, w_.Wr_);
|
||||
Prod(Temp1_, PrevState, w_.Ur_);
|
||||
Prod(Temp2_, Context, w_.Cr_);
|
||||
Element(Logit(_1 + _2 + _3),
|
||||
R_, Temp1_, Temp2_);
|
||||
|
||||
Prod(S_, Embd, w_.W_);
|
||||
Broadcast(_1 + _2, S_, w_.B_); // Broadcasting row-wise
|
||||
Prod(Temp1_, Element(_1 * _2, R_, PrevState), w_.U_);
|
||||
Prod(Temp2_, Context, w_.C_);
|
||||
|
||||
Element(Tanh(_1 + _2 + _3), S_, Temp1_, Temp2_);
|
||||
|
||||
Element((1.0 - _1) * _2 + _1 * _3,
|
||||
Z_, PrevState, S_);
|
||||
|
||||
State.Resize(Z_.Rows(), Z_.Cols());
|
||||
Swap(State, Z_);
|
||||
|
||||
return State;
|
||||
}
|
||||
|
||||
private:
|
||||
// Model matrices
|
||||
const Weights& w_;
|
||||
|
||||
// reused to avoid allocation
|
||||
mblas::Matrix Z_;
|
||||
mblas::Matrix R_;
|
||||
mblas::Matrix S_;
|
||||
|
||||
mblas::Matrix Temp1_;
|
||||
mblas::Matrix Temp2_;
|
||||
};
|
||||
|
||||
template <class Weights>
|
||||
class Alignment {
|
||||
public:
|
||||
Alignment(const Weights& model)
|
||||
: w_(model)
|
||||
{}
|
||||
|
||||
void GetContext(mblas::Matrix& Context,
|
||||
const mblas::Matrix& SourceContext,
|
||||
const mblas::Matrix& PrevState) {
|
||||
using namespace mblas;
|
||||
|
||||
Prod(Temp1_, SourceContext, w_.Ua_);
|
||||
Prod(Temp2_, PrevState, w_.Wa_);
|
||||
|
||||
Broadcast(Tanh(_1 + _2), Temp1_, Temp2_);
|
||||
|
||||
Prod(A_, w_.Va_, Temp1_, false, true);
|
||||
size_t rows1 = SourceContext.Rows();
|
||||
size_t rows2 = PrevState.Rows();
|
||||
A_.Reshape(rows2, rows1); // due to broadcasting above
|
||||
|
||||
mblas::Softmax(A_);
|
||||
Prod(Context, A_, SourceContext);
|
||||
}
|
||||
|
||||
private:
|
||||
const Weights& w_;
|
||||
|
||||
mblas::Matrix Temp1_;
|
||||
mblas::Matrix Temp2_;
|
||||
mblas::Matrix A_;
|
||||
|
||||
mblas::Matrix Ones_;
|
||||
mblas::Matrix Sums_;
|
||||
};
|
||||
|
||||
template <class Weights>
|
||||
class Softmax {
|
||||
public:
|
||||
Softmax(const Weights& model)
|
||||
: w_(model), filtered_(false)
|
||||
{}
|
||||
|
||||
void GetProbs(mblas::Matrix& Probs,
|
||||
const mblas::Matrix& PrevState,
|
||||
const mblas::Matrix& PrevEmbd,
|
||||
const mblas::Matrix& Context) {
|
||||
|
||||
using namespace mblas;
|
||||
|
||||
Prod(T_, PrevState, w_.Uo_);
|
||||
|
||||
Prod(Temp1_, PrevEmbd, w_.Vo_);
|
||||
Prod(Temp2_, Context, w_.Co_);
|
||||
Element(_1 + _2 + _3, T_, Temp1_, Temp2_);
|
||||
Broadcast(_1 + _2, T_, w_.UoB_); // Broadcasting row-wise
|
||||
PairwiseReduce(Max(_1, _2), T_);
|
||||
|
||||
if(filtered_) { // use only filtered vocabulary for SoftMax
|
||||
Prod(Probs, T_, FilteredWo_);
|
||||
Broadcast(_1 + _2, Probs, FilteredWoB_); // Broadcasting row-wise
|
||||
}
|
||||
else {
|
||||
Prod(Probs, T_, w_.Wo_);
|
||||
Broadcast(_1 + _2, Probs, w_.WoB_); // Broadcasting row-wise
|
||||
}
|
||||
mblas::Softmax(Probs);
|
||||
}
|
||||
|
||||
void Filter(const std::vector<size_t>& ids) {
|
||||
using namespace mblas;
|
||||
|
||||
Matrix TempWo;
|
||||
Transpose(TempWo, w_.Wo_);
|
||||
Assemble(FilteredWo_, TempWo, ids);
|
||||
Transpose(FilteredWo_);
|
||||
|
||||
Matrix TempWoB;
|
||||
Transpose(TempWoB, w_.WoB_);
|
||||
Assemble(FilteredWoB_, TempWoB, ids);
|
||||
Transpose(FilteredWoB_);
|
||||
|
||||
filtered_ = true;
|
||||
}
|
||||
|
||||
private:
|
||||
const Weights& w_;
|
||||
|
||||
bool filtered_;
|
||||
mblas::Matrix FilteredWo_;
|
||||
mblas::Matrix FilteredWoB_;
|
||||
|
||||
mblas::Matrix T_;
|
||||
mblas::Matrix Temp1_;
|
||||
mblas::Matrix Temp2_;
|
||||
|
||||
mblas::Matrix Ones_;
|
||||
mblas::Matrix Sums_;
|
||||
};
|
||||
|
||||
public:
|
||||
Decoder(const Weights& model)
|
||||
: embeddings_(model.decEmbeddings_),
|
||||
rnn_(model.decRnn_), alignment_(model.decAlignment_),
|
||||
softmax_(model.decSoftmax_)
|
||||
{}
|
||||
|
||||
void EmptyState(mblas::Matrix& State, const mblas::Matrix& SourceContext,
|
||||
size_t batchSize = 1) {
|
||||
State.Resize(batchSize, 1000);
|
||||
rnn_.InitializeState(State, SourceContext, batchSize);
|
||||
}
|
||||
|
||||
void EmptyEmbedding(mblas::Matrix& Embedding, size_t batchSize = 1) {
|
||||
Embedding.Clear();
|
||||
Embedding.Resize(batchSize, 620, 0);
|
||||
}
|
||||
|
||||
void MakeStep(mblas::Matrix& NextState,
|
||||
mblas::Matrix& NextEmbeddings,
|
||||
mblas::Matrix& Probs,
|
||||
const std::vector<size_t>& batch,
|
||||
const mblas::Matrix& State,
|
||||
const mblas::Matrix& Embeddings,
|
||||
const mblas::Matrix& SourceContext) {
|
||||
GetProbs(Probs, AlignedSourceContext_,
|
||||
State, Embeddings, SourceContext);
|
||||
Lookup(NextEmbeddings, batch);
|
||||
GetNextState(NextState, NextEmbeddings,
|
||||
State, AlignedSourceContext_);
|
||||
}
|
||||
|
||||
//private:
|
||||
|
||||
void Filter(const std::vector<size_t>& ids) {
|
||||
softmax_.Filter(ids);
|
||||
}
|
||||
|
||||
void GetProbs(mblas::Matrix& Probs,
|
||||
mblas::Matrix& AlignedSourceContext,
|
||||
const mblas::Matrix& PrevState,
|
||||
const mblas::Matrix& PrevEmbedding,
|
||||
const mblas::Matrix& SourceContext) {
|
||||
alignment_.GetContext(AlignedSourceContext, SourceContext, PrevState);
|
||||
softmax_.GetProbs(Probs, PrevState, PrevEmbedding, AlignedSourceContext);
|
||||
}
|
||||
|
||||
void Lookup(mblas::Matrix& Embedding, const std::vector<size_t>& w) {
|
||||
embeddings_.Lookup(Embedding, w);
|
||||
}
|
||||
|
||||
void GetNextState(mblas::Matrix& State,
|
||||
const mblas::Matrix& Embedding,
|
||||
const mblas::Matrix& PrevState,
|
||||
const mblas::Matrix& AlignedSourceContext) {
|
||||
rnn_.GetNextState(State, Embedding, PrevState, AlignedSourceContext);
|
||||
}
|
||||
|
||||
private:
|
||||
mblas::Matrix AlignedSourceContext_;
|
||||
|
||||
Embeddings<Weights::DecEmbeddings> embeddings_;
|
||||
RNN<Weights::DecRnn> rnn_;
|
||||
Alignment<Weights::DecAlignment> alignment_;
|
||||
Softmax<Weights::DecSoftmax> softmax_;
|
||||
};
|
@ -1,118 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "mblas/matrix.h"
|
||||
#include "bahdanau/model.h"
|
||||
|
||||
class Encoder {
|
||||
private:
|
||||
template <class Weights>
|
||||
class Embeddings {
|
||||
public:
|
||||
Embeddings(const Weights& model)
|
||||
: w_(model)
|
||||
{}
|
||||
|
||||
void Lookup(mblas::Matrix& Row, size_t i) {
|
||||
using namespace mblas;
|
||||
CopyRow(Row, w_.E_, i);
|
||||
Element(_1 + _2,
|
||||
Row, w_.EB_);
|
||||
}
|
||||
|
||||
private:
|
||||
const Weights& w_;
|
||||
};
|
||||
|
||||
template <class Weights>
|
||||
class RNN {
|
||||
public:
|
||||
RNN(const Weights& model)
|
||||
: w_(model) {}
|
||||
|
||||
void InitializeState(size_t batchSize = 1) {
|
||||
State_.Clear();
|
||||
State_.Resize(batchSize, 1000, 0.0);
|
||||
}
|
||||
|
||||
void GetNextState(mblas::Matrix& State,
|
||||
const mblas::Matrix& Embd,
|
||||
const mblas::Matrix& PrevState) {
|
||||
using namespace mblas;
|
||||
|
||||
Prod(Za_, Embd, w_.Wz_);
|
||||
Prod(Temp_, PrevState, w_.Uz_);
|
||||
Element(Logit(_1 + _2), Za_, Temp_);
|
||||
|
||||
Prod(Ra_, Embd, w_.Wr_);
|
||||
Prod(Temp_, PrevState, w_.Ur_);
|
||||
Element(Logit(_1 + _2), Ra_, Temp_);
|
||||
|
||||
Prod(Ha_, Embd, w_.W_);
|
||||
Prod(Temp_, Element(_1 * _2, Ra_, PrevState), w_.U_);
|
||||
Element(_1 + _2, Ha_, w_.B_); // Broadcasting row-wise
|
||||
Element(Tanh(_1 + _2), Ha_, Temp_);
|
||||
|
||||
Element((1.0 - _1) * _2 + _1 * _3, Za_, PrevState, Ha_);
|
||||
|
||||
Swap(State, Za_);
|
||||
}
|
||||
|
||||
template <class It>
|
||||
void GetContext(It it, It end,
|
||||
mblas::Matrix& Context, bool invert) {
|
||||
InitializeState();
|
||||
|
||||
size_t n = std::distance(it, end);
|
||||
size_t i = 0;
|
||||
while(it != end) {
|
||||
GetNextState(State_, *it++, State_);
|
||||
if(invert)
|
||||
mblas::PasteRow(Context, State_, n - i - 1, 1000);
|
||||
else
|
||||
mblas::PasteRow(Context, State_, i, 0);
|
||||
++i;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
// Model matrices
|
||||
const Weights& w_;
|
||||
|
||||
// reused to avoid allocation
|
||||
mblas::Matrix Za_;
|
||||
mblas::Matrix Ra_;
|
||||
mblas::Matrix Ha_;
|
||||
mblas::Matrix Temp_;
|
||||
mblas::Matrix State_;
|
||||
};
|
||||
|
||||
public:
|
||||
Encoder(const Weights& model)
|
||||
: embeddings_(model.encEmbeddings_),
|
||||
forwardRnn_(model.encForwardRnn_),
|
||||
backwardRnn_(model.encBackwardRnn_)
|
||||
{}
|
||||
|
||||
void GetContext(const std::vector<size_t>& words,
|
||||
mblas::Matrix& Context) {
|
||||
std::vector<mblas::Matrix> embeddedWords;
|
||||
|
||||
Context.Resize(words.size(), 2000);
|
||||
for(auto& w : words) {
|
||||
embeddedWords.emplace_back();
|
||||
embeddings_.Lookup(embeddedWords.back(), w);
|
||||
}
|
||||
|
||||
forwardRnn_.GetContext(embeddedWords.begin(),
|
||||
embeddedWords.end(),
|
||||
Context, false);
|
||||
backwardRnn_.GetContext(embeddedWords.rbegin(),
|
||||
embeddedWords.rend(),
|
||||
Context, true);
|
||||
}
|
||||
|
||||
private:
|
||||
Embeddings<Weights::EncEmbeddings> embeddings_;
|
||||
RNN<Weights::EncForwardRnn> forwardRnn_;
|
||||
RNN<Weights::EncBackwardRnn> backwardRnn_;
|
||||
};
|
@ -1,169 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
|
||||
#include "mblas/matrix.h"
|
||||
#include "npz_converter.h"
|
||||
|
||||
struct Weights {
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct EncEmbeddings {
|
||||
EncEmbeddings(const NpzConverter& model)
|
||||
: E_(model["W_0_enc_approx_embdr"]),
|
||||
EB_(model("b_0_enc_approx_embdr", true))
|
||||
{}
|
||||
|
||||
const mblas::Matrix E_;
|
||||
const mblas::Matrix EB_;
|
||||
};
|
||||
|
||||
struct EncForwardRnn {
|
||||
EncForwardRnn(const NpzConverter& model)
|
||||
: W_(model["W_0_enc_input_embdr_0"]),
|
||||
B_(model("b_0_enc_input_embdr_0", true)),
|
||||
U_(model["W_enc_transition_0"]),
|
||||
Wz_(model["W_0_enc_update_embdr_0"]),
|
||||
Uz_(model["G_enc_transition_0"]),
|
||||
Wr_(model["W_0_enc_reset_embdr_0"]),
|
||||
Ur_(model["R_enc_transition_0"])
|
||||
{}
|
||||
|
||||
const mblas::Matrix W_;
|
||||
const mblas::Matrix B_;
|
||||
const mblas::Matrix U_;
|
||||
const mblas::Matrix Wz_;
|
||||
const mblas::Matrix Uz_;
|
||||
const mblas::Matrix Wr_;
|
||||
const mblas::Matrix Ur_;
|
||||
};
|
||||
|
||||
struct EncBackwardRnn {
|
||||
EncBackwardRnn(const NpzConverter& model)
|
||||
: W_(model["W_0_back_enc_input_embdr_0"]),
|
||||
B_(model("b_0_back_enc_input_embdr_0", true)),
|
||||
U_(model["W_back_enc_transition_0"]),
|
||||
Wz_(model["W_0_back_enc_update_embdr_0"]),
|
||||
Uz_(model["G_back_enc_transition_0"]),
|
||||
Wr_(model["W_0_back_enc_reset_embdr_0"]),
|
||||
Ur_(model["R_back_enc_transition_0"])
|
||||
{}
|
||||
|
||||
const mblas::Matrix W_;
|
||||
const mblas::Matrix B_;
|
||||
const mblas::Matrix U_;
|
||||
const mblas::Matrix Wz_;
|
||||
const mblas::Matrix Uz_;
|
||||
const mblas::Matrix Wr_;
|
||||
const mblas::Matrix Ur_;
|
||||
};
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct DecEmbeddings {
|
||||
DecEmbeddings(const NpzConverter& model)
|
||||
: E_(model["W_0_dec_approx_embdr"]),
|
||||
EB_(model("b_0_dec_approx_embdr", true))
|
||||
{}
|
||||
|
||||
const mblas::Matrix E_;
|
||||
const mblas::Matrix EB_;
|
||||
};
|
||||
|
||||
struct DecRnn {
|
||||
DecRnn(const NpzConverter& model)
|
||||
: Ws_(model["W_0_dec_initializer_0"]),
|
||||
WsB_(model("b_0_dec_initializer_0", true)),
|
||||
|
||||
W_(model["W_0_dec_input_embdr_0"]),
|
||||
B_(model("b_0_dec_input_embdr_0", true)),
|
||||
U_(model["W_dec_transition_0"]),
|
||||
C_(model["W_0_dec_dec_inputter_0"]),
|
||||
|
||||
Wz_(model["W_0_dec_update_embdr_0"]),
|
||||
Uz_(model["G_dec_transition_0"]),
|
||||
Cz_(model["W_0_dec_dec_updater_0"]),
|
||||
|
||||
Wr_(model["W_0_dec_reset_embdr_0"]),
|
||||
Ur_(model["R_dec_transition_0"]),
|
||||
Cr_(model["W_0_dec_dec_reseter_0"])
|
||||
{}
|
||||
|
||||
const mblas::Matrix Ws_;
|
||||
const mblas::Matrix WsB_;
|
||||
const mblas::Matrix W_;
|
||||
const mblas::Matrix B_;
|
||||
const mblas::Matrix U_;
|
||||
const mblas::Matrix C_;
|
||||
const mblas::Matrix Wz_;
|
||||
const mblas::Matrix Uz_;
|
||||
const mblas::Matrix Cz_;
|
||||
const mblas::Matrix Wr_;
|
||||
const mblas::Matrix Ur_;
|
||||
const mblas::Matrix Cr_;
|
||||
};
|
||||
|
||||
struct DecAlignment {
|
||||
DecAlignment(const NpzConverter& model)
|
||||
: Va_(model("D_dec_transition_0", true)),
|
||||
Wa_(model["B_dec_transition_0"]),
|
||||
Ua_(model["A_dec_transition_0"])
|
||||
{}
|
||||
|
||||
const mblas::Matrix Va_;
|
||||
const mblas::Matrix Wa_;
|
||||
const mblas::Matrix Ua_;
|
||||
};
|
||||
|
||||
struct DecSoftmax {
|
||||
DecSoftmax(const NpzConverter& model)
|
||||
: WoB_(model("b_dec_deep_softmax", true)),
|
||||
Uo_(model["W_0_dec_hid_readout_0"]),
|
||||
UoB_(model("b_0_dec_hid_readout_0", true)),
|
||||
Vo_(model["W_0_dec_prev_readout_0"]),
|
||||
Co_(model["W_0_dec_repr_readout"])
|
||||
{
|
||||
const mblas::Matrix Wo1_(model["W1_dec_deep_softmax"]);
|
||||
const mblas::Matrix Wo2_(model["W2_dec_deep_softmax"]);
|
||||
mblas::Prod(const_cast<mblas::Matrix&>(Wo_), Wo1_, Wo2_);
|
||||
}
|
||||
|
||||
const mblas::Matrix Wo_;
|
||||
const mblas::Matrix WoB_;
|
||||
const mblas::Matrix Uo_;
|
||||
const mblas::Matrix UoB_;
|
||||
const mblas::Matrix Vo_;
|
||||
const mblas::Matrix Co_;
|
||||
};
|
||||
|
||||
Weights(const std::string& npzFile, size_t device = 0)
|
||||
: Weights(NpzConverter(npzFile), device)
|
||||
{}
|
||||
|
||||
Weights(const NpzConverter& model, size_t device = 0)
|
||||
: encEmbeddings_(model),
|
||||
decEmbeddings_(model),
|
||||
encForwardRnn_(model),
|
||||
encBackwardRnn_(model),
|
||||
decRnn_(model),
|
||||
decAlignment_(model),
|
||||
decSoftmax_(model),
|
||||
device_(device)
|
||||
{}
|
||||
|
||||
size_t GetDevice() {
|
||||
return device_;
|
||||
}
|
||||
|
||||
const EncEmbeddings encEmbeddings_;
|
||||
const DecEmbeddings decEmbeddings_;
|
||||
const EncForwardRnn encForwardRnn_;
|
||||
const EncBackwardRnn encBackwardRnn_;
|
||||
const DecRnn decRnn_;
|
||||
const DecAlignment decAlignment_;
|
||||
const DecSoftmax decSoftmax_;
|
||||
|
||||
const size_t device_;
|
||||
};
|
@ -1,95 +0,0 @@
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <boost/timer/timer.hpp>
|
||||
#include <boost/program_options/options_description.hpp>
|
||||
#include <boost/program_options/parsers.hpp>
|
||||
#include <boost/program_options/variables_map.hpp>
|
||||
#include <boost/lexical_cast.hpp>
|
||||
|
||||
#include "bahdanau/model.h"
|
||||
#include "vocab.h"
|
||||
#include "decoder/nmt_decoder.h"
|
||||
|
||||
|
||||
void ProgramOptions(int argc, char *argv[],
|
||||
std::string& modelPath,
|
||||
std::string& svPath,
|
||||
std::string& tvPath,
|
||||
size_t& beamsize,
|
||||
size_t& device) {
|
||||
bool help = false;
|
||||
|
||||
namespace po = boost::program_options;
|
||||
po::options_description cmdline_options("Allowed options");
|
||||
cmdline_options.add_options()
|
||||
("beamsize,b", po::value(&beamsize)->default_value(10),
|
||||
"Beam size")
|
||||
("device,d", po::value(&device)->default_value(0),
|
||||
"CUDA Device")
|
||||
("model,m", po::value(&modelPath)->required(),
|
||||
"Path to a model")
|
||||
("source,s", po::value(&svPath)->required(),
|
||||
"Path to a source vocab file.")
|
||||
("target,t", po::value(&tvPath)->required(),
|
||||
"Path to a target vocab file.")
|
||||
("help,h", po::value(&help)->zero_tokens()->default_value(false),
|
||||
"Print this help message and exit.")
|
||||
;
|
||||
po::variables_map vm;
|
||||
try {
|
||||
po::store(po::command_line_parser(argc, argv).
|
||||
options(cmdline_options).run(), vm);
|
||||
po::notify(vm);
|
||||
} catch (std::exception& e) {
|
||||
std::cout << "Error: " << e.what() << std::endl << std::endl;
|
||||
|
||||
std::cout << "Usage: " + std::string(argv[0]) + " [options]" << std::endl;
|
||||
std::cout << cmdline_options << std::endl;
|
||||
exit(0);
|
||||
}
|
||||
|
||||
if (help) {
|
||||
std::cout << "Usage: " + std::string(argv[0]) + " [options]" << std::endl;
|
||||
std::cout << cmdline_options << std::endl;
|
||||
exit(0);
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
std::string modelPath, srcVocabPath, trgVocabPath;
|
||||
size_t device = 0;
|
||||
size_t beamsize = 10;
|
||||
ProgramOptions(argc, argv, modelPath, srcVocabPath, trgVocabPath, beamsize, device);
|
||||
std::cerr << "Using device GPU" << device << std::endl;;
|
||||
cudaSetDevice(device);
|
||||
std::cerr << "Loading model... ";
|
||||
std::shared_ptr<Weights> model(new Weights(modelPath));
|
||||
std::shared_ptr<Vocab> srcVocab(new Vocab(srcVocabPath));
|
||||
std::shared_ptr<Vocab> trgVocab(new Vocab(trgVocabPath));
|
||||
std::cerr << "done." << std::endl;
|
||||
|
||||
NMTDecoder decoder(model, srcVocab, trgVocab, beamsize);
|
||||
|
||||
std::cerr << "Start translating...\n";
|
||||
|
||||
std::ios_base::sync_with_stdio(false);
|
||||
|
||||
std::string line;
|
||||
boost::timer::cpu_timer timer;
|
||||
while(std::getline(std::cin, line)) {
|
||||
auto result = decoder.translate(line);
|
||||
for (auto it = result.rbegin(); it != result.rend(); ++it) {
|
||||
std::string word = (*trgVocab)[*it];
|
||||
if(it != result.rbegin())
|
||||
std::cout << " ";
|
||||
if(word != "</s>")
|
||||
std::cout << word;
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
std::cerr << timer.format() << std::endl;
|
||||
return 0;
|
||||
}
|
@ -1,32 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstddef>
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
|
||||
class Hypothesis {
|
||||
public:
|
||||
Hypothesis(size_t word, size_t prev, float cost)
|
||||
: prev_(prev),
|
||||
word_(word),
|
||||
cost_(cost) {
|
||||
}
|
||||
|
||||
size_t GetWord() const {
|
||||
return word_;
|
||||
}
|
||||
|
||||
size_t GetPrevStateIndex() const {
|
||||
return prev_;
|
||||
}
|
||||
|
||||
float GetCost() const {
|
||||
return cost_;
|
||||
}
|
||||
|
||||
private:
|
||||
const size_t prev_;
|
||||
const size_t word_;
|
||||
const float cost_;
|
||||
};
|
||||
|
@ -1,75 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
|
||||
#include "decoder/hypothesis.h"
|
||||
|
||||
class HypothesisManager {
|
||||
using Hypotheses = std::vector<Hypothesis>;
|
||||
public:
|
||||
HypothesisManager(size_t beamSize, size_t EOSIndex)
|
||||
: beamSize_(beamSize),
|
||||
EOSIndex_(EOSIndex),
|
||||
baseIndex_(0) {
|
||||
hypotheses_.emplace_back(0, 0, 0);
|
||||
}
|
||||
|
||||
void AddHypotheses(const Hypotheses& hypos) {
|
||||
size_t nextBaseIndex = hypotheses_.size();
|
||||
for (const auto& hypo : hypos) {
|
||||
if (hypo.GetWord() == EOSIndex_) {
|
||||
completedHypotheses_.emplace_back(hypo.GetWord(),
|
||||
hypo.GetPrevStateIndex() + baseIndex_,
|
||||
hypo.GetCost());
|
||||
} else {
|
||||
hypotheses_.emplace_back(hypo.GetWord(), hypo.GetPrevStateIndex() + baseIndex_,
|
||||
hypo.GetCost());
|
||||
}
|
||||
}
|
||||
baseIndex_ = nextBaseIndex;
|
||||
}
|
||||
|
||||
std::vector<size_t> GetBestTranslation() {
|
||||
size_t bestHypoId = 0;
|
||||
for (size_t i = 0; i < completedHypotheses_.size(); ++i) {
|
||||
if (completedHypotheses_[bestHypoId].GetCost()
|
||||
< completedHypotheses_[i].GetCost()) {
|
||||
bestHypoId = i;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// for (auto hypo : completedHypotheses_) {
|
||||
// std::vector<size_t> words;
|
||||
// words.push_back(hypo.GetWord());
|
||||
// size_t state = hypo.GetPrevStateIndex();
|
||||
// while (state > 0) {
|
||||
// words.push_back(hypotheses_[state].GetWord());
|
||||
// state = hypotheses_[state].GetPrevStateIndex();
|
||||
// }
|
||||
// for (auto it = words.rbegin(); it != words.rend(); ++it) std::cerr << *it << " ";
|
||||
// std::cerr << hypo.GetCost() << std::endl;
|
||||
// }
|
||||
|
||||
std::vector<size_t> bestSentence;
|
||||
bestSentence.push_back(completedHypotheses_[bestHypoId].GetWord());
|
||||
size_t state = completedHypotheses_[bestHypoId].GetPrevStateIndex();
|
||||
|
||||
while (state > 0) {
|
||||
bestSentence.push_back(hypotheses_[state].GetWord());
|
||||
state = hypotheses_[state].GetPrevStateIndex();
|
||||
}
|
||||
|
||||
return bestSentence;
|
||||
}
|
||||
|
||||
private:
|
||||
Hypotheses hypotheses_;
|
||||
size_t beamSize_;
|
||||
Hypotheses completedHypotheses_;
|
||||
const size_t EOSIndex_;
|
||||
size_t baseIndex_;
|
||||
};
|
||||
|
||||
|
@ -1,175 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <numeric>
|
||||
|
||||
#include <thrust/functional.h>
|
||||
#include <thrust/device_vector.h>
|
||||
#include <thrust/host_vector.h>
|
||||
#include <thrust/device_ptr.h>
|
||||
#include <thrust/extrema.h>
|
||||
#include <thrust/sort.h>
|
||||
#include <thrust/sequence.h>
|
||||
|
||||
#include "common/vocab.h"
|
||||
#include "bahdanau/encoder.h"
|
||||
#include "bahdanau/decoder.h"
|
||||
#include "bahdanau/model.h"
|
||||
#include "common/utils.h"
|
||||
#include "mblas/matrix.h"
|
||||
#include "decoder/hypothesis_manager.h"
|
||||
|
||||
|
||||
using namespace thrust::placeholders;
|
||||
|
||||
class NMTDecoder {
|
||||
using Words = std::vector<size_t>;
|
||||
using Hypotheses = std::vector<Hypothesis>;
|
||||
public:
|
||||
NMTDecoder(
|
||||
std::shared_ptr<Weights> model,
|
||||
std::shared_ptr<Vocab> srcVocab,
|
||||
std::shared_ptr<Vocab> trgVocab,
|
||||
const size_t beamSize=1)
|
||||
: model_(model),
|
||||
srcVocab_(srcVocab),
|
||||
trgVocab_(trgVocab),
|
||||
encoder_(new Encoder(*model_)),
|
||||
decoder_(new Decoder(*model_)),
|
||||
beamSize_(beamSize),
|
||||
Costs_() {
|
||||
}
|
||||
|
||||
Words translate(std::string& sentence) {
|
||||
size_t sourceSentenceLength = prepareSourceSentence(sentence);
|
||||
prepareDecoder();
|
||||
|
||||
size_t batchSize = beamSize_;
|
||||
Costs_.Clear();
|
||||
Costs_.Resize(batchSize, 1, 0.0);
|
||||
HypothesisManager hypoManager(batchSize, (*trgVocab_)["</s>"]);
|
||||
|
||||
mblas::Matrix Probs;
|
||||
|
||||
for(size_t len = 0; len < 3 * sourceSentenceLength; ++len) {
|
||||
std::vector<size_t> bestWordIndices, bestWordHyps;
|
||||
decoder_->GetProbs(Probs, AlignedSourceContext_,
|
||||
PrevState_, PrevEmbedding_, SourceContext_);
|
||||
|
||||
// Przeniesione tutaj. moze decoder powinien to robic.
|
||||
Element(Log(_1), Probs);
|
||||
|
||||
// Brzydkie, ale GH tez to ma, troche pomaga przy wiekszym
|
||||
// BeamSize, ale jeszcze gdzies jest problem.
|
||||
if(len < sourceSentenceLength * 0.5) {
|
||||
size_t eol = (*trgVocab_)["</s>"];
|
||||
for(size_t i = 0; i < Probs.Rows(); ++i) {
|
||||
Probs.Set(i, eol, std::numeric_limits<float>::lowest());
|
||||
}
|
||||
}
|
||||
|
||||
auto bestHypos = GetBestExtensions(Probs, batchSize);
|
||||
hypoManager.AddHypotheses(bestHypos);
|
||||
|
||||
size_t cidx = 0;
|
||||
std::vector<size_t> costIndeces;
|
||||
for (auto& best: bestHypos) {
|
||||
if (best.GetWord() != (*trgVocab_)["</s>"]) {
|
||||
bestWordIndices.push_back(best.GetWord());
|
||||
bestWordHyps.push_back(best.GetPrevStateIndex());
|
||||
costIndeces.push_back(cidx);
|
||||
} else {
|
||||
//std::cerr << "Finshed at " << Costs_(0, cidx) << std::endl;
|
||||
--batchSize;
|
||||
}
|
||||
cidx++;
|
||||
}
|
||||
|
||||
if (batchSize <= 0)
|
||||
break;
|
||||
|
||||
// Zrobic warunkowo
|
||||
mblas::Matrix CostsTemp;
|
||||
mblas::Assemble(CostsTemp, Costs_, costIndeces);
|
||||
mblas::Swap(Costs_, CostsTemp);
|
||||
//mblas::debug1(Costs_);
|
||||
|
||||
decoder_->Lookup(Embedding_, bestWordIndices);
|
||||
Assemble(BestState_, PrevState_, bestWordHyps);
|
||||
decoder_->GetNextState(State_, Embedding_,
|
||||
BestState_, AlignedSourceContext_);
|
||||
|
||||
mblas::Swap(State_, PrevState_);
|
||||
mblas::Swap(Embedding_, PrevEmbedding_);
|
||||
}
|
||||
|
||||
return hypoManager.GetBestTranslation();
|
||||
}
|
||||
|
||||
private:
|
||||
size_t prepareSourceSentence(std::string& sentence) {
|
||||
Trim(sentence);
|
||||
std::vector<std::string> tokens;
|
||||
Split(sentence, tokens, " ");
|
||||
auto encoded_tokens = srcVocab_->Encode(tokens, true);
|
||||
encoder_->GetContext(encoded_tokens, SourceContext_);
|
||||
return encoded_tokens.size();
|
||||
}
|
||||
|
||||
Hypotheses GetBestExtensions(mblas::Matrix& Probs, size_t batchSize) {
|
||||
Hypotheses hypos;
|
||||
|
||||
// One kernel. Na pewno nie dwa razy transpose wielkiej macierzy, batchsize * vocab
|
||||
Costs_.Reshape(1, batchSize);
|
||||
Broadcast(_1 + _2, Transpose(Probs), Costs_);
|
||||
Costs_.Reshape(batchSize, 1);
|
||||
Transpose(Probs);
|
||||
|
||||
size_t probSize = Probs.Cols() * Probs.Rows();
|
||||
thrust::device_vector<int> keys(probSize);
|
||||
thrust::sequence(keys.begin(), keys.end());
|
||||
|
||||
// warto sortować w odwrotnej kolejnosci, zaoszczedzi kombinacje ponizej
|
||||
thrust::sort_by_key(Probs.begin(), Probs.end(), keys.begin());
|
||||
// OK, to pewnie uzywa thrust::copy? Sprawdzić
|
||||
thrust::host_vector<int> bestKeys(keys.end() - batchSize, keys.end());
|
||||
|
||||
HypothesisManager hypoManager(batchSize, (*trgVocab_)["</s>"]);
|
||||
|
||||
// za pomoca thrust::copy zrobic dwie kopie, jedno do Costs, jedna do wektora na cpu, w drugim kroku uzyc cpu
|
||||
for (size_t i = 0; i < bestKeys.size(); ++i) {
|
||||
Costs_.GetVec()[i] = Probs.GetVec()[probSize - batchSize + i];
|
||||
hypos.emplace_back(bestKeys[i] % Probs.Cols(), bestKeys[i] / Probs.Cols(), Probs.GetVec()[probSize - batchSize + i]);
|
||||
}
|
||||
|
||||
return hypos;
|
||||
|
||||
}
|
||||
|
||||
void prepareDecoder() {
|
||||
decoder_->EmptyState(PrevState_, SourceContext_, 1);
|
||||
decoder_->EmptyEmbedding(PrevEmbedding_, 1);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::shared_ptr<Weights> model_;
|
||||
std::shared_ptr<Vocab> srcVocab_;
|
||||
std::shared_ptr<Vocab> trgVocab_;
|
||||
std::shared_ptr<Encoder> encoder_;
|
||||
std::shared_ptr<Decoder> decoder_;
|
||||
const size_t beamSize_;
|
||||
mblas::Matrix SourceContext_;
|
||||
mblas::Matrix PrevState_;
|
||||
mblas::Matrix PrevEmbedding_;
|
||||
mblas::Matrix BestState_;
|
||||
mblas::Matrix Costs_;
|
||||
|
||||
mblas::Matrix AlignedSourceContext_;
|
||||
|
||||
mblas::Matrix State_;
|
||||
mblas::Matrix Embedding_;
|
||||
|
||||
};
|
@ -1,15 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstddef>
|
||||
|
||||
struct Result {
|
||||
Result(const size_t state, const size_t word, const float score)
|
||||
: state(state),
|
||||
word(word),
|
||||
score(score) {
|
||||
}
|
||||
|
||||
size_t state;
|
||||
size_t word;
|
||||
float score;
|
||||
};
|
@ -106,7 +106,9 @@ class Search {
|
||||
return history;
|
||||
}
|
||||
|
||||
void BestHyps(Beam& bestHyps, const Beam& prevHyps, std::vector<mblas::Matrix*>& ProbsEnsemble, const size_t beamSize) {
|
||||
void BestHyps(Beam& bestHyps, const Beam& prevHyps,
|
||||
std::vector<mblas::Matrix*>& ProbsEnsemble,
|
||||
const size_t beamSize) {
|
||||
using namespace mblas;
|
||||
|
||||
Matrix& Probs = *ProbsEnsemble[0];
|
||||
|
102
src/test/test.cu
102
src/test/test.cu
@ -1,102 +0,0 @@
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <algorithm>
|
||||
#include <boost/timer/timer.hpp>
|
||||
#include <boost/algorithm/string.hpp>
|
||||
|
||||
#include "mblas/matrix.h"
|
||||
#include "bahdanau.h"
|
||||
#include "vocab.h"
|
||||
|
||||
#include "states.h"
|
||||
|
||||
using namespace mblas;
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
size_t device = 0;
|
||||
|
||||
if(argc > 1) {
|
||||
if(std::string(argv[1]) == "1")
|
||||
device = 1;
|
||||
else if(std::string(argv[1]) == "2")
|
||||
device = 2;
|
||||
}
|
||||
|
||||
std::cerr << device << std::endl;
|
||||
cudaSetDevice(device);
|
||||
|
||||
std::string source = "thank you .";
|
||||
std::string target = "vielen dank .";
|
||||
//std::string source = "you know , one of the intense pleasures of travel and one of the delights of ethnographic research is the opportunity to live amongst those who have not forgotten the old ways , who still feel their past in the wind , touch it in stones polished by rain , taste it in the bitter leaves of plants .";
|
||||
//std::string target = "wissen sie , eine der intensiven freuden des reisens und eine der freuden der ethnografischen forschung ist die chance zu leben unter jenen , die die alten wege nicht vergessen haben , die immer noch ihre vergangenheit im wind spüren , berühren sie in steine poliert durch regen , schmecken sie in den bitteren blätter der pflanzen .";
|
||||
|
||||
std::cerr << "Loading model" << std::endl;
|
||||
Weights weights("/home/marcinj/Badania/best_nmt/search_model.npz", device);
|
||||
Vocab svcb("/home/marcinj/Badania/best_nmt/vocab/en_de.en.txt");
|
||||
Vocab tvcb("/home/marcinj/Badania/best_nmt/vocab/en_de.de.txt");
|
||||
|
||||
std::cerr << "Creating encoder" << std::endl;
|
||||
Encoder encoder(weights);
|
||||
std::cerr << "Creating decoder" << std::endl;
|
||||
Decoder decoder(weights);
|
||||
|
||||
std::vector<std::string> sourceSplit;
|
||||
boost::split(sourceSplit, source, boost::is_any_of(" "),
|
||||
boost::token_compress_on);
|
||||
|
||||
std::cerr << "Source: " << std::endl;
|
||||
std::vector<size_t> sWords(sourceSplit.size());
|
||||
std::transform(sourceSplit.begin(), sourceSplit.end(), sWords.begin(),
|
||||
[&](const std::string& w) { std::cerr << svcb[w] << ", "; return svcb[w]; });
|
||||
sWords.push_back(svcb["</s>"]);
|
||||
std::cerr << svcb["</s>"] << std::endl;
|
||||
|
||||
typedef std::vector<size_t> Batch;
|
||||
|
||||
std::vector<std::string> targetSplit;
|
||||
boost::split(targetSplit, target, boost::is_any_of(" "),
|
||||
boost::token_compress_on);
|
||||
|
||||
std::cerr << "Target: " << std::endl;
|
||||
size_t bs = 1000;
|
||||
std::vector<std::vector<size_t>> tWordsBatch(targetSplit.size());
|
||||
std::transform(targetSplit.begin(), targetSplit.end(), tWordsBatch.begin(),
|
||||
[&](const std::string& w) { std::cerr << tvcb[w] << ", "; return Batch(bs, tvcb[w]); });
|
||||
tWordsBatch.push_back(Batch(bs, tvcb["</s>"]));
|
||||
std::cerr << tvcb["</s>"] << std::endl;
|
||||
|
||||
mblas::Matrix SourceContext;
|
||||
encoder.GetContext(sWords, SourceContext);
|
||||
|
||||
mblas::Matrix State, NextState;
|
||||
mblas::Matrix Embeddings, NextEmbeddings;
|
||||
mblas::Matrix Probs;
|
||||
|
||||
std::cerr << "Testing" << std::endl;
|
||||
boost::timer::auto_cpu_timer timer;
|
||||
size_t batchSize = tWordsBatch[0].size();
|
||||
|
||||
for(size_t i = 0; i < 1; ++i) {
|
||||
decoder.EmptyState(State, SourceContext, batchSize);
|
||||
decoder.EmptyEmbedding(Embeddings, batchSize);
|
||||
|
||||
float sum = 0;
|
||||
for(auto batch : tWordsBatch) {
|
||||
decoder.MakeStep(NextState, NextEmbeddings, Probs,
|
||||
batch, State, Embeddings, SourceContext);
|
||||
|
||||
for(size_t i = 0; i < 1; ++i) {
|
||||
float p = Probs(i, batch[i]);
|
||||
std:: cerr << log(p) << " ";
|
||||
if(i == 0) {
|
||||
sum += log(p);
|
||||
}
|
||||
}
|
||||
|
||||
mblas::Swap(Embeddings, NextEmbeddings);
|
||||
mblas::Swap(State, NextState);
|
||||
}
|
||||
std::cerr << i << " " << sum << std::endl;
|
||||
}
|
||||
}
|
@ -1,100 +0,0 @@
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <algorithm>
|
||||
#include <boost/timer/timer.hpp>
|
||||
#include <boost/algorithm/string.hpp>
|
||||
|
||||
#include "mblas/matrix.h"
|
||||
#include "dl4mt.h"
|
||||
#include "vocab.h"
|
||||
|
||||
using namespace mblas;
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
size_t device = 0;
|
||||
|
||||
if(argc > 1) {
|
||||
if(std::string(argv[1]) == "1")
|
||||
device = 1;
|
||||
else if(std::string(argv[1]) == "2")
|
||||
device = 2;
|
||||
}
|
||||
|
||||
std::cerr << device << std::endl;
|
||||
cudaSetDevice(device);
|
||||
|
||||
std::string source = "thank you .";
|
||||
std::string target = "vielen Dank .";
|
||||
|
||||
std::cerr << "Loading model" << std::endl;
|
||||
Weights weights("testmodel/model.npz", device);
|
||||
|
||||
Vocab svcb("testmodel/vocab.en.txt");
|
||||
Vocab tvcb("testmodel/vocab.de.txt");
|
||||
|
||||
std::cerr << "Creating encoder" << std::endl;
|
||||
Encoder encoder(weights);
|
||||
|
||||
std::cerr << "Creating decoder" << std::endl;
|
||||
Decoder decoder(weights);
|
||||
|
||||
std::vector<std::string> sourceSplit;
|
||||
boost::split(sourceSplit, source, boost::is_any_of(" "),
|
||||
boost::token_compress_on);
|
||||
|
||||
std::cerr << "Source: " << std::endl;
|
||||
std::vector<size_t> sWords(sourceSplit.size());
|
||||
std::transform(sourceSplit.begin(), sourceSplit.end(), sWords.begin(),
|
||||
[&](const std::string& w) { std::cerr << svcb[w] << ", "; return svcb[w]; });
|
||||
sWords.push_back(svcb["</s>"]);
|
||||
std::cerr << svcb["</s>"] << std::endl;
|
||||
|
||||
typedef std::vector<size_t> Batch;
|
||||
|
||||
std::vector<std::string> targetSplit;
|
||||
boost::split(targetSplit, target, boost::is_any_of(" "),
|
||||
boost::token_compress_on);
|
||||
|
||||
std::cerr << "Target: " << std::endl;
|
||||
size_t bs = 1000;
|
||||
|
||||
std::vector<std::vector<size_t>> tWordsBatch(targetSplit.size());
|
||||
std::transform(targetSplit.begin(), targetSplit.end(), tWordsBatch.begin(),
|
||||
[&](const std::string& w) { std::cerr << tvcb[w] << ", "; return Batch(bs, tvcb[w]); });
|
||||
tWordsBatch.push_back(Batch(bs, tvcb["</s>"]));
|
||||
std::cerr << tvcb["</s>"] << std::endl;
|
||||
|
||||
mblas::Matrix SourceContext;
|
||||
encoder.GetContext(sWords, SourceContext);
|
||||
|
||||
mblas::Matrix State, NextState;
|
||||
mblas::Matrix Embeddings, NextEmbeddings;
|
||||
mblas::Matrix Probs;
|
||||
|
||||
std::cerr << "Testing" << std::endl;
|
||||
boost::timer::auto_cpu_timer timer;
|
||||
size_t batchSize = tWordsBatch[0].size();
|
||||
|
||||
for(size_t i = 0; i < 1; ++i) {
|
||||
decoder.EmptyState(State, SourceContext, batchSize);
|
||||
decoder.EmptyEmbedding(Embeddings, batchSize);
|
||||
|
||||
float sum = 0;
|
||||
for(auto batch : tWordsBatch) {
|
||||
decoder.MakeStep(NextState, Probs,
|
||||
State, Embeddings, SourceContext);
|
||||
decoder.Lookup(NextEmbeddings, batch);
|
||||
for(size_t i = 0; i < 1; ++i) {
|
||||
float p = Probs(i, batch[i]);
|
||||
if(i == 0) {
|
||||
sum += log(p);
|
||||
}
|
||||
}
|
||||
|
||||
mblas::Swap(Embeddings, NextEmbeddings);
|
||||
mblas::Swap(State, NextState);
|
||||
}
|
||||
std::cerr << i << " " << sum << std::endl;
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user