mosesdecoder/mert/MiraWeightVector.cpp

192 lines
3.6 KiB
C++
Raw Normal View History

#include "MiraWeightVector.h"
2012-06-26 19:40:16 +04:00
#include <cmath>
using namespace std;
namespace MosesTuning
{
2013-05-29 21:16:15 +04:00
/**
* Constructor, initializes to the zero vector
*/
MiraWeightVector::MiraWeightVector()
: m_weights(),
m_totals(),
m_lastUpdated()
{
m_numUpdates = 0;
}
/**
* Constructor with provided initial vector
* \param init Initial feature values
*/
MiraWeightVector::MiraWeightVector(const vector<ValType>& init)
: m_weights(init),
m_totals(init),
m_lastUpdated(init.size(), 0)
{
m_numUpdates = 0;
}
/**
* Update a the model
* \param fv Feature vector to be added to the weights
* \param tau FV will be scaled by this value before update
*/
2013-05-29 21:16:15 +04:00
void MiraWeightVector::update(const MiraFeatureVector& fv, float tau)
{
m_numUpdates++;
2013-05-29 21:16:15 +04:00
for(size_t i=0; i<fv.size(); i++) {
update(fv.feat(i), fv.val(i)*tau);
}
}
/**
* Perform an empty update (affects averaging)
*/
2013-05-29 21:16:15 +04:00
void MiraWeightVector::tick()
{
m_numUpdates++;
}
/**
* Score a feature vector according to the model
* \param fv Feature vector to be scored
*/
2013-05-29 21:16:15 +04:00
ValType MiraWeightVector::score(const MiraFeatureVector& fv) const
{
ValType toRet = 0.0;
for(size_t i=0; i<fv.size(); i++) {
toRet += weight(fv.feat(i)) * fv.val(i);
}
return toRet;
}
/**
* Return an averaged view of this weight vector
*/
2013-05-29 21:16:15 +04:00
AvgWeightVector MiraWeightVector::avg()
{
this->fixTotals();
return AvgWeightVector(*this);
}
/**
* Updates a weight and lazily updates its total
*/
2013-05-29 21:16:15 +04:00
void MiraWeightVector::update(size_t index, ValType delta)
{
// Handle previously unseen weights
while(index>=m_weights.size()) {
m_weights.push_back(0.0);
m_totals.push_back(0.0);
m_lastUpdated.push_back(0);
}
// Book keeping for w = w + delta
m_totals[index] += (m_numUpdates - m_lastUpdated[index]) * m_weights[index] + delta;
m_weights[index] += delta;
m_lastUpdated[index] = m_numUpdates;
}
2015-01-14 14:07:42 +03:00
void MiraWeightVector::ToSparse(SparseVector* sparse) const
{
for (size_t i = 0; i < m_weights.size(); ++i) {
if(abs(m_weights[i])>1e-8) {
sparse->set(i,m_weights[i]);
}
}
}
/**
* Make sure everyone's total is up-to-date
*/
2013-05-29 21:16:15 +04:00
void MiraWeightVector::fixTotals()
{
for(size_t i=0; i<m_weights.size(); i++) update(i,0);
}
/**
* Helper to handle out of range weights
*/
2013-05-29 21:16:15 +04:00
ValType MiraWeightVector::weight(size_t index) const
{
if(index < m_weights.size()) {
return m_weights[index];
2013-05-29 21:16:15 +04:00
} else {
return 0;
}
}
2013-05-29 21:16:15 +04:00
ValType MiraWeightVector::sqrNorm() const
{
ValType toRet = 0;
2013-05-29 21:16:15 +04:00
for(size_t i=0; i<m_weights.size(); i++) {
toRet += weight(i) * weight(i);
}
return toRet;
}
AvgWeightVector::AvgWeightVector(const MiraWeightVector& wv)
:m_wv(wv)
{}
2013-05-29 21:16:15 +04:00
ostream& operator<<(ostream& o, const MiraWeightVector& e)
2012-06-26 19:40:16 +04:00
{
2013-05-29 21:16:15 +04:00
for(size_t i=0; i<e.m_weights.size(); i++) {
2012-06-26 19:40:16 +04:00
if(abs(e.m_weights[i])>1e-8) {
if(i>0) o << " ";
2014-09-30 22:50:10 +04:00
o << i << ":" << e.m_weights[i];
2012-06-26 19:40:16 +04:00
}
}
return o;
}
ValType AvgWeightVector::weight(size_t index) const
{
if(m_wv.m_numUpdates==0) return m_wv.weight(index);
else {
if(index < m_wv.m_totals.size()) {
return m_wv.m_totals[index] / m_wv.m_numUpdates;
2013-05-29 21:16:15 +04:00
} else {
return 0;
}
}
}
2013-05-29 21:16:15 +04:00
ValType AvgWeightVector::score(const MiraFeatureVector& fv) const
{
ValType toRet = 0.0;
for(size_t i=0; i<fv.size(); i++) {
toRet += weight(fv.feat(i)) * fv.val(i);
}
return toRet;
}
2013-05-29 21:16:15 +04:00
size_t AvgWeightVector::size() const
{
return m_wv.m_weights.size();
}
2015-01-14 14:07:42 +03:00
void AvgWeightVector::ToSparse(SparseVector* sparse) const
{
for (size_t i = 0; i < size(); ++i) {
ValType w = weight(i);
if(abs(w)>1e-8) {
sparse->set(i,w);
}
}
}
// --Emacs trickery--
// Local Variables:
// mode:c++
// c-basic-offset:2
// End:
}