merged from master

This commit is contained in:
Frank Seide 2019-03-13 09:14:49 -07:00
commit ff59ebd12b
26 changed files with 412 additions and 328 deletions

@ -1 +1 @@
Subproject commit febdc3f56f75929b1f7b5b38a4b9b96ea8f648e7
Subproject commit 142eadddbe04493c1024b42586030b72e9cb7ea2

2
src/common/options.h Normal file → Executable file
View File

@ -111,7 +111,7 @@ public:
}
try {
return !options_[key].as<std::string>().empty();
} catch(const YAML::BadConversion& e) {
} catch(const YAML::BadConversion&) {
ABORT("Option '{}' is neither a sequence nor a text");
}
return false;

66
src/graph/expression_operators.cpp Normal file → Executable file
View File

@ -135,31 +135,49 @@ Expr le(Expr a, float b) { return Expression<CmpNodeOp>(a, a->graph()->constant(
/*********************************************************/
Expr operator+(Expr a, float b) {
return Expression<ScalarAddNodeOp>(a, b);
if (b == 0)
return a;
else
return Expression<ScalarAddNodeOp>(a, b);
}
Expr operator+(float a, Expr b) {
return Expression<ScalarAddNodeOp>(b, a);
if (a == 0)
return b;
else
return Expression<ScalarAddNodeOp>(b, a);
}
Expr operator-(Expr a, float b) {
return Expression<ScalarAddNodeOp>(a, -b);
if (b == 0)
return a;
else
return Expression<ScalarAddNodeOp>(a, -b);
}
Expr operator-(float a, Expr b) {
return Expression<ScalarAddNodeOp>(-b, a);
if (a == 0)
return -b;
else
return Expression<ScalarAddNodeOp>(-b, a);
}
Expr operator*(float a, Expr b) {
return Expression<ScalarMultNodeOp>(b, a);
if (a == 1.0f)
return b;
else
return Expression<ScalarMultNodeOp>(b, a);
}
Expr operator*(Expr a, float b) {
return Expression<ScalarMultNodeOp>(a, b);
if (b == 1.0f)
return a;
else
return Expression<ScalarMultNodeOp>(a, b);
}
Expr operator/(Expr a, float b) {
return Expression<ScalarMultNodeOp>(a, 1.f / b);
return a * (1.f / b);
}
// TODO: efficient version of this without constant()
@ -195,6 +213,8 @@ Expr repeat(Expr a, size_t repeats, int ax) {
}
Expr reshape(Expr a, Shape shape) {
if (a->shape() == shape)
return a;
return Expression<ReshapeNodeOp>(a, shape);
}
@ -238,7 +258,7 @@ Expr flatten_2d(Expr a) {
Expr stopGradient(Expr a) {
// implemented as a dummy reshape that is not trainable
auto res = reshape(a, a->shape());
auto res = Expression<ReshapeNodeOp>(a, a->shape());
res->setTrainable(false);
return res;
}
@ -254,7 +274,12 @@ Expr gather(Expr a, int axis, Expr indices) {
return Expression<GatherNodeOp>(a, axis, indices);
}
// index_select() -- gather arbitrary elements along an axis; unbatched (indices are specified as a 1D vector)
// index_select() -- gather arbitrary elements along an axis from an unbatched
// input 'a'. Indices are specified as a 1D vector.
// This is used e.g. for embedding lookup.
// Note: To use a batch of index vectors, reshape them into a single vector,
// call index_select(), then reshape the result back. Reshapes are cheap.
// This function has the same semantics as PyTorch operation of the same name.
Expr index_select(Expr a, int axis, Expr indices) {
ABORT_IF(indices->shape().size() != 1, "Indices must be a 1D tensor");
// We have specialized kernels for non-batched indexing of first or last axis of a 2D tensor.
@ -507,13 +532,28 @@ Expr transpose(Expr a, const std::vector<int>& axes) {
Expr swapAxes(Expr x, int axis1, int axis2)
{
axis1 = x->shape().axis(axis1);
axis2 = x->shape().axis(axis2);
const auto& shape = x->shape();
axis1 = shape.axis(axis1);
axis2 = shape.axis(axis2);
if (axis1 == axis2)
return x;
if (shape[axis1] == 1 || shape[axis2] == 1) { // can we use a reshape instead?
if (axis1 > axis2)
std::swap(axis1, axis2);
bool canReshape = true;
for (int ax = axis1 + 1; ax < axis2 && canReshape; ax++)
canReshape &= (shape[ax] == 1);
if (canReshape) {
auto newShape = shape;
newShape.set(axis1, shape[axis2]);
newShape.set(axis2, shape[axis1]);
//LOG(info, "SwapAxes() did a reshape from {} to {}", shape.toString(), newShape.toString());
return reshape(x, newShape);
}
}
// TODO: This is code dup from transpose(x). Implement transpose(x) as swapAxes(x, 0, 1)
std::vector<int> axes(x->shape().size());
for (int i = 0; i < axes.size(); ++i)
std::vector<int> axes(shape.size());
for (int i = 0; i < axes.size(); ++i) // @TODO: use std::iota()
axes[i] = i;
std::swap(axes[axis1], axes[axis2]);
return transpose(x, axes);

View File

@ -717,6 +717,8 @@ private:
public:
ReshapeNodeOp(Expr a, Shape shape) : UnaryNodeOp(a, shape, a->value_type()), reshapee_(a) {
ABORT_IF(a->shape().elements() != shape.elements(),
"Reshape must not change the number of elements (from {} to {})", a->shape().toString(), shape.toString());
Node::destroy_ = false;
}

4
src/microsoft/quicksand.cpp Normal file → Executable file
View File

@ -129,7 +129,7 @@ public:
QSNBestBatch qsNbestBatch;
for(const auto& history : histories) { // loop over batch entries
QSNBest qsNbest;
NBestList nbestHyps = history->NBest(SIZE_MAX); // request as many N as we have
NBestList nbestHyps = history->nBest(SIZE_MAX); // request as many N as we have
for (const Result& result : nbestHyps) { // loop over N-best entries
// get hypothesis word sequence and normalized sentence score
auto words = std::get<0>(result);
@ -147,7 +147,7 @@ public:
else
alignmentThreshold = std::max(std::stof(alignment), 0.f);
auto hyp = std::get<1>(result);
data::WordAlignment align = data::ConvertSoftAlignToHardAlign(hyp->TracebackAlignment(), alignmentThreshold);
data::WordAlignment align = data::ConvertSoftAlignToHardAlign(hyp->tracebackAlignment(), alignmentThreshold);
// convert to QuickSAND format
alignmentSets.resize(words.size());
for (const auto& p : align) // @TODO: Does the feature_model param max_alignment_links apply here?

4
src/models/decoder.h Normal file → Executable file
View File

@ -74,7 +74,7 @@ public:
auto yShifted = shift(y, {1, 0, 0});
state->setTargetEmbeddings(yShifted);
state->setTargetHistoryEmbeddings(yShifted);
state->setTargetMask(yMask);
state->setTargetWords(data);
}
@ -105,7 +105,7 @@ public:
selectedEmbs = yEmb->apply(words, {dimBeam, 1, dimBatch, dimTrgEmb});
}
state->setTargetEmbeddings(selectedEmbs);
state->setTargetHistoryEmbeddings(selectedEmbs);
}
virtual const std::vector<Expr> getAlignments(int /*i*/ = 0) { return {}; };

3
src/models/encoder_decoder.cpp Normal file → Executable file
View File

@ -171,8 +171,7 @@ Ptr<DecoderState> EncoderDecoder::step(Ptr<ExpressionGraph> graph,
state = hypIndices.empty() ? state : state->select(hypIndices, beamSize);
// Fill stte with embeddings based on last prediction
decoders_[0]->embeddingsFromPrediction(
graph, state, words, dimBatch, beamSize);
decoders_[0]->embeddingsFromPrediction(graph, state, words, dimBatch, beamSize);
auto nextState = decoders_[0]->step(graph, state);
return nextState;

2
src/models/s2s.h Normal file → Executable file
View File

@ -283,7 +283,7 @@ public:
virtual Ptr<DecoderState> step(Ptr<ExpressionGraph> graph,
Ptr<DecoderState> state) override {
auto embeddings = state->getTargetEmbeddings();
auto embeddings = state->getTargetHistoryEmbeddings();
// dropout target words
float dropoutTrg = inference_ ? 0 : opt<float>("dropout-trg");

10
src/models/states.h Normal file → Executable file
View File

@ -33,7 +33,7 @@ protected:
std::vector<Ptr<EncoderState>> encStates_;
Ptr<data::CorpusBatch> batch_;
Expr targetEmbeddings_;
Expr targetHistoryEmbeddings_; // decoder history (teacher-forced or from decoding), embedded
Expr targetMask_;
Words targetWords_;
@ -69,17 +69,13 @@ public:
virtual const rnn::States& getStates() const { return states_; }
virtual Expr getTargetEmbeddings() const { return targetEmbeddings_; };
virtual void setTargetEmbeddings(Expr targetEmbeddings) {
targetEmbeddings_ = targetEmbeddings;
}
virtual Expr getTargetHistoryEmbeddings() const { return targetHistoryEmbeddings_; };
virtual void setTargetHistoryEmbeddings(Expr targetEmbeddings) { targetHistoryEmbeddings_ = targetEmbeddings; }
virtual const Words& getTargetWords() const { return targetWords_; };
virtual void setTargetWords(const Words& targetWords) { targetWords_ = targetWords; }
virtual Expr getTargetMask() const { return targetMask_; };
virtual void setTargetMask(Expr targetMask) { targetMask_ = targetMask; }
virtual const Words& getSourceWords() const {

View File

@ -688,8 +688,8 @@ public:
}
Ptr<DecoderState> step(Ptr<DecoderState> state) {
auto embeddings = state->getTargetEmbeddings(); // [-4: beam depth=1, -3: max length, -2: batch size, -1: vector dim]
auto decoderMask = state->getTargetMask(); // [max length, batch size, 1] --this is a hypothesis
auto embeddings = state->getTargetHistoryEmbeddings(); // [-4: beam depth=1, -3: max length, -2: batch size, -1: vector dim]
auto decoderMask = state->getTargetMask(); // [max length, batch size, 1] --this is a hypothesis
// dropout target words
float dropoutTrg = inference_ ? 0 : opt<float>("dropout-trg");

0
src/models/transformer_factory.h Normal file → Executable file
View File

View File

@ -1,6 +1,3 @@
// TODO: This is a wrapper around transformer.h. We kept the .H name to minimize confusing git, until this is code-reviewed.
// This is meant to speed-up builds, and to support Ctrl-F7 to rebuild.
#include "models/transformer.h"
namespace marian {

1
src/tensors/cpu/tensor_operators.cpp Normal file → Executable file
View File

@ -15,6 +15,7 @@ namespace marian {
namespace cpu {
void IsNan(const Tensor in, Ptr<Allocator> allocator, bool& isNan, bool& isInf, bool zero) {
isNan; isInf; zero;
ABORT("Not implemented");
}

6
src/training/validator.h Normal file → Executable file
View File

@ -535,7 +535,7 @@ public:
std::stringstream best1;
std::stringstream bestn;
printer->print(history, best1, bestn);
collector->Write((long)history->GetLineNum(),
collector->Write((long)history->getLineNum(),
best1.str(),
bestn.str(),
options_->get<bool>("n-best"));
@ -677,14 +677,14 @@ public:
size_t no = 0;
std::lock_guard<std::mutex> statsLock(mutex_);
for(auto history : histories) {
auto result = history->Top();
auto result = history->top();
const auto& words = std::get<0>(result);
updateStats(stats, words, batch, no, vocabs_.back()->getEosId());
std::stringstream best1;
std::stringstream bestn;
printer->print(history, best1, bestn);
collector->Write((long)history->GetLineNum(),
collector->Write((long)history->getLineNum(),
best1.str(),
bestn.str(),
/*nbest=*/ false);

341
src/translator/beam_search.h Normal file → Executable file
View File

@ -18,6 +18,8 @@ private:
Word trgEosId_{Word::NONE};
Word trgUnkId_{Word::NONE};
static constexpr auto INVALID_PATH_SCORE = -9999;
public:
BeamSearch(Ptr<Options> options,
const std::vector<Ptr<Scorer>>& scorers,
@ -31,75 +33,72 @@ public:
trgEosId_(trgEosId),
trgUnkId_(trgUnkId) {}
Beams toHyps(const std::vector<unsigned int> keys,
const std::vector<float> pathScores,
size_t vocabSize,
// combine new expandedPathScores and previous beams into new set of beams
Beams toHyps(const std::vector<unsigned int>& nBestKeys, // [dimBatch, beamSize] flattened -> ((batchIdx, beamHypIdx) flattened, word idx) flattened
const std::vector<float>& nBestPathScores, // [dimBatch, beamSize] flattened
const size_t inputBeamSize, // for interpretation of nBestKeys
const size_t vocabSize, // ditto.
const Beams& beams,
std::vector<Ptr<ScorerState>>& states,
size_t beamSize,
bool first,
Ptr<data::CorpusBatch> batch) {
Beams newBeams(beams.size());
const std::vector<Ptr<ScorerState /*const*/>>& states,
Ptr<data::CorpusBatch /*const*/> batch) const {
std::vector<float> align;
if(options_->hasAndNotEmpty("alignment"))
// Use alignments from the first scorer, even if ensemble
align = scorers_[0]->getAlignment();
align = scorers_[0]->getAlignment(); // use alignments from the first scorer, even if ensemble
for(size_t i = 0; i < keys.size(); ++i) {
// Keys contains indices to vocab items in the entire beam.
// Values can be between 0 and beamSize * vocabSize.
Word embIdx = Word::fromWordIndex(keys[i] % vocabSize);
auto beamIdx = i / beamSize;
const auto dimBatch = beams.size();
Beams newBeams(dimBatch);
// Retrieve short list for final softmax (based on words aligned
// to source sentences). If short list has been set, map the indices
// in the sub-selected vocabulary matrix back to their original positions.
for(size_t i = 0; i < nBestKeys.size(); ++i) {
// Keys encode batchIdx, beamHypIdx, and word index in the entire beam.
// They can be between 0 and beamSize * vocabSize-1.
const auto key = nBestKeys[i];
const float pathScore = nBestPathScores[i]; // expanded path score for (batchIdx, beamHypIdx, word)
// decompose key into individual indices (batchIdx, beamHypIdx, wordIdx)
const auto wordIdx = (Word)(key % vocabSize);
const auto beamHypIdx = (key / vocabSize) % inputBeamSize;
const auto batchIdx = (key / vocabSize) / inputBeamSize;
const auto& beam = beams[batchIdx];
auto& newBeam = newBeams[batchIdx];
if (newBeam.size() >= beam.size()) // @TODO: Why this condition? It does happen. Why?
continue;
if (pathScore <= INVALID_PATH_SCORE) // (unused slot)
continue;
ABORT_IF(beamHypIdx >= beam.size(), "Out of bounds beamHypIdx??");
// Map wordIdx to word
Word word;
// If short list has been set, then wordIdx is an index into the short-listed word set,
// rather than the true word index.
auto shortlist = scorers_[0]->getShortlist();
if(shortlist)
embIdx = Word::fromWordIndex(shortlist->reverseMap(embIdx.toWordIndex())); // @TODO: should reverseMap accept a size_t or a Word?
if (shortlist)
word = shortlist->reverseMap(wordIdx);
else
word = wordIdx;
if(newBeams[beamIdx].size() < beams[beamIdx].size()) {
auto& beam = beams[beamIdx];
auto& newBeam = newBeams[beamIdx];
auto hyp = New<Hypothesis>(beam[beamHypIdx], word, (IndexType)beamHypIdx, pathScore);
auto hypIdx = (IndexType)(keys[i] / vocabSize);
float pathScore = pathScores[i];
auto hypIdxTrans
= IndexType((hypIdx / beamSize) + (hypIdx % beamSize) * beams.size());
if(first)
hypIdxTrans = hypIdx;
size_t beamHypIdx = hypIdx % beamSize;
if(beamHypIdx >= (int)beam.size())
beamHypIdx = beamHypIdx % beam.size();
if(first)
beamHypIdx = 0;
auto hyp = New<Hypothesis>(beam[beamHypIdx], embIdx, hypIdxTrans, pathScore);
// Set score breakdown for n-best lists
if(options_->get<bool>("n-best")) {
std::vector<float> breakDown(states.size(), 0);
beam[beamHypIdx]->GetScoreBreakdown().resize(states.size(), 0);
for(size_t j = 0; j < states.size(); ++j) {
size_t key = embIdx.toWordIndex() + hypIdxTrans * vocabSize;
breakDown[j] = states[j]->breakDown(key)
+ beam[beamHypIdx]->GetScoreBreakdown()[j];
}
hyp->GetScoreBreakdown() = breakDown;
// Set score breakdown for n-best lists
if(options_->get<bool>("n-best")) {
std::vector<float> breakDown(states.size(), 0);
beam[beamHypIdx]->getScoreBreakdown().resize(states.size(), 0); // @TODO: Why? Can we just guard the read-out below, then make it const? Or getScoreBreakdown(j)?
for(size_t j = 0; j < states.size(); ++j) {
size_t flattenedLogitIndex = (beamHypIdx * dimBatch + batchIdx) * vocabSize + wordIdx; // (beam idx, batch idx, word idx); note: beam and batch are transposed, compared to 'key'
breakDown[j] = states[j]->breakDown(flattenedLogitIndex) + beam[beamHypIdx]->getScoreBreakdown()[j];
// @TODO: pass those 3 indices directly into breakDown (state knows the dimensions)
}
// Set alignments
if(!align.empty()) {
hyp->SetAlignment(
getAlignmentsForHypothesis(align, batch, (int)beamHypIdx, (int)beamIdx));
}
newBeam.push_back(hyp);
hyp->setScoreBreakdown(breakDown);
}
// Set alignments
if(!align.empty()) {
hyp->setAlignment(getAlignmentsForHypothesis(align, batch, (int)beamHypIdx, (int)batchIdx));
}
newBeam.push_back(hyp);
}
return newBeams;
}
@ -108,7 +107,7 @@ public:
const std::vector<float> alignAll,
Ptr<data::CorpusBatch> batch,
int beamHypIdx,
int beamIdx) {
int beamIdx) const {
// Let's B be the beam size, N be the number of batched sentences,
// and L the number of words in the longest sentence in the batch.
// The alignment vector:
@ -140,12 +139,13 @@ public:
return align;
}
Beams pruneBeam(const Beams& beams) {
// remove all beam entries that have reached EOS
Beams purgeBeams(const Beams& beams) {
Beams newBeams;
for(auto beam : beams) {
Beam newBeam;
for(auto hyp : beam) {
if(hyp->GetWord() != trgEosId_) {
if(hyp->getWord() != trgEosId_) {
newBeam.push_back(hyp);
}
}
@ -154,154 +154,171 @@ public:
return newBeams;
}
//**********************************************************************
// main decoding function
Histories search(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> batch) {
int dimBatch = (int)batch->size();
ABORT_IF(batch->back()->vocab() && batch->back()->vocab()->getEosId() != trgEosId_,
"Batch uses different EOS token than was passed to BeamSearch originally");
Histories histories;
for(int i = 0; i < dimBatch; ++i) {
size_t sentId = batch->getSentenceIds()[i];
auto history = New<History>(sentId,
options_->get<float>("normalize"),
options_->get<float>("word-penalty"));
histories.push_back(history);
}
const int dimBatch = (int)batch->size();
size_t localBeamSize = beamSize_; // max over beam sizes of active sentence hypotheses
auto getNBestList = createGetNBestListFn(localBeamSize, dimBatch, graph->getDeviceId());
Beams beams(dimBatch); // [batchIndex][beamIndex] is one sentence hypothesis
for(auto& beam : beams)
beam.resize(localBeamSize, New<Hypothesis>());
bool first = true;
bool final = false;
for(int i = 0; i < dimBatch; ++i)
histories[i]->Add(beams[i], trgEosId_);
std::vector<Ptr<ScorerState>> states;
auto getNBestList = createGetNBestListFn(beamSize_, dimBatch, graph->getDeviceId());
for(auto scorer : scorers_) {
scorer->clear(graph);
}
Histories histories(dimBatch);
for(int i = 0; i < dimBatch; ++i) {
size_t sentId = batch->getSentenceIds()[i];
histories[i] = New<History>(sentId,
options_->get<float>("normalize"),
options_->get<float>("word-penalty"));
}
// start states
std::vector<Ptr<ScorerState>> states;
for(auto scorer : scorers_) {
states.push_back(scorer->startState(graph, batch));
}
// main loop over output tokens
do {
Beams beams(dimBatch, Beam(beamSize_, New<Hypothesis>())); // array [dimBatch] of array [localBeamSize] of Hypothesis
//Beams beams(dimBatch); // array [dimBatch] of array [localBeamSize] of Hypothesis
//for(auto& beam : beams)
// beam.resize(beamSize_, New<Hypothesis>());
for(int i = 0; i < dimBatch; ++i)
histories[i]->add(beams[i], trgEosId_);
// the decoder updates the following state information in each output time step:
// - beams: array [dimBatch] of array [localBeamSize] of Hypothesis
// - current output time step's set of active hypotheses, aka active search space
// - states[.]: ScorerState
// - NN state; one per scorer, e.g. 2 for ensemble of 2
// and it forms the following return value
// - histories: array [dimBatch] of History
// with History: vector [t] of array [localBeamSize] of Hypothesis
// with Hypothesis: (last word, aggregate score, prev Hypothesis)
// main loop over output time steps
for (size_t t = 0; ; t++) {
ABORT_IF(dimBatch != beams.size(), "Lost a batch entry??");
// determine beam size for next output time step, as max over still-active sentences
// E.g. if all batch entries are down from beam 5 to no more than 4 surviving hyps, then
// switch to beam of 4 for all. If all are done, then beam ends up being 0, and we are done.
size_t localBeamSize = 0; // @TODO: is there some std::algorithm for this?
for(auto& beam : beams)
if(beam.size() > localBeamSize)
localBeamSize = beam.size();
// done if all batch entries have reached EOS on all beam entries
if (localBeamSize == 0)
break;
//**********************************************************************
// create constant containing previous path scores for current beam
// also create mapping of hyp indices, which are not 1:1 if sentences complete
std::vector<IndexType> hypIndices; // [beamIndex * activeBatchSize + batchIndex] backpointers, concatenated over beam positions. Used for reordering hypotheses
Words predWords;
Expr prevPathScores; // [beam, 1, 1, 1]
if(first) {
// no scores yet
// Also create mapping of hyp indices, for reordering the decoder-state tensors.
std::vector<IndexType> hypIndices; // [localBeamsize, 1, dimBatch, 1] (flattened) tensor index ((beamHypIdx, batchIdx), flattened) of prev hyp that a hyp originated from
std::vector<Word> prevWords; // [localBeamsize, 1, dimBatch, 1] (flattened) word that a hyp ended in, for advancing the decoder-model's history
Expr prevPathScores; // [localBeamSize, 1, dimBatch, 1], path score that a hyp ended in (last axis will broadcast into vocab size when adding expandedPathScores)
if(t == 0) { // no scores yet
prevPathScores = graph->constant({1, 1, 1, 1}, inits::from_value(0));
} else {
std::vector<float> beamScores;
dimBatch = (int)batch->size();
for(size_t i = 0; i < localBeamSize; ++i) {
for(size_t j = 0; j < beams.size(); ++j) { // loop over batch entries (active sentences)
auto& beam = beams[j];
if(i < beam.size()) {
auto hyp = beam[i];
hypIndices.push_back((IndexType)hyp->GetPrevStateIndex()); // backpointer
predWords.push_back(hyp->GetWord());
beamScores.push_back(hyp->GetPathScore());
} else { // dummy hypothesis
std::vector<float> prevScores;
for(size_t beamHypIdx = 0; beamHypIdx < localBeamSize; ++beamHypIdx) {
for(int batchIdx = 0; batchIdx < dimBatch; ++batchIdx) { // loop over batch entries (active sentences)
auto& beam = beams[batchIdx];
if(beamHypIdx < beam.size()) {
auto hyp = beam[beamHypIdx];
hypIndices.push_back((IndexType)(hyp->getPrevStateIndex() * dimBatch + batchIdx)); // (beamHypIdx, batchIdx), flattened, for index_select() operation
prevWords .push_back(hyp->getWord());
prevScores.push_back(hyp->getPathScore());
} else { // pad to localBeamSize (dummy hypothesis)
hypIndices.push_back(0);
predWords.push_back(Word::ZERO); // (unused)
beamScores.push_back(-9999);
prevWords.push_back(trgEosId_); // (unused, but let's use a valid value)
prevScores.push_back((float)INVALID_PATH_SCORE);
}
}
}
prevPathScores = graph->constant({(int)localBeamSize, 1, dimBatch, 1},
inits::from_vector(beamScores));
prevPathScores = graph->constant({(int)localBeamSize, 1, dimBatch, 1}, inits::from_vector(prevScores));
}
//**********************************************************************
// prepare scores for beam search
auto pathScores = prevPathScores;
// compute expanded path scores with word prediction probs from all scorers
auto expandedPathScores = prevPathScores; // will become [localBeamSize, 1, dimBatch, dimVocab]
Expr logProbs;
for(size_t i = 0; i < scorers_.size(); ++i) {
states[i] = scorers_[i]->step(
graph, states[i], hypIndices, predWords, dimBatch, (int)localBeamSize);
if(scorers_[i]->getWeight() != 1.f)
pathScores = pathScores + scorers_[i]->getWeight() * states[i]->getLogProbs();
else
pathScores = pathScores + states[i]->getLogProbs();
// compute output probabilities for current output time step
// - uses hypIndices[index in beam, 1, batch index, 1] to reorder scorer state to reflect the top-N in beams[][]
// - adds prevWords [index in beam, 1, batch index, 1] to the scorer's target history
// - performs one step of the scorer
// - returns new NN state for use in next output time step
// - returns vector of prediction probabilities over output vocab via newState
states[i] = scorers_[i]->step(graph, states[i], hypIndices, prevWords, dimBatch, (int)localBeamSize);
logProbs = states[i]->getLogProbs(); // [localBeamSize, 1, dimBatch, dimVocab]
// expand all hypotheses, [localBeamSize, 1, dimBatch, 1] -> [localBeamSize, 1, dimBatch, dimVocab]
expandedPathScores = expandedPathScores + scorers_[i]->getWeight() * logProbs;
}
// make beams continuous
if(dimBatch > 1 && localBeamSize > 1)
pathScores = transpose(pathScores, {2, 1, 0, 3});
expandedPathScores = swapAxes(expandedPathScores, 0, 2); // -> [dimBatch, 1, localBeamSize, dimVocab]
if(first)
// perform NN computation
if(t == 0)
graph->forward();
else
graph->forwardNext();
//**********************************************************************
// suppress specific symbols if not at right positions
if(trgUnkId_ != Word::NONE && options_->has("allow-unk")
&& !options_->get<bool>("allow-unk"))
suppressWord(pathScores, trgUnkId_);
if(trgUnkId_ != Word::NONE && options_->has("allow-unk") && !options_->get<bool>("allow-unk"))
suppressWord(expandedPathScores, trgUnkId_);
for(auto state : states)
state->blacklist(pathScores, batch);
state->blacklist(expandedPathScores, batch);
//**********************************************************************
// perform beam search and pruning
std::vector<unsigned int> outKeys;
std::vector<float> outPathScores;
// perform beam search
std::vector<size_t> beamSizes(dimBatch, localBeamSize);
getNBestList(beamSizes, pathScores->val(), outPathScores, outKeys, first);
// find N best amongst the (localBeamSize * dimVocab) hypotheses
std::vector<unsigned int> nBestKeys; // [dimBatch, localBeamSize] flattened -> (batchIdx, beamHypIdx, word idx) flattened
std::vector<float> nBestPathScores; // [dimBatch, localBeamSize] flattened
getNBestList(/*in*/ expandedPathScores->val(), // [dimBatch, 1, localBeamSize, dimVocab or dimShortlist]
/*N=*/localBeamSize, // desired beam size
/*out*/ nBestPathScores, /*out*/ nBestKeys,
/*first=*/t == 0); // @TODO: this is only used for checking presently, and should be removed altogether
// Now, nBestPathScores contain N-best expandedPathScores for each batch and beam,
// and nBestKeys for each their original location (batchIdx, beamHypIdx, word).
int dimTrgVoc = pathScores->shape()[-1];
beams = toHyps(outKeys,
outPathScores,
dimTrgVoc,
// combine N-best sets with existing search space (beams) to updated search space
beams = toHyps(nBestKeys, nBestPathScores,
/*inputBeamSize*/expandedPathScores->shape()[-2], // used for interpretation of keys
/*vocabSize=*/expandedPathScores->shape()[-1], // used for interpretation of keys
beams,
states,
localBeamSize,
first,
batch);
states, // only used for keeping track of per-ensemble-member path score
batch); // only used for propagating alignment info
auto prunedBeams = pruneBeam(beams);
// remove all hyps that end in EOS
// The position of a hyp in the beam may change.
const auto purgedNewBeams = purgeBeams(beams);
// add updated search space (beams) to our return value
bool maxLengthReached = false;
for(int i = 0; i < dimBatch; ++i) {
// if this batch entry has surviving hyps then add them to the traceback grid
if(!beams[i].empty()) {
final = final
|| histories[i]->size()
>= options_->get<float>("max-length-factor")
* batch->front()->batchWidth();
histories[i]->Add(
beams[i], trgEosId_, prunedBeams[i].empty() || final);
if (histories[i]->size() >= options_->get<float>("max-length-factor") * batch->front()->batchWidth())
maxLengthReached = true;
histories[i]->add(beams[i], trgEosId_, purgedNewBeams[i].empty() || maxLengthReached);
}
}
beams = prunedBeams;
if (maxLengthReached) // early exit if max length limit was reached
break;
// determine beam size for next sentence, as max over still-active sentences
if(!first) {
size_t maxBeam = 0;
for(auto& beam : beams)
if(beam.size() > maxBeam)
maxBeam = beam.size();
localBeamSize = maxBeam;
}
first = false;
// this is the search space for the next output time step
beams = purgedNewBeams;
} // end of main loop over output time steps
} while(localBeamSize != 0 && !final); // end of main loop over output tokens
return histories;
return histories; // [dimBatch][t][N best hyps]
}
};
} // namespace marian

27
src/translator/history.h Normal file → Executable file
View File

@ -19,18 +19,17 @@ private:
float normalizedPathScore; // length-normalized sentence score
};
float lengthPenalty(size_t length) { return std::pow((float)length, alpha_); }
float wordPenalty(size_t length) { return wp_ * (float)length; }
public:
History(size_t lineNo, float alpha = 1.f, float wp_ = 0.f);
float LengthPenalty(size_t length) { return std::pow((float)length, alpha_); }
float WordPenalty(size_t length) { return wp_ * (float)length; }
void Add(const Beam& beam, Word trgEosId, bool last = false) {
if(beam.back()->GetPrevHyp() != nullptr) {
void add(const Beam& beam, Word trgEosId, bool last = false) {
if(beam.back()->getPrevHyp() != nullptr) {
for(size_t j = 0; j < beam.size(); ++j)
if(beam[j]->GetWord() == trgEosId || last) {
float pathScore = (beam[j]->GetPathScore() - WordPenalty(history_.size()))
/ LengthPenalty(history_.size());
if(beam[j]->getWord() == trgEosId || last) {
float pathScore =
(beam[j]->getPathScore() - wordPenalty(history_.size())) / lengthPenalty(history_.size());
topHyps_.push({history_.size(), j, pathScore});
// std::cerr << "Add " << history_.size() << " " << j << " " << pathScore
// << std::endl;
@ -41,7 +40,7 @@ public:
size_t size() const { return history_.size(); } // number of time steps
NBestList NBest(size_t n) const {
NBestList nBest(size_t n) const {
NBestList nbest;
for (auto topHypsCopy = topHyps_; nbest.size() < n && !topHypsCopy.empty(); topHypsCopy.pop()) {
auto bestHypCoord = topHypsCopy.top();
@ -53,17 +52,17 @@ public:
// std::cerr << "h: " << start << " " << j << " " << c << std::endl;
// trace back best path
Words targetWords = bestHyp->TracebackWords();
Words targetWords = bestHyp->tracebackWords();
// note: bestHyp->GetPathScore() is not normalized, while bestHypCoord.normalizedPathScore is
// note: bestHyp->getPathScore() is not normalized, while bestHypCoord.normalizedPathScore is
nbest.emplace_back(targetWords, bestHyp, bestHypCoord.normalizedPathScore);
}
return nbest;
}
Result Top() const { return NBest(1)[0]; }
Result top() const { return nBest(1)[0]; }
size_t GetLineNum() const { return lineNo_; }
size_t getLineNum() const { return lineNo_; }
private:
std::vector<Beam> history_; // [time step][index into beam] search grid
@ -73,5 +72,5 @@ private:
float wp_;
};
typedef std::vector<Ptr<History>> Histories;
typedef std::vector<Ptr<History>> Histories; // [batchDim]
} // namespace marian

42
src/translator/hypothesis.h Normal file → Executable file
View File

@ -6,36 +6,42 @@
namespace marian {
// one single (possibly partial) hypothesis in beam search
// key elements:
// - the word that this hyp ends with
// - the aggregate score up to and including the word
// - back pointer to previous hypothesis for traceback
class Hypothesis {
public:
Hypothesis() : prevHyp_(nullptr), prevIndex_(0), word_(Word::ZERO), pathScore_(0.0) {}
Hypothesis(const Ptr<Hypothesis> prevHyp,
Word word,
IndexType prevIndex,
IndexType prevIndex, // (beamHypIdx, batchIdx) flattened as beamHypIdx * dimBatch + batchIdx
float pathScore)
: prevHyp_(prevHyp), prevIndex_(prevIndex), word_(word), pathScore_(pathScore) {}
const Ptr<Hypothesis> GetPrevHyp() const { return prevHyp_; }
const Ptr<Hypothesis> getPrevHyp() const { return prevHyp_; }
Word GetWord() const { return word_; }
Word getWord() const { return word_; }
IndexType GetPrevStateIndex() const { return prevIndex_; }
IndexType getPrevStateIndex() const { return prevIndex_; }
float GetPathScore() const { return pathScore_; }
float getPathScore() const { return pathScore_; }
std::vector<float>& GetScoreBreakdown() { return scoreBreakdown_; }
std::vector<float>& GetAlignment() { return alignment_; }
std::vector<float>& getScoreBreakdown() { return scoreBreakdown_; }
void setScoreBreakdown(const std::vector<float>& scoreBreaddown) { scoreBreakdown_ = scoreBreaddown; }
void SetAlignment(const std::vector<float>& align) { alignment_ = align; };
const std::vector<float>& getAlignment() { return alignment_; }
void setAlignment(const std::vector<float>& align) { alignment_ = align; };
// helpers to trace back paths referenced from this hypothesis
Words TracebackWords()
Words tracebackWords()
{
Words targetWords;
for (auto hyp = this; hyp->GetPrevHyp(); hyp = hyp->GetPrevHyp().get()) {
targetWords.push_back(hyp->GetWord());
// std::cerr << hyp->GetWord() << " " << hyp << std::endl;
for (auto hyp = this; hyp->getPrevHyp(); hyp = hyp->getPrevHyp().get()) {
targetWords.push_back(hyp->getWord());
// std::cerr << hyp->getWord() << " " << hyp << std::endl;
}
std::reverse(targetWords.begin(), targetWords.end());
return targetWords;
@ -43,11 +49,11 @@ public:
// get soft alignments for each target word starting from the hyp one
typedef data::SoftAlignment SoftAlignment;
SoftAlignment TracebackAlignment()
SoftAlignment tracebackAlignment()
{
SoftAlignment align;
for (auto hyp = this; hyp->GetPrevHyp(); hyp = hyp->GetPrevHyp().get()) {
align.push_back(hyp->GetAlignment());
for (auto hyp = this; hyp->getPrevHyp(); hyp = hyp->getPrevHyp().get()) {
align.push_back(hyp->getAlignment());
}
std::reverse(align.begin(), align.end());
return align;
@ -59,12 +65,12 @@ private:
const Word word_;
const float pathScore_;
std::vector<float> scoreBreakdown_;
std::vector<float> scoreBreakdown_; // [num scorers]
std::vector<float> alignment_;
};
typedef std::vector<Ptr<Hypothesis>> Beam; // Beam = vector of hypotheses
typedef std::vector<Beam> Beams; // Beams = vector of vector of hypotheses
typedef std::vector<Ptr<Hypothesis>> Beam; // Beam = vector [beamSize] of hypotheses
typedef std::vector<Beam> Beams; // Beams = vector [batchDim] of vector [beamSize] of hypotheses
typedef std::tuple<Words, Ptr<Hypothesis>, float> Result; // (word ids for hyp, hyp, normalized sentence score for hyp)
typedef std::vector<Result> NBestList; // sorted vector of (word ids, hyp, sent score) tuples
} // namespace marian

94
src/translator/nth_element.cpp Normal file → Executable file
View File

@ -14,22 +14,16 @@ namespace marian {
class NthElementCPU {
std::vector<int> h_res_idx;
std::vector<float> h_res;
size_t lastN;
//size_t lastN_;
public:
NthElementCPU() = delete;
NthElementCPU() {}
NthElementCPU(const NthElementCPU& copy) = delete;
NthElementCPU(size_t maxBeamSize, size_t maxBatchSize) {
size_t maxSize = maxBeamSize * maxBatchSize;
h_res.resize(maxSize);
h_res_idx.resize(maxSize);
}
private:
void getNBestList(float* scores,
const std::vector<int>& batchFirstElementIdxs,
const std::vector<int>& cumulativeBeamSizes) {
void selectNBest(float* scores,
const std::vector<int>& batchFirstElementIdxs,
const std::vector<int>& cumulativeBeamSizes) {
/* For each batch, select the max N elements, where N is the beam size for
* this batch. Locally record these elements (their current value and index
* in 'scores') before updating each element to a large negative value, such
@ -49,7 +43,7 @@ private:
std::vector<int>::iterator middle = begin + beamSize;
std::vector<int>::iterator end = idxs.begin() + batchFirstElementIdxs[batchIdx + 1];
std::partial_sort(
begin, middle, end, [=](int a, int b) { return scores[a] > scores[b]; });
begin, middle, end, [&](int a, int b) { return scores[a] > scores[b]; });
while(begin != middle) {
int idx = *begin++;
@ -62,39 +56,57 @@ private:
}
public:
void getNBestList(const std::vector<size_t>& beamSizes,
Tensor scores,
std::vector<float>& outPathScores,
std::vector<unsigned>& outKeys,
const bool isFirst) {
std::vector<int> cumulativeBeamSizes(beamSizes.size() + 1, 0);
std::vector<int> batchFirstElementIdxs(beamSizes.size() + 1, 0);
void getNBestList(Tensor scores, // [dimBatch, 1, beamSize, dimVocab or dimShortlist]
size_t N,
std::vector<float>& outPathScores,
std::vector<unsigned>& outKeys,
const bool isFirst) {
const auto vocabSize = scores->shape()[-1];
const auto inputN = scores->shape()[-2];
const auto dimBatch = scores->shape()[-4];
ABORT_IF(inputN != (isFirst ? 1 : N), "Input tensor has wrong beam dim??"); // @TODO: Remove isFirst argument altogether
auto vocabSize = scores->shape()[-1];
for(int i = 0; i < beamSizes.size(); ++i) {
cumulativeBeamSizes[i + 1] = cumulativeBeamSizes[i] + (int)beamSizes[i];
batchFirstElementIdxs[i + 1]
+= (isFirst ? i + 1 : cumulativeBeamSizes[i + 1]) * vocabSize;
std::vector<int> cumulativeBeamSizes(dimBatch + 1, 0);
std::vector<int> batchFirstElementIdxs(dimBatch + 1, 0);
for(int batchIdx = 0; batchIdx < dimBatch; ++batchIdx) {
#if 1
cumulativeBeamSizes[batchIdx + 1] = (batchIdx + 1) * (int)N;
batchFirstElementIdxs[batchIdx + 1] += (batchIdx + 1) * inputN * vocabSize;
ABORT_IF(cumulativeBeamSizes[batchIdx + 1] != cumulativeBeamSizes[batchIdx] + (int)N, "cumulativeBeamSizes wrong??");
ABORT_IF((isFirst ? batchIdx + 1 : cumulativeBeamSizes[batchIdx + 1]) != (batchIdx + 1) * inputN, "inputN wrong??");
#else
cumulativeBeamSizes[batchIdx + 1] = cumulativeBeamSizes[batchIdx] + (int)N;
ABORT_IF(cumulativeBeamSizes[batchIdx + 1] != (batchIdx + 1) * N, "cumulativeBeamSizes wrong??");
batchFirstElementIdxs[batchIdx + 1]
+= (isFirst ? batchIdx + 1 : cumulativeBeamSizes[batchIdx + 1]) * vocabSize;
ABORT_IF((isFirst ? batchIdx + 1 : cumulativeBeamSizes[batchIdx + 1]) != (batchIdx + 1) * inputN, "inputN wrong??");
#endif
}
ABORT_IF(cumulativeBeamSizes.back() != dimBatch * N, "cumulativeBeamSizes.back() wrong??");
getNBestList(scores->data(), batchFirstElementIdxs, cumulativeBeamSizes);
getPairs(cumulativeBeamSizes.back(), outKeys, outPathScores);
size_t maxSize = N * dimBatch;
h_res.resize(maxSize);
h_res_idx.resize(maxSize);
selectNBest(scores->data(), batchFirstElementIdxs, cumulativeBeamSizes);
getPairs(/*cumulativeBeamSizes.back(),*/ outKeys, outPathScores);
}
private:
void getPairs(size_t number,
void getPairs(/*size_t number,*/
std::vector<unsigned>& outKeys,
std::vector<float>& outValues) {
std::copy(h_res_idx.begin(), h_res_idx.begin() + number, std::back_inserter(outKeys));
std::copy(h_res .begin(), h_res .begin() + number, std::back_inserter(outValues));
lastN = number;
std::copy(h_res_idx.begin(), h_res_idx.end(), std::back_inserter(outKeys));
std::copy(h_res .begin(), h_res .end(), std::back_inserter(outValues));
//lastN_ = number;
}
void getValueByKey(std::vector<float>& out, float* d_in) {
for(size_t i = 0; i < lastN; ++i) {
out[i] = d_in[h_res_idx[i]];
}
}
//void getValueByKey(std::vector<float>& out, float* d_in) {
// for(size_t i = 0; i < lastN_; ++i) {
// out[i] = d_in[h_res_idx[i]];
// }
//}
};
#ifdef CUDA_FOUND
@ -108,15 +120,11 @@ GetNBestListFn createGetNBestListFn(size_t beamSize, size_t dimBatch, DeviceId d
if(deviceId.type == DeviceType::gpu)
return createGetNBestListGPUFn(beamSize, dimBatch, deviceId);
#else
deviceId; // (unused)
deviceId; beamSize; dimBatch; // (unused)
#endif
auto nth = New<NthElementCPU>(beamSize, dimBatch);
return [nth](const std::vector<size_t>& beamSizes,
Tensor logProbs,
std::vector<float>& outCosts,
std::vector<unsigned>& outKeys,
const bool isFirst) {
return nth->getNBestList(beamSizes, logProbs, outCosts, outKeys, isFirst);
auto nth = New<NthElementCPU>();
return [nth](Tensor logProbs, size_t N, std::vector<float>& outCosts, std::vector<unsigned>& outKeys, const bool isFirst) {
return nth->getNBestList(logProbs, N, outCosts, outKeys, isFirst);
};
}

90
src/translator/nth_element.cu Normal file → Executable file
View File

@ -279,6 +279,7 @@ public:
size_t maxBatchSize,
DeviceId deviceId)
: deviceId_(deviceId),
maxBeamSize_(maxBeamSize), maxBatchSize_(maxBatchSize),
NUM_BLOCKS(std::min(
500,
int(maxBeamSize* MAX_VOCAB_SIZE / (2 * BLOCK_SIZE))
@ -316,9 +317,9 @@ public:
}
private:
void getNBestList(float* probs,
const std::vector<int>& batchFirstElementIdxs,
const std::vector<int>& cummulatedBeamSizes) {
void selectNBest(float* probs,
const std::vector<int>& batchFirstElementIdxs,
const std::vector<int>& cumulativeBeamSizes) {
cudaSetDevice(deviceId_.no);
CUDA_CHECK(cudaMemcpyAsync(d_batchPosition,
batchFirstElementIdxs.data(),
@ -326,8 +327,8 @@ private:
cudaMemcpyHostToDevice,
/* stream_ */ 0));
CUDA_CHECK(cudaMemcpyAsync(d_cumBeamSizes,
cummulatedBeamSizes.data(),
cummulatedBeamSizes.size() * sizeof(int),
cumulativeBeamSizes.data(),
cumulativeBeamSizes.size() * sizeof(int),
cudaMemcpyHostToDevice,
/* stream_ */ 0));
@ -353,26 +354,43 @@ private:
}
public:
void getNBestList(const std::vector<size_t>& beamSizes,
Tensor Probs,
void getNBestList(Tensor scores,
size_t N,
std::vector<float>& outCosts,
std::vector<unsigned>& outKeys,
const bool isFirst) {
cudaSetDevice(deviceId_.no);
std::vector<int> cummulatedBeamSizes(beamSizes.size() + 1, 0);
const auto vocabSize = scores->shape()[-1];
const auto inputN = scores->shape()[-2];
const auto dimBatch = scores->shape()[-4];
ABORT_IF(inputN != (isFirst ? 1 : N), "Input tensor has wrong beam dim??"); // @TODO: Remove isFirst argument altogether
ABORT_IF(vocabSize > MAX_VOCAB_SIZE, "GetNBestList(): actual vocab size exceeds MAX_VOCAB_SIZE");
ABORT_IF(dimBatch > maxBatchSize_, "GetNBestList(): actual batch size exceeds initialization parameter");
ABORT_IF(N > maxBeamSize_, "GetNBestList(): actual beam size exceeds initialization parameter"); // @TODO: or inputN?
const std::vector<size_t> beamSizes(dimBatch, N);
std::vector<int> cumulativeBeamSizes(beamSizes.size() + 1, 0);
std::vector<int> batchFirstElementIdxs(beamSizes.size() + 1, 0);
const size_t vocabSize = Probs->shape()[-1];
for(size_t i = 0; i < beamSizes.size(); ++i) {
cummulatedBeamSizes[i + 1] = cummulatedBeamSizes[i] + beamSizes[i];
batchFirstElementIdxs[i + 1]
+= ((isFirst) ? (i + 1) : cummulatedBeamSizes[i + 1]) * vocabSize;
for(size_t batchIdx = 0; batchIdx < beamSizes.size(); ++batchIdx) {
#if 1
cumulativeBeamSizes[batchIdx + 1] = (batchIdx + 1) * (int)N;
batchFirstElementIdxs[batchIdx + 1] += (batchIdx + 1) * inputN * vocabSize;
ABORT_IF(cumulativeBeamSizes[batchIdx + 1] != cumulativeBeamSizes[batchIdx] + (int)N, "cumulativeBeamSizes wrong??");
ABORT_IF((isFirst ? batchIdx + 1 : cumulativeBeamSizes[batchIdx + 1]) != (batchIdx + 1) * inputN, "inputN wrong??");
#else
cumulativeBeamSizes[batchIdx + 1] = cumulativeBeamSizes[batchIdx] + beamSizes[batchIdx];
ABORT_IF(cumulativeBeamSizes[batchIdx + 1] != (batchIdx + 1) * N, "cumulativeBeamSizes wrong??");
batchFirstElementIdxs[batchIdx + 1]
+= ((isFirst) ? (batchIdx + 1) : cumulativeBeamSizes[batchIdx + 1]) * vocabSize;
ABORT_IF((isFirst ? batchIdx + 1 : cumulativeBeamSizes[batchIdx + 1]) != (batchIdx + 1) * inputN, "inputN wrong??");
#endif
}
getNBestList(Probs->data(), batchFirstElementIdxs, cummulatedBeamSizes);
getPairs(cummulatedBeamSizes.back(), outKeys, outCosts);
selectNBest(scores->data(), batchFirstElementIdxs, cumulativeBeamSizes);
getPairs(dimBatch * N, outKeys, outCosts);
ABORT_IF(cumulativeBeamSizes.back() != dimBatch * N, "cumulativeBeamSizes.back() wrong??");
}
private:
@ -397,26 +415,28 @@ private:
outValues.push_back(h_res[i]);
}
lastN = number;
//lastN = number;
}
void getValueByKey(std::vector<float>& out, float* d_in) {
cudaSetDevice(deviceId_.no);
gGetValueByKey<<<1, lastN, 0, /* stream_ */ 0>>>(
d_in, d_breakdown, h_res_idx, lastN);
CUDA_CHECK(cudaMemcpyAsync(out.data(),
d_breakdown,
lastN * sizeof(float),
cudaMemcpyDeviceToHost,
/* stream_ */ 0));
CUDA_CHECK(cudaStreamSynchronize(/* stream_ */ 0));
}
//void getValueByKey(std::vector<float>& out, float* d_in) {
// cudaSetDevice(deviceId_.no);
//
// gGetValueByKey<<<1, lastN, 0, /* stream_ */ 0>>>(
// d_in, d_breakdown, h_res_idx, lastN);
//
// CUDA_CHECK(cudaMemcpyAsync(out.data(),
// d_breakdown,
// lastN * sizeof(float),
// cudaMemcpyDeviceToHost,
// /* stream_ */ 0));
// CUDA_CHECK(cudaStreamSynchronize(/* stream_ */ 0));
//}
DeviceId deviceId_;
const int MAX_VOCAB_SIZE = 100000;
size_t maxBeamSize_;
size_t maxBatchSize_;
const int BLOCK_SIZE = 512;
const int NUM_BLOCKS;
@ -433,19 +453,15 @@ private:
float* d_breakdown;
int* d_batchPosition;
int* d_cumBeamSizes;
size_t lastN;
//size_t lastN;
};
// factory function
// Returns a lambda with the same signature as the getNBestList() function.
GetNBestListFn createGetNBestListGPUFn(size_t beamSize, size_t dimBatch, DeviceId deviceId) {
auto nth = New<NthElementGPU>(beamSize, dimBatch, deviceId);
return [nth](const std::vector<size_t>& beamSizes,
Tensor logProbs,
std::vector<float>& outCosts,
std::vector<unsigned>& outKeys,
const bool isFirst) {
return nth->getNBestList(beamSizes, logProbs, outCosts, outKeys, isFirst);
return [nth](Tensor logProbs, size_t N, std::vector<float>& outCosts, std::vector<unsigned>& outKeys, const bool isFirst) {
return nth->getNBestList(logProbs, N, outCosts, outKeys, isFirst);
};
}

4
src/translator/nth_element.h Normal file → Executable file
View File

@ -10,8 +10,8 @@
namespace marian {
typedef std::function<void(const std::vector<size_t>& beamSizes,
Tensor logProbs,
typedef std::function<void(Tensor logProbs,
size_t N,
std::vector<float>& outCosts,
std::vector<unsigned>& outKeys,
const bool isFirst)> GetNBestListFn;

6
src/translator/output_printer.cpp Normal file → Executable file
View File

@ -6,9 +6,9 @@ std::string OutputPrinter::getAlignment(const Ptr<Hypothesis>& hyp) {
data::SoftAlignment align;
auto last = hyp;
// get soft alignments for each target word starting from the last one
while(last->GetPrevHyp().get() != nullptr) {
align.push_back(last->GetAlignment());
last = last->GetPrevHyp();
while(last->getPrevHyp().get() != nullptr) {
align.push_back(last->getAlignment());
last = last->getPrevHyp();
}
// reverse alignments

14
src/translator/output_printer.h Normal file → Executable file
View File

@ -24,7 +24,7 @@ public:
template <class OStream>
void print(Ptr<History> history, OStream& best1, OStream& bestn) {
const auto& nbl = history->NBest(nbest_);
const auto& nbl = history->nBest(nbest_);
for(size_t i = 0; i < nbl.size(); ++i) {
const auto& result = nbl[i];
@ -35,17 +35,17 @@ public:
std::reverse(words.begin(), words.end());
std::string translation = vocab_->decode(words);
bestn << history->GetLineNum() << " ||| " << translation;
bestn << history->getLineNum() << " ||| " << translation;
if(!alignment_.empty())
bestn << " ||| " << getAlignment(hypo);
bestn << " |||";
if(hypo->GetScoreBreakdown().empty()) {
bestn << " F0=" << hypo->GetPathScore();
if(hypo->getScoreBreakdown().empty()) {
bestn << " F0=" << hypo->getPathScore();
} else {
for(size_t j = 0; j < hypo->GetScoreBreakdown().size(); ++j) {
bestn << " F" << j << "= " << hypo->GetScoreBreakdown()[j];
for(size_t j = 0; j < hypo->getScoreBreakdown().size(); ++j) {
bestn << " F" << j << "= " << hypo->getScoreBreakdown()[j];
}
}
@ -58,7 +58,7 @@ public:
bestn << std::flush;
}
auto result = history->Top();
auto result = history->top();
auto words = std::get<0>(result);
if(reverse_)

6
src/translator/scorers.h Normal file → Executable file
View File

@ -9,9 +9,9 @@ namespace marian {
class ScorerState {
public:
virtual Expr getLogProbs() = 0;
virtual Expr getLogProbs() const = 0;
virtual float breakDown(size_t i) { return getLogProbs()->val()->get(i); }
float breakDown(size_t i) const { return getLogProbs()->val()->get(i); }
virtual void blacklist(Expr /*totalCosts*/, Ptr<data::CorpusBatch> /*batch*/){};
};
@ -57,7 +57,7 @@ public:
virtual Ptr<DecoderState> getState() { return state_; }
virtual Expr getLogProbs() override { return state_->getLogProbs(); };
virtual Expr getLogProbs() const override { return state_->getLogProbs(); };
virtual void blacklist(Expr totalCosts, Ptr<data::CorpusBatch> batch) override {
state_->blacklist(totalCosts, batch);

4
src/translator/translator.h Normal file → Executable file
View File

@ -106,7 +106,7 @@ public:
std::stringstream best1;
std::stringstream bestn;
printer->print(history, best1, bestn);
collector->Write((long)history->GetLineNum(),
collector->Write((long)history->getLineNum(),
best1.str(),
bestn.str(),
options_->get<bool>("n-best"));
@ -211,7 +211,7 @@ public:
std::stringstream best1;
std::stringstream bestn;
printer->print(history, best1, bestn);
collector->add((long)history->GetLineNum(), best1.str(), bestn.str());
collector->add((long)history->getLineNum(), best1.str(), bestn.str());
}
};

0
vs/Marian.vcxproj Normal file → Executable file
View File

3
vs/Marian.vcxproj.filters Normal file → Executable file
View File

@ -1528,6 +1528,9 @@
<ClInclude Include="..\src\data\vocab_base.h">
<Filter>data</Filter>
</ClInclude>
<ClInclude Include="..\src\models\transformer.h">
<Filter>models</Filter>
</ClInclude>
</ItemGroup>
<ItemGroup>
<Filter Include="3rd_party">