mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
attempts at relu and dropout nodes
This commit is contained in:
parent
a9b1872f4e
commit
cd7dbfc0f6
@ -50,15 +50,15 @@ from keras.optimizers import Adam, SGD
|
||||
|
||||
def baseline_model(pixels_count, classes_count):
|
||||
model = Sequential()
|
||||
model.add(Dropout(0.2, input_shape=(pixels_count,)))
|
||||
# model.add(Dense(100, input_dim=pixels_count, init='uniform', activation='tanh'))
|
||||
model.add(Dense(2048, input_dim=pixels_count, init='uniform', activation='tanh'))
|
||||
model.add(Dense(1024, init='uniform', activation='tanh'))
|
||||
model.add(Dense(512, init='uniform', activation='tanh'))
|
||||
model.add(Dense(256, init='uniform', activation='tanh'))
|
||||
model.add(Dense(128, init='uniform', activation='tanh'))
|
||||
model.add(Dense(2048, input_dim=pixels_count, init='uniform', activation='relu'))
|
||||
model.add(Dropout(0.5))
|
||||
model.add(Dense(2048, init='uniform', activation='relu'))
|
||||
model.add(Dropout(0.5))
|
||||
model.add(Dense(classes_count, init='uniform', activation='softmax'))
|
||||
|
||||
opt = Adam();
|
||||
opt = Adam(lr=0.0001);
|
||||
model.compile(loss='categorical_crossentropy', optimizer=opt, metrics=['accuracy'])
|
||||
return model
|
||||
|
||||
@ -99,7 +99,7 @@ if __name__ == "__main__":
|
||||
# Fit the model
|
||||
|
||||
start = time.time();
|
||||
model.fit(X_train, y_train, nb_epoch=10, batch_size=200, verbose=2, shuffle=True)
|
||||
model.fit(X_train, y_train, nb_epoch=50, batch_size=50, verbose=2, shuffle=True)
|
||||
|
||||
print "Time elapsed", time.time() - start, "s"
|
||||
# Final evaluation of the model
|
||||
|
@ -40,7 +40,7 @@ target_link_libraries(validate_mnist_batch marian_lib)
|
||||
target_link_libraries(validate_encoder_decoder marian_lib)
|
||||
|
||||
foreach(exec marian mnist_benchmark validate_mnist_batch validate_encoder_decoder)
|
||||
target_link_libraries(${exec} ${EXT_LIBS} cuda cudnn)
|
||||
target_link_libraries(${exec} ${EXT_LIBS} cuda cudnn curand)
|
||||
cuda_add_cublas_to_target(${exec})
|
||||
set_target_properties(${exec} PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}")
|
||||
endforeach(exec)
|
||||
|
@ -37,6 +37,14 @@ Expr tanh(Expr a) {
|
||||
return Expr(a.graph(), new TanhNodeOp(a));
|
||||
}
|
||||
|
||||
Expr relu(Expr a) {
|
||||
return Expr(a.graph(), new ReLUNodeOp(a));
|
||||
}
|
||||
|
||||
Expr dropout(Expr a) {
|
||||
return Expr(a.graph(), new DropoutNodeOp(a));
|
||||
}
|
||||
|
||||
Expr log(Expr a) {
|
||||
return Expr(a.graph(), new LogNodeOp(a));
|
||||
};
|
||||
|
@ -31,6 +31,10 @@ Expr logit(Expr a);
|
||||
|
||||
Expr tanh(Expr a);
|
||||
|
||||
Expr relu(Expr a);
|
||||
|
||||
Expr dropout(Expr a);
|
||||
|
||||
Expr log(Expr a);
|
||||
|
||||
Expr exp(Expr a);
|
||||
|
@ -32,7 +32,7 @@ ExpressionGraph build_graph(const std::vector<int>& dims) {
|
||||
layers.emplace_back(x);
|
||||
}
|
||||
else {
|
||||
layers.emplace_back(tanh(dot(layers.back(), weights.back()) + biases.back()));
|
||||
layers.emplace_back(relu(dot(layers.back(), weights.back()) + biases.back()));
|
||||
}
|
||||
|
||||
weights.emplace_back(
|
||||
@ -113,14 +113,13 @@ int main(int argc, char** argv) {
|
||||
std::vector<float> testLabels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", testRows, LABEL_SIZE);
|
||||
std::cerr << "Done." << std::endl;
|
||||
|
||||
ExpressionGraph g = build_graph({IMAGE_SIZE, 2048, 1048, 512, 256, 128, LABEL_SIZE});
|
||||
//ExpressionGraph g = build_graph({IMAGE_SIZE, 300, 100, LABEL_SIZE});
|
||||
ExpressionGraph g = build_graph({IMAGE_SIZE, 2048, 2048, LABEL_SIZE});
|
||||
|
||||
Tensor xt({BATCH_SIZE, IMAGE_SIZE});
|
||||
Tensor yt({BATCH_SIZE, LABEL_SIZE});
|
||||
|
||||
boost::timer::cpu_timer total;
|
||||
Adam opt;
|
||||
Adam opt(0.0002);
|
||||
for(int i = 1; i <= 10; ++i) {
|
||||
boost::timer::cpu_timer timer;
|
||||
shuffle(trainImages, trainLabels, IMAGE_SIZE, LABEL_SIZE);
|
||||
|
@ -168,6 +168,58 @@ struct TanhNodeOp : public UnaryNodeOp {
|
||||
|
||||
};
|
||||
|
||||
struct ReLUNodeOp : public UnaryNodeOp {
|
||||
template <typename ...Args>
|
||||
ReLUNodeOp(Args ...args)
|
||||
: UnaryNodeOp(args...) { }
|
||||
|
||||
void forward() {
|
||||
Element(_1 = Max(0.0f * _2, _2), // @TODO: fix 0 * _2
|
||||
val_, a_->val());
|
||||
}
|
||||
|
||||
void backward() {
|
||||
Element(_1 += _2 * (_3 > 0.0f),
|
||||
a_->grad(), adj_, val_);
|
||||
}
|
||||
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=\"ReLU\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
};
|
||||
|
||||
struct DropoutNodeOp : public UnaryNodeOp {
|
||||
template <typename ...Args>
|
||||
DropoutNodeOp(Args ...args)
|
||||
: UnaryNodeOp(args...),
|
||||
p_(0.5), seed_(time(0)) { }
|
||||
|
||||
void forward() {
|
||||
Dropout(val_, a_->val(), p_, seed_++);
|
||||
}
|
||||
|
||||
void backward() {
|
||||
Element(_1 += _2 * (_3 != 0.0f), // transform non-zero to 1
|
||||
a_->grad(), adj_, val_);
|
||||
}
|
||||
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=\"Dropout(" << p_ << ")\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
private:
|
||||
float p_;
|
||||
int seed_;
|
||||
};
|
||||
|
||||
|
||||
struct SoftmaxNodeOp : public UnaryNodeOp {
|
||||
template <typename ...Args>
|
||||
SoftmaxNodeOp(Args ...args)
|
||||
|
@ -1,26 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
// This file is part of the Marian toolkit.
|
||||
// Marian is copyright (c) 2016 Marcin Junczys-Dowmunt.
|
||||
//
|
||||
// Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
// of this software and associated documentation files (the "Software"), to deal
|
||||
// in the Software without restriction, including without limitation the rights
|
||||
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
// copies of the Software, and to permit persons to whom the Software is
|
||||
// furnished to do so, subject to the following conditions:
|
||||
//
|
||||
// The above copyright notice and this permission notice shall be included in
|
||||
// all copies or substantial portions of the Software.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
// SOFTWARE.
|
||||
|
||||
#include <map>
|
||||
#include <boost/any.hpp>
|
||||
#include "tensor_operators.h"
|
||||
@ -62,9 +41,9 @@ class Adagrad {
|
||||
|
||||
auto gtIt = gt_.begin();
|
||||
for(auto& param : graph.params()) {
|
||||
Element(_1 += _2 * _2,
|
||||
Element(_1 += (_2 * _2),
|
||||
*gtIt, param.grad());
|
||||
Element(_1 -= eta_ / (Sqrt(_2) + eps_) * _3,
|
||||
Element(_1 -= (eta_ / (Sqrt(_2) + eps_)) * _3,
|
||||
param.val(), *gtIt, param.grad());
|
||||
gtIt++;
|
||||
}
|
||||
@ -102,9 +81,9 @@ class Adam {
|
||||
auto vtIt = vt_.begin();
|
||||
|
||||
for(auto& param : graph.params()) {
|
||||
Element(_1 = beta1_ * _1 + (1 - beta1_) * _2,
|
||||
Element(_1 = (beta1_ * _1) + ((1 - beta1_) * _2),
|
||||
*mtIt, param.grad());
|
||||
Element(_1 = beta2_ * _1 + (1 - beta2_) * _2 * _2,
|
||||
Element(_1 = (beta2_ * _1) + ((1 - beta2_) * (_2 * _2)),
|
||||
*vtIt, param.grad());
|
||||
Element(_1 -= eta_ * (_2 / denom1) / (Sqrt(_3 / denom2) + eps_),
|
||||
param.val(), *mtIt, *vtIt);
|
||||
|
@ -19,6 +19,8 @@
|
||||
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
// SOFTWARE.
|
||||
|
||||
#include <curand_kernel.h>
|
||||
|
||||
#include "tensor_operators.h"
|
||||
|
||||
using namespace std;
|
||||
@ -33,6 +35,48 @@ static cublasHandle_t create_handle() {
|
||||
}
|
||||
cublasHandle_t cublasHandle = create_handle();
|
||||
|
||||
__global__ void gDropout(float* out, const float* in,
|
||||
int seed, const float p, int rows, int cols) {
|
||||
|
||||
int shift = blockIdx.x * cols + threadIdx.x;
|
||||
curandState state;
|
||||
curand_init(seed, shift, 0, &state);
|
||||
for(int bid = 0; bid < rows; bid += gridDim.x) {
|
||||
int j = bid + blockIdx.x;
|
||||
if(j < rows) {
|
||||
Float* rowOut = out + j * cols;
|
||||
const Float* rowIn = in + j * cols;
|
||||
|
||||
for(int tid = 0; tid < cols; tid += blockDim.x) {
|
||||
int i = tid + threadIdx.x;
|
||||
if(i < cols) {
|
||||
//int offset = i;
|
||||
float dropout = (curand_uniform(&state) >= p);
|
||||
rowOut[i] = dropout * rowIn[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Slow!!!
|
||||
void Dropout(Tensor out, Tensor in, float p, int seed) {
|
||||
int m = in.shape()[0];
|
||||
int n = in.shape()[1];
|
||||
|
||||
curandGenerator_t prng;
|
||||
curandCreateGenerator(&prng, CURAND_RNG_PSEUDO_XORWOW);
|
||||
curandSetPseudoRandomGeneratorSeed(prng, (unsigned long long) seed);
|
||||
curandGenerateUniform(prng, out.data(), m * n);
|
||||
Element(_1 = (_1 > p), out);
|
||||
Element(_1 = _1 * _2, out, in);
|
||||
//int blocks = std::min(MAX_BLOCKS, m);
|
||||
//int threads = std::min(MAX_THREADS, k);
|
||||
//gDropout<<<blocks, threads>>>(out.data(), in.data(), seed, p, m, k);
|
||||
//cudaStreamSynchronize(0);
|
||||
}
|
||||
|
||||
|
||||
__global__ void gSoftmaxGrad(float* grad, const float* adj, const float* val,
|
||||
const int rows, const int cols) {
|
||||
for(int bid = 0; bid < rows; bid += gridDim.x) {
|
||||
|
@ -163,6 +163,8 @@ void Element(Functor functor,
|
||||
cudaStreamSynchronize(0);
|
||||
}
|
||||
|
||||
void Dropout(Tensor Out, Tensor in, float p, int seed);
|
||||
|
||||
void SubtractMax(Tensor* Out);
|
||||
|
||||
void Softmax(Tensor* Out);
|
||||
|
@ -112,6 +112,7 @@ namespace thrust
|
||||
make_actor(_1),
|
||||
make_actor(_2));
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user