Created
February 15, 2026 01:48
-
-
Save srcreigh/28cd1ed97904bb646248e95814d320ff to your computer and use it in GitHub Desktop.
Zig union(enum) diagnostics
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| const std = @import("std"); | |
| const Module = @This(); | |
| pub fn FromUnion(comptime Payload: type) type { | |
| const Error = ErrorSetFromEnum(std.meta.FieldEnum(Payload)); | |
| // Need nested call in order for the ZLS param labels to match | |
| // the constants in the generated type. Otherwise there is name | |
| // shadowing error. | |
| return FromUnionInternal(Error, Payload); | |
| } | |
| inline fn FromUnionInternal(comptime E: type, comptime P: type) type { | |
| return struct { | |
| payload: ?Payload = null, | |
| pub const Error = E; | |
| pub const Payload = P; | |
| /// Gets an payload for an error. | |
| pub fn get(self: *const @This(), comptime err: Error) ?PayloadType(err) { | |
| return Module.get(Payload, self.payload, Error, err); | |
| } | |
| /// Sets the payload for this error, and returns the error. | |
| pub inline fn withContext(self: *@This(), comptime err: Error, value: ?PayloadType(err)) Error { | |
| if (value) |v| { | |
| self.payload = @unionInit(Payload, @errorName(err), v); | |
| } | |
| return err; | |
| } | |
| /// Copies an error payload from another Diagnostic and returns the error. | |
| pub inline fn forwardFrom(self: *@This(), other: anytype, comptime err: Error) Error { | |
| if (other.get(@errorCast(err))) |v| { | |
| self.payload = @unionInit(Payload, @errorName(err), v); | |
| } | |
| return err; | |
| } | |
| /// Copies all error payloads from another Diagnostic and returns the error. | |
| pub inline fn forwardAll( | |
| self: *@This(), | |
| other: anytype, | |
| err: @TypeOf(other).Error, | |
| ) @TypeOf(other).Error { | |
| switch (err) { | |
| inline else => |e| { | |
| if (other.get(e)) |v| { | |
| self.payload = @unionInit(Payload, @errorName(e), v); | |
| } | |
| return e; | |
| }, | |
| } | |
| } | |
| pub fn PayloadType(comptime err: Error) type { | |
| return @FieldType(Payload, @errorName(err)); | |
| } | |
| /// Automatically instantiates a functions Diagnostics argument, calls the function, | |
| /// copies the error payload on failure, and returns the result of the function call. | |
| /// | |
| /// The function must take a *Diagnostics as its last argument. Specify the function's | |
| /// arguments as a tuple, without the *Diagnostics argument: | |
| /// | |
| /// const dict = try dt.call(trainDictionary, .{ samples, sizes }); | |
| /// | |
| /// pub const TrainDictionaryDiagnostics = Diagnostics(union(enum) { | |
| /// ZstdError: c.ZSTD_ErrorCode, | |
| /// }); | |
| /// pub fn trainDictionary( | |
| /// samples: []const u8, | |
| /// sizes: []const usize, | |
| /// dt: *TrainDictionaryDiagnostics, | |
| /// ) TrainDictionaryDiagnostics.Error![]const u8 { | |
| /// // ... | |
| /// } | |
| pub inline fn call(self: *@This(), comptime f: anytype, a: anytype) ReturnType(f) { | |
| // TODO: inline call and Module.call produces better profiling traces, but | |
| // puts all the Diagnostics on the same stack frame.. is that an issue? | |
| var diag: FnDiagnosticsType(f) = .{}; | |
| return @call(.auto, f, tupleAppend(a, &diag)) catch |err| { | |
| return self.forwardAll(diag, err); | |
| }; | |
| } | |
| }; | |
| } | |
| fn ReturnType(func: anytype) type { | |
| return @typeInfo(@TypeOf(func)).@"fn".return_type.?; | |
| } | |
| // Returns the payload associated to an error for a function's diagnostics. | |
| pub fn FnErrorPayload(func: anytype, comptime err: anyerror) type { | |
| return @FieldType(FnDiagnosticsType(func).Payload, @errorName(err)); | |
| } | |
| fn FnDiagnosticsType(func: anytype) type { | |
| const fn_info = @typeInfo(@TypeOf(func)).@"fn"; | |
| if (fn_info.is_var_args) | |
| @compileError("call: varargs functions are not supported"); | |
| if (fn_info.params.len == 0) | |
| @compileError("FnDiagnosticsType: function has no parameters (expected trailing Diagnostics pointer)"); | |
| const LastParamType = fn_info.params[fn_info.params.len - 1].type orelse | |
| @compileError("FnDiagnosticsType: last parameter has no concrete type"); | |
| const last_param_info = @typeInfo(LastParamType); | |
| if (last_param_info != .pointer) | |
| @compileError("FnDiagnosticsType: last parameter must be a pointer to a Diagnostics struct"); | |
| return last_param_info.pointer.child; | |
| } | |
| fn tupleAppend(tuple: anytype, extra: anytype) TupleAppend(@TypeOf(tuple), @TypeOf(extra)) { | |
| const In = @TypeOf(tuple); | |
| const in_info = @typeInfo(In).@"struct"; | |
| const Out = TupleAppend(In, @TypeOf(extra)); | |
| var out: Out = undefined; | |
| inline for (in_info.fields) |f| { | |
| @field(out, f.name) = @field(tuple, f.name); | |
| } | |
| const last_name = comptime tupleIndexName(in_info.fields.len); | |
| @field(out, last_name) = extra; | |
| return out; | |
| } | |
| fn TupleAppend(comptime Tuple: type, comptime Extra: type) type { | |
| const ti = @typeInfo(Tuple); | |
| if (ti != .@"struct" or !ti.@"struct".is_tuple) | |
| @compileError("TupleAppend: first argument must be a tuple type"); | |
| const s = ti.@"struct"; | |
| const StructField = std.builtin.Type.StructField; | |
| var fields: [s.fields.len + 1]StructField = undefined; | |
| inline for (s.fields, 0..) |f, i| { | |
| fields[i] = .{ | |
| .name = f.name, | |
| .type = f.type, | |
| .default_value_ptr = f.default_value_ptr, | |
| .is_comptime = f.is_comptime, | |
| .alignment = f.alignment, | |
| }; | |
| } | |
| fields[s.fields.len] = .{ | |
| .name = comptime tupleIndexName(s.fields.len), | |
| .type = Extra, | |
| .default_value_ptr = null, | |
| .is_comptime = false, | |
| .alignment = @alignOf(Extra), | |
| }; | |
| const params = std.builtin.Type.Struct{ | |
| .layout = .auto, | |
| .is_tuple = true, | |
| .fields = &fields, | |
| .decls = &.{}, | |
| }; | |
| return @Type(std.builtin.Type{ .@"struct" = params }); | |
| } | |
| fn tupleIndexName(comptime idx: usize) [:0]const u8 { | |
| comptime var buf: [32]u8 = undefined; | |
| return std.fmt.bufPrintZ(&buf, "{d}", .{idx}) catch unreachable; | |
| } | |
| fn get( | |
| D: type, | |
| dt: ?D, | |
| Err: type, | |
| comptime err: Err, | |
| ) ?@FieldType(D, @errorName(err)) { | |
| const field_name = @errorName(err); | |
| const Tag = std.meta.Tag(D); | |
| if (dt) |s| { | |
| if (std.meta.activeTag(s) == @field(Tag, field_name)) { | |
| return @field(s, field_name); | |
| } else { | |
| return null; | |
| } | |
| } else { | |
| return null; | |
| } | |
| } | |
| fn ErrorSetFromEnum(comptime E: type) type { | |
| const einfo = @typeInfo(E); | |
| if (einfo != .@"enum") @compileError("expected an enum type"); | |
| const fields = einfo.@"enum".fields; | |
| comptime var errs: [fields.len]std.builtin.Type.Error = undefined; | |
| inline for (fields, 0..) |f, i| { | |
| errs[i] = .{ .name = f.name }; | |
| } | |
| return @Type(.{ .error_set = errs[0..] }); | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment