Co-authored-by: moslehpour <moslehpour@meta.com>
This commit is contained in:
Mohsen 2022-10-17 13:04:51 -07:00 committed by GitHub
parent 66d713b4d0
commit 05625e3e6e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -97,7 +97,7 @@ class SinusoidalPositionalEmbedding(nn.Module):
if self.onnx_trace:
flat_embeddings = self.weights.detach().index_select(0, positions.view(-1))
embedding_shape = torch.cat(
(bsz.view(1), seq_len.view(1), torch.tensor([-1], dtype=torch.long))
(bsz, seq_len, torch.tensor([-1], dtype=torch.long))
)
embeddings = torch.onnx.operators.reshape_from_tensor_shape(
flat_embeddings, embedding_shape