functions
This commit is contained in:
parent
6eb4e21263
commit
fc988d257f
@ -75,6 +75,19 @@ func (es *ExitStatement) String() string {
|
||||
return fmt.Sprintf("exit %s", es.Value.String())
|
||||
}
|
||||
|
||||
// ReturnStatement represents return statements
|
||||
type ReturnStatement struct {
|
||||
Value Expression // optional, can be nil
|
||||
}
|
||||
|
||||
func (rs *ReturnStatement) statementNode() {}
|
||||
func (rs *ReturnStatement) String() string {
|
||||
if rs.Value == nil {
|
||||
return "return"
|
||||
}
|
||||
return fmt.Sprintf("return %s", rs.Value.String())
|
||||
}
|
||||
|
||||
// ElseIfClause represents an elseif condition
|
||||
type ElseIfClause struct {
|
||||
Condition Expression
|
||||
@ -241,6 +254,52 @@ type NilLiteral struct{}
|
||||
func (nl *NilLiteral) expressionNode() {}
|
||||
func (nl *NilLiteral) String() string { return "nil" }
|
||||
|
||||
// FunctionLiteral represents function literals: fn(a, b, ...) ... end
|
||||
type FunctionLiteral struct {
|
||||
Parameters []string
|
||||
Variadic bool
|
||||
Body []Statement
|
||||
}
|
||||
|
||||
func (fl *FunctionLiteral) expressionNode() {}
|
||||
func (fl *FunctionLiteral) String() string {
|
||||
var params string
|
||||
for i, param := range fl.Parameters {
|
||||
if i > 0 {
|
||||
params += ", "
|
||||
}
|
||||
params += param
|
||||
}
|
||||
if fl.Variadic {
|
||||
if len(fl.Parameters) > 0 {
|
||||
params += ", "
|
||||
}
|
||||
params += "..."
|
||||
}
|
||||
|
||||
result := fmt.Sprintf("fn(%s)\n", params)
|
||||
for _, stmt := range fl.Body {
|
||||
result += "\t" + stmt.String() + "\n"
|
||||
}
|
||||
result += "end"
|
||||
return result
|
||||
}
|
||||
|
||||
// CallExpression represents function calls: func(arg1, arg2, ...)
|
||||
type CallExpression struct {
|
||||
Function Expression
|
||||
Arguments []Expression
|
||||
}
|
||||
|
||||
func (ce *CallExpression) expressionNode() {}
|
||||
func (ce *CallExpression) String() string {
|
||||
var args []string
|
||||
for _, arg := range ce.Arguments {
|
||||
args = append(args, arg.String())
|
||||
}
|
||||
return fmt.Sprintf("%s(%s)", ce.Function.String(), joinStrings(args, ", "))
|
||||
}
|
||||
|
||||
// PrefixExpression represents prefix operations like -x
|
||||
type PrefixExpression struct {
|
||||
Operator string
|
||||
|
@ -47,6 +47,15 @@ func (l *Lexer) peekChar() byte {
|
||||
return l.input[l.readPosition]
|
||||
}
|
||||
|
||||
// peekCharAt returns the character at offset positions ahead
|
||||
func (l *Lexer) peekCharAt(offset int) byte {
|
||||
pos := l.readPosition + offset - 1
|
||||
if pos >= len(l.input) {
|
||||
return 0
|
||||
}
|
||||
return l.input[pos]
|
||||
}
|
||||
|
||||
// skipWhitespace skips whitespace characters
|
||||
func (l *Lexer) skipWhitespace() {
|
||||
for l.ch == ' ' || l.ch == '\t' || l.ch == '\n' || l.ch == '\r' {
|
||||
@ -249,7 +258,14 @@ func (l *Lexer) NextToken() Token {
|
||||
case '/':
|
||||
tok = Token{Type: SLASH, Literal: string(l.ch), Line: l.line, Column: l.column}
|
||||
case '.':
|
||||
tok = Token{Type: DOT, Literal: string(l.ch), Line: l.line, Column: l.column}
|
||||
// Check for ellipsis (...)
|
||||
if l.peekChar() == '.' && l.peekCharAt(2) == '.' {
|
||||
l.readChar() // skip first '.'
|
||||
l.readChar() // skip second '.'
|
||||
tok = Token{Type: ELLIPSIS, Literal: "...", Line: l.line, Column: l.column}
|
||||
} else {
|
||||
tok = Token{Type: DOT, Literal: string(l.ch), Line: l.line, Column: l.column}
|
||||
}
|
||||
case '(':
|
||||
tok = Token{Type: LPAREN, Literal: string(l.ch), Line: l.line, Column: l.column}
|
||||
case ')':
|
||||
|
129
parser/parser.go
129
parser/parser.go
@ -49,6 +49,7 @@ func NewParser(lexer *Lexer) *Parser {
|
||||
p.registerPrefix(LPAREN, p.parseGroupedExpression)
|
||||
p.registerPrefix(LBRACE, p.parseTableLiteral)
|
||||
p.registerPrefix(MINUS, p.parsePrefixExpression)
|
||||
p.registerPrefix(FN, p.parseFunctionLiteral)
|
||||
|
||||
p.infixParseFns = make(map[TokenType]func(Expression) Expression)
|
||||
p.registerInfix(PLUS, p.parseInfixExpression)
|
||||
@ -63,6 +64,7 @@ func NewParser(lexer *Lexer) *Parser {
|
||||
p.registerInfix(GT_EQ, p.parseInfixExpression)
|
||||
p.registerInfix(DOT, p.parseDotExpression)
|
||||
p.registerInfix(LBRACKET, p.parseIndexExpression)
|
||||
p.registerInfix(LPAREN, p.parseCallExpression)
|
||||
|
||||
// Read two tokens, so curToken and peekToken are both set
|
||||
p.nextToken()
|
||||
@ -121,6 +123,8 @@ func (p *Parser) parseStatement() Statement {
|
||||
return p.parseBreakStatement()
|
||||
case EXIT:
|
||||
return p.parseExitStatement()
|
||||
case RETURN:
|
||||
return p.parseReturnStatement()
|
||||
case ASSIGN:
|
||||
p.addError("assignment operator '=' without left-hand side identifier")
|
||||
return nil
|
||||
@ -214,10 +218,27 @@ func (p *Parser) parseExitStatement() *ExitStatement {
|
||||
return stmt
|
||||
}
|
||||
|
||||
// parseReturnStatement parses return statements
|
||||
func (p *Parser) parseReturnStatement() *ReturnStatement {
|
||||
stmt := &ReturnStatement{}
|
||||
|
||||
// Check if there's an optional expression after 'return'
|
||||
if p.canStartExpression(p.peekToken.Type) {
|
||||
p.nextToken() // move past 'return'
|
||||
stmt.Value = p.ParseExpression(LOWEST)
|
||||
if stmt.Value == nil {
|
||||
p.addError("expected expression after 'return'")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return stmt
|
||||
}
|
||||
|
||||
// canStartExpression checks if a token type can start an expression
|
||||
func (p *Parser) canStartExpression(tokenType TokenType) bool {
|
||||
switch tokenType {
|
||||
case IDENT, NUMBER, STRING, TRUE, FALSE, NIL, LPAREN, LBRACE, MINUS:
|
||||
case IDENT, NUMBER, STRING, TRUE, FALSE, NIL, LPAREN, LBRACE, MINUS, FN:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
@ -621,6 +642,74 @@ func (p *Parser) parseGroupedExpression() Expression {
|
||||
return exp
|
||||
}
|
||||
|
||||
func (p *Parser) parseFunctionLiteral() Expression {
|
||||
fn := &FunctionLiteral{}
|
||||
|
||||
if !p.expectPeek(LPAREN) {
|
||||
p.addError("expected '(' after 'fn'")
|
||||
return nil
|
||||
}
|
||||
|
||||
fn.Parameters, fn.Variadic = p.parseFunctionParameters()
|
||||
|
||||
if !p.expectPeek(RPAREN) {
|
||||
p.addError("expected ')' after function parameters")
|
||||
return nil
|
||||
}
|
||||
|
||||
p.nextToken() // move past ')'
|
||||
|
||||
// Parse function body
|
||||
fn.Body = p.parseBlockStatements(END)
|
||||
|
||||
if !p.curTokenIs(END) {
|
||||
p.addError("expected 'end' to close function")
|
||||
return nil
|
||||
}
|
||||
|
||||
return fn
|
||||
}
|
||||
|
||||
func (p *Parser) parseFunctionParameters() ([]string, bool) {
|
||||
var params []string
|
||||
var variadic bool
|
||||
|
||||
if p.peekTokenIs(RPAREN) {
|
||||
return params, false
|
||||
}
|
||||
|
||||
p.nextToken()
|
||||
|
||||
for {
|
||||
if p.curTokenIs(ELLIPSIS) {
|
||||
variadic = true
|
||||
break
|
||||
}
|
||||
|
||||
if !p.curTokenIs(IDENT) {
|
||||
p.addError("expected parameter name")
|
||||
return nil, false
|
||||
}
|
||||
|
||||
params = append(params, p.curToken.Literal)
|
||||
|
||||
if !p.peekTokenIs(COMMA) {
|
||||
break
|
||||
}
|
||||
|
||||
p.nextToken() // move to ','
|
||||
p.nextToken() // move past ','
|
||||
|
||||
// Check for ellipsis after comma
|
||||
if p.curTokenIs(ELLIPSIS) {
|
||||
variadic = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return params, variadic
|
||||
}
|
||||
|
||||
func (p *Parser) parseTableLiteral() Expression {
|
||||
table := &TableLiteral{}
|
||||
table.Pairs = []TablePair{}
|
||||
@ -725,6 +814,36 @@ func (p *Parser) parseDotExpression(left Expression) Expression {
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Parser) parseCallExpression(fn Expression) Expression {
|
||||
call := &CallExpression{Function: fn}
|
||||
call.Arguments = p.parseExpressionList(RPAREN)
|
||||
return call
|
||||
}
|
||||
|
||||
func (p *Parser) parseExpressionList(end TokenType) []Expression {
|
||||
var args []Expression
|
||||
|
||||
if p.peekTokenIs(end) {
|
||||
p.nextToken()
|
||||
return args
|
||||
}
|
||||
|
||||
p.nextToken()
|
||||
args = append(args, p.ParseExpression(LOWEST))
|
||||
|
||||
for p.peekTokenIs(COMMA) {
|
||||
p.nextToken()
|
||||
p.nextToken()
|
||||
args = append(args, p.ParseExpression(LOWEST))
|
||||
}
|
||||
|
||||
if !p.expectPeek(end) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return args
|
||||
}
|
||||
|
||||
func (p *Parser) parseIndexExpression(left Expression) Expression {
|
||||
p.nextToken() // move past '['
|
||||
|
||||
@ -776,7 +895,7 @@ func (p *Parser) expectPeekIdent() bool {
|
||||
// isKeyword checks if a token type is a keyword that can be used as identifier
|
||||
func (p *Parser) isKeyword(t TokenType) bool {
|
||||
switch t {
|
||||
case TRUE, FALSE, NIL, IF, THEN, ELSEIF, ELSE, END, ECHO, FOR, WHILE, IN, DO, BREAK, EXIT:
|
||||
case TRUE, FALSE, NIL, IF, THEN, ELSEIF, ELSE, END, ECHO, FOR, WHILE, IN, DO, BREAK, EXIT, FN, RETURN:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
@ -910,6 +1029,8 @@ func tokenTypeString(t TokenType) string {
|
||||
return "]"
|
||||
case COMMA:
|
||||
return ","
|
||||
case ELLIPSIS:
|
||||
return "..."
|
||||
case IF:
|
||||
return "if"
|
||||
case THEN:
|
||||
@ -934,6 +1055,10 @@ func tokenTypeString(t TokenType) string {
|
||||
return "break"
|
||||
case EXIT:
|
||||
return "exit"
|
||||
case FN:
|
||||
return "fn"
|
||||
case RETURN:
|
||||
return "return"
|
||||
case EOF:
|
||||
return "end of file"
|
||||
case ILLEGAL:
|
||||
|
705
parser/tests/functions_test.go
Normal file
705
parser/tests/functions_test.go
Normal file
@ -0,0 +1,705 @@
|
||||
package parser_test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"git.sharkk.net/Sharkk/Mako/parser"
|
||||
)
|
||||
|
||||
func TestBasicFunctionLiterals(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
paramCount int
|
||||
variadic bool
|
||||
bodyCount int
|
||||
description string
|
||||
}{
|
||||
{"fn() end", 0, false, 0, "empty function"},
|
||||
{"fn(a) echo a end", 1, false, 1, "single parameter"},
|
||||
{"fn(a, b) return a + b end", 2, false, 1, "two parameters"},
|
||||
{"fn(...) echo \"variadic\" end", 0, true, 1, "variadic only"},
|
||||
{"fn(a, b, ...) return a end", 2, true, 1, "mixed params and variadic"},
|
||||
{"fn(x, y) x = x + 1 y = y + 1 end", 2, false, 2, "multiple statements"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.description, func(t *testing.T) {
|
||||
l := parser.NewLexer(tt.input)
|
||||
p := parser.NewParser(l)
|
||||
expr := p.ParseExpression(parser.LOWEST)
|
||||
checkParserErrors(t, p)
|
||||
|
||||
fn, ok := expr.(*parser.FunctionLiteral)
|
||||
if !ok {
|
||||
t.Fatalf("expected FunctionLiteral, got %T", expr)
|
||||
}
|
||||
|
||||
if len(fn.Parameters) != tt.paramCount {
|
||||
t.Errorf("expected %d parameters, got %d", tt.paramCount, len(fn.Parameters))
|
||||
}
|
||||
|
||||
if fn.Variadic != tt.variadic {
|
||||
t.Errorf("expected variadic = %t, got %t", tt.variadic, fn.Variadic)
|
||||
}
|
||||
|
||||
if len(fn.Body) != tt.bodyCount {
|
||||
t.Errorf("expected %d body statements, got %d", tt.bodyCount, len(fn.Body))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFunctionParameters(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
params []string
|
||||
variadic bool
|
||||
desc string
|
||||
}{
|
||||
{"fn(a) end", []string{"a"}, false, "single param"},
|
||||
{"fn(a, b, c) end", []string{"a", "b", "c"}, false, "multiple params"},
|
||||
{"fn(...) end", []string{}, true, "only variadic"},
|
||||
{"fn(x, ...) end", []string{"x"}, true, "param with variadic"},
|
||||
{"fn(a, b, ...) end", []string{"a", "b"}, true, "multiple params with variadic"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
l := parser.NewLexer(tt.input)
|
||||
p := parser.NewParser(l)
|
||||
expr := p.ParseExpression(parser.LOWEST)
|
||||
checkParserErrors(t, p)
|
||||
|
||||
fn, ok := expr.(*parser.FunctionLiteral)
|
||||
if !ok {
|
||||
t.Fatalf("expected FunctionLiteral, got %T", expr)
|
||||
}
|
||||
|
||||
if len(fn.Parameters) != len(tt.params) {
|
||||
t.Fatalf("expected %d parameters, got %d", len(tt.params), len(fn.Parameters))
|
||||
}
|
||||
|
||||
for i, expected := range tt.params {
|
||||
if fn.Parameters[i] != expected {
|
||||
t.Errorf("parameter %d: expected %s, got %s", i, expected, fn.Parameters[i])
|
||||
}
|
||||
}
|
||||
|
||||
if fn.Variadic != tt.variadic {
|
||||
t.Errorf("expected variadic = %t, got %t", tt.variadic, fn.Variadic)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReturnStatements(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
hasValue bool
|
||||
desc string
|
||||
}{
|
||||
{"return", false, "return without value"},
|
||||
{"return 42", true, "return with number"},
|
||||
{"return \"hello\"", true, "return with string"},
|
||||
{"return x", true, "return with identifier"},
|
||||
{"return x + y", true, "return with expression"},
|
||||
{"return table.key", true, "return with member access"},
|
||||
{"return arr[0]", true, "return with index access"},
|
||||
{"return {a = 1, b = 2}", true, "return with table"},
|
||||
{"return fn(x) return x end", true, "return with function"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
l := parser.NewLexer(tt.input)
|
||||
p := parser.NewParser(l)
|
||||
program := p.ParseProgram()
|
||||
checkParserErrors(t, p)
|
||||
|
||||
if len(program.Statements) != 1 {
|
||||
t.Fatalf("expected 1 statement, got %d", len(program.Statements))
|
||||
}
|
||||
|
||||
stmt, ok := program.Statements[0].(*parser.ReturnStatement)
|
||||
if !ok {
|
||||
t.Fatalf("expected ReturnStatement, got %T", program.Statements[0])
|
||||
}
|
||||
|
||||
if tt.hasValue && stmt.Value == nil {
|
||||
t.Error("expected return value but got nil")
|
||||
}
|
||||
|
||||
if !tt.hasValue && stmt.Value != nil {
|
||||
t.Error("expected no return value but got one")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFunctionAssignments(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
desc string
|
||||
}{
|
||||
{"add = fn(a, b) return a + b end", "assign function to variable"},
|
||||
{"math.add = fn(x, y) return x + y end", "assign function to member"},
|
||||
{"funcs[\"add\"] = fn(a, b) return a + b end", "assign function to index"},
|
||||
{"callback = fn() echo \"called\" end", "assign simple function"},
|
||||
{"sum = fn(...) return 0 end", "assign variadic function"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
l := parser.NewLexer(tt.input)
|
||||
p := parser.NewParser(l)
|
||||
program := p.ParseProgram()
|
||||
checkParserErrors(t, p)
|
||||
|
||||
if len(program.Statements) != 1 {
|
||||
t.Fatalf("expected 1 statement, got %d", len(program.Statements))
|
||||
}
|
||||
|
||||
stmt, ok := program.Statements[0].(*parser.AssignStatement)
|
||||
if !ok {
|
||||
t.Fatalf("expected AssignStatement, got %T", program.Statements[0])
|
||||
}
|
||||
|
||||
_, ok = stmt.Value.(*parser.FunctionLiteral)
|
||||
if !ok {
|
||||
t.Fatalf("expected FunctionLiteral value, got %T", stmt.Value)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFunctionsInTables(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
desc string
|
||||
}{
|
||||
{"{add = fn(a, b) return a + b end}", "function in hash table"},
|
||||
{"{fn(x) return x end, fn(y) return y end}", "functions in array"},
|
||||
{"{math = {add = fn(a, b) return a + b end}}", "nested function in table"},
|
||||
{"{callback = fn() echo \"hi\" end, data = 42}", "mixed table with function"},
|
||||
{"{fn(...) return 0 end}", "variadic function in array"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
l := parser.NewLexer(tt.input)
|
||||
p := parser.NewParser(l)
|
||||
expr := p.ParseExpression(parser.LOWEST)
|
||||
checkParserErrors(t, p)
|
||||
|
||||
table, ok := expr.(*parser.TableLiteral)
|
||||
if !ok {
|
||||
t.Fatalf("expected TableLiteral, got %T", expr)
|
||||
}
|
||||
|
||||
// Verify at least one pair contains a function (or nested table with function)
|
||||
foundFunction := false
|
||||
for _, pair := range table.Pairs {
|
||||
if _, ok := pair.Value.(*parser.FunctionLiteral); ok {
|
||||
foundFunction = true
|
||||
break
|
||||
}
|
||||
// Check nested tables for functions
|
||||
if nestedTable, ok := pair.Value.(*parser.TableLiteral); ok {
|
||||
for _, nestedPair := range nestedTable.Pairs {
|
||||
if _, ok := nestedPair.Value.(*parser.FunctionLiteral); ok {
|
||||
foundFunction = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !foundFunction {
|
||||
t.Error("expected to find at least one function in table")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFunctionWithComplexBody(t *testing.T) {
|
||||
input := `fn(x, y)
|
||||
if x > y then
|
||||
return x
|
||||
else
|
||||
return y
|
||||
end
|
||||
echo "unreachable"
|
||||
end`
|
||||
|
||||
l := parser.NewLexer(input)
|
||||
p := parser.NewParser(l)
|
||||
expr := p.ParseExpression(parser.LOWEST)
|
||||
checkParserErrors(t, p)
|
||||
|
||||
fn, ok := expr.(*parser.FunctionLiteral)
|
||||
if !ok {
|
||||
t.Fatalf("expected FunctionLiteral, got %T", expr)
|
||||
}
|
||||
|
||||
if len(fn.Parameters) != 2 {
|
||||
t.Errorf("expected 2 parameters, got %d", len(fn.Parameters))
|
||||
}
|
||||
|
||||
if len(fn.Body) != 2 {
|
||||
t.Errorf("expected 2 body statements, got %d", len(fn.Body))
|
||||
}
|
||||
|
||||
// First statement should be if
|
||||
ifStmt, ok := fn.Body[0].(*parser.IfStatement)
|
||||
if !ok {
|
||||
t.Fatalf("expected IfStatement, got %T", fn.Body[0])
|
||||
}
|
||||
|
||||
// Check that if body contains return
|
||||
_, ok = ifStmt.Body[0].(*parser.ReturnStatement)
|
||||
if !ok {
|
||||
t.Fatalf("expected ReturnStatement in if body, got %T", ifStmt.Body[0])
|
||||
}
|
||||
|
||||
// Check that else body contains return
|
||||
_, ok = ifStmt.Else[0].(*parser.ReturnStatement)
|
||||
if !ok {
|
||||
t.Fatalf("expected ReturnStatement in else body, got %T", ifStmt.Else[0])
|
||||
}
|
||||
|
||||
// Second statement should be echo
|
||||
_, ok = fn.Body[1].(*parser.EchoStatement)
|
||||
if !ok {
|
||||
t.Fatalf("expected EchoStatement, got %T", fn.Body[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestNestedFunctions(t *testing.T) {
|
||||
input := `fn(x)
|
||||
inner = fn(y) return y * 2 end
|
||||
return inner(x)
|
||||
end`
|
||||
|
||||
l := parser.NewLexer(input)
|
||||
p := parser.NewParser(l)
|
||||
expr := p.ParseExpression(parser.LOWEST)
|
||||
checkParserErrors(t, p)
|
||||
|
||||
fn, ok := expr.(*parser.FunctionLiteral)
|
||||
if !ok {
|
||||
t.Fatalf("expected FunctionLiteral, got %T", expr)
|
||||
}
|
||||
|
||||
if len(fn.Body) != 2 {
|
||||
t.Fatalf("expected 2 body statements, got %d", len(fn.Body))
|
||||
}
|
||||
|
||||
// First statement: assignment of inner function
|
||||
assign, ok := fn.Body[0].(*parser.AssignStatement)
|
||||
if !ok {
|
||||
t.Fatalf("expected AssignStatement, got %T", fn.Body[0])
|
||||
}
|
||||
|
||||
innerFn, ok := assign.Value.(*parser.FunctionLiteral)
|
||||
if !ok {
|
||||
t.Fatalf("expected FunctionLiteral value, got %T", assign.Value)
|
||||
}
|
||||
|
||||
if len(innerFn.Parameters) != 1 {
|
||||
t.Errorf("expected 1 parameter in inner function, got %d", len(innerFn.Parameters))
|
||||
}
|
||||
|
||||
// Second statement: return
|
||||
_, ok = fn.Body[1].(*parser.ReturnStatement)
|
||||
if !ok {
|
||||
t.Fatalf("expected ReturnStatement, got %T", fn.Body[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestFunctionInLoop(t *testing.T) {
|
||||
input := `for i = 1, 10 do
|
||||
callback = fn(x) return x + i end
|
||||
echo callback(i)
|
||||
end`
|
||||
|
||||
l := parser.NewLexer(input)
|
||||
p := parser.NewParser(l)
|
||||
program := p.ParseProgram()
|
||||
checkParserErrors(t, p)
|
||||
|
||||
if len(program.Statements) != 1 {
|
||||
t.Fatalf("expected 1 statement, got %d", len(program.Statements))
|
||||
}
|
||||
|
||||
forStmt, ok := program.Statements[0].(*parser.ForStatement)
|
||||
if !ok {
|
||||
t.Fatalf("expected ForStatement, got %T", program.Statements[0])
|
||||
}
|
||||
|
||||
if len(forStmt.Body) != 2 {
|
||||
t.Fatalf("expected 2 body statements, got %d", len(forStmt.Body))
|
||||
}
|
||||
|
||||
// First: function assignment
|
||||
assign, ok := forStmt.Body[0].(*parser.AssignStatement)
|
||||
if !ok {
|
||||
t.Fatalf("expected AssignStatement, got %T", forStmt.Body[0])
|
||||
}
|
||||
|
||||
_, ok = assign.Value.(*parser.FunctionLiteral)
|
||||
if !ok {
|
||||
t.Fatalf("expected FunctionLiteral value, got %T", assign.Value)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFunctionErrors(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expectedError string
|
||||
desc string
|
||||
}{
|
||||
{"fn", "expected '(' after 'fn'", "missing parentheses"},
|
||||
{"fn(", "expected ')' after function parameters", "missing closing paren"},
|
||||
{"fn() echo x", "expected 'end' to close function", "missing end"},
|
||||
{"fn(123) end", "expected parameter name", "invalid parameter name"},
|
||||
{"fn(a,) end", "expected parameter name", "trailing comma"},
|
||||
{"fn(a, 123) end", "expected parameter name", "invalid second parameter"},
|
||||
{"fn(..., a) end", "expected next token to be ), got , instead", "parameter after ellipsis"},
|
||||
{"fn(a, ..., b) end", "expected next token to be ), got , instead", "parameter after ellipsis in middle"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
l := parser.NewLexer(tt.input)
|
||||
p := parser.NewParser(l)
|
||||
p.ParseExpression(parser.LOWEST)
|
||||
|
||||
if !p.HasErrors() {
|
||||
t.Fatal("expected parsing errors")
|
||||
}
|
||||
|
||||
errors := p.Errors()
|
||||
found := false
|
||||
for _, err := range errors {
|
||||
if strings.Contains(err.Message, tt.expectedError) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
errorMsgs := make([]string, len(errors))
|
||||
for i, err := range errors {
|
||||
errorMsgs[i] = err.Message
|
||||
}
|
||||
t.Errorf("expected error containing %q, got %v", tt.expectedError, errorMsgs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReturnErrors(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expectedError string
|
||||
desc string
|
||||
}{
|
||||
{"return +", "unexpected token '+'", "return with invalid expression"},
|
||||
{"return (", "unexpected end of input", "return with incomplete expression"},
|
||||
{"return 1 +", "expected expression after operator '+'", "return with incomplete infix"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
l := parser.NewLexer(tt.input)
|
||||
p := parser.NewParser(l)
|
||||
p.ParseProgram()
|
||||
|
||||
if !p.HasErrors() {
|
||||
t.Fatal("expected parsing errors")
|
||||
}
|
||||
|
||||
errors := p.Errors()
|
||||
found := false
|
||||
for _, err := range errors {
|
||||
if strings.Contains(err.Message, tt.expectedError) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
errorMsgs := make([]string, len(errors))
|
||||
for i, err := range errors {
|
||||
errorMsgs[i] = err.Message
|
||||
}
|
||||
t.Errorf("expected error containing %q, got %v", tt.expectedError, errorMsgs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFunctionStringRepresentation(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
contains []string
|
||||
desc string
|
||||
}{
|
||||
{
|
||||
"fn() end",
|
||||
[]string{"fn()", "end"},
|
||||
"empty function",
|
||||
},
|
||||
{
|
||||
"fn(a, b) return a + b end",
|
||||
[]string{"fn(a, b)", "return (a + b)", "end"},
|
||||
"function with params and return",
|
||||
},
|
||||
{
|
||||
"fn(...) echo \"variadic\" end",
|
||||
[]string{"fn(...)", "echo \"variadic\"", "end"},
|
||||
"variadic function",
|
||||
},
|
||||
{
|
||||
"fn(a, b, ...) return a end",
|
||||
[]string{"fn(a, b, ...)", "return a", "end"},
|
||||
"mixed params and variadic",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
l := parser.NewLexer(tt.input)
|
||||
p := parser.NewParser(l)
|
||||
expr := p.ParseExpression(parser.LOWEST)
|
||||
checkParserErrors(t, p)
|
||||
|
||||
fnStr := expr.String()
|
||||
for _, contain := range tt.contains {
|
||||
if !strings.Contains(fnStr, contain) {
|
||||
t.Errorf("expected function string to contain %q, got:\n%s", contain, fnStr)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReturnStringRepresentation(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
desc string
|
||||
}{
|
||||
{"return", "return", "simple return"},
|
||||
{"return 42", "return 42.00", "return with number"},
|
||||
{"return \"hello\"", "return \"hello\"", "return with string"},
|
||||
{"return x + 1", "return (x + 1.00)", "return with expression"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
l := parser.NewLexer(tt.input)
|
||||
p := parser.NewParser(l)
|
||||
program := p.ParseProgram()
|
||||
checkParserErrors(t, p)
|
||||
|
||||
result := strings.TrimSpace(program.String())
|
||||
if result != tt.expected {
|
||||
t.Errorf("expected %q, got %q", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestComplexFunctionProgram(t *testing.T) {
|
||||
input := `math = {
|
||||
add = fn(a, b) return a + b end,
|
||||
multiply = fn(x, y) return x * y end,
|
||||
max = fn(a, b)
|
||||
if a > b then
|
||||
return a
|
||||
else
|
||||
return b
|
||||
end
|
||||
end
|
||||
}
|
||||
|
||||
result = math.add(5, 3)
|
||||
echo result
|
||||
|
||||
calculator = fn(op, ...)
|
||||
if op == "add" then
|
||||
return fn(a, b) return a + b end
|
||||
elseif op == "sub" then
|
||||
return fn(a, b) return a - b end
|
||||
else
|
||||
return nil
|
||||
end
|
||||
end
|
||||
|
||||
adder = calculator("add")
|
||||
echo adder`
|
||||
|
||||
l := parser.NewLexer(input)
|
||||
p := parser.NewParser(l)
|
||||
program := p.ParseProgram()
|
||||
checkParserErrors(t, p)
|
||||
|
||||
if len(program.Statements) != 6 {
|
||||
t.Fatalf("expected 6 statements, got %d", len(program.Statements))
|
||||
}
|
||||
|
||||
// First: table with functions
|
||||
mathAssign, ok := program.Statements[0].(*parser.AssignStatement)
|
||||
if !ok {
|
||||
t.Fatalf("statement 0: expected AssignStatement, got %T", program.Statements[0])
|
||||
}
|
||||
|
||||
table, ok := mathAssign.Value.(*parser.TableLiteral)
|
||||
if !ok {
|
||||
t.Fatalf("expected TableLiteral, got %T", mathAssign.Value)
|
||||
}
|
||||
|
||||
if len(table.Pairs) != 3 {
|
||||
t.Errorf("expected 3 functions in math table, got %d", len(table.Pairs))
|
||||
}
|
||||
|
||||
// Verify all pairs contain functions
|
||||
for i, pair := range table.Pairs {
|
||||
_, ok := pair.Value.(*parser.FunctionLiteral)
|
||||
if !ok {
|
||||
t.Errorf("pair %d: expected FunctionLiteral, got %T", i, pair.Value)
|
||||
}
|
||||
}
|
||||
|
||||
// Second: result assignment (function call would be handled by interpreter)
|
||||
_, ok = program.Statements[1].(*parser.AssignStatement)
|
||||
if !ok {
|
||||
t.Fatalf("statement 1: expected AssignStatement, got %T", program.Statements[1])
|
||||
}
|
||||
|
||||
// Third: echo
|
||||
_, ok = program.Statements[2].(*parser.EchoStatement)
|
||||
if !ok {
|
||||
t.Fatalf("statement 2: expected EchoStatement, got %T", program.Statements[2])
|
||||
}
|
||||
|
||||
// Fourth: calculator function assignment
|
||||
calcAssign, ok := program.Statements[3].(*parser.AssignStatement)
|
||||
if !ok {
|
||||
t.Fatalf("statement 3: expected AssignStatement, got %T", program.Statements[3])
|
||||
}
|
||||
|
||||
calcFn, ok := calcAssign.Value.(*parser.FunctionLiteral)
|
||||
if !ok {
|
||||
t.Fatalf("expected FunctionLiteral, got %T", calcAssign.Value)
|
||||
}
|
||||
|
||||
if !calcFn.Variadic {
|
||||
t.Error("expected calculator function to be variadic")
|
||||
}
|
||||
|
||||
// Fifth: adder assignment
|
||||
_, ok = program.Statements[4].(*parser.AssignStatement)
|
||||
if !ok {
|
||||
t.Fatalf("statement 4: expected AssignStatement, got %T", program.Statements[4])
|
||||
}
|
||||
|
||||
// Sixth: echo adder
|
||||
_, ok = program.Statements[5].(*parser.EchoStatement)
|
||||
if !ok {
|
||||
t.Fatalf("statement 5: expected EchoStatement, got %T", program.Statements[5])
|
||||
}
|
||||
}
|
||||
|
||||
func TestFunctionInConditional(t *testing.T) {
|
||||
input := `if condition then
|
||||
handler = fn(x) return x * 2 end
|
||||
else
|
||||
handler = fn(x) return x + 1 end
|
||||
end
|
||||
|
||||
for i = 1, 10 do
|
||||
if i > 5 then
|
||||
result = fn() return "high" end
|
||||
else
|
||||
result = fn() return "low" end
|
||||
end
|
||||
echo result()
|
||||
end`
|
||||
|
||||
l := parser.NewLexer(input)
|
||||
p := parser.NewParser(l)
|
||||
program := p.ParseProgram()
|
||||
checkParserErrors(t, p)
|
||||
|
||||
if len(program.Statements) != 2 {
|
||||
t.Fatalf("expected 2 statements, got %d", len(program.Statements))
|
||||
}
|
||||
|
||||
// First: if statement with function assignments
|
||||
ifStmt, ok := program.Statements[0].(*parser.IfStatement)
|
||||
if !ok {
|
||||
t.Fatalf("statement 0: expected IfStatement, got %T", program.Statements[0])
|
||||
}
|
||||
|
||||
// Check if body has function assignment
|
||||
ifAssign, ok := ifStmt.Body[0].(*parser.AssignStatement)
|
||||
if !ok {
|
||||
t.Fatalf("if body: expected AssignStatement, got %T", ifStmt.Body[0])
|
||||
}
|
||||
|
||||
_, ok = ifAssign.Value.(*parser.FunctionLiteral)
|
||||
if !ok {
|
||||
t.Fatalf("if body: expected FunctionLiteral, got %T", ifAssign.Value)
|
||||
}
|
||||
|
||||
// Check else body has function assignment
|
||||
elseAssign, ok := ifStmt.Else[0].(*parser.AssignStatement)
|
||||
if !ok {
|
||||
t.Fatalf("else body: expected AssignStatement, got %T", ifStmt.Else[0])
|
||||
}
|
||||
|
||||
_, ok = elseAssign.Value.(*parser.FunctionLiteral)
|
||||
if !ok {
|
||||
t.Fatalf("else body: expected FunctionLiteral, got %T", elseAssign.Value)
|
||||
}
|
||||
|
||||
// Second: for loop with nested conditionals containing functions
|
||||
forStmt, ok := program.Statements[1].(*parser.ForStatement)
|
||||
if !ok {
|
||||
t.Fatalf("statement 1: expected ForStatement, got %T", program.Statements[1])
|
||||
}
|
||||
|
||||
if len(forStmt.Body) != 2 {
|
||||
t.Fatalf("expected 2 statements in for body, got %d", len(forStmt.Body))
|
||||
}
|
||||
|
||||
// Check nested if contains function assignments
|
||||
nestedIf, ok := forStmt.Body[0].(*parser.IfStatement)
|
||||
if !ok {
|
||||
t.Fatalf("for body[0]: expected IfStatement, got %T", forStmt.Body[0])
|
||||
}
|
||||
|
||||
// Verify both branches assign functions
|
||||
nestedIfAssign, ok := nestedIf.Body[0].(*parser.AssignStatement)
|
||||
if !ok {
|
||||
t.Fatalf("nested if body: expected AssignStatement, got %T", nestedIf.Body[0])
|
||||
}
|
||||
|
||||
_, ok = nestedIfAssign.Value.(*parser.FunctionLiteral)
|
||||
if !ok {
|
||||
t.Fatalf("nested if body: expected FunctionLiteral, got %T", nestedIfAssign.Value)
|
||||
}
|
||||
|
||||
nestedElseAssign, ok := nestedIf.Else[0].(*parser.AssignStatement)
|
||||
if !ok {
|
||||
t.Fatalf("nested else body: expected AssignStatement, got %T", nestedIf.Else[0])
|
||||
}
|
||||
|
||||
_, ok = nestedElseAssign.Value.(*parser.FunctionLiteral)
|
||||
if !ok {
|
||||
t.Fatalf("nested else body: expected FunctionLiteral, got %T", nestedElseAssign.Value)
|
||||
}
|
||||
}
|
@ -36,6 +36,7 @@ const (
|
||||
LBRACKET // [
|
||||
RBRACKET // ]
|
||||
COMMA // ,
|
||||
ELLIPSIS // ...
|
||||
|
||||
// Keywords
|
||||
IF
|
||||
@ -50,6 +51,8 @@ const (
|
||||
DO
|
||||
BREAK
|
||||
EXIT
|
||||
FN
|
||||
RETURN
|
||||
|
||||
// Special
|
||||
EOF
|
||||
@ -93,6 +96,7 @@ var precedences = map[TokenType]Precedence{
|
||||
STAR: PRODUCT,
|
||||
DOT: MEMBER,
|
||||
LBRACKET: MEMBER,
|
||||
LPAREN: CALL,
|
||||
}
|
||||
|
||||
// lookupIdent checks if an identifier is a keyword
|
||||
@ -113,6 +117,8 @@ func lookupIdent(ident string) TokenType {
|
||||
"do": DO,
|
||||
"break": BREAK,
|
||||
"exit": EXIT,
|
||||
"fn": FN,
|
||||
"return": RETURN,
|
||||
}
|
||||
|
||||
if tok, ok := keywords[ident]; ok {
|
||||
|
Loading…
x
Reference in New Issue
Block a user