mirror of
https://github.com/moses-smt/mosesdecoder.git
synced 2024-12-28 14:32:38 +03:00
Changes to Mmsapt.
- added phrase table features: pfwd / pbwd weighted by each sample's weight (when used with biased sampling) - a few minor internal changes ({p|j}stats.add(...) reports back current joint count (intended for early stopping of sampling when evidence is sufficient; yet to be explored experimentally).
This commit is contained in:
parent
4c78e7c0b2
commit
c165e80e48
@ -32,8 +32,8 @@ namespace sapt
|
||||
indoc = other.indoc;
|
||||
for (int i = 0; i <= LRModel::NONE; i++)
|
||||
{
|
||||
ofwd[i] = other.ofwd[i];
|
||||
obwd[i] = other.obwd[i];
|
||||
ofwd[i] = other.ofwd[i];
|
||||
obwd[i] = other.obwd[i];
|
||||
}
|
||||
}
|
||||
|
||||
@ -53,7 +53,7 @@ namespace sapt
|
||||
return obwd[idx];
|
||||
}
|
||||
|
||||
void
|
||||
size_t
|
||||
jstats::
|
||||
add(float w, float b, std::vector<unsigned char> const& a, uint32_t const cnt2,
|
||||
uint32_t fwd_orient, uint32_t bwd_orient, int const docid)
|
||||
@ -65,24 +65,25 @@ namespace sapt
|
||||
my_bcnt += b;
|
||||
if (a.size())
|
||||
{
|
||||
size_t i = 0;
|
||||
while (i < my_aln.size() && my_aln[i].second != a) ++i;
|
||||
if (i == my_aln.size())
|
||||
my_aln.push_back(std::pair<size_t,std::vector<unsigned char> >(1,a));
|
||||
else
|
||||
my_aln[i].first++;
|
||||
if (my_aln[i].first > my_aln[i/2].first)
|
||||
push_heap(my_aln.begin(),my_aln.begin()+i+1);
|
||||
size_t i = 0;
|
||||
while (i < my_aln.size() && my_aln[i].second != a) ++i;
|
||||
if (i == my_aln.size())
|
||||
my_aln.push_back(std::pair<size_t,std::vector<unsigned char> >(1,a));
|
||||
else
|
||||
my_aln[i].first++;
|
||||
if (my_aln[i].first > my_aln[i/2].first)
|
||||
push_heap(my_aln.begin(),my_aln.begin()+i+1);
|
||||
}
|
||||
++ofwd[fwd_orient];
|
||||
++obwd[bwd_orient];
|
||||
if (docid >= 0)
|
||||
{
|
||||
// while (int(indoc.size()) <= docid) indoc.push_back(0);
|
||||
++indoc[docid];
|
||||
// while (int(indoc.size()) <= docid) indoc.push_back(0);
|
||||
++indoc[docid];
|
||||
}
|
||||
return my_rcnt;
|
||||
}
|
||||
|
||||
|
||||
std::vector<std::pair<size_t, std::vector<unsigned char> > > const&
|
||||
jstats::
|
||||
aln() const
|
||||
|
@ -39,7 +39,7 @@ namespace sapt
|
||||
|
||||
std::vector<std::pair<size_t, std::vector<unsigned char> > > const & aln() const;
|
||||
|
||||
void
|
||||
size_t
|
||||
add(float w, float b, std::vector<unsigned char> const& a, uint32_t const cnt2,
|
||||
uint32_t fwd_orient, uint32_t bwd_orient, int const docid);
|
||||
|
||||
|
@ -63,7 +63,7 @@ namespace sapt
|
||||
}
|
||||
}
|
||||
|
||||
bool
|
||||
size_t
|
||||
pstats::
|
||||
add(uint64_t pid, float const w, float const b,
|
||||
std::vector<unsigned char> const& a,
|
||||
@ -73,13 +73,13 @@ namespace sapt
|
||||
{
|
||||
boost::lock_guard<boost::mutex> guard(this->lock);
|
||||
jstats& entry = this->trg[pid];
|
||||
entry.add(w, b, a, cnt2, fwd_o, bwd_o, docid);
|
||||
size_t ret = entry.add(w, b, a, cnt2, fwd_o, bwd_o, docid);
|
||||
if (this->good < entry.rcnt())
|
||||
{
|
||||
UTIL_THROW(util::Exception, "more joint counts than good counts:"
|
||||
<< entry.rcnt() << "/" << this->good << "!");
|
||||
UTIL_THROW(util::Exception, "more joint counts than good counts:"
|
||||
<< entry.rcnt() << "/" << this->good << "!");
|
||||
}
|
||||
return true;
|
||||
return ret;
|
||||
}
|
||||
|
||||
void
|
||||
|
@ -41,7 +41,7 @@ namespace sapt
|
||||
void register_worker();
|
||||
size_t count_workers() { return in_progress; }
|
||||
|
||||
bool
|
||||
size_t
|
||||
add(uint64_t const pid, // target phrase id
|
||||
float const w, // sample weight (1./(# of phrases extractable))
|
||||
float const b, // sample bias score
|
||||
|
@ -45,7 +45,7 @@ BitextSampler : public Moses::reference_counter
|
||||
{
|
||||
typedef Bitext<Token> bitext;
|
||||
typedef TSA<Token> tsa;
|
||||
typedef SamplingBias bias;
|
||||
typedef SamplingBias bias_t;
|
||||
typedef typename Bitext<Token>::iter tsa_iter;
|
||||
mutable boost::condition_variable m_ready;
|
||||
mutable boost::mutex m_lock;
|
||||
@ -59,7 +59,7 @@ BitextSampler : public Moses::reference_counter
|
||||
char const* m_next; // current position
|
||||
char const* m_stop; // end of search range
|
||||
sampling_method const m_method; // look at all/random/ranked samples
|
||||
SPTR<bias const> const m_bias; // bias over candidates
|
||||
SPTR<bias_t const> const m_bias; // bias over candidates
|
||||
size_t const m_samples; // how many samples at most
|
||||
size_t const m_min_samples;
|
||||
// non-const members
|
||||
@ -67,20 +67,20 @@ BitextSampler : public Moses::reference_counter
|
||||
size_t m_ctr; // number of samples considered
|
||||
float m_total_bias; // for random sampling with bias
|
||||
bool m_finished;
|
||||
|
||||
size_t m_num_occurrences; // estimated number of phrase occurrences in corpus
|
||||
boost::taus88 m_rnd; // every job has its own pseudo random generator
|
||||
// double m_rnd_denom; // denominator for scaling random sampling
|
||||
double m_bias_total;
|
||||
|
||||
bool consider_sample(TokenPosition const& p);
|
||||
size_t consider_sample(TokenPosition const& p);
|
||||
size_t perform_random_sampling();
|
||||
size_t perform_full_phrase_extraction();
|
||||
|
||||
int check_sample_distribution(uint64_t const& sid, uint64_t const& offset);
|
||||
bool flip_coin(id_type & sid, ushort & offset);
|
||||
bool flip_coin(id_type const& sid, ushort const& offset, SamplingBias const* bias);
|
||||
|
||||
public:
|
||||
BitextSampler(BitextSampler const& other);
|
||||
BitextSampler const& operator=(BitextSampler const& other);
|
||||
// BitextSampler const& operator=(BitextSampler const& other);
|
||||
BitextSampler(SPTR<bitext const> const& bitext,
|
||||
typename bitext::iter const& phrase,
|
||||
SPTR<SamplingBias const> const& bias,
|
||||
@ -159,9 +159,9 @@ check_sample_distribution(uint64_t const& sid, uint64_t const& offset)
|
||||
template<typename Token>
|
||||
bool
|
||||
BitextSampler<Token>::
|
||||
flip_coin(id_type & sid, ushort & offset)
|
||||
flip_coin(id_type const& sid, ushort const& offset, bias_t const* bias)
|
||||
{
|
||||
int no_maybe_yes = m_bias ? check_sample_distribution(sid, offset) : 1;
|
||||
int no_maybe_yes = bias ? check_sample_distribution(sid, offset) : 1;
|
||||
if (no_maybe_yes == 0) return false; // no
|
||||
if (no_maybe_yes > 1) return true; // yes
|
||||
// ... maybe: flip a coin
|
||||
@ -170,8 +170,8 @@ flip_coin(id_type & sid, ushort & offset)
|
||||
size_t options_left = (options_total - m_ctr);
|
||||
size_t random_number = options_left * (m_rnd()/(m_rnd.max()+1.));
|
||||
size_t threshold;
|
||||
if (m_bias_total) // we have a bias and there are candidates with non-zero prob
|
||||
threshold = ((*m_bias)[sid]/m_bias_total * options_total * m_samples);
|
||||
if (bias && m_bias_total > 0) // we have a bias and there are candidates with non-zero prob
|
||||
threshold = ((*bias)[sid]/m_bias_total * options_total * m_samples);
|
||||
else // no bias, or all have prob 0 (can happen with a very opinionated bias)
|
||||
threshold = m_samples;
|
||||
return random_number + options_chosen < threshold;
|
||||
@ -199,13 +199,12 @@ BitextSampler(SPTR<Bitext<Token> const> const& bitext,
|
||||
, m_ctr(0)
|
||||
, m_total_bias(0)
|
||||
, m_finished(false)
|
||||
, m_num_occurrences(phrase.ca())
|
||||
, m_rnd(0)
|
||||
// , m_rnd_denom(m_rnd.max() + 1)
|
||||
{
|
||||
m_stats.reset(new pstats);
|
||||
m_stats->raw_cnt = phrase.ca();
|
||||
m_stats->register_worker();
|
||||
// cerr << phrase.str(bitext->V1.get()) << " [" << HERE << "]" << endl;
|
||||
}
|
||||
|
||||
template<typename Token>
|
||||
@ -221,8 +220,8 @@ BitextSampler(BitextSampler const& other)
|
||||
, m_bias(other.m_bias)
|
||||
, m_samples(other.m_samples)
|
||||
, m_min_samples(other.m_min_samples)
|
||||
, m_num_occurrences(other.m_num_occurrences)
|
||||
, m_rnd(0)
|
||||
// , m_rnd_denom(m_rnd.max() + 1)
|
||||
{
|
||||
// lock both instances
|
||||
boost::unique_lock<boost::mutex> mylock(m_lock);
|
||||
@ -235,6 +234,23 @@ BitextSampler(BitextSampler const& other)
|
||||
m_finished = other.m_finished;
|
||||
}
|
||||
|
||||
// Uniform sampling
|
||||
template<typename Token>
|
||||
size_t
|
||||
BitextSampler<Token>::
|
||||
perform_full_phrase_extraction()
|
||||
{
|
||||
if (m_next == m_stop) return m_ctr;
|
||||
for (sapt::tsa::ArrayEntry I(m_next); I.next < m_stop; ++m_ctr)
|
||||
{
|
||||
++m_ctr;
|
||||
m_root->readEntry(I.next, I);
|
||||
consider_sample(I);
|
||||
}
|
||||
return m_ctr;
|
||||
}
|
||||
|
||||
|
||||
// Uniform sampling
|
||||
template<typename Token>
|
||||
size_t
|
||||
@ -260,14 +276,14 @@ perform_random_sampling()
|
||||
{
|
||||
++m_ctr;
|
||||
m_root->readEntry(I.next,I);
|
||||
if (!flip_coin(I.sid, I.offset)) continue;
|
||||
if (!flip_coin(I.sid, I.offset, m_bias.get())) continue;
|
||||
consider_sample(I);
|
||||
}
|
||||
return m_ctr;
|
||||
}
|
||||
|
||||
template<typename Token>
|
||||
bool
|
||||
size_t
|
||||
BitextSampler<Token>::
|
||||
consider_sample(TokenPosition const& p)
|
||||
{
|
||||
@ -279,7 +295,7 @@ consider_sample(TokenPosition const& p)
|
||||
if (!m_bitext->find_trg_phr_bounds(rec))
|
||||
{ // no good, probably because phrase is not coherent
|
||||
m_stats->count_sample(docid, 0, rec.po_fwd, rec.po_bwd);
|
||||
return false;
|
||||
return 0;
|
||||
}
|
||||
|
||||
// all good: register this sample as valid
|
||||
@ -300,6 +316,7 @@ consider_sample(TokenPosition const& p)
|
||||
// pair once per source phrase occurrence, or else run the risk of
|
||||
// having more joint counts than marginal counts.
|
||||
|
||||
size_t max_evidence = 0;
|
||||
for (size_t s = rec.s1; s <= rec.s2; ++s)
|
||||
{
|
||||
TSA<Token> const& I = m_fwd ? *m_bitext->I2 : *m_bitext->I1;
|
||||
@ -313,8 +330,10 @@ consider_sample(TokenPosition const& p)
|
||||
continue; // don't over-count
|
||||
seen.push_back(tpid);
|
||||
size_t raw2 = b->approxOccurrenceCount();
|
||||
m_stats->add(tpid, sample_weight, m_bias ? (*m_bias)[p.sid] : 1,
|
||||
aln, raw2, rec.po_fwd, rec.po_bwd, docid);
|
||||
size_t evid = m_stats->add(tpid, sample_weight,
|
||||
m_bias ? (*m_bias)[p.sid] : 1,
|
||||
aln, raw2, rec.po_fwd, rec.po_bwd, docid);
|
||||
max_evidence = std::max(max_evidence, evid);
|
||||
bool ok = (i == rec.e2) || b->extend(o[i].id());
|
||||
UTIL_THROW_IF2(!ok, "Could not extend target phrase.");
|
||||
}
|
||||
@ -322,7 +341,7 @@ consider_sample(TokenPosition const& p)
|
||||
for (size_t k = 1; k < aln.size(); k += 2)
|
||||
--aln[k];
|
||||
}
|
||||
return true;
|
||||
return max_evidence;
|
||||
}
|
||||
|
||||
#ifndef MMT
|
||||
@ -333,7 +352,9 @@ operator()()
|
||||
{
|
||||
if (m_finished) return true;
|
||||
boost::unique_lock<boost::mutex> lock(m_lock);
|
||||
if (m_method == random_sampling)
|
||||
if (m_method == full_coverage)
|
||||
preform_full_phrase_extraction(); // consider all occurrences
|
||||
else if (m_method == random_sampling)
|
||||
perform_random_sampling();
|
||||
else UTIL_THROW2("Unsupported sampling method.");
|
||||
m_finished = true;
|
||||
|
@ -26,7 +26,7 @@ namespace sapt
|
||||
BOOST_FOREACH(char const& x, denom)
|
||||
{
|
||||
if (x == '+') { --checksum; continue; }
|
||||
if (x != 'g' && x != 's' && x != 'r') continue;
|
||||
if (x != 'g' && x != 's' && x != 'r' && x != 'b') continue;
|
||||
std::string s = (boost::format("pbwd-%c%.3f") % x % c).str();
|
||||
this->m_feature_names.push_back(s);
|
||||
}
|
||||
@ -48,9 +48,12 @@ namespace sapt
|
||||
BOOST_FOREACH(char const& x, denom)
|
||||
{
|
||||
uint32_t m2 = pp.raw2;
|
||||
if (x == 'g') m2 = round(m2 * float(pp.good1) / pp.raw1);
|
||||
if (x == 'g' || x == 'b') m2 = round(m2 * float(pp.good1) / pp.raw1);
|
||||
else if (x == 's') m2 = round(m2 * float(pp.sample1) / pp.raw1);
|
||||
(*dest)[i++] = log(lbop(std::max(m2, pp.joint), pp.joint,conf));
|
||||
|
||||
(*dest)[i] = log(lbop(std::max(m2, pp.joint), pp.joint,conf));
|
||||
if (x == 'b') (*dest)[i] += log(pp.cum_bias) - log(pp.joint);
|
||||
++i;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
@ -27,7 +27,7 @@ namespace sapt
|
||||
BOOST_FOREACH(char const& x, denom)
|
||||
{
|
||||
if (x == '+') { --checksum; continue; }
|
||||
if (x != 'g' && x != 's' && x != 'r') continue;
|
||||
if (x != 'g' && x != 's' && x != 'r' && x != 'b') continue;
|
||||
std::string s = (boost::format("pfwd-%c%.3f") % x % c).str();
|
||||
this->m_feature_names.push_back(s);
|
||||
}
|
||||
@ -49,12 +49,16 @@ namespace sapt
|
||||
// cerr<<pp.joint<<"/"<<pp.good1<<"/"<<pp.raw2<<endl;
|
||||
}
|
||||
size_t i = this->m_index;
|
||||
float g = log(lbop(pp.good1, pp.joint, conf));;
|
||||
BOOST_FOREACH(char const& c, this->denom)
|
||||
{
|
||||
switch (c)
|
||||
{
|
||||
case 'b':
|
||||
(*dest)[i++] = g + log(pp.cum_bias) - log(pp.joint);
|
||||
break;
|
||||
case 'g':
|
||||
(*dest)[i++] = log(lbop(pp.good1, pp.joint, conf));
|
||||
(*dest)[i++] = g;
|
||||
break;
|
||||
case 's':
|
||||
(*dest)[i++] = log(lbop(pp.sample1, pp.joint, conf));
|
||||
|
2
regtest
2
regtest
@ -1 +1 @@
|
||||
Subproject commit f69e79f5fc92d993354fa775de197b029d321175
|
||||
Subproject commit e07a00c9733e0fecb8433f1c9d5805d3f0b35c6f
|
Loading…
Reference in New Issue
Block a user