mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-30 21:39:52 +03:00
towards n-best-lists
This commit is contained in:
parent
436f0bd52e
commit
19e71f8061
@ -2,9 +2,9 @@ cmake_minimum_required(VERSION 3.1.0)
|
||||
set(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake)
|
||||
|
||||
project(amunn CXX)
|
||||
SET(CMAKE_CXX_FLAGS " -std=c++11 -g -O3 -fPIC -funroll-loops -Wno-unused-result -Wno-deprecated")
|
||||
SET(CMAKE_CXX_FLAGS " -std=c++11 -g -O3 -funroll-loops -Wno-unused-result -Wno-deprecated")
|
||||
#SET(CUDA_PROPAGATE_HOST_FLAGS OFF)
|
||||
SET(CUDA_NVCC_FLAGS " -std=c++11 -g -O3 -fPIC -arch=sm_35 -lineinfo --use_fast_math")
|
||||
SET(CUDA_NVCC_FLAGS " -std=c++11 -g -O3 -arch=sm_35 -lineinfo --use_fast_math")
|
||||
#SET(CUDA_VERBOSE_BUILD ON)
|
||||
|
||||
include_directories(${amunn_SOURCE_DIR})
|
||||
|
@ -4,6 +4,10 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
|
||||
#include "types.h"
|
||||
#include "utils.h"
|
||||
|
||||
class Vocab {
|
||||
public:
|
||||
@ -15,8 +19,6 @@ class Vocab {
|
||||
str2id_[line] = c++;
|
||||
id2str_.push_back(line);
|
||||
}
|
||||
//str2id_["</s>"] = c;
|
||||
//id2str_.push_back("</s>");
|
||||
}
|
||||
|
||||
size_t operator[](const std::string& word) const {
|
||||
@ -27,15 +29,32 @@ class Vocab {
|
||||
return 1;
|
||||
}
|
||||
|
||||
inline std::vector<size_t> Encode(const std::vector<std::string>& sentence, bool addEOS=false) const {
|
||||
std::vector<size_t> indexes;
|
||||
for (auto& word: sentence) {
|
||||
indexes.push_back((*this)[word]);
|
||||
Sentence operator()(const std::vector<std::string>& lineTokens, bool addEOS = true) const {
|
||||
Sentence words(lineTokens.size());
|
||||
std::transform(lineTokens.begin(), lineTokens.end(), words.begin(),
|
||||
[&](const std::string& w) { return (*this)[w]; });
|
||||
if(addEOS)
|
||||
words.push_back(EOS);
|
||||
return words;
|
||||
}
|
||||
|
||||
Sentence operator()(const std::string& line, bool addEOS = true) const {
|
||||
std::vector<std::string> lineTokens;
|
||||
Split(line, lineTokens, " ");
|
||||
return (*this)(lineTokens, addEOS);
|
||||
}
|
||||
|
||||
std::string operator()(const Sentence& sentence, bool ignoreEOS = true) const {
|
||||
std::stringstream line;
|
||||
for(size_t i = 0; i < sentence.size(); ++i) {
|
||||
if(sentence[i] != EOS || !ignoreEOS) {
|
||||
if(i > 0) {
|
||||
line << " ";
|
||||
}
|
||||
line << (*this)[sentence[i]];
|
||||
}
|
||||
}
|
||||
if (addEOS) {
|
||||
indexes.push_back((*this)["</s>"]);
|
||||
}
|
||||
return indexes;
|
||||
return line.str();
|
||||
}
|
||||
|
||||
|
||||
|
@ -57,6 +57,33 @@ void ProgramOptions(int argc, char *argv[],
|
||||
}
|
||||
}
|
||||
|
||||
class BPE {
|
||||
public:
|
||||
BPE(const std::string& sep = "@@ ")
|
||||
: sep_(sep) {}
|
||||
|
||||
std::string split(const std::string& line) {
|
||||
return line;
|
||||
}
|
||||
|
||||
std::string unsplit(const std::string& line) {
|
||||
std::string joined = line;
|
||||
size_t pos = joined.find(sep_);
|
||||
while(pos != std::string::npos) {
|
||||
joined.erase(pos, sep_.size());
|
||||
pos = joined.find(sep_, pos);
|
||||
}
|
||||
return joined;
|
||||
}
|
||||
|
||||
operator bool() const {
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
std::string sep_;
|
||||
};
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
std::string modelPath, srcVocabPath, trgVocabPath;
|
||||
size_t device = 0;
|
||||
@ -70,17 +97,23 @@ int main(int argc, char* argv[]) {
|
||||
Vocab trgVocab(trgVocabPath);
|
||||
std::cerr << "done." << std::endl;
|
||||
|
||||
Search search(model, srcVocab, trgVocab);
|
||||
Search search(model);
|
||||
|
||||
std::cerr << "Translating...\n";
|
||||
|
||||
std::ios_base::sync_with_stdio(false);
|
||||
|
||||
std::string line;
|
||||
BPE bpe;
|
||||
|
||||
boost::timer::cpu_timer timer;
|
||||
while(std::getline(std::cin, line)) {
|
||||
auto result = search.Decode(line, beamSize);
|
||||
std::cout << result << std::endl;
|
||||
std::string in;
|
||||
while(std::getline(std::cin, in)) {
|
||||
Sentence sentence = bpe ? srcVocab(bpe.split(in)) : srcVocab(in);
|
||||
History history = search.Decode(sentence, beamSize);
|
||||
std::string out = trgVocab(history.Top().first);
|
||||
if(bpe)
|
||||
out = bpe.unsplit(out);
|
||||
std::cout << out << std::endl;
|
||||
}
|
||||
std::cerr << timer.format() << std::endl;
|
||||
return 0;
|
||||
|
@ -6,6 +6,7 @@
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
#include <sstream>
|
||||
#include <queue>
|
||||
#include <boost/timer/timer.hpp>
|
||||
|
||||
#include <thrust/functional.h>
|
||||
@ -16,24 +17,70 @@
|
||||
#include <thrust/sort.h>
|
||||
#include <thrust/sequence.h>
|
||||
|
||||
#include "types.h"
|
||||
#include "matrix.h"
|
||||
#include "dl4mt.h"
|
||||
#include "vocab.h"
|
||||
#include "hypothesis.h"
|
||||
#include "utils.h"
|
||||
|
||||
#define EOL "</s>"
|
||||
typedef std::vector<Hypothesis> Beam;
|
||||
typedef std::pair<Sentence, Hypothesis> Result;
|
||||
typedef std::vector<Result> NBestList;
|
||||
|
||||
class History {
|
||||
private:
|
||||
struct HypothesisCoord {
|
||||
bool operator<(const HypothesisCoord& hc) const {
|
||||
return cost < hc.cost;
|
||||
}
|
||||
|
||||
size_t i;
|
||||
size_t j;
|
||||
float cost;
|
||||
};
|
||||
|
||||
public:
|
||||
void Add(const Beam& beam, bool last = false) {
|
||||
for(size_t j = 0; j < beam.size(); ++j)
|
||||
if(beam[j].GetWord() == EOS || last)
|
||||
topHyps_.push({ history_.size(), j, beam[j].GetCost() });
|
||||
history_.push_back(beam);
|
||||
}
|
||||
|
||||
size_t size() const {
|
||||
return history_.size();
|
||||
}
|
||||
|
||||
NBestList NBest(size_t n) {
|
||||
|
||||
}
|
||||
|
||||
Result Top() const {
|
||||
Sentence targetWords;
|
||||
auto bestHypCoord = topHyps_.top();
|
||||
size_t start = bestHypCoord.i;
|
||||
size_t j = bestHypCoord.j;
|
||||
for(int i = start; i >= 0; i--) {
|
||||
auto& bestHyp = history_[i][j];
|
||||
targetWords.push_back(bestHyp.GetWord());
|
||||
j = bestHyp.GetPrevStateIndex();
|
||||
}
|
||||
|
||||
std::reverse(targetWords.begin(), targetWords.end());
|
||||
return Result(targetWords, history_[bestHypCoord.i][bestHypCoord.j]);
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<Beam> history_;
|
||||
std::priority_queue<HypothesisCoord> topHyps_;
|
||||
|
||||
};
|
||||
|
||||
class Search {
|
||||
typedef std::vector<Hypothesis> Beam;
|
||||
typedef std::vector<Beam> History;
|
||||
|
||||
private:
|
||||
const Weights& model_;
|
||||
Encoder encoder_;
|
||||
Decoder decoder_;
|
||||
const Vocab svcb_;
|
||||
const Vocab tvcb_;
|
||||
|
||||
mblas::Matrix State_, NextState_, BeamState_;
|
||||
mblas::Matrix Embeddings_, NextEmbeddings_;
|
||||
@ -41,24 +88,14 @@ class Search {
|
||||
mblas::Matrix SourceContext_;
|
||||
|
||||
public:
|
||||
Search(const Weights& model, const Vocab& svcb, const Vocab tvcb)
|
||||
Search(const Weights& model)
|
||||
: model_(model),
|
||||
encoder_(model_),
|
||||
decoder_(model_),
|
||||
svcb_(svcb), tvcb_(tvcb)
|
||||
decoder_(model_)
|
||||
{}
|
||||
|
||||
std::string Decode(const std::string& source, size_t beamSize = 12) {
|
||||
// this should happen somewhere else
|
||||
std::vector<std::string> sourceSplit;
|
||||
Split(source, sourceSplit, " ");
|
||||
std::vector<size_t> sourceWords(sourceSplit.size());
|
||||
std::transform(sourceSplit.begin(), sourceSplit.end(), sourceWords.begin(),
|
||||
[&](const std::string& w) { return svcb_[w]; });
|
||||
sourceWords.push_back(svcb_[EOL]);
|
||||
|
||||
History Decode(const Sentence sourceWords, size_t beamSize = 12) {
|
||||
encoder_.GetContext(sourceWords, SourceContext_);
|
||||
|
||||
decoder_.EmptyState(State_, SourceContext_, 1);
|
||||
decoder_.EmptyEmbedding(Embeddings_, 1);
|
||||
|
||||
@ -72,13 +109,13 @@ class Search {
|
||||
|
||||
Beam hyps;
|
||||
BestHyps(hyps, prevHyps, Probs_, beamSize);
|
||||
history.push_back(hyps);
|
||||
history.Add(hyps, history.size() + 1 == sourceWords.size() * 3);
|
||||
|
||||
Beam survivors;
|
||||
std::vector<size_t> beamWords;
|
||||
std::vector<size_t> beamStateIds;
|
||||
for(auto& h : hyps) {
|
||||
if(h.GetWord() != tvcb_[EOL]) {
|
||||
if(h.GetWord() != EOS) {
|
||||
survivors.push_back(h);
|
||||
beamWords.push_back(h.GetWord());
|
||||
beamStateIds.push_back(h.GetPrevStateIndex());
|
||||
@ -98,7 +135,7 @@ class Search {
|
||||
|
||||
} while(history.size() < sourceWords.size() * 3);
|
||||
|
||||
return FindBest(history);
|
||||
return history;
|
||||
}
|
||||
|
||||
void BestHyps(Beam& bestHyps, const Beam& prevHyps, mblas::Matrix& Probs, const size_t beamSize) {
|
||||
@ -131,44 +168,4 @@ class Search {
|
||||
bestHyps.emplace_back(wordIndex, hypIndex, cost);
|
||||
}
|
||||
}
|
||||
|
||||
std::string FindBest(const History& history) {
|
||||
std::vector<size_t> targetWords;
|
||||
|
||||
size_t best = 0;
|
||||
size_t beamSize = 0;
|
||||
float bestCost = std::numeric_limits<float>::lowest();
|
||||
|
||||
for(auto b = history.rbegin(); b != history.rend(); b++) {
|
||||
if(b->size() > beamSize) {
|
||||
beamSize = b->size();
|
||||
for(size_t i = 0; i < beamSize; ++i) {
|
||||
if(b == history.rbegin() || (*b)[i].GetWord() == tvcb_[EOL]) {
|
||||
if((*b)[i].GetCost() > bestCost) {
|
||||
best = i;
|
||||
bestCost = (*b)[i].GetCost();
|
||||
targetWords.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto& bestHyp = (*b)[best];
|
||||
targetWords.push_back(bestHyp.GetWord());
|
||||
best = bestHyp.GetPrevStateIndex();
|
||||
}
|
||||
|
||||
std::reverse(targetWords.begin(), targetWords.end());
|
||||
std::stringstream translation;
|
||||
for(size_t i = 0; i < targetWords.size(); ++i) {
|
||||
if(tvcb_[targetWords[i]] != EOL) {
|
||||
if(i > 0) {
|
||||
translation << " ";
|
||||
}
|
||||
translation << tvcb_[targetWords[i]];
|
||||
}
|
||||
}
|
||||
return translation.str();
|
||||
}
|
||||
|
||||
};
|
@ -39,9 +39,7 @@ std::vector<std::string> NBest::GetTokens(const size_t index) const {
|
||||
}
|
||||
|
||||
std::vector<size_t> NBest::GetEncodedTokens(const size_t index) const {
|
||||
std::vector<std::string> tokens;
|
||||
Split(srcSentences_[index], tokens);
|
||||
return srcVocab_->Encode(tokens, true);
|
||||
return (*srcVocab_)(srcSentences_[index]);
|
||||
}
|
||||
|
||||
void NBest::Parse_(const std::string& path) {
|
||||
@ -76,10 +74,10 @@ inline std::vector< std::vector< std::string > > NBest::SplitBatch(std::vector<s
|
||||
return splittedBatch;
|
||||
}
|
||||
|
||||
inline Batch NBest::EncodeBatch(const std::vector<std::vector<std::string>>& batch) const {
|
||||
inline Batch NBest::EncodeBatch(const std::vector<std::string>& batch) const {
|
||||
Batch encodedBatch;
|
||||
for (auto& sentence: batch) {
|
||||
encodedBatch.push_back(trgVocab_->Encode(sentence, true));
|
||||
encodedBatch.push_back((*trgVocab_)(sentence));
|
||||
}
|
||||
return encodedBatch;
|
||||
}
|
||||
@ -103,7 +101,7 @@ inline Batch NBest::MaskAndTransposeBatch(const Batch& batch) const {
|
||||
|
||||
|
||||
Batch NBest::ProcessBatch(std::vector<std::string>& batch) const {
|
||||
return MaskAndTransposeBatch(EncodeBatch(SplitBatch(batch)));
|
||||
return MaskAndTransposeBatch(EncodeBatch(batch));
|
||||
}
|
||||
|
||||
std::vector<Batch> NBest::GetBatches(const size_t index) const {
|
||||
|
@ -41,7 +41,7 @@ class NBest {
|
||||
std::vector<std::vector<std::string>> SplitBatch(std::vector<std::string>& batch) const;
|
||||
void ParseInputFile(const std::string& path);
|
||||
|
||||
Batch EncodeBatch(const std::vector<std::vector<std::string>>& batch) const;
|
||||
Batch EncodeBatch(const std::vector<std::string>& batch) const;
|
||||
|
||||
Batch MaskAndTransposeBatch(const Batch& batch) const;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user