From 5ae2a6ef23876037778ed7b39a1bfc7d81a4cb2a Mon Sep 17 00:00:00 2001 From: Sky Johnson Date: Wed, 11 Jun 2025 16:28:55 -0500 Subject: [PATCH] structs --- parser/ast.go | 63 +++- parser/parser.go | 313 ++++++++++++++++- parser/tests/structs_test.go | 632 +++++++++++++++++++++++++++++++++++ parser/token.go | 3 + parser/types.go | 197 ++++++++++- 5 files changed, 1199 insertions(+), 9 deletions(-) create mode 100644 parser/tests/structs_test.go diff --git a/parser/ast.go b/parser/ast.go index 941eb82..80c961d 100644 --- a/parser/ast.go +++ b/parser/ast.go @@ -4,7 +4,7 @@ import "fmt" // TypeInfo represents type information for expressions type TypeInfo struct { - Type string // "number", "string", "bool", "table", "function", "nil", "any" + Type string // "number", "string", "bool", "table", "function", "nil", "any", struct name Inferred bool // true if type was inferred, false if explicitly declared } @@ -41,6 +41,67 @@ func (p *Program) String() string { return result } +// StructField represents a field in a struct definition +type StructField struct { + Name string + TypeHint *TypeInfo +} + +func (sf *StructField) String() string { + if sf.TypeHint != nil { + return fmt.Sprintf("%s: %s", sf.Name, sf.TypeHint.Type) + } + return sf.Name +} + +// StructStatement represents struct definitions +type StructStatement struct { + Name string + Fields []StructField +} + +func (ss *StructStatement) statementNode() {} +func (ss *StructStatement) String() string { + var fields string + for i, field := range ss.Fields { + if i > 0 { + fields += ",\n\t" + } + fields += field.String() + } + return fmt.Sprintf("struct %s {\n\t%s\n}", ss.Name, fields) +} + +// MethodDefinition represents method definitions on structs +type MethodDefinition struct { + StructName string + MethodName string + Function *FunctionLiteral +} + +func (md *MethodDefinition) statementNode() {} +func (md *MethodDefinition) String() string { + return fmt.Sprintf("fn %s.%s%s", md.StructName, md.MethodName, md.Function.String()[2:]) // skip "fn" from function string +} + +// StructConstructorExpression represents struct constructor calls like my_type{...} +type StructConstructorExpression struct { + StructName string + Fields []TablePair // reuse TablePair for field assignments + typeInfo *TypeInfo +} + +func (sce *StructConstructorExpression) expressionNode() {} +func (sce *StructConstructorExpression) String() string { + var pairs []string + for _, pair := range sce.Fields { + pairs = append(pairs, pair.String()) + } + return fmt.Sprintf("%s{%s}", sce.StructName, joinStrings(pairs, ", ")) +} +func (sce *StructConstructorExpression) GetType() *TypeInfo { return sce.typeInfo } +func (sce *StructConstructorExpression) SetType(t *TypeInfo) { sce.typeInfo = t } + // AssignStatement represents variable assignment with optional type hint type AssignStatement struct { Name Expression // Changed from *Identifier to Expression for member access diff --git a/parser/parser.go b/parser/parser.go index 906e559..b29d0d6 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -34,6 +34,9 @@ type Parser struct { // Scope tracking scopes []map[string]bool // stack of scopes, each tracking declared variables scopeTypes []string // track what type each scope is: "global", "function", "loop" + + // Struct tracking + structs map[string]*StructStatement // track defined structs } // NewParser creates a new parser instance @@ -43,6 +46,7 @@ func NewParser(lexer *Lexer) *Parser { errors: []ParseError{}, scopes: []map[string]bool{make(map[string]bool)}, // start with global scope scopeTypes: []string{"global"}, // start with global scope type + structs: make(map[string]*StructStatement), // track struct definitions } p.prefixParseFns = make(map[TokenType]func() Expression) @@ -74,6 +78,7 @@ func NewParser(lexer *Lexer) *Parser { p.registerInfix(DOT, p.parseDotExpression) p.registerInfix(LBRACKET, p.parseIndexExpression) p.registerInfix(LPAREN, p.parseCallExpression) + p.registerInfix(LBRACE, p.parseStructConstructor) // struct constructor // Read two tokens, so curToken and peekToken are both set p.nextToken() @@ -157,7 +162,7 @@ func (p *Parser) parseTypeHint() *TypeInfo { } typeName := p.curToken.Literal - if !ValidTypeName(typeName) { + if !ValidTypeName(typeName) && !p.isStructDefined(typeName) { p.addError(fmt.Sprintf("invalid type name '%s'", typeName)) return nil } @@ -165,6 +170,12 @@ func (p *Parser) parseTypeHint() *TypeInfo { return &TypeInfo{Type: typeName, Inferred: false} } +// isStructDefined checks if a struct name is defined +func (p *Parser) isStructDefined(name string) bool { + _, exists := p.structs[name] + return exists +} + // registerPrefix registers a prefix parse function func (p *Parser) registerPrefix(tokenType TokenType, fn func() Expression) { p.prefixParseFns[tokenType] = fn @@ -200,6 +211,10 @@ func (p *Parser) ParseProgram() *Program { // parseStatement parses a statement func (p *Parser) parseStatement() Statement { switch p.curToken.Type { + case STRUCT: + return p.parseStructStatement() + case FN: + return p.parseFunctionStatement() case IDENT: return p.parseIdentifierStatement() case IF: @@ -230,6 +245,147 @@ func (p *Parser) parseStatement() Statement { } } +// parseStructStatement parses struct definitions +func (p *Parser) parseStructStatement() *StructStatement { + stmt := &StructStatement{} + + if !p.expectPeek(IDENT) { + p.addError("expected struct name") + return nil + } + + stmt.Name = p.curToken.Literal + + if !p.expectPeek(LBRACE) { + p.addError("expected '{' after struct name") + return nil + } + + stmt.Fields = []StructField{} + + if p.peekTokenIs(RBRACE) { + p.nextToken() + p.structs[stmt.Name] = stmt + return stmt + } + + p.nextToken() + + for { + if p.curTokenIs(EOF) { + p.addError("unexpected end of input, expected }") + return nil + } + + if !p.curTokenIs(IDENT) { + p.addError("expected field name") + return nil + } + + field := StructField{Name: p.curToken.Literal} + + // Parse optional type hint + field.TypeHint = p.parseTypeHint() + if field.TypeHint == nil { + p.addError("struct fields require type annotation") + return nil + } + + stmt.Fields = append(stmt.Fields, field) + + if !p.peekTokenIs(COMMA) { + break + } + + p.nextToken() // consume comma + p.nextToken() // move to next field + + if p.curTokenIs(RBRACE) { + break + } + + if p.curTokenIs(EOF) { + p.addError("expected next token to be }") + return nil + } + } + + if !p.expectPeek(RBRACE) { + return nil + } + + p.structs[stmt.Name] = stmt + return stmt +} + +// parseFunctionStatement handles both regular functions and methods +func (p *Parser) parseFunctionStatement() Statement { + if !p.expectPeek(IDENT) { + p.addError("expected function name") + return nil + } + + funcName := p.curToken.Literal + + // Check if this is a method definition (struct.method) + if p.peekTokenIs(DOT) { + p.nextToken() // consume '.' + + if !p.expectPeek(IDENT) { + p.addError("expected method name after '.'") + return nil + } + + methodName := p.curToken.Literal + + if !p.expectPeek(LPAREN) { + p.addError("expected '(' after method name") + return nil + } + + // Parse the function literal starting from parameters + funcLit := &FunctionLiteral{} + funcLit.Parameters, funcLit.Variadic = p.parseFunctionParameters() + + if !p.expectPeek(RPAREN) { + p.addError("expected ')' after function parameters") + return nil + } + + // Check for return type hint + funcLit.ReturnType = p.parseTypeHint() + + p.nextToken() + + p.enterFunctionScope() + for _, param := range funcLit.Parameters { + p.declareVariable(param.Name) + } + funcLit.Body = p.parseBlockStatements(END) + p.exitFunctionScope() + + if !p.curTokenIs(END) { + p.addError("expected 'end' to close function") + return nil + } + + return &MethodDefinition{ + StructName: funcName, + MethodName: methodName, + Function: funcLit, + } + } + + // Regular function - this should be handled as expression statement + // Reset to handle as function literal + funcLit := p.parseFunctionLiteral() + if funcLit == nil { + return nil + } + + return &ExpressionStatement{Expression: funcLit} +} + // parseIdentifierStatement handles both assignments and expression statements starting with identifiers func (p *Parser) parseIdentifierStatement() Statement { // Parse the left-hand side expression first @@ -948,6 +1104,157 @@ func (p *Parser) parseTableLiteral() Expression { return table } +// parseStructConstructor handles struct constructor calls like my_type{...} +func (p *Parser) parseStructConstructor(left Expression) Expression { + // left should be an identifier representing the struct name + ident, ok := left.(*Identifier) + if !ok { + // Not an identifier, fall back to table literal parsing + return p.parseTableLiteralFromBrace() + } + + structName := ident.Value + + // Always try to parse as struct constructor if we have an identifier + // Type checking will catch undefined structs later + constructor := &StructConstructorExpression{ + StructName: structName, + Fields: []TablePair{}, + } + + if p.peekTokenIs(RBRACE) { + p.nextToken() + return constructor + } + + p.nextToken() + + for { + if p.curTokenIs(EOF) { + p.addError("unexpected end of input, expected }") + return nil + } + + pair := TablePair{} + + if (p.curTokenIs(IDENT) || p.curTokenIs(STRING)) && p.peekTokenIs(ASSIGN) { + if p.curTokenIs(IDENT) { + pair.Key = &Identifier{Value: p.curToken.Literal} + } else { + pair.Key = &StringLiteral{Value: p.curToken.Literal} + } + p.nextToken() + p.nextToken() + + if p.curTokenIs(EOF) { + p.addError("expected expression after assignment operator") + return nil + } + + pair.Value = p.ParseExpression(LOWEST) + } else { + pair.Value = p.ParseExpression(LOWEST) + } + + if pair.Value == nil { + return nil + } + + constructor.Fields = append(constructor.Fields, pair) + + if !p.peekTokenIs(COMMA) { + break + } + + p.nextToken() + p.nextToken() + + if p.curTokenIs(RBRACE) { + break + } + + if p.curTokenIs(EOF) { + p.addError("expected next token to be }") + return nil + } + } + + if !p.expectPeek(RBRACE) { + return nil + } + + return constructor +} + +func (p *Parser) parseTableLiteralFromBrace() Expression { + // We're already at the opening brace, so parse as table literal + table := &TableLiteral{} + table.Pairs = []TablePair{} + + if p.peekTokenIs(RBRACE) { + p.nextToken() + return table + } + + p.nextToken() + + for { + if p.curTokenIs(EOF) { + p.addError("unexpected end of input, expected }") + return nil + } + + pair := TablePair{} + + if (p.curTokenIs(IDENT) || p.curTokenIs(STRING)) && p.peekTokenIs(ASSIGN) { + if p.curTokenIs(IDENT) { + pair.Key = &Identifier{Value: p.curToken.Literal} + } else { + pair.Key = &StringLiteral{Value: p.curToken.Literal} + } + p.nextToken() + p.nextToken() + + if p.curTokenIs(EOF) { + p.addError("expected expression after assignment operator") + return nil + } + + pair.Value = p.ParseExpression(LOWEST) + } else { + pair.Value = p.ParseExpression(LOWEST) + } + + if pair.Value == nil { + return nil + } + + table.Pairs = append(table.Pairs, pair) + + if !p.peekTokenIs(COMMA) { + break + } + + p.nextToken() + p.nextToken() + + if p.curTokenIs(RBRACE) { + break + } + + if p.curTokenIs(EOF) { + p.addError("expected next token to be }") + return nil + } + } + + if !p.expectPeek(RBRACE) { + return nil + } + + return table +} + func (p *Parser) parseInfixExpression(left Expression) Expression { expression := &InfixExpression{ Left: left, @@ -1057,7 +1364,7 @@ func (p *Parser) expectPeekIdent() bool { func (p *Parser) isKeyword(t TokenType) bool { switch t { - case TRUE, FALSE, NIL, AND, OR, NOT, IF, THEN, ELSEIF, ELSE, END, ECHO, FOR, WHILE, IN, DO, BREAK, EXIT, FN, RETURN: + case TRUE, FALSE, NIL, AND, OR, NOT, IF, THEN, ELSEIF, ELSE, END, ECHO, FOR, WHILE, IN, DO, BREAK, EXIT, FN, RETURN, STRUCT: return true default: return false @@ -1227,6 +1534,8 @@ func tokenTypeString(t TokenType) string { return "fn" case RETURN: return "return" + case STRUCT: + return "struct" case EOF: return "end of file" case ILLEGAL: diff --git a/parser/tests/structs_test.go b/parser/tests/structs_test.go new file mode 100644 index 0000000..36a3394 --- /dev/null +++ b/parser/tests/structs_test.go @@ -0,0 +1,632 @@ +package parser_test + +import ( + "testing" + + "git.sharkk.net/Sharkk/Mako/parser" +) + +func TestBasicStructDefinition(t *testing.T) { + input := `struct Person { + name: string, + age: number +}` + + l := parser.NewLexer(input) + p := parser.NewParser(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + if len(program.Statements) != 1 { + t.Fatalf("expected 1 statement, got %d", len(program.Statements)) + } + + stmt, ok := program.Statements[0].(*parser.StructStatement) + if !ok { + t.Fatalf("expected StructStatement, got %T", program.Statements[0]) + } + + if stmt.Name != "Person" { + t.Errorf("expected struct name 'Person', got %s", stmt.Name) + } + + if len(stmt.Fields) != 2 { + t.Fatalf("expected 2 fields, got %d", len(stmt.Fields)) + } + + // Test first field + if stmt.Fields[0].Name != "name" { + t.Errorf("expected field name 'name', got %s", stmt.Fields[0].Name) + } + if stmt.Fields[0].TypeHint == nil { + t.Fatal("expected type hint for name field") + } + if stmt.Fields[0].TypeHint.Type != "string" { + t.Errorf("expected type 'string', got %s", stmt.Fields[0].TypeHint.Type) + } + + // Test second field + if stmt.Fields[1].Name != "age" { + t.Errorf("expected field name 'age', got %s", stmt.Fields[1].Name) + } + if stmt.Fields[1].TypeHint == nil { + t.Fatal("expected type hint for age field") + } + if stmt.Fields[1].TypeHint.Type != "number" { + t.Errorf("expected type 'number', got %s", stmt.Fields[1].TypeHint.Type) + } +} + +func TestEmptyStructDefinition(t *testing.T) { + input := `struct Empty {}` + + l := parser.NewLexer(input) + p := parser.NewParser(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + if len(program.Statements) != 1 { + t.Fatalf("expected 1 statement, got %d", len(program.Statements)) + } + + stmt, ok := program.Statements[0].(*parser.StructStatement) + if !ok { + t.Fatalf("expected StructStatement, got %T", program.Statements[0]) + } + + if stmt.Name != "Empty" { + t.Errorf("expected struct name 'Empty', got %s", stmt.Name) + } + + if len(stmt.Fields) != 0 { + t.Errorf("expected 0 fields, got %d", len(stmt.Fields)) + } +} + +func TestComplexStructDefinition(t *testing.T) { + input := `struct Complex { + id: number, + name: string, + active: bool, + data: table, + callback: function, + optional: any +}` + + l := parser.NewLexer(input) + p := parser.NewParser(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + if len(program.Statements) != 1 { + t.Fatalf("expected 1 statement, got %d", len(program.Statements)) + } + + stmt, ok := program.Statements[0].(*parser.StructStatement) + if !ok { + t.Fatalf("expected StructStatement, got %T", program.Statements[0]) + } + + expectedTypes := []string{"number", "string", "bool", "table", "function", "any"} + expectedNames := []string{"id", "name", "active", "data", "callback", "optional"} + + if len(stmt.Fields) != len(expectedTypes) { + t.Fatalf("expected %d fields, got %d", len(expectedTypes), len(stmt.Fields)) + } + + for i, field := range stmt.Fields { + if field.Name != expectedNames[i] { + t.Errorf("field %d: expected name '%s', got '%s'", i, expectedNames[i], field.Name) + } + if field.TypeHint == nil { + t.Fatalf("field %d: expected type hint", i) + } + if field.TypeHint.Type != expectedTypes[i] { + t.Errorf("field %d: expected type '%s', got '%s'", i, expectedTypes[i], field.TypeHint.Type) + } + } +} + +func TestMethodDefinition(t *testing.T) { + input := `struct Person { + name: string, + age: number +} + +fn Person.getName(): string + return self.name +end + +fn Person.setAge(newAge: number) + self.age = newAge +end` + + l := parser.NewLexer(input) + p := parser.NewParser(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + if len(program.Statements) != 3 { + t.Fatalf("expected 3 statements, got %d", len(program.Statements)) + } + + // First statement: struct definition + structStmt, ok := program.Statements[0].(*parser.StructStatement) + if !ok { + t.Fatalf("expected StructStatement, got %T", program.Statements[0]) + } + if structStmt.Name != "Person" { + t.Errorf("expected struct name 'Person', got %s", structStmt.Name) + } + + // Second statement: getter method + method1, ok := program.Statements[1].(*parser.MethodDefinition) + if !ok { + t.Fatalf("expected MethodDefinition, got %T", program.Statements[1]) + } + if method1.StructName != "Person" { + t.Errorf("expected struct name 'Person', got %s", method1.StructName) + } + if method1.MethodName != "getName" { + t.Errorf("expected method name 'getName', got %s", method1.MethodName) + } + if method1.Function.ReturnType == nil { + t.Fatal("expected return type for getName method") + } + if method1.Function.ReturnType.Type != "string" { + t.Errorf("expected return type 'string', got %s", method1.Function.ReturnType.Type) + } + if len(method1.Function.Parameters) != 0 { + t.Errorf("expected 0 parameters, got %d", len(method1.Function.Parameters)) + } + + // Third statement: setter method + method2, ok := program.Statements[2].(*parser.MethodDefinition) + if !ok { + t.Fatalf("expected MethodDefinition, got %T", program.Statements[2]) + } + if method2.StructName != "Person" { + t.Errorf("expected struct name 'Person', got %s", method2.StructName) + } + if method2.MethodName != "setAge" { + t.Errorf("expected method name 'setAge', got %s", method2.MethodName) + } + if method2.Function.ReturnType != nil { + t.Errorf("expected no return type for setAge method, got %s", method2.Function.ReturnType.Type) + } + if len(method2.Function.Parameters) != 1 { + t.Fatalf("expected 1 parameter, got %d", len(method2.Function.Parameters)) + } + if method2.Function.Parameters[0].Name != "newAge" { + t.Errorf("expected parameter name 'newAge', got %s", method2.Function.Parameters[0].Name) + } + if method2.Function.Parameters[0].TypeHint == nil { + t.Fatal("expected type hint for newAge parameter") + } + if method2.Function.Parameters[0].TypeHint.Type != "number" { + t.Errorf("expected parameter type 'number', got %s", method2.Function.Parameters[0].TypeHint.Type) + } +} + +func TestStructConstructor(t *testing.T) { + input := `struct Person { + name: string, + age: number +} + +person = Person{name = "John", age = 30} +empty = Person{}` + + l := parser.NewLexer(input) + p := parser.NewParser(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + if len(program.Statements) != 3 { + t.Fatalf("expected 3 statements, got %d", len(program.Statements)) + } + + // Second statement: constructor with fields + assign1, ok := program.Statements[1].(*parser.AssignStatement) + if !ok { + t.Fatalf("expected AssignStatement, got %T", program.Statements[1]) + } + + constructor1, ok := assign1.Value.(*parser.StructConstructorExpression) + if !ok { + t.Fatalf("expected StructConstructorExpression, got %T", assign1.Value) + } + + if constructor1.StructName != "Person" { + t.Errorf("expected struct name 'Person', got %s", constructor1.StructName) + } + + if len(constructor1.Fields) != 2 { + t.Fatalf("expected 2 fields, got %d", len(constructor1.Fields)) + } + + // Check name field + nameKey, ok := constructor1.Fields[0].Key.(*parser.Identifier) + if !ok { + t.Fatalf("expected Identifier for name key, got %T", constructor1.Fields[0].Key) + } + if nameKey.Value != "name" { + t.Errorf("expected key 'name', got %s", nameKey.Value) + } + testStringLiteral(t, constructor1.Fields[0].Value, "John") + + // Check age field + ageKey, ok := constructor1.Fields[1].Key.(*parser.Identifier) + if !ok { + t.Fatalf("expected Identifier for age key, got %T", constructor1.Fields[1].Key) + } + if ageKey.Value != "age" { + t.Errorf("expected key 'age', got %s", ageKey.Value) + } + testNumberLiteral(t, constructor1.Fields[1].Value, 30) + + // Third statement: empty constructor + assign2, ok := program.Statements[2].(*parser.AssignStatement) + if !ok { + t.Fatalf("expected AssignStatement, got %T", program.Statements[2]) + } + + constructor2, ok := assign2.Value.(*parser.StructConstructorExpression) + if !ok { + t.Fatalf("expected StructConstructorExpression, got %T", assign2.Value) + } + + if constructor2.StructName != "Person" { + t.Errorf("expected struct name 'Person', got %s", constructor2.StructName) + } + + if len(constructor2.Fields) != 0 { + t.Errorf("expected 0 fields, got %d", len(constructor2.Fields)) + } +} + +func TestNestedStructTypes(t *testing.T) { + input := `struct Address { + street: string, + city: string +} + +struct Person { + name: string, + address: Address +} + +person = Person{ + name = "John", + address = Address{street = "Main St", city = "NYC"} +}` + + l := parser.NewLexer(input) + p := parser.NewParser(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + if len(program.Statements) != 3 { + t.Fatalf("expected 3 statements, got %d", len(program.Statements)) + } + + // Check Person struct has Address field type + personStruct, ok := program.Statements[1].(*parser.StructStatement) + if !ok { + t.Fatalf("expected StructStatement, got %T", program.Statements[1]) + } + + addressField := personStruct.Fields[1] + if addressField.Name != "address" { + t.Errorf("expected field name 'address', got %s", addressField.Name) + } + if addressField.TypeHint.Type != "Address" { + t.Errorf("expected field type 'Address', got %s", addressField.TypeHint.Type) + } + + // Check nested constructor + assign, ok := program.Statements[2].(*parser.AssignStatement) + if !ok { + t.Fatalf("expected AssignStatement, got %T", program.Statements[2]) + } + + personConstructor, ok := assign.Value.(*parser.StructConstructorExpression) + if !ok { + t.Fatalf("expected StructConstructorExpression, got %T", assign.Value) + } + + // Check the nested Address constructor + addressConstructor, ok := personConstructor.Fields[1].Value.(*parser.StructConstructorExpression) + if !ok { + t.Fatalf("expected nested StructConstructorExpression, got %T", personConstructor.Fields[1].Value) + } + + if addressConstructor.StructName != "Address" { + t.Errorf("expected nested struct name 'Address', got %s", addressConstructor.StructName) + } + + if len(addressConstructor.Fields) != 2 { + t.Errorf("expected 2 fields in nested constructor, got %d", len(addressConstructor.Fields)) + } +} + +func TestStructIntegrationWithProgram(t *testing.T) { + input := `struct Point { + x: number, + y: number +} + +fn Point.distance(other: Point): number + dx = self.x - other.x + dy = self.y - other.y + return (dx * dx + dy * dy) +end + +p1 = Point{x = 0, y = 0} +p2 = Point{x = 3, y = 4} + +if p1.distance(p2) then + echo "Distance calculated" +end + +for i = 1, 10 do + point = Point{x = i, y = i * 2} + echo point.x +end` + + l := parser.NewLexer(input) + p := parser.NewParser(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + if len(program.Statements) != 6 { + t.Fatalf("expected 6 statements, got %d", len(program.Statements)) + } + + // Verify struct definition + structStmt, ok := program.Statements[0].(*parser.StructStatement) + if !ok { + t.Fatalf("expected StructStatement, got %T", program.Statements[0]) + } + if structStmt.Name != "Point" { + t.Errorf("expected struct name 'Point', got %s", structStmt.Name) + } + + // Verify method definition + methodStmt, ok := program.Statements[1].(*parser.MethodDefinition) + if !ok { + t.Fatalf("expected MethodDefinition, got %T", program.Statements[1]) + } + if methodStmt.StructName != "Point" { + t.Errorf("expected struct name 'Point', got %s", methodStmt.StructName) + } + if methodStmt.MethodName != "distance" { + t.Errorf("expected method name 'distance', got %s", methodStmt.MethodName) + } + + // Verify constructors + for i := 2; i <= 3; i++ { + assign, ok := program.Statements[i].(*parser.AssignStatement) + if !ok { + t.Fatalf("statement %d: expected AssignStatement, got %T", i, program.Statements[i]) + } + constructor, ok := assign.Value.(*parser.StructConstructorExpression) + if !ok { + t.Fatalf("statement %d: expected StructConstructorExpression, got %T", i, assign.Value) + } + if constructor.StructName != "Point" { + t.Errorf("statement %d: expected struct name 'Point', got %s", i, constructor.StructName) + } + } + + // Verify if statement with method call + ifStmt, ok := program.Statements[4].(*parser.IfStatement) + if !ok { + t.Fatalf("expected IfStatement, got %T", program.Statements[4]) + } + callExpr, ok := ifStmt.Condition.(*parser.CallExpression) + if !ok { + t.Fatalf("expected CallExpression in if condition, got %T", ifStmt.Condition) + } + dotExpr, ok := callExpr.Function.(*parser.DotExpression) + if !ok { + t.Fatalf("expected DotExpression for method call, got %T", callExpr.Function) + } + if dotExpr.Key != "distance" { + t.Errorf("expected method name 'distance', got %s", dotExpr.Key) + } + + // Verify for loop with struct creation + forStmt, ok := program.Statements[5].(*parser.ForStatement) + if !ok { + t.Fatalf("expected ForStatement, got %T", program.Statements[5]) + } + if len(forStmt.Body) != 2 { + t.Errorf("expected 2 statements in for body, got %d", len(forStmt.Body)) + } + + // Check struct constructor in loop + loopAssign, ok := forStmt.Body[0].(*parser.AssignStatement) + if !ok { + t.Fatalf("expected AssignStatement in loop, got %T", forStmt.Body[0]) + } + loopConstructor, ok := loopAssign.Value.(*parser.StructConstructorExpression) + if !ok { + t.Fatalf("expected StructConstructorExpression in loop, got %T", loopAssign.Value) + } + if loopConstructor.StructName != "Point" { + t.Errorf("expected struct name 'Point' in loop, got %s", loopConstructor.StructName) + } +} + +func TestStructErrorCases(t *testing.T) { + tests := []struct { + name string + input string + expectedErrorSubstring string + }{ + { + name: "missing field type", + input: `struct Person { + name + }`, + expectedErrorSubstring: "struct fields require type annotation", + }, + { + name: "missing struct name", + input: `struct { + name: string + }`, + expectedErrorSubstring: "expected struct name", + }, + { + name: "missing opening brace", + input: `struct Person + name: string + }`, + expectedErrorSubstring: "expected '{' after struct name", + }, + { + name: "missing closing brace", + input: `struct Person { + name: string`, + expectedErrorSubstring: "expected next token to be }, got end of file instead", + }, + { + name: "invalid field type", + input: `struct Person { + name: invalidtype + }`, + expectedErrorSubstring: "invalid type name 'invalidtype'", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + l := parser.NewLexer(tt.input) + p := parser.NewParser(l) + _ = p.ParseProgram() + + if !p.HasErrors() { + t.Fatalf("expected parser errors, but got none") + } + + errors := p.ErrorStrings() + found := false + for _, err := range errors { + if containsSubstring(err, tt.expectedErrorSubstring) { + found = true + break + } + } + + if !found { + t.Errorf("expected error containing '%s', got errors: %v", tt.expectedErrorSubstring, errors) + } + }) + } +} + +func TestSingleLineStruct(t *testing.T) { + input := `struct Person { name: string, age: number }` + + l := parser.NewLexer(input) + p := parser.NewParser(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + if len(program.Statements) != 1 { + t.Fatalf("expected 1 statement, got %d", len(program.Statements)) + } + + stmt, ok := program.Statements[0].(*parser.StructStatement) + if !ok { + t.Fatalf("expected StructStatement, got %T", program.Statements[0]) + } + + if stmt.Name != "Person" { + t.Errorf("expected struct name 'Person', got %s", stmt.Name) + } + + if len(stmt.Fields) != 2 { + t.Fatalf("expected 2 fields, got %d", len(stmt.Fields)) + } + + if stmt.Fields[0].Name != "name" || stmt.Fields[0].TypeHint.Type != "string" { + t.Errorf("expected first field 'name: string', got '%s: %s'", + stmt.Fields[0].Name, stmt.Fields[0].TypeHint.Type) + } + + if stmt.Fields[1].Name != "age" || stmt.Fields[1].TypeHint.Type != "number" { + t.Errorf("expected second field 'age: number', got '%s: %s'", + stmt.Fields[1].Name, stmt.Fields[1].TypeHint.Type) + } +} + +func TestStructString(t *testing.T) { + input := `struct Person { + name: string, + age: number +}` + + l := parser.NewLexer(input) + p := parser.NewParser(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + stmt := program.Statements[0].(*parser.StructStatement) + str := stmt.String() + + expected := "struct Person {\n\tname: string,\n\tage: number\n}" + if str != expected { + t.Errorf("expected string representation:\n%s\ngot:\n%s", expected, str) + } +} + +func TestMethodString(t *testing.T) { + input := `struct Person { + name: string +} + +fn Person.getName(): string + return self.name +end` + + l := parser.NewLexer(input) + p := parser.NewParser(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + method := program.Statements[1].(*parser.MethodDefinition) + str := method.String() + + if !containsSubstring(str, "fn Person.getName") { + t.Errorf("expected method string to contain 'fn Person.getName', got: %s", str) + } + if !containsSubstring(str, ": string") { + t.Errorf("expected method string to contain return type, got: %s", str) + } +} + +func TestConstructorString(t *testing.T) { + input := `struct Person { + name: string, + age: number +} + +person = Person{name = "John", age = 30}` + + l := parser.NewLexer(input) + p := parser.NewParser(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + assign := program.Statements[1].(*parser.AssignStatement) + constructor := assign.Value.(*parser.StructConstructorExpression) + str := constructor.String() + + expected := `Person{name = "John", age = 30.00}` + if str != expected { + t.Errorf("expected constructor string:\n%s\ngot:\n%s", expected, str) + } +} diff --git a/parser/token.go b/parser/token.go index 4a8fb80..3fcdf28 100644 --- a/parser/token.go +++ b/parser/token.go @@ -59,6 +59,7 @@ const ( EXIT FN RETURN + STRUCT // Special EOF @@ -107,6 +108,7 @@ var precedences = map[TokenType]Precedence{ DOT: MEMBER, LBRACKET: MEMBER, LPAREN: CALL, + LBRACE: CALL, } // lookupIdent checks if an identifier is a keyword @@ -132,6 +134,7 @@ func lookupIdent(ident string) TokenType { "exit": EXIT, "fn": FN, "return": RETURN, + "struct": STRUCT, } if tok, ok := keywords[ident]; ok { diff --git a/parser/types.go b/parser/types.go index c674b36..7bce16c 100644 --- a/parser/types.go +++ b/parser/types.go @@ -76,6 +76,9 @@ type TypeInferrer struct { nilType *TypeInfo tableType *TypeInfo anyType *TypeInfo + + // Struct definitions + structs map[string]*StructStatement } // NewTypeInferrer creates a new type inference engine @@ -86,6 +89,7 @@ func NewTypeInferrer() *TypeInferrer { currentScope: globalScope, globalScope: globalScope, errors: []TypeError{}, + structs: make(map[string]*StructStatement), // Pre-allocate common types to reduce allocations numberType: &TypeInfo{Type: TypeNumber, Inferred: true}, @@ -101,6 +105,14 @@ func NewTypeInferrer() *TypeInferrer { // InferTypes performs type inference on the entire program func (ti *TypeInferrer) InferTypes(program *Program) []TypeError { + // First pass: collect struct definitions + for _, stmt := range program.Statements { + if structStmt, ok := stmt.(*StructStatement); ok { + ti.structs[structStmt.Name] = structStmt + } + } + + // Second pass: infer types for _, stmt := range program.Statements { ti.inferStatement(stmt) } @@ -129,9 +141,27 @@ func (ti *TypeInferrer) addError(message string, node Node) { }) } +// getStructType returns TypeInfo for a struct +func (ti *TypeInferrer) getStructType(name string) *TypeInfo { + if _, exists := ti.structs[name]; exists { + return &TypeInfo{Type: name, Inferred: true} + } + return nil +} + +// isStructType checks if a type is a struct type +func (ti *TypeInferrer) isStructType(t *TypeInfo) bool { + _, exists := ti.structs[t.Type] + return exists +} + // inferStatement infers types for statements func (ti *TypeInferrer) inferStatement(stmt Statement) { switch s := stmt.(type) { + case *StructStatement: + ti.inferStructStatement(s) + case *MethodDefinition: + ti.inferMethodDefinition(s) case *AssignStatement: ti.inferAssignStatement(s) case *EchoStatement: @@ -152,9 +182,63 @@ func (ti *TypeInferrer) inferStatement(stmt Statement) { if s.Value != nil { ti.inferExpression(s.Value) } + case *ExpressionStatement: + ti.inferExpression(s.Expression) } } +// inferStructStatement handles struct definitions +func (ti *TypeInferrer) inferStructStatement(stmt *StructStatement) { + // Validate field types + for _, field := range stmt.Fields { + if field.TypeHint != nil { + if !ValidTypeName(field.TypeHint.Type) && !ti.isStructType(field.TypeHint) { + ti.addError(fmt.Sprintf("invalid field type '%s' in struct '%s'", + field.TypeHint.Type, stmt.Name), stmt) + } + } + } +} + +// inferMethodDefinition handles method definitions +func (ti *TypeInferrer) inferMethodDefinition(stmt *MethodDefinition) { + // Check if struct exists + if _, exists := ti.structs[stmt.StructName]; !exists { + ti.addError(fmt.Sprintf("method defined on undefined struct '%s'", stmt.StructName), stmt) + return + } + + // Infer the function body + ti.enterScope() + + // Add self parameter implicitly + ti.currentScope.Define(&Symbol{ + Name: "self", + Type: ti.getStructType(stmt.StructName), + Declared: true, + }) + + // Add explicit parameters + for _, param := range stmt.Function.Parameters { + paramType := ti.anyType + if param.TypeHint != nil { + paramType = param.TypeHint + } + ti.currentScope.Define(&Symbol{ + Name: param.Name, + Type: paramType, + Declared: true, + }) + } + + // Infer function body + for _, bodyStmt := range stmt.Function.Body { + ti.inferStatement(bodyStmt) + } + + ti.exitScope() +} + // inferAssignStatement handles variable assignments with type checking func (ti *TypeInferrer) inferAssignStatement(stmt *AssignStatement) { // Infer the type of the value expression @@ -288,9 +372,9 @@ func (ti *TypeInferrer) inferForStatement(stmt *ForStatement) { func (ti *TypeInferrer) inferForInStatement(stmt *ForInStatement) { iterableType := ti.inferExpression(stmt.Iterable) - // For now, assume iterable is a table - if !ti.isTableType(iterableType) { - ti.addError("for-in requires an iterable (table)", stmt.Iterable) + // For now, assume iterable is a table or struct + if !ti.isTableType(iterableType) && !ti.isStructType(iterableType) { + ti.addError("for-in requires an iterable (table or struct)", stmt.Iterable) } ti.enterScope() @@ -341,6 +425,8 @@ func (ti *TypeInferrer) inferExpression(expr Expression) *TypeInfo { return ti.nilType case *TableLiteral: return ti.inferTableLiteral(e) + case *StructConstructorExpression: + return ti.inferStructConstructor(e) case *FunctionLiteral: return ti.inferFunctionLiteral(e) case *CallExpression: @@ -353,12 +439,85 @@ func (ti *TypeInferrer) inferExpression(expr Expression) *TypeInfo { return ti.inferIndexExpression(e) case *DotExpression: return ti.inferDotExpression(e) + case *AssignExpression: + return ti.inferAssignExpression(e) default: ti.addError("unknown expression type", expr) return ti.anyType } } +// inferStructConstructor handles struct constructor expressions +func (ti *TypeInferrer) inferStructConstructor(expr *StructConstructorExpression) *TypeInfo { + structDef, exists := ti.structs[expr.StructName] + if !exists { + ti.addError(fmt.Sprintf("undefined struct '%s'", expr.StructName), expr) + return ti.anyType + } + + // Validate field assignments + for _, pair := range expr.Fields { + if pair.Key != nil { + fieldName := "" + if ident, ok := pair.Key.(*Identifier); ok { + fieldName = ident.Value + } else if str, ok := pair.Key.(*StringLiteral); ok { + fieldName = str.Value + } + + // Check if field exists in struct + fieldExists := false + var fieldType *TypeInfo + for _, field := range structDef.Fields { + if field.Name == fieldName { + fieldExists = true + fieldType = field.TypeHint + break + } + } + + if !fieldExists { + ti.addError(fmt.Sprintf("struct '%s' has no field '%s'", expr.StructName, fieldName), expr) + } else { + // Check type compatibility + valueType := ti.inferExpression(pair.Value) + if !ti.isTypeCompatible(valueType, fieldType) { + ti.addError(fmt.Sprintf("cannot assign %s to field '%s' of type %s", + valueType.Type, fieldName, fieldType.Type), expr) + } + } + } else { + // Array-style assignment not valid for structs + ti.addError("struct constructors require named field assignments", expr) + } + } + + structType := ti.getStructType(expr.StructName) + expr.SetType(structType) + return structType +} + +// inferAssignExpression handles assignment expressions +func (ti *TypeInferrer) inferAssignExpression(expr *AssignExpression) *TypeInfo { + valueType := ti.inferExpression(expr.Value) + + if ident, ok := expr.Name.(*Identifier); ok { + if expr.IsDeclaration { + ti.currentScope.Define(&Symbol{ + Name: ident.Value, + Type: valueType, + Declared: true, + }) + } + ident.SetType(valueType) + } else { + ti.inferExpression(expr.Name) + } + + expr.SetType(valueType) + return valueType +} + // inferIdentifier looks up identifier type in symbol table func (ti *TypeInferrer) inferIdentifier(ident *Identifier) *TypeInfo { symbol := ti.currentScope.Lookup(ident.Value) @@ -497,17 +656,43 @@ func (ti *TypeInferrer) inferInfixExpression(infix *InfixExpression) *TypeInfo { // inferIndexExpression infers table[index] type func (ti *TypeInferrer) inferIndexExpression(index *IndexExpression) *TypeInfo { - ti.inferExpression(index.Left) + leftType := ti.inferExpression(index.Left) ti.inferExpression(index.Index) - // For now, assume table access returns any + // If indexing a struct, try to infer field type + if ti.isStructType(leftType) { + if strLit, ok := index.Index.(*StringLiteral); ok { + if structDef, exists := ti.structs[leftType.Type]; exists { + for _, field := range structDef.Fields { + if field.Name == strLit.Value { + index.SetType(field.TypeHint) + return field.TypeHint + } + } + } + } + } + + // For now, assume table/struct access returns any index.SetType(ti.anyType) return ti.anyType } // inferDotExpression infers table.key type func (ti *TypeInferrer) inferDotExpression(dot *DotExpression) *TypeInfo { - ti.inferExpression(dot.Left) + leftType := ti.inferExpression(dot.Left) + + // If accessing a struct field, try to infer field type + if ti.isStructType(leftType) { + if structDef, exists := ti.structs[leftType.Type]; exists { + for _, field := range structDef.Fields { + if field.Name == dot.Key { + dot.SetType(field.TypeHint) + return field.TypeHint + } + } + } + } // For now, assume member access returns any dot.SetType(ti.anyType)