mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-30 21:39:52 +03:00
towards n-best-lists
This commit is contained in:
parent
436f0bd52e
commit
19e71f8061
@ -2,9 +2,9 @@ cmake_minimum_required(VERSION 3.1.0)
|
|||||||
set(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake)
|
set(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake)
|
||||||
|
|
||||||
project(amunn CXX)
|
project(amunn CXX)
|
||||||
SET(CMAKE_CXX_FLAGS " -std=c++11 -g -O3 -fPIC -funroll-loops -Wno-unused-result -Wno-deprecated")
|
SET(CMAKE_CXX_FLAGS " -std=c++11 -g -O3 -funroll-loops -Wno-unused-result -Wno-deprecated")
|
||||||
#SET(CUDA_PROPAGATE_HOST_FLAGS OFF)
|
#SET(CUDA_PROPAGATE_HOST_FLAGS OFF)
|
||||||
SET(CUDA_NVCC_FLAGS " -std=c++11 -g -O3 -fPIC -arch=sm_35 -lineinfo --use_fast_math")
|
SET(CUDA_NVCC_FLAGS " -std=c++11 -g -O3 -arch=sm_35 -lineinfo --use_fast_math")
|
||||||
#SET(CUDA_VERBOSE_BUILD ON)
|
#SET(CUDA_VERBOSE_BUILD ON)
|
||||||
|
|
||||||
include_directories(${amunn_SOURCE_DIR})
|
include_directories(${amunn_SOURCE_DIR})
|
||||||
|
@ -4,6 +4,10 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
#include "types.h"
|
||||||
|
#include "utils.h"
|
||||||
|
|
||||||
class Vocab {
|
class Vocab {
|
||||||
public:
|
public:
|
||||||
@ -15,8 +19,6 @@ class Vocab {
|
|||||||
str2id_[line] = c++;
|
str2id_[line] = c++;
|
||||||
id2str_.push_back(line);
|
id2str_.push_back(line);
|
||||||
}
|
}
|
||||||
//str2id_["</s>"] = c;
|
|
||||||
//id2str_.push_back("</s>");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t operator[](const std::string& word) const {
|
size_t operator[](const std::string& word) const {
|
||||||
@ -27,15 +29,32 @@ class Vocab {
|
|||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline std::vector<size_t> Encode(const std::vector<std::string>& sentence, bool addEOS=false) const {
|
Sentence operator()(const std::vector<std::string>& lineTokens, bool addEOS = true) const {
|
||||||
std::vector<size_t> indexes;
|
Sentence words(lineTokens.size());
|
||||||
for (auto& word: sentence) {
|
std::transform(lineTokens.begin(), lineTokens.end(), words.begin(),
|
||||||
indexes.push_back((*this)[word]);
|
[&](const std::string& w) { return (*this)[w]; });
|
||||||
|
if(addEOS)
|
||||||
|
words.push_back(EOS);
|
||||||
|
return words;
|
||||||
|
}
|
||||||
|
|
||||||
|
Sentence operator()(const std::string& line, bool addEOS = true) const {
|
||||||
|
std::vector<std::string> lineTokens;
|
||||||
|
Split(line, lineTokens, " ");
|
||||||
|
return (*this)(lineTokens, addEOS);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string operator()(const Sentence& sentence, bool ignoreEOS = true) const {
|
||||||
|
std::stringstream line;
|
||||||
|
for(size_t i = 0; i < sentence.size(); ++i) {
|
||||||
|
if(sentence[i] != EOS || !ignoreEOS) {
|
||||||
|
if(i > 0) {
|
||||||
|
line << " ";
|
||||||
|
}
|
||||||
|
line << (*this)[sentence[i]];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (addEOS) {
|
return line.str();
|
||||||
indexes.push_back((*this)["</s>"]);
|
|
||||||
}
|
|
||||||
return indexes;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -57,6 +57,33 @@ void ProgramOptions(int argc, char *argv[],
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class BPE {
|
||||||
|
public:
|
||||||
|
BPE(const std::string& sep = "@@ ")
|
||||||
|
: sep_(sep) {}
|
||||||
|
|
||||||
|
std::string split(const std::string& line) {
|
||||||
|
return line;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string unsplit(const std::string& line) {
|
||||||
|
std::string joined = line;
|
||||||
|
size_t pos = joined.find(sep_);
|
||||||
|
while(pos != std::string::npos) {
|
||||||
|
joined.erase(pos, sep_.size());
|
||||||
|
pos = joined.find(sep_, pos);
|
||||||
|
}
|
||||||
|
return joined;
|
||||||
|
}
|
||||||
|
|
||||||
|
operator bool() const {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::string sep_;
|
||||||
|
};
|
||||||
|
|
||||||
int main(int argc, char* argv[]) {
|
int main(int argc, char* argv[]) {
|
||||||
std::string modelPath, srcVocabPath, trgVocabPath;
|
std::string modelPath, srcVocabPath, trgVocabPath;
|
||||||
size_t device = 0;
|
size_t device = 0;
|
||||||
@ -70,17 +97,23 @@ int main(int argc, char* argv[]) {
|
|||||||
Vocab trgVocab(trgVocabPath);
|
Vocab trgVocab(trgVocabPath);
|
||||||
std::cerr << "done." << std::endl;
|
std::cerr << "done." << std::endl;
|
||||||
|
|
||||||
Search search(model, srcVocab, trgVocab);
|
Search search(model);
|
||||||
|
|
||||||
std::cerr << "Translating...\n";
|
std::cerr << "Translating...\n";
|
||||||
|
|
||||||
std::ios_base::sync_with_stdio(false);
|
std::ios_base::sync_with_stdio(false);
|
||||||
|
|
||||||
std::string line;
|
BPE bpe;
|
||||||
|
|
||||||
boost::timer::cpu_timer timer;
|
boost::timer::cpu_timer timer;
|
||||||
while(std::getline(std::cin, line)) {
|
std::string in;
|
||||||
auto result = search.Decode(line, beamSize);
|
while(std::getline(std::cin, in)) {
|
||||||
std::cout << result << std::endl;
|
Sentence sentence = bpe ? srcVocab(bpe.split(in)) : srcVocab(in);
|
||||||
|
History history = search.Decode(sentence, beamSize);
|
||||||
|
std::string out = trgVocab(history.Top().first);
|
||||||
|
if(bpe)
|
||||||
|
out = bpe.unsplit(out);
|
||||||
|
std::cout << out << std::endl;
|
||||||
}
|
}
|
||||||
std::cerr << timer.format() << std::endl;
|
std::cerr << timer.format() << std::endl;
|
||||||
return 0;
|
return 0;
|
||||||
|
@ -6,6 +6,7 @@
|
|||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
#include <queue>
|
||||||
#include <boost/timer/timer.hpp>
|
#include <boost/timer/timer.hpp>
|
||||||
|
|
||||||
#include <thrust/functional.h>
|
#include <thrust/functional.h>
|
||||||
@ -16,24 +17,70 @@
|
|||||||
#include <thrust/sort.h>
|
#include <thrust/sort.h>
|
||||||
#include <thrust/sequence.h>
|
#include <thrust/sequence.h>
|
||||||
|
|
||||||
|
#include "types.h"
|
||||||
#include "matrix.h"
|
#include "matrix.h"
|
||||||
#include "dl4mt.h"
|
#include "dl4mt.h"
|
||||||
#include "vocab.h"
|
|
||||||
#include "hypothesis.h"
|
#include "hypothesis.h"
|
||||||
#include "utils.h"
|
#include "utils.h"
|
||||||
|
|
||||||
#define EOL "</s>"
|
typedef std::vector<Hypothesis> Beam;
|
||||||
|
typedef std::pair<Sentence, Hypothesis> Result;
|
||||||
|
typedef std::vector<Result> NBestList;
|
||||||
|
|
||||||
|
class History {
|
||||||
|
private:
|
||||||
|
struct HypothesisCoord {
|
||||||
|
bool operator<(const HypothesisCoord& hc) const {
|
||||||
|
return cost < hc.cost;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t i;
|
||||||
|
size_t j;
|
||||||
|
float cost;
|
||||||
|
};
|
||||||
|
|
||||||
|
public:
|
||||||
|
void Add(const Beam& beam, bool last = false) {
|
||||||
|
for(size_t j = 0; j < beam.size(); ++j)
|
||||||
|
if(beam[j].GetWord() == EOS || last)
|
||||||
|
topHyps_.push({ history_.size(), j, beam[j].GetCost() });
|
||||||
|
history_.push_back(beam);
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t size() const {
|
||||||
|
return history_.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
NBestList NBest(size_t n) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
Result Top() const {
|
||||||
|
Sentence targetWords;
|
||||||
|
auto bestHypCoord = topHyps_.top();
|
||||||
|
size_t start = bestHypCoord.i;
|
||||||
|
size_t j = bestHypCoord.j;
|
||||||
|
for(int i = start; i >= 0; i--) {
|
||||||
|
auto& bestHyp = history_[i][j];
|
||||||
|
targetWords.push_back(bestHyp.GetWord());
|
||||||
|
j = bestHyp.GetPrevStateIndex();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::reverse(targetWords.begin(), targetWords.end());
|
||||||
|
return Result(targetWords, history_[bestHypCoord.i][bestHypCoord.j]);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::vector<Beam> history_;
|
||||||
|
std::priority_queue<HypothesisCoord> topHyps_;
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
class Search {
|
class Search {
|
||||||
typedef std::vector<Hypothesis> Beam;
|
|
||||||
typedef std::vector<Beam> History;
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const Weights& model_;
|
const Weights& model_;
|
||||||
Encoder encoder_;
|
Encoder encoder_;
|
||||||
Decoder decoder_;
|
Decoder decoder_;
|
||||||
const Vocab svcb_;
|
|
||||||
const Vocab tvcb_;
|
|
||||||
|
|
||||||
mblas::Matrix State_, NextState_, BeamState_;
|
mblas::Matrix State_, NextState_, BeamState_;
|
||||||
mblas::Matrix Embeddings_, NextEmbeddings_;
|
mblas::Matrix Embeddings_, NextEmbeddings_;
|
||||||
@ -41,24 +88,14 @@ class Search {
|
|||||||
mblas::Matrix SourceContext_;
|
mblas::Matrix SourceContext_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
Search(const Weights& model, const Vocab& svcb, const Vocab tvcb)
|
Search(const Weights& model)
|
||||||
: model_(model),
|
: model_(model),
|
||||||
encoder_(model_),
|
encoder_(model_),
|
||||||
decoder_(model_),
|
decoder_(model_)
|
||||||
svcb_(svcb), tvcb_(tvcb)
|
|
||||||
{}
|
{}
|
||||||
|
|
||||||
std::string Decode(const std::string& source, size_t beamSize = 12) {
|
History Decode(const Sentence sourceWords, size_t beamSize = 12) {
|
||||||
// this should happen somewhere else
|
|
||||||
std::vector<std::string> sourceSplit;
|
|
||||||
Split(source, sourceSplit, " ");
|
|
||||||
std::vector<size_t> sourceWords(sourceSplit.size());
|
|
||||||
std::transform(sourceSplit.begin(), sourceSplit.end(), sourceWords.begin(),
|
|
||||||
[&](const std::string& w) { return svcb_[w]; });
|
|
||||||
sourceWords.push_back(svcb_[EOL]);
|
|
||||||
|
|
||||||
encoder_.GetContext(sourceWords, SourceContext_);
|
encoder_.GetContext(sourceWords, SourceContext_);
|
||||||
|
|
||||||
decoder_.EmptyState(State_, SourceContext_, 1);
|
decoder_.EmptyState(State_, SourceContext_, 1);
|
||||||
decoder_.EmptyEmbedding(Embeddings_, 1);
|
decoder_.EmptyEmbedding(Embeddings_, 1);
|
||||||
|
|
||||||
@ -72,13 +109,13 @@ class Search {
|
|||||||
|
|
||||||
Beam hyps;
|
Beam hyps;
|
||||||
BestHyps(hyps, prevHyps, Probs_, beamSize);
|
BestHyps(hyps, prevHyps, Probs_, beamSize);
|
||||||
history.push_back(hyps);
|
history.Add(hyps, history.size() + 1 == sourceWords.size() * 3);
|
||||||
|
|
||||||
Beam survivors;
|
Beam survivors;
|
||||||
std::vector<size_t> beamWords;
|
std::vector<size_t> beamWords;
|
||||||
std::vector<size_t> beamStateIds;
|
std::vector<size_t> beamStateIds;
|
||||||
for(auto& h : hyps) {
|
for(auto& h : hyps) {
|
||||||
if(h.GetWord() != tvcb_[EOL]) {
|
if(h.GetWord() != EOS) {
|
||||||
survivors.push_back(h);
|
survivors.push_back(h);
|
||||||
beamWords.push_back(h.GetWord());
|
beamWords.push_back(h.GetWord());
|
||||||
beamStateIds.push_back(h.GetPrevStateIndex());
|
beamStateIds.push_back(h.GetPrevStateIndex());
|
||||||
@ -98,7 +135,7 @@ class Search {
|
|||||||
|
|
||||||
} while(history.size() < sourceWords.size() * 3);
|
} while(history.size() < sourceWords.size() * 3);
|
||||||
|
|
||||||
return FindBest(history);
|
return history;
|
||||||
}
|
}
|
||||||
|
|
||||||
void BestHyps(Beam& bestHyps, const Beam& prevHyps, mblas::Matrix& Probs, const size_t beamSize) {
|
void BestHyps(Beam& bestHyps, const Beam& prevHyps, mblas::Matrix& Probs, const size_t beamSize) {
|
||||||
@ -131,44 +168,4 @@ class Search {
|
|||||||
bestHyps.emplace_back(wordIndex, hypIndex, cost);
|
bestHyps.emplace_back(wordIndex, hypIndex, cost);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string FindBest(const History& history) {
|
|
||||||
std::vector<size_t> targetWords;
|
|
||||||
|
|
||||||
size_t best = 0;
|
|
||||||
size_t beamSize = 0;
|
|
||||||
float bestCost = std::numeric_limits<float>::lowest();
|
|
||||||
|
|
||||||
for(auto b = history.rbegin(); b != history.rend(); b++) {
|
|
||||||
if(b->size() > beamSize) {
|
|
||||||
beamSize = b->size();
|
|
||||||
for(size_t i = 0; i < beamSize; ++i) {
|
|
||||||
if(b == history.rbegin() || (*b)[i].GetWord() == tvcb_[EOL]) {
|
|
||||||
if((*b)[i].GetCost() > bestCost) {
|
|
||||||
best = i;
|
|
||||||
bestCost = (*b)[i].GetCost();
|
|
||||||
targetWords.clear();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
auto& bestHyp = (*b)[best];
|
|
||||||
targetWords.push_back(bestHyp.GetWord());
|
|
||||||
best = bestHyp.GetPrevStateIndex();
|
|
||||||
}
|
|
||||||
|
|
||||||
std::reverse(targetWords.begin(), targetWords.end());
|
|
||||||
std::stringstream translation;
|
|
||||||
for(size_t i = 0; i < targetWords.size(); ++i) {
|
|
||||||
if(tvcb_[targetWords[i]] != EOL) {
|
|
||||||
if(i > 0) {
|
|
||||||
translation << " ";
|
|
||||||
}
|
|
||||||
translation << tvcb_[targetWords[i]];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return translation.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
};
|
};
|
@ -39,9 +39,7 @@ std::vector<std::string> NBest::GetTokens(const size_t index) const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::vector<size_t> NBest::GetEncodedTokens(const size_t index) const {
|
std::vector<size_t> NBest::GetEncodedTokens(const size_t index) const {
|
||||||
std::vector<std::string> tokens;
|
return (*srcVocab_)(srcSentences_[index]);
|
||||||
Split(srcSentences_[index], tokens);
|
|
||||||
return srcVocab_->Encode(tokens, true);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void NBest::Parse_(const std::string& path) {
|
void NBest::Parse_(const std::string& path) {
|
||||||
@ -76,10 +74,10 @@ inline std::vector< std::vector< std::string > > NBest::SplitBatch(std::vector<s
|
|||||||
return splittedBatch;
|
return splittedBatch;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline Batch NBest::EncodeBatch(const std::vector<std::vector<std::string>>& batch) const {
|
inline Batch NBest::EncodeBatch(const std::vector<std::string>& batch) const {
|
||||||
Batch encodedBatch;
|
Batch encodedBatch;
|
||||||
for (auto& sentence: batch) {
|
for (auto& sentence: batch) {
|
||||||
encodedBatch.push_back(trgVocab_->Encode(sentence, true));
|
encodedBatch.push_back((*trgVocab_)(sentence));
|
||||||
}
|
}
|
||||||
return encodedBatch;
|
return encodedBatch;
|
||||||
}
|
}
|
||||||
@ -103,7 +101,7 @@ inline Batch NBest::MaskAndTransposeBatch(const Batch& batch) const {
|
|||||||
|
|
||||||
|
|
||||||
Batch NBest::ProcessBatch(std::vector<std::string>& batch) const {
|
Batch NBest::ProcessBatch(std::vector<std::string>& batch) const {
|
||||||
return MaskAndTransposeBatch(EncodeBatch(SplitBatch(batch)));
|
return MaskAndTransposeBatch(EncodeBatch(batch));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<Batch> NBest::GetBatches(const size_t index) const {
|
std::vector<Batch> NBest::GetBatches(const size_t index) const {
|
||||||
|
@ -41,7 +41,7 @@ class NBest {
|
|||||||
std::vector<std::vector<std::string>> SplitBatch(std::vector<std::string>& batch) const;
|
std::vector<std::vector<std::string>> SplitBatch(std::vector<std::string>& batch) const;
|
||||||
void ParseInputFile(const std::string& path);
|
void ParseInputFile(const std::string& path);
|
||||||
|
|
||||||
Batch EncodeBatch(const std::vector<std::vector<std::string>>& batch) const;
|
Batch EncodeBatch(const std::vector<std::string>& batch) const;
|
||||||
|
|
||||||
Batch MaskAndTransposeBatch(const Batch& batch) const;
|
Batch MaskAndTransposeBatch(const Batch& batch) const;
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user