mirror of
https://github.com/moses-smt/mosesdecoder.git
synced 2024-12-25 12:52:29 +03:00
numerically safe softmax in VW
This commit is contained in:
parent
3e5c0e8667
commit
7b527006c8
@ -2,6 +2,7 @@
|
||||
#define moses_Normalizer_h
|
||||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include "Util.h"
|
||||
|
||||
namespace Discriminative
|
||||
@ -45,16 +46,25 @@ public:
|
||||
virtual ~SquaredLossNormalizer() {}
|
||||
};
|
||||
|
||||
// safe softmax
|
||||
class LogisticLossNormalizer : public Normalizer
|
||||
{
|
||||
public:
|
||||
virtual void operator()(std::vector<float> &losses) const {
|
||||
float sum = 0;
|
||||
std::vector<float>::iterator it;
|
||||
|
||||
float sum = 0;
|
||||
float max = 0;
|
||||
for (it = losses.begin(); it != losses.end(); it++) {
|
||||
*it = exp(-*it);
|
||||
*it = -*it;
|
||||
max = std::max(max, *it);
|
||||
}
|
||||
|
||||
for (it = losses.begin(); it != losses.end(); it++) {
|
||||
*it = exp(*it - max);
|
||||
sum += *it;
|
||||
}
|
||||
|
||||
for (it = losses.begin(); it != losses.end(); it++) {
|
||||
*it /= sum;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user