Skip to content

Instantly share code, notes, and snippets.

@DSamuelHodge
Last active November 25, 2024 19:15
Show Gist options
  • Select an option

  • Save DSamuelHodge/9019c23fe97aedac76a416c44eb05750 to your computer and use it in GitHub Desktop.

Select an option

Save DSamuelHodge/9019c23fe97aedac76a416c44eb05750 to your computer and use it in GitHub Desktop.
Differential Transformer Attention mechanism for Meta Llama 3.2 1B (Instruct).
'''
Differential Transformer Attention mechanism for Llama models.
This implementation replaces the standard LlamaSpdaAttention with a differential attention
mechanism that computes attention scores as the difference between two separate softmax
attention maps. This approach helps reduce noise and creates sparse attention patterns,
leading to improved performance in various NLP tasks.
Implementation based on research by:
Ye, T., Dong, L., Xia, Y., Sun, Y., Zhu, Y., Huang, G., & Wei, F. (2024)
"Differential Transformer"
Microsoft Research Technical Report MSR-TR-2024-42
Key Features:
- Dual softmax attention computation
- Noise cancellation through subtraction
- Enhanced sparse attention patterns
- Improved long-context modeling
- Better key information retrieval
- Reduced hallucination in generation tasks
Args:
config (LlamaConfig): Model configuration containing attention parameters
layer_idx (int): Index of the current layer
Attributes:
hidden_size (int): Dimension of hidden layers
num_attention_heads (int): Number of attention heads
head_dim (int): Dimension of each attention head
max_position_embeddings (int): Maximum sequence length
Methods:
forward(
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]
Note:
To implement this attention mechanism, use the 'replace_attention_layers' function
to swap the standard LlamaSpdaAttention with this implementation.
Example:
>>> config = LlamaConfig.from_pretrained("meta-llama/Llama-3.2-1B")
>>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B")
>>> replace_attention_layers(model, DifferentialTransformerAttention)
References:
[1] Original implementation: https://github.com/microsoft/unilm/tree/master/Diff-Transformer
'''
import os
import math
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import nn
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaRotaryEmbedding,
LlamaForCausalLM,
LlamaConfig,
apply_rotary_pos_emb,
repeat_kv,
Cache
)
class LlamaDiffSdpaAttention(LlamaAttention):
def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
super().__init__(config, layer_idx)
# Basic dimensions from Meta-Llama 3.2 1B config
self.hidden_size = config.hidden_size # 2048
self.num_heads = config.num_attention_heads # 32
self.num_key_value_heads = config.num_key_value_heads # 8
self.num_key_value_groups = self.num_heads // self.num_key_value_heads # 32 // 8 = 4
self.head_dim = config.head_dim # 64
self.scaling = self.head_dim ** -0.5
# Projections with correct dimensions
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
# Lambda parameters
self.layer_idx = layer_idx if layer_idx is not None else 0
self.lambda_init = lambda_init_fn(self.layer_idx)
self.lambda_q1 = nn.Parameter(torch.zeros(self.head_dim // 2, dtype=torch.bfloat16).normal_(mean=0, std=0.1))
self.lambda_k1 = nn.Parameter(torch.zeros(self.head_dim // 2, dtype=torch.bfloat16).normal_(mean=0, std=0.1))
self.lambda_q2 = nn.Parameter(torch.zeros(self.head_dim // 2, dtype=torch.bfloat16).normal_(mean=0, std=0.1))
self.lambda_k2 = nn.Parameter(torch.zeros(self.head_dim // 2, dtype=torch.bfloat16).normal_(mean=0, std=0.1))
# RoPE embedding
scaling_factor = config.rope_scaling["factor"] if config.rope_scaling else 1.0
self.rotary_emb = LlamaRotaryEmbedding(
config=config,
scaling_factor=scaling_factor
)
self.subln = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids)
cos = cos.to(dtype=hidden_states.dtype)
sin = sin.to(dtype=hidden_states.dtype)
# Apply RoPE before splitting
query_states = apply_rotary_pos_emb(query_states, cos, sin, position_ids)[0]
key_states = apply_rotary_pos_emb(key_states, cos, sin, position_ids)[0]
# Split after RoPE
q1, q2 = query_states.chunk(2, dim=-1)
k1, k2 = key_states.chunk(2, dim=-1)
if past_key_value is not None:
key_states = torch.stack((k1, k2), dim=2)
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx,
{"position_ids": position_ids}
)
k1, k2 = key_states.unbind(dim=2)
k1 = repeat_kv(k1, self.num_key_value_groups)
k2 = repeat_kv(k2, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
scaling = float(self.head_dim // 2) ** -0.5
q1 = q1 * scaling
q2 = q2 * scaling
attn1 = F.scaled_dot_product_attention(
q1, k1, value_states,
attn_mask=attention_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=True
)
attn2 = F.scaled_dot_product_attention(
q2, k2, value_states,
attn_mask=attention_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=True
)
lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1)).to(hidden_states.dtype)
lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1)).to(hidden_states.dtype)
lambda_full = lambda_1.unsqueeze(-1) - lambda_2.unsqueeze(-1) + self.lambda_init
attn_output = attn1 - lambda_full * attn2
attn_output = self.subln(attn_output)
attn_output = attn_output * (1 - self.lambda_init)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment