Mako/compiler/compiler_test.go
2025-05-07 09:45:50 -05:00

326 lines
9.3 KiB
Go

package compiler_test
import (
"testing"
assert "git.sharkk.net/Go/Assert"
"git.sharkk.net/Sharkk/Mako/compiler"
"git.sharkk.net/Sharkk/Mako/types"
)
// Helper function to check if an opcode exists in the bytecode
func hasOpCode(chunk *types.Chunk, opCode types.OpCode) bool {
for _, instr := range chunk.Code {
if instr.Op == opCode {
return true
}
}
return false
}
// Helper function to count opcodes in the bytecode
func countOpCodes(chunk *types.Chunk, opCode types.OpCode) int {
count := 0
for _, instr := range chunk.Code {
if instr.Op == opCode {
count++
}
}
return count
}
// Helper function to check constants in the chunk
func hasConstant(chunk *types.Chunk, value types.Value) bool {
for _, constant := range chunk.Constants {
if constant.String() == value.String() {
return true
}
}
return false
}
func TestCompileLiterals(t *testing.T) {
tests := []struct {
source string
opCode types.OpCode
hasValue bool
value types.Value
}{
{"nil", types.OP_NIL, false, nil},
{"true", types.OP_TRUE, false, nil},
{"false", types.OP_FALSE, false, nil},
{"123", types.OP_CONSTANT, true, types.NumberValue{Value: 123}},
{"\"hello\"", types.OP_CONSTANT, true, types.StringValue{Value: "hello"}},
}
for _, test := range tests {
function, errors := compiler.CompileSource(test.source)
assert.Equal(t, 0, len(errors))
assert.NotNil(t, function)
assert.NotNil(t, function.Chunk)
// Check if the right opcode exists
assert.True(t, hasOpCode(function.Chunk, test.opCode))
// Check if the constant exists if needed
if test.hasValue {
assert.True(t, hasConstant(function.Chunk, test.value))
}
}
}
func TestCompileVariables(t *testing.T) {
// Test variable declaration and reference
source := "x = 10\necho x"
function, errors := compiler.CompileSource(source)
assert.Equal(t, 0, len(errors))
assert.NotNil(t, function)
// Check for SET_GLOBAL and GET_GLOBAL opcodes
assert.True(t, hasOpCode(function.Chunk, types.OP_SET_GLOBAL))
assert.True(t, hasOpCode(function.Chunk, types.OP_GET_GLOBAL))
// Check that "x" exists as a constant
stringValue := types.StringValue{Value: "x"}
assert.True(t, hasConstant(function.Chunk, stringValue))
// Check that 10 exists as a constant
numberValue := types.NumberValue{Value: 10}
assert.True(t, hasConstant(function.Chunk, numberValue))
}
func TestCompileExpressions(t *testing.T) {
tests := []struct {
source string
opCode types.OpCode
}{
{"1 + 2", types.OP_ADD},
{"3 - 4", types.OP_SUBTRACT},
{"5 * 6", types.OP_MULTIPLY},
{"7 / 8", types.OP_DIVIDE},
{"9 == 10", types.OP_EQUAL},
{"11 != 12", types.OP_NOT},
{"13 < 14", types.OP_LESS},
{"15 > 16", types.OP_GREATER},
{"-17", types.OP_NEGATE},
}
for _, test := range tests {
function, errors := compiler.CompileSource(test.source)
assert.Equal(t, 0, len(errors))
assert.NotNil(t, function)
assert.NotNil(t, function.Chunk)
// Check if the right opcode exists
assert.True(t, hasOpCode(function.Chunk, test.opCode))
}
// Test complex expression
source := "1 + 2 * 3 - 4 / 5"
function, errors := compiler.CompileSource(source)
assert.Equal(t, 0, len(errors))
assert.NotNil(t, function)
// Check for all arithmetic operations
assert.True(t, hasOpCode(function.Chunk, types.OP_ADD))
assert.True(t, hasOpCode(function.Chunk, types.OP_MULTIPLY))
assert.True(t, hasOpCode(function.Chunk, types.OP_SUBTRACT))
assert.True(t, hasOpCode(function.Chunk, types.OP_DIVIDE))
}
func TestCompileStatements(t *testing.T) {
// Test echo statement
source := "echo \"hello\""
function, errors := compiler.CompileSource(source)
assert.Equal(t, 0, len(errors))
assert.NotNil(t, function)
// Check for PRINT opcode
assert.True(t, hasOpCode(function.Chunk, types.OP_PRINT))
// Check that string constant exists
stringValue := types.StringValue{Value: "hello"}
assert.True(t, hasConstant(function.Chunk, stringValue))
// Test multiple statements
source = "x = 1\ny = 2\necho x + y"
function, errors = compiler.CompileSource(source)
assert.Equal(t, 0, len(errors))
assert.NotNil(t, function)
// Check for SET_GLOBAL, GET_GLOBAL, ADD, and PRINT opcodes
assert.True(t, hasOpCode(function.Chunk, types.OP_SET_GLOBAL))
assert.True(t, hasOpCode(function.Chunk, types.OP_GET_GLOBAL))
assert.True(t, hasOpCode(function.Chunk, types.OP_ADD))
assert.True(t, hasOpCode(function.Chunk, types.OP_PRINT))
// Expect at least 2 SET_GLOBAL instructions
assert.True(t, countOpCodes(function.Chunk, types.OP_SET_GLOBAL) >= 2)
}
func TestCompileControlFlow(t *testing.T) {
// Test if statement
source := "if true then echo \"yes\" end"
function, errors := compiler.CompileSource(source)
assert.Equal(t, 0, len(errors))
assert.NotNil(t, function)
// Check for TRUE, JUMP_IF_FALSE, and PRINT opcodes
assert.True(t, hasOpCode(function.Chunk, types.OP_TRUE))
assert.True(t, hasOpCode(function.Chunk, types.OP_JUMP_IF_FALSE))
assert.True(t, hasOpCode(function.Chunk, types.OP_PRINT))
// Test if-else statement
source = "if false then echo \"yes\" else echo \"no\" end"
function, errors = compiler.CompileSource(source)
assert.Equal(t, 0, len(errors))
assert.NotNil(t, function)
// Check for FALSE, JUMP_IF_FALSE, JUMP, and PRINT opcodes
assert.True(t, hasOpCode(function.Chunk, types.OP_FALSE))
assert.True(t, hasOpCode(function.Chunk, types.OP_JUMP_IF_FALSE))
assert.True(t, hasOpCode(function.Chunk, types.OP_JUMP))
assert.True(t, hasOpCode(function.Chunk, types.OP_PRINT))
// Test if-elseif-else statement
source = "if false then echo \"a\" elseif true then echo \"b\" else echo \"c\" end"
function, errors = compiler.CompileSource(source)
assert.Equal(t, 0, len(errors))
assert.NotNil(t, function)
// Check for appropriate opcodes
assert.True(t, hasOpCode(function.Chunk, types.OP_FALSE))
assert.True(t, hasOpCode(function.Chunk, types.OP_TRUE))
assert.True(t, hasOpCode(function.Chunk, types.OP_JUMP_IF_FALSE))
assert.True(t, hasOpCode(function.Chunk, types.OP_JUMP))
assert.True(t, hasOpCode(function.Chunk, types.OP_PRINT))
}
func TestCompileFunctions(t *testing.T) {
// Test function declaration
source := "fn add(a, b) return a + b end"
function, errors := compiler.CompileSource(source)
assert.Equal(t, 0, len(errors))
assert.NotNil(t, function)
// Check for CLOSURE opcode
assert.True(t, hasOpCode(function.Chunk, types.OP_CLOSURE))
// Test function call
source = "fn add(a, b) return a + b end\necho add(1, 2)"
function, errors = compiler.CompileSource(source)
assert.Equal(t, 0, len(errors))
assert.NotNil(t, function)
// Check for CLOSURE, CALL, and PRINT opcodes
assert.True(t, hasOpCode(function.Chunk, types.OP_CLOSURE))
assert.True(t, hasOpCode(function.Chunk, types.OP_CALL))
assert.True(t, hasOpCode(function.Chunk, types.OP_PRINT))
// Test anonymous function expression
source = "fnVar = fn(x, y) return x * y end"
function, errors = compiler.CompileSource(source)
assert.Equal(t, 0, len(errors))
assert.NotNil(t, function)
// Check for CLOSURE and SET_GLOBAL opcodes
assert.True(t, hasOpCode(function.Chunk, types.OP_CLOSURE))
assert.True(t, hasOpCode(function.Chunk, types.OP_SET_GLOBAL))
}
func TestCompileClosures(t *testing.T) {
// Test closure that captures a variable
source := `
x = 10
fn makeAdder()
return fn(y) return x + y end
end
adder = makeAdder()
echo adder(5)
`
function, errors := compiler.CompileSource(source)
assert.Equal(t, 0, len(errors))
assert.NotNil(t, function)
// Check for GET_UPVALUE opcode (used in closures)
assert.True(t, hasOpCode(function.Chunk, types.OP_CLOSURE))
// The nested functions should create closures
// This is a simplified test since we can't easily peek inside the nested functions
assert.True(t, countOpCodes(function.Chunk, types.OP_CLOSURE) >= 2)
}
func TestCompileErrors(t *testing.T) {
// Test syntax error
source := "if true echo \"missing then\" end"
_, errors := compiler.CompileSource(source)
assert.True(t, len(errors) > 0)
// Test undefined variable
source = "echo undefinedVar"
function, _ := compiler.CompileSource(source)
// This should compile, as variable resolution happens at runtime
assert.NotNil(t, function)
assert.True(t, hasOpCode(function.Chunk, types.OP_GET_GLOBAL))
}
func TestCompileComplexProgram(t *testing.T) {
// Test a more complex program
source := `
// Calculate factorial recursively
fn factorial(n)
if n <= 1 then
return 1
else
return n * factorial(n - 1)
end
end
// Calculate factorial iteratively
fn factorialIter(n)
result = 1
if n > 1 then
current = 2
while current <= n
result = result * current
current = current + 1
end
end
return result
end
// Use both functions
echo factorial(5)
echo factorialIter(5)
`
// This won't compile yet because 'while' is not implemented
// But we can check if the function declarations compile
function, errors := compiler.CompileSource(source)
// We expect some errors due to missing 'while' implementation
if len(errors) == 0 {
// If no errors, check for expected opcodes
assert.NotNil(t, function)
assert.True(t, hasOpCode(function.Chunk, types.OP_CLOSURE))
assert.True(t, hasOpCode(function.Chunk, types.OP_CALL))
assert.True(t, hasOpCode(function.Chunk, types.OP_PRINT))
}
}