Make RotaryPositionalEmbedding jit-compatible (#5237)

This commit is contained in:
Egor Lakomkin 2023-07-07 08:08:01 +02:00 committed by GitHub
parent 31fba013a0
commit 100cd91db1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 28 additions and 8 deletions

View File

@ -14,27 +14,26 @@ class RotaryPositionalEmbedding(torch.nn.Module):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
self.seq_len_cached = None
self.cos_cached = None
self.sin_cached = None
self.seq_len_cached = 0
self.cos_cached = torch.empty(self.seq_len_cached, 1, 1, dim)
self.sin_cached = torch.empty(self.seq_len_cached, 1, 1, dim)
self.precision = precision
def forward(self, x, seq_len=None):
def forward(self, x, seq_len: int = 0):
"""
Args:
x: Input x with T X B X C
seq_len: Sequence length of input x
"""
if seq_len != self.seq_len_cached:
if seq_len > self.seq_len_cached:
self.seq_len_cached = seq_len
t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.cos_cached = emb.cos()[:, None, None, :]
self.sin_cached = emb.sin()[:, None, None, :]
self.cos_cached = emb.cos().view(emb.size(0), 1, 1, emb.size(1))
self.sin_cached = emb.sin().view(emb.size(0), 1, 1, emb.size(1))
return self.cos_cached, self.sin_cached
# rotary pos emb helpers:
def rotate_half(x):
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]

View File

@ -59,6 +59,27 @@ class TestRotaryPositionalEmbedding(unittest.TestCase):
)
)
def test_jit_compile_rope_module(self):
module_scripted = torch.jit.script(self.rope_pos_emd)
apply_rotary_scripted = torch.jit.script(apply_rotary_pos_emb)
# Test several different lengths
for T in [3, 5, 10]:
sample = torch.randn(T, self.B, self.C)
# Run forward pass with the original module
cos_original, sin_original = self.rope_pos_emd(sample, T)
query = sample.view(T, self.B, 1, self.C)
new_query, new_key = apply_rotary_pos_emb(query, query, cos_original, sin_original)
# Run forward pass with the scripted module
cos_scripted, sin_scripted = module_scripted(sample, T)
new_query_scripted, new_key_scripted = apply_rotary_scripted(query, query, cos_scripted, sin_scripted)
# Ensure the outputs are the same
self.assertTrue(torch.allclose(cos_original, cos_scripted))
self.assertTrue(torch.allclose(sin_original, sin_scripted))
self.assertTrue(torch.allclose(new_query, new_query_scripted))
self.assertTrue(torch.allclose(new_key, new_key_scripted))
if __name__ == "__main__":
unittest.main()