BASE layers (#1654)

Summary:
# Before submitting

- [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
- [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)?
- [ ] Did you make sure to update the docs?
- [ ] Did you write any new necessary tests?

## What does this PR do?
Fixes # (issue).

## PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

## Did you have fun?
Make sure you had fun coding �

Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1654

Reviewed By: myleott

Differential Revision: D27128074

Pulled By: shruti-bh

fbshipit-source-id: ac86d383cd53c9c9bdd946fea839a37b719d95e3
This commit is contained in:
Michael Lewis 2021-03-29 18:02:07 -07:00 committed by Facebook GitHub Bot
parent 1c9738c6e9
commit 7dafb05754
16 changed files with 341 additions and 13 deletions

View File

@ -115,7 +115,7 @@ def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss):
if not end_of_epoch and cfg.keep_interval_updates > 0:
# remove old checkpoints; checkpoints are sorted in descending order
checkpoints = checkpoint_paths(
cfg.save_dir, pattern=r"checkpoint_\d+_(\d+)\.pt"
cfg.save_dir, pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix)
)
for old_chk in checkpoints[cfg.keep_interval_updates :]:
if os.path.lexists(old_chk):
@ -123,7 +123,7 @@ def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss):
if cfg.keep_last_epochs > 0:
# remove old epoch checkpoints; checkpoints are sorted in descending order
checkpoints = checkpoint_paths(cfg.save_dir, pattern=r"checkpoint(\d+)\.pt")
checkpoints = checkpoint_paths(cfg.save_dir, pattern=r"checkpoint(\d+){}\.pt".format(suffix))
for old_chk in checkpoints[cfg.keep_last_epochs :]:
if os.path.lexists(old_chk):
os.remove(old_chk)
@ -132,8 +132,8 @@ def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss):
# only keep the best N checkpoints according to validation metric
checkpoints = checkpoint_paths(
cfg.save_dir,
pattern=r"checkpoint\.best_{}_(\d+\.?\d*)\.pt".format(
cfg.best_checkpoint_metric
pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format(
cfg.best_checkpoint_metric, suffix
),
)
if not cfg.maximize_best_checkpoint_metric:

View File

@ -0,0 +1,95 @@
/**
* Copyright 2017-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the license found in the
* LICENSE file in the root directory of this source tree.
*/
/*
C++ code for solving the linear assignment problem.
Based on the Auction Algorithm from https://dspace.mit.edu/bitstream/handle/1721.1/3265/P-2108-26912652.pdf and the implementation from:
https://github.com/bkj/auction-lap
Adapted to be more efficient when each worker is looking for k jobs instead of 1.
*/
#include <torch/extension.h>
#include <iostream>
using namespace torch::indexing;
torch::Tensor balanced_assignment(torch::Tensor job_and_worker_to_score) {
int max_iterations = 100;
torch::Tensor epsilon = (job_and_worker_to_score.max() - job_and_worker_to_score.min()) / 50;
epsilon.clamp_min_(1e-04);
torch::Tensor worker_and_job_to_score = job_and_worker_to_score.detach().transpose(0,1).contiguous();
int num_workers = worker_and_job_to_score.size(0);
int num_jobs = worker_and_job_to_score.size(1);
auto device = worker_and_job_to_score.device();
int jobs_per_worker = num_jobs / num_workers;
torch::Tensor value = worker_and_job_to_score.clone();
int counter = 0;
torch::Tensor max_value = worker_and_job_to_score.max();
torch::Tensor bid_indices;
torch::Tensor cost = worker_and_job_to_score.new_zeros({1, num_jobs});
torch::Tensor bids = worker_and_job_to_score.new_empty({num_workers, num_jobs});
torch::Tensor bid_increments = worker_and_job_to_score.new_empty({num_workers, jobs_per_worker});
torch::Tensor top_values = worker_and_job_to_score.new_empty({num_workers, jobs_per_worker + 1});
torch::Tensor high_bids = worker_and_job_to_score.new_empty({num_jobs});
torch::Tensor top_index = top_values.to(torch::kLong);
torch::Tensor high_bidders = top_index.new_empty({num_jobs});
torch::Tensor have_bids = high_bidders.to(torch::kBool);
torch::Tensor jobs_indices = torch::arange({num_jobs}, torch::dtype(torch::kLong).device(device));
torch::Tensor true_tensor = torch::ones({1}, torch::dtype(torch::kBool).device(device));
while (true) {
bids.zero_();
torch::topk_out(top_values, top_index, value, jobs_per_worker + 1, 1);
// Each worker bids the difference in value between that job and the k+1th job
torch::sub_out(bid_increments,
top_values.index({Slice(None, None), Slice(0, jobs_per_worker)}),
top_values.index({Slice(None, None), jobs_per_worker}).unsqueeze(1));
bid_increments.add_(epsilon);
bids.scatter_(1,
top_index.index({Slice(None, None),Slice(0, jobs_per_worker)}),
bid_increments);
if (counter < max_iterations && counter > 0) {
// Put in a minimal bid to retain items from the last round if no-one else bids for them this round
bids.view(-1).index_put_({bid_indices}, epsilon);
}
// Find the highest bidding worker per job
torch::max_out(high_bids, high_bidders, bids, 0);
torch::gt_out(have_bids, high_bids, 0);
if (have_bids.all().item<bool>()) {
// All jobs were bid for
break;
}
// Make popular items more expensive
cost.add_(high_bids);
torch::sub_out(value, worker_and_job_to_score, cost);
bid_indices = ((high_bidders * num_jobs) + jobs_indices).index({have_bids});
if (counter < max_iterations) {
// Make sure that this item will be in the winning worker's top-k next time.
value.view(-1).index_put_({bid_indices}, max_value);
}
else {
// Suboptimal approximation that converges quickly from current solution
value.view(-1).index_put_({bid_indices}, worker_and_job_to_score.view(-1).index({bid_indices}));
}
counter += 1;
}
return top_index.index({Slice(None, None), Slice(0, jobs_per_worker)}).reshape(-1);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("balanced_assignment", &balanced_assignment, "Balanced Assignment");
}

View File

@ -41,13 +41,16 @@ def collate_tokens(
move_eos_to_beginning=False,
pad_to_length=None,
pad_to_multiple=1,
pad_to_bsz=None,
):
"""Convert a list of 1d tensors into a padded 2d tensor."""
size = max(v.size(0) for v in values)
size = size if pad_to_length is None else max(size, pad_to_length)
if pad_to_multiple != 1 and size % pad_to_multiple != 0:
size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple)
res = values[0].new(len(values), size).fill_(pad_idx)
batch_size = len(values) if pad_to_bsz is None else max(len(values), pad_to_bsz)
res = values[0].new(batch_size, size).fill_(pad_idx)
def copy_tensor(src, dst):
assert dst.numel() == src.numel()

View File

@ -9,7 +9,7 @@ import torch
from . import FairseqDataset, data_utils
def collate(samples, pad_idx, eos_idx):
def collate(samples, pad_idx, eos_idx, fixed_pad_length=None, pad_to_bsz=None):
if len(samples) == 0:
return {}
@ -23,6 +23,8 @@ def collate(samples, pad_idx, eos_idx):
pad_idx,
eos_idx,
left_pad=False,
pad_to_length=fixed_pad_length,
pad_to_bsz=pad_to_bsz,
)
)
return res
@ -32,6 +34,8 @@ def collate(samples, pad_idx, eos_idx):
pad_idx,
eos_idx,
left_pad=False,
pad_to_length=fixed_pad_length,
pad_to_bsz=pad_to_bsz,
)
src_tokens = merge("source")
@ -75,6 +79,10 @@ class MonolingualDataset(FairseqDataset):
shuffle=False,
targets=None,
add_bos_token=False,
fixed_pad_length=None,
pad_to_bsz=None,
src_lang_idx=None,
tgt_lang_idx=None,
):
self.dataset = dataset
self.sizes = np.array(sizes)
@ -83,6 +91,10 @@ class MonolingualDataset(FairseqDataset):
self.add_eos_for_other_targets = add_eos_for_other_targets
self.shuffle = shuffle
self.add_bos_token = add_bos_token
self.fixed_pad_length = fixed_pad_length
self.pad_to_bsz = pad_to_bsz
self.src_lang_idx = src_lang_idx
self.tgt_lang_idx = tgt_lang_idx
assert targets is None or all(
t in {"self", "future", "past"} for t in targets
@ -165,6 +177,11 @@ class MonolingualDataset(FairseqDataset):
target = torch.cat([target.new([self.tgt_vocab.bos()]), target])
return source, target
def num_tokens_vec(self, indices):
"""Return the number of tokens for a set of positions defined by indices.
This value is used to enforce ``--max-tokens`` during batching."""
return self.sizes[indices]
def _filter_vocab(self, target):
if len(self.tgt_vocab) != len(self.vocab):
@ -200,7 +217,13 @@ class MonolingualDataset(FairseqDataset):
target sentence of shape `(bsz, tgt_len)`. Padding will appear
on the right.
"""
return collate(samples, self.vocab.pad(), self.vocab.eos())
return collate(
samples,
self.vocab.pad(),
self.vocab.eos(),
self.fixed_pad_length,
self.pad_to_bsz,
)
def num_tokens(self, index):
"""Return the number of tokens in a sample. This value is used to

View File

@ -136,6 +136,11 @@ class LegacyDistributedDataParallel(nn.Module):
continue
if param.grad is None:
param.grad = torch.zeros_like(param)
if hasattr(param, 'expert'):
# Skip gradient sync for unshared parameters
continue
if param.grad.requires_grad:
raise RuntimeError(
"DistributedDataParallel only works "

View File

@ -306,6 +306,9 @@ def distributed_init(cfg: FairseqConfig):
model_part_number = get_model_parallel_rank()
cfg.checkpoint.checkpoint_suffix += "-model_part-{0}".format(model_part_number)
if getattr(cfg.model, "base_layers", 0) > 0:
cfg.checkpoint.checkpoint_suffix = f"-rank-{cfg.distributed_training.distributed_rank}"
return cfg.distributed_training.distributed_rank

View File

@ -19,6 +19,7 @@ from fairseq.models import (
)
from fairseq.modules import (
AdaptiveSoftmax,
BaseLayer,
FairseqDropout,
LayerDropModuleList,
LayerNorm,
@ -751,6 +752,10 @@ class TransformerDecoder(FairseqIncrementalDecoder):
nn.init.normal_(
self.output_projection.weight, mean=0, std=self.output_embed_dim ** -0.5
)
num_base_layers = getattr(args, "base_layers", 0)
for i in range(num_base_layers):
self.layers.insert(((i+1) * args.decoder_layers) // (num_base_layers + 1), BaseLayer(args))
def build_decoder_layer(self, args, no_encoder_attn=False):
layer = TransformerDecoderLayer(args, no_encoder_attn)

View File

@ -180,6 +180,16 @@ class TransformerLanguageModelConfig(FairseqDataclass):
)
}
)
# config for "BASE Layers: Simplifying Training of Large, Sparse Models"
base_layers: Optional[int] = field(
default=0, metadata={"help": "number of BASE layers in total"}
)
base_sublayers: Optional[int] = field(
default=1, metadata={"help": "number of sublayers in each BASE layer"}
)
base_shuffle: Optional[int] = field(
default=1, metadata={"help": "shuffle tokens between workers before computing assignment"}
)
# options from other parts of the config
add_bos_token: bool = II("task.add_bos_token")
tokens_per_sample: int = II("task.tokens_per_sample")
@ -313,6 +323,10 @@ def base_lm_architecture(args):
args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8)
args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0)
args.base_layers = getattr(args, "base_layers", 0)
args.base_sublayers = getattr(args, "base_sublayers", 1)
args.base_shuffle = getattr(args, "base_shuffle", False)
args.add_bos_token = getattr(args, "add_bos_token", False)
args.no_token_positional_embeddings = getattr(
args, "no_token_positional_embeddings", False

View File

@ -6,6 +6,7 @@
from .adaptive_input import AdaptiveInput
from .adaptive_softmax import AdaptiveSoftmax
from .base_layer import BaseLayer
from .beamable_mm import BeamableMM
from .character_token_embedder import CharacterTokenEmbedder
from .conv_tbc import ConvTBC
@ -39,6 +40,7 @@ from .vggblock import VGGBlock
__all__ = [
"AdaptiveInput",
"AdaptiveSoftmax",
"BaseLayer",
"BeamableMM",
"CharacterTokenEmbedder",
"ConvTBC",

View File

@ -0,0 +1,135 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch.nn as nn
import torch
import sys
from fairseq import utils
from fairseq.distributed import utils as distributed_utils
from fairseq.modules.layer_norm import LayerNorm
class BaseLayer(nn.Module):
def __init__(self, args):
super().__init__()
self.num_workers = distributed_utils.get_data_parallel_world_size()
expert_centroids = torch.empty(self.num_workers, args.decoder_embed_dim)
torch.nn.init.orthogonal_(expert_centroids, gain=0.1)
self.register_parameter("expert_centroids", torch.nn.Parameter(expert_centroids))
self.expert_network = nn.Sequential(*([BaseSublayer(args) for _ in range(args.base_sublayers)]))
self.expert_id = distributed_utils.get_data_parallel_rank()
self.shuffle = args.base_shuffle
self.cpp = self.load_assignment()
# Add a special attribute to the expert parameters, so we know not to sync their gradients
for param in self.expert_network.parameters():
param.expert = True
def forward(self, input_features, *args, **kwargs):
features = input_features.reshape(-1, input_features.size(-1))
is_training = input_features.requires_grad
if self.shuffle and is_training:
# Send each token to a random worker, to break correlations within the batch
shuffle_sort = torch.randperm(features.size(0), device=features.device)
features = All2All.apply(features[shuffle_sort])
with torch.no_grad():
# Compute similarity of each token to each expert, for routing
token_expert_affinities = features.matmul(self.expert_centroids.transpose(0, 1))
# Compute which token goes to which expert
sort_by_expert, input_splits, output_splits = self.balanced_assignment(token_expert_affinities) \
if is_training else self.greedy_assignment(token_expert_affinities)
# Swap these tokens for the right ones for our expert
routed_features = All2All.apply(features[sort_by_expert], output_splits, input_splits)
if routed_features.size(0) > 0:
# Mix in the expert network based on how appropriate it is for these tokens
alpha = torch.sigmoid(routed_features.mv(self.expert_centroids[self.expert_id])).unsqueeze(1)
routed_features = alpha * self.expert_network(routed_features) + (1 - alpha) * routed_features
# Return to original worker and ordering
result = All2All.apply(routed_features, input_splits, output_splits)[self.inverse_sort(sort_by_expert)]
if self.shuffle and is_training:
# Undo shuffling
result = All2All.apply(result)[self.inverse_sort(shuffle_sort)]
# Return additional Nones for compatibility with TransformerDecoderLayer
return result.view(input_features.size()), None, None
def inverse_sort(self, order):
# Creates an index that undoes a sort: xs==xs[order][inverse_sort(order)]
return torch.empty_like(order).scatter_(0, order, torch.arange(0, order.size(0), device=order.device))
def balanced_assignment(self, scores):
ok = scores.isfinite()
if not ok.all():
# NaNs here can break the assignment algorithm
scores[~ok] = scores[ok].min()
return self.cpp.balanced_assignment(scores), None, None
# Assigns each token to the top k experts
def greedy_assignment(self, scores, k=1):
token_to_workers = torch.topk(scores, dim=1, k=k, largest=True).indices.view(-1)
token_to_workers, sort_ordering = torch.sort(token_to_workers)
worker2token = sort_ordering // k
# Find how many tokens we're sending to each other worker (being careful for sending 0 tokens to some workers)
output_splits = torch.zeros((self.num_workers,), dtype=torch.long, device=scores.device)
workers, counts = torch.unique_consecutive(token_to_workers, return_counts=True)
output_splits[workers] = counts
# Tell other workers how many tokens to expect from us
input_splits = All2All.apply(output_splits)
return worker2token, input_splits.tolist(), output_splits.tolist()
def load_assignment(self):
try:
from fairseq import libbase
return libbase
except ImportError as e:
sys.stderr.write(
"ERROR: missing libbase. run `python setup.py build_ext --inplace`\n"
)
raise e
class BaseSublayer(nn.Module):
def __init__(self, args):
super().__init__()
self.activation_fn = utils.get_activation_fn(
activation=getattr(args, 'activation_fn', 'relu') or "relu"
)
self.norm = LayerNorm(args.decoder_embed_dim, export=False)
self.ff1 = torch.nn.Linear(args.decoder_embed_dim, args.decoder_ffn_embed_dim)
self.ff2 = torch.nn.Linear(args.decoder_ffn_embed_dim, args.decoder_embed_dim)
self.ff2.weight.data.zero_()
def forward(self, xs):
return xs + self.ff2(self.activation_fn(self.ff1(self.norm(xs))))
# Wraps torch.distributed.all_to_all_single as a function that supports autograd
class All2All(torch.autograd.Function):
@staticmethod
def forward(ctx, xs, input_splits=None, output_splits=None):
ctx.input_splits = input_splits
ctx.output_splits = output_splits
ys = torch.empty_like(xs) if output_splits is None else \
xs.new_empty(size=[sum(output_splits)] + list(xs.size()[1:]))
torch.distributed.all_to_all_single(ys, xs, output_split_sizes=output_splits, input_split_sizes=input_splits)
return ys
@staticmethod
def backward(ctx, grad_output):
result = torch.empty_like(grad_output) if ctx.input_splits is None else \
grad_output.new_empty(size=[sum(ctx.input_splits)] + list(grad_output.size()[1:]))
torch.distributed.all_to_all_single(result, grad_output,
output_split_sizes=ctx.input_splits, input_split_sizes=ctx.output_splits)
return result, None, None

View File

@ -64,6 +64,8 @@ class _FP16OptimizerMixin(object):
fp32_params = []
for p in params:
p32 = torch.nn.Parameter(p.data.float())
if hasattr(p, 'expert'):
p32.expert = True
p32.grad = torch.zeros_like(p32.data)
if hasattr(p, "param_group"):
p32.param_group = p.param_group

View File

@ -84,8 +84,17 @@ class LanguageModelingConfig(FairseqDataclass):
'e.g., "train,valid" (default: all dataset splits)'
},
)
pad_to_fixed_length: Optional[bool] = field(
default=False, metadata={"help": "pad to fixed length"},
)
pad_to_fixed_bsz: Optional[bool] = field(
default=False, metadata={"help": "boolean to pad to fixed batch size"},
)
# TODO common vars below add to parent
seed: int = II("common.seed")
batch_size: Optional[int] = II("dataset.batch_size")
batch_size_valid: Optional[int] = II("dataset.batch_size_valid")
dataset_impl: Optional[ChoiceEnum(get_available_dataset_impl())] = II(
"dataset.dataset_impl"
)
@ -232,6 +241,13 @@ class LanguageModelingTask(LegacyFairseqTask):
self.args.sample_break_mode is not None
and self.args.sample_break_mode != "none"
)
fixed_pad_length = None
if self.args.pad_to_fixed_length:
fixed_pad_length = self.args.tokens_per_sample
pad_to_bsz = None
if self.args.pad_to_fixed_bsz:
pad_to_bsz = self.args.batch_size_valid if 'valid' in split else self.args.batch_size
self.datasets[split] = MonolingualDataset(
dataset=dataset,
@ -242,6 +258,8 @@ class LanguageModelingTask(LegacyFairseqTask):
shuffle=True,
targets=self.targets,
add_bos_token=self.args.add_bos_token,
fixed_pad_length=fixed_pad_length,
pad_to_bsz=pad_to_bsz,
)
def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs):

View File

@ -195,7 +195,7 @@ class Trainer(object):
@property
def should_save_checkpoint_on_current_rank(self) -> bool:
"""Indicates whether to save checkpoints on the current DDP rank."""
if self.cfg.distributed_training.ddp_backend == "fully_sharded":
if self.cfg.distributed_training.ddp_backend == "fully_sharded" or getattr(self.cfg.model, "base_layers", 0) > 0:
return True
else:
return self.is_data_parallel_master
@ -415,6 +415,7 @@ class Trainer(object):
or self.tpu
# FSDP requires loading checkpoint shards on all ranks
or self.cfg.distributed_training.ddp_backend == "fully_sharded"
or getattr(self.cfg.model, "base_layers", 0) > 0
)
if load_on_all_ranks or self.data_parallel_rank == 0:

View File

@ -339,10 +339,14 @@ def multi_tensor_total_norm(grads, chunk_size=2048 * 32) -> torch.Tensor:
@torch.no_grad()
def clip_grad_norm_(params, max_norm, aggregate_norm_fn=None) -> torch.Tensor:
def grad_exists(p):
return p is not None and getattr(p, "grad", None) is not None
if isinstance(params, torch.Tensor):
params = [params]
params = list(params)
grads = [p.grad.detach() for p in filter(lambda p: p.grad is not None, params)]
grads = [p.grad.detach() for p in params if grad_exists(p) and not hasattr(p, 'expert')]
expert_grads = [p.grad.detach() for p in params if grad_exists(p) and hasattr(p, 'expert')]
if len(grads) == 0:
if len(params) > 0:
return params[0].new_tensor(0.0)
@ -377,7 +381,7 @@ def clip_grad_norm_(params, max_norm, aggregate_norm_fn=None) -> torch.Tensor:
if max_norm > 0:
max_norm = float(max_norm)
clip_coef = (max_norm / (total_norm + 1e-6)).clamp_(max=1)
for g in grads:
for g in grads + expert_grads:
g.mul_(clip_coef)
return total_norm

View File

@ -98,9 +98,16 @@ def main(cfg: FairseqConfig) -> None:
logger.info("model: {}".format(model.__class__.__name__))
logger.info("criterion: {}".format(criterion.__class__.__name__))
logger.info(
"num. model params: {:,} (num. trained: {:,})".format(
sum(getattr(p, "_orig_size", p).numel() for p in model.parameters()),
sum(getattr(p, "_orig_size", p).numel() for p in model.parameters() if p.requires_grad),
"num. shared model params: {:,} (num. trained: {:,})".format(
sum(p.numel() for p in model.parameters() if not getattr(p, "expert", False)),
sum(p.numel() for p in model.parameters() if not getattr(p, "expert", False) and p.requires_grad)
)
)
logger.info(
"num. expert model params: {} (num. trained: {})".format(
sum(p.numel() for p in model.parameters() if getattr(p, "expert", False)),
sum(p.numel() for p in model.parameters() if getattr(p, "expert", False) and p.requires_grad),
)
)

View File

@ -99,6 +99,17 @@ try:
# torch is not available when generating docs
from torch.utils import cpp_extension
extensions.extend(
[
cpp_extension.CppExtension(
"fairseq.libbase",
sources=[
"fairseq/clib/libbase/balanced_assignment.cpp",
],
)
]
)
extensions.extend(
[
cpp_extension.CppExtension(