mosesdecoder/vw/Normalizer.h

79 lines
1.7 KiB
C
Raw Normal View History

2015-01-06 19:49:23 +03:00
#ifndef moses_Normalizer_h
#define moses_Normalizer_h
#include <vector>
2016-03-10 19:37:21 +03:00
#include <algorithm>
2015-01-06 19:49:23 +03:00
#include "Util.h"
namespace Discriminative
{
2015-01-14 14:07:42 +03:00
class Normalizer
{
2015-01-06 19:49:23 +03:00
public:
virtual void operator()(std::vector<float> &losses) const = 0;
2015-01-09 14:14:17 +03:00
virtual ~Normalizer() {}
2015-01-06 19:49:23 +03:00
};
2015-01-14 14:07:42 +03:00
class SquaredLossNormalizer : public Normalizer
{
2015-01-06 19:49:23 +03:00
public:
2015-01-14 14:07:42 +03:00
virtual void operator()(std::vector<float> &losses) const {
2015-01-06 19:49:23 +03:00
// This is (?) a good choice for sqrt loss (default loss function in VW)
float sum = 0;
// clip to [0,1] and take 1-Z as non-normalized prob
std::vector<float>::iterator it;
for (it = losses.begin(); it != losses.end(); it++) {
if (*it <= 0.0) *it = 1.0;
else if (*it >= 1.0) *it = 0.0;
else *it = 1.0 - *it;
sum += *it;
}
if (! Moses::Equals(sum, 0)) {
// normalize
for (it = losses.begin(); it != losses.end(); it++)
*it /= sum;
} else {
// sum of non-normalized probs is 0, then take uniform probs
for (it = losses.begin(); it != losses.end(); it++)
*it = 1.0 / losses.size();
}
}
2015-01-09 14:14:17 +03:00
virtual ~SquaredLossNormalizer() {}
2015-01-06 19:49:23 +03:00
};
2016-03-10 19:37:21 +03:00
// safe softmax
2015-01-14 14:07:42 +03:00
class LogisticLossNormalizer : public Normalizer
{
2015-01-06 19:49:23 +03:00
public:
2015-01-14 14:07:42 +03:00
virtual void operator()(std::vector<float> &losses) const {
2015-01-06 19:49:23 +03:00
std::vector<float>::iterator it;
2016-03-10 19:37:21 +03:00
float sum = 0;
float max = 0;
2015-01-06 19:49:23 +03:00
for (it = losses.begin(); it != losses.end(); it++) {
2016-03-10 19:37:21 +03:00
*it = -*it;
max = std::max(max, *it);
}
for (it = losses.begin(); it != losses.end(); it++) {
*it = exp(*it - max);
2015-01-06 19:49:23 +03:00
sum += *it;
}
2016-03-10 19:37:21 +03:00
2015-01-06 19:49:23 +03:00
for (it = losses.begin(); it != losses.end(); it++) {
*it /= sum;
}
}
2015-01-09 14:14:17 +03:00
virtual ~LogisticLossNormalizer() {}
2015-01-06 19:49:23 +03:00
};
} // namespace Discriminative
#endif // moses_Normalizer_h