mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
Corrected bug in softmax; implemented cross-entropy node along with a few CUDA functions.
This commit is contained in:
parent
20c80e4bbc
commit
c9dc0cf934
@ -125,4 +125,8 @@ Expr dot(Expr a, Expr b) {
|
||||
return Expr(a.graph(), new DotNodeOp(a, b));
|
||||
}
|
||||
|
||||
Expr cross_entropy(Expr a, Expr b) {
|
||||
return Expr(a.graph(), new CrossEntropyNodeOp(a, b));
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -112,4 +112,6 @@ inline Expr mean(Expr a, Args ...args) {
|
||||
}
|
||||
}
|
||||
|
||||
Expr cross_entropy(Expr a, Expr b);
|
||||
|
||||
}
|
||||
|
@ -43,12 +43,11 @@ ExpressionGraph build_graph(const std::vector<int>& dims) {
|
||||
|
||||
}
|
||||
|
||||
auto probs = named(
|
||||
softmax(dot(layers.back(), weights.back()) + biases.back()),
|
||||
"probs"
|
||||
);
|
||||
auto scores = named(dot(layers.back(), weights.back()) + biases.back(),
|
||||
"scores");
|
||||
|
||||
auto cost = -mean(sum(y * log(probs), axis=1), axis=0);
|
||||
auto cost = mean(cross_entropy(scores, y), axis=0);
|
||||
//auto cost = mean(-sum(y * log(softmax(scores)), axis=1), axis=0);
|
||||
auto costreg = named(
|
||||
cost, "cost"
|
||||
);
|
||||
@ -142,7 +141,7 @@ int main(int argc, char** argv) {
|
||||
g.forward(BATCH_SIZE);
|
||||
|
||||
std::vector<float> bResults;
|
||||
bResults << g["probs"].val();
|
||||
bResults << g["scores"].val();
|
||||
results.insert(results.end(), bResults.begin(), bResults.end());
|
||||
}
|
||||
|
||||
|
@ -151,7 +151,7 @@ struct SoftmaxNodeOp : public UnaryNodeOp {
|
||||
|
||||
void forward() {
|
||||
// B = softmax(A).
|
||||
val_ = a_->val();
|
||||
thrust::copy(a_->val().begin(), a_->val().end(), val_.begin());
|
||||
// Safe version of softmax.
|
||||
Softmax(&val_);
|
||||
}
|
||||
@ -441,4 +441,71 @@ struct DivNodeOp : public BinaryNodeOp {
|
||||
|
||||
};
|
||||
|
||||
// Cross-entropy node. It computes -b*log(softmax(a)), summing rowwise.
|
||||
struct CrossEntropyNodeOp : public BinaryNodeOp {
|
||||
template <typename ...Args>
|
||||
CrossEntropyNodeOp(ChainPtr a, ChainPtr b, Args ...args)
|
||||
: BinaryNodeOp(a, b,
|
||||
keywords::shape=newShape(a, b),
|
||||
args...) { }
|
||||
|
||||
Shape newShape(ChainPtr a, ChainPtr b) {
|
||||
Shape shape1 = a->shape();
|
||||
Shape shape2 = b->shape();
|
||||
UTIL_THROW_IF2(shape1[0] != shape2[0] || shape1[1] != shape2[1],
|
||||
"cross entropy requires dimensions to match");
|
||||
shape1[1] = 1;
|
||||
return shape1;
|
||||
}
|
||||
|
||||
// We're caching the softmax probabilities here because we'll need them for
|
||||
// the backward computation.
|
||||
void forward() {
|
||||
// C = -dot(B, log(softmax(A))).
|
||||
if (probs_) {
|
||||
probs_.set(0.0);
|
||||
} else {
|
||||
probs_.allocate(a_->val().shape(), 0.0);
|
||||
}
|
||||
thrust::copy(a_->val().begin(), a_->val().end(), probs_.begin());
|
||||
Softmax(&probs_); // Safe version of softmax.
|
||||
Tensor result(a_->val().shape());
|
||||
Element(_1 = -_2 * Log(_3), result, b_->val(), probs_);
|
||||
SumRowwise(result, val_);
|
||||
}
|
||||
|
||||
// @TODO: In most cases it's wasteful to compute the derivative with respect
|
||||
// to the second input which is typically an input node in the computation
|
||||
// graph. In general the backward functions can skip the computation of
|
||||
// gradients wrt input nodes.
|
||||
void backward() {
|
||||
// For each row, the first input derivative is given by adj * (p - y),
|
||||
// where y is the gold label distribution (e.g. one hot vector) and
|
||||
// p is the softmax output (probabilities).
|
||||
// The second input derivative is -adj*log(p).
|
||||
Tensor result(probs_.shape());
|
||||
|
||||
// Compute first input derivative.
|
||||
Element(_1 = _2 - _3, result, probs_, b_->val());
|
||||
ScaleRowwise(result, adj_);
|
||||
Element(_1 += _2, a_->grad(), result);
|
||||
|
||||
// Compute second input derivative.
|
||||
Element(_1 = -Log(_2), result, probs_); // @TODO: use a cached log here.
|
||||
ScaleRowwise(result, adj_);
|
||||
Element(_1 += _2, b_->grad(), result);
|
||||
}
|
||||
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=\"cross_entropy\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
protected:
|
||||
Tensor probs_;
|
||||
|
||||
};
|
||||
|
||||
}
|
||||
|
@ -227,4 +227,48 @@ Tensor Prod(Tensor C, const Tensor A, const Tensor B,
|
||||
return temp;
|
||||
}
|
||||
|
||||
Tensor SumRowwise(cublasHandle_t handle, const Tensor A, Tensor result) {
|
||||
size_t rows = A.shape()[0];
|
||||
size_t cols = A.shape()[1];
|
||||
thrust::device_vector<float> d_ones(cols, 1.f);
|
||||
Float alpha = 1.f;
|
||||
Float beta = 0.f;
|
||||
cublasSgemv(handle, CUBLAS_OP_T, cols, rows, &alpha,
|
||||
A.data(), cols,
|
||||
thrust::raw_pointer_cast(d_ones.data()), 1, &beta,
|
||||
result.data(), 1);
|
||||
return result;
|
||||
}
|
||||
|
||||
Tensor SumRowwise(const Tensor A, Tensor result) {
|
||||
Tensor temp = SumRowwise(cublasHandle, A, result);
|
||||
return temp;
|
||||
}
|
||||
|
||||
// @TODO: replace this by something else when broadcast elementwise operations
|
||||
// are ready.
|
||||
__global__ void gScaleRowwise(Float* out, const Float* scalingFactors,
|
||||
size_t rows, size_t cols) {
|
||||
for(int bid = 0; bid < rows; bid += gridDim.x) {
|
||||
int j = bid + blockIdx.x;
|
||||
if(j < rows) {
|
||||
Float* rowOut = out + j * cols;
|
||||
for(int tid = 0; tid < cols; tid += blockDim.x) {
|
||||
int i = tid + threadIdx.x;
|
||||
if(i < cols) rowOut[i] *= scalingFactors[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ScaleRowwise(Tensor Out, const Tensor ScalingFactors) {
|
||||
Float* d_out = Out.data();
|
||||
const Float* d_in = ScalingFactors.data();
|
||||
int blocks = std::min(MAX_BLOCKS, (int)Out.shape()[0]);
|
||||
int threads = std::min(MAX_THREADS, (int)Out.shape()[1]);
|
||||
gScaleRowwise<<<blocks, threads>>>(d_out, d_in,
|
||||
Out.shape()[0], Out.shape()[1]);
|
||||
cudaStreamSynchronize(0);
|
||||
}
|
||||
|
||||
}
|
@ -165,4 +165,13 @@ Tensor Prod(cublasHandle_t handle, Tensor C, const Tensor A, const Tensor B,
|
||||
Tensor Prod(Tensor C, const Tensor A, const Tensor B,
|
||||
bool transA, bool transB, Float beta = 0);
|
||||
|
||||
Tensor SumRowwise(cublasHandle_t handle, const Tensor A, Tensor result);
|
||||
|
||||
Tensor SumRowwise(const Tensor A, Tensor result);
|
||||
|
||||
__global__ void gScaleRowwise(Float* out, const Float* scalingFactors,
|
||||
size_t rows, size_t cols);
|
||||
|
||||
void ScaleRowwise(Tensor Out, const Tensor ScalingFactors);
|
||||
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user