diff --git a/fairseq/models/transformer/transformer_encoder.py b/fairseq/models/transformer/transformer_encoder.py index 0b7e6d837..c887c5afe 100644 --- a/fairseq/models/transformer/transformer_encoder.py +++ b/fairseq/models/transformer/transformer_encoder.py @@ -8,23 +8,22 @@ from typing import Dict, List, Optional import torch import torch.nn as nn +from torch import Tensor + from fairseq import utils from fairseq.distributed import fsdp_wrap from fairseq.models import FairseqEncoder +from fairseq.models.transformer import TransformerConfig from fairseq.modules import ( FairseqDropout, LayerDropModuleList, LayerNorm, PositionalEmbedding, SinusoidalPositionalEmbedding, + transformer_layer, ) -from fairseq.modules import transformer_layer from fairseq.modules.checkpoint_activations import checkpoint_wrapper from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_ -from torch import Tensor -from fairseq.models.transformer import ( - TransformerConfig, -) # rewrite name for backward compatibility in `make_generation_fast_` @@ -220,11 +219,79 @@ class TransformerEncoderBase(FairseqEncoder): if return_all_hiddens: encoder_states.append(x) - # encoder layers - for layer in self.layers: - lr = layer( - x, encoder_padding_mask=encoder_padding_mask if has_pads else None + # nested tensor and BT enable + layer = self.layers[0] + BT_flag = False + NT_flag = False + # torch version check, BT>=1.12.0 and NT>=1.13.0.dev20220613 + # internal format is '1.13.0a0+fb' + # external format is '1.13.0.dev20220613'(cpu&gpu) for nightly or "1.11.0"(cpu) or '1.11.0+cu102'(gpu) for stable + BT_version = False + NT_version = False + if "fb" in torch.__version__: + BT_version = True + NT_version = True + else: + if "+" in torch.__version__: + torch_version = torch.__version__.split("+")[0] + else: + torch_version = torch.__version__ + + torch_version = torch_version.split(".") + int_version = ( + int(torch_version[0]) * 1000 + + int(torch_version[1]) * 10 + + int(torch_version[2]) ) + if len(torch_version) == 3: + if int_version >= 1120: + BT_version = True + if int_version >= 1131: + NT_version = True + elif len(torch_version) == 4: + if int_version >= 1130: + BT_version = True + # Consider _nested_tensor_from_mask_left_aligned is landed after "20220613" + if int_version >= 1131 or ( + int_version == 1130 and torch_version[3][3:] >= "20220613" + ): + NT_version = True + + if ( + BT_version + and x.dim() == 3 + and layer.load_to_BT + and not layer.return_fc + and layer.can_use_fastpath + and not layer.training + and not layer.ever_training + and not layer.cfg_checkpoint_activations + ): + # Batch first can not be justified but needs user to make sure + x = x.transpose(0, 1) + # Check mask conditions for nested tensor + if NT_version: + if ( + encoder_padding_mask is not None + and torch._nested_tensor_from_mask_left_aligned( + x, encoder_padding_mask.logical_not() + ) + ): + if not torch.is_grad_enabled() or not x.requires_grad: + x = torch._nested_tensor_from_mask( + x, encoder_padding_mask.logical_not() + ) + NT_flag = True + BT_flag = True + + # encoder layers + if NT_flag: + processing_mask = None + else: + processing_mask = encoder_padding_mask + encoder_padding_mask_out = processing_mask if has_pads else None + for layer in self.layers: + lr = layer(x, encoder_padding_mask=encoder_padding_mask_out) if isinstance(lr, tuple) and len(lr) == 2: x, fc_result = lr @@ -237,6 +304,13 @@ class TransformerEncoderBase(FairseqEncoder): encoder_states.append(x) fc_results.append(fc_result) + # change back to non-nested and Batch second + if NT_flag: + x = x.to_padded_tensor(0.0) + + if NT_flag or BT_flag: + x = x.transpose(0, 1) + if self.layer_norm is not None: x = self.layer_norm(x) diff --git a/fairseq/modules/transformer_layer.py b/fairseq/modules/transformer_layer.py index 2e687b94d..6589a5075 100644 --- a/fairseq/modules/transformer_layer.py +++ b/fairseq/modules/transformer_layer.py @@ -7,14 +7,13 @@ from typing import Dict, List, Optional import torch import torch.nn as nn +from torch import Tensor + from fairseq import utils +from fairseq.models.transformer import TransformerConfig from fairseq.modules import LayerNorm, MultiheadAttention from fairseq.modules.fairseq_dropout import FairseqDropout from fairseq.modules.quant_noise import quant_noise -from torch import Tensor -from fairseq.models.transformer import ( - TransformerConfig, -) class TransformerEncoderLayerBase(nn.Module): @@ -68,6 +67,103 @@ class TransformerEncoderLayerBase(nn.Module): self.final_layer_norm = LayerNorm(self.embed_dim, export=cfg.export) + self.num_heads = cfg.encoder.attention_heads + self.load_to_BT = False + self.ever_training = False + # For BT, we need continuous mem + self.in_proj_weight = torch.nn.Parameter( + torch.zeros( + self.self_attn.q_proj.weight.shape[0] * 3, + self.self_attn.q_proj.weight.shape[1], + ) + ) + self.in_proj_bias = torch.nn.Parameter( + torch.zeros(self.self_attn.q_proj.bias.shape[0] * 3) + ) + + if ( + self.activation_fn is torch.nn.functional.relu + or isinstance(self.activation_fn, torch.nn.ReLU) + or self.activation_fn == "relu" + ): + self.activation_relu_or_gelu = 1 + elif ( + self.activation_fn is torch.nn.functional.gelu + or isinstance(self.activation_fn, torch.nn.GELU) + or self.activation_fn == "gelu" + ): + self.activation_relu_or_gelu = 2 + else: + self.activation_relu_or_gelu = 0 + # Batch first can not be justified but needs user to make sure + self.can_use_fastpath = ( + not self.normalize_before + and self.activation_relu_or_gelu + and (self.self_attn_layer_norm.eps == self.final_layer_norm.eps) + ) + self.cfg_checkpoint_activations = self.cfg.checkpoint_activations + # torch version check + # make sure BT version is >=1.12.0 + self.BT_version = False + if "fb" in torch.__version__: + self.BT_version = True + else: + if "+" in torch.__version__: + self.torch_version = torch.__version__.split("+")[0] + else: + self.torch_version = torch.__version__ + + self.torch_version = self.torch_version.split(".") + self.int_version = ( + int(self.torch_version[0]) * 1000 + + int(self.torch_version[1]) * 10 + + int(self.torch_version[2]) + ) + if len(self.torch_version) == 3: + if self.int_version >= 1120: + self.BT_version = True + elif len(self.torch_version) == 4: + if self.int_version >= 1130: + self.BT_version = True + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + self.load_to_BT = True + + old_name = prefix + "self_attn." + q_proj_weight = state_dict[old_name + "q_proj.weight"] + k_proj_weight = state_dict[old_name + "k_proj.weight"] + v_proj_weight = state_dict[old_name + "v_proj.weight"] + q_proj_bias = state_dict[old_name + "q_proj.bias"] + k_proj_bias = state_dict[old_name + "k_proj.bias"] + v_proj_bias = state_dict[old_name + "v_proj.bias"] + + new_name = prefix + state_dict[new_name + "in_proj_weight"] = torch.cat( + (q_proj_weight, k_proj_weight, v_proj_weight), dim=0 + ) + state_dict[new_name + "in_proj_bias"] = torch.cat( + (q_proj_bias, k_proj_bias, v_proj_bias), dim=0 + ) + + super(TransformerEncoderLayerBase, self)._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size): return quant_noise( nn.Linear(input_dim, output_dim), p=q_noise, block_size=qn_block_size @@ -187,44 +283,83 @@ class TransformerEncoderLayerBase(nn.Module): # Note that we cannot use -inf here, because at some edge cases, # the attention weight (before softmax) for some padded element in query # will become -inf, which results in NaN in model parameters - if attn_mask is not None: - attn_mask = attn_mask.masked_fill( - attn_mask.to(torch.bool), -1e8 if x.dtype == torch.float32 else -1e4 + + if self.training: + self.ever_training = True + + if ( + self.BT_version + and x.dim() == 3 + and self.load_to_BT + and not self.return_fc + and self.can_use_fastpath + and not self.training + and not self.ever_training + and not self.cfg_checkpoint_activations + ): + # assume is Batch first and nested tensor + output = torch._transformer_encoder_layer_fwd( + x, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.self_attn.out_proj.weight, + self.self_attn.out_proj.bias, + self.activation_relu_or_gelu == 2, + False, # norm_first, currently not supported + self.self_attn_layer_norm.eps, + self.self_attn_layer_norm.weight, + self.self_attn_layer_norm.bias, + self.final_layer_norm.weight, + self.final_layer_norm.bias, + self.fc1.weight, + self.fc1.bias, + self.fc2.weight, + self.fc2.bias, + encoder_padding_mask if encoder_padding_mask is not None else attn_mask, ) + return output - residual = x - if self.normalize_before: - x = self.self_attn_layer_norm(x) - x, _ = self.self_attn( - query=x, - key=x, - value=x, - key_padding_mask=encoder_padding_mask, - need_weights=False, - attn_mask=attn_mask, - ) - x = self.dropout_module(x) - x = self.residual_connection(x, residual) - if not self.normalize_before: - x = self.self_attn_layer_norm(x) + else: + if attn_mask is not None: + attn_mask = attn_mask.masked_fill( + attn_mask.to(torch.bool), -1e8 if x.dtype == torch.float32 else -1e4 + ) - residual = x - if self.normalize_before: - x = self.final_layer_norm(x) - x = self.activation_fn(self.fc1(x)) - x = self.activation_dropout_module(x) - x = self.fc2(x) + residual = x + if self.normalize_before: + x = self.self_attn_layer_norm(x) + x, _ = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=encoder_padding_mask, + need_weights=False, + attn_mask=attn_mask, + ) + x = self.dropout_module(x) + x = self.residual_connection(x, residual) + if not self.normalize_before: + x = self.self_attn_layer_norm(x) - fc_result = x + residual = x + if self.normalize_before: + x = self.final_layer_norm(x) + x = self.activation_fn(self.fc1(x)) + x = self.activation_dropout_module(x) + x = self.fc2(x) - x = self.dropout_module(x) - x = self.residual_connection(x, residual) - if not self.normalize_before: - x = self.final_layer_norm(x) + fc_result = x - if self.return_fc and not torch.jit.is_scripting(): - return x, fc_result - return x + x = self.dropout_module(x) + x = self.residual_connection(x, residual) + if not self.normalize_before: + x = self.final_layer_norm(x) + + if self.return_fc and not torch.jit.is_scripting(): + return x, fc_result + return x # backward compatible with the legacy argparse format diff --git a/scripts/better_transformer.py b/scripts/better_transformer.py new file mode 100644 index 000000000..2bbf64c3c --- /dev/null +++ b/scripts/better_transformer.py @@ -0,0 +1,380 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import sys + +import click +import numpy as np +import torch +from fvcore.nn import FlopCountAnalysis + +from fairseq.models.transformer import TransformerConfig as FairseqTransformerConfig +from fairseq.models.transformer import TransformerEncoder as FairseqTransformerEncoder + +seed = 0 +torch.manual_seed(seed) +np.random.seed(seed) + + +def benchmark_torch_function(iters, f, *args, **kwargs): + f(*args, **kwargs) + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + for _ in range(iters): + f(*args, **kwargs) + end_event.record() + torch.cuda.synchronize() + return (start_event.elapsed_time(end_event) * 1.0e-3) / iters + + +def numerical_test(lengths, truth_tensors, test_list): + """ + truth_tensors is the source of truth. + test_dict looks like + [ + (name, out_tensors, atol, rtol), + ... + ] + """ + for name, out_tensors, rtol, atol in test_list: + n_failures = 0 + max_diff = 0 + for (length, truth, out) in zip(lengths, truth_tensors, out_tensors): + cut_truth = truth[:length] + cut_out = out[:length] + max_diff = max(max_diff, torch.max(torch.abs(cut_truth - cut_out))) + if not torch.allclose(cut_truth, cut_out, atol=atol, rtol=rtol): + n_failures += 1 + + if n_failures == 0: + print(f"{name} PASS") + else: + print(f"{name} FAIL {n_failures}/{len(lengths)}. Max diff is {max_diff}") + + +@click.group() +def cli(): + pass + + +@cli.command() +@click.option("--save", is_flag=True, default=False) +@click.option("--load", is_flag=True, default=False) +@click.option("--half", is_flag=True, default=False) +@click.option("--bt2fairseq", is_flag=True, default=False) +def transformer( + save, + load, + half, + bt2fairseq, +): + xlarge = False + large = False + DEFAULT_PADDING_IDX = 1 + avg_sequence_length = 128 + max_sequence_length = 256 + batch_size = 64 + + class FairseqEncoder(torch.nn.Module): + def __init__( + self, + embed_dim, + attention_heads, + ffn_embed_dim, + num_layers, + embedding_layer, # torch.nn.Embedding. Must have a padding_idx field + dropout=0, + normalize_before=False, + torch_encoder=None, # torch encoder that you can map weights from + activation="relu", + ): + super().__init__() + + cfg = FairseqTransformerConfig() + cfg.encoder.embed_dim = embed_dim + cfg.encoder.attention_heads = attention_heads + cfg.encoder.ffn_embed_dim = ffn_embed_dim + cfg.dropout = dropout + cfg.encoder.normalize_before = normalize_before + cfg.encoder.layers = num_layers + # make embedding behavior same as other encoders + cfg.no_token_positional_embeddings = True + cfg.no_scale_embedding = True + cfg.activation_fn = activation + dictionary = {} # TODO: verify what this is + + self.encoder = FairseqTransformerEncoder( + cfg, dictionary, embedding_layer, return_fc=False + ) + + if torch_encoder is not None: + for src_layer, dst_layer in zip( + torch_encoder.layers, self.encoder.layers + ): + w_q, w_k, w_v = src_layer.self_attn.in_proj_weight.chunk(3, dim=0) + b_q, b_k, b_v = src_layer.self_attn.in_proj_bias.chunk(3, dim=0) + + dst_layer.self_attn.q_proj.weight = torch.nn.Parameter(w_q) + dst_layer.self_attn.q_proj.bias = torch.nn.Parameter(b_q) + dst_layer.self_attn.k_proj.weight = torch.nn.Parameter(w_k) + dst_layer.self_attn.k_proj.bias = torch.nn.Parameter(b_k) + dst_layer.self_attn.v_proj.weight = torch.nn.Parameter(w_v) + dst_layer.self_attn.v_proj.bias = torch.nn.Parameter(b_v) + + dst_layer.self_attn.out_proj.weight = ( + src_layer.self_attn.out_proj.weight + ) + dst_layer.self_attn.out_proj.bias = ( + src_layer.self_attn.out_proj.bias + ) + + dst_layer.fc1.weight = src_layer.linear1.weight + dst_layer.fc1.bias = src_layer.linear1.bias + + # fairseq may use fusedlayernorm from nvidia apex - diff properties + dst_layer.self_attn_layer_norm.load_state_dict( + src_layer.norm1.state_dict() + ) + + dst_layer.fc2.weight = src_layer.linear2.weight + dst_layer.fc2.bias = src_layer.linear2.bias + + dst_layer.final_layer_norm.load_state_dict( + src_layer.norm2.state_dict() + ) + + # self.encoder = self.encoder.eval().cuda().half() + + def forward(self, tokens, src_lengths=None): + return self.encoder( + tokens, + src_lengths=src_lengths, + return_all_hiddens=False, + token_embeddings=None, + ) + + def get_layers_embedding_dim_num_heads_for_configuration(xlarge, large): + if xlarge: + # XLM-R extra large (no BERT-XL exists) + L = 24 # Layers + D = 2560 # Embedding Dim + H = 32 # Number of Heads + FD = 10240 # Feed-forward network dim + V = 30000 # Vocab Size + elif large: + # BERT-large + L = 24 + D = 1024 + H = 16 + FD = 4096 + V = 30000 + else: + # BERT-base + L = 12 + D = 768 + H = 12 + FD = 3072 + V = 30000 + + return (L, D, H, FD, V) + + # Better transformer + class PTTransformer(torch.nn.Module): + def __init__(self, transformer, embedding): + super().__init__() + self.transformer = transformer + self.embedding = embedding + self.padding_idx = DEFAULT_PADDING_IDX + + def forward(self, x): + padding_mask = None + if not x.is_nested: + padding_mask = x.eq(self.padding_idx) + x = self.embedding(x) + return self.transformer(x, src_key_padding_mask=padding_mask) + + def make_transformer(): + return ( + PTTransformer( + torch.nn.TransformerEncoder( + torch.nn.TransformerEncoderLayer( + d_model=D, + nhead=H, + dim_feedforward=FD, + batch_first=True, + activation="relu", + ), + num_layers=L, + enable_nested_tensor=False, + ), + embedding_layer, + ) + .eval() + .cuda() + ) + + def copy_weights(layers_fairseq, layers_bt): + for src_layer, dst_layer in zip(layers_fairseq, layers_bt): + w_q = src_layer.self_attn.q_proj.weight + b_q = src_layer.self_attn.q_proj.bias + w_k = src_layer.self_attn.k_proj.weight + b_k = src_layer.self_attn.k_proj.bias + w_v = src_layer.self_attn.v_proj.weight + b_v = src_layer.self_attn.v_proj.bias + dst_layer.self_attn.in_proj_weight = torch.nn.Parameter( + torch.cat((w_q, w_k, w_v), dim=0) + ) + dst_layer.self_attn.in_proj_bias = torch.nn.Parameter( + torch.cat((b_q, b_k, b_v), dim=0) + ) + + dst_layer.self_attn.out_proj.weight = src_layer.self_attn.out_proj.weight + dst_layer.self_attn.out_proj.bias = src_layer.self_attn.out_proj.bias + + dst_layer.linear1.weight = src_layer.fc1.weight + dst_layer.linear1.bias = src_layer.fc1.bias + dst_layer.linear2.weight = src_layer.fc2.weight + dst_layer.linear2.bias = src_layer.fc2.bias + + dst_layer.norm1.weight = src_layer.self_attn_layer_norm.weight + dst_layer.norm1.bias = src_layer.self_attn_layer_norm.bias + dst_layer.norm2.weight = src_layer.final_layer_norm.weight + dst_layer.norm2.bias = src_layer.final_layer_norm.bias + + (L, D, H, FD, V) = get_layers_embedding_dim_num_heads_for_configuration( + xlarge, large + ) + embedding_layer = torch.nn.Embedding(V, D, DEFAULT_PADDING_IDX) + # True means BT as source and fairseq is target, False means the other way + # mode1 = False + if bt2fairseq: + # BT as source and fairseq is target, copy BT's weight to fairseq + transformer = make_transformer() + fairseq_transformer = ( + FairseqEncoder( + D, + H, + FD, + L, + embedding_layer, + dropout=0, + normalize_before=False, + torch_encoder=transformer.transformer, + activation="relu", + ) + .eval() + .cuda() + ) + if half: + transformer.half() + fairseq_transformer.half() + if not bt2fairseq: + # the other way around, fairseq is source and BT is target,copy fairseq's weight to BT + transformer = make_transformer() + fairseq_transformer = ( + FairseqEncoder( + D, + H, + FD, + L, + embedding_layer, + dropout=0, + normalize_before=False, + torch_encoder=None, + activation="relu", + ) + .eval() + .cuda() + ) + # for the test where we need to load existing ckpt. It is tested that after loading + # the ckpt, the results between fairseq_transformer(BT kernel) equals BT + if half: + transformer.half() + fairseq_transformer.half() + if save: + torch.save(fairseq_transformer.state_dict(), "./fairseq.pt") + sys.exit(0) + if load: + fairseq_transformer.load_state_dict(torch.load("./fairseq.pt")) + # copy + copy_weights(fairseq_transformer.encoder.layers, transformer.transformer.layers) + + device = "cuda" + lengths = (avg_sequence_length,) * batch_size + tokens = torch.full( + (batch_size, max_sequence_length), + DEFAULT_PADDING_IDX, + device=device, + dtype=torch.long, + ) + for i in range(batch_size): + tokens[i, : lengths[i]] = torch.randint( + DEFAULT_PADDING_IDX + 1, + V - 1, + size=(lengths[i],), + device=device, + dtype=torch.long, + ) + # mask + if half: + lengths_tensor = torch.Tensor(lengths).cuda().half() + else: + lengths_tensor = torch.Tensor(lengths).cuda() + + with torch.inference_mode(): + fs_output = fairseq_transformer(tokens, lengths_tensor)["encoder_out"][0] + fs_output = fs_output.transpose(0, 1) + with torch.inference_mode(): + t_output = transformer(tokens) + test_lst = [ + # (name, output, relative tolerance, absolute tolerance) + ("FS", fs_output, 1e-4, 9e-3), + ] + numerical_test(lengths, t_output, test_lst) + + iters = 100 + t = benchmark_torch_function(iters, transformer, tokens) + + def bert_flops(B, T, D, L): + mlp = 2 * (B * T * D * 4 * D) + 2 * (B * T * D * 4 * D) + qkv = 3 * 2 * B * T * D * D + attn = 2 * B * D * T * T + 2 * B * D * T * T + 2 * B * T * D * D + return L * (mlp + qkv + attn) + + flops = bert_flops(batch_size, avg_sequence_length, D, L) + flops_e = ( + FlopCountAnalysis(transformer, (tokens[:, :avg_sequence_length])).total() * 2 + ) + with torch.inference_mode(): + bt = benchmark_torch_function(iters, transformer, tokens) + fst = benchmark_torch_function( + iters, fairseq_transformer, tokens, lengths_tensor + ) + + def metrics(tt, baseline=None): + if baseline: + return metrics(tt) + f", Speedup: {baseline / tt:.2f}x" + return f"{tt * 1.0e3:.2f} ms/iter, {flops_e / tt / 1.0e12:.2f} TFLOP/s" + + results = [ + f"Seed: {seed}", + f"Padded tokens: {(1-sum(lengths)/(tokens.numel()))*100:.2f}%", + f"Batch shape: {tokens.shape}", + f"Analytical flops per batch: {flops/ batch_size / 1e9:.2f} GFLOPS", + f"Empirical flops per batch: {flops_e/ batch_size / 1e9:.2f} GFLOPS", + f"B: {batch_size}", + f"T: {avg_sequence_length}", + f"TMax: {max_sequence_length}", + f"Eager Time: {metrics(t)}", + f"BetterTransformer: {metrics(bt, t)}", + f"FST: {metrics(fst, t)}", + ] + print("===========Speedup Results") + print("; ".join(results)) + + +if __name__ == "__main__": + cli() diff --git a/tests/test_export.py b/tests/test_export.py index 3e9a48d18..36fcc4455 100644 --- a/tests/test_export.py +++ b/tests/test_export.py @@ -92,8 +92,35 @@ class TestExportModels(unittest.TestCase): scripted = torch.jit.script(module) _test_save_and_load(scripted) + def version_check(): + # check Nested Tensor available. Make sure version >= '1.13.0.dev20220613' + if "fb" in torch.__version__: + return False + else: + if "+" in torch.__version__: + torch_version = torch.__version__.split("+")[0] + else: + torch_version = torch.__version__ + + torch_version = torch_version.split(".") + int_version = ( + int(torch_version[0]) * 1000 + + int(torch_version[1]) * 10 + + int(torch_version[2]) + ) + if len(torch_version) == 3: + if int_version >= 1131: + return False + elif len(torch_version) == 4: + if int_version >= 1131 or ( + int_version == 1130 and torch_version[3][3:] >= "20220613" + ): + return False + return True + @unittest.skipIf( - torch.__version__ < "1.6.0", "Targeting OSS scriptability for the 1.6 release" + version_check(), + "Targeting OSS scriptability for the 1.13.0.dev20220613 release", ) def test_export_transformer(self): task, parser = get_dummy_task_and_parser() @@ -104,7 +131,8 @@ class TestExportModels(unittest.TestCase): _test_save_and_load(scripted) @unittest.skipIf( - torch.__version__ < "1.6.0", "Targeting OSS scriptability for the 1.6 release" + version_check(), + "Targeting OSS scriptability for the 1.13.0.dev20220613 release", ) def test_export_transformer_no_token_pos_emb(self): task, parser = get_dummy_task_and_parser() diff --git a/tests/test_sequence_generator.py b/tests/test_sequence_generator.py index 2e42df0e5..a0b6e8934 100644 --- a/tests/test_sequence_generator.py +++ b/tests/test_sequence_generator.py @@ -22,6 +22,33 @@ from fairseq.tasks.fairseq_task import LegacyFairseqTask DEFAULT_TEST_VOCAB_SIZE = 100 +def version_check(): + # check Nested Tensor available. Make sure version >= '1.13.0.dev20220613' + if "fb" in torch.__version__: + return False + else: + if "+" in torch.__version__: + torch_version = torch.__version__.split("+")[0] + else: + torch_version = torch.__version__ + + torch_version = torch_version.split(".") + int_version = ( + int(torch_version[0]) * 1000 + + int(torch_version[1]) * 10 + + int(torch_version[2]) + ) + if len(torch_version) == 3: + if int_version >= 1131: + return False + elif len(torch_version) == 4: + if int_version >= 1131 or ( + int_version == 1130 and torch_version[3][3:] >= "20220613" + ): + return False + return True + + class DummyTask(LegacyFairseqTask): def __init__(self, args): super().__init__(args) @@ -113,7 +140,9 @@ class TestJitSequenceGeneratorBase(unittest.TestCase): JIT_MSG = "Targeting OSS scriptability for the 1.6 release" -@unittest.skipIf(torch.__version__ < "1.6.0", JIT_MSG) +@unittest.skipIf( + version_check(), "Targeting OSS scriptability for the 1.13.0.dev20220613 release" +) class TestJitSequenceGenerator(TestJitSequenceGeneratorBase): def test_export_transformer(self): model = self.transformer_model