Skip to content

Instantly share code, notes, and snippets.

@srcreigh
Last active February 18, 2026 18:43
Show Gist options
  • Select an option

  • Save srcreigh/1d05f1e74150e9c50938c1b6e7bcdff6 to your computer and use it in GitHub Desktop.

Select an option

Save srcreigh/1d05f1e74150e9c50938c1b6e7bcdff6 to your computer and use it in GitHub Desktop.
Zig union(enum)-based diagnostics
// =============================================================================
// diagnostics.zig: Generic union(enum) Diagnostics error
// payloads library
//
// MIT LICENSE (C) 2026 Shane Creighton-Young
//
// Project Link: https://gist.github.com/srcreigh/1d05f1e74150e9c50938c1b6e7bcdff6
//
// HOW TO USE
// - Read the article https://srcreigh.ca/posts/error-payloads-in-zig/
// - Read the "basic example usage" and "more example usage" tests
// - Copy the file into your project
//
// CHANGELOG
// Feb 13 2026: Initial version
// Feb 16 2026: Add tests
// Feb 18 2026: Add support for calling functions which take an optional pointer
// to Diagnostics. See "call works for funcs taking optional diag ptr"
// test.
//
// =============================================================================
// =============================================================================
// MIT License
//
// Copyright (c) 2026 Shane Creighton-Young
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.
// =============================================================================
const std = @import("std");
/// Make a union(enum)-based Diagnostics type.
pub fn FromUnion(comptime Payload: type) type {
return FromUnionInternal(ErrorSetFromEnum(std.meta.FieldEnum(Payload)), Payload);
}
inline fn FromUnionInternal(comptime E: type, comptime P: type) type {
return struct {
payload: ?Payload = null,
pub const Payload = P;
pub const Error = E;
/// Gets an payload for an error.
pub fn get(self: *const @This(), comptime err: @This().Error) ?ErrorPayload(err) {
const field_name = @errorName(err);
const payload = self.payload orelse return null;
if (std.meta.activeTag(payload) == @field(std.meta.Tag(Payload), field_name)) {
return @field(payload, field_name);
} else {
return null;
}
}
/// Sets the payload for this error, and returns the error.
pub inline fn withContext(self: *@This(), comptime err: @This().Error, value: ?ErrorPayload(err)) @This().Error {
if (value) |v| {
self.payload = @unionInit(Payload, @errorName(err), v);
}
return err;
}
/// Copies an error payload from another Diagnostic and returns the error.
pub inline fn forwardFrom(self: *@This(), other: anytype, comptime err: @This().Error) @This().Error {
if (other.get(@errorCast(err))) |v| {
self.payload = @unionInit(Payload, @errorName(err), v);
}
return err;
}
/// Copies all error payloads from another Diagnostic and returns the error.
pub inline fn forwardAll(
self: *@This(),
other: anytype,
err: @TypeOf(other).Error,
) @TypeOf(other).Error {
switch (err) {
inline else => |e| {
if (other.get(e)) |v| {
self.payload = @unionInit(Payload, @errorName(e), v);
}
return e;
},
}
}
pub fn ErrorPayload(comptime err: @This().Error) type {
return @FieldType(Payload, @errorName(err));
}
/// Automatically instantiates a functions Diagnostics argument, calls the function,
/// copies the error payload on failure, and returns the result of the function call.
///
/// The function must take a *Diagnostics as its last argument. Specify the function's
/// arguments as a tuple, without the *Diagnostics argument.
///
/// For methods, `callMethod` may be more convenient.
pub inline fn call(
self: *@This(),
comptime f: anytype,
a: anytype,
) @TypeOf(@call(.auto, f, a ++ .{@constCast(&OfFunction(f){})})) {
// TODO: inline call produces better profiling traces, but
// puts all the Diagnostics on the same stack frame.. is that an issue?
var diag: OfFunction(f) = .{};
return @call(.auto, f, a ++ .{&diag}) catch |err| {
return self.forwardAll(diag, err);
};
}
/// Same as `call`, but for receiver functions. Looks up the method from the
/// receiver's type and converts the receiver to the type the function expects.
///
/// The receiver `v` must be passed as a pointer, even if the function expects a
/// value type receiver.
///
/// If v were a value type passed directly, it would be impossible for this
/// function to know if v is mutable in the calling context, and so seems
/// impossible to safely call a non-const receiver function.
pub inline fn callMethod(
self: *@This(),
r: anytype,
comptime m: MethodEnum(@TypeOf(r)),
a: anytype,
) MethodReturnType(@TypeOf(r), m, @TypeOf(a)) {
const v_info = @typeInfo(@TypeOf(r));
if (v_info != .pointer) {
@compileError("callMethod: receiver must be passed as a pointer (even to non-pointer receiver functions)");
}
const func = method(@TypeOf(r), m);
var diag: OfFunction(func) = .{};
comptime var needs_deref = false;
comptime {
const num_v_layers = numPointerLayers(@TypeOf(r));
const num_param_layers = numPointerLayers(ParameterType(func, 0));
if (num_v_layers == num_param_layers + 1) {
needs_deref = true;
}
}
if (needs_deref) {
return @call(.auto, func, .{r.*} ++ a ++ .{&diag}) catch |err| {
return self.forwardAll(diag, err);
};
} else {
return @call(.auto, func, .{r} ++ a ++ .{&diag}) catch |err| {
return self.forwardAll(diag, err);
};
}
}
};
}
/// Helper to pull the Diagnostics type from a function.
pub fn OfFunction(comptime func: anytype) type {
const n_params = numParameters(func);
if (n_params == 0) @compileError("OfFunction: function must have at least one parameter (a *Diagnostics)");
const T = ParameterType(func, n_params - 1);
// traverse optional
var last_param_info = @typeInfo(T);
if (last_param_info == .optional) {
last_param_info = @typeInfo(last_param_info.optional.child);
}
if (last_param_info != .pointer) {
@compileError("diagnostics.OfFunction: functions last parameter must be a pointer or optional pointer to a Diagnostics struct");
}
return last_param_info.pointer.child;
}
/// Helper to get the diag's error set from the `diag: *Diagnostics` argument of a function.
/// This can be used to declare the error set of the function return value.
///
/// Call it like this:
///
/// diagnostics.Error(@TypeOf(diag))
pub fn Error(D: type) type {
var typ = @typeInfo(D);
if (typ == .optional) {
typ = @typeInfo(typ.optional.child);
}
return typ.pointer.child.Error;
}
test "basic example usage" {
const Foo = struct {
fn myFunc(foo: i32, diag: *FromUnion(union(enum) {
WrongNumber: struct {
correct_number: i32,
},
})) Error(@TypeOf(diag))!void {
if (foo != 42) {
return diag.withContext(error.WrongNumber, .{ .correct_number = 42 });
}
}
};
var diag = OfFunction(Foo.myFunc){};
var failed = false;
Foo.myFunc(0, &diag) catch |err| switch (err) {
error.WrongNumber => {
try std.testing.expectEqual(42, diag.get(error.WrongNumber).?.correct_number);
failed = true;
},
};
try std.testing.expect(failed);
}
test "more example usage" {
// Tells you how many bytes you can still get on OutOfMemory.
const ToyAllocator = struct {
data: [100]u8 = std.mem.zeroes([100]u8),
ptr: ?[]u8 = null,
pub fn alloc(
self: *@This(),
T: type,
size: usize,
diag: *FromUnion(union(enum) {
OutOfMemory: struct {
// largest block still available to the allocator
num_available: usize,
},
}),
) Error(@TypeOf(diag))![]T {
if (self.ptr != null) return diag.withContext(error.OutOfMemory, .{ .num_available = 0 });
if (size > 100) return diag.withContext(error.OutOfMemory, .{ .num_available = 100 });
const slice = self.data[0..size];
self.ptr = slice;
return slice;
}
pub fn free(self: *@This(), ptr: []u8) void {
std.debug.assert(self.ptr != null);
std.debug.assert(self.ptr.?.ptr == ptr.ptr);
std.debug.assert(self.ptr.?.len == ptr.len);
self.ptr = null;
}
};
{
var alloc = ToyAllocator{};
var alloc_diag = OfFunction(ToyAllocator.alloc){};
{
var failed = false;
_ = alloc.alloc(u8, 1000, &alloc_diag) catch |err| switch (err) {
error.OutOfMemory => {
try std.testing.expectEqual(100, alloc_diag.get(error.OutOfMemory).?.num_available);
failed = true;
},
};
try std.testing.expect(failed);
}
{
const foo = try alloc.alloc(u8, 10, &alloc_diag);
defer alloc.free(foo);
try std.testing.expectEqual(10, foo.len);
var failed = false;
_ = alloc.alloc(u8, 10, &alloc_diag) catch |err| switch (err) {
error.OutOfMemory => {
try std.testing.expectEqual(0, alloc_diag.get(error.OutOfMemory).?.num_available);
failed = true;
},
};
try std.testing.expect(failed);
}
}
// Simple alloc formatter. Returns ToyAllocator payload unchanged.
const Formatter = struct {
pub fn formatAlloc(
alloc: *ToyAllocator,
comptime fmt: []const u8,
args: anytype,
diag: *FromUnion(union(enum) {
OutOfMemory: OfFunction(ToyAllocator.alloc).ErrorPayload(error.OutOfMemory),
}),
) Error(@TypeOf(diag))![]u8 {
const sz = std.fmt.count(fmt, args);
const buf = try diag.callMethod(&alloc, .alloc, .{ u8, sz });
_ = std.fmt.bufPrint(buf, fmt, args) catch |err| switch (err) {
error.NoSpaceLeft => unreachable,
};
return buf;
}
};
{
var alloc = ToyAllocator{};
var fmt_diag = OfFunction(Formatter.formatAlloc){};
{
const str = try Formatter.formatAlloc(&alloc, "Hello, {s}!", .{"world"}, &fmt_diag);
defer alloc.free(str);
try std.testing.expectEqualStrings("Hello, world!", str);
}
{
var failed = false;
_ = Formatter.formatAlloc(&alloc, "This is a loooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooong string that won't fit in the allocator buffer: {s}", .{"oops"}, &fmt_diag) catch |err| switch (err) {
error.OutOfMemory => {
try std.testing.expectEqual(100, fmt_diag.get(error.OutOfMemory).?.num_available);
failed = true;
},
else => unreachable,
};
try std.testing.expect(failed);
}
}
// Simulate batch execution. Return first failure with context to retry remainder,
// and include OOM context in case of query string OOM.
const QueryExecutor = struct {
next_err_idx: ?usize,
next_err_msg: ?[]const u8,
pub fn executeQueries(
self: *@This(),
alloc: *ToyAllocator,
comptime fmt: []const u8,
// slice of args for fmt
args_args: anytype,
out_buf: []usize,
diag: *FromUnion(union(enum) {
QueryTooLarge: struct {
query_idx: usize,
max_size: ?usize,
},
QueryFailed: struct {
query_idx: usize,
// backing storing for err_msg slice
errmsg_buf: [25]u8,
errmsg_len: usize,
pub fn slice(self2: *const @This()) []const u8 {
return self2.errmsg_buf[0..self2.errmsg_len];
}
},
}),
) Error(@TypeOf(diag))!void {
inline for (args_args, 0..) |args, i| {
var format_alloc_diag = OfFunction(Formatter.formatAlloc){};
const q = Formatter.formatAlloc(alloc, fmt, args, &format_alloc_diag) catch |err| switch (err) {
error.OutOfMemory => {
const max_size = if (format_alloc_diag.get(error.OutOfMemory)) |oom| oom.num_available else null;
return diag.withContext(error.QueryTooLarge, .{
.query_idx = i,
.max_size = max_size,
});
},
};
defer alloc.free(q);
if (i != self.next_err_idx) {
out_buf[i] = i;
} else {
std.debug.assert(self.next_err_msg != null);
defer {
self.next_err_idx = null;
self.next_err_msg = null;
}
var payload = std.mem.zeroes(OfFunction(executeQueries).ErrorPayload(error.QueryFailed));
if (std.fmt.bufPrint(&payload.errmsg_buf, "{s}", .{self.next_err_msg.?})) |slice| {
payload.errmsg_len = slice.len;
} else |err| switch (err) {
error.NoSpaceLeft => {
// full, truncate with "..."
const len = payload.errmsg_buf.len;
payload.errmsg_len = len;
@memcpy(payload.errmsg_buf[len - 3 ..][0..3], "...");
},
}
payload.query_idx = i;
return diag.withContext(error.QueryFailed, payload);
}
}
}
};
{
var alloc: ToyAllocator = .{};
var out_buf: [5]usize = std.mem.zeroes([5]usize);
var qb = QueryExecutor{
.next_err_idx = 2,
.next_err_msg = "syntax error",
};
{
var qb_diag = OfFunction(QueryExecutor.executeQueries){};
var failed = false;
const args = &.{ .{1}, .{2}, .{3}, .{4} };
qb.executeQueries(
&alloc,
"SELECT * FROM table WHERE id = {d}",
args,
&out_buf,
&qb_diag,
) catch |err| switch (err) {
error.QueryFailed => {
const payload = qb_diag.get(error.QueryFailed).?;
try std.testing.expectEqual(2, payload.query_idx);
try std.testing.expectEqualStrings("syntax error", payload.slice());
failed = true;
},
else => unreachable,
};
try std.testing.expect(failed);
}
{
qb.next_err_idx = 0;
qb.next_err_msg = "long error message which should get truncated by the error handling code";
var failed = false;
var qb_diag = OfFunction(QueryExecutor.executeQueries){};
qb.executeQueries(
&alloc,
"SELECT * FROM table WHERE id = {d}",
&.{.{1}},
&out_buf,
&qb_diag,
) catch |err| switch (err) {
error.QueryFailed => {
const payload = qb_diag.get(error.QueryFailed).?;
try std.testing.expectEqual(0, payload.query_idx);
try std.testing.expectEqualStrings("long error message whi...", payload.slice());
failed = true;
},
else => unreachable,
};
try std.testing.expect(failed);
}
{
var failed = false;
var qb_diag = OfFunction(QueryExecutor.executeQueries){};
qb.executeQueries(
&alloc,
"SELECT * FROM table WHERE name LIKE '{s}'",
&.{.{"this is a very long pattern that won't fit in the allocator buffer"}},
&out_buf,
&qb_diag,
) catch |err| switch (err) {
error.QueryTooLarge => {
const payload = qb_diag.get(error.QueryTooLarge).?;
try std.testing.expectEqual(0, payload.query_idx);
try std.testing.expectEqual(100, payload.max_size.?);
failed = true;
},
else => unreachable,
};
try std.testing.expect(failed);
}
}
}
test "call works for funcs taking optional diag ptr" {
const Diag = FromUnion(union(enum) { Oops: i32 });
const Foo = struct {
pub fn foo(diag: ?*Diag) Error(@TypeOf(diag))!void {
if (diag) |d| {
return d.withContext(error.Oops, 123);
}
return error.Oops;
}
};
const Diag2 = FromUnion(union(enum) { Oops: i32 });
var diag: Diag2 = .{};
var failed = false;
diag.call(Foo.foo, .{}) catch |err| switch (err) {
error.Oops => {
try std.testing.expectEqual(123, diag.get(error.Oops).?);
failed = true;
},
};
try std.testing.expect(failed);
}
// Helper to compute a method's return type via @TypeOf(@call(...)),
// since @typeInfo(func).@"fn".return_type is sometimes null.
//
// The null return_value happens when a function has an inline diag type,
// uses diagnostics.Error(@TypeOf(diag)) in the return type, and then
// that function is called via diag.callMethod(...).
fn MethodReturnType(V: type, comptime m: MethodEnum(V), A: type) type {
const v_info = @typeInfo(V);
if (v_info != .pointer) {
@compileError("callMethod: receiver must be passed as a pointer");
}
const func = method(V, m);
comptime var needs_deref = false;
comptime {
const num_v_layers = numPointerLayers(V);
const num_param_layers = numPointerLayers(ParameterType(func, 0));
if (num_v_layers == num_param_layers + 1) {
needs_deref = true;
}
}
const a: A = undefined;
if (needs_deref) {
const v: @typeInfo(V).pointer.child = undefined;
return @TypeOf(@call(
.auto,
func,
.{v} ++ a ++ .{@constCast(&OfFunction(func){})},
));
} else {
const v: V = undefined;
return @TypeOf(@call(
.auto,
func,
.{v} ++ a ++ .{@constCast(&OfFunction(func){})},
));
}
}
fn numPointerLayers(T: type) usize {
var count: usize = 0;
comptime var current = T;
inline while (@typeInfo(current) == .pointer) {
count += 1;
current = @typeInfo(current).pointer.child;
}
return count;
}
fn method(T: type, comptime meth: MethodEnum(T)) MethodType(T, meth) {
const StructT = InnerStructType(T);
return @field(StructT, @tagName(meth));
}
fn MethodType(T: type, comptime meth: MethodEnum(T)) type {
const StructT = InnerStructType(T);
return @TypeOf(@field(StructT, @tagName(meth)));
}
fn isReceiverFunction(T: type, func: anytype) bool {
const fn_info = @typeInfo(@TypeOf(func)).@"fn";
if (fn_info.params.len == 0) return false;
const FirstParam = fn_info.params[0].type orelse return false;
switch (@typeInfo(FirstParam)) {
.pointer => |p| {
return p.child == T;
},
.@"struct" => {
return FirstParam == T;
},
else => return false,
}
}
pub fn MethodEnum(comptime T: type) type {
// Resolve underlying struct type if T is a pointer type.
const StructT = InnerStructType(T);
const info = @typeInfo(StructT);
const decls = info.@"struct".decls;
comptime var n: usize = 0;
inline for (decls) |d| {
const v = @field(StructT, d.name);
if (@typeInfo(@TypeOf(v)) != .@"fn") continue;
if (!isReceiverFunction(StructT, v)) continue;
n += 1;
}
var fields: [n]std.builtin.Type.EnumField = undefined;
comptime var i: usize = 0;
inline for (decls) |d| {
const v = @field(StructT, d.name);
if (@typeInfo(@TypeOf(v)) != .@"fn") continue;
if (!isReceiverFunction(StructT, v)) continue;
fields[i] = .{
.name = d.name,
.value = i,
};
i += 1;
}
return @Type(.{
.@"enum" = .{
.tag_type = usize,
.fields = &fields,
.decls = &.{},
.is_exhaustive = true,
},
});
}
fn InnerStructType(T: type) type {
var StructT = T;
while (true) {
StructT = switch (@typeInfo(StructT)) {
.pointer => |p| p.child,
.@"struct" => break,
else => @compileError("MethodEnum expects a struct or pointer to struct type"),
};
}
return StructT;
}
fn ParameterType(func: anytype, comptime i: usize) type {
const fn_info = @typeInfo(@TypeOf(func)).@"fn";
if (i >= fn_info.params.len)
@compileError("ParameterType: function has fewer than " ++ std.fmt.comptimeInt(i + 1) ++ " parameters");
return fn_info.params[i].type orelse
@compileError("ParameterType: parameter " ++ std.fmt.comptimeInt(i) ++ " has no concrete type");
}
fn numParameters(func: anytype) usize {
const fn_info = @typeInfo(@TypeOf(func)).@"fn";
return fn_info.params.len;
}
fn ErrorSetFromEnum(comptime E: type) type {
const einfo = @typeInfo(E);
if (einfo != .@"enum") @compileError("expected an enum type");
const fields = einfo.@"enum".fields;
comptime var errs: [fields.len]std.builtin.Type.Error = undefined;
inline for (fields, 0..) |f, i| {
errs[i] = .{ .name = f.name };
}
return @Type(.{ .error_set = errs[0..] });
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment