attempts at relu and dropout nodes

This commit is contained in:
Marcin Junczys-Dowmunt 2016-09-19 12:06:04 +02:00
parent a9b1872f4e
commit cd7dbfc0f6
10 changed files with 126 additions and 37 deletions

View File

@ -50,15 +50,15 @@ from keras.optimizers import Adam, SGD
def baseline_model(pixels_count, classes_count):
model = Sequential()
model.add(Dropout(0.2, input_shape=(pixels_count,)))
# model.add(Dense(100, input_dim=pixels_count, init='uniform', activation='tanh'))
model.add(Dense(2048, input_dim=pixels_count, init='uniform', activation='tanh'))
model.add(Dense(1024, init='uniform', activation='tanh'))
model.add(Dense(512, init='uniform', activation='tanh'))
model.add(Dense(256, init='uniform', activation='tanh'))
model.add(Dense(128, init='uniform', activation='tanh'))
model.add(Dense(2048, input_dim=pixels_count, init='uniform', activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(2048, init='uniform', activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(classes_count, init='uniform', activation='softmax'))
opt = Adam();
opt = Adam(lr=0.0001);
model.compile(loss='categorical_crossentropy', optimizer=opt, metrics=['accuracy'])
return model
@ -99,7 +99,7 @@ if __name__ == "__main__":
# Fit the model
start = time.time();
model.fit(X_train, y_train, nb_epoch=10, batch_size=200, verbose=2, shuffle=True)
model.fit(X_train, y_train, nb_epoch=50, batch_size=50, verbose=2, shuffle=True)
print "Time elapsed", time.time() - start, "s"
# Final evaluation of the model

View File

@ -40,7 +40,7 @@ target_link_libraries(validate_mnist_batch marian_lib)
target_link_libraries(validate_encoder_decoder marian_lib)
foreach(exec marian mnist_benchmark validate_mnist_batch validate_encoder_decoder)
target_link_libraries(${exec} ${EXT_LIBS} cuda cudnn)
target_link_libraries(${exec} ${EXT_LIBS} cuda cudnn curand)
cuda_add_cublas_to_target(${exec})
set_target_properties(${exec} PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}")
endforeach(exec)

View File

@ -37,6 +37,14 @@ Expr tanh(Expr a) {
return Expr(a.graph(), new TanhNodeOp(a));
}
Expr relu(Expr a) {
return Expr(a.graph(), new ReLUNodeOp(a));
}
Expr dropout(Expr a) {
return Expr(a.graph(), new DropoutNodeOp(a));
}
Expr log(Expr a) {
return Expr(a.graph(), new LogNodeOp(a));
};

View File

@ -31,6 +31,10 @@ Expr logit(Expr a);
Expr tanh(Expr a);
Expr relu(Expr a);
Expr dropout(Expr a);
Expr log(Expr a);
Expr exp(Expr a);

View File

@ -32,7 +32,7 @@ ExpressionGraph build_graph(const std::vector<int>& dims) {
layers.emplace_back(x);
}
else {
layers.emplace_back(tanh(dot(layers.back(), weights.back()) + biases.back()));
layers.emplace_back(relu(dot(layers.back(), weights.back()) + biases.back()));
}
weights.emplace_back(
@ -113,14 +113,13 @@ int main(int argc, char** argv) {
std::vector<float> testLabels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", testRows, LABEL_SIZE);
std::cerr << "Done." << std::endl;
ExpressionGraph g = build_graph({IMAGE_SIZE, 2048, 1048, 512, 256, 128, LABEL_SIZE});
//ExpressionGraph g = build_graph({IMAGE_SIZE, 300, 100, LABEL_SIZE});
ExpressionGraph g = build_graph({IMAGE_SIZE, 2048, 2048, LABEL_SIZE});
Tensor xt({BATCH_SIZE, IMAGE_SIZE});
Tensor yt({BATCH_SIZE, LABEL_SIZE});
boost::timer::cpu_timer total;
Adam opt;
Adam opt(0.0002);
for(int i = 1; i <= 10; ++i) {
boost::timer::cpu_timer timer;
shuffle(trainImages, trainLabels, IMAGE_SIZE, LABEL_SIZE);

View File

@ -168,6 +168,58 @@ struct TanhNodeOp : public UnaryNodeOp {
};
struct ReLUNodeOp : public UnaryNodeOp {
template <typename ...Args>
ReLUNodeOp(Args ...args)
: UnaryNodeOp(args...) { }
void forward() {
Element(_1 = Max(0.0f * _2, _2), // @TODO: fix 0 * _2
val_, a_->val());
}
void backward() {
Element(_1 += _2 * (_3 > 0.0f),
a_->grad(), adj_, val_);
}
virtual std::string graphviz() {
std::stringstream ss;
ss << "\"" << this << "\" [shape=\"box\", label=\"ReLU\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
return ss.str();
};
};
struct DropoutNodeOp : public UnaryNodeOp {
template <typename ...Args>
DropoutNodeOp(Args ...args)
: UnaryNodeOp(args...),
p_(0.5), seed_(time(0)) { }
void forward() {
Dropout(val_, a_->val(), p_, seed_++);
}
void backward() {
Element(_1 += _2 * (_3 != 0.0f), // transform non-zero to 1
a_->grad(), adj_, val_);
}
virtual std::string graphviz() {
std::stringstream ss;
ss << "\"" << this << "\" [shape=\"box\", label=\"Dropout(" << p_ << ")\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
return ss.str();
};
private:
float p_;
int seed_;
};
struct SoftmaxNodeOp : public UnaryNodeOp {
template <typename ...Args>
SoftmaxNodeOp(Args ...args)

View File

@ -1,26 +1,5 @@
#pragma once
// This file is part of the Marian toolkit.
// Marian is copyright (c) 2016 Marcin Junczys-Dowmunt.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.
#include <map>
#include <boost/any.hpp>
#include "tensor_operators.h"
@ -62,9 +41,9 @@ class Adagrad {
auto gtIt = gt_.begin();
for(auto& param : graph.params()) {
Element(_1 += _2 * _2,
Element(_1 += (_2 * _2),
*gtIt, param.grad());
Element(_1 -= eta_ / (Sqrt(_2) + eps_) * _3,
Element(_1 -= (eta_ / (Sqrt(_2) + eps_)) * _3,
param.val(), *gtIt, param.grad());
gtIt++;
}
@ -102,9 +81,9 @@ class Adam {
auto vtIt = vt_.begin();
for(auto& param : graph.params()) {
Element(_1 = beta1_ * _1 + (1 - beta1_) * _2,
Element(_1 = (beta1_ * _1) + ((1 - beta1_) * _2),
*mtIt, param.grad());
Element(_1 = beta2_ * _1 + (1 - beta2_) * _2 * _2,
Element(_1 = (beta2_ * _1) + ((1 - beta2_) * (_2 * _2)),
*vtIt, param.grad());
Element(_1 -= eta_ * (_2 / denom1) / (Sqrt(_3 / denom2) + eps_),
param.val(), *mtIt, *vtIt);

View File

@ -19,6 +19,8 @@
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.
#include <curand_kernel.h>
#include "tensor_operators.h"
using namespace std;
@ -33,6 +35,48 @@ static cublasHandle_t create_handle() {
}
cublasHandle_t cublasHandle = create_handle();
__global__ void gDropout(float* out, const float* in,
int seed, const float p, int rows, int cols) {
int shift = blockIdx.x * cols + threadIdx.x;
curandState state;
curand_init(seed, shift, 0, &state);
for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x;
if(j < rows) {
Float* rowOut = out + j * cols;
const Float* rowIn = in + j * cols;
for(int tid = 0; tid < cols; tid += blockDim.x) {
int i = tid + threadIdx.x;
if(i < cols) {
//int offset = i;
float dropout = (curand_uniform(&state) >= p);
rowOut[i] = dropout * rowIn[i];
}
}
}
}
}
// Slow!!!
void Dropout(Tensor out, Tensor in, float p, int seed) {
int m = in.shape()[0];
int n = in.shape()[1];
curandGenerator_t prng;
curandCreateGenerator(&prng, CURAND_RNG_PSEUDO_XORWOW);
curandSetPseudoRandomGeneratorSeed(prng, (unsigned long long) seed);
curandGenerateUniform(prng, out.data(), m * n);
Element(_1 = (_1 > p), out);
Element(_1 = _1 * _2, out, in);
//int blocks = std::min(MAX_BLOCKS, m);
//int threads = std::min(MAX_THREADS, k);
//gDropout<<<blocks, threads>>>(out.data(), in.data(), seed, p, m, k);
//cudaStreamSynchronize(0);
}
__global__ void gSoftmaxGrad(float* grad, const float* adj, const float* val,
const int rows, const int cols) {
for(int bid = 0; bid < rows; bid += gridDim.x) {

View File

@ -163,6 +163,8 @@ void Element(Functor functor,
cudaStreamSynchronize(0);
}
void Dropout(Tensor Out, Tensor in, float p, int seed);
void SubtractMax(Tensor* Out);
void Softmax(Tensor* Out);

View File

@ -112,6 +112,7 @@ namespace thrust
make_actor(_1),
make_actor(_2));
}
}
}
}