2015-01-06 19:49:23 +03:00
|
|
|
#ifndef moses_Normalizer_h
|
|
|
|
#define moses_Normalizer_h
|
|
|
|
|
|
|
|
#include <vector>
|
2016-03-10 19:37:21 +03:00
|
|
|
#include <algorithm>
|
2015-01-06 19:49:23 +03:00
|
|
|
#include "Util.h"
|
|
|
|
|
|
|
|
namespace Discriminative
|
|
|
|
{
|
|
|
|
|
2015-01-14 14:07:42 +03:00
|
|
|
class Normalizer
|
|
|
|
{
|
2015-01-06 19:49:23 +03:00
|
|
|
public:
|
|
|
|
virtual void operator()(std::vector<float> &losses) const = 0;
|
2015-01-09 14:14:17 +03:00
|
|
|
virtual ~Normalizer() {}
|
2015-01-06 19:49:23 +03:00
|
|
|
};
|
|
|
|
|
2015-01-14 14:07:42 +03:00
|
|
|
class SquaredLossNormalizer : public Normalizer
|
|
|
|
{
|
2015-01-06 19:49:23 +03:00
|
|
|
public:
|
2015-01-14 14:07:42 +03:00
|
|
|
virtual void operator()(std::vector<float> &losses) const {
|
2015-01-06 19:49:23 +03:00
|
|
|
// This is (?) a good choice for sqrt loss (default loss function in VW)
|
|
|
|
|
|
|
|
float sum = 0;
|
|
|
|
|
|
|
|
// clip to [0,1] and take 1-Z as non-normalized prob
|
|
|
|
std::vector<float>::iterator it;
|
|
|
|
for (it = losses.begin(); it != losses.end(); it++) {
|
|
|
|
if (*it <= 0.0) *it = 1.0;
|
|
|
|
else if (*it >= 1.0) *it = 0.0;
|
|
|
|
else *it = 1.0 - *it;
|
|
|
|
sum += *it;
|
|
|
|
}
|
|
|
|
|
|
|
|
if (! Moses::Equals(sum, 0)) {
|
|
|
|
// normalize
|
|
|
|
for (it = losses.begin(); it != losses.end(); it++)
|
|
|
|
*it /= sum;
|
|
|
|
} else {
|
|
|
|
// sum of non-normalized probs is 0, then take uniform probs
|
|
|
|
for (it = losses.begin(); it != losses.end(); it++)
|
|
|
|
*it = 1.0 / losses.size();
|
|
|
|
}
|
|
|
|
}
|
2015-01-09 14:14:17 +03:00
|
|
|
|
|
|
|
virtual ~SquaredLossNormalizer() {}
|
2015-01-06 19:49:23 +03:00
|
|
|
};
|
|
|
|
|
2016-03-10 19:37:21 +03:00
|
|
|
// safe softmax
|
2015-01-14 14:07:42 +03:00
|
|
|
class LogisticLossNormalizer : public Normalizer
|
|
|
|
{
|
2015-01-06 19:49:23 +03:00
|
|
|
public:
|
2015-01-14 14:07:42 +03:00
|
|
|
virtual void operator()(std::vector<float> &losses) const {
|
2015-01-06 19:49:23 +03:00
|
|
|
std::vector<float>::iterator it;
|
2016-03-10 19:37:21 +03:00
|
|
|
|
|
|
|
float sum = 0;
|
|
|
|
float max = 0;
|
2015-01-06 19:49:23 +03:00
|
|
|
for (it = losses.begin(); it != losses.end(); it++) {
|
2016-03-10 19:37:21 +03:00
|
|
|
*it = -*it;
|
|
|
|
max = std::max(max, *it);
|
|
|
|
}
|
|
|
|
|
|
|
|
for (it = losses.begin(); it != losses.end(); it++) {
|
|
|
|
*it = exp(*it - max);
|
2015-01-06 19:49:23 +03:00
|
|
|
sum += *it;
|
|
|
|
}
|
2016-03-10 19:37:21 +03:00
|
|
|
|
2015-01-06 19:49:23 +03:00
|
|
|
for (it = losses.begin(); it != losses.end(); it++) {
|
|
|
|
*it /= sum;
|
|
|
|
}
|
|
|
|
}
|
2015-01-09 14:14:17 +03:00
|
|
|
|
|
|
|
virtual ~LogisticLossNormalizer() {}
|
2015-01-06 19:49:23 +03:00
|
|
|
};
|
|
|
|
|
|
|
|
} // namespace Discriminative
|
|
|
|
|
|
|
|
#endif // moses_Normalizer_h
|