mirror of
https://github.com/marian-nmt/marian.git
synced 2024-12-03 03:54:56 +03:00
create PreProcess() method
This commit is contained in:
parent
ac2217f37e
commit
3a12bbd233
@ -45,40 +45,22 @@ std::shared_ptr<Histories> Search::Process(const God &god, const Sentences& sent
|
||||
std::shared_ptr<Histories> ret(new Histories(god, sentences));
|
||||
|
||||
size_t batchSize = sentences.size();
|
||||
std::vector<size_t> beamSizes(batchSize, 1);
|
||||
size_t vocabSize = scorers_[0]->GetVocabSize();
|
||||
|
||||
Beam prevHyps(batchSize, HypothesisPtr(new Hypothesis()));
|
||||
for (size_t i = 0; i < ret->size(); ++i) {
|
||||
History &history = *ret->at(i).get();
|
||||
history.Add(prevHyps);
|
||||
}
|
||||
|
||||
PreProcess(god, sentences, ret, prevHyps);
|
||||
|
||||
// Encode
|
||||
States states(scorers_.size());
|
||||
States nextStates(scorers_.size());
|
||||
|
||||
size_t vocabSize = scorers_[0]->GetVocabSize();
|
||||
|
||||
bool filter = god.Get<std::vector<std::string>>("softmax-filter").size();
|
||||
if (filter) {
|
||||
std::set<Word> srcWords;
|
||||
for (size_t i = 0; i < sentences.size(); ++i) {
|
||||
const Sentence &sentence = *sentences.at(i);
|
||||
for (const auto& srcWord : sentence.GetWords()) {
|
||||
srcWords.insert(srcWord);
|
||||
}
|
||||
}
|
||||
vocabSize = MakeFilter(god, srcWords, vocabSize);
|
||||
}
|
||||
|
||||
size_t maxLength = 0;
|
||||
for (size_t i = 0; i < sentences.size(); ++i) {
|
||||
const Sentence &sentence = *sentences.at(i);
|
||||
maxLength = std::max(maxLength, sentence.GetWords().size());
|
||||
}
|
||||
|
||||
Encode(sentences, states, nextStates);
|
||||
|
||||
for (size_t decoderStep = 0; decoderStep < 3 * maxLength; ++decoderStep) {
|
||||
// Decode
|
||||
std::vector<size_t> beamSizes(batchSize, 1);
|
||||
|
||||
for (size_t decoderStep = 0; decoderStep < 3 * sentences.GetMaxLength(); ++decoderStep) {
|
||||
for (size_t i = 0; i < scorers_.size(); i++) {
|
||||
Scorer &scorer = *scorers_[i];
|
||||
State &state = *states[i];
|
||||
@ -135,5 +117,33 @@ std::shared_ptr<Histories> Search::Process(const God &god, const Sentences& sent
|
||||
return ret;
|
||||
}
|
||||
|
||||
void Search::PreProcess(
|
||||
const God &god,
|
||||
const Sentences& sentences,
|
||||
std::shared_ptr<Histories> ret,
|
||||
Beam &prevHyps)
|
||||
{
|
||||
size_t vocabSize = scorers_[0]->GetVocabSize();
|
||||
|
||||
for (size_t i = 0; i < ret->size(); ++i) {
|
||||
History &history = *ret->at(i).get();
|
||||
history.Add(prevHyps);
|
||||
}
|
||||
|
||||
bool filter = god.Get<std::vector<std::string>>("softmax-filter").size();
|
||||
if (filter) {
|
||||
std::set<Word> srcWords;
|
||||
for (size_t i = 0; i < sentences.size(); ++i) {
|
||||
const Sentence &sentence = *sentences.at(i);
|
||||
for (const auto& srcWord : sentence.GetWords()) {
|
||||
srcWords.insert(srcWord);
|
||||
}
|
||||
}
|
||||
vocabSize = MakeFilter(god, srcWords, vocabSize);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
@ -13,6 +13,12 @@ class Search {
|
||||
Search(const God &god);
|
||||
std::shared_ptr<Histories> Process(const God &god, const Sentences& sentences);
|
||||
|
||||
void PreProcess(
|
||||
const God &god,
|
||||
const Sentences& sentences,
|
||||
std::shared_ptr<Histories> ret,
|
||||
Beam &prevHyps);
|
||||
|
||||
const DeviceInfo &GetDeviceInfo()
|
||||
{ return deviceInfo_; }
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user