mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-03 20:13:47 +03:00
start lsh shortlist
This commit is contained in:
parent
b36d0bbbab
commit
415769fb2f
@ -133,16 +133,30 @@ Ptr<Shortlist> QuicksandShortlistGenerator::generate(Ptr<data::CorpusBatch> batc
|
||||
return New<Shortlist>(indices);
|
||||
}
|
||||
|
||||
LSHlistGenerator::LSHlistGenerator(int k, int nbits) {
|
||||
|
||||
}
|
||||
|
||||
Ptr<Shortlist> LSHlistGenerator::generate(Ptr<data::CorpusBatch> batch) const {
|
||||
|
||||
}
|
||||
|
||||
|
||||
Ptr<ShortlistGenerator> createShortlistGenerator(Ptr<Options> options,
|
||||
Ptr<const Vocab> srcVocab,
|
||||
Ptr<const Vocab> trgVocab,
|
||||
size_t srcIdx,
|
||||
size_t trgIdx,
|
||||
const std::vector<int> &lshOpts,
|
||||
bool shared) {
|
||||
std::cerr << "lshOpts=" << lshOpts.size() << std::endl;
|
||||
std::vector<std::string> vals = options->get<std::vector<std::string>>("shortlist");
|
||||
ABORT_IF(vals.empty(), "No path to shortlist given");
|
||||
std::string fname = vals[0];
|
||||
if(filesystem::Path(fname).extension().string() == ".bin") {
|
||||
if (lshOpts.size() == 2) {
|
||||
return New<LSHlistGenerator>(lshOpts[0], lshOpts[1]);
|
||||
}
|
||||
else if(filesystem::Path(fname).extension().string() == ".bin") {
|
||||
return New<QuicksandShortlistGenerator>(options, srcVocab, trgVocab, srcIdx, trgIdx, shared);
|
||||
} else {
|
||||
return New<LexicalShortlistGenerator>(options, srcVocab, trgVocab, srcIdx, trgIdx, shared);
|
||||
|
@ -328,6 +328,13 @@ public:
|
||||
virtual Ptr<Shortlist> generate(Ptr<data::CorpusBatch> batch) const override;
|
||||
};
|
||||
|
||||
class LSHlistGenerator : public ShortlistGenerator {
|
||||
private:
|
||||
|
||||
public:
|
||||
LSHlistGenerator(int k, int nbits);
|
||||
};
|
||||
|
||||
/*
|
||||
Shortlist factory to create correct type of shortlist. Currently assumes everything is a text shortlist
|
||||
unless the extension is *.bin for which the Microsoft legacy binary shortlist is used.
|
||||
@ -337,6 +344,7 @@ Ptr<ShortlistGenerator> createShortlistGenerator(Ptr<Options> options,
|
||||
Ptr<const Vocab> trgVocab,
|
||||
size_t srcIdx = 0,
|
||||
size_t trgIdx = 1,
|
||||
const std::vector<int> &lshOpts,
|
||||
bool shared = false);
|
||||
|
||||
} // namespace data
|
||||
|
@ -62,8 +62,9 @@ public:
|
||||
trgVocab_->load(vocabs.back());
|
||||
auto srcVocab = corpus_->getVocabs()[0];
|
||||
|
||||
if(options_->hasAndNotEmpty("shortlist"))
|
||||
shortlistGenerator_ = data::createShortlistGenerator(options_, srcVocab, trgVocab_, 0, 1, vocabs.front() == vocabs.back());
|
||||
std::vector<int> lshOpts = options_->get<std::vector<int>>("output-approx-knn");
|
||||
if(lshOpts.size() == 2 || options_->hasAndNotEmpty("shortlist"))
|
||||
shortlistGenerator_ = data::createShortlistGenerator(options_, srcVocab, trgVocab_, 0, 1, lshOpts, vocabs.front() == vocabs.back());
|
||||
|
||||
auto devices = Config::getDevices(options_);
|
||||
numDevices_ = devices.size();
|
||||
|
Loading…
Reference in New Issue
Block a user