mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-04 14:04:24 +03:00
some clean-up
This commit is contained in:
parent
3ccdc15263
commit
6f7f0d77f1
@ -94,7 +94,6 @@ Expr broadcast(Shape bShape, Expr a) {
|
||||
"Cannot broadcast tensor dimension "
|
||||
<< dimA << " to " << dimB);
|
||||
if(dimA == 1 && dimB != 1) {
|
||||
std::cerr << "Broadcasting dim " << i << " from " << dimA << " to " << dimB << std::endl;
|
||||
if(i == 0) {
|
||||
Expr one = ones(keywords::shape={bShape[0], 1});
|
||||
a = dot(one, a);
|
||||
|
@ -132,15 +132,21 @@ struct ArgmaxOp : public UnaryNodeOp {
|
||||
|
||||
void forward() {
|
||||
//val_ = Argmax(a_->val(), axis_);
|
||||
UTIL_THROW2("Not implemented");
|
||||
}
|
||||
|
||||
void backward() {}
|
||||
void backward() {
|
||||
UTIL_THROW2("Not implemented");
|
||||
}
|
||||
|
||||
private:
|
||||
int axis_;
|
||||
};
|
||||
|
||||
|
||||
// @TODO, make this numerically safe(r):
|
||||
// softmax(X) = softmax_safe(X - max(X, axis=1))
|
||||
// Probably best to do this directly in Softmax
|
||||
// function.
|
||||
struct SoftmaxNodeOp : public UnaryNodeOp {
|
||||
template <typename ...Args>
|
||||
SoftmaxNodeOp(ChainPtr a, Args ...args)
|
||||
|
@ -52,8 +52,6 @@ class TensorImpl {
|
||||
UTIL_THROW_IF2(shape_.size() < 1 || shape_.size() > 4,
|
||||
"Wrong number of dimensions: " << shape_.size());
|
||||
|
||||
std::cerr << "Allocating : " << shape[0] << " " << shape[1] << std::endl;
|
||||
|
||||
int size = GetTotalSize(shape_);
|
||||
data_.resize(size, value);
|
||||
}
|
||||
|
@ -40,8 +40,8 @@ int main(int argc, char** argv) {
|
||||
auto b = param(shape={1, LABEL_SIZE},
|
||||
init=[bData](Tensor t) { t.set(bData); });
|
||||
|
||||
auto predict = softmax(dot(x, w) + b, axis=1);
|
||||
auto graph = -mean(sum(y * log(predict), axis=1), axis=0);
|
||||
auto probs = softmax(dot(x, w) + b, axis=1);
|
||||
auto graph = -mean(sum(y * log(probs), axis=1), axis=0);
|
||||
|
||||
std::cerr << "Done." << std::endl;
|
||||
|
||||
@ -52,9 +52,24 @@ int main(int argc, char** argv) {
|
||||
y = yt << testLabels;
|
||||
|
||||
graph.forward(BATCH_SIZE);
|
||||
auto results = probs.val();
|
||||
std::vector<float> resultsv(results.size());
|
||||
resultsv << results;
|
||||
|
||||
size_t acc = 0;
|
||||
for (size_t i = 0; i < testLabels.size(); i += LABEL_SIZE) {
|
||||
size_t correct = 0;
|
||||
size_t probsed = 0;
|
||||
for (size_t j = 0; j < LABEL_SIZE; ++j) {
|
||||
if (testLabels[i+j]) correct = j;
|
||||
if (resultsv[i + j] > resultsv[i + probsed]) probsed = j;
|
||||
}
|
||||
acc += (correct == probsed);
|
||||
}
|
||||
std::cerr << "Cost: " << graph.val()[0] << " - Accuracy: " << float(acc) / BATCH_SIZE << std::endl;
|
||||
|
||||
float eta = 0.1;
|
||||
for (size_t j = 0; j < 100; ++j) {
|
||||
for (size_t j = 0; j < 10; ++j) {
|
||||
for(size_t i = 0; i < 60; ++i) {
|
||||
graph.backward();
|
||||
|
||||
@ -65,21 +80,21 @@ int main(int argc, char** argv) {
|
||||
graph.forward(BATCH_SIZE);
|
||||
}
|
||||
std::cerr << "Epoch: " << j << std::endl;
|
||||
auto results = predict.val();
|
||||
auto results = probs.val();
|
||||
std::vector<float> resultsv(results.size());
|
||||
resultsv << results;
|
||||
|
||||
size_t acc = 0;
|
||||
for (size_t i = 0; i < testLabels.size(); i += LABEL_SIZE) {
|
||||
size_t correct = 0;
|
||||
size_t predicted = 0;
|
||||
size_t probsed = 0;
|
||||
for (size_t j = 0; j < LABEL_SIZE; ++j) {
|
||||
if (testLabels[i+j]) correct = j;
|
||||
if (resultsv[i + j] > resultsv[i + predicted]) predicted = j;
|
||||
if (resultsv[i + j] > resultsv[i + probsed]) probsed = j;
|
||||
}
|
||||
acc += (correct == predicted);
|
||||
acc += (correct == probsed);
|
||||
}
|
||||
std::cerr << "Accuracy: " << float(acc) / BATCH_SIZE << std::endl;
|
||||
std::cerr << "Cost: " << graph.val()[0] << " - Accuracy: " << float(acc) / BATCH_SIZE << std::endl;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user