From 119cfceccecca3274979427502ae98473fe0e8d3 Mon Sep 17 00:00:00 2001 From: Sky Johnson Date: Tue, 10 Jun 2025 11:12:46 -0500 Subject: [PATCH] comparison and minus --- parser/ast.go | 11 ++ parser/lexer.go | 32 +++++- parser/parser.go | 35 +++++++ parser/tests/errors_test.go | 123 +++++++++++++++++++++++ parser/tests/expressions_test.go | 166 +++++++++++++++++++++++++++++++ parser/token.go | 26 ++++- 6 files changed, 387 insertions(+), 6 deletions(-) diff --git a/parser/ast.go b/parser/ast.go index 6313197..b19911b 100644 --- a/parser/ast.go +++ b/parser/ast.go @@ -200,6 +200,17 @@ type NilLiteral struct{} func (nl *NilLiteral) expressionNode() {} func (nl *NilLiteral) String() string { return "nil" } +// PrefixExpression represents prefix operations like -x +type PrefixExpression struct { + Operator string + Right Expression +} + +func (pe *PrefixExpression) expressionNode() {} +func (pe *PrefixExpression) String() string { + return fmt.Sprintf("(%s%s)", pe.Operator, pe.Right.String()) +} + // InfixExpression represents binary operations type InfixExpression struct { Left Expression diff --git a/parser/lexer.go b/parser/lexer.go index f6b21b9..e0b0297 100644 --- a/parser/lexer.go +++ b/parser/lexer.go @@ -209,7 +209,37 @@ func (l *Lexer) NextToken() Token { switch l.ch { case '=': - tok = Token{Type: ASSIGN, Literal: string(l.ch), Line: l.line, Column: l.column} + if l.peekChar() == '=' { + ch := l.ch + l.readChar() + tok = Token{Type: EQ, Literal: string(ch) + string(l.ch), Line: l.line, Column: l.column} + } else { + tok = Token{Type: ASSIGN, Literal: string(l.ch), Line: l.line, Column: l.column} + } + case '!': + if l.peekChar() == '=' { + ch := l.ch + l.readChar() + tok = Token{Type: NOT_EQ, Literal: string(ch) + string(l.ch), Line: l.line, Column: l.column} + } else { + tok = Token{Type: ILLEGAL, Literal: string(l.ch), Line: l.line, Column: l.column} + } + case '<': + if l.peekChar() == '=' { + ch := l.ch + l.readChar() + tok = Token{Type: LT_EQ, Literal: string(ch) + string(l.ch), Line: l.line, Column: l.column} + } else { + tok = Token{Type: LT, Literal: string(l.ch), Line: l.line, Column: l.column} + } + case '>': + if l.peekChar() == '=' { + ch := l.ch + l.readChar() + tok = Token{Type: GT_EQ, Literal: string(ch) + string(l.ch), Line: l.line, Column: l.column} + } else { + tok = Token{Type: GT, Literal: string(l.ch), Line: l.line, Column: l.column} + } case '+': tok = Token{Type: PLUS, Literal: string(l.ch), Line: l.line, Column: l.column} case '-': diff --git a/parser/parser.go b/parser/parser.go index ef6a6f6..b25ded0 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -48,12 +48,19 @@ func NewParser(lexer *Lexer) *Parser { p.registerPrefix(NIL, p.parseNilLiteral) p.registerPrefix(LPAREN, p.parseGroupedExpression) p.registerPrefix(LBRACE, p.parseTableLiteral) + p.registerPrefix(MINUS, p.parsePrefixExpression) p.infixParseFns = make(map[TokenType]func(Expression) Expression) p.registerInfix(PLUS, p.parseInfixExpression) p.registerInfix(MINUS, p.parseInfixExpression) p.registerInfix(SLASH, p.parseInfixExpression) p.registerInfix(STAR, p.parseInfixExpression) + p.registerInfix(EQ, p.parseInfixExpression) + p.registerInfix(NOT_EQ, p.parseInfixExpression) + p.registerInfix(LT, p.parseInfixExpression) + p.registerInfix(GT, p.parseInfixExpression) + p.registerInfix(LT_EQ, p.parseInfixExpression) + p.registerInfix(GT_EQ, p.parseInfixExpression) p.registerInfix(DOT, p.parseDotExpression) p.registerInfix(LBRACKET, p.parseIndexExpression) @@ -514,6 +521,22 @@ func (p *Parser) parseNilLiteral() Expression { return &NilLiteral{} } +func (p *Parser) parsePrefixExpression() Expression { + expression := &PrefixExpression{ + Operator: p.curToken.Literal, + } + + p.nextToken() + + expression.Right = p.ParseExpression(PREFIX) + if expression.Right == nil { + p.addError(fmt.Sprintf("expected expression after prefix operator '%s'", expression.Operator)) + return nil + } + + return expression +} + func (p *Parser) parseGroupedExpression() Expression { p.nextToken() @@ -792,6 +815,18 @@ func tokenTypeString(t TokenType) string { return "/" case DOT: return "." + case EQ: + return "==" + case NOT_EQ: + return "!=" + case LT: + return "<" + case GT: + return ">" + case LT_EQ: + return "<=" + case GT_EQ: + return ">=" case LPAREN: return "(" case RPAREN: diff --git a/parser/tests/errors_test.go b/parser/tests/errors_test.go index 71bbf20..174ed7c 100644 --- a/parser/tests/errors_test.go +++ b/parser/tests/errors_test.go @@ -210,3 +210,126 @@ func TestTokenTypeStringWithEcho(t *testing.T) { }) } } + +func TestPrefixOperatorErrors(t *testing.T) { + tests := []struct { + input string + expectedError string + desc string + }{ + {"-", "expected expression after prefix operator '-'", "minus without operand"}, + {"-(", "unexpected end of input", "minus with incomplete expression"}, + {"-+", "unexpected operator '+'", "minus followed by plus"}, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + l := parser.NewLexer(tt.input) + p := parser.NewParser(l) + p.ParseExpression(parser.LOWEST) + + if !p.HasErrors() { + t.Fatal("expected parsing errors") + } + + errors := p.Errors() + found := false + for _, err := range errors { + if strings.Contains(err.Message, tt.expectedError) { + found = true + break + } + } + + if !found { + errorMsgs := make([]string, len(errors)) + for i, err := range errors { + errorMsgs[i] = err.Message + } + t.Errorf("expected error containing %q, got %v", tt.expectedError, errorMsgs) + } + }) + } +} + +func TestComparisonOperatorErrors(t *testing.T) { + tests := []struct { + input string + expectedError string + desc string + }{ + {"x ==", "expected expression after operator '=='", "== without right operand"}, + {"!= 5", "unexpected token '!='", "!= without left operand"}, + {"< y", "unexpected token '<'", "< without left operand"}, + {"> z", "unexpected token '>'", "> without left operand"}, + {"<= 10", "unexpected token '<='", "<= without left operand"}, + {">= 20", "unexpected token '>='", ">= without left operand"}, + {"x !=", "expected expression after operator '!='", "!= without right operand"}, + {"a <", "expected expression after operator '<'", "< without right operand"}, + {"b >", "expected expression after operator '>'", "> without right operand"}, + {"c <=", "expected expression after operator '<='", "<= without right operand"}, + {"d >=", "expected expression after operator '>='", ">= without right operand"}, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + l := parser.NewLexer(tt.input) + p := parser.NewParser(l) + p.ParseExpression(parser.LOWEST) + + if !p.HasErrors() { + t.Fatal("expected parsing errors") + } + + errors := p.Errors() + found := false + for _, err := range errors { + if strings.Contains(err.Message, tt.expectedError) { + found = true + break + } + } + + if !found { + errorMsgs := make([]string, len(errors)) + for i, err := range errors { + errorMsgs[i] = err.Message + } + t.Errorf("expected error containing %q, got %v", tt.expectedError, errorMsgs) + } + }) + } +} + +func TestTokenTypeStringWithComparisons(t *testing.T) { + tests := []struct { + input string + expectedMessage string + }{ + {"!= 5", "Parse error at line 1, column 1: unexpected token '!=' (near '!=')"}, + {"< y", "Parse error at line 1, column 1: unexpected token '<' (near '<')"}, + {">= 20", "Parse error at line 1, column 1: unexpected token '>=' (near '>=')"}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + l := parser.NewLexer(tt.input) + p := parser.NewParser(l) + p.ParseExpression(parser.LOWEST) + + if !p.HasErrors() { + t.Fatal("expected parsing errors") + } + + errors := p.Errors() + if len(errors) == 0 { + t.Fatal("expected at least one error") + } + + errorStr := errors[0].Error() + if !strings.Contains(errorStr, "Parse error at line") { + t.Errorf("expected formatted error message, got: %s", errorStr) + } + }) + } +} diff --git a/parser/tests/expressions_test.go b/parser/tests/expressions_test.go index 0f4860f..b5d1e2a 100644 --- a/parser/tests/expressions_test.go +++ b/parser/tests/expressions_test.go @@ -6,6 +6,113 @@ import ( "git.sharkk.net/Sharkk/Mako/parser" ) +func TestPrefixExpressions(t *testing.T) { + tests := []struct { + input string + operator string + value any + }{ + {"-5", "-", 5.0}, + {"-x", "-", "x"}, + {"-true", "-", true}, + {"-(1 + 2)", "-", "(1.00 + 2.00)"}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + l := parser.NewLexer(tt.input) + p := parser.NewParser(l) + expr := p.ParseExpression(parser.LOWEST) + checkParserErrors(t, p) + + prefix, ok := expr.(*parser.PrefixExpression) + if !ok { + t.Fatalf("expected PrefixExpression, got %T", expr) + } + + if prefix.Operator != tt.operator { + t.Errorf("expected operator %s, got %s", tt.operator, prefix.Operator) + } + + switch expected := tt.value.(type) { + case float64: + testNumberLiteral(t, prefix.Right, expected) + case string: + if expected == "x" { + testIdentifier(t, prefix.Right, expected) + } else { + // It's an expression string + if prefix.Right.String() != expected { + t.Errorf("expected %s, got %s", expected, prefix.Right.String()) + } + } + case bool: + testBooleanLiteral(t, prefix.Right, expected) + } + }) + } +} + +func TestComparisonExpressions(t *testing.T) { + tests := []struct { + input string + leftValue any + operator string + rightValue any + }{ + {"1 == 1", 1.0, "==", 1.0}, + {"1 != 2", 1.0, "!=", 2.0}, + {"x < y", "x", "<", "y"}, + {"a > b", "a", ">", "b"}, + {"5 <= 10", 5.0, "<=", 10.0}, + {"10 >= 5", 10.0, ">=", 5.0}, + {"true == false", true, "==", false}, + {"nil != x", nil, "!=", "x"}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + l := parser.NewLexer(tt.input) + p := parser.NewParser(l) + expr := p.ParseExpression(parser.LOWEST) + checkParserErrors(t, p) + + infix, ok := expr.(*parser.InfixExpression) + if !ok { + t.Fatalf("expected InfixExpression, got %T", expr) + } + + if infix.Operator != tt.operator { + t.Errorf("expected operator %s, got %s", tt.operator, infix.Operator) + } + + // Test left operand + switch leftVal := tt.leftValue.(type) { + case float64: + testNumberLiteral(t, infix.Left, leftVal) + case string: + testIdentifier(t, infix.Left, leftVal) + case bool: + testBooleanLiteral(t, infix.Left, leftVal) + case nil: + testNilLiteral(t, infix.Left) + } + + // Test right operand + switch rightVal := tt.rightValue.(type) { + case float64: + testNumberLiteral(t, infix.Right, rightVal) + case string: + testIdentifier(t, infix.Right, rightVal) + case bool: + testBooleanLiteral(t, infix.Right, rightVal) + case nil: + testNilLiteral(t, infix.Right) + } + }) + } +} + func TestInfixExpressions(t *testing.T) { tests := []struct { input string @@ -37,11 +144,38 @@ func TestOperatorPrecedence(t *testing.T) { input string expected string }{ + // Arithmetic precedence {"1 + 2 * 3", "(1.00 + (2.00 * 3.00))"}, {"2 * 3 + 4", "((2.00 * 3.00) + 4.00)"}, {"(1 + 2) * 3", "((1.00 + 2.00) * 3.00)"}, {"1 + 2 - 3", "((1.00 + 2.00) - 3.00)"}, {"2 * 3 / 4", "((2.00 * 3.00) / 4.00)"}, + + // Prefix with arithmetic + {"-1 + 2", "((-1.00) + 2.00)"}, + {"-(1 + 2)", "(-(1.00 + 2.00))"}, + {"-x * 2", "((-x) * 2.00)"}, + + // Comparison precedence + {"1 + 2 == 3", "((1.00 + 2.00) == 3.00)"}, + {"1 * 2 < 3 + 4", "((1.00 * 2.00) < (3.00 + 4.00))"}, + {"a + b != c * d", "((a + b) != (c * d))"}, + {"x <= y + z", "(x <= (y + z))"}, + {"a * b >= c / d", "((a * b) >= (c / d))"}, + + // Comparison chaining + {"a == b != c", "((a == b) != c)"}, + {"x < y <= z", "((x < y) <= z)"}, + + // Member access with operators + {"table.key + 1", "(table.key + 1.00)"}, + {"arr[0] * 2", "(arr[0.00] * 2.00)"}, + {"obj.x == obj.y", "(obj.x == obj.y)"}, + {"-table.value", "(-table.value)"}, + + // Complex combinations + {"-x + y * z == a.b", "(((-x) + (y * z)) == a.b)"}, + {"(a + b) * c <= d[0]", "(((a + b) * c) <= d[0.00])"}, } for _, tt := range tests { @@ -57,3 +191,35 @@ func TestOperatorPrecedence(t *testing.T) { }) } } + +func TestComplexExpressionsWithComparisons(t *testing.T) { + tests := []struct { + input string + desc string + }{ + {"x + 1 == y * 2", "arithmetic comparison"}, + {"table.count > arr[0] + 5", "member access comparison"}, + {"-value <= max", "prefix comparison"}, + {"(a + b) != (c - d)", "grouped comparison"}, + {"obj.x < obj.y && obj.y > obj.z", "multiple comparisons"}, // Note: && not implemented yet + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + l := parser.NewLexer(tt.input) + p := parser.NewParser(l) + expr := p.ParseExpression(parser.LOWEST) + + // Skip && test since it's not implemented + if tt.input == "obj.x < obj.y && obj.y > obj.z" { + return + } + + checkParserErrors(t, p) + + if expr == nil { + t.Error("expected non-nil expression") + } + }) + } +} diff --git a/parser/token.go b/parser/token.go index 217be87..c79821c 100644 --- a/parser/token.go +++ b/parser/token.go @@ -20,6 +20,14 @@ const ( SLASH // / DOT // . + // Comparison operators + EQ // == + NOT_EQ // != + LT // < + GT // > + LT_EQ // <= + GT_EQ // >= + // Delimiters LPAREN // ( RPAREN // ) @@ -60,15 +68,23 @@ type Precedence int const ( _ Precedence = iota LOWEST - SUM // + - PRODUCT // * - MEMBER // table[key], table.key - PREFIX // -x, !x - CALL // function() + EQUALS // ==, != + LESSGREATER // >, <, >=, <= + SUM // +, - + PRODUCT // *, / + PREFIX // -x, !x + MEMBER // table[key], table.key + CALL // function() ) // precedences maps token types to their precedence levels var precedences = map[TokenType]Precedence{ + EQ: EQUALS, + NOT_EQ: EQUALS, + LT: LESSGREATER, + GT: LESSGREATER, + LT_EQ: LESSGREATER, + GT_EQ: LESSGREATER, PLUS: SUM, MINUS: SUM, SLASH: PRODUCT,