Mako/compiler/compiler.go
2025-06-11 21:50:55 -05:00

1103 lines
27 KiB
Go

package compiler
import (
"fmt"
"git.sharkk.net/Sharkk/Mako/parser"
)
// Compiler holds the compilation state and compiles AST to bytecode
type Compiler struct {
current *CompilerState // Current compilation state
enclosing *CompilerState // Enclosing function state for closures
errors []CompileError // Compilation errors
}
// NewCompiler creates a new compiler instance
func NewCompiler() *Compiler {
return &Compiler{
current: NewCompilerState(FunctionTypeScript),
errors: make([]CompileError, 0),
}
}
// Compile compiles a program AST to bytecode with optimizations
func (c *Compiler) Compile(program *parser.Program) (*Chunk, []CompileError) {
for _, stmt := range program.Statements {
c.compileStatement(stmt)
}
c.current.EmitInstruction(OpReturnNil)
if len(c.errors) > 0 {
return nil, c.errors
}
// Apply optimizations
c.optimizeChunk(c.current.Chunk)
return c.current.Chunk, nil
}
// Statement compilation
func (c *Compiler) compileStatement(stmt parser.Statement) {
if lineNode := c.getLineFromNode(stmt); lineNode != 0 {
c.current.SetLine(lineNode)
}
switch s := stmt.(type) {
case *parser.StructStatement:
c.compileStructStatement(s)
case *parser.MethodDefinition:
c.compileMethodDefinition(s)
case *parser.Assignment:
c.compileAssignment(s)
case *parser.ExpressionStatement:
c.compileExpression(s.Expression)
c.current.EmitInstruction(OpPop) // Discard result
case *parser.EchoStatement:
c.compileExpression(s.Value)
c.current.EmitInstruction(OpEcho)
case *parser.IfStatement:
c.compileIfStatement(s)
case *parser.WhileStatement:
c.compileWhileStatement(s)
case *parser.ForStatement:
c.compileForStatement(s)
case *parser.ForInStatement:
c.compileForInStatement(s)
case *parser.ReturnStatement:
c.compileReturnStatement(s)
case *parser.ExitStatement:
c.compileExitStatement(s)
case *parser.BreakStatement:
c.current.EmitBreak()
default:
c.addError(fmt.Sprintf("unknown statement type: %T", stmt))
}
}
// Expression compilation with constant folding
func (c *Compiler) compileExpression(expr parser.Expression) {
if lineNode := c.getLineFromNode(expr); lineNode != 0 {
c.current.SetLine(lineNode)
}
// Try constant folding first
if constValue := c.tryConstantFold(expr); constValue != nil {
c.emitConstant(*constValue)
return
}
switch e := expr.(type) {
case *parser.Identifier:
c.compileIdentifier(e)
case *parser.NumberLiteral:
c.compileNumberLiteral(e)
case *parser.StringLiteral:
c.compileStringLiteral(e)
case *parser.BooleanLiteral:
c.compileBooleanLiteral(e)
case *parser.NilLiteral:
c.compileNilLiteral(e)
case *parser.TableLiteral:
c.compileTableLiteral(e)
case *parser.StructConstructor:
c.compileStructConstructor(e)
case *parser.FunctionLiteral:
c.compileFunctionLiteral(e)
case *parser.CallExpression:
c.compileCallExpression(e)
case *parser.PrefixExpression:
c.compilePrefixExpression(e)
case *parser.InfixExpression:
c.compileInfixExpression(e)
case *parser.IndexExpression:
c.compileIndexExpression(e)
case *parser.DotExpression:
c.compileDotExpression(e)
case *parser.Assignment:
c.compileAssignmentExpression(e)
default:
c.addError(fmt.Sprintf("unknown expression type: %T", expr))
}
}
// Constant folding engine
func (c *Compiler) tryConstantFold(expr parser.Expression) *Value {
switch e := expr.(type) {
case *parser.NumberLiteral:
return &Value{Type: ValueNumber, Data: e.Value}
case *parser.StringLiteral:
return &Value{Type: ValueString, Data: e.Value}
case *parser.BooleanLiteral:
return &Value{Type: ValueBool, Data: e.Value}
case *parser.NilLiteral:
return &Value{Type: ValueNil, Data: nil}
case *parser.PrefixExpression:
return c.foldPrefixExpression(e)
case *parser.InfixExpression:
return c.foldInfixExpression(e)
}
return nil
}
func (c *Compiler) foldPrefixExpression(expr *parser.PrefixExpression) *Value {
rightValue := c.tryConstantFold(expr.Right)
if rightValue == nil {
return nil
}
switch expr.Operator {
case "-":
if rightValue.Type == ValueNumber {
return &Value{Type: ValueNumber, Data: -rightValue.Data.(float64)}
}
case "not":
return &Value{Type: ValueBool, Data: !c.isTruthy(*rightValue)}
}
return nil
}
func (c *Compiler) foldInfixExpression(expr *parser.InfixExpression) *Value {
leftValue := c.tryConstantFold(expr.Left)
rightValue := c.tryConstantFold(expr.Right)
if leftValue == nil || rightValue == nil {
return nil
}
// Arithmetic operations
if leftValue.Type == ValueNumber && rightValue.Type == ValueNumber {
l := leftValue.Data.(float64)
r := rightValue.Data.(float64)
switch expr.Operator {
case "+":
return &Value{Type: ValueNumber, Data: l + r}
case "-":
return &Value{Type: ValueNumber, Data: l - r}
case "*":
return &Value{Type: ValueNumber, Data: l * r}
case "/":
if r != 0 {
return &Value{Type: ValueNumber, Data: l / r}
}
case "<":
return &Value{Type: ValueBool, Data: l < r}
case "<=":
return &Value{Type: ValueBool, Data: l <= r}
case ">":
return &Value{Type: ValueBool, Data: l > r}
case ">=":
return &Value{Type: ValueBool, Data: l >= r}
}
}
// Comparison operations
switch expr.Operator {
case "==":
return &Value{Type: ValueBool, Data: c.valuesEqual(*leftValue, *rightValue)}
case "!=":
return &Value{Type: ValueBool, Data: !c.valuesEqual(*leftValue, *rightValue)}
}
// Logical operations
switch expr.Operator {
case "and":
if !c.isTruthy(*leftValue) {
return leftValue
}
return rightValue
case "or":
if c.isTruthy(*leftValue) {
return leftValue
}
return rightValue
}
return nil
}
func (c *Compiler) isTruthy(value Value) bool {
switch value.Type {
case ValueNil:
return false
case ValueBool:
return value.Data.(bool)
default:
return true
}
}
func (c *Compiler) valuesEqual(a, b Value) bool {
if a.Type != b.Type {
return false
}
switch a.Type {
case ValueNil:
return true
case ValueBool:
return a.Data.(bool) == b.Data.(bool)
case ValueNumber:
return a.Data.(float64) == b.Data.(float64)
case ValueString:
return a.Data.(string) == b.Data.(string)
default:
return false
}
}
// Optimized constant emission
func (c *Compiler) emitConstant(value Value) {
switch value.Type {
case ValueNil:
c.current.EmitInstruction(OpLoadNil)
case ValueBool:
if value.Data.(bool) {
c.current.EmitInstruction(OpLoadTrue)
} else {
c.current.EmitInstruction(OpLoadFalse)
}
case ValueNumber:
num := value.Data.(float64)
if num == 0 {
c.current.EmitInstruction(OpLoadZero)
} else if num == 1 {
c.current.EmitInstruction(OpLoadOne)
} else {
index := c.current.AddConstant(value)
if index == -1 {
c.addError("too many constants")
return
}
c.current.EmitInstruction(OpLoadConst, uint16(index))
}
default:
index := c.current.AddConstant(value)
if index == -1 {
c.addError("too many constants")
return
}
c.current.EmitInstruction(OpLoadConst, uint16(index))
}
}
// Literal compilation with optimizations
func (c *Compiler) compileNumberLiteral(node *parser.NumberLiteral) {
value := Value{Type: ValueNumber, Data: node.Value}
c.emitConstant(value)
}
func (c *Compiler) compileStringLiteral(node *parser.StringLiteral) {
value := Value{Type: ValueString, Data: node.Value}
c.emitConstant(value)
}
func (c *Compiler) compileBooleanLiteral(node *parser.BooleanLiteral) {
value := Value{Type: ValueBool, Data: node.Value}
c.emitConstant(value)
}
func (c *Compiler) compileNilLiteral(node *parser.NilLiteral) {
c.current.EmitInstruction(OpLoadNil)
}
// Optimized identifier compilation
func (c *Compiler) compileIdentifier(node *parser.Identifier) {
slot := c.current.ResolveLocal(node.Value)
if slot != -1 {
if slot == -2 {
c.addError("can't read local variable in its own initializer")
return
}
c.emitLoadLocal(slot)
return
}
upvalue := c.resolveUpvalue(node.Value)
if upvalue != -1 {
c.current.EmitInstruction(OpGetUpvalue, uint16(upvalue))
return
}
// Global variable
value := Value{Type: ValueString, Data: node.Value}
index := c.current.AddConstant(value)
if index == -1 {
c.addError("too many constants")
return
}
c.current.EmitInstruction(OpLoadGlobal, uint16(index))
}
// Optimized local variable access
func (c *Compiler) emitLoadLocal(slot int) {
switch slot {
case 0:
c.current.EmitInstruction(OpLoadLocal0)
case 1:
c.current.EmitInstruction(OpLoadLocal1)
case 2:
c.current.EmitInstruction(OpLoadLocal2)
default:
c.current.EmitInstruction(OpLoadLocal, uint16(slot))
}
}
func (c *Compiler) emitStoreLocal(slot int) {
switch slot {
case 0:
c.current.EmitInstruction(OpStoreLocal0)
case 1:
c.current.EmitInstruction(OpStoreLocal1)
case 2:
c.current.EmitInstruction(OpStoreLocal2)
default:
c.current.EmitInstruction(OpStoreLocal, uint16(slot))
}
}
// Assignment compilation with optimizations
func (c *Compiler) compileAssignment(node *parser.Assignment) {
c.compileExpression(node.Value)
switch target := node.Target.(type) {
case *parser.Identifier:
if node.IsDeclaration {
if c.current.FunctionType == FunctionTypeScript && c.current.ScopeDepth == 0 {
// Global variable
value := Value{Type: ValueString, Data: target.Value}
index := c.current.AddConstant(value)
if index == -1 {
c.addError("too many constants")
return
}
c.current.EmitInstruction(OpStoreGlobal, uint16(index))
} else {
// Local variable
if err := c.current.AddLocal(target.Value); err != nil {
c.addError(err.Error())
return
}
c.current.MarkInitialized()
}
} else {
// Assignment to existing variable
slot := c.current.ResolveLocal(target.Value)
if slot != -1 {
c.emitStoreLocal(slot)
} else {
upvalue := c.resolveUpvalue(target.Value)
if upvalue != -1 {
c.current.EmitInstruction(OpSetUpvalue, uint16(upvalue))
} else {
// Global assignment
value := Value{Type: ValueString, Data: target.Value}
index := c.current.AddConstant(value)
if index == -1 {
c.addError("too many constants")
return
}
c.current.EmitInstruction(OpStoreGlobal, uint16(index))
}
}
}
case *parser.DotExpression:
c.compileDotAssignment(target)
case *parser.IndexExpression:
c.compileExpression(target.Left)
c.compileExpression(target.Index)
c.current.EmitInstruction(OpSetIndex)
default:
c.addError("invalid assignment target")
}
}
// Optimized dot expression assignment
func (c *Compiler) compileDotAssignment(dot *parser.DotExpression) {
// Check for local.field optimization
if ident, ok := dot.Left.(*parser.Identifier); ok {
slot := c.current.ResolveLocal(ident.Value)
if slot != -1 && slot <= 2 {
// Use optimized local field assignment
value := Value{Type: ValueString, Data: dot.Key}
index := c.current.AddConstant(value)
if index == -1 {
c.addError("too many constants")
return
}
c.current.EmitInstruction(OpSetLocalField, uint16(slot), uint16(index))
return
}
}
// Fall back to regular field assignment
c.compileExpression(dot.Left)
value := Value{Type: ValueString, Data: dot.Key}
index := c.current.AddConstant(value)
if index == -1 {
c.addError("too many constants")
return
}
c.current.EmitInstruction(OpSetField, uint16(index))
}
func (c *Compiler) compileAssignmentExpression(node *parser.Assignment) {
c.compileAssignment(node)
// Assignment expressions leave the assigned value on stack
}
// Optimized operator compilation
func (c *Compiler) compilePrefixExpression(node *parser.PrefixExpression) {
c.compileExpression(node.Right)
switch node.Operator {
case "-":
c.current.EmitInstruction(OpNeg)
case "not":
c.current.EmitInstruction(OpNot)
default:
c.addError(fmt.Sprintf("unknown prefix operator: %s", node.Operator))
}
}
func (c *Compiler) compileInfixExpression(node *parser.InfixExpression) {
// Check for increment/decrement patterns
if c.tryOptimizeIncDec(node) {
return
}
// Handle short-circuit operators
if node.Operator == "and" {
c.compileExpression(node.Left)
jump := c.current.EmitJump(OpJumpIfFalse)
c.current.EmitInstruction(OpPop)
c.compileExpression(node.Right)
c.current.PatchJump(jump)
return
}
if node.Operator == "or" {
c.compileExpression(node.Left)
elseJump := c.current.EmitJump(OpJumpIfFalse)
endJump := c.current.EmitJump(OpJump)
c.current.PatchJump(elseJump)
c.current.EmitInstruction(OpPop)
c.compileExpression(node.Right)
c.current.PatchJump(endJump)
return
}
// Regular binary operators
c.compileExpression(node.Left)
c.compileExpression(node.Right)
switch node.Operator {
case "+":
c.current.EmitInstruction(OpAdd)
case "-":
c.current.EmitInstruction(OpSub)
case "*":
c.current.EmitInstruction(OpMul)
case "/":
c.current.EmitInstruction(OpDiv)
case "==":
c.current.EmitInstruction(OpEq)
case "!=":
c.current.EmitInstruction(OpNeq)
case "<":
c.current.EmitInstruction(OpLt)
case "<=":
c.current.EmitInstruction(OpLte)
case ">":
c.current.EmitInstruction(OpGt)
case ">=":
c.current.EmitInstruction(OpGte)
default:
c.addError(fmt.Sprintf("unknown infix operator: %s", node.Operator))
}
}
// Try to optimize increment/decrement patterns
func (c *Compiler) tryOptimizeIncDec(node *parser.InfixExpression) bool {
// Look for patterns like: var = var + 1 or var = var - 1
if node.Operator != "+" && node.Operator != "-" {
return false
}
leftIdent, ok := node.Left.(*parser.Identifier)
if !ok {
return false
}
rightLit, ok := node.Right.(*parser.NumberLiteral)
if !ok || rightLit.Value != 1 {
return false
}
slot := c.current.ResolveLocal(leftIdent.Value)
if slot == -1 {
return false
}
// Emit optimized increment/decrement
if node.Operator == "+" {
c.current.EmitInstruction(OpInc, uint16(slot))
} else {
c.current.EmitInstruction(OpDec, uint16(slot))
}
// Load the result back onto stack
c.emitLoadLocal(slot)
return true
}
// Optimized dot expression compilation
func (c *Compiler) compileDotExpression(node *parser.DotExpression) {
// Check for local.field optimization
if ident, ok := node.Left.(*parser.Identifier); ok {
slot := c.current.ResolveLocal(ident.Value)
if slot != -1 && slot <= 2 {
// Use optimized local field access
value := Value{Type: ValueString, Data: node.Key}
index := c.current.AddConstant(value)
if index == -1 {
c.addError("too many constants")
return
}
c.current.EmitInstruction(OpGetLocalField, uint16(slot), uint16(index))
return
}
}
// Fall back to regular field access
c.compileExpression(node.Left)
value := Value{Type: ValueString, Data: node.Key}
index := c.current.AddConstant(value)
if index == -1 {
c.addError("too many constants")
return
}
c.current.EmitInstruction(OpGetField, uint16(index))
}
// Optimized function call compilation
func (c *Compiler) compileCallExpression(node *parser.CallExpression) {
// Check for calls to local functions
if ident, ok := node.Function.(*parser.Identifier); ok {
slot := c.current.ResolveLocal(ident.Value)
if slot == 0 || slot == 1 {
// Compile arguments
for _, arg := range node.Arguments {
c.compileExpression(arg)
}
// Use optimized call instruction
if slot == 0 {
c.current.EmitInstruction(OpCallLocal0, uint16(len(node.Arguments)))
} else {
c.current.EmitInstruction(OpCallLocal1, uint16(len(node.Arguments)))
}
return
}
}
// Regular function call
c.compileExpression(node.Function)
for _, arg := range node.Arguments {
c.compileExpression(arg)
}
c.current.EmitInstruction(OpCall, uint16(len(node.Arguments)))
}
// Control flow compilation (unchanged from original)
func (c *Compiler) compileIfStatement(node *parser.IfStatement) {
c.compileExpression(node.Condition)
thenJump := c.current.EmitJump(OpJumpIfFalse)
c.current.EmitInstruction(OpPop)
c.current.BeginScope()
for _, stmt := range node.Body {
c.compileStatement(stmt)
}
c.current.EndScope()
elseJump := c.current.EmitJump(OpJump)
c.current.PatchJump(thenJump)
c.current.EmitInstruction(OpPop)
var elseifJumps []int
for _, elseif := range node.ElseIfs {
c.compileExpression(elseif.Condition)
nextJump := c.current.EmitJump(OpJumpIfFalse)
c.current.EmitInstruction(OpPop)
c.current.BeginScope()
for _, stmt := range elseif.Body {
c.compileStatement(stmt)
}
c.current.EndScope()
elseifJumps = append(elseifJumps, c.current.EmitJump(OpJump))
c.current.PatchJump(nextJump)
c.current.EmitInstruction(OpPop)
}
if len(node.Else) > 0 {
c.current.BeginScope()
for _, stmt := range node.Else {
c.compileStatement(stmt)
}
c.current.EndScope()
}
c.current.PatchJump(elseJump)
for _, jump := range elseifJumps {
c.current.PatchJump(jump)
}
}
func (c *Compiler) compileWhileStatement(node *parser.WhileStatement) {
c.current.EnterLoop()
c.compileExpression(node.Condition)
exitJump := c.current.EmitJump(OpJumpIfFalse)
c.current.EmitInstruction(OpPop)
c.current.BeginScope()
for _, stmt := range node.Body {
c.compileStatement(stmt)
}
c.current.EndScope()
// Use optimized loop back instruction
jump := len(c.current.Chunk.Code) - c.current.LoopStart + 2
c.current.EmitInstruction(OpLoopBack, uint16(jump))
c.current.PatchJump(exitJump)
c.current.EmitInstruction(OpPop)
c.current.ExitLoop()
}
// Remaining compilation methods (struct, function, etc.) unchanged but with optimization calls
func (c *Compiler) compileForStatement(node *parser.ForStatement) {
c.current.BeginScope()
c.current.EnterLoop()
c.compileExpression(node.Start)
if err := c.current.AddLocal(node.Variable.Value); err != nil {
c.addError(err.Error())
return
}
c.current.MarkInitialized()
loopVar := len(c.current.Locals) - 1
c.compileExpression(node.End)
endSlot := len(c.current.Locals)
if err := c.current.AddLocal("__end"); err != nil {
c.addError(err.Error())
return
}
c.current.MarkInitialized()
if node.Step != nil {
c.compileExpression(node.Step)
} else {
c.current.EmitInstruction(OpLoadOne)
}
stepSlot := len(c.current.Locals)
if err := c.current.AddLocal("__step"); err != nil {
c.addError(err.Error())
return
}
c.current.MarkInitialized()
conditionStart := len(c.current.Chunk.Code)
c.emitLoadLocal(loopVar)
c.emitLoadLocal(endSlot)
c.current.EmitInstruction(OpLte)
exitJump := c.current.EmitJump(OpJumpIfFalse)
c.current.EmitInstruction(OpPop)
for _, stmt := range node.Body {
c.compileStatement(stmt)
}
c.emitLoadLocal(loopVar)
c.emitLoadLocal(stepSlot)
c.current.EmitInstruction(OpAdd)
c.emitStoreLocal(loopVar)
jumpBack := len(c.current.Chunk.Code) - conditionStart + 2
c.current.EmitInstruction(OpLoopBack, uint16(jumpBack))
c.current.PatchJump(exitJump)
c.current.EmitInstruction(OpPop)
c.current.ExitLoop()
c.current.EndScope()
}
// Apply chunk-level optimizations
func (c *Compiler) optimizeChunk(chunk *Chunk) {
c.peepholeOptimize(chunk)
c.eliminateDeadCode(chunk)
}
func (c *Compiler) peepholeOptimize(chunk *Chunk) {
// Simple peephole optimizations
code := chunk.Code
i := 0
for i < len(code)-6 {
op1, _, next1 := DecodeInstruction(code, i)
op2, _, _ := DecodeInstruction(code, next1)
// Remove POP followed by same constant load
if op1 == OpPop && (op2 == OpLoadTrue || op2 == OpLoadFalse || op2 == OpLoadNil) {
// Could optimize in some cases
}
i = next1
}
}
func (c *Compiler) eliminateDeadCode(chunk *Chunk) {
// Remove unreachable code after returns/exits
code := chunk.Code
i := 0
for i < len(code) {
op, _, next := DecodeInstruction(code, i)
if op == OpReturn || op == OpReturnNil || op == OpExit {
// Mark subsequent instructions as dead until next reachable point
for j := next; j < len(code); j++ {
_, _, nextNext := DecodeInstruction(code, j)
if c.isJumpTarget(chunk, j) {
break
}
code[j] = uint8(OpNoop)
j = nextNext - 1
}
}
i = next
}
}
func (c *Compiler) isJumpTarget(chunk *Chunk, offset int) bool {
// Simple check - would need more sophisticated analysis in real implementation
return false
}
// Keep all other methods from original compiler.go unchanged
// (struct compilation, function compilation, etc.)
func (c *Compiler) compileStructStatement(node *parser.StructStatement) {
fields := make([]StructField, len(node.Fields))
for i, field := range node.Fields {
fields[i] = StructField{
Name: field.Name,
Type: c.typeInfoToValueType(field.TypeHint),
Offset: uint16(i),
}
}
structDef := Struct{
Name: node.Name,
Fields: fields,
Methods: make(map[string]uint16),
ID: node.ID,
}
c.current.Chunk.Structs = append(c.current.Chunk.Structs, structDef)
}
func (c *Compiler) compileMethodDefinition(node *parser.MethodDefinition) {
enclosing := c.current
c.current = NewCompilerState(FunctionTypeMethod)
c.current.parent = enclosing
c.enclosing = enclosing
c.current.BeginScope()
if err := c.current.AddLocal("self"); err != nil {
c.addError(err.Error())
return
}
c.current.MarkInitialized()
for _, param := range node.Function.Parameters {
if err := c.current.AddLocal(param.Name); err != nil {
c.addError(err.Error())
return
}
c.current.MarkInitialized()
}
for _, stmt := range node.Function.Body {
c.compileStatement(stmt)
}
c.current.EmitInstruction(OpReturnNil)
function := Function{
Name: node.MethodName,
Arity: len(node.Function.Parameters) + 1,
Variadic: node.Function.Variadic,
LocalCount: len(c.current.Locals),
UpvalCount: len(c.current.Upvalues),
Chunk: *c.current.Chunk,
Defaults: []Value{},
}
functionIndex := len(enclosing.Chunk.Functions)
enclosing.Chunk.Functions = append(enclosing.Chunk.Functions, function)
for i := range enclosing.Chunk.Structs {
if enclosing.Chunk.Structs[i].ID == node.StructID {
enclosing.Chunk.Structs[i].Methods[node.MethodName] = uint16(functionIndex)
break
}
}
c.current = enclosing
c.enclosing = nil
}
func (c *Compiler) compileStructConstructor(node *parser.StructConstructor) {
c.current.EmitInstruction(OpNewStruct, node.StructID)
for _, field := range node.Fields {
if field.Key != nil {
c.current.EmitInstruction(OpDup)
var fieldName string
if ident, ok := field.Key.(*parser.Identifier); ok {
fieldName = ident.Value
} else if str, ok := field.Key.(*parser.StringLiteral); ok {
fieldName = str.Value
} else {
c.addError("struct field names must be identifiers or strings")
continue
}
fieldIndex := c.findStructFieldIndex(node.StructID, fieldName)
if fieldIndex == -1 {
c.addError(fmt.Sprintf("struct has no field '%s'", fieldName))
continue
}
c.compileExpression(field.Value)
c.current.EmitInstruction(OpSetProperty, uint16(fieldIndex))
} else {
c.addError("struct constructors require named field assignments")
}
}
}
func (c *Compiler) compileTableLiteral(node *parser.TableLiteral) {
c.current.EmitInstruction(OpNewTable)
for _, pair := range node.Pairs {
if pair.Key == nil {
c.compileExpression(pair.Value)
c.current.EmitInstruction(OpTableInsert)
} else {
c.current.EmitInstruction(OpDup)
c.compileExpression(pair.Key)
c.compileExpression(pair.Value)
c.current.EmitInstruction(OpSetIndex)
}
}
}
func (c *Compiler) compileIndexExpression(node *parser.IndexExpression) {
c.compileExpression(node.Left)
c.compileExpression(node.Index)
c.current.EmitInstruction(OpGetIndex)
}
func (c *Compiler) compileFunctionLiteral(node *parser.FunctionLiteral) {
enclosing := c.current
c.current = NewCompilerState(FunctionTypeFunction)
c.current.parent = enclosing
c.enclosing = enclosing
c.current.BeginScope()
for _, param := range node.Parameters {
if err := c.current.AddLocal(param.Name); err != nil {
c.addError(err.Error())
return
}
c.current.MarkInitialized()
}
for _, stmt := range node.Body {
c.compileStatement(stmt)
}
c.current.EmitInstruction(OpReturnNil)
function := Function{
Name: "",
Arity: len(node.Parameters),
Variadic: node.Variadic,
LocalCount: len(c.current.Locals),
UpvalCount: len(c.current.Upvalues),
Chunk: *c.current.Chunk,
Defaults: []Value{},
}
functionIndex := len(enclosing.Chunk.Functions)
enclosing.Chunk.Functions = append(enclosing.Chunk.Functions, function)
c.current = enclosing
c.enclosing = nil
c.current.EmitInstruction(OpClosure, uint16(functionIndex), uint16(function.UpvalCount))
}
func (c *Compiler) compileReturnStatement(node *parser.ReturnStatement) {
if node.Value != nil {
c.compileExpression(node.Value)
c.current.EmitInstruction(OpReturn)
} else {
c.current.EmitInstruction(OpReturnNil)
}
}
func (c *Compiler) compileExitStatement(node *parser.ExitStatement) {
if node.Value != nil {
c.compileExpression(node.Value)
} else {
c.current.EmitInstruction(OpLoadZero)
}
c.current.EmitInstruction(OpExit)
}
func (c *Compiler) compileForInStatement(node *parser.ForInStatement) {
c.current.BeginScope()
c.current.EnterLoop()
c.compileExpression(node.Iterable)
if node.Key != nil {
if err := c.current.AddLocal(node.Key.Value); err != nil {
c.addError(err.Error())
return
}
c.current.MarkInitialized()
}
if err := c.current.AddLocal(node.Value.Value); err != nil {
c.addError(err.Error())
return
}
c.current.MarkInitialized()
conditionStart := len(c.current.Chunk.Code)
c.current.EmitInstruction(OpLoadNil)
c.current.EmitInstruction(OpNot)
exitJump := c.current.EmitJump(OpJumpIfFalse)
c.current.EmitInstruction(OpPop)
for _, stmt := range node.Body {
c.compileStatement(stmt)
}
jumpBack := len(c.current.Chunk.Code) - conditionStart + 2
c.current.EmitInstruction(OpLoopBack, uint16(jumpBack))
c.current.PatchJump(exitJump)
c.current.EmitInstruction(OpPop)
c.current.ExitLoop()
c.current.EndScope()
}
// Helper methods
func (c *Compiler) resolveUpvalue(name string) int {
if c.enclosing == nil {
return -1
}
local := c.enclosing.ResolveLocal(name)
if local != -1 {
c.enclosing.Locals[local].IsCaptured = true
return c.current.AddUpvalue(uint8(local), true)
}
upvalue := c.resolveUpvalueInEnclosing(name)
if upvalue != -1 {
return c.current.AddUpvalue(uint8(upvalue), false)
}
return -1
}
func (c *Compiler) resolveUpvalueInEnclosing(name string) int {
if c.enclosing == nil {
return -1
}
return -1
}
func (c *Compiler) typeInfoToValueType(typeInfo parser.TypeInfo) ValueType {
switch typeInfo.Type {
case parser.TypeNumber:
return ValueNumber
case parser.TypeString:
return ValueString
case parser.TypeBool:
return ValueBool
case parser.TypeNil:
return ValueNil
case parser.TypeTable:
return ValueTable
case parser.TypeFunction:
return ValueFunction
case parser.TypeStruct:
return ValueStruct
default:
return ValueNil
}
}
func (c *Compiler) findStructFieldIndex(structID uint16, fieldName string) int {
for _, structDef := range c.current.Chunk.Structs {
if structDef.ID == structID {
for i, field := range structDef.Fields {
if field.Name == fieldName {
return i
}
}
break
}
}
return -1
}
func (c *Compiler) addError(message string) {
c.errors = append(c.errors, CompileError{
Message: message,
Line: c.current.CurrentLine,
Column: 0,
})
}
func (c *Compiler) Errors() []CompileError { return c.errors }
func (c *Compiler) HasErrors() bool { return len(c.errors) > 0 }
func (c *Compiler) getLineFromNode(node any) int {
return 0 // Placeholder
}