dont project maske tokens for mlm loss (#859)

Summary:
This saves ~4-5gb gpu memory while training roberta large with `seq_len=512`.

I am able to fit `--max-sentences=16` on `volta32gb` for `roberta-large`
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/859

Differential Revision: D17435814

fbshipit-source-id: 2663909768fac0ef0102107613770ee01b1f8c00
This commit is contained in:
Naman Goyal 2019-09-18 10:05:01 -07:00 committed by Facebook Github Bot
parent 31dd13fa65
commit 718677ebb0
2 changed files with 16 additions and 9 deletions

View File

@ -30,8 +30,11 @@ class MaskedLmLoss(FairseqCriterion):
3) logging outputs to display while training
"""
# compute MLM loss
logits = model(**sample['net_input'], return_all_hiddens=False)[0]
masked_tokens = sample['target'].ne(self.padding_idx)
logits = model(**sample['net_input'], masked_tokens=masked_tokens)[0]
targets = model.get_targets(sample, [logits])
targets = targets[masked_tokens]
loss = F.nll_loss(
F.log_softmax(
logits.view(-1, logits.size(-1)),
@ -43,7 +46,7 @@ class MaskedLmLoss(FairseqCriterion):
ignore_index=self.padding_idx,
)
sample_size = targets.ne(self.padding_idx).int().sum().item()
sample_size = masked_tokens.int().sum().item()
logging_output = {
'loss': utils.item(loss.data) if reduce else loss.data,
@ -64,6 +67,7 @@ class MaskedLmLoss(FairseqCriterion):
agg_output = {
'loss': loss / sample_size / math.log(2),
'nll_loss': sum(log.get('nll_loss', 0) for log in logging_outputs) / sample_size / math.log(2) if ntokens > 0 else 0.,
'ntokens': ntokens,
'nsentences': nsentences,
'sample_size': sample_size,

View File

@ -201,14 +201,17 @@ class RobertaLMHead(nn.Module):
self.weight = weight
self.bias = nn.Parameter(torch.zeros(output_dim))
def forward(self, features, **kwargs):
def forward(self, features, masked_tokens=None, **kwargs):
# Only project the unmasked tokens while training,
# saves both memory and computation
if masked_tokens is not None:
features = features[masked_tokens, :]
x = self.dense(features)
x = self.activation_fn(x)
x = self.layer_norm(x)
# project back to size of vocabulary with bias
x = F.linear(x, self.weight) + self.bias
return x
@ -265,7 +268,7 @@ class RobertaEncoder(FairseqDecoder):
weight=self.sentence_encoder.embed_tokens.weight,
)
def forward(self, src_tokens, features_only=False, return_all_hiddens=False, **unused):
def forward(self, src_tokens, features_only=False, return_all_hiddens=False, masked_tokens=None, **unused):
"""
Args:
src_tokens (LongTensor): input tokens of shape `(batch, src_len)`
@ -283,7 +286,7 @@ class RobertaEncoder(FairseqDecoder):
"""
x, extra = self.extract_features(src_tokens, return_all_hiddens)
if not features_only:
x = self.output_layer(x)
x = self.output_layer(x, masked_tokens=masked_tokens)
return x, extra
def extract_features(self, src_tokens, return_all_hiddens=False, **unused):
@ -293,8 +296,8 @@ class RobertaEncoder(FairseqDecoder):
features = inner_states[-1]
return features, {'inner_states': inner_states if return_all_hiddens else None}
def output_layer(self, features, **unused):
return self.lm_head(features)
def output_layer(self, features, masked_tokens=None, **unused):
return self.lm_head(features, masked_tokens)
def max_positions(self):
"""Maximum output length supported by the encoder."""