mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-03 20:13:47 +03:00
bug fix: History::Add() should obey the actual EOS symbol from the given vocabulary;
SubBatch now holds the vocabulary, allowing to debug-print the sequences with word ids; new operators: logsum, max, min; new overload operator/(float,Expr); shift() now takes a padding value; transformer.h is now a compiled .cpp file on Windows (in prep for renaming to transformer.cpp); fixed two MSVC warnings; bug fix: logger should not output two \r characters at the end of the line; minor changes (const correctness, accessors)
This commit is contained in:
parent
8adde0787e
commit
e87564c664
5
.gitignore
vendored
5
.gitignore
vendored
@ -4,6 +4,8 @@ src/common/version.h
|
||||
*.vcxproj.user
|
||||
/vs/x64
|
||||
pingme.txt
|
||||
/local
|
||||
# TODO: ^^ the correct solution for /local is to add that to some local git config, don't remember which one. Cf. mtmain.
|
||||
|
||||
# Compiled Object files
|
||||
*.slo
|
||||
@ -52,3 +54,6 @@ examples/mnist/*ubyte
|
||||
# Contrib
|
||||
/.ycm_extra_conf.py
|
||||
/.vimrc
|
||||
/vs/MarianDll.sln
|
||||
/vs/MarianDll.VC.db
|
||||
/vs/MarianDll.VC.VC.opendb
|
||||
|
14
src/3rd_party/spdlog/details/os.h
vendored
14
src/3rd_party/spdlog/details/os.h
vendored
@ -112,17 +112,17 @@ inline bool operator!=(const std::tm& tm1, const std::tm& tm2)
|
||||
return !(tm1 == tm2);
|
||||
}
|
||||
|
||||
#ifdef _WIN32
|
||||
inline const char* eol()
|
||||
{
|
||||
return "\r\n";
|
||||
}
|
||||
#else
|
||||
//#ifdef _WIN32
|
||||
//inline const char* eol()
|
||||
//{
|
||||
// return "\r\n";
|
||||
//}
|
||||
//#else
|
||||
constexpr inline const char* eol()
|
||||
{
|
||||
return "\n";
|
||||
}
|
||||
#endif
|
||||
//#endif
|
||||
|
||||
#ifdef _WIN32
|
||||
inline unsigned short eol_size()
|
||||
|
@ -133,11 +133,11 @@ inline bool operator!=(const std::tm& tm1, const std::tm& tm2)
|
||||
|
||||
// eol definition
|
||||
#if !defined (SPDLOG_EOL)
|
||||
#ifdef _WIN32
|
||||
#define SPDLOG_EOL "\r\n"
|
||||
#else
|
||||
//#ifdef _WIN32
|
||||
//#define SPDLOG_EOL "\r\n"
|
||||
//#else
|
||||
#define SPDLOG_EOL "\n"
|
||||
#endif
|
||||
//#endif
|
||||
#endif
|
||||
|
||||
SPDLOG_CONSTEXPR static const char* eol = SPDLOG_EOL;
|
||||
|
@ -57,6 +57,9 @@ struct DeviceId {
|
||||
friend bool operator==(DeviceId id1, DeviceId id2) {
|
||||
return id1.no == id2.no && id1.type == id2.type;
|
||||
}
|
||||
friend bool operator!=(DeviceId id1, DeviceId id2) {
|
||||
return !(id1 == id2);
|
||||
}
|
||||
};
|
||||
|
||||
class TensorBase;
|
||||
|
@ -58,7 +58,9 @@ void createLoggers(const marian::Config* options) {
|
||||
|
||||
if(options && options->has("log")) {
|
||||
generalLogs.push_back(options->get<std::string>("log"));
|
||||
#ifndef _WIN32 // can't open the same file twice in Windows for some reason
|
||||
validLogs.push_back(options->get<std::string>("log"));
|
||||
#endif
|
||||
}
|
||||
|
||||
if(options && options->has("valid-log")) {
|
||||
|
@ -13,7 +13,7 @@
|
||||
namespace marian {
|
||||
|
||||
struct Shape {
|
||||
public:
|
||||
public: // TODO: why public?
|
||||
std::vector<int> shape_;
|
||||
|
||||
public:
|
||||
@ -24,6 +24,8 @@ public:
|
||||
std::copy(il.begin(), il.end(), begin());
|
||||
}
|
||||
|
||||
Shape(std::vector<int>&& shape) : shape_(std::move(shape)) { }
|
||||
|
||||
void resize(size_t n) { shape_.resize(n, 1); }
|
||||
|
||||
const int* data() const { return shape_.data(); }
|
||||
@ -61,6 +63,7 @@ public:
|
||||
|
||||
inline int operator[](int i) const { return dim(i); }
|
||||
|
||||
inline int back() const { return shape_.back(); }
|
||||
inline int& back() { return shape_.back(); }
|
||||
|
||||
inline int stride(int i) const {
|
||||
@ -132,7 +135,7 @@ public:
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
int axis(int ax) {
|
||||
int axis(int ax) const {
|
||||
if(ax < 0)
|
||||
return size() + ax;
|
||||
else
|
||||
|
@ -10,7 +10,6 @@
|
||||
#include "common/config.h"
|
||||
#include "data/batch_stats.h"
|
||||
#include "data/rng_engine.h"
|
||||
#include "data/vocab.h"
|
||||
#include "training/training_state.h"
|
||||
|
||||
namespace marian {
|
||||
|
@ -73,8 +73,8 @@ public:
|
||||
}
|
||||
|
||||
std::vector<Ptr<SubBatch>> subBatches;
|
||||
for(auto m : maxDims) {
|
||||
subBatches.emplace_back(New<SubBatch>(batchSize, m));
|
||||
for (int j = 0; j < maxDims.size(); ++j) {
|
||||
subBatches.emplace_back(New<SubBatch>(batchSize, maxDims[j], vocabs_[j]));
|
||||
}
|
||||
|
||||
std::vector<size_t> words(maxDims.size(), 0);
|
||||
|
@ -107,6 +107,8 @@ class SubBatch {
|
||||
private:
|
||||
std::vector<Word> indices_;
|
||||
std::vector<float> mask_;
|
||||
Ptr<Vocab> vocab_;
|
||||
// ... TODO: add the length information (remember it)
|
||||
|
||||
size_t size_;
|
||||
size_t width_;
|
||||
@ -119,12 +121,13 @@ public:
|
||||
* @param size Number of sentences
|
||||
* @param width Number of words in the longest sentence
|
||||
*/
|
||||
SubBatch(int size, int width)
|
||||
SubBatch(int size, int width, const Ptr<Vocab>& vocab)
|
||||
: indices_(size * width, 0),
|
||||
mask_(size * width, 0),
|
||||
size_(size),
|
||||
width_(width),
|
||||
words_(0) {}
|
||||
words_(0),
|
||||
vocab_(vocab){}
|
||||
|
||||
/**
|
||||
* @brief Flat vector of word indices.
|
||||
@ -141,6 +144,11 @@ public:
|
||||
*/
|
||||
std::vector<float>& mask() { return mask_; }
|
||||
|
||||
/**
|
||||
* @brief Accessors to the vocab_ field.
|
||||
*/
|
||||
const Ptr<Vocab>& vocab() const { return vocab_; }
|
||||
|
||||
/**
|
||||
* @brief The number of sentences in the batch.
|
||||
*/
|
||||
@ -173,7 +181,7 @@ public:
|
||||
for(int k = 0; k < n; ++k) {
|
||||
size_t __size__ = std::min(subSize, totSize);
|
||||
|
||||
auto sb = New<SubBatch>(__size__, width_);
|
||||
auto sb = New<SubBatch>(__size__, width_, vocab_);
|
||||
|
||||
size_t __words__ = 0;
|
||||
for(int j = 0; j < width_; ++j) {
|
||||
@ -204,12 +212,12 @@ public:
|
||||
*/
|
||||
class CorpusBatch : public Batch {
|
||||
private:
|
||||
std::vector<Ptr<SubBatch>> batches_;
|
||||
std::vector<Ptr<SubBatch>> subBatches_;
|
||||
std::vector<float> guidedAlignment_;
|
||||
std::vector<float> dataWeights_;
|
||||
|
||||
public:
|
||||
CorpusBatch(const std::vector<Ptr<SubBatch>>& batches) : batches_(batches) {}
|
||||
CorpusBatch(const std::vector<Ptr<SubBatch>>& subBatches) : subBatches_(subBatches) {}
|
||||
|
||||
/**
|
||||
* @brief Access i-th subbatch storing a source or target sentence.
|
||||
@ -221,52 +229,52 @@ public:
|
||||
*
|
||||
* @return Pointer to the requested element.
|
||||
*/
|
||||
Ptr<SubBatch> operator[](size_t i) const { return batches_[i]; }
|
||||
Ptr<SubBatch> operator[](size_t i) const { return subBatches_[i]; }
|
||||
|
||||
/**
|
||||
* @brief Access the first subbatch, i.e. the source sentence.
|
||||
*/
|
||||
Ptr<SubBatch> front() { return batches_.front(); }
|
||||
Ptr<SubBatch> front() { return subBatches_.front(); }
|
||||
|
||||
/**
|
||||
* @brief Access the last subbatch, i.e. the target sentence.
|
||||
*/
|
||||
Ptr<SubBatch> back() { return batches_.back(); }
|
||||
Ptr<SubBatch> back() { return subBatches_.back(); }
|
||||
|
||||
/**
|
||||
* @brief The number of sentences in the batch.
|
||||
*/
|
||||
size_t size() const { return batches_[0]->batchSize(); }
|
||||
size_t size() const { return subBatches_[0]->batchSize(); }
|
||||
|
||||
/**
|
||||
* @brief The total number of words for the longest sentence in the batch plus one. Pass which=0 for source and -1 for target.
|
||||
*/
|
||||
size_t words(int which = 0) const { return batches_[which >= 0 ? which : which + (ptrdiff_t)batches_.size()]->batchWords(); }
|
||||
size_t words(int which = 0) const { return subBatches_[which >= 0 ? which : which + (ptrdiff_t)subBatches_.size()]->batchWords(); }
|
||||
|
||||
/**
|
||||
* @brief The width of the source mini-batch. Num words + padded?
|
||||
*/
|
||||
size_t width() const { return batches_[0]->batchWidth(); }
|
||||
size_t width() const { return subBatches_[0]->batchWidth(); }
|
||||
|
||||
/**
|
||||
* @brief The number of sentences in the batch, target words.
|
||||
*/
|
||||
size_t sizeTrg() const { return batches_.back()->batchSize(); }
|
||||
size_t sizeTrg() const { return subBatches_.back()->batchSize(); }
|
||||
|
||||
/**
|
||||
* @brief The number of words for the longest sentence in the batch plus one.
|
||||
*/
|
||||
size_t wordsTrg() const { return batches_.back()->batchWords(); };
|
||||
size_t wordsTrg() const { return subBatches_.back()->batchWords(); };
|
||||
|
||||
/**
|
||||
* @brief The width of the target mini-batch. Num words + padded?
|
||||
*/
|
||||
size_t widthTrg() const { return batches_.back()->batchWidth(); };
|
||||
size_t widthTrg() const { return subBatches_.back()->batchWidth(); };
|
||||
|
||||
/**
|
||||
* @brief The number of source and targets.
|
||||
*/
|
||||
size_t sets() const { return batches_.size(); }
|
||||
size_t sets() const { return subBatches_.size(); }
|
||||
|
||||
/**
|
||||
* @brief Creates a batch filled with fake data. Used to determine the size of
|
||||
@ -284,8 +292,10 @@ public:
|
||||
std::vector<Ptr<SubBatch>> batches;
|
||||
|
||||
for(auto len : lengths) {
|
||||
auto sb = New<SubBatch>(batchSize, len);
|
||||
std::fill(sb->mask().begin(), sb->mask().end(), 1);
|
||||
auto vocab = New<Vocab>();
|
||||
vocab->createFake();
|
||||
auto sb = New<SubBatch>(batchSize, len, vocab); // data: gets initialized to 0. No EOS symbol is distinguished.
|
||||
std::fill(sb->mask().begin(), sb->mask().end(), 1); // mask: no items ask being masked out
|
||||
|
||||
batches.push_back(sb);
|
||||
}
|
||||
@ -323,7 +333,7 @@ public:
|
||||
std::vector<Ptr<Batch>> split(size_t n) {
|
||||
// split each subbatch separately
|
||||
std::vector<std::vector<Ptr<SubBatch>>> subs(n);
|
||||
for(auto subBatch : batches_) {
|
||||
for(auto subBatch : subBatches_) {
|
||||
size_t i = 0;
|
||||
for(auto splitSubBatch : subBatch->split(n))
|
||||
subs[i++].push_back(splitSubBatch);
|
||||
@ -357,7 +367,7 @@ public:
|
||||
size_t width = 1;
|
||||
// There are more weights than sentences, i.e. these are word weights.
|
||||
if(dataWeights_.size() != oldSize)
|
||||
width = batches_.back()->batchWidth();
|
||||
width = subBatches_.back()->batchWidth();
|
||||
|
||||
for(auto split : splits) {
|
||||
std::vector<float> ws(width * split->size(), 1.0f);
|
||||
@ -402,14 +412,16 @@ public:
|
||||
}
|
||||
|
||||
size_t b = 0;
|
||||
for(auto sb : batches_) {
|
||||
for(auto sb : subBatches_) {
|
||||
std::cerr << "batch " << b++ << ": " << std::endl;
|
||||
const auto& vocab = *sb->vocab();
|
||||
for(size_t i = 0; i < sb->batchWidth(); i++) {
|
||||
std::cerr << "\t w: ";
|
||||
for(size_t j = 0; j < sb->batchSize(); j++) {
|
||||
size_t idx = i * sb->batchSize() + j;
|
||||
Word w = sb->data()[idx];
|
||||
std::cerr << w << " ";
|
||||
const auto& s = vocab[w];
|
||||
std::cerr << s << " ";
|
||||
}
|
||||
std::cerr << std::endl;
|
||||
}
|
||||
|
@ -64,8 +64,8 @@ public:
|
||||
}
|
||||
|
||||
std::vector<Ptr<SubBatch>> subBatches;
|
||||
for(auto m : maxDims) {
|
||||
subBatches.emplace_back(New<SubBatch>(batchSize, m));
|
||||
for (int j = 0; j < maxDims.size(); ++j) {
|
||||
subBatches.emplace_back(New<SubBatch>(batchSize, maxDims[j], vocabs_[j]));
|
||||
}
|
||||
|
||||
std::vector<size_t> words(maxDims.size(), 0);
|
||||
|
@ -83,8 +83,8 @@ public:
|
||||
}
|
||||
|
||||
std::vector<Ptr<SubBatch>> subBatches;
|
||||
for(auto m : maxDims) {
|
||||
subBatches.emplace_back(New<SubBatch>(batchSize, m));
|
||||
for (int j = 0; j < maxDims.size(); ++j) {
|
||||
subBatches.emplace_back(New<SubBatch>(batchSize, maxDims[j], vocabs_[j]));
|
||||
}
|
||||
|
||||
std::vector<size_t> words(maxDims.size(), 0);
|
||||
|
@ -54,6 +54,7 @@ public:
|
||||
iterator begin() { return iterator(*this); }
|
||||
iterator end() { return iterator(); }
|
||||
|
||||
// TODO: There are half dozen functions called toBatch(), which are very similar. Factor them.
|
||||
batch_ptr toBatch(const std::vector<sample>& batchVector) {
|
||||
int batchSize = batchVector.size();
|
||||
|
||||
@ -71,8 +72,8 @@ public:
|
||||
}
|
||||
|
||||
std::vector<Ptr<SubBatch>> subBatches;
|
||||
for(auto m : maxDims) {
|
||||
subBatches.emplace_back(New<SubBatch>(batchSize, m));
|
||||
for (int j = 0; j < maxDims.size(); ++j) {
|
||||
subBatches.emplace_back(New<SubBatch>(batchSize, maxDims[j], vocabs_[j]));
|
||||
}
|
||||
|
||||
std::vector<size_t> words(maxDims.size(), 0);
|
||||
|
@ -81,6 +81,16 @@ int Vocab::loadOrCreate(const std::string& vocabPath,
|
||||
}
|
||||
}
|
||||
|
||||
// helper to insert a word into str2id_[] and id2str_[]
|
||||
Word Vocab::insertWord(Word id, const std::string& str)
|
||||
{
|
||||
str2id_[str] = id;
|
||||
if (id >= id2str_.size())
|
||||
id2str_.resize(id + 1);
|
||||
id2str_[id] = str;
|
||||
return id;
|
||||
};
|
||||
|
||||
int Vocab::load(const std::string& vocabPath, int max) {
|
||||
bool isJson = regex::regex_search(vocabPath, regex::regex("\\.(json|yml)$"));
|
||||
LOG(info, "[data] Loading vocabulary from {} file {}", isJson ? "JSON/Yaml" : "text", vocabPath);
|
||||
@ -108,15 +118,6 @@ int Vocab::load(const std::string& vocabPath, int max) {
|
||||
|
||||
std::unordered_set<Word> seenSpecial;
|
||||
|
||||
// helper to insert a word into str2id_[] and id2str_[]
|
||||
auto insertWord = [&](Word id, const std::string& str)
|
||||
{
|
||||
str2id_[str] = id;
|
||||
if(id >= id2str_.size())
|
||||
id2str_.resize(id + 1);
|
||||
id2str_[id] = str;
|
||||
};
|
||||
|
||||
id2str_.reserve(vocab.size());
|
||||
for(auto&& pair : vocab) {
|
||||
auto str = pair.first;
|
||||
@ -169,6 +170,12 @@ int Vocab::load(const std::string& vocabPath, int max) {
|
||||
return std::max((int)id2str_.size(), max);
|
||||
}
|
||||
|
||||
void Vocab::createFake() // for fakeBatch()
|
||||
{
|
||||
eosId_ = insertWord(DEFAULT_EOS_ID, DEFAULT_EOS_STR);
|
||||
unkId_ = insertWord(DEFAULT_UNK_ID, DEFAULT_UNK_STR);
|
||||
}
|
||||
|
||||
class Vocab::VocabFreqOrderer {
|
||||
private:
|
||||
std::unordered_map<std::string, size_t>& counter_;
|
||||
|
@ -1,12 +1,12 @@
|
||||
#pragma once
|
||||
|
||||
#include "common/file_stream.h"
|
||||
#include "data/types.h"
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "common/file_stream.h"
|
||||
#include "data/types.h"
|
||||
|
||||
namespace marian {
|
||||
|
||||
class Vocab {
|
||||
@ -39,6 +39,11 @@ public:
|
||||
Word GetEosId() const { return eosId_; }
|
||||
Word GetUnkId() const { return unkId_; }
|
||||
|
||||
void createFake(); // for fakeBatch()
|
||||
|
||||
private:
|
||||
Word insertWord(Word id, const std::string& str);
|
||||
|
||||
private:
|
||||
typedef std::map<std::string, size_t> Str2Id;
|
||||
Str2Id str2id_;
|
||||
|
@ -94,6 +94,15 @@ BINARY(Minus, operator-, x - y);
|
||||
BINARY(Mult, operator*, x* y);
|
||||
BINARY(Div, operator/, x / y);
|
||||
|
||||
BINARY(LogSum,
|
||||
logsum,
|
||||
(/*if*/ (x < y) ? // Note: This may not be ideal for CUDA; cf. CNTK implementation
|
||||
(y + log1pf(expf(x - y)))
|
||||
/*else*/ :
|
||||
(x + log1pf(expf(y - x)))));
|
||||
BINARY(Max, max, (x > y) ? y : x); // note: std::max not available on CUDA it seems
|
||||
BINARY(Min, min, (x < y) ? y : x);
|
||||
|
||||
UNARY(Negate, operator!, !x);
|
||||
BINARY(Eq, operator==, x == y);
|
||||
BINARY(NEq, operator!=, x != y);
|
||||
|
@ -80,6 +80,19 @@ Expr operator/(Expr a, Expr b) {
|
||||
return Expression<DivNodeOp>(a, b);
|
||||
}
|
||||
|
||||
// on names: stay close to Python/numpy?
|
||||
Expr logsum(Expr a, Expr b) { // TODO: haggle over the name (logplus, logadd, expAddLog)
|
||||
return Expression<LogSumNodeOp>(a, b);
|
||||
}
|
||||
|
||||
Expr max(Expr a, Expr b) { // TODO: haggle over the name (max vs. elementMax)
|
||||
return Expression<MaxNodeOp>(a, b);
|
||||
}
|
||||
|
||||
Expr min(Expr a, Expr b) { // TODO: haggle over the name
|
||||
return Expression<MinNodeOp>(a, b);
|
||||
}
|
||||
|
||||
/*********************************************************/
|
||||
|
||||
Expr operator+(Expr a, float b) {
|
||||
@ -110,6 +123,11 @@ Expr operator/(Expr a, float b) {
|
||||
return Expression<ScalarMultNodeOp>(a, 1.f / b);
|
||||
}
|
||||
|
||||
Expr operator/(float a, Expr b) { // TODO: efficient version of this without constant()
|
||||
auto aExpr = b->graph()->constant({}, inits::from_value(a));
|
||||
return aExpr / b;
|
||||
}
|
||||
|
||||
// Expr pow(float a, Expr b) {
|
||||
// return Expression<Scalar1PowNodeOp>(a, b);
|
||||
//
|
||||
@ -326,6 +344,7 @@ Expr affine(Expr a, Expr b, Expr bias, bool transA, bool transB, float scale) {
|
||||
}
|
||||
}
|
||||
|
||||
// swap the last two axes
|
||||
Expr transpose(Expr a) {
|
||||
std::vector<int> axes(a->shape().size());
|
||||
for(int i = 0; i < axes.size(); ++i) {
|
||||
@ -435,8 +454,8 @@ Expr highway(const std::string prefix, Expr x) {
|
||||
// return gamma * (xmmju / std);
|
||||
//}
|
||||
|
||||
Expr shift(Expr a, Shape shift) {
|
||||
return Expression<ShiftNodeOp>(a, shift);
|
||||
Expr shift(Expr a, Shape shift, float padValue) {
|
||||
return Expression<ShiftNodeOp>(a, shift, padValue);
|
||||
}
|
||||
|
||||
// Expr lexical_bias(Expr logits, Expr att, float eps, Ptr<sparse::CSR> lf) {
|
||||
|
@ -7,7 +7,7 @@ Expr debug(Expr a, const std::string& message = "");
|
||||
|
||||
Expr plus(const std::vector<Expr>&);
|
||||
|
||||
Expr logit(Expr a);
|
||||
Expr logit(Expr a); // aka sigmoid
|
||||
Expr logit(const std::vector<Expr>&);
|
||||
|
||||
Expr swish(Expr a);
|
||||
@ -60,6 +60,12 @@ Expr operator/(Expr a, float b);
|
||||
// Expr pow(float a, Expr b);
|
||||
// Expr pow(Expr a, float b);
|
||||
|
||||
Expr logsum(Expr a, Expr b); // TODO: haggle over the name (logplus, logadd, expAddLog)
|
||||
|
||||
Expr max(Expr a, Expr b); // TODO: haggle over the name (max vs. elementMax)
|
||||
|
||||
Expr min(Expr a, Expr b); // TODO: haggle over the name
|
||||
|
||||
Expr dot(Expr a,
|
||||
Expr b,
|
||||
bool transA = false,
|
||||
@ -141,7 +147,7 @@ static inline Expr dropout(Expr x, float prob) {
|
||||
return dropout(x, prob, x->shape());
|
||||
}
|
||||
|
||||
Expr shift(Expr, Shape);
|
||||
Expr shift(Expr, Shape, float padValue = 0);
|
||||
|
||||
Expr convert2cudnnFormat(Expr x);
|
||||
|
||||
|
@ -528,6 +528,66 @@ struct DivNodeOp : public ElementBinaryNodeOp {
|
||||
// const std::string type() { return "pow"; }
|
||||
//};
|
||||
|
||||
struct LogSumNodeOp : public ElementBinaryNodeOp {
|
||||
LogSumNodeOp(Expr a, Expr b) : ElementBinaryNodeOp(a, b) {}
|
||||
|
||||
NodeOps forwardOps() {
|
||||
using namespace functional;
|
||||
return{
|
||||
NodeOp(Element(_1 = logsum(_2, _3), val_, child(0)->val(), child(1)->val())) };
|
||||
}
|
||||
|
||||
NodeOps backwardOps() {
|
||||
using namespace functional;
|
||||
|
||||
// d/dx (ln( exp(x) + (exp(y)) = exp(x) / (exp(x) + exp(y)) = 1 / (1 + exp(y-x)) = sigmoid(x-y)
|
||||
return{ NodeOp(Add(_1 * logit(_2 - _3), child(0)->grad(), adj_, child(0)->val(), child(1)->val())),
|
||||
NodeOp(Add(_1 * logit(_3 - _2), child(1)->grad(), adj_, child(0)->val(), child(1)->val())) };
|
||||
}
|
||||
|
||||
// TODO: this is not a "type" (as in data type). It's an operator name.
|
||||
const std::string type() { return "logsum"; }
|
||||
};
|
||||
|
||||
struct MaxNodeOp : public ElementBinaryNodeOp {
|
||||
MaxNodeOp(Expr a, Expr b) : ElementBinaryNodeOp(a, b) {}
|
||||
|
||||
NodeOps forwardOps() {
|
||||
using namespace functional;
|
||||
return{
|
||||
NodeOp(Element(_1 = max(_2, _3), val_, child(0)->val(), child(1)->val())) };
|
||||
}
|
||||
|
||||
NodeOps backwardOps() {
|
||||
using namespace functional;
|
||||
|
||||
return{ NodeOp(Add((_2 > _3) * _1, child(0)->grad(), adj_, child(0)->val(), child(1)->val())),
|
||||
NodeOp(Add((_2 <= _3) * _1, child(1)->grad(), adj_, child(0)->val(), child(1)->val())) };
|
||||
}
|
||||
|
||||
const std::string type() { return "max"; }
|
||||
};
|
||||
|
||||
// TODO: lotsa code dup here!
|
||||
struct MinNodeOp : public ElementBinaryNodeOp {
|
||||
MinNodeOp(Expr a, Expr b) : ElementBinaryNodeOp(a, b) {}
|
||||
|
||||
NodeOps forwardOps() {
|
||||
using namespace functional;
|
||||
return{
|
||||
NodeOp(Element(_1 = min(_2, _3), val_, child(0)->val(), child(1)->val())) };
|
||||
}
|
||||
|
||||
NodeOps backwardOps() {
|
||||
using namespace functional;
|
||||
|
||||
return{ NodeOp(Add((_2 < _3) * _1, child(0)->grad(), adj_, child(0)->val(), child(1)->val())),
|
||||
NodeOp(Add((_2 >= _3) * _1, child(1)->grad(), adj_, child(0)->val(), child(1)->val())) };
|
||||
}
|
||||
|
||||
const std::string type() { return "min"; }
|
||||
};
|
||||
|
||||
// Cross-entropy node. It computes -b*log(softmax(a)), summing rowwise.
|
||||
struct CrossEntropyNodeOp : public NaryNodeOp {
|
||||
CrossEntropyNodeOp(Expr a, Expr b) : NaryNodeOp({a, b}, newShape(a)) {}
|
||||
|
@ -550,7 +550,8 @@ struct LogNodeOp : public UnaryNodeOp {
|
||||
NodeOps backwardOps() {
|
||||
using namespace functional;
|
||||
return {
|
||||
NodeOp(Add(_1 * (1.f / _2), child(0)->grad(), adj_, child(0)->val()))};
|
||||
//NodeOp(Add(_1 * (1.f / _2), child(0)->grad(), adj_, child(0)->val()))};
|
||||
NodeOp(Add(_1 / _2, child(0)->grad(), adj_, child(0)->val()))};
|
||||
}
|
||||
|
||||
const std::string type() { return "log"; }
|
||||
@ -931,8 +932,9 @@ public:
|
||||
Shape outShape = a->shape();
|
||||
|
||||
axis_ = outShape.axis(axis);
|
||||
for(int i = 0; i <= axis_; ++i)
|
||||
outShape.set(i, 1);
|
||||
for(int i = 0; i < axis_; ++i)
|
||||
ABORT_IF(outShape[i] != 1, "non-consecutive slices are presently not supported by step()");
|
||||
outShape.set(axis_, 1);
|
||||
|
||||
return outShape;
|
||||
}
|
||||
@ -993,15 +995,15 @@ public:
|
||||
};
|
||||
|
||||
struct ShiftNodeOp : public UnaryNodeOp {
|
||||
ShiftNodeOp(Expr a, Shape shift)
|
||||
: UnaryNodeOp(a, a->shape()), shift_(shift) {}
|
||||
ShiftNodeOp(Expr a, Shape shift, float padValue)
|
||||
: UnaryNodeOp(a, a->shape()), shift_(shift), padValue_(padValue) {}
|
||||
|
||||
NodeOps forwardOps() {
|
||||
return {NodeOp(Shift(val_, child(0)->val(), shift_, false))};
|
||||
return {NodeOp(Shift(val_, child(0)->val(), shift_, padValue_, /*invert=*/false))};
|
||||
}
|
||||
|
||||
NodeOps backwardOps() {
|
||||
return {NodeOp(Shift(child(0)->grad(), adj_, shift_, true))};
|
||||
return {NodeOp(Shift(child(0)->grad(), adj_, shift_, /*padValue=*/0.f, /*invert=*/true))};
|
||||
}
|
||||
|
||||
const std::string type() { return "shift"; }
|
||||
@ -1011,6 +1013,7 @@ struct ShiftNodeOp : public UnaryNodeOp {
|
||||
size_t seed = NaryNodeOp::hash();
|
||||
for(auto i : shift_)
|
||||
boost::hash_combine(seed, i);
|
||||
boost::hash_combine(seed, padValue_);
|
||||
hash_ = seed;
|
||||
}
|
||||
return hash_;
|
||||
@ -1027,7 +1030,8 @@ struct ShiftNodeOp : public UnaryNodeOp {
|
||||
return true;
|
||||
}
|
||||
|
||||
Shape shift_;
|
||||
Shape shift_; // shift offsets in each dimension
|
||||
float padValue_; // what value to shift in
|
||||
};
|
||||
|
||||
// struct LexicalProbNodeOp : public NaryNodeOp {
|
||||
|
@ -125,7 +125,7 @@ public:
|
||||
virtual void setShortlist(Ptr<data::Shortlist> shortlist) { shortlist_ = shortlist; }
|
||||
|
||||
template <typename T>
|
||||
T opt(const std::string& key) {
|
||||
T opt(const std::string& key) const {
|
||||
return options_->get<T>(key);
|
||||
}
|
||||
|
||||
|
@ -44,7 +44,7 @@ public:
|
||||
= 0;
|
||||
|
||||
template <typename T>
|
||||
T opt(const std::string& key) {
|
||||
T opt(const std::string& key) const {
|
||||
return options_->get<T>(key);
|
||||
}
|
||||
|
||||
|
@ -8,7 +8,7 @@
|
||||
#include "models/hardatt.h"
|
||||
#include "models/nematus.h"
|
||||
#include "models/s2s.h"
|
||||
#include "models/transformer.h"
|
||||
#include "models/transformer_factory.h"
|
||||
|
||||
#ifdef CUDNN
|
||||
#include "models/char_s2s.h"
|
||||
@ -34,7 +34,8 @@ Ptr<EncoderBase> EncoderFactory::construct() {
|
||||
#endif
|
||||
|
||||
if(options_->get<std::string>("type") == "transformer")
|
||||
return New<EncoderTransformer>(options_);
|
||||
//return New<EncoderTransformer>(options_);
|
||||
return NewEncoderTransformer(options_);
|
||||
|
||||
ABORT("Unknown encoder type");
|
||||
}
|
||||
@ -43,7 +44,8 @@ Ptr<DecoderBase> DecoderFactory::construct() {
|
||||
if(options_->get<std::string>("type") == "s2s")
|
||||
return New<DecoderS2S>(options_);
|
||||
if(options_->get<std::string>("type") == "transformer")
|
||||
return New<DecoderTransformer>(options_);
|
||||
//return New<DecoderTransformer>(options_);
|
||||
return NewDecoderTransformer(options_);
|
||||
if(options_->get<std::string>("type") == "hard-att")
|
||||
return New<DecoderHardAtt>(options_);
|
||||
if(options_->get<std::string>("type") == "hard-soft-att")
|
||||
|
@ -1,7 +1,11 @@
|
||||
// TODO: This is really a .CPP file now. I kept the .H name to minimize confusing git, until this is code-reviewed.
|
||||
// This is meant to speed-up builds, and to support Ctrl-F7 to rebuild
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "marian.h"
|
||||
|
||||
#include "models/transformer_factory.h"
|
||||
#include "models/encoder.h"
|
||||
#include "models/decoder.h"
|
||||
#include "models/states.h"
|
||||
|
18
src/models/transformer_factory.h
Normal file
18
src/models/transformer_factory.h
Normal file
@ -0,0 +1,18 @@
|
||||
#pragma once
|
||||
|
||||
#include "marian.h"
|
||||
|
||||
#include "models/encoder.h"
|
||||
#include "models/decoder.h"
|
||||
//#include "models/states.h"
|
||||
//#include "layers/constructors.h"
|
||||
//#include "layers/factory.h"
|
||||
|
||||
namespace marian {
|
||||
Ptr<EncoderBase> NewEncoderTransformer(Ptr<Options> options);
|
||||
Ptr<DecoderBase> NewDecoderTransformer(Ptr<Options> options);
|
||||
}
|
||||
|
||||
#ifndef _WIN32 // TODO: remove this once I updated the Linux-side makefile
|
||||
#include "models/transformer.h"
|
||||
#endif
|
@ -7,13 +7,13 @@
|
||||
namespace marian {
|
||||
namespace cpu {
|
||||
|
||||
void Dropout(Tensor tensor, float p) {
|
||||
void Dropout(Tensor tensor, float dropProb) {
|
||||
auto cpuBackend
|
||||
= std::static_pointer_cast<cpu::Backend>(tensor->getBackend());
|
||||
auto &gen = cpuBackend->getRandomGenerator();
|
||||
std::bernoulli_distribution dist(1.f - p);
|
||||
std::bernoulli_distribution dist(1.f - dropProb);
|
||||
std::generate(tensor->data(), tensor->data() + tensor->size(), [&]() {
|
||||
return dist(gen) / (1.f - p); // TODO: fix this warning C4804: '/': unsafe use of type 'bool' in operation
|
||||
return (float)dist(gen) / (1.f - dropProb);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
@ -21,7 +21,7 @@ struct QuantizeNodeOp : public UnaryNodeOp {
|
||||
|
||||
NodeOps backwardOps() {
|
||||
ABORT("Only used for inference");
|
||||
return {NodeOp()};
|
||||
return {NodeOp(0)};
|
||||
}
|
||||
|
||||
const std::string type() { return "quantizeInt16"; }
|
||||
@ -62,7 +62,7 @@ public:
|
||||
|
||||
NodeOps backwardOps() {
|
||||
ABORT("Only used for inference");
|
||||
return {NodeOp()};
|
||||
return {NodeOp(0)};
|
||||
}
|
||||
|
||||
const std::string type() { return "dotInt16"; }
|
||||
@ -106,7 +106,7 @@ public:
|
||||
|
||||
NodeOps backwardOps() {
|
||||
ABORT("Only used for inference");
|
||||
return {NodeOp()};
|
||||
return {NodeOp(0)};
|
||||
}
|
||||
|
||||
const std::string type() { return "affineInt16"; }
|
||||
|
@ -878,7 +878,7 @@ void LayerNormalizationGrad(Tensor gradX_,
|
||||
}
|
||||
}
|
||||
|
||||
void Shift(Tensor out_, Tensor in_, marian::Shape shift, bool invert) {
|
||||
void Shift(Tensor out_, Tensor in_, marian::Shape shift, float padValue, bool invert) {
|
||||
int offset = 0;
|
||||
for(int i = 0; i < shift.size(); ++i)
|
||||
offset += in_->shape().stride(i) * shift[i];
|
||||
@ -892,8 +892,9 @@ void Shift(Tensor out_, Tensor in_, marian::Shape shift, bool invert) {
|
||||
int length = out_->shape().elements();
|
||||
#pragma omp parallel for
|
||||
for(int i = 0; i < length; ++i) {
|
||||
// BUGBUG: This logic is only correct for the outermost axis.
|
||||
if(i - offset < 0 || i - offset >= length) {
|
||||
out[i] = 0.f;
|
||||
out[i] = padValue;
|
||||
} else {
|
||||
out[i] = in[i - offset];
|
||||
}
|
||||
|
@ -15,3 +15,10 @@ template void Add<BinaryFunctor<elem::Mult, Assignee<1>, BinaryFunctor<elem::sPR
|
||||
template void Add<BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::sReLUBack, Assignee<2>>>, marian::Tensor, marian::Tensor>(BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::sReLUBack, Assignee<2>>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
|
||||
template void Add<BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Assignee<1>, Assignee<2>>, BinaryFunctor<elem::Minus, Capture, Assignee<2>>>, marian::Tensor, marian::Tensor>(BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Assignee<1>, Assignee<2>>, BinaryFunctor<elem::Minus, Capture, Assignee<2>>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
|
||||
template void Add<BinaryFunctor<elem::Mult, BinaryFunctor<elem::Bump, Assignee<1>, Capture>, Assignee<2>>, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase> >(BinaryFunctor<elem::Mult, BinaryFunctor<elem::Bump, Assignee<1>, Capture>, Assignee<2> >, float, std::shared_ptr<marian::TensorBase>, marian::Tensor, marian::Tensor);
|
||||
template void Add<BinaryFunctor<elem::Mult, BinaryFunctor<elem::Lt, Assignee<2>, Assignee<3>>, Assignee<1>>, marian::Tensor, marian::Tensor, marian::Tensor>(BinaryFunctor<elem::Mult, BinaryFunctor<elem::Lt, Assignee<2>, Assignee<3>>, Assignee<1>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
|
||||
template void Add<BinaryFunctor<elem::Mult, BinaryFunctor<elem::Geq, Assignee<2>, Assignee<3>>, Assignee<1>>, marian::Tensor, marian::Tensor, marian::Tensor>(BinaryFunctor<elem::Mult, BinaryFunctor<elem::Geq, Assignee<2>, Assignee<3>>, Assignee<1>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
|
||||
template void Add<BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::Logit, BinaryFunctor<elem::Minus, Assignee<3>, Assignee<2>>>>, marian::Tensor, marian::Tensor, marian::Tensor>(BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::Logit, BinaryFunctor<elem::Minus, Assignee<3>, Assignee<2>>>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
|
||||
template void Add<BinaryFunctor<elem::Mult, BinaryFunctor<elem::Gt, Assignee<2>, Assignee<3>>, Assignee<1>>, marian::Tensor, marian::Tensor, marian::Tensor>(BinaryFunctor<elem::Mult, BinaryFunctor<elem::Gt, Assignee<2>, Assignee<3>>, Assignee<1>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
|
||||
template void Add<BinaryFunctor<elem::Mult, BinaryFunctor<elem::Leq, Assignee<2>, Assignee<3>>, Assignee<1>>, marian::Tensor, marian::Tensor, marian::Tensor>(BinaryFunctor<elem::Mult, BinaryFunctor<elem::Leq, Assignee<2>, Assignee<3>>, Assignee<1>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
|
||||
template void Add<BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::Logit, BinaryFunctor<elem::Minus, Assignee<2>, Assignee<3>>>>, marian::Tensor, marian::Tensor, marian::Tensor>(BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::Logit, BinaryFunctor<elem::Minus, Assignee<2>, Assignee<3>>>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
|
||||
template void Add<BinaryFunctor<elem::Div, Assignee<1>, Assignee<2> >, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase> >(BinaryFunctor<elem::Div, Assignee<1>, Assignee<2> >, float, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase>);
|
||||
|
@ -38,3 +38,6 @@ template void Element<Assign<Var<1>, TernaryFunctor<elem::IfThenElse, BinaryFunc
|
||||
template void Element<Assign<Var<1>, TernaryFunctor<elem::IfThenElse, BinaryFunctor<elem::Leq, UnaryFunctor<elem::Abs, Assignee<1> >, Capture>, Capture, Assignee<1> > >>(Assign<Var<1>, TernaryFunctor<elem::IfThenElse, BinaryFunctor<elem::Leq, UnaryFunctor<elem::Abs, Assignee<1> >, Capture>, Capture, Assignee<1> > >, marian::Tensor);
|
||||
template void Element<Assign<Var<1>, TernaryFunctor<elem::IfThenElse, BinaryFunctor<elem::Leq, UnaryFunctor<elem::Abs, Assignee<2> >, Capture>, Capture, Capture> >, marian::Tensor >(Assign<Var<1>, TernaryFunctor<elem::IfThenElse, BinaryFunctor<elem::Leq, UnaryFunctor<elem::Abs, Assignee<2> >, Capture>, Capture, Capture> >, marian::Tensor, marian::Tensor);
|
||||
template void Element<Assign<Var<1>, BinaryFunctor<elem::Clip, Assignee<2>, Capture>>, marian::Tensor>(Assign<Var<1>, BinaryFunctor<elem::Clip, Assignee<2>, Capture>>, marian::Tensor, marian::Tensor);
|
||||
template void Element<Assign<Var<1>, BinaryFunctor<elem::LogSum, Assignee<2>, Assignee<3>>>, marian::Tensor, marian::Tensor>(Assign<Var<1>, BinaryFunctor<elem::LogSum, Assignee<2>, Assignee<3>>>, marian::Tensor, marian::Tensor, marian::Tensor);
|
||||
template void Element<Assign<Var<1>, BinaryFunctor<elem::Max, Assignee<2>, Assignee<3>>>, marian::Tensor, marian::Tensor>(Assign<Var<1>, BinaryFunctor<elem::Max, Assignee<2>, Assignee<3>>>, marian::Tensor, marian::Tensor, marian::Tensor);
|
||||
template void Element<Assign<Var<1>, BinaryFunctor<elem::Min, Assignee<2>, Assignee<3>>>, marian::Tensor, marian::Tensor>(Assign<Var<1>, BinaryFunctor<elem::Min, Assignee<2>, Assignee<3>>>, marian::Tensor, marian::Tensor, marian::Tensor);
|
||||
|
@ -1558,19 +1558,19 @@ void LayerNormalizationGrad(Tensor gradX,
|
||||
eps);
|
||||
}
|
||||
|
||||
__global__ void gShift(float* out, const float* in, int length, int offset) {
|
||||
__global__ void gShift(float* out, const float* in, int length, int offset, float padValue) {
|
||||
for(int bid = 0; bid < length; bid += blockDim.x * gridDim.x) {
|
||||
int index = bid + blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if(index < length) {
|
||||
if(index - offset < 0 || index - offset >= length)
|
||||
out[index] = 0;
|
||||
out[index] = padValue;
|
||||
else
|
||||
out[index] = in[index - offset];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Shift(Tensor out, Tensor in, marian::Shape shift, bool invert) {
|
||||
void Shift(Tensor out, Tensor in, marian::Shape shift, float padValue, bool invert) {
|
||||
ABORT_IF(in->shape().size() != shift.size(), "bad dimensions");
|
||||
|
||||
// BUGBUG: This can only shift along the first axis. Shifting, e.g., along the last axis cannot be implemented this way.
|
||||
@ -1588,7 +1588,7 @@ void Shift(Tensor out, Tensor in, marian::Shape shift, bool invert) {
|
||||
int threads = std::min(MAX_THREADS, length);
|
||||
int blocks = std::min(MAX_BLOCKS, length / threads + (length % threads != 0));
|
||||
|
||||
gShift<<<blocks, threads>>>(out->data(), in->data(), length, offset);
|
||||
gShift<<<blocks, threads>>>(out->data(), in->data(), length, offset, padValue);
|
||||
}
|
||||
|
||||
__global__ void gSetSparse(float* out,
|
||||
|
@ -78,7 +78,7 @@ void Reduce(Functor functor, marian::Tensor out, Tensors... tensors) {
|
||||
DISPATCH4(CrossEntropyPickBackward, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor)
|
||||
|
||||
DISPATCH3(TransposeND, marian::Tensor, marian::Tensor, const std::vector<int>&)
|
||||
DISPATCH4(Shift, marian::Tensor, marian::Tensor, marian::Shape, bool)
|
||||
DISPATCH5(Shift, marian::Tensor, marian::Tensor, marian::Shape, float, bool)
|
||||
|
||||
DISPATCH3(Concatenate, marian::Tensor, const std::vector<marian::Tensor>&, int)
|
||||
// clang-format on
|
||||
|
@ -2,7 +2,9 @@
|
||||
|
||||
#include "common/config.h"
|
||||
#include "data/batch_generator.h"
|
||||
#if SQLITE_FOUND
|
||||
#include "data/corpus_sqlite.h"
|
||||
#endif
|
||||
#include "models/model_task.h"
|
||||
#include "training/scheduler.h"
|
||||
#include "training/validator.h"
|
||||
@ -21,9 +23,11 @@ public:
|
||||
using namespace data;
|
||||
|
||||
Ptr<CorpusBase> dataset;
|
||||
#if SQLITE_FOUND
|
||||
if(!options_->get<std::string>("sqlite").empty())
|
||||
dataset = New<CorpusSQLite>(options_);
|
||||
else
|
||||
#endif
|
||||
dataset = New<Corpus>(options_);
|
||||
|
||||
dataset->prepare();
|
||||
|
@ -132,7 +132,7 @@ public:
|
||||
bool final = false;
|
||||
|
||||
for(int i = 0; i < dimBatch; ++i)
|
||||
histories[i]->Add(beams[i]);
|
||||
histories[i]->Add(beams[i], trgEosId_);
|
||||
|
||||
std::vector<Ptr<ScorerState>> states;
|
||||
|
||||
@ -181,6 +181,7 @@ public:
|
||||
//**********************************************************************
|
||||
// prepare costs for beam search
|
||||
auto totalCosts = prevCosts;
|
||||
// BUGBUG: it's not cost but score (higher=better)
|
||||
|
||||
for(int i = 0; i < scorers_.size(); ++i) {
|
||||
states[i] = scorers_[i]->step(
|
||||
@ -191,6 +192,7 @@ public:
|
||||
= totalCosts + scorers_[i]->getWeight() * states[i]->getProbs();
|
||||
else
|
||||
totalCosts = totalCosts + states[i]->getProbs();
|
||||
// BUGBUG: getProbs() -> getLogProbs(); totalCosts -> totalScores (higher=better)
|
||||
}
|
||||
|
||||
// make beams continuous
|
||||
@ -226,7 +228,7 @@ public:
|
||||
if(!beams[i].empty()) {
|
||||
final = final
|
||||
|| histories[i]->size() >= options_->get<float>("max-length-factor") * batch->front()->batchWidth();
|
||||
histories[i]->Add(beams[i], prunedBeams[i].empty() || final);
|
||||
histories[i]->Add(beams[i], trgEosId_, prunedBeams[i].empty() || final);
|
||||
}
|
||||
}
|
||||
beams = prunedBeams;
|
||||
|
@ -1,8 +1,9 @@
|
||||
#pragma once
|
||||
|
||||
#include <queue>
|
||||
|
||||
#include "hypothesis.h"
|
||||
#include "data/types.h"
|
||||
|
||||
#include <queue>
|
||||
|
||||
namespace marian {
|
||||
|
||||
@ -22,10 +23,10 @@ public:
|
||||
float LengthPenalty(size_t length) { return std::pow((float)length, alpha_); }
|
||||
float WordPenalty(size_t length) { return wp_ * (float)length; }
|
||||
|
||||
void Add(const Beam& beam, bool last = false) {
|
||||
void Add(const Beam& beam, Word trgEosId, bool last = false) {
|
||||
if(beam.back()->GetPrevHyp() != nullptr) {
|
||||
for(size_t j = 0; j < beam.size(); ++j)
|
||||
if(beam[j]->GetWord() == 0 || last) {
|
||||
if(beam[j]->GetWord() == trgEosId || last) {
|
||||
float cost = (beam[j]->GetCost() - WordPenalty(history_.size())) / LengthPenalty(history_.size());
|
||||
topHyps_.push({history_.size(), j, cost});
|
||||
// std::cerr << "Add " << history_.size() << " " << j << " " << cost
|
||||
|
@ -22,7 +22,7 @@ void Printer(Ptr<Config> options,
|
||||
for(size_t i = 0; i < nbl.size(); ++i) {
|
||||
const auto& result = nbl[i];
|
||||
const auto& words = std::get<0>(result);
|
||||
const auto& hypo = std::get<1>(result);
|
||||
const auto& hypo = std::get<1>(result);
|
||||
|
||||
float realCost = std::get<2>(result);
|
||||
|
||||
|
@ -111,7 +111,11 @@
|
||||
</ItemDefinitionGroup>
|
||||
<ItemGroup>
|
||||
<ClCompile Include="..\src\3rd_party\yaml-cpp\yaml-node.cpp" />
|
||||
<ClCompile Include="..\src\command\marian_decoder.cpp" />
|
||||
<ClCompile Include="..\src\command\marian-main.cpp" />
|
||||
<ClInclude Include="..\src\command\marian.cpp" />
|
||||
<ClInclude Include="..\src\command\marian_decoder.cpp" />
|
||||
<ClInclude Include="..\src\command\marian_scorer.cpp" />
|
||||
<ClInclude Include="..\src\command\marian_vocab.cpp" />
|
||||
<ClCompile Include="..\src\common\utils.cpp" />
|
||||
<ClCompile Include="..\src\common\logging.cpp" />
|
||||
<ClCompile Include="..\src\common\config.cpp" />
|
||||
@ -386,8 +390,9 @@
|
||||
<ClInclude Include="..\src\models\nematus.h" />
|
||||
<ClInclude Include="..\src\models\s2s.h" />
|
||||
<ClInclude Include="..\src\models\states.h" />
|
||||
<ClInclude Include="..\src\models\transformer.h" />
|
||||
<ClCompile Include="..\src\models\transformer.h" />
|
||||
<ClInclude Include="..\src\models\experimental\lex_probs.h" />
|
||||
<ClInclude Include="..\src\models\transformer_factory.h" />
|
||||
<ClInclude Include="..\src\optimizers\clippers.h" />
|
||||
<ClInclude Include="..\src\optimizers\optimizers.h" />
|
||||
<ClInclude Include="..\src\rescorer\rescorer.h" />
|
||||
|
@ -205,9 +205,6 @@
|
||||
<ClCompile Include="..\src\3rd_party\yaml-cpp\contrib\graphbuilderadapter.cpp">
|
||||
<Filter>3rd_party\yaml-cpp\contrib</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="..\src\command\marian_decoder.cpp">
|
||||
<Filter>command</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="..\src\3rd_party\yaml-cpp\yaml-node.cpp">
|
||||
<Filter>3rd_party\yaml-cpp</Filter>
|
||||
</ClCompile>
|
||||
@ -217,6 +214,18 @@
|
||||
<ClCompile Include="..\src\tensors\cpu\sharp\sse_gemm.cpp">
|
||||
<Filter>tensors\cpu\sharp</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="..\src\models\transformer.h">
|
||||
<Filter>models</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="..\src\command\marian-main.cpp">
|
||||
<Filter>command</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="..\src\command\marian_decoder.cpp">
|
||||
<Filter>command</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="..\src\command\marian.cpp">
|
||||
<Filter>command</Filter>
|
||||
</ClCompile>
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<ClInclude Include="..\src\marian.h" />
|
||||
@ -805,9 +814,6 @@
|
||||
<ClInclude Include="..\src\models\states.h">
|
||||
<Filter>models</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="..\src\models\transformer.h">
|
||||
<Filter>models</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="..\src\models\experimental\lex_probs.h">
|
||||
<Filter>models\experimental</Filter>
|
||||
</ClInclude>
|
||||
@ -967,6 +973,15 @@
|
||||
<ClInclude Include="..\src\tensors\cpu\sharp\int_gemm.h">
|
||||
<Filter>tensors\cpu\sharp</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="..\src\models\transformer_factory.h">
|
||||
<Filter>models</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="..\src\command\marian_scorer.cpp">
|
||||
<Filter>command</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="..\src\command\marian_vocab.cpp">
|
||||
<Filter>command</Filter>
|
||||
</ClInclude>
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<Filter Include="3rd_party">
|
||||
|
Loading…
Reference in New Issue
Block a user