Mako/parser/tests/functions_test.go

706 lines
18 KiB
Go

package parser_test
import (
"strings"
"testing"
"git.sharkk.net/Sharkk/Mako/parser"
)
func TestBasicFunctionLiterals(t *testing.T) {
tests := []struct {
input string
paramCount int
variadic bool
bodyCount int
description string
}{
{"fn() end", 0, false, 0, "empty function"},
{"fn(a) echo a end", 1, false, 1, "single parameter"},
{"fn(a, b) return a + b end", 2, false, 1, "two parameters"},
{"fn(...) echo \"variadic\" end", 0, true, 1, "variadic only"},
{"fn(a, b, ...) return a end", 2, true, 1, "mixed params and variadic"},
{"fn(x, y) x = x + 1 y = y + 1 end", 2, false, 2, "multiple statements"},
}
for _, tt := range tests {
t.Run(tt.description, func(t *testing.T) {
l := parser.NewLexer(tt.input)
p := parser.NewParser(l)
expr := p.ParseExpression(parser.LOWEST)
checkParserErrors(t, p)
fn, ok := expr.(*parser.FunctionLiteral)
if !ok {
t.Fatalf("expected FunctionLiteral, got %T", expr)
}
if len(fn.Parameters) != tt.paramCount {
t.Errorf("expected %d parameters, got %d", tt.paramCount, len(fn.Parameters))
}
if fn.Variadic != tt.variadic {
t.Errorf("expected variadic = %t, got %t", tt.variadic, fn.Variadic)
}
if len(fn.Body) != tt.bodyCount {
t.Errorf("expected %d body statements, got %d", tt.bodyCount, len(fn.Body))
}
})
}
}
func TestFunctionParameters(t *testing.T) {
tests := []struct {
input string
params []string
variadic bool
desc string
}{
{"fn(a) end", []string{"a"}, false, "single param"},
{"fn(a, b, c) end", []string{"a", "b", "c"}, false, "multiple params"},
{"fn(...) end", []string{}, true, "only variadic"},
{"fn(x, ...) end", []string{"x"}, true, "param with variadic"},
{"fn(a, b, ...) end", []string{"a", "b"}, true, "multiple params with variadic"},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
l := parser.NewLexer(tt.input)
p := parser.NewParser(l)
expr := p.ParseExpression(parser.LOWEST)
checkParserErrors(t, p)
fn, ok := expr.(*parser.FunctionLiteral)
if !ok {
t.Fatalf("expected FunctionLiteral, got %T", expr)
}
if len(fn.Parameters) != len(tt.params) {
t.Fatalf("expected %d parameters, got %d", len(tt.params), len(fn.Parameters))
}
for i, expected := range tt.params {
if fn.Parameters[i].Name != expected {
t.Errorf("parameter %d: expected %s, got %s", i, expected, fn.Parameters[i].Name)
}
}
if fn.Variadic != tt.variadic {
t.Errorf("expected variadic = %t, got %t", tt.variadic, fn.Variadic)
}
})
}
}
func TestReturnStatements(t *testing.T) {
tests := []struct {
input string
hasValue bool
desc string
}{
{"return", false, "return without value"},
{"return 42", true, "return with number"},
{"return \"hello\"", true, "return with string"},
{"return x", true, "return with identifier"},
{"return x + y", true, "return with expression"},
{"return table.key", true, "return with member access"},
{"return arr[0]", true, "return with index access"},
{"return {a = 1, b = 2}", true, "return with table"},
{"return fn(x) return x end", true, "return with function"},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
l := parser.NewLexer(tt.input)
p := parser.NewParser(l)
program := p.ParseProgram()
checkParserErrors(t, p)
if len(program.Statements) != 1 {
t.Fatalf("expected 1 statement, got %d", len(program.Statements))
}
stmt, ok := program.Statements[0].(*parser.ReturnStatement)
if !ok {
t.Fatalf("expected ReturnStatement, got %T", program.Statements[0])
}
if tt.hasValue && stmt.Value == nil {
t.Error("expected return value but got nil")
}
if !tt.hasValue && stmt.Value != nil {
t.Error("expected no return value but got one")
}
})
}
}
func TestFunctionAssignments(t *testing.T) {
tests := []struct {
input string
desc string
}{
{"add = fn(a, b) return a + b end", "assign function to variable"},
{"math.add = fn(x, y) return x + y end", "assign function to member"},
{"funcs[\"add\"] = fn(a, b) return a + b end", "assign function to index"},
{"callback = fn() echo \"called\" end", "assign simple function"},
{"sum = fn(...) return 0 end", "assign variadic function"},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
l := parser.NewLexer(tt.input)
p := parser.NewParser(l)
program := p.ParseProgram()
checkParserErrors(t, p)
if len(program.Statements) != 1 {
t.Fatalf("expected 1 statement, got %d", len(program.Statements))
}
stmt, ok := program.Statements[0].(*parser.Assignment)
if !ok {
t.Fatalf("expected AssignStatement, got %T", program.Statements[0])
}
_, ok = stmt.Value.(*parser.FunctionLiteral)
if !ok {
t.Fatalf("expected FunctionLiteral value, got %T", stmt.Value)
}
})
}
}
func TestFunctionsInTables(t *testing.T) {
tests := []struct {
input string
desc string
}{
{"{add = fn(a, b) return a + b end}", "function in hash table"},
{"{fn(x) return x end, fn(y) return y end}", "functions in array"},
{"{math = {add = fn(a, b) return a + b end}}", "nested function in table"},
{"{callback = fn() echo \"hi\" end, data = 42}", "mixed table with function"},
{"{fn(...) return 0 end}", "variadic function in array"},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
l := parser.NewLexer(tt.input)
p := parser.NewParser(l)
expr := p.ParseExpression(parser.LOWEST)
checkParserErrors(t, p)
table, ok := expr.(*parser.TableLiteral)
if !ok {
t.Fatalf("expected TableLiteral, got %T", expr)
}
// Verify at least one pair contains a function (or nested table with function)
foundFunction := false
for _, pair := range table.Pairs {
if _, ok := pair.Value.(*parser.FunctionLiteral); ok {
foundFunction = true
break
}
// Check nested tables for functions
if nestedTable, ok := pair.Value.(*parser.TableLiteral); ok {
for _, nestedPair := range nestedTable.Pairs {
if _, ok := nestedPair.Value.(*parser.FunctionLiteral); ok {
foundFunction = true
break
}
}
}
}
if !foundFunction {
t.Error("expected to find at least one function in table")
}
})
}
}
func TestFunctionWithComplexBody(t *testing.T) {
input := `fn(x, y)
if x > y then
return x
else
return y
end
echo "unreachable"
end`
l := parser.NewLexer(input)
p := parser.NewParser(l)
expr := p.ParseExpression(parser.LOWEST)
checkParserErrors(t, p)
fn, ok := expr.(*parser.FunctionLiteral)
if !ok {
t.Fatalf("expected FunctionLiteral, got %T", expr)
}
if len(fn.Parameters) != 2 {
t.Errorf("expected 2 parameters, got %d", len(fn.Parameters))
}
if len(fn.Body) != 2 {
t.Errorf("expected 2 body statements, got %d", len(fn.Body))
}
// First statement should be if
ifStmt, ok := fn.Body[0].(*parser.IfStatement)
if !ok {
t.Fatalf("expected IfStatement, got %T", fn.Body[0])
}
// Check that if body contains return
_, ok = ifStmt.Body[0].(*parser.ReturnStatement)
if !ok {
t.Fatalf("expected ReturnStatement in if body, got %T", ifStmt.Body[0])
}
// Check that else body contains return
_, ok = ifStmt.Else[0].(*parser.ReturnStatement)
if !ok {
t.Fatalf("expected ReturnStatement in else body, got %T", ifStmt.Else[0])
}
// Second statement should be echo
_, ok = fn.Body[1].(*parser.EchoStatement)
if !ok {
t.Fatalf("expected EchoStatement, got %T", fn.Body[1])
}
}
func TestNestedFunctions(t *testing.T) {
input := `fn(x)
inner = fn(y) return y * 2 end
return inner(x)
end`
l := parser.NewLexer(input)
p := parser.NewParser(l)
expr := p.ParseExpression(parser.LOWEST)
checkParserErrors(t, p)
fn, ok := expr.(*parser.FunctionLiteral)
if !ok {
t.Fatalf("expected FunctionLiteral, got %T", expr)
}
if len(fn.Body) != 2 {
t.Fatalf("expected 2 body statements, got %d", len(fn.Body))
}
// First statement: assignment of inner function
assign, ok := fn.Body[0].(*parser.Assignment)
if !ok {
t.Fatalf("expected AssignStatement, got %T", fn.Body[0])
}
innerFn, ok := assign.Value.(*parser.FunctionLiteral)
if !ok {
t.Fatalf("expected FunctionLiteral value, got %T", assign.Value)
}
if len(innerFn.Parameters) != 1 {
t.Errorf("expected 1 parameter in inner function, got %d", len(innerFn.Parameters))
}
// Second statement: return
_, ok = fn.Body[1].(*parser.ReturnStatement)
if !ok {
t.Fatalf("expected ReturnStatement, got %T", fn.Body[1])
}
}
func TestFunctionInLoop(t *testing.T) {
input := `for i = 1, 10 do
callback = fn(x) return x + i end
echo callback(i)
end`
l := parser.NewLexer(input)
p := parser.NewParser(l)
program := p.ParseProgram()
checkParserErrors(t, p)
if len(program.Statements) != 1 {
t.Fatalf("expected 1 statement, got %d", len(program.Statements))
}
forStmt, ok := program.Statements[0].(*parser.ForStatement)
if !ok {
t.Fatalf("expected ForStatement, got %T", program.Statements[0])
}
if len(forStmt.Body) != 2 {
t.Fatalf("expected 2 body statements, got %d", len(forStmt.Body))
}
// First: function assignment
assign, ok := forStmt.Body[0].(*parser.Assignment)
if !ok {
t.Fatalf("expected AssignStatement, got %T", forStmt.Body[0])
}
_, ok = assign.Value.(*parser.FunctionLiteral)
if !ok {
t.Fatalf("expected FunctionLiteral value, got %T", assign.Value)
}
}
func TestFunctionErrors(t *testing.T) {
tests := []struct {
input string
expectedError string
desc string
}{
{"fn", "expected '(' after 'fn'", "missing parentheses"},
{"fn(", "expected ')' after function parameters", "missing closing paren"},
{"fn() echo x", "expected 'end' to close function", "missing end"},
{"fn(123) end", "expected parameter name", "invalid parameter name"},
{"fn(a,) end", "expected parameter name", "trailing comma"},
{"fn(a, 123) end", "expected parameter name", "invalid second parameter"},
{"fn(..., a) end", "expected next token to be ), got , instead", "parameter after ellipsis"},
{"fn(a, ..., b) end", "expected next token to be ), got , instead", "parameter after ellipsis in middle"},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
l := parser.NewLexer(tt.input)
p := parser.NewParser(l)
p.ParseExpression(parser.LOWEST)
if !p.HasErrors() {
t.Fatal("expected parsing errors")
}
errors := p.Errors()
found := false
for _, err := range errors {
if strings.Contains(err.Message, tt.expectedError) {
found = true
break
}
}
if !found {
errorMsgs := make([]string, len(errors))
for i, err := range errors {
errorMsgs[i] = err.Message
}
t.Errorf("expected error containing %q, got %v", tt.expectedError, errorMsgs)
}
})
}
}
func TestReturnErrors(t *testing.T) {
tests := []struct {
input string
expectedError string
desc string
}{
{"return +", "unexpected token '+'", "return with invalid expression"},
{"return (", "unexpected end of input", "return with incomplete expression"},
{"return 1 +", "expected expression after operator '+'", "return with incomplete infix"},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
l := parser.NewLexer(tt.input)
p := parser.NewParser(l)
p.ParseProgram()
if !p.HasErrors() {
t.Fatal("expected parsing errors")
}
errors := p.Errors()
found := false
for _, err := range errors {
if strings.Contains(err.Message, tt.expectedError) {
found = true
break
}
}
if !found {
errorMsgs := make([]string, len(errors))
for i, err := range errors {
errorMsgs[i] = err.Message
}
t.Errorf("expected error containing %q, got %v", tt.expectedError, errorMsgs)
}
})
}
}
func TestFunctionStringRepresentation(t *testing.T) {
tests := []struct {
input string
contains []string
desc string
}{
{
"fn() end",
[]string{"fn()", "end"},
"empty function",
},
{
"fn(a, b) return a + b end",
[]string{"fn(a, b)", "return (a + b)", "end"},
"function with params and return",
},
{
"fn(...) echo \"variadic\" end",
[]string{"fn(...)", "echo \"variadic\"", "end"},
"variadic function",
},
{
"fn(a, b, ...) return a end",
[]string{"fn(a, b, ...)", "return a", "end"},
"mixed params and variadic",
},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
l := parser.NewLexer(tt.input)
p := parser.NewParser(l)
expr := p.ParseExpression(parser.LOWEST)
checkParserErrors(t, p)
fnStr := expr.String()
for _, contain := range tt.contains {
if !strings.Contains(fnStr, contain) {
t.Errorf("expected function string to contain %q, got:\n%s", contain, fnStr)
}
}
})
}
}
func TestReturnStringRepresentation(t *testing.T) {
tests := []struct {
input string
expected string
desc string
}{
{"return", "return", "simple return"},
{"return 42", "return 42.00", "return with number"},
{"return \"hello\"", "return \"hello\"", "return with string"},
{"return x + 1", "return (x + 1.00)", "return with expression"},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
l := parser.NewLexer(tt.input)
p := parser.NewParser(l)
program := p.ParseProgram()
checkParserErrors(t, p)
result := strings.TrimSpace(program.String())
if result != tt.expected {
t.Errorf("expected %q, got %q", tt.expected, result)
}
})
}
}
func TestComplexFunctionProgram(t *testing.T) {
input := `math = {
add = fn(a, b) return a + b end,
multiply = fn(x, y) return x * y end,
max = fn(a, b)
if a > b then
return a
else
return b
end
end
}
result = math.add(5, 3)
echo result
calculator = fn(op, ...)
if op == "add" then
return fn(a, b) return a + b end
elseif op == "sub" then
return fn(a, b) return a - b end
else
return nil
end
end
adder = calculator("add")
echo adder`
l := parser.NewLexer(input)
p := parser.NewParser(l)
program := p.ParseProgram()
checkParserErrors(t, p)
if len(program.Statements) != 6 {
t.Fatalf("expected 6 statements, got %d", len(program.Statements))
}
// First: table with functions
mathAssign, ok := program.Statements[0].(*parser.Assignment)
if !ok {
t.Fatalf("statement 0: expected AssignStatement, got %T", program.Statements[0])
}
table, ok := mathAssign.Value.(*parser.TableLiteral)
if !ok {
t.Fatalf("expected TableLiteral, got %T", mathAssign.Value)
}
if len(table.Pairs) != 3 {
t.Errorf("expected 3 functions in math table, got %d", len(table.Pairs))
}
// Verify all pairs contain functions
for i, pair := range table.Pairs {
_, ok := pair.Value.(*parser.FunctionLiteral)
if !ok {
t.Errorf("pair %d: expected FunctionLiteral, got %T", i, pair.Value)
}
}
// Second: result assignment (function call would be handled by interpreter)
_, ok = program.Statements[1].(*parser.Assignment)
if !ok {
t.Fatalf("statement 1: expected AssignStatement, got %T", program.Statements[1])
}
// Third: echo
_, ok = program.Statements[2].(*parser.EchoStatement)
if !ok {
t.Fatalf("statement 2: expected EchoStatement, got %T", program.Statements[2])
}
// Fourth: calculator function assignment
calcAssign, ok := program.Statements[3].(*parser.Assignment)
if !ok {
t.Fatalf("statement 3: expected AssignStatement, got %T", program.Statements[3])
}
calcFn, ok := calcAssign.Value.(*parser.FunctionLiteral)
if !ok {
t.Fatalf("expected FunctionLiteral, got %T", calcAssign.Value)
}
if !calcFn.Variadic {
t.Error("expected calculator function to be variadic")
}
// Fifth: adder assignment
_, ok = program.Statements[4].(*parser.Assignment)
if !ok {
t.Fatalf("statement 4: expected AssignStatement, got %T", program.Statements[4])
}
// Sixth: echo adder
_, ok = program.Statements[5].(*parser.EchoStatement)
if !ok {
t.Fatalf("statement 5: expected EchoStatement, got %T", program.Statements[5])
}
}
func TestFunctionInConditional(t *testing.T) {
input := `if condition then
handler = fn(x) return x * 2 end
else
handler = fn(x) return x + 1 end
end
for i = 1, 10 do
if i > 5 then
result = fn() return "high" end
else
result = fn() return "low" end
end
echo result()
end`
l := parser.NewLexer(input)
p := parser.NewParser(l)
program := p.ParseProgram()
checkParserErrors(t, p)
if len(program.Statements) != 2 {
t.Fatalf("expected 2 statements, got %d", len(program.Statements))
}
// First: if statement with function assignments
ifStmt, ok := program.Statements[0].(*parser.IfStatement)
if !ok {
t.Fatalf("statement 0: expected IfStatement, got %T", program.Statements[0])
}
// Check if body has function assignment
ifAssign, ok := ifStmt.Body[0].(*parser.Assignment)
if !ok {
t.Fatalf("if body: expected AssignStatement, got %T", ifStmt.Body[0])
}
_, ok = ifAssign.Value.(*parser.FunctionLiteral)
if !ok {
t.Fatalf("if body: expected FunctionLiteral, got %T", ifAssign.Value)
}
// Check else body has function assignment
elseAssign, ok := ifStmt.Else[0].(*parser.Assignment)
if !ok {
t.Fatalf("else body: expected AssignStatement, got %T", ifStmt.Else[0])
}
_, ok = elseAssign.Value.(*parser.FunctionLiteral)
if !ok {
t.Fatalf("else body: expected FunctionLiteral, got %T", elseAssign.Value)
}
// Second: for loop with nested conditionals containing functions
forStmt, ok := program.Statements[1].(*parser.ForStatement)
if !ok {
t.Fatalf("statement 1: expected ForStatement, got %T", program.Statements[1])
}
if len(forStmt.Body) != 2 {
t.Fatalf("expected 2 statements in for body, got %d", len(forStmt.Body))
}
// Check nested if contains function assignments
nestedIf, ok := forStmt.Body[0].(*parser.IfStatement)
if !ok {
t.Fatalf("for body[0]: expected IfStatement, got %T", forStmt.Body[0])
}
// Verify both branches assign functions
nestedIfAssign, ok := nestedIf.Body[0].(*parser.Assignment)
if !ok {
t.Fatalf("nested if body: expected AssignStatement, got %T", nestedIf.Body[0])
}
_, ok = nestedIfAssign.Value.(*parser.FunctionLiteral)
if !ok {
t.Fatalf("nested if body: expected FunctionLiteral, got %T", nestedIfAssign.Value)
}
nestedElseAssign, ok := nestedIf.Else[0].(*parser.Assignment)
if !ok {
t.Fatalf("nested else body: expected AssignStatement, got %T", nestedIf.Else[0])
}
_, ok = nestedElseAssign.Value.(*parser.FunctionLiteral)
if !ok {
t.Fatalf("nested else body: expected FunctionLiteral, got %T", nestedElseAssign.Value)
}
}