package compiler import ( "fmt" "git.sharkk.net/Sharkk/Mako/parser" "git.sharkk.net/Sharkk/Mako/types" ) // Compile converts AST to bytecode func Compile(program *parser.Program) *types.Bytecode { c := &compiler{ constants: []any{}, instructions: []types.Instruction{}, scopes: []scope{}, } // Start in global scope c.enterScope() for _, stmt := range program.Statements { c.compileStatement(stmt) } c.exitScope() return &types.Bytecode{ Constants: c.constants, Instructions: c.instructions, } } type scope struct { variables map[string]bool } type compiler struct { constants []any instructions []types.Instruction scopes []scope } func (c *compiler) enterScope() { c.scopes = append(c.scopes, scope{ variables: make(map[string]bool), }) c.emit(types.OpEnterScope, 0) } func (c *compiler) exitScope() { c.scopes = c.scopes[:len(c.scopes)-1] c.emit(types.OpExitScope, 0) } func (c *compiler) declareVariable(name string) { if len(c.scopes) > 0 { c.scopes[len(c.scopes)-1].variables[name] = true } } func (c *compiler) isLocalVariable(name string) bool { for i := len(c.scopes) - 1; i >= 0; i-- { if _, ok := c.scopes[i].variables[name]; ok { return true } } return false } func (c *compiler) compileStatement(stmt parser.Statement) { switch s := stmt.(type) { case *parser.VariableStatement: c.compileExpression(s.Value) nameIndex := c.addConstant(s.Name.Value) // Use SetGlobal for top-level variables to persist between REPL lines if len(c.scopes) <= 1 { c.emit(types.OpSetGlobal, nameIndex) } else { c.declareVariable(s.Name.Value) c.emit(types.OpSetLocal, nameIndex) } case *parser.IndexAssignmentStatement: c.compileExpression(s.Left) c.compileExpression(s.Index) c.compileExpression(s.Value) c.emit(types.OpSetIndex, 0) case *parser.EchoStatement: c.compileExpression(s.Value) c.emit(types.OpEcho, 0) case *parser.BlockStatement: c.enterScope() for _, blockStmt := range s.Statements { c.compileStatement(blockStmt) } c.exitScope() } } func (c *compiler) compileExpression(expr parser.Expression) { switch e := expr.(type) { case *parser.StringLiteral: constIndex := c.addConstant(e.Value) c.emit(types.OpConstant, constIndex) case *parser.NumberLiteral: constIndex := c.addConstant(e.Value) c.emit(types.OpConstant, constIndex) case *parser.BooleanLiteral: constIndex := c.addConstant(e.Value) c.emit(types.OpConstant, constIndex) case *parser.Identifier: nameIndex := c.addConstant(e.Value) // Check if it's a local variable first if c.isLocalVariable(e.Value) { c.emit(types.OpGetLocal, nameIndex) } else { // Otherwise treat as global c.emit(types.OpGetGlobal, nameIndex) } case *parser.TableLiteral: c.emit(types.OpNewTable, 0) for key, value := range e.Pairs { c.emit(types.OpDup, 0) // Special handling for identifier keys in tables if ident, ok := key.(*parser.Identifier); ok { // Treat identifiers as string literals in table keys strIndex := c.addConstant(ident.Value) c.emit(types.OpConstant, strIndex) } else { // For other expressions, compile normally c.compileExpression(key) } c.compileExpression(value) c.emit(types.OpSetIndex, 0) c.emit(types.OpPop, 0) } case *parser.IndexExpression: c.compileExpression(e.Left) c.compileExpression(e.Index) c.emit(types.OpGetIndex, 0) // Arithmetic expressions case *parser.InfixExpression: // Compile left and right expressions c.compileExpression(e.Left) c.compileExpression(e.Right) // Generate the appropriate operation switch e.Operator { case "+": c.emit(types.OpAdd, 0) case "-": c.emit(types.OpSubtract, 0) case "*": c.emit(types.OpMultiply, 0) case "/": c.emit(types.OpDivide, 0) case "==": c.emit(types.OpEqual, 0) case "!=": c.emit(types.OpNotEqual, 0) case "<": c.emit(types.OpLessThan, 0) case ">": c.emit(types.OpGreaterThan, 0) case "<=": c.emit(types.OpLessEqual, 0) case ">=": c.emit(types.OpGreaterEqual, 0) default: panic(fmt.Sprintf("Unknown infix operator: %s", e.Operator)) } case *parser.PrefixExpression: // Compile the operand c.compileExpression(e.Right) // Generate the appropriate operation switch e.Operator { case "-": c.emit(types.OpNegate, 0) default: panic(fmt.Sprintf("Unknown prefix operator: %s", e.Operator)) } case *parser.GroupedExpression: // Just compile the inner expression c.compileExpression(e.Expr) case *parser.IfExpression: // Compile condition c.compileExpression(e.Condition) // Emit jump-if-false with placeholder jumpNotTruePos := len(c.instructions) c.emit(types.OpJumpIfFalse, 0) // Will backpatch // Compile consequence (then block) lastStmtIndex := len(e.Consequence.Statements) - 1 for i, stmt := range e.Consequence.Statements { if i == lastStmtIndex { // For the last statement, we need to ensure it leaves a value if exprStmt, ok := stmt.(*parser.ExpressionStatement); ok { c.compileExpression(exprStmt.Expression) } else { c.compileStatement(stmt) // Push null if not an expression statement nullIndex := c.addConstant(nil) c.emit(types.OpConstant, nullIndex) } } else { c.compileStatement(stmt) } } // If no statements, push null if len(e.Consequence.Statements) == 0 { nullIndex := c.addConstant(nil) c.emit(types.OpConstant, nullIndex) } // Emit jump to skip else part jumpPos := len(c.instructions) c.emit(types.OpJump, 0) // Will backpatch // Backpatch jump-if-false to point to else afterConsequencePos := len(c.instructions) c.instructions[jumpNotTruePos].Operand = afterConsequencePos // Compile alternative (else block) if e.Alternative != nil { lastStmtIndex = len(e.Alternative.Statements) - 1 for i, stmt := range e.Alternative.Statements { if i == lastStmtIndex { // For the last statement, we need to ensure it leaves a value if exprStmt, ok := stmt.(*parser.ExpressionStatement); ok { c.compileExpression(exprStmt.Expression) } else { c.compileStatement(stmt) // Push null if not an expression statement nullIndex := c.addConstant(nil) c.emit(types.OpConstant, nullIndex) } } else { c.compileStatement(stmt) } } // If no statements, push null if len(e.Alternative.Statements) == 0 { nullIndex := c.addConstant(nil) c.emit(types.OpConstant, nullIndex) } } else { // No else - push null nullIndex := c.addConstant(nil) c.emit(types.OpConstant, nullIndex) } // Backpatch jump to point after else afterAlternativePos := len(c.instructions) c.instructions[jumpPos].Operand = afterAlternativePos } } func (c *compiler) addConstant(value any) int { c.constants = append(c.constants, value) return len(c.constants) - 1 } func (c *compiler) emit(op types.Opcode, operand int) { instruction := types.Instruction{ Opcode: op, Operand: operand, } c.instructions = append(c.instructions, instruction) }