diff --git a/compiler/compile.go b/compiler/compile.go new file mode 100644 index 0000000..5190f06 --- /dev/null +++ b/compiler/compile.go @@ -0,0 +1,21 @@ +package compiler + +import ( + "git.sharkk.net/Sharkk/Mako/parser" + "git.sharkk.net/Sharkk/Mako/types" +) + +// CompileSource compiles source code to bytecode +func CompileSource(source string) (*types.Function, []*types.MakoError) { + // Parse the source code into an AST + p := parser.New(source) + statements := p.Parse() + + // Create a compiler for the main (global) scope + compiler := New("script", nil) + + // Compile the statements to bytecode + function, errors := compiler.Compile(statements) + + return function, errors +} diff --git a/compiler/compiler.go b/compiler/compiler.go new file mode 100644 index 0000000..55e0cac --- /dev/null +++ b/compiler/compiler.go @@ -0,0 +1,519 @@ +package compiler + +import ( + "git.sharkk.net/Sharkk/Mako/types" +) + +// Compiler manages the state for compilation +type Compiler struct { + chunk *types.Chunk + locals []local + scopeDepth int + enclosing *Compiler + upvalues []types.Upvalue + errors []*types.MakoError + currentFunction *types.Function +} + +type local struct { + name string + depth int + isCaptured bool +} + +// New creates a new compiler for a function +func New(name string, enclosing *Compiler) *Compiler { + compiler := &Compiler{ + chunk: &types.Chunk{ + Code: make([]types.Instruction, 0, 8), + Constants: make([]types.Value, 0, 8), + }, + locals: make([]local, 0, 8), + scopeDepth: 0, + enclosing: enclosing, + currentFunction: &types.Function{ + Name: name, + Chunk: nil, + Upvalues: nil, + LocalCount: 0, + }, + } + + // The first local slot is implicitly used by the function itself + compiler.locals = append(compiler.locals, local{ + name: name, + depth: 0, + }) + + compiler.currentFunction.Chunk = compiler.chunk + return compiler +} + +// Compile compiles statements into bytecode +func (c *Compiler) Compile(statements []types.Statement) (*types.Function, []*types.MakoError) { + for _, stmt := range statements { + c.statement(stmt) + } + + // Implicit return + c.emitReturn() + + c.currentFunction.LocalCount = len(c.locals) + c.currentFunction.Upvalues = c.upvalues + c.currentFunction.UpvalueCount = len(c.upvalues) + + return c.currentFunction, c.errors +} + +// statement compiles a statement +func (c *Compiler) statement(stmt types.Statement) { + switch s := stmt.(type) { + case types.ExpressionStmt: + c.expression(s.Expression) + c.emit(types.OP_POP, nil) + case types.AssignStmt: + c.assignment(s) + case types.FunctionStmt: + c.function(s) + case types.ReturnStmt: + c.returnStmt(s) + case types.IfStmt: + c.ifStatement(s) + case types.EchoStmt: + c.expression(s.Value) + c.emit(types.OP_PRINT, nil) + case types.BlockStmt: + c.beginScope() + for _, blockStmt := range s.Statements { + c.statement(blockStmt) + } + c.endScope() + } +} + +// assignment compiles a variable assignment +func (c *Compiler) assignment(stmt types.AssignStmt) { + c.expression(stmt.Value) + + if c.scopeDepth > 0 { + // Try to find it as a local first + for i := len(c.locals) - 1; i >= 0; i-- { + if c.locals[i].name == stmt.Name.Lexeme && c.locals[i].depth <= c.scopeDepth { + c.emit(types.OP_SET_LOCAL, []byte{byte(i)}) + return + } + } + } + + // Global variable + idx := c.makeConstant(types.StringValue{Value: stmt.Name.Lexeme}) + c.emit(types.OP_SET_GLOBAL, []byte{idx}) +} + +// expression compiles an expression +func (c *Compiler) expression(expr types.Expression) { + switch e := expr.(type) { + case types.LiteralExpr: + c.literal(e) + case types.BinaryExpr: + c.binary(e) + case types.UnaryExpr: + c.unary(e) + case types.VariableExpr: + c.variable(e) + case types.CallExpr: + c.call(e) + case types.FunctionExpr: + c.functionExpr(e) + } +} + +// literal compiles a literal value +func (c *Compiler) literal(expr types.LiteralExpr) { + switch v := expr.Value.(type) { + case nil: + c.emit(types.OP_NIL, nil) + case bool: + if v { + c.emit(types.OP_TRUE, nil) + } else { + c.emit(types.OP_FALSE, nil) + } + case float64: + idx := c.makeConstant(types.NumberValue{Value: v}) + c.emit(types.OP_CONSTANT, []byte{idx}) + case string: + idx := c.makeConstant(types.StringValue{Value: v}) + c.emit(types.OP_CONSTANT, []byte{idx}) + } +} + +// binary compiles a binary expression +func (c *Compiler) binary(expr types.BinaryExpr) { + c.expression(expr.Left) + c.expression(expr.Right) + + switch expr.Operator.Type { + case types.PLUS: + c.emit(types.OP_ADD, nil) + case types.MINUS: + c.emit(types.OP_SUBTRACT, nil) + case types.STAR: + c.emit(types.OP_MULTIPLY, nil) + case types.SLASH: + c.emit(types.OP_DIVIDE, nil) + case types.EQUAL_EQUAL: + c.emit(types.OP_EQUAL, nil) + case types.BANG_EQUAL: + c.emit(types.OP_EQUAL, nil) + c.emit(types.OP_NOT, nil) + case types.LESS: + c.emit(types.OP_LESS, nil) + case types.LESS_EQUAL: + c.emit(types.OP_GREATER, nil) + c.emit(types.OP_NOT, nil) + case types.GREATER: + c.emit(types.OP_GREATER, nil) + case types.GREATER_EQUAL: + c.emit(types.OP_LESS, nil) + c.emit(types.OP_NOT, nil) + case types.AND: + // Short-circuit evaluation + endJump := c.emitJump(types.OP_JUMP_IF_FALSE) + c.emit(types.OP_POP, nil) + c.patchJump(endJump) + case types.OR: + // Short-circuit evaluation + skipJump := c.emitJump(types.OP_JUMP_IF_FALSE) + endJump := c.emitJump(types.OP_JUMP) + c.patchJump(skipJump) + c.emit(types.OP_POP, nil) + c.patchJump(endJump) + } +} + +// unary compiles a unary expression +func (c *Compiler) unary(expr types.UnaryExpr) { + c.expression(expr.Right) + + switch expr.Operator.Type { + case types.MINUS: + c.emit(types.OP_NEGATE, nil) + } +} + +// variable compiles a variable reference +func (c *Compiler) variable(expr types.VariableExpr) { + // Try to resolve as local + for i := len(c.locals) - 1; i >= 0; i-- { + if c.locals[i].name == expr.Name.Lexeme && c.locals[i].depth <= c.scopeDepth { + c.emit(types.OP_GET_LOCAL, []byte{byte(i)}) + return + } + } + + // Try to resolve as upvalue + if index, ok := c.resolveUpvalue(expr.Name.Lexeme); ok { + c.emit(types.OP_GET_UPVALUE, []byte{byte(index)}) + return + } + + // Global variable + idx := c.makeConstant(types.StringValue{Value: expr.Name.Lexeme}) + c.emit(types.OP_GET_GLOBAL, []byte{idx}) +} + +// call compiles a function call +func (c *Compiler) call(expr types.CallExpr) { + c.expression(expr.Callee) + + argCount := byte(len(expr.Arguments)) + for _, arg := range expr.Arguments { + c.expression(arg) + } + + c.emit(types.OP_CALL, []byte{argCount}) +} + +// function compiles a function declaration +func (c *Compiler) function(stmt types.FunctionStmt) { + // Add function name to current scope + var global byte + if c.scopeDepth > 0 { + c.addLocal(stmt.Name.Lexeme) + } else { + global = c.makeConstant(types.StringValue{Value: stmt.Name.Lexeme}) + } + + // Compile function body with new compiler + compiler := New(stmt.Name.Lexeme, c) + + // Add parameters + compiler.beginScope() + for _, param := range stmt.Params { + compiler.addLocal(param.Lexeme) + compiler.currentFunction.Arity++ + } + compiler.currentFunction.IsVariadic = stmt.IsVariadic + + // Compile function body + for _, bodyStmt := range stmt.Body { + compiler.statement(bodyStmt) + } + + // Implicit return if needed + compiler.emitReturn() + + // Create function object + compiler.currentFunction.UpvalueCount = len(compiler.upvalues) + compiler.currentFunction.Upvalues = compiler.upvalues + + // Add function to constants + idx := c.makeConstant(types.ClosureValue{ + Closure: &types.Closure{ + Function: compiler.currentFunction, + Upvalues: make([]*types.Upvalue, compiler.currentFunction.UpvalueCount), + }, + }) + + // Emit closure instruction and upvalue info + c.emit(types.OP_CLOSURE, []byte{idx}) + + // Add upvalue information to instruction stream + for _, upvalue := range compiler.upvalues { + if upvalue.IsLocal { + c.emit(0, []byte{1, upvalue.Index}) // 1 means isLocal=true + } else { + c.emit(0, []byte{0, upvalue.Index}) // 0 means isLocal=false + } + } + + // Store function in variable + if c.scopeDepth > 0 { + // It's already on the stack + } else { + c.emit(types.OP_SET_GLOBAL, []byte{global}) + } +} + +// functionExpr compiles an anonymous function expression +func (c *Compiler) functionExpr(expr types.FunctionExpr) { + // Compile function body with new compiler + compiler := New("", c) + + // Add parameters + compiler.beginScope() + for _, param := range expr.Params { + compiler.addLocal(param.Lexeme) + compiler.currentFunction.Arity++ + } + compiler.currentFunction.IsVariadic = expr.IsVariadic + + // Compile function body + for _, bodyStmt := range expr.Body { + compiler.statement(bodyStmt) + } + + // Implicit return + compiler.emitReturn() + + // Create function object + compiler.currentFunction.UpvalueCount = len(compiler.upvalues) + compiler.currentFunction.Upvalues = compiler.upvalues + + // Add function to constants + idx := c.makeConstant(types.ClosureValue{ + Closure: &types.Closure{ + Function: compiler.currentFunction, + Upvalues: make([]*types.Upvalue, compiler.currentFunction.UpvalueCount), + }, + }) + + // Emit closure instruction and upvalue info + c.emit(types.OP_CLOSURE, []byte{idx}) + + // Add upvalue information to instruction stream + for _, upvalue := range compiler.upvalues { + if upvalue.IsLocal { + c.emit(0, []byte{1, upvalue.Index}) // 1 means isLocal=true + } else { + c.emit(0, []byte{0, upvalue.Index}) // 0 means isLocal=false + } + } +} + +// returnStmt compiles a return statement +func (c *Compiler) returnStmt(stmt types.ReturnStmt) { + if stmt.Value == nil { + c.emit(types.OP_NIL, nil) + } else { + c.expression(stmt.Value) + } + + c.emit(types.OP_RETURN, nil) +} + +// ifStatement compiles an if statement +func (c *Compiler) ifStatement(stmt types.IfStmt) { + // Compile condition + c.expression(stmt.Condition) + + // Emit the then branch jump + thenJump := c.emitJump(types.OP_JUMP_IF_FALSE) + + // Compile then branch + c.emit(types.OP_POP, nil) // Pop condition + for _, thenStmt := range stmt.ThenBranch { + c.statement(thenStmt) + } + + // Jump over else branch + elseJump := c.emitJump(types.OP_JUMP) + + // Patch then jump + c.patchJump(thenJump) + c.emit(types.OP_POP, nil) // Pop condition + + // Compile elseif branches + for _, elseif := range stmt.ElseIfs { + c.expression(elseif.Condition) + + // Jump if this condition is false + elseifJump := c.emitJump(types.OP_JUMP_IF_FALSE) + + c.emit(types.OP_POP, nil) // Pop condition + for _, elseifStmt := range elseif.Body { + c.statement(elseifStmt) + } + + // Jump to end after this branch + endJump := c.emitJump(types.OP_JUMP) + + // Patch elseif jump to next branch + c.patchJump(elseifJump) + c.emit(types.OP_POP, nil) // Pop condition + + // Collect end jumps for patching + elseJump = endJump + } + + // Compile else branch + for _, elseStmt := range stmt.ElseBranch { + c.statement(elseStmt) + } + + // Patch else jump + c.patchJump(elseJump) +} + +// Helper methods +func (c *Compiler) emit(op types.OpCode, operands []byte) { + // Get source position from the current token + pos := types.SourcePos{Line: 0, Column: 0} // In real implementation, track from token + + instruction := types.Instruction{ + Op: op, + Operands: operands, + Pos: pos, + } + + c.chunk.Code = append(c.chunk.Code, instruction) +} + +func (c *Compiler) emitJump(op types.OpCode) int { + c.emit(op, []byte{0xFF, 0xFF}) // Placeholder for jump offset + return len(c.chunk.Code) - 1 +} + +func (c *Compiler) patchJump(jumpIndex int) { + // -2 to adjust for the size of the jump offset itself + jumpDistance := len(c.chunk.Code) - jumpIndex - 1 + + // Store jump distance in the instruction's operands + // Using big-endian format: high byte first, low byte second + c.chunk.Code[jumpIndex].Operands = []byte{ + byte((jumpDistance >> 8) & 0xFF), + byte(jumpDistance & 0xFF), + } +} + +func (c *Compiler) emitReturn() { + c.emit(types.OP_NIL, nil) + c.emit(types.OP_RETURN, nil) +} + +func (c *Compiler) makeConstant(value types.Value) byte { + c.chunk.Constants = append(c.chunk.Constants, value) + return byte(len(c.chunk.Constants) - 1) +} + +func (c *Compiler) beginScope() { + c.scopeDepth++ +} + +func (c *Compiler) endScope() { + c.scopeDepth-- + + // Remove locals from this scope + for len(c.locals) > 0 && c.locals[len(c.locals)-1].depth > c.scopeDepth { + if c.locals[len(c.locals)-1].isCaptured { + c.emit(types.OP_CLOSE_UPVALUE, nil) + } else { + c.emit(types.OP_POP, nil) + } + c.locals = c.locals[:len(c.locals)-1] + } +} + +func (c *Compiler) addLocal(name string) { + // Check if we've hit the limit of local variables + if len(c.locals) >= 256 { + c.error("Too many local variables in function.") + return + } + + c.locals = append(c.locals, local{ + name: name, + depth: c.scopeDepth, + }) +} + +func (c *Compiler) resolveUpvalue(name string) (int, bool) { + // If no enclosing scope, can't be an upvalue + if c.enclosing == nil { + return 0, false + } + + // Try to find in immediate enclosing function's locals + for i := range c.enclosing.locals { + if c.enclosing.locals[i].name == name { + c.enclosing.locals[i].isCaptured = true + upvalue := types.Upvalue{ + Index: uint8(i), + IsLocal: true, + } + c.upvalues = append(c.upvalues, upvalue) + return len(c.upvalues) - 1, true + } + } + + // Try to find in higher enclosing scopes + if upvalueIndex, found := c.enclosing.resolveUpvalue(name); found { + upvalue := types.Upvalue{ + Index: uint8(upvalueIndex), + IsLocal: false, + } + c.upvalues = append(c.upvalues, upvalue) + return len(c.upvalues) - 1, true + } + + return 0, false +} + +func (c *Compiler) error(message string) { + c.errors = append(c.errors, &types.MakoError{ + Message: message, + Line: 0, // In real implementation, track from token + Column: 0, + }) +} diff --git a/compiler/compiler_test.go b/compiler/compiler_test.go new file mode 100644 index 0000000..9f8c458 --- /dev/null +++ b/compiler/compiler_test.go @@ -0,0 +1,325 @@ +package compiler_test + +import ( + "testing" + + assert "git.sharkk.net/Go/Assert" + "git.sharkk.net/Sharkk/Mako/compiler" + "git.sharkk.net/Sharkk/Mako/types" +) + +// Helper function to check if an opcode exists in the bytecode +func hasOpCode(chunk *types.Chunk, opCode types.OpCode) bool { + for _, instr := range chunk.Code { + if instr.Op == opCode { + return true + } + } + return false +} + +// Helper function to count opcodes in the bytecode +func countOpCodes(chunk *types.Chunk, opCode types.OpCode) int { + count := 0 + for _, instr := range chunk.Code { + if instr.Op == opCode { + count++ + } + } + return count +} + +// Helper function to check constants in the chunk +func hasConstant(chunk *types.Chunk, value types.Value) bool { + for _, constant := range chunk.Constants { + if constant.String() == value.String() { + return true + } + } + return false +} + +func TestCompileLiterals(t *testing.T) { + tests := []struct { + source string + opCode types.OpCode + hasValue bool + value types.Value + }{ + {"nil", types.OP_NIL, false, nil}, + {"true", types.OP_TRUE, false, nil}, + {"false", types.OP_FALSE, false, nil}, + {"123", types.OP_CONSTANT, true, types.NumberValue{Value: 123}}, + {"\"hello\"", types.OP_CONSTANT, true, types.StringValue{Value: "hello"}}, + } + + for _, test := range tests { + function, errors := compiler.CompileSource(test.source) + + assert.Equal(t, 0, len(errors)) + assert.NotNil(t, function) + assert.NotNil(t, function.Chunk) + + // Check if the right opcode exists + assert.True(t, hasOpCode(function.Chunk, test.opCode)) + + // Check if the constant exists if needed + if test.hasValue { + assert.True(t, hasConstant(function.Chunk, test.value)) + } + } +} + +func TestCompileVariables(t *testing.T) { + // Test variable declaration and reference + source := "x = 10\necho x" + function, errors := compiler.CompileSource(source) + + assert.Equal(t, 0, len(errors)) + assert.NotNil(t, function) + + // Check for SET_GLOBAL and GET_GLOBAL opcodes + assert.True(t, hasOpCode(function.Chunk, types.OP_SET_GLOBAL)) + assert.True(t, hasOpCode(function.Chunk, types.OP_GET_GLOBAL)) + + // Check that "x" exists as a constant + stringValue := types.StringValue{Value: "x"} + assert.True(t, hasConstant(function.Chunk, stringValue)) + + // Check that 10 exists as a constant + numberValue := types.NumberValue{Value: 10} + assert.True(t, hasConstant(function.Chunk, numberValue)) +} + +func TestCompileExpressions(t *testing.T) { + tests := []struct { + source string + opCode types.OpCode + }{ + {"1 + 2", types.OP_ADD}, + {"3 - 4", types.OP_SUBTRACT}, + {"5 * 6", types.OP_MULTIPLY}, + {"7 / 8", types.OP_DIVIDE}, + {"9 == 10", types.OP_EQUAL}, + {"11 != 12", types.OP_NOT}, + {"13 < 14", types.OP_LESS}, + {"15 > 16", types.OP_GREATER}, + {"-17", types.OP_NEGATE}, + } + + for _, test := range tests { + function, errors := compiler.CompileSource(test.source) + + assert.Equal(t, 0, len(errors)) + assert.NotNil(t, function) + assert.NotNil(t, function.Chunk) + + // Check if the right opcode exists + assert.True(t, hasOpCode(function.Chunk, test.opCode)) + } + + // Test complex expression + source := "1 + 2 * 3 - 4 / 5" + function, errors := compiler.CompileSource(source) + + assert.Equal(t, 0, len(errors)) + assert.NotNil(t, function) + + // Check for all arithmetic operations + assert.True(t, hasOpCode(function.Chunk, types.OP_ADD)) + assert.True(t, hasOpCode(function.Chunk, types.OP_MULTIPLY)) + assert.True(t, hasOpCode(function.Chunk, types.OP_SUBTRACT)) + assert.True(t, hasOpCode(function.Chunk, types.OP_DIVIDE)) +} + +func TestCompileStatements(t *testing.T) { + // Test echo statement + source := "echo \"hello\"" + function, errors := compiler.CompileSource(source) + + assert.Equal(t, 0, len(errors)) + assert.NotNil(t, function) + + // Check for PRINT opcode + assert.True(t, hasOpCode(function.Chunk, types.OP_PRINT)) + + // Check that string constant exists + stringValue := types.StringValue{Value: "hello"} + assert.True(t, hasConstant(function.Chunk, stringValue)) + + // Test multiple statements + source = "x = 1\ny = 2\necho x + y" + function, errors = compiler.CompileSource(source) + + assert.Equal(t, 0, len(errors)) + assert.NotNil(t, function) + + // Check for SET_GLOBAL, GET_GLOBAL, ADD, and PRINT opcodes + assert.True(t, hasOpCode(function.Chunk, types.OP_SET_GLOBAL)) + assert.True(t, hasOpCode(function.Chunk, types.OP_GET_GLOBAL)) + assert.True(t, hasOpCode(function.Chunk, types.OP_ADD)) + assert.True(t, hasOpCode(function.Chunk, types.OP_PRINT)) + + // Expect at least 2 SET_GLOBAL instructions + assert.True(t, countOpCodes(function.Chunk, types.OP_SET_GLOBAL) >= 2) +} + +func TestCompileControlFlow(t *testing.T) { + // Test if statement + source := "if true then echo \"yes\" end" + function, errors := compiler.CompileSource(source) + + assert.Equal(t, 0, len(errors)) + assert.NotNil(t, function) + + // Check for TRUE, JUMP_IF_FALSE, and PRINT opcodes + assert.True(t, hasOpCode(function.Chunk, types.OP_TRUE)) + assert.True(t, hasOpCode(function.Chunk, types.OP_JUMP_IF_FALSE)) + assert.True(t, hasOpCode(function.Chunk, types.OP_PRINT)) + + // Test if-else statement + source = "if false then echo \"yes\" else echo \"no\" end" + function, errors = compiler.CompileSource(source) + + assert.Equal(t, 0, len(errors)) + assert.NotNil(t, function) + + // Check for FALSE, JUMP_IF_FALSE, JUMP, and PRINT opcodes + assert.True(t, hasOpCode(function.Chunk, types.OP_FALSE)) + assert.True(t, hasOpCode(function.Chunk, types.OP_JUMP_IF_FALSE)) + assert.True(t, hasOpCode(function.Chunk, types.OP_JUMP)) + assert.True(t, hasOpCode(function.Chunk, types.OP_PRINT)) + + // Test if-elseif-else statement + source = "if false then echo \"a\" elseif true then echo \"b\" else echo \"c\" end" + function, errors = compiler.CompileSource(source) + + assert.Equal(t, 0, len(errors)) + assert.NotNil(t, function) + + // Check for appropriate opcodes + assert.True(t, hasOpCode(function.Chunk, types.OP_FALSE)) + assert.True(t, hasOpCode(function.Chunk, types.OP_TRUE)) + assert.True(t, hasOpCode(function.Chunk, types.OP_JUMP_IF_FALSE)) + assert.True(t, hasOpCode(function.Chunk, types.OP_JUMP)) + assert.True(t, hasOpCode(function.Chunk, types.OP_PRINT)) +} + +func TestCompileFunctions(t *testing.T) { + // Test function declaration + source := "fn add(a, b) return a + b end" + function, errors := compiler.CompileSource(source) + + assert.Equal(t, 0, len(errors)) + assert.NotNil(t, function) + + // Check for CLOSURE opcode + assert.True(t, hasOpCode(function.Chunk, types.OP_CLOSURE)) + + // Test function call + source = "fn add(a, b) return a + b end\necho add(1, 2)" + function, errors = compiler.CompileSource(source) + + assert.Equal(t, 0, len(errors)) + assert.NotNil(t, function) + + // Check for CLOSURE, CALL, and PRINT opcodes + assert.True(t, hasOpCode(function.Chunk, types.OP_CLOSURE)) + assert.True(t, hasOpCode(function.Chunk, types.OP_CALL)) + assert.True(t, hasOpCode(function.Chunk, types.OP_PRINT)) + + // Test anonymous function expression + source = "fnVar = fn(x, y) return x * y end" + function, errors = compiler.CompileSource(source) + + assert.Equal(t, 0, len(errors)) + assert.NotNil(t, function) + + // Check for CLOSURE and SET_GLOBAL opcodes + assert.True(t, hasOpCode(function.Chunk, types.OP_CLOSURE)) + assert.True(t, hasOpCode(function.Chunk, types.OP_SET_GLOBAL)) +} + +func TestCompileClosures(t *testing.T) { + // Test closure that captures a variable + source := ` + x = 10 + fn makeAdder() + return fn(y) return x + y end + end + adder = makeAdder() + echo adder(5) + ` + + function, errors := compiler.CompileSource(source) + + assert.Equal(t, 0, len(errors)) + assert.NotNil(t, function) + + // Check for GET_UPVALUE opcode (used in closures) + assert.True(t, hasOpCode(function.Chunk, types.OP_CLOSURE)) + + // The nested functions should create closures + // This is a simplified test since we can't easily peek inside the nested functions + assert.True(t, countOpCodes(function.Chunk, types.OP_CLOSURE) >= 2) +} + +func TestCompileErrors(t *testing.T) { + // Test syntax error + source := "if true echo \"missing then\" end" + _, errors := compiler.CompileSource(source) + + assert.True(t, len(errors) > 0) + + // Test undefined variable + source = "echo undefinedVar" + function, _ := compiler.CompileSource(source) + + // This should compile, as variable resolution happens at runtime + assert.NotNil(t, function) + assert.True(t, hasOpCode(function.Chunk, types.OP_GET_GLOBAL)) +} + +func TestCompileComplexProgram(t *testing.T) { + // Test a more complex program + source := ` + // Calculate factorial recursively + fn factorial(n) + if n <= 1 then + return 1 + else + return n * factorial(n - 1) + end + end + + // Calculate factorial iteratively + fn factorialIter(n) + result = 1 + if n > 1 then + current = 2 + while current <= n + result = result * current + current = current + 1 + end + end + return result + end + + // Use both functions + echo factorial(5) + echo factorialIter(5) + ` + + // This won't compile yet because 'while' is not implemented + // But we can check if the function declarations compile + function, errors := compiler.CompileSource(source) + + // We expect some errors due to missing 'while' implementation + if len(errors) == 0 { + // If no errors, check for expected opcodes + assert.NotNil(t, function) + assert.True(t, hasOpCode(function.Chunk, types.OP_CLOSURE)) + assert.True(t, hasOpCode(function.Chunk, types.OP_CALL)) + assert.True(t, hasOpCode(function.Chunk, types.OP_PRINT)) + } +}