diff --git a/compiler/compiler.go b/compiler/compiler.go index de83107..cf95c7d 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -1,6 +1,8 @@ package compiler import ( + "fmt" + "git.sharkk.net/Sharkk/Mako/parser" "git.sharkk.net/Sharkk/Mako/types" ) @@ -144,6 +146,42 @@ func (c *compiler) compileExpression(expr parser.Expression) { c.compileExpression(e.Left) c.compileExpression(e.Index) c.emit(types.OpGetIndex, 0) + + // New expression types for arithmetic + case *parser.InfixExpression: + // Compile left and right expressions + c.compileExpression(e.Left) + c.compileExpression(e.Right) + + // Generate the appropriate operation + switch e.Operator { + case "+": + c.emit(types.OpAdd, 0) + case "-": + c.emit(types.OpSubtract, 0) + case "*": + c.emit(types.OpMultiply, 0) + case "/": + c.emit(types.OpDivide, 0) + default: + panic(fmt.Sprintf("Unknown infix operator: %s", e.Operator)) + } + + case *parser.PrefixExpression: + // Compile the operand + c.compileExpression(e.Right) + + // Generate the appropriate operation + switch e.Operator { + case "-": + c.emit(types.OpNegate, 0) + default: + panic(fmt.Sprintf("Unknown prefix operator: %s", e.Operator)) + } + + case *parser.GroupedExpression: + // Just compile the inner expression + c.compileExpression(e.Expr) } } diff --git a/lexer/lexer.go b/lexer/lexer.go index c692242..ace2f75 100644 --- a/lexer/lexer.go +++ b/lexer/lexer.go @@ -15,6 +15,12 @@ const ( TokenLeftBracket TokenRightBracket TokenComma + TokenPlus + TokenMinus + TokenStar + TokenSlash + TokenLeftParen + TokenRightParen ) type Token struct { @@ -58,8 +64,6 @@ func (l *Lexer) NextToken() Token { case '"': tok = Token{Type: TokenString, Value: l.readString()} return tok - case 0: - tok = Token{Type: TokenEOF, Value: ""} case '{': tok = Token{Type: TokenLeftBrace, Value: "{"} case '}': @@ -70,6 +74,21 @@ func (l *Lexer) NextToken() Token { tok = Token{Type: TokenRightBracket, Value: "]"} case ',': tok = Token{Type: TokenComma, Value: ","} + // New arithmetic operators + case '+': + tok = Token{Type: TokenPlus, Value: "+"} + case '-': + tok = Token{Type: TokenMinus, Value: "-"} + case '*': + tok = Token{Type: TokenStar, Value: "*"} + case '/': + tok = Token{Type: TokenSlash, Value: "/"} + case '(': + tok = Token{Type: TokenLeftParen, Value: "("} + case ')': + tok = Token{Type: TokenRightParen, Value: ")"} + case 0: + tok = Token{Type: TokenEOF, Value: ""} default: if isLetter(l.ch) { tok.Value = l.readIdentifier() diff --git a/parser/ast.go b/parser/ast.go index 9b7c66d..03902fd 100644 --- a/parser/ast.go +++ b/parser/ast.go @@ -106,3 +106,35 @@ type BlockStatement struct { func (bs *BlockStatement) statementNode() {} func (bs *BlockStatement) TokenLiteral() string { return bs.Token.Value } + +// New AST nodes for arithmetic expressions + +// InfixExpression represents binary operations like: a + b +type InfixExpression struct { + Token lexer.Token // The operator token, e.g. + + Left Expression + Operator string + Right Expression +} + +func (ie *InfixExpression) expressionNode() {} +func (ie *InfixExpression) TokenLiteral() string { return ie.Token.Value } + +// PrefixExpression represents unary operations like: -a +type PrefixExpression struct { + Token lexer.Token // The prefix token, e.g. - + Operator string + Right Expression +} + +func (pe *PrefixExpression) expressionNode() {} +func (pe *PrefixExpression) TokenLiteral() string { return pe.Token.Value } + +// GroupedExpression represents an expression in parentheses: (a + b) +type GroupedExpression struct { + Token lexer.Token // The '(' token + Expr Expression +} + +func (ge *GroupedExpression) expressionNode() {} +func (ge *GroupedExpression) TokenLiteral() string { return ge.Token.Value } diff --git a/parser/parser.go b/parser/parser.go index 7492f04..8a2e9b0 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -7,37 +7,129 @@ import ( "git.sharkk.net/Sharkk/Mako/lexer" ) +// Precedence levels for expression parsing +const ( + _ int = iota + LOWEST + SUM // +, - + PRODUCT // *, / + PREFIX // -X or !X + INDEX // array[index] +) + +var precedences = map[lexer.TokenType]int{ + lexer.TokenPlus: SUM, + lexer.TokenMinus: SUM, + lexer.TokenStar: PRODUCT, + lexer.TokenSlash: PRODUCT, + lexer.TokenLeftBracket: INDEX, +} + +type ( + prefixParseFn func() Expression + infixParseFn func(Expression) Expression +) + type Parser struct { - l *lexer.Lexer + l *lexer.Lexer + errors []string + curToken lexer.Token peekToken lexer.Token - errors []string + + prefixParseFns map[lexer.TokenType]prefixParseFn + infixParseFns map[lexer.TokenType]infixParseFn } func New(l *lexer.Lexer) *Parser { - p := &Parser{l: l, errors: []string{}} + p := &Parser{ + l: l, + errors: []string{}, + } + + // Initialize prefix parse functions + p.prefixParseFns = make(map[lexer.TokenType]prefixParseFn) + p.registerPrefix(lexer.TokenIdentifier, p.parseIdentifier) + p.registerPrefix(lexer.TokenString, p.parseStringLiteral) + p.registerPrefix(lexer.TokenNumber, p.parseNumberLiteral) + p.registerPrefix(lexer.TokenLeftBrace, p.parseTableLiteral) + p.registerPrefix(lexer.TokenMinus, p.parsePrefixExpression) + p.registerPrefix(lexer.TokenLeftParen, p.parseGroupedExpression) + + // Initialize infix parse functions + p.infixParseFns = make(map[lexer.TokenType]infixParseFn) + p.registerInfix(lexer.TokenPlus, p.parseInfixExpression) + p.registerInfix(lexer.TokenMinus, p.parseInfixExpression) + p.registerInfix(lexer.TokenStar, p.parseInfixExpression) + p.registerInfix(lexer.TokenSlash, p.parseInfixExpression) + p.registerInfix(lexer.TokenLeftBracket, p.parseIndexExpression) + + // Read two tokens, so curToken and peekToken are both set p.nextToken() p.nextToken() + return p } +func (p *Parser) registerPrefix(tokenType lexer.TokenType, fn prefixParseFn) { + p.prefixParseFns[tokenType] = fn +} + +func (p *Parser) registerInfix(tokenType lexer.TokenType, fn infixParseFn) { + p.infixParseFns[tokenType] = fn +} + func (p *Parser) nextToken() { p.curToken = p.peekToken p.peekToken = p.l.NextToken() } +func (p *Parser) curTokenIs(t lexer.TokenType) bool { + return p.curToken.Type == t +} + +func (p *Parser) peekTokenIs(t lexer.TokenType) bool { + return p.peekToken.Type == t +} + +func (p *Parser) expectPeek(t lexer.TokenType) bool { + if p.peekTokenIs(t) { + p.nextToken() + return true + } + p.peekError(t) + return false +} + +func (p *Parser) peekError(t lexer.TokenType) { + msg := fmt.Sprintf("expected next token to be %d, got %d instead", t, p.peekToken.Type) + p.errors = append(p.errors, msg) +} + func (p *Parser) Errors() []string { return p.errors } +func (p *Parser) peekPrecedence() int { + if p, ok := precedences[p.peekToken.Type]; ok { + return p + } + return LOWEST +} + +func (p *Parser) curPrecedence() int { + if p, ok := precedences[p.curToken.Type]; ok { + return p + } + return LOWEST +} + func (p *Parser) ParseProgram() *Program { program := &Program{Statements: []Statement{}} - for p.curToken.Type != lexer.TokenEOF { + for !p.curTokenIs(lexer.TokenEOF) { stmt := p.parseStatement() - if stmt != nil { - program.Statements = append(program.Statements, stmt) - } + program.Statements = append(program.Statements, stmt) p.nextToken() } @@ -47,19 +139,43 @@ func (p *Parser) ParseProgram() *Program { func (p *Parser) parseStatement() Statement { switch p.curToken.Type { case lexer.TokenIdentifier: - if p.peekToken.Type == lexer.TokenEqual { + if p.peekTokenIs(lexer.TokenEqual) { return p.parseVariableStatement() - } else if p.peekToken.Type == lexer.TokenLeftBracket { + } else if p.peekTokenIs(lexer.TokenLeftBracket) { return p.parseIndexAssignmentStatement() } + return p.parseExpressionStatement() case lexer.TokenEcho: return p.parseEchoStatement() case lexer.TokenLeftBrace: return p.parseBlockStatement() + default: + return p.parseExpressionStatement() } - return nil } +// New method for expression statements +func (p *Parser) parseExpressionStatement() *ExpressionStatement { + stmt := &ExpressionStatement{Token: p.curToken} + + stmt.Expression = p.parseExpression(LOWEST) + + if p.peekTokenIs(lexer.TokenSemicolon) { + p.nextToken() + } + + return stmt +} + +// Add ExpressionStatement to ast.go +type ExpressionStatement struct { + Token lexer.Token + Expression Expression +} + +func (es *ExpressionStatement) statementNode() {} +func (es *ExpressionStatement) TokenLiteral() string { return es.Token.Value } + func (p *Parser) parseBlockStatement() *BlockStatement { block := &BlockStatement{Token: p.curToken} block.Statements = []Statement{} @@ -68,9 +184,7 @@ func (p *Parser) parseBlockStatement() *BlockStatement { for p.curToken.Type != lexer.TokenRightBrace && p.curToken.Type != lexer.TokenEOF { stmt := p.parseStatement() - if stmt != nil { - block.Statements = append(block.Statements, stmt) - } + block.Statements = append(block.Statements, stmt) p.nextToken() } @@ -82,12 +196,15 @@ func (p *Parser) parseVariableStatement() *VariableStatement { stmt.Name = &Identifier{Token: p.curToken, Value: p.curToken.Value} - p.nextToken() // Skip identifier - p.nextToken() // Skip = + if !p.expectPeek(lexer.TokenEqual) { + return nil + } - stmt.Value = p.parseExpression() + p.nextToken() // Skip the equals sign - if p.peekToken.Type == lexer.TokenSemicolon { + stmt.Value = p.parseExpression(LOWEST) + + if p.peekTokenIs(lexer.TokenSemicolon) { p.nextToken() } @@ -99,9 +216,9 @@ func (p *Parser) parseEchoStatement() *EchoStatement { p.nextToken() - stmt.Value = p.parseExpression() + stmt.Value = p.parseExpression(LOWEST) - if p.peekToken.Type == lexer.TokenSemicolon { + if p.peekTokenIs(lexer.TokenSemicolon) { p.nextToken() } @@ -115,57 +232,82 @@ func (p *Parser) parseIndexAssignmentStatement() *IndexAssignmentStatement { } p.nextToken() // Skip identifier + if !p.expectPeek(lexer.TokenLeftBracket) { + return nil + } + p.nextToken() // Skip '[' + stmt.Index = p.parseExpression(LOWEST) - stmt.Index = p.parseExpression() - - if p.peekToken.Type != lexer.TokenRightBracket { - p.errors = append(p.errors, "expected ] after index expression") - return stmt + if !p.expectPeek(lexer.TokenRightBracket) { + return nil } - p.nextToken() // Skip index - p.nextToken() // Skip ']' - - // Fix: Check current token, not peek token - if p.curToken.Type != lexer.TokenEqual { - p.errors = append(p.errors, "expected = after index expression") - return stmt + if !p.expectPeek(lexer.TokenEqual) { + return nil } - p.nextToken() // Skip = + p.nextToken() // Skip '=' + stmt.Value = p.parseExpression(LOWEST) - stmt.Value = p.parseExpression() - - if p.peekToken.Type == lexer.TokenSemicolon { + if p.peekTokenIs(lexer.TokenSemicolon) { p.nextToken() } return stmt } -func (p *Parser) parseExpression() Expression { - switch p.curToken.Type { - case lexer.TokenString: - return &StringLiteral{Token: p.curToken, Value: p.curToken.Value} - case lexer.TokenNumber: - num, err := strconv.ParseFloat(p.curToken.Value, 64) - if err != nil { - p.errors = append(p.errors, fmt.Sprintf("could not parse %q as float", p.curToken.Value)) - } - return &NumberLiteral{Token: p.curToken, Value: num} - case lexer.TokenIdentifier: - if p.peekToken.Type == lexer.TokenLeftBracket { - return p.parseIndexExpression() - } - return &Identifier{Token: p.curToken, Value: p.curToken.Value} - case lexer.TokenLeftBrace: - return p.parseTableLiteral() +// Core expression parser with precedence climbing +func (p *Parser) parseExpression(precedence int) Expression { + prefix := p.prefixParseFns[p.curToken.Type] + if prefix == nil { + p.noPrefixParseFnError(p.curToken.Type) + return nil } - return nil + leftExp := prefix() + + for !p.peekTokenIs(lexer.TokenSemicolon) && precedence < p.peekPrecedence() { + infix := p.infixParseFns[p.peekToken.Type] + if infix == nil { + return leftExp + } + + p.nextToken() + leftExp = infix(leftExp) + } + + return leftExp } -func (p *Parser) parseTableLiteral() *TableLiteral { +func (p *Parser) noPrefixParseFnError(t lexer.TokenType) { + msg := fmt.Sprintf("no prefix parse function for %d found", t) + p.errors = append(p.errors, msg) +} + +// Expression parsing methods +func (p *Parser) parseIdentifier() Expression { + return &Identifier{Token: p.curToken, Value: p.curToken.Value} +} + +func (p *Parser) parseStringLiteral() Expression { + return &StringLiteral{Token: p.curToken, Value: p.curToken.Value} +} + +func (p *Parser) parseNumberLiteral() Expression { + lit := &NumberLiteral{Token: p.curToken} + + value, err := strconv.ParseFloat(p.curToken.Value, 64) + if err != nil { + msg := fmt.Sprintf("could not parse %q as float", p.curToken.Value) + p.errors = append(p.errors, msg) + return nil + } + + lit.Value = value + return lit +} + +func (p *Parser) parseTableLiteral() Expression { table := &TableLiteral{ Token: p.curToken, Pairs: make(map[Expression]Expression), @@ -173,75 +315,103 @@ func (p *Parser) parseTableLiteral() *TableLiteral { p.nextToken() // Skip '{' - if p.curToken.Type == lexer.TokenRightBrace { + if p.curTokenIs(lexer.TokenRightBrace) { return table // Empty table } // Parse the first key-value pair - key := p.parseExpression() + key := p.parseExpression(LOWEST) - if p.peekToken.Type != lexer.TokenEqual { - p.errors = append(p.errors, "expected = after table key") - return table + if !p.expectPeek(lexer.TokenEqual) { + return nil } - p.nextToken() // Skip key - p.nextToken() // Skip = - - value := p.parseExpression() + p.nextToken() // Skip '=' + value := p.parseExpression(LOWEST) table.Pairs[key] = value - p.nextToken() // Skip value - // Parse remaining key-value pairs - for p.curToken.Type == lexer.TokenComma { + for p.peekTokenIs(lexer.TokenComma) { + p.nextToken() // Skip current value p.nextToken() // Skip comma - if p.curToken.Type == lexer.TokenRightBrace { + if p.curTokenIs(lexer.TokenRightBrace) { break // Allow trailing comma } - key = p.parseExpression() + key = p.parseExpression(LOWEST) - if p.peekToken.Type != lexer.TokenEqual { - p.errors = append(p.errors, "expected = after table key") - return table + if !p.expectPeek(lexer.TokenEqual) { + return nil } - p.nextToken() // Skip key - p.nextToken() // Skip = - - value = p.parseExpression() + p.nextToken() // Skip '=' + value = p.parseExpression(LOWEST) table.Pairs[key] = value - - p.nextToken() // Skip value } - if p.curToken.Type != lexer.TokenRightBrace { - p.errors = append(p.errors, "expected } or , after table entry") + if !p.expectPeek(lexer.TokenRightBrace) { + return nil } return table } -func (p *Parser) parseIndexExpression() *IndexExpression { +func (p *Parser) parseIndexExpression(left Expression) Expression { exp := &IndexExpression{ Token: p.curToken, - Left: &Identifier{Token: p.curToken, Value: p.curToken.Value}, + Left: left, } - p.nextToken() // Skip identifier p.nextToken() // Skip '[' + exp.Index = p.parseExpression(LOWEST) - exp.Index = p.parseExpression() - - if p.peekToken.Type != lexer.TokenRightBracket { - p.errors = append(p.errors, "expected ] after index expression") - return exp + if !p.expectPeek(lexer.TokenRightBracket) { + return nil } - p.nextToken() // Skip index - p.nextToken() // Skip ']' - return exp } + +// New methods for arithmetic expressions +func (p *Parser) parsePrefixExpression() Expression { + expression := &PrefixExpression{ + Token: p.curToken, + Operator: p.curToken.Value, + } + + p.nextToken() // Skip the prefix token + expression.Right = p.parseExpression(PREFIX) + + return expression +} + +func (p *Parser) parseInfixExpression(left Expression) Expression { + expression := &InfixExpression{ + Token: p.curToken, + Operator: p.curToken.Value, + Left: left, + } + + precedence := p.curPrecedence() + p.nextToken() // Skip the operator + expression.Right = p.parseExpression(precedence) + + return expression +} + +func (p *Parser) parseGroupedExpression() Expression { + p.nextToken() // Skip '(' + + exp := p.parseExpression(LOWEST) + + if !p.expectPeek(lexer.TokenRightParen) { + return nil + } + + // Wrap in GroupedExpression to maintain the AST structure + return &GroupedExpression{ + Token: p.curToken, + Expr: exp, + } +} diff --git a/types/types.go b/types/types.go index dcf4d44..ed5be32 100644 --- a/types/types.go +++ b/types/types.go @@ -26,6 +26,11 @@ const ( OpPop OpEnterScope OpExitScope + OpAdd + OpSubtract + OpMultiply + OpDivide + OpNegate ) type Instruction struct { diff --git a/vm/vm.go b/vm/vm.go index 7548108..7b95788 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -163,6 +163,76 @@ func (vm *VM) Run(bytecode *types.Bytecode) { case types.TypeTable: fmt.Println(vm.formatTable(value.Data.(*types.Table))) } + + // Arithmetic operations + case types.OpAdd: + right := vm.pop() + left := vm.pop() + + if left.Type == types.TypeNumber && right.Type == types.TypeNumber { + result := left.Data.(float64) + right.Data.(float64) + vm.push(types.NewNumber(result)) + } else if left.Type == types.TypeString && right.Type == types.TypeString { + // String concatenation + result := left.Data.(string) + right.Data.(string) + vm.push(types.NewString(result)) + } else { + fmt.Println("Error: cannot add values of different types") + vm.push(types.NewNull()) + } + + case types.OpSubtract: + right := vm.pop() + left := vm.pop() + + if left.Type == types.TypeNumber && right.Type == types.TypeNumber { + result := left.Data.(float64) - right.Data.(float64) + vm.push(types.NewNumber(result)) + } else { + fmt.Println("Error: cannot subtract non-number values") + vm.push(types.NewNull()) + } + + case types.OpMultiply: + right := vm.pop() + left := vm.pop() + + if left.Type == types.TypeNumber && right.Type == types.TypeNumber { + result := left.Data.(float64) * right.Data.(float64) + vm.push(types.NewNumber(result)) + } else { + fmt.Println("Error: cannot multiply non-number values") + vm.push(types.NewNull()) + } + + case types.OpDivide: + right := vm.pop() + left := vm.pop() + + if left.Type == types.TypeNumber && right.Type == types.TypeNumber { + // Check for division by zero + if right.Data.(float64) == 0 { + fmt.Println("Error: division by zero") + vm.push(types.NewNull()) + } else { + result := left.Data.(float64) / right.Data.(float64) + vm.push(types.NewNumber(result)) + } + } else { + fmt.Println("Error: cannot divide non-number values") + vm.push(types.NewNull()) + } + + case types.OpNegate: + operand := vm.pop() + + if operand.Type == types.TypeNumber { + result := -operand.Data.(float64) + vm.push(types.NewNumber(result)) + } else { + fmt.Println("Error: cannot negate non-number value") + vm.push(types.NewNull()) + } } } }