From fc988d257f58ae29d7aca4e7a0b4d1b081c4cae7 Mon Sep 17 00:00:00 2001 From: Sky Johnson Date: Tue, 10 Jun 2025 23:05:21 -0500 Subject: [PATCH] functions --- parser/ast.go | 59 +++ parser/lexer.go | 18 +- parser/parser.go | 129 +++++- parser/tests/functions_test.go | 705 +++++++++++++++++++++++++++++++++ parser/token.go | 6 + 5 files changed, 914 insertions(+), 3 deletions(-) create mode 100644 parser/tests/functions_test.go diff --git a/parser/ast.go b/parser/ast.go index 8d789be..8069b12 100644 --- a/parser/ast.go +++ b/parser/ast.go @@ -75,6 +75,19 @@ func (es *ExitStatement) String() string { return fmt.Sprintf("exit %s", es.Value.String()) } +// ReturnStatement represents return statements +type ReturnStatement struct { + Value Expression // optional, can be nil +} + +func (rs *ReturnStatement) statementNode() {} +func (rs *ReturnStatement) String() string { + if rs.Value == nil { + return "return" + } + return fmt.Sprintf("return %s", rs.Value.String()) +} + // ElseIfClause represents an elseif condition type ElseIfClause struct { Condition Expression @@ -241,6 +254,52 @@ type NilLiteral struct{} func (nl *NilLiteral) expressionNode() {} func (nl *NilLiteral) String() string { return "nil" } +// FunctionLiteral represents function literals: fn(a, b, ...) ... end +type FunctionLiteral struct { + Parameters []string + Variadic bool + Body []Statement +} + +func (fl *FunctionLiteral) expressionNode() {} +func (fl *FunctionLiteral) String() string { + var params string + for i, param := range fl.Parameters { + if i > 0 { + params += ", " + } + params += param + } + if fl.Variadic { + if len(fl.Parameters) > 0 { + params += ", " + } + params += "..." + } + + result := fmt.Sprintf("fn(%s)\n", params) + for _, stmt := range fl.Body { + result += "\t" + stmt.String() + "\n" + } + result += "end" + return result +} + +// CallExpression represents function calls: func(arg1, arg2, ...) +type CallExpression struct { + Function Expression + Arguments []Expression +} + +func (ce *CallExpression) expressionNode() {} +func (ce *CallExpression) String() string { + var args []string + for _, arg := range ce.Arguments { + args = append(args, arg.String()) + } + return fmt.Sprintf("%s(%s)", ce.Function.String(), joinStrings(args, ", ")) +} + // PrefixExpression represents prefix operations like -x type PrefixExpression struct { Operator string diff --git a/parser/lexer.go b/parser/lexer.go index e0b0297..01d8e44 100644 --- a/parser/lexer.go +++ b/parser/lexer.go @@ -47,6 +47,15 @@ func (l *Lexer) peekChar() byte { return l.input[l.readPosition] } +// peekCharAt returns the character at offset positions ahead +func (l *Lexer) peekCharAt(offset int) byte { + pos := l.readPosition + offset - 1 + if pos >= len(l.input) { + return 0 + } + return l.input[pos] +} + // skipWhitespace skips whitespace characters func (l *Lexer) skipWhitespace() { for l.ch == ' ' || l.ch == '\t' || l.ch == '\n' || l.ch == '\r' { @@ -249,7 +258,14 @@ func (l *Lexer) NextToken() Token { case '/': tok = Token{Type: SLASH, Literal: string(l.ch), Line: l.line, Column: l.column} case '.': - tok = Token{Type: DOT, Literal: string(l.ch), Line: l.line, Column: l.column} + // Check for ellipsis (...) + if l.peekChar() == '.' && l.peekCharAt(2) == '.' { + l.readChar() // skip first '.' + l.readChar() // skip second '.' + tok = Token{Type: ELLIPSIS, Literal: "...", Line: l.line, Column: l.column} + } else { + tok = Token{Type: DOT, Literal: string(l.ch), Line: l.line, Column: l.column} + } case '(': tok = Token{Type: LPAREN, Literal: string(l.ch), Line: l.line, Column: l.column} case ')': diff --git a/parser/parser.go b/parser/parser.go index 33409ef..e580d1f 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -49,6 +49,7 @@ func NewParser(lexer *Lexer) *Parser { p.registerPrefix(LPAREN, p.parseGroupedExpression) p.registerPrefix(LBRACE, p.parseTableLiteral) p.registerPrefix(MINUS, p.parsePrefixExpression) + p.registerPrefix(FN, p.parseFunctionLiteral) p.infixParseFns = make(map[TokenType]func(Expression) Expression) p.registerInfix(PLUS, p.parseInfixExpression) @@ -63,6 +64,7 @@ func NewParser(lexer *Lexer) *Parser { p.registerInfix(GT_EQ, p.parseInfixExpression) p.registerInfix(DOT, p.parseDotExpression) p.registerInfix(LBRACKET, p.parseIndexExpression) + p.registerInfix(LPAREN, p.parseCallExpression) // Read two tokens, so curToken and peekToken are both set p.nextToken() @@ -121,6 +123,8 @@ func (p *Parser) parseStatement() Statement { return p.parseBreakStatement() case EXIT: return p.parseExitStatement() + case RETURN: + return p.parseReturnStatement() case ASSIGN: p.addError("assignment operator '=' without left-hand side identifier") return nil @@ -214,10 +218,27 @@ func (p *Parser) parseExitStatement() *ExitStatement { return stmt } +// parseReturnStatement parses return statements +func (p *Parser) parseReturnStatement() *ReturnStatement { + stmt := &ReturnStatement{} + + // Check if there's an optional expression after 'return' + if p.canStartExpression(p.peekToken.Type) { + p.nextToken() // move past 'return' + stmt.Value = p.ParseExpression(LOWEST) + if stmt.Value == nil { + p.addError("expected expression after 'return'") + return nil + } + } + + return stmt +} + // canStartExpression checks if a token type can start an expression func (p *Parser) canStartExpression(tokenType TokenType) bool { switch tokenType { - case IDENT, NUMBER, STRING, TRUE, FALSE, NIL, LPAREN, LBRACE, MINUS: + case IDENT, NUMBER, STRING, TRUE, FALSE, NIL, LPAREN, LBRACE, MINUS, FN: return true default: return false @@ -621,6 +642,74 @@ func (p *Parser) parseGroupedExpression() Expression { return exp } +func (p *Parser) parseFunctionLiteral() Expression { + fn := &FunctionLiteral{} + + if !p.expectPeek(LPAREN) { + p.addError("expected '(' after 'fn'") + return nil + } + + fn.Parameters, fn.Variadic = p.parseFunctionParameters() + + if !p.expectPeek(RPAREN) { + p.addError("expected ')' after function parameters") + return nil + } + + p.nextToken() // move past ')' + + // Parse function body + fn.Body = p.parseBlockStatements(END) + + if !p.curTokenIs(END) { + p.addError("expected 'end' to close function") + return nil + } + + return fn +} + +func (p *Parser) parseFunctionParameters() ([]string, bool) { + var params []string + var variadic bool + + if p.peekTokenIs(RPAREN) { + return params, false + } + + p.nextToken() + + for { + if p.curTokenIs(ELLIPSIS) { + variadic = true + break + } + + if !p.curTokenIs(IDENT) { + p.addError("expected parameter name") + return nil, false + } + + params = append(params, p.curToken.Literal) + + if !p.peekTokenIs(COMMA) { + break + } + + p.nextToken() // move to ',' + p.nextToken() // move past ',' + + // Check for ellipsis after comma + if p.curTokenIs(ELLIPSIS) { + variadic = true + break + } + } + + return params, variadic +} + func (p *Parser) parseTableLiteral() Expression { table := &TableLiteral{} table.Pairs = []TablePair{} @@ -725,6 +814,36 @@ func (p *Parser) parseDotExpression(left Expression) Expression { } } +func (p *Parser) parseCallExpression(fn Expression) Expression { + call := &CallExpression{Function: fn} + call.Arguments = p.parseExpressionList(RPAREN) + return call +} + +func (p *Parser) parseExpressionList(end TokenType) []Expression { + var args []Expression + + if p.peekTokenIs(end) { + p.nextToken() + return args + } + + p.nextToken() + args = append(args, p.ParseExpression(LOWEST)) + + for p.peekTokenIs(COMMA) { + p.nextToken() + p.nextToken() + args = append(args, p.ParseExpression(LOWEST)) + } + + if !p.expectPeek(end) { + return nil + } + + return args +} + func (p *Parser) parseIndexExpression(left Expression) Expression { p.nextToken() // move past '[' @@ -776,7 +895,7 @@ func (p *Parser) expectPeekIdent() bool { // isKeyword checks if a token type is a keyword that can be used as identifier func (p *Parser) isKeyword(t TokenType) bool { switch t { - case TRUE, FALSE, NIL, IF, THEN, ELSEIF, ELSE, END, ECHO, FOR, WHILE, IN, DO, BREAK, EXIT: + case TRUE, FALSE, NIL, IF, THEN, ELSEIF, ELSE, END, ECHO, FOR, WHILE, IN, DO, BREAK, EXIT, FN, RETURN: return true default: return false @@ -910,6 +1029,8 @@ func tokenTypeString(t TokenType) string { return "]" case COMMA: return "," + case ELLIPSIS: + return "..." case IF: return "if" case THEN: @@ -934,6 +1055,10 @@ func tokenTypeString(t TokenType) string { return "break" case EXIT: return "exit" + case FN: + return "fn" + case RETURN: + return "return" case EOF: return "end of file" case ILLEGAL: diff --git a/parser/tests/functions_test.go b/parser/tests/functions_test.go new file mode 100644 index 0000000..fccad34 --- /dev/null +++ b/parser/tests/functions_test.go @@ -0,0 +1,705 @@ +package parser_test + +import ( + "strings" + "testing" + + "git.sharkk.net/Sharkk/Mako/parser" +) + +func TestBasicFunctionLiterals(t *testing.T) { + tests := []struct { + input string + paramCount int + variadic bool + bodyCount int + description string + }{ + {"fn() end", 0, false, 0, "empty function"}, + {"fn(a) echo a end", 1, false, 1, "single parameter"}, + {"fn(a, b) return a + b end", 2, false, 1, "two parameters"}, + {"fn(...) echo \"variadic\" end", 0, true, 1, "variadic only"}, + {"fn(a, b, ...) return a end", 2, true, 1, "mixed params and variadic"}, + {"fn(x, y) x = x + 1 y = y + 1 end", 2, false, 2, "multiple statements"}, + } + + for _, tt := range tests { + t.Run(tt.description, func(t *testing.T) { + l := parser.NewLexer(tt.input) + p := parser.NewParser(l) + expr := p.ParseExpression(parser.LOWEST) + checkParserErrors(t, p) + + fn, ok := expr.(*parser.FunctionLiteral) + if !ok { + t.Fatalf("expected FunctionLiteral, got %T", expr) + } + + if len(fn.Parameters) != tt.paramCount { + t.Errorf("expected %d parameters, got %d", tt.paramCount, len(fn.Parameters)) + } + + if fn.Variadic != tt.variadic { + t.Errorf("expected variadic = %t, got %t", tt.variadic, fn.Variadic) + } + + if len(fn.Body) != tt.bodyCount { + t.Errorf("expected %d body statements, got %d", tt.bodyCount, len(fn.Body)) + } + }) + } +} + +func TestFunctionParameters(t *testing.T) { + tests := []struct { + input string + params []string + variadic bool + desc string + }{ + {"fn(a) end", []string{"a"}, false, "single param"}, + {"fn(a, b, c) end", []string{"a", "b", "c"}, false, "multiple params"}, + {"fn(...) end", []string{}, true, "only variadic"}, + {"fn(x, ...) end", []string{"x"}, true, "param with variadic"}, + {"fn(a, b, ...) end", []string{"a", "b"}, true, "multiple params with variadic"}, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + l := parser.NewLexer(tt.input) + p := parser.NewParser(l) + expr := p.ParseExpression(parser.LOWEST) + checkParserErrors(t, p) + + fn, ok := expr.(*parser.FunctionLiteral) + if !ok { + t.Fatalf("expected FunctionLiteral, got %T", expr) + } + + if len(fn.Parameters) != len(tt.params) { + t.Fatalf("expected %d parameters, got %d", len(tt.params), len(fn.Parameters)) + } + + for i, expected := range tt.params { + if fn.Parameters[i] != expected { + t.Errorf("parameter %d: expected %s, got %s", i, expected, fn.Parameters[i]) + } + } + + if fn.Variadic != tt.variadic { + t.Errorf("expected variadic = %t, got %t", tt.variadic, fn.Variadic) + } + }) + } +} + +func TestReturnStatements(t *testing.T) { + tests := []struct { + input string + hasValue bool + desc string + }{ + {"return", false, "return without value"}, + {"return 42", true, "return with number"}, + {"return \"hello\"", true, "return with string"}, + {"return x", true, "return with identifier"}, + {"return x + y", true, "return with expression"}, + {"return table.key", true, "return with member access"}, + {"return arr[0]", true, "return with index access"}, + {"return {a = 1, b = 2}", true, "return with table"}, + {"return fn(x) return x end", true, "return with function"}, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + l := parser.NewLexer(tt.input) + p := parser.NewParser(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + if len(program.Statements) != 1 { + t.Fatalf("expected 1 statement, got %d", len(program.Statements)) + } + + stmt, ok := program.Statements[0].(*parser.ReturnStatement) + if !ok { + t.Fatalf("expected ReturnStatement, got %T", program.Statements[0]) + } + + if tt.hasValue && stmt.Value == nil { + t.Error("expected return value but got nil") + } + + if !tt.hasValue && stmt.Value != nil { + t.Error("expected no return value but got one") + } + }) + } +} + +func TestFunctionAssignments(t *testing.T) { + tests := []struct { + input string + desc string + }{ + {"add = fn(a, b) return a + b end", "assign function to variable"}, + {"math.add = fn(x, y) return x + y end", "assign function to member"}, + {"funcs[\"add\"] = fn(a, b) return a + b end", "assign function to index"}, + {"callback = fn() echo \"called\" end", "assign simple function"}, + {"sum = fn(...) return 0 end", "assign variadic function"}, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + l := parser.NewLexer(tt.input) + p := parser.NewParser(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + if len(program.Statements) != 1 { + t.Fatalf("expected 1 statement, got %d", len(program.Statements)) + } + + stmt, ok := program.Statements[0].(*parser.AssignStatement) + if !ok { + t.Fatalf("expected AssignStatement, got %T", program.Statements[0]) + } + + _, ok = stmt.Value.(*parser.FunctionLiteral) + if !ok { + t.Fatalf("expected FunctionLiteral value, got %T", stmt.Value) + } + }) + } +} + +func TestFunctionsInTables(t *testing.T) { + tests := []struct { + input string + desc string + }{ + {"{add = fn(a, b) return a + b end}", "function in hash table"}, + {"{fn(x) return x end, fn(y) return y end}", "functions in array"}, + {"{math = {add = fn(a, b) return a + b end}}", "nested function in table"}, + {"{callback = fn() echo \"hi\" end, data = 42}", "mixed table with function"}, + {"{fn(...) return 0 end}", "variadic function in array"}, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + l := parser.NewLexer(tt.input) + p := parser.NewParser(l) + expr := p.ParseExpression(parser.LOWEST) + checkParserErrors(t, p) + + table, ok := expr.(*parser.TableLiteral) + if !ok { + t.Fatalf("expected TableLiteral, got %T", expr) + } + + // Verify at least one pair contains a function (or nested table with function) + foundFunction := false + for _, pair := range table.Pairs { + if _, ok := pair.Value.(*parser.FunctionLiteral); ok { + foundFunction = true + break + } + // Check nested tables for functions + if nestedTable, ok := pair.Value.(*parser.TableLiteral); ok { + for _, nestedPair := range nestedTable.Pairs { + if _, ok := nestedPair.Value.(*parser.FunctionLiteral); ok { + foundFunction = true + break + } + } + } + } + + if !foundFunction { + t.Error("expected to find at least one function in table") + } + }) + } +} + +func TestFunctionWithComplexBody(t *testing.T) { + input := `fn(x, y) + if x > y then + return x + else + return y + end + echo "unreachable" +end` + + l := parser.NewLexer(input) + p := parser.NewParser(l) + expr := p.ParseExpression(parser.LOWEST) + checkParserErrors(t, p) + + fn, ok := expr.(*parser.FunctionLiteral) + if !ok { + t.Fatalf("expected FunctionLiteral, got %T", expr) + } + + if len(fn.Parameters) != 2 { + t.Errorf("expected 2 parameters, got %d", len(fn.Parameters)) + } + + if len(fn.Body) != 2 { + t.Errorf("expected 2 body statements, got %d", len(fn.Body)) + } + + // First statement should be if + ifStmt, ok := fn.Body[0].(*parser.IfStatement) + if !ok { + t.Fatalf("expected IfStatement, got %T", fn.Body[0]) + } + + // Check that if body contains return + _, ok = ifStmt.Body[0].(*parser.ReturnStatement) + if !ok { + t.Fatalf("expected ReturnStatement in if body, got %T", ifStmt.Body[0]) + } + + // Check that else body contains return + _, ok = ifStmt.Else[0].(*parser.ReturnStatement) + if !ok { + t.Fatalf("expected ReturnStatement in else body, got %T", ifStmt.Else[0]) + } + + // Second statement should be echo + _, ok = fn.Body[1].(*parser.EchoStatement) + if !ok { + t.Fatalf("expected EchoStatement, got %T", fn.Body[1]) + } +} + +func TestNestedFunctions(t *testing.T) { + input := `fn(x) + inner = fn(y) return y * 2 end + return inner(x) +end` + + l := parser.NewLexer(input) + p := parser.NewParser(l) + expr := p.ParseExpression(parser.LOWEST) + checkParserErrors(t, p) + + fn, ok := expr.(*parser.FunctionLiteral) + if !ok { + t.Fatalf("expected FunctionLiteral, got %T", expr) + } + + if len(fn.Body) != 2 { + t.Fatalf("expected 2 body statements, got %d", len(fn.Body)) + } + + // First statement: assignment of inner function + assign, ok := fn.Body[0].(*parser.AssignStatement) + if !ok { + t.Fatalf("expected AssignStatement, got %T", fn.Body[0]) + } + + innerFn, ok := assign.Value.(*parser.FunctionLiteral) + if !ok { + t.Fatalf("expected FunctionLiteral value, got %T", assign.Value) + } + + if len(innerFn.Parameters) != 1 { + t.Errorf("expected 1 parameter in inner function, got %d", len(innerFn.Parameters)) + } + + // Second statement: return + _, ok = fn.Body[1].(*parser.ReturnStatement) + if !ok { + t.Fatalf("expected ReturnStatement, got %T", fn.Body[1]) + } +} + +func TestFunctionInLoop(t *testing.T) { + input := `for i = 1, 10 do + callback = fn(x) return x + i end + echo callback(i) +end` + + l := parser.NewLexer(input) + p := parser.NewParser(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + if len(program.Statements) != 1 { + t.Fatalf("expected 1 statement, got %d", len(program.Statements)) + } + + forStmt, ok := program.Statements[0].(*parser.ForStatement) + if !ok { + t.Fatalf("expected ForStatement, got %T", program.Statements[0]) + } + + if len(forStmt.Body) != 2 { + t.Fatalf("expected 2 body statements, got %d", len(forStmt.Body)) + } + + // First: function assignment + assign, ok := forStmt.Body[0].(*parser.AssignStatement) + if !ok { + t.Fatalf("expected AssignStatement, got %T", forStmt.Body[0]) + } + + _, ok = assign.Value.(*parser.FunctionLiteral) + if !ok { + t.Fatalf("expected FunctionLiteral value, got %T", assign.Value) + } +} + +func TestFunctionErrors(t *testing.T) { + tests := []struct { + input string + expectedError string + desc string + }{ + {"fn", "expected '(' after 'fn'", "missing parentheses"}, + {"fn(", "expected ')' after function parameters", "missing closing paren"}, + {"fn() echo x", "expected 'end' to close function", "missing end"}, + {"fn(123) end", "expected parameter name", "invalid parameter name"}, + {"fn(a,) end", "expected parameter name", "trailing comma"}, + {"fn(a, 123) end", "expected parameter name", "invalid second parameter"}, + {"fn(..., a) end", "expected next token to be ), got , instead", "parameter after ellipsis"}, + {"fn(a, ..., b) end", "expected next token to be ), got , instead", "parameter after ellipsis in middle"}, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + l := parser.NewLexer(tt.input) + p := parser.NewParser(l) + p.ParseExpression(parser.LOWEST) + + if !p.HasErrors() { + t.Fatal("expected parsing errors") + } + + errors := p.Errors() + found := false + for _, err := range errors { + if strings.Contains(err.Message, tt.expectedError) { + found = true + break + } + } + + if !found { + errorMsgs := make([]string, len(errors)) + for i, err := range errors { + errorMsgs[i] = err.Message + } + t.Errorf("expected error containing %q, got %v", tt.expectedError, errorMsgs) + } + }) + } +} + +func TestReturnErrors(t *testing.T) { + tests := []struct { + input string + expectedError string + desc string + }{ + {"return +", "unexpected token '+'", "return with invalid expression"}, + {"return (", "unexpected end of input", "return with incomplete expression"}, + {"return 1 +", "expected expression after operator '+'", "return with incomplete infix"}, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + l := parser.NewLexer(tt.input) + p := parser.NewParser(l) + p.ParseProgram() + + if !p.HasErrors() { + t.Fatal("expected parsing errors") + } + + errors := p.Errors() + found := false + for _, err := range errors { + if strings.Contains(err.Message, tt.expectedError) { + found = true + break + } + } + + if !found { + errorMsgs := make([]string, len(errors)) + for i, err := range errors { + errorMsgs[i] = err.Message + } + t.Errorf("expected error containing %q, got %v", tt.expectedError, errorMsgs) + } + }) + } +} + +func TestFunctionStringRepresentation(t *testing.T) { + tests := []struct { + input string + contains []string + desc string + }{ + { + "fn() end", + []string{"fn()", "end"}, + "empty function", + }, + { + "fn(a, b) return a + b end", + []string{"fn(a, b)", "return (a + b)", "end"}, + "function with params and return", + }, + { + "fn(...) echo \"variadic\" end", + []string{"fn(...)", "echo \"variadic\"", "end"}, + "variadic function", + }, + { + "fn(a, b, ...) return a end", + []string{"fn(a, b, ...)", "return a", "end"}, + "mixed params and variadic", + }, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + l := parser.NewLexer(tt.input) + p := parser.NewParser(l) + expr := p.ParseExpression(parser.LOWEST) + checkParserErrors(t, p) + + fnStr := expr.String() + for _, contain := range tt.contains { + if !strings.Contains(fnStr, contain) { + t.Errorf("expected function string to contain %q, got:\n%s", contain, fnStr) + } + } + }) + } +} + +func TestReturnStringRepresentation(t *testing.T) { + tests := []struct { + input string + expected string + desc string + }{ + {"return", "return", "simple return"}, + {"return 42", "return 42.00", "return with number"}, + {"return \"hello\"", "return \"hello\"", "return with string"}, + {"return x + 1", "return (x + 1.00)", "return with expression"}, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + l := parser.NewLexer(tt.input) + p := parser.NewParser(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + result := strings.TrimSpace(program.String()) + if result != tt.expected { + t.Errorf("expected %q, got %q", tt.expected, result) + } + }) + } +} + +func TestComplexFunctionProgram(t *testing.T) { + input := `math = { + add = fn(a, b) return a + b end, + multiply = fn(x, y) return x * y end, + max = fn(a, b) + if a > b then + return a + else + return b + end + end +} + +result = math.add(5, 3) +echo result + +calculator = fn(op, ...) + if op == "add" then + return fn(a, b) return a + b end + elseif op == "sub" then + return fn(a, b) return a - b end + else + return nil + end +end + +adder = calculator("add") +echo adder` + + l := parser.NewLexer(input) + p := parser.NewParser(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + if len(program.Statements) != 6 { + t.Fatalf("expected 6 statements, got %d", len(program.Statements)) + } + + // First: table with functions + mathAssign, ok := program.Statements[0].(*parser.AssignStatement) + if !ok { + t.Fatalf("statement 0: expected AssignStatement, got %T", program.Statements[0]) + } + + table, ok := mathAssign.Value.(*parser.TableLiteral) + if !ok { + t.Fatalf("expected TableLiteral, got %T", mathAssign.Value) + } + + if len(table.Pairs) != 3 { + t.Errorf("expected 3 functions in math table, got %d", len(table.Pairs)) + } + + // Verify all pairs contain functions + for i, pair := range table.Pairs { + _, ok := pair.Value.(*parser.FunctionLiteral) + if !ok { + t.Errorf("pair %d: expected FunctionLiteral, got %T", i, pair.Value) + } + } + + // Second: result assignment (function call would be handled by interpreter) + _, ok = program.Statements[1].(*parser.AssignStatement) + if !ok { + t.Fatalf("statement 1: expected AssignStatement, got %T", program.Statements[1]) + } + + // Third: echo + _, ok = program.Statements[2].(*parser.EchoStatement) + if !ok { + t.Fatalf("statement 2: expected EchoStatement, got %T", program.Statements[2]) + } + + // Fourth: calculator function assignment + calcAssign, ok := program.Statements[3].(*parser.AssignStatement) + if !ok { + t.Fatalf("statement 3: expected AssignStatement, got %T", program.Statements[3]) + } + + calcFn, ok := calcAssign.Value.(*parser.FunctionLiteral) + if !ok { + t.Fatalf("expected FunctionLiteral, got %T", calcAssign.Value) + } + + if !calcFn.Variadic { + t.Error("expected calculator function to be variadic") + } + + // Fifth: adder assignment + _, ok = program.Statements[4].(*parser.AssignStatement) + if !ok { + t.Fatalf("statement 4: expected AssignStatement, got %T", program.Statements[4]) + } + + // Sixth: echo adder + _, ok = program.Statements[5].(*parser.EchoStatement) + if !ok { + t.Fatalf("statement 5: expected EchoStatement, got %T", program.Statements[5]) + } +} + +func TestFunctionInConditional(t *testing.T) { + input := `if condition then + handler = fn(x) return x * 2 end +else + handler = fn(x) return x + 1 end +end + +for i = 1, 10 do + if i > 5 then + result = fn() return "high" end + else + result = fn() return "low" end + end + echo result() +end` + + l := parser.NewLexer(input) + p := parser.NewParser(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + if len(program.Statements) != 2 { + t.Fatalf("expected 2 statements, got %d", len(program.Statements)) + } + + // First: if statement with function assignments + ifStmt, ok := program.Statements[0].(*parser.IfStatement) + if !ok { + t.Fatalf("statement 0: expected IfStatement, got %T", program.Statements[0]) + } + + // Check if body has function assignment + ifAssign, ok := ifStmt.Body[0].(*parser.AssignStatement) + if !ok { + t.Fatalf("if body: expected AssignStatement, got %T", ifStmt.Body[0]) + } + + _, ok = ifAssign.Value.(*parser.FunctionLiteral) + if !ok { + t.Fatalf("if body: expected FunctionLiteral, got %T", ifAssign.Value) + } + + // Check else body has function assignment + elseAssign, ok := ifStmt.Else[0].(*parser.AssignStatement) + if !ok { + t.Fatalf("else body: expected AssignStatement, got %T", ifStmt.Else[0]) + } + + _, ok = elseAssign.Value.(*parser.FunctionLiteral) + if !ok { + t.Fatalf("else body: expected FunctionLiteral, got %T", elseAssign.Value) + } + + // Second: for loop with nested conditionals containing functions + forStmt, ok := program.Statements[1].(*parser.ForStatement) + if !ok { + t.Fatalf("statement 1: expected ForStatement, got %T", program.Statements[1]) + } + + if len(forStmt.Body) != 2 { + t.Fatalf("expected 2 statements in for body, got %d", len(forStmt.Body)) + } + + // Check nested if contains function assignments + nestedIf, ok := forStmt.Body[0].(*parser.IfStatement) + if !ok { + t.Fatalf("for body[0]: expected IfStatement, got %T", forStmt.Body[0]) + } + + // Verify both branches assign functions + nestedIfAssign, ok := nestedIf.Body[0].(*parser.AssignStatement) + if !ok { + t.Fatalf("nested if body: expected AssignStatement, got %T", nestedIf.Body[0]) + } + + _, ok = nestedIfAssign.Value.(*parser.FunctionLiteral) + if !ok { + t.Fatalf("nested if body: expected FunctionLiteral, got %T", nestedIfAssign.Value) + } + + nestedElseAssign, ok := nestedIf.Else[0].(*parser.AssignStatement) + if !ok { + t.Fatalf("nested else body: expected AssignStatement, got %T", nestedIf.Else[0]) + } + + _, ok = nestedElseAssign.Value.(*parser.FunctionLiteral) + if !ok { + t.Fatalf("nested else body: expected FunctionLiteral, got %T", nestedElseAssign.Value) + } +} diff --git a/parser/token.go b/parser/token.go index b61a3b9..ae808dc 100644 --- a/parser/token.go +++ b/parser/token.go @@ -36,6 +36,7 @@ const ( LBRACKET // [ RBRACKET // ] COMMA // , + ELLIPSIS // ... // Keywords IF @@ -50,6 +51,8 @@ const ( DO BREAK EXIT + FN + RETURN // Special EOF @@ -93,6 +96,7 @@ var precedences = map[TokenType]Precedence{ STAR: PRODUCT, DOT: MEMBER, LBRACKET: MEMBER, + LPAREN: CALL, } // lookupIdent checks if an identifier is a keyword @@ -113,6 +117,8 @@ func lookupIdent(ident string) TokenType { "do": DO, "break": BREAK, "exit": EXIT, + "fn": FN, + "return": RETURN, } if tok, ok := keywords[ident]; ok {