package compiler import ( "fmt" "git.sharkk.net/Sharkk/Mako/parser" "git.sharkk.net/Sharkk/Mako/types" ) // Compile converts AST to bytecode func Compile(program *parser.Program) *types.Bytecode { c := &compiler{ constants: []any{}, instructions: []types.Instruction{}, scopes: []scope{}, currentFunction: nil, } // Start in global scope c.enterScope() // Add nil check for program if program == nil { c.exitScope() return &types.Bytecode{ Constants: c.constants, Instructions: c.instructions, } } // Process each statement safely for _, stmt := range program.Statements { // Skip nil statements if stmt == nil { continue } c.compileStatement(stmt) } c.exitScope() return &types.Bytecode{ Constants: c.constants, Instructions: c.instructions, } } type scope struct { variables map[string]bool upvalues map[string]int } type compiler struct { constants []any instructions []types.Instruction scopes []scope currentFunction *functionCompiler } type functionCompiler struct { constants []any instructions []types.Instruction numParams int upvalues []upvalueInfo } type upvalueInfo struct { index int // Index in the upvalue list isLocal bool // Whether this is a local variable or an upvalue from an outer scope capturedFrom int // The scope level where this variable was captured from } func (c *compiler) enterScope() { c.scopes = append(c.scopes, scope{ variables: make(map[string]bool), upvalues: make(map[string]int), }) c.emit(types.OpEnterScope, 0) } func (c *compiler) exitScope() { c.scopes = c.scopes[:len(c.scopes)-1] c.emit(types.OpExitScope, 0) } func (c *compiler) declareVariable(name string) { if len(c.scopes) > 0 { c.scopes[len(c.scopes)-1].variables[name] = true } } func (c *compiler) isLocalVariable(name string) bool { for i := len(c.scopes) - 1; i >= 0; i-- { if _, ok := c.scopes[i].variables[name]; ok { return true } } return false } func (c *compiler) compileStatement(stmt parser.Statement) { if stmt == nil { return } switch s := stmt.(type) { case *parser.VariableStatement: c.compileExpression(s.Value) nameIndex := c.addConstant(s.Name.Value) // Use SetGlobal for top-level variables to persist between REPL lines if len(c.scopes) <= 1 { c.emit(types.OpSetGlobal, nameIndex) } else { c.declareVariable(s.Name.Value) c.emit(types.OpSetLocal, nameIndex) } case *parser.IndexAssignmentStatement: c.compileExpression(s.Left) c.compileExpression(s.Index) c.compileExpression(s.Value) c.emit(types.OpSetIndex, 0) case *parser.EchoStatement: c.compileExpression(s.Value) c.emit(types.OpEcho, 0) case *parser.ReturnStatement: if s.Value != nil { c.compileExpression(s.Value) } else { nullIndex := c.addConstant(nil) c.emit(types.OpConstant, nullIndex) } c.emit(types.OpReturn, 0) case *parser.FunctionStatement: // Use the dedicated function for function statements c.compileFunctionDeclaration(s) // BlockStatement now should only be used for keyword blocks like if-then-else-end case *parser.BlockStatement: for _, blockStmt := range s.Statements { c.compileStatement(blockStmt) } case *parser.ExpressionStatement: c.compileExpression(s.Expression) // Pop the value since we're not using it c.emit(types.OpPop, 0) } } func (c *compiler) compileExpression(expr parser.Expression) { switch e := expr.(type) { case *parser.StringLiteral: constIndex := c.addConstant(e.Value) c.emit(types.OpConstant, constIndex) case *parser.NumberLiteral: constIndex := c.addConstant(e.Value) c.emit(types.OpConstant, constIndex) case *parser.BooleanLiteral: constIndex := c.addConstant(e.Value) c.emit(types.OpConstant, constIndex) case *parser.NilLiteral: constIndex := c.addConstant(nil) c.emit(types.OpConstant, constIndex) case *parser.Identifier: nameIndex := c.addConstant(e.Value) // Check if it's a local variable first if c.isLocalVariable(e.Value) { c.emit(types.OpGetLocal, nameIndex) } else { // Otherwise treat as global c.emit(types.OpGetGlobal, nameIndex) } case *parser.TableLiteral: c.emit(types.OpNewTable, 0) for key, value := range e.Pairs { c.emit(types.OpDup, 0) // Special handling for identifier keys in tables if ident, ok := key.(*parser.Identifier); ok { // Treat identifiers as string literals in table keys strIndex := c.addConstant(ident.Value) c.emit(types.OpConstant, strIndex) } else { // For other expressions, compile normally c.compileExpression(key) } c.compileExpression(value) c.emit(types.OpSetIndex, 0) c.emit(types.OpPop, 0) } case *parser.IndexExpression: c.compileExpression(e.Left) c.compileExpression(e.Index) c.emit(types.OpGetIndex, 0) case *parser.FunctionLiteral: c.compileFunctionLiteral(e) case *parser.CallExpression: // Compile the function expression first c.compileExpression(e.Function) // Then compile the arguments for _, arg := range e.Arguments { c.compileExpression(arg) } // Emit the call instruction with the number of arguments c.emit(types.OpCall, len(e.Arguments)) // Arithmetic expressions case *parser.InfixExpression: switch e.Operator { case "and": // Compile left operand c.compileExpression(e.Left) // Duplicate to check condition c.emit(types.OpDup, 0) // Jump if false (short-circuit) jumpFalsePos := len(c.instructions) c.emit(types.OpJumpIfFalse, 0) // Will backpatch // Pop the duplicate since we'll replace it c.emit(types.OpPop, 0) // Compile right operand c.compileExpression(e.Right) // Jump target for short-circuit endPos := len(c.instructions) c.instructions[jumpFalsePos].Operand = endPos case "or": // Compile left operand c.compileExpression(e.Left) // Duplicate to check condition c.emit(types.OpDup, 0) // Need to check if it's truthy to short-circuit falseJumpPos := len(c.instructions) c.emit(types.OpJumpIfFalse, 0) // Jump to right eval if false // If truthy, jump to end trueJumpPos := len(c.instructions) c.emit(types.OpJump, 0) // Jump to end if true // Position for false case falsePos := len(c.instructions) c.instructions[falseJumpPos].Operand = falsePos // Pop the duplicate since we'll replace it c.emit(types.OpPop, 0) // Compile right operand c.compileExpression(e.Right) // End position endPos := len(c.instructions) c.instructions[trueJumpPos].Operand = endPos default: // Original infix expression compilation c.compileExpression(e.Left) c.compileExpression(e.Right) // Generate the appropriate operation switch e.Operator { case "+": c.emit(types.OpAdd, 0) case "-": c.emit(types.OpSubtract, 0) case "*": c.emit(types.OpMultiply, 0) case "/": c.emit(types.OpDivide, 0) case "==": c.emit(types.OpEqual, 0) case "!=": c.emit(types.OpNotEqual, 0) case "<": c.emit(types.OpLessThan, 0) case ">": c.emit(types.OpGreaterThan, 0) case "<=": c.emit(types.OpLessEqual, 0) case ">=": c.emit(types.OpGreaterEqual, 0) default: panic(fmt.Sprintf("Unknown infix operator: %s", e.Operator)) } } case *parser.PrefixExpression: // Compile the operand c.compileExpression(e.Right) // Generate the appropriate operation switch e.Operator { case "-": c.emit(types.OpNegate, 0) case "not": c.emit(types.OpNot, 0) default: panic(fmt.Sprintf("Unknown prefix operator: %s", e.Operator)) } case *parser.GroupedExpression: // Just compile the inner expression c.compileExpression(e.Expr) case *parser.IfExpression: // Compile condition c.compileExpression(e.Condition) // Emit jump-if-false with placeholder jumpNotTruePos := len(c.instructions) c.emit(types.OpJumpIfFalse, 0) // Will backpatch // Compile consequence (then block) if e.Consequence != nil { lastStmtIndex := len(e.Consequence.Statements) - 1 for i, stmt := range e.Consequence.Statements { if i == lastStmtIndex { // For the last statement, we need to ensure it leaves a value if exprStmt, ok := stmt.(*parser.ExpressionStatement); ok { c.compileExpression(exprStmt.Expression) } else { c.compileStatement(stmt) // Push null if not an expression statement nullIndex := c.addConstant(nil) c.emit(types.OpConstant, nullIndex) } } else { c.compileStatement(stmt) } } // If no statements, push null if len(e.Consequence.Statements) == 0 { nullIndex := c.addConstant(nil) c.emit(types.OpConstant, nullIndex) } } else { // No consequence block, push null nullIndex := c.addConstant(nil) c.emit(types.OpConstant, nullIndex) } // Emit jump to skip else part jumpPos := len(c.instructions) c.emit(types.OpJump, 0) // Will backpatch // Backpatch jump-if-false to point to else afterConsequencePos := len(c.instructions) c.instructions[jumpNotTruePos].Operand = afterConsequencePos // Compile alternative (else block) if e.Alternative != nil { lastStmtIndex := len(e.Alternative.Statements) - 1 for i, stmt := range e.Alternative.Statements { if i == lastStmtIndex { // For the last statement, we need to ensure it leaves a value if exprStmt, ok := stmt.(*parser.ExpressionStatement); ok { c.compileExpression(exprStmt.Expression) } else { c.compileStatement(stmt) // Push null if not an expression statement nullIndex := c.addConstant(nil) c.emit(types.OpConstant, nullIndex) } } else { c.compileStatement(stmt) } } // If no statements, push null if len(e.Alternative.Statements) == 0 { nullIndex := c.addConstant(nil) c.emit(types.OpConstant, nullIndex) } } else { // No else - push null nullIndex := c.addConstant(nil) c.emit(types.OpConstant, nullIndex) } // Backpatch jump to point after else afterAlternativePos := len(c.instructions) c.instructions[jumpPos].Operand = afterAlternativePos } } func (c *compiler) compileFunctionLiteral(fn *parser.FunctionLiteral) { // Save the current compiler state parentCompiler := c.currentFunction // Create a new function compiler fnCompiler := &functionCompiler{ constants: []any{}, instructions: []types.Instruction{}, numParams: len(fn.Parameters), upvalues: []upvalueInfo{}, } c.currentFunction = fnCompiler // Enter a new scope for the function body c.enterScope() // Declare parameters as local variables for _, param := range fn.Parameters { c.declareVariable(param.Value) paramIndex := c.addConstant(param.Value) c.emit(types.OpSetLocal, paramIndex) } // Compile the function body for _, stmt := range fn.Body.Statements { c.compileStatement(stmt) } // Ensure the function always returns a value // If the last instruction is not a return, add one if len(fnCompiler.instructions) == 0 || (len(fnCompiler.instructions) > 0 && fnCompiler.instructions[len(fnCompiler.instructions)-1].Opcode != types.OpReturn) { nullIndex := c.addConstant(nil) c.emit(types.OpConstant, nullIndex) c.emit(types.OpReturn, 0) } // Exit the function scope c.exitScope() // Restore the parent compiler c.currentFunction = parentCompiler // Extract upvalue information for closure creation upvalueIndexes := make([]int, len(fnCompiler.upvalues)) for i, upvalue := range fnCompiler.upvalues { upvalueIndexes[i] = upvalue.index } // Create a Function object and add it to the constants function := types.NewFunction( fnCompiler.instructions, fnCompiler.numParams, fnCompiler.constants, upvalueIndexes, ) functionIndex := c.addConstant(function) c.emit(types.OpFunction, functionIndex) } func (c *compiler) addConstant(value any) int { if c.currentFunction != nil { c.currentFunction.constants = append(c.currentFunction.constants, value) return len(c.currentFunction.constants) - 1 } c.constants = append(c.constants, value) return len(c.constants) - 1 } func (c *compiler) emit(op types.Opcode, operand int) { instruction := types.Instruction{ Opcode: op, Operand: operand, } if c.currentFunction != nil { c.currentFunction.instructions = append(c.currentFunction.instructions, instruction) } else { c.instructions = append(c.instructions, instruction) } } func (c *compiler) compileFunctionDeclaration(fn *parser.FunctionStatement) { // Save the current compiler state parentCompiler := c.currentFunction // Create a new function compiler fnCompiler := &functionCompiler{ constants: []any{}, instructions: []types.Instruction{}, numParams: len(fn.Parameters), upvalues: []upvalueInfo{}, } c.currentFunction = fnCompiler // Enter a new scope for the function body c.enterScope() // Declare parameters as local variables for _, param := range fn.Parameters { c.declareVariable(param.Value) paramIndex := c.addConstant(param.Value) c.emit(types.OpSetLocal, paramIndex) } // Compile the function body for _, stmt := range fn.Body.Statements { c.compileStatement(stmt) } // Ensure the function always returns a value // If the last instruction is not a return, add one if len(fnCompiler.instructions) == 0 || fnCompiler.instructions[len(fnCompiler.instructions)-1].Opcode != types.OpReturn { nullIndex := c.addConstant(nil) c.emit(types.OpConstant, nullIndex) c.emit(types.OpReturn, 0) } // Exit the function scope c.exitScope() // Restore the parent compiler c.currentFunction = parentCompiler // Extract upvalue information for closure creation upvalueIndexes := make([]int, len(fnCompiler.upvalues)) for i, upvalue := range fnCompiler.upvalues { upvalueIndexes[i] = upvalue.index } // Create a Function object and add it to the constants function := types.NewFunction( fnCompiler.instructions, fnCompiler.numParams, fnCompiler.constants, upvalueIndexes, ) functionIndex := c.addConstant(function) c.emit(types.OpFunction, functionIndex) // Store the function in a global variable nameIndex := c.addConstant(fn.Name.Value) // Use SetGlobal for top-level variables to persist between REPL lines if len(c.scopes) <= 1 { c.emit(types.OpSetGlobal, nameIndex) } else { c.declareVariable(fn.Name.Value) c.emit(types.OpSetLocal, nameIndex) } }