Back out "BT enablement on fairseq - fairseq change"

Summary:
Context: https://fburl.com/7vdj7vhl

Backing out due to breaking our TorchScript test:
```
RuntimeError:
method cannot be used as a value:
  File "/dev/shm/uid-30041/54641b26-seed-nspid4026533396_cgpid7154327-ns-4026533393/fairseq/modules/transformer_layer.py", line 307
                self.in_proj_weight,
                self.in_proj_bias,
                self.self_attn.out_proj.weight,
                ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
                self.self_attn.out_proj.bias,
                self.activation_relu_or_gelu == 2,

Stack trace:
Exception type: torch::jit::ErrorReport
```
https://fburl.com/sandcastle/4pzqemf5

Original commit changeset: 984266f850fc

Original Phabricator Diff: D37082681 (3a757d7ab2)

Differential Revision: D37303846

fbshipit-source-id: 1757ea5dae98be5beb4d08f70b0c3001d6ea336f
This commit is contained in:
Wei Ho 2022-06-21 17:27:50 -07:00 committed by Facebook GitHub Bot
parent 08fe88479f
commit 956fcf495b
5 changed files with 47 additions and 693 deletions

View File

@ -8,22 +8,23 @@ from typing import Dict, List, Optional
import torch
import torch.nn as nn
from torch import Tensor
from fairseq import utils
from fairseq.distributed import fsdp_wrap
from fairseq.models import FairseqEncoder
from fairseq.models.transformer import TransformerConfig
from fairseq.modules import (
FairseqDropout,
LayerDropModuleList,
LayerNorm,
PositionalEmbedding,
SinusoidalPositionalEmbedding,
transformer_layer,
)
from fairseq.modules import transformer_layer
from fairseq.modules.checkpoint_activations import checkpoint_wrapper
from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_
from torch import Tensor
from fairseq.models.transformer import (
TransformerConfig,
)
# rewrite name for backward compatibility in `make_generation_fast_`
@ -219,79 +220,11 @@ class TransformerEncoderBase(FairseqEncoder):
if return_all_hiddens:
encoder_states.append(x)
# nested tensor and BT enable
layer = self.layers[0]
BT_flag = False
NT_flag = False
# torch version check, BT>=1.12.0 and NT>=1.13.0.dev20220613
# internal format is '1.13.0a0+fb'
# external format is '1.13.0.dev20220613'(cpu&gpu) for nightly or "1.11.0"(cpu) or '1.11.0+cu102'(gpu) for stable
BT_version = False
NT_version = False
if "fb" in torch.__version__:
BT_version = True
NT_version = True
else:
if "+" in torch.__version__:
torch_version = torch.__version__.split("+")[0]
else:
torch_version = torch.__version__
torch_version = torch_version.split(".")
int_version = (
int(torch_version[0]) * 1000
+ int(torch_version[1]) * 10
+ int(torch_version[2])
)
if len(torch_version) == 3:
if int_version >= 1120:
BT_version = True
if int_version >= 1131:
NT_version = True
elif len(torch_version) == 4:
if int_version >= 1130:
BT_version = True
# Consider _nested_tensor_from_mask_left_aligned is landed after "20220613"
if int_version >= 1131 or (
int_version == 1130 and torch_version[3][3:] >= "20220613"
):
NT_version = True
if (
BT_version
and x.dim() == 3
and layer.load_to_BT
and not layer.return_fc
and layer.can_use_fastpath
and not layer.training
and not layer.ever_training
and not layer.cfg_checkpoint_activations
):
# Batch first can not be justified but needs user to make sure
x = x.transpose(0, 1)
# Check mask conditions for nested tensor
if NT_version:
if (
encoder_padding_mask is not None
and torch._nested_tensor_from_mask_left_aligned(
x, encoder_padding_mask.logical_not()
)
):
if not torch.is_grad_enabled() or not x.requires_grad:
x = torch._nested_tensor_from_mask(
x, encoder_padding_mask.logical_not()
)
NT_flag = True
BT_flag = True
# encoder layers
if NT_flag:
processing_mask = None
else:
processing_mask = encoder_padding_mask
encoder_padding_mask_out = processing_mask if has_pads else None
for layer in self.layers:
lr = layer(x, encoder_padding_mask=encoder_padding_mask_out)
lr = layer(
x, encoder_padding_mask=encoder_padding_mask if has_pads else None
)
if isinstance(lr, tuple) and len(lr) == 2:
x, fc_result = lr
@ -304,13 +237,6 @@ class TransformerEncoderBase(FairseqEncoder):
encoder_states.append(x)
fc_results.append(fc_result)
# change back to non-nested and Batch second
if NT_flag:
x = x.to_padded_tensor(0.0)
if NT_flag or BT_flag:
x = x.transpose(0, 1)
if self.layer_norm is not None:
x = self.layer_norm(x)

View File

@ -7,13 +7,14 @@ from typing import Dict, List, Optional
import torch
import torch.nn as nn
from torch import Tensor
from fairseq import utils
from fairseq.models.transformer import TransformerConfig
from fairseq.modules import LayerNorm, MultiheadAttention
from fairseq.modules.fairseq_dropout import FairseqDropout
from fairseq.modules.quant_noise import quant_noise
from torch import Tensor
from fairseq.models.transformer import (
TransformerConfig,
)
class TransformerEncoderLayerBase(nn.Module):
@ -67,103 +68,6 @@ class TransformerEncoderLayerBase(nn.Module):
self.final_layer_norm = LayerNorm(self.embed_dim, export=cfg.export)
self.num_heads = cfg.encoder.attention_heads
self.load_to_BT = False
self.ever_training = False
# For BT, we need continuous mem
self.in_proj_weight = torch.nn.Parameter(
torch.zeros(
self.self_attn.q_proj.weight.shape[0] * 3,
self.self_attn.q_proj.weight.shape[1],
)
)
self.in_proj_bias = torch.nn.Parameter(
torch.zeros(self.self_attn.q_proj.bias.shape[0] * 3)
)
if (
self.activation_fn is torch.nn.functional.relu
or isinstance(self.activation_fn, torch.nn.ReLU)
or self.activation_fn == "relu"
):
self.activation_relu_or_gelu = 1
elif (
self.activation_fn is torch.nn.functional.gelu
or isinstance(self.activation_fn, torch.nn.GELU)
or self.activation_fn == "gelu"
):
self.activation_relu_or_gelu = 2
else:
self.activation_relu_or_gelu = 0
# Batch first can not be justified but needs user to make sure
self.can_use_fastpath = (
not self.normalize_before
and self.activation_relu_or_gelu
and (self.self_attn_layer_norm.eps == self.final_layer_norm.eps)
)
self.cfg_checkpoint_activations = self.cfg.checkpoint_activations
# torch version check
# make sure BT version is >=1.12.0
self.BT_version = False
if "fb" in torch.__version__:
self.BT_version = True
else:
if "+" in torch.__version__:
self.torch_version = torch.__version__.split("+")[0]
else:
self.torch_version = torch.__version__
self.torch_version = self.torch_version.split(".")
self.int_version = (
int(self.torch_version[0]) * 1000
+ int(self.torch_version[1]) * 10
+ int(self.torch_version[2])
)
if len(self.torch_version) == 3:
if self.int_version >= 1120:
self.BT_version = True
elif len(self.torch_version) == 4:
if self.int_version >= 1130:
self.BT_version = True
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
self.load_to_BT = True
old_name = prefix + "self_attn."
q_proj_weight = state_dict[old_name + "q_proj.weight"]
k_proj_weight = state_dict[old_name + "k_proj.weight"]
v_proj_weight = state_dict[old_name + "v_proj.weight"]
q_proj_bias = state_dict[old_name + "q_proj.bias"]
k_proj_bias = state_dict[old_name + "k_proj.bias"]
v_proj_bias = state_dict[old_name + "v_proj.bias"]
new_name = prefix
state_dict[new_name + "in_proj_weight"] = torch.cat(
(q_proj_weight, k_proj_weight, v_proj_weight), dim=0
)
state_dict[new_name + "in_proj_bias"] = torch.cat(
(q_proj_bias, k_proj_bias, v_proj_bias), dim=0
)
super(TransformerEncoderLayerBase, self)._load_from_state_dict(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)
def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size):
return quant_noise(
nn.Linear(input_dim, output_dim), p=q_noise, block_size=qn_block_size
@ -283,83 +187,44 @@ class TransformerEncoderLayerBase(nn.Module):
# Note that we cannot use -inf here, because at some edge cases,
# the attention weight (before softmax) for some padded element in query
# will become -inf, which results in NaN in model parameters
if self.training:
self.ever_training = True
if (
self.BT_version
and x.dim() == 3
and self.load_to_BT
and not self.return_fc
and self.can_use_fastpath
and not self.training
and not self.ever_training
and not self.cfg_checkpoint_activations
):
# assume is Batch first and nested tensor
output = torch._transformer_encoder_layer_fwd(
x,
self.embed_dim,
self.num_heads,
self.in_proj_weight,
self.in_proj_bias,
self.self_attn.out_proj.weight,
self.self_attn.out_proj.bias,
self.activation_relu_or_gelu == 2,
False, # norm_first, currently not supported
self.self_attn_layer_norm.eps,
self.self_attn_layer_norm.weight,
self.self_attn_layer_norm.bias,
self.final_layer_norm.weight,
self.final_layer_norm.bias,
self.fc1.weight,
self.fc1.bias,
self.fc2.weight,
self.fc2.bias,
encoder_padding_mask if encoder_padding_mask is not None else attn_mask,
if attn_mask is not None:
attn_mask = attn_mask.masked_fill(
attn_mask.to(torch.bool), -1e8 if x.dtype == torch.float32 else -1e4
)
return output
else:
if attn_mask is not None:
attn_mask = attn_mask.masked_fill(
attn_mask.to(torch.bool), -1e8 if x.dtype == torch.float32 else -1e4
)
residual = x
if self.normalize_before:
x = self.self_attn_layer_norm(x)
x, _ = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=encoder_padding_mask,
need_weights=False,
attn_mask=attn_mask,
)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.self_attn_layer_norm(x)
residual = x
if self.normalize_before:
x = self.self_attn_layer_norm(x)
x, _ = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=encoder_padding_mask,
need_weights=False,
attn_mask=attn_mask,
)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.self_attn_layer_norm(x)
residual = x
if self.normalize_before:
x = self.final_layer_norm(x)
x = self.activation_fn(self.fc1(x))
x = self.activation_dropout_module(x)
x = self.fc2(x)
residual = x
if self.normalize_before:
x = self.final_layer_norm(x)
x = self.activation_fn(self.fc1(x))
x = self.activation_dropout_module(x)
x = self.fc2(x)
fc_result = x
fc_result = x
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.final_layer_norm(x)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.final_layer_norm(x)
if self.return_fc and not torch.jit.is_scripting():
return x, fc_result
return x
if self.return_fc and not torch.jit.is_scripting():
return x, fc_result
return x
# backward compatible with the legacy argparse format

View File

@ -1,380 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import sys
import click
import numpy as np
import torch
from fvcore.nn import FlopCountAnalysis
from fairseq.models.transformer import TransformerConfig as FairseqTransformerConfig
from fairseq.models.transformer import TransformerEncoder as FairseqTransformerEncoder
seed = 0
torch.manual_seed(seed)
np.random.seed(seed)
def benchmark_torch_function(iters, f, *args, **kwargs):
f(*args, **kwargs)
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(iters):
f(*args, **kwargs)
end_event.record()
torch.cuda.synchronize()
return (start_event.elapsed_time(end_event) * 1.0e-3) / iters
def numerical_test(lengths, truth_tensors, test_list):
"""
truth_tensors is the source of truth.
test_dict looks like
[
(name, out_tensors, atol, rtol),
...
]
"""
for name, out_tensors, rtol, atol in test_list:
n_failures = 0
max_diff = 0
for (length, truth, out) in zip(lengths, truth_tensors, out_tensors):
cut_truth = truth[:length]
cut_out = out[:length]
max_diff = max(max_diff, torch.max(torch.abs(cut_truth - cut_out)))
if not torch.allclose(cut_truth, cut_out, atol=atol, rtol=rtol):
n_failures += 1
if n_failures == 0:
print(f"{name} PASS")
else:
print(f"{name} FAIL {n_failures}/{len(lengths)}. Max diff is {max_diff}")
@click.group()
def cli():
pass
@cli.command()
@click.option("--save", is_flag=True, default=False)
@click.option("--load", is_flag=True, default=False)
@click.option("--half", is_flag=True, default=False)
@click.option("--bt2fairseq", is_flag=True, default=False)
def transformer(
save,
load,
half,
bt2fairseq,
):
xlarge = False
large = False
DEFAULT_PADDING_IDX = 1
avg_sequence_length = 128
max_sequence_length = 256
batch_size = 64
class FairseqEncoder(torch.nn.Module):
def __init__(
self,
embed_dim,
attention_heads,
ffn_embed_dim,
num_layers,
embedding_layer, # torch.nn.Embedding. Must have a padding_idx field
dropout=0,
normalize_before=False,
torch_encoder=None, # torch encoder that you can map weights from
activation="relu",
):
super().__init__()
cfg = FairseqTransformerConfig()
cfg.encoder.embed_dim = embed_dim
cfg.encoder.attention_heads = attention_heads
cfg.encoder.ffn_embed_dim = ffn_embed_dim
cfg.dropout = dropout
cfg.encoder.normalize_before = normalize_before
cfg.encoder.layers = num_layers
# make embedding behavior same as other encoders
cfg.no_token_positional_embeddings = True
cfg.no_scale_embedding = True
cfg.activation_fn = activation
dictionary = {} # TODO: verify what this is
self.encoder = FairseqTransformerEncoder(
cfg, dictionary, embedding_layer, return_fc=False
)
if torch_encoder is not None:
for src_layer, dst_layer in zip(
torch_encoder.layers, self.encoder.layers
):
w_q, w_k, w_v = src_layer.self_attn.in_proj_weight.chunk(3, dim=0)
b_q, b_k, b_v = src_layer.self_attn.in_proj_bias.chunk(3, dim=0)
dst_layer.self_attn.q_proj.weight = torch.nn.Parameter(w_q)
dst_layer.self_attn.q_proj.bias = torch.nn.Parameter(b_q)
dst_layer.self_attn.k_proj.weight = torch.nn.Parameter(w_k)
dst_layer.self_attn.k_proj.bias = torch.nn.Parameter(b_k)
dst_layer.self_attn.v_proj.weight = torch.nn.Parameter(w_v)
dst_layer.self_attn.v_proj.bias = torch.nn.Parameter(b_v)
dst_layer.self_attn.out_proj.weight = (
src_layer.self_attn.out_proj.weight
)
dst_layer.self_attn.out_proj.bias = (
src_layer.self_attn.out_proj.bias
)
dst_layer.fc1.weight = src_layer.linear1.weight
dst_layer.fc1.bias = src_layer.linear1.bias
# fairseq may use fusedlayernorm from nvidia apex - diff properties
dst_layer.self_attn_layer_norm.load_state_dict(
src_layer.norm1.state_dict()
)
dst_layer.fc2.weight = src_layer.linear2.weight
dst_layer.fc2.bias = src_layer.linear2.bias
dst_layer.final_layer_norm.load_state_dict(
src_layer.norm2.state_dict()
)
# self.encoder = self.encoder.eval().cuda().half()
def forward(self, tokens, src_lengths=None):
return self.encoder(
tokens,
src_lengths=src_lengths,
return_all_hiddens=False,
token_embeddings=None,
)
def get_layers_embedding_dim_num_heads_for_configuration(xlarge, large):
if xlarge:
# XLM-R extra large (no BERT-XL exists)
L = 24 # Layers
D = 2560 # Embedding Dim
H = 32 # Number of Heads
FD = 10240 # Feed-forward network dim
V = 30000 # Vocab Size
elif large:
# BERT-large
L = 24
D = 1024
H = 16
FD = 4096
V = 30000
else:
# BERT-base
L = 12
D = 768
H = 12
FD = 3072
V = 30000
return (L, D, H, FD, V)
# Better transformer
class PTTransformer(torch.nn.Module):
def __init__(self, transformer, embedding):
super().__init__()
self.transformer = transformer
self.embedding = embedding
self.padding_idx = DEFAULT_PADDING_IDX
def forward(self, x):
padding_mask = None
if not x.is_nested:
padding_mask = x.eq(self.padding_idx)
x = self.embedding(x)
return self.transformer(x, src_key_padding_mask=padding_mask)
def make_transformer():
return (
PTTransformer(
torch.nn.TransformerEncoder(
torch.nn.TransformerEncoderLayer(
d_model=D,
nhead=H,
dim_feedforward=FD,
batch_first=True,
activation="relu",
),
num_layers=L,
enable_nested_tensor=False,
),
embedding_layer,
)
.eval()
.cuda()
)
def copy_weights(layers_fairseq, layers_bt):
for src_layer, dst_layer in zip(layers_fairseq, layers_bt):
w_q = src_layer.self_attn.q_proj.weight
b_q = src_layer.self_attn.q_proj.bias
w_k = src_layer.self_attn.k_proj.weight
b_k = src_layer.self_attn.k_proj.bias
w_v = src_layer.self_attn.v_proj.weight
b_v = src_layer.self_attn.v_proj.bias
dst_layer.self_attn.in_proj_weight = torch.nn.Parameter(
torch.cat((w_q, w_k, w_v), dim=0)
)
dst_layer.self_attn.in_proj_bias = torch.nn.Parameter(
torch.cat((b_q, b_k, b_v), dim=0)
)
dst_layer.self_attn.out_proj.weight = src_layer.self_attn.out_proj.weight
dst_layer.self_attn.out_proj.bias = src_layer.self_attn.out_proj.bias
dst_layer.linear1.weight = src_layer.fc1.weight
dst_layer.linear1.bias = src_layer.fc1.bias
dst_layer.linear2.weight = src_layer.fc2.weight
dst_layer.linear2.bias = src_layer.fc2.bias
dst_layer.norm1.weight = src_layer.self_attn_layer_norm.weight
dst_layer.norm1.bias = src_layer.self_attn_layer_norm.bias
dst_layer.norm2.weight = src_layer.final_layer_norm.weight
dst_layer.norm2.bias = src_layer.final_layer_norm.bias
(L, D, H, FD, V) = get_layers_embedding_dim_num_heads_for_configuration(
xlarge, large
)
embedding_layer = torch.nn.Embedding(V, D, DEFAULT_PADDING_IDX)
# True means BT as source and fairseq is target, False means the other way
# mode1 = False
if bt2fairseq:
# BT as source and fairseq is target, copy BT's weight to fairseq
transformer = make_transformer()
fairseq_transformer = (
FairseqEncoder(
D,
H,
FD,
L,
embedding_layer,
dropout=0,
normalize_before=False,
torch_encoder=transformer.transformer,
activation="relu",
)
.eval()
.cuda()
)
if half:
transformer.half()
fairseq_transformer.half()
if not bt2fairseq:
# the other way around, fairseq is source and BT is target,copy fairseq's weight to BT
transformer = make_transformer()
fairseq_transformer = (
FairseqEncoder(
D,
H,
FD,
L,
embedding_layer,
dropout=0,
normalize_before=False,
torch_encoder=None,
activation="relu",
)
.eval()
.cuda()
)
# for the test where we need to load existing ckpt. It is tested that after loading
# the ckpt, the results between fairseq_transformer(BT kernel) equals BT
if half:
transformer.half()
fairseq_transformer.half()
if save:
torch.save(fairseq_transformer.state_dict(), "./fairseq.pt")
sys.exit(0)
if load:
fairseq_transformer.load_state_dict(torch.load("./fairseq.pt"))
# copy
copy_weights(fairseq_transformer.encoder.layers, transformer.transformer.layers)
device = "cuda"
lengths = (avg_sequence_length,) * batch_size
tokens = torch.full(
(batch_size, max_sequence_length),
DEFAULT_PADDING_IDX,
device=device,
dtype=torch.long,
)
for i in range(batch_size):
tokens[i, : lengths[i]] = torch.randint(
DEFAULT_PADDING_IDX + 1,
V - 1,
size=(lengths[i],),
device=device,
dtype=torch.long,
)
# mask
if half:
lengths_tensor = torch.Tensor(lengths).cuda().half()
else:
lengths_tensor = torch.Tensor(lengths).cuda()
with torch.inference_mode():
fs_output = fairseq_transformer(tokens, lengths_tensor)["encoder_out"][0]
fs_output = fs_output.transpose(0, 1)
with torch.inference_mode():
t_output = transformer(tokens)
test_lst = [
# (name, output, relative tolerance, absolute tolerance)
("FS", fs_output, 1e-4, 9e-3),
]
numerical_test(lengths, t_output, test_lst)
iters = 100
t = benchmark_torch_function(iters, transformer, tokens)
def bert_flops(B, T, D, L):
mlp = 2 * (B * T * D * 4 * D) + 2 * (B * T * D * 4 * D)
qkv = 3 * 2 * B * T * D * D
attn = 2 * B * D * T * T + 2 * B * D * T * T + 2 * B * T * D * D
return L * (mlp + qkv + attn)
flops = bert_flops(batch_size, avg_sequence_length, D, L)
flops_e = (
FlopCountAnalysis(transformer, (tokens[:, :avg_sequence_length])).total() * 2
)
with torch.inference_mode():
bt = benchmark_torch_function(iters, transformer, tokens)
fst = benchmark_torch_function(
iters, fairseq_transformer, tokens, lengths_tensor
)
def metrics(tt, baseline=None):
if baseline:
return metrics(tt) + f", Speedup: {baseline / tt:.2f}x"
return f"{tt * 1.0e3:.2f} ms/iter, {flops_e / tt / 1.0e12:.2f} TFLOP/s"
results = [
f"Seed: {seed}",
f"Padded tokens: {(1-sum(lengths)/(tokens.numel()))*100:.2f}%",
f"Batch shape: {tokens.shape}",
f"Analytical flops per batch: {flops/ batch_size / 1e9:.2f} GFLOPS",
f"Empirical flops per batch: {flops_e/ batch_size / 1e9:.2f} GFLOPS",
f"B: {batch_size}",
f"T: {avg_sequence_length}",
f"TMax: {max_sequence_length}",
f"Eager Time: {metrics(t)}",
f"BetterTransformer: {metrics(bt, t)}",
f"FST: {metrics(fst, t)}",
]
print("===========Speedup Results")
print("; ".join(results))
if __name__ == "__main__":
cli()

View File

@ -92,35 +92,8 @@ class TestExportModels(unittest.TestCase):
scripted = torch.jit.script(module)
_test_save_and_load(scripted)
def version_check():
# check Nested Tensor available. Make sure version >= '1.13.0.dev20220613'
if "fb" in torch.__version__:
return False
else:
if "+" in torch.__version__:
torch_version = torch.__version__.split("+")[0]
else:
torch_version = torch.__version__
torch_version = torch_version.split(".")
int_version = (
int(torch_version[0]) * 1000
+ int(torch_version[1]) * 10
+ int(torch_version[2])
)
if len(torch_version) == 3:
if int_version >= 1131:
return False
elif len(torch_version) == 4:
if int_version >= 1131 or (
int_version == 1130 and torch_version[3][3:] >= "20220613"
):
return False
return True
@unittest.skipIf(
version_check(),
"Targeting OSS scriptability for the 1.13.0.dev20220613 release",
torch.__version__ < "1.6.0", "Targeting OSS scriptability for the 1.6 release"
)
def test_export_transformer(self):
task, parser = get_dummy_task_and_parser()
@ -131,8 +104,7 @@ class TestExportModels(unittest.TestCase):
_test_save_and_load(scripted)
@unittest.skipIf(
version_check(),
"Targeting OSS scriptability for the 1.13.0.dev20220613 release",
torch.__version__ < "1.6.0", "Targeting OSS scriptability for the 1.6 release"
)
def test_export_transformer_no_token_pos_emb(self):
task, parser = get_dummy_task_and_parser()

View File

@ -22,33 +22,6 @@ from fairseq.tasks.fairseq_task import LegacyFairseqTask
DEFAULT_TEST_VOCAB_SIZE = 100
def version_check():
# check Nested Tensor available. Make sure version >= '1.13.0.dev20220613'
if "fb" in torch.__version__:
return False
else:
if "+" in torch.__version__:
torch_version = torch.__version__.split("+")[0]
else:
torch_version = torch.__version__
torch_version = torch_version.split(".")
int_version = (
int(torch_version[0]) * 1000
+ int(torch_version[1]) * 10
+ int(torch_version[2])
)
if len(torch_version) == 3:
if int_version >= 1131:
return False
elif len(torch_version) == 4:
if int_version >= 1131 or (
int_version == 1130 and torch_version[3][3:] >= "20220613"
):
return False
return True
class DummyTask(LegacyFairseqTask):
def __init__(self, args):
super().__init__(args)
@ -140,9 +113,7 @@ class TestJitSequenceGeneratorBase(unittest.TestCase):
JIT_MSG = "Targeting OSS scriptability for the 1.6 release"
@unittest.skipIf(
version_check(), "Targeting OSS scriptability for the 1.13.0.dev20220613 release"
)
@unittest.skipIf(torch.__version__ < "1.6.0", JIT_MSG)
class TestJitSequenceGenerator(TestJitSequenceGeneratorBase):
def test_export_transformer(self):
model = self.transformer_model