Last active
February 9, 2026 20:57
-
-
Save sno2/d7b574f139fe2eaf966ca3bd3cebafc1 to your computer and use it in GitHub Desktop.
tiny Go math
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| package main | |
| import ( | |
| "fmt" | |
| "log" | |
| "math" | |
| "os" | |
| "strconv" | |
| ) | |
| type Token uint8 | |
| const ( | |
| TokenEof Token = iota | |
| TokenInvalid | |
| TokenFloat | |
| TokenPlus | |
| TokenAsterisk | |
| TokenSlash | |
| TokenCaret | |
| TokenLeftParen | |
| TokenRightParen | |
| ) | |
| type State uint8 | |
| const ( | |
| StateInit State = iota | |
| StateInt | |
| StateFloat | |
| ) | |
| type Lexer struct { | |
| source []byte // nul-terminated | |
| start int | |
| index int | |
| token Token | |
| } | |
| func (lex *Lexer) nextInner() Token { | |
| state := StateInit | |
| for { | |
| b := lex.source[lex.index] | |
| switch state { | |
| case StateInit: | |
| lex.start = lex.index | |
| switch { | |
| case b == 0: | |
| return TokenEof | |
| case b == ' ' || b == '\t' || b == '\r' || b == '\n': | |
| lex.index += 1 | |
| case b >= '0' && b <= '9': | |
| lex.index += 1 | |
| state = StateInt | |
| case b == '+': | |
| lex.index += 1 | |
| return TokenPlus | |
| case b == '*': | |
| lex.index += 1 | |
| return TokenAsterisk | |
| case b == '/': | |
| lex.index += 1 | |
| return TokenSlash | |
| case b == '^': | |
| lex.index += 1 | |
| return TokenCaret | |
| case b == '(': | |
| lex.index += 1 | |
| return TokenLeftParen | |
| case b == ')': | |
| lex.index += 1 | |
| return TokenRightParen | |
| default: | |
| lex.index += 1 | |
| return TokenInvalid | |
| } | |
| case StateInt: | |
| switch { | |
| case b >= '0' && b <= '9': | |
| lex.index += 1 | |
| case b == '.': | |
| lex.index += 1 | |
| state = StateFloat | |
| default: | |
| return TokenFloat | |
| } | |
| case StateFloat: | |
| switch { | |
| case b >= '0' && b <= '9': | |
| lex.index += 1 | |
| default: | |
| return TokenFloat | |
| } | |
| } | |
| } | |
| } | |
| func (lex *Lexer) next() { | |
| lex.token = lex.nextInner() | |
| } | |
| type InstructionTag uint8 | |
| const ( | |
| InstructionReturn InstructionTag = iota // terminator | |
| InstructionPushFloat | |
| InstructionAdd | |
| InstructionMul | |
| InstructionDiv | |
| InstructionPow | |
| ) | |
| type Instruction struct { // tagged union-like | |
| tag InstructionTag | |
| arg float64 // change to [N]uint64 and bitcast if you have many argument types | |
| } | |
| type Compiler struct { | |
| lex Lexer | |
| instructions []Instruction | |
| // an index into the source for mapping errors is usually stored in some | |
| // out-of-bound list | |
| // e.g. []uint32 where each instruction has a single index. | |
| // then, you can run your lexer to get a full token to highlight in an error | |
| // message | |
| } | |
| func (comp *Compiler) tokenSource() []byte { | |
| return comp.lex.source[comp.lex.start:comp.lex.index] | |
| } | |
| func (comp *Compiler) addInstruction(insn Instruction) { | |
| comp.instructions = append(comp.instructions, insn) | |
| } | |
| func (comp *Compiler) compilePrimaryExpression() { | |
| switch comp.lex.token { | |
| case TokenLeftParen: | |
| comp.lex.next() | |
| comp.compileExpression() | |
| if comp.lex.token != TokenRightParen { | |
| log.Fatalf("expected ')'") // lazy | |
| } | |
| comp.lex.next() | |
| case TokenFloat: | |
| value, _ := strconv.ParseFloat(string(comp.tokenSource()), 64) | |
| comp.lex.next() | |
| comp.addInstruction(Instruction{ | |
| tag: InstructionPushFloat, | |
| arg: value, | |
| }) | |
| default: | |
| log.Fatalln("expected primary expression") | |
| } | |
| } | |
| type Assoc uint8 | |
| const ( | |
| AssocLeftToRight Assoc = iota | |
| AssocRightToLeft | |
| ) | |
| func (tok Token) infixOperator() (int8, Assoc, InstructionTag) { | |
| switch tok { | |
| case TokenPlus: | |
| return 1, AssocLeftToRight, InstructionAdd | |
| case TokenAsterisk: | |
| return 2, AssocLeftToRight, InstructionMul | |
| case TokenSlash: | |
| return 2, AssocLeftToRight, InstructionDiv | |
| case TokenCaret: | |
| return 3, AssocRightToLeft, InstructionPow | |
| default: | |
| return -1, 0, 0 | |
| } | |
| } | |
| func (comp *Compiler) compileExpressionPrec(minPrec int8) { | |
| comp.compilePrimaryExpression() | |
| for { | |
| // you may need to check for other types of expressions here | |
| // e.g. '(' if you want function calls or '.' if you want field access | |
| prec, assoc, tag := comp.lex.token.infixOperator() | |
| if assoc == AssocLeftToRight && prec <= minPrec || | |
| assoc == AssocRightToLeft && prec < minPrec { | |
| break | |
| } | |
| comp.lex.next() | |
| comp.compileExpressionPrec(prec) | |
| comp.addInstruction(Instruction{tag: tag}) | |
| } | |
| } | |
| func (comp *Compiler) compileExpression() { | |
| comp.compileExpressionPrec(0) | |
| } | |
| func (comp *Compiler) Compile() { | |
| comp.lex.next() | |
| comp.compileExpression() | |
| if comp.lex.token != TokenEof { | |
| log.Fatalf("expected eof, got %d\n", comp.lex.token) | |
| } | |
| comp.addInstruction(Instruction{tag: InstructionReturn}) | |
| } | |
| type Vm struct { | |
| instructions []Instruction | |
| stack []float64 | |
| index int | |
| } | |
| func (vm *Vm) pop() float64 { | |
| end := len(vm.stack) - 1 | |
| value := vm.stack[end] | |
| vm.stack = vm.stack[:end] | |
| return value | |
| } | |
| func (vm *Vm) Execute() float64 { | |
| for { | |
| insn := vm.instructions[vm.index] | |
| vm.index += 1 | |
| switch insn.tag { | |
| case InstructionReturn: | |
| if len(vm.stack) != 1 { | |
| log.Fatalln("bug: extra values left on stack") | |
| } | |
| return vm.pop() | |
| case InstructionPushFloat: | |
| vm.stack = append(vm.stack, insn.arg) | |
| case InstructionAdd: | |
| right, left := vm.pop(), vm.pop() | |
| vm.stack = append(vm.stack, left+right) | |
| case InstructionMul: | |
| right, left := vm.pop(), vm.pop() | |
| vm.stack = append(vm.stack, left*right) | |
| case InstructionDiv: | |
| right, left := vm.pop(), vm.pop() | |
| vm.stack = append(vm.stack, left/right) | |
| case InstructionPow: | |
| right, left := vm.pop(), vm.pop() | |
| vm.stack = append(vm.stack, math.Pow(left, right)) | |
| } | |
| } | |
| } | |
| func test(expected float64, expression string) { | |
| comp := Compiler{lex: Lexer{source: []byte(expression + "\x00")}} | |
| comp.Compile() | |
| vm := Vm{instructions: comp.instructions} | |
| result := vm.Execute() | |
| if result != expected { | |
| fmt.Printf("FAIL: expected %f, got %f for %s\n", expected, result, expression) | |
| os.Exit(1) | |
| } else { | |
| fmt.Printf("PASS: %8f == %s\n", result, expression) | |
| } | |
| } | |
| func main() { | |
| fmt.Println("========= basic expressions =========") | |
| test(1, "1") | |
| test(2.4, "2.4") | |
| test(2, "(((2)))") | |
| fmt.Println("========= left-to-right add =========") | |
| test(3, "1.5+1.5") | |
| test(8.5, "1.5+3+4") | |
| fmt.Println("======= left-to-right mul/div =======") | |
| test(3, "18/2/3") | |
| test(27, "18/(2/3)") | |
| test(2.5, "5/2") | |
| test(5, "2.5*2") | |
| test(4, "2*6/3") | |
| fmt.Println("========= right-to-left pow =========") | |
| test(3, "3^2^0") | |
| test(1024, "(2*2)^5") | |
| test(7, "3^2^0 + (1*4)") | |
| test(15, "3^2^0 * (1+4)") | |
| fmt.Println("============ miscellaneous ==========") | |
| test(52, "7+5*3^2") | |
| test(52, "7+3^2*5") | |
| test(52, "3^2*5+7") | |
| test(108, "3^2*(5+7)") | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment