Changes to Mmsapt.

- added phrase table features: pfwd / pbwd weighted by each sample's
  weight (when used with biased sampling)
- a few minor internal changes ({p|j}stats.add(...) reports back
  current joint count (intended for early stopping of sampling when
  evidence is sufficient; yet to be explored experimentally).
This commit is contained in:
Ulrich Germann 2015-11-29 18:13:25 +00:00
parent 4c78e7c0b2
commit c165e80e48
8 changed files with 77 additions and 48 deletions

View File

@ -32,8 +32,8 @@ namespace sapt
indoc = other.indoc;
for (int i = 0; i <= LRModel::NONE; i++)
{
ofwd[i] = other.ofwd[i];
obwd[i] = other.obwd[i];
ofwd[i] = other.ofwd[i];
obwd[i] = other.obwd[i];
}
}
@ -53,7 +53,7 @@ namespace sapt
return obwd[idx];
}
void
size_t
jstats::
add(float w, float b, std::vector<unsigned char> const& a, uint32_t const cnt2,
uint32_t fwd_orient, uint32_t bwd_orient, int const docid)
@ -65,24 +65,25 @@ namespace sapt
my_bcnt += b;
if (a.size())
{
size_t i = 0;
while (i < my_aln.size() && my_aln[i].second != a) ++i;
if (i == my_aln.size())
my_aln.push_back(std::pair<size_t,std::vector<unsigned char> >(1,a));
else
my_aln[i].first++;
if (my_aln[i].first > my_aln[i/2].first)
push_heap(my_aln.begin(),my_aln.begin()+i+1);
size_t i = 0;
while (i < my_aln.size() && my_aln[i].second != a) ++i;
if (i == my_aln.size())
my_aln.push_back(std::pair<size_t,std::vector<unsigned char> >(1,a));
else
my_aln[i].first++;
if (my_aln[i].first > my_aln[i/2].first)
push_heap(my_aln.begin(),my_aln.begin()+i+1);
}
++ofwd[fwd_orient];
++obwd[bwd_orient];
if (docid >= 0)
{
// while (int(indoc.size()) <= docid) indoc.push_back(0);
++indoc[docid];
// while (int(indoc.size()) <= docid) indoc.push_back(0);
++indoc[docid];
}
return my_rcnt;
}
std::vector<std::pair<size_t, std::vector<unsigned char> > > const&
jstats::
aln() const

View File

@ -39,7 +39,7 @@ namespace sapt
std::vector<std::pair<size_t, std::vector<unsigned char> > > const & aln() const;
void
size_t
add(float w, float b, std::vector<unsigned char> const& a, uint32_t const cnt2,
uint32_t fwd_orient, uint32_t bwd_orient, int const docid);

View File

@ -63,7 +63,7 @@ namespace sapt
}
}
bool
size_t
pstats::
add(uint64_t pid, float const w, float const b,
std::vector<unsigned char> const& a,
@ -73,13 +73,13 @@ namespace sapt
{
boost::lock_guard<boost::mutex> guard(this->lock);
jstats& entry = this->trg[pid];
entry.add(w, b, a, cnt2, fwd_o, bwd_o, docid);
size_t ret = entry.add(w, b, a, cnt2, fwd_o, bwd_o, docid);
if (this->good < entry.rcnt())
{
UTIL_THROW(util::Exception, "more joint counts than good counts:"
<< entry.rcnt() << "/" << this->good << "!");
UTIL_THROW(util::Exception, "more joint counts than good counts:"
<< entry.rcnt() << "/" << this->good << "!");
}
return true;
return ret;
}
void

View File

@ -41,7 +41,7 @@ namespace sapt
void register_worker();
size_t count_workers() { return in_progress; }
bool
size_t
add(uint64_t const pid, // target phrase id
float const w, // sample weight (1./(# of phrases extractable))
float const b, // sample bias score

View File

@ -45,7 +45,7 @@ BitextSampler : public Moses::reference_counter
{
typedef Bitext<Token> bitext;
typedef TSA<Token> tsa;
typedef SamplingBias bias;
typedef SamplingBias bias_t;
typedef typename Bitext<Token>::iter tsa_iter;
mutable boost::condition_variable m_ready;
mutable boost::mutex m_lock;
@ -59,7 +59,7 @@ BitextSampler : public Moses::reference_counter
char const* m_next; // current position
char const* m_stop; // end of search range
sampling_method const m_method; // look at all/random/ranked samples
SPTR<bias const> const m_bias; // bias over candidates
SPTR<bias_t const> const m_bias; // bias over candidates
size_t const m_samples; // how many samples at most
size_t const m_min_samples;
// non-const members
@ -67,20 +67,20 @@ BitextSampler : public Moses::reference_counter
size_t m_ctr; // number of samples considered
float m_total_bias; // for random sampling with bias
bool m_finished;
size_t m_num_occurrences; // estimated number of phrase occurrences in corpus
boost::taus88 m_rnd; // every job has its own pseudo random generator
// double m_rnd_denom; // denominator for scaling random sampling
double m_bias_total;
bool consider_sample(TokenPosition const& p);
size_t consider_sample(TokenPosition const& p);
size_t perform_random_sampling();
size_t perform_full_phrase_extraction();
int check_sample_distribution(uint64_t const& sid, uint64_t const& offset);
bool flip_coin(id_type & sid, ushort & offset);
bool flip_coin(id_type const& sid, ushort const& offset, SamplingBias const* bias);
public:
BitextSampler(BitextSampler const& other);
BitextSampler const& operator=(BitextSampler const& other);
// BitextSampler const& operator=(BitextSampler const& other);
BitextSampler(SPTR<bitext const> const& bitext,
typename bitext::iter const& phrase,
SPTR<SamplingBias const> const& bias,
@ -159,9 +159,9 @@ check_sample_distribution(uint64_t const& sid, uint64_t const& offset)
template<typename Token>
bool
BitextSampler<Token>::
flip_coin(id_type & sid, ushort & offset)
flip_coin(id_type const& sid, ushort const& offset, bias_t const* bias)
{
int no_maybe_yes = m_bias ? check_sample_distribution(sid, offset) : 1;
int no_maybe_yes = bias ? check_sample_distribution(sid, offset) : 1;
if (no_maybe_yes == 0) return false; // no
if (no_maybe_yes > 1) return true; // yes
// ... maybe: flip a coin
@ -170,8 +170,8 @@ flip_coin(id_type & sid, ushort & offset)
size_t options_left = (options_total - m_ctr);
size_t random_number = options_left * (m_rnd()/(m_rnd.max()+1.));
size_t threshold;
if (m_bias_total) // we have a bias and there are candidates with non-zero prob
threshold = ((*m_bias)[sid]/m_bias_total * options_total * m_samples);
if (bias && m_bias_total > 0) // we have a bias and there are candidates with non-zero prob
threshold = ((*bias)[sid]/m_bias_total * options_total * m_samples);
else // no bias, or all have prob 0 (can happen with a very opinionated bias)
threshold = m_samples;
return random_number + options_chosen < threshold;
@ -199,13 +199,12 @@ BitextSampler(SPTR<Bitext<Token> const> const& bitext,
, m_ctr(0)
, m_total_bias(0)
, m_finished(false)
, m_num_occurrences(phrase.ca())
, m_rnd(0)
// , m_rnd_denom(m_rnd.max() + 1)
{
m_stats.reset(new pstats);
m_stats->raw_cnt = phrase.ca();
m_stats->register_worker();
// cerr << phrase.str(bitext->V1.get()) << " [" << HERE << "]" << endl;
}
template<typename Token>
@ -221,8 +220,8 @@ BitextSampler(BitextSampler const& other)
, m_bias(other.m_bias)
, m_samples(other.m_samples)
, m_min_samples(other.m_min_samples)
, m_num_occurrences(other.m_num_occurrences)
, m_rnd(0)
// , m_rnd_denom(m_rnd.max() + 1)
{
// lock both instances
boost::unique_lock<boost::mutex> mylock(m_lock);
@ -235,6 +234,23 @@ BitextSampler(BitextSampler const& other)
m_finished = other.m_finished;
}
// Uniform sampling
template<typename Token>
size_t
BitextSampler<Token>::
perform_full_phrase_extraction()
{
if (m_next == m_stop) return m_ctr;
for (sapt::tsa::ArrayEntry I(m_next); I.next < m_stop; ++m_ctr)
{
++m_ctr;
m_root->readEntry(I.next, I);
consider_sample(I);
}
return m_ctr;
}
// Uniform sampling
template<typename Token>
size_t
@ -260,14 +276,14 @@ perform_random_sampling()
{
++m_ctr;
m_root->readEntry(I.next,I);
if (!flip_coin(I.sid, I.offset)) continue;
if (!flip_coin(I.sid, I.offset, m_bias.get())) continue;
consider_sample(I);
}
return m_ctr;
}
template<typename Token>
bool
size_t
BitextSampler<Token>::
consider_sample(TokenPosition const& p)
{
@ -279,7 +295,7 @@ consider_sample(TokenPosition const& p)
if (!m_bitext->find_trg_phr_bounds(rec))
{ // no good, probably because phrase is not coherent
m_stats->count_sample(docid, 0, rec.po_fwd, rec.po_bwd);
return false;
return 0;
}
// all good: register this sample as valid
@ -300,6 +316,7 @@ consider_sample(TokenPosition const& p)
// pair once per source phrase occurrence, or else run the risk of
// having more joint counts than marginal counts.
size_t max_evidence = 0;
for (size_t s = rec.s1; s <= rec.s2; ++s)
{
TSA<Token> const& I = m_fwd ? *m_bitext->I2 : *m_bitext->I1;
@ -313,8 +330,10 @@ consider_sample(TokenPosition const& p)
continue; // don't over-count
seen.push_back(tpid);
size_t raw2 = b->approxOccurrenceCount();
m_stats->add(tpid, sample_weight, m_bias ? (*m_bias)[p.sid] : 1,
aln, raw2, rec.po_fwd, rec.po_bwd, docid);
size_t evid = m_stats->add(tpid, sample_weight,
m_bias ? (*m_bias)[p.sid] : 1,
aln, raw2, rec.po_fwd, rec.po_bwd, docid);
max_evidence = std::max(max_evidence, evid);
bool ok = (i == rec.e2) || b->extend(o[i].id());
UTIL_THROW_IF2(!ok, "Could not extend target phrase.");
}
@ -322,7 +341,7 @@ consider_sample(TokenPosition const& p)
for (size_t k = 1; k < aln.size(); k += 2)
--aln[k];
}
return true;
return max_evidence;
}
#ifndef MMT
@ -333,7 +352,9 @@ operator()()
{
if (m_finished) return true;
boost::unique_lock<boost::mutex> lock(m_lock);
if (m_method == random_sampling)
if (m_method == full_coverage)
preform_full_phrase_extraction(); // consider all occurrences
else if (m_method == random_sampling)
perform_random_sampling();
else UTIL_THROW2("Unsupported sampling method.");
m_finished = true;

View File

@ -26,7 +26,7 @@ namespace sapt
BOOST_FOREACH(char const& x, denom)
{
if (x == '+') { --checksum; continue; }
if (x != 'g' && x != 's' && x != 'r') continue;
if (x != 'g' && x != 's' && x != 'r' && x != 'b') continue;
std::string s = (boost::format("pbwd-%c%.3f") % x % c).str();
this->m_feature_names.push_back(s);
}
@ -48,9 +48,12 @@ namespace sapt
BOOST_FOREACH(char const& x, denom)
{
uint32_t m2 = pp.raw2;
if (x == 'g') m2 = round(m2 * float(pp.good1) / pp.raw1);
if (x == 'g' || x == 'b') m2 = round(m2 * float(pp.good1) / pp.raw1);
else if (x == 's') m2 = round(m2 * float(pp.sample1) / pp.raw1);
(*dest)[i++] = log(lbop(std::max(m2, pp.joint), pp.joint,conf));
(*dest)[i] = log(lbop(std::max(m2, pp.joint), pp.joint,conf));
if (x == 'b') (*dest)[i] += log(pp.cum_bias) - log(pp.joint);
++i;
}
}
};

View File

@ -27,7 +27,7 @@ namespace sapt
BOOST_FOREACH(char const& x, denom)
{
if (x == '+') { --checksum; continue; }
if (x != 'g' && x != 's' && x != 'r') continue;
if (x != 'g' && x != 's' && x != 'r' && x != 'b') continue;
std::string s = (boost::format("pfwd-%c%.3f") % x % c).str();
this->m_feature_names.push_back(s);
}
@ -49,12 +49,16 @@ namespace sapt
// cerr<<pp.joint<<"/"<<pp.good1<<"/"<<pp.raw2<<endl;
}
size_t i = this->m_index;
float g = log(lbop(pp.good1, pp.joint, conf));;
BOOST_FOREACH(char const& c, this->denom)
{
switch (c)
{
case 'b':
(*dest)[i++] = g + log(pp.cum_bias) - log(pp.joint);
break;
case 'g':
(*dest)[i++] = log(lbop(pp.good1, pp.joint, conf));
(*dest)[i++] = g;
break;
case 's':
(*dest)[i++] = log(lbop(pp.sample1, pp.joint, conf));

@ -1 +1 @@
Subproject commit f69e79f5fc92d993354fa775de197b029d321175
Subproject commit e07a00c9733e0fecb8433f1c9d5805d3f0b35c6f