package compiler import ( "fmt" "git.sharkk.net/Sharkk/Mako/parser" ) // Compiler holds the compilation state and compiles AST to bytecode type Compiler struct { current *CompilerState // Current compilation state enclosing *CompilerState // Enclosing function state for closures errors []CompileError // Compilation errors } // NewCompiler creates a new compiler instance func NewCompiler() *Compiler { return &Compiler{ current: NewCompilerState(FunctionTypeScript), errors: make([]CompileError, 0), } } // Compile compiles a program AST to bytecode func (c *Compiler) Compile(program *parser.Program) (*Chunk, []CompileError) { for _, stmt := range program.Statements { c.compileStatement(stmt) } c.current.EmitInstruction(OpReturnNil) if len(c.errors) > 0 { return nil, c.errors } return c.current.Chunk, nil } // Statement compilation func (c *Compiler) compileStatement(stmt parser.Statement) { // Extract line from any statement that has position info if lineNode := c.getLineFromNode(stmt); lineNode != 0 { c.current.SetLine(lineNode) } switch s := stmt.(type) { case *parser.StructStatement: c.compileStructStatement(s) case *parser.MethodDefinition: c.compileMethodDefinition(s) case *parser.Assignment: c.compileAssignment(s) case *parser.ExpressionStatement: c.compileExpression(s.Expression) c.current.EmitInstruction(OpPop) // Discard result case *parser.EchoStatement: c.compileExpression(s.Value) c.current.EmitInstruction(OpEcho) case *parser.IfStatement: c.compileIfStatement(s) case *parser.WhileStatement: c.compileWhileStatement(s) case *parser.ForStatement: c.compileForStatement(s) case *parser.ForInStatement: c.compileForInStatement(s) case *parser.ReturnStatement: c.compileReturnStatement(s) case *parser.ExitStatement: c.compileExitStatement(s) case *parser.BreakStatement: c.current.EmitBreak() default: c.addError(fmt.Sprintf("unknown statement type: %T", stmt)) } } // Expression compilation func (c *Compiler) compileExpression(expr parser.Expression) { if lineNode := c.getLineFromNode(expr); lineNode != 0 { c.current.SetLine(lineNode) } switch e := expr.(type) { case *parser.Identifier: c.compileIdentifier(e) case *parser.NumberLiteral: c.compileNumberLiteral(e) case *parser.StringLiteral: c.compileStringLiteral(e) case *parser.BooleanLiteral: c.compileBooleanLiteral(e) case *parser.NilLiteral: c.compileNilLiteral(e) case *parser.TableLiteral: c.compileTableLiteral(e) case *parser.StructConstructor: c.compileStructConstructor(e) case *parser.FunctionLiteral: c.compileFunctionLiteral(e) case *parser.CallExpression: c.compileCallExpression(e) case *parser.PrefixExpression: c.compilePrefixExpression(e) case *parser.InfixExpression: c.compileInfixExpression(e) case *parser.IndexExpression: c.compileIndexExpression(e) case *parser.DotExpression: c.compileDotExpression(e) case *parser.Assignment: c.compileAssignmentExpression(e) default: c.addError(fmt.Sprintf("unknown expression type: %T", expr)) } } // Literal compilation func (c *Compiler) compileNumberLiteral(node *parser.NumberLiteral) { value := Value{Type: ValueNumber, Data: node.Value} index := c.current.AddConstant(value) if index == -1 { c.addError("too many constants") return } c.current.EmitInstruction(OpLoadConst, uint16(index)) } func (c *Compiler) compileStringLiteral(node *parser.StringLiteral) { value := Value{Type: ValueString, Data: node.Value} index := c.current.AddConstant(value) if index == -1 { c.addError("too many constants") return } c.current.EmitInstruction(OpLoadConst, uint16(index)) } func (c *Compiler) compileBooleanLiteral(node *parser.BooleanLiteral) { value := Value{Type: ValueBool, Data: node.Value} index := c.current.AddConstant(value) if index == -1 { c.addError("too many constants") return } c.current.EmitInstruction(OpLoadConst, uint16(index)) } func (c *Compiler) compileNilLiteral(node *parser.NilLiteral) { value := Value{Type: ValueNil, Data: nil} index := c.current.AddConstant(value) if index == -1 { c.addError("too many constants") return } c.current.EmitInstruction(OpLoadConst, uint16(index)) } // Identifier compilation func (c *Compiler) compileIdentifier(node *parser.Identifier) { // Try local variables first slot := c.current.ResolveLocal(node.Value) if slot != -1 { if slot == -2 { c.addError("can't read local variable in its own initializer") return } c.current.EmitInstruction(OpLoadLocal, uint16(slot)) return } // Try upvalues upvalue := c.resolveUpvalue(node.Value) if upvalue != -1 { c.current.EmitInstruction(OpGetUpvalue, uint16(upvalue)) return } // Must be global value := Value{Type: ValueString, Data: node.Value} index := c.current.AddConstant(value) if index == -1 { c.addError("too many constants") return } c.current.EmitInstruction(OpLoadGlobal, uint16(index)) } // Assignment compilation func (c *Compiler) compileAssignment(node *parser.Assignment) { c.compileExpression(node.Value) switch target := node.Target.(type) { case *parser.Identifier: if node.IsDeclaration { // Check if we're at global scope if c.current.FunctionType == FunctionTypeScript && c.current.ScopeDepth == 0 { // Global variable declaration - treat as global assignment value := Value{Type: ValueString, Data: target.Value} index := c.current.AddConstant(value) if index == -1 { c.addError("too many constants") return } c.current.EmitInstruction(OpStoreGlobal, uint16(index)) } else { // Local variable declaration if err := c.current.AddLocal(target.Value); err != nil { c.addError(err.Error()) return } c.current.MarkInitialized() } } else { // Assignment to existing variable slot := c.current.ResolveLocal(target.Value) if slot != -1 { c.current.EmitInstruction(OpStoreLocal, uint16(slot)) } else { upvalue := c.resolveUpvalue(target.Value) if upvalue != -1 { c.current.EmitInstruction(OpSetUpvalue, uint16(upvalue)) } else { // Global assignment value := Value{Type: ValueString, Data: target.Value} index := c.current.AddConstant(value) if index == -1 { c.addError("too many constants") return } c.current.EmitInstruction(OpStoreGlobal, uint16(index)) } } } case *parser.DotExpression: // table.field = value c.compileExpression(target.Left) value := Value{Type: ValueString, Data: target.Key} index := c.current.AddConstant(value) if index == -1 { c.addError("too many constants") return } c.current.EmitInstruction(OpSetField, uint16(index)) case *parser.IndexExpression: // table[key] = value c.compileExpression(target.Left) c.compileExpression(target.Index) c.current.EmitInstruction(OpSetIndex) default: c.addError("invalid assignment target") } } func (c *Compiler) compileAssignmentExpression(node *parser.Assignment) { c.compileAssignment(node) // Assignment expressions leave the assigned value on stack } // Operator compilation func (c *Compiler) compilePrefixExpression(node *parser.PrefixExpression) { c.compileExpression(node.Right) switch node.Operator { case "-": c.current.EmitInstruction(OpNeg) case "not": c.current.EmitInstruction(OpNot) default: c.addError(fmt.Sprintf("unknown prefix operator: %s", node.Operator)) } } func (c *Compiler) compileInfixExpression(node *parser.InfixExpression) { // Handle short-circuit operators specially if node.Operator == "and" { c.compileExpression(node.Left) jump := c.current.EmitJump(OpJumpIfFalse) c.current.EmitInstruction(OpPop) c.compileExpression(node.Right) c.current.PatchJump(jump) return } if node.Operator == "or" { c.compileExpression(node.Left) elseJump := c.current.EmitJump(OpJumpIfFalse) endJump := c.current.EmitJump(OpJump) c.current.PatchJump(elseJump) c.current.EmitInstruction(OpPop) c.compileExpression(node.Right) c.current.PatchJump(endJump) return } // Regular binary operators c.compileExpression(node.Left) c.compileExpression(node.Right) switch node.Operator { case "+": c.current.EmitInstruction(OpAdd) case "-": c.current.EmitInstruction(OpSub) case "*": c.current.EmitInstruction(OpMul) case "/": c.current.EmitInstruction(OpDiv) case "==": c.current.EmitInstruction(OpEq) case "!=": c.current.EmitInstruction(OpNeq) case "<": c.current.EmitInstruction(OpLt) case "<=": c.current.EmitInstruction(OpLte) case ">": c.current.EmitInstruction(OpGt) case ">=": c.current.EmitInstruction(OpGte) default: c.addError(fmt.Sprintf("unknown infix operator: %s", node.Operator)) } } // Control flow compilation func (c *Compiler) compileIfStatement(node *parser.IfStatement) { c.compileExpression(node.Condition) // Jump over then branch if condition is false thenJump := c.current.EmitJump(OpJumpIfFalse) c.current.EmitInstruction(OpPop) // Compile then branch c.current.BeginScope() for _, stmt := range node.Body { c.compileStatement(stmt) } c.current.EndScope() // Jump over else branches elseJump := c.current.EmitJump(OpJump) c.current.PatchJump(thenJump) c.current.EmitInstruction(OpPop) // Compile elseif branches var elseifJumps []int for _, elseif := range node.ElseIfs { c.compileExpression(elseif.Condition) nextJump := c.current.EmitJump(OpJumpIfFalse) c.current.EmitInstruction(OpPop) c.current.BeginScope() for _, stmt := range elseif.Body { c.compileStatement(stmt) } c.current.EndScope() elseifJumps = append(elseifJumps, c.current.EmitJump(OpJump)) c.current.PatchJump(nextJump) c.current.EmitInstruction(OpPop) } // Compile else branch if len(node.Else) > 0 { c.current.BeginScope() for _, stmt := range node.Else { c.compileStatement(stmt) } c.current.EndScope() } // Patch all jumps to end c.current.PatchJump(elseJump) for _, jump := range elseifJumps { c.current.PatchJump(jump) } } func (c *Compiler) compileWhileStatement(node *parser.WhileStatement) { c.current.EnterLoop() c.compileExpression(node.Condition) exitJump := c.current.EmitJump(OpJumpIfFalse) c.current.EmitInstruction(OpPop) c.current.BeginScope() for _, stmt := range node.Body { c.compileStatement(stmt) } c.current.EndScope() // Jump back to condition jump := len(c.current.Chunk.Code) - c.current.LoopStart + 2 c.current.EmitInstruction(OpJump, uint16(jump)) c.current.PatchJump(exitJump) c.current.EmitInstruction(OpPop) c.current.ExitLoop() } // For loop compilation func (c *Compiler) compileForStatement(node *parser.ForStatement) { c.current.BeginScope() c.current.EnterLoop() // Initialize loop variable c.compileExpression(node.Start) if err := c.current.AddLocal(node.Variable.Value); err != nil { c.addError(err.Error()) return } c.current.MarkInitialized() loopVar := len(c.current.Locals) - 1 // Compile end value c.compileExpression(node.End) endSlot := len(c.current.Locals) if err := c.current.AddLocal("__end"); err != nil { c.addError(err.Error()) return } c.current.MarkInitialized() // Compile step value (default 1) if node.Step != nil { c.compileExpression(node.Step) } else { value := Value{Type: ValueNumber, Data: float64(1)} index := c.current.AddConstant(value) if index == -1 { c.addError("too many constants") return } c.current.EmitInstruction(OpLoadConst, uint16(index)) } stepSlot := len(c.current.Locals) if err := c.current.AddLocal("__step"); err != nil { c.addError(err.Error()) return } c.current.MarkInitialized() // Loop condition: check if loop variable <= end (for positive step) conditionStart := len(c.current.Chunk.Code) c.current.EmitInstruction(OpLoadLocal, uint16(loopVar)) // Load loop var c.current.EmitInstruction(OpLoadLocal, uint16(endSlot)) // Load end c.current.EmitInstruction(OpLte) // Compare exitJump := c.current.EmitJump(OpJumpIfFalse) c.current.EmitInstruction(OpPop) // Loop body for _, stmt := range node.Body { c.compileStatement(stmt) } // Increment loop variable: var = var + step c.current.EmitInstruction(OpLoadLocal, uint16(loopVar)) // Load current value c.current.EmitInstruction(OpLoadLocal, uint16(stepSlot)) // Load step c.current.EmitInstruction(OpAdd) // Add c.current.EmitInstruction(OpStoreLocal, uint16(loopVar)) // Store back // Jump back to condition jumpBack := len(c.current.Chunk.Code) - conditionStart + 2 c.current.EmitInstruction(OpJump, uint16(jumpBack)) // Exit c.current.PatchJump(exitJump) c.current.EmitInstruction(OpPop) c.current.ExitLoop() c.current.EndScope() } // For-in loop compilation func (c *Compiler) compileForInStatement(node *parser.ForInStatement) { c.current.BeginScope() c.current.EnterLoop() // Compile iterable and set up iterator c.compileExpression(node.Iterable) // For simplicity, assume table iteration // In a full implementation, we'd need to handle different iterator types // Create iterator variables if node.Key != nil { if err := c.current.AddLocal(node.Key.Value); err != nil { c.addError(err.Error()) return } c.current.MarkInitialized() } if err := c.current.AddLocal(node.Value.Value); err != nil { c.addError(err.Error()) return } c.current.MarkInitialized() // Loop condition (simplified - would need actual iterator logic) conditionStart := len(c.current.Chunk.Code) // For now, just emit a simple condition that will be false // In real implementation, this would call iterator methods nilValue := Value{Type: ValueNil, Data: nil} index := c.current.AddConstant(nilValue) if index == -1 { c.addError("too many constants") return } c.current.EmitInstruction(OpLoadConst, uint16(index)) c.current.EmitInstruction(OpNot) // Convert nil to true, everything else to false exitJump := c.current.EmitJump(OpJumpIfFalse) c.current.EmitInstruction(OpPop) // Loop body for _, stmt := range node.Body { c.compileStatement(stmt) } // Jump back to condition jumpBack := len(c.current.Chunk.Code) - conditionStart + 2 c.current.EmitInstruction(OpJump, uint16(jumpBack)) // Exit c.current.PatchJump(exitJump) c.current.EmitInstruction(OpPop) c.current.ExitLoop() c.current.EndScope() } // Function compilation func (c *Compiler) compileFunctionLiteral(node *parser.FunctionLiteral) { // Create new compiler state for function enclosing := c.current c.current = NewCompilerState(FunctionTypeFunction) c.current.parent = enclosing c.enclosing = enclosing // Begin function scope c.current.BeginScope() // Add parameters as local variables for _, param := range node.Parameters { if err := c.current.AddLocal(param.Name); err != nil { c.addError(err.Error()) return } c.current.MarkInitialized() } // Compile function body for _, stmt := range node.Body { c.compileStatement(stmt) } // Implicit return nil if no explicit return c.current.EmitInstruction(OpReturnNil) // Create function object function := Function{ Name: "", // Anonymous function Arity: len(node.Parameters), Variadic: node.Variadic, LocalCount: len(c.current.Locals), UpvalCount: len(c.current.Upvalues), Chunk: *c.current.Chunk, Defaults: []Value{}, // TODO: Handle default parameters } // Add function to parent chunk functionIndex := len(enclosing.Chunk.Functions) enclosing.Chunk.Functions = append(enclosing.Chunk.Functions, function) // Restore enclosing state c.current = enclosing c.enclosing = nil // Emit closure instruction c.current.EmitInstruction(OpClosure, uint16(functionIndex), uint16(function.UpvalCount)) } // Struct compilation func (c *Compiler) compileStructStatement(node *parser.StructStatement) { // Convert parser fields to compiler fields fields := make([]StructField, len(node.Fields)) for i, field := range node.Fields { fields[i] = StructField{ Name: field.Name, Type: c.typeInfoToValueType(field.TypeHint), Offset: uint16(i), } } // Create struct definition structDef := Struct{ Name: node.Name, Fields: fields, Methods: make(map[string]uint16), ID: node.ID, } // Add to chunk c.current.Chunk.Structs = append(c.current.Chunk.Structs, structDef) } func (c *Compiler) compileMethodDefinition(node *parser.MethodDefinition) { // Compile method as a function enclosing := c.current c.current = NewCompilerState(FunctionTypeMethod) c.current.parent = enclosing c.enclosing = enclosing // Begin function scope c.current.BeginScope() // Add 'self' parameter if err := c.current.AddLocal("self"); err != nil { c.addError(err.Error()) return } c.current.MarkInitialized() // Add method parameters for _, param := range node.Function.Parameters { if err := c.current.AddLocal(param.Name); err != nil { c.addError(err.Error()) return } c.current.MarkInitialized() } // Compile method body for _, stmt := range node.Function.Body { c.compileStatement(stmt) } // Implicit return nil c.current.EmitInstruction(OpReturnNil) // Create function object function := Function{ Name: node.MethodName, Arity: len(node.Function.Parameters) + 1, // +1 for self Variadic: node.Function.Variadic, LocalCount: len(c.current.Locals), UpvalCount: len(c.current.Upvalues), Chunk: *c.current.Chunk, Defaults: []Value{}, } // Add to parent chunk functionIndex := len(enclosing.Chunk.Functions) enclosing.Chunk.Functions = append(enclosing.Chunk.Functions, function) // Find struct and add method reference for i := range enclosing.Chunk.Structs { if enclosing.Chunk.Structs[i].ID == node.StructID { enclosing.Chunk.Structs[i].Methods[node.MethodName] = uint16(functionIndex) break } } // Restore state c.current = enclosing c.enclosing = nil } func (c *Compiler) compileStructConstructor(node *parser.StructConstructor) { // Create new struct instance c.current.EmitInstruction(OpNewStruct, node.StructID) // Initialize fields for _, field := range node.Fields { if field.Key != nil { // Named field assignment c.current.EmitInstruction(OpDup) // Duplicate struct reference // Get field name var fieldName string if ident, ok := field.Key.(*parser.Identifier); ok { fieldName = ident.Value } else if str, ok := field.Key.(*parser.StringLiteral); ok { fieldName = str.Value } else { c.addError("struct field names must be identifiers or strings") continue } // Find field index in struct definition fieldIndex := c.findStructFieldIndex(node.StructID, fieldName) if fieldIndex == -1 { c.addError(fmt.Sprintf("struct has no field '%s'", fieldName)) continue } c.compileExpression(field.Value) c.current.EmitInstruction(OpSetProperty, uint16(fieldIndex)) } else { // Positional field assignment (not typical for structs) c.addError("struct constructors require named field assignments") } } } // Table operations func (c *Compiler) compileTableLiteral(node *parser.TableLiteral) { c.current.EmitInstruction(OpNewTable) for _, pair := range node.Pairs { if pair.Key == nil { // Array-style element c.compileExpression(pair.Value) c.current.EmitInstruction(OpTableInsert) } else { // Key-value pair c.current.EmitInstruction(OpDup) // Duplicate table reference c.compileExpression(pair.Key) c.compileExpression(pair.Value) c.current.EmitInstruction(OpSetIndex) } } } func (c *Compiler) compileDotExpression(node *parser.DotExpression) { c.compileExpression(node.Left) value := Value{Type: ValueString, Data: node.Key} index := c.current.AddConstant(value) if index == -1 { c.addError("too many constants") return } c.current.EmitInstruction(OpGetField, uint16(index)) } func (c *Compiler) compileIndexExpression(node *parser.IndexExpression) { c.compileExpression(node.Left) c.compileExpression(node.Index) c.current.EmitInstruction(OpGetIndex) } // Function calls func (c *Compiler) compileCallExpression(node *parser.CallExpression) { c.compileExpression(node.Function) // Compile arguments for _, arg := range node.Arguments { c.compileExpression(arg) } c.current.EmitInstruction(OpCall, uint16(len(node.Arguments))) } func (c *Compiler) compileReturnStatement(node *parser.ReturnStatement) { if node.Value != nil { c.compileExpression(node.Value) c.current.EmitInstruction(OpReturn) } else { c.current.EmitInstruction(OpReturnNil) } } func (c *Compiler) compileExitStatement(node *parser.ExitStatement) { if node.Value != nil { c.compileExpression(node.Value) } else { // Default exit code 0 value := Value{Type: ValueNumber, Data: float64(0)} index := c.current.AddConstant(value) if index == -1 { c.addError("too many constants") return } c.current.EmitInstruction(OpLoadConst, uint16(index)) } c.current.EmitInstruction(OpExit) } // Helper methods func (c *Compiler) resolveUpvalue(name string) int { if c.enclosing == nil { return -1 } local := c.enclosing.ResolveLocal(name) if local != -1 { c.enclosing.Locals[local].IsCaptured = true return c.current.AddUpvalue(uint8(local), true) } upvalue := c.resolveUpvalueInEnclosing(name) if upvalue != -1 { return c.current.AddUpvalue(uint8(upvalue), false) } return -1 } func (c *Compiler) resolveUpvalueInEnclosing(name string) int { if c.enclosing == nil { return -1 } // This would recursively check enclosing scopes // Simplified for now return -1 } func (c *Compiler) typeInfoToValueType(typeInfo parser.TypeInfo) ValueType { switch typeInfo.Type { case parser.TypeNumber: return ValueNumber case parser.TypeString: return ValueString case parser.TypeBool: return ValueBool case parser.TypeNil: return ValueNil case parser.TypeTable: return ValueTable case parser.TypeFunction: return ValueFunction case parser.TypeStruct: return ValueStruct default: return ValueNil } } func (c *Compiler) findStructFieldIndex(structID uint16, fieldName string) int { for _, structDef := range c.current.Chunk.Structs { if structDef.ID == structID { for i, field := range structDef.Fields { if field.Name == fieldName { return i } } break } } return -1 } // Enhanced error reporting func (c *Compiler) addError(message string) { c.errors = append(c.errors, CompileError{ Message: message, Line: c.current.CurrentLine, Column: 0, // Column tracking would need more work }) } // Error reporting func (c *Compiler) Errors() []CompileError { return c.errors } func (c *Compiler) HasErrors() bool { return len(c.errors) > 0 } // Helper to extract line info from AST nodes func (c *Compiler) getLineFromNode(node any) int { // Since AST nodes don't store position, we'd need to modify the parser // For now, track during compilation by passing line info through return 0 // Placeholder } // Bytecode disassembler helper for debugging func DisassembleInstruction(chunk *Chunk, offset int) int { if offset >= len(chunk.Code) { return offset } line := 0 if offset < len(chunk.Lines) { line = chunk.Lines[offset] } fmt.Printf("%04d ", offset) if offset > 0 && line == chunk.Lines[offset-1] { fmt.Print(" | ") } else { fmt.Printf("%4d ", line) } op := Opcode(chunk.Code[offset]) fmt.Printf("%-16s", opcodeNames[op]) switch op { case OpLoadConst, OpLoadLocal, OpStoreLocal: operand := uint16(chunk.Code[offset+1]) | (uint16(chunk.Code[offset+2]) << 8) fmt.Printf(" %4d", operand) if op == OpLoadConst && int(operand) < len(chunk.Constants) { fmt.Printf(" '") printValue(chunk.Constants[operand]) fmt.Printf("'") } return offset + 3 case OpJump, OpJumpIfTrue, OpJumpIfFalse: jump := uint16(chunk.Code[offset+1]) | (uint16(chunk.Code[offset+2]) << 8) fmt.Printf(" %4d -> %d", jump, offset+3+int(jump)) return offset + 3 default: return offset + 1 } } func printValue(value Value) { switch value.Type { case ValueNil: fmt.Print("nil") case ValueBool: fmt.Print(value.Data.(bool)) case ValueNumber: fmt.Printf("%.2f", value.Data.(float64)) case ValueString: fmt.Print(value.Data.(string)) default: fmt.Printf("<%s>", valueTypeString(value.Type)) } } func valueTypeString(vt ValueType) string { switch vt { case ValueTable: return "table" case ValueFunction: return "function" case ValueStruct: return "struct" case ValueArray: return "array" case ValueUpvalue: return "upvalue" default: return "unknown" } }