This commit is contained in:
Hieu Hoang 2021-04-30 09:52:56 -07:00
parent 5be82498ae
commit 84d498756b
2 changed files with 9 additions and 9 deletions

View File

@ -24,9 +24,9 @@ Shortlist::Shortlist(const std::vector<WordIndex>& indices)
Shortlist::~Shortlist() {}
const std::vector<WordIndex>& Shortlist::indices() const { return indices_; }
WordIndex Shortlist::reverseMap(size_t , size_t , int idx) const { return indices_[idx]; }
WordIndex Shortlist::reverseMap(int , int , int idx) const { return indices_[idx]; }
WordIndex Shortlist::tryForwardMap(size_t , size_t , WordIndex wIdx) const {
WordIndex Shortlist::tryForwardMap(int , int , WordIndex wIdx) const {
auto first = std::lower_bound(indices_.begin(), indices_.end(), wIdx);
if(first != indices_.end() && *first == wIdx) // check if element not less than wIdx has been found and if equal to wIdx
return (int)std::distance(indices_.begin(), first); // return coordinate if found
@ -45,7 +45,7 @@ void Shortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW, Exp
out->val()->set(indices_);
};
int k = indices_.size();
int k = (int) indices_.size();
Shape kShape({k});
indicesExpr_ = lambda({input, weights}, kShape, Type::uint32, forward);
@ -141,7 +141,7 @@ LSHShortlist::LSHShortlist(int k, int nbits)
#define BLAS_FOUND 1
WordIndex LSHShortlist::reverseMap(size_t batchIdx, size_t beamIdx, int idx) const {
WordIndex LSHShortlist::reverseMap(int batchIdx, int beamIdx, int idx) const {
std::cerr << "\nbatchIdx=" << batchIdx << " beamIdx=" << beamIdx << " idx=" << idx << std::endl;
//std::cerr << "indicesExpr_=" << indicesExpr_->shape() << std::endl;
int currBeamSize = indicesExpr_->shape()[1];
@ -153,7 +153,7 @@ WordIndex LSHShortlist::reverseMap(size_t batchIdx, size_t beamIdx, int idx) con
return indices_[idx];
}
WordIndex LSHShortlist::tryForwardMap(size_t , size_t , WordIndex wIdx) const {
WordIndex LSHShortlist::tryForwardMap(int , int , WordIndex wIdx) const {
//utils::Debug(indices_, "LSHShortlist::tryForwardMap indices_");
auto first = std::lower_bound(indices_.begin(), indices_.end(), wIdx);
bool found = first != indices_.end();

View File

@ -43,8 +43,8 @@ public:
Shortlist(const std::vector<WordIndex>& indices);
virtual ~Shortlist();
virtual WordIndex reverseMap(size_t batchIdx, size_t beamIdx, int idx) const;
virtual WordIndex tryForwardMap(size_t batchIdx, size_t beamIdx, WordIndex wIdx) const;
virtual WordIndex reverseMap(int batchIdx, int beamIdx, int idx) const;
virtual WordIndex tryForwardMap(int batchIdx, int beamIdx, WordIndex wIdx) const;
virtual void filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt);
virtual Expr getIndicesExpr(int batchSize, int currBeamSize) const;
@ -76,8 +76,8 @@ private:
public:
LSHShortlist(int k, int nbits);
virtual WordIndex reverseMap(size_t batchIdx, size_t beamIdx, int idx) const override;
virtual WordIndex tryForwardMap(size_t batchIdx, size_t beamIdx, WordIndex wIdx) const override;
virtual WordIndex reverseMap(int batchIdx, int beamIdx, int idx) const override;
virtual WordIndex tryForwardMap(int batchIdx, int beamIdx, WordIndex wIdx) const override;
virtual void filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt) override;
virtual Expr getIndicesExpr(int batchSize,int currBeamSize) const override;