From 718677ebb044e27aaf1a30640c2f7ab6b8fa8509 Mon Sep 17 00:00:00 2001 From: Naman Goyal Date: Wed, 18 Sep 2019 10:05:01 -0700 Subject: [PATCH] dont project maske tokens for mlm loss (#859) Summary: This saves ~4-5gb gpu memory while training roberta large with `seq_len=512`. I am able to fit `--max-sentences=16` on `volta32gb` for `roberta-large` Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/859 Differential Revision: D17435814 fbshipit-source-id: 2663909768fac0ef0102107613770ee01b1f8c00 --- fairseq/criterions/masked_lm.py | 8 ++++++-- fairseq/models/roberta/model.py | 17 ++++++++++------- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/fairseq/criterions/masked_lm.py b/fairseq/criterions/masked_lm.py index d8907eba5..4eae5c384 100644 --- a/fairseq/criterions/masked_lm.py +++ b/fairseq/criterions/masked_lm.py @@ -30,8 +30,11 @@ class MaskedLmLoss(FairseqCriterion): 3) logging outputs to display while training """ # compute MLM loss - logits = model(**sample['net_input'], return_all_hiddens=False)[0] + masked_tokens = sample['target'].ne(self.padding_idx) + logits = model(**sample['net_input'], masked_tokens=masked_tokens)[0] targets = model.get_targets(sample, [logits]) + targets = targets[masked_tokens] + loss = F.nll_loss( F.log_softmax( logits.view(-1, logits.size(-1)), @@ -43,7 +46,7 @@ class MaskedLmLoss(FairseqCriterion): ignore_index=self.padding_idx, ) - sample_size = targets.ne(self.padding_idx).int().sum().item() + sample_size = masked_tokens.int().sum().item() logging_output = { 'loss': utils.item(loss.data) if reduce else loss.data, @@ -64,6 +67,7 @@ class MaskedLmLoss(FairseqCriterion): agg_output = { 'loss': loss / sample_size / math.log(2), + 'nll_loss': sum(log.get('nll_loss', 0) for log in logging_outputs) / sample_size / math.log(2) if ntokens > 0 else 0., 'ntokens': ntokens, 'nsentences': nsentences, 'sample_size': sample_size, diff --git a/fairseq/models/roberta/model.py b/fairseq/models/roberta/model.py index e5528dfc9..7b9cbba53 100644 --- a/fairseq/models/roberta/model.py +++ b/fairseq/models/roberta/model.py @@ -201,14 +201,17 @@ class RobertaLMHead(nn.Module): self.weight = weight self.bias = nn.Parameter(torch.zeros(output_dim)) - def forward(self, features, **kwargs): + def forward(self, features, masked_tokens=None, **kwargs): + # Only project the unmasked tokens while training, + # saves both memory and computation + if masked_tokens is not None: + features = features[masked_tokens, :] + x = self.dense(features) x = self.activation_fn(x) x = self.layer_norm(x) - # project back to size of vocabulary with bias x = F.linear(x, self.weight) + self.bias - return x @@ -265,7 +268,7 @@ class RobertaEncoder(FairseqDecoder): weight=self.sentence_encoder.embed_tokens.weight, ) - def forward(self, src_tokens, features_only=False, return_all_hiddens=False, **unused): + def forward(self, src_tokens, features_only=False, return_all_hiddens=False, masked_tokens=None, **unused): """ Args: src_tokens (LongTensor): input tokens of shape `(batch, src_len)` @@ -283,7 +286,7 @@ class RobertaEncoder(FairseqDecoder): """ x, extra = self.extract_features(src_tokens, return_all_hiddens) if not features_only: - x = self.output_layer(x) + x = self.output_layer(x, masked_tokens=masked_tokens) return x, extra def extract_features(self, src_tokens, return_all_hiddens=False, **unused): @@ -293,8 +296,8 @@ class RobertaEncoder(FairseqDecoder): features = inner_states[-1] return features, {'inner_states': inner_states if return_all_hiddens else None} - def output_layer(self, features, **unused): - return self.lm_head(features) + def output_layer(self, features, masked_tokens=None, **unused): + return self.lm_head(features, masked_tokens) def max_positions(self): """Maximum output length supported by the encoder."""