Mako/compiler/compiler.go
2025-05-07 09:45:50 -05:00

520 lines
12 KiB
Go

package compiler
import (
"git.sharkk.net/Sharkk/Mako/types"
)
// Compiler manages the state for compilation
type Compiler struct {
chunk *types.Chunk
locals []local
scopeDepth int
enclosing *Compiler
upvalues []types.Upvalue
errors []*types.MakoError
currentFunction *types.Function
}
type local struct {
name string
depth int
isCaptured bool
}
// New creates a new compiler for a function
func New(name string, enclosing *Compiler) *Compiler {
compiler := &Compiler{
chunk: &types.Chunk{
Code: make([]types.Instruction, 0, 8),
Constants: make([]types.Value, 0, 8),
},
locals: make([]local, 0, 8),
scopeDepth: 0,
enclosing: enclosing,
currentFunction: &types.Function{
Name: name,
Chunk: nil,
Upvalues: nil,
LocalCount: 0,
},
}
// The first local slot is implicitly used by the function itself
compiler.locals = append(compiler.locals, local{
name: name,
depth: 0,
})
compiler.currentFunction.Chunk = compiler.chunk
return compiler
}
// Compile compiles statements into bytecode
func (c *Compiler) Compile(statements []types.Statement) (*types.Function, []*types.MakoError) {
for _, stmt := range statements {
c.statement(stmt)
}
// Implicit return
c.emitReturn()
c.currentFunction.LocalCount = len(c.locals)
c.currentFunction.Upvalues = c.upvalues
c.currentFunction.UpvalueCount = len(c.upvalues)
return c.currentFunction, c.errors
}
// statement compiles a statement
func (c *Compiler) statement(stmt types.Statement) {
switch s := stmt.(type) {
case types.ExpressionStmt:
c.expression(s.Expression)
c.emit(types.OP_POP, nil)
case types.AssignStmt:
c.assignment(s)
case types.FunctionStmt:
c.function(s)
case types.ReturnStmt:
c.returnStmt(s)
case types.IfStmt:
c.ifStatement(s)
case types.EchoStmt:
c.expression(s.Value)
c.emit(types.OP_PRINT, nil)
case types.BlockStmt:
c.beginScope()
for _, blockStmt := range s.Statements {
c.statement(blockStmt)
}
c.endScope()
}
}
// assignment compiles a variable assignment
func (c *Compiler) assignment(stmt types.AssignStmt) {
c.expression(stmt.Value)
if c.scopeDepth > 0 {
// Try to find it as a local first
for i := len(c.locals) - 1; i >= 0; i-- {
if c.locals[i].name == stmt.Name.Lexeme && c.locals[i].depth <= c.scopeDepth {
c.emit(types.OP_SET_LOCAL, []byte{byte(i)})
return
}
}
}
// Global variable
idx := c.makeConstant(types.StringValue{Value: stmt.Name.Lexeme})
c.emit(types.OP_SET_GLOBAL, []byte{idx})
}
// expression compiles an expression
func (c *Compiler) expression(expr types.Expression) {
switch e := expr.(type) {
case types.LiteralExpr:
c.literal(e)
case types.BinaryExpr:
c.binary(e)
case types.UnaryExpr:
c.unary(e)
case types.VariableExpr:
c.variable(e)
case types.CallExpr:
c.call(e)
case types.FunctionExpr:
c.functionExpr(e)
}
}
// literal compiles a literal value
func (c *Compiler) literal(expr types.LiteralExpr) {
switch v := expr.Value.(type) {
case nil:
c.emit(types.OP_NIL, nil)
case bool:
if v {
c.emit(types.OP_TRUE, nil)
} else {
c.emit(types.OP_FALSE, nil)
}
case float64:
idx := c.makeConstant(types.NumberValue{Value: v})
c.emit(types.OP_CONSTANT, []byte{idx})
case string:
idx := c.makeConstant(types.StringValue{Value: v})
c.emit(types.OP_CONSTANT, []byte{idx})
}
}
// binary compiles a binary expression
func (c *Compiler) binary(expr types.BinaryExpr) {
c.expression(expr.Left)
c.expression(expr.Right)
switch expr.Operator.Type {
case types.PLUS:
c.emit(types.OP_ADD, nil)
case types.MINUS:
c.emit(types.OP_SUBTRACT, nil)
case types.STAR:
c.emit(types.OP_MULTIPLY, nil)
case types.SLASH:
c.emit(types.OP_DIVIDE, nil)
case types.EQUAL_EQUAL:
c.emit(types.OP_EQUAL, nil)
case types.BANG_EQUAL:
c.emit(types.OP_EQUAL, nil)
c.emit(types.OP_NOT, nil)
case types.LESS:
c.emit(types.OP_LESS, nil)
case types.LESS_EQUAL:
c.emit(types.OP_GREATER, nil)
c.emit(types.OP_NOT, nil)
case types.GREATER:
c.emit(types.OP_GREATER, nil)
case types.GREATER_EQUAL:
c.emit(types.OP_LESS, nil)
c.emit(types.OP_NOT, nil)
case types.AND:
// Short-circuit evaluation
endJump := c.emitJump(types.OP_JUMP_IF_FALSE)
c.emit(types.OP_POP, nil)
c.patchJump(endJump)
case types.OR:
// Short-circuit evaluation
skipJump := c.emitJump(types.OP_JUMP_IF_FALSE)
endJump := c.emitJump(types.OP_JUMP)
c.patchJump(skipJump)
c.emit(types.OP_POP, nil)
c.patchJump(endJump)
}
}
// unary compiles a unary expression
func (c *Compiler) unary(expr types.UnaryExpr) {
c.expression(expr.Right)
switch expr.Operator.Type {
case types.MINUS:
c.emit(types.OP_NEGATE, nil)
}
}
// variable compiles a variable reference
func (c *Compiler) variable(expr types.VariableExpr) {
// Try to resolve as local
for i := len(c.locals) - 1; i >= 0; i-- {
if c.locals[i].name == expr.Name.Lexeme && c.locals[i].depth <= c.scopeDepth {
c.emit(types.OP_GET_LOCAL, []byte{byte(i)})
return
}
}
// Try to resolve as upvalue
if index, ok := c.resolveUpvalue(expr.Name.Lexeme); ok {
c.emit(types.OP_GET_UPVALUE, []byte{byte(index)})
return
}
// Global variable
idx := c.makeConstant(types.StringValue{Value: expr.Name.Lexeme})
c.emit(types.OP_GET_GLOBAL, []byte{idx})
}
// call compiles a function call
func (c *Compiler) call(expr types.CallExpr) {
c.expression(expr.Callee)
argCount := byte(len(expr.Arguments))
for _, arg := range expr.Arguments {
c.expression(arg)
}
c.emit(types.OP_CALL, []byte{argCount})
}
// function compiles a function declaration
func (c *Compiler) function(stmt types.FunctionStmt) {
// Add function name to current scope
var global byte
if c.scopeDepth > 0 {
c.addLocal(stmt.Name.Lexeme)
} else {
global = c.makeConstant(types.StringValue{Value: stmt.Name.Lexeme})
}
// Compile function body with new compiler
compiler := New(stmt.Name.Lexeme, c)
// Add parameters
compiler.beginScope()
for _, param := range stmt.Params {
compiler.addLocal(param.Lexeme)
compiler.currentFunction.Arity++
}
compiler.currentFunction.IsVariadic = stmt.IsVariadic
// Compile function body
for _, bodyStmt := range stmt.Body {
compiler.statement(bodyStmt)
}
// Implicit return if needed
compiler.emitReturn()
// Create function object
compiler.currentFunction.UpvalueCount = len(compiler.upvalues)
compiler.currentFunction.Upvalues = compiler.upvalues
// Add function to constants
idx := c.makeConstant(types.ClosureValue{
Closure: &types.Closure{
Function: compiler.currentFunction,
Upvalues: make([]*types.Upvalue, compiler.currentFunction.UpvalueCount),
},
})
// Emit closure instruction and upvalue info
c.emit(types.OP_CLOSURE, []byte{idx})
// Add upvalue information to instruction stream
for _, upvalue := range compiler.upvalues {
if upvalue.IsLocal {
c.emit(0, []byte{1, upvalue.Index}) // 1 means isLocal=true
} else {
c.emit(0, []byte{0, upvalue.Index}) // 0 means isLocal=false
}
}
// Store function in variable
if c.scopeDepth > 0 {
// It's already on the stack
} else {
c.emit(types.OP_SET_GLOBAL, []byte{global})
}
}
// functionExpr compiles an anonymous function expression
func (c *Compiler) functionExpr(expr types.FunctionExpr) {
// Compile function body with new compiler
compiler := New("", c)
// Add parameters
compiler.beginScope()
for _, param := range expr.Params {
compiler.addLocal(param.Lexeme)
compiler.currentFunction.Arity++
}
compiler.currentFunction.IsVariadic = expr.IsVariadic
// Compile function body
for _, bodyStmt := range expr.Body {
compiler.statement(bodyStmt)
}
// Implicit return
compiler.emitReturn()
// Create function object
compiler.currentFunction.UpvalueCount = len(compiler.upvalues)
compiler.currentFunction.Upvalues = compiler.upvalues
// Add function to constants
idx := c.makeConstant(types.ClosureValue{
Closure: &types.Closure{
Function: compiler.currentFunction,
Upvalues: make([]*types.Upvalue, compiler.currentFunction.UpvalueCount),
},
})
// Emit closure instruction and upvalue info
c.emit(types.OP_CLOSURE, []byte{idx})
// Add upvalue information to instruction stream
for _, upvalue := range compiler.upvalues {
if upvalue.IsLocal {
c.emit(0, []byte{1, upvalue.Index}) // 1 means isLocal=true
} else {
c.emit(0, []byte{0, upvalue.Index}) // 0 means isLocal=false
}
}
}
// returnStmt compiles a return statement
func (c *Compiler) returnStmt(stmt types.ReturnStmt) {
if stmt.Value == nil {
c.emit(types.OP_NIL, nil)
} else {
c.expression(stmt.Value)
}
c.emit(types.OP_RETURN, nil)
}
// ifStatement compiles an if statement
func (c *Compiler) ifStatement(stmt types.IfStmt) {
// Compile condition
c.expression(stmt.Condition)
// Emit the then branch jump
thenJump := c.emitJump(types.OP_JUMP_IF_FALSE)
// Compile then branch
c.emit(types.OP_POP, nil) // Pop condition
for _, thenStmt := range stmt.ThenBranch {
c.statement(thenStmt)
}
// Jump over else branch
elseJump := c.emitJump(types.OP_JUMP)
// Patch then jump
c.patchJump(thenJump)
c.emit(types.OP_POP, nil) // Pop condition
// Compile elseif branches
for _, elseif := range stmt.ElseIfs {
c.expression(elseif.Condition)
// Jump if this condition is false
elseifJump := c.emitJump(types.OP_JUMP_IF_FALSE)
c.emit(types.OP_POP, nil) // Pop condition
for _, elseifStmt := range elseif.Body {
c.statement(elseifStmt)
}
// Jump to end after this branch
endJump := c.emitJump(types.OP_JUMP)
// Patch elseif jump to next branch
c.patchJump(elseifJump)
c.emit(types.OP_POP, nil) // Pop condition
// Collect end jumps for patching
elseJump = endJump
}
// Compile else branch
for _, elseStmt := range stmt.ElseBranch {
c.statement(elseStmt)
}
// Patch else jump
c.patchJump(elseJump)
}
// Helper methods
func (c *Compiler) emit(op types.OpCode, operands []byte) {
// Get source position from the current token
pos := types.SourcePos{Line: 0, Column: 0} // In real implementation, track from token
instruction := types.Instruction{
Op: op,
Operands: operands,
Pos: pos,
}
c.chunk.Code = append(c.chunk.Code, instruction)
}
func (c *Compiler) emitJump(op types.OpCode) int {
c.emit(op, []byte{0xFF, 0xFF}) // Placeholder for jump offset
return len(c.chunk.Code) - 1
}
func (c *Compiler) patchJump(jumpIndex int) {
// -2 to adjust for the size of the jump offset itself
jumpDistance := len(c.chunk.Code) - jumpIndex - 1
// Store jump distance in the instruction's operands
// Using big-endian format: high byte first, low byte second
c.chunk.Code[jumpIndex].Operands = []byte{
byte((jumpDistance >> 8) & 0xFF),
byte(jumpDistance & 0xFF),
}
}
func (c *Compiler) emitReturn() {
c.emit(types.OP_NIL, nil)
c.emit(types.OP_RETURN, nil)
}
func (c *Compiler) makeConstant(value types.Value) byte {
c.chunk.Constants = append(c.chunk.Constants, value)
return byte(len(c.chunk.Constants) - 1)
}
func (c *Compiler) beginScope() {
c.scopeDepth++
}
func (c *Compiler) endScope() {
c.scopeDepth--
// Remove locals from this scope
for len(c.locals) > 0 && c.locals[len(c.locals)-1].depth > c.scopeDepth {
if c.locals[len(c.locals)-1].isCaptured {
c.emit(types.OP_CLOSE_UPVALUE, nil)
} else {
c.emit(types.OP_POP, nil)
}
c.locals = c.locals[:len(c.locals)-1]
}
}
func (c *Compiler) addLocal(name string) {
// Check if we've hit the limit of local variables
if len(c.locals) >= 256 {
c.error("Too many local variables in function.")
return
}
c.locals = append(c.locals, local{
name: name,
depth: c.scopeDepth,
})
}
func (c *Compiler) resolveUpvalue(name string) (int, bool) {
// If no enclosing scope, can't be an upvalue
if c.enclosing == nil {
return 0, false
}
// Try to find in immediate enclosing function's locals
for i := range c.enclosing.locals {
if c.enclosing.locals[i].name == name {
c.enclosing.locals[i].isCaptured = true
upvalue := types.Upvalue{
Index: uint8(i),
IsLocal: true,
}
c.upvalues = append(c.upvalues, upvalue)
return len(c.upvalues) - 1, true
}
}
// Try to find in higher enclosing scopes
if upvalueIndex, found := c.enclosing.resolveUpvalue(name); found {
upvalue := types.Upvalue{
Index: uint8(upvalueIndex),
IsLocal: false,
}
c.upvalues = append(c.upvalues, upvalue)
return len(c.upvalues) - 1, true
}
return 0, false
}
func (c *Compiler) error(message string) {
c.errors = append(c.errors, &types.MakoError{
Message: message,
Line: 0, // In real implementation, track from token
Column: 0,
})
}