-
-
Save rust-play/0cfa5f6e6b0f693be8048e637be0f179 to your computer and use it in GitHub Desktop.
Code shared from the Rust Playground
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| use ndarray::{Array1, Array2, Axis}; // ndarray 0.16.1 | |
| use num_traits::Float; // num-traits 0.2.19 | |
| pub struct RmsNorm { | |
| pub weight: Array1<f32>, | |
| pub eps: f32, | |
| } | |
| impl RmsNorm { | |
| pub fn new(dim: usize, eps: f32) -> Self { | |
| Self { | |
| // Initialize weights to 1.0 | |
| weight: Array1::ones(dim), | |
| eps, | |
| } | |
| } | |
| /// Computes RMSNorm over the last axis of a 2D input (Batch, Dim) | |
| pub fn forward(&self, x: &Array2<f32>) -> Array2<f32> { | |
| // 1. Calculate the mean of squares: (1/n) * Σ x^2 | |
| // We sum along Axis(1) which is the hidden dimension | |
| let ms = x.mapv(|v| v.powi(2)).mean_axis(Axis(1)).unwrap(); | |
| // 2. Calculate 1 / sqrt(ms + eps) | |
| let rrms = ms.mapv(|v| 1.0 / (v + self.eps).sqrt()); | |
| // 3. Normalize: x * rrms | |
| // We need to insert an axis to multiply (Batch, Dim) by (Batch, 1) | |
| let mut normalized = x * &rrms.insert_axis(Axis(1)); | |
| // 4. Scale by learnable weight g | |
| normalized *= &self.weight; | |
| normalized | |
| } | |
| } | |
| fn main() { | |
| let dim = 4; | |
| let norm = RmsNorm::new(dim, 1e-5); | |
| // Example input: Batch of 2, Dim of 4 | |
| let input = ndarray::array![ | |
| [1.0, 2.0, 3.0, 4.0], | |
| [0.5, 0.5, 0.5, 0.5] | |
| ]; | |
| let output = norm.forward(&input); | |
| println!("Normalized output:\n{:?}", output); | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment