dont project maske tokens for mlm loss (#859)

Summary:
This saves ~4-5gb gpu memory while training roberta large with `seq_len=512`.

I am able to fit `--max-sentences=16` on `volta32gb` for `roberta-large`
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/859

Differential Revision: D17435814

fbshipit-source-id: 2663909768fac0ef0102107613770ee01b1f8c00
This commit is contained in:
Naman Goyal 2019-09-18 10:05:01 -07:00 committed by Facebook Github Bot
parent 31dd13fa65
commit 718677ebb0
2 changed files with 16 additions and 9 deletions

View File

@ -30,8 +30,11 @@ class MaskedLmLoss(FairseqCriterion):
3) logging outputs to display while training 3) logging outputs to display while training
""" """
# compute MLM loss # compute MLM loss
logits = model(**sample['net_input'], return_all_hiddens=False)[0] masked_tokens = sample['target'].ne(self.padding_idx)
logits = model(**sample['net_input'], masked_tokens=masked_tokens)[0]
targets = model.get_targets(sample, [logits]) targets = model.get_targets(sample, [logits])
targets = targets[masked_tokens]
loss = F.nll_loss( loss = F.nll_loss(
F.log_softmax( F.log_softmax(
logits.view(-1, logits.size(-1)), logits.view(-1, logits.size(-1)),
@ -43,7 +46,7 @@ class MaskedLmLoss(FairseqCriterion):
ignore_index=self.padding_idx, ignore_index=self.padding_idx,
) )
sample_size = targets.ne(self.padding_idx).int().sum().item() sample_size = masked_tokens.int().sum().item()
logging_output = { logging_output = {
'loss': utils.item(loss.data) if reduce else loss.data, 'loss': utils.item(loss.data) if reduce else loss.data,
@ -64,6 +67,7 @@ class MaskedLmLoss(FairseqCriterion):
agg_output = { agg_output = {
'loss': loss / sample_size / math.log(2), 'loss': loss / sample_size / math.log(2),
'nll_loss': sum(log.get('nll_loss', 0) for log in logging_outputs) / sample_size / math.log(2) if ntokens > 0 else 0.,
'ntokens': ntokens, 'ntokens': ntokens,
'nsentences': nsentences, 'nsentences': nsentences,
'sample_size': sample_size, 'sample_size': sample_size,

View File

@ -201,14 +201,17 @@ class RobertaLMHead(nn.Module):
self.weight = weight self.weight = weight
self.bias = nn.Parameter(torch.zeros(output_dim)) self.bias = nn.Parameter(torch.zeros(output_dim))
def forward(self, features, **kwargs): def forward(self, features, masked_tokens=None, **kwargs):
# Only project the unmasked tokens while training,
# saves both memory and computation
if masked_tokens is not None:
features = features[masked_tokens, :]
x = self.dense(features) x = self.dense(features)
x = self.activation_fn(x) x = self.activation_fn(x)
x = self.layer_norm(x) x = self.layer_norm(x)
# project back to size of vocabulary with bias # project back to size of vocabulary with bias
x = F.linear(x, self.weight) + self.bias x = F.linear(x, self.weight) + self.bias
return x return x
@ -265,7 +268,7 @@ class RobertaEncoder(FairseqDecoder):
weight=self.sentence_encoder.embed_tokens.weight, weight=self.sentence_encoder.embed_tokens.weight,
) )
def forward(self, src_tokens, features_only=False, return_all_hiddens=False, **unused): def forward(self, src_tokens, features_only=False, return_all_hiddens=False, masked_tokens=None, **unused):
""" """
Args: Args:
src_tokens (LongTensor): input tokens of shape `(batch, src_len)` src_tokens (LongTensor): input tokens of shape `(batch, src_len)`
@ -283,7 +286,7 @@ class RobertaEncoder(FairseqDecoder):
""" """
x, extra = self.extract_features(src_tokens, return_all_hiddens) x, extra = self.extract_features(src_tokens, return_all_hiddens)
if not features_only: if not features_only:
x = self.output_layer(x) x = self.output_layer(x, masked_tokens=masked_tokens)
return x, extra return x, extra
def extract_features(self, src_tokens, return_all_hiddens=False, **unused): def extract_features(self, src_tokens, return_all_hiddens=False, **unused):
@ -293,8 +296,8 @@ class RobertaEncoder(FairseqDecoder):
features = inner_states[-1] features = inner_states[-1]
return features, {'inner_states': inner_states if return_all_hiddens else None} return features, {'inner_states': inner_states if return_all_hiddens else None}
def output_layer(self, features, **unused): def output_layer(self, features, masked_tokens=None, **unused):
return self.lm_head(features) return self.lm_head(features, masked_tokens)
def max_positions(self): def max_positions(self):
"""Maximum output length supported by the encoder.""" """Maximum output length supported by the encoder."""