diff --git a/compiler/compiler.go b/compiler/compiler.go index 5b4fb37..703306d 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -1,290 +1,538 @@ package compiler -import "fmt" +import ( + "fmt" -// Constants for compiler limits -const ( - MaxLocals = 256 // Maximum local variables per function - MaxUpvalues = 256 // Maximum upvalues per function - MaxConstants = 65536 // Maximum constants per chunk + "git.sharkk.net/Sharkk/Mako/parser" ) -// CompilerState holds state during compilation -type CompilerState struct { - Chunk *Chunk // Current chunk being compiled - Constants map[string]int // Constant pool index mapping - Functions []Function // Compiled functions - Structs []Struct // Compiled structs - Locals []Local // Local variable stack - Upvalues []UpvalueRef // Upvalue definitions - ScopeDepth int // Current scope nesting level - FunctionType FunctionType // Type of function being compiled - BreakJumps []int // Break jump addresses for loops - ContinueJumps []int // Continue jump addresses for loops - LoopStart int // Start of current loop for continue - LoopDepth int // Current loop nesting depth +// Compiler holds the compilation state and compiles AST to bytecode +type Compiler struct { + current *CompilerState // Current compilation state + enclosing *CompilerState // Enclosing function state for closures + errors []CompileError // Compilation errors } -// Local represents a local variable during compilation -type Local struct { - Name string // Variable name - Depth int // Scope depth where declared - IsCaptured bool // Whether variable is captured by closure - Slot int // Stack slot index -} - -// UpvalueRef represents an upvalue reference during compilation -type UpvalueRef struct { - Index uint8 // Index in enclosing function's locals or upvalues - IsLocal bool // True if captures local, false if captures upvalue -} - -// FunctionType represents the type of function being compiled -type FunctionType uint8 - -const ( - FunctionTypeScript FunctionType = iota // Top-level script - FunctionTypeFunction // Regular function - FunctionTypeMethod // Struct method -) - -// CompileError represents a compilation error with location information -type CompileError struct { - Message string - Line int - Column int -} - -func (ce CompileError) Error() string { - return fmt.Sprintf("Compile error at line %d, column %d: %s", ce.Line, ce.Column, ce.Message) -} - -// NewCompilerState creates a new compiler state for compilation -func NewCompilerState(functionType FunctionType) *CompilerState { - return &CompilerState{ - Chunk: NewChunk(), - Constants: make(map[string]int), - Functions: make([]Function, 0), - Structs: make([]Struct, 0), - Locals: make([]Local, 0, MaxLocals), - Upvalues: make([]UpvalueRef, 0, MaxUpvalues), - ScopeDepth: 0, - FunctionType: functionType, - BreakJumps: make([]int, 0), - ContinueJumps: make([]int, 0), - LoopStart: -1, - LoopDepth: 0, +// NewCompiler creates a new compiler instance +func NewCompiler() *Compiler { + return &Compiler{ + current: NewCompilerState(FunctionTypeScript), + errors: make([]CompileError, 0), } } -// NewChunk creates a new bytecode chunk -func NewChunk() *Chunk { - return &Chunk{ - Code: make([]uint8, 0, 256), - Constants: make([]Value, 0, 64), - Lines: make([]int, 0, 256), - Functions: make([]Function, 0), - Structs: make([]Struct, 0), +// Compile compiles a program AST to bytecode +func (c *Compiler) Compile(program *parser.Program) (*Chunk, []CompileError) { + for _, stmt := range program.Statements { + c.compileStatement(stmt) } + + c.current.EmitInstruction(OpReturnNil) + + if len(c.errors) > 0 { + return nil, c.errors + } + + return c.current.Chunk, nil } -// Scope management methods -func (cs *CompilerState) BeginScope() { - cs.ScopeDepth++ -} - -func (cs *CompilerState) EndScope() { - cs.ScopeDepth-- - - // Remove locals that go out of scope - for len(cs.Locals) > 0 && cs.Locals[len(cs.Locals)-1].Depth > cs.ScopeDepth { - local := cs.Locals[len(cs.Locals)-1] - if local.IsCaptured { - // Emit close upvalue instruction - cs.EmitByte(uint8(OpCloseUpvalue)) - } else { - // Emit pop instruction - cs.EmitByte(uint8(OpPop)) - } - cs.Locals = cs.Locals[:len(cs.Locals)-1] - } -} - -// Local variable management -func (cs *CompilerState) AddLocal(name string) error { - if len(cs.Locals) >= MaxLocals { - return CompileError{ - Message: "too many local variables in function", - } - } - - local := Local{ - Name: name, - Depth: -1, // Mark as uninitialized - IsCaptured: false, - Slot: len(cs.Locals), - } - - cs.Locals = append(cs.Locals, local) - return nil -} - -func (cs *CompilerState) MarkInitialized() { - if len(cs.Locals) > 0 { - cs.Locals[len(cs.Locals)-1].Depth = cs.ScopeDepth - } -} - -func (cs *CompilerState) ResolveLocal(name string) int { - for i := len(cs.Locals) - 1; i >= 0; i-- { - local := &cs.Locals[i] - if local.Name == name { - if local.Depth == -1 { - // Variable used before initialization - return -2 - } - return i - } - } - return -1 -} - -// Upvalue management -func (cs *CompilerState) AddUpvalue(index uint8, isLocal bool) int { - upvalueCount := len(cs.Upvalues) - - // Check if upvalue already exists - for i := range upvalueCount { - upvalue := &cs.Upvalues[i] - if upvalue.Index == index && upvalue.IsLocal == isLocal { - return i - } - } - - if upvalueCount >= MaxUpvalues { - return -1 // Too many upvalues - } - - cs.Upvalues = append(cs.Upvalues, UpvalueRef{ - Index: index, - IsLocal: isLocal, - }) - - return upvalueCount -} - -// Constant pool management -func (cs *CompilerState) AddConstant(value Value) int { - // Check if constant already exists to avoid duplicates - key := cs.valueKey(value) - if index, exists := cs.Constants[key]; exists { - return index - } - - if len(cs.Chunk.Constants) >= MaxConstants { - return -1 // Too many constants - } - - index := len(cs.Chunk.Constants) - cs.Chunk.Constants = append(cs.Chunk.Constants, value) - cs.Constants[key] = index - return index -} - -// Generate unique key for value in constant pool -func (cs *CompilerState) valueKey(value Value) string { - switch value.Type { - case ValueNil: - return "nil" - case ValueBool: - if value.Data.(bool) { - return "bool:true" - } - return "bool:false" - case ValueNumber: - return fmt.Sprintf("number:%g", value.Data.(float64)) - case ValueString: - return fmt.Sprintf("string:%s", value.Data.(string)) +// Statement compilation +func (c *Compiler) compileStatement(stmt parser.Statement) { + switch s := stmt.(type) { + case *parser.StructStatement: + c.compileStructStatement(s) + case *parser.MethodDefinition: + c.compileMethodDefinition(s) + case *parser.Assignment: + c.compileAssignment(s) + case *parser.ExpressionStatement: + c.compileExpression(s.Expression) + c.current.EmitInstruction(OpPop) // Discard result + case *parser.EchoStatement: + c.compileExpression(s.Value) + c.current.EmitInstruction(OpEcho) + case *parser.IfStatement: + c.compileIfStatement(s) + case *parser.WhileStatement: + c.compileWhileStatement(s) + case *parser.ForStatement: + c.compileForStatement(s) + case *parser.ForInStatement: + c.compileForInStatement(s) + case *parser.ReturnStatement: + c.compileReturnStatement(s) + case *parser.ExitStatement: + c.compileExitStatement(s) + case *parser.BreakStatement: + c.current.EmitBreak() default: - // For complex types, use memory address as fallback - return fmt.Sprintf("%T:%p", value.Data, value.Data) + c.addError(fmt.Sprintf("unknown statement type: %T", stmt)) } } -// Bytecode emission methods -func (cs *CompilerState) EmitByte(byte uint8) { - cs.Chunk.Code = append(cs.Chunk.Code, byte) - cs.Chunk.Lines = append(cs.Chunk.Lines, 0) // Line will be set by caller -} - -func (cs *CompilerState) EmitBytes(bytes ...uint8) { - for _, b := range bytes { - cs.EmitByte(b) +// Expression compilation +func (c *Compiler) compileExpression(expr parser.Expression) { + switch e := expr.(type) { + case *parser.Identifier: + c.compileIdentifier(e) + case *parser.NumberLiteral: + c.compileNumberLiteral(e) + case *parser.StringLiteral: + c.compileStringLiteral(e) + case *parser.BooleanLiteral: + c.compileBooleanLiteral(e) + case *parser.NilLiteral: + c.compileNilLiteral(e) + case *parser.TableLiteral: + c.compileTableLiteral(e) + case *parser.StructConstructor: + c.compileStructConstructor(e) + case *parser.FunctionLiteral: + c.compileFunctionLiteral(e) + case *parser.CallExpression: + c.compileCallExpression(e) + case *parser.PrefixExpression: + c.compilePrefixExpression(e) + case *parser.InfixExpression: + c.compileInfixExpression(e) + case *parser.IndexExpression: + c.compileIndexExpression(e) + case *parser.DotExpression: + c.compileDotExpression(e) + case *parser.Assignment: + c.compileAssignmentExpression(e) + default: + c.addError(fmt.Sprintf("unknown expression type: %T", expr)) } } -func (cs *CompilerState) EmitInstruction(op Opcode, operands ...uint16) { - bytes := EncodeInstruction(op, operands...) - cs.EmitBytes(bytes...) +// Literal compilation +func (c *Compiler) compileNumberLiteral(node *parser.NumberLiteral) { + value := Value{Type: ValueNumber, Data: node.Value} + index := c.current.AddConstant(value) + if index == -1 { + c.addError("too many constants") + return + } + c.current.EmitInstruction(OpLoadConst, uint16(index)) } -func (cs *CompilerState) EmitJump(op Opcode) int { - cs.EmitByte(uint8(op)) - cs.EmitByte(0xFF) // Placeholder - cs.EmitByte(0xFF) // Placeholder - return len(cs.Chunk.Code) - 2 // Return offset of jump address +func (c *Compiler) compileStringLiteral(node *parser.StringLiteral) { + value := Value{Type: ValueString, Data: node.Value} + index := c.current.AddConstant(value) + if index == -1 { + c.addError("too many constants") + return + } + c.current.EmitInstruction(OpLoadConst, uint16(index)) } -func (cs *CompilerState) PatchJump(offset int) { - // Calculate jump distance - jump := len(cs.Chunk.Code) - offset - 2 +func (c *Compiler) compileBooleanLiteral(node *parser.BooleanLiteral) { + value := Value{Type: ValueBool, Data: node.Value} + index := c.current.AddConstant(value) + if index == -1 { + c.addError("too many constants") + return + } + c.current.EmitInstruction(OpLoadConst, uint16(index)) +} - if jump > 65535 { - // Jump too large - would need long jump instruction +func (c *Compiler) compileNilLiteral(node *parser.NilLiteral) { + value := Value{Type: ValueNil, Data: nil} + index := c.current.AddConstant(value) + if index == -1 { + c.addError("too many constants") + return + } + c.current.EmitInstruction(OpLoadConst, uint16(index)) +} + +// Identifier compilation +func (c *Compiler) compileIdentifier(node *parser.Identifier) { + // Try local variables first + slot := c.current.ResolveLocal(node.Value) + if slot != -1 { + if slot == -2 { + c.addError("can't read local variable in its own initializer") + return + } + c.current.EmitInstruction(OpLoadLocal, uint16(slot)) return } - cs.Chunk.Code[offset] = uint8(jump & 0xFF) - cs.Chunk.Code[offset+1] = uint8((jump >> 8) & 0xFF) -} - -// Loop management -func (cs *CompilerState) EnterLoop() { - cs.LoopStart = len(cs.Chunk.Code) - cs.LoopDepth++ -} - -func (cs *CompilerState) ExitLoop() { - cs.LoopDepth-- - if cs.LoopDepth == 0 { - cs.LoopStart = -1 + // Try upvalues + upvalue := c.resolveUpvalue(node.Value) + if upvalue != -1 { + c.current.EmitInstruction(OpGetUpvalue, uint16(upvalue)) + return } - // Patch break jumps - for _, jumpOffset := range cs.BreakJumps { - cs.PatchJump(jumpOffset) + // Must be global + value := Value{Type: ValueString, Data: node.Value} + index := c.current.AddConstant(value) + if index == -1 { + c.addError("too many constants") + return } - cs.BreakJumps = cs.BreakJumps[:0] + c.current.EmitInstruction(OpLoadGlobal, uint16(index)) +} - // Patch continue jumps - for _, jumpOffset := range cs.ContinueJumps { - jump := cs.LoopStart - jumpOffset - 2 - if jump < 65535 { - cs.Chunk.Code[jumpOffset] = uint8(jump & 0xFF) - cs.Chunk.Code[jumpOffset+1] = uint8((jump >> 8) & 0xFF) +// Assignment compilation +func (c *Compiler) compileAssignment(node *parser.Assignment) { + c.compileExpression(node.Value) + + switch target := node.Target.(type) { + case *parser.Identifier: + if node.IsDeclaration { + // Check if we're at global scope + if c.current.FunctionType == FunctionTypeScript && c.current.ScopeDepth == 0 { + // Global variable declaration - treat as global assignment + value := Value{Type: ValueString, Data: target.Value} + index := c.current.AddConstant(value) + if index == -1 { + c.addError("too many constants") + return + } + c.current.EmitInstruction(OpStoreGlobal, uint16(index)) + } else { + // Local variable declaration + if err := c.current.AddLocal(target.Value); err != nil { + c.addError(err.Error()) + return + } + c.current.MarkInitialized() + } + } else { + // Assignment to existing variable + slot := c.current.ResolveLocal(target.Value) + if slot != -1 { + c.current.EmitInstruction(OpStoreLocal, uint16(slot)) + } else { + upvalue := c.resolveUpvalue(target.Value) + if upvalue != -1 { + c.current.EmitInstruction(OpSetUpvalue, uint16(upvalue)) + } else { + // Global assignment + value := Value{Type: ValueString, Data: target.Value} + index := c.current.AddConstant(value) + if index == -1 { + c.addError("too many constants") + return + } + c.current.EmitInstruction(OpStoreGlobal, uint16(index)) + } + } + } + case *parser.DotExpression: + // table.field = value + c.compileExpression(target.Left) + value := Value{Type: ValueString, Data: target.Key} + index := c.current.AddConstant(value) + if index == -1 { + c.addError("too many constants") + return + } + c.current.EmitInstruction(OpSetField, uint16(index)) + case *parser.IndexExpression: + // table[key] = value + c.compileExpression(target.Left) + c.compileExpression(target.Index) + c.current.EmitInstruction(OpSetIndex) + default: + c.addError("invalid assignment target") + } +} + +func (c *Compiler) compileAssignmentExpression(node *parser.Assignment) { + c.compileAssignment(node) + // Assignment expressions leave the assigned value on stack +} + +// Operator compilation +func (c *Compiler) compilePrefixExpression(node *parser.PrefixExpression) { + c.compileExpression(node.Right) + + switch node.Operator { + case "-": + c.current.EmitInstruction(OpNeg) + case "not": + c.current.EmitInstruction(OpNot) + default: + c.addError(fmt.Sprintf("unknown prefix operator: %s", node.Operator)) + } +} + +func (c *Compiler) compileInfixExpression(node *parser.InfixExpression) { + // Handle short-circuit operators specially + if node.Operator == "and" { + c.compileExpression(node.Left) + jump := c.current.EmitJump(OpJumpIfFalse) + c.current.EmitInstruction(OpPop) + c.compileExpression(node.Right) + c.current.PatchJump(jump) + return + } + + if node.Operator == "or" { + c.compileExpression(node.Left) + elseJump := c.current.EmitJump(OpJumpIfFalse) + endJump := c.current.EmitJump(OpJump) + c.current.PatchJump(elseJump) + c.current.EmitInstruction(OpPop) + c.compileExpression(node.Right) + c.current.PatchJump(endJump) + return + } + + // Regular binary operators + c.compileExpression(node.Left) + c.compileExpression(node.Right) + + switch node.Operator { + case "+": + c.current.EmitInstruction(OpAdd) + case "-": + c.current.EmitInstruction(OpSub) + case "*": + c.current.EmitInstruction(OpMul) + case "/": + c.current.EmitInstruction(OpDiv) + case "==": + c.current.EmitInstruction(OpEq) + case "!=": + c.current.EmitInstruction(OpNeq) + case "<": + c.current.EmitInstruction(OpLt) + case "<=": + c.current.EmitInstruction(OpLte) + case ">": + c.current.EmitInstruction(OpGt) + case ">=": + c.current.EmitInstruction(OpGte) + default: + c.addError(fmt.Sprintf("unknown infix operator: %s", node.Operator)) + } +} + +// Control flow compilation +func (c *Compiler) compileIfStatement(node *parser.IfStatement) { + c.compileExpression(node.Condition) + + // Jump over then branch if condition is false + thenJump := c.current.EmitJump(OpJumpIfFalse) + c.current.EmitInstruction(OpPop) + + // Compile then branch + c.current.BeginScope() + for _, stmt := range node.Body { + c.compileStatement(stmt) + } + c.current.EndScope() + + // Jump over else branches + elseJump := c.current.EmitJump(OpJump) + c.current.PatchJump(thenJump) + c.current.EmitInstruction(OpPop) + + // Compile elseif branches + var elseifJumps []int + for _, elseif := range node.ElseIfs { + c.compileExpression(elseif.Condition) + nextJump := c.current.EmitJump(OpJumpIfFalse) + c.current.EmitInstruction(OpPop) + + c.current.BeginScope() + for _, stmt := range elseif.Body { + c.compileStatement(stmt) + } + c.current.EndScope() + + elseifJumps = append(elseifJumps, c.current.EmitJump(OpJump)) + c.current.PatchJump(nextJump) + c.current.EmitInstruction(OpPop) + } + + // Compile else branch + if len(node.Else) > 0 { + c.current.BeginScope() + for _, stmt := range node.Else { + c.compileStatement(stmt) + } + c.current.EndScope() + } + + // Patch all jumps to end + c.current.PatchJump(elseJump) + for _, jump := range elseifJumps { + c.current.PatchJump(jump) + } +} + +func (c *Compiler) compileWhileStatement(node *parser.WhileStatement) { + c.current.EnterLoop() + + c.compileExpression(node.Condition) + exitJump := c.current.EmitJump(OpJumpIfFalse) + c.current.EmitInstruction(OpPop) + + c.current.BeginScope() + for _, stmt := range node.Body { + c.compileStatement(stmt) + } + c.current.EndScope() + + // Jump back to condition + jump := len(c.current.Chunk.Code) - c.current.LoopStart + 2 + c.current.EmitInstruction(OpJump, uint16(jump)) + + c.current.PatchJump(exitJump) + c.current.EmitInstruction(OpPop) + + c.current.ExitLoop() +} + +// Table operations +func (c *Compiler) compileTableLiteral(node *parser.TableLiteral) { + c.current.EmitInstruction(OpNewTable) + + for _, pair := range node.Pairs { + if pair.Key == nil { + // Array-style element + c.compileExpression(pair.Value) + c.current.EmitInstruction(OpTableInsert) + } else { + // Key-value pair + c.current.EmitInstruction(OpDup) // Duplicate table reference + c.compileExpression(pair.Key) + c.compileExpression(pair.Value) + c.current.EmitInstruction(OpSetIndex) } } - cs.ContinueJumps = cs.ContinueJumps[:0] } -func (cs *CompilerState) EmitBreak() { - jumpOffset := cs.EmitJump(OpJump) - cs.BreakJumps = append(cs.BreakJumps, jumpOffset) +func (c *Compiler) compileDotExpression(node *parser.DotExpression) { + c.compileExpression(node.Left) + value := Value{Type: ValueString, Data: node.Key} + index := c.current.AddConstant(value) + if index == -1 { + c.addError("too many constants") + return + } + c.current.EmitInstruction(OpGetField, uint16(index)) } -func (cs *CompilerState) EmitContinue() { - if cs.LoopStart != -1 { - jumpOffset := cs.EmitJump(OpJump) - cs.ContinueJumps = append(cs.ContinueJumps, jumpOffset) +func (c *Compiler) compileIndexExpression(node *parser.IndexExpression) { + c.compileExpression(node.Left) + c.compileExpression(node.Index) + c.current.EmitInstruction(OpGetIndex) +} + +// Function compilation +func (c *Compiler) compileCallExpression(node *parser.CallExpression) { + c.compileExpression(node.Function) + + // Compile arguments + for _, arg := range node.Arguments { + c.compileExpression(arg) + } + + c.current.EmitInstruction(OpCall, uint16(len(node.Arguments))) +} + +func (c *Compiler) compileReturnStatement(node *parser.ReturnStatement) { + if node.Value != nil { + c.compileExpression(node.Value) + c.current.EmitInstruction(OpReturn) + } else { + c.current.EmitInstruction(OpReturnNil) } } + +func (c *Compiler) compileExitStatement(node *parser.ExitStatement) { + if node.Value != nil { + c.compileExpression(node.Value) + } else { + // Default exit code 0 + value := Value{Type: ValueNumber, Data: float64(0)} + index := c.current.AddConstant(value) + if index == -1 { + c.addError("too many constants") + return + } + c.current.EmitInstruction(OpLoadConst, uint16(index)) + } + c.current.EmitInstruction(OpExit) +} + +// Placeholder implementations for complex features +func (c *Compiler) compileStructStatement(node *parser.StructStatement) { + // TODO: Implement struct compilation + c.addError("struct compilation not yet implemented") +} + +func (c *Compiler) compileMethodDefinition(node *parser.MethodDefinition) { + // TODO: Implement method compilation + c.addError("method compilation not yet implemented") +} + +func (c *Compiler) compileStructConstructor(node *parser.StructConstructor) { + // TODO: Implement struct constructor compilation + c.addError("struct constructor compilation not yet implemented") +} + +func (c *Compiler) compileFunctionLiteral(node *parser.FunctionLiteral) { + // TODO: Implement function literal compilation + c.addError("function literal compilation not yet implemented") +} + +func (c *Compiler) compileForStatement(node *parser.ForStatement) { + // TODO: Implement numeric for loop compilation + c.addError("for statement compilation not yet implemented") +} + +func (c *Compiler) compileForInStatement(node *parser.ForInStatement) { + // TODO: Implement for-in loop compilation + c.addError("for-in statement compilation not yet implemented") +} + +// Helper methods +func (c *Compiler) resolveUpvalue(name string) int { + if c.enclosing == nil { + return -1 + } + + local := c.enclosing.ResolveLocal(name) + if local != -1 { + c.enclosing.Locals[local].IsCaptured = true + return c.current.AddUpvalue(uint8(local), true) + } + + upvalue := c.resolveUpvalueInEnclosing(name) + if upvalue != -1 { + return c.current.AddUpvalue(uint8(upvalue), false) + } + + return -1 +} + +func (c *Compiler) resolveUpvalueInEnclosing(name string) int { + if c.enclosing == nil { + return -1 + } + + // This would recursively check enclosing scopes + // Simplified for now + return -1 +} + +func (c *Compiler) addError(message string) { + c.errors = append(c.errors, CompileError{ + Message: message, + Line: 0, // TODO: Add line tracking + Column: 0, // TODO: Add column tracking + }) +} + +// Error reporting +func (c *Compiler) Errors() []CompileError { return c.errors } +func (c *Compiler) HasErrors() bool { return len(c.errors) > 0 } diff --git a/compiler/state.go b/compiler/state.go new file mode 100644 index 0000000..5b4fb37 --- /dev/null +++ b/compiler/state.go @@ -0,0 +1,290 @@ +package compiler + +import "fmt" + +// Constants for compiler limits +const ( + MaxLocals = 256 // Maximum local variables per function + MaxUpvalues = 256 // Maximum upvalues per function + MaxConstants = 65536 // Maximum constants per chunk +) + +// CompilerState holds state during compilation +type CompilerState struct { + Chunk *Chunk // Current chunk being compiled + Constants map[string]int // Constant pool index mapping + Functions []Function // Compiled functions + Structs []Struct // Compiled structs + Locals []Local // Local variable stack + Upvalues []UpvalueRef // Upvalue definitions + ScopeDepth int // Current scope nesting level + FunctionType FunctionType // Type of function being compiled + BreakJumps []int // Break jump addresses for loops + ContinueJumps []int // Continue jump addresses for loops + LoopStart int // Start of current loop for continue + LoopDepth int // Current loop nesting depth +} + +// Local represents a local variable during compilation +type Local struct { + Name string // Variable name + Depth int // Scope depth where declared + IsCaptured bool // Whether variable is captured by closure + Slot int // Stack slot index +} + +// UpvalueRef represents an upvalue reference during compilation +type UpvalueRef struct { + Index uint8 // Index in enclosing function's locals or upvalues + IsLocal bool // True if captures local, false if captures upvalue +} + +// FunctionType represents the type of function being compiled +type FunctionType uint8 + +const ( + FunctionTypeScript FunctionType = iota // Top-level script + FunctionTypeFunction // Regular function + FunctionTypeMethod // Struct method +) + +// CompileError represents a compilation error with location information +type CompileError struct { + Message string + Line int + Column int +} + +func (ce CompileError) Error() string { + return fmt.Sprintf("Compile error at line %d, column %d: %s", ce.Line, ce.Column, ce.Message) +} + +// NewCompilerState creates a new compiler state for compilation +func NewCompilerState(functionType FunctionType) *CompilerState { + return &CompilerState{ + Chunk: NewChunk(), + Constants: make(map[string]int), + Functions: make([]Function, 0), + Structs: make([]Struct, 0), + Locals: make([]Local, 0, MaxLocals), + Upvalues: make([]UpvalueRef, 0, MaxUpvalues), + ScopeDepth: 0, + FunctionType: functionType, + BreakJumps: make([]int, 0), + ContinueJumps: make([]int, 0), + LoopStart: -1, + LoopDepth: 0, + } +} + +// NewChunk creates a new bytecode chunk +func NewChunk() *Chunk { + return &Chunk{ + Code: make([]uint8, 0, 256), + Constants: make([]Value, 0, 64), + Lines: make([]int, 0, 256), + Functions: make([]Function, 0), + Structs: make([]Struct, 0), + } +} + +// Scope management methods +func (cs *CompilerState) BeginScope() { + cs.ScopeDepth++ +} + +func (cs *CompilerState) EndScope() { + cs.ScopeDepth-- + + // Remove locals that go out of scope + for len(cs.Locals) > 0 && cs.Locals[len(cs.Locals)-1].Depth > cs.ScopeDepth { + local := cs.Locals[len(cs.Locals)-1] + if local.IsCaptured { + // Emit close upvalue instruction + cs.EmitByte(uint8(OpCloseUpvalue)) + } else { + // Emit pop instruction + cs.EmitByte(uint8(OpPop)) + } + cs.Locals = cs.Locals[:len(cs.Locals)-1] + } +} + +// Local variable management +func (cs *CompilerState) AddLocal(name string) error { + if len(cs.Locals) >= MaxLocals { + return CompileError{ + Message: "too many local variables in function", + } + } + + local := Local{ + Name: name, + Depth: -1, // Mark as uninitialized + IsCaptured: false, + Slot: len(cs.Locals), + } + + cs.Locals = append(cs.Locals, local) + return nil +} + +func (cs *CompilerState) MarkInitialized() { + if len(cs.Locals) > 0 { + cs.Locals[len(cs.Locals)-1].Depth = cs.ScopeDepth + } +} + +func (cs *CompilerState) ResolveLocal(name string) int { + for i := len(cs.Locals) - 1; i >= 0; i-- { + local := &cs.Locals[i] + if local.Name == name { + if local.Depth == -1 { + // Variable used before initialization + return -2 + } + return i + } + } + return -1 +} + +// Upvalue management +func (cs *CompilerState) AddUpvalue(index uint8, isLocal bool) int { + upvalueCount := len(cs.Upvalues) + + // Check if upvalue already exists + for i := range upvalueCount { + upvalue := &cs.Upvalues[i] + if upvalue.Index == index && upvalue.IsLocal == isLocal { + return i + } + } + + if upvalueCount >= MaxUpvalues { + return -1 // Too many upvalues + } + + cs.Upvalues = append(cs.Upvalues, UpvalueRef{ + Index: index, + IsLocal: isLocal, + }) + + return upvalueCount +} + +// Constant pool management +func (cs *CompilerState) AddConstant(value Value) int { + // Check if constant already exists to avoid duplicates + key := cs.valueKey(value) + if index, exists := cs.Constants[key]; exists { + return index + } + + if len(cs.Chunk.Constants) >= MaxConstants { + return -1 // Too many constants + } + + index := len(cs.Chunk.Constants) + cs.Chunk.Constants = append(cs.Chunk.Constants, value) + cs.Constants[key] = index + return index +} + +// Generate unique key for value in constant pool +func (cs *CompilerState) valueKey(value Value) string { + switch value.Type { + case ValueNil: + return "nil" + case ValueBool: + if value.Data.(bool) { + return "bool:true" + } + return "bool:false" + case ValueNumber: + return fmt.Sprintf("number:%g", value.Data.(float64)) + case ValueString: + return fmt.Sprintf("string:%s", value.Data.(string)) + default: + // For complex types, use memory address as fallback + return fmt.Sprintf("%T:%p", value.Data, value.Data) + } +} + +// Bytecode emission methods +func (cs *CompilerState) EmitByte(byte uint8) { + cs.Chunk.Code = append(cs.Chunk.Code, byte) + cs.Chunk.Lines = append(cs.Chunk.Lines, 0) // Line will be set by caller +} + +func (cs *CompilerState) EmitBytes(bytes ...uint8) { + for _, b := range bytes { + cs.EmitByte(b) + } +} + +func (cs *CompilerState) EmitInstruction(op Opcode, operands ...uint16) { + bytes := EncodeInstruction(op, operands...) + cs.EmitBytes(bytes...) +} + +func (cs *CompilerState) EmitJump(op Opcode) int { + cs.EmitByte(uint8(op)) + cs.EmitByte(0xFF) // Placeholder + cs.EmitByte(0xFF) // Placeholder + return len(cs.Chunk.Code) - 2 // Return offset of jump address +} + +func (cs *CompilerState) PatchJump(offset int) { + // Calculate jump distance + jump := len(cs.Chunk.Code) - offset - 2 + + if jump > 65535 { + // Jump too large - would need long jump instruction + return + } + + cs.Chunk.Code[offset] = uint8(jump & 0xFF) + cs.Chunk.Code[offset+1] = uint8((jump >> 8) & 0xFF) +} + +// Loop management +func (cs *CompilerState) EnterLoop() { + cs.LoopStart = len(cs.Chunk.Code) + cs.LoopDepth++ +} + +func (cs *CompilerState) ExitLoop() { + cs.LoopDepth-- + if cs.LoopDepth == 0 { + cs.LoopStart = -1 + } + + // Patch break jumps + for _, jumpOffset := range cs.BreakJumps { + cs.PatchJump(jumpOffset) + } + cs.BreakJumps = cs.BreakJumps[:0] + + // Patch continue jumps + for _, jumpOffset := range cs.ContinueJumps { + jump := cs.LoopStart - jumpOffset - 2 + if jump < 65535 { + cs.Chunk.Code[jumpOffset] = uint8(jump & 0xFF) + cs.Chunk.Code[jumpOffset+1] = uint8((jump >> 8) & 0xFF) + } + } + cs.ContinueJumps = cs.ContinueJumps[:0] +} + +func (cs *CompilerState) EmitBreak() { + jumpOffset := cs.EmitJump(OpJump) + cs.BreakJumps = append(cs.BreakJumps, jumpOffset) +} + +func (cs *CompilerState) EmitContinue() { + if cs.LoopStart != -1 { + jumpOffset := cs.EmitJump(OpJump) + cs.ContinueJumps = append(cs.ContinueJumps, jumpOffset) + } +} diff --git a/compiler/tests/compiler_test.go b/compiler/tests/compiler_test.go new file mode 100644 index 0000000..2299aa3 --- /dev/null +++ b/compiler/tests/compiler_test.go @@ -0,0 +1,364 @@ +package compiler_test + +import ( + "testing" + + "git.sharkk.net/Sharkk/Mako/compiler" + "git.sharkk.net/Sharkk/Mako/parser" +) + +// Helper function to compile source code and return chunk +func compileSource(t *testing.T, source string) *compiler.Chunk { + lexer := parser.NewLexer(source) + p := parser.NewParser(lexer) + program := p.ParseProgram() + + if p.HasErrors() { + t.Fatalf("Parser errors: %v", p.ErrorStrings()) + } + + comp := compiler.NewCompiler() + chunk, errors := comp.Compile(program) + + if len(errors) > 0 { + t.Fatalf("Compiler errors: %v", errors) + } + + return chunk +} + +// Helper to check instruction at position +func checkInstruction(t *testing.T, chunk *compiler.Chunk, pos int, expected compiler.Opcode, operands ...uint16) { + if pos >= len(chunk.Code) { + t.Fatalf("Position %d out of bounds (code length: %d)", pos, len(chunk.Code)) + } + + op, actualOperands, _ := compiler.DecodeInstruction(chunk.Code, pos) + + if op != expected { + t.Errorf("Expected opcode %v at position %d, got %v", expected, pos, op) + } + + if len(actualOperands) != len(operands) { + t.Errorf("Expected %d operands, got %d", len(operands), len(actualOperands)) + return + } + + for i, expected := range operands { + if actualOperands[i] != expected { + t.Errorf("Expected operand %d to be %d, got %d", i, expected, actualOperands[i]) + } + } +} + +// Test literal compilation +func TestNumberLiteral(t *testing.T) { + chunk := compileSource(t, "echo 42") + + // Should have one constant (42) and load it + if len(chunk.Constants) != 1 { + t.Fatalf("Expected 1 constant, got %d", len(chunk.Constants)) + } + + if chunk.Constants[0].Type != compiler.ValueNumber { + t.Errorf("Expected number constant, got %v", chunk.Constants[0].Type) + } + + if chunk.Constants[0].Data.(float64) != 42.0 { + t.Errorf("Expected constant value 42, got %v", chunk.Constants[0].Data) + } + + // Check bytecode: OpLoadConst 0, OpEcho, OpReturnNil + checkInstruction(t, chunk, 0, compiler.OpLoadConst, 0) + checkInstruction(t, chunk, 3, compiler.OpEcho) + checkInstruction(t, chunk, 4, compiler.OpReturnNil) +} + +func TestStringLiteral(t *testing.T) { + chunk := compileSource(t, `echo "hello"`) + + if len(chunk.Constants) != 1 { + t.Fatalf("Expected 1 constant, got %d", len(chunk.Constants)) + } + + if chunk.Constants[0].Type != compiler.ValueString { + t.Errorf("Expected string constant, got %v", chunk.Constants[0].Type) + } + + if chunk.Constants[0].Data.(string) != "hello" { + t.Errorf("Expected constant value 'hello', got %v", chunk.Constants[0].Data) + } + + checkInstruction(t, chunk, 0, compiler.OpLoadConst, 0) +} + +func TestBooleanLiterals(t *testing.T) { + chunk := compileSource(t, "echo true") + + if chunk.Constants[0].Type != compiler.ValueBool { + t.Errorf("Expected bool constant, got %v", chunk.Constants[0].Type) + } + + if chunk.Constants[0].Data.(bool) != true { + t.Errorf("Expected true, got %v", chunk.Constants[0].Data) + } +} + +func TestNilLiteral(t *testing.T) { + chunk := compileSource(t, "echo nil") + + if chunk.Constants[0].Type != compiler.ValueNil { + t.Errorf("Expected nil constant, got %v", chunk.Constants[0].Type) + } +} + +// Test arithmetic operations +func TestArithmetic(t *testing.T) { + tests := []struct { + source string + expected compiler.Opcode + }{ + {"echo 1 + 2", compiler.OpAdd}, + {"echo 5 - 3", compiler.OpSub}, + {"echo 4 * 6", compiler.OpMul}, + {"echo 8 / 2", compiler.OpDiv}, + } + + for _, test := range tests { + chunk := compileSource(t, test.source) + + // Should have: LoadConst 0, LoadConst 1, OpArithmetic, OpEcho, OpReturnNil + checkInstruction(t, chunk, 0, compiler.OpLoadConst, 0) + checkInstruction(t, chunk, 3, compiler.OpLoadConst, 1) + checkInstruction(t, chunk, 6, test.expected) + checkInstruction(t, chunk, 7, compiler.OpEcho) + } +} + +// Test comparison operations +func TestComparison(t *testing.T) { + tests := []struct { + source string + expected compiler.Opcode + }{ + {"echo 1 == 2", compiler.OpEq}, + {"echo 1 != 2", compiler.OpNeq}, + {"echo 1 < 2", compiler.OpLt}, + {"echo 1 <= 2", compiler.OpLte}, + {"echo 1 > 2", compiler.OpGt}, + {"echo 1 >= 2", compiler.OpGte}, + } + + for _, test := range tests { + chunk := compileSource(t, test.source) + checkInstruction(t, chunk, 6, test.expected) + } +} + +// Test prefix operations +func TestPrefixOperations(t *testing.T) { + tests := []struct { + source string + expected compiler.Opcode + }{ + {"echo -42", compiler.OpNeg}, + {"echo not true", compiler.OpNot}, + } + + for _, test := range tests { + chunk := compileSource(t, test.source) + checkInstruction(t, chunk, 3, test.expected) + } +} + +// Test variable assignment +func TestLocalAssignment(t *testing.T) { + // Test local assignment within a function scope + chunk := compileSource(t, ` + fn test() + x: number = 42 + end + `) + + // This tests function compilation which is not yet implemented + // For now, just check that it doesn't crash + if chunk == nil { + t.Skip("Function compilation not yet implemented") + } +} + +func TestGlobalAssignment(t *testing.T) { + chunk := compileSource(t, "x = 42") + + // Should have: LoadConst 0, StoreGlobal 1, OpReturnNil + // Constants: [42, "x"] + if len(chunk.Constants) != 2 { + t.Fatalf("Expected 2 constants, got %d", len(chunk.Constants)) + } + + // Check that we have the number and variable name + if chunk.Constants[0].Data.(float64) != 42.0 { + t.Errorf("Expected first constant to be 42, got %v", chunk.Constants[0].Data) + } + if chunk.Constants[1].Data.(string) != "x" { + t.Errorf("Expected second constant to be 'x', got %v", chunk.Constants[1].Data) + } + + checkInstruction(t, chunk, 0, compiler.OpLoadConst, 0) // Load 42 + checkInstruction(t, chunk, 3, compiler.OpStoreGlobal, 1) // Store to "x" +} + +// Test echo statement +func TestEchoStatement(t *testing.T) { + chunk := compileSource(t, "echo 42") + + // Should have: LoadConst 0, OpEcho, OpReturnNil + checkInstruction(t, chunk, 0, compiler.OpLoadConst, 0) + checkInstruction(t, chunk, 3, compiler.OpEcho) + checkInstruction(t, chunk, 4, compiler.OpReturnNil) +} + +// Test if statement +func TestIfStatement(t *testing.T) { + chunk := compileSource(t, ` + if true then + echo 1 + end + `) + + // Should start with: LoadConst, JumpIfFalse (with offset), Pop + checkInstruction(t, chunk, 0, compiler.OpLoadConst, 0) // Load true + + // JumpIfFalse has 1 operand (the jump offset), but we don't need to check the exact value + op, operands, _ := compiler.DecodeInstruction(chunk.Code, 3) + if op != compiler.OpJumpIfFalse { + t.Errorf("Expected OpJumpIfFalse at position 3, got %v", op) + } + if len(operands) != 1 { + t.Errorf("Expected 1 operand for JumpIfFalse, got %d", len(operands)) + } + + checkInstruction(t, chunk, 6, compiler.OpPop) // Pop condition +} + +// Test while loop +func TestWhileLoop(t *testing.T) { + chunk := compileSource(t, ` + while true do + break + end + `) + + // Should have condition evaluation and loop structure + checkInstruction(t, chunk, 0, compiler.OpLoadConst, 0) // Load true + + // JumpIfFalse has 1 operand (the jump offset) + op, operands, _ := compiler.DecodeInstruction(chunk.Code, 3) + if op != compiler.OpJumpIfFalse { + t.Errorf("Expected OpJumpIfFalse at position 3, got %v", op) + } + if len(operands) != 1 { + t.Errorf("Expected 1 operand for JumpIfFalse, got %d", len(operands)) + } +} + +// Test table creation +func TestTableLiteral(t *testing.T) { + chunk := compileSource(t, "echo {1, 2, 3}") + + // Should start with OpNewTable + checkInstruction(t, chunk, 0, compiler.OpNewTable) +} + +// Test table with key-value pairs +func TestTableWithKeys(t *testing.T) { + chunk := compileSource(t, `echo {x = 1, y = 2}`) + + checkInstruction(t, chunk, 0, compiler.OpNewTable) + // Should have subsequent operations to set fields +} + +// Test function call +func TestFunctionCall(t *testing.T) { + chunk := compileSource(t, "print(42)") + + // Should have: LoadGlobal "print", LoadConst 42, Call 1 + // The exact positions depend on constant ordering + found := false + for i := 0; i < len(chunk.Code)-2; i++ { + op, operands, _ := compiler.DecodeInstruction(chunk.Code, i) + if op == compiler.OpCall && len(operands) > 0 && operands[0] == 1 { + found = true + break + } + } + if !found { + t.Error("Expected OpCall with 1 argument") + } +} + +// Test constant deduplication +func TestConstantDeduplication(t *testing.T) { + chunk := compileSource(t, "echo 42\necho 42\necho 42") + + // Should only have one constant despite multiple uses + if len(chunk.Constants) != 1 { + t.Errorf("Expected 1 constant (deduplicated), got %d", len(chunk.Constants)) + } +} + +// Test short-circuit evaluation +func TestShortCircuitAnd(t *testing.T) { + chunk := compileSource(t, "echo true and false") + + // Should have conditional jumping for short-circuit + found := false + for i := 0; i < len(chunk.Code); i++ { + op, _, _ := compiler.DecodeInstruction(chunk.Code, i) + if op == compiler.OpJumpIfFalse { + found = true + break + } + } + if !found { + t.Error("Expected JumpIfFalse for short-circuit and") + } +} + +func TestShortCircuitOr(t *testing.T) { + chunk := compileSource(t, "echo false or true") + + // Should have conditional jumping for short-circuit + foundFalseJump := false + foundJump := false + for i := 0; i < len(chunk.Code); i++ { + op, _, _ := compiler.DecodeInstruction(chunk.Code, i) + if op == compiler.OpJumpIfFalse { + foundFalseJump = true + } + if op == compiler.OpJump { + foundJump = true + } + } + if !foundFalseJump || !foundJump { + t.Error("Expected JumpIfFalse and Jump for short-circuit or") + } +} + +// Test complex expressions +func TestComplexExpression(t *testing.T) { + chunk := compileSource(t, "echo 1 + 2 * 3") + + // Should follow correct precedence: Load 1, Load 2, Load 3, Mul, Add + if len(chunk.Constants) != 3 { + t.Fatalf("Expected 3 constants, got %d", len(chunk.Constants)) + } + + // Verify constants + expected := []float64{1, 2, 3} + for i, exp := range expected { + if chunk.Constants[i].Data.(float64) != exp { + t.Errorf("Expected constant %d to be %v, got %v", i, exp, chunk.Constants[i].Data) + } + } +}