some more refactoring

This commit is contained in:
Marcin Junczys-Dowmunt 2017-06-17 21:22:27 +02:00
parent 78c9c30aaf
commit 84c3121885
7 changed files with 300 additions and 170 deletions

View File

@ -17,7 +17,13 @@ set(CMAKE_CXX_FLAGS_DEBUG " -std=c++11 -g -O0 -fPIC -Wno-unused-result -Wno-depr
set(CMAKE_CXX_FLAGS_PROFILE "${CMAKE_CXX_FLAGS_RELEASE} -g -pg")
set(CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS_RELEASE})
LIST(APPEND CUDA_NVCC_FLAGS -std=c++11; --default-stream per-thread; -O3; --use_fast_math; -Xcompiler '-fPIC'; -arch=sm_35;)
if (CMAKE_BUILD_TYPE STREQUAL "Debug")
LIST(APPEND CUDA_NVCC_FLAGS -std=c++11; --default-stream per-thread; -O0; -g; -Xcompiler '-fPIC'; -arch=sm_35;)
else(CMAKE_BUILD_TYPE STREQUAL "Debug")
LIST(APPEND CUDA_NVCC_FLAGS -std=c++11; --default-stream per-thread; -O3; --use_fast_math; -Xcompiler '-fPIC'; -arch=sm_35;)
endif(CMAKE_BUILD_TYPE STREQUAL "Debug")
LIST(REMOVE_DUPLICATES CUDA_NVCC_FLAGS)
SET(CUDA_PROPAGATE_HOST_FLAGS OFF)

101
src/common/options.h Normal file
View File

@ -0,0 +1,101 @@
#pragma once
#include <string>
#include <sstream>
#include "common/definitions.h"
#include "3rd_party/yaml-cpp/yaml.h"
namespace marian {
class Options {
protected:
YAML::Node options_;
public:
YAML::Node& getOptions() {
return options_;
}
void parse(const std::string& yaml) {
auto node = YAML::Load(yaml);
for(auto it : node)
options_[it.first.as<std::string>()] = it.second;
}
void merge(YAML::Node& node) {
for(auto it : node)
options_[it.first.as<std::string>()] = it.second;
}
std::string str() {
std::stringstream ss;
ss << options_;
return ss.str();
}
template <typename T>
void set(const std::string& key, T value) {
options_[key] = value;
}
template <typename T>
T get(const std::string& key) {
return options_[key].as<T>();
}
};
template <class Obj>
struct DefaultCreate {
template <typename ...Args>
static Ptr<Obj> create(Ptr<ExpressionGraph> graph, Ptr<Options> options, Args ...args) {
return New<Obj>(graph, options, args...);
}
};
template <class Obj, class Create=DefaultCreate<Obj>>
class Builder {
protected:
Ptr<Options> options_;
Ptr<ExpressionGraph> graph_;
public:
Builder(Ptr<ExpressionGraph> graph)
: options_(New<Options>()), graph_(graph) {}
Ptr<Options> getOptions() {
return options_;
}
virtual std::string str() {
return options_->str();
}
template <typename T>
Builder& operator()(const std::string& key, T value) {
options_->set(key, value);
return *this;
}
Builder& operator()(const std::string& yaml) {
options_->parse(yaml);
return *this;
}
Builder& operator()(YAML::Node yaml) {
options_->merge(yaml);
return *this;
}
template <typename T>
T get(const std::string& key) {
return options_->get<T>(key);
}
template <typename ...Args>
Ptr<Obj> create(Args ...args) {
return Create::create(graph_, options_, args...);
}
};
}

View File

@ -3,6 +3,7 @@
#include "rnn/attention.h"
#include "rnn/rnn.h"
#include "rnn/cells.h"
#include "common/options.h"
#include "models/encdec.h"
@ -91,36 +92,43 @@ public:
x = dropout(x, mask = srcWordDrop);
}
auto type = options_->get<std::string>("cell-enc");
UTIL_THROW_IF2(amun && type != "gru",
auto cellType = options_->get<std::string>("cell-enc");
UTIL_THROW_IF2(amun && cellType != "gru",
"--type amun does not currently support other rnn cells than gru, "
"use --type s2s");
auto cellFw = rnn::cell(type)(graph,
prefix_ + "_bi",
dimSrcEmb,
dimEncState,
normalize = layerNorm,
dropout_prob = dropoutRnn);
auto cellFw = rnn::cell(graph)
("type", cellType)
("prefix", prefix_ + "_bi")
("dimInput", dimSrcEmb)
("dimState", dimEncState)
("dropout", dropoutRnn)
("normalize", layerNorm)
("final", false)
.create();
auto xFw = rnn::RNN(cellFw)(x);
auto cellBw = rnn::cell(type)(graph,
prefix_ + "_bi_r",
dimSrcEmb,
dimEncState,
normalize = layerNorm,
dropout_prob = dropoutRnn);
auto cellBw = rnn::cell(graph)
("type", cellType)
("prefix", prefix_ + "_bi_r")
("dimInput", dimSrcEmb)
("dimState", dimEncState)
("dropout", dropoutRnn)
("normalize", layerNorm)
("final", false)
.create();
auto xBw = rnn::RNN(cellBw, direction = dir::backward)(x, xMask);
auto xContext = concatenate({xFw, xBw}, axis = 1);
if(encoderLayers > 1) {
auto layerCells
= rnn::cells(type, encoderLayers-1)(graph, prefix_,
2 * dimEncState, dimEncState,
normalize=layerNorm, dropout_prob=dropoutRnn);
xContext = rnn::MLRNN(layerCells, skip=skipDepth)(xContext);
}
//if(encoderLayers > 1) {
// auto layerCells
// = rnn::cells(cellType, encoderLayers-1)(graph, prefix_,
// 2 * dimEncState, dimEncState,
// normalize=layerNorm, dropout_prob=dropoutRnn);
//
// xContext = rnn::MLRNN(layerCells, skip=skipDepth)(xContext);
//}
return New<EncoderStateS2S>(xContext, xMask, batch);
@ -205,25 +213,35 @@ public:
}
if(!attCell_) {
auto attCell = New<rnn::StackedCell>(dimTrgEmb, dimDecState);
auto attCell = rnn::stacked_cell(graph).create();
auto cell1 = rnn::cell(graph)
("type", cellType)
("prefix", prefix_ + "_cell1")
("dimInput", dimTrgEmb)
("dimState", dimDecState)
("dropout", dropoutRnn)
("normalize", layerNorm)
("final", false)
.create();
auto attention = rnn::attention(graph)
("prefix", prefix_)
("dimState", dimDecState)
("dropout", dropoutRnn)
("normalize", layerNorm)
.create(state->getEncoderState());
auto cell2 = rnn::cell(graph)
("type", cellType)
("prefix", prefix_ + "_cell2")
("dimInput", attention->dimOutput())
("dimState", dimDecState)
("dropout", dropoutRnn)
("normalize", layerNorm)
("final", true)
.create();
auto cell1 = rnn::cell(cellType)(graph,
prefix_ + "_cell1",
dimTrgEmb,
dimDecState,
dropout_prob = dropoutRnn,
normalize = layerNorm);
auto attention = New<rnn::Attention>(prefix_,
state->getEncoderState(),
dimDecState,
dropout_prob = dropoutRnn,
normalize = layerNorm);
auto cell2 = rnn::cell(cellType)(graph,
prefix_ + "_cell2",
attention->dimOutput(),
dimDecState,
dropout_prob = dropoutRnn,
normalize = layerNorm);
attCell->push_back(cell1);
attCell->push_back(attention);
attCell->push_back(cell2);
@ -242,25 +260,25 @@ public:
keywords::axis = 2);
if(decoderLayers > 1) {
rnn::States statesIn;
for(int i = 1; i < stateS2S->getStates().size(); ++i)
statesIn.push_back(stateS2S->getStates()[i]);
if(!rnnLn) {
auto layerCells
= rnn::cells(cellType, decoderLayers - 1)(graph, prefix_,
dimDecState, dimDecState,
normalize = layerNorm,
dropout_prob = dropoutRnn);
rnnLn = New<rnn::MLRNN>(layerCells, skip = skipDepth, skip_first = skipDepth);
}
decContext = (*rnnLn)(decContext, statesIn);
for(auto state : rnnLn->last())
decStates.push_back(state);
}
//if(decoderLayers > 1) {
// rnn::States statesIn;
// for(int i = 1; i < stateS2S->getStates().size(); ++i)
// statesIn.push_back(stateS2S->getStates()[i]);
//
// if(!rnnLn) {
// auto layerCells
// = rnn::cells(cellType, decoderLayers - 1)(graph, prefix_,
// dimDecState, dimDecState,
// normalize = layerNorm,
// dropout_prob = dropoutRnn);
//
// rnnLn = New<rnn::MLRNN>(layerCells, skip = skipDepth, skip_first = skipDepth);
// }
//
// decContext = (*rnnLn)(decContext, statesIn);
// for(auto state : rnnLn->last())
// decStates.push_back(state);
//}
//// 2-layer feedforward network for outputs and cost
auto logitsL1

View File

@ -26,27 +26,26 @@ private:
std::vector<Expr> contexts_;
std::vector<Expr> alignments_;
bool layerNorm_;
float dropout_;
Expr contextDropped_;
Expr dropMaskContext_;
Expr dropMaskState_;
Expr cov_;
public:
template <typename... Args>
GlobalAttention(const std::string prefix,
Ptr<EncoderState> encState,
int dimDecState,
Args... args)
: encState_(encState),
contextDropped_(encState->getContext()),
layerNorm_(Get(keywords::normalize, false, args...)),
cov_(Get(keywords::coverage, nullptr, args...)) {
int dimEncState = encState_->getContext()->shape()[1];
GlobalAttention(Ptr<ExpressionGraph> graph,
Ptr<Options> options,
Ptr<EncoderState> encState)
: CellInput(options),
encState_(encState),
contextDropped_(encState->getContext()) {
auto graph = encState_->getContext()->graph();
int dimDecState = options_->get<int>("dimState");
dropout_ = options_->get<float>("dropout");
layerNorm_ = options_->get<bool>("normalize");
std::string prefix = options_->get<std::string>("prefix");
int dimEncState = encState_->getContext()->shape()[1];
Wa_ = graph->param(prefix + "_W_comb_att",
{dimDecState, dimEncState},
@ -60,7 +59,6 @@ public:
ba_ = graph->param(
prefix + "_b_att", {1, dimEncState}, keywords::init = inits::zeros);
dropout_ = Get(keywords::dropout_prob, 0.0f, args...);
if(dropout_ > 0.0f) {
dropMaskContext_ = graph->dropout(dropout_, {1, dimEncState});
dropMaskState_ = graph->dropout(dropout_, {1, dimDecState});
@ -130,6 +128,8 @@ public:
using Attention = GlobalAttention;
typedef Builder<Attention> attention;
}
}

View File

@ -29,12 +29,15 @@ private:
Expr dropMaskS_;
public:
template <typename... Args>
Tanh(Ptr<ExpressionGraph> graph,
const std::string prefix,
int dimInput,
int dimState,
Args... args) : Cell(dimInput, dimState) {
Tanh(Ptr<ExpressionGraph> graph, Ptr<Options> options)
: Cell(options) {
int dimInput = options_->get<int>("dimInput");
int dimState = options_->get<int>("dimState");
std::string prefix = options_->get<std::string>("prefix");
layerNorm_ = options_->get<bool>("normalize");
dropout_ = options_->get<float>("dropout");
U_ = graph->param(prefix + "_U",
{dimState, dimState},
keywords::init = inits::glorot_uniform);
@ -44,9 +47,6 @@ public:
b_ = graph->param(
prefix + "_b", {1, dimState}, keywords::init = inits::zeros);
layerNorm_ = Get(keywords::normalize, false, args...);
dropout_ = Get(keywords::dropout_prob, 0.0f, args...);
if(dropout_ > 0.0f) {
dropMaskX_ = graph->dropout(dropout_, {1, dimInput});
dropMaskS_ = graph->dropout(dropout_, {1, dimState});
@ -129,13 +129,18 @@ protected:
Expr dropMaskS_;
public:
template <typename... Args>
GRU(Ptr<ExpressionGraph> graph,
const std::string prefix,
int dimInput,
int dimState,
Args... args)
: Cell(dimInput, dimState), prefix_{prefix} {
Ptr<Options> options)
: Cell(options) {
int dimInput = options_->get<int>("dimInput");
int dimState = options_->get<int>("dimState");
std::string prefix = options_->get<std::string>("prefix");
layerNorm_ = options_->get<bool>("normalize");
dropout_ = options_->get<float>("dropout");
final_ = options_->get<bool>("final");
auto U = graph->param(prefix + "_U",
{dimState, 2 * dimState},
keywords::init = inits::glorot_uniform);
@ -165,10 +170,6 @@ public:
// b_ = graph->param(prefix + "_b", {1, 3 * dimState},
// keywords::init=inits::zeros);
final_ = Get(keywords::final, false, args...);
layerNorm_ = Get(keywords::normalize, false, args...);
dropout_ = Get(keywords::dropout_prob, 0.0f, args...);
if(dropout_ > 0.0f) {
dropMaskX_ = graph->dropout(dropout_, {1, dimInput});
dropMaskS_ = graph->dropout(dropout_, {1, dimState});
@ -250,13 +251,16 @@ protected:
Expr dropMaskS_;
public:
template <typename... Args>
FastLSTM(Ptr<ExpressionGraph> graph,
const std::string prefix,
int dimInput,
int dimState,
Args... args)
: Cell(dimInput, dimState), prefix_{prefix} {
Ptr<Options> options)
: Cell(options) {
int dimInput = options_->get<int>("dimInput");
int dimState = options_->get<int>("dimState");
std::string prefix = options_->get<std::string>("prefix");
layerNorm_ = options_->get<bool>("normalize");
dropout_ = options_->get<float>("dropout");
U_ = graph->param(prefix + "_U", {dimState, 4 * dimState},
keywords::init=inits::glorot_uniform);
@ -265,9 +269,6 @@ public:
b_ = graph->param(prefix + "_b", {1, 4 * dimState},
keywords::init=inits::zeros);
layerNorm_ = Get(keywords::normalize, false, args...);
dropout_ = Get(keywords::dropout_prob, 0.0f, args...);
if(dropout_ > 0.0f) {
dropMaskX_ = graph->dropout(dropout_, {1, dimInput});
dropMaskS_ = graph->dropout(dropout_, {1, dimState});
@ -343,19 +344,18 @@ using LSTM = FastLSTM;
template <class CellType>
class Multiplicative : public CellType {
private:
protected:
Expr Um_, Wm_, bm_;
Expr gamma1m_, gamma2m_;
public:
template <typename... Args>
Multiplicative(Ptr<ExpressionGraph> graph,
const std::string prefix,
int dimInput,
int dimState,
Args... args)
: CellType(graph, prefix, dimInput, dimState, args...) {
Ptr<Options> options)
: CellType(graph, options) {
int dimInput = options->get<int>("dimInput");
int dimState = options->get<int>("dimState");
std::string prefix = options->get<std::string>("prefix");
Um_ = graph->param(prefix + "_Um", {dimState, dimState},
keywords::init=inits::glorot_uniform);
@ -415,7 +415,6 @@ using MGRU = Multiplicative<GRU>;
class SlowLSTM : public Cell {
private:
std::string prefix_;
Expr Uf_, Wf_, bf_;
Expr Ui_, Wi_, bi_;
@ -423,13 +422,13 @@ private:
Expr Uc_, Wc_, bc_;
public:
template <typename... Args>
SlowLSTM(Ptr<ExpressionGraph> graph,
const std::string prefix,
int dimInput,
int dimState,
Args... args)
: Cell(dimInput, dimState), prefix_{prefix} {
Ptr<Options> options)
: Cell(options) {
int dimInput = options_->get<int>("dimInput");
int dimState = options_->get<int>("dimState");
std::string prefix = options->get<std::string>("prefix");
Uf_ = graph->param(prefix + "_Uf", {dimState, dimState},
keywords::init=inits::glorot_uniform);
@ -512,18 +511,15 @@ public:
class TestLSTM : public Cell {
private:
std::string prefix_;
Expr U_, W_, b_;
public:
template <typename... Args>
TestLSTM(Ptr<ExpressionGraph> graph,
const std::string prefix,
int dimInput,
int dimState,
Args... args)
: Cell(dimInput, dimState), prefix_{prefix} {
TestLSTM(Ptr<ExpressionGraph> graph, Ptr<Options> options)
: Cell(options) {
int dimInput = options_->get<int>("dimInput");
int dimState = options_->get<int>("dimState");
std::string prefix = options->get<std::string>("prefix");
auto Uf = graph->param(prefix + "_Uf", {dimState, dimState},
keywords::init=inits::glorot_uniform);
@ -602,30 +598,27 @@ public:
}
};
class cell {
private:
std::string type_;
public:
cell(const std::string& type)
: type_(type) {}
struct CellCreate {
template <typename ...Args>
Ptr<Cell> operator()(Args&& ...args) {
if(type_ == "gru")
return New<GRU>(args...);
if(type_ == "lstm")
return New<LSTM>(args...);
if(type_ == "mlstm")
return New<MLSTM>(args...);
if(type_ == "mgru")
return New<MGRU>(args...);
if(type_ == "tanh")
return New<Tanh>(args...);
return New<GRU>(args...);
static Ptr<Cell> create(Ptr<ExpressionGraph> graph, Ptr<Options> options, Args ...args) {
std::string type = options->get<std::string>("type");
if(type == "gru")
return New<GRU>(graph, options, args...);
if(type == "lstm")
return New<LSTM>(graph, options, args...);
if(type == "mlstm")
return New<MLSTM>(graph, options, args...);
if(type == "mgru")
return New<MGRU>(graph, options, args...);
if(type == "tanh")
return New<Tanh>(graph, options, args...);
return New<GRU>(graph, options, args...);
}
};
typedef Builder<Cell, CellCreate> cell;
class cells {
private:
std::string type_;

View File

@ -72,7 +72,7 @@ private:
States apply(const Expr input, const Expr mask = nullptr) {
auto graph = input->graph();
int dimBatch = input->shape()[0];
int dimState = cell_->dimState();
int dimState = cell_->getOptions()->get<int>("dimState");
auto output = graph->zeros(keywords::shape = {dimBatch, dimState});
Expr cell = output;

View File

@ -4,6 +4,7 @@
#include <vector>
#include "common/definitions.h"
#include "common/options.h"
#include "graph/expression_graph.h"
namespace marian {
@ -86,7 +87,13 @@ class States {
class Cell;
struct CellInput;
struct Stackable : std::enable_shared_from_this<Stackable> {
class Stackable : public std::enable_shared_from_this<Stackable> {
protected:
Ptr<Options> options_;
public:
Stackable(Ptr<Options> options) : options_(options) {}
// required for dynamic_pointer_cast to detect polymorphism
virtual ~Stackable() {}
@ -99,29 +106,32 @@ struct Stackable : std::enable_shared_from_this<Stackable> {
inline bool is() {
return as<Cast>() != nullptr;
}
Ptr<Options> getOptions() {
return options_;
}
};
struct CellInput : public Stackable {
// Change this to apply(State)
class CellInput : public Stackable {
public:
CellInput(Ptr<Options> options)
: Stackable(options) { }
virtual Expr apply(State) = 0;
virtual int dimOutput() = 0;
};
class Cell : public Stackable {
protected:
int dimInput_;
int dimState_;
public:
Cell(int dimInput, int dimState)
: dimInput_(dimInput), dimState_(dimState) {}
Cell(Ptr<Options> options)
: Stackable(options) {
virtual int dimInput() {
return dimInput_;
}
virtual int dimState() {
return dimState_;
//options_->set("prefix", "");
//options_->set("dimInput", 512);
//options_->set("dimState", 1024);
//options_->set("dropout", 0);
//options_->set("normalize", false);
//options_->set("final", false);
}
State apply(std::vector<Expr> inputs, State state, Expr mask = nullptr) {
@ -133,12 +143,12 @@ public:
};
class MultiCellInput : public CellInput {
private:
protected:
std::vector<Ptr<CellInput>> inputs_;
public:
MultiCellInput(const std::vector<Ptr<CellInput>>& inputs)
: inputs_(inputs) {}
MultiCellInput(const std::vector<Ptr<CellInput>>& inputs, Ptr<Options> options)
: CellInput(options), inputs_(inputs) {}
void push_back(Ptr<CellInput> input) {
inputs_.push_back(input);
@ -164,16 +174,16 @@ public:
};
class StackedCell : public Cell {
private:
protected:
std::vector<Ptr<Stackable>> stackables_;
std::vector<Expr> lastInputs_;
public:
StackedCell(int dimInput, int dimState) : Cell(dimInput, dimState) {}
StackedCell(Ptr<ExpressionGraph>, Ptr<Options> options) : Cell(options) {}
StackedCell(int dimInput, int dimState,
StackedCell(Ptr<ExpressionGraph>, Ptr<Options> options,
const std::vector<Ptr<Stackable>>& stackables)
: Cell(dimInput, dimState), stackables_(stackables) {}
: Cell(options), stackables_(stackables) {}
void push_back(Ptr<Stackable> stackable) {
stackables_.push_back(stackable);
@ -212,5 +222,7 @@ public:
};
typedef Builder<rnn::StackedCell> stacked_cell;
}
}