Last active
November 25, 2024 19:15
-
-
Save DSamuelHodge/9019c23fe97aedac76a416c44eb05750 to your computer and use it in GitHub Desktop.
Differential Transformer Attention mechanism for Meta Llama 3.2 1B (Instruct).
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| ''' | |
| Differential Transformer Attention mechanism for Llama models. | |
| This implementation replaces the standard LlamaSpdaAttention with a differential attention | |
| mechanism that computes attention scores as the difference between two separate softmax | |
| attention maps. This approach helps reduce noise and creates sparse attention patterns, | |
| leading to improved performance in various NLP tasks. | |
| Implementation based on research by: | |
| Ye, T., Dong, L., Xia, Y., Sun, Y., Zhu, Y., Huang, G., & Wei, F. (2024) | |
| "Differential Transformer" | |
| Microsoft Research Technical Report MSR-TR-2024-42 | |
| Key Features: | |
| - Dual softmax attention computation | |
| - Noise cancellation through subtraction | |
| - Enhanced sparse attention patterns | |
| - Improved long-context modeling | |
| - Better key information retrieval | |
| - Reduced hallucination in generation tasks | |
| Args: | |
| config (LlamaConfig): Model configuration containing attention parameters | |
| layer_idx (int): Index of the current layer | |
| Attributes: | |
| hidden_size (int): Dimension of hidden layers | |
| num_attention_heads (int): Number of attention heads | |
| head_dim (int): Dimension of each attention head | |
| max_position_embeddings (int): Maximum sequence length | |
| Methods: | |
| forward( | |
| hidden_states: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| past_key_value: Optional[Tuple[torch.Tensor]] = None, | |
| output_attentions: bool = False, | |
| use_cache: bool = False, | |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]] | |
| Note: | |
| To implement this attention mechanism, use the 'replace_attention_layers' function | |
| to swap the standard LlamaSpdaAttention with this implementation. | |
| Example: | |
| >>> config = LlamaConfig.from_pretrained("meta-llama/Llama-3.2-1B") | |
| >>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B") | |
| >>> replace_attention_layers(model, DifferentialTransformerAttention) | |
| References: | |
| [1] Original implementation: https://github.com/microsoft/unilm/tree/master/Diff-Transformer | |
| ''' | |
| import os | |
| import math | |
| from typing import List, Optional, Tuple, Union | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig | |
| from transformers.models.llama.modeling_llama import ( | |
| LlamaAttention, | |
| LlamaRotaryEmbedding, | |
| LlamaForCausalLM, | |
| LlamaConfig, | |
| apply_rotary_pos_emb, | |
| repeat_kv, | |
| Cache | |
| ) | |
| class LlamaDiffSdpaAttention(LlamaAttention): | |
| def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): | |
| super().__init__(config, layer_idx) | |
| # Basic dimensions from Meta-Llama 3.2 1B config | |
| self.hidden_size = config.hidden_size # 2048 | |
| self.num_heads = config.num_attention_heads # 32 | |
| self.num_key_value_heads = config.num_key_value_heads # 8 | |
| self.num_key_value_groups = self.num_heads // self.num_key_value_heads # 32 // 8 = 4 | |
| self.head_dim = config.head_dim # 64 | |
| self.scaling = self.head_dim ** -0.5 | |
| # Projections with correct dimensions | |
| self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) | |
| self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) | |
| self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) | |
| self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) | |
| # Lambda parameters | |
| self.layer_idx = layer_idx if layer_idx is not None else 0 | |
| self.lambda_init = lambda_init_fn(self.layer_idx) | |
| self.lambda_q1 = nn.Parameter(torch.zeros(self.head_dim // 2, dtype=torch.bfloat16).normal_(mean=0, std=0.1)) | |
| self.lambda_k1 = nn.Parameter(torch.zeros(self.head_dim // 2, dtype=torch.bfloat16).normal_(mean=0, std=0.1)) | |
| self.lambda_q2 = nn.Parameter(torch.zeros(self.head_dim // 2, dtype=torch.bfloat16).normal_(mean=0, std=0.1)) | |
| self.lambda_k2 = nn.Parameter(torch.zeros(self.head_dim // 2, dtype=torch.bfloat16).normal_(mean=0, std=0.1)) | |
| # RoPE embedding | |
| scaling_factor = config.rope_scaling["factor"] if config.rope_scaling else 1.0 | |
| self.rotary_emb = LlamaRotaryEmbedding( | |
| config=config, | |
| scaling_factor=scaling_factor | |
| ) | |
| self.subln = RMSNorm(self.head_dim, eps=config.rms_norm_eps) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| past_key_value: Optional[Cache] = None, | |
| output_attentions: bool = False, | |
| use_cache: bool = False, | |
| cache_position: Optional[torch.LongTensor] = None, | |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: | |
| bsz, q_len, _ = hidden_states.size() | |
| query_states = self.q_proj(hidden_states) | |
| key_states = self.k_proj(hidden_states) | |
| value_states = self.v_proj(hidden_states) | |
| query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) | |
| key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) | |
| value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) | |
| query_states = query_states.transpose(1, 2) | |
| key_states = key_states.transpose(1, 2) | |
| value_states = value_states.transpose(1, 2) | |
| cos, sin = self.rotary_emb(value_states, position_ids) | |
| cos = cos.to(dtype=hidden_states.dtype) | |
| sin = sin.to(dtype=hidden_states.dtype) | |
| # Apply RoPE before splitting | |
| query_states = apply_rotary_pos_emb(query_states, cos, sin, position_ids)[0] | |
| key_states = apply_rotary_pos_emb(key_states, cos, sin, position_ids)[0] | |
| # Split after RoPE | |
| q1, q2 = query_states.chunk(2, dim=-1) | |
| k1, k2 = key_states.chunk(2, dim=-1) | |
| if past_key_value is not None: | |
| key_states = torch.stack((k1, k2), dim=2) | |
| key_states, value_states = past_key_value.update( | |
| key_states, value_states, self.layer_idx, | |
| {"position_ids": position_ids} | |
| ) | |
| k1, k2 = key_states.unbind(dim=2) | |
| k1 = repeat_kv(k1, self.num_key_value_groups) | |
| k2 = repeat_kv(k2, self.num_key_value_groups) | |
| value_states = repeat_kv(value_states, self.num_key_value_groups) | |
| scaling = float(self.head_dim // 2) ** -0.5 | |
| q1 = q1 * scaling | |
| q2 = q2 * scaling | |
| attn1 = F.scaled_dot_product_attention( | |
| q1, k1, value_states, | |
| attn_mask=attention_mask, | |
| dropout_p=self.attention_dropout if self.training else 0.0, | |
| is_causal=True | |
| ) | |
| attn2 = F.scaled_dot_product_attention( | |
| q2, k2, value_states, | |
| attn_mask=attention_mask, | |
| dropout_p=self.attention_dropout if self.training else 0.0, | |
| is_causal=True | |
| ) | |
| lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1)).to(hidden_states.dtype) | |
| lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1)).to(hidden_states.dtype) | |
| lambda_full = lambda_1.unsqueeze(-1) - lambda_2.unsqueeze(-1) + self.lambda_init | |
| attn_output = attn1 - lambda_full * attn2 | |
| attn_output = self.subln(attn_output) | |
| attn_output = attn_output * (1 - self.lambda_init) | |
| attn_output = attn_output.transpose(1, 2).contiguous() | |
| attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) | |
| attn_output = self.o_proj(attn_output) | |
| return attn_output, None, past_key_value |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment