diff --git a/.gitignore b/.gitignore index 5b90e79..6bd5d39 100644 --- a/.gitignore +++ b/.gitignore @@ -25,3 +25,4 @@ go.work.sum # env file .env +/test.mako \ No newline at end of file diff --git a/compiler/compiler.go b/compiler/compiler.go index cf95c7d..b01b7fe 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -110,6 +110,10 @@ func (c *compiler) compileExpression(expr parser.Expression) { constIndex := c.addConstant(e.Value) c.emit(types.OpConstant, constIndex) + case *parser.BooleanLiteral: + constIndex := c.addConstant(e.Value) + c.emit(types.OpConstant, constIndex) + case *parser.Identifier: nameIndex := c.addConstant(e.Value) @@ -147,7 +151,7 @@ func (c *compiler) compileExpression(expr parser.Expression) { c.compileExpression(e.Index) c.emit(types.OpGetIndex, 0) - // New expression types for arithmetic + // Arithmetic expressions case *parser.InfixExpression: // Compile left and right expressions c.compileExpression(e.Left) @@ -163,6 +167,18 @@ func (c *compiler) compileExpression(expr parser.Expression) { c.emit(types.OpMultiply, 0) case "/": c.emit(types.OpDivide, 0) + case "==": + c.emit(types.OpEqual, 0) + case "!=": + c.emit(types.OpNotEqual, 0) + case "<": + c.emit(types.OpLessThan, 0) + case ">": + c.emit(types.OpGreaterThan, 0) + case "<=": + c.emit(types.OpLessEqual, 0) + case ">=": + c.emit(types.OpGreaterEqual, 0) default: panic(fmt.Sprintf("Unknown infix operator: %s", e.Operator)) } @@ -182,6 +198,80 @@ func (c *compiler) compileExpression(expr parser.Expression) { case *parser.GroupedExpression: // Just compile the inner expression c.compileExpression(e.Expr) + + case *parser.IfExpression: + // Compile condition + c.compileExpression(e.Condition) + + // Emit jump-if-false with placeholder + jumpNotTruePos := len(c.instructions) + c.emit(types.OpJumpIfFalse, 0) // Will backpatch + + // Compile consequence (then block) + lastStmtIndex := len(e.Consequence.Statements) - 1 + for i, stmt := range e.Consequence.Statements { + if i == lastStmtIndex { + // For the last statement, we need to ensure it leaves a value + if exprStmt, ok := stmt.(*parser.ExpressionStatement); ok { + c.compileExpression(exprStmt.Expression) + } else { + c.compileStatement(stmt) + // Push null if not an expression statement + nullIndex := c.addConstant(nil) + c.emit(types.OpConstant, nullIndex) + } + } else { + c.compileStatement(stmt) + } + } + + // If no statements, push null + if len(e.Consequence.Statements) == 0 { + nullIndex := c.addConstant(nil) + c.emit(types.OpConstant, nullIndex) + } + + // Emit jump to skip else part + jumpPos := len(c.instructions) + c.emit(types.OpJump, 0) // Will backpatch + + // Backpatch jump-if-false to point to else + afterConsequencePos := len(c.instructions) + c.instructions[jumpNotTruePos].Operand = afterConsequencePos + + // Compile alternative (else block) + if e.Alternative != nil { + lastStmtIndex = len(e.Alternative.Statements) - 1 + for i, stmt := range e.Alternative.Statements { + if i == lastStmtIndex { + // For the last statement, we need to ensure it leaves a value + if exprStmt, ok := stmt.(*parser.ExpressionStatement); ok { + c.compileExpression(exprStmt.Expression) + } else { + c.compileStatement(stmt) + // Push null if not an expression statement + nullIndex := c.addConstant(nil) + c.emit(types.OpConstant, nullIndex) + } + } else { + c.compileStatement(stmt) + } + } + + // If no statements, push null + if len(e.Alternative.Statements) == 0 { + nullIndex := c.addConstant(nil) + c.emit(types.OpConstant, nullIndex) + } + } else { + // No else - push null + nullIndex := c.addConstant(nil) + c.emit(types.OpConstant, nullIndex) + } + + // Backpatch jump to point after else + afterAlternativePos := len(c.instructions) + c.instructions[jumpPos].Operand = afterAlternativePos } } diff --git a/lexer/lexer.go b/lexer/lexer.go index ace2f75..425914c 100644 --- a/lexer/lexer.go +++ b/lexer/lexer.go @@ -21,6 +21,18 @@ const ( TokenSlash TokenLeftParen TokenRightParen + TokenIf + TokenThen + TokenElse + TokenTrue + TokenFalse + TokenEqualEqual + TokenNotEqual + TokenLessThan + TokenGreaterThan + TokenLessEqual + TokenGreaterEqual + TokenEnd ) type Token struct { @@ -55,10 +67,37 @@ func (l *Lexer) NextToken() Token { var tok Token l.skipWhitespace() + l.skipComment() switch l.ch { case '=': - tok = Token{Type: TokenEqual, Value: "="} + if l.peekChar() == '=' { + l.readChar() // consume the current '=' + tok = Token{Type: TokenEqualEqual, Value: "=="} + } else { + tok = Token{Type: TokenEqual, Value: "="} + } + case '!': + if l.peekChar() == '=' { + l.readChar() // consume the current '!' + tok = Token{Type: TokenNotEqual, Value: "!="} + } else { + tok = Token{Type: TokenEOF, Value: ""} // Not supported yet + } + case '<': + if l.peekChar() == '=' { + l.readChar() // consume the current '<' + tok = Token{Type: TokenLessEqual, Value: "<="} + } else { + tok = Token{Type: TokenLessThan, Value: "<"} + } + case '>': + if l.peekChar() == '=' { + l.readChar() // consume the current '>' + tok = Token{Type: TokenGreaterEqual, Value: ">="} + } else { + tok = Token{Type: TokenGreaterThan, Value: ">"} + } case ';': tok = Token{Type: TokenSemicolon, Value: ";"} case '"': @@ -74,7 +113,6 @@ func (l *Lexer) NextToken() Token { tok = Token{Type: TokenRightBracket, Value: "]"} case ',': tok = Token{Type: TokenComma, Value: ","} - // New arithmetic operators case '+': tok = Token{Type: TokenPlus, Value: "+"} case '-': @@ -92,9 +130,22 @@ func (l *Lexer) NextToken() Token { default: if isLetter(l.ch) { tok.Value = l.readIdentifier() - if tok.Value == "echo" { + switch tok.Value { + case "echo": tok.Type = TokenEcho - } else { + case "if": + tok.Type = TokenIf + case "then": + tok.Type = TokenThen + case "else": + tok.Type = TokenElse + case "true": + tok.Type = TokenTrue + case "false": + tok.Type = TokenFalse + case "end": + tok.Type = TokenEnd + default: tok.Type = TokenIdentifier } return tok @@ -153,3 +204,23 @@ func isLetter(ch byte) bool { func isDigit(ch byte) bool { return '0' <= ch && ch <= '9' } + +func (l *Lexer) peekChar() byte { + if l.readPos >= len(l.input) { + return 0 + } + return l.input[l.readPos] +} + +func (l *Lexer) skipComment() { + if l.ch == '/' && l.peekChar() == '/' { + l.readChar() + l.readChar() + + for l.ch != '\n' && l.ch != 0 { + l.readChar() + } + + l.skipWhitespace() + } +} diff --git a/mako.go b/mako.go new file mode 100644 index 0000000..8371d38 --- /dev/null +++ b/mako.go @@ -0,0 +1,68 @@ +package main + +import ( + "bufio" + "fmt" + "os" + + "git.sharkk.net/Sharkk/Mako/compiler" + "git.sharkk.net/Sharkk/Mako/lexer" + "git.sharkk.net/Sharkk/Mako/parser" + "git.sharkk.net/Sharkk/Mako/vm" +) + +func RunRepl() { + scanner := bufio.NewScanner(os.Stdin) + virtualMachine := vm.New() + + fmt.Println("Mako REPL (type 'exit' to quit)") + for { + fmt.Print(">> ") + if !scanner.Scan() { + break + } + + input := scanner.Text() + if input == "exit" { + break + } + + ExecuteCode(input, virtualMachine) + } +} + +func ExecuteCode(code string, virtualMachine *vm.VM) { + lex := lexer.New(code) + p := parser.New(lex) + program := p.ParseProgram() + + if len(p.Errors()) > 0 { + for _, err := range p.Errors() { + fmt.Printf("Error: %s\n", err) + } + return + } + + bytecode := compiler.Compile(program) + virtualMachine.Run(bytecode) +} + +func main() { + args := os.Args[1:] + + // If there's a command line argument, try to execute it as a file + if len(args) > 0 { + filename := args[0] + fileContent, err := os.ReadFile(filename) + if err != nil { + fmt.Printf("Error reading file %s: %v\n", filename, err) + os.Exit(1) + } + + // Execute the file content + ExecuteCode(string(fileContent), vm.New()) + } else { + // No arguments, run the REPL + RunRepl() + } +} diff --git a/parser/ast.go b/parser/ast.go index 03902fd..0c82620 100644 --- a/parser/ast.go +++ b/parser/ast.go @@ -138,3 +138,21 @@ type GroupedExpression struct { func (ge *GroupedExpression) expressionNode() {} func (ge *GroupedExpression) TokenLiteral() string { return ge.Token.Value } + +type IfExpression struct { + Token lexer.Token // The 'if' token + Condition Expression + Consequence *BlockStatement + Alternative *BlockStatement // nil if no 'else' +} + +func (ie *IfExpression) expressionNode() {} +func (ie *IfExpression) TokenLiteral() string { return ie.Token.Value } + +type BooleanLiteral struct { + Token lexer.Token + Value bool +} + +func (bl *BooleanLiteral) expressionNode() {} +func (bl *BooleanLiteral) TokenLiteral() string { return bl.Token.Value } diff --git a/parser/parser.go b/parser/parser.go index 8a2e9b0..0fd2d75 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -18,11 +18,17 @@ const ( ) var precedences = map[lexer.TokenType]int{ - lexer.TokenPlus: SUM, - lexer.TokenMinus: SUM, - lexer.TokenStar: PRODUCT, - lexer.TokenSlash: PRODUCT, - lexer.TokenLeftBracket: INDEX, + lexer.TokenPlus: SUM, + lexer.TokenMinus: SUM, + lexer.TokenStar: PRODUCT, + lexer.TokenSlash: PRODUCT, + lexer.TokenLeftBracket: INDEX, + lexer.TokenEqualEqual: LOWEST + 1, + lexer.TokenNotEqual: LOWEST + 1, + lexer.TokenLessThan: LOWEST + 1, + lexer.TokenGreaterThan: LOWEST + 1, + lexer.TokenLessEqual: LOWEST + 1, + lexer.TokenGreaterEqual: LOWEST + 1, } type ( @@ -55,6 +61,9 @@ func New(l *lexer.Lexer) *Parser { p.registerPrefix(lexer.TokenLeftBrace, p.parseTableLiteral) p.registerPrefix(lexer.TokenMinus, p.parsePrefixExpression) p.registerPrefix(lexer.TokenLeftParen, p.parseGroupedExpression) + p.registerPrefix(lexer.TokenIf, p.parseIfExpression) // New + p.registerPrefix(lexer.TokenTrue, p.parseBooleanLiteral) // New + p.registerPrefix(lexer.TokenFalse, p.parseBooleanLiteral) // New // Initialize infix parse functions p.infixParseFns = make(map[lexer.TokenType]infixParseFn) @@ -64,6 +73,14 @@ func New(l *lexer.Lexer) *Parser { p.registerInfix(lexer.TokenSlash, p.parseInfixExpression) p.registerInfix(lexer.TokenLeftBracket, p.parseIndexExpression) + // Register comparison operators + p.registerInfix(lexer.TokenEqualEqual, p.parseInfixExpression) + p.registerInfix(lexer.TokenNotEqual, p.parseInfixExpression) + p.registerInfix(lexer.TokenLessThan, p.parseInfixExpression) + p.registerInfix(lexer.TokenGreaterThan, p.parseInfixExpression) + p.registerInfix(lexer.TokenLessEqual, p.parseInfixExpression) + p.registerInfix(lexer.TokenGreaterEqual, p.parseInfixExpression) + // Read two tokens, so curToken and peekToken are both set p.nextToken() p.nextToken() @@ -415,3 +432,58 @@ func (p *Parser) parseGroupedExpression() Expression { Expr: exp, } } + +func (p *Parser) parseBooleanLiteral() Expression { + return &BooleanLiteral{ + Token: p.curToken, + Value: p.curTokenIs(lexer.TokenTrue), + } +} + +func (p *Parser) parseIfExpression() Expression { + expression := &IfExpression{Token: p.curToken} + + p.nextToken() // Skip 'if' + + // Parse condition + expression.Condition = p.parseExpression(LOWEST) + + if !p.expectPeek(lexer.TokenThen) { + return nil + } + + p.nextToken() // Skip 'then' + + // Parse consequence (then block) + if p.curTokenIs(lexer.TokenLeftBrace) { + expression.Consequence = p.parseBlockStatement() + } else { + // For single statement without braces + stmt := &BlockStatement{Token: p.curToken} + stmt.Statements = []Statement{p.parseStatement()} + expression.Consequence = stmt + } + + // Check for 'else' + if p.peekTokenIs(lexer.TokenElse) { + p.nextToken() // Skip current token + p.nextToken() // Skip 'else' + + // Parse alternative (else block) + if p.curTokenIs(lexer.TokenLeftBrace) { + expression.Alternative = p.parseBlockStatement() + } else { + // For single statement without braces + stmt := &BlockStatement{Token: p.curToken} + stmt.Statements = []Statement{p.parseStatement()} + expression.Alternative = stmt + } + } + + // Check for 'end' if we had a then block without braces + if !p.curTokenIs(lexer.TokenRightBrace) && p.peekTokenIs(lexer.TokenEnd) { + p.nextToken() // Consume 'end' + } + + return expression +} diff --git a/repl.go b/repl.go deleted file mode 100644 index 80fe43f..0000000 --- a/repl.go +++ /dev/null @@ -1,45 +0,0 @@ -// File: cmd/main.go -package main - -import ( - "bufio" - "fmt" - "os" - - "git.sharkk.net/Sharkk/Mako/compiler" - "git.sharkk.net/Sharkk/Mako/lexer" - "git.sharkk.net/Sharkk/Mako/parser" - "git.sharkk.net/Sharkk/Mako/vm" -) - -func main() { - scanner := bufio.NewScanner(os.Stdin) - virtualMachine := vm.New() - - fmt.Println("Mako REPL (type 'exit' to quit)") - for { - fmt.Print(">> ") - if !scanner.Scan() { - break - } - - input := scanner.Text() - if input == "exit" { - break - } - - lex := lexer.New(input) - p := parser.New(lex) - program := p.ParseProgram() - - if len(p.Errors()) > 0 { - for _, err := range p.Errors() { - fmt.Printf("Error: %s\n", err) - } - continue - } - - bytecode := compiler.Compile(program) - virtualMachine.Run(bytecode) - } -} diff --git a/types/types.go b/types/types.go index ed5be32..b96ee07 100644 --- a/types/types.go +++ b/types/types.go @@ -31,6 +31,14 @@ const ( OpMultiply OpDivide OpNegate + OpJumpIfFalse + OpJump + OpEqual + OpNotEqual + OpLessThan + OpGreaterThan + OpLessEqual + OpGreaterEqual ) type Instruction struct { @@ -60,6 +68,10 @@ func NewNumber(n float64) Value { return Value{Type: TypeNumber, Data: n} } +func NewBoolean(b bool) Value { + return Value{Type: TypeBoolean, Data: b} +} + // TableEntry maintains insertion order type TableEntry struct { Key Value diff --git a/vm/vm.go b/vm/vm.go index 7b95788..f9fe05e 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -44,6 +44,10 @@ func (vm *VM) Run(bytecode *types.Bytecode) { vm.push(types.NewString(v)) case float64: vm.push(types.NewNumber(v)) + case bool: + vm.push(types.NewBoolean(v)) + case nil: + vm.push(types.NewNull()) } case types.OpSetLocal: @@ -164,6 +168,27 @@ func (vm *VM) Run(bytecode *types.Bytecode) { fmt.Println(vm.formatTable(value.Data.(*types.Table))) } + // Jump instructions + case types.OpJumpIfFalse: + condition := vm.pop() + // Consider falsy: false, null, 0 + shouldJump := false + + if condition.Type == types.TypeBoolean && !condition.Data.(bool) { + shouldJump = true + } else if condition.Type == types.TypeNull { + shouldJump = true + } else if condition.Type == types.TypeNumber && condition.Data.(float64) == 0 { + shouldJump = true + } + + if shouldJump { + ip = instruction.Operand - 1 // -1 because loop will increment + } + + case types.OpJump: + ip = instruction.Operand - 1 // -1 because loop will increment + // Arithmetic operations case types.OpAdd: right := vm.pop() @@ -233,16 +258,178 @@ func (vm *VM) Run(bytecode *types.Bytecode) { fmt.Println("Error: cannot negate non-number value") vm.push(types.NewNull()) } + + // Comparison operators with safer implementation + case types.OpEqual: + if vm.sp < 2 { + fmt.Println("Error: not enough operands for equality comparison") + vm.push(types.NewBoolean(false)) + continue + } + + right := vm.pop() + left := vm.pop() + + if left.Type != right.Type { + vm.push(types.NewBoolean(false)) + } else { + switch left.Type { + case types.TypeNumber: + vm.push(types.NewBoolean(left.Data.(float64) == right.Data.(float64))) + case types.TypeString: + vm.push(types.NewBoolean(left.Data.(string) == right.Data.(string))) + case types.TypeBoolean: + vm.push(types.NewBoolean(left.Data.(bool) == right.Data.(bool))) + case types.TypeNull: + vm.push(types.NewBoolean(true)) // null == null + default: + vm.push(types.NewBoolean(false)) + } + } + + case types.OpNotEqual: + if vm.sp < 2 { + fmt.Println("Error: not enough operands for inequality comparison") + vm.push(types.NewBoolean(true)) + continue + } + + right := vm.pop() + left := vm.pop() + + if left.Type != right.Type { + vm.push(types.NewBoolean(true)) + } else { + switch left.Type { + case types.TypeNumber: + vm.push(types.NewBoolean(left.Data.(float64) != right.Data.(float64))) + case types.TypeString: + vm.push(types.NewBoolean(left.Data.(string) != right.Data.(string))) + case types.TypeBoolean: + vm.push(types.NewBoolean(left.Data.(bool) != right.Data.(bool))) + case types.TypeNull: + vm.push(types.NewBoolean(false)) // null != null is false + default: + vm.push(types.NewBoolean(true)) + } + } + + case types.OpLessThan: + if vm.sp < 2 { + fmt.Println("Error: not enough operands for less-than comparison") + vm.push(types.NewBoolean(false)) + continue + } + + // Peek at values first before popping + right := vm.stack[vm.sp-1] + left := vm.stack[vm.sp-2] + + if left.Type == types.TypeNumber && right.Type == types.TypeNumber { + // Now pop them + vm.pop() + vm.pop() + vm.push(types.NewBoolean(left.Data.(float64) < right.Data.(float64))) + } else { + // Pop the values to maintain stack balance + vm.pop() + vm.pop() + fmt.Println("Error: cannot compare non-number values with <") + vm.push(types.NewBoolean(false)) + } + + case types.OpGreaterThan: + if vm.sp < 2 { + fmt.Println("Error: not enough operands for greater-than comparison") + vm.push(types.NewBoolean(false)) + continue + } + + // Peek at values first before popping + right := vm.stack[vm.sp-1] + left := vm.stack[vm.sp-2] + + if left.Type == types.TypeNumber && right.Type == types.TypeNumber { + // Now pop them + vm.pop() + vm.pop() + vm.push(types.NewBoolean(left.Data.(float64) > right.Data.(float64))) + } else { + // Pop the values to maintain stack balance + vm.pop() + vm.pop() + fmt.Println("Error: cannot compare non-number values with >") + vm.push(types.NewBoolean(false)) + } + + case types.OpLessEqual: + if vm.sp < 2 { + fmt.Println("Error: not enough operands for less-equal comparison") + vm.push(types.NewBoolean(false)) + continue + } + + // Peek at values first before popping + right := vm.stack[vm.sp-1] + left := vm.stack[vm.sp-2] + + if left.Type == types.TypeNumber && right.Type == types.TypeNumber { + // Now pop them + vm.pop() + vm.pop() + vm.push(types.NewBoolean(left.Data.(float64) <= right.Data.(float64))) + } else { + // Pop the values to maintain stack balance + vm.pop() + vm.pop() + fmt.Println("Error: cannot compare non-number values with <=") + vm.push(types.NewBoolean(false)) + } + + case types.OpGreaterEqual: + if vm.sp < 2 { + fmt.Println("Error: not enough operands for greater-equal comparison") + vm.push(types.NewBoolean(false)) + continue + } + + // Peek at values first before popping + right := vm.stack[vm.sp-1] + left := vm.stack[vm.sp-2] + + if left.Type == types.TypeNumber && right.Type == types.TypeNumber { + // Now pop them + vm.pop() + vm.pop() + vm.push(types.NewBoolean(left.Data.(float64) >= right.Data.(float64))) + } else { + // Pop the values to maintain stack balance + vm.pop() + vm.pop() + fmt.Println("Error: cannot compare non-number values with >=") + vm.push(types.NewBoolean(false)) + } } } } func (vm *VM) push(value types.Value) { + if vm.sp >= len(vm.stack) { + // Grow stack if needed + newStack := make([]types.Value, len(vm.stack)*2) + copy(newStack, vm.stack) + vm.stack = newStack + } vm.stack[vm.sp] = value vm.sp++ } func (vm *VM) pop() types.Value { + if vm.sp <= 0 { + // Return null instead of causing a panic when trying to pop from an empty stack + fmt.Println("Stack underflow error") + return types.NewNull() + } vm.sp-- return vm.stack[vm.sp] }