From 7dafb05754fe268bb5f76a1c97cf3a14062f44e5 Mon Sep 17 00:00:00 2001 From: Michael Lewis Date: Mon, 29 Mar 2021 18:02:07 -0700 Subject: [PATCH] BASE layers (#1654) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes # (issue). ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1654 Reviewed By: myleott Differential Revision: D27128074 Pulled By: shruti-bh fbshipit-source-id: ac86d383cd53c9c9bdd946fea839a37b719d95e3 --- fairseq/checkpoint_utils.py | 8 +- fairseq/clib/libbase/balanced_assignment.cpp | 95 ++++++++++++ fairseq/data/data_utils.py | 5 +- fairseq/data/monolingual_dataset.py | 27 +++- .../legacy_distributed_data_parallel.py | 5 + fairseq/distributed/utils.py | 3 + fairseq/models/transformer.py | 5 + fairseq/models/transformer_lm.py | 14 ++ fairseq/modules/__init__.py | 2 + fairseq/modules/base_layer.py | 135 ++++++++++++++++++ fairseq/optim/fp16_optimizer.py | 2 + fairseq/tasks/language_modeling.py | 18 +++ fairseq/trainer.py | 3 +- fairseq/utils.py | 8 +- fairseq_cli/train.py | 13 +- setup.py | 11 ++ 16 files changed, 341 insertions(+), 13 deletions(-) create mode 100644 fairseq/clib/libbase/balanced_assignment.cpp create mode 100644 fairseq/modules/base_layer.py diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index 86e00a771..7e1b8479d 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -115,7 +115,7 @@ def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss): if not end_of_epoch and cfg.keep_interval_updates > 0: # remove old checkpoints; checkpoints are sorted in descending order checkpoints = checkpoint_paths( - cfg.save_dir, pattern=r"checkpoint_\d+_(\d+)\.pt" + cfg.save_dir, pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix) ) for old_chk in checkpoints[cfg.keep_interval_updates :]: if os.path.lexists(old_chk): @@ -123,7 +123,7 @@ def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss): if cfg.keep_last_epochs > 0: # remove old epoch checkpoints; checkpoints are sorted in descending order - checkpoints = checkpoint_paths(cfg.save_dir, pattern=r"checkpoint(\d+)\.pt") + checkpoints = checkpoint_paths(cfg.save_dir, pattern=r"checkpoint(\d+){}\.pt".format(suffix)) for old_chk in checkpoints[cfg.keep_last_epochs :]: if os.path.lexists(old_chk): os.remove(old_chk) @@ -132,8 +132,8 @@ def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss): # only keep the best N checkpoints according to validation metric checkpoints = checkpoint_paths( cfg.save_dir, - pattern=r"checkpoint\.best_{}_(\d+\.?\d*)\.pt".format( - cfg.best_checkpoint_metric + pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format( + cfg.best_checkpoint_metric, suffix ), ) if not cfg.maximize_best_checkpoint_metric: diff --git a/fairseq/clib/libbase/balanced_assignment.cpp b/fairseq/clib/libbase/balanced_assignment.cpp new file mode 100644 index 000000000..296f03b6a --- /dev/null +++ b/fairseq/clib/libbase/balanced_assignment.cpp @@ -0,0 +1,95 @@ +/** + * Copyright 2017-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* +C++ code for solving the linear assignment problem. +Based on the Auction Algorithm from https://dspace.mit.edu/bitstream/handle/1721.1/3265/P-2108-26912652.pdf and the implementation from: +https://github.com/bkj/auction-lap +Adapted to be more efficient when each worker is looking for k jobs instead of 1. +*/ +#include +#include +using namespace torch::indexing; +torch::Tensor balanced_assignment(torch::Tensor job_and_worker_to_score) { + int max_iterations = 100; + torch::Tensor epsilon = (job_and_worker_to_score.max() - job_and_worker_to_score.min()) / 50; + epsilon.clamp_min_(1e-04); + torch::Tensor worker_and_job_to_score = job_and_worker_to_score.detach().transpose(0,1).contiguous(); + int num_workers = worker_and_job_to_score.size(0); + int num_jobs = worker_and_job_to_score.size(1); + auto device = worker_and_job_to_score.device(); + int jobs_per_worker = num_jobs / num_workers; + torch::Tensor value = worker_and_job_to_score.clone(); + int counter = 0; + torch::Tensor max_value = worker_and_job_to_score.max(); + + torch::Tensor bid_indices; + torch::Tensor cost = worker_and_job_to_score.new_zeros({1, num_jobs}); + torch::Tensor bids = worker_and_job_to_score.new_empty({num_workers, num_jobs}); + torch::Tensor bid_increments = worker_and_job_to_score.new_empty({num_workers, jobs_per_worker}); + torch::Tensor top_values = worker_and_job_to_score.new_empty({num_workers, jobs_per_worker + 1}); + torch::Tensor high_bids = worker_and_job_to_score.new_empty({num_jobs}); + + torch::Tensor top_index = top_values.to(torch::kLong); + torch::Tensor high_bidders = top_index.new_empty({num_jobs}); + torch::Tensor have_bids = high_bidders.to(torch::kBool); + torch::Tensor jobs_indices = torch::arange({num_jobs}, torch::dtype(torch::kLong).device(device)); + torch::Tensor true_tensor = torch::ones({1}, torch::dtype(torch::kBool).device(device)); + + while (true) { + bids.zero_(); + torch::topk_out(top_values, top_index, value, jobs_per_worker + 1, 1); + + // Each worker bids the difference in value between that job and the k+1th job + torch::sub_out(bid_increments, + top_values.index({Slice(None, None), Slice(0, jobs_per_worker)}), + top_values.index({Slice(None, None), jobs_per_worker}).unsqueeze(1)); + + bid_increments.add_(epsilon); + bids.scatter_(1, + top_index.index({Slice(None, None),Slice(0, jobs_per_worker)}), + bid_increments); + + if (counter < max_iterations && counter > 0) { + // Put in a minimal bid to retain items from the last round if no-one else bids for them this round + bids.view(-1).index_put_({bid_indices}, epsilon); + } + + // Find the highest bidding worker per job + torch::max_out(high_bids, high_bidders, bids, 0); + torch::gt_out(have_bids, high_bids, 0); + + if (have_bids.all().item()) { + // All jobs were bid for + break; + } + + // Make popular items more expensive + cost.add_(high_bids); + torch::sub_out(value, worker_and_job_to_score, cost); + + bid_indices = ((high_bidders * num_jobs) + jobs_indices).index({have_bids}); + + if (counter < max_iterations) { + // Make sure that this item will be in the winning worker's top-k next time. + value.view(-1).index_put_({bid_indices}, max_value); + } + else { + // Suboptimal approximation that converges quickly from current solution + value.view(-1).index_put_({bid_indices}, worker_and_job_to_score.view(-1).index({bid_indices})); + } + + counter += 1; + } + + return top_index.index({Slice(None, None), Slice(0, jobs_per_worker)}).reshape(-1); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("balanced_assignment", &balanced_assignment, "Balanced Assignment"); +} diff --git a/fairseq/data/data_utils.py b/fairseq/data/data_utils.py index 79df6f376..63c7fcd11 100644 --- a/fairseq/data/data_utils.py +++ b/fairseq/data/data_utils.py @@ -41,13 +41,16 @@ def collate_tokens( move_eos_to_beginning=False, pad_to_length=None, pad_to_multiple=1, + pad_to_bsz=None, ): """Convert a list of 1d tensors into a padded 2d tensor.""" size = max(v.size(0) for v in values) size = size if pad_to_length is None else max(size, pad_to_length) if pad_to_multiple != 1 and size % pad_to_multiple != 0: size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple) - res = values[0].new(len(values), size).fill_(pad_idx) + + batch_size = len(values) if pad_to_bsz is None else max(len(values), pad_to_bsz) + res = values[0].new(batch_size, size).fill_(pad_idx) def copy_tensor(src, dst): assert dst.numel() == src.numel() diff --git a/fairseq/data/monolingual_dataset.py b/fairseq/data/monolingual_dataset.py index bf7aa86f6..54fd583b6 100644 --- a/fairseq/data/monolingual_dataset.py +++ b/fairseq/data/monolingual_dataset.py @@ -9,7 +9,7 @@ import torch from . import FairseqDataset, data_utils -def collate(samples, pad_idx, eos_idx): +def collate(samples, pad_idx, eos_idx, fixed_pad_length=None, pad_to_bsz=None): if len(samples) == 0: return {} @@ -23,6 +23,8 @@ def collate(samples, pad_idx, eos_idx): pad_idx, eos_idx, left_pad=False, + pad_to_length=fixed_pad_length, + pad_to_bsz=pad_to_bsz, ) ) return res @@ -32,6 +34,8 @@ def collate(samples, pad_idx, eos_idx): pad_idx, eos_idx, left_pad=False, + pad_to_length=fixed_pad_length, + pad_to_bsz=pad_to_bsz, ) src_tokens = merge("source") @@ -75,6 +79,10 @@ class MonolingualDataset(FairseqDataset): shuffle=False, targets=None, add_bos_token=False, + fixed_pad_length=None, + pad_to_bsz=None, + src_lang_idx=None, + tgt_lang_idx=None, ): self.dataset = dataset self.sizes = np.array(sizes) @@ -83,6 +91,10 @@ class MonolingualDataset(FairseqDataset): self.add_eos_for_other_targets = add_eos_for_other_targets self.shuffle = shuffle self.add_bos_token = add_bos_token + self.fixed_pad_length = fixed_pad_length + self.pad_to_bsz = pad_to_bsz + self.src_lang_idx = src_lang_idx + self.tgt_lang_idx = tgt_lang_idx assert targets is None or all( t in {"self", "future", "past"} for t in targets @@ -165,6 +177,11 @@ class MonolingualDataset(FairseqDataset): target = torch.cat([target.new([self.tgt_vocab.bos()]), target]) return source, target + def num_tokens_vec(self, indices): + """Return the number of tokens for a set of positions defined by indices. + This value is used to enforce ``--max-tokens`` during batching.""" + return self.sizes[indices] + def _filter_vocab(self, target): if len(self.tgt_vocab) != len(self.vocab): @@ -200,7 +217,13 @@ class MonolingualDataset(FairseqDataset): target sentence of shape `(bsz, tgt_len)`. Padding will appear on the right. """ - return collate(samples, self.vocab.pad(), self.vocab.eos()) + return collate( + samples, + self.vocab.pad(), + self.vocab.eos(), + self.fixed_pad_length, + self.pad_to_bsz, + ) def num_tokens(self, index): """Return the number of tokens in a sample. This value is used to diff --git a/fairseq/distributed/legacy_distributed_data_parallel.py b/fairseq/distributed/legacy_distributed_data_parallel.py index b586e76b7..f2308f87c 100644 --- a/fairseq/distributed/legacy_distributed_data_parallel.py +++ b/fairseq/distributed/legacy_distributed_data_parallel.py @@ -136,6 +136,11 @@ class LegacyDistributedDataParallel(nn.Module): continue if param.grad is None: param.grad = torch.zeros_like(param) + + if hasattr(param, 'expert'): + # Skip gradient sync for unshared parameters + continue + if param.grad.requires_grad: raise RuntimeError( "DistributedDataParallel only works " diff --git a/fairseq/distributed/utils.py b/fairseq/distributed/utils.py index 970b78491..b09e87fe0 100644 --- a/fairseq/distributed/utils.py +++ b/fairseq/distributed/utils.py @@ -306,6 +306,9 @@ def distributed_init(cfg: FairseqConfig): model_part_number = get_model_parallel_rank() cfg.checkpoint.checkpoint_suffix += "-model_part-{0}".format(model_part_number) + if getattr(cfg.model, "base_layers", 0) > 0: + cfg.checkpoint.checkpoint_suffix = f"-rank-{cfg.distributed_training.distributed_rank}" + return cfg.distributed_training.distributed_rank diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py index 1e47d102f..8da5beb3a 100644 --- a/fairseq/models/transformer.py +++ b/fairseq/models/transformer.py @@ -19,6 +19,7 @@ from fairseq.models import ( ) from fairseq.modules import ( AdaptiveSoftmax, + BaseLayer, FairseqDropout, LayerDropModuleList, LayerNorm, @@ -751,6 +752,10 @@ class TransformerDecoder(FairseqIncrementalDecoder): nn.init.normal_( self.output_projection.weight, mean=0, std=self.output_embed_dim ** -0.5 ) + num_base_layers = getattr(args, "base_layers", 0) + for i in range(num_base_layers): + self.layers.insert(((i+1) * args.decoder_layers) // (num_base_layers + 1), BaseLayer(args)) + def build_decoder_layer(self, args, no_encoder_attn=False): layer = TransformerDecoderLayer(args, no_encoder_attn) diff --git a/fairseq/models/transformer_lm.py b/fairseq/models/transformer_lm.py index b616a923d..a54677691 100644 --- a/fairseq/models/transformer_lm.py +++ b/fairseq/models/transformer_lm.py @@ -180,6 +180,16 @@ class TransformerLanguageModelConfig(FairseqDataclass): ) } ) + # config for "BASE Layers: Simplifying Training of Large, Sparse Models" + base_layers: Optional[int] = field( + default=0, metadata={"help": "number of BASE layers in total"} + ) + base_sublayers: Optional[int] = field( + default=1, metadata={"help": "number of sublayers in each BASE layer"} + ) + base_shuffle: Optional[int] = field( + default=1, metadata={"help": "shuffle tokens between workers before computing assignment"} + ) # options from other parts of the config add_bos_token: bool = II("task.add_bos_token") tokens_per_sample: int = II("task.tokens_per_sample") @@ -313,6 +323,10 @@ def base_lm_architecture(args): args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8) args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0) + args.base_layers = getattr(args, "base_layers", 0) + args.base_sublayers = getattr(args, "base_sublayers", 1) + args.base_shuffle = getattr(args, "base_shuffle", False) + args.add_bos_token = getattr(args, "add_bos_token", False) args.no_token_positional_embeddings = getattr( args, "no_token_positional_embeddings", False diff --git a/fairseq/modules/__init__.py b/fairseq/modules/__init__.py index e2326ac6e..81930aa71 100644 --- a/fairseq/modules/__init__.py +++ b/fairseq/modules/__init__.py @@ -6,6 +6,7 @@ from .adaptive_input import AdaptiveInput from .adaptive_softmax import AdaptiveSoftmax +from .base_layer import BaseLayer from .beamable_mm import BeamableMM from .character_token_embedder import CharacterTokenEmbedder from .conv_tbc import ConvTBC @@ -39,6 +40,7 @@ from .vggblock import VGGBlock __all__ = [ "AdaptiveInput", "AdaptiveSoftmax", + "BaseLayer", "BeamableMM", "CharacterTokenEmbedder", "ConvTBC", diff --git a/fairseq/modules/base_layer.py b/fairseq/modules/base_layer.py new file mode 100644 index 000000000..e7ef155b2 --- /dev/null +++ b/fairseq/modules/base_layer.py @@ -0,0 +1,135 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch.nn as nn +import torch +import sys +from fairseq import utils +from fairseq.distributed import utils as distributed_utils +from fairseq.modules.layer_norm import LayerNorm + + +class BaseLayer(nn.Module): + + def __init__(self, args): + super().__init__() + self.num_workers = distributed_utils.get_data_parallel_world_size() + expert_centroids = torch.empty(self.num_workers, args.decoder_embed_dim) + torch.nn.init.orthogonal_(expert_centroids, gain=0.1) + self.register_parameter("expert_centroids", torch.nn.Parameter(expert_centroids)) + self.expert_network = nn.Sequential(*([BaseSublayer(args) for _ in range(args.base_sublayers)])) + self.expert_id = distributed_utils.get_data_parallel_rank() + self.shuffle = args.base_shuffle + self.cpp = self.load_assignment() + + # Add a special attribute to the expert parameters, so we know not to sync their gradients + for param in self.expert_network.parameters(): + param.expert = True + + def forward(self, input_features, *args, **kwargs): + features = input_features.reshape(-1, input_features.size(-1)) + is_training = input_features.requires_grad + + if self.shuffle and is_training: + # Send each token to a random worker, to break correlations within the batch + shuffle_sort = torch.randperm(features.size(0), device=features.device) + features = All2All.apply(features[shuffle_sort]) + + with torch.no_grad(): + # Compute similarity of each token to each expert, for routing + token_expert_affinities = features.matmul(self.expert_centroids.transpose(0, 1)) + + # Compute which token goes to which expert + sort_by_expert, input_splits, output_splits = self.balanced_assignment(token_expert_affinities) \ + if is_training else self.greedy_assignment(token_expert_affinities) + # Swap these tokens for the right ones for our expert + routed_features = All2All.apply(features[sort_by_expert], output_splits, input_splits) + + if routed_features.size(0) > 0: + # Mix in the expert network based on how appropriate it is for these tokens + alpha = torch.sigmoid(routed_features.mv(self.expert_centroids[self.expert_id])).unsqueeze(1) + routed_features = alpha * self.expert_network(routed_features) + (1 - alpha) * routed_features + # Return to original worker and ordering + result = All2All.apply(routed_features, input_splits, output_splits)[self.inverse_sort(sort_by_expert)] + + if self.shuffle and is_training: + # Undo shuffling + result = All2All.apply(result)[self.inverse_sort(shuffle_sort)] + + # Return additional Nones for compatibility with TransformerDecoderLayer + return result.view(input_features.size()), None, None + + def inverse_sort(self, order): + # Creates an index that undoes a sort: xs==xs[order][inverse_sort(order)] + return torch.empty_like(order).scatter_(0, order, torch.arange(0, order.size(0), device=order.device)) + + def balanced_assignment(self, scores): + ok = scores.isfinite() + if not ok.all(): + # NaNs here can break the assignment algorithm + scores[~ok] = scores[ok].min() + return self.cpp.balanced_assignment(scores), None, None + + # Assigns each token to the top k experts + def greedy_assignment(self, scores, k=1): + token_to_workers = torch.topk(scores, dim=1, k=k, largest=True).indices.view(-1) + token_to_workers, sort_ordering = torch.sort(token_to_workers) + worker2token = sort_ordering // k + + # Find how many tokens we're sending to each other worker (being careful for sending 0 tokens to some workers) + output_splits = torch.zeros((self.num_workers,), dtype=torch.long, device=scores.device) + workers, counts = torch.unique_consecutive(token_to_workers, return_counts=True) + output_splits[workers] = counts + # Tell other workers how many tokens to expect from us + input_splits = All2All.apply(output_splits) + return worker2token, input_splits.tolist(), output_splits.tolist() + + def load_assignment(self): + try: + from fairseq import libbase + + return libbase + + except ImportError as e: + sys.stderr.write( + "ERROR: missing libbase. run `python setup.py build_ext --inplace`\n" + ) + raise e + + +class BaseSublayer(nn.Module): + def __init__(self, args): + super().__init__() + self.activation_fn = utils.get_activation_fn( + activation=getattr(args, 'activation_fn', 'relu') or "relu" + ) + self.norm = LayerNorm(args.decoder_embed_dim, export=False) + self.ff1 = torch.nn.Linear(args.decoder_embed_dim, args.decoder_ffn_embed_dim) + self.ff2 = torch.nn.Linear(args.decoder_ffn_embed_dim, args.decoder_embed_dim) + self.ff2.weight.data.zero_() + + def forward(self, xs): + return xs + self.ff2(self.activation_fn(self.ff1(self.norm(xs)))) + + +# Wraps torch.distributed.all_to_all_single as a function that supports autograd +class All2All(torch.autograd.Function): + @staticmethod + def forward(ctx, xs, input_splits=None, output_splits=None): + ctx.input_splits = input_splits + ctx.output_splits = output_splits + + ys = torch.empty_like(xs) if output_splits is None else \ + xs.new_empty(size=[sum(output_splits)] + list(xs.size()[1:])) + torch.distributed.all_to_all_single(ys, xs, output_split_sizes=output_splits, input_split_sizes=input_splits) + return ys + + @staticmethod + def backward(ctx, grad_output): + result = torch.empty_like(grad_output) if ctx.input_splits is None else \ + grad_output.new_empty(size=[sum(ctx.input_splits)] + list(grad_output.size()[1:])) + torch.distributed.all_to_all_single(result, grad_output, + output_split_sizes=ctx.input_splits, input_split_sizes=ctx.output_splits) + return result, None, None diff --git a/fairseq/optim/fp16_optimizer.py b/fairseq/optim/fp16_optimizer.py index 00ea1bbb7..370a91010 100644 --- a/fairseq/optim/fp16_optimizer.py +++ b/fairseq/optim/fp16_optimizer.py @@ -64,6 +64,8 @@ class _FP16OptimizerMixin(object): fp32_params = [] for p in params: p32 = torch.nn.Parameter(p.data.float()) + if hasattr(p, 'expert'): + p32.expert = True p32.grad = torch.zeros_like(p32.data) if hasattr(p, "param_group"): p32.param_group = p.param_group diff --git a/fairseq/tasks/language_modeling.py b/fairseq/tasks/language_modeling.py index a3847733a..3069490fd 100644 --- a/fairseq/tasks/language_modeling.py +++ b/fairseq/tasks/language_modeling.py @@ -84,8 +84,17 @@ class LanguageModelingConfig(FairseqDataclass): 'e.g., "train,valid" (default: all dataset splits)' }, ) + pad_to_fixed_length: Optional[bool] = field( + default=False, metadata={"help": "pad to fixed length"}, + ) + pad_to_fixed_bsz: Optional[bool] = field( + default=False, metadata={"help": "boolean to pad to fixed batch size"}, + ) + # TODO common vars below add to parent seed: int = II("common.seed") + batch_size: Optional[int] = II("dataset.batch_size") + batch_size_valid: Optional[int] = II("dataset.batch_size_valid") dataset_impl: Optional[ChoiceEnum(get_available_dataset_impl())] = II( "dataset.dataset_impl" ) @@ -232,6 +241,13 @@ class LanguageModelingTask(LegacyFairseqTask): self.args.sample_break_mode is not None and self.args.sample_break_mode != "none" ) + fixed_pad_length = None + if self.args.pad_to_fixed_length: + fixed_pad_length = self.args.tokens_per_sample + + pad_to_bsz = None + if self.args.pad_to_fixed_bsz: + pad_to_bsz = self.args.batch_size_valid if 'valid' in split else self.args.batch_size self.datasets[split] = MonolingualDataset( dataset=dataset, @@ -242,6 +258,8 @@ class LanguageModelingTask(LegacyFairseqTask): shuffle=True, targets=self.targets, add_bos_token=self.args.add_bos_token, + fixed_pad_length=fixed_pad_length, + pad_to_bsz=pad_to_bsz, ) def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs): diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 4535e9bda..6195afb4a 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -195,7 +195,7 @@ class Trainer(object): @property def should_save_checkpoint_on_current_rank(self) -> bool: """Indicates whether to save checkpoints on the current DDP rank.""" - if self.cfg.distributed_training.ddp_backend == "fully_sharded": + if self.cfg.distributed_training.ddp_backend == "fully_sharded" or getattr(self.cfg.model, "base_layers", 0) > 0: return True else: return self.is_data_parallel_master @@ -415,6 +415,7 @@ class Trainer(object): or self.tpu # FSDP requires loading checkpoint shards on all ranks or self.cfg.distributed_training.ddp_backend == "fully_sharded" + or getattr(self.cfg.model, "base_layers", 0) > 0 ) if load_on_all_ranks or self.data_parallel_rank == 0: diff --git a/fairseq/utils.py b/fairseq/utils.py index 90bb8369f..03826d18d 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -339,10 +339,14 @@ def multi_tensor_total_norm(grads, chunk_size=2048 * 32) -> torch.Tensor: @torch.no_grad() def clip_grad_norm_(params, max_norm, aggregate_norm_fn=None) -> torch.Tensor: + def grad_exists(p): + return p is not None and getattr(p, "grad", None) is not None if isinstance(params, torch.Tensor): params = [params] params = list(params) - grads = [p.grad.detach() for p in filter(lambda p: p.grad is not None, params)] + grads = [p.grad.detach() for p in params if grad_exists(p) and not hasattr(p, 'expert')] + expert_grads = [p.grad.detach() for p in params if grad_exists(p) and hasattr(p, 'expert')] + if len(grads) == 0: if len(params) > 0: return params[0].new_tensor(0.0) @@ -377,7 +381,7 @@ def clip_grad_norm_(params, max_norm, aggregate_norm_fn=None) -> torch.Tensor: if max_norm > 0: max_norm = float(max_norm) clip_coef = (max_norm / (total_norm + 1e-6)).clamp_(max=1) - for g in grads: + for g in grads + expert_grads: g.mul_(clip_coef) return total_norm diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index 6924dfe5c..1cca64d98 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -98,9 +98,16 @@ def main(cfg: FairseqConfig) -> None: logger.info("model: {}".format(model.__class__.__name__)) logger.info("criterion: {}".format(criterion.__class__.__name__)) logger.info( - "num. model params: {:,} (num. trained: {:,})".format( - sum(getattr(p, "_orig_size", p).numel() for p in model.parameters()), - sum(getattr(p, "_orig_size", p).numel() for p in model.parameters() if p.requires_grad), + "num. shared model params: {:,} (num. trained: {:,})".format( + sum(p.numel() for p in model.parameters() if not getattr(p, "expert", False)), + sum(p.numel() for p in model.parameters() if not getattr(p, "expert", False) and p.requires_grad) + ) + ) + + logger.info( + "num. expert model params: {} (num. trained: {})".format( + sum(p.numel() for p in model.parameters() if getattr(p, "expert", False)), + sum(p.numel() for p in model.parameters() if getattr(p, "expert", False) and p.requires_grad), ) ) diff --git a/setup.py b/setup.py index 3670ff3cf..51e555229 100644 --- a/setup.py +++ b/setup.py @@ -99,6 +99,17 @@ try: # torch is not available when generating docs from torch.utils import cpp_extension + extensions.extend( + [ + cpp_extension.CppExtension( + "fairseq.libbase", + sources=[ + "fairseq/clib/libbase/balanced_assignment.cpp", + ], + ) + ] + ) + extensions.extend( [ cpp_extension.CppExtension(