mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-04 14:04:24 +03:00
merge
This commit is contained in:
commit
d18d299009
@ -1,5 +1,10 @@
|
||||
Marian
|
||||
======
|
||||
|
||||
[![Join the chat at https://gitter.im/MarianNMT/Lobby](https://badges.gitter.im/MarianNMT/Lobby.svg)](https://gitter.im/MarianNMT/Lobby?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
|
||||
|
||||
Google group for commit messages: https://groups.google.com/forum/#!forum/mariannmt
|
||||
|
||||
A C++ gpu-specific parallel automatic differentiation library
|
||||
with operator overloading.
|
||||
|
||||
|
@ -4,25 +4,65 @@ import sys
|
||||
import os
|
||||
import numpy as np
|
||||
import time
|
||||
import theano
|
||||
|
||||
np.set_printoptions(threshold=np.inf, linewidth=np.inf, suppress=True)
|
||||
np.random.seed(42)
|
||||
|
||||
from keras.datasets import mnist
|
||||
from keras.utils import np_utils
|
||||
from keras.models import Sequential
|
||||
from keras.layers import Dense
|
||||
from keras.layers import Dropout
|
||||
from keras import backend as K
|
||||
from keras.optimizers import Adam, SGD
|
||||
|
||||
def softmax(x):
|
||||
return np.exp(x) / np.sum(np.exp(x), axis=1)[:, None]
|
||||
#class Adam2(SGD):
|
||||
# def get_gradients(self, loss, params):
|
||||
# print "lalala"
|
||||
# grads = K.gradients(loss, params)
|
||||
# if hasattr(self, 'clipnorm') and self.clipnorm > 0:
|
||||
# norm = K.sqrt(sum([K.sum(K.square(g)) for g in grads]))
|
||||
# grads = [clip_norm(g, self.clipnorm, norm) for g in grads]
|
||||
# if hasattr(self, 'clipvalue') and self.clipvalue > 0:
|
||||
# grads = [K.clip(g, -self.clipvalue, self.clipvalue) for g in grads]
|
||||
# grads = [theano.printing.Print('Gradient')(g) for g in grads]
|
||||
# return grads
|
||||
#
|
||||
#
|
||||
#X = 123456789
|
||||
#Y = 362436069
|
||||
#Z = 521288629
|
||||
#W = 88675123
|
||||
#
|
||||
#def xorshift():
|
||||
# global X, Y, Z, W
|
||||
# t = (X ^ (X << 11)) % 1000
|
||||
# X = Y
|
||||
# Y = Z
|
||||
# Z = W
|
||||
# W = (W ^ (W >> 19) ^ t ^ (t >> 8)) % 1000
|
||||
# return 0.1 * ((W % 1000)/1000.0) - 0.05
|
||||
|
||||
#def xorshift_init(shape, name=None):
|
||||
# init = np.array([xorshift() for i in range(shape[0] * shape[1])]).reshape(shape)
|
||||
# return K.variable(init, name=name)
|
||||
|
||||
def baseline_model(pixels_count, classes_count):
|
||||
model = Sequential()
|
||||
model.add(Dense(2000, input_dim=pixels_count, init='normal', activation='tanh'))
|
||||
model.add(Dense(2000, init='normal', activation='tanh'))
|
||||
model.add(Dense(2000, init='normal', activation='tanh'))
|
||||
model.add(Dense(2000, init='normal', activation='tanh'))
|
||||
model.add(Dense(2000, init='normal', activation='tanh'))
|
||||
model.add(Dense(classes_count, input_dim=100, init='normal', activation='softmax'))
|
||||
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
|
||||
# model.add(Dropout(0.2, input_shape=(pixels_count,)))
|
||||
model.add(Dense(2048, input_dim=pixels_count, init='uniform', activation='relu'))
|
||||
# model.add(Dense(2048, init='uniform', activation='relu'))
|
||||
# model.add(Dropout(0.5))
|
||||
model.add(Dense(2048, init='uniform', activation='relu'))
|
||||
model.add(Dense(2048, init='uniform', activation='relu'))
|
||||
model.add(Dense(2048, init='uniform', activation='relu'))
|
||||
model.add(Dense(2048, init='uniform', activation='relu'))
|
||||
# model.add(Dropout(0.5))
|
||||
model.add(Dense(classes_count, init='uniform', activation='softmax'))
|
||||
|
||||
opt = Adam(lr=0.0002);
|
||||
model.compile(loss='categorical_crossentropy', optimizer=opt, metrics=['accuracy'])
|
||||
return model
|
||||
|
||||
|
||||
@ -41,8 +81,8 @@ if __name__ == "__main__":
|
||||
|
||||
### Normalize data to (0, 1)
|
||||
|
||||
X_train = X_train / 255
|
||||
X_test = X_test / 255
|
||||
X_train = X_train / 255.0
|
||||
X_test = X_test / 255.0
|
||||
|
||||
### Change classes to one hot encoding matrixes
|
||||
|
||||
@ -56,10 +96,14 @@ if __name__ == "__main__":
|
||||
|
||||
# Build the model
|
||||
model = baseline_model(pixels_count, classes_count)
|
||||
|
||||
#for layer in model.layers:
|
||||
# print layer.get_weights()
|
||||
# Fit the model
|
||||
|
||||
start = time.time();
|
||||
model.fit(X_train, y_train, nb_epoch=10, batch_size=200, verbose=2)
|
||||
model.fit(X_train, y_train, nb_epoch=10, batch_size=200, verbose=2, shuffle=True)
|
||||
|
||||
print "Time elapsed", time.time() - start, "s"
|
||||
# Final evaluation of the model
|
||||
scores = model.evaluate(X_test, y_test, verbose=0)
|
||||
|
@ -45,7 +45,7 @@ target_link_libraries(validate_encoder_decoder marian_lib)
|
||||
target_link_libraries(test_nodes marian_lib)
|
||||
|
||||
foreach(exec marian mnist_benchmark validate_mnist_batch validate_encoder_decoder test_nodes)
|
||||
target_link_libraries(${exec} ${EXT_LIBS} cuda cudnn)
|
||||
target_link_libraries(${exec} ${EXT_LIBS} cuda cudnn curand)
|
||||
cuda_add_cublas_to_target(${exec})
|
||||
set_target_properties(${exec} PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}")
|
||||
endforeach(exec)
|
||||
|
@ -42,7 +42,9 @@ struct Chainable {
|
||||
|
||||
virtual void allocate(size_t) = 0;
|
||||
virtual std::string graphviz() = 0;
|
||||
virtual void set_name(const std::string&) = 0;
|
||||
virtual const std::string &name() const = 0;
|
||||
virtual const std::string label(const std::string& type) = 0;
|
||||
|
||||
virtual const Shape& shape() = 0;
|
||||
virtual DataType &val() = 0;
|
||||
|
@ -25,6 +25,7 @@
|
||||
namespace marian {
|
||||
|
||||
Expr named(Expr a, const std::string& name) {
|
||||
a.node()->set_name(name);
|
||||
a.graph()->add_named_node(a, name);
|
||||
return a;
|
||||
}
|
||||
@ -37,6 +38,14 @@ Expr tanh(Expr a) {
|
||||
return Expr(a.graph(), new TanhNodeOp(a));
|
||||
}
|
||||
|
||||
Expr relu(Expr a) {
|
||||
return Expr(a.graph(), new ReLUNodeOp(a));
|
||||
}
|
||||
|
||||
Expr dropout(Expr a) {
|
||||
return Expr(a.graph(), new DropoutNodeOp(a));
|
||||
}
|
||||
|
||||
Expr log(Expr a) {
|
||||
return Expr(a.graph(), new LogNodeOp(a));
|
||||
};
|
||||
@ -59,92 +68,30 @@ Expr argmax(Expr a) {
|
||||
|
||||
/*********************************************************/
|
||||
|
||||
static Shape newShape(ChainPtr a, ChainPtr b) {
|
||||
size_t dimsA = a->shape().size();
|
||||
size_t dimsB = b->shape().size();
|
||||
UTIL_THROW_IF2(dimsA != dimsB,
|
||||
"Tensors have different numbers of dimensions");
|
||||
Shape shape(dimsA);
|
||||
for(size_t i = 0; i < dimsA; ++i) {
|
||||
int dimA = a->shape()[i];
|
||||
int dimB = b->shape()[i];
|
||||
bool broadcastable = (dimA == dimB || dimA == 1 || dimB == 1);
|
||||
UTIL_THROW_IF2(!broadcastable, "Different dimensions in elementwise "
|
||||
<< "operation cannot be broadcasted: " << dimA << " != " << dimB);
|
||||
shape[i] = std::max(dimA, dimB);
|
||||
if(dimA == whatevs || dimB == whatevs)
|
||||
shape[i] = whatevs;
|
||||
}
|
||||
return shape;
|
||||
}
|
||||
|
||||
Expr broadcast(Shape bShape, Expr a) {
|
||||
const Shape& aShape = a.node()->shape();
|
||||
if(aShape == bShape) {
|
||||
return a;
|
||||
}
|
||||
else {
|
||||
size_t dimsA = aShape.size();
|
||||
size_t dimsB = bShape.size();
|
||||
UTIL_THROW_IF2(dimsA != dimsB,
|
||||
"Tensor and shape have different number of dimensions");
|
||||
for(size_t i = 0; i < dimsA; ++i) {
|
||||
int dimA = aShape[i];
|
||||
int dimB = bShape[i];
|
||||
bool broadcastable = (dimA == dimB || dimA == 1);
|
||||
UTIL_THROW_IF2(!broadcastable,
|
||||
"Cannot broadcast tensor dimension "
|
||||
<< dimA << " to " << dimB);
|
||||
if(dimA == 1 && dimB != 1) {
|
||||
if(i == 0) {
|
||||
Expr one = a.graph()->ones(keywords::shape={bShape[0], 1});
|
||||
a = dot(one, a);
|
||||
}
|
||||
else if(i == 1) {
|
||||
Expr one = a.graph()->ones(keywords::shape={1, bShape[1]});
|
||||
a = dot(a, one);
|
||||
}
|
||||
else {
|
||||
UTIL_THROW2("Not implemented");
|
||||
}
|
||||
}
|
||||
}
|
||||
return a;
|
||||
}
|
||||
}
|
||||
|
||||
Expr operator+(Expr a, Expr b) {
|
||||
Shape shape = newShape(a, b);
|
||||
Expr cast_a = broadcast(shape, a);
|
||||
Expr cast_b = broadcast(shape, b);
|
||||
return Expr(a.graph(), new PlusNodeOp(cast_a, cast_b));
|
||||
return Expr(a.graph(), new PlusNodeOp(a, b));
|
||||
}
|
||||
|
||||
Expr operator-(Expr a, Expr b) {
|
||||
Shape shape = newShape(a, b);
|
||||
Expr cast_a = broadcast(shape, a);
|
||||
Expr cast_b = broadcast(shape, b);
|
||||
return Expr(a.graph(), new MinusNodeOp(cast_a, cast_b));
|
||||
return Expr(a.graph(), new MinusNodeOp(a, b));
|
||||
}
|
||||
|
||||
Expr operator*(Expr a, Expr b) {
|
||||
Shape shape = newShape(a, b);
|
||||
Expr cast_a = broadcast(shape, a);
|
||||
Expr cast_b = broadcast(shape, b);
|
||||
return Expr(a.graph(), new MultNodeOp(cast_a, cast_b));
|
||||
return Expr(a.graph(), new MultNodeOp(a, b));
|
||||
}
|
||||
|
||||
Expr operator/(Expr a, Expr b) {
|
||||
Shape shape = newShape(a, b);
|
||||
Expr cast_a = broadcast(shape, a);
|
||||
Expr cast_b = broadcast(shape, b);
|
||||
return Expr(a.graph(), new DivNodeOp(cast_a, cast_b));
|
||||
return Expr(a.graph(), new DivNodeOp(a, b));
|
||||
}
|
||||
|
||||
Expr dot(Expr a, Expr b) {
|
||||
return Expr(a.graph(), new DotNodeOp(a, b));
|
||||
}
|
||||
|
||||
Expr reluplus(Expr a, Expr b) {
|
||||
return Expr(a.graph(), new ReLUPlusNodeOp(a, b));
|
||||
}
|
||||
|
||||
Expr cross_entropy(Expr a, Expr b) {
|
||||
return Expr(a.graph(), new CrossEntropyNodeOp(a, b));
|
||||
}
|
||||
|
@ -31,6 +31,10 @@ Expr logit(Expr a);
|
||||
|
||||
Expr tanh(Expr a);
|
||||
|
||||
Expr relu(Expr a);
|
||||
|
||||
Expr dropout(Expr a);
|
||||
|
||||
Expr log(Expr a);
|
||||
|
||||
Expr exp(Expr a);
|
||||
@ -49,9 +53,8 @@ Expr operator/(Expr a, Expr b);
|
||||
|
||||
Expr dot(Expr a, Expr b);
|
||||
|
||||
/******************************************************/
|
||||
Expr reluplus(Expr a, Expr b);
|
||||
|
||||
Expr broadcast(Shape bShape, Expr a);
|
||||
|
||||
/*********************************************************/
|
||||
|
||||
|
@ -1,5 +1,7 @@
|
||||
#include <algorithm>
|
||||
#include <chrono>
|
||||
#include <iomanip>
|
||||
#include <cstdio>
|
||||
#include <boost/timer/timer.hpp>
|
||||
|
||||
#include "marian.h"
|
||||
@ -31,31 +33,28 @@ ExpressionGraph build_graph(const std::vector<int>& dims) {
|
||||
layers.emplace_back(x);
|
||||
}
|
||||
else {
|
||||
layers.emplace_back(tanh(dot(layers.back(), weights.back())) + biases.back());
|
||||
layers.emplace_back(reluplus(dot(layers.back(), weights.back()), biases.back()));
|
||||
//layers.emplace_back(relu(dot(layers.back(), weights.back()) + biases.back()));
|
||||
}
|
||||
|
||||
weights.emplace_back(
|
||||
g.param(shape={in, out},
|
||||
init=normal()));
|
||||
named(g.param(shape={in, out}, init=uniform()), "W" + std::to_string(i)));
|
||||
biases.emplace_back(
|
||||
g.param(shape={1, out},
|
||||
init=normal()));
|
||||
named(g.param(shape={1, out}, init=zeros), "b" + std::to_string(i)));
|
||||
}
|
||||
|
||||
Expr scores = named(dot(layers.back(), weights.back()) + biases.back(),
|
||||
auto scores = named(dot(layers.back(), weights.back()) + biases.back(),
|
||||
"scores");
|
||||
|
||||
Expr cost = mean(cross_entropy(scores, y), axis=0);
|
||||
|
||||
auto cost = mean(cross_entropy(scores, y), axis=0);
|
||||
//auto cost = mean(-sum(y * log(softmax(scores)), axis=1), axis=0);
|
||||
Expr costreg = named(
|
||||
auto costreg = named(
|
||||
cost, "cost"
|
||||
);
|
||||
|
||||
// If we uncomment the line below, this will just horribly diverge.
|
||||
// auto dummy_probs = named(softmax(scores), "dummy_probs");
|
||||
|
||||
std::cout << g.graphviz() << std::endl;
|
||||
|
||||
std::cerr << timer.format(5, "%ws") << std::endl;
|
||||
return g;
|
||||
}
|
||||
@ -84,8 +83,26 @@ void shuffle(std::vector<float>& x, std::vector<float>& y, size_t dimx, size_t d
|
||||
|
||||
}
|
||||
|
||||
float accuracy(const std::vector<float> pred, const std::vector<float> labels) {
|
||||
size_t acc = 0;
|
||||
for (size_t i = 0; i < labels.size(); i += LABEL_SIZE) {
|
||||
size_t correct = 0;
|
||||
size_t proposed = 0;
|
||||
for (size_t j = 0; j < LABEL_SIZE; ++j) {
|
||||
if (labels[i + j])
|
||||
correct = j;
|
||||
if (pred[i + j] > pred[i + proposed])
|
||||
proposed = j;
|
||||
}
|
||||
acc += (correct == proposed);
|
||||
}
|
||||
return float(acc) / (labels.size() / LABEL_SIZE);
|
||||
}
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
|
||||
std::cerr << std::setprecision(4) << std::fixed;
|
||||
|
||||
int trainRows, testRows;
|
||||
|
||||
std::cerr << "Loading train set...";
|
||||
@ -98,18 +115,19 @@ int main(int argc, char** argv) {
|
||||
std::vector<float> testLabels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", testRows, LABEL_SIZE);
|
||||
std::cerr << "Done." << std::endl;
|
||||
|
||||
ExpressionGraph g = build_graph({IMAGE_SIZE, 2000, 2000, 2000, 2000, 2000, LABEL_SIZE});
|
||||
ExpressionGraph g = build_graph({IMAGE_SIZE, 2048, 2048, 2048, 2048, 2048, LABEL_SIZE});
|
||||
std::cout << g.graphviz() << std::endl;
|
||||
|
||||
Tensor xt({BATCH_SIZE, IMAGE_SIZE});
|
||||
Tensor yt({BATCH_SIZE, LABEL_SIZE});
|
||||
|
||||
|
||||
boost::timer::cpu_timer total;
|
||||
Adam opt;
|
||||
Adam opt(0.0002);
|
||||
for(int i = 1; i <= 10; ++i) {
|
||||
boost::timer::cpu_timer timer;
|
||||
shuffle(trainImages, trainLabels, IMAGE_SIZE, LABEL_SIZE);
|
||||
float cost = 0;
|
||||
float acc = 0;
|
||||
for(int j = 0; j < trainRows / BATCH_SIZE; j++) {
|
||||
size_t xBatch = IMAGE_SIZE * BATCH_SIZE;
|
||||
auto xbegin = trainImages.begin() + j * xBatch;
|
||||
@ -119,19 +137,26 @@ int main(int argc, char** argv) {
|
||||
size_t yBatch = LABEL_SIZE * BATCH_SIZE;
|
||||
auto ybegin = trainLabels.begin() + j * yBatch;
|
||||
auto yend = ybegin + yBatch;
|
||||
yt.set(ybegin, yend);
|
||||
std::vector<float> ytv(ybegin, yend);
|
||||
yt.set(ytv);
|
||||
|
||||
g["x"] = xt;
|
||||
g["y"] = yt;
|
||||
|
||||
opt(g, BATCH_SIZE);
|
||||
cost += g["cost"].val()[0];
|
||||
|
||||
cost += (g["cost"].val()[0] * BATCH_SIZE) / trainRows;
|
||||
|
||||
std::vector<float> bResults;
|
||||
bResults << g["scores"].val();
|
||||
|
||||
acc += (accuracy(bResults, ytv) * BATCH_SIZE) / trainRows;
|
||||
}
|
||||
std::cerr << "Epoch: " << i << " - Cost: " << cost / trainRows * BATCH_SIZE << " - " << timer.format(3, "%ws") << std::endl;
|
||||
std::cerr << "Epoch: " << i << " - Cost: " << cost << " - Accuracy: " << acc << " - " << timer.format(3, "%ws") << std::endl;
|
||||
}
|
||||
std::cerr << "Total: " << total.format(3, "%ws") << std::endl;
|
||||
|
||||
std::vector<float> results;
|
||||
std::vector<float> results;
|
||||
for(int j = 0; j < testRows / BATCH_SIZE; j++) {
|
||||
size_t xBatch = IMAGE_SIZE * BATCH_SIZE;
|
||||
auto xbegin = testImages.begin() + j * xBatch;
|
||||
@ -149,19 +174,7 @@ int main(int argc, char** argv) {
|
||||
results.insert(results.end(), bResults.begin(), bResults.end());
|
||||
}
|
||||
|
||||
size_t acc = 0;
|
||||
for (size_t i = 0; i < testLabels.size(); i += LABEL_SIZE) {
|
||||
size_t correct = 0;
|
||||
size_t proposed = 0;
|
||||
for (size_t j = 0; j < LABEL_SIZE; ++j) {
|
||||
if (testLabels[i + j])
|
||||
correct = j;
|
||||
if (results[i + j] > results[i + proposed])
|
||||
proposed = j;
|
||||
}
|
||||
acc += (correct == proposed);
|
||||
}
|
||||
std::cerr << "Accuracy: " << float(acc) / testRows << std::endl;
|
||||
std::cerr << "Accuracy: " << accuracy(results, testLabels) << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
14
src/node.h
14
src/node.h
@ -89,8 +89,22 @@ class Node : public Chainable<Tensor>,
|
||||
return shape_;
|
||||
}
|
||||
|
||||
void set_name(const std::string& name) {
|
||||
name_ = name;
|
||||
}
|
||||
|
||||
const std::string &name() const { return name_; }
|
||||
|
||||
virtual const std::string label(const std::string& type) {
|
||||
std::stringstream label;
|
||||
label << "<" << type;
|
||||
if(name_ != "none") {
|
||||
label << "<br/>" << "\"" << name_ << "\"";
|
||||
}
|
||||
label << ">";
|
||||
return label.str();
|
||||
}
|
||||
|
||||
protected:
|
||||
Shape shape_;
|
||||
std::string name_;
|
||||
|
@ -22,8 +22,6 @@
|
||||
// SOFTWARE.
|
||||
|
||||
#include "node.h"
|
||||
#include "node_operators_unary.h"
|
||||
#include "node_operators_binary.h"
|
||||
#include "tensor_operators.h"
|
||||
|
||||
namespace marian {
|
||||
@ -48,7 +46,7 @@ struct InputNode : public Node {
|
||||
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"parallelogram\", label=\"input\", style=\"filled\", fillcolor=\"lawngreen\"]" << std::endl << std::endl;
|
||||
ss << "\"" << this << "\" [shape=\"circle\", label=" << label("input") << ", style=\"filled\", fillcolor=\"lawngreen\"]" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
@ -68,7 +66,7 @@ struct ConstantNode : public Node {
|
||||
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"diamond\", label=\"const\"]" << std::endl << std::endl;
|
||||
ss << "\"" << this << "\" [shape=\"diamond\", label=" << label("const") << "]" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
@ -99,7 +97,9 @@ struct ParamNode : public Node {
|
||||
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"hexagon\", label=\"param\", style=\"filled\", fillcolor=\"orangered\"]" << std::endl << std::endl;
|
||||
ss << "\"" << this << "\" [shape=\"hexagon\", label=" << label("param")
|
||||
<< ", style=\"filled\", fillcolor=\"orangered\"]"
|
||||
<< std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
@ -109,8 +109,524 @@ struct ParamNode : public Node {
|
||||
bool initialized_;
|
||||
};
|
||||
|
||||
struct UnaryNodeOp : public Node {
|
||||
ChainPtr a_;
|
||||
|
||||
template <typename ...Args>
|
||||
UnaryNodeOp(ChainPtr a, Args ...args)
|
||||
: Node(keywords::shape=a->shape(), //@TODO: Check keywords?
|
||||
args...), a_(a) {}
|
||||
};
|
||||
|
||||
struct LogitNodeOp : public UnaryNodeOp {
|
||||
template <typename ...Args>
|
||||
LogitNodeOp(Args ...args)
|
||||
: UnaryNodeOp(args...) { }
|
||||
|
||||
void forward() {
|
||||
Element(_1 = Sigma(_2),
|
||||
val_, a_->val());
|
||||
}
|
||||
|
||||
void backward() {
|
||||
Element(_1 += _2 * _3 * (1.0f - _3),
|
||||
a_->grad(), adj_, val_);
|
||||
}
|
||||
|
||||
void check() {
|
||||
|
||||
}
|
||||
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=" << label("logit")
|
||||
<< ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
};
|
||||
|
||||
struct TanhNodeOp : public UnaryNodeOp {
|
||||
template <typename ...Args>
|
||||
TanhNodeOp(Args ...args)
|
||||
: UnaryNodeOp(args...) { }
|
||||
|
||||
void forward() {
|
||||
Element(_1 = Tanh(_2),
|
||||
val_, a_->val());
|
||||
}
|
||||
|
||||
void backward() {
|
||||
Element(_1 += _2 * (1.0f - (_3 * _3)),
|
||||
a_->grad(), adj_, val_);
|
||||
}
|
||||
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=" << label("tanh")
|
||||
<< ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
};
|
||||
|
||||
struct ReLUNodeOp : public UnaryNodeOp {
|
||||
template <typename ...Args>
|
||||
ReLUNodeOp(Args ...args)
|
||||
: UnaryNodeOp(args...) { }
|
||||
|
||||
void forward() {
|
||||
Element(_1 = ReLU(_2),
|
||||
val_, a_->val());
|
||||
}
|
||||
|
||||
void backward() {
|
||||
Element(_1 += _2 * ReLUback(_3),
|
||||
a_->grad(), adj_, a_->val());
|
||||
}
|
||||
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=" << label("ReLU")
|
||||
<< ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
};
|
||||
|
||||
// @TODO: slow and probably buggy
|
||||
struct DropoutNodeOp : public UnaryNodeOp {
|
||||
template <typename ...Args>
|
||||
DropoutNodeOp(Args ...args)
|
||||
: UnaryNodeOp(args...),
|
||||
p_(0.5), seed_(time(0)) { }
|
||||
|
||||
void forward() {
|
||||
//Element(_1 = Bernoulli(p_, (size_t)this) * _2,
|
||||
// val_, a_->val())
|
||||
Dropout(val_, a_->val(), p_, seed_++);
|
||||
}
|
||||
|
||||
void backward() {
|
||||
Element(_1 += _2 * (_3 != 0.0f), // transform non-zero to 1
|
||||
a_->grad(), adj_, val_);
|
||||
}
|
||||
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=" << label("dropout")
|
||||
<< ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
private:
|
||||
float p_;
|
||||
int seed_;
|
||||
};
|
||||
|
||||
|
||||
struct SoftmaxNodeOp : public UnaryNodeOp {
|
||||
template <typename ...Args>
|
||||
SoftmaxNodeOp(Args ...args)
|
||||
: UnaryNodeOp(args...) { }
|
||||
|
||||
void forward() {
|
||||
// B = softmax(A).
|
||||
thrust::copy(a_->val().begin(), a_->val().end(), val_.begin());
|
||||
// Safe version of softmax.
|
||||
Softmax(&val_);
|
||||
}
|
||||
|
||||
void backward() {
|
||||
// For each row, the Jacobian times vector is given by:
|
||||
// J * dy = p .* (dy - avg*1)
|
||||
// where avg = p'*dy and p is the softmax output (probabilities).
|
||||
//
|
||||
// For more information, see sec. 2.5 of the following reference:
|
||||
// André F. T. Martins and Ramon Astudillo.
|
||||
// "From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label
|
||||
// Classification." ICML 2016.
|
||||
// http://jmlr.org/proceedings/papers/v48/martins16.pdf
|
||||
|
||||
SoftmaxGrad(a_->grad(), adj_, val_);
|
||||
}
|
||||
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=" << label("softmax")
|
||||
<< ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
};
|
||||
|
||||
struct ArgmaxNodeOp : public UnaryNodeOp {
|
||||
template <typename ...Args>
|
||||
ArgmaxNodeOp(ChainPtr a, Args ...args)
|
||||
: UnaryNodeOp(a, keywords::shape=newShape(a), args...) { }
|
||||
|
||||
void forward() {
|
||||
// B = softmax(A).
|
||||
Argmax(&val_, &a_->val());
|
||||
}
|
||||
|
||||
void backward() {
|
||||
}
|
||||
|
||||
Shape newShape(ChainPtr a) {
|
||||
Shape shape = a->shape();
|
||||
shape[1] = 1;
|
||||
return shape;
|
||||
}
|
||||
|
||||
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label="
|
||||
<< label("argmax") << ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
};
|
||||
|
||||
struct LogNodeOp : public UnaryNodeOp {
|
||||
template <typename ...Args>
|
||||
LogNodeOp(Args ...args)
|
||||
: UnaryNodeOp(args...) {}
|
||||
|
||||
void forward() {
|
||||
Element(_1 = Log(_2), val_, a_->val());
|
||||
}
|
||||
|
||||
void backward() {
|
||||
Element(_1 += _2 * (1.f / _3),
|
||||
a_->grad(), adj_, a_->val());
|
||||
}
|
||||
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label="
|
||||
<< label("log") << ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
};
|
||||
|
||||
struct ExpNodeOp : public UnaryNodeOp {
|
||||
template <typename ...Args>
|
||||
ExpNodeOp(Args ...args)
|
||||
: UnaryNodeOp(args...) { }
|
||||
|
||||
void forward() {
|
||||
Element(_1 = Exp(_2), val_, a_->val());
|
||||
}
|
||||
|
||||
void backward() {
|
||||
Element(_1 += _2 * Exp(_3),
|
||||
a_->grad(), adj_, a_->val());
|
||||
}
|
||||
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=" << label("exp")
|
||||
<< ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
};
|
||||
|
||||
struct NegNodeOp : public UnaryNodeOp {
|
||||
template <typename ...Args>
|
||||
NegNodeOp(Args ...args)
|
||||
: UnaryNodeOp(args...) { }
|
||||
|
||||
void forward() {
|
||||
Element(_1 = -_2, val_, a_->val());
|
||||
}
|
||||
|
||||
void backward() {
|
||||
Element(_1 += -_2, a_->grad(), adj_);
|
||||
}
|
||||
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label="
|
||||
<< label("-") << ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
};
|
||||
|
||||
/******************************************************/
|
||||
|
||||
struct BinaryNodeOp : public Node {
|
||||
ChainPtr a_;
|
||||
ChainPtr b_;
|
||||
|
||||
template <typename ...Args>
|
||||
BinaryNodeOp(ChainPtr a, ChainPtr b, Args ...args)
|
||||
: Node(args...), a_(a), b_(b) {}
|
||||
};
|
||||
|
||||
/*** Matrix Product ***/
|
||||
|
||||
struct DotNodeOp : public BinaryNodeOp {
|
||||
template <typename ...Args>
|
||||
DotNodeOp(ChainPtr a, ChainPtr b, Args ...args)
|
||||
: BinaryNodeOp(a, b,
|
||||
keywords::shape=newShape(a, b),
|
||||
args...) { }
|
||||
|
||||
Shape newShape(ChainPtr a, ChainPtr b) {
|
||||
Shape shape1 = a->shape();
|
||||
Shape shape2 = b->shape();
|
||||
UTIL_THROW_IF2(shape1[1] != shape2[0],
|
||||
"matrix product requires dimensions to match");
|
||||
shape1[1] = shape2[1];
|
||||
return shape1;
|
||||
}
|
||||
|
||||
void forward() {
|
||||
// C = A*B
|
||||
Prod(val_, a_->val(), b_->val(), false, false);
|
||||
}
|
||||
|
||||
void backward() {
|
||||
// D is the adjoint, the matrix of derivatives
|
||||
// df/dA += D*B.T
|
||||
// df/dB += A.T*D
|
||||
// beta set to 1.0 in gemm, C = dot(A,B) + beta * C
|
||||
// to sum gradients from different graph parts
|
||||
Prod(a_->grad(), adj_, b_->val(), false, true, 1.0);
|
||||
Prod(b_->grad(), a_->val(), adj_, true, false, 1.0);
|
||||
}
|
||||
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=" << label("×")
|
||||
<< ", style=\"filled\", fillcolor=\"orange\"]" << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl;
|
||||
ss << "\"" << b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
};
|
||||
|
||||
struct PlusNodeOp : public BinaryNodeOp {
|
||||
template <typename ...Args>
|
||||
PlusNodeOp(ChainPtr a, ChainPtr b, Args ...args)
|
||||
: BinaryNodeOp(a, b, keywords::shape=a->shape(), args...) { }
|
||||
|
||||
void forward() {
|
||||
Element(_1 = _2 + _3,
|
||||
val_, a_->val(), b_->val());
|
||||
}
|
||||
|
||||
void backward() {
|
||||
Element(_1 += _2,
|
||||
a_->grad(), adj_);
|
||||
Element(_1 += _2,
|
||||
b_->grad(), adj_);
|
||||
}
|
||||
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=" << label("+")
|
||||
<< ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl;
|
||||
ss << "\"" << b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
};
|
||||
|
||||
struct ReLUPlusNodeOp : public BinaryNodeOp {
|
||||
template <typename ...Args>
|
||||
ReLUPlusNodeOp(ChainPtr a, ChainPtr b, Args ...args)
|
||||
: BinaryNodeOp(a, b, keywords::shape=a->shape(), args...) { }
|
||||
|
||||
void forward() {
|
||||
Element(_1 = ReLU(_2 + _3),
|
||||
val_, a_->val(), b_->val());
|
||||
}
|
||||
|
||||
void backward() {
|
||||
Element(_1 += _2 * ReLUback(_3 + _4),
|
||||
a_->grad(), adj_, a_->val(), b_->val());
|
||||
Element(_1 += _2 * ReLUback(_3 + _4),
|
||||
b_->grad(), adj_, a_->val(), b_->val());
|
||||
}
|
||||
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=" << label("ReLU<br/>+")
|
||||
<< ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl;
|
||||
ss << "\"" << b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
};
|
||||
|
||||
struct MinusNodeOp : public BinaryNodeOp {
|
||||
template <typename ...Args>
|
||||
MinusNodeOp(ChainPtr a, ChainPtr b, Args ...args)
|
||||
: BinaryNodeOp(a, b, keywords::shape=a->shape(), args...) { }
|
||||
|
||||
void forward() {
|
||||
Element(_1 = _2 - _3,
|
||||
val_, a_->val(), b_->val());
|
||||
}
|
||||
|
||||
void backward() {
|
||||
Element(_1 += _2,
|
||||
a_->grad(), adj_);
|
||||
Element(_1 -= _2,
|
||||
b_->grad(), adj_);
|
||||
}
|
||||
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=" << label("-")
|
||||
<< ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl;
|
||||
ss << "\"" << b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
};
|
||||
|
||||
struct MultNodeOp : public BinaryNodeOp {
|
||||
template <typename ...Args>
|
||||
MultNodeOp(ChainPtr a, ChainPtr b, Args ...args)
|
||||
: BinaryNodeOp(a, b, keywords::shape=a->shape(), args...) { }
|
||||
|
||||
void forward() {
|
||||
Element(_1 = _2 * _3,
|
||||
val_, a_->val(), b_->val());
|
||||
}
|
||||
|
||||
void backward() {
|
||||
Element(_1 += _2 * _3,
|
||||
a_->grad(), adj_, b_->val());
|
||||
Element(_1 += _2 * _3,
|
||||
b_->grad(), adj_, a_->val());
|
||||
}
|
||||
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=" << label("•")
|
||||
<< ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl;
|
||||
ss << "\"" << b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
};
|
||||
|
||||
struct DivNodeOp : public BinaryNodeOp {
|
||||
template <typename ...Args>
|
||||
DivNodeOp(ChainPtr a, ChainPtr b, Args ...args)
|
||||
: BinaryNodeOp(a, b, keywords::shape=a->shape(), args...) { }
|
||||
|
||||
void forward() {
|
||||
Element(_1 = _2 / _3,
|
||||
val_, a_->val(), b_->val());
|
||||
}
|
||||
|
||||
void backward() {
|
||||
Element(_1 += _2 * 1.0f / _3,
|
||||
a_->grad(), adj_, b_->val());
|
||||
Element(_1 -= _2 * _3 / (_4 * _4),
|
||||
b_->grad(), adj_, a_->val(), b_->val());
|
||||
}
|
||||
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=" << label("÷")
|
||||
<< ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl;
|
||||
ss << "\"" << b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
};
|
||||
|
||||
// Cross-entropy node. It computes -b*log(softmax(a)), summing rowwise.
|
||||
struct CrossEntropyNodeOp : public BinaryNodeOp {
|
||||
template <typename ...Args>
|
||||
CrossEntropyNodeOp(ChainPtr a, ChainPtr b, Args ...args)
|
||||
: BinaryNodeOp(a, b,
|
||||
keywords::shape=newShape(a, b),
|
||||
args...) { }
|
||||
|
||||
Shape newShape(ChainPtr a, ChainPtr b) {
|
||||
Shape shape1 = a->shape();
|
||||
Shape shape2 = b->shape();
|
||||
UTIL_THROW_IF2(shape1[0] != shape2[0] || shape1[1] != shape2[1],
|
||||
"cross entropy requires dimensions to match");
|
||||
shape1[1] = 1;
|
||||
return shape1;
|
||||
}
|
||||
|
||||
// We're caching the softmax probabilities here because we'll need them for
|
||||
// the backward computation.
|
||||
void forward() {
|
||||
// C = -dot(B, log(softmax(A))).
|
||||
if (probs_) {
|
||||
probs_.set(0.0);
|
||||
} else {
|
||||
probs_.allocate(a_->val().shape(), 0.0);
|
||||
}
|
||||
thrust::copy(a_->val().begin(), a_->val().end(), probs_.begin());
|
||||
Softmax(&probs_); // Safe version of softmax.
|
||||
Tensor result(a_->val().shape());
|
||||
Element(_1 = -_2 * Log(_3), result, b_->val(), probs_);
|
||||
SumRowwise(result, val_);
|
||||
}
|
||||
|
||||
// @TODO: In most cases it's wasteful to compute the derivative with respect
|
||||
// to the second input which is typically an input node in the computation
|
||||
// graph. In general the backward functions can skip the computation of
|
||||
// gradients wrt input nodes.
|
||||
void backward() {
|
||||
// For each row, the first input derivative is given by adj * (p - y),
|
||||
// where y is the gold label distribution (e.g. one hot vector) and
|
||||
// p is the softmax output (probabilities).
|
||||
// The second input derivative is -adj*log(p).
|
||||
Tensor result(probs_.shape());
|
||||
|
||||
// Compute first input derivative.
|
||||
Element(_1 = _2 - _3, result, probs_, b_->val());
|
||||
ScaleRowwise(result, adj_);
|
||||
Element(_1 += _2, a_->grad(), result);
|
||||
|
||||
// Compute second input derivative.
|
||||
Element(_1 = -Log(_2), result, probs_); // @TODO: use a cached log here.
|
||||
ScaleRowwise(result, adj_);
|
||||
Element(_1 += _2, b_->grad(), result);
|
||||
}
|
||||
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=" << label("x-ent")
|
||||
<< ", style=\"filled\", fillcolor=\"orange\"]" << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
ss << "\"" << b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
protected:
|
||||
Tensor probs_;
|
||||
|
||||
};
|
||||
|
||||
}
|
||||
|
@ -1,37 +1,14 @@
|
||||
#pragma once
|
||||
|
||||
// This file is part of the Marian toolkit.
|
||||
// Marian is copyright (c) 2016 Marcin Junczys-Dowmunt.
|
||||
//
|
||||
// Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
// of this software and associated documentation files (the "Software"), to deal
|
||||
// in the Software without restriction, including without limitation the rights
|
||||
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
// copies of the Software, and to permit persons to whom the Software is
|
||||
// furnished to do so, subject to the following conditions:
|
||||
//
|
||||
// The above copyright notice and this permission notice shall be included in
|
||||
// all copies or substantial portions of the Software.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
// SOFTWARE.
|
||||
|
||||
#include <map>
|
||||
#include <boost/any.hpp>
|
||||
#include "tensor_operators.h"
|
||||
|
||||
namespace marian {
|
||||
|
||||
|
||||
// @TODO: modify computation graph to group all paramters in single matrix object.
|
||||
// This will allow to perform a single large SGD update per batch. Currently there
|
||||
// are as many updates as different paramters.
|
||||
|
||||
// @TODO: Implement Element(...) with multiple functors for compacting of calls.
|
||||
// are as many updates as different parameters.
|
||||
|
||||
class Sgd {
|
||||
public:
|
||||
@ -64,9 +41,9 @@ class Adagrad {
|
||||
|
||||
auto gtIt = gt_.begin();
|
||||
for(auto& param : graph.params()) {
|
||||
Element(_1 += _2 * _2,
|
||||
Element(_1 += (_2 * _2),
|
||||
*gtIt, param.grad());
|
||||
Element(_1 -= eta_ / (Sqrt(_2) + eps_) * _3,
|
||||
Element(_1 -= (eta_ / (Sqrt(_2) + eps_)) * _3,
|
||||
param.val(), *gtIt, param.grad());
|
||||
gtIt++;
|
||||
}
|
||||
@ -78,6 +55,7 @@ class Adagrad {
|
||||
std::vector<Tensor> gt_;
|
||||
};
|
||||
|
||||
|
||||
// @TODO: Add serialization for historic gradients and parameters
|
||||
// https://arxiv.org/pdf/1412.6980v8.pdf
|
||||
class Adam {
|
||||
@ -94,18 +72,19 @@ class Adam {
|
||||
vt_.emplace_back(Tensor(param.grad().shape(), 0));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
t_++;
|
||||
float denom1 = 1 - pow(beta1_, t_);
|
||||
float denom2 = 1 - pow(beta2_, t_);
|
||||
|
||||
auto mtIt = mt_.begin();
|
||||
auto vtIt = vt_.begin();
|
||||
|
||||
for(auto& param : graph.params()) {
|
||||
Element(_1 = beta1_ * _2 + (1 - beta1_) * _3,
|
||||
*mtIt, *mtIt, param.grad());
|
||||
Element(_1 = beta2_ * _2 + (1 - beta2_) * _3 * _3,
|
||||
*vtIt, *vtIt, param.grad());
|
||||
Element(_1 = (beta1_ * _1) + ((1 - beta1_) * _2),
|
||||
*mtIt, param.grad());
|
||||
Element(_1 = (beta2_ * _1) + ((1 - beta2_) * (_2 * _2)),
|
||||
*vtIt, param.grad());
|
||||
Element(_1 -= eta_ * (_2 / denom1) / (Sqrt(_3 / denom2) + eps_),
|
||||
param.val(), *mtIt, *vtIt);
|
||||
mtIt++; vtIt++;
|
||||
|
@ -25,11 +25,25 @@
|
||||
#include <algorithm>
|
||||
#include <iterator>
|
||||
#include <functional>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "tensor.h"
|
||||
|
||||
namespace marian {
|
||||
|
||||
float xor128() {
|
||||
static uint64_t x = 123456789;
|
||||
static uint64_t y = 362436069;
|
||||
static uint64_t z = 521288629;
|
||||
static uint64_t w = 88675123;
|
||||
uint64_t t;
|
||||
|
||||
t = (x ^ (x << 11)) % 1000;
|
||||
x = y; y = z; z = w;
|
||||
w = (w ^ (w >> 19) ^ t ^ (t >> 8)) % 1000;
|
||||
return 0.1 * ((w % 1000) / 1000.f) - 0.05;
|
||||
}
|
||||
|
||||
// Use a constant seed for deterministic behaviour.
|
||||
std::default_random_engine engine(42);
|
||||
|
||||
@ -60,12 +74,29 @@ std::function<void(Tensor)> normal(float mean = 0.0, float std = 0.05) {
|
||||
};
|
||||
}
|
||||
|
||||
std::function<void(Tensor)> uniform(float a = 0.0, float b = 0.05) {
|
||||
std::function<void(Tensor)> uniform(float a = -0.05, float b = 0.05) {
|
||||
return [a, b](Tensor t) {
|
||||
distribution<std::uniform_real_distribution<float>>(t, a, b);
|
||||
};
|
||||
}
|
||||
|
||||
void glorot_uniform(Tensor t) {
|
||||
float b = sqrtf( 6.0f / (t.shape()[0] + t.shape()[1]) );
|
||||
distribution<std::uniform_real_distribution<float>>(t, -b, b);
|
||||
}
|
||||
|
||||
void xorshift(Tensor t) {
|
||||
std::vector<float> vals(t.size());
|
||||
for(auto&& v : vals)
|
||||
v = xor128();
|
||||
t << vals;
|
||||
}
|
||||
|
||||
void glorot_normal(Tensor t) {
|
||||
float b = sqrtf( 2.0f / (t.shape()[0] + t.shape()[1]) );
|
||||
distribution<std::uniform_real_distribution<float>>(t, -b, b);
|
||||
}
|
||||
|
||||
std::function<void(Tensor)> from_vector(const std::vector<float>& v) {
|
||||
return [v](Tensor t) {
|
||||
t << v;
|
||||
|
@ -19,6 +19,8 @@
|
||||
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
// SOFTWARE.
|
||||
|
||||
#include <curand_kernel.h>
|
||||
|
||||
#include "tensor_operators.h"
|
||||
|
||||
using namespace std;
|
||||
@ -33,6 +35,48 @@ static cublasHandle_t create_handle() {
|
||||
}
|
||||
cublasHandle_t cublasHandle = create_handle();
|
||||
|
||||
__global__ void gDropout(float* out, const float* in,
|
||||
int seed, const float p, int rows, int cols) {
|
||||
|
||||
int shift = blockIdx.x * cols + threadIdx.x;
|
||||
curandState state;
|
||||
curand_init(seed, shift, 0, &state);
|
||||
for(int bid = 0; bid < rows; bid += gridDim.x) {
|
||||
int j = bid + blockIdx.x;
|
||||
if(j < rows) {
|
||||
Float* rowOut = out + j * cols;
|
||||
const Float* rowIn = in + j * cols;
|
||||
|
||||
for(int tid = 0; tid < cols; tid += blockDim.x) {
|
||||
int i = tid + threadIdx.x;
|
||||
if(i < cols) {
|
||||
//int offset = i;
|
||||
float dropout = (curand_uniform(&state) >= p);
|
||||
rowOut[i] = dropout * rowIn[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Slow!!!
|
||||
void Dropout(Tensor out, Tensor in, float p, int seed) {
|
||||
int m = in.shape()[0];
|
||||
int n = in.shape()[1];
|
||||
|
||||
curandGenerator_t prng;
|
||||
curandCreateGenerator(&prng, CURAND_RNG_PSEUDO_XORWOW);
|
||||
curandSetPseudoRandomGeneratorSeed(prng, (unsigned long long) seed);
|
||||
curandGenerateUniform(prng, out.data(), m * n);
|
||||
Element(_1 = (_1 > p), out);
|
||||
Element(_1 = _1 * _2, out, in);
|
||||
//int blocks = std::min(MAX_BLOCKS, m);
|
||||
//int threads = std::min(MAX_THREADS, k);
|
||||
//gDropout<<<blocks, threads>>>(out.data(), in.data(), seed, p, m, k);
|
||||
//cudaStreamSynchronize(0);
|
||||
}
|
||||
|
||||
|
||||
__global__ void gSoftmaxGrad(float* grad, const float* adj, const float* val,
|
||||
const int rows, const int cols) {
|
||||
for(int bid = 0; bid < rows; bid += gridDim.x) {
|
||||
|
@ -29,140 +29,176 @@ using namespace thrust::placeholders;
|
||||
#define MAX_THREADS 512
|
||||
#define MAX_BLOCKS 65535
|
||||
|
||||
template <class Functor>
|
||||
__global__ void gElement(Functor functor, Float* out,
|
||||
size_t rows, size_t cols) {
|
||||
for(int bid = 0; bid < rows; bid += gridDim.x) {
|
||||
int j = bid + blockIdx.x;
|
||||
if(j < rows) {
|
||||
Float* rowOut = out + j * cols;
|
||||
for(int tid = 0; tid < cols; tid += blockDim.x) {
|
||||
int i = tid + threadIdx.x;
|
||||
if(i < cols)
|
||||
rowOut[i] = functor(rowOut[i]);;
|
||||
}
|
||||
class TensorView {
|
||||
private:
|
||||
float* data_;
|
||||
int rows_;
|
||||
int cols_;
|
||||
|
||||
public:
|
||||
TensorView(Tensor t)
|
||||
: data_(t.data()), rows_(t.shape()[0]), cols_(t.shape()[1]) {}
|
||||
|
||||
__device__ float& operator()(int i, int j) {
|
||||
if(rows_ != 1 && cols_ != 1)
|
||||
return data_[i * cols_ + j];
|
||||
if(rows_ != 1 && cols_ == 1)
|
||||
return data_[i];
|
||||
if(rows_ == 1 && cols_ != 1)
|
||||
return data_[j];
|
||||
return data_[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ int rows() {
|
||||
return rows_;
|
||||
}
|
||||
|
||||
__device__ int cols() {
|
||||
return cols_;
|
||||
}
|
||||
};
|
||||
|
||||
//template <class Functor>
|
||||
//__global__ void gElement(Functor functor) {
|
||||
// int rows = out.rows();
|
||||
// int cols = out.cols();
|
||||
// for(int bid = 0; bid < rows; bid += gridDim.x) {
|
||||
// int i = bid + blockIdx.x;
|
||||
// if(i < rows) {
|
||||
// for(int tid = 0; tid < cols; tid += blockDim.x) {
|
||||
// int j = tid + threadIdx.x;
|
||||
// if(j < cols)
|
||||
// functor(i, j);
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
//}
|
||||
|
||||
template <class Functor>
|
||||
__global__ void gElement(Functor functor,
|
||||
Float* out, const Float* in,
|
||||
size_t rows, size_t cols) {
|
||||
TensorView out) {
|
||||
int rows = out.rows();
|
||||
int cols = out.cols();
|
||||
for(int bid = 0; bid < rows; bid += gridDim.x) {
|
||||
int j = bid + blockIdx.x;
|
||||
if(j < rows) {
|
||||
Float* rowOut = out + j * cols;
|
||||
const Float* rowIn = in + j * cols;
|
||||
|
||||
int i = bid + blockIdx.x;
|
||||
if(i < rows) {
|
||||
for(int tid = 0; tid < cols; tid += blockDim.x) {
|
||||
int i = tid + threadIdx.x;
|
||||
if(i < cols)
|
||||
rowOut[i] = functor(rowOut[i], rowIn[i]);;
|
||||
int j = tid + threadIdx.x;
|
||||
if(j < cols)
|
||||
out(i, j) = functor(out(i, j));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class Functor>
|
||||
__global__ void gElement(Functor functor,
|
||||
Float* out, const Float* in1, const Float* in2,
|
||||
size_t rows, size_t cols) {
|
||||
for(int bid = 0; bid < rows; bid += gridDim.x) {
|
||||
int j = bid + blockIdx.x;
|
||||
if(j < rows) {
|
||||
Float* rowOut = out + j * cols;
|
||||
const Float* rowIn1 = in1 + j * cols;
|
||||
const Float* rowIn2 = in2 + j * cols;
|
||||
|
||||
for(int tid = 0; tid < cols; tid += blockDim.x) {
|
||||
int i = tid + threadIdx.x;
|
||||
if(i < cols)
|
||||
rowOut[i] = functor(rowOut[i], rowIn1[i], rowIn2[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class Functor>
|
||||
__global__ void gElement(Functor functor,
|
||||
Float* out, const Float* in1,
|
||||
const Float* in2, const Float* in3,
|
||||
size_t rows, size_t cols) {
|
||||
for(int bid = 0; bid < rows; bid += gridDim.x) {
|
||||
int j = bid + blockIdx.x;
|
||||
if(j < rows) {
|
||||
Float* rowOut = out + j * cols;
|
||||
const Float* rowIn1 = in1 + j * cols;
|
||||
const Float* rowIn2 = in2 + j * cols;
|
||||
const Float* rowIn3 = in3 + j * cols;
|
||||
|
||||
for(int tid = 0; tid < cols; tid += blockDim.x) {
|
||||
int i = tid + threadIdx.x;
|
||||
if(i < cols)
|
||||
rowOut[i] = functor(rowOut[i], rowIn1[i], rowIn2[i], rowIn3[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// @TODO add broadcasting
|
||||
|
||||
template <class Functor>
|
||||
void Element(Functor functor, Tensor Out) {
|
||||
Float* d_out = Out.data();
|
||||
int blocks = std::min(MAX_BLOCKS, (int)Out.shape()[0]);
|
||||
int threads = std::min(MAX_THREADS, (int)Out.shape()[1]);
|
||||
gElement<<<blocks, threads>>>(functor, d_out,
|
||||
Out.shape()[0], Out.shape()[1]);
|
||||
cudaStreamSynchronize(0);
|
||||
}
|
||||
|
||||
template <class Functor>
|
||||
void Element(Functor functor,
|
||||
Tensor Out, const Tensor In) {
|
||||
Float* d_out = Out.data();
|
||||
const Float* d_in = In.data();
|
||||
Tensor out) {
|
||||
|
||||
int blocks = std::min(MAX_BLOCKS, (int)Out.shape()[0]);
|
||||
int threads = std::min(MAX_THREADS, (int)Out.shape()[1]);
|
||||
gElement<<<blocks, threads>>>(functor, d_out, d_in,
|
||||
Out.shape()[0], Out.shape()[1]);
|
||||
int m = out.shape()[0];
|
||||
int n = out.shape()[1];
|
||||
|
||||
int blocks = std::min(MAX_BLOCKS, m);
|
||||
int threads = std::min(MAX_THREADS, n);
|
||||
gElement<<<blocks, threads>>>(functor, TensorView(out));
|
||||
cudaStreamSynchronize(0);
|
||||
}
|
||||
|
||||
|
||||
template <class Functor>
|
||||
__global__ void gElement(Functor functor,
|
||||
TensorView out, TensorView in) {
|
||||
int rows = out.rows();
|
||||
int cols = out.cols();
|
||||
for(int bid = 0; bid < rows; bid += gridDim.x) {
|
||||
int i = bid + blockIdx.x;
|
||||
if(i < rows) {
|
||||
for(int tid = 0; tid < cols; tid += blockDim.x) {
|
||||
int j = tid + threadIdx.x;
|
||||
if(j < cols)
|
||||
out(i, j) = functor(out(i, j), in(i, j));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class Functor>
|
||||
void Element(Functor functor,
|
||||
Tensor Out, const Tensor In1, const Tensor In2) {
|
||||
Tensor out, Tensor in) {
|
||||
|
||||
int m = out.shape()[0];
|
||||
int n = out.shape()[1];
|
||||
|
||||
Float* d_out = Out.data();
|
||||
const Float* d_in1 = In1.data();
|
||||
const Float* d_in2 = In2.data();
|
||||
|
||||
int blocks = std::min(MAX_BLOCKS, (int)Out.shape()[0]);
|
||||
int threads = std::min(MAX_THREADS, (int)Out.shape()[1]);
|
||||
gElement<<<blocks, threads>>>(functor, d_out, d_in1, d_in2,
|
||||
Out.shape()[0], Out.shape()[1]);
|
||||
int blocks = std::min(MAX_BLOCKS, m);
|
||||
int threads = std::min(MAX_THREADS, n);
|
||||
gElement<<<blocks, threads>>>(functor, TensorView(out), TensorView(in));
|
||||
cudaStreamSynchronize(0);
|
||||
}
|
||||
|
||||
template <class Functor>
|
||||
__global__ void gElement(Functor functor,
|
||||
TensorView out, TensorView in1, TensorView in2) {
|
||||
int rows = out.rows();
|
||||
int cols = out.cols();
|
||||
for(int bid = 0; bid < rows; bid += gridDim.x) {
|
||||
int i = bid + blockIdx.x;
|
||||
if(i < rows) {
|
||||
for(int tid = 0; tid < cols; tid += blockDim.x) {
|
||||
int j = tid + threadIdx.x;
|
||||
if(j < cols)
|
||||
out(i, j) = functor(out(i, j), in1(i, j), in2(i, j));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class Functor>
|
||||
void Element(Functor functor,
|
||||
Tensor Out, const Tensor In1,
|
||||
const Tensor In2, const Tensor In3) {
|
||||
Tensor out, Tensor in1, Tensor in2) {
|
||||
|
||||
int m = out.shape()[0];
|
||||
int n = out.shape()[1];
|
||||
|
||||
Float* d_out = Out.data();
|
||||
const Float* d_in1 = In1.data();
|
||||
const Float* d_in2 = In2.data();
|
||||
const Float* d_in3 = In3.data();
|
||||
|
||||
int blocks = std::min(MAX_BLOCKS, (int)Out.shape()[0]);
|
||||
int threads = std::min(MAX_THREADS, (int)Out.shape()[1]);
|
||||
gElement<<<blocks, threads>>>(functor, d_out, d_in1, d_in2, d_in3,
|
||||
Out.shape()[0], Out.shape()[1]);
|
||||
int blocks = std::min(MAX_BLOCKS, m);
|
||||
int threads = std::min(MAX_THREADS, n);
|
||||
gElement<<<blocks, threads>>>(functor, TensorView(out),
|
||||
TensorView(in1), TensorView(in2));
|
||||
cudaStreamSynchronize(0);
|
||||
}
|
||||
|
||||
template <class Functor>
|
||||
__global__ void gElement(Functor functor,
|
||||
TensorView out, TensorView in1, TensorView in2, TensorView in3) {
|
||||
int rows = out.rows();
|
||||
int cols = out.cols();
|
||||
for(int bid = 0; bid < rows; bid += gridDim.x) {
|
||||
int i = bid + blockIdx.x;
|
||||
if(i < rows) {
|
||||
for(int tid = 0; tid < cols; tid += blockDim.x) {
|
||||
int j = tid + threadIdx.x;
|
||||
if(j < cols)
|
||||
out(i, j) = functor(out(i, j), in1(i, j), in2(i, j), in3(i, j));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class Functor>
|
||||
void Element(Functor functor, Tensor out,
|
||||
Tensor in1, Tensor in2, Tensor in3) {
|
||||
|
||||
int m = out.shape()[0];
|
||||
int n = out.shape()[1];
|
||||
|
||||
int blocks = std::min(MAX_BLOCKS, m);
|
||||
int threads = std::min(MAX_THREADS, n);
|
||||
gElement<<<blocks, threads>>>(functor, TensorView(out),
|
||||
TensorView(in1), TensorView(in2), TensorView(in3));
|
||||
cudaStreamSynchronize(0);
|
||||
}
|
||||
|
||||
void Dropout(Tensor Out, Tensor in, float p, int seed);
|
||||
|
||||
void SubtractMax(Tensor* Out);
|
||||
|
||||
void Softmax(Tensor* Out);
|
||||
|
151
src/test.cu
151
src/test.cu
@ -20,156 +20,25 @@
|
||||
// SOFTWARE.
|
||||
|
||||
#include <fstream>
|
||||
#include <boost/timer/timer.hpp>
|
||||
|
||||
#include "marian.h"
|
||||
#include "mnist.h"
|
||||
#include "vocab.h"
|
||||
|
||||
#include "tensor_operators.h"
|
||||
|
||||
using namespace std;
|
||||
using namespace marian;
|
||||
using namespace keywords;
|
||||
|
||||
///////////////////////////////////////////////////////
|
||||
string output(const std::vector<float> &vec)
|
||||
{
|
||||
stringstream strm;
|
||||
for (size_t i = 0; i < vec.size(); ++i) {
|
||||
strm << vec[i] << " ";
|
||||
}
|
||||
return strm.str();
|
||||
}
|
||||
|
||||
//void testArgMax()
|
||||
//{
|
||||
// using namespace std;
|
||||
// using namespace marian;
|
||||
//
|
||||
// std::vector<float> hVec({29,19, 49,39, 79,99, 79,39});
|
||||
// cerr << "hVec =" << output(hVec) << endl;
|
||||
//
|
||||
// thrust::device_vector<float> dVec(8);
|
||||
// thrust::copy(hVec.begin(), hVec.end(), dVec.begin());
|
||||
// float *data = thrust::raw_pointer_cast(dVec.data());
|
||||
//
|
||||
// thrust::device_vector<float> dLabel(4);
|
||||
// float *labelPtr = thrust::raw_pointer_cast(dLabel.data());
|
||||
//
|
||||
// gArgMax<<<4, 1, sizeof(float)>>>(labelPtr, data, 4, 2);
|
||||
//
|
||||
// std::vector<float> hVec2(8);
|
||||
// thrust::copy(dVec.begin(), dVec.end(), hVec2.begin());
|
||||
// cerr << "hVec2=" << output(hVec2) << endl;
|
||||
//
|
||||
// std::vector<float> hLabel(4);
|
||||
// thrust::copy(dLabel.begin(), dLabel.end(), hLabel.begin());
|
||||
// cerr << "hLabel=" << output(hLabel) << endl;
|
||||
//
|
||||
// exit(0);
|
||||
//}
|
||||
|
||||
///////////////////////////////////////////////////////
|
||||
int main(int argc, char** argv) {
|
||||
//testArgMax();
|
||||
|
||||
using namespace std;
|
||||
using namespace marian;
|
||||
using namespace keywords;
|
||||
|
||||
Vocab sourceVocab, targetVocab;
|
||||
|
||||
int input_size = 10;
|
||||
int output_size = 2;
|
||||
int batch_size = 25;
|
||||
int hidden_size = 5;
|
||||
int num_inputs = 8;
|
||||
|
||||
std::vector<Expr> inExpr;
|
||||
std::vector<Expr> outExpr;
|
||||
std::vector<Expr> hiddenExpr;
|
||||
|
||||
ExpressionGraph g;
|
||||
|
||||
for (int t = 0; t < num_inputs; ++t) {
|
||||
inExpr.emplace_back(g.input(shape={batch_size, input_size}));
|
||||
outExpr.emplace_back(g.input(shape={batch_size, output_size}));
|
||||
}
|
||||
|
||||
Expr Wxh = g.param(shape={input_size, hidden_size}, name="Wxh");
|
||||
Expr Whh = g.param(shape={hidden_size, hidden_size}, name="Whh");
|
||||
Expr bh = g.param(shape={1, hidden_size}, name="bh");
|
||||
Expr h0 = g.param(shape={1, hidden_size}, name="h0");
|
||||
|
||||
// read parallel corpus from file
|
||||
std::fstream sourceFile("../examples/mt/dev/newstest2013.de");
|
||||
std::fstream targetFile("../examples/mt/dev/newstest2013.en");
|
||||
|
||||
string sourceLine, targetLine;
|
||||
while (getline(sourceFile, sourceLine)) {
|
||||
getline(targetFile, targetLine);
|
||||
std::vector<size_t> sourceIds = sourceVocab.ProcessSentence(sourceLine);
|
||||
std::vector<size_t> targetIds = sourceVocab.ProcessSentence(targetLine);
|
||||
}
|
||||
|
||||
std::cerr << "Building RNN..." << std::endl;
|
||||
hiddenExpr.emplace_back(tanh(dot(inExpr[0], Wxh) + dot(h0, Whh) + bh));
|
||||
for (int t = 1; t < num_inputs; ++t) {
|
||||
hiddenExpr.emplace_back(tanh(dot(inExpr[t], Wxh) + dot(hiddenExpr[t-1], Whh) + bh));
|
||||
}
|
||||
|
||||
Expr Why = g.param(shape={hidden_size, output_size}, name="Why");
|
||||
Expr by = g.param(shape={1, output_size}, name="by");
|
||||
|
||||
std::cerr << "Building output layer..." << std::endl;
|
||||
std::vector<Expr> Yp;
|
||||
|
||||
Yp.emplace_back(softmax(dot(hiddenExpr[0], Why) + by));
|
||||
Expr cross_entropy = sum(outExpr[0] * log(Yp[0]), axis=1);
|
||||
for (int t = 1; t < num_inputs; ++t) {
|
||||
Yp.emplace_back(softmax(dot(hiddenExpr[t], Why) + by));
|
||||
cross_entropy = cross_entropy + sum(outExpr[t] * log(Yp[t]), axis=1);
|
||||
}
|
||||
Expr graph = -mean(cross_entropy, axis=0, name="cost");
|
||||
|
||||
for (int t = 0; t < num_inputs; ++t) {
|
||||
Tensor Xt({batch_size, input_size});
|
||||
Tensor Yt({batch_size, output_size});
|
||||
|
||||
float max = 1.;
|
||||
std::vector<float> values(batch_size * input_size);
|
||||
std::vector<float> classes(batch_size * output_size, 0.0);
|
||||
int k = 0;
|
||||
int l = 0;
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
for (int j = 0; j < input_size; ++j, ++k) {
|
||||
values[k] = max * (2.0*static_cast<float>(rand()) / RAND_MAX - 1.0);
|
||||
}
|
||||
int gold = output_size * static_cast<float>(rand()) / RAND_MAX;
|
||||
classes[l + gold] = 1.0;
|
||||
l += output_size;
|
||||
}
|
||||
|
||||
thrust::copy(values.begin(), values.end(), Xt.begin());
|
||||
thrust::copy(classes.begin(), classes.end(), Yt.begin());
|
||||
|
||||
inExpr[t] = Xt;
|
||||
outExpr[t] = Yt;
|
||||
}
|
||||
|
||||
std::cout << g.graphviz() << std::endl;
|
||||
Tensor a({1000, 1000}, 3);
|
||||
Tensor b({1, 1}, 2);
|
||||
|
||||
g.forward(batch_size);
|
||||
g.backward();
|
||||
|
||||
std::cerr << graph.val().Debug() << std::endl;
|
||||
|
||||
std::cerr << "inExpr[0]=" << inExpr[0].val().Debug() << std::endl;
|
||||
std::cerr << "outExpr[0]=" << outExpr[0].val().Debug() << std::endl;
|
||||
|
||||
std::cerr << "Whh.grad=" << Whh.grad().Debug() << std::endl;
|
||||
std::cerr << "bh.grad=" << bh.grad().Debug() << std::endl;
|
||||
std::cerr << "Why.grad=" << Why.grad().Debug() << std::endl;
|
||||
std::cerr << "by.grad=" << by.grad().Debug() << std::endl;
|
||||
std::cerr << "Wxh.grad=" << Wxh.grad().Debug() << std::endl;
|
||||
std::cerr << "h0.grad=" << h0.grad().Debug() << std::endl;
|
||||
|
||||
boost::timer::cpu_timer timer;
|
||||
for(int i = 0; i < 1000; ++i)
|
||||
Element(_1 += _1 * _2, a, b);
|
||||
std::cerr << timer.format(5, "%ws") << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
@ -37,13 +37,7 @@ namespace thrust
|
||||
struct unary_exp : public thrust::unary_function<T,T> {
|
||||
__host__ __device__
|
||||
T operator()(const T &x) const {
|
||||
float x2 = x;
|
||||
float clip = 16;
|
||||
if(x2 > clip)
|
||||
x2 = clip;
|
||||
if(x2 < -clip)
|
||||
x2 = -clip;
|
||||
return expf(x2);
|
||||
return expf(x);
|
||||
}
|
||||
};
|
||||
|
||||
@ -58,10 +52,7 @@ namespace thrust
|
||||
struct unary_log : public thrust::unary_function<T,T> {
|
||||
__host__ __device__
|
||||
T operator()(const T &x) const {
|
||||
float x2 = x;
|
||||
if(x2 < 10e-10)
|
||||
x2 = 10e-10;
|
||||
return logf(x2);
|
||||
return logf(x);
|
||||
}
|
||||
};
|
||||
|
||||
@ -76,13 +67,7 @@ namespace thrust
|
||||
struct unary_sigma : public thrust::unary_function<T,T> {
|
||||
__host__ __device__
|
||||
T operator()(const T &x) const {
|
||||
float x2 = x;
|
||||
float clip = 16;
|
||||
if(x2 > clip)
|
||||
x2 = clip;
|
||||
if(x2 < -clip)
|
||||
x2 = -clip;
|
||||
return 1.0 / (1.0 + expf(-x2));
|
||||
return 1.0 / (1.0 + expf(-x));
|
||||
}
|
||||
};
|
||||
|
||||
@ -127,6 +112,33 @@ namespace thrust
|
||||
make_actor(_1),
|
||||
make_actor(_2));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
struct unary_relu : public thrust::unary_function<T,T> {
|
||||
__host__ __device__
|
||||
T operator()(const T &x) const { return x > 0.0f ? x : 0.0f; }
|
||||
};
|
||||
|
||||
template<typename Eval>
|
||||
__host__ __device__
|
||||
actor<composite<unary_operator<unary_relu>, actor<Eval>>>
|
||||
ReLU(const actor<Eval> &_1) {
|
||||
return compose(unary_operator<unary_relu>(), _1);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
struct unary_reluback : public thrust::unary_function<T,T> {
|
||||
__host__ __device__
|
||||
T operator()(const T &x) const { return x > 0.0f ? 1.0f : 0.0f; }
|
||||
};
|
||||
|
||||
template<typename Eval>
|
||||
__host__ __device__
|
||||
actor<composite<unary_operator<unary_reluback>, actor<Eval>>>
|
||||
ReLUback(const actor<Eval> &_1) {
|
||||
return compose(unary_operator<unary_reluback>(), _1);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user