mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-03 20:13:47 +03:00
virtual destructor
This commit is contained in:
parent
1672201450
commit
5be82498ae
@ -21,10 +21,12 @@ const T* get(const void*& current, size_t num = 1) {
|
||||
Shortlist::Shortlist(const std::vector<WordIndex>& indices)
|
||||
: indices_(indices) {}
|
||||
|
||||
const std::vector<WordIndex>& Shortlist::indices() const { return indices_; }
|
||||
WordIndex Shortlist::reverseMap(size_t batchIdx, size_t beamIdx, int idx) const { return indices_[idx]; }
|
||||
Shortlist::~Shortlist() {}
|
||||
|
||||
WordIndex Shortlist::tryForwardMap(size_t batchIdx, size_t beamIdx, WordIndex wIdx) const {
|
||||
const std::vector<WordIndex>& Shortlist::indices() const { return indices_; }
|
||||
WordIndex Shortlist::reverseMap(size_t , size_t , int idx) const { return indices_[idx]; }
|
||||
|
||||
WordIndex Shortlist::tryForwardMap(size_t , size_t , WordIndex wIdx) const {
|
||||
auto first = std::lower_bound(indices_.begin(), indices_.end(), wIdx);
|
||||
if(first != indices_.end() && *first == wIdx) // check if element not less than wIdx has been found and if equal to wIdx
|
||||
return (int)std::distance(indices_.begin(), first); // return coordinate if found
|
||||
@ -39,7 +41,7 @@ void Shortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW, Exp
|
||||
//std::cerr << "currBeamSize=" << currBeamSize << std::endl;
|
||||
//std::cerr << "batchSize=" << batchSize << std::endl;
|
||||
|
||||
auto forward = [this](Expr out, const std::vector<Expr>& inputs) {
|
||||
auto forward = [this](Expr out, const std::vector<Expr>& ) {
|
||||
out->val()->set(indices_);
|
||||
};
|
||||
|
||||
@ -151,7 +153,7 @@ WordIndex LSHShortlist::reverseMap(size_t batchIdx, size_t beamIdx, int idx) con
|
||||
return indices_[idx];
|
||||
}
|
||||
|
||||
WordIndex LSHShortlist::tryForwardMap(size_t batchIdx, size_t beamIdx, WordIndex wIdx) const {
|
||||
WordIndex LSHShortlist::tryForwardMap(size_t , size_t , WordIndex wIdx) const {
|
||||
//utils::Debug(indices_, "LSHShortlist::tryForwardMap indices_");
|
||||
auto first = std::lower_bound(indices_.begin(), indices_.end(), wIdx);
|
||||
bool found = first != indices_.end();
|
||||
@ -203,7 +205,7 @@ void LSHShortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW,
|
||||
|
||||
indices_.clear();
|
||||
for(auto id : ids) {
|
||||
indices_.push_back(id);
|
||||
indices_.push_back((WordIndex)id);
|
||||
}
|
||||
|
||||
for (size_t hypoIdx = 0; hypoIdx < numHypos; ++hypoIdx) {
|
||||
|
@ -41,7 +41,8 @@ public:
|
||||
static constexpr WordIndex npos{std::numeric_limits<WordIndex>::max()}; // used to identify invalid shortlist entries similar to std::string::npos
|
||||
|
||||
Shortlist(const std::vector<WordIndex>& indices);
|
||||
|
||||
virtual ~Shortlist();
|
||||
|
||||
virtual WordIndex reverseMap(size_t batchIdx, size_t beamIdx, int idx) const;
|
||||
virtual WordIndex tryForwardMap(size_t batchIdx, size_t beamIdx, WordIndex wIdx) const;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user