mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-07-14 18:50:22 +03:00
Make RotaryPositionalEmbedding jit-compatible (#5237)
This commit is contained in:
parent
31fba013a0
commit
100cd91db1
@ -14,27 +14,26 @@ class RotaryPositionalEmbedding(torch.nn.Module):
|
||||
super().__init__()
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
||||
self.register_buffer("inv_freq", inv_freq)
|
||||
self.seq_len_cached = None
|
||||
self.cos_cached = None
|
||||
self.sin_cached = None
|
||||
self.seq_len_cached = 0
|
||||
self.cos_cached = torch.empty(self.seq_len_cached, 1, 1, dim)
|
||||
self.sin_cached = torch.empty(self.seq_len_cached, 1, 1, dim)
|
||||
self.precision = precision
|
||||
|
||||
def forward(self, x, seq_len=None):
|
||||
def forward(self, x, seq_len: int = 0):
|
||||
"""
|
||||
Args:
|
||||
x: Input x with T X B X C
|
||||
seq_len: Sequence length of input x
|
||||
"""
|
||||
if seq_len != self.seq_len_cached:
|
||||
if seq_len > self.seq_len_cached:
|
||||
self.seq_len_cached = seq_len
|
||||
t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
|
||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
|
||||
self.cos_cached = emb.cos()[:, None, None, :]
|
||||
self.sin_cached = emb.sin()[:, None, None, :]
|
||||
self.cos_cached = emb.cos().view(emb.size(0), 1, 1, emb.size(1))
|
||||
self.sin_cached = emb.sin().view(emb.size(0), 1, 1, emb.size(1))
|
||||
return self.cos_cached, self.sin_cached
|
||||
|
||||
|
||||
# rotary pos emb helpers:
|
||||
def rotate_half(x):
|
||||
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
|
||||
|
@ -59,6 +59,27 @@ class TestRotaryPositionalEmbedding(unittest.TestCase):
|
||||
)
|
||||
)
|
||||
|
||||
def test_jit_compile_rope_module(self):
|
||||
module_scripted = torch.jit.script(self.rope_pos_emd)
|
||||
apply_rotary_scripted = torch.jit.script(apply_rotary_pos_emb)
|
||||
# Test several different lengths
|
||||
for T in [3, 5, 10]:
|
||||
sample = torch.randn(T, self.B, self.C)
|
||||
# Run forward pass with the original module
|
||||
cos_original, sin_original = self.rope_pos_emd(sample, T)
|
||||
query = sample.view(T, self.B, 1, self.C)
|
||||
new_query, new_key = apply_rotary_pos_emb(query, query, cos_original, sin_original)
|
||||
|
||||
# Run forward pass with the scripted module
|
||||
cos_scripted, sin_scripted = module_scripted(sample, T)
|
||||
new_query_scripted, new_key_scripted = apply_rotary_scripted(query, query, cos_scripted, sin_scripted)
|
||||
|
||||
# Ensure the outputs are the same
|
||||
self.assertTrue(torch.allclose(cos_original, cos_scripted))
|
||||
self.assertTrue(torch.allclose(sin_original, sin_scripted))
|
||||
self.assertTrue(torch.allclose(new_query, new_query_scripted))
|
||||
self.assertTrue(torch.allclose(new_key, new_key_scripted))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user