diff --git a/CMakeLists.txt b/CMakeLists.txt index 995bf902..4e7e1758 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -2,8 +2,8 @@ cmake_minimum_required(VERSION 3.5.1) set(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake) project(marian CXX) -SET(CMAKE_CXX_FLAGS " -std=c++11 -g -O3 -funroll-loops -Wno-unused-result -Wno-deprecated") -LIST(APPEND CUDA_NVCC_FLAGS --default-stream per-thread; -std=c++11; -g; -O3; -arch=sm_35; -lineinfo; --use_fast_math; -Xcompiler '-fPIC') +SET(CMAKE_CXX_FLAGS " -std=c++11 -g -O3 -Ofast -Wno-unused-result -Wno-deprecated -fPIC") +LIST(APPEND CUDA_NVCC_FLAGS --default-stream per-thread; -std=c++11; -g; -O3; -arch=sm_61; --use_fast_math; -Xcompiler '-fPIC') add_definitions(-DCUDA_API_PER_THREAD_DEFAULT_STREAM) SET(CUDA_PROPAGATE_HOST_FLAGS OFF) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 490fbd20..4fed5b1d 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -2,17 +2,18 @@ include_directories(.) include_directories(3rd_party) -cuda_add_library(marian_lib - 3rd_party/cnpy/cnpy.cpp +cuda_add_library(marian_lib +# 3rd_party/cnpy/cnpy.cpp 3rd_party/exception.cpp - expression_graph.cu + expression_graph.cu expression_operators.cu node.cu + node_operators.cu node_operators_unary.cu node_operators_binary.cu - tensor.cu + tensors/tensor.cu tensor_operators.cu - vocab.cpp +# vocab.cpp ) target_link_libraries(marian_lib) @@ -23,13 +24,13 @@ cuda_add_executable( ) cuda_add_executable( - mnist_benchmark - mnist_benchmark.cu + tensor_test + tensor_test.cu ) cuda_add_executable( - xor - xor.cu + mnist_benchmark + mnist_benchmark.cu ) #cuda_add_executable( @@ -37,18 +38,25 @@ cuda_add_executable( # validate_encoder_decoder.cu #) -cuda_add_executable( - test_nodes - test_nodes.cu -) +#cuda_add_executable( +# test_nodes +# test_nodes.cu +#) -#target_link_libraries(softmax_benchmark marian_lib) -#target_link_libraries(mnist_benchmark marian_lib) +target_link_libraries(softmax_benchmark marian_lib) +target_link_libraries(tensor_test marian_lib) +target_link_libraries(mnist_benchmark marian_lib) #target_link_libraries(validate_encoder_decoder marian_lib) #target_link_libraries(test_nodes marian_lib) -foreach(exec mnist_benchmark softmax_benchmark xor test_nodes ) - target_link_libraries(${exec} marian_lib ${EXT_LIBS} cuda cudnn curand) +#foreach(exec mnist_benchmark softmax_benchmark test_nodes tensor_test) +# target_link_libraries(${exec} ${EXT_LIBS} cuda cudnn curand) +# cuda_add_cublas_to_target(${exec}) +# set_target_properties(${exec} PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}") +#endforeach(exec) + +foreach(exec mnist_benchmark tensor_test softmax_benchmark) + target_link_libraries(${exec} ${EXT_LIBS} cuda cudnn curand) cuda_add_cublas_to_target(${exec}) set_target_properties(${exec} PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}") endforeach(exec) diff --git a/src/batch_generator.h b/src/batch_generator.h index aca47d10..a648504a 100644 --- a/src/batch_generator.h +++ b/src/batch_generator.h @@ -131,9 +131,10 @@ class BatchGenerator { return currentBatch_; } - void prepare() { + void prepare(bool shuffle=true) { //boost::timer::cpu_timer total; - data_->shuffle(); + if(shuffle) + data_->shuffle(); //std::cerr << "shuffle: " << total.format(5, "%ws") << std::endl; current_ = data_->begin(); fillBatches(); diff --git a/src/chainable.h b/src/chainable.h index 359b3966..120575b5 100644 --- a/src/chainable.h +++ b/src/chainable.h @@ -100,15 +100,8 @@ struct Chainable { virtual ExpressionGraphPtr graph() = 0; virtual const Shape& shape() = 0; - virtual DataType val() = 0; - virtual DataType grad() = 0; - - virtual void setVal(DataType t) { - UTIL_THROW2("Tensors can only be assigned to input and parameter nodes"); - }; - virtual void setGrad(DataType t) { - UTIL_THROW2("Gradients can only be assigned to parameter nodes"); - }; + virtual DataType& val() = 0; + virtual DataType& grad() = 0; }; /** @brief Defines a convenience type to represent a shared pointer to a Chainable object. */ diff --git a/src/dataset.h b/src/dataset.h index 20b3919d..141738a8 100644 --- a/src/dataset.h +++ b/src/dataset.h @@ -46,8 +46,8 @@ typedef std::shared_ptr ExamplePtr; /** @brief Defines a convenience type to represent an ordered collection of ::ExamplePtr objects. */ typedef std::vector Examples; -/** - * @brief Defines a convenience type to represent a const_iterator over the ::ExamplePtr objects +/** + * @brief Defines a convenience type to represent a const_iterator over the ::ExamplePtr objects * stored in an ::Examples object. */ typedef Examples::const_iterator ExampleIterator; @@ -65,7 +65,7 @@ class Input { /** @brief Constructs a new Input object with the specified Shape */ Input(const Shape& shape) : shape_(shape), - data_(new Data(shape_.totalSize(), 0.0f)) {} + data_(new Data(shape_.elements(), 0.0f)) {} /** @brief Gets an iterator pointing to the beginning of this object's ::Data */ Data::iterator begin() { @@ -105,17 +105,17 @@ class Input { class DataBase { public: - + /** @brief Returns an iterator pointing to the beginning of this object's underlying data. */ virtual ExampleIterator begin() const = 0; /** @brief Returns an iterator pointing to the end of this object's underlying data. */ virtual ExampleIterator end() const = 0; - + /** @brief Randomly shuffles the elements of this object's underlying data. */ virtual void shuffle() = 0; - /** + /** * @brief Returns the size of the i-th dimension of the data. * * When an individual data point from this DataSet is used in the construction of an ExpressionGraph, @@ -136,7 +136,7 @@ class DataBase { /** @brief Defines a convenience type to represent a shared pointer to a DataBase object. */ typedef std::shared_ptr DataBasePtr; -/** +/** * @brief Convenience function to construct a new DataBase object and return a shared pointer to that object. * * The template parameters for this function specify two main pieces of information: diff --git a/src/definitions.h b/src/definitions.h index 45e90462..d6c061d6 100644 --- a/src/definitions.h +++ b/src/definitions.h @@ -25,6 +25,9 @@ #include #include #include +#include + +#include "shape.h" namespace marian { /** @brief Creates shared_ptr of any type, passes all arguments to any available constructor */ @@ -33,7 +36,6 @@ namespace marian { return std::shared_ptr(new T(std::forward(args)...)); } - const size_t SHAPE_SIZE = 2; typedef float Float; @@ -44,113 +46,14 @@ namespace marian { * In that case, this placeholder would be used to specify that the batch size value will be defined at some later point. */ const int whatevs{-1}; - - /** - * @brief Represents the size of each dimension in a tensor. - * - * Note: this class currently is hard-coded to 2 dimensions. - * This is likely to change. - */ - class Shape { - private: - int shape_[SHAPE_SIZE]; - - public: - - /** - * @brief Constructs a default shape. - * - * This default shape has two dimensions. - * The size of each dimension is 1. - */ - Shape() : shape_{1, 1} { } - - /** - * @brief Constructs a shape. - * - * @param i A list of integers representing the size of each dimension. - */ - Shape(std::initializer_list il) { - std::copy(il.begin(), il.end(), begin()); - } - - /** - * @brief Gets a reference to the int representing the size of the ith dimension represented by this object. - * - * @return a reference to the int representing the size of the ith dimension represented by this object - */ - int& operator[](int i) { - return shape_[i]; - } - - /** - * @brief Gets the size of the ith dimension represented by this object. - * - * @return the size of the ith dimension represented by this object - */ - const int& operator[](int i) const { - return shape_[i]; - } - - /** - * @brief Gets the number of dimensions represented by this object - * - * @return the number of dimensions represented by this object - */ - size_t size() const { - return SHAPE_SIZE; - } - - /** - * @brief Gets the total number of elements in a tensor of this shape. - * - * For example, if this shape represents a 5x100 tensor, this method would return 500. - * - * @return the total number of elements in a tensor of this shape - */ - size_t totalSize() const { - size_t s = 1; - for(int i = 0; i < size(); ++i) - s *= shape_[i]; - return s; - } - - /** @brief Gets a pointer to an int that specifies the size of the first dimension represented by this object */ - int* begin() { return shape_; } - - /** @brief Gets a pointer to an int that specifies the size of the last dimension represented by this object */ - int* end() { return shape_ + SHAPE_SIZE; } - - /** @brief Gets a const pointer to an int that specifies the size of the first dimension represented by this object */ - const int* begin() const { return shape_; } - - /** @brief Gets a const pointer to an int that specifies the size of the last dimension represented by this object */ - const int* end() const { return shape_+ SHAPE_SIZE; } - - /** - * @brief Tests this object for equality against another Shape object. - * - * @return true if the size of each dimension in this object - * is equal to the size of the corresponding dimension in the other object, - * false otherwise - */ - bool operator==(const Shape& other) const { - return std::equal(begin(), end(), other.begin()); - } - - /** - * @brief Tests this object for inequality against another Shape object. - */ - bool operator!=(const Shape& other) const { - return !(*this == other); - } - }; } + #include "keywords.h" namespace marian { - class Tensor; + class TensorBase; + typedef std::shared_ptr Tensor; class OptimizerBase; typedef std::shared_ptr OptimizerBasePtr; @@ -173,7 +76,7 @@ namespace marian { KEY(value, float) KEY(lazy_shape, std::function) KEY(lazy_value, std::function) - KEY(init, std::function) + KEY(init, std::function) KEY(optimizer, OptimizerBasePtr) KEY(batch_size, int) diff --git a/src/expression_graph.h b/src/expression_graph.h index c702f71b..5199c28f 100644 --- a/src/expression_graph.h +++ b/src/expression_graph.h @@ -28,16 +28,74 @@ #include "definitions.h" #include "chainable.h" #include "node_operators.h" -#include "tensor.h" #include "batch_generator.h" +#include "tensors/tensor_allocator.h" +#include "tensors/tensor_gpu.h" namespace marian { -// Forward declaration of ExpressionGraph class; this enables it to be used in the following typedef of ExpressionGraphPtr -class ExpressionGraph; +class Parameters { + private: + /** @brief List of all parameter nodes of this expression graph. */ + std::vector params_; + TensorAllocator vals_; + TensorAllocator grads_; -/** @brief A pointer to an expression graph. */ -typedef std::shared_ptr ExpressionGraphPtr; + public: + Parameters() + : vals_(newTensorAllocator()), + grads_(newTensorAllocator()) + {} + + auto begin() -> decltype(params_.begin()) { + return params_.begin(); + } + + auto end() -> decltype(params_.begin()) { + return params_.end(); + } + + size_t size() { + return params_.size(); + } + + size_t totalSize() { + size_t sum = 0; + for(auto p : params_) + sum += p->shape().elements(); + return sum; + } + + void add(Expr p) { + params_.push_back(p); + } + + void allocateForward() { + if(vals_->capacity() == 0) { + vals_->reserveExact(totalSize()); + for(auto p: params_) + if(!p->val()) + vals_->allocate(p->val(), p->shape()); + } + } + + void allocateBackward() { + if(grads_->capacity() == 0) { + grads_->reserveExact(totalSize()); + for(auto p: params_) + if(!p->grad()) + grads_->allocate(p->grad(), p->shape()); + } + } + + Tensor vals() { + return vals_->asTensor(); + } + + Tensor grads() { + return grads_->asTensor(); + } +}; template Expr Expression(Args&& ... args); @@ -50,7 +108,7 @@ class ExpressionGraph : public std::enable_shared_from_this { /** @brief Constructs a new expression graph * Constructor is private to force use of New() */ - ExpressionGraph() {} + ExpressionGraph() : tensors_(newTensorAllocator()) {} // delete copy and move constructors ExpressionGraph(const ExpressionGraph&) = delete; @@ -69,8 +127,8 @@ class ExpressionGraph : public std::enable_shared_from_this { for(int i = 0; i < gInputs.size(); ++i) { if(!gInputs[i]->val()) - gInputs[i]->setVal(Tensor(bInputs[i].shape())); - gInputs[i]->val().set(bInputs[i].begin(), bInputs[i].end()); + tensor(gInputs[i]->val(), bInputs[i].shape()); + gInputs[i]->val()->set(bInputs[i].data()); } } @@ -80,7 +138,7 @@ class ExpressionGraph : public std::enable_shared_from_this { * Backpropogation is implemented by performing first the forward pass * and then the backward pass of algorithmic differentiation (AD) on the nodes of the graph. * - * @param batchSize XXX Marcin, could you provide a description of this param? + * @param batch A batch of training data */ void backprop(data::BatchPtr batch) { forward(batch); @@ -103,6 +161,8 @@ class ExpressionGraph : public std::enable_shared_from_this { * @param batchSize XXX Marcin, could you provide a description of this param? */ void forward(data::BatchPtr batch) { + params_.allocateForward(); + for(auto&& v : tape_) if(!v->skipped_training()) v->allocate(batch->dim()); @@ -144,6 +204,8 @@ class ExpressionGraph : public std::enable_shared_from_this { UTIL_THROW_IF2(topNodes_.size() > 1, "There are more than one top most node for backward step"); + params_.allocateBackward(); + for(auto&& v : tape_) if(!v->skipped_training()) v->set_zero_adjoint(); @@ -241,7 +303,7 @@ class ExpressionGraph : public std::enable_shared_from_this { template inline Expr param(Args ...args) { auto e = Expression(shared_from_this(), args...); - params_.emplace_back(e); + params_.add(e); return e; } @@ -318,7 +380,7 @@ class ExpressionGraph : public std::enable_shared_from_this { * * @return the list of all parameter nodes of this expression graph */ - std::vector& params() { + Parameters& params() { return params_; } @@ -346,14 +408,6 @@ class ExpressionGraph : public std::enable_shared_from_this { named_.emplace(name, e); } - /** - * @brief Returns a pointer to the list of items contained in this graph. - * - * The items in the list will be in the order they were created. - * - * @return a pointer to the list of items contained in this graph - */ - void add(Expr node) { tape_.push_back(node); if(!node->skipped_training()) @@ -364,24 +418,34 @@ class ExpressionGraph : public std::enable_shared_from_this { topNodes_.erase(node); } + + + template + void tensor(Tensor& t, Args&&... args) { + tensors_->allocate(t, args...); + } + private: - /** @brief Pointer to the list of nodes */ + /** @brief The full list of nodes */ Tape tape_; /** @brief Maps from name to expression node. */ std::map named_; - /** @brief List of all parameter nodes of this expression graph. */ - std::vector params_; - /** @brief List of all input nodes of this expression graph. */ std::vector inputs_; /** @brief Contains all nodes with regard to which we want to calculate derivatives */ std::unordered_set topNodes_; + + Parameters params_; + TensorAllocator tensors_; }; +/** @brief A pointer to an expression graph. */ +typedef std::shared_ptr ExpressionGraphPtr; + template Expr Expression(Args&& ... args) { auto e = Expr(new T(std::forward(args)...)); @@ -389,5 +453,4 @@ Expr Expression(Args&& ... args) { return e; } - } diff --git a/src/expression_operators.h b/src/expression_operators.h index 8557b0a6..70eb1e49 100644 --- a/src/expression_operators.h +++ b/src/expression_operators.h @@ -92,39 +92,9 @@ Expr reluplus(Expr a, Expr b); /*********************************************************/ -// inefficient template -inline Expr sum(Expr a, Args ...args) { - using namespace keywords; - Keywords params(args...); - int ax = params.Get(axis, whatevs); - - if(ax == 0) { - auto lshape = [a]() -> Shape { - int rows = a->val().shape()[0]; - return {1, rows}; - }; - Expr one = a->graph()->ones(shape={1, a->shape()[0]}, - lazy_shape=lshape); - return dot(one, a); - } - else if(ax == 1) { - auto lshape = [a]() -> Shape { - int cols = a->val().shape()[1]; - //std::cerr << "Shape will be " << cols << " by 1." << std::endl; - return {cols, 1}; - }; - Expr one = a->graph()->ones(shape={a->shape()[1], 1}, - lazy_shape=lshape); - return dot(a, one); - } - else if(ax == 2) { - UTIL_THROW2("Not implemented"); - } - else if(ax == 3) { - UTIL_THROW2("Not implemented"); - } - return sum(sum(a, axis=0), axis=1); +Expr sum(Expr a, Args ...args) { + return Expression(a, args...); } Expr softmax(Expr a); @@ -133,36 +103,11 @@ Expr logsoftmax(Expr a); Expr argmax(Expr a); -// inefficient template -inline Expr mean(Expr a, Args ...args) { - using namespace keywords; - Keywords params(args...); - size_t ax = params.Get(axis, whatevs); - - switch (ax) { - case 0: - return sum(a, axis=0) / a->graph()->constant(shape={1, 1}, - lazy_value=[a]() -> Float { - return a->val().shape()[0]; - }); - case 1: - return sum(a, axis=1) / a->graph()->constant(shape={1, 1}, - lazy_value=[a]() -> Float { - return a->val().shape()[1]; - }); - case 2: - UTIL_THROW2("Not implemented"); - case 3: - UTIL_THROW2("Not implemented"); - default: - return sum(a) / a->graph()->constant(shape={1, 1}, - lazy_value=[a]() -> Float { - return a->val().size(); - }); - } +Expr mean(Expr a, Args ...args) { + return Expression(a, args...); } - Expr cross_entropy(Expr a, Expr b); +Expr cross_entropy(Expr a, Expr b); } diff --git a/src/mnist_benchmark.cu b/src/mnist_benchmark.cu index 1a8899b5..f94ddc9d 100644 --- a/src/mnist_benchmark.cu +++ b/src/mnist_benchmark.cu @@ -10,6 +10,10 @@ #include "trainer.h" #include "models/feedforward.h" +#include "tensors/tensor.h" +#include "tensors/tensor_gpu.h" +#include "tensors/tensor_allocator.h" + using namespace marian; using namespace keywords; using namespace data; diff --git a/src/models/feedforward.h b/src/models/feedforward.h index 631577fc..117d9419 100644 --- a/src/models/feedforward.h +++ b/src/models/feedforward.h @@ -39,7 +39,7 @@ ExpressionGraphPtr FeedforwardClassifier(const std::vector& dims) { // Because calculating over one observed data point at a time can be inefficient, // it is customary to operate over a batch of observed data points at once. // - // At this point, we do not know the batch size: + // At this point, we do not know the batch size: // whatevs therefore serves as a placeholder for the batch size, which will be specified later // // Once the batch size is known, "x" will represent a matrix with dimensions [batch_size, dims.front()]. @@ -48,7 +48,7 @@ ExpressionGraphPtr FeedforwardClassifier(const std::vector& dims) { "x"); // Construct an input node called "y" and add it to the expression graph. - // + // // For each observed data point, this input will hold the ground truth label for that data point. // dims.back() specifies the size of this vector // @@ -58,7 +58,7 @@ ExpressionGraphPtr FeedforwardClassifier(const std::vector& dims) { // Because calculating over one observed data point at a time can be inefficient, // it is customary to operate over a batch of observed data points at once. // - // At this point, we do not know the batch size: + // At this point, we do not know the batch size: // whatevs therefore serves as a placeholder for the batch size, which will be specified later // // Once the batch size is known, "y" will represent a matrix with dimensions [batch_size, dims.front()]. @@ -72,7 +72,7 @@ ExpressionGraphPtr FeedforwardClassifier(const std::vector& dims) { int out = dims[i+1]; if(i == 0) { - // Create a dropout node as the parent of x, + // Create a dropout node as the parent of x, // and place that dropout node as the value of layers[0] layers.emplace_back(dropout(x, value=0.2)); } else { @@ -88,7 +88,7 @@ ExpressionGraphPtr FeedforwardClassifier(const std::vector& dims) { weights.emplace_back( name(g->param(shape={in, out}, init=uniform()), "W" + std::to_string(i))); - + // Construct a bias node. By definition, a bias node stores the value 1. // Therefore, we don't actually store the 1. // Instead, the bias node object stores the weights on the connections diff --git a/src/node.cu b/src/node.cu index addbad9a..730c08f9 100644 --- a/src/node.cu +++ b/src/node.cu @@ -9,67 +9,96 @@ void Node::skip_training() { graph_->remove_top_node(shared_from_this()); } +void Node::allocate(size_t batchSize) { + auto it1 = shape_.begin(); + auto it2 = givenShape_.begin(); + while(it1 != shape_.end()) { + if(*it2 == whatevs) + *it1 = batchSize; + it1++; it2++; + } + + graph_->tensor(val_, shape_); + + if(Has(keywords::value)) + val_->set(Get(keywords::value, 0)); +} + +void Node::init_dependent() { + if(!adj_) + graph_->tensor(adj_, shape_); + adj_->set(1); +} + +void Node::set_zero_adjoint() { + if(!adj_) { + graph_->tensor(adj_, shape_); + } + adj_->set(0); +} + + // GPU void Node::calc_numeric_grad(Float delta, Tensor input, Tensor grad) { using namespace std; - size_t inputSize = GetTotalSize(input.shape()); - size_t valSize = GetTotalSize(val_.shape()); - - UTIL_THROW_IF2(inputSize != GetTotalSize(grad.shape()), - "inputSize != gradSize:" << inputSize << "!=" << GetTotalSize(grad.shape())); - UTIL_THROW_IF2(valSize != GetTotalSize(adj_.shape()), - "valSize != adjSize :" << valSize << "!=" << GetTotalSize(adj_.shape())); - - cerr << "inputSize=grad=" << Debug(input.shape())<< "=" << inputSize << " " - << "valSize=adj_=" << Debug(val_.shape()) << "=" << valSize - << endl; - - //cerr << "input=" << input.Debug() << endl; - //cerr << "adj_=" << adj_.Debug() << endl; - - std::vector prevCalcGrad; - prevCalcGrad << grad; - //cerr << "origGrad=" << grad.Debug() << endl; - //output("diffGrad", diffGrad); - - //output("prevCalcGrad", prevCalcGrad.begin(), prevCalcGrad.end()); - - Tensor newValTensor(input.shape()); - - // LOOP thru each element in input & add delta - for (size_t inputInd = 0; inputInd < inputSize; ++inputInd) { - input.incr(inputInd, delta); - //output("input", input.begin(), input.end()); - - forward(); - - val_.sum(newValTensor, inputInd); - //cudaDeviceSynchronize(); - - input.incr(inputInd, -delta); - } - - std::vector newVal; - newVal << newValTensor; - //cudaDeviceSynchronize(); - - // orig value - forward(); - - float sumValOrig = val_.sum(); - //float sumValOrig = thrust::reduce(val_.begin(), val_.end(), (float) 0.0f, thrust::plus()); - //cudaDeviceSynchronize(); - - //output("newVal", newVal.begin(), newVal.end()); - - // calc gradient - Tensor prevGradTensor(input.shape()); - thrust::copy(grad.begin(), grad.end(), prevGradTensor.begin()); - - Tensor gradTensor(input.shape()); - Element(_1 = (_2 - sumValOrig) / delta, gradTensor, newValTensor); - Element(_1 = _2 * _3 + _4, grad, adj_, gradTensor, prevGradTensor); +// size_t inputSize = GetTotalSize(input->shape()); +// size_t valSize = GetTotalSize(val_->shape()); +// +// UTIL_THROW_IF2(inputSize != GetTotalSize(grad->shape()), +// "inputSize != gradSize:" << inputSize << "!=" << grad->shape()->elements()); +// UTIL_THROW_IF2(valSize != GetTotalSize(adj_->shape()), +// "valSize != adjSize :" << valSize << "!=" << adj_->shape()->elements()); +// +// cerr << "inputSize=grad=" << Debug(input->shape())<< "=" << inputSize << " " +// << "valSize=adj_=" << Debug(val_->shape()) << "=" << valSize +// << endl; +// +// //cerr << "input=" << input.Debug() << endl; +// //cerr << "adj_=" << adj_.Debug() << endl; +// +// std::vector prevCalcGrad; +// prevCalcGrad << grad; +// //cerr << "origGrad=" << grad.Debug() << endl; +// //output("diffGrad", diffGrad); +// +// //output("prevCalcGrad", prevCalcGrad.begin(), prevCalcGrad.end()); +// +// Tensor newValTensor(input.shape()); +// +// // LOOP thru each element in input & add delta +// for (size_t inputInd = 0; inputInd < inputSize; ++inputInd) { +// input.incr(inputInd, delta); +// //output("input", input.begin(), input.end()); +// +// forward(); +// +// val_.sum(newValTensor, inputInd); +// //cudaDeviceSynchronize(); +// +// input.incr(inputInd, -delta); +// } +// +// std::vector newVal; +// newVal << newValTensor; +// //cudaDeviceSynchronize(); +// +// // orig value +// forward(); +// +// float sumValOrig = val_.sum(); +// //float sumValOrig = thrust::reduce(val_.begin(), val_.end(), (float) 0.0f, thrust::plus()); +// //cudaDeviceSynchronize(); +// +// //output("newVal", newVal.begin(), newVal.end()); +// +// // calc gradient +// Tensor prevGradTensor(input.shape()); +// thrust::copy(grad.begin(), grad.end(), prevGradTensor.begin()); +// +// Tensor gradTensor(input.shape()); +// Element(_1 = (_2 - sumValOrig) / delta, gradTensor, newValTensor); +// Element(_1 = _2 * _3 + _4, grad, adj_, gradTensor, prevGradTensor); } /* diff --git a/src/node.h b/src/node.h index f8fe1def..8cb06065 100644 --- a/src/node.h +++ b/src/node.h @@ -22,9 +22,10 @@ // SOFTWARE. #include +#include #include "keywords.h" -#include "tensor.h" +#include "tensors/tensor.h" #include "chainable.h" namespace marian { @@ -60,52 +61,17 @@ class Node : public Chainable, virtual bool skipped_training() { return skipTraining_; } - virtual void allocate(size_t batchSize) { - auto it1 = shape_.begin(); - auto it2 = givenShape_.begin(); - while(it1 != shape_.end()) { - if(*it2 == whatevs) - *it1 = batchSize; - it1++; it2++; - } + virtual void allocate(size_t batchSize); - if(Has(keywords::lazy_shape)) { - auto defaultShape = [this]() -> Shape { return shape_; }; - shape_ = Get(keywords::lazy_shape, defaultShape)(); - } - if(Has(keywords::lazy_value)) - val_.allocate(shape_, Get( - keywords::lazy_value, []()->Float{return 0.f;})()); - else if(Has(keywords::value)) - val_.allocate(shape_, Get(keywords::value, 0)); - else - val_.allocate(shape_); - } + virtual void init_dependent(); - virtual void init_dependent() { - if(adj_) { - adj_.set(1); - } - else { - adj_.allocate(shape_, 1); - } - } + virtual void set_zero_adjoint(); - virtual void set_zero_adjoint() { - if(adj_) { - adj_.set(0); - } - else { - adj_.allocate(shape_, 0); - } - } - - virtual Tensor val() { + virtual Tensor& val() { return val_; }; - virtual Tensor grad() { - //UTIL_THROW_IF2(!adj_, "Tensor has not been allocated"); + virtual Tensor& grad() { return adj_; }; diff --git a/src/node_operators.cu b/src/node_operators.cu new file mode 100644 index 00000000..912a822f --- /dev/null +++ b/src/node_operators.cu @@ -0,0 +1,37 @@ +// This file is part of the Marian toolkit. +// Marian is copyright (c) 2016 Marcin Junczys-Dowmunt. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include "node_operators.h" +#include "expression_graph.h" + +namespace marian { + + void ParamNode::allocate(size_t batchSize) { + // @TODO params + graph()->tensor(val_, shape_); + if(!initialized_) { + init_(val_); + initialized_ = true; + } + } + + +} diff --git a/src/node_operators.h b/src/node_operators.h index e1f6416e..9ee6ccc9 100644 --- a/src/node_operators.h +++ b/src/node_operators.h @@ -39,12 +39,6 @@ struct InputNode : public Node { ~InputNode() {} - virtual void setVal(Tensor t) { - val_ = t; - shape_ = t.shape(); - //@todo, shape checking - } - void forward() {} void backward() {} @@ -82,7 +76,7 @@ struct ParamNode : public Node { template ParamNode(Args ...args) : Node(args...), - init_(Get(keywords::init, [](Tensor){ })), + init_(Get(keywords::init, [](Tensor&){ })), initialized_(false) { UTIL_THROW_IF2(!Has(keywords::shape) && @@ -90,30 +84,12 @@ struct ParamNode : public Node { "Param items require shape information"); } - virtual void setVal(Tensor t) { - val_ = t; - shape_ = t.shape(); - //@todo, shape checking - }; - ~ParamNode() {} - virtual void setGrad(Tensor t) { - adj_ = t; - shape_ = t.shape(); - //@todo, shape checking - }; - void forward() {} void backward() {} - virtual void allocate(size_t batchSize) { - val_.allocate(shape_); - if(!initialized_) { - init_(val_); - initialized_ = true; - } - } + virtual void allocate(size_t batchSize); virtual std::string graphviz() { std::stringstream ss; @@ -125,7 +101,7 @@ struct ParamNode : public Node { private: - std::function init_; + std::function init_; bool initialized_; }; diff --git a/src/node_operators_binary.cu b/src/node_operators_binary.cu index aed5aa2f..320f1d10 100644 --- a/src/node_operators_binary.cu +++ b/src/node_operators_binary.cu @@ -6,4 +6,21 @@ namespace marian { graph_->remove_top_node(a_); graph_->remove_top_node(b_); } + + // We're caching the logsoftmax probabilities here because we'll need them for + // the backward computation. + void CrossEntropyNodeOp::forward() { + // C = sum(-B * logsoftmax(A)) + if(!probs_) + graph_->tensor(probs_, a_->val()->shape()); + + CudnnLogSoftmax(probs_, a_->val()); + + if(!result_) + graph_->tensor(result_, a_->val()->shape()); + Element(_1 = -_2 * _3, result_, b_->val(), probs_); + Sum(val_, result_, 1); + } + + } diff --git a/src/node_operators_binary.h b/src/node_operators_binary.h index 0fcd5552..a51792bf 100644 --- a/src/node_operators_binary.h +++ b/src/node_operators_binary.h @@ -1,6 +1,7 @@ #pragma once #include "node.h" +#include "thrust_functions.h" #include "tensor_operators.h" namespace marian { @@ -12,16 +13,16 @@ struct BinaryNodeOp : public Node { template BinaryNodeOp(Expr a, Expr b, Args ...args) : Node(a->graph(), - keywords::shape=keywords::Get(keywords::shape, a->shape(), args...), - keywords::no_inference=a->skipped_inference() - || b->skipped_inference() - || keywords::Get(keywords::no_inference, false, args...), + keywords::shape=keywords::Get(keywords::shape, a->shape(), args...), + keywords::no_inference=a->skipped_inference() + || b->skipped_inference() + || keywords::Get(keywords::no_inference, false, args...), keywords::no_training=a->skipped_training() - || b->skipped_training() - || keywords::Get(keywords::no_training, false, args...), - args...), a_(a), b_(b) + || b->skipped_training() + || keywords::Get(keywords::no_training, false, args...), + args...), a_(a), b_(b) { - remove_children_from_top_nodes(); + remove_children_from_top_nodes(); } ~BinaryNodeOp() {} @@ -29,53 +30,53 @@ struct BinaryNodeOp : public Node { void remove_children_from_top_nodes(); void backward_debug(Float delta) { - using namespace std; - - cerr << "BinaryNodeOp::" << typeid(*this).name() << "::backward_debug()" << endl; - - std::vector preCalcGradA, diffGradA, numericalGradA; - preCalcGradA << a_->grad(); - //output("preCalcGradA", preCalcGradA); - - std::vector preCalcGradB, diffGradB, numericalGradB; - preCalcGradB << b_->grad(); - //output("preCalcGradB", preCalcGradB); - - // use df/dx to calc grad - backward(); - cerr << "orig a_->grad()=" << a_->grad().Debug() << endl; - cerr << "orig b_->grad()=" << b_->grad().Debug() << endl; - - diffGradA << a_->grad(); - diffGradB << b_->grad(); - - //cerr << "orig a_->grad()=" << a_->grad().Debug() << endl; - //cerr << "orig b_->grad()=" << b_->grad().Debug() << endl; - - cerr << "TENSOR A:" << endl; - a_->grad().set(preCalcGradA); - b_->grad().set(preCalcGradB); - - calc_numeric_grad(delta, a_->val(), a_->grad()); - cerr << "numerical a_->grad()=" << a_->grad().Debug() << endl; - - numericalGradA << a_->grad(); - outputL2Norm("TENSOR A", diffGradA, numericalGradA); - - - cerr << "TENSOR B:" << endl; - a_->grad().set(preCalcGradA); - b_->grad().set(preCalcGradB); - - calc_numeric_grad(delta, b_->val(), b_->grad()); - cerr << "numerical b_->grad()=" << b_->grad().Debug() << endl; - - numericalGradB << b_->grad(); - outputL2Norm("TENSOR B", diffGradB, numericalGradB); - - // reset to diff grad - a_->grad().set(diffGradA); - b_->grad().set(diffGradB); + //using namespace std; + // + //cerr << "BinaryNodeOp::" << typeid(*this).name() << "::backward_debug()" << endl; + // + //std::vector preCalcGradA, diffGradA, numericalGradA; + //preCalcGradA << a_->grad(); + ////output("preCalcGradA", preCalcGradA); + // + //std::vector preCalcGradB, diffGradB, numericalGradB; + //preCalcGradB << b_->grad(); + ////output("preCalcGradB", preCalcGradB); + // + //// use df/dx to calc grad + //backward(); + //cerr << "orig a_->grad()=" << a_->grad().Debug() << endl; + //cerr << "orig b_->grad()=" << b_->grad().Debug() << endl; + // + //diffGradA << a_->grad(); + //diffGradB << b_->grad(); + // + ////cerr << "orig a_->grad()=" << a_->grad().Debug() << endl; + ////cerr << "orig b_->grad()=" << b_->grad().Debug() << endl; + // + //cerr << "TENSOR A:" << endl; + //a_->grad().set(preCalcGradA); + //b_->grad().set(preCalcGradB); + // + //calc_numeric_grad(delta, a_->val(), a_->grad()); + //cerr << "numerical a_->grad()=" << a_->grad().Debug() << endl; + // + //numericalGradA << a_->grad(); + //outputL2Norm("TENSOR A", diffGradA, numericalGradA); + // + // + //cerr << "TENSOR B:" << endl; + //a_->grad().set(preCalcGradA); + //b_->grad().set(preCalcGradB); + // + //calc_numeric_grad(delta, b_->val(), b_->grad()); + //cerr << "numerical b_->grad()=" << b_->grad().Debug() << endl; + // + //numericalGradB << b_->grad(); + //outputL2Norm("TENSOR B", diffGradB, numericalGradB); + // + //// reset to diff grad + //a_->grad().set(diffGradA); + //b_->grad().set(diffGradB); } @@ -231,7 +232,7 @@ struct MultNodeOp : public BinaryNodeOp { virtual std::string graphviz() { std::stringstream ss; - ss << "\"" << this << "\" [shape=\"box\", label=" << label("x") + ss << "\"" << this << "\" [shape=\"box\", label=" << label("×") << ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl; ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl; ss << "\"" << b_ << "\" -> \"" << this << "\"" << std::endl << std::endl; @@ -285,29 +286,14 @@ struct CrossEntropyNodeOp : public BinaryNodeOp { return shape1; } - // We're caching the logsoftmax probabilities here because we'll need them for - // the backward computation. - void forward() { - // C = -dot(B, logsoftmax(A)). - if (probs_) { - probs_.set(0.0); - } else { - probs_.allocate(a_->val().shape(), 0.0); - } - - CudnnLogSoftmax(probs_, a_->val()); - if(!result_) - result_.allocate(a_->val().shape()); - Element(_1 = -_2 * _3, result_, b_->val(), probs_); - SumRowwise(result_, val_); - } + void forward(); // @TODO: In most cases it's wasteful to compute the derivative with respect // to the second input which is typically an input node in the computation // graph. In general the backward functions can skip the computation of // gradients wrt input nodes. void backward() { - // We are using logsoftmax for this and cached probs are logs. + // We are using logsoftmax for this and cached probs are logs. // For each row, the first input derivative is given by adj * (exp(p) - y), // where y is the gold label distribution (e.g. one hot vector) and // p is the softmax output (probabilities). @@ -315,18 +301,18 @@ struct CrossEntropyNodeOp : public BinaryNodeOp { // Compute first input derivative. Element(_1 += _2 * (Exp(_3) - _4), - a_->grad(), adj_, probs_, b_->val()); + a_->grad(), adj_, probs_, b_->val()); // Compute second input derivative. - Element(_1 -= _2 * _3, b_->grad(), - adj_, probs_); + Element(_1 -= _2 * _3, + b_->grad(), adj_, probs_); } virtual std::string graphviz() { std::stringstream ss; ss << "\"" << this << "\" [shape=\"box\", label=" << label("x-ent") << ", style=\"filled\", fillcolor=\"orange\"]" << std::endl; - ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; + ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl; ss << "\"" << b_ << "\" -> \"" << this << "\"" << std::endl << std::endl; return ss.str(); }; diff --git a/src/node_operators_unary.h b/src/node_operators_unary.h index 334ee415..345f5925 100644 --- a/src/node_operators_unary.h +++ b/src/node_operators_unary.h @@ -1,7 +1,9 @@ #pragma once #include "node.h" +#include "tensors/tensor.h" #include "tensor_operators.h" +#include "thrust_functions.h" namespace marian { @@ -25,30 +27,30 @@ struct UnaryNodeOp : public Node { void remove_children_from_top_nodes(); void backward_debug(Float delta) { - using namespace std; - - cerr << "UnaryNodeOp::" << typeid(*this).name() << "::backward_numeric()" << endl; - - std::vector preCalcGradA, diffGradA, numericalGradA; - preCalcGradA << a_->grad(); - //output("preCalcGradA", preCalcGradA); - - // use df/dx to calc grad - backward(); - cerr << "orig a_->grad()=" << a_->grad().Debug() << endl; - diffGradA << a_->grad(); - - a_->grad().set(preCalcGradA); - - calc_numeric_grad(delta, a_->val(), a_->grad()); - cerr << "numerical a_->grad()=" << a_->grad().Debug() << endl; - - numericalGradA << a_->grad(); - - outputL2Norm("", diffGradA, numericalGradA); - - // reset to diff grad - a_->grad().set(diffGradA); + using namespace std; + // + //cerr << "UnaryNodeOp::" << typeid(*this).name() << "::backward_numeric()" << endl; + // + //std::vector preCalcGradA, diffGradA, numericalGradA; + //a_->grad() >> preCalcGradA ; + ////output("preCalcGradA", preCalcGradA); + // + //// use df/dx to calc grad + //backward(); + //cerr << "orig a_->grad()=" << a_->grad().Debug() << endl; + //a_->grad() >> diffGradA; + // + //a_->grad()->set(preCalcGradA); + // + //calc_numeric_grad(delta, a_->val(), a_->grad()); + ////cerr << "numerical a_->grad()=" << a_->grad()->Debug() << endl; + // + //a_->grad() >> numericalGradA; + // + //outputL2Norm("", diffGradA, numericalGradA); + // + //// reset to diff grad + //a_->grad()->set(diffGradA); } }; @@ -68,10 +70,6 @@ struct LogitNodeOp : public UnaryNodeOp { a_->grad(), adj_, val_); } - void check() { - - } - virtual std::string graphviz() { std::stringstream ss; ss << "\"" << this << "\" [shape=\"box\", label=" << label("logit") @@ -210,7 +208,7 @@ struct SoftmaxNodeOp : public UnaryNodeOp { : UnaryNodeOp(args...) { } void forward() { - CudnnSoftmax(val_, a_->val()); + Softmax(val_, a_->val()); } void backward() { @@ -292,12 +290,94 @@ struct ArgmaxNodeOp : public UnaryNodeOp { }; +struct SumNodeOp : public UnaryNodeOp { + template + SumNodeOp(Expr a, Args ...args) + : UnaryNodeOp(a, keywords::shape=newShape(a, args...), args...) { } + + void forward() { + Sum(val_, a_->val(), Get(keywords::axis, -1)); + } + + void backward() { + SumBackward(a_->grad(), adj_, Get(keywords::axis, -1)); + } + + template + Shape newShape(Expr a, Args ...args) { + int ax = keywords::Get(keywords::axis, -1, args...); + Shape shape = a->shape(); + if(ax == 0) { + shape[0] = 1; + } + else if(ax == 1) { + shape[1] = 1; + } + else { + shape[0] = 1; + shape[1] = 1; + } + return shape; + } + + virtual std::string graphviz() { + std::stringstream ss; + ss << "\"" << this << "\" [shape=\"box\", label=" + << label("sum") << ", style=\"filled\", fillcolor=\"orange\"]" << std::endl; + ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; + return ss.str(); + }; + +}; + +struct MeanNodeOp : public UnaryNodeOp { + template + MeanNodeOp(Expr a, Args ...args) + : UnaryNodeOp(a, keywords::shape=newShape(a, args...), args...) { } + + void forward() { + Sum(val_, a_->val(), Get(keywords::axis, -1), true); + } + + void backward() { + SumBackward(a_->grad(), adj_, Get(keywords::axis, -1), true); + } + + template + Shape newShape(Expr a, Args ...args) { + int ax = keywords::Get(keywords::axis, -1, args...); + Shape shape = a->shape(); + if(ax == 0) { + shape[0] = 1; + } + else if(ax == 1) { + shape[1] = 1; + } + else { + shape[0] = 1; + shape[1] = 1; + } + return shape; + } + + virtual std::string graphviz() { + std::stringstream ss; + ss << "\"" << this << "\" [shape=\"box\", label=" + << label("mean") << ", style=\"filled\", fillcolor=\"orange\"]" << std::endl; + ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; + return ss.str(); + }; + +}; + + struct LogNodeOp : public UnaryNodeOp { template LogNodeOp(Args ...args) : UnaryNodeOp(args...) {} void forward() { + std::cerr << val_.get() << " <-> " << a_->val().get() << std::endl; Element(_1 = Log(_2), val_, a_->val()); } diff --git a/src/optimizers.h b/src/optimizers.h index 17756a00..a4460612 100644 --- a/src/optimizers.h +++ b/src/optimizers.h @@ -38,29 +38,35 @@ class Sgd : public OptimizerBase { class Adagrad : public OptimizerBase { public: Adagrad(float eta=0.01, float eps=1e-8) - : eta_(eta), eps_(eps) {} + : eta_(eta), eps_(eps), + alloc_(newTensorAllocator()) + {} void update(ExpressionGraphPtr graph, data::BatchPtr batch) { graph->backprop(batch); - if(gt_.size() < graph->params().size()) - for(auto& param : graph->params()) - gt_.emplace_back(Tensor(param->grad().shape(), 0)); - - auto gtIt = gt_.begin(); - for(auto& param : graph->params()) { - Element(_1 += (_2 * _2), - *gtIt, param->grad()); - Element(_1 -= (eta_ / (Sqrt(_2) + eps_)) * _3, - param->val(), *gtIt, param->grad()); - gtIt++; + if(!gt_) { + int totalSize = graph->params().totalSize(); + alloc_->reserveExact(totalSize); + alloc_->allocate(gt_, {1, totalSize}); + gt_->set(0); } + + Tensor pVals = graph->params().vals(); + Tensor pGrads = graph->params().grads(); + + ElementVec(_1 += (_2 * _2), + gt_, pGrads); + + ElementVec(_1 -= (eta_ / (Sqrt(_2) + eps_)) * _3, + pVals, gt_, pGrads); } private: float eta_; float eps_; - std::vector gt_; + TensorAllocator alloc_; + Tensor gt_; }; @@ -69,34 +75,39 @@ class Adagrad : public OptimizerBase { class Adam : public OptimizerBase { public: Adam(float eta=0.001, float beta1=0.9, float beta2=0.999, float eps=1e-8) - : eta_(eta), beta1_(beta1), beta2_(beta2), eps_(eps), t_(0) {} + : eta_(eta), beta1_(beta1), beta2_(beta2), eps_(eps), t_(0), + mtAlloc_(newTensorAllocator()), + vtAlloc_(newTensorAllocator()) + {} void update(ExpressionGraphPtr graph, data::BatchPtr batch) { graph->backprop(batch); - if(mt_.size() < graph->params().size()) { - for(auto& param : graph->params()) { - mt_.emplace_back(Tensor(param->grad().shape(), 0)); - vt_.emplace_back(Tensor(param->grad().shape(), 0)); - } + if(!mt_) { + int totalSize = graph->params().totalSize(); + mtAlloc_->reserveExact(totalSize); + mtAlloc_->allocate(mt_, {1, totalSize}); + mt_->set(0); + + vtAlloc_->reserveExact(totalSize); + vtAlloc_->allocate(vt_, {1, totalSize}); + vt_->set(0); } t_++; float denom1 = 1 - pow(beta1_, t_); float denom2 = 1 - pow(beta2_, t_); - auto mtIt = mt_.begin(); - auto vtIt = vt_.begin(); + Tensor pVals = graph->params().vals(); + Tensor pGrads = graph->params().grads(); - for(auto& param : graph->params()) { - Element(_1 = (beta1_ * _1) + ((1 - beta1_) * _2), - *mtIt, param->grad()); - Element(_1 = (beta2_ * _1) + ((1 - beta2_) * (_2 * _2)), - *vtIt, param->grad()); - Element(_1 -= eta_ * (_2 / denom1) / (Sqrt(_3 / denom2) + eps_), - param->val(), *mtIt, *vtIt); - mtIt++; vtIt++; - } + ElementVec(_1 = (beta1_ * _1) + ((1 - beta1_) * _2), + mt_, pGrads); + ElementVec(_1 = (beta2_ * _1) + ((1 - beta2_) * (_2 * _2)), + vt_, pGrads); + + ElementVec(_1 -= eta_ * (_2 / denom1) / (Sqrt(_3 / denom2) + eps_), + pVals, mt_, vt_); } private: @@ -105,8 +116,11 @@ class Adam : public OptimizerBase { float beta2_; float eps_; size_t t_; - std::vector mt_; - std::vector vt_; + + TensorAllocator mtAlloc_; + Tensor mt_; + TensorAllocator vtAlloc_; + Tensor vt_; }; template diff --git a/src/param_initializers.h b/src/param_initializers.h index 99013432..f40b3e0b 100644 --- a/src/param_initializers.h +++ b/src/param_initializers.h @@ -27,7 +27,7 @@ #include #include -#include "tensor.h" +#include "tensors/tensor.h" namespace marian { @@ -47,60 +47,60 @@ float xor128() { // Use a constant seed for deterministic behaviour. std::default_random_engine engine(42); -void zeros(Tensor t) { - t.set(0.f); +void zeros(Tensor& t) { + t->set(0.f); } -void ones(Tensor t) { - t.set(1.0f); +void ones(Tensor& t) { + t->set(1.0f); } template -void distribution(Tensor t, float a, float b) { +void distribution(Tensor& t, float a, float b) { //std::random_device device; //std::default_random_engine engine(device()); Distribution dist(a, b); auto gen = std::bind(dist, engine); - std::vector vals(t.size()); + std::vector vals(t->size()); std::generate(begin(vals), end(vals), gen); t << vals; } -std::function normal(float mean = 0.0, float std = 0.05) { - return [mean, std](Tensor t) { +std::function normal(float mean = 0.0, float std = 0.05) { + return [mean, std](Tensor& t) { distribution>(t, mean, std); - }; + }; } -std::function uniform(float a = -0.05, float b = 0.05) { - return [a, b](Tensor t) { +std::function uniform(float a = -0.05, float b = 0.05) { + return [a, b](Tensor& t) { distribution>(t, a, b); - }; + }; } -void glorot_uniform(Tensor t) { - float b = sqrtf( 6.0f / (t.shape()[0] + t.shape()[1]) ); +void glorot_uniform(Tensor& t) { + float b = sqrtf( 6.0f / (t->shape()[0] + t->shape()[1]) ); distribution>(t, -b, b); } -void xorshift(Tensor t) { - std::vector vals(t.size()); +void xorshift(Tensor& t) { + std::vector vals(t->size()); for(auto&& v : vals) v = xor128(); t << vals; } -void glorot_normal(Tensor t) { - float b = sqrtf( 2.0f / (t.shape()[0] + t.shape()[1]) ); +void glorot_normal(Tensor& t) { + float b = sqrtf( 2.0f / (t->shape()[0] + t->shape()[1]) ); distribution>(t, -b, b); } -std::function from_vector(const std::vector& v) { - return [v](Tensor t) { +std::function from_vector(const std::vector& v) { + return [v](Tensor& t) { t << v; }; } - + } // namespace marian diff --git a/src/shape.h b/src/shape.h new file mode 100644 index 00000000..36d07ffc --- /dev/null +++ b/src/shape.h @@ -0,0 +1,136 @@ +#pragma once + +// This file is part of the Marian toolkit. +// Marian is copyright (c) 2016 Marcin Junczys-Dowmunt. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include + +#include "exception.h" + +namespace marian { + + /** + * @brief Represents the size of each dimension in a tensor. + * + * Note: this class currently is hard-coded to four dimensions. + */ + + const size_t SHAPE_SIZE = 2; + + + struct Shape { + int shape_[SHAPE_SIZE]; + + /** + * @brief Constructs a default shape. + * + * This default shape has four dimensions. + * The size of each dimension is 1. + */ + Shape() : shape_{1, 1} { } + + /** + * @brief Constructs a shape. + * + * @param i A list of integers representing the size of each dimension. + */ + Shape(std::initializer_list il) { + std::copy(il.begin(), il.end(), begin()); + } + + Shape(const Shape& shape) { + std::copy(shape.begin(), shape.end(), begin()); + } + + /** + * @brief Gets a reference to the int representing the size of the ith dimension represented by this object. + * + * @return a reference to the int representing the size of the ith dimension represented by this object + */ + __host__ __device__ + int& operator[](int i) { + return shape_[i]; + } + + /** + * @brief Gets the size of the ith dimension represented by this object. + * + * @return the size of the ith dimension represented by this object + */ + __host__ __device__ + const int& operator[](int i) const { + return shape_[i]; + } + + /** + * @brief Gets the number of dimensions represented by this object + * + * @return the number of dimensions represented by this object + */ + size_t size() const { + return SHAPE_SIZE; + } + + /** + * @brief Gets the total number of elements in a tensor of this shape. + * + * For example, if this shape represents a 5x100 tensor, this method would return 500. + * + * @return the total number of elements in a tensor of this shape + */ + size_t elements() const { + size_t s = 1; + for(int i = 0; i < size(); ++i) + s *= shape_[i]; + return s; + } + + /** @brief Gets a pointer to an int that specifies the size of the first dimension represented by this object */ + int* begin() { return shape_; } + + /** @brief Gets a pointer to an int that specifies the size of the last dimension represented by this object */ + int* end() { return shape_ + SHAPE_SIZE; } + + /** @brief Gets a const pointer to an int that specifies the size of the first dimension represented by this object */ + const int* begin() const { return shape_; } + + /** @brief Gets a const pointer to an int that specifies the size of the last dimension represented by this object */ + const int* end() const { return shape_+ SHAPE_SIZE; } + + /** + * @brief Tests this object for equality against another Shape object. + * + * @return true if the size of each dimension in this object + * is equal to the size of the corresponding dimension in the other object, + * false otherwise + */ + bool operator==(const Shape& other) const { + return std::equal(begin(), end(), other.begin()); + } + + /** + * @brief Tests this object for inequality against another Shape object. + */ + bool operator!=(const Shape& other) const { + return !(*this == other); + } + }; +} diff --git a/src/softmax_benchmark.cu b/src/softmax_benchmark.cu index 5889c465..4172bd1c 100644 --- a/src/softmax_benchmark.cu +++ b/src/softmax_benchmark.cu @@ -2,10 +2,13 @@ #include #include #include +#include #include -#include "tensor.h" +#include "tensors/tensor_allocator.h" +#include "tensors/tensor_gpu.h" + #include "tensor_operators.h" #include "param_initializers.h" @@ -15,9 +18,13 @@ template void testForward(F f, size_t l, const Shape& shape, const std::string& desc) { - Tensor in(shape); - Tensor out(shape); - + + auto ta = newTensorAllocator(); + + Tensor in, out; + ta->allocate(in, shape); + ta->allocate(out, shape); + uniform(-5, 5)(in); std::cout << desc << ": " << std::flush; @@ -34,10 +41,15 @@ template void testBackward(F f, size_t l, const Shape& shape, const std::string& desc) { - Tensor in(shape); - Tensor adj(shape, 1); - Tensor grad(shape); - + + auto ta = newTensorAllocator(); + + Tensor in, adj, grad; + ta->allocate(in, shape); + ta->allocate(adj, shape); + adj->set(1); + ta->allocate(grad, shape); + uniform(-5, 5)(in); std::cout << desc << ": " << std::flush; @@ -52,30 +64,30 @@ void testBackward(F f, size_t l, int main() { int l = 1000; - + std::vector shapes = { {1000, 1000}, {80, 50000}, {50000, 80}, }; - + for(auto& shape : shapes) { - std::cout << "Testing shape: " << shape[0] << "x" << shape[1] << std::endl << std::endl; - + std::cout << "Testing shape: " << shape[0] << "x" << shape[1] << std::endl << std::endl; + std::cout << "Softmax forward" << std::endl; testForward(CudnnSoftmax, l, shape, "CuDNN "); testForward(Softmax, l, shape, "Marian"); std::cout << std::endl; - + std::cout << "Softmax backward" << std::endl; testBackward(CudnnSoftmaxGrad, l, shape, "CuDNN "); testBackward(SoftmaxGrad, l, shape, "Marian"); std::cout << std::endl; - + std::cout << "Log-softmax backward" << std::endl; testBackward(CudnnLogSoftmaxGrad, l, shape, "CuDNN "); testBackward(LogSoftmaxGrad, l, shape, "Marian"); std::cout << std::endl; } return 0; -} \ No newline at end of file +} diff --git a/src/tensor_operators.cu b/src/tensor_operators.cu index a66b7b22..b3280d70 100644 --- a/src/tensor_operators.cu +++ b/src/tensor_operators.cu @@ -39,63 +39,73 @@ static cudnnHandle_t create_handle_dnn() { cublasHandle_t cublasHandle = create_handle(); cudnnHandle_t cudnnHandle = create_handle_dnn(); -void CudnnSoftmax(Tensor out, Tensor in) { +void CudnnSoftmax(Tensor& out, Tensor& in) { float alpha = 1, beta = 0; + auto inGpu = static_cast(in.get()); + auto outGpu = static_cast(out.get()); cudnnSoftmaxForward(cudnnHandle, CUDNN_SOFTMAX_ACCURATE, CUDNN_SOFTMAX_MODE_CHANNEL, &alpha, - in.cudnn(), - in.data(), + inGpu->cudnn(), + inGpu->data(), &beta, - out.cudnn(), - out.data()); + outGpu->cudnn(), + outGpu->data()); cudaDeviceSynchronize(); } -void CudnnLogSoftmax(Tensor out, Tensor in) { +void CudnnLogSoftmax(Tensor& out, Tensor& in) { float alpha = 1, beta = 0; + auto inGpu = static_cast(in.get()); + auto outGpu = static_cast(out.get()); cudnnSoftmaxForward(cudnnHandle, CUDNN_SOFTMAX_LOG, CUDNN_SOFTMAX_MODE_CHANNEL, &alpha, - in.cudnn(), - in.data(), + inGpu->cudnn(), + inGpu->data(), &beta, - out.cudnn(), - out.data()); + outGpu->cudnn(), + outGpu->data()); cudaDeviceSynchronize(); } -void CudnnSoftmaxGrad(Tensor grad, Tensor adj, Tensor val) { +void CudnnSoftmaxGrad(Tensor& grad, Tensor& adj, Tensor& val) { float alpha = 1, beta = 0; + auto valGpu = static_cast(val.get()); + auto adjGpu = static_cast(adj.get()); + auto gradGpu = static_cast(grad.get()); cudnnSoftmaxBackward(cudnnHandle, CUDNN_SOFTMAX_ACCURATE, CUDNN_SOFTMAX_MODE_CHANNEL, &alpha, - val.cudnn(), - val.data(), - adj.cudnn(), - adj.data(), + valGpu->cudnn(), + valGpu->data(), + adjGpu->cudnn(), + adjGpu->data(), &beta, - grad.cudnn(), - grad.data()); + gradGpu->cudnn(), + gradGpu->data()); cudaDeviceSynchronize(); } -void CudnnLogSoftmaxGrad(Tensor grad, Tensor adj, Tensor val) { +void CudnnLogSoftmaxGrad(Tensor& grad, Tensor& adj, Tensor& val) { float alpha = 1, beta = 0; + auto valGpu = static_cast(val.get()); + auto adjGpu = static_cast(adj.get()); + auto gradGpu = static_cast(grad.get()); cudnnSoftmaxBackward(cudnnHandle, CUDNN_SOFTMAX_LOG, CUDNN_SOFTMAX_MODE_CHANNEL, &alpha, - val.cudnn(), - val.data(), - adj.cudnn(), - adj.data(), + valGpu->cudnn(), + valGpu->data(), + adjGpu->cudnn(), + adjGpu->data(), &beta, - grad.cudnn(), - grad.data()); + gradGpu->cudnn(), + gradGpu->data()); cudaDeviceSynchronize(); } @@ -137,18 +147,18 @@ __global__ void gSubtractMax(float* out, const float* in, } } -void SubtractMax(Tensor out, Tensor in) { +void SubtractMax(Tensor& out, Tensor& in) { // Out is a m-by-k matrix, passed as input. // The max element of each row of Out is computed and subtracted from Out. // Out is both input and output. - size_t m = out.shape()[0]; - size_t k = out.shape()[1]; + size_t m = out->shape()[0]; + size_t k = out->shape()[1]; int blocks = std::min(MAX_BLOCKS, (int) m); int threads = std::min(MAX_THREADS, (int) k); int shared = sizeof(float) * threads * 2; - gSubtractMax<<>>(out.data(), - in.data(), m, k); + gSubtractMax<<>>(out->data(), + in->data(), m, k); cudaStreamSynchronize(0); } @@ -186,18 +196,18 @@ __global__ void gSoftMax(float* softMaxP, size_t rows, size_t cols) { } } -void Softmax(Tensor out, Tensor in) { - size_t m = out.shape()[0]; - size_t k = out.shape()[1]; +void Softmax(Tensor& out, Tensor& in) { + size_t m = out->shape()[0]; + size_t k = out->shape()[1]; int blocks = std::min(MAX_BLOCKS, (int) m); int threads = std::min(MAX_THREADS, (int) k); int shared = sizeof(float) * threads * 2; // Subtract the max rowwise for numerical stability (safe softmax). - gSubtractMax<<>>(out.data(), - in.data(), m, k); + gSubtractMax<<>>(out->data(), + in->data(), m, k); cudaStreamSynchronize(0); - gSoftMax<<>>(out.data(), m, k); + gSoftMax<<>>(out->data(), m, k); cudaStreamSynchronize(0); } @@ -240,18 +250,18 @@ __global__ void gSoftmaxGrad(float* grad, const float* adj, const float* val, } } -void SoftmaxGrad(Tensor grad, Tensor adj, Tensor val) { +void SoftmaxGrad(Tensor& grad, Tensor& adj, Tensor& val) { // grad and val are both m-by-k matrices, passed as input. // A weighted average of each row of grad (according to the weights // specified in val) is computed and subtracted from Out. // adj is multiplied for each element to get backward step in autodiff - int m = grad.shape()[0]; - int k = grad.shape()[1]; + int m = grad->shape()[0]; + int k = grad->shape()[1]; int blocks = std::min(MAX_BLOCKS, m); int threads = std::min(MAX_THREADS, k); int shared = sizeof(float) * threads * 2; - gSoftmaxGrad<<>>(grad.data(), adj.data(), val.data(), + gSoftmaxGrad<<>>(grad->data(), adj->data(), val->data(), m, k); cudaStreamSynchronize(0); } @@ -293,19 +303,19 @@ __global__ void gLogSoftmaxGrad(float* grad, const float* adj, const float* val, } } -void LogSoftmaxGrad(Tensor grad, Tensor adj, Tensor val) { +void LogSoftmaxGrad(Tensor& grad, Tensor& adj, Tensor& val) { // grad and val are both m-by-k matrices, passed as input. // A weighted average of each row of grad (according to the weights // specified in val) is computed and subtracted from Out. // adj is multiplied for each element to get backward step in autodiff - int m = grad.shape()[0]; - int k = grad.shape()[1]; + int m = grad->shape()[0]; + int k = grad->shape()[1]; int blocks = std::min(MAX_BLOCKS, m); int threads = std::min(MAX_THREADS, k); int shared = sizeof(float) * threads * 2; - gLogSoftmaxGrad<<>>(grad.data(), - adj.data(), val.data(), + gLogSoftmaxGrad<<>>(grad->data(), + adj->data(), val->data(), m, k); cudaStreamSynchronize(0); } @@ -327,80 +337,154 @@ __global__ void gArgmax(float *out, const float *data, size_t rows, size_t cols) out[row] = maxInd; } -void Argmax(Tensor* Out, const Tensor* In) { - size_t m = In->shape()[0]; - size_t k = In->shape()[1]; - - int blocks = m; //std::min(MAX_BLOCKS, (int) m); - int threads = k; //std::min(MAX_THREADS, (int) k); - //int shared = sizeof(float) * threads * 2; - gArgmax<<>>(Out->data(), In->data(), m, k); - cudaStreamSynchronize(0); -} +//void Argmax(Tensor* Out, const Tensor* In) { +// size_t m = In->shape()[0]; +// size_t k = In->shape()[1]; +// +// int blocks = m; //std::min(MAX_BLOCKS, (int) m); +// int threads = k; //std::min(MAX_THREADS, (int) k); +// //int shared = sizeof(float) * threads * 2; +// gArgmax<<>>(Out->data(), In->data(), m, k); +// cudaStreamSynchronize(0); +//} /////////////////////////////////////////////////////// -Tensor Prod(cublasHandle_t handle, Tensor C, const Tensor A, const Tensor B, +void Prod(cublasHandle_t handle, Tensor& C, const Tensor& A, const Tensor& B, bool transA, bool transB, Float beta) { Float alpha = 1.0; - size_t m = A.shape()[0]; - size_t k = A.shape()[1]; + size_t m = A->shape()[0]; + size_t k = A->shape()[1]; if(transA) std::swap(m, k); - size_t l = B.shape()[0]; - size_t n = B.shape()[1]; + size_t l = B->shape()[0]; + size_t n = B->shape()[1]; if(transB) std::swap(l, n); - size_t lda = A.shape()[1]; - size_t ldb = B.shape()[1]; - size_t ldc = B.shape()[1]; + size_t lda = A->shape()[1]; + size_t ldb = B->shape()[1]; + size_t ldc = B->shape()[1]; if(transB) - ldc = B.shape()[0]; + ldc = B->shape()[0]; cublasOperation_t opA = transA ? CUBLAS_OP_T : CUBLAS_OP_N; cublasOperation_t opB = transB ? CUBLAS_OP_T : CUBLAS_OP_N; cublasSgemm(handle, opB, opA, - n, m, k, &alpha, B.data(), ldb, A.data(), lda, &beta, C.data(), ldc); - return C; + n, m, k, &alpha, B->data(), ldb, A->data(), lda, &beta, C->data(), ldc); } -Tensor Prod(Tensor C, const Tensor A, const Tensor B, - bool transA, bool transB, Float beta) { +void Prod(Tensor& C, const Tensor& A, const Tensor& B, + bool transA, bool transB, Float beta) { - Tensor temp = Prod(cublasHandle, C, A, B, transA, transB, beta); - return temp; + Prod(cublasHandle, C, A, B, transA, transB, beta); } -Tensor SumRowwise(cublasHandle_t handle, const Tensor A, Tensor result) { - size_t rows = A.shape()[0]; - size_t cols = A.shape()[1]; - thrust::device_vector d_ones(cols, 1.f); - Float alpha = 1.f; - Float beta = 0.f; - cublasSgemv(handle, CUBLAS_OP_T, cols, rows, &alpha, - A.data(), cols, - thrust::raw_pointer_cast(d_ones.data()), 1, &beta, - result.data(), 1); - return result; +void Sum(Tensor& out, const Tensor& in, int axis, bool mean) { + int rows = in->shape()[0]; + int cols = in->shape()[1]; + + if(axis == 0) { + float scale = 1.f; + if(mean) + scale = 1.f / rows; + + thrust::device_vector d_ones(rows, scale); + Tensor ones(new TensorGPU(thrust::raw_pointer_cast(d_ones.data()), + {1, rows})); + Prod(out, ones, in, false, false); + } + else if(axis == 1) { + float scale = 1.f; + if(mean) + scale = 1.f / cols; + + thrust::device_vector d_ones(cols, scale); + Tensor ones(new TensorGPU(thrust::raw_pointer_cast(d_ones.data()), + {cols, 1})); + Prod(out, in, ones, false, false); + } + else { + float scale1 = 1.f; + float scale2 = 1.f; + if(mean) { + scale1 = 1.f / rows; + scale2 = 1.f / cols; + } + thrust::device_vector d_ones1(rows, scale1); + Tensor ones1(new TensorGPU(thrust::raw_pointer_cast(d_ones1.data()), + {1, rows})); + thrust::device_vector d_ones2(cols, scale2); + Tensor ones2(new TensorGPU(thrust::raw_pointer_cast(d_ones2.data()), + {cols, 1})); + thrust::device_vector d_temp(cols, 0.f); + Tensor temp(new TensorGPU(thrust::raw_pointer_cast(d_temp.data()), + {1, cols})); + + Prod(temp, ones1, in, false, false); + Prod(out, temp, ones2, false, false); + } } -Tensor SumRowwise(const Tensor A, Tensor result) { - Tensor temp = SumRowwise(cublasHandle, A, result); - return temp; +void SumBackward(Tensor& out, const Tensor& in, int axis, bool mean) { + int rows = out->shape()[0]; + int cols = out->shape()[1]; + + if(axis == 0) { + float scale = 1.f; + if(mean) + scale = 1.f / rows; + + thrust::device_vector d_ones(rows, scale); + Tensor ones(new TensorGPU(thrust::raw_pointer_cast(d_ones.data()), + {rows, 1})); + Prod(out, ones, in, false, false); + } + else if(axis == 1) { + float scale = 1.f; + if(mean) + scale = 1.f / cols; + + thrust::device_vector d_ones(cols, scale); + Tensor ones(new TensorGPU(thrust::raw_pointer_cast(d_ones.data()), + {1, cols})); + Prod(out, in, ones, false, false); + } + else { + float scale1 = 1.f; + float scale2 = 1.f; + if(mean) { + scale1 = 1.f / rows; + scale2 = 1.f / cols; + } + thrust::device_vector d_ones1(rows, scale1); + Tensor ones1(new TensorGPU(thrust::raw_pointer_cast(d_ones1.data()), + {rows, 1})); + thrust::device_vector d_ones2(cols, scale2); + Tensor ones2(new TensorGPU(thrust::raw_pointer_cast(d_ones2.data()), + {1, cols})); + thrust::device_vector d_temp(rows, 0.f); + Tensor temp(new TensorGPU(thrust::raw_pointer_cast(d_temp.data()), + {rows, 1})); + + Prod(temp, ones1, in, false, false); + Prod(out, temp, ones2, false, false); + } } -void CudnnDropoutPrepare(Tensor in, float p, + +void CudnnDropoutPrepare(Tensor& in, float p, cudnnDropoutDescriptor_t* dropDesc, void** space, size_t* spaceSize, void** states, size_t seed) { size_t statesSize; cudnnDropoutGetStatesSize(cudnnHandle, &statesSize); - cudnnDropoutGetReserveSpaceSize(in.cudnn(), spaceSize); + auto inGpu = static_cast(in.get()); + cudnnDropoutGetReserveSpaceSize(inGpu->cudnn(), spaceSize); cudaMalloc((void**)states, statesSize); cudaMalloc((void**)space, *spaceSize); @@ -423,26 +507,30 @@ void CudnnDropoutDestroy(cudnnDropoutDescriptor_t dropDesc, void CudnnDropoutForward(cudnnDropoutDescriptor_t dropoutDesc, void* space, size_t spaceSize, - Tensor out, Tensor in) { + Tensor& out, Tensor& in) { + auto inGpu = static_cast(in.get()); + auto outGpu = static_cast(out.get()); cudnnDropoutForward(cudnnHandle, dropoutDesc, - in.cudnn(), - in.data(), - out.cudnn(), - out.data(), + inGpu->cudnn(), + inGpu->data(), + outGpu->cudnn(), + outGpu->data(), space, spaceSize); } void CudnnDropoutBackward(cudnnDropoutDescriptor_t dropoutDesc, void* space, size_t spaceSize, - Tensor out, Tensor in) { + Tensor& out, Tensor& in) { + auto inGpu = static_cast(in.get()); + auto outGpu = static_cast(out.get()); cudnnDropoutBackward(cudnnHandle, dropoutDesc, - in.cudnn(), - in.data(), - out.cudnn(), - out.data(), + inGpu->cudnn(), + inGpu->data(), + outGpu->cudnn(), + outGpu->data(), space, spaceSize); } diff --git a/src/tensor_operators.h b/src/tensor_operators.h index b7870b48..6a7f5745 100644 --- a/src/tensor_operators.h +++ b/src/tensor_operators.h @@ -21,7 +21,10 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. -#include "tensor.h" +#include +#include + +#include "tensors/tensor_gpu.h" namespace marian { @@ -29,11 +32,79 @@ using namespace thrust::placeholders; #define MAX_THREADS 512 #define MAX_BLOCKS 65535 +template +__global__ void gElementVec(Functor functor, + float* out, const float* in, + int length) { + for(int bid = 0; bid < length; bid += blockDim.x * gridDim.x) { + int noColumn = bid + blockDim.x * blockIdx.x + threadIdx.x; + if (noColumn < length) { + out[noColumn] = functor(out[noColumn], in[noColumn]); + } + } +} + +template +void ElementVec(Functor functor, + T1 out, T2 in) { + + int rows = out->shape()[0]; + int cols = out->shape()[1]; + + int length = rows * cols; + + float* d_out = out->data(); + float* d_in = in->data(); + + int threads = std::min(MAX_THREADS, length); + int blocks = std::min(MAX_BLOCKS, length / threads + (length % threads != 0)); + + gElementVec<<>>(functor, d_out, d_in, length); + cudaStreamSynchronize(0); +} + +template +__global__ void gElementVec(Functor functor, + float* out, + const float* in1, + const float* in2, + int length) { + for(int bid = 0; bid < length; bid += blockDim.x * gridDim.x) { + int noColumn = bid + blockDim.x * blockIdx.x + threadIdx.x; + if (noColumn < length) { + out[noColumn] = functor(out[noColumn], + in1[noColumn], + in2[noColumn]); + } + } +} + +template +void ElementVec(Functor functor, + T1 out, T2 in1, T3 in2) { + + int rows = out->shape()[0]; + int cols = out->shape()[1]; + + int length = rows * cols; + + float* d_out = out->data(); + float* d_in1 = in1->data(); + float* d_in2 = in2->data(); + + int threads = std::min(MAX_THREADS, (int)length); + int blocks = std::min(MAX_BLOCKS, length / threads + (length % threads != 0)); + + gElementVec<<>>(functor, d_out, d_in1, d_in2, length); + cudaStreamSynchronize(0); +} + template __global__ void gElement(Functor functor, T out) { - int rows = out.rows(); - int cols = out.cols(); + int rows = out.shape()[0]; + int cols = out.shape()[1]; + for(int bid = 0; bid < rows; bid += gridDim.x) { int i = bid + blockIdx.x; if(i < rows) { @@ -47,14 +118,17 @@ __global__ void gElement(Functor functor, } template -void Element(Functor functor, T out) { +void Element(Functor functor, T& out) { - int m = out.shape()[0]; - int n = out.shape()[1]; + int m = out->shape()[0]; + int n = out->shape()[1]; int blocks = std::min(MAX_BLOCKS, m); int threads = std::min(MAX_THREADS, n); - gElement<<>>(functor, out.gpu()); + + auto outGpu = static_cast(out.get()); + + gElement<<>>(functor, outGpu->access()); cudaStreamSynchronize(0); } @@ -62,15 +136,17 @@ void Element(Functor functor, T out) { template __global__ void gElement(Functor functor, T1 out, T2 in) { - int rows = out.rows(); - int cols = out.cols(); + int rows = out.shape()[0]; + int cols = out.shape()[1]; + for(int bid = 0; bid < rows; bid += gridDim.x) { int i = bid + blockIdx.x; if(i < rows) { for(int tid = 0; tid < cols; tid += blockDim.x) { int j = tid + threadIdx.x; - if(j < cols) + if(j < cols) { out(i, j) = functor(out(i, j), in(i, j)); + } } } } @@ -78,22 +154,32 @@ __global__ void gElement(Functor functor, template void Element(Functor functor, - T1 out, T2 in) { + T1& out, T2& in) { - int m = out.shape()[0]; - int n = out.shape()[1]; + if(out->shape() == in->shape()) { + ElementVec(functor, out, in); + } + else { + int m = out->shape()[0]; + int n = out->shape()[1]; - int blocks = std::min(MAX_BLOCKS, m); - int threads = std::min(MAX_THREADS, n); - gElement<<>>(functor, out.gpu(), in.gpu()); - cudaStreamSynchronize(0); + int blocks = std::min(MAX_BLOCKS, m); + int threads = std::min(MAX_THREADS, n); + + auto inGpu = static_cast(in.get()); + auto outGpu = static_cast(out.get()); + + gElement<<>>(functor, + outGpu->access(), inGpu->access()); + cudaStreamSynchronize(0); + } } template __global__ void gElement(Functor functor, T1 out, T2 in1, T3 in2) { - int rows = out.rows(); - int cols = out.cols(); + int rows = out.shape()[0]; + int cols = out.shape()[1]; for(int bid = 0; bid < rows; bid += gridDim.x) { int i = bid + blockIdx.x; if(i < rows) { @@ -108,23 +194,35 @@ __global__ void gElement(Functor functor, template void Element(Functor functor, - T1 out, T2 in1, T3 in2) { + T1& out, T2& in1, T3& in2) { - int m = out.shape()[0]; - int n = out.shape()[1]; + if(out->shape() == in1->shape() && in1->shape() == in2->shape()) { + ElementVec(functor, out, in1, in2); + } + else { + auto in1Gpu = static_cast(in1.get()); + auto in2Gpu = static_cast(in2.get()); + auto outGpu = static_cast(out.get()); - int blocks = std::min(MAX_BLOCKS, m); - int threads = std::min(MAX_THREADS, n); - gElement<<>>(functor, out.gpu(), - in1.gpu(), in2.gpu()); - cudaStreamSynchronize(0); + int m = out->shape()[0]; + int n = out->shape()[1]; + + int blocks = std::min(MAX_BLOCKS, m); + int threads = std::min(MAX_THREADS, n); + gElement<<>>(functor, + outGpu->access(), + in1Gpu->access(), + in2Gpu->access()); + cudaStreamSynchronize(0); + } } template __global__ void gElement(Functor functor, T1 out, T2 in1, T3 in2, T4 in3) { - int rows = out.rows(); - int cols = out.cols(); + int rows = out.shape()[0]; + int cols = out.shape()[1]; + for(int bid = 0; bid < rows; bid += gridDim.x) { int i = bid + blockIdx.x; if(i < rows) { @@ -139,48 +237,53 @@ __global__ void gElement(Functor functor, template void Element(Functor functor, - T1 out, T2 in1, T3 in2, T4 in3) { + T1& out, T2& in1, T3& in2, T4& in3) { - int m = out.shape()[0]; - int n = out.shape()[1]; + auto in1Gpu = static_cast(in1.get()); + auto in2Gpu = static_cast(in2.get()); + auto in3Gpu = static_cast(in3.get()); + auto outGpu = static_cast(out.get()); + + int m = outGpu->shape()[0]; + int n = outGpu->shape()[1]; int blocks = std::min(MAX_BLOCKS, m); int threads = std::min(MAX_THREADS, n); - gElement<<>>(functor, out.gpu(), - in1.gpu(), in2.gpu(), in3.gpu()); + gElement<<>>(functor, + outGpu->access(), + in1Gpu->access(), + in2Gpu->access(), + in3Gpu->access()); cudaStreamSynchronize(0); } -void ClipNorm(Tensor out, float threshold); +void ClipNorm(Tensor& out, float threshold); -void SubtractMax(Tensor out, Tensor in); +void SubtractMax(Tensor& out, Tensor& in); -void Softmax(Tensor out, Tensor in); +void Softmax(Tensor& out, Tensor& in); -void SoftmaxGrad(Tensor grad, Tensor adj, Tensor val); -void LogSoftmaxGrad(Tensor grad, Tensor adj, Tensor val); +void SoftmaxGrad(Tensor& grad, Tensor& adj, Tensor& val); +void LogSoftmaxGrad(Tensor& grad, Tensor& adj, Tensor& val); -void CudnnSoftmax(Tensor out, Tensor in); -void CudnnSoftmaxGrad(Tensor grad, Tensor adj, Tensor val); +void CudnnSoftmax(Tensor& out, Tensor& in); +void CudnnSoftmaxGrad(Tensor& grad, Tensor& adj, Tensor& val); -void CudnnLogSoftmax(Tensor out, Tensor in); -void CudnnLogSoftmaxGrad(Tensor grad, Tensor adj, Tensor val); +void CudnnLogSoftmax(Tensor& out, Tensor& in); +void CudnnLogSoftmaxGrad(Tensor& grad, Tensor& adj, Tensor& val); -void Argmax(Tensor* Out, const Tensor* In); +void Argmax(Tensor& Out, const Tensor& In); -Tensor Prod(cublasHandle_t handle, Tensor C, const Tensor A, const Tensor B, +void Prod(cublasHandle_t handle, Tensor& C, const Tensor& A, const Tensor& B, bool transA, bool transB, Float beta); -Tensor Prod(Tensor C, const Tensor A, const Tensor B, +void Prod(Tensor& C, const Tensor& A, const Tensor& B, bool transA, bool transB, Float beta = 0); -Tensor SumRowwise(cublasHandle_t handle, const Tensor A, Tensor result); +void Sum(Tensor& out, const Tensor& in, int axis=-1, bool mean=false); +void SumBackward(Tensor& out, const Tensor& in, int axis=-1, bool mean=false); -Tensor SumRowwise(const Tensor A, Tensor result); - -void ScaleRowwise(Tensor Out, const Tensor ScalingFactors); - -void CudnnDropoutPrepare(Tensor in, float p, +void CudnnDropoutPrepare(Tensor& in, float p, cudnnDropoutDescriptor_t* dropDesc, void** space, size_t* spaceSize, void** states, size_t seed); @@ -190,11 +293,11 @@ void CudnnDropoutDestroy(cudnnDropoutDescriptor_t dropDesc, void CudnnDropoutForward(cudnnDropoutDescriptor_t dropoutDesc, void* space, size_t spaceSize, - Tensor out, Tensor in); + Tensor& out, Tensor& in); void CudnnDropoutBackward(cudnnDropoutDescriptor_t dropoutDesc, void* space, size_t spaceSize, - Tensor out, Tensor in); + Tensor& out, Tensor& in); } diff --git a/src/tensor_test.cu b/src/tensor_test.cu new file mode 100644 index 00000000..f894742a --- /dev/null +++ b/src/tensor_test.cu @@ -0,0 +1,28 @@ +#include + +#include "tensors/tensor_allocator.h" +#include "tensors/tensor_gpu.h" +#include "tensor_operators.h" + +using namespace marian; + +int main() { + TensorAllocator params = newTensorAllocator(); + //params->allocate(100000000); + + std::vector tensors; + for (int i = 0; i < 200; ++i) { + std::cerr << i << std::endl; + tensors.emplace_back(); + params->allocate(tensors.back(), {784,2048}); + std::cerr << tensors.back()->size() << std::endl; + std::cerr << params->capacity() << " " << params->size() << std::endl; + } + + for(int i = 0; i < 200; i++) { + tensors[i]->set(0, 3.14 * i); + std::cerr << tensors[i]->get(0) << std::endl; + } + + return 0; +} diff --git a/src/tensor.cu b/src/tensors/bac/tensor.cu similarity index 100% rename from src/tensor.cu rename to src/tensors/bac/tensor.cu diff --git a/src/tensor.h b/src/tensors/bac/tensor.h similarity index 99% rename from src/tensor.h rename to src/tensors/bac/tensor.h index 5fc9d09d..f4b7ddf7 100644 --- a/src/tensor.h +++ b/src/tensors/bac/tensor.h @@ -236,8 +236,7 @@ class TensorImpl { * * @return Vector in string form. */ - std::string Debug() const - { + std::string Debug() const { std::stringstream strm; assert(shape_.size()); strm << "shape=" << marian::Debug(shape_) << std::endl; diff --git a/src/tensors/tensor.cu b/src/tensors/tensor.cu new file mode 100644 index 00000000..5d5feb7b --- /dev/null +++ b/src/tensors/tensor.cu @@ -0,0 +1,36 @@ +// This file is part of the Marian toolkit. +// Marian is copyright (c) 2016 Marcin Junczys-Dowmunt. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include "tensors/tensor.h" + +namespace marian { + +Tensor& operator<<(Tensor& t, const std::vector& v) { + t->set(v); + return t; +} + +Tensor& operator>>(Tensor& t, std::vector& v) { + t->get(v); + return t; +} + +} diff --git a/src/tensors/tensor.h b/src/tensors/tensor.h new file mode 100644 index 00000000..3a91db61 --- /dev/null +++ b/src/tensors/tensor.h @@ -0,0 +1,76 @@ +#pragma once + +// This file is part of the Marian toolkit. +// Marian is copyright (c) 2016 Marcin Junczys-Dowmunt. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include +#include + +#include "definitions.h" + +namespace marian { + +class TensorBase { + public: + TensorBase(float* data, Shape shape) + : data_(data), shape_(shape) + { } + + virtual ~TensorBase() {} + + virtual void reset(float* data) { + data_ = data; + } + + virtual float* data() { + return data_; + } + + virtual Shape& shape() { + return shape_; + } + + virtual size_t size() { + return shape_.elements(); + } + + virtual float get(size_t) = 0; + virtual void set(size_t, float) = 0; + + virtual void set(float) = 0; + + virtual void get(std::vector &v) = 0; + virtual void set(const std::vector &v) = 0; + + virtual std::string debug() = 0; + + protected: + float* data_; + Shape shape_; +}; + +typedef std::shared_ptr Tensor; + +Tensor& operator<<(Tensor& t, const std::vector& v); + +Tensor& operator>>(Tensor& t, std::vector& v); + +} diff --git a/src/tensors/tensor_allocator.h b/src/tensors/tensor_allocator.h new file mode 100644 index 00000000..839cf037 --- /dev/null +++ b/src/tensors/tensor_allocator.h @@ -0,0 +1,137 @@ +#pragma once + +// This file is part of the Marian toolkit. +// Marian is copyright (c) 2016 Marcin Junczys-Dowmunt. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include + +#include "definitions.h" +#include "tensors/tensor.h" + +namespace marian { + +class TensorAllocatorBase { + public: + virtual ~TensorAllocatorBase() {}; + virtual void reserve(size_t) = 0; + virtual void reserveExact(size_t) = 0; + virtual void clear() = 0; + virtual void allocate(Tensor&, Shape) = 0; + virtual size_t capacity() = 0; + virtual size_t size() = 0; + virtual Tensor asTensor() = 0; +}; + +template +class TensorAllocatorDerived : public TensorAllocatorBase { + private: + const size_t CHUNK = 128; + const size_t MBYTE = 1024 * 1024; + const size_t FLOATS = CHUNK * MBYTE / sizeof(float); + + Device device_; + std::vector allocated_; + + void reset(Tensor t, float* start) { + t->reset(start); + } + + void resetAllocated() { + float* start = device_.data(); + for(auto t : allocated_) { + reset(t, start); + start += t->size(); + } + } + + void checkSpace(Shape shape) { + float* start = device_.data(); + if(!allocated_.empty()) { + start = allocated_.back()->data() + allocated_.back()->size(); + } + + size_t available = device_.data() + device_.capacity() - start; + if(shape.elements() > available) { + reserve(device_.capacity() - available + shape.elements()); + } + } + + public: + void reserve(size_t elements = 0) { + float mult = elements / FLOATS + 1; + std::cerr << "Reserving " << mult * CHUNK << " MB" << std::endl; + device_.reserve(mult * FLOATS); + resetAllocated(); + } + + void reserveExact(size_t elements = 0) { + size_t mbytes = (elements * sizeof(float)) / MBYTE; + std::cerr << "Reserving space for " << elements + << " floats (" << mbytes << " MB)" << std::endl; + device_.reserve(elements); + resetAllocated(); + } + + void clear() { + allocated_.clear(); + } + + void allocate(Tensor &t, Shape shape) { + if(!t || t->shape() != shape) { + checkSpace(shape); + + float* start = device_.data(); + if(!allocated_.empty()) { + start = allocated_.back()->data() + allocated_.back()->size(); + } + + t.reset(new typename Device::tensor_type(start, shape)); + allocated_.push_back(t); + } + } + + Tensor asTensor() { + float* start = device_.data(); + return Tensor(new typename Device::tensor_type(start, {1, (int)size()})); + } + + size_t capacity() { + return device_.capacity(); + } + + size_t size() { + float* start = device_.data(); + float* end = start; + if(!allocated_.empty()) + end = allocated_.back()->data() + allocated_.back()->size(); + + return end - start; + } +}; + +typedef std::shared_ptr TensorAllocator; + +template +TensorAllocator newTensorAllocator() { + return TensorAllocator(new TensorAllocatorDerived()); +} + +} diff --git a/src/tensors/tensor_cpu.h b/src/tensors/tensor_cpu.h new file mode 100644 index 00000000..ffabfc1c --- /dev/null +++ b/src/tensors/tensor_cpu.h @@ -0,0 +1,95 @@ +#pragma once + +// This file is part of the Marian toolkit. +// Marian is copyright (c) 2016 Marcin Junczys-Dowmunt. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include + +#include "tensors/tensor.h" + +namespace marian { + +class TensorCPU : public TensorBase { + public: + TensorCPU(float* data, Shape shape) + : TensorBase(data, shape) {} + + float get(size_t i) { + return data_[i]; + } + + void set(size_t i, float value) { + data_[i] = value; + } + + void get(std::vector &v) { + v.resize(size()); + std::copy(data_, data_ + size(), v.begin()); + } + + void set(float value) { + std::fill(data_, data_ + size(), value); + } + + void set(const std::vector &v) { + std::copy(v.begin(), v.end(), data_); + } +}; + +class DeviceCPU { + private: + float* data_; + size_t size_ + + public: + DeviceCPU() + : data_(0), size_(0) {} + + ~DeviceCPU() { + if(data_) + delete[] data_; + } + + typedef TensorCPU tensor_type; + + void reserve(size_t size) { + UTIL_THROW_IF2(size < size_, "New size must be larger than old size"); + float* temp = new float[size]; + + if(data_) { + std::memcpy(temp, data_, size_* sizeof(float)); + delete[] data_; + } + + data_ = temp; + size_ = size; + } + + float* data() { + return data_; + } + + size_t capacity() { + return size_; + } +}; + +} diff --git a/src/tensors/tensor_gpu.h b/src/tensors/tensor_gpu.h new file mode 100644 index 00000000..d81c6233 --- /dev/null +++ b/src/tensors/tensor_gpu.h @@ -0,0 +1,202 @@ +#pragma once + +// This file is part of the Marian toolkit. +// Marian is copyright (c) 2016 Marcin Junczys-Dowmunt. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include +#include +#include +#include +#include + +#include "exception.h" +#include "definitions.h" +#include "tensors/tensor.h" + +namespace marian { + +#define CUDA_CHECK(ans) { gpuAssert((ans), __FILE__, __LINE__); } + +inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort=true) +{ + if (code != cudaSuccess) + { + fprintf(stderr,"GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line); + if (abort) exit(code); + } +} + +struct Access { + float* data_; + Shape shape_; + + Access(float* data, const Shape& shape) + : data_(data), shape_(shape) {} + + __device__ + inline float& operator()(size_t i, size_t j) { + int rows = shape_[0]; + int cols = shape_[1]; + if(rows != 1 && cols != 1) + return data_[i * cols + j]; + if(rows != 1 && cols == 1) + return data_[i]; + if(rows == 1 && cols != 1) + return data_[j]; + return data_[0]; + } + + __device__ __host__ + float* data() { + return data_; + } + + __device__ __host__ + Shape& shape() { + return shape_; + } + + //Access* toDevice() { + // Access* ptr; + // cudaMalloc(&ptr, sizeof(Access)); + // cudaMemcpy(ptr, this, sizeof(Access), cudaMemcpyHostToDevice); + // return ptr; + //} +}; + +class TensorGPU : public TensorBase { + private: + // cuDNN stuff + cudnnTensorDescriptor_t cudnnDesc_; + + public: + TensorGPU(float* data, Shape shape) + : TensorBase(data, shape) { + cudnnCreateTensorDescriptor(&cudnnDesc_); + cudnnSetTensor4dDescriptorEx(cudnnDesc_, CUDNN_DATA_FLOAT, + shape_[0], shape_[1], 1, 1, + shape_[1], 1, 1, 1); + } + + ~TensorGPU() { + cudnnDestroyTensorDescriptor(cudnnDesc_); + } + + + float get(size_t i) { + float temp; + CUDA_CHECK(cudaMemcpy(&temp, data_ + i, sizeof(float), + cudaMemcpyDeviceToHost)); + return temp; + } + + void set(size_t i, float value) { + CUDA_CHECK(cudaMemcpy(data_ + i, &value, sizeof(float), + cudaMemcpyHostToDevice)); + } + + void get(std::vector &v) { + v.resize(size()); + CUDA_CHECK(cudaMemcpy(v.data(), data_, size() * sizeof(float), + cudaMemcpyDeviceToHost)); + } + + void set(float value) { + thrust::fill(thrust::device_ptr(data_), + thrust::device_ptr(data_ + size()), value); + } + + void set(const std::vector &v) { + CUDA_CHECK(cudaMemcpy(data_, v.data(), v.size() * sizeof(float), + cudaMemcpyHostToDevice)); + } + + cudnnTensorDescriptor_t& cudnn() { + return cudnnDesc_; + } + + Access access() { + return Access(data_, shape_); + } + + std::string debug() { + std::stringstream strm; + assert(shape_.size()); + strm << "shape=" << shape_[0] << "x" << shape_[1] << std::endl; + + // values + size_t totSize = shape_.elements(); + std::vector values(totSize); + get(values); + + size_t ind = 0; + for (size_t i = 0; i < shape()[0]; ++i) { + for (size_t j = 0; j < shape()[1]; ++j) { + strm << values[ind] << " "; + ++ind; + } + strm << std::endl; + } + return strm.str(); + } +}; + +class DeviceGPU { + private: + float* data_; + size_t size_; + + public: + DeviceGPU() + : data_(0), size_(0) {} + + ~DeviceGPU() { + if(data_) + CUDA_CHECK(cudaFree(data_)); + } + + typedef TensorGPU tensor_type; + + void reserve(size_t size) { + UTIL_THROW_IF2(size < size_, "New size must be larger than old size"); + float *temp; + CUDA_CHECK(cudaMalloc(&temp, size * sizeof(float))); + + if(data_) { + CUDA_CHECK(cudaMemcpy(temp, data_, size_* sizeof(float), + cudaMemcpyDeviceToDevice)); + CUDA_CHECK(cudaFree(data_)); + } + + data_ = temp; + size_ = size; + } + + float* data() { + return data_; + } + + size_t capacity() { + return size_; + } +}; + +} diff --git a/src/trainer.h b/src/trainer.h index ea080a96..cb7b6882 100644 --- a/src/trainer.h +++ b/src/trainer.h @@ -57,7 +57,7 @@ class Trainer : public RunBase, while(bg) { BatchPtr batch = bg.next(); opt->update(graph_, batch); - cost += (*graph_)["cost"]->val()[0] * batch->dim(); + cost += (*graph_)["cost"]->val()->get(0) * batch->dim(); totalExamples += batch->dim(); update++; } @@ -115,7 +115,7 @@ class Validator : public RunBase, BatchGenerator bg(dataset_, batchSize); size_t update = 0; - bg.prepare(); + bg.prepare(false); float total = 0; float cor = 0; @@ -123,7 +123,7 @@ class Validator : public RunBase, BatchPtr batch = bg.next(); graph_->inference(batch); std::vector scores; - scores << (*graph_)["scores"]->val(); + (*graph_)["scores"]->val()->get(scores); cor += correct(scores, batch->inputs()[1].data()); total += batch->dim();