Update incremental search, cuts runtime by a third

This commit is contained in:
Kenneth Heafield 2013-02-14 13:11:53 +00:00
parent 10012fac15
commit 8ef095e8fa
8 changed files with 272 additions and 180 deletions

View File

@ -40,7 +40,7 @@ class ChartCellLabel
public:
union Stack {
const HypoList *cube; // cube pruning
const search::Vertex *incr; // incremental search after filling.
search::Vertex *incr; // incremental search after filling.
void *incr_generator; // incremental search during filling.
};

View File

@ -105,7 +105,7 @@ template <class Model> void Fill<Model>::Add(const TargetPhraseCollection &targe
vertices.reserve(nts.size());
float below_score = 0.0;
for (StackVec::const_iterator i(nts.begin()); i != nts.end(); ++i) {
vertices.push_back((*i)->GetStack().incr->RootPartial());
vertices.push_back((*i)->GetStack().incr->RootAlternate());
if (vertices.back().Empty()) return;
below_score += vertices.back().Bound();
}

View File

@ -1 +1 @@
fakelib search : edge_generator.cc nbest.cc rule.cc vertex.cc vertex_generator.cc ../lm//kenlm ../util//kenutil /top//boost_system : : : <include>.. ;
fakelib search : edge_generator.cc nbest.cc rule.cc vertex.cc ../lm//kenlm ../util//kenutil /top//boost_system : : : <include>.. ;

View File

@ -54,20 +54,20 @@ template <class Model> PartialEdge EdgeGenerator::Pop(Context<Model> &context) {
Arity victim = 0;
Arity victim_completed;
Arity incomplete;
unsigned char lowest_niceness = 255;
// Select victim or return if complete.
{
Arity completed = 0;
unsigned char lowest_length = 255;
for (Arity i = 0; i != arity; ++i) {
if (top_nt[i].Complete()) {
++completed;
} else if (top_nt[i].Length() < lowest_length) {
lowest_length = top_nt[i].Length();
} else if (top_nt[i].Niceness() < lowest_niceness) {
lowest_niceness = top_nt[i].Niceness();
victim = i;
victim_completed = completed;
}
}
if (lowest_length == 255) {
if (lowest_niceness == 255) {
return top;
}
incomplete = arity - completed;
@ -92,10 +92,14 @@ template <class Model> PartialEdge EdgeGenerator::Pop(Context<Model> &context) {
generate_.push(alternate);
}
#ifndef NDEBUG
Score before = top.GetScore();
#endif
// top is now the continuation.
FastScore(context, victim, victim - victim_completed, incomplete, old_value, top);
// TODO: dedupe?
generate_.push(top);
assert(lowest_niceness != 254 || top.GetScore() == before);
// Invalid indicates no new hypothesis generated.
return PartialEdge();

View File

@ -2,6 +2,8 @@
#include "search/context.hh"
#include <boost/unordered_map.hpp>
#include <algorithm>
#include <functional>
@ -11,45 +13,193 @@ namespace search {
namespace {
struct GreaterByBound : public std::binary_function<const VertexNode *, const VertexNode *, bool> {
bool operator()(const VertexNode *first, const VertexNode *second) const {
return first->Bound() > second->Bound();
}
const uint64_t kCompleteAdd = static_cast<uint64_t>(-1);
class DivideLeft {
public:
explicit DivideLeft(unsigned char index)
: index_(index) {}
uint64_t operator()(const lm::ngram::ChartState &state) const {
return (index_ < state.left.length) ?
state.left.pointers[index_] :
(kCompleteAdd - state.left.full);
}
private:
unsigned char index_;
};
class DivideRight {
public:
explicit DivideRight(unsigned char index)
: index_(index) {}
uint64_t operator()(const lm::ngram::ChartState &state) const {
return (index_ < state.right.length) ?
static_cast<uint64_t>(state.right.words[index_]) :
(kCompleteAdd - state.left.full);
}
private:
unsigned char index_;
};
template <class Divider> void Split(const Divider &divider, const std::vector<HypoState> &hypos, std::vector<VertexNode> &extend) {
// Map from divider to index in extend.
typedef boost::unordered_map<uint64_t, std::size_t> Lookup;
Lookup lookup;
for (std::vector<HypoState>::const_iterator i = hypos.begin(); i != hypos.end(); ++i) {
uint64_t key = divider(i->state);
std::pair<Lookup::iterator, bool> res(lookup.insert(std::make_pair(key, extend.size())));
if (res.second) {
extend.resize(extend.size() + 1);
extend.back().AppendHypothesis(*i);
} else {
extend[res.first->second].AppendHypothesis(*i);
}
}
//assert((extend.size() != 1) || (hypos.size() == 1));
}
lm::WordIndex Identify(const lm::ngram::Right &right, unsigned char index) {
return right.words[index];
}
uint64_t Identify(const lm::ngram::Left &left, unsigned char index) {
return left.pointers[index];
}
template <class Side> class DetermineSame {
public:
DetermineSame(const Side &side, unsigned char guaranteed)
: side_(side), guaranteed_(guaranteed), shared_(side.length), complete_(true) {}
void Consider(const Side &other) {
if (shared_ != other.length) {
complete_ = false;
if (shared_ > other.length)
shared_ = other.length;
}
for (unsigned char i = guaranteed_; i < shared_; ++i) {
if (Identify(side_, i) != Identify(other, i)) {
shared_ = i;
complete_ = false;
return;
}
}
}
unsigned char Shared() const { return shared_; }
bool Complete() const { return complete_; }
private:
const Side &side_;
unsigned char guaranteed_, shared_;
bool complete_;
};
// Custom enum to save memory: valid values of policy_.
// Alternate and there is still alternation to do.
const unsigned char kPolicyAlternate = 0;
// Branch based on left state only, because right ran out or this is a left tree.
const unsigned char kPolicyOneLeft = 1;
// Branch based on right state only.
const unsigned char kPolicyOneRight = 2;
// Reveal everything in the next branch. Used to terminate the left/right policies.
// static const unsigned char kPolicyEverything = 3;
} // namespace
void VertexNode::RecursiveSortAndSet(ContextBase &context, VertexNode *&parent_ptr) {
if (Complete()) {
assert(end_);
assert(extend_.empty());
return;
namespace {
struct GreaterByScore : public std::binary_function<const HypoState &, const HypoState &, bool> {
bool operator()(const HypoState &first, const HypoState &second) const {
return first.score > second.score;
}
if (extend_.size() == 1) {
parent_ptr = extend_[0];
extend_[0]->RecursiveSortAndSet(context, parent_ptr);
context.DeleteVertexNode(this);
return;
};
} // namespace
void VertexNode::FinishRoot() {
std::sort(hypos_.begin(), hypos_.end(), GreaterByScore());
extend_.clear();
// HACK: extend to one hypo so that root can be blank.
state_.left.full = false;
state_.left.length = 0;
state_.right.length = 0;
right_full_ = false;
niceness_ = 0;
policy_ = kPolicyAlternate;
if (hypos_.size() == 1) {
extend_.resize(1);
extend_.front().AppendHypothesis(hypos_.front());
extend_.front().FinishedAppending(0, 0);
}
for (std::vector<VertexNode*>::iterator i = extend_.begin(); i != extend_.end(); ++i) {
(*i)->RecursiveSortAndSet(context, *i);
if (hypos_.empty()) {
bound_ = -INFINITY;
} else {
bound_ = hypos_.front().score;
}
std::sort(extend_.begin(), extend_.end(), GreaterByBound());
bound_ = extend_.front()->Bound();
}
void VertexNode::SortAndSet(ContextBase &context) {
// This is the root. The root might be empty.
if (extend_.empty()) {
bound_ = -INFINITY;
return;
void VertexNode::FinishedAppending(const unsigned char common_left, const unsigned char common_right) {
assert(!hypos_.empty());
assert(extend_.empty());
bound_ = hypos_.front().score;
state_ = hypos_.front().state;
bool all_full = state_.left.full;
bool all_non_full = !state_.left.full;
DetermineSame<lm::ngram::Left> left(state_.left, common_left);
DetermineSame<lm::ngram::Right> right(state_.right, common_right);
for (std::vector<HypoState>::const_iterator i = hypos_.begin() + 1; i != hypos_.end(); ++i) {
all_full &= i->state.left.full;
all_non_full &= !i->state.left.full;
left.Consider(i->state.left);
right.Consider(i->state.right);
}
// The root cannot be replaced. There's always one transition.
for (std::vector<VertexNode*>::iterator i = extend_.begin(); i != extend_.end(); ++i) {
(*i)->RecursiveSortAndSet(context, *i);
state_.left.full = all_full && left.Complete();
right_full_ = all_full && right.Complete();
state_.left.length = left.Shared();
state_.right.length = right.Shared();
if (!all_full && !all_non_full) {
policy_ = kPolicyAlternate;
} else if (left.Complete()) {
policy_ = kPolicyOneRight;
} else if (right.Complete()) {
policy_ = kPolicyOneLeft;
} else {
policy_ = kPolicyAlternate;
}
niceness_ = state_.left.length + state_.right.length;
}
void VertexNode::BuildExtend() {
// Already built.
if (!extend_.empty()) return;
// Nothing to build since this is a leaf.
if (hypos_.size() <= 1) return;
bool left_branch = true;
switch (policy_) {
case kPolicyAlternate:
left_branch = (state_.left.length <= state_.right.length);
break;
case kPolicyOneLeft:
left_branch = true;
break;
case kPolicyOneRight:
left_branch = false;
break;
}
if (left_branch) {
Split(DivideLeft(state_.left.length), hypos_, extend_);
} else {
Split(DivideRight(state_.right.length), hypos_, extend_);
}
for (std::vector<VertexNode>::iterator i = extend_.begin(); i != extend_.end(); ++i) {
// TODO: provide more here for branching?
i->FinishedAppending(state_.left.length, state_.right.length);
}
std::sort(extend_.begin(), extend_.end(), GreaterByBound());
bound_ = extend_.front()->Bound();
}
} // namespace search

View File

@ -16,59 +16,74 @@ namespace search {
class ContextBase;
struct HypoState {
History history;
lm::ngram::ChartState state;
Score score;
};
class VertexNode {
public:
VertexNode() : end_() {}
VertexNode() {}
void InitRoot() {
extend_.clear();
state_.left.full = false;
state_.left.length = 0;
state_.right.length = 0;
right_full_ = false;
end_ = History();
void InitRoot() { hypos_.clear(); }
/* The steps of building a VertexNode:
* 1. Default construct.
* 2. AppendHypothesis at least once, possibly multiple times.
* 3. FinishAppending with the number of words on left and right guaranteed
* to be common.
* 4. If !Complete(), call BuildExtend to construct the extensions
*/
// Must default construct, call AppendHypothesis 1 or more times then do FinishedAppending.
void AppendHypothesis(const NBestComplete &best) {
assert(hypos_.empty() || !(hypos_.front().state == *best.state));
HypoState hypo;
hypo.history = best.history;
hypo.state = *best.state;
hypo.score = best.score;
hypos_.push_back(hypo);
}
void AppendHypothesis(const HypoState &hypo) {
hypos_.push_back(hypo);
}
lm::ngram::ChartState &MutableState() { return state_; }
bool &MutableRightFull() { return right_full_; }
// Sort hypotheses for the root.
void FinishRoot();
void AddExtend(VertexNode *next) {
extend_.push_back(next);
}
void FinishedAppending(const unsigned char common_left, const unsigned char common_right);
void SetEnd(History end, Score score) {
assert(!end_);
end_ = end;
bound_ = score;
}
void SortAndSet(ContextBase &context);
void BuildExtend();
// Should only happen to a root node when the entire vertex is empty.
bool Empty() const {
return !end_ && extend_.empty();
return hypos_.empty() && extend_.empty();
}
bool Complete() const {
return end_;
// HACK: prevent root from being complete. TODO: allow root to be complete.
return hypos_.size() == 1 && extend_.empty();
}
const lm::ngram::ChartState &State() const { return state_; }
bool RightFull() const { return right_full_; }
// Priority relative to other non-terminals. 0 is highest.
unsigned char Niceness() const { return niceness_; }
Score Bound() const {
return bound_;
}
unsigned char Length() const {
return state_.left.length + state_.right.length;
// Will be invalid unless this is a leaf.
const History End() const {
assert(hypos_.size() == 1);
return hypos_.front().history;
}
// Will be invalid unless this is a leaf.
const History End() const { return end_; }
const VertexNode &operator[](size_t index) const {
return *extend_[index];
VertexNode &operator[](size_t index) {
assert(!extend_.empty());
return extend_[index];
}
size_t Size() const {
@ -76,22 +91,26 @@ class VertexNode {
}
private:
void RecursiveSortAndSet(ContextBase &context, VertexNode *&parent);
// Hypotheses to be split.
std::vector<HypoState> hypos_;
std::vector<VertexNode*> extend_;
std::vector<VertexNode> extend_;
lm::ngram::ChartState state_;
bool right_full_;
unsigned char niceness_;
unsigned char policy_;
Score bound_;
History end_;
};
class PartialVertex {
public:
PartialVertex() {}
explicit PartialVertex(const VertexNode &back) : back_(&back), index_(0) {}
explicit PartialVertex(VertexNode &back) : back_(&back), index_(0) {}
bool Empty() const { return back_->Empty(); }
@ -100,17 +119,14 @@ class PartialVertex {
const lm::ngram::ChartState &State() const { return back_->State(); }
bool RightFull() const { return back_->RightFull(); }
Score Bound() const { return Complete() ? back_->Bound() : (*back_)[index_].Bound(); }
Score Bound() const { return index_ ? (*back_)[index_].Bound() : back_->Bound(); }
unsigned char Length() const { return back_->Length(); }
bool HasAlternative() const {
return index_ + 1 < back_->Size();
}
unsigned char Niceness() const { return back_->Niceness(); }
// Split into continuation and alternative, rendering this the continuation.
bool Split(PartialVertex &alternative) {
assert(!Complete());
back_->BuildExtend();
bool ret;
if (index_ + 1 < back_->Size()) {
alternative.index_ = index_ + 1;
@ -129,7 +145,7 @@ class PartialVertex {
}
private:
const VertexNode *back_;
VertexNode *back_;
unsigned int index_;
};
@ -139,10 +155,21 @@ class Vertex {
public:
Vertex() {}
PartialVertex RootPartial() const { return PartialVertex(root_); }
//PartialVertex RootFirst() const { return PartialVertex(right_); }
PartialVertex RootAlternate() { return PartialVertex(root_); }
//PartialVertex RootLast() const { return PartialVertex(left_); }
const History BestChild() const {
PartialVertex top(RootPartial());
bool Empty() const {
return root_.Empty();
}
Score Bound() const {
return root_.Bound();
}
const History BestChild() {
// left_ and right_ are not set at the root.
PartialVertex top(RootAlternate());
if (top.Empty()) {
return History();
} else {
@ -158,6 +185,12 @@ class Vertex {
template <class Output> friend class VertexGenerator;
template <class Output> friend class RootVertexGenerator;
VertexNode root_;
// These will not be set for the root vertex.
// Branches only on left state.
//VertexNode left_;
// Branches only on right state.
//VertexNode right_;
};
} // namespace search

View File

@ -1,68 +0,0 @@
#include "search/vertex_generator.hh"
#include "lm/left.hh"
#include "search/context.hh"
#include "search/edge.hh"
#include <boost/unordered_map.hpp>
#include <boost/version.hpp>
#include <stdint.h>
namespace search {
#if BOOST_VERSION > 104200
namespace {
const uint64_t kCompleteAdd = static_cast<uint64_t>(-1);
Trie &FindOrInsert(ContextBase &context, Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full) {
Trie &next = node.extend[added];
if (!next.under) {
next.under = context.NewVertexNode();
lm::ngram::ChartState &writing = next.under->MutableState();
writing = state;
writing.left.full &= left_full && state.left.full;
next.under->MutableRightFull() = right_full && state.left.full;
writing.left.length = left;
writing.right.length = right;
node.under->AddExtend(next.under);
}
return next;
}
} // namespace
void AddHypothesis(ContextBase &context, Trie &root, const NBestComplete &end) {
const lm::ngram::ChartState &state = *end.state;
unsigned char left = 0, right = 0;
Trie *node = &root;
while (true) {
if (left == state.left.length) {
node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, left, true, right, false);
for (; right < state.right.length; ++right) {
node = &FindOrInsert(context, *node, state.right.words[right], state, left, true, right + 1, false);
}
break;
}
node = &FindOrInsert(context, *node, state.left.pointers[left], state, left + 1, false, right, false);
left++;
if (right == state.right.length) {
node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, left, false, right, true);
for (; left < state.left.length; ++left) {
node = &FindOrInsert(context, *node, state.left.pointers[left], state, left + 1, false, right, true);
}
break;
}
node = &FindOrInsert(context, *node, state.right.words[right], state, left, false, right + 1, false);
right++;
}
node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, state.left.length, true, state.right.length, true);
node->under->SetEnd(end.history, end.score);
}
#endif // BOOST_VERSION
} // namespace search

View File

@ -5,13 +5,6 @@
#include "search/types.hh"
#include "search/vertex.hh"
#include <boost/unordered_map.hpp>
#include <boost/version.hpp>
#if BOOST_VERSION <= 104200
#include "util/exception.hh"
#endif
namespace lm {
namespace ngram {
class ChartState;
@ -22,45 +15,25 @@ namespace search {
class ContextBase;
#if BOOST_VERSION > 104200
// Parallel structure to VertexNode.
struct Trie {
Trie() : under(NULL) {}
VertexNode *under;
boost::unordered_map<uint64_t, Trie> extend;
};
void AddHypothesis(ContextBase &context, Trie &root, const NBestComplete &end);
#endif // BOOST_VERSION
// Output makes the single-best or n-best list.
template <class Output> class VertexGenerator {
public:
VertexGenerator(ContextBase &context, Vertex &gen, Output &nbest) : context_(context), gen_(gen), nbest_(nbest) {
gen.root_.InitRoot();
}
VertexGenerator(ContextBase &context, Vertex &gen, Output &nbest) : context_(context), gen_(gen), nbest_(nbest) {}
void NewHypothesis(PartialEdge partial) {
nbest_.Add(existing_[hash_value(partial.CompletedState())], partial);
}
void FinishedSearch() {
#if BOOST_VERSION > 104200
Trie root;
root.under = &gen_.root_;
gen_.root_.InitRoot();
for (typename Existing::iterator i(existing_.begin()); i != existing_.end(); ++i) {
AddHypothesis(context_, root, nbest_.Complete(i->second));
gen_.root_.AppendHypothesis(nbest_.Complete(i->second));
}
existing_.clear();
root.under->SortAndSet(context_);
#else
UTIL_THROW(util::Exception, "Upgrade Boost to >= 1.42.0 to use incremental search.");
#endif
gen_.root_.FinishRoot();
}
const Vertex &Generating() const { return gen_; }
Vertex &Generating() { return gen_; }
private:
ContextBase &context_;
@ -87,8 +60,8 @@ template <class Output> class RootVertexGenerator {
void FinishedSearch() {
gen_.root_.InitRoot();
NBestComplete completed(out_.Complete(combine_));
gen_.root_.SetEnd(completed.history, completed.score);
gen_.root_.AppendHypothesis(out_.Complete(combine_));
gen_.root_.FinishRoot();
}
private: