diff --git a/fairseq/modules/rotary_positional_embedding.py b/fairseq/modules/rotary_positional_embedding.py index 84b88984..b74028b0 100644 --- a/fairseq/modules/rotary_positional_embedding.py +++ b/fairseq/modules/rotary_positional_embedding.py @@ -14,27 +14,26 @@ class RotaryPositionalEmbedding(torch.nn.Module): super().__init__() inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq) - self.seq_len_cached = None - self.cos_cached = None - self.sin_cached = None + self.seq_len_cached = 0 + self.cos_cached = torch.empty(self.seq_len_cached, 1, 1, dim) + self.sin_cached = torch.empty(self.seq_len_cached, 1, 1, dim) self.precision = precision - def forward(self, x, seq_len=None): + def forward(self, x, seq_len: int = 0): """ Args: x: Input x with T X B X C seq_len: Sequence length of input x """ - if seq_len != self.seq_len_cached: + if seq_len > self.seq_len_cached: self.seq_len_cached = seq_len t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq) freqs = torch.einsum("i,j->ij", t, self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1).to(x.device) - self.cos_cached = emb.cos()[:, None, None, :] - self.sin_cached = emb.sin()[:, None, None, :] + self.cos_cached = emb.cos().view(emb.size(0), 1, 1, emb.size(1)) + self.sin_cached = emb.sin().view(emb.size(0), 1, 1, emb.size(1)) return self.cos_cached, self.sin_cached - # rotary pos emb helpers: def rotate_half(x): x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] diff --git a/tests/test_rotary_positional_embedding.py b/tests/test_rotary_positional_embedding.py index ea9f0bee..7c44e86d 100644 --- a/tests/test_rotary_positional_embedding.py +++ b/tests/test_rotary_positional_embedding.py @@ -59,6 +59,27 @@ class TestRotaryPositionalEmbedding(unittest.TestCase): ) ) + def test_jit_compile_rope_module(self): + module_scripted = torch.jit.script(self.rope_pos_emd) + apply_rotary_scripted = torch.jit.script(apply_rotary_pos_emb) + # Test several different lengths + for T in [3, 5, 10]: + sample = torch.randn(T, self.B, self.C) + # Run forward pass with the original module + cos_original, sin_original = self.rope_pos_emd(sample, T) + query = sample.view(T, self.B, 1, self.C) + new_query, new_key = apply_rotary_pos_emb(query, query, cos_original, sin_original) + + # Run forward pass with the scripted module + cos_scripted, sin_scripted = module_scripted(sample, T) + new_query_scripted, new_key_scripted = apply_rotary_scripted(query, query, cos_scripted, sin_scripted) + + # Ensure the outputs are the same + self.assertTrue(torch.allclose(cos_original, cos_scripted)) + self.assertTrue(torch.allclose(sin_original, sin_scripted)) + self.assertTrue(torch.allclose(new_query, new_query_scripted)) + self.assertTrue(torch.allclose(new_key, new_key_scripted)) + if __name__ == "__main__": unittest.main()