BT enablement on fairseq - fairseq change (#4480)

Summary:
Pull Request resolved: https://github.com/facebookresearch/fairseq/pull/4480

as titled and depends on D36057338
Fork the inference path inside the forward function. If loaded the checkpoint file and perform the inference, we will deploy BT. Otherwise, fairseq take the position.

In summary:
Accuracy: accuracy loss due to the fp16, the maximum diff is around 0.009. If we set it to fp32, there is no accuracy loss
Perf: the current fairseq has similar speed as vanilla version. After the enablement, the speedup is similar to standalone BT test.
With batch size=64
For V100, the speedup reaches to 1.23x
For A100, the speedup reaches to 1.38x

After enable nested tensor,
For V100, the speedup reaches to 2.46x

Reviewed By: mikekgfb

Differential Revision: D37082681

fbshipit-source-id: 984266f850fc30603e48be56e41ac2c67da080f5
This commit is contained in:
Wei Wei 2022-06-15 21:48:41 -07:00 committed by Facebook GitHub Bot
parent d9c661bf4f
commit 3a757d7ab2
5 changed files with 694 additions and 48 deletions

View File

@ -8,23 +8,22 @@ from typing import Dict, List, Optional
import torch
import torch.nn as nn
from torch import Tensor
from fairseq import utils
from fairseq.distributed import fsdp_wrap
from fairseq.models import FairseqEncoder
from fairseq.models.transformer import TransformerConfig
from fairseq.modules import (
FairseqDropout,
LayerDropModuleList,
LayerNorm,
PositionalEmbedding,
SinusoidalPositionalEmbedding,
transformer_layer,
)
from fairseq.modules import transformer_layer
from fairseq.modules.checkpoint_activations import checkpoint_wrapper
from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_
from torch import Tensor
from fairseq.models.transformer import (
TransformerConfig,
)
# rewrite name for backward compatibility in `make_generation_fast_`
@ -220,11 +219,79 @@ class TransformerEncoderBase(FairseqEncoder):
if return_all_hiddens:
encoder_states.append(x)
# encoder layers
for layer in self.layers:
lr = layer(
x, encoder_padding_mask=encoder_padding_mask if has_pads else None
# nested tensor and BT enable
layer = self.layers[0]
BT_flag = False
NT_flag = False
# torch version check, BT>=1.12.0 and NT>=1.13.0.dev20220613
# internal format is '1.13.0a0+fb'
# external format is '1.13.0.dev20220613'(cpu&gpu) for nightly or "1.11.0"(cpu) or '1.11.0+cu102'(gpu) for stable
BT_version = False
NT_version = False
if "fb" in torch.__version__:
BT_version = True
NT_version = True
else:
if "+" in torch.__version__:
torch_version = torch.__version__.split("+")[0]
else:
torch_version = torch.__version__
torch_version = torch_version.split(".")
int_version = (
int(torch_version[0]) * 1000
+ int(torch_version[1]) * 10
+ int(torch_version[2])
)
if len(torch_version) == 3:
if int_version >= 1120:
BT_version = True
if int_version >= 1131:
NT_version = True
elif len(torch_version) == 4:
if int_version >= 1130:
BT_version = True
# Consider _nested_tensor_from_mask_left_aligned is landed after "20220613"
if int_version >= 1131 or (
int_version == 1130 and torch_version[3][3:] >= "20220613"
):
NT_version = True
if (
BT_version
and x.dim() == 3
and layer.load_to_BT
and not layer.return_fc
and layer.can_use_fastpath
and not layer.training
and not layer.ever_training
and not layer.cfg_checkpoint_activations
):
# Batch first can not be justified but needs user to make sure
x = x.transpose(0, 1)
# Check mask conditions for nested tensor
if NT_version:
if (
encoder_padding_mask is not None
and torch._nested_tensor_from_mask_left_aligned(
x, encoder_padding_mask.logical_not()
)
):
if not torch.is_grad_enabled() or not x.requires_grad:
x = torch._nested_tensor_from_mask(
x, encoder_padding_mask.logical_not()
)
NT_flag = True
BT_flag = True
# encoder layers
if NT_flag:
processing_mask = None
else:
processing_mask = encoder_padding_mask
encoder_padding_mask_out = processing_mask if has_pads else None
for layer in self.layers:
lr = layer(x, encoder_padding_mask=encoder_padding_mask_out)
if isinstance(lr, tuple) and len(lr) == 2:
x, fc_result = lr
@ -237,6 +304,13 @@ class TransformerEncoderBase(FairseqEncoder):
encoder_states.append(x)
fc_results.append(fc_result)
# change back to non-nested and Batch second
if NT_flag:
x = x.to_padded_tensor(0.0)
if NT_flag or BT_flag:
x = x.transpose(0, 1)
if self.layer_norm is not None:
x = self.layer_norm(x)

View File

@ -7,14 +7,13 @@ from typing import Dict, List, Optional
import torch
import torch.nn as nn
from torch import Tensor
from fairseq import utils
from fairseq.models.transformer import TransformerConfig
from fairseq.modules import LayerNorm, MultiheadAttention
from fairseq.modules.fairseq_dropout import FairseqDropout
from fairseq.modules.quant_noise import quant_noise
from torch import Tensor
from fairseq.models.transformer import (
TransformerConfig,
)
class TransformerEncoderLayerBase(nn.Module):
@ -68,6 +67,103 @@ class TransformerEncoderLayerBase(nn.Module):
self.final_layer_norm = LayerNorm(self.embed_dim, export=cfg.export)
self.num_heads = cfg.encoder.attention_heads
self.load_to_BT = False
self.ever_training = False
# For BT, we need continuous mem
self.in_proj_weight = torch.nn.Parameter(
torch.zeros(
self.self_attn.q_proj.weight.shape[0] * 3,
self.self_attn.q_proj.weight.shape[1],
)
)
self.in_proj_bias = torch.nn.Parameter(
torch.zeros(self.self_attn.q_proj.bias.shape[0] * 3)
)
if (
self.activation_fn is torch.nn.functional.relu
or isinstance(self.activation_fn, torch.nn.ReLU)
or self.activation_fn == "relu"
):
self.activation_relu_or_gelu = 1
elif (
self.activation_fn is torch.nn.functional.gelu
or isinstance(self.activation_fn, torch.nn.GELU)
or self.activation_fn == "gelu"
):
self.activation_relu_or_gelu = 2
else:
self.activation_relu_or_gelu = 0
# Batch first can not be justified but needs user to make sure
self.can_use_fastpath = (
not self.normalize_before
and self.activation_relu_or_gelu
and (self.self_attn_layer_norm.eps == self.final_layer_norm.eps)
)
self.cfg_checkpoint_activations = self.cfg.checkpoint_activations
# torch version check
# make sure BT version is >=1.12.0
self.BT_version = False
if "fb" in torch.__version__:
self.BT_version = True
else:
if "+" in torch.__version__:
self.torch_version = torch.__version__.split("+")[0]
else:
self.torch_version = torch.__version__
self.torch_version = self.torch_version.split(".")
self.int_version = (
int(self.torch_version[0]) * 1000
+ int(self.torch_version[1]) * 10
+ int(self.torch_version[2])
)
if len(self.torch_version) == 3:
if self.int_version >= 1120:
self.BT_version = True
elif len(self.torch_version) == 4:
if self.int_version >= 1130:
self.BT_version = True
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
self.load_to_BT = True
old_name = prefix + "self_attn."
q_proj_weight = state_dict[old_name + "q_proj.weight"]
k_proj_weight = state_dict[old_name + "k_proj.weight"]
v_proj_weight = state_dict[old_name + "v_proj.weight"]
q_proj_bias = state_dict[old_name + "q_proj.bias"]
k_proj_bias = state_dict[old_name + "k_proj.bias"]
v_proj_bias = state_dict[old_name + "v_proj.bias"]
new_name = prefix
state_dict[new_name + "in_proj_weight"] = torch.cat(
(q_proj_weight, k_proj_weight, v_proj_weight), dim=0
)
state_dict[new_name + "in_proj_bias"] = torch.cat(
(q_proj_bias, k_proj_bias, v_proj_bias), dim=0
)
super(TransformerEncoderLayerBase, self)._load_from_state_dict(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)
def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size):
return quant_noise(
nn.Linear(input_dim, output_dim), p=q_noise, block_size=qn_block_size
@ -187,44 +283,83 @@ class TransformerEncoderLayerBase(nn.Module):
# Note that we cannot use -inf here, because at some edge cases,
# the attention weight (before softmax) for some padded element in query
# will become -inf, which results in NaN in model parameters
if attn_mask is not None:
attn_mask = attn_mask.masked_fill(
attn_mask.to(torch.bool), -1e8 if x.dtype == torch.float32 else -1e4
if self.training:
self.ever_training = True
if (
self.BT_version
and x.dim() == 3
and self.load_to_BT
and not self.return_fc
and self.can_use_fastpath
and not self.training
and not self.ever_training
and not self.cfg_checkpoint_activations
):
# assume is Batch first and nested tensor
output = torch._transformer_encoder_layer_fwd(
x,
self.embed_dim,
self.num_heads,
self.in_proj_weight,
self.in_proj_bias,
self.self_attn.out_proj.weight,
self.self_attn.out_proj.bias,
self.activation_relu_or_gelu == 2,
False, # norm_first, currently not supported
self.self_attn_layer_norm.eps,
self.self_attn_layer_norm.weight,
self.self_attn_layer_norm.bias,
self.final_layer_norm.weight,
self.final_layer_norm.bias,
self.fc1.weight,
self.fc1.bias,
self.fc2.weight,
self.fc2.bias,
encoder_padding_mask if encoder_padding_mask is not None else attn_mask,
)
return output
residual = x
if self.normalize_before:
x = self.self_attn_layer_norm(x)
x, _ = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=encoder_padding_mask,
need_weights=False,
attn_mask=attn_mask,
)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.self_attn_layer_norm(x)
else:
if attn_mask is not None:
attn_mask = attn_mask.masked_fill(
attn_mask.to(torch.bool), -1e8 if x.dtype == torch.float32 else -1e4
)
residual = x
if self.normalize_before:
x = self.final_layer_norm(x)
x = self.activation_fn(self.fc1(x))
x = self.activation_dropout_module(x)
x = self.fc2(x)
residual = x
if self.normalize_before:
x = self.self_attn_layer_norm(x)
x, _ = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=encoder_padding_mask,
need_weights=False,
attn_mask=attn_mask,
)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.self_attn_layer_norm(x)
fc_result = x
residual = x
if self.normalize_before:
x = self.final_layer_norm(x)
x = self.activation_fn(self.fc1(x))
x = self.activation_dropout_module(x)
x = self.fc2(x)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.final_layer_norm(x)
fc_result = x
if self.return_fc and not torch.jit.is_scripting():
return x, fc_result
return x
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.final_layer_norm(x)
if self.return_fc and not torch.jit.is_scripting():
return x, fc_result
return x
# backward compatible with the legacy argparse format

View File

@ -0,0 +1,380 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import sys
import click
import numpy as np
import torch
from fvcore.nn import FlopCountAnalysis
from fairseq.models.transformer import TransformerConfig as FairseqTransformerConfig
from fairseq.models.transformer import TransformerEncoder as FairseqTransformerEncoder
seed = 0
torch.manual_seed(seed)
np.random.seed(seed)
def benchmark_torch_function(iters, f, *args, **kwargs):
f(*args, **kwargs)
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(iters):
f(*args, **kwargs)
end_event.record()
torch.cuda.synchronize()
return (start_event.elapsed_time(end_event) * 1.0e-3) / iters
def numerical_test(lengths, truth_tensors, test_list):
"""
truth_tensors is the source of truth.
test_dict looks like
[
(name, out_tensors, atol, rtol),
...
]
"""
for name, out_tensors, rtol, atol in test_list:
n_failures = 0
max_diff = 0
for (length, truth, out) in zip(lengths, truth_tensors, out_tensors):
cut_truth = truth[:length]
cut_out = out[:length]
max_diff = max(max_diff, torch.max(torch.abs(cut_truth - cut_out)))
if not torch.allclose(cut_truth, cut_out, atol=atol, rtol=rtol):
n_failures += 1
if n_failures == 0:
print(f"{name} PASS")
else:
print(f"{name} FAIL {n_failures}/{len(lengths)}. Max diff is {max_diff}")
@click.group()
def cli():
pass
@cli.command()
@click.option("--save", is_flag=True, default=False)
@click.option("--load", is_flag=True, default=False)
@click.option("--half", is_flag=True, default=False)
@click.option("--bt2fairseq", is_flag=True, default=False)
def transformer(
save,
load,
half,
bt2fairseq,
):
xlarge = False
large = False
DEFAULT_PADDING_IDX = 1
avg_sequence_length = 128
max_sequence_length = 256
batch_size = 64
class FairseqEncoder(torch.nn.Module):
def __init__(
self,
embed_dim,
attention_heads,
ffn_embed_dim,
num_layers,
embedding_layer, # torch.nn.Embedding. Must have a padding_idx field
dropout=0,
normalize_before=False,
torch_encoder=None, # torch encoder that you can map weights from
activation="relu",
):
super().__init__()
cfg = FairseqTransformerConfig()
cfg.encoder.embed_dim = embed_dim
cfg.encoder.attention_heads = attention_heads
cfg.encoder.ffn_embed_dim = ffn_embed_dim
cfg.dropout = dropout
cfg.encoder.normalize_before = normalize_before
cfg.encoder.layers = num_layers
# make embedding behavior same as other encoders
cfg.no_token_positional_embeddings = True
cfg.no_scale_embedding = True
cfg.activation_fn = activation
dictionary = {} # TODO: verify what this is
self.encoder = FairseqTransformerEncoder(
cfg, dictionary, embedding_layer, return_fc=False
)
if torch_encoder is not None:
for src_layer, dst_layer in zip(
torch_encoder.layers, self.encoder.layers
):
w_q, w_k, w_v = src_layer.self_attn.in_proj_weight.chunk(3, dim=0)
b_q, b_k, b_v = src_layer.self_attn.in_proj_bias.chunk(3, dim=0)
dst_layer.self_attn.q_proj.weight = torch.nn.Parameter(w_q)
dst_layer.self_attn.q_proj.bias = torch.nn.Parameter(b_q)
dst_layer.self_attn.k_proj.weight = torch.nn.Parameter(w_k)
dst_layer.self_attn.k_proj.bias = torch.nn.Parameter(b_k)
dst_layer.self_attn.v_proj.weight = torch.nn.Parameter(w_v)
dst_layer.self_attn.v_proj.bias = torch.nn.Parameter(b_v)
dst_layer.self_attn.out_proj.weight = (
src_layer.self_attn.out_proj.weight
)
dst_layer.self_attn.out_proj.bias = (
src_layer.self_attn.out_proj.bias
)
dst_layer.fc1.weight = src_layer.linear1.weight
dst_layer.fc1.bias = src_layer.linear1.bias
# fairseq may use fusedlayernorm from nvidia apex - diff properties
dst_layer.self_attn_layer_norm.load_state_dict(
src_layer.norm1.state_dict()
)
dst_layer.fc2.weight = src_layer.linear2.weight
dst_layer.fc2.bias = src_layer.linear2.bias
dst_layer.final_layer_norm.load_state_dict(
src_layer.norm2.state_dict()
)
# self.encoder = self.encoder.eval().cuda().half()
def forward(self, tokens, src_lengths=None):
return self.encoder(
tokens,
src_lengths=src_lengths,
return_all_hiddens=False,
token_embeddings=None,
)
def get_layers_embedding_dim_num_heads_for_configuration(xlarge, large):
if xlarge:
# XLM-R extra large (no BERT-XL exists)
L = 24 # Layers
D = 2560 # Embedding Dim
H = 32 # Number of Heads
FD = 10240 # Feed-forward network dim
V = 30000 # Vocab Size
elif large:
# BERT-large
L = 24
D = 1024
H = 16
FD = 4096
V = 30000
else:
# BERT-base
L = 12
D = 768
H = 12
FD = 3072
V = 30000
return (L, D, H, FD, V)
# Better transformer
class PTTransformer(torch.nn.Module):
def __init__(self, transformer, embedding):
super().__init__()
self.transformer = transformer
self.embedding = embedding
self.padding_idx = DEFAULT_PADDING_IDX
def forward(self, x):
padding_mask = None
if not x.is_nested:
padding_mask = x.eq(self.padding_idx)
x = self.embedding(x)
return self.transformer(x, src_key_padding_mask=padding_mask)
def make_transformer():
return (
PTTransformer(
torch.nn.TransformerEncoder(
torch.nn.TransformerEncoderLayer(
d_model=D,
nhead=H,
dim_feedforward=FD,
batch_first=True,
activation="relu",
),
num_layers=L,
enable_nested_tensor=False,
),
embedding_layer,
)
.eval()
.cuda()
)
def copy_weights(layers_fairseq, layers_bt):
for src_layer, dst_layer in zip(layers_fairseq, layers_bt):
w_q = src_layer.self_attn.q_proj.weight
b_q = src_layer.self_attn.q_proj.bias
w_k = src_layer.self_attn.k_proj.weight
b_k = src_layer.self_attn.k_proj.bias
w_v = src_layer.self_attn.v_proj.weight
b_v = src_layer.self_attn.v_proj.bias
dst_layer.self_attn.in_proj_weight = torch.nn.Parameter(
torch.cat((w_q, w_k, w_v), dim=0)
)
dst_layer.self_attn.in_proj_bias = torch.nn.Parameter(
torch.cat((b_q, b_k, b_v), dim=0)
)
dst_layer.self_attn.out_proj.weight = src_layer.self_attn.out_proj.weight
dst_layer.self_attn.out_proj.bias = src_layer.self_attn.out_proj.bias
dst_layer.linear1.weight = src_layer.fc1.weight
dst_layer.linear1.bias = src_layer.fc1.bias
dst_layer.linear2.weight = src_layer.fc2.weight
dst_layer.linear2.bias = src_layer.fc2.bias
dst_layer.norm1.weight = src_layer.self_attn_layer_norm.weight
dst_layer.norm1.bias = src_layer.self_attn_layer_norm.bias
dst_layer.norm2.weight = src_layer.final_layer_norm.weight
dst_layer.norm2.bias = src_layer.final_layer_norm.bias
(L, D, H, FD, V) = get_layers_embedding_dim_num_heads_for_configuration(
xlarge, large
)
embedding_layer = torch.nn.Embedding(V, D, DEFAULT_PADDING_IDX)
# True means BT as source and fairseq is target, False means the other way
# mode1 = False
if bt2fairseq:
# BT as source and fairseq is target, copy BT's weight to fairseq
transformer = make_transformer()
fairseq_transformer = (
FairseqEncoder(
D,
H,
FD,
L,
embedding_layer,
dropout=0,
normalize_before=False,
torch_encoder=transformer.transformer,
activation="relu",
)
.eval()
.cuda()
)
if half:
transformer.half()
fairseq_transformer.half()
if not bt2fairseq:
# the other way around, fairseq is source and BT is target,copy fairseq's weight to BT
transformer = make_transformer()
fairseq_transformer = (
FairseqEncoder(
D,
H,
FD,
L,
embedding_layer,
dropout=0,
normalize_before=False,
torch_encoder=None,
activation="relu",
)
.eval()
.cuda()
)
# for the test where we need to load existing ckpt. It is tested that after loading
# the ckpt, the results between fairseq_transformer(BT kernel) equals BT
if half:
transformer.half()
fairseq_transformer.half()
if save:
torch.save(fairseq_transformer.state_dict(), "./fairseq.pt")
sys.exit(0)
if load:
fairseq_transformer.load_state_dict(torch.load("./fairseq.pt"))
# copy
copy_weights(fairseq_transformer.encoder.layers, transformer.transformer.layers)
device = "cuda"
lengths = (avg_sequence_length,) * batch_size
tokens = torch.full(
(batch_size, max_sequence_length),
DEFAULT_PADDING_IDX,
device=device,
dtype=torch.long,
)
for i in range(batch_size):
tokens[i, : lengths[i]] = torch.randint(
DEFAULT_PADDING_IDX + 1,
V - 1,
size=(lengths[i],),
device=device,
dtype=torch.long,
)
# mask
if half:
lengths_tensor = torch.Tensor(lengths).cuda().half()
else:
lengths_tensor = torch.Tensor(lengths).cuda()
with torch.inference_mode():
fs_output = fairseq_transformer(tokens, lengths_tensor)["encoder_out"][0]
fs_output = fs_output.transpose(0, 1)
with torch.inference_mode():
t_output = transformer(tokens)
test_lst = [
# (name, output, relative tolerance, absolute tolerance)
("FS", fs_output, 1e-4, 9e-3),
]
numerical_test(lengths, t_output, test_lst)
iters = 100
t = benchmark_torch_function(iters, transformer, tokens)
def bert_flops(B, T, D, L):
mlp = 2 * (B * T * D * 4 * D) + 2 * (B * T * D * 4 * D)
qkv = 3 * 2 * B * T * D * D
attn = 2 * B * D * T * T + 2 * B * D * T * T + 2 * B * T * D * D
return L * (mlp + qkv + attn)
flops = bert_flops(batch_size, avg_sequence_length, D, L)
flops_e = (
FlopCountAnalysis(transformer, (tokens[:, :avg_sequence_length])).total() * 2
)
with torch.inference_mode():
bt = benchmark_torch_function(iters, transformer, tokens)
fst = benchmark_torch_function(
iters, fairseq_transformer, tokens, lengths_tensor
)
def metrics(tt, baseline=None):
if baseline:
return metrics(tt) + f", Speedup: {baseline / tt:.2f}x"
return f"{tt * 1.0e3:.2f} ms/iter, {flops_e / tt / 1.0e12:.2f} TFLOP/s"
results = [
f"Seed: {seed}",
f"Padded tokens: {(1-sum(lengths)/(tokens.numel()))*100:.2f}%",
f"Batch shape: {tokens.shape}",
f"Analytical flops per batch: {flops/ batch_size / 1e9:.2f} GFLOPS",
f"Empirical flops per batch: {flops_e/ batch_size / 1e9:.2f} GFLOPS",
f"B: {batch_size}",
f"T: {avg_sequence_length}",
f"TMax: {max_sequence_length}",
f"Eager Time: {metrics(t)}",
f"BetterTransformer: {metrics(bt, t)}",
f"FST: {metrics(fst, t)}",
]
print("===========Speedup Results")
print("; ".join(results))
if __name__ == "__main__":
cli()

View File

@ -92,8 +92,35 @@ class TestExportModels(unittest.TestCase):
scripted = torch.jit.script(module)
_test_save_and_load(scripted)
def version_check():
# check Nested Tensor available. Make sure version >= '1.13.0.dev20220613'
if "fb" in torch.__version__:
return False
else:
if "+" in torch.__version__:
torch_version = torch.__version__.split("+")[0]
else:
torch_version = torch.__version__
torch_version = torch_version.split(".")
int_version = (
int(torch_version[0]) * 1000
+ int(torch_version[1]) * 10
+ int(torch_version[2])
)
if len(torch_version) == 3:
if int_version >= 1131:
return False
elif len(torch_version) == 4:
if int_version >= 1131 or (
int_version == 1130 and torch_version[3][3:] >= "20220613"
):
return False
return True
@unittest.skipIf(
torch.__version__ < "1.6.0", "Targeting OSS scriptability for the 1.6 release"
version_check(),
"Targeting OSS scriptability for the 1.13.0.dev20220613 release",
)
def test_export_transformer(self):
task, parser = get_dummy_task_and_parser()
@ -104,7 +131,8 @@ class TestExportModels(unittest.TestCase):
_test_save_and_load(scripted)
@unittest.skipIf(
torch.__version__ < "1.6.0", "Targeting OSS scriptability for the 1.6 release"
version_check(),
"Targeting OSS scriptability for the 1.13.0.dev20220613 release",
)
def test_export_transformer_no_token_pos_emb(self):
task, parser = get_dummy_task_and_parser()

View File

@ -22,6 +22,33 @@ from fairseq.tasks.fairseq_task import LegacyFairseqTask
DEFAULT_TEST_VOCAB_SIZE = 100
def version_check():
# check Nested Tensor available. Make sure version >= '1.13.0.dev20220613'
if "fb" in torch.__version__:
return False
else:
if "+" in torch.__version__:
torch_version = torch.__version__.split("+")[0]
else:
torch_version = torch.__version__
torch_version = torch_version.split(".")
int_version = (
int(torch_version[0]) * 1000
+ int(torch_version[1]) * 10
+ int(torch_version[2])
)
if len(torch_version) == 3:
if int_version >= 1131:
return False
elif len(torch_version) == 4:
if int_version >= 1131 or (
int_version == 1130 and torch_version[3][3:] >= "20220613"
):
return False
return True
class DummyTask(LegacyFairseqTask):
def __init__(self, args):
super().__init__(args)
@ -113,7 +140,9 @@ class TestJitSequenceGeneratorBase(unittest.TestCase):
JIT_MSG = "Targeting OSS scriptability for the 1.6 release"
@unittest.skipIf(torch.__version__ < "1.6.0", JIT_MSG)
@unittest.skipIf(
version_check(), "Targeting OSS scriptability for the 1.13.0.dev20220613 release"
)
class TestJitSequenceGenerator(TestJitSequenceGeneratorBase):
def test_export_transformer(self):
model = self.transformer_model