mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
do not do dropout at inference
This commit is contained in:
parent
93cdfdcc9a
commit
147d9dc25b
@ -480,7 +480,8 @@ namespace marian {
|
||||
auto offsets = graph->constant({ (int)factoredData.offsets.size() }, inits::fromVector(factoredData.offsets), Type::uint32);
|
||||
// apply dropout
|
||||
// We apply it to the weights, i.e. factors get dropped out separately, but always as entire vectors.
|
||||
weights = dropout(weights, dropProb);
|
||||
if(!inference_)
|
||||
weights = dropout(weights, dropProb);
|
||||
// perform the product
|
||||
return csr_dot(factoredData.shape, weights, indices, offsets, E_);
|
||||
}
|
||||
@ -552,7 +553,8 @@ namespace marian {
|
||||
auto selectedEmbs = rows(E_, embIdxExpr); // [(B*W) x E]
|
||||
selectedEmbs = reshape(selectedEmbs, shape); // [W, B, E]
|
||||
// @BUGBUG: We should not broadcast along dimBatch=[-2]. Then we can also dropout before reshape() (test that separately)
|
||||
selectedEmbs = dropout(selectedEmbs, options_->get<float>("dropout", 0.0f), { selectedEmbs->shape()[-3], 1, 1 });
|
||||
if(!inference_)
|
||||
selectedEmbs = dropout(selectedEmbs, options_->get<float>("dropout", 0.0f), { selectedEmbs->shape()[-3], 1, 1 });
|
||||
return selectedEmbs;
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user