566 lines
14 KiB
Go
566 lines
14 KiB
Go
package compiler
|
|
|
|
import (
|
|
"fmt"
|
|
|
|
"git.sharkk.net/Sharkk/Mako/parser"
|
|
"git.sharkk.net/Sharkk/Mako/types"
|
|
)
|
|
|
|
// Compile converts AST to bytecode
|
|
func Compile(program *parser.Program) *types.Bytecode {
|
|
c := &compiler{
|
|
constants: []any{},
|
|
instructions: []types.Instruction{},
|
|
scopes: []scope{},
|
|
currentFunction: nil,
|
|
}
|
|
|
|
// Start in global scope
|
|
c.enterScope()
|
|
|
|
// Add nil check for program
|
|
if program == nil {
|
|
c.exitScope()
|
|
return &types.Bytecode{
|
|
Constants: c.constants,
|
|
Instructions: c.instructions,
|
|
}
|
|
}
|
|
|
|
// Process each statement safely
|
|
for _, stmt := range program.Statements {
|
|
// Skip nil statements
|
|
if stmt == nil {
|
|
continue
|
|
}
|
|
c.compileStatement(stmt)
|
|
}
|
|
|
|
c.exitScope()
|
|
|
|
return &types.Bytecode{
|
|
Constants: c.constants,
|
|
Instructions: c.instructions,
|
|
}
|
|
}
|
|
|
|
type scope struct {
|
|
variables map[string]bool
|
|
upvalues map[string]int
|
|
}
|
|
|
|
type compiler struct {
|
|
constants []any
|
|
instructions []types.Instruction
|
|
scopes []scope
|
|
currentFunction *functionCompiler
|
|
}
|
|
|
|
type functionCompiler struct {
|
|
constants []any
|
|
instructions []types.Instruction
|
|
numParams int
|
|
upvalues []upvalueInfo
|
|
}
|
|
|
|
type upvalueInfo struct {
|
|
index int // Index in the upvalue list
|
|
isLocal bool // Whether this is a local variable or an upvalue from an outer scope
|
|
capturedFrom int // The scope level where this variable was captured from
|
|
}
|
|
|
|
func (c *compiler) enterScope() {
|
|
c.scopes = append(c.scopes, scope{
|
|
variables: make(map[string]bool),
|
|
upvalues: make(map[string]int),
|
|
})
|
|
c.emit(types.OpEnterScope, 0)
|
|
}
|
|
|
|
func (c *compiler) exitScope() {
|
|
c.scopes = c.scopes[:len(c.scopes)-1]
|
|
c.emit(types.OpExitScope, 0)
|
|
}
|
|
|
|
func (c *compiler) declareVariable(name string) {
|
|
if len(c.scopes) > 0 {
|
|
c.scopes[len(c.scopes)-1].variables[name] = true
|
|
}
|
|
}
|
|
|
|
func (c *compiler) isLocalVariable(name string) bool {
|
|
for i := len(c.scopes) - 1; i >= 0; i-- {
|
|
if _, ok := c.scopes[i].variables[name]; ok {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (c *compiler) compileStatement(stmt parser.Statement) {
|
|
if stmt == nil {
|
|
return
|
|
}
|
|
|
|
switch s := stmt.(type) {
|
|
case *parser.VariableStatement:
|
|
c.compileExpression(s.Value)
|
|
nameIndex := c.addConstant(s.Name.Value)
|
|
|
|
// Use SetGlobal for top-level variables to persist between REPL lines
|
|
if len(c.scopes) <= 1 {
|
|
c.emit(types.OpSetGlobal, nameIndex)
|
|
} else {
|
|
c.declareVariable(s.Name.Value)
|
|
c.emit(types.OpSetLocal, nameIndex)
|
|
}
|
|
|
|
case *parser.IndexAssignmentStatement:
|
|
c.compileExpression(s.Left)
|
|
c.compileExpression(s.Index)
|
|
c.compileExpression(s.Value)
|
|
c.emit(types.OpSetIndex, 0)
|
|
|
|
case *parser.EchoStatement:
|
|
c.compileExpression(s.Value)
|
|
c.emit(types.OpEcho, 0)
|
|
|
|
case *parser.ReturnStatement:
|
|
if s.Value != nil {
|
|
c.compileExpression(s.Value)
|
|
} else {
|
|
nullIndex := c.addConstant(nil)
|
|
c.emit(types.OpConstant, nullIndex)
|
|
}
|
|
c.emit(types.OpReturn, 0)
|
|
|
|
case *parser.FunctionStatement:
|
|
// Use the dedicated function for function statements
|
|
c.compileFunctionDeclaration(s)
|
|
|
|
// BlockStatement now should only be used for keyword blocks like if-then-else-end
|
|
case *parser.BlockStatement:
|
|
for _, blockStmt := range s.Statements {
|
|
c.compileStatement(blockStmt)
|
|
}
|
|
|
|
case *parser.ExpressionStatement:
|
|
c.compileExpression(s.Expression)
|
|
// Pop the value since we're not using it
|
|
c.emit(types.OpPop, 0)
|
|
}
|
|
}
|
|
|
|
func (c *compiler) compileExpression(expr parser.Expression) {
|
|
switch e := expr.(type) {
|
|
case *parser.StringLiteral:
|
|
constIndex := c.addConstant(e.Value)
|
|
c.emit(types.OpConstant, constIndex)
|
|
|
|
case *parser.NumberLiteral:
|
|
constIndex := c.addConstant(e.Value)
|
|
c.emit(types.OpConstant, constIndex)
|
|
|
|
case *parser.BooleanLiteral:
|
|
constIndex := c.addConstant(e.Value)
|
|
c.emit(types.OpConstant, constIndex)
|
|
|
|
case *parser.NilLiteral:
|
|
constIndex := c.addConstant(nil)
|
|
c.emit(types.OpConstant, constIndex)
|
|
|
|
case *parser.Identifier:
|
|
nameIndex := c.addConstant(e.Value)
|
|
|
|
// Check if it's a local variable first
|
|
if c.isLocalVariable(e.Value) {
|
|
c.emit(types.OpGetLocal, nameIndex)
|
|
} else {
|
|
// Otherwise treat as global
|
|
c.emit(types.OpGetGlobal, nameIndex)
|
|
}
|
|
|
|
case *parser.TableLiteral:
|
|
c.emit(types.OpNewTable, 0)
|
|
|
|
for key, value := range e.Pairs {
|
|
c.emit(types.OpDup, 0)
|
|
|
|
// Special handling for identifier keys in tables
|
|
if ident, ok := key.(*parser.Identifier); ok {
|
|
// Treat identifiers as string literals in table keys
|
|
strIndex := c.addConstant(ident.Value)
|
|
c.emit(types.OpConstant, strIndex)
|
|
} else {
|
|
// For other expressions, compile normally
|
|
c.compileExpression(key)
|
|
}
|
|
|
|
c.compileExpression(value)
|
|
c.emit(types.OpSetIndex, 0)
|
|
c.emit(types.OpPop, 0)
|
|
}
|
|
|
|
case *parser.IndexExpression:
|
|
c.compileExpression(e.Left)
|
|
c.compileExpression(e.Index)
|
|
c.emit(types.OpGetIndex, 0)
|
|
|
|
case *parser.FunctionLiteral:
|
|
c.compileFunctionLiteral(e)
|
|
|
|
case *parser.CallExpression:
|
|
// Compile the function expression first
|
|
c.compileExpression(e.Function)
|
|
|
|
// Then compile the arguments
|
|
for _, arg := range e.Arguments {
|
|
c.compileExpression(arg)
|
|
}
|
|
|
|
// Emit the call instruction with the number of arguments
|
|
c.emit(types.OpCall, len(e.Arguments))
|
|
|
|
// Arithmetic expressions
|
|
case *parser.InfixExpression:
|
|
switch e.Operator {
|
|
case "and":
|
|
// Compile left operand
|
|
c.compileExpression(e.Left)
|
|
|
|
// Duplicate to check condition
|
|
c.emit(types.OpDup, 0)
|
|
|
|
// Jump if false (short-circuit)
|
|
jumpFalsePos := len(c.instructions)
|
|
c.emit(types.OpJumpIfFalse, 0) // Will backpatch
|
|
|
|
// Pop the duplicate since we'll replace it
|
|
c.emit(types.OpPop, 0)
|
|
|
|
// Compile right operand
|
|
c.compileExpression(e.Right)
|
|
|
|
// Jump target for short-circuit
|
|
endPos := len(c.instructions)
|
|
c.instructions[jumpFalsePos].Operand = endPos
|
|
|
|
case "or":
|
|
// Compile left operand
|
|
c.compileExpression(e.Left)
|
|
|
|
// Duplicate to check condition
|
|
c.emit(types.OpDup, 0)
|
|
|
|
// Need to check if it's truthy to short-circuit
|
|
falseJumpPos := len(c.instructions)
|
|
c.emit(types.OpJumpIfFalse, 0) // Jump to right eval if false
|
|
|
|
// If truthy, jump to end
|
|
trueJumpPos := len(c.instructions)
|
|
c.emit(types.OpJump, 0) // Jump to end if true
|
|
|
|
// Position for false case
|
|
falsePos := len(c.instructions)
|
|
c.instructions[falseJumpPos].Operand = falsePos
|
|
|
|
// Pop the duplicate since we'll replace it
|
|
c.emit(types.OpPop, 0)
|
|
|
|
// Compile right operand
|
|
c.compileExpression(e.Right)
|
|
|
|
// End position
|
|
endPos := len(c.instructions)
|
|
c.instructions[trueJumpPos].Operand = endPos
|
|
|
|
default:
|
|
// Original infix expression compilation
|
|
c.compileExpression(e.Left)
|
|
c.compileExpression(e.Right)
|
|
|
|
// Generate the appropriate operation
|
|
switch e.Operator {
|
|
case "+":
|
|
c.emit(types.OpAdd, 0)
|
|
case "-":
|
|
c.emit(types.OpSubtract, 0)
|
|
case "*":
|
|
c.emit(types.OpMultiply, 0)
|
|
case "/":
|
|
c.emit(types.OpDivide, 0)
|
|
case "==":
|
|
c.emit(types.OpEqual, 0)
|
|
case "!=":
|
|
c.emit(types.OpNotEqual, 0)
|
|
case "<":
|
|
c.emit(types.OpLessThan, 0)
|
|
case ">":
|
|
c.emit(types.OpGreaterThan, 0)
|
|
case "<=":
|
|
c.emit(types.OpLessEqual, 0)
|
|
case ">=":
|
|
c.emit(types.OpGreaterEqual, 0)
|
|
default:
|
|
panic(fmt.Sprintf("Unknown infix operator: %s", e.Operator))
|
|
}
|
|
}
|
|
|
|
case *parser.PrefixExpression:
|
|
// Compile the operand
|
|
c.compileExpression(e.Right)
|
|
|
|
// Generate the appropriate operation
|
|
switch e.Operator {
|
|
case "-":
|
|
c.emit(types.OpNegate, 0)
|
|
case "not":
|
|
c.emit(types.OpNot, 0)
|
|
default:
|
|
panic(fmt.Sprintf("Unknown prefix operator: %s", e.Operator))
|
|
}
|
|
|
|
case *parser.GroupedExpression:
|
|
// Just compile the inner expression
|
|
c.compileExpression(e.Expr)
|
|
|
|
case *parser.IfExpression:
|
|
// Compile condition
|
|
c.compileExpression(e.Condition)
|
|
|
|
// Emit jump-if-false with placeholder
|
|
jumpNotTruePos := len(c.instructions)
|
|
c.emit(types.OpJumpIfFalse, 0) // Will backpatch
|
|
|
|
// Compile consequence (then block)
|
|
if e.Consequence != nil {
|
|
lastStmtIndex := len(e.Consequence.Statements) - 1
|
|
for i, stmt := range e.Consequence.Statements {
|
|
if i == lastStmtIndex {
|
|
// For the last statement, we need to ensure it leaves a value
|
|
if exprStmt, ok := stmt.(*parser.ExpressionStatement); ok {
|
|
c.compileExpression(exprStmt.Expression)
|
|
} else {
|
|
c.compileStatement(stmt)
|
|
// Push null if not an expression statement
|
|
nullIndex := c.addConstant(nil)
|
|
c.emit(types.OpConstant, nullIndex)
|
|
}
|
|
} else {
|
|
c.compileStatement(stmt)
|
|
}
|
|
}
|
|
|
|
// If no statements, push null
|
|
if len(e.Consequence.Statements) == 0 {
|
|
nullIndex := c.addConstant(nil)
|
|
c.emit(types.OpConstant, nullIndex)
|
|
}
|
|
} else {
|
|
// No consequence block, push null
|
|
nullIndex := c.addConstant(nil)
|
|
c.emit(types.OpConstant, nullIndex)
|
|
}
|
|
|
|
// Emit jump to skip else part
|
|
jumpPos := len(c.instructions)
|
|
c.emit(types.OpJump, 0) // Will backpatch
|
|
|
|
// Backpatch jump-if-false to point to else
|
|
afterConsequencePos := len(c.instructions)
|
|
c.instructions[jumpNotTruePos].Operand = afterConsequencePos
|
|
|
|
// Compile alternative (else block)
|
|
if e.Alternative != nil {
|
|
lastStmtIndex := len(e.Alternative.Statements) - 1
|
|
for i, stmt := range e.Alternative.Statements {
|
|
if i == lastStmtIndex {
|
|
// For the last statement, we need to ensure it leaves a value
|
|
if exprStmt, ok := stmt.(*parser.ExpressionStatement); ok {
|
|
c.compileExpression(exprStmt.Expression)
|
|
} else {
|
|
c.compileStatement(stmt)
|
|
// Push null if not an expression statement
|
|
nullIndex := c.addConstant(nil)
|
|
c.emit(types.OpConstant, nullIndex)
|
|
}
|
|
} else {
|
|
c.compileStatement(stmt)
|
|
}
|
|
}
|
|
|
|
// If no statements, push null
|
|
if len(e.Alternative.Statements) == 0 {
|
|
nullIndex := c.addConstant(nil)
|
|
c.emit(types.OpConstant, nullIndex)
|
|
}
|
|
} else {
|
|
// No else - push null
|
|
nullIndex := c.addConstant(nil)
|
|
c.emit(types.OpConstant, nullIndex)
|
|
}
|
|
|
|
// Backpatch jump to point after else
|
|
afterAlternativePos := len(c.instructions)
|
|
c.instructions[jumpPos].Operand = afterAlternativePos
|
|
}
|
|
}
|
|
|
|
func (c *compiler) compileFunctionLiteral(fn *parser.FunctionLiteral) {
|
|
// Save the current compiler state
|
|
parentCompiler := c.currentFunction
|
|
|
|
// Create a new function compiler
|
|
fnCompiler := &functionCompiler{
|
|
constants: []any{},
|
|
instructions: []types.Instruction{},
|
|
numParams: len(fn.Parameters),
|
|
upvalues: []upvalueInfo{},
|
|
}
|
|
|
|
c.currentFunction = fnCompiler
|
|
|
|
// Enter a new scope for the function body
|
|
c.enterScope()
|
|
|
|
// Declare parameters as local variables
|
|
for _, param := range fn.Parameters {
|
|
c.declareVariable(param.Value)
|
|
paramIndex := c.addConstant(param.Value)
|
|
c.emit(types.OpSetLocal, paramIndex)
|
|
}
|
|
|
|
// Compile the function body
|
|
for _, stmt := range fn.Body.Statements {
|
|
c.compileStatement(stmt)
|
|
}
|
|
|
|
// Ensure the function always returns a value
|
|
// If the last instruction is not a return, add one
|
|
if len(fnCompiler.instructions) == 0 ||
|
|
(len(fnCompiler.instructions) > 0 && fnCompiler.instructions[len(fnCompiler.instructions)-1].Opcode != types.OpReturn) {
|
|
nullIndex := c.addConstant(nil)
|
|
c.emit(types.OpConstant, nullIndex)
|
|
c.emit(types.OpReturn, 0)
|
|
}
|
|
|
|
// Exit the function scope
|
|
c.exitScope()
|
|
|
|
// Restore the parent compiler
|
|
c.currentFunction = parentCompiler
|
|
|
|
// Extract upvalue information for closure creation
|
|
upvalueIndexes := make([]int, len(fnCompiler.upvalues))
|
|
for i, upvalue := range fnCompiler.upvalues {
|
|
upvalueIndexes[i] = upvalue.index
|
|
}
|
|
|
|
// Create a Function object and add it to the constants
|
|
function := types.NewFunction(
|
|
fnCompiler.instructions,
|
|
fnCompiler.numParams,
|
|
fnCompiler.constants,
|
|
upvalueIndexes,
|
|
)
|
|
|
|
functionIndex := c.addConstant(function)
|
|
c.emit(types.OpFunction, functionIndex)
|
|
}
|
|
|
|
func (c *compiler) addConstant(value any) int {
|
|
if c.currentFunction != nil {
|
|
c.currentFunction.constants = append(c.currentFunction.constants, value)
|
|
return len(c.currentFunction.constants) - 1
|
|
}
|
|
|
|
c.constants = append(c.constants, value)
|
|
return len(c.constants) - 1
|
|
}
|
|
|
|
func (c *compiler) emit(op types.Opcode, operand int) {
|
|
instruction := types.Instruction{
|
|
Opcode: op,
|
|
Operand: operand,
|
|
}
|
|
|
|
if c.currentFunction != nil {
|
|
c.currentFunction.instructions = append(c.currentFunction.instructions, instruction)
|
|
} else {
|
|
c.instructions = append(c.instructions, instruction)
|
|
}
|
|
}
|
|
|
|
func (c *compiler) compileFunctionDeclaration(fn *parser.FunctionStatement) {
|
|
// Save the current compiler state
|
|
parentCompiler := c.currentFunction
|
|
|
|
// Create a new function compiler
|
|
fnCompiler := &functionCompiler{
|
|
constants: []any{},
|
|
instructions: []types.Instruction{},
|
|
numParams: len(fn.Parameters),
|
|
upvalues: []upvalueInfo{},
|
|
}
|
|
|
|
c.currentFunction = fnCompiler
|
|
|
|
// Enter a new scope for the function body
|
|
c.enterScope()
|
|
|
|
// Declare parameters as local variables
|
|
for _, param := range fn.Parameters {
|
|
c.declareVariable(param.Value)
|
|
paramIndex := c.addConstant(param.Value)
|
|
c.emit(types.OpSetLocal, paramIndex)
|
|
}
|
|
|
|
// Compile the function body
|
|
for _, stmt := range fn.Body.Statements {
|
|
c.compileStatement(stmt)
|
|
}
|
|
|
|
// Ensure the function always returns a value
|
|
// If the last instruction is not a return, add one
|
|
if len(fnCompiler.instructions) == 0 || fnCompiler.instructions[len(fnCompiler.instructions)-1].Opcode != types.OpReturn {
|
|
nullIndex := c.addConstant(nil)
|
|
c.emit(types.OpConstant, nullIndex)
|
|
c.emit(types.OpReturn, 0)
|
|
}
|
|
|
|
// Exit the function scope
|
|
c.exitScope()
|
|
|
|
// Restore the parent compiler
|
|
c.currentFunction = parentCompiler
|
|
|
|
// Extract upvalue information for closure creation
|
|
upvalueIndexes := make([]int, len(fnCompiler.upvalues))
|
|
for i, upvalue := range fnCompiler.upvalues {
|
|
upvalueIndexes[i] = upvalue.index
|
|
}
|
|
|
|
// Create a Function object and add it to the constants
|
|
function := types.NewFunction(
|
|
fnCompiler.instructions,
|
|
fnCompiler.numParams,
|
|
fnCompiler.constants,
|
|
upvalueIndexes,
|
|
)
|
|
|
|
functionIndex := c.addConstant(function)
|
|
c.emit(types.OpFunction, functionIndex)
|
|
|
|
// Store the function in a global variable
|
|
nameIndex := c.addConstant(fn.Name.Value)
|
|
|
|
// Use SetGlobal for top-level variables to persist between REPL lines
|
|
if len(c.scopes) <= 1 {
|
|
c.emit(types.OpSetGlobal, nameIndex)
|
|
} else {
|
|
c.declareVariable(fn.Name.Value)
|
|
c.emit(types.OpSetLocal, nameIndex)
|
|
}
|
|
}
|