Skip to content

Instantly share code, notes, and snippets.

@Rexicon226
Last active February 2, 2026 14:55
Show Gist options
  • Select an option

  • Save Rexicon226/72b67f5dc80c4c65e6392316670c50f2 to your computer and use it in GitHub Desktop.

Select an option

Save Rexicon226/72b67f5dc80c4c65e6392316670c50f2 to your computer and use it in GitHub Desktop.
AVX-512 Keccak (mainly for parallel, not serial)
const V = @Vector(8, u64);
fn keccak(bytes: []const u8, comptime r: u32, comptime delim: u8, out: *[r]u8) void {
var state: [144]u8 align(32) = @splat(0);
var x: [5]V = @splat(@splat(0));
var input = bytes;
const rsize = 200 - 2 * r;
comptime std.debug.assert(rsize == 72); // TODO
// Absorb full rsize blocks.
while (input.len >= rsize) {
input = load(&x, input, rsize);
keccakF(24, &x);
}
// Pad remaining bytes.
@memcpy(state[0..input.len], input);
state[input.len] = delim;
@memset(state[input.len + 1 .. rsize - 1], 0);
state[rsize - 1] |= 0x80;
// Absorb final block.
_ = load(&x, &state, rsize);
keccakF(24, &x);
// Squeeze output.
mstore(out[0..], x[0], .{ true, true, true, true, true, false, false, false });
mstore(out[40..], x[1], .{ true, true, true, true, false, false, false, false });
}
fn mstore(x: []u8, src: V, comptime mask: @Vector(8, bool)) void {
const S = struct {
extern fn @"llvm.masked.store.v8i64.p0"(V, [*]u8, i32, @Vector(8, bool)) void;
};
S.@"llvm.masked.store.v8i64.p0"(src, x.ptr, 1, mask);
}
fn mload(x: []const u8, comptime mask: @Vector(8, bool)) V {
const S = struct {
extern fn @"llvm.masked.load.v8i64.p0"([*]const u8, i32, @Vector(8, bool), V) V;
};
return S.@"llvm.masked.load.v8i64.p0"(x.ptr, 1, mask, @splat(0));
}
fn load(x: *[5]V, bytes: []const u8, rsize: u32) []const u8 {
var offset: usize = 0;
x[0] ^= mload(bytes[offset..], .{ true, true, true, true, true, false, false, false });
offset += 40;
if (rsize == 72) {
x[1] ^= mload(bytes[offset..], .{ true, true, true, true, false, false, false, false });
offset += 32;
} else if (rsize >= 104) {
@panic("TODO, assumes r = 64 atm");
}
return bytes[offset..];
}
fn pi(y: u32) [8]u32 {
var out: [8]u32 = @splat(0);
for (0..5) |x| out[x] = @mod(3 * (@as(i32, x) - 3 * y), 5);
return out;
}
extern fn @"llvm.x86.avx512.prolv.q.512"(V, V) V;
const rolv = @"llvm.x86.avx512.prolv.q.512";
/// The round constants RC[i]
const RC: [24]u64 = .{
0x0000000000000001, 0x0000000000008082, 0x800000000000808a,
0x8000000080008000, 0x000000000000808b, 0x0000000080000001,
0x8000000080008081, 0x8000000000008009, 0x000000000000008a,
0x0000000000000088, 0x0000000080008009, 0x000000008000000a,
0x000000008000808b, 0x800000000000008b, 0x8000000000008089,
0x8000000000008003, 0x8000000000008002, 0x8000000000000080,
0x000000000000800a, 0x800000008000000a, 0x8000000080008081,
0x8000000000008080, 0x0000000080000001, 0x8000000080008008,
};
/// https://keccak.team/files/Keccak-reference-3.0.pdf
fn keccakF(comptime rounds: u32, A: *[5]V) void {
// zig fmt: off
// A = Round[b](A, RC[i])
for (0..rounds) |i| {
// θ step
// C[x] = A[x,0] xor A[x,1] xor A[x,2] xor A[x,3] xor A[x,4], for x in 0…4
const C = A[0] ^ A[1] ^ A[2] ^ A[3] ^ A[4];
// D[x] = C[x-1] xor rot(C[x+1],1), for x in 0…4
//
// Note that, "All the operations on the indices are done modulo 5".
//
// 1. We slide the C state "downwards" by one element to achieve the "-1" index.
// 2. We slide the C state "upwards" by one element to achieve the "+1" index.
// 3. We perform the rot + xor to compute the D[x] state.
const down = @shuffle(u64, C, undefined, [_]i32{ 4, 0, 1, 2, 3, 0, 0, 0 });
const up = @shuffle(u64, C, undefined, [_]i32{ 1, 2, 3, 4, 0, 0, 0, 0 });
const D = down ^ std.math.rotl(V, up, 1);
// A[x,y] = A[x,y] xor D[x], for (x,y) in (0…4,0…4)
A[0] ^= D;
A[1] ^= D;
A[2] ^= D;
A[3] ^= D;
A[4] ^= D;
// ρ and π steps
//
// B[y,2*x+3*y] = rot(A[x,y], r[x,y]), for (x,y) in (0…4,0…4)
//
// After ρ (rotations), the π permutes lanes. We see that:
//
// B[x', y'] = rot(A[x, y], r[x, y]) where
// x' = y
// y' = (2x * 3y) mod 5
//
// We want to work backwards and compute, for a known destination row y'
// which original x values land in each output lane x' = 0..4.
//
// Given x' = y, we easily see that y' = (2x * 3x') mod 5, and solving
// for x gives us 2x = y' - 3x' (mod 5). The inverse of 2 mod 5 is 3,
// so we arrive at x = 3y' + x' mod 5. And given a new row `y`, that is
// what the pi() function returns across the whole row.
const B = [5]V{
@shuffle(u64, rolv(A[0], .{ 0, 1, 62, 28, 27, 0, 0, 0 }), undefined, pi(0)),
@shuffle(u64, rolv(A[1], .{ 36, 44, 6, 55, 20, 0, 0, 0 }), undefined, pi(1)),
@shuffle(u64, rolv(A[2], .{ 3, 10, 43, 25, 39, 0, 0, 0 }), undefined, pi(2)),
@shuffle(u64, rolv(A[3], .{ 41, 45, 15, 21, 8, 0, 0, 0 }), undefined, pi(3)),
@shuffle(u64, rolv(A[4], .{ 18, 2, 61, 56, 14, 0, 0, 0 }), undefined, pi(4)),
};
// χ step
// A[x,y] = B[x,y] xor ((not B[x+1,y]) and B[x+2,y]), for (x,y) in (0…4,0…4)
A[0] = B[0] ^ (~B[1] & B[2]);
A[1] = B[1] ^ (~B[2] & B[3]);
A[2] = B[2] ^ (~B[3] & B[4]);
A[3] = B[3] ^ (~B[4] & B[0]);
A[4] = B[4] ^ (~B[0] & B[1]);
// The last thing we do is transpose the vectors before starting the next round.
//
// We start with:
// A[0] = [ 00, 01, 02, 03, 04 ]
// A[1] = [ 10, 11, 12, 13, 14 ]
// A[2] = [ 20, 21, 22, 23, 24 ]
// A[3] = [ 30, 31, 32, 33, 34 ]
// A[4] = [ 40, 41, 42, 43, 44 ]
//
// We end up with:
// A[0] = [ 00, 10, 20, 30, 40 ]
// A[1] = [ 01, 11, 21, 31, 41 ]
// A[2] = [ 02, 12, 22, 32, 42 ]
// A[3] = [ 03, 13, 23, 33, 43 ]
// A[4] = [ 04, 14, 24, 34, 44 ]
// [00, 10, 01, 11, 02, 12, 04, 03]
const c1 = @shuffle(u64, A[0], A[1], [_]i32{ 0, -1, 1, -2, 2, -3, 4, 3 });
// [21, 31, 20, 30, 23, 24, 22, 32]
const c2 = @shuffle(u64, A[2], A[3], [_]i32{ 1, -2, 0, -1, 3, 4, 2, -3 });
// [00, 10, 20, 30, 02, 12, 22, 32]
const c3 = @select(u64, [_]bool{ true, true, false, false, true, true, false, false }, c1, c2);
// [21, 31, 01, 11, 23, 24, 04, 03]
const c4 = @select(u64, [_]bool{ false, false, true, true, false, false, true, true }, c1, c2);
// [14, 34, 13, 33, __, __, __, __]
const c5 = @shuffle(u64, A[1], A[3], [_]i32{ 4, -5, 3, -4, 0, 0, 0, 0 });
// [21, 31, 01, 11, 44, 24, 04, __]
const c6 = @select(u64, [_]bool{ false, false, false, false, true, false, false, false }, A[4], c4);
// [14, 34, 13, 33, 23, 24, 04, 03]
const c7 = @select(u64, [_]bool{ true, true, true, true, false, false, false, false }, c5, c4);
// [00, 10, 20, 30, 40, 00, 43, 44]
A[0] = @shuffle(u64, c3, A[4], [_]i32{ 0, 1, 2, 3, -1, 0, -4, -5 });
// [01, 11, 21, 31, 41, __, __, __]
A[1] = @shuffle(u64, c4, A[4], [_]i32{ 2, 3, 0, 1, -2, 0, 0, 0 });
// [02, 12, 22, 32, 42, __, __, __]
A[2] = @shuffle(u64, c3, A[4], [_]i32{ 4, 5, 6, 7, -3, 0, 0, 0 });
// [03, 13, 23, 33, 43, __, __, __]
A[3] = @shuffle(u64, c7, A[4], [_]i32{ 7, 2, 4, 3, -4, 0, 0, 0 });
// [04, 14, 24, 34, 44, __, __, __]
A[4] = @shuffle(u64, c6, c5, [_]i32{ 6, -1, 5, -2, 4, 0, 0, 0 });
// ι step
//
// A[0,0] = A[0,0] xor RC[i]
//
// While the final step does not depend on the tranpose, as the 00 element does
// not move, we do it last in order to not stall all of those perm/blends.
A[0] ^= .{ RC[i], 0, 0, 0, 0, 0, 0, 0 };
}
// zig fmt: on
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment