diff --git a/parser/parser_test.go b/parser/parser_test.go new file mode 100644 index 0000000..d94ddb8 --- /dev/null +++ b/parser/parser_test.go @@ -0,0 +1,331 @@ +package parser + +import ( + "testing" +) + +func TestLiterals(t *testing.T) { + tests := []struct { + input string + expected any + }{ + {"42", 42.0}, + {"3.14", 3.14}, + {`"hello"`, "hello"}, + {"true", true}, + {"false", false}, + {"nil", nil}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + l := NewLexer(tt.input) + p := NewParser(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + if len(program.Statements) != 0 { + t.Fatalf("expected 0 statements for literal, got %d", len(program.Statements)) + } + + // Parse as expression + l = NewLexer(tt.input) + p = NewParser(l) + expr := p.parseExpression(LOWEST) + + switch expected := tt.expected.(type) { + case float64: + testNumberLiteral(t, expr, expected) + case string: + testStringLiteral(t, expr, expected) + case bool: + testBooleanLiteral(t, expr, expected) + case nil: + testNilLiteral(t, expr) + } + }) + } +} + +func TestAssignStatements(t *testing.T) { + tests := []struct { + input string + expectedIdentifier string + expectedValue any + isExpression bool // true if expectedValue is expression string representation + }{ + {"x = 42", "x", 42.0, false}, + {"name = \"test\"", "name", "test", false}, + {"flag = true", "flag", true, false}, + {"ptr = nil", "ptr", nil, false}, + {"result = 3 + 4", "result", "(3.00 + 4.00)", true}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + l := NewLexer(tt.input) + p := NewParser(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + if len(program.Statements) != 1 { + t.Fatalf("expected 1 statement, got %d", len(program.Statements)) + } + + stmt, ok := program.Statements[0].(*AssignStatement) + if !ok { + t.Fatalf("expected AssignStatement, got %T", program.Statements[0]) + } + + if stmt.Name.Value != tt.expectedIdentifier { + t.Errorf("expected identifier %s, got %s", tt.expectedIdentifier, stmt.Name.Value) + } + + if tt.isExpression { + // Test the string representation of the expression + if stmt.Value.String() != tt.expectedValue.(string) { + t.Errorf("expected expression %s, got %s", tt.expectedValue.(string), stmt.Value.String()) + } + } else { + // Test the actual value based on type + switch expected := tt.expectedValue.(type) { + case float64: + testNumberLiteral(t, stmt.Value, expected) + case string: + testStringLiteral(t, stmt.Value, expected) + case bool: + testBooleanLiteral(t, stmt.Value, expected) + case nil: + testNilLiteral(t, stmt.Value) + } + } + }) + } +} + +func TestInfixExpressions(t *testing.T) { + tests := []struct { + input string + leftValue any + operator string + rightValue any + }{ + {"5 + 5", 5.0, "+", 5.0}, + {"5 - 5", 5.0, "-", 5.0}, + {"5 * 5", 5.0, "*", 5.0}, + {"5 / 5", 5.0, "/", 5.0}, + {"true + false", true, "+", false}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + l := NewLexer(tt.input) + p := NewParser(l) + expr := p.parseExpression(LOWEST) + checkParserErrors(t, p) + + testInfixExpression(t, expr, tt.leftValue, tt.operator, tt.rightValue) + }) + } +} + +func TestOperatorPrecedence(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {"1 + 2 * 3", "(1.00 + (2.00 * 3.00))"}, + {"2 * 3 + 4", "((2.00 * 3.00) + 4.00)"}, + {"(1 + 2) * 3", "((1.00 + 2.00) * 3.00)"}, + {"1 + 2 - 3", "((1.00 + 2.00) - 3.00)"}, + {"2 * 3 / 4", "((2.00 * 3.00) / 4.00)"}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + l := NewLexer(tt.input) + p := NewParser(l) + expr := p.parseExpression(LOWEST) + checkParserErrors(t, p) + + if expr.String() != tt.expected { + t.Errorf("expected %s, got %s", tt.expected, expr.String()) + } + }) + } +} + +func TestParsingErrors(t *testing.T) { + tests := []struct { + input string + expectedError string + }{ + {"x =", "no prefix parse function"}, + {"= 5", "no prefix parse function"}, + {"(1 + 2", "expected next token to be"}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + l := NewLexer(tt.input) + p := NewParser(l) + + if tt.input == "x =" { + p.ParseProgram() + } else { + // Parse as expression to catch syntax errors + p.parseExpression(LOWEST) + } + + errors := p.Errors() + if len(errors) == 0 { + t.Fatalf("expected parsing errors, got none") + } + + found := false + for _, err := range errors { + if containsSubstring(err, tt.expectedError) { + found = true + break + } + } + + if !found { + t.Errorf("expected error containing %q, got %v", tt.expectedError, errors) + } + }) + } +} + +func TestProgram(t *testing.T) { + input := `x = 42 +y = "hello" +z = true + false` + + l := NewLexer(input) + p := NewParser(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + if len(program.Statements) != 3 { + t.Fatalf("expected 3 statements, got %d", len(program.Statements)) + } + + expectedIdentifiers := []string{"x", "y", "z"} + for i, expectedIdent := range expectedIdentifiers { + stmt, ok := program.Statements[i].(*AssignStatement) + if !ok { + t.Fatalf("statement %d: expected AssignStatement, got %T", i, program.Statements[i]) + } + + if stmt.Name.Value != expectedIdent { + t.Errorf("statement %d: expected identifier %s, got %s", i, expectedIdent, stmt.Name.Value) + } + } +} + +// Helper functions for testing specific node types +func testNumberLiteral(t *testing.T, expr Expression, expected float64) { + t.Helper() + num, ok := expr.(*NumberLiteral) + if !ok { + t.Fatalf("expected NumberLiteral, got %T", expr) + } + if num.Value != expected { + t.Errorf("expected %f, got %f", expected, num.Value) + } +} + +func testStringLiteral(t *testing.T, expr Expression, expected string) { + t.Helper() + str, ok := expr.(*StringLiteral) + if !ok { + t.Fatalf("expected StringLiteral, got %T", expr) + } + if str.Value != expected { + t.Errorf("expected %s, got %s", expected, str.Value) + } +} + +func testBooleanLiteral(t *testing.T, expr Expression, expected bool) { + t.Helper() + boolean, ok := expr.(*BooleanLiteral) + if !ok { + t.Fatalf("expected BooleanLiteral, got %T", expr) + } + if boolean.Value != expected { + t.Errorf("expected %t, got %t", expected, boolean.Value) + } +} + +func testNilLiteral(t *testing.T, expr Expression) { + t.Helper() + _, ok := expr.(*NilLiteral) + if !ok { + t.Fatalf("expected NilLiteral, got %T", expr) + } +} + +func testIdentifier(t *testing.T, expr Expression, expected string) { + t.Helper() + ident, ok := expr.(*Identifier) + if !ok { + t.Fatalf("expected Identifier, got %T", expr) + } + if ident.Value != expected { + t.Errorf("expected %s, got %s", expected, ident.Value) + } +} + +func testInfixExpression(t *testing.T, expr Expression, left any, operator string, right any) { + t.Helper() + infix, ok := expr.(*InfixExpression) + if !ok { + t.Fatalf("expected InfixExpression, got %T", expr) + } + + if infix.Operator != operator { + t.Errorf("expected operator %s, got %s", operator, infix.Operator) + } + + switch leftVal := left.(type) { + case float64: + testNumberLiteral(t, infix.Left, leftVal) + case string: + testIdentifier(t, infix.Left, leftVal) + case bool: + testBooleanLiteral(t, infix.Left, leftVal) + } + + switch rightVal := right.(type) { + case float64: + testNumberLiteral(t, infix.Right, rightVal) + case string: + testIdentifier(t, infix.Right, rightVal) + case bool: + testBooleanLiteral(t, infix.Right, rightVal) + } +} + +func checkParserErrors(t *testing.T, p *Parser) { + t.Helper() + errors := p.Errors() + if len(errors) == 0 { + return + } + + t.Errorf("parser has %d errors", len(errors)) + for _, msg := range errors { + t.Errorf("parser error: %q", msg) + } + t.FailNow() +} + +func containsSubstring(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +}