Mako/compiler/state.go
2025-06-11 21:50:55 -05:00

620 lines
15 KiB
Go

package compiler
import "fmt"
// Constants for compiler limits
const (
MaxLocals = 256 // Maximum local variables per function
MaxUpvalues = 256 // Maximum upvalues per function
MaxConstants = 65536 // Maximum constants per chunk
)
// CompilerState holds state during compilation
type CompilerState struct {
Chunk *Chunk // Current chunk being compiled
Constants map[string]int // Constant pool index mapping for deduplication
Functions []Function // Compiled functions
Structs []Struct // Compiled structs
Locals []Local // Local variable stack
Upvalues []UpvalueRef // Upvalue definitions
ScopeDepth int // Current scope nesting level
FunctionType FunctionType // Type of function being compiled
BreakJumps []int // Break jump addresses for loops
ContinueJumps []int // Continue jump addresses for loops
LoopStart int // Start of current loop for continue
LoopDepth int // Current loop nesting depth
parent *CompilerState // Parent compiler state for nested functions
CurrentLine int // Current source line being compiled
}
// Local represents a local variable during compilation
type Local struct {
Name string // Variable name
Depth int // Scope depth where declared
IsCaptured bool // Whether variable is captured by closure
Slot int // Stack slot index
}
// UpvalueRef represents an upvalue reference during compilation
type UpvalueRef struct {
Index uint8 // Index in enclosing function's locals or upvalues
IsLocal bool // True if captures local, false if captures upvalue
}
// FunctionType represents the type of function being compiled
type FunctionType uint8
const (
FunctionTypeScript FunctionType = iota // Top-level script
FunctionTypeFunction // Regular function
FunctionTypeMethod // Struct method
)
// CompileError represents a compilation error with location information
type CompileError struct {
Message string
Line int
Column int
}
func (ce CompileError) Error() string {
return fmt.Sprintf("Compile error at line %d, column %d: %s", ce.Line, ce.Column, ce.Message)
}
// NewCompilerState creates a new compiler state for compilation
func NewCompilerState(functionType FunctionType) *CompilerState {
return &CompilerState{
Chunk: NewChunk(),
Constants: make(map[string]int),
Functions: make([]Function, 0),
Structs: make([]Struct, 0),
Locals: make([]Local, 0, MaxLocals),
Upvalues: make([]UpvalueRef, 0, MaxUpvalues),
ScopeDepth: 0,
FunctionType: functionType,
BreakJumps: make([]int, 0),
ContinueJumps: make([]int, 0),
LoopStart: -1,
LoopDepth: 0,
parent: nil,
}
}
// NewChunk creates a new bytecode chunk
func NewChunk() *Chunk {
return &Chunk{
Code: make([]uint8, 0, 256),
Constants: make([]Value, 0, 64),
Lines: make([]int, 0, 256),
Functions: make([]Function, 0),
Structs: make([]Struct, 0),
}
}
// Scope management methods
func (cs *CompilerState) BeginScope() {
cs.ScopeDepth++
}
func (cs *CompilerState) EndScope() {
cs.ScopeDepth--
// Remove locals that go out of scope
for len(cs.Locals) > 0 && cs.Locals[len(cs.Locals)-1].Depth > cs.ScopeDepth {
local := cs.Locals[len(cs.Locals)-1]
if local.IsCaptured {
cs.EmitByte(uint8(OpCloseUpvalue))
} else {
cs.EmitByte(uint8(OpPop))
}
cs.Locals = cs.Locals[:len(cs.Locals)-1]
}
}
// Local variable management
func (cs *CompilerState) AddLocal(name string) error {
if len(cs.Locals) >= MaxLocals {
return CompileError{
Message: "too many local variables in function",
}
}
local := Local{
Name: name,
Depth: -1, // Mark as uninitialized
IsCaptured: false,
Slot: len(cs.Locals),
}
cs.Locals = append(cs.Locals, local)
return nil
}
func (cs *CompilerState) MarkInitialized() {
if len(cs.Locals) > 0 {
cs.Locals[len(cs.Locals)-1].Depth = cs.ScopeDepth
}
}
func (cs *CompilerState) ResolveLocal(name string) int {
for i := len(cs.Locals) - 1; i >= 0; i-- {
local := &cs.Locals[i]
if local.Name == name {
if local.Depth == -1 {
return -2 // Variable used before initialization
}
return i
}
}
return -1
}
// Upvalue management
func (cs *CompilerState) AddUpvalue(index uint8, isLocal bool) int {
upvalueCount := len(cs.Upvalues)
// Check if upvalue already exists
for i := range upvalueCount {
upvalue := &cs.Upvalues[i]
if upvalue.Index == index && upvalue.IsLocal == isLocal {
return i
}
}
if upvalueCount >= MaxUpvalues {
return -1 // Too many upvalues
}
cs.Upvalues = append(cs.Upvalues, UpvalueRef{
Index: index,
IsLocal: isLocal,
})
return upvalueCount
}
// Optimized constant pool management with deduplication
func (cs *CompilerState) AddConstant(value Value) int {
// Generate unique key for deduplication
key := cs.valueKey(value)
if index, exists := cs.Constants[key]; exists {
return index
}
if len(cs.Chunk.Constants) >= MaxConstants {
return -1 // Too many constants
}
index := len(cs.Chunk.Constants)
cs.Chunk.Constants = append(cs.Chunk.Constants, value)
cs.Constants[key] = index
return index
}
// Generate unique key for value deduplication
func (cs *CompilerState) valueKey(value Value) string {
switch value.Type {
case ValueNil:
return "nil"
case ValueBool:
if value.Data.(bool) {
return "bool:true"
}
return "bool:false"
case ValueNumber:
return fmt.Sprintf("number:%g", value.Data.(float64))
case ValueString:
return fmt.Sprintf("string:%s", value.Data.(string))
default:
// For complex types, use memory address as fallback
return fmt.Sprintf("%T:%p", value.Data, value.Data)
}
}
// Optimized bytecode emission methods
func (cs *CompilerState) EmitByte(byte uint8) {
cs.Chunk.Code = append(cs.Chunk.Code, byte)
cs.Chunk.Lines = append(cs.Chunk.Lines, cs.CurrentLine)
}
func (cs *CompilerState) EmitBytes(bytes ...uint8) {
for _, b := range bytes {
cs.EmitByte(b)
}
}
func (cs *CompilerState) EmitInstruction(op Opcode, operands ...uint16) {
bytes := EncodeInstruction(op, operands...)
cs.EmitBytes(bytes...)
}
// Optimized jump emission with better jump distance calculation
func (cs *CompilerState) EmitJump(op Opcode) int {
cs.EmitByte(uint8(op))
cs.EmitByte(0xFF) // Placeholder
cs.EmitByte(0xFF) // Placeholder
return len(cs.Chunk.Code) - 2
}
func (cs *CompilerState) PatchJump(offset int) {
jump := len(cs.Chunk.Code) - offset - 2
if jump > 65535 {
// Jump distance too large - would need to implement long jumps
return
}
cs.Chunk.Code[offset] = uint8(jump & 0xFF)
cs.Chunk.Code[offset+1] = uint8((jump >> 8) & 0xFF)
}
// Enhanced loop management with optimization support
func (cs *CompilerState) EnterLoop() {
cs.LoopStart = len(cs.Chunk.Code)
cs.LoopDepth++
// Clear previous jump lists for new loop
cs.BreakJumps = cs.BreakJumps[:0]
cs.ContinueJumps = cs.ContinueJumps[:0]
}
func (cs *CompilerState) ExitLoop() {
cs.LoopDepth--
if cs.LoopDepth == 0 {
cs.LoopStart = -1
}
// Patch break jumps to current position
for _, jumpOffset := range cs.BreakJumps {
cs.PatchJump(jumpOffset)
}
cs.BreakJumps = cs.BreakJumps[:0]
// Patch continue jumps to loop start
for _, jumpOffset := range cs.ContinueJumps {
if cs.LoopStart != -1 {
jump := jumpOffset - cs.LoopStart + 2
if jump < 65535 && jump >= 0 {
cs.Chunk.Code[jumpOffset] = uint8(jump & 0xFF)
cs.Chunk.Code[jumpOffset+1] = uint8((jump >> 8) & 0xFF)
}
}
}
cs.ContinueJumps = cs.ContinueJumps[:0]
}
func (cs *CompilerState) EmitBreak() {
jumpOffset := cs.EmitJump(OpJump)
cs.BreakJumps = append(cs.BreakJumps, jumpOffset)
}
func (cs *CompilerState) EmitContinue() {
if cs.LoopStart != -1 {
jumpOffset := cs.EmitJump(OpJump)
cs.ContinueJumps = append(cs.ContinueJumps, jumpOffset)
}
}
// Optimized instruction emission helpers
func (cs *CompilerState) EmitLoadConstant(value Value) {
switch value.Type {
case ValueNil:
cs.EmitInstruction(OpLoadNil)
case ValueBool:
if value.Data.(bool) {
cs.EmitInstruction(OpLoadTrue)
} else {
cs.EmitInstruction(OpLoadFalse)
}
case ValueNumber:
num := value.Data.(float64)
if num == 0 {
cs.EmitInstruction(OpLoadZero)
} else if num == 1 {
cs.EmitInstruction(OpLoadOne)
} else {
index := cs.AddConstant(value)
if index != -1 {
cs.EmitInstruction(OpLoadConst, uint16(index))
}
}
default:
index := cs.AddConstant(value)
if index != -1 {
cs.EmitInstruction(OpLoadConst, uint16(index))
}
}
}
func (cs *CompilerState) EmitLoadLocal(slot int) {
switch slot {
case 0:
cs.EmitInstruction(OpLoadLocal0)
case 1:
cs.EmitInstruction(OpLoadLocal1)
case 2:
cs.EmitInstruction(OpLoadLocal2)
default:
cs.EmitInstruction(OpLoadLocal, uint16(slot))
}
}
func (cs *CompilerState) EmitStoreLocal(slot int) {
switch slot {
case 0:
cs.EmitInstruction(OpStoreLocal0)
case 1:
cs.EmitInstruction(OpStoreLocal1)
case 2:
cs.EmitInstruction(OpStoreLocal2)
default:
cs.EmitInstruction(OpStoreLocal, uint16(slot))
}
}
// Instruction pattern detection for optimization
func (cs *CompilerState) GetLastInstruction() (Opcode, []uint16) {
if len(cs.Chunk.Code) == 0 {
return OpNoop, nil
}
// Find the last complete instruction
for i := len(cs.Chunk.Code) - 1; i >= 0; {
op := Opcode(cs.Chunk.Code[i])
operandCount := GetOperandCount(op)
if i >= operandCount*2 {
// This is a complete instruction
operands := make([]uint16, operandCount)
for j := 0; j < operandCount; j++ {
operands[j] = uint16(cs.Chunk.Code[i+1+j*2]) |
(uint16(cs.Chunk.Code[i+2+j*2]) << 8)
}
return op, operands
}
i--
}
return OpNoop, nil
}
// Replace last instruction (for peephole optimization)
func (cs *CompilerState) ReplaceLastInstruction(op Opcode, operands ...uint16) bool {
if len(cs.Chunk.Code) == 0 {
return false
}
// Find last instruction
lastOp, _ := cs.GetLastInstruction()
lastSize := InstructionSize(lastOp)
if len(cs.Chunk.Code) < lastSize {
return false
}
// Remove last instruction
cs.Chunk.Code = cs.Chunk.Code[:len(cs.Chunk.Code)-lastSize]
cs.Chunk.Lines = cs.Chunk.Lines[:len(cs.Chunk.Lines)-lastSize]
// Emit new instruction
cs.EmitInstruction(op, operands...)
return true
}
// Constant folding support
func (cs *CompilerState) TryConstantFolding(op Opcode, operands ...Value) *Value {
if len(operands) < 2 {
return nil
}
left, right := operands[0], operands[1]
// Only fold numeric operations for now
if left.Type != ValueNumber || right.Type != ValueNumber {
return nil
}
l := left.Data.(float64)
r := right.Data.(float64)
switch op {
case OpAdd:
return &Value{Type: ValueNumber, Data: l + r}
case OpSub:
return &Value{Type: ValueNumber, Data: l - r}
case OpMul:
return &Value{Type: ValueNumber, Data: l * r}
case OpDiv:
if r != 0 {
return &Value{Type: ValueNumber, Data: l / r}
}
case OpEq:
return &Value{Type: ValueBool, Data: l == r}
case OpNeq:
return &Value{Type: ValueBool, Data: l != r}
case OpLt:
return &Value{Type: ValueBool, Data: l < r}
case OpLte:
return &Value{Type: ValueBool, Data: l <= r}
case OpGt:
return &Value{Type: ValueBool, Data: l > r}
case OpGte:
return &Value{Type: ValueBool, Data: l >= r}
}
return nil
}
// Dead code elimination support
func (cs *CompilerState) MarkUnreachable(start, end int) {
if start >= 0 && end <= len(cs.Chunk.Code) {
for i := start; i < end; i++ {
cs.Chunk.Code[i] = uint8(OpNoop)
}
}
}
// Optimization statistics tracking
type OptimizationStats struct {
ConstantsFolded int
InstructionsOpt int
DeadCodeEliminated int
JumpsOptimized int
}
func (cs *CompilerState) GetOptimizationStats() OptimizationStats {
// Count specialized instructions used
specialized := 0
noops := 0
for i := 0; i < len(cs.Chunk.Code); {
op, _, next := DecodeInstruction(cs.Chunk.Code, i)
if IsSpecializedInstruction(op) {
specialized++
}
if op == OpNoop {
noops++
}
i = next
}
return OptimizationStats{
InstructionsOpt: specialized,
DeadCodeEliminated: noops,
}
}
func (cs *CompilerState) SetLine(line int) {
cs.CurrentLine = line
}
// Debugging support
func (cs *CompilerState) PrintChunk(name string) {
fmt.Printf("== %s ==\n", name)
for offset := 0; offset < len(cs.Chunk.Code); {
offset = cs.disassembleInstruction(offset)
}
}
func (cs *CompilerState) disassembleInstruction(offset int) int {
fmt.Printf("%04d ", offset)
if offset > 0 && len(cs.Chunk.Lines) > offset &&
len(cs.Chunk.Lines) > offset-1 &&
cs.Chunk.Lines[offset] == cs.Chunk.Lines[offset-1] {
fmt.Print(" | ")
} else if len(cs.Chunk.Lines) > offset {
fmt.Printf("%4d ", cs.Chunk.Lines[offset])
} else {
fmt.Print(" ? ")
}
if offset >= len(cs.Chunk.Code) {
fmt.Println("END")
return offset + 1
}
instruction := cs.Chunk.Code[offset]
op := Opcode(instruction)
if name, exists := opcodeNames[op]; exists {
fmt.Printf("%-16s", name)
} else {
fmt.Printf("UNKNOWN_%02x ", instruction)
}
switch op {
case OpLoadConst:
return cs.constantInstruction(offset)
case OpLoadLocal, OpStoreLocal:
return cs.byteInstruction(offset)
case OpJump, OpJumpIfTrue, OpJumpIfFalse:
return cs.jumpInstruction(offset, 1)
case OpLoopBack:
return cs.jumpInstruction(offset, -1)
default:
fmt.Println()
return offset + 1
}
}
func (cs *CompilerState) constantInstruction(offset int) int {
if offset+2 >= len(cs.Chunk.Code) {
fmt.Println(" [incomplete]")
return offset + 1
}
constant := uint16(cs.Chunk.Code[offset+1]) | (uint16(cs.Chunk.Code[offset+2]) << 8)
fmt.Printf(" %4d '", constant)
if int(constant) < len(cs.Chunk.Constants) {
cs.printValue(cs.Chunk.Constants[constant])
} else {
fmt.Print("???")
}
fmt.Println("'")
return offset + 3
}
func (cs *CompilerState) byteInstruction(offset int) int {
if offset+2 >= len(cs.Chunk.Code) {
fmt.Println(" [incomplete]")
return offset + 1
}
slot := uint16(cs.Chunk.Code[offset+1]) | (uint16(cs.Chunk.Code[offset+2]) << 8)
fmt.Printf(" %4d\n", slot)
return offset + 3
}
func (cs *CompilerState) jumpInstruction(offset int, sign int) int {
if offset+2 >= len(cs.Chunk.Code) {
fmt.Println(" [incomplete]")
return offset + 1
}
jump := uint16(cs.Chunk.Code[offset+1]) | (uint16(cs.Chunk.Code[offset+2]) << 8)
target := offset + 3 + sign*int(jump)
fmt.Printf(" %4d -> %d\n", jump, target)
return offset + 3
}
func (cs *CompilerState) printValue(value Value) {
switch value.Type {
case ValueNil:
fmt.Print("nil")
case ValueBool:
if value.Data.(bool) {
fmt.Print("true")
} else {
fmt.Print("false")
}
case ValueNumber:
fmt.Printf("%.2g", value.Data.(float64))
case ValueString:
fmt.Printf("\"%s\"", value.Data.(string))
default:
fmt.Printf("<%s>", cs.valueTypeString(value.Type))
}
}
func (cs *CompilerState) valueTypeString(vt ValueType) string {
switch vt {
case ValueTable:
return "table"
case ValueFunction:
return "function"
case ValueStruct:
return "struct"
case ValueArray:
return "array"
case ValueUpvalue:
return "upvalue"
default:
return "unknown"
}
}