avoid reallocation of temporary tensors

This commit is contained in:
Marcin Junczys-Dowmunt 2016-09-21 02:38:12 +02:00
parent ffd2dcd513
commit 828a0db8bc
5 changed files with 73 additions and 19 deletions

View File

@ -51,13 +51,13 @@ from keras.optimizers import Adam, SGD
def baseline_model(pixels_count, classes_count):
model = Sequential()
model.add(Dropout(0.2, input_shape=(pixels_count,)))
model.add(Dense(2048, input_dim=pixels_count, init='uniform', activation='tanh'))
model.add(Dense(2048, input_dim=pixels_count, init='uniform', activation='relu'))
# model.add(Dense(2048, init='uniform', activation='relu'))
model.add(Dropout(0.5))
# model.add(Dense(2048, init='uniform', activation='relu'))
# model.add(Dense(2048, init='uniform', activation='relu'))
# model.add(Dense(2048, init='uniform', activation='relu'))
model.add(Dense(2048, init='uniform', activation='tanh'))
model.add(Dense(2048, init='uniform', activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(classes_count, init='uniform', activation='softmax'))
@ -102,7 +102,7 @@ if __name__ == "__main__":
# Fit the model
start = time.time();
model.fit(X_train, y_train, nb_epoch=20, batch_size=200, verbose=2, shuffle=True)
model.fit(X_train, y_train, nb_epoch=50, batch_size=200, verbose=2, shuffle=True)
print "Time elapsed", time.time() - start, "s"
# Final evaluation of the model

View File

@ -115,14 +115,14 @@ int main(int argc, char** argv) {
std::cerr << "Done." << std::endl;
ExpressionGraph g = build_graph({IMAGE_SIZE, 2048, 2048, LABEL_SIZE});
//std::cout << g.graphviz() << std::endl;
std::cout << g.graphviz() << std::endl;
Tensor xt({BATCH_SIZE, IMAGE_SIZE});
Tensor yt({BATCH_SIZE, LABEL_SIZE});
boost::timer::cpu_timer total;
Adam opt(0.0002);
for(int i = 1; i <= 10; ++i) {
for(int i = 1; i <= 30; ++i) {
boost::timer::cpu_timer timer;
shuffle(trainImages, trainLabels, IMAGE_SIZE, LABEL_SIZE);
float cost = 0;

View File

@ -251,9 +251,10 @@ struct CrossEntropyNodeOp : public BinaryNodeOp {
}
thrust::copy(a_->val().begin(), a_->val().end(), probs_.begin());
Softmax(&probs_); // Safe version of softmax.
Tensor result(a_->val().shape());
Element(_1 = -_2 * Log(_3), result, b_->val(), probs_);
SumRowwise(result, val_);
if(!result_)
result_.allocate(a_->val().shape());
Element(_1 = -_2 * Log(_3), result_, b_->val(), probs_);
SumRowwise(result_, val_);
}
// @TODO: In most cases it's wasteful to compute the derivative with respect
@ -265,17 +266,18 @@ struct CrossEntropyNodeOp : public BinaryNodeOp {
// where y is the gold label distribution (e.g. one hot vector) and
// p is the softmax output (probabilities).
// The second input derivative is -adj*log(p).
Tensor result(probs_.shape());
if(!result_)
result_.allocate(probs_.shape());
// Compute first input derivative.
Element(_1 = _2 - _3, result, probs_, b_->val());
ScaleRowwise(result, adj_);
Element(_1 += _2, a_->grad(), result);
Element(_1 = _2 - _3, result_, probs_, b_->val());
ScaleRowwise(result_, adj_);
Element(_1 += _2, a_->grad(), result_);
// Compute second input derivative.
Element(_1 = -Log(_2), result, probs_); // @TODO: use a cached log here.
ScaleRowwise(result, adj_);
Element(_1 += _2, b_->grad(), result);
Element(_1 = -Log(_2), result_, probs_); // @TODO: use a cached log here.
ScaleRowwise(result_, adj_);
Element(_1 += _2, b_->grad(), result_);
}
virtual std::string graphviz() {
@ -289,7 +291,7 @@ struct CrossEntropyNodeOp : public BinaryNodeOp {
protected:
Tensor probs_;
Tensor result_;
};

View File

@ -134,6 +134,59 @@ void SubtractMax(Tensor* Out) {
cudaStreamSynchronize(0);
}
///////////////////////////////////////////////////////
//template <class T>
//__global__ void gClipNorm(T t) {
// int rows = t.rows();
// int cols = t.cols();
//
// for(int bid = 0; bid < rows; bid += gridDim.x) {
// int i = bid + blockIdx.x;
// if(i < rows) {
// extern __shared__ float _share[];
// float* _sum = _share + blockDim.x;
// _sum[threadIdx.x] = 0.0;
// for(int tid = 0; tid < cols; tid += blockDim.x) {
// int j = tid + threadIdx.x;
// if(j < cols)
// _sum[threadIdx.x] += powf(t(i,j), 2.0f);
// }
// __syncthreads();
// int len = blockDim.x;
// while(len != 1) {
// __syncthreads();
// int skip = (len + 1) >> 1;
// if(threadIdx.x < (len >> 1))
// _sum[threadIdx.x] += _sum[threadIdx.x + skip];
// len = (len + 1) >> 1;
// }
// __syncthreads();
// float total = 0;
// if(j == 0) {
// for()
// }
// for(int tid = 0; tid < cols; tid += blockDim.x){
// int j = tid + threadIdx.x;
// if(j < cols)
// sp[j] /= _sum[0];
// }
// }
// }
//}
//
//void ClipNorm(Tensor out, float threshold);
// size_t m = out.shape()[0];
// size_t k = out.shape()[1];
//
// int blocks = std::min(MAX_BLOCKS, (int) m);
// int threads = std::min(MAX_THREADS, (int) k);
// int shared = sizeof(float) * threads * 2;
// gClipNorm<<<blocks, threads, shared>>>(out.gpu());
// cudaStreamSynchronize(0);
//}
///////////////////////////////////////////////////////
__global__ void gSoftMax(float* softMaxP, size_t rows, size_t cols) {
for(int bid = 0; bid < rows; bid += gridDim.x) {

View File

@ -151,6 +151,8 @@ void Element(Functor functor,
cudaStreamSynchronize(0);
}
void ClipNorm(Tensor out, float threshold);
void SubtractMax(Tensor* Out);
void Softmax(Tensor* Out);
@ -169,9 +171,6 @@ Tensor SumRowwise(cublasHandle_t handle, const Tensor A, Tensor result);
Tensor SumRowwise(const Tensor A, Tensor result);
__global__ void gScaleRowwise(Float* out, const Float* scalingFactors,
size_t rows, size_t cols);
void ScaleRowwise(Tensor Out, const Tensor ScalingFactors);
}