Skip to content

Instantly share code, notes, and snippets.

@rust-play
Created December 20, 2025 14:27
Show Gist options
  • Select an option

  • Save rust-play/0cfa5f6e6b0f693be8048e637be0f179 to your computer and use it in GitHub Desktop.

Select an option

Save rust-play/0cfa5f6e6b0f693be8048e637be0f179 to your computer and use it in GitHub Desktop.
Code shared from the Rust Playground
use ndarray::{Array1, Array2, Axis}; // ndarray 0.16.1
use num_traits::Float; // num-traits 0.2.19
pub struct RmsNorm {
pub weight: Array1<f32>,
pub eps: f32,
}
impl RmsNorm {
pub fn new(dim: usize, eps: f32) -> Self {
Self {
// Initialize weights to 1.0
weight: Array1::ones(dim),
eps,
}
}
/// Computes RMSNorm over the last axis of a 2D input (Batch, Dim)
pub fn forward(&self, x: &Array2<f32>) -> Array2<f32> {
// 1. Calculate the mean of squares: (1/n) * Σ x^2
// We sum along Axis(1) which is the hidden dimension
let ms = x.mapv(|v| v.powi(2)).mean_axis(Axis(1)).unwrap();
// 2. Calculate 1 / sqrt(ms + eps)
let rrms = ms.mapv(|v| 1.0 / (v + self.eps).sqrt());
// 3. Normalize: x * rrms
// We need to insert an axis to multiply (Batch, Dim) by (Batch, 1)
let mut normalized = x * &rrms.insert_axis(Axis(1));
// 4. Scale by learnable weight g
normalized *= &self.weight;
normalized
}
}
fn main() {
let dim = 4;
let norm = RmsNorm::new(dim, 1e-5);
// Example input: Batch of 2, Dim of 4
let input = ndarray::array![
[1.0, 2.0, 3.0, 4.0],
[0.5, 0.5, 0.5, 0.5]
];
let output = norm.forward(&input);
println!("Normalized output:\n{:?}", output);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment