numerically safe softmax in VW

This commit is contained in:
Ales Tamchyna 2016-03-10 17:37:21 +01:00
parent 3e5c0e8667
commit 7b527006c8

View File

@ -2,6 +2,7 @@
#define moses_Normalizer_h
#include <vector>
#include <algorithm>
#include "Util.h"
namespace Discriminative
@ -45,16 +46,25 @@ public:
virtual ~SquaredLossNormalizer() {}
};
// safe softmax
class LogisticLossNormalizer : public Normalizer
{
public:
virtual void operator()(std::vector<float> &losses) const {
float sum = 0;
std::vector<float>::iterator it;
float sum = 0;
float max = 0;
for (it = losses.begin(); it != losses.end(); it++) {
*it = exp(-*it);
*it = -*it;
max = std::max(max, *it);
}
for (it = losses.begin(); it != losses.end(); it++) {
*it = exp(*it - max);
sum += *it;
}
for (it = losses.begin(); it != losses.end(); it++) {
*it /= sum;
}