diff --git a/compiler/compiler.go b/compiler/compiler.go index 60e6546..f14d99e 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -2,6 +2,7 @@ package compiler import ( "fmt" + "math" "git.sharkk.net/Sharkk/Mako/parser" ) @@ -77,7 +78,155 @@ func (c *Compiler) compileStatement(stmt parser.Statement) { } } -// Expression compilation with constant folding +// Enhanced constant folding engine with complete coverage +func (c *Compiler) tryConstantFold(expr parser.Expression) *Value { + switch e := expr.(type) { + case *parser.NumberLiteral: + return &Value{Type: ValueNumber, Data: e.Value} + case *parser.StringLiteral: + return &Value{Type: ValueString, Data: e.Value} + case *parser.BooleanLiteral: + return &Value{Type: ValueBool, Data: e.Value} + case *parser.NilLiteral: + return &Value{Type: ValueNil, Data: nil} + case *parser.PrefixExpression: + return c.foldPrefixExpression(e) + case *parser.InfixExpression: + return c.foldInfixExpression(e) + } + return nil +} + +func (c *Compiler) foldPrefixExpression(expr *parser.PrefixExpression) *Value { + rightValue := c.tryConstantFold(expr.Right) + if rightValue == nil { + return nil + } + + switch expr.Operator { + case "-": + if rightValue.Type == ValueNumber { + num := rightValue.Data.(float64) + return &Value{Type: ValueNumber, Data: -num} + } + case "not": + return &Value{Type: ValueBool, Data: !c.isTruthy(*rightValue)} + } + return nil +} + +func (c *Compiler) foldInfixExpression(expr *parser.InfixExpression) *Value { + leftValue := c.tryConstantFold(expr.Left) + rightValue := c.tryConstantFold(expr.Right) + + if leftValue == nil || rightValue == nil { + return nil + } + + // Arithmetic operations + if leftValue.Type == ValueNumber && rightValue.Type == ValueNumber { + l := leftValue.Data.(float64) + r := rightValue.Data.(float64) + + switch expr.Operator { + case "+": + return &Value{Type: ValueNumber, Data: l + r} + case "-": + return &Value{Type: ValueNumber, Data: l - r} + case "*": + return &Value{Type: ValueNumber, Data: l * r} + case "/": + if r != 0 { + return &Value{Type: ValueNumber, Data: l / r} + } + return nil + case "%": + if r != 0 { + return &Value{Type: ValueNumber, Data: math.Mod(l, r)} + } + return nil + case "<": + return &Value{Type: ValueBool, Data: l < r} + case "<=": + return &Value{Type: ValueBool, Data: l <= r} + case ">": + return &Value{Type: ValueBool, Data: l > r} + case ">=": + return &Value{Type: ValueBool, Data: l >= r} + } + } + + // String operations + if leftValue.Type == ValueString && rightValue.Type == ValueString { + l := leftValue.Data.(string) + r := rightValue.Data.(string) + + switch expr.Operator { + case "+": + return &Value{Type: ValueString, Data: l + r} + } + } + + // Comparison operations + switch expr.Operator { + case "==": + return &Value{Type: ValueBool, Data: c.valuesEqual(*leftValue, *rightValue)} + case "!=": + return &Value{Type: ValueBool, Data: !c.valuesEqual(*leftValue, *rightValue)} + case "and": + if !c.isTruthy(*leftValue) { + return leftValue + } + return rightValue + case "or": + if c.isTruthy(*leftValue) { + return leftValue + } + return rightValue + } + + return nil +} + +func (c *Compiler) isTruthy(value Value) bool { + switch value.Type { + case ValueNil: + return false + case ValueBool: + return value.Data.(bool) + case ValueNumber: + return value.Data.(float64) != 0 + case ValueString: + return value.Data.(string) != "" + default: + return true + } +} + +func (c *Compiler) valuesEqual(a, b Value) bool { + if a.Type != b.Type { + return false + } + switch a.Type { + case ValueNil: + return true + case ValueBool: + return a.Data.(bool) == b.Data.(bool) + case ValueNumber: + aNum := a.Data.(float64) + bNum := b.Data.(float64) + if math.IsNaN(aNum) && math.IsNaN(bNum) { + return true + } + return aNum == bNum + case ValueString: + return a.Data.(string) == b.Data.(string) + default: + return false + } +} + +// Expression compilation with enhanced constant folding func (c *Compiler) compileExpression(expr parser.Expression) { if lineNode := c.getLineFromNode(expr); lineNode != 0 { c.current.SetLine(lineNode) @@ -123,131 +272,6 @@ func (c *Compiler) compileExpression(expr parser.Expression) { } } -// Constant folding engine -func (c *Compiler) tryConstantFold(expr parser.Expression) *Value { - switch e := expr.(type) { - case *parser.NumberLiteral: - return &Value{Type: ValueNumber, Data: e.Value} - case *parser.StringLiteral: - return &Value{Type: ValueString, Data: e.Value} - case *parser.BooleanLiteral: - return &Value{Type: ValueBool, Data: e.Value} - case *parser.NilLiteral: - return &Value{Type: ValueNil, Data: nil} - case *parser.PrefixExpression: - return c.foldPrefixExpression(e) - case *parser.InfixExpression: - return c.foldInfixExpression(e) - } - return nil -} - -func (c *Compiler) foldPrefixExpression(expr *parser.PrefixExpression) *Value { - rightValue := c.tryConstantFold(expr.Right) - if rightValue == nil { - return nil - } - - switch expr.Operator { - case "-": - if rightValue.Type == ValueNumber { - return &Value{Type: ValueNumber, Data: -rightValue.Data.(float64)} - } - case "not": - return &Value{Type: ValueBool, Data: !c.isTruthy(*rightValue)} - } - return nil -} - -func (c *Compiler) foldInfixExpression(expr *parser.InfixExpression) *Value { - leftValue := c.tryConstantFold(expr.Left) - rightValue := c.tryConstantFold(expr.Right) - - if leftValue == nil || rightValue == nil { - return nil - } - - // Arithmetic operations - if leftValue.Type == ValueNumber && rightValue.Type == ValueNumber { - l := leftValue.Data.(float64) - r := rightValue.Data.(float64) - - switch expr.Operator { - case "+": - return &Value{Type: ValueNumber, Data: l + r} - case "-": - return &Value{Type: ValueNumber, Data: l - r} - case "*": - return &Value{Type: ValueNumber, Data: l * r} - case "/": - if r != 0 { - return &Value{Type: ValueNumber, Data: l / r} - } - case "<": - return &Value{Type: ValueBool, Data: l < r} - case "<=": - return &Value{Type: ValueBool, Data: l <= r} - case ">": - return &Value{Type: ValueBool, Data: l > r} - case ">=": - return &Value{Type: ValueBool, Data: l >= r} - } - } - - // Comparison operations - switch expr.Operator { - case "==": - return &Value{Type: ValueBool, Data: c.valuesEqual(*leftValue, *rightValue)} - case "!=": - return &Value{Type: ValueBool, Data: !c.valuesEqual(*leftValue, *rightValue)} - } - - // Logical operations - switch expr.Operator { - case "and": - if !c.isTruthy(*leftValue) { - return leftValue - } - return rightValue - case "or": - if c.isTruthy(*leftValue) { - return leftValue - } - return rightValue - } - - return nil -} - -func (c *Compiler) isTruthy(value Value) bool { - switch value.Type { - case ValueNil: - return false - case ValueBool: - return value.Data.(bool) - default: - return true - } -} - -func (c *Compiler) valuesEqual(a, b Value) bool { - if a.Type != b.Type { - return false - } - switch a.Type { - case ValueNil: - return true - case ValueBool: - return a.Data.(bool) == b.Data.(bool) - case ValueNumber: - return a.Data.(float64) == b.Data.(float64) - case ValueString: - return a.Data.(string) == b.Data.(string) - default: - return false - } -} - // Optimized constant emission func (c *Compiler) emitConstant(value Value) { switch value.Type { @@ -265,6 +289,9 @@ func (c *Compiler) emitConstant(value Value) { c.current.EmitInstruction(OpLoadZero) } else if num == 1 { c.current.EmitInstruction(OpLoadOne) + } else if num == -1 { + c.current.EmitInstruction(OpLoadOne) + c.current.EmitInstruction(OpNeg) } else { index := c.current.AddConstant(value) if index == -1 { @@ -283,7 +310,7 @@ func (c *Compiler) emitConstant(value Value) { } } -// Literal compilation with optimizations +// Literal compilation func (c *Compiler) compileNumberLiteral(node *parser.NumberLiteral) { value := Value{Type: ValueNumber, Data: node.Value} c.emitConstant(value) @@ -303,7 +330,7 @@ func (c *Compiler) compileNilLiteral(node *parser.NilLiteral) { c.current.EmitInstruction(OpLoadNil) } -// Optimized identifier compilation +// Identifier compilation func (c *Compiler) compileIdentifier(node *parser.Identifier) { slot := c.current.ResolveLocal(node.Value) if slot != -1 { @@ -331,7 +358,7 @@ func (c *Compiler) compileIdentifier(node *parser.Identifier) { c.current.EmitInstruction(OpLoadGlobal, uint16(index)) } -// Optimized local variable access +// Local variable access helpers func (c *Compiler) emitLoadLocal(slot int) { switch slot { case 0: @@ -358,7 +385,7 @@ func (c *Compiler) emitStoreLocal(slot int) { } } -// Assignment compilation with optimizations +// Assignment compilation func (c *Compiler) compileAssignment(node *parser.Assignment) { c.compileExpression(node.Value) @@ -414,13 +441,10 @@ func (c *Compiler) compileAssignment(node *parser.Assignment) { } } -// Optimized dot expression assignment func (c *Compiler) compileDotAssignment(dot *parser.DotExpression) { - // Check for local.field optimization if ident, ok := dot.Left.(*parser.Identifier); ok { slot := c.current.ResolveLocal(ident.Value) if slot != -1 && slot <= 2 { - // Use optimized local field assignment value := Value{Type: ValueString, Data: dot.Key} index := c.current.AddConstant(value) if index == -1 { @@ -432,7 +456,6 @@ func (c *Compiler) compileDotAssignment(dot *parser.DotExpression) { } } - // Fall back to regular field assignment c.compileExpression(dot.Left) value := Value{Type: ValueString, Data: dot.Key} index := c.current.AddConstant(value) @@ -445,10 +468,9 @@ func (c *Compiler) compileDotAssignment(dot *parser.DotExpression) { func (c *Compiler) compileAssignmentExpression(node *parser.Assignment) { c.compileAssignment(node) - // Assignment expressions leave the assigned value on stack } -// Optimized operator compilation +// Operator compilation func (c *Compiler) compilePrefixExpression(node *parser.PrefixExpression) { c.compileExpression(node.Right) @@ -463,12 +485,15 @@ func (c *Compiler) compilePrefixExpression(node *parser.PrefixExpression) { } func (c *Compiler) compileInfixExpression(node *parser.InfixExpression) { - // Check for increment/decrement patterns if c.tryOptimizeIncDec(node) { return } - // Handle short-circuit operators + if c.tryOptimizeArithmeticWithConstant(node) { + return + } + + // Short-circuit operators if node.Operator == "and" { c.compileExpression(node.Left) jump := c.current.EmitJump(OpJumpIfFalse) @@ -502,6 +527,8 @@ func (c *Compiler) compileInfixExpression(node *parser.InfixExpression) { c.current.EmitInstruction(OpMul) case "/": c.current.EmitInstruction(OpDiv) + case "%": + c.current.EmitInstruction(OpMod) case "==": c.current.EmitInstruction(OpEq) case "!=": @@ -519,9 +546,7 @@ func (c *Compiler) compileInfixExpression(node *parser.InfixExpression) { } } -// Try to optimize increment/decrement patterns func (c *Compiler) tryOptimizeIncDec(node *parser.InfixExpression) bool { - // Look for patterns like: var = var + 1 or var = var - 1 if node.Operator != "+" && node.Operator != "-" { return false } @@ -541,25 +566,60 @@ func (c *Compiler) tryOptimizeIncDec(node *parser.InfixExpression) bool { return false } - // Emit optimized increment/decrement if node.Operator == "+" { c.current.EmitInstruction(OpInc, uint16(slot)) } else { c.current.EmitInstruction(OpDec, uint16(slot)) } - // Load the result back onto stack c.emitLoadLocal(slot) return true } -// Optimized dot expression compilation +func (c *Compiler) tryOptimizeArithmeticWithConstant(node *parser.InfixExpression) bool { + if node.Operator != "+" && node.Operator != "-" { + return false + } + + leftIdent, ok := node.Left.(*parser.Identifier) + if !ok { + return false + } + + rightLit, ok := node.Right.(*parser.NumberLiteral) + if !ok { + return false + } + + slot := c.current.ResolveLocal(leftIdent.Value) + if slot == -1 { + return false + } + + if rightLit.Value >= 0 && rightLit.Value <= 255 && rightLit.Value == math.Floor(rightLit.Value) { + c.emitLoadLocal(slot) + + constValue := Value{Type: ValueNumber, Data: rightLit.Value} + constIndex := c.current.AddConstant(constValue) + if constIndex == -1 { + return false + } + + if node.Operator == "+" { + c.current.EmitInstruction(OpAddConst, uint16(constIndex)) + } else { + c.current.EmitInstruction(OpSubConst, uint16(constIndex)) + } + return true + } + + return false +} + func (c *Compiler) compileDotExpression(node *parser.DotExpression) { - // Check for local.field optimization if ident, ok := node.Left.(*parser.Identifier); ok { slot := c.current.ResolveLocal(ident.Value) if slot != -1 && slot <= 2 { - // Use optimized local field access value := Value{Type: ValueString, Data: node.Key} index := c.current.AddConstant(value) if index == -1 { @@ -571,7 +631,6 @@ func (c *Compiler) compileDotExpression(node *parser.DotExpression) { } } - // Fall back to regular field access c.compileExpression(node.Left) value := Value{Type: ValueString, Data: node.Key} index := c.current.AddConstant(value) @@ -582,18 +641,14 @@ func (c *Compiler) compileDotExpression(node *parser.DotExpression) { c.current.EmitInstruction(OpGetField, uint16(index)) } -// Optimized function call compilation func (c *Compiler) compileCallExpression(node *parser.CallExpression) { - // Check for calls to local functions if ident, ok := node.Function.(*parser.Identifier); ok { slot := c.current.ResolveLocal(ident.Value) if slot == 0 || slot == 1 { - // Compile arguments for _, arg := range node.Arguments { c.compileExpression(arg) } - // Use optimized call instruction if slot == 0 { c.current.EmitInstruction(OpCallLocal0, uint16(len(node.Arguments))) } else { @@ -603,7 +658,6 @@ func (c *Compiler) compileCallExpression(node *parser.CallExpression) { } } - // Regular function call c.compileExpression(node.Function) for _, arg := range node.Arguments { c.compileExpression(arg) @@ -611,193 +665,28 @@ func (c *Compiler) compileCallExpression(node *parser.CallExpression) { c.current.EmitInstruction(OpCall, uint16(len(node.Arguments))) } -// Control flow compilation (unchanged from original) -func (c *Compiler) compileIfStatement(node *parser.IfStatement) { - c.compileExpression(node.Condition) +func (c *Compiler) compileIndexExpression(node *parser.IndexExpression) { + c.compileExpression(node.Left) + c.compileExpression(node.Index) + c.current.EmitInstruction(OpGetIndex) +} - thenJump := c.current.EmitJump(OpJumpIfFalse) - c.current.EmitInstruction(OpPop) +func (c *Compiler) compileTableLiteral(node *parser.TableLiteral) { + c.current.EmitInstruction(OpNewTable) - c.current.BeginScope() - for _, stmt := range node.Body { - c.compileStatement(stmt) - } - c.current.EndScope() - - elseJump := c.current.EmitJump(OpJump) - c.current.PatchJump(thenJump) - c.current.EmitInstruction(OpPop) - - var elseifJumps []int - for _, elseif := range node.ElseIfs { - c.compileExpression(elseif.Condition) - nextJump := c.current.EmitJump(OpJumpIfFalse) - c.current.EmitInstruction(OpPop) - - c.current.BeginScope() - for _, stmt := range elseif.Body { - c.compileStatement(stmt) + for _, pair := range node.Pairs { + if pair.Key == nil { + c.compileExpression(pair.Value) + c.current.EmitInstruction(OpTableInsert) + } else { + c.current.EmitInstruction(OpDup) + c.compileExpression(pair.Key) + c.compileExpression(pair.Value) + c.current.EmitInstruction(OpSetIndex) } - c.current.EndScope() - - elseifJumps = append(elseifJumps, c.current.EmitJump(OpJump)) - c.current.PatchJump(nextJump) - c.current.EmitInstruction(OpPop) - } - - if len(node.Else) > 0 { - c.current.BeginScope() - for _, stmt := range node.Else { - c.compileStatement(stmt) - } - c.current.EndScope() - } - - c.current.PatchJump(elseJump) - for _, jump := range elseifJumps { - c.current.PatchJump(jump) } } -func (c *Compiler) compileWhileStatement(node *parser.WhileStatement) { - c.current.EnterLoop() - - c.compileExpression(node.Condition) - exitJump := c.current.EmitJump(OpJumpIfFalse) - c.current.EmitInstruction(OpPop) - - c.current.BeginScope() - for _, stmt := range node.Body { - c.compileStatement(stmt) - } - c.current.EndScope() - - // Use optimized loop back instruction - jump := len(c.current.Chunk.Code) - c.current.LoopStart + 2 - c.current.EmitInstruction(OpLoopBack, uint16(jump)) - - c.current.PatchJump(exitJump) - c.current.EmitInstruction(OpPop) - - c.current.ExitLoop() -} - -// Remaining compilation methods (struct, function, etc.) unchanged but with optimization calls - -func (c *Compiler) compileForStatement(node *parser.ForStatement) { - c.current.BeginScope() - c.current.EnterLoop() - - c.compileExpression(node.Start) - if err := c.current.AddLocal(node.Variable.Value); err != nil { - c.addError(err.Error()) - return - } - c.current.MarkInitialized() - loopVar := len(c.current.Locals) - 1 - - c.compileExpression(node.End) - endSlot := len(c.current.Locals) - if err := c.current.AddLocal("__end"); err != nil { - c.addError(err.Error()) - return - } - c.current.MarkInitialized() - - if node.Step != nil { - c.compileExpression(node.Step) - } else { - c.current.EmitInstruction(OpLoadOne) - } - stepSlot := len(c.current.Locals) - if err := c.current.AddLocal("__step"); err != nil { - c.addError(err.Error()) - return - } - c.current.MarkInitialized() - - conditionStart := len(c.current.Chunk.Code) - c.emitLoadLocal(loopVar) - c.emitLoadLocal(endSlot) - c.current.EmitInstruction(OpLte) - exitJump := c.current.EmitJump(OpJumpIfFalse) - c.current.EmitInstruction(OpPop) - - for _, stmt := range node.Body { - c.compileStatement(stmt) - } - - c.emitLoadLocal(loopVar) - c.emitLoadLocal(stepSlot) - c.current.EmitInstruction(OpAdd) - c.emitStoreLocal(loopVar) - - jumpBack := len(c.current.Chunk.Code) - conditionStart + 2 - c.current.EmitInstruction(OpLoopBack, uint16(jumpBack)) - - c.current.PatchJump(exitJump) - c.current.EmitInstruction(OpPop) - - c.current.ExitLoop() - c.current.EndScope() -} - -// Apply chunk-level optimizations -func (c *Compiler) optimizeChunk(chunk *Chunk) { - c.peepholeOptimize(chunk) - c.eliminateDeadCode(chunk) -} - -func (c *Compiler) peepholeOptimize(chunk *Chunk) { - // Simple peephole optimizations - code := chunk.Code - i := 0 - - for i < len(code)-6 { - op1, _, next1 := DecodeInstruction(code, i) - op2, _, _ := DecodeInstruction(code, next1) - - // Remove POP followed by same constant load - if op1 == OpPop && (op2 == OpLoadTrue || op2 == OpLoadFalse || op2 == OpLoadNil) { - // Could optimize in some cases - } - - i = next1 - } -} - -func (c *Compiler) eliminateDeadCode(chunk *Chunk) { - // Remove unreachable code after returns/exits - code := chunk.Code - i := 0 - - for i < len(code) { - op, _, next := DecodeInstruction(code, i) - - if op == OpReturn || op == OpReturnNil || op == OpExit { - // Mark subsequent instructions as dead until next reachable point - for j := next; j < len(code); j++ { - _, _, nextNext := DecodeInstruction(code, j) - if c.isJumpTarget(chunk, j) { - break - } - code[j] = uint8(OpNoop) - j = nextNext - 1 - } - } - - i = next - } -} - -func (c *Compiler) isJumpTarget(chunk *Chunk, offset int) bool { - // Simple check - would need more sophisticated analysis in real implementation - return false -} - -// Keep all other methods from original compiler.go unchanged -// (struct compilation, function compilation, etc.) - func (c *Compiler) compileStructStatement(node *parser.StructStatement) { fields := make([]StructField, len(node.Fields)) for i, field := range node.Fields { @@ -901,28 +790,6 @@ func (c *Compiler) compileStructConstructor(node *parser.StructConstructor) { } } -func (c *Compiler) compileTableLiteral(node *parser.TableLiteral) { - c.current.EmitInstruction(OpNewTable) - - for _, pair := range node.Pairs { - if pair.Key == nil { - c.compileExpression(pair.Value) - c.current.EmitInstruction(OpTableInsert) - } else { - c.current.EmitInstruction(OpDup) - c.compileExpression(pair.Key) - c.compileExpression(pair.Value) - c.current.EmitInstruction(OpSetIndex) - } - } -} - -func (c *Compiler) compileIndexExpression(node *parser.IndexExpression) { - c.compileExpression(node.Left) - c.compileExpression(node.Index) - c.current.EmitInstruction(OpGetIndex) -} - func (c *Compiler) compileFunctionLiteral(node *parser.FunctionLiteral) { enclosing := c.current c.current = NewCompilerState(FunctionTypeFunction) @@ -964,22 +831,132 @@ func (c *Compiler) compileFunctionLiteral(node *parser.FunctionLiteral) { c.current.EmitInstruction(OpClosure, uint16(functionIndex), uint16(function.UpvalCount)) } -func (c *Compiler) compileReturnStatement(node *parser.ReturnStatement) { - if node.Value != nil { - c.compileExpression(node.Value) - c.current.EmitInstruction(OpReturn) - } else { - c.current.EmitInstruction(OpReturnNil) +// Control flow compilation +func (c *Compiler) compileIfStatement(node *parser.IfStatement) { + c.compileExpression(node.Condition) + + thenJump := c.current.EmitJump(OpJumpIfFalse) + c.current.EmitInstruction(OpPop) + + c.current.BeginScope() + for _, stmt := range node.Body { + c.compileStatement(stmt) + } + c.current.EndScope() + + elseJump := c.current.EmitJump(OpJump) + c.current.PatchJump(thenJump) + c.current.EmitInstruction(OpPop) + + var elseifJumps []int + for _, elseif := range node.ElseIfs { + c.compileExpression(elseif.Condition) + nextJump := c.current.EmitJump(OpJumpIfFalse) + c.current.EmitInstruction(OpPop) + + c.current.BeginScope() + for _, stmt := range elseif.Body { + c.compileStatement(stmt) + } + c.current.EndScope() + + elseifJumps = append(elseifJumps, c.current.EmitJump(OpJump)) + c.current.PatchJump(nextJump) + c.current.EmitInstruction(OpPop) + } + + if len(node.Else) > 0 { + c.current.BeginScope() + for _, stmt := range node.Else { + c.compileStatement(stmt) + } + c.current.EndScope() + } + + c.current.PatchJump(elseJump) + for _, jump := range elseifJumps { + c.current.PatchJump(jump) } } -func (c *Compiler) compileExitStatement(node *parser.ExitStatement) { - if node.Value != nil { - c.compileExpression(node.Value) - } else { - c.current.EmitInstruction(OpLoadZero) +func (c *Compiler) compileWhileStatement(node *parser.WhileStatement) { + c.current.EnterLoop() + + c.compileExpression(node.Condition) + exitJump := c.current.EmitJump(OpJumpIfFalse) + c.current.EmitInstruction(OpPop) + + c.current.BeginScope() + for _, stmt := range node.Body { + c.compileStatement(stmt) } - c.current.EmitInstruction(OpExit) + c.current.EndScope() + + jump := len(c.current.Chunk.Code) - c.current.LoopStart + 2 + c.current.EmitInstruction(OpLoopBack, uint16(jump)) + + c.current.PatchJump(exitJump) + c.current.EmitInstruction(OpPop) + + c.current.ExitLoop() +} + +func (c *Compiler) compileForStatement(node *parser.ForStatement) { + c.current.BeginScope() + c.current.EnterLoop() + + c.compileExpression(node.Start) + if err := c.current.AddLocal(node.Variable.Value); err != nil { + c.addError(err.Error()) + return + } + c.current.MarkInitialized() + loopVar := len(c.current.Locals) - 1 + + c.compileExpression(node.End) + endSlot := len(c.current.Locals) + if err := c.current.AddLocal("__end"); err != nil { + c.addError(err.Error()) + return + } + c.current.MarkInitialized() + + if node.Step != nil { + c.compileExpression(node.Step) + } else { + c.current.EmitInstruction(OpLoadOne) + } + stepSlot := len(c.current.Locals) + if err := c.current.AddLocal("__step"); err != nil { + c.addError(err.Error()) + return + } + c.current.MarkInitialized() + + conditionStart := len(c.current.Chunk.Code) + c.emitLoadLocal(loopVar) + c.emitLoadLocal(endSlot) + c.current.EmitInstruction(OpLte) + exitJump := c.current.EmitJump(OpJumpIfFalse) + c.current.EmitInstruction(OpPop) + + for _, stmt := range node.Body { + c.compileStatement(stmt) + } + + c.emitLoadLocal(loopVar) + c.emitLoadLocal(stepSlot) + c.current.EmitInstruction(OpAdd) + c.emitStoreLocal(loopVar) + + jumpBack := len(c.current.Chunk.Code) - conditionStart + 2 + c.current.EmitInstruction(OpLoopBack, uint16(jumpBack)) + + c.current.PatchJump(exitJump) + c.current.EmitInstruction(OpPop) + + c.current.ExitLoop() + c.current.EndScope() } func (c *Compiler) compileForInStatement(node *parser.ForInStatement) { @@ -1024,6 +1001,114 @@ func (c *Compiler) compileForInStatement(node *parser.ForInStatement) { c.current.EndScope() } +func (c *Compiler) compileReturnStatement(node *parser.ReturnStatement) { + if node.Value != nil { + c.compileExpression(node.Value) + c.current.EmitInstruction(OpReturn) + } else { + c.current.EmitInstruction(OpReturnNil) + } +} + +func (c *Compiler) compileExitStatement(node *parser.ExitStatement) { + if node.Value != nil { + c.compileExpression(node.Value) + } else { + c.current.EmitInstruction(OpLoadZero) + } + c.current.EmitInstruction(OpExit) +} + +// Optimization +func (c *Compiler) optimizeChunk(chunk *Chunk) { + c.peepholeOptimize(chunk) + c.eliminateDeadCode(chunk) +} + +func (c *Compiler) peepholeOptimize(chunk *Chunk) { + code := chunk.Code + changed := true + + for changed { + changed = false + i := 0 + + for i < len(code)-6 { + op1, _, next1 := DecodeInstruction(code, i) + + if next1 < len(code) { + op2, _, next2 := DecodeInstruction(code, next1) + + if op1 == OpLoadConst && op2 == OpPop { + c.removeInstructions(chunk, i, next2) + changed = true + continue + } + + if op1 == OpLoadZero && op2 == OpAdd { + code[i] = uint8(OpNoop) + code[next1] = uint8(OpNoop) + changed = true + } + + if op1 == OpNeg && op2 == OpNeg { + code[i] = uint8(OpNoop) + code[next1] = uint8(OpNoop) + changed = true + } + + if (op1 == OpLoadTrue || op1 == OpLoadFalse) && op2 == OpNot { + if op1 == OpLoadTrue { + code[i] = uint8(OpLoadFalse) + } else { + code[i] = uint8(OpLoadTrue) + } + code[next1] = uint8(OpNoop) + changed = true + } + } + + i = next1 + } + } +} + +func (c *Compiler) removeInstructions(chunk *Chunk, start, end int) { + for i := start; i < end && i < len(chunk.Code); i++ { + chunk.Code[i] = uint8(OpNoop) + } +} + +func (c *Compiler) eliminateDeadCode(chunk *Chunk) { + code := chunk.Code + i := 0 + + for i < len(code) { + op, _, next := DecodeInstruction(code, i) + + if op == OpReturn || op == OpReturnNil || op == OpExit { + for j := next; j < len(code); j++ { + nextOp, _, nextNext := DecodeInstruction(code, j) + if c.isJumpTarget(chunk, j) { + break + } + if nextOp == OpNoop { + j = nextNext - 1 + continue + } + code[j] = uint8(OpNoop) + j = nextNext - 1 + } + } + + i = next + } +} + +func (c *Compiler) isJumpTarget(chunk *Chunk, offset int) bool { + return false +} + // Helper methods func (c *Compiler) resolveUpvalue(name string) int { if c.enclosing == nil { @@ -1094,9 +1179,9 @@ func (c *Compiler) addError(message string) { }) } +func (c *Compiler) getLineFromNode(node any) int { + return 0 // Placeholder - would extract from AST node +} + func (c *Compiler) Errors() []CompileError { return c.errors } func (c *Compiler) HasErrors() bool { return len(c.errors) > 0 } - -func (c *Compiler) getLineFromNode(node any) int { - return 0 // Placeholder -} diff --git a/compiler/state.go b/compiler/state.go index b0f4fb5..885e003 100644 --- a/compiler/state.go +++ b/compiler/state.go @@ -173,7 +173,7 @@ func (cs *CompilerState) AddUpvalue(index uint8, isLocal bool) int { return upvalueCount } -// Optimized constant pool management with deduplication +// Enhanced constant pool management with better deduplication func (cs *CompilerState) AddConstant(value Value) int { // Generate unique key for deduplication key := cs.valueKey(value) @@ -191,7 +191,7 @@ func (cs *CompilerState) AddConstant(value Value) int { return index } -// Generate unique key for value deduplication +// Enhanced value key generation for better deduplication func (cs *CompilerState) valueKey(value Value) string { switch value.Type { case ValueNil: @@ -202,15 +202,41 @@ func (cs *CompilerState) valueKey(value Value) string { } return "bool:false" case ValueNumber: - return fmt.Sprintf("number:%g", value.Data.(float64)) + num := value.Data.(float64) + // Handle special numeric values + if num == 0 { + return "number:0" + } else if num == 1 { + return "number:1" + } else if num == -1 { + return "number:-1" + } + return fmt.Sprintf("number:%g", num) case ValueString: - return fmt.Sprintf("string:%s", value.Data.(string)) + str := value.Data.(string) + if str == "" { + return "string:empty" + } + // For very long strings, just use a hash to avoid memory issues + if len(str) > 100 { + return fmt.Sprintf("string:hash:%d", cs.simpleHash(str)) + } + return fmt.Sprintf("string:%s", str) default: // For complex types, use memory address as fallback return fmt.Sprintf("%T:%p", value.Data, value.Data) } } +// Simple hash function for long strings +func (cs *CompilerState) simpleHash(s string) uint32 { + var hash uint32 + for _, c := range s { + hash = hash*31 + uint32(c) + } + return hash +} + // Optimized bytecode emission methods func (cs *CompilerState) EmitByte(byte uint8) { cs.Chunk.Code = append(cs.Chunk.Code, byte) @@ -240,7 +266,7 @@ func (cs *CompilerState) PatchJump(offset int) { jump := len(cs.Chunk.Code) - offset - 2 if jump > 65535 { - // Jump distance too large - would need to implement long jumps + // Jump distance too large - could implement long jumps here return } @@ -352,7 +378,7 @@ func (cs *CompilerState) EmitStoreLocal(slot int) { } } -// Instruction pattern detection for optimization +// Enhanced instruction pattern detection for optimization func (cs *CompilerState) GetLastInstruction() (Opcode, []uint16) { if len(cs.Chunk.Code) == 0 { return OpNoop, nil @@ -367,8 +393,10 @@ func (cs *CompilerState) GetLastInstruction() (Opcode, []uint16) { // This is a complete instruction operands := make([]uint16, operandCount) for j := 0; j < operandCount; j++ { - operands[j] = uint16(cs.Chunk.Code[i+1+j*2]) | - (uint16(cs.Chunk.Code[i+2+j*2]) << 8) + if i+1+j*2 < len(cs.Chunk.Code) && i+2+j*2 < len(cs.Chunk.Code) { + operands[j] = uint16(cs.Chunk.Code[i+1+j*2]) | + (uint16(cs.Chunk.Code[i+2+j*2]) << 8) + } } return op, operands } @@ -402,50 +430,6 @@ func (cs *CompilerState) ReplaceLastInstruction(op Opcode, operands ...uint16) b return true } -// Constant folding support -func (cs *CompilerState) TryConstantFolding(op Opcode, operands ...Value) *Value { - if len(operands) < 2 { - return nil - } - - left, right := operands[0], operands[1] - - // Only fold numeric operations for now - if left.Type != ValueNumber || right.Type != ValueNumber { - return nil - } - - l := left.Data.(float64) - r := right.Data.(float64) - - switch op { - case OpAdd: - return &Value{Type: ValueNumber, Data: l + r} - case OpSub: - return &Value{Type: ValueNumber, Data: l - r} - case OpMul: - return &Value{Type: ValueNumber, Data: l * r} - case OpDiv: - if r != 0 { - return &Value{Type: ValueNumber, Data: l / r} - } - case OpEq: - return &Value{Type: ValueBool, Data: l == r} - case OpNeq: - return &Value{Type: ValueBool, Data: l != r} - case OpLt: - return &Value{Type: ValueBool, Data: l < r} - case OpLte: - return &Value{Type: ValueBool, Data: l <= r} - case OpGt: - return &Value{Type: ValueBool, Data: l > r} - case OpGte: - return &Value{Type: ValueBool, Data: l >= r} - } - - return nil -} - // Dead code elimination support func (cs *CompilerState) MarkUnreachable(start, end int) { if start >= 0 && end <= len(cs.Chunk.Code) { @@ -461,12 +445,14 @@ type OptimizationStats struct { InstructionsOpt int DeadCodeEliminated int JumpsOptimized int + ConstantsDeduped int } func (cs *CompilerState) GetOptimizationStats() OptimizationStats { // Count specialized instructions used specialized := 0 noops := 0 + constantsDeduped := len(cs.Constants) - len(cs.Chunk.Constants) for i := 0; i < len(cs.Chunk.Code); { op, _, next := DecodeInstruction(cs.Chunk.Code, i) @@ -482,6 +468,7 @@ func (cs *CompilerState) GetOptimizationStats() OptimizationStats { return OptimizationStats{ InstructionsOpt: specialized, DeadCodeEliminated: noops, + ConstantsDeduped: constantsDeduped, } } @@ -489,13 +476,31 @@ func (cs *CompilerState) SetLine(line int) { cs.CurrentLine = line } -// Debugging support +// Enhanced debugging support func (cs *CompilerState) PrintChunk(name string) { fmt.Printf("== %s ==\n", name) + fmt.Printf("Constants: %d\n", len(cs.Chunk.Constants)) + fmt.Printf("Functions: %d\n", len(cs.Chunk.Functions)) + fmt.Printf("Structs: %d\n", len(cs.Chunk.Structs)) + fmt.Printf("Code size: %d bytes\n", len(cs.Chunk.Code)) + + stats := cs.GetOptimizationStats() + fmt.Printf("Optimizations: %d specialized, %d dead eliminated, %d constants deduped\n", + stats.InstructionsOpt, stats.DeadCodeEliminated, stats.ConstantsDeduped) + fmt.Println() for offset := 0; offset < len(cs.Chunk.Code); { offset = cs.disassembleInstruction(offset) } + + if len(cs.Chunk.Constants) > 0 { + fmt.Println("\nConstants:") + for i, constant := range cs.Chunk.Constants { + fmt.Printf("%4d: ", i) + cs.printValue(constant) + fmt.Println() + } + } } func (cs *CompilerState) disassembleInstruction(offset int) int { @@ -528,12 +533,14 @@ func (cs *CompilerState) disassembleInstruction(offset int) int { switch op { case OpLoadConst: return cs.constantInstruction(offset) - case OpLoadLocal, OpStoreLocal: + case OpLoadLocal, OpStoreLocal, OpAddConst, OpSubConst, OpInc, OpDec: return cs.byteInstruction(offset) case OpJump, OpJumpIfTrue, OpJumpIfFalse: return cs.jumpInstruction(offset, 1) case OpLoopBack: return cs.jumpInstruction(offset, -1) + case OpGetLocalField, OpSetLocalField, OpTestAndJump: + return cs.doubleByteInstruction(offset) default: fmt.Println() return offset + 1 @@ -570,6 +577,18 @@ func (cs *CompilerState) byteInstruction(offset int) int { return offset + 3 } +func (cs *CompilerState) doubleByteInstruction(offset int) int { + if offset+4 >= len(cs.Chunk.Code) { + fmt.Println(" [incomplete]") + return offset + 1 + } + + arg1 := uint16(cs.Chunk.Code[offset+1]) | (uint16(cs.Chunk.Code[offset+2]) << 8) + arg2 := uint16(cs.Chunk.Code[offset+3]) | (uint16(cs.Chunk.Code[offset+4]) << 8) + fmt.Printf(" %4d %4d\n", arg1, arg2) + return offset + 5 +} + func (cs *CompilerState) jumpInstruction(offset int, sign int) int { if offset+2 >= len(cs.Chunk.Code) { fmt.Println(" [incomplete]") @@ -593,9 +612,14 @@ func (cs *CompilerState) printValue(value Value) { fmt.Print("false") } case ValueNumber: - fmt.Printf("%.2g", value.Data.(float64)) + fmt.Printf("%.6g", value.Data.(float64)) case ValueString: - fmt.Printf("\"%s\"", value.Data.(string)) + str := value.Data.(string) + if len(str) > 50 { + fmt.Printf("\"%s...\"", str[:47]) + } else { + fmt.Printf("\"%s\"", str) + } default: fmt.Printf("<%s>", cs.valueTypeString(value.Type)) }