Skip to content

Instantly share code, notes, and snippets.

@jingwangsg
Created October 7, 2024 21:09
Show Gist options
  • Select an option

  • Save jingwangsg/54c4f211514ee8e6f710ca1aead111e3 to your computer and use it in GitHub Desktop.

Select an option

Save jingwangsg/54c4f211514ee8e6f710ca1aead111e3 to your computer and use it in GitHub Desktop.
rotaryemb_patch
import torch
import torch.nn as nn
class RotaryEmbeddingPatched(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (
self.base
** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)
)
self.register_buffer("inv_freq", inv_freq, persistent=False)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
with torch.autocast(device_type=device.type, enabled=False):
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).float()
freqs = torch.outer(t, self.inv_freq.float())
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
self.register_buffer("cos_cached", cos.to(dtype), persistent=False)
self.register_buffer("sin_cached", sin.to(dtype), persistent=False)
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if not hasattr(self, "cos_cached") or seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
return (
self.cos_cached[:seq_len].to(dtype=x.dtype),
self.sin_cached[:seq_len].to(dtype=x.dtype),
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment