functions

This commit is contained in:
Sky Johnson 2025-06-10 23:05:21 -05:00
parent 6eb4e21263
commit fc988d257f
5 changed files with 914 additions and 3 deletions

View File

@ -75,6 +75,19 @@ func (es *ExitStatement) String() string {
return fmt.Sprintf("exit %s", es.Value.String())
}
// ReturnStatement represents return statements
type ReturnStatement struct {
Value Expression // optional, can be nil
}
func (rs *ReturnStatement) statementNode() {}
func (rs *ReturnStatement) String() string {
if rs.Value == nil {
return "return"
}
return fmt.Sprintf("return %s", rs.Value.String())
}
// ElseIfClause represents an elseif condition
type ElseIfClause struct {
Condition Expression
@ -241,6 +254,52 @@ type NilLiteral struct{}
func (nl *NilLiteral) expressionNode() {}
func (nl *NilLiteral) String() string { return "nil" }
// FunctionLiteral represents function literals: fn(a, b, ...) ... end
type FunctionLiteral struct {
Parameters []string
Variadic bool
Body []Statement
}
func (fl *FunctionLiteral) expressionNode() {}
func (fl *FunctionLiteral) String() string {
var params string
for i, param := range fl.Parameters {
if i > 0 {
params += ", "
}
params += param
}
if fl.Variadic {
if len(fl.Parameters) > 0 {
params += ", "
}
params += "..."
}
result := fmt.Sprintf("fn(%s)\n", params)
for _, stmt := range fl.Body {
result += "\t" + stmt.String() + "\n"
}
result += "end"
return result
}
// CallExpression represents function calls: func(arg1, arg2, ...)
type CallExpression struct {
Function Expression
Arguments []Expression
}
func (ce *CallExpression) expressionNode() {}
func (ce *CallExpression) String() string {
var args []string
for _, arg := range ce.Arguments {
args = append(args, arg.String())
}
return fmt.Sprintf("%s(%s)", ce.Function.String(), joinStrings(args, ", "))
}
// PrefixExpression represents prefix operations like -x
type PrefixExpression struct {
Operator string

View File

@ -47,6 +47,15 @@ func (l *Lexer) peekChar() byte {
return l.input[l.readPosition]
}
// peekCharAt returns the character at offset positions ahead
func (l *Lexer) peekCharAt(offset int) byte {
pos := l.readPosition + offset - 1
if pos >= len(l.input) {
return 0
}
return l.input[pos]
}
// skipWhitespace skips whitespace characters
func (l *Lexer) skipWhitespace() {
for l.ch == ' ' || l.ch == '\t' || l.ch == '\n' || l.ch == '\r' {
@ -249,7 +258,14 @@ func (l *Lexer) NextToken() Token {
case '/':
tok = Token{Type: SLASH, Literal: string(l.ch), Line: l.line, Column: l.column}
case '.':
tok = Token{Type: DOT, Literal: string(l.ch), Line: l.line, Column: l.column}
// Check for ellipsis (...)
if l.peekChar() == '.' && l.peekCharAt(2) == '.' {
l.readChar() // skip first '.'
l.readChar() // skip second '.'
tok = Token{Type: ELLIPSIS, Literal: "...", Line: l.line, Column: l.column}
} else {
tok = Token{Type: DOT, Literal: string(l.ch), Line: l.line, Column: l.column}
}
case '(':
tok = Token{Type: LPAREN, Literal: string(l.ch), Line: l.line, Column: l.column}
case ')':

View File

@ -49,6 +49,7 @@ func NewParser(lexer *Lexer) *Parser {
p.registerPrefix(LPAREN, p.parseGroupedExpression)
p.registerPrefix(LBRACE, p.parseTableLiteral)
p.registerPrefix(MINUS, p.parsePrefixExpression)
p.registerPrefix(FN, p.parseFunctionLiteral)
p.infixParseFns = make(map[TokenType]func(Expression) Expression)
p.registerInfix(PLUS, p.parseInfixExpression)
@ -63,6 +64,7 @@ func NewParser(lexer *Lexer) *Parser {
p.registerInfix(GT_EQ, p.parseInfixExpression)
p.registerInfix(DOT, p.parseDotExpression)
p.registerInfix(LBRACKET, p.parseIndexExpression)
p.registerInfix(LPAREN, p.parseCallExpression)
// Read two tokens, so curToken and peekToken are both set
p.nextToken()
@ -121,6 +123,8 @@ func (p *Parser) parseStatement() Statement {
return p.parseBreakStatement()
case EXIT:
return p.parseExitStatement()
case RETURN:
return p.parseReturnStatement()
case ASSIGN:
p.addError("assignment operator '=' without left-hand side identifier")
return nil
@ -214,10 +218,27 @@ func (p *Parser) parseExitStatement() *ExitStatement {
return stmt
}
// parseReturnStatement parses return statements
func (p *Parser) parseReturnStatement() *ReturnStatement {
stmt := &ReturnStatement{}
// Check if there's an optional expression after 'return'
if p.canStartExpression(p.peekToken.Type) {
p.nextToken() // move past 'return'
stmt.Value = p.ParseExpression(LOWEST)
if stmt.Value == nil {
p.addError("expected expression after 'return'")
return nil
}
}
return stmt
}
// canStartExpression checks if a token type can start an expression
func (p *Parser) canStartExpression(tokenType TokenType) bool {
switch tokenType {
case IDENT, NUMBER, STRING, TRUE, FALSE, NIL, LPAREN, LBRACE, MINUS:
case IDENT, NUMBER, STRING, TRUE, FALSE, NIL, LPAREN, LBRACE, MINUS, FN:
return true
default:
return false
@ -621,6 +642,74 @@ func (p *Parser) parseGroupedExpression() Expression {
return exp
}
func (p *Parser) parseFunctionLiteral() Expression {
fn := &FunctionLiteral{}
if !p.expectPeek(LPAREN) {
p.addError("expected '(' after 'fn'")
return nil
}
fn.Parameters, fn.Variadic = p.parseFunctionParameters()
if !p.expectPeek(RPAREN) {
p.addError("expected ')' after function parameters")
return nil
}
p.nextToken() // move past ')'
// Parse function body
fn.Body = p.parseBlockStatements(END)
if !p.curTokenIs(END) {
p.addError("expected 'end' to close function")
return nil
}
return fn
}
func (p *Parser) parseFunctionParameters() ([]string, bool) {
var params []string
var variadic bool
if p.peekTokenIs(RPAREN) {
return params, false
}
p.nextToken()
for {
if p.curTokenIs(ELLIPSIS) {
variadic = true
break
}
if !p.curTokenIs(IDENT) {
p.addError("expected parameter name")
return nil, false
}
params = append(params, p.curToken.Literal)
if !p.peekTokenIs(COMMA) {
break
}
p.nextToken() // move to ','
p.nextToken() // move past ','
// Check for ellipsis after comma
if p.curTokenIs(ELLIPSIS) {
variadic = true
break
}
}
return params, variadic
}
func (p *Parser) parseTableLiteral() Expression {
table := &TableLiteral{}
table.Pairs = []TablePair{}
@ -725,6 +814,36 @@ func (p *Parser) parseDotExpression(left Expression) Expression {
}
}
func (p *Parser) parseCallExpression(fn Expression) Expression {
call := &CallExpression{Function: fn}
call.Arguments = p.parseExpressionList(RPAREN)
return call
}
func (p *Parser) parseExpressionList(end TokenType) []Expression {
var args []Expression
if p.peekTokenIs(end) {
p.nextToken()
return args
}
p.nextToken()
args = append(args, p.ParseExpression(LOWEST))
for p.peekTokenIs(COMMA) {
p.nextToken()
p.nextToken()
args = append(args, p.ParseExpression(LOWEST))
}
if !p.expectPeek(end) {
return nil
}
return args
}
func (p *Parser) parseIndexExpression(left Expression) Expression {
p.nextToken() // move past '['
@ -776,7 +895,7 @@ func (p *Parser) expectPeekIdent() bool {
// isKeyword checks if a token type is a keyword that can be used as identifier
func (p *Parser) isKeyword(t TokenType) bool {
switch t {
case TRUE, FALSE, NIL, IF, THEN, ELSEIF, ELSE, END, ECHO, FOR, WHILE, IN, DO, BREAK, EXIT:
case TRUE, FALSE, NIL, IF, THEN, ELSEIF, ELSE, END, ECHO, FOR, WHILE, IN, DO, BREAK, EXIT, FN, RETURN:
return true
default:
return false
@ -910,6 +1029,8 @@ func tokenTypeString(t TokenType) string {
return "]"
case COMMA:
return ","
case ELLIPSIS:
return "..."
case IF:
return "if"
case THEN:
@ -934,6 +1055,10 @@ func tokenTypeString(t TokenType) string {
return "break"
case EXIT:
return "exit"
case FN:
return "fn"
case RETURN:
return "return"
case EOF:
return "end of file"
case ILLEGAL:

View File

@ -0,0 +1,705 @@
package parser_test
import (
"strings"
"testing"
"git.sharkk.net/Sharkk/Mako/parser"
)
func TestBasicFunctionLiterals(t *testing.T) {
tests := []struct {
input string
paramCount int
variadic bool
bodyCount int
description string
}{
{"fn() end", 0, false, 0, "empty function"},
{"fn(a) echo a end", 1, false, 1, "single parameter"},
{"fn(a, b) return a + b end", 2, false, 1, "two parameters"},
{"fn(...) echo \"variadic\" end", 0, true, 1, "variadic only"},
{"fn(a, b, ...) return a end", 2, true, 1, "mixed params and variadic"},
{"fn(x, y) x = x + 1 y = y + 1 end", 2, false, 2, "multiple statements"},
}
for _, tt := range tests {
t.Run(tt.description, func(t *testing.T) {
l := parser.NewLexer(tt.input)
p := parser.NewParser(l)
expr := p.ParseExpression(parser.LOWEST)
checkParserErrors(t, p)
fn, ok := expr.(*parser.FunctionLiteral)
if !ok {
t.Fatalf("expected FunctionLiteral, got %T", expr)
}
if len(fn.Parameters) != tt.paramCount {
t.Errorf("expected %d parameters, got %d", tt.paramCount, len(fn.Parameters))
}
if fn.Variadic != tt.variadic {
t.Errorf("expected variadic = %t, got %t", tt.variadic, fn.Variadic)
}
if len(fn.Body) != tt.bodyCount {
t.Errorf("expected %d body statements, got %d", tt.bodyCount, len(fn.Body))
}
})
}
}
func TestFunctionParameters(t *testing.T) {
tests := []struct {
input string
params []string
variadic bool
desc string
}{
{"fn(a) end", []string{"a"}, false, "single param"},
{"fn(a, b, c) end", []string{"a", "b", "c"}, false, "multiple params"},
{"fn(...) end", []string{}, true, "only variadic"},
{"fn(x, ...) end", []string{"x"}, true, "param with variadic"},
{"fn(a, b, ...) end", []string{"a", "b"}, true, "multiple params with variadic"},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
l := parser.NewLexer(tt.input)
p := parser.NewParser(l)
expr := p.ParseExpression(parser.LOWEST)
checkParserErrors(t, p)
fn, ok := expr.(*parser.FunctionLiteral)
if !ok {
t.Fatalf("expected FunctionLiteral, got %T", expr)
}
if len(fn.Parameters) != len(tt.params) {
t.Fatalf("expected %d parameters, got %d", len(tt.params), len(fn.Parameters))
}
for i, expected := range tt.params {
if fn.Parameters[i] != expected {
t.Errorf("parameter %d: expected %s, got %s", i, expected, fn.Parameters[i])
}
}
if fn.Variadic != tt.variadic {
t.Errorf("expected variadic = %t, got %t", tt.variadic, fn.Variadic)
}
})
}
}
func TestReturnStatements(t *testing.T) {
tests := []struct {
input string
hasValue bool
desc string
}{
{"return", false, "return without value"},
{"return 42", true, "return with number"},
{"return \"hello\"", true, "return with string"},
{"return x", true, "return with identifier"},
{"return x + y", true, "return with expression"},
{"return table.key", true, "return with member access"},
{"return arr[0]", true, "return with index access"},
{"return {a = 1, b = 2}", true, "return with table"},
{"return fn(x) return x end", true, "return with function"},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
l := parser.NewLexer(tt.input)
p := parser.NewParser(l)
program := p.ParseProgram()
checkParserErrors(t, p)
if len(program.Statements) != 1 {
t.Fatalf("expected 1 statement, got %d", len(program.Statements))
}
stmt, ok := program.Statements[0].(*parser.ReturnStatement)
if !ok {
t.Fatalf("expected ReturnStatement, got %T", program.Statements[0])
}
if tt.hasValue && stmt.Value == nil {
t.Error("expected return value but got nil")
}
if !tt.hasValue && stmt.Value != nil {
t.Error("expected no return value but got one")
}
})
}
}
func TestFunctionAssignments(t *testing.T) {
tests := []struct {
input string
desc string
}{
{"add = fn(a, b) return a + b end", "assign function to variable"},
{"math.add = fn(x, y) return x + y end", "assign function to member"},
{"funcs[\"add\"] = fn(a, b) return a + b end", "assign function to index"},
{"callback = fn() echo \"called\" end", "assign simple function"},
{"sum = fn(...) return 0 end", "assign variadic function"},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
l := parser.NewLexer(tt.input)
p := parser.NewParser(l)
program := p.ParseProgram()
checkParserErrors(t, p)
if len(program.Statements) != 1 {
t.Fatalf("expected 1 statement, got %d", len(program.Statements))
}
stmt, ok := program.Statements[0].(*parser.AssignStatement)
if !ok {
t.Fatalf("expected AssignStatement, got %T", program.Statements[0])
}
_, ok = stmt.Value.(*parser.FunctionLiteral)
if !ok {
t.Fatalf("expected FunctionLiteral value, got %T", stmt.Value)
}
})
}
}
func TestFunctionsInTables(t *testing.T) {
tests := []struct {
input string
desc string
}{
{"{add = fn(a, b) return a + b end}", "function in hash table"},
{"{fn(x) return x end, fn(y) return y end}", "functions in array"},
{"{math = {add = fn(a, b) return a + b end}}", "nested function in table"},
{"{callback = fn() echo \"hi\" end, data = 42}", "mixed table with function"},
{"{fn(...) return 0 end}", "variadic function in array"},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
l := parser.NewLexer(tt.input)
p := parser.NewParser(l)
expr := p.ParseExpression(parser.LOWEST)
checkParserErrors(t, p)
table, ok := expr.(*parser.TableLiteral)
if !ok {
t.Fatalf("expected TableLiteral, got %T", expr)
}
// Verify at least one pair contains a function (or nested table with function)
foundFunction := false
for _, pair := range table.Pairs {
if _, ok := pair.Value.(*parser.FunctionLiteral); ok {
foundFunction = true
break
}
// Check nested tables for functions
if nestedTable, ok := pair.Value.(*parser.TableLiteral); ok {
for _, nestedPair := range nestedTable.Pairs {
if _, ok := nestedPair.Value.(*parser.FunctionLiteral); ok {
foundFunction = true
break
}
}
}
}
if !foundFunction {
t.Error("expected to find at least one function in table")
}
})
}
}
func TestFunctionWithComplexBody(t *testing.T) {
input := `fn(x, y)
if x > y then
return x
else
return y
end
echo "unreachable"
end`
l := parser.NewLexer(input)
p := parser.NewParser(l)
expr := p.ParseExpression(parser.LOWEST)
checkParserErrors(t, p)
fn, ok := expr.(*parser.FunctionLiteral)
if !ok {
t.Fatalf("expected FunctionLiteral, got %T", expr)
}
if len(fn.Parameters) != 2 {
t.Errorf("expected 2 parameters, got %d", len(fn.Parameters))
}
if len(fn.Body) != 2 {
t.Errorf("expected 2 body statements, got %d", len(fn.Body))
}
// First statement should be if
ifStmt, ok := fn.Body[0].(*parser.IfStatement)
if !ok {
t.Fatalf("expected IfStatement, got %T", fn.Body[0])
}
// Check that if body contains return
_, ok = ifStmt.Body[0].(*parser.ReturnStatement)
if !ok {
t.Fatalf("expected ReturnStatement in if body, got %T", ifStmt.Body[0])
}
// Check that else body contains return
_, ok = ifStmt.Else[0].(*parser.ReturnStatement)
if !ok {
t.Fatalf("expected ReturnStatement in else body, got %T", ifStmt.Else[0])
}
// Second statement should be echo
_, ok = fn.Body[1].(*parser.EchoStatement)
if !ok {
t.Fatalf("expected EchoStatement, got %T", fn.Body[1])
}
}
func TestNestedFunctions(t *testing.T) {
input := `fn(x)
inner = fn(y) return y * 2 end
return inner(x)
end`
l := parser.NewLexer(input)
p := parser.NewParser(l)
expr := p.ParseExpression(parser.LOWEST)
checkParserErrors(t, p)
fn, ok := expr.(*parser.FunctionLiteral)
if !ok {
t.Fatalf("expected FunctionLiteral, got %T", expr)
}
if len(fn.Body) != 2 {
t.Fatalf("expected 2 body statements, got %d", len(fn.Body))
}
// First statement: assignment of inner function
assign, ok := fn.Body[0].(*parser.AssignStatement)
if !ok {
t.Fatalf("expected AssignStatement, got %T", fn.Body[0])
}
innerFn, ok := assign.Value.(*parser.FunctionLiteral)
if !ok {
t.Fatalf("expected FunctionLiteral value, got %T", assign.Value)
}
if len(innerFn.Parameters) != 1 {
t.Errorf("expected 1 parameter in inner function, got %d", len(innerFn.Parameters))
}
// Second statement: return
_, ok = fn.Body[1].(*parser.ReturnStatement)
if !ok {
t.Fatalf("expected ReturnStatement, got %T", fn.Body[1])
}
}
func TestFunctionInLoop(t *testing.T) {
input := `for i = 1, 10 do
callback = fn(x) return x + i end
echo callback(i)
end`
l := parser.NewLexer(input)
p := parser.NewParser(l)
program := p.ParseProgram()
checkParserErrors(t, p)
if len(program.Statements) != 1 {
t.Fatalf("expected 1 statement, got %d", len(program.Statements))
}
forStmt, ok := program.Statements[0].(*parser.ForStatement)
if !ok {
t.Fatalf("expected ForStatement, got %T", program.Statements[0])
}
if len(forStmt.Body) != 2 {
t.Fatalf("expected 2 body statements, got %d", len(forStmt.Body))
}
// First: function assignment
assign, ok := forStmt.Body[0].(*parser.AssignStatement)
if !ok {
t.Fatalf("expected AssignStatement, got %T", forStmt.Body[0])
}
_, ok = assign.Value.(*parser.FunctionLiteral)
if !ok {
t.Fatalf("expected FunctionLiteral value, got %T", assign.Value)
}
}
func TestFunctionErrors(t *testing.T) {
tests := []struct {
input string
expectedError string
desc string
}{
{"fn", "expected '(' after 'fn'", "missing parentheses"},
{"fn(", "expected ')' after function parameters", "missing closing paren"},
{"fn() echo x", "expected 'end' to close function", "missing end"},
{"fn(123) end", "expected parameter name", "invalid parameter name"},
{"fn(a,) end", "expected parameter name", "trailing comma"},
{"fn(a, 123) end", "expected parameter name", "invalid second parameter"},
{"fn(..., a) end", "expected next token to be ), got , instead", "parameter after ellipsis"},
{"fn(a, ..., b) end", "expected next token to be ), got , instead", "parameter after ellipsis in middle"},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
l := parser.NewLexer(tt.input)
p := parser.NewParser(l)
p.ParseExpression(parser.LOWEST)
if !p.HasErrors() {
t.Fatal("expected parsing errors")
}
errors := p.Errors()
found := false
for _, err := range errors {
if strings.Contains(err.Message, tt.expectedError) {
found = true
break
}
}
if !found {
errorMsgs := make([]string, len(errors))
for i, err := range errors {
errorMsgs[i] = err.Message
}
t.Errorf("expected error containing %q, got %v", tt.expectedError, errorMsgs)
}
})
}
}
func TestReturnErrors(t *testing.T) {
tests := []struct {
input string
expectedError string
desc string
}{
{"return +", "unexpected token '+'", "return with invalid expression"},
{"return (", "unexpected end of input", "return with incomplete expression"},
{"return 1 +", "expected expression after operator '+'", "return with incomplete infix"},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
l := parser.NewLexer(tt.input)
p := parser.NewParser(l)
p.ParseProgram()
if !p.HasErrors() {
t.Fatal("expected parsing errors")
}
errors := p.Errors()
found := false
for _, err := range errors {
if strings.Contains(err.Message, tt.expectedError) {
found = true
break
}
}
if !found {
errorMsgs := make([]string, len(errors))
for i, err := range errors {
errorMsgs[i] = err.Message
}
t.Errorf("expected error containing %q, got %v", tt.expectedError, errorMsgs)
}
})
}
}
func TestFunctionStringRepresentation(t *testing.T) {
tests := []struct {
input string
contains []string
desc string
}{
{
"fn() end",
[]string{"fn()", "end"},
"empty function",
},
{
"fn(a, b) return a + b end",
[]string{"fn(a, b)", "return (a + b)", "end"},
"function with params and return",
},
{
"fn(...) echo \"variadic\" end",
[]string{"fn(...)", "echo \"variadic\"", "end"},
"variadic function",
},
{
"fn(a, b, ...) return a end",
[]string{"fn(a, b, ...)", "return a", "end"},
"mixed params and variadic",
},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
l := parser.NewLexer(tt.input)
p := parser.NewParser(l)
expr := p.ParseExpression(parser.LOWEST)
checkParserErrors(t, p)
fnStr := expr.String()
for _, contain := range tt.contains {
if !strings.Contains(fnStr, contain) {
t.Errorf("expected function string to contain %q, got:\n%s", contain, fnStr)
}
}
})
}
}
func TestReturnStringRepresentation(t *testing.T) {
tests := []struct {
input string
expected string
desc string
}{
{"return", "return", "simple return"},
{"return 42", "return 42.00", "return with number"},
{"return \"hello\"", "return \"hello\"", "return with string"},
{"return x + 1", "return (x + 1.00)", "return with expression"},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
l := parser.NewLexer(tt.input)
p := parser.NewParser(l)
program := p.ParseProgram()
checkParserErrors(t, p)
result := strings.TrimSpace(program.String())
if result != tt.expected {
t.Errorf("expected %q, got %q", tt.expected, result)
}
})
}
}
func TestComplexFunctionProgram(t *testing.T) {
input := `math = {
add = fn(a, b) return a + b end,
multiply = fn(x, y) return x * y end,
max = fn(a, b)
if a > b then
return a
else
return b
end
end
}
result = math.add(5, 3)
echo result
calculator = fn(op, ...)
if op == "add" then
return fn(a, b) return a + b end
elseif op == "sub" then
return fn(a, b) return a - b end
else
return nil
end
end
adder = calculator("add")
echo adder`
l := parser.NewLexer(input)
p := parser.NewParser(l)
program := p.ParseProgram()
checkParserErrors(t, p)
if len(program.Statements) != 6 {
t.Fatalf("expected 6 statements, got %d", len(program.Statements))
}
// First: table with functions
mathAssign, ok := program.Statements[0].(*parser.AssignStatement)
if !ok {
t.Fatalf("statement 0: expected AssignStatement, got %T", program.Statements[0])
}
table, ok := mathAssign.Value.(*parser.TableLiteral)
if !ok {
t.Fatalf("expected TableLiteral, got %T", mathAssign.Value)
}
if len(table.Pairs) != 3 {
t.Errorf("expected 3 functions in math table, got %d", len(table.Pairs))
}
// Verify all pairs contain functions
for i, pair := range table.Pairs {
_, ok := pair.Value.(*parser.FunctionLiteral)
if !ok {
t.Errorf("pair %d: expected FunctionLiteral, got %T", i, pair.Value)
}
}
// Second: result assignment (function call would be handled by interpreter)
_, ok = program.Statements[1].(*parser.AssignStatement)
if !ok {
t.Fatalf("statement 1: expected AssignStatement, got %T", program.Statements[1])
}
// Third: echo
_, ok = program.Statements[2].(*parser.EchoStatement)
if !ok {
t.Fatalf("statement 2: expected EchoStatement, got %T", program.Statements[2])
}
// Fourth: calculator function assignment
calcAssign, ok := program.Statements[3].(*parser.AssignStatement)
if !ok {
t.Fatalf("statement 3: expected AssignStatement, got %T", program.Statements[3])
}
calcFn, ok := calcAssign.Value.(*parser.FunctionLiteral)
if !ok {
t.Fatalf("expected FunctionLiteral, got %T", calcAssign.Value)
}
if !calcFn.Variadic {
t.Error("expected calculator function to be variadic")
}
// Fifth: adder assignment
_, ok = program.Statements[4].(*parser.AssignStatement)
if !ok {
t.Fatalf("statement 4: expected AssignStatement, got %T", program.Statements[4])
}
// Sixth: echo adder
_, ok = program.Statements[5].(*parser.EchoStatement)
if !ok {
t.Fatalf("statement 5: expected EchoStatement, got %T", program.Statements[5])
}
}
func TestFunctionInConditional(t *testing.T) {
input := `if condition then
handler = fn(x) return x * 2 end
else
handler = fn(x) return x + 1 end
end
for i = 1, 10 do
if i > 5 then
result = fn() return "high" end
else
result = fn() return "low" end
end
echo result()
end`
l := parser.NewLexer(input)
p := parser.NewParser(l)
program := p.ParseProgram()
checkParserErrors(t, p)
if len(program.Statements) != 2 {
t.Fatalf("expected 2 statements, got %d", len(program.Statements))
}
// First: if statement with function assignments
ifStmt, ok := program.Statements[0].(*parser.IfStatement)
if !ok {
t.Fatalf("statement 0: expected IfStatement, got %T", program.Statements[0])
}
// Check if body has function assignment
ifAssign, ok := ifStmt.Body[0].(*parser.AssignStatement)
if !ok {
t.Fatalf("if body: expected AssignStatement, got %T", ifStmt.Body[0])
}
_, ok = ifAssign.Value.(*parser.FunctionLiteral)
if !ok {
t.Fatalf("if body: expected FunctionLiteral, got %T", ifAssign.Value)
}
// Check else body has function assignment
elseAssign, ok := ifStmt.Else[0].(*parser.AssignStatement)
if !ok {
t.Fatalf("else body: expected AssignStatement, got %T", ifStmt.Else[0])
}
_, ok = elseAssign.Value.(*parser.FunctionLiteral)
if !ok {
t.Fatalf("else body: expected FunctionLiteral, got %T", elseAssign.Value)
}
// Second: for loop with nested conditionals containing functions
forStmt, ok := program.Statements[1].(*parser.ForStatement)
if !ok {
t.Fatalf("statement 1: expected ForStatement, got %T", program.Statements[1])
}
if len(forStmt.Body) != 2 {
t.Fatalf("expected 2 statements in for body, got %d", len(forStmt.Body))
}
// Check nested if contains function assignments
nestedIf, ok := forStmt.Body[0].(*parser.IfStatement)
if !ok {
t.Fatalf("for body[0]: expected IfStatement, got %T", forStmt.Body[0])
}
// Verify both branches assign functions
nestedIfAssign, ok := nestedIf.Body[0].(*parser.AssignStatement)
if !ok {
t.Fatalf("nested if body: expected AssignStatement, got %T", nestedIf.Body[0])
}
_, ok = nestedIfAssign.Value.(*parser.FunctionLiteral)
if !ok {
t.Fatalf("nested if body: expected FunctionLiteral, got %T", nestedIfAssign.Value)
}
nestedElseAssign, ok := nestedIf.Else[0].(*parser.AssignStatement)
if !ok {
t.Fatalf("nested else body: expected AssignStatement, got %T", nestedIf.Else[0])
}
_, ok = nestedElseAssign.Value.(*parser.FunctionLiteral)
if !ok {
t.Fatalf("nested else body: expected FunctionLiteral, got %T", nestedElseAssign.Value)
}
}

View File

@ -36,6 +36,7 @@ const (
LBRACKET // [
RBRACKET // ]
COMMA // ,
ELLIPSIS // ...
// Keywords
IF
@ -50,6 +51,8 @@ const (
DO
BREAK
EXIT
FN
RETURN
// Special
EOF
@ -93,6 +96,7 @@ var precedences = map[TokenType]Precedence{
STAR: PRODUCT,
DOT: MEMBER,
LBRACKET: MEMBER,
LPAREN: CALL,
}
// lookupIdent checks if an identifier is a keyword
@ -113,6 +117,8 @@ func lookupIdent(ident string) TokenType {
"do": DO,
"break": BREAK,
"exit": EXIT,
"fn": FN,
"return": RETURN,
}
if tok, ok := keywords[ident]; ok {