mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-04 14:04:24 +03:00
Add initial SGD
This commit is contained in:
parent
ff3a7dd010
commit
8047efa8d9
@ -5,8 +5,8 @@
|
||||
#include <functional>
|
||||
|
||||
namespace marian {
|
||||
typedef float Float;
|
||||
typedef std::vector<int> Shape;
|
||||
typedef float Float;
|
||||
typedef std::vector<size_t> Shape;
|
||||
const int whatevs{-1};
|
||||
}
|
||||
|
||||
|
67
src/sgd.h
Normal file
67
src/sgd.h
Normal file
@ -0,0 +1,67 @@
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <iostream>
|
||||
|
||||
#include "expressions.h"
|
||||
|
||||
namespace marian {
|
||||
|
||||
class SGD {
|
||||
public:
|
||||
SGD(Expr& cost_func, Expr& inX, Expr& inY, float eta, std::vector<std::vector<float>> &xData,
|
||||
std::vector<float> &yData, size_t numClasses, size_t epochs, size_t batchSize)
|
||||
: cost_function_(&cost_func),
|
||||
inX_(&inX),
|
||||
inY_(&inY),
|
||||
eta_(eta),
|
||||
xData_(xData),
|
||||
yData_(yData),
|
||||
epochs_(epochs),
|
||||
batchSize_(batchSize),
|
||||
numClasses_(numClasses) {}
|
||||
|
||||
void run() {
|
||||
auto numExamples = xData_[0].size();
|
||||
Tensor xt({(int)batchSize_, (int)numExamples}, 0.0f);
|
||||
Tensor yt({(int)batchSize_, (int)numClasses_}, 0.0f);
|
||||
for (size_t numEpoch = 0; numEpoch < epochs_; ++numEpoch) {
|
||||
std::cerr << "Starting epoch #" << numEpoch << std::endl;
|
||||
size_t startId = 0;
|
||||
size_t endId = startId + batchSize_;
|
||||
|
||||
while (endId < numExamples) {
|
||||
prepareBatch(startId, xt, yt);
|
||||
*inX_ = xt;
|
||||
*inY_ = yt;
|
||||
|
||||
cost_function_->forward(batchSize_);
|
||||
cost_function_->backward();
|
||||
|
||||
updateModel();
|
||||
|
||||
startId += batchSize_;
|
||||
endId += batchSize_;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void prepareBatch(const size_t index, Tensor& xt, Tensor& yt) {
|
||||
}
|
||||
|
||||
void updateModel() {
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<Expr> cost_function_;
|
||||
std::shared_ptr<Expr> inX_;
|
||||
std::shared_ptr<Expr> inY_;
|
||||
const float eta_;
|
||||
std::vector<std::vector<float>> &xData_;
|
||||
std::vector<float> &yData_;
|
||||
const size_t epochs_;
|
||||
const size_t batchSize_;
|
||||
const size_t numClasses_;
|
||||
};
|
||||
|
||||
} // namespace marian
|
78
src/tensor.h
78
src/tensor.h
@ -16,16 +16,16 @@ namespace marian {
|
||||
struct Handles {
|
||||
cudnnHandle_t cudnnHandle;
|
||||
cublasHandle_t cublasHandle;
|
||||
|
||||
cudnnOpTensorDescriptor_t add;
|
||||
|
||||
|
||||
cudnnOpTensorDescriptor_t add;
|
||||
|
||||
Handles() {
|
||||
cudnnCreate(&cudnnHandle);
|
||||
cublasCreate(&cublasHandle);
|
||||
cudnnCreateOpTensorDescriptor(&add);
|
||||
cudnnSetOpTensorDescriptor(add, CUDNN_OP_TENSOR_ADD, CUDNN_DATA_FLOAT, CUDNN_NOT_PROPAGATE_NAN);
|
||||
}
|
||||
|
||||
|
||||
~Handles() {
|
||||
cudnnDestroy(cudnnHandle);
|
||||
cublasDestroy(cublasHandle);
|
||||
@ -63,7 +63,7 @@ class TensorImpl {
|
||||
cudnnTensorDescriptor_t desc_;
|
||||
size_t tno_;
|
||||
static size_t tensorCounter;
|
||||
|
||||
|
||||
cudnnDataType_t dataType() {
|
||||
switch(sizeof(Float)) {
|
||||
case 2: return CUDNN_DATA_HALF;
|
||||
@ -74,15 +74,15 @@ class TensorImpl {
|
||||
|
||||
public:
|
||||
typedef Float value_type;
|
||||
|
||||
|
||||
TensorImpl(const Shape& shape, value_type value = 0)
|
||||
: shape_(shape), tno_(tensorCounter++)
|
||||
{
|
||||
|
||||
// @TODO:
|
||||
|
||||
// @TODO:
|
||||
UTIL_THROW_IF2(shape_.size() != 2,
|
||||
"For now, only 2D Tensors, will be fixed later.");
|
||||
|
||||
|
||||
UTIL_THROW_IF2(shape_.size() < 1 || shape_.size() > 4,
|
||||
"Wrong number of dimensions: " << shape_.size());
|
||||
|
||||
@ -106,54 +106,54 @@ class TensorImpl {
|
||||
shape_[0], shape_[1], shape_[2], shape_[3]); break;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
TensorImpl(const TensorImpl&) = delete;
|
||||
TensorImpl(TensorImpl&&) = delete;
|
||||
|
||||
|
||||
~TensorImpl() {
|
||||
cudnnDestroyTensorDescriptor(desc_);
|
||||
}
|
||||
|
||||
|
||||
value_type operator[](size_t i) const {
|
||||
return data_[i];
|
||||
}
|
||||
|
||||
|
||||
auto begin() -> decltype( data_.begin() ) {
|
||||
return data_.begin();
|
||||
}
|
||||
|
||||
|
||||
auto begin() const -> decltype( data_.begin() ) {
|
||||
return data_.begin();
|
||||
}
|
||||
|
||||
|
||||
auto end() -> decltype( data_.end() ) {
|
||||
return data_.end();
|
||||
}
|
||||
|
||||
|
||||
auto end() const -> decltype( data_.end() ) {
|
||||
return data_.end();
|
||||
}
|
||||
|
||||
|
||||
const Shape& shape() const {
|
||||
return shape_;
|
||||
}
|
||||
|
||||
|
||||
size_t size() const {
|
||||
return data_.size();
|
||||
}
|
||||
|
||||
|
||||
value_type* data() {
|
||||
return thrust::raw_pointer_cast(data_.data());
|
||||
}
|
||||
|
||||
|
||||
cudnnTensorDescriptor_t desc() const {
|
||||
return desc_;
|
||||
}
|
||||
|
||||
|
||||
size_t id() const {
|
||||
return tno_;
|
||||
}
|
||||
|
||||
|
||||
void set(value_type value) {
|
||||
thrust::fill(data_.begin(), data_.end(), value);
|
||||
}
|
||||
@ -194,70 +194,70 @@ size_t TensorImpl<Type>::tensorCounter = 0;
|
||||
class Tensor {
|
||||
private:
|
||||
std::shared_ptr<TensorImpl<Float>> pimpl_;
|
||||
|
||||
|
||||
public:
|
||||
typedef TensorImpl<Float>::value_type value_type;
|
||||
|
||||
|
||||
Tensor() {}
|
||||
Tensor(Shape shape, value_type value = 0) {
|
||||
allocate(shape, value);
|
||||
}
|
||||
|
||||
|
||||
~Tensor() {}
|
||||
|
||||
|
||||
void allocate(Shape shape, value_type value = 0) {
|
||||
if(!pimpl_)
|
||||
pimpl_.reset(new TensorImpl<Float>(shape, value));
|
||||
}
|
||||
|
||||
|
||||
value_type operator[](size_t i) const {
|
||||
return (*pimpl_)[i];
|
||||
}
|
||||
|
||||
|
||||
size_t size() const {
|
||||
return pimpl_->size();
|
||||
}
|
||||
|
||||
|
||||
value_type* data() {
|
||||
return pimpl_->data();
|
||||
}
|
||||
|
||||
|
||||
const value_type* data() const {
|
||||
return pimpl_->data();
|
||||
}
|
||||
|
||||
|
||||
auto begin() -> decltype( pimpl_->begin() ) {
|
||||
return pimpl_->begin();
|
||||
}
|
||||
|
||||
|
||||
auto begin() const -> decltype( pimpl_->begin() ) {
|
||||
return pimpl_->begin();
|
||||
}
|
||||
|
||||
|
||||
auto end() -> decltype( pimpl_->begin() ) {
|
||||
return pimpl_->begin();
|
||||
}
|
||||
|
||||
|
||||
auto end() const -> decltype( pimpl_->begin() ) {
|
||||
return pimpl_->begin();
|
||||
}
|
||||
|
||||
|
||||
const Shape& shape() const {
|
||||
return pimpl_->shape();
|
||||
}
|
||||
|
||||
|
||||
cudnnTensorDescriptor_t desc() const {
|
||||
return pimpl_->desc();
|
||||
}
|
||||
|
||||
|
||||
void set(value_type value) {
|
||||
pimpl_->set(value);
|
||||
}
|
||||
|
||||
|
||||
size_t id() const {
|
||||
return pimpl_->id();
|
||||
}
|
||||
|
||||
|
||||
operator bool() {
|
||||
return pimpl_ != nullptr;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user