Fix tlayer torch JIT export

Fix tlayer torch JIT export exception:
"Could not cast value of type NoneType to bool"
When torch jit exporting, self.need_attn is None.
Fix #4459
This commit is contained in:
William Tambellini 2022-06-27 11:12:03 -07:00
parent cba35cdbca
commit 2c685ca0b4

View File

@ -563,7 +563,6 @@ class TransformerDecoderLayerBase(nn.Module):
"""
if need_head_weights:
need_attn = True
residual = x
if self.normalize_before:
x = self.self_attn_layer_norm(x)
@ -638,6 +637,11 @@ class TransformerDecoderLayerBase(nn.Module):
assert incremental_state is not None
self.encoder_attn._set_input_buffer(incremental_state, saved_state)
need_weights = need_attn
# To prevent "Could not cast value of type NoneType to bool" when torchscript export:
if self.need_attn is not None and self.training is not None:
need_weights = need_attn or (not self.training and self.need_attn)
x, attn = self.encoder_attn(
query=x,
key=encoder_out,
@ -645,7 +649,7 @@ class TransformerDecoderLayerBase(nn.Module):
key_padding_mask=encoder_padding_mask,
incremental_state=incremental_state,
static_kv=True,
need_weights=need_attn or (not self.training and self.need_attn),
need_weights=need_weights,
need_head_weights=need_head_weights,
)
x = self.dropout_module(x)