mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-03 20:13:47 +03:00
some more refactoring
This commit is contained in:
parent
78c9c30aaf
commit
84c3121885
@ -17,7 +17,13 @@ set(CMAKE_CXX_FLAGS_DEBUG " -std=c++11 -g -O0 -fPIC -Wno-unused-result -Wno-depr
|
||||
set(CMAKE_CXX_FLAGS_PROFILE "${CMAKE_CXX_FLAGS_RELEASE} -g -pg")
|
||||
set(CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS_RELEASE})
|
||||
|
||||
LIST(APPEND CUDA_NVCC_FLAGS -std=c++11; --default-stream per-thread; -O3; --use_fast_math; -Xcompiler '-fPIC'; -arch=sm_35;)
|
||||
|
||||
if (CMAKE_BUILD_TYPE STREQUAL "Debug")
|
||||
LIST(APPEND CUDA_NVCC_FLAGS -std=c++11; --default-stream per-thread; -O0; -g; -Xcompiler '-fPIC'; -arch=sm_35;)
|
||||
else(CMAKE_BUILD_TYPE STREQUAL "Debug")
|
||||
LIST(APPEND CUDA_NVCC_FLAGS -std=c++11; --default-stream per-thread; -O3; --use_fast_math; -Xcompiler '-fPIC'; -arch=sm_35;)
|
||||
endif(CMAKE_BUILD_TYPE STREQUAL "Debug")
|
||||
|
||||
LIST(REMOVE_DUPLICATES CUDA_NVCC_FLAGS)
|
||||
SET(CUDA_PROPAGATE_HOST_FLAGS OFF)
|
||||
|
||||
|
101
src/common/options.h
Normal file
101
src/common/options.h
Normal file
@ -0,0 +1,101 @@
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include "common/definitions.h"
|
||||
#include "3rd_party/yaml-cpp/yaml.h"
|
||||
|
||||
namespace marian {
|
||||
|
||||
class Options {
|
||||
protected:
|
||||
YAML::Node options_;
|
||||
|
||||
public:
|
||||
YAML::Node& getOptions() {
|
||||
return options_;
|
||||
}
|
||||
|
||||
void parse(const std::string& yaml) {
|
||||
auto node = YAML::Load(yaml);
|
||||
for(auto it : node)
|
||||
options_[it.first.as<std::string>()] = it.second;
|
||||
}
|
||||
|
||||
void merge(YAML::Node& node) {
|
||||
for(auto it : node)
|
||||
options_[it.first.as<std::string>()] = it.second;
|
||||
}
|
||||
|
||||
std::string str() {
|
||||
std::stringstream ss;
|
||||
ss << options_;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void set(const std::string& key, T value) {
|
||||
options_[key] = value;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T get(const std::string& key) {
|
||||
return options_[key].as<T>();
|
||||
}
|
||||
};
|
||||
|
||||
template <class Obj>
|
||||
struct DefaultCreate {
|
||||
template <typename ...Args>
|
||||
static Ptr<Obj> create(Ptr<ExpressionGraph> graph, Ptr<Options> options, Args ...args) {
|
||||
return New<Obj>(graph, options, args...);
|
||||
}
|
||||
};
|
||||
|
||||
template <class Obj, class Create=DefaultCreate<Obj>>
|
||||
class Builder {
|
||||
protected:
|
||||
Ptr<Options> options_;
|
||||
Ptr<ExpressionGraph> graph_;
|
||||
|
||||
public:
|
||||
Builder(Ptr<ExpressionGraph> graph)
|
||||
: options_(New<Options>()), graph_(graph) {}
|
||||
|
||||
Ptr<Options> getOptions() {
|
||||
return options_;
|
||||
}
|
||||
|
||||
virtual std::string str() {
|
||||
return options_->str();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Builder& operator()(const std::string& key, T value) {
|
||||
options_->set(key, value);
|
||||
return *this;
|
||||
}
|
||||
|
||||
Builder& operator()(const std::string& yaml) {
|
||||
options_->parse(yaml);
|
||||
return *this;
|
||||
}
|
||||
|
||||
Builder& operator()(YAML::Node yaml) {
|
||||
options_->merge(yaml);
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T get(const std::string& key) {
|
||||
return options_->get<T>(key);
|
||||
}
|
||||
|
||||
template <typename ...Args>
|
||||
Ptr<Obj> create(Args ...args) {
|
||||
return Create::create(graph_, options_, args...);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
}
|
136
src/models/s2s.h
136
src/models/s2s.h
@ -3,6 +3,7 @@
|
||||
#include "rnn/attention.h"
|
||||
#include "rnn/rnn.h"
|
||||
#include "rnn/cells.h"
|
||||
#include "common/options.h"
|
||||
|
||||
#include "models/encdec.h"
|
||||
|
||||
@ -91,36 +92,43 @@ public:
|
||||
x = dropout(x, mask = srcWordDrop);
|
||||
}
|
||||
|
||||
auto type = options_->get<std::string>("cell-enc");
|
||||
UTIL_THROW_IF2(amun && type != "gru",
|
||||
auto cellType = options_->get<std::string>("cell-enc");
|
||||
UTIL_THROW_IF2(amun && cellType != "gru",
|
||||
"--type amun does not currently support other rnn cells than gru, "
|
||||
"use --type s2s");
|
||||
|
||||
auto cellFw = rnn::cell(type)(graph,
|
||||
prefix_ + "_bi",
|
||||
dimSrcEmb,
|
||||
dimEncState,
|
||||
normalize = layerNorm,
|
||||
dropout_prob = dropoutRnn);
|
||||
auto cellFw = rnn::cell(graph)
|
||||
("type", cellType)
|
||||
("prefix", prefix_ + "_bi")
|
||||
("dimInput", dimSrcEmb)
|
||||
("dimState", dimEncState)
|
||||
("dropout", dropoutRnn)
|
||||
("normalize", layerNorm)
|
||||
("final", false)
|
||||
.create();
|
||||
auto xFw = rnn::RNN(cellFw)(x);
|
||||
|
||||
auto cellBw = rnn::cell(type)(graph,
|
||||
prefix_ + "_bi_r",
|
||||
dimSrcEmb,
|
||||
dimEncState,
|
||||
normalize = layerNorm,
|
||||
dropout_prob = dropoutRnn);
|
||||
auto cellBw = rnn::cell(graph)
|
||||
("type", cellType)
|
||||
("prefix", prefix_ + "_bi_r")
|
||||
("dimInput", dimSrcEmb)
|
||||
("dimState", dimEncState)
|
||||
("dropout", dropoutRnn)
|
||||
("normalize", layerNorm)
|
||||
("final", false)
|
||||
.create();
|
||||
auto xBw = rnn::RNN(cellBw, direction = dir::backward)(x, xMask);
|
||||
|
||||
auto xContext = concatenate({xFw, xBw}, axis = 1);
|
||||
|
||||
if(encoderLayers > 1) {
|
||||
auto layerCells
|
||||
= rnn::cells(type, encoderLayers-1)(graph, prefix_,
|
||||
2 * dimEncState, dimEncState,
|
||||
normalize=layerNorm, dropout_prob=dropoutRnn);
|
||||
|
||||
xContext = rnn::MLRNN(layerCells, skip=skipDepth)(xContext);
|
||||
}
|
||||
//if(encoderLayers > 1) {
|
||||
// auto layerCells
|
||||
// = rnn::cells(cellType, encoderLayers-1)(graph, prefix_,
|
||||
// 2 * dimEncState, dimEncState,
|
||||
// normalize=layerNorm, dropout_prob=dropoutRnn);
|
||||
//
|
||||
// xContext = rnn::MLRNN(layerCells, skip=skipDepth)(xContext);
|
||||
//}
|
||||
|
||||
return New<EncoderStateS2S>(xContext, xMask, batch);
|
||||
|
||||
@ -205,25 +213,35 @@ public:
|
||||
}
|
||||
|
||||
if(!attCell_) {
|
||||
auto attCell = New<rnn::StackedCell>(dimTrgEmb, dimDecState);
|
||||
auto attCell = rnn::stacked_cell(graph).create();
|
||||
|
||||
auto cell1 = rnn::cell(graph)
|
||||
("type", cellType)
|
||||
("prefix", prefix_ + "_cell1")
|
||||
("dimInput", dimTrgEmb)
|
||||
("dimState", dimDecState)
|
||||
("dropout", dropoutRnn)
|
||||
("normalize", layerNorm)
|
||||
("final", false)
|
||||
.create();
|
||||
|
||||
auto attention = rnn::attention(graph)
|
||||
("prefix", prefix_)
|
||||
("dimState", dimDecState)
|
||||
("dropout", dropoutRnn)
|
||||
("normalize", layerNorm)
|
||||
.create(state->getEncoderState());
|
||||
|
||||
auto cell2 = rnn::cell(graph)
|
||||
("type", cellType)
|
||||
("prefix", prefix_ + "_cell2")
|
||||
("dimInput", attention->dimOutput())
|
||||
("dimState", dimDecState)
|
||||
("dropout", dropoutRnn)
|
||||
("normalize", layerNorm)
|
||||
("final", true)
|
||||
.create();
|
||||
|
||||
auto cell1 = rnn::cell(cellType)(graph,
|
||||
prefix_ + "_cell1",
|
||||
dimTrgEmb,
|
||||
dimDecState,
|
||||
dropout_prob = dropoutRnn,
|
||||
normalize = layerNorm);
|
||||
auto attention = New<rnn::Attention>(prefix_,
|
||||
state->getEncoderState(),
|
||||
dimDecState,
|
||||
dropout_prob = dropoutRnn,
|
||||
normalize = layerNorm);
|
||||
auto cell2 = rnn::cell(cellType)(graph,
|
||||
prefix_ + "_cell2",
|
||||
attention->dimOutput(),
|
||||
dimDecState,
|
||||
dropout_prob = dropoutRnn,
|
||||
normalize = layerNorm);
|
||||
attCell->push_back(cell1);
|
||||
attCell->push_back(attention);
|
||||
attCell->push_back(cell2);
|
||||
@ -242,25 +260,25 @@ public:
|
||||
keywords::axis = 2);
|
||||
|
||||
|
||||
if(decoderLayers > 1) {
|
||||
rnn::States statesIn;
|
||||
for(int i = 1; i < stateS2S->getStates().size(); ++i)
|
||||
statesIn.push_back(stateS2S->getStates()[i]);
|
||||
|
||||
if(!rnnLn) {
|
||||
auto layerCells
|
||||
= rnn::cells(cellType, decoderLayers - 1)(graph, prefix_,
|
||||
dimDecState, dimDecState,
|
||||
normalize = layerNorm,
|
||||
dropout_prob = dropoutRnn);
|
||||
|
||||
rnnLn = New<rnn::MLRNN>(layerCells, skip = skipDepth, skip_first = skipDepth);
|
||||
}
|
||||
|
||||
decContext = (*rnnLn)(decContext, statesIn);
|
||||
for(auto state : rnnLn->last())
|
||||
decStates.push_back(state);
|
||||
}
|
||||
//if(decoderLayers > 1) {
|
||||
// rnn::States statesIn;
|
||||
// for(int i = 1; i < stateS2S->getStates().size(); ++i)
|
||||
// statesIn.push_back(stateS2S->getStates()[i]);
|
||||
//
|
||||
// if(!rnnLn) {
|
||||
// auto layerCells
|
||||
// = rnn::cells(cellType, decoderLayers - 1)(graph, prefix_,
|
||||
// dimDecState, dimDecState,
|
||||
// normalize = layerNorm,
|
||||
// dropout_prob = dropoutRnn);
|
||||
//
|
||||
// rnnLn = New<rnn::MLRNN>(layerCells, skip = skipDepth, skip_first = skipDepth);
|
||||
// }
|
||||
//
|
||||
// decContext = (*rnnLn)(decContext, statesIn);
|
||||
// for(auto state : rnnLn->last())
|
||||
// decStates.push_back(state);
|
||||
//}
|
||||
|
||||
//// 2-layer feedforward network for outputs and cost
|
||||
auto logitsL1
|
||||
|
@ -26,27 +26,26 @@ private:
|
||||
std::vector<Expr> contexts_;
|
||||
std::vector<Expr> alignments_;
|
||||
bool layerNorm_;
|
||||
|
||||
float dropout_;
|
||||
|
||||
Expr contextDropped_;
|
||||
Expr dropMaskContext_;
|
||||
Expr dropMaskState_;
|
||||
|
||||
Expr cov_;
|
||||
|
||||
public:
|
||||
template <typename... Args>
|
||||
GlobalAttention(const std::string prefix,
|
||||
Ptr<EncoderState> encState,
|
||||
int dimDecState,
|
||||
Args... args)
|
||||
: encState_(encState),
|
||||
contextDropped_(encState->getContext()),
|
||||
layerNorm_(Get(keywords::normalize, false, args...)),
|
||||
cov_(Get(keywords::coverage, nullptr, args...)) {
|
||||
int dimEncState = encState_->getContext()->shape()[1];
|
||||
GlobalAttention(Ptr<ExpressionGraph> graph,
|
||||
Ptr<Options> options,
|
||||
Ptr<EncoderState> encState)
|
||||
: CellInput(options),
|
||||
encState_(encState),
|
||||
contextDropped_(encState->getContext()) {
|
||||
|
||||
auto graph = encState_->getContext()->graph();
|
||||
int dimDecState = options_->get<int>("dimState");
|
||||
dropout_ = options_->get<float>("dropout");
|
||||
layerNorm_ = options_->get<bool>("normalize");
|
||||
std::string prefix = options_->get<std::string>("prefix");
|
||||
|
||||
int dimEncState = encState_->getContext()->shape()[1];
|
||||
|
||||
Wa_ = graph->param(prefix + "_W_comb_att",
|
||||
{dimDecState, dimEncState},
|
||||
@ -60,7 +59,6 @@ public:
|
||||
ba_ = graph->param(
|
||||
prefix + "_b_att", {1, dimEncState}, keywords::init = inits::zeros);
|
||||
|
||||
dropout_ = Get(keywords::dropout_prob, 0.0f, args...);
|
||||
if(dropout_ > 0.0f) {
|
||||
dropMaskContext_ = graph->dropout(dropout_, {1, dimEncState});
|
||||
dropMaskState_ = graph->dropout(dropout_, {1, dimDecState});
|
||||
@ -130,6 +128,8 @@ public:
|
||||
|
||||
using Attention = GlobalAttention;
|
||||
|
||||
typedef Builder<Attention> attention;
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
137
src/rnn/cells.h
137
src/rnn/cells.h
@ -29,12 +29,15 @@ private:
|
||||
Expr dropMaskS_;
|
||||
|
||||
public:
|
||||
template <typename... Args>
|
||||
Tanh(Ptr<ExpressionGraph> graph,
|
||||
const std::string prefix,
|
||||
int dimInput,
|
||||
int dimState,
|
||||
Args... args) : Cell(dimInput, dimState) {
|
||||
Tanh(Ptr<ExpressionGraph> graph, Ptr<Options> options)
|
||||
: Cell(options) {
|
||||
|
||||
int dimInput = options_->get<int>("dimInput");
|
||||
int dimState = options_->get<int>("dimState");
|
||||
std::string prefix = options_->get<std::string>("prefix");
|
||||
layerNorm_ = options_->get<bool>("normalize");
|
||||
dropout_ = options_->get<float>("dropout");
|
||||
|
||||
U_ = graph->param(prefix + "_U",
|
||||
{dimState, dimState},
|
||||
keywords::init = inits::glorot_uniform);
|
||||
@ -44,9 +47,6 @@ public:
|
||||
b_ = graph->param(
|
||||
prefix + "_b", {1, dimState}, keywords::init = inits::zeros);
|
||||
|
||||
layerNorm_ = Get(keywords::normalize, false, args...);
|
||||
|
||||
dropout_ = Get(keywords::dropout_prob, 0.0f, args...);
|
||||
if(dropout_ > 0.0f) {
|
||||
dropMaskX_ = graph->dropout(dropout_, {1, dimInput});
|
||||
dropMaskS_ = graph->dropout(dropout_, {1, dimState});
|
||||
@ -129,13 +129,18 @@ protected:
|
||||
Expr dropMaskS_;
|
||||
|
||||
public:
|
||||
template <typename... Args>
|
||||
GRU(Ptr<ExpressionGraph> graph,
|
||||
const std::string prefix,
|
||||
int dimInput,
|
||||
int dimState,
|
||||
Args... args)
|
||||
: Cell(dimInput, dimState), prefix_{prefix} {
|
||||
Ptr<Options> options)
|
||||
: Cell(options) {
|
||||
|
||||
int dimInput = options_->get<int>("dimInput");
|
||||
int dimState = options_->get<int>("dimState");
|
||||
std::string prefix = options_->get<std::string>("prefix");
|
||||
layerNorm_ = options_->get<bool>("normalize");
|
||||
dropout_ = options_->get<float>("dropout");
|
||||
final_ = options_->get<bool>("final");
|
||||
|
||||
|
||||
auto U = graph->param(prefix + "_U",
|
||||
{dimState, 2 * dimState},
|
||||
keywords::init = inits::glorot_uniform);
|
||||
@ -165,10 +170,6 @@ public:
|
||||
// b_ = graph->param(prefix + "_b", {1, 3 * dimState},
|
||||
// keywords::init=inits::zeros);
|
||||
|
||||
final_ = Get(keywords::final, false, args...);
|
||||
layerNorm_ = Get(keywords::normalize, false, args...);
|
||||
|
||||
dropout_ = Get(keywords::dropout_prob, 0.0f, args...);
|
||||
if(dropout_ > 0.0f) {
|
||||
dropMaskX_ = graph->dropout(dropout_, {1, dimInput});
|
||||
dropMaskS_ = graph->dropout(dropout_, {1, dimState});
|
||||
@ -250,13 +251,16 @@ protected:
|
||||
Expr dropMaskS_;
|
||||
|
||||
public:
|
||||
template <typename... Args>
|
||||
FastLSTM(Ptr<ExpressionGraph> graph,
|
||||
const std::string prefix,
|
||||
int dimInput,
|
||||
int dimState,
|
||||
Args... args)
|
||||
: Cell(dimInput, dimState), prefix_{prefix} {
|
||||
Ptr<Options> options)
|
||||
: Cell(options) {
|
||||
|
||||
int dimInput = options_->get<int>("dimInput");
|
||||
int dimState = options_->get<int>("dimState");
|
||||
std::string prefix = options_->get<std::string>("prefix");
|
||||
layerNorm_ = options_->get<bool>("normalize");
|
||||
dropout_ = options_->get<float>("dropout");
|
||||
|
||||
|
||||
U_ = graph->param(prefix + "_U", {dimState, 4 * dimState},
|
||||
keywords::init=inits::glorot_uniform);
|
||||
@ -265,9 +269,6 @@ public:
|
||||
b_ = graph->param(prefix + "_b", {1, 4 * dimState},
|
||||
keywords::init=inits::zeros);
|
||||
|
||||
layerNorm_ = Get(keywords::normalize, false, args...);
|
||||
|
||||
dropout_ = Get(keywords::dropout_prob, 0.0f, args...);
|
||||
if(dropout_ > 0.0f) {
|
||||
dropMaskX_ = graph->dropout(dropout_, {1, dimInput});
|
||||
dropMaskS_ = graph->dropout(dropout_, {1, dimState});
|
||||
@ -343,19 +344,18 @@ using LSTM = FastLSTM;
|
||||
|
||||
template <class CellType>
|
||||
class Multiplicative : public CellType {
|
||||
private:
|
||||
protected:
|
||||
Expr Um_, Wm_, bm_;
|
||||
Expr gamma1m_, gamma2m_;
|
||||
|
||||
public:
|
||||
|
||||
template <typename... Args>
|
||||
Multiplicative(Ptr<ExpressionGraph> graph,
|
||||
const std::string prefix,
|
||||
int dimInput,
|
||||
int dimState,
|
||||
Args... args)
|
||||
: CellType(graph, prefix, dimInput, dimState, args...) {
|
||||
Ptr<Options> options)
|
||||
: CellType(graph, options) {
|
||||
|
||||
int dimInput = options->get<int>("dimInput");
|
||||
int dimState = options->get<int>("dimState");
|
||||
std::string prefix = options->get<std::string>("prefix");
|
||||
|
||||
Um_ = graph->param(prefix + "_Um", {dimState, dimState},
|
||||
keywords::init=inits::glorot_uniform);
|
||||
@ -415,7 +415,6 @@ using MGRU = Multiplicative<GRU>;
|
||||
|
||||
class SlowLSTM : public Cell {
|
||||
private:
|
||||
std::string prefix_;
|
||||
|
||||
Expr Uf_, Wf_, bf_;
|
||||
Expr Ui_, Wi_, bi_;
|
||||
@ -423,13 +422,13 @@ private:
|
||||
Expr Uc_, Wc_, bc_;
|
||||
|
||||
public:
|
||||
template <typename... Args>
|
||||
SlowLSTM(Ptr<ExpressionGraph> graph,
|
||||
const std::string prefix,
|
||||
int dimInput,
|
||||
int dimState,
|
||||
Args... args)
|
||||
: Cell(dimInput, dimState), prefix_{prefix} {
|
||||
Ptr<Options> options)
|
||||
: Cell(options) {
|
||||
|
||||
int dimInput = options_->get<int>("dimInput");
|
||||
int dimState = options_->get<int>("dimState");
|
||||
std::string prefix = options->get<std::string>("prefix");
|
||||
|
||||
Uf_ = graph->param(prefix + "_Uf", {dimState, dimState},
|
||||
keywords::init=inits::glorot_uniform);
|
||||
@ -512,18 +511,15 @@ public:
|
||||
|
||||
class TestLSTM : public Cell {
|
||||
private:
|
||||
std::string prefix_;
|
||||
|
||||
Expr U_, W_, b_;
|
||||
|
||||
public:
|
||||
template <typename... Args>
|
||||
TestLSTM(Ptr<ExpressionGraph> graph,
|
||||
const std::string prefix,
|
||||
int dimInput,
|
||||
int dimState,
|
||||
Args... args)
|
||||
: Cell(dimInput, dimState), prefix_{prefix} {
|
||||
TestLSTM(Ptr<ExpressionGraph> graph, Ptr<Options> options)
|
||||
: Cell(options) {
|
||||
|
||||
int dimInput = options_->get<int>("dimInput");
|
||||
int dimState = options_->get<int>("dimState");
|
||||
std::string prefix = options->get<std::string>("prefix");
|
||||
|
||||
auto Uf = graph->param(prefix + "_Uf", {dimState, dimState},
|
||||
keywords::init=inits::glorot_uniform);
|
||||
@ -602,30 +598,27 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class cell {
|
||||
private:
|
||||
std::string type_;
|
||||
|
||||
public:
|
||||
cell(const std::string& type)
|
||||
: type_(type) {}
|
||||
|
||||
struct CellCreate {
|
||||
template <typename ...Args>
|
||||
Ptr<Cell> operator()(Args&& ...args) {
|
||||
if(type_ == "gru")
|
||||
return New<GRU>(args...);
|
||||
if(type_ == "lstm")
|
||||
return New<LSTM>(args...);
|
||||
if(type_ == "mlstm")
|
||||
return New<MLSTM>(args...);
|
||||
if(type_ == "mgru")
|
||||
return New<MGRU>(args...);
|
||||
if(type_ == "tanh")
|
||||
return New<Tanh>(args...);
|
||||
return New<GRU>(args...);
|
||||
static Ptr<Cell> create(Ptr<ExpressionGraph> graph, Ptr<Options> options, Args ...args) {
|
||||
std::string type = options->get<std::string>("type");
|
||||
|
||||
if(type == "gru")
|
||||
return New<GRU>(graph, options, args...);
|
||||
if(type == "lstm")
|
||||
return New<LSTM>(graph, options, args...);
|
||||
if(type == "mlstm")
|
||||
return New<MLSTM>(graph, options, args...);
|
||||
if(type == "mgru")
|
||||
return New<MGRU>(graph, options, args...);
|
||||
if(type == "tanh")
|
||||
return New<Tanh>(graph, options, args...);
|
||||
return New<GRU>(graph, options, args...);
|
||||
}
|
||||
};
|
||||
|
||||
typedef Builder<Cell, CellCreate> cell;
|
||||
|
||||
class cells {
|
||||
private:
|
||||
std::string type_;
|
||||
|
@ -72,7 +72,7 @@ private:
|
||||
States apply(const Expr input, const Expr mask = nullptr) {
|
||||
auto graph = input->graph();
|
||||
int dimBatch = input->shape()[0];
|
||||
int dimState = cell_->dimState();
|
||||
int dimState = cell_->getOptions()->get<int>("dimState");
|
||||
|
||||
auto output = graph->zeros(keywords::shape = {dimBatch, dimState});
|
||||
Expr cell = output;
|
||||
|
@ -4,6 +4,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include "common/definitions.h"
|
||||
#include "common/options.h"
|
||||
#include "graph/expression_graph.h"
|
||||
|
||||
namespace marian {
|
||||
@ -86,7 +87,13 @@ class States {
|
||||
class Cell;
|
||||
struct CellInput;
|
||||
|
||||
struct Stackable : std::enable_shared_from_this<Stackable> {
|
||||
class Stackable : public std::enable_shared_from_this<Stackable> {
|
||||
protected:
|
||||
Ptr<Options> options_;
|
||||
|
||||
public:
|
||||
Stackable(Ptr<Options> options) : options_(options) {}
|
||||
|
||||
// required for dynamic_pointer_cast to detect polymorphism
|
||||
virtual ~Stackable() {}
|
||||
|
||||
@ -99,29 +106,32 @@ struct Stackable : std::enable_shared_from_this<Stackable> {
|
||||
inline bool is() {
|
||||
return as<Cast>() != nullptr;
|
||||
}
|
||||
|
||||
Ptr<Options> getOptions() {
|
||||
return options_;
|
||||
}
|
||||
};
|
||||
|
||||
struct CellInput : public Stackable {
|
||||
// Change this to apply(State)
|
||||
class CellInput : public Stackable {
|
||||
public:
|
||||
CellInput(Ptr<Options> options)
|
||||
: Stackable(options) { }
|
||||
|
||||
virtual Expr apply(State) = 0;
|
||||
virtual int dimOutput() = 0;
|
||||
};
|
||||
|
||||
class Cell : public Stackable {
|
||||
protected:
|
||||
int dimInput_;
|
||||
int dimState_;
|
||||
|
||||
public:
|
||||
Cell(int dimInput, int dimState)
|
||||
: dimInput_(dimInput), dimState_(dimState) {}
|
||||
Cell(Ptr<Options> options)
|
||||
: Stackable(options) {
|
||||
|
||||
virtual int dimInput() {
|
||||
return dimInput_;
|
||||
}
|
||||
|
||||
virtual int dimState() {
|
||||
return dimState_;
|
||||
//options_->set("prefix", "");
|
||||
//options_->set("dimInput", 512);
|
||||
//options_->set("dimState", 1024);
|
||||
//options_->set("dropout", 0);
|
||||
//options_->set("normalize", false);
|
||||
//options_->set("final", false);
|
||||
}
|
||||
|
||||
State apply(std::vector<Expr> inputs, State state, Expr mask = nullptr) {
|
||||
@ -133,12 +143,12 @@ public:
|
||||
};
|
||||
|
||||
class MultiCellInput : public CellInput {
|
||||
private:
|
||||
protected:
|
||||
std::vector<Ptr<CellInput>> inputs_;
|
||||
|
||||
public:
|
||||
MultiCellInput(const std::vector<Ptr<CellInput>>& inputs)
|
||||
: inputs_(inputs) {}
|
||||
MultiCellInput(const std::vector<Ptr<CellInput>>& inputs, Ptr<Options> options)
|
||||
: CellInput(options), inputs_(inputs) {}
|
||||
|
||||
void push_back(Ptr<CellInput> input) {
|
||||
inputs_.push_back(input);
|
||||
@ -164,16 +174,16 @@ public:
|
||||
};
|
||||
|
||||
class StackedCell : public Cell {
|
||||
private:
|
||||
protected:
|
||||
std::vector<Ptr<Stackable>> stackables_;
|
||||
std::vector<Expr> lastInputs_;
|
||||
|
||||
public:
|
||||
StackedCell(int dimInput, int dimState) : Cell(dimInput, dimState) {}
|
||||
StackedCell(Ptr<ExpressionGraph>, Ptr<Options> options) : Cell(options) {}
|
||||
|
||||
StackedCell(int dimInput, int dimState,
|
||||
StackedCell(Ptr<ExpressionGraph>, Ptr<Options> options,
|
||||
const std::vector<Ptr<Stackable>>& stackables)
|
||||
: Cell(dimInput, dimState), stackables_(stackables) {}
|
||||
: Cell(options), stackables_(stackables) {}
|
||||
|
||||
void push_back(Ptr<Stackable> stackable) {
|
||||
stackables_.push_back(stackable);
|
||||
@ -212,5 +222,7 @@ public:
|
||||
|
||||
};
|
||||
|
||||
typedef Builder<rnn::StackedCell> stacked_cell;
|
||||
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user