create PreProcess() method

This commit is contained in:
Hieu Hoang 2017-01-30 16:25:03 +00:00
parent ac2217f37e
commit 3a12bbd233
2 changed files with 42 additions and 26 deletions

View File

@ -45,40 +45,22 @@ std::shared_ptr<Histories> Search::Process(const God &god, const Sentences& sent
std::shared_ptr<Histories> ret(new Histories(god, sentences));
size_t batchSize = sentences.size();
std::vector<size_t> beamSizes(batchSize, 1);
size_t vocabSize = scorers_[0]->GetVocabSize();
Beam prevHyps(batchSize, HypothesisPtr(new Hypothesis()));
for (size_t i = 0; i < ret->size(); ++i) {
History &history = *ret->at(i).get();
history.Add(prevHyps);
}
PreProcess(god, sentences, ret, prevHyps);
// Encode
States states(scorers_.size());
States nextStates(scorers_.size());
size_t vocabSize = scorers_[0]->GetVocabSize();
bool filter = god.Get<std::vector<std::string>>("softmax-filter").size();
if (filter) {
std::set<Word> srcWords;
for (size_t i = 0; i < sentences.size(); ++i) {
const Sentence &sentence = *sentences.at(i);
for (const auto& srcWord : sentence.GetWords()) {
srcWords.insert(srcWord);
}
}
vocabSize = MakeFilter(god, srcWords, vocabSize);
}
size_t maxLength = 0;
for (size_t i = 0; i < sentences.size(); ++i) {
const Sentence &sentence = *sentences.at(i);
maxLength = std::max(maxLength, sentence.GetWords().size());
}
Encode(sentences, states, nextStates);
for (size_t decoderStep = 0; decoderStep < 3 * maxLength; ++decoderStep) {
// Decode
std::vector<size_t> beamSizes(batchSize, 1);
for (size_t decoderStep = 0; decoderStep < 3 * sentences.GetMaxLength(); ++decoderStep) {
for (size_t i = 0; i < scorers_.size(); i++) {
Scorer &scorer = *scorers_[i];
State &state = *states[i];
@ -135,5 +117,33 @@ std::shared_ptr<Histories> Search::Process(const God &god, const Sentences& sent
return ret;
}
void Search::PreProcess(
const God &god,
const Sentences& sentences,
std::shared_ptr<Histories> ret,
Beam &prevHyps)
{
size_t vocabSize = scorers_[0]->GetVocabSize();
for (size_t i = 0; i < ret->size(); ++i) {
History &history = *ret->at(i).get();
history.Add(prevHyps);
}
bool filter = god.Get<std::vector<std::string>>("softmax-filter").size();
if (filter) {
std::set<Word> srcWords;
for (size_t i = 0; i < sentences.size(); ++i) {
const Sentence &sentence = *sentences.at(i);
for (const auto& srcWord : sentence.GetWords()) {
srcWords.insert(srcWord);
}
}
vocabSize = MakeFilter(god, srcWords, vocabSize);
}
}
}

View File

@ -13,6 +13,12 @@ class Search {
Search(const God &god);
std::shared_ptr<Histories> Process(const God &god, const Sentences& sentences);
void PreProcess(
const God &god,
const Sentences& sentences,
std::shared_ptr<Histories> ret,
Beam &prevHyps);
const DeviceInfo &GetDeviceInfo()
{ return deviceInfo_; }