towards n-best-lists

This commit is contained in:
Marcin Junczys-Dowmunt 2016-04-14 20:49:38 +02:00
parent 436f0bd52e
commit 19e71f8061
6 changed files with 134 additions and 87 deletions

View File

@ -2,9 +2,9 @@ cmake_minimum_required(VERSION 3.1.0)
set(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake)
project(amunn CXX)
SET(CMAKE_CXX_FLAGS " -std=c++11 -g -O3 -fPIC -funroll-loops -Wno-unused-result -Wno-deprecated")
SET(CMAKE_CXX_FLAGS " -std=c++11 -g -O3 -funroll-loops -Wno-unused-result -Wno-deprecated")
#SET(CUDA_PROPAGATE_HOST_FLAGS OFF)
SET(CUDA_NVCC_FLAGS " -std=c++11 -g -O3 -fPIC -arch=sm_35 -lineinfo --use_fast_math")
SET(CUDA_NVCC_FLAGS " -std=c++11 -g -O3 -arch=sm_35 -lineinfo --use_fast_math")
#SET(CUDA_VERBOSE_BUILD ON)
include_directories(${amunn_SOURCE_DIR})

View File

@ -4,6 +4,10 @@
#include <string>
#include <vector>
#include <fstream>
#include <sstream>
#include "types.h"
#include "utils.h"
class Vocab {
public:
@ -15,8 +19,6 @@ class Vocab {
str2id_[line] = c++;
id2str_.push_back(line);
}
//str2id_["</s>"] = c;
//id2str_.push_back("</s>");
}
size_t operator[](const std::string& word) const {
@ -27,15 +29,32 @@ class Vocab {
return 1;
}
inline std::vector<size_t> Encode(const std::vector<std::string>& sentence, bool addEOS=false) const {
std::vector<size_t> indexes;
for (auto& word: sentence) {
indexes.push_back((*this)[word]);
Sentence operator()(const std::vector<std::string>& lineTokens, bool addEOS = true) const {
Sentence words(lineTokens.size());
std::transform(lineTokens.begin(), lineTokens.end(), words.begin(),
[&](const std::string& w) { return (*this)[w]; });
if(addEOS)
words.push_back(EOS);
return words;
}
Sentence operator()(const std::string& line, bool addEOS = true) const {
std::vector<std::string> lineTokens;
Split(line, lineTokens, " ");
return (*this)(lineTokens, addEOS);
}
std::string operator()(const Sentence& sentence, bool ignoreEOS = true) const {
std::stringstream line;
for(size_t i = 0; i < sentence.size(); ++i) {
if(sentence[i] != EOS || !ignoreEOS) {
if(i > 0) {
line << " ";
}
line << (*this)[sentence[i]];
}
}
if (addEOS) {
indexes.push_back((*this)["</s>"]);
}
return indexes;
return line.str();
}

View File

@ -57,6 +57,33 @@ void ProgramOptions(int argc, char *argv[],
}
}
class BPE {
public:
BPE(const std::string& sep = "@@ ")
: sep_(sep) {}
std::string split(const std::string& line) {
return line;
}
std::string unsplit(const std::string& line) {
std::string joined = line;
size_t pos = joined.find(sep_);
while(pos != std::string::npos) {
joined.erase(pos, sep_.size());
pos = joined.find(sep_, pos);
}
return joined;
}
operator bool() const {
return true;
}
private:
std::string sep_;
};
int main(int argc, char* argv[]) {
std::string modelPath, srcVocabPath, trgVocabPath;
size_t device = 0;
@ -70,17 +97,23 @@ int main(int argc, char* argv[]) {
Vocab trgVocab(trgVocabPath);
std::cerr << "done." << std::endl;
Search search(model, srcVocab, trgVocab);
Search search(model);
std::cerr << "Translating...\n";
std::ios_base::sync_with_stdio(false);
std::string line;
BPE bpe;
boost::timer::cpu_timer timer;
while(std::getline(std::cin, line)) {
auto result = search.Decode(line, beamSize);
std::cout << result << std::endl;
std::string in;
while(std::getline(std::cin, in)) {
Sentence sentence = bpe ? srcVocab(bpe.split(in)) : srcVocab(in);
History history = search.Decode(sentence, beamSize);
std::string out = trgVocab(history.Top().first);
if(bpe)
out = bpe.unsplit(out);
std::cout << out << std::endl;
}
std::cerr << timer.format() << std::endl;
return 0;

View File

@ -6,6 +6,7 @@
#include <algorithm>
#include <limits>
#include <sstream>
#include <queue>
#include <boost/timer/timer.hpp>
#include <thrust/functional.h>
@ -16,24 +17,70 @@
#include <thrust/sort.h>
#include <thrust/sequence.h>
#include "types.h"
#include "matrix.h"
#include "dl4mt.h"
#include "vocab.h"
#include "hypothesis.h"
#include "utils.h"
#define EOL "</s>"
typedef std::vector<Hypothesis> Beam;
typedef std::pair<Sentence, Hypothesis> Result;
typedef std::vector<Result> NBestList;
class History {
private:
struct HypothesisCoord {
bool operator<(const HypothesisCoord& hc) const {
return cost < hc.cost;
}
size_t i;
size_t j;
float cost;
};
public:
void Add(const Beam& beam, bool last = false) {
for(size_t j = 0; j < beam.size(); ++j)
if(beam[j].GetWord() == EOS || last)
topHyps_.push({ history_.size(), j, beam[j].GetCost() });
history_.push_back(beam);
}
size_t size() const {
return history_.size();
}
NBestList NBest(size_t n) {
}
Result Top() const {
Sentence targetWords;
auto bestHypCoord = topHyps_.top();
size_t start = bestHypCoord.i;
size_t j = bestHypCoord.j;
for(int i = start; i >= 0; i--) {
auto& bestHyp = history_[i][j];
targetWords.push_back(bestHyp.GetWord());
j = bestHyp.GetPrevStateIndex();
}
std::reverse(targetWords.begin(), targetWords.end());
return Result(targetWords, history_[bestHypCoord.i][bestHypCoord.j]);
}
private:
std::vector<Beam> history_;
std::priority_queue<HypothesisCoord> topHyps_;
};
class Search {
typedef std::vector<Hypothesis> Beam;
typedef std::vector<Beam> History;
private:
const Weights& model_;
Encoder encoder_;
Decoder decoder_;
const Vocab svcb_;
const Vocab tvcb_;
mblas::Matrix State_, NextState_, BeamState_;
mblas::Matrix Embeddings_, NextEmbeddings_;
@ -41,24 +88,14 @@ class Search {
mblas::Matrix SourceContext_;
public:
Search(const Weights& model, const Vocab& svcb, const Vocab tvcb)
Search(const Weights& model)
: model_(model),
encoder_(model_),
decoder_(model_),
svcb_(svcb), tvcb_(tvcb)
decoder_(model_)
{}
std::string Decode(const std::string& source, size_t beamSize = 12) {
// this should happen somewhere else
std::vector<std::string> sourceSplit;
Split(source, sourceSplit, " ");
std::vector<size_t> sourceWords(sourceSplit.size());
std::transform(sourceSplit.begin(), sourceSplit.end(), sourceWords.begin(),
[&](const std::string& w) { return svcb_[w]; });
sourceWords.push_back(svcb_[EOL]);
History Decode(const Sentence sourceWords, size_t beamSize = 12) {
encoder_.GetContext(sourceWords, SourceContext_);
decoder_.EmptyState(State_, SourceContext_, 1);
decoder_.EmptyEmbedding(Embeddings_, 1);
@ -72,13 +109,13 @@ class Search {
Beam hyps;
BestHyps(hyps, prevHyps, Probs_, beamSize);
history.push_back(hyps);
history.Add(hyps, history.size() + 1 == sourceWords.size() * 3);
Beam survivors;
std::vector<size_t> beamWords;
std::vector<size_t> beamStateIds;
for(auto& h : hyps) {
if(h.GetWord() != tvcb_[EOL]) {
if(h.GetWord() != EOS) {
survivors.push_back(h);
beamWords.push_back(h.GetWord());
beamStateIds.push_back(h.GetPrevStateIndex());
@ -98,7 +135,7 @@ class Search {
} while(history.size() < sourceWords.size() * 3);
return FindBest(history);
return history;
}
void BestHyps(Beam& bestHyps, const Beam& prevHyps, mblas::Matrix& Probs, const size_t beamSize) {
@ -131,44 +168,4 @@ class Search {
bestHyps.emplace_back(wordIndex, hypIndex, cost);
}
}
std::string FindBest(const History& history) {
std::vector<size_t> targetWords;
size_t best = 0;
size_t beamSize = 0;
float bestCost = std::numeric_limits<float>::lowest();
for(auto b = history.rbegin(); b != history.rend(); b++) {
if(b->size() > beamSize) {
beamSize = b->size();
for(size_t i = 0; i < beamSize; ++i) {
if(b == history.rbegin() || (*b)[i].GetWord() == tvcb_[EOL]) {
if((*b)[i].GetCost() > bestCost) {
best = i;
bestCost = (*b)[i].GetCost();
targetWords.clear();
}
}
}
}
auto& bestHyp = (*b)[best];
targetWords.push_back(bestHyp.GetWord());
best = bestHyp.GetPrevStateIndex();
}
std::reverse(targetWords.begin(), targetWords.end());
std::stringstream translation;
for(size_t i = 0; i < targetWords.size(); ++i) {
if(tvcb_[targetWords[i]] != EOL) {
if(i > 0) {
translation << " ";
}
translation << tvcb_[targetWords[i]];
}
}
return translation.str();
}
};

View File

@ -39,9 +39,7 @@ std::vector<std::string> NBest::GetTokens(const size_t index) const {
}
std::vector<size_t> NBest::GetEncodedTokens(const size_t index) const {
std::vector<std::string> tokens;
Split(srcSentences_[index], tokens);
return srcVocab_->Encode(tokens, true);
return (*srcVocab_)(srcSentences_[index]);
}
void NBest::Parse_(const std::string& path) {
@ -76,10 +74,10 @@ inline std::vector< std::vector< std::string > > NBest::SplitBatch(std::vector<s
return splittedBatch;
}
inline Batch NBest::EncodeBatch(const std::vector<std::vector<std::string>>& batch) const {
inline Batch NBest::EncodeBatch(const std::vector<std::string>& batch) const {
Batch encodedBatch;
for (auto& sentence: batch) {
encodedBatch.push_back(trgVocab_->Encode(sentence, true));
encodedBatch.push_back((*trgVocab_)(sentence));
}
return encodedBatch;
}
@ -103,7 +101,7 @@ inline Batch NBest::MaskAndTransposeBatch(const Batch& batch) const {
Batch NBest::ProcessBatch(std::vector<std::string>& batch) const {
return MaskAndTransposeBatch(EncodeBatch(SplitBatch(batch)));
return MaskAndTransposeBatch(EncodeBatch(batch));
}
std::vector<Batch> NBest::GetBatches(const size_t index) const {

View File

@ -41,7 +41,7 @@ class NBest {
std::vector<std::vector<std::string>> SplitBatch(std::vector<std::string>& batch) const;
void ParseInputFile(const std::string& path);
Batch EncodeBatch(const std::vector<std::vector<std::string>>& batch) const;
Batch EncodeBatch(const std::vector<std::string>& batch) const;
Batch MaskAndTransposeBatch(const Batch& batch) const;