Mako/compiler/compiler.go
2025-05-06 23:20:10 -05:00

566 lines
14 KiB
Go

package compiler
import (
"fmt"
"git.sharkk.net/Sharkk/Mako/parser"
"git.sharkk.net/Sharkk/Mako/types"
)
// Compile converts AST to bytecode
func Compile(program *parser.Program) *types.Bytecode {
c := &compiler{
constants: []any{},
instructions: []types.Instruction{},
scopes: []scope{},
currentFunction: nil,
}
// Start in global scope
c.enterScope()
// Add nil check for program
if program == nil {
c.exitScope()
return &types.Bytecode{
Constants: c.constants,
Instructions: c.instructions,
}
}
// Process each statement safely
for _, stmt := range program.Statements {
// Skip nil statements
if stmt == nil {
continue
}
c.compileStatement(stmt)
}
c.exitScope()
return &types.Bytecode{
Constants: c.constants,
Instructions: c.instructions,
}
}
type scope struct {
variables map[string]bool
upvalues map[string]int
}
type compiler struct {
constants []any
instructions []types.Instruction
scopes []scope
currentFunction *functionCompiler
}
type functionCompiler struct {
constants []any
instructions []types.Instruction
numParams int
upvalues []upvalueInfo
}
type upvalueInfo struct {
index int // Index in the upvalue list
isLocal bool // Whether this is a local variable or an upvalue from an outer scope
capturedFrom int // The scope level where this variable was captured from
}
func (c *compiler) enterScope() {
c.scopes = append(c.scopes, scope{
variables: make(map[string]bool),
upvalues: make(map[string]int),
})
c.emit(types.OpEnterScope, 0)
}
func (c *compiler) exitScope() {
c.scopes = c.scopes[:len(c.scopes)-1]
c.emit(types.OpExitScope, 0)
}
func (c *compiler) declareVariable(name string) {
if len(c.scopes) > 0 {
c.scopes[len(c.scopes)-1].variables[name] = true
}
}
func (c *compiler) isLocalVariable(name string) bool {
for i := len(c.scopes) - 1; i >= 0; i-- {
if _, ok := c.scopes[i].variables[name]; ok {
return true
}
}
return false
}
func (c *compiler) compileStatement(stmt parser.Statement) {
if stmt == nil {
return
}
switch s := stmt.(type) {
case *parser.VariableStatement:
c.compileExpression(s.Value)
nameIndex := c.addConstant(s.Name.Value)
// Use SetGlobal for top-level variables to persist between REPL lines
if len(c.scopes) <= 1 {
c.emit(types.OpSetGlobal, nameIndex)
} else {
c.declareVariable(s.Name.Value)
c.emit(types.OpSetLocal, nameIndex)
}
case *parser.IndexAssignmentStatement:
c.compileExpression(s.Left)
c.compileExpression(s.Index)
c.compileExpression(s.Value)
c.emit(types.OpSetIndex, 0)
case *parser.EchoStatement:
c.compileExpression(s.Value)
c.emit(types.OpEcho, 0)
case *parser.ReturnStatement:
if s.Value != nil {
c.compileExpression(s.Value)
} else {
nullIndex := c.addConstant(nil)
c.emit(types.OpConstant, nullIndex)
}
c.emit(types.OpReturn, 0)
case *parser.FunctionStatement:
// Use the dedicated function for function statements
c.compileFunctionDeclaration(s)
// BlockStatement now should only be used for keyword blocks like if-then-else-end
case *parser.BlockStatement:
for _, blockStmt := range s.Statements {
c.compileStatement(blockStmt)
}
case *parser.ExpressionStatement:
c.compileExpression(s.Expression)
// Pop the value since we're not using it
c.emit(types.OpPop, 0)
}
}
func (c *compiler) compileExpression(expr parser.Expression) {
switch e := expr.(type) {
case *parser.StringLiteral:
constIndex := c.addConstant(e.Value)
c.emit(types.OpConstant, constIndex)
case *parser.NumberLiteral:
constIndex := c.addConstant(e.Value)
c.emit(types.OpConstant, constIndex)
case *parser.BooleanLiteral:
constIndex := c.addConstant(e.Value)
c.emit(types.OpConstant, constIndex)
case *parser.NilLiteral:
constIndex := c.addConstant(nil)
c.emit(types.OpConstant, constIndex)
case *parser.Identifier:
nameIndex := c.addConstant(e.Value)
// Check if it's a local variable first
if c.isLocalVariable(e.Value) {
c.emit(types.OpGetLocal, nameIndex)
} else {
// Otherwise treat as global
c.emit(types.OpGetGlobal, nameIndex)
}
case *parser.TableLiteral:
c.emit(types.OpNewTable, 0)
for key, value := range e.Pairs {
c.emit(types.OpDup, 0)
// Special handling for identifier keys in tables
if ident, ok := key.(*parser.Identifier); ok {
// Treat identifiers as string literals in table keys
strIndex := c.addConstant(ident.Value)
c.emit(types.OpConstant, strIndex)
} else {
// For other expressions, compile normally
c.compileExpression(key)
}
c.compileExpression(value)
c.emit(types.OpSetIndex, 0)
c.emit(types.OpPop, 0)
}
case *parser.IndexExpression:
c.compileExpression(e.Left)
c.compileExpression(e.Index)
c.emit(types.OpGetIndex, 0)
case *parser.FunctionLiteral:
c.compileFunctionLiteral(e)
case *parser.CallExpression:
// Compile the function expression first
c.compileExpression(e.Function)
// Then compile the arguments
for _, arg := range e.Arguments {
c.compileExpression(arg)
}
// Emit the call instruction with the number of arguments
c.emit(types.OpCall, len(e.Arguments))
// Arithmetic expressions
case *parser.InfixExpression:
switch e.Operator {
case "and":
// Compile left operand
c.compileExpression(e.Left)
// Duplicate to check condition
c.emit(types.OpDup, 0)
// Jump if false (short-circuit)
jumpFalsePos := len(c.instructions)
c.emit(types.OpJumpIfFalse, 0) // Will backpatch
// Pop the duplicate since we'll replace it
c.emit(types.OpPop, 0)
// Compile right operand
c.compileExpression(e.Right)
// Jump target for short-circuit
endPos := len(c.instructions)
c.instructions[jumpFalsePos].Operand = endPos
case "or":
// Compile left operand
c.compileExpression(e.Left)
// Duplicate to check condition
c.emit(types.OpDup, 0)
// Need to check if it's truthy to short-circuit
falseJumpPos := len(c.instructions)
c.emit(types.OpJumpIfFalse, 0) // Jump to right eval if false
// If truthy, jump to end
trueJumpPos := len(c.instructions)
c.emit(types.OpJump, 0) // Jump to end if true
// Position for false case
falsePos := len(c.instructions)
c.instructions[falseJumpPos].Operand = falsePos
// Pop the duplicate since we'll replace it
c.emit(types.OpPop, 0)
// Compile right operand
c.compileExpression(e.Right)
// End position
endPos := len(c.instructions)
c.instructions[trueJumpPos].Operand = endPos
default:
// Original infix expression compilation
c.compileExpression(e.Left)
c.compileExpression(e.Right)
// Generate the appropriate operation
switch e.Operator {
case "+":
c.emit(types.OpAdd, 0)
case "-":
c.emit(types.OpSubtract, 0)
case "*":
c.emit(types.OpMultiply, 0)
case "/":
c.emit(types.OpDivide, 0)
case "==":
c.emit(types.OpEqual, 0)
case "!=":
c.emit(types.OpNotEqual, 0)
case "<":
c.emit(types.OpLessThan, 0)
case ">":
c.emit(types.OpGreaterThan, 0)
case "<=":
c.emit(types.OpLessEqual, 0)
case ">=":
c.emit(types.OpGreaterEqual, 0)
default:
panic(fmt.Sprintf("Unknown infix operator: %s", e.Operator))
}
}
case *parser.PrefixExpression:
// Compile the operand
c.compileExpression(e.Right)
// Generate the appropriate operation
switch e.Operator {
case "-":
c.emit(types.OpNegate, 0)
case "not":
c.emit(types.OpNot, 0)
default:
panic(fmt.Sprintf("Unknown prefix operator: %s", e.Operator))
}
case *parser.GroupedExpression:
// Just compile the inner expression
c.compileExpression(e.Expr)
case *parser.IfExpression:
// Compile condition
c.compileExpression(e.Condition)
// Emit jump-if-false with placeholder
jumpNotTruePos := len(c.instructions)
c.emit(types.OpJumpIfFalse, 0) // Will backpatch
// Compile consequence (then block)
if e.Consequence != nil {
lastStmtIndex := len(e.Consequence.Statements) - 1
for i, stmt := range e.Consequence.Statements {
if i == lastStmtIndex {
// For the last statement, we need to ensure it leaves a value
if exprStmt, ok := stmt.(*parser.ExpressionStatement); ok {
c.compileExpression(exprStmt.Expression)
} else {
c.compileStatement(stmt)
// Push null if not an expression statement
nullIndex := c.addConstant(nil)
c.emit(types.OpConstant, nullIndex)
}
} else {
c.compileStatement(stmt)
}
}
// If no statements, push null
if len(e.Consequence.Statements) == 0 {
nullIndex := c.addConstant(nil)
c.emit(types.OpConstant, nullIndex)
}
} else {
// No consequence block, push null
nullIndex := c.addConstant(nil)
c.emit(types.OpConstant, nullIndex)
}
// Emit jump to skip else part
jumpPos := len(c.instructions)
c.emit(types.OpJump, 0) // Will backpatch
// Backpatch jump-if-false to point to else
afterConsequencePos := len(c.instructions)
c.instructions[jumpNotTruePos].Operand = afterConsequencePos
// Compile alternative (else block)
if e.Alternative != nil {
lastStmtIndex := len(e.Alternative.Statements) - 1
for i, stmt := range e.Alternative.Statements {
if i == lastStmtIndex {
// For the last statement, we need to ensure it leaves a value
if exprStmt, ok := stmt.(*parser.ExpressionStatement); ok {
c.compileExpression(exprStmt.Expression)
} else {
c.compileStatement(stmt)
// Push null if not an expression statement
nullIndex := c.addConstant(nil)
c.emit(types.OpConstant, nullIndex)
}
} else {
c.compileStatement(stmt)
}
}
// If no statements, push null
if len(e.Alternative.Statements) == 0 {
nullIndex := c.addConstant(nil)
c.emit(types.OpConstant, nullIndex)
}
} else {
// No else - push null
nullIndex := c.addConstant(nil)
c.emit(types.OpConstant, nullIndex)
}
// Backpatch jump to point after else
afterAlternativePos := len(c.instructions)
c.instructions[jumpPos].Operand = afterAlternativePos
}
}
func (c *compiler) compileFunctionLiteral(fn *parser.FunctionLiteral) {
// Save the current compiler state
parentCompiler := c.currentFunction
// Create a new function compiler
fnCompiler := &functionCompiler{
constants: []any{},
instructions: []types.Instruction{},
numParams: len(fn.Parameters),
upvalues: []upvalueInfo{},
}
c.currentFunction = fnCompiler
// Enter a new scope for the function body
c.enterScope()
// Declare parameters as local variables
for _, param := range fn.Parameters {
c.declareVariable(param.Value)
paramIndex := c.addConstant(param.Value)
c.emit(types.OpSetLocal, paramIndex)
}
// Compile the function body
for _, stmt := range fn.Body.Statements {
c.compileStatement(stmt)
}
// Ensure the function always returns a value
// If the last instruction is not a return, add one
if len(fnCompiler.instructions) == 0 ||
(len(fnCompiler.instructions) > 0 && fnCompiler.instructions[len(fnCompiler.instructions)-1].Opcode != types.OpReturn) {
nullIndex := c.addConstant(nil)
c.emit(types.OpConstant, nullIndex)
c.emit(types.OpReturn, 0)
}
// Exit the function scope
c.exitScope()
// Restore the parent compiler
c.currentFunction = parentCompiler
// Extract upvalue information for closure creation
upvalueIndexes := make([]int, len(fnCompiler.upvalues))
for i, upvalue := range fnCompiler.upvalues {
upvalueIndexes[i] = upvalue.index
}
// Create a Function object and add it to the constants
function := types.NewFunction(
fnCompiler.instructions,
fnCompiler.numParams,
fnCompiler.constants,
upvalueIndexes,
)
functionIndex := c.addConstant(function)
c.emit(types.OpFunction, functionIndex)
}
func (c *compiler) addConstant(value any) int {
if c.currentFunction != nil {
c.currentFunction.constants = append(c.currentFunction.constants, value)
return len(c.currentFunction.constants) - 1
}
c.constants = append(c.constants, value)
return len(c.constants) - 1
}
func (c *compiler) emit(op types.Opcode, operand int) {
instruction := types.Instruction{
Opcode: op,
Operand: operand,
}
if c.currentFunction != nil {
c.currentFunction.instructions = append(c.currentFunction.instructions, instruction)
} else {
c.instructions = append(c.instructions, instruction)
}
}
func (c *compiler) compileFunctionDeclaration(fn *parser.FunctionStatement) {
// Save the current compiler state
parentCompiler := c.currentFunction
// Create a new function compiler
fnCompiler := &functionCompiler{
constants: []any{},
instructions: []types.Instruction{},
numParams: len(fn.Parameters),
upvalues: []upvalueInfo{},
}
c.currentFunction = fnCompiler
// Enter a new scope for the function body
c.enterScope()
// Declare parameters as local variables
for _, param := range fn.Parameters {
c.declareVariable(param.Value)
paramIndex := c.addConstant(param.Value)
c.emit(types.OpSetLocal, paramIndex)
}
// Compile the function body
for _, stmt := range fn.Body.Statements {
c.compileStatement(stmt)
}
// Ensure the function always returns a value
// If the last instruction is not a return, add one
if len(fnCompiler.instructions) == 0 || fnCompiler.instructions[len(fnCompiler.instructions)-1].Opcode != types.OpReturn {
nullIndex := c.addConstant(nil)
c.emit(types.OpConstant, nullIndex)
c.emit(types.OpReturn, 0)
}
// Exit the function scope
c.exitScope()
// Restore the parent compiler
c.currentFunction = parentCompiler
// Extract upvalue information for closure creation
upvalueIndexes := make([]int, len(fnCompiler.upvalues))
for i, upvalue := range fnCompiler.upvalues {
upvalueIndexes[i] = upvalue.index
}
// Create a Function object and add it to the constants
function := types.NewFunction(
fnCompiler.instructions,
fnCompiler.numParams,
fnCompiler.constants,
upvalueIndexes,
)
functionIndex := c.addConstant(function)
c.emit(types.OpFunction, functionIndex)
// Store the function in a global variable
nameIndex := c.addConstant(fn.Name.Value)
// Use SetGlobal for top-level variables to persist between REPL lines
if len(c.scopes) <= 1 {
c.emit(types.OpSetGlobal, nameIndex)
} else {
c.declareVariable(fn.Name.Value)
c.emit(types.OpSetLocal, nameIndex)
}
}