Speed improvements (#531)

Summary:
* Add FusedLayerNorm and FusedAdam
* Softmax and zero grad optimizations
Pull Request resolved: https://github.com/pytorch/fairseq/pull/531

Differential Revision: D14218457

Pulled By: myleott

fbshipit-source-id: 5656b2d0152cd85f77dc21ec0e1439ec04b9fa89
This commit is contained in:
Myle Ott 2019-03-14 11:39:03 -07:00 committed by Facebook Github Bot
parent a24880bd10
commit 48d9afbeb3
14 changed files with 103 additions and 55 deletions

View File

@ -36,12 +36,12 @@ translation and language modeling datasets.
![Model](fairseq.gif)
# Requirements and Installation
* A [PyTorch installation](http://pytorch.org/)
* For training new models, you'll also need an NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl)
* Python version 3.6
Currently fairseq requires PyTorch version >= 1.0.0.
Please follow the instructions here: https://github.com/pytorch/pytorch#installation.
* [PyTorch](http://pytorch.org/) version >= 1.0.0
* Python version >= 3.6
* For training new models, you'll also need an NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl)
Please follow the instructions here to install PyTorch: https://github.com/pytorch/pytorch#installation.
If you use Docker make sure to increase the shared memory size either with
`--ipc=host` or `--shm-size` as command line options to `nvidia-docker run`.
@ -60,6 +60,12 @@ cd fairseq
pip install --editable .
```
**Improved training speed**
Training speed can be further improved by installing NVIDIA's
[apex](https://github.com/NVIDIA/apex) library with the `--cuda_ext` option.
fairseq will automatically switch to the faster modules provided by apex.
# Getting Started
The [full documentation](https://fairseq.readthedocs.io/) contains instructions

View File

@ -122,8 +122,10 @@ def all_gather_list(data, group=None, max_size=16384):
if not hasattr(all_gather_list, '_buffer') or \
all_gather_list._buffer.numel() < buffer_size:
all_gather_list._buffer = torch.cuda.ByteTensor(buffer_size)
all_gather_list._cpu_buffer = torch.ByteTensor(max_size).pin_memory()
buffer = all_gather_list._buffer
buffer.zero_()
cpu_buffer = all_gather_list._cpu_buffer
enc = pickle.dumps(data)
enc_size = len(enc)
@ -131,10 +133,12 @@ def all_gather_list(data, group=None, max_size=16384):
raise ValueError('encoded data exceeds max_size: {}'.format(enc_size + 2))
assert max_size < 255*256
buffer_rank = buffer[rank * max_size : (rank + 1) * max_size]
buffer_rank[0] = enc_size // 255 # this encoding works for max_size < 65k
buffer_rank[1] = enc_size % 255
buffer_rank[2:enc_size+2] = torch.ByteTensor(list(enc))
cpu_buffer[0] = enc_size // 255 # this encoding works for max_size < 65k
cpu_buffer[1] = enc_size % 255
cpu_buffer[2 : enc_size + 2] = torch.ByteTensor(list(enc))
start = rank * max_size
size = enc_size + 2
buffer[start : start + size].copy_(cpu_buffer[:size])
all_reduce(buffer, group=group)
@ -144,9 +148,7 @@ def all_gather_list(data, group=None, max_size=16384):
out_buffer = buffer[i * max_size : (i + 1) * max_size]
size = (255 * utils.item(out_buffer[0])) + utils.item(out_buffer[1])
if size > 0:
result.append(
pickle.loads(bytes(out_buffer[2:size+2].tolist()))
)
result.append(pickle.loads(bytes(out_buffer[2 : size + 2].tolist())))
return result
except pickle.UnpicklingError:
raise Exception(

View File

@ -6,7 +6,8 @@
# can be found in the PATENTS file in the same directory.
import torch.nn as nn
import torch.nn.functional as F
from fairseq import utils
class FairseqDecoder(nn.Module):
@ -15,6 +16,7 @@ class FairseqDecoder(nn.Module):
def __init__(self, dictionary):
super().__init__()
self.dictionary = dictionary
self.onnx_trace = False
def forward(self, prev_output_tokens, encoder_out):
"""
@ -33,6 +35,9 @@ class FairseqDecoder(nn.Module):
"""
raise NotImplementedError
def prepare_for_onnx_export_(self):
self.onnx_trace = True
def get_normalized_probs(self, net_output, log_probs, sample):
"""Get normalized probabilities (or log probs) from a net's output."""
@ -45,11 +50,11 @@ class FairseqDecoder(nn.Module):
out = self.adaptive_softmax.get_log_prob(net_output[0], target=target)
return out.exp_() if not log_probs else out
logits = net_output[0].float()
logits = net_output[0]
if log_probs:
return F.log_softmax(logits, dim=-1)
return utils.log_softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
else:
return F.softmax(logits, dim=-1)
return utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
def max_positions(self):
"""Maximum input length supported by the decoder."""

View File

@ -13,8 +13,8 @@ import torch.nn as nn
import torch.nn.functional as F
from fairseq.modules import (
DownsampledMultiHeadAttention, GradMultiply, LearnedPositionalEmbedding,
LinearizedConvolution,
DownsampledMultiHeadAttention, GradMultiply, LayerNorm,
LearnedPositionalEmbedding, LinearizedConvolution,
)
from fairseq import utils
@ -351,13 +351,13 @@ class FConvDecoder(FairseqDecoder):
# pretrained and trained models are joined
self.joining = nn.Sequential(
Linear(out_embed_dim*2, out_embed_dim*2),
nn.LayerNorm(out_embed_dim*2),
LayerNorm(out_embed_dim*2),
nn.GLU(),
Linear(out_embed_dim, out_embed_dim*2),
nn.LayerNorm(out_embed_dim*2),
LayerNorm(out_embed_dim*2),
nn.GLU(),
Linear(out_embed_dim, out_embed_dim),
nn.LayerNorm(out_embed_dim)
LayerNorm(out_embed_dim)
)
# pretrained model contains an output layer that is nhid -> vocab size
# but the models are combined in their hidden state
@ -470,7 +470,7 @@ class SelfAttention(nn.Module):
self.in_proj_q = Linear(out_channels, embed_dim)
self.in_proj_k = Linear(out_channels, embed_dim)
self.in_proj_v = Linear(out_channels, embed_dim)
self.ln = nn.LayerNorm(out_channels)
self.ln = LayerNorm(out_channels)
def forward(self, x):
residual = x

View File

@ -11,17 +11,16 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import options
from fairseq import utils
from fairseq import options, utils
from fairseq.modules import (
AdaptiveInput, AdaptiveSoftmax, CharacterTokenEmbedder, LearnedPositionalEmbedding, MultiheadAttention,
SinusoidalPositionalEmbedding, DynamicConv1dTBC, LightweightConv1dTBC
AdaptiveInput, AdaptiveSoftmax, CharacterTokenEmbedder, LayerNorm,
LearnedPositionalEmbedding, MultiheadAttention, SinusoidalPositionalEmbedding,
DynamicConv1dTBC, LightweightConv1dTBC,
)
from . import (
FairseqIncrementalDecoder, FairseqEncoder, FairseqLanguageModel, FairseqModel, register_model,
register_model_architecture,
FairseqIncrementalDecoder, FairseqEncoder, FairseqLanguageModel,
FairseqModel, register_model, register_model_architecture,
)
@ -771,11 +770,6 @@ def Embedding(num_embeddings, embedding_dim, padding_idx):
return m
def LayerNorm(embedding_dim):
m = nn.LayerNorm(embedding_dim)
return m
def Linear(in_features, out_features, bias=True):
m = nn.Linear(in_features, out_features, bias)
nn.init.xavier_uniform_(m.weight)

View File

@ -11,17 +11,15 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import options
from fairseq import utils
from fairseq import options, utils
from fairseq.modules import (
AdaptiveInput, AdaptiveSoftmax, CharacterTokenEmbedder, LearnedPositionalEmbedding, MultiheadAttention,
SinusoidalPositionalEmbedding
AdaptiveInput, AdaptiveSoftmax, CharacterTokenEmbedder, LayerNorm,
LearnedPositionalEmbedding, MultiheadAttention, SinusoidalPositionalEmbedding,
)
from . import (
FairseqIncrementalDecoder, FairseqEncoder, FairseqLanguageModel, FairseqModel, register_model,
register_model_architecture,
FairseqIncrementalDecoder, FairseqEncoder, FairseqLanguageModel,
FairseqModel, register_model, register_model_architecture,
)
@ -766,11 +764,6 @@ def Embedding(num_embeddings, embedding_dim, padding_idx):
return m
def LayerNorm(embedding_dim):
m = nn.LayerNorm(embedding_dim)
return m
def Linear(in_features, out_features, bias=True):
m = nn.Linear(in_features, out_features, bias)
nn.init.xavier_uniform_(m.weight)

View File

@ -14,6 +14,7 @@ from .downsampled_multihead_attention import DownsampledMultiHeadAttention
from .dynamic_convolution import DynamicConv1dTBC
from .grad_multiply import GradMultiply
from .highway import Highway
from .layer_norm import LayerNorm
from .learned_positional_embedding import LearnedPositionalEmbedding
from .lightweight_convolution import LightweightConv1dTBC
from .linearized_convolution import LinearizedConvolution
@ -34,6 +35,7 @@ __all__ = [
'DynamicConv1dTBC',
'GradMultiply',
'Highway',
'LayerNorm',
'LearnedPositionalEmbedding',
'LightweightConv1dTBC',
'LinearizedConvolution',

View File

@ -0,0 +1,18 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch
def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True):
if torch.cuda.is_available():
try:
from apex.normalization import FusedLayerNorm
return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
except ImportError:
pass
return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)

View File

@ -6,6 +6,7 @@
# can be found in the PATENTS file in the same directory.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
@ -121,6 +122,8 @@ class LightweightConv1dTBC(nn.Module):
self.reset_parameters()
self.onnx_trace = False
def reset_parameters(self):
nn.init.xavier_uniform_(self.weight)
if self.bias is not None:
@ -144,6 +147,9 @@ class LightweightConv1dTBC(nn.Module):
output = output + self.bias.view(1, 1, -1)
return output
def prepare_for_onnx_export_(self):
self.onnx_trace = True
def _forward_unfolded(self, x, incremental_state):
'''The conventional implementation of convolutions.
Unfolding the input by having a window shifting to the right.'''
@ -167,7 +173,7 @@ class LightweightConv1dTBC(nn.Module):
x_unfold = x_unfold.view(T*B*H, R, K)
if self.weight_softmax:
weight = F.softmax(weight.float(), dim=1).type_as(weight)
weight = utils.softmax(weight, dim=1, onnx_trace=self.onnx_trace).type_as(weight)
if incremental_state is not None:
weight = weight[:, -x_unfold.size(2):]
@ -192,7 +198,7 @@ class LightweightConv1dTBC(nn.Module):
weight = self.weight.view(H, K)
if self.weight_softmax:
weight = F.softmax(weight.float(), dim=1).type_as(weight)
weight = utils.softmax(weight, dim=1, onnx_trace=self.onnx_trace).type_as(weight)
weight = weight.view(1, H, K).expand(T*B, H, K).contiguous()
weight = weight.view(T, B*H, K).transpose(0, 1)

View File

@ -184,7 +184,9 @@ class MultiheadAttention(nn.Module):
).type_as(attn_weights) # FP16 support: cast to float and back
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = F.softmax(attn_weights.float(), dim=-1).type_as(attn_weights)
attn_weights = utils.softmax(
attn_weights, dim=-1, onnx_trace=self.onnx_trace,
).type_as(attn_weights)
attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training)
attn = torch.bmm(attn_weights, v)

View File

@ -16,7 +16,11 @@ from . import FairseqOptimizer, register_optimizer
class FairseqAdam(FairseqOptimizer):
def __init__(self, args, params):
super().__init__(args, params)
self._optimizer = Adam(params, **self.optimizer_config)
try:
from apex.optimizers import FusedAdam
self._optimizer = FusedAdam(params, **self.optimizer_config)
except ImportError:
self._optimizer = Adam(params, **self.optimizer_config)
@staticmethod
def add_args(parser):

View File

@ -92,4 +92,7 @@ class FairseqOptimizer(object):
def zero_grad(self):
"""Clears the gradients of all optimized parameters."""
for group in self.optimizer.param_groups:
for p in group['params']:
p.grad = None
self.optimizer.zero_grad()

View File

@ -205,11 +205,8 @@ class FP16Optimizer(optim.FairseqOptimizer):
def zero_grad(self):
"""Clears the gradients of all optimized parameters."""
self.fp32_optimizer.zero_grad()
for p in self.params:
if p.grad is not None:
p.grad.detach_()
p.grad.zero_()
p.grad = None
self._needs_sync = False

View File

@ -4,15 +4,17 @@
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from collections import defaultdict, OrderedDict
import importlib.util
import logging
import os
import re
import sys
import traceback
from collections import defaultdict, OrderedDict
import torch
import torch.nn.functional as F
from torch.serialization import default_restore_location
@ -447,3 +449,17 @@ def import_user_module(args):
sys.path.insert(0, module_parent)
importlib.import_module(module_name)
sys.path.pop(0)
def softmax(x, dim, onnx_trace=False):
if onnx_trace:
return F.softmax(x.float(), dim=dim)
else:
return F.softmax(x, dim=dim, dtype=torch.float32)
def log_softmax(x, dim, onnx_trace=False):
if onnx_trace:
return F.log_softmax(x.float(), dim=dim)
else:
return F.log_softmax(x, dim=dim, dtype=torch.float32)