diff --git a/parser/ast.go b/parser/ast.go index 80c961d..2d1c4ae 100644 --- a/parser/ast.go +++ b/parser/ast.go @@ -2,32 +2,28 @@ package parser import "fmt" -// TypeInfo represents type information for expressions -type TypeInfo struct { - Type string // "number", "string", "bool", "table", "function", "nil", "any", struct name - Inferred bool // true if type was inferred, false if explicitly declared -} +// Note: Type definitions moved to types.go for proper separation of concerns // Node represents any node in the AST type Node interface { String() string } -// Statement represents statement nodes +// Statement represents statement nodes that can appear at the top level or in blocks type Statement interface { Node statementNode() } -// Expression represents expression nodes +// Expression represents expression nodes that produce values and have types type Expression interface { Node expressionNode() - GetType() *TypeInfo - SetType(*TypeInfo) + TypeInfo() TypeInfo // Returns type by value, not pointer } -// Program represents the root of the AST +// Program represents the root of the AST containing all top-level statements. +// Tracks exit code for script termination and owns the statement list. type Program struct { Statements []Statement ExitCode int @@ -41,23 +37,23 @@ func (p *Program) String() string { return result } -// StructField represents a field in a struct definition +// StructField represents a field definition within a struct. +// Contains field name and required type annotation for compile-time checking. type StructField struct { Name string - TypeHint *TypeInfo + TypeHint TypeInfo // Required for struct fields, embeds directly } func (sf *StructField) String() string { - if sf.TypeHint != nil { - return fmt.Sprintf("%s: %s", sf.Name, sf.TypeHint.Type) - } - return sf.Name + return fmt.Sprintf("%s: %s", sf.Name, typeToString(sf.TypeHint)) } -// StructStatement represents struct definitions +// StructStatement represents struct type definitions with named fields. +// Defines new types that can be instantiated and used for type checking. type StructStatement struct { Name string Fields []StructField + ID uint16 // Unique identifier for fast lookup } func (ss *StructStatement) statementNode() {} @@ -72,77 +68,72 @@ func (ss *StructStatement) String() string { return fmt.Sprintf("struct %s {\n\t%s\n}", ss.Name, fields) } -// MethodDefinition represents method definitions on structs +// MethodDefinition represents method definitions attached to struct types. +// Links a function implementation to a specific struct via struct ID. type MethodDefinition struct { - StructName string + StructID uint16 // Index into struct table for fast lookup MethodName string Function *FunctionLiteral } func (md *MethodDefinition) statementNode() {} func (md *MethodDefinition) String() string { - return fmt.Sprintf("fn %s.%s%s", md.StructName, md.MethodName, md.Function.String()[2:]) // skip "fn" from function string + return fmt.Sprintf("fn .%s%s", md.MethodName, md.Function.String()[2:]) } -// StructConstructorExpression represents struct constructor calls like my_type{...} -type StructConstructorExpression struct { - StructName string - Fields []TablePair // reuse TablePair for field assignments - typeInfo *TypeInfo +// StructConstructor represents struct instantiation with field initialization. +// Uses struct ID for fast type resolution and validation during parsing. +type StructConstructor struct { + StructID uint16 // Index into struct table + Fields []TablePair // Reuses table pair structure for field assignments + typeInfo TypeInfo // Cached type info for this constructor } -func (sce *StructConstructorExpression) expressionNode() {} -func (sce *StructConstructorExpression) String() string { +func (sc *StructConstructor) expressionNode() {} +func (sc *StructConstructor) String() string { var pairs []string - for _, pair := range sce.Fields { + for _, pair := range sc.Fields { pairs = append(pairs, pair.String()) } - return fmt.Sprintf("%s{%s}", sce.StructName, joinStrings(pairs, ", ")) + return fmt.Sprintf("{%s}", joinStrings(pairs, ", ")) } -func (sce *StructConstructorExpression) GetType() *TypeInfo { return sce.typeInfo } -func (sce *StructConstructorExpression) SetType(t *TypeInfo) { sce.typeInfo = t } +func (sc *StructConstructor) TypeInfo() TypeInfo { return sc.typeInfo } -// AssignStatement represents variable assignment with optional type hint -type AssignStatement struct { - Name Expression // Changed from *Identifier to Expression for member access - TypeHint *TypeInfo // optional type hint - Value Expression - IsDeclaration bool // true if this is the first assignment in current scope +// Assignment represents both variable assignment statements and assignment expressions. +// Unified design reduces AST node count and simplifies type checking logic. +type Assignment struct { + Target Expression // Target (identifier, dot, or index expression) + Value Expression // Value being assigned + TypeHint TypeInfo // Optional explicit type hint, embeds directly + IsDeclaration bool // True if declaring new variable in current scope + IsExpression bool // True if used as expression (wrapped in parentheses) } -func (as *AssignStatement) statementNode() {} -func (as *AssignStatement) String() string { +func (a *Assignment) statementNode() {} +func (a *Assignment) expressionNode() {} +func (a *Assignment) String() string { prefix := "" - if as.IsDeclaration { + if a.IsDeclaration { prefix = "local " } var nameStr string - if as.TypeHint != nil { - nameStr = fmt.Sprintf("%s: %s", as.Name.String(), as.TypeHint.Type) + if a.TypeHint.Type != TypeUnknown { + nameStr = fmt.Sprintf("%s: %s", a.Target.String(), typeToString(a.TypeHint)) } else { - nameStr = as.Name.String() + nameStr = a.Target.String() } - return fmt.Sprintf("%s%s = %s", prefix, nameStr, as.Value.String()) + result := fmt.Sprintf("%s%s = %s", prefix, nameStr, a.Value.String()) + if a.IsExpression { + return "(" + result + ")" + } + return result } +func (a *Assignment) TypeInfo() TypeInfo { return a.Value.TypeInfo() } -// AssignExpression represents assignment as an expression (only in parentheses) -type AssignExpression struct { - Name Expression // Target (identifier, dot, or index expression) - Value Expression // Value to assign - IsDeclaration bool // true if this declares a new variable - typeInfo *TypeInfo // type of the expression (same as assigned value) -} - -func (ae *AssignExpression) expressionNode() {} -func (ae *AssignExpression) String() string { - return fmt.Sprintf("(%s = %s)", ae.Name.String(), ae.Value.String()) -} -func (ae *AssignExpression) GetType() *TypeInfo { return ae.typeInfo } -func (ae *AssignExpression) SetType(t *TypeInfo) { ae.typeInfo = t } - -// ExpressionStatement represents expressions used as statements +// ExpressionStatement wraps expressions used as statements. +// Allows function calls and other expressions at statement level. type ExpressionStatement struct { Expression Expression } @@ -152,7 +143,8 @@ func (es *ExpressionStatement) String() string { return es.Expression.String() } -// EchoStatement represents echo output statements +// EchoStatement represents output statements for displaying values. +// Simple debugging and output mechanism built into the language. type EchoStatement struct { Value Expression } @@ -162,17 +154,17 @@ func (es *EchoStatement) String() string { return fmt.Sprintf("echo %s", es.Value.String()) } -// BreakStatement represents break statements to exit loops +// BreakStatement represents loop exit statements. +// Simple marker node with no additional data needed. type BreakStatement struct{} func (bs *BreakStatement) statementNode() {} -func (bs *BreakStatement) String() string { - return "break" -} +func (bs *BreakStatement) String() string { return "break" } -// ExitStatement represents exit statements to quit the script +// ExitStatement represents script termination with optional exit code. +// Value expression is nil for plain "exit", non-nil for "exit ". type ExitStatement struct { - Value Expression // optional, can be nil + Value Expression // Optional exit code expression } func (es *ExitStatement) statementNode() {} @@ -183,9 +175,10 @@ func (es *ExitStatement) String() string { return fmt.Sprintf("exit %s", es.Value.String()) } -// ReturnStatement represents return statements +// ReturnStatement represents function return with optional value. +// Value expression is nil for plain "return", non-nil for "return ". type ReturnStatement struct { - Value Expression // optional, can be nil + Value Expression // Optional return value expression } func (rs *ReturnStatement) statementNode() {} @@ -196,7 +189,8 @@ func (rs *ReturnStatement) String() string { return fmt.Sprintf("return %s", rs.Value.String()) } -// ElseIfClause represents an elseif condition +// ElseIfClause represents conditional branches in if statements. +// Contains condition expression and body statements for this branch. type ElseIfClause struct { Condition Expression Body []Statement @@ -210,30 +204,28 @@ func (eic *ElseIfClause) String() string { return fmt.Sprintf("elseif %s then\n%s", eic.Condition.String(), body) } -// IfStatement represents conditional statements +// IfStatement represents conditional execution with optional elseif and else branches. +// Supports multiple elseif clauses and an optional final else clause. type IfStatement struct { - Condition Expression - Body []Statement - ElseIfs []ElseIfClause - Else []Statement + Condition Expression // Main condition + Body []Statement // Statements to execute if condition is true + ElseIfs []ElseIfClause // Optional elseif branches + Else []Statement // Optional else branch } func (is *IfStatement) statementNode() {} func (is *IfStatement) String() string { var result string - // If clause result += fmt.Sprintf("if %s then\n", is.Condition.String()) for _, stmt := range is.Body { result += "\t" + stmt.String() + "\n" } - // ElseIf clauses for _, elseif := range is.ElseIfs { result += elseif.String() } - // Else clause if len(is.Else) > 0 { result += "else\n" for _, stmt := range is.Else { @@ -245,7 +237,8 @@ func (is *IfStatement) String() string { return result } -// WhileStatement represents while loops: while condition do ... end +// WhileStatement represents condition-based loops that execute while condition is true. +// Contains condition expression and body statements to repeat. type WhileStatement struct { Condition Expression Body []Statement @@ -264,13 +257,14 @@ func (ws *WhileStatement) String() string { return result } -// ForStatement represents numeric for loops: for i = start, end, step do ... end +// ForStatement represents numeric for loops with start, end, and optional step. +// Variable is automatically scoped to the loop body. type ForStatement struct { - Variable *Identifier - Start Expression - End Expression - Step Expression // optional, nil means step of 1 - Body []Statement + Variable *Identifier // Loop variable (automatically number type) + Start Expression // Starting value expression + End Expression // Ending value expression + Step Expression // Optional step expression (nil means step of 1) + Body []Statement // Loop body statements } func (fs *ForStatement) statementNode() {} @@ -292,12 +286,13 @@ func (fs *ForStatement) String() string { return result } -// ForInStatement represents iterator for loops: for k, v in expr do ... end +// ForInStatement represents iterator-based loops over tables, arrays, or other iterables. +// Supports both single variable (for v in iter) and key-value (for k,v in iter) forms. type ForInStatement struct { - Key *Identifier // optional, nil for single variable iteration - Value *Identifier - Iterable Expression - Body []Statement + Key *Identifier // Optional key variable (nil for single variable iteration) + Value *Identifier // Value variable (required) + Iterable Expression // Expression to iterate over + Body []Statement // Loop body statements } func (fis *ForInStatement) statementNode() {} @@ -319,56 +314,60 @@ func (fis *ForInStatement) String() string { return result } -// FunctionParameter represents a function parameter with optional type hint +// FunctionParameter represents a parameter in function definitions. +// Contains parameter name and optional type hint for type checking. type FunctionParameter struct { Name string - TypeHint *TypeInfo + TypeHint TypeInfo // Optional type constraint, embeds directly } func (fp *FunctionParameter) String() string { - if fp.TypeHint != nil { - return fmt.Sprintf("%s: %s", fp.Name, fp.TypeHint.Type) + if fp.TypeHint.Type != TypeUnknown { + return fmt.Sprintf("%s: %s", fp.Name, typeToString(fp.TypeHint)) } return fp.Name } -// Identifier represents identifiers +// Identifier represents variable references and names. +// Stores resolved type information for efficient type checking. type Identifier struct { Value string - typeInfo *TypeInfo + typeInfo TypeInfo // Resolved type, embeds directly } -func (i *Identifier) expressionNode() {} -func (i *Identifier) String() string { return i.Value } -func (i *Identifier) GetType() *TypeInfo { return i.typeInfo } -func (i *Identifier) SetType(t *TypeInfo) { i.typeInfo = t } +func (i *Identifier) expressionNode() {} +func (i *Identifier) String() string { return i.Value } +func (i *Identifier) TypeInfo() TypeInfo { + if i.typeInfo.Type == TypeUnknown { + return AnyType + } + return i.typeInfo +} -// NumberLiteral represents numeric literals +// NumberLiteral represents numeric constants including integers, floats, hex, and binary. +// Always has number type, so no additional type storage needed. type NumberLiteral struct { - Value float64 - typeInfo *TypeInfo + Value float64 // All numbers stored as float64 for simplicity } -func (nl *NumberLiteral) expressionNode() {} -func (nl *NumberLiteral) String() string { return fmt.Sprintf("%.2f", nl.Value) } -func (nl *NumberLiteral) GetType() *TypeInfo { return nl.typeInfo } -func (nl *NumberLiteral) SetType(t *TypeInfo) { nl.typeInfo = t } +func (nl *NumberLiteral) expressionNode() {} +func (nl *NumberLiteral) String() string { return fmt.Sprintf("%.2f", nl.Value) } +func (nl *NumberLiteral) TypeInfo() TypeInfo { return NumberType } -// StringLiteral represents string literals +// StringLiteral represents string constants and multiline strings. +// Always has string type, so no additional type storage needed. type StringLiteral struct { - Value string - typeInfo *TypeInfo + Value string // String content without quotes } -func (sl *StringLiteral) expressionNode() {} -func (sl *StringLiteral) String() string { return fmt.Sprintf(`"%s"`, sl.Value) } -func (sl *StringLiteral) GetType() *TypeInfo { return sl.typeInfo } -func (sl *StringLiteral) SetType(t *TypeInfo) { sl.typeInfo = t } +func (sl *StringLiteral) expressionNode() {} +func (sl *StringLiteral) String() string { return fmt.Sprintf(`"%s"`, sl.Value) } +func (sl *StringLiteral) TypeInfo() TypeInfo { return StringType } -// BooleanLiteral represents boolean literals +// BooleanLiteral represents true and false constants. +// Always has bool type, so no additional type storage needed. type BooleanLiteral struct { - Value bool - typeInfo *TypeInfo + Value bool } func (bl *BooleanLiteral) expressionNode() {} @@ -378,26 +377,23 @@ func (bl *BooleanLiteral) String() string { } return "false" } -func (bl *BooleanLiteral) GetType() *TypeInfo { return bl.typeInfo } -func (bl *BooleanLiteral) SetType(t *TypeInfo) { bl.typeInfo = t } +func (bl *BooleanLiteral) TypeInfo() TypeInfo { return BoolType } -// NilLiteral represents nil literal -type NilLiteral struct { - typeInfo *TypeInfo -} +// NilLiteral represents the nil constant value. +// Always has nil type, so no additional type storage needed. +type NilLiteral struct{} -func (nl *NilLiteral) expressionNode() {} -func (nl *NilLiteral) String() string { return "nil" } -func (nl *NilLiteral) GetType() *TypeInfo { return nl.typeInfo } -func (nl *NilLiteral) SetType(t *TypeInfo) { nl.typeInfo = t } +func (nl *NilLiteral) expressionNode() {} +func (nl *NilLiteral) String() string { return "nil" } +func (nl *NilLiteral) TypeInfo() TypeInfo { return NilType } -// FunctionLiteral represents function literals with typed parameters +// FunctionLiteral represents function definitions with parameters, body, and optional return type. +// Always has function type, stores additional return type information separately. type FunctionLiteral struct { - Parameters []FunctionParameter - Variadic bool - ReturnType *TypeInfo // optional return type hint - Body []Statement - typeInfo *TypeInfo + Parameters []FunctionParameter // Function parameters with optional types + Body []Statement // Function body statements + ReturnType TypeInfo // Optional return type hint, embeds directly + Variadic bool // True if function accepts variable arguments } func (fl *FunctionLiteral) expressionNode() {} @@ -417,8 +413,8 @@ func (fl *FunctionLiteral) String() string { } result := fmt.Sprintf("fn(%s)", params) - if fl.ReturnType != nil { - result += ": " + fl.ReturnType.Type + if fl.ReturnType.Type != TypeUnknown { + result += ": " + typeToString(fl.ReturnType) } result += "\n" @@ -428,14 +424,14 @@ func (fl *FunctionLiteral) String() string { result += "end" return result } -func (fl *FunctionLiteral) GetType() *TypeInfo { return fl.typeInfo } -func (fl *FunctionLiteral) SetType(t *TypeInfo) { fl.typeInfo = t } +func (fl *FunctionLiteral) TypeInfo() TypeInfo { return FunctionType } -// CallExpression represents function calls: func(arg1, arg2, ...) +// CallExpression represents function calls with arguments. +// Stores inferred return type from function signature analysis. type CallExpression struct { - Function Expression - Arguments []Expression - typeInfo *TypeInfo + Function Expression // Function expression to call + Arguments []Expression // Argument expressions + typeInfo TypeInfo // Inferred return type, embeds directly } func (ce *CallExpression) expressionNode() {} @@ -446,74 +442,73 @@ func (ce *CallExpression) String() string { } return fmt.Sprintf("%s(%s)", ce.Function.String(), joinStrings(args, ", ")) } -func (ce *CallExpression) GetType() *TypeInfo { return ce.typeInfo } -func (ce *CallExpression) SetType(t *TypeInfo) { ce.typeInfo = t } +func (ce *CallExpression) TypeInfo() TypeInfo { return ce.typeInfo } -// PrefixExpression represents prefix operations like -x, not x +// PrefixExpression represents unary operations like negation and logical not. +// Stores result type based on operator and operand type analysis. type PrefixExpression struct { - Operator string - Right Expression - typeInfo *TypeInfo + Operator string // Operator symbol ("-", "not") + Right Expression // Operand expression + typeInfo TypeInfo // Result type, embeds directly } func (pe *PrefixExpression) expressionNode() {} func (pe *PrefixExpression) String() string { - // Add space for word operators if pe.Operator == "not" { return fmt.Sprintf("(%s %s)", pe.Operator, pe.Right.String()) } return fmt.Sprintf("(%s%s)", pe.Operator, pe.Right.String()) } -func (pe *PrefixExpression) GetType() *TypeInfo { return pe.typeInfo } -func (pe *PrefixExpression) SetType(t *TypeInfo) { pe.typeInfo = t } +func (pe *PrefixExpression) TypeInfo() TypeInfo { return pe.typeInfo } -// InfixExpression represents binary operations +// InfixExpression represents binary operations between two expressions. +// Stores result type based on operator and operand type compatibility. type InfixExpression struct { - Left Expression - Operator string - Right Expression - typeInfo *TypeInfo + Left Expression // Left operand + Right Expression // Right operand + Operator string // Operator symbol ("+", "-", "==", "and", etc.) + typeInfo TypeInfo // Result type, embeds directly } func (ie *InfixExpression) expressionNode() {} func (ie *InfixExpression) String() string { return fmt.Sprintf("(%s %s %s)", ie.Left.String(), ie.Operator, ie.Right.String()) } -func (ie *InfixExpression) GetType() *TypeInfo { return ie.typeInfo } -func (ie *InfixExpression) SetType(t *TypeInfo) { ie.typeInfo = t } +func (ie *InfixExpression) TypeInfo() TypeInfo { return ie.typeInfo } -// IndexExpression represents table[key] access +// IndexExpression represents bracket-based member access (table[key]). +// Stores inferred element type based on container type analysis. type IndexExpression struct { - Left Expression - Index Expression - typeInfo *TypeInfo + Left Expression // Container expression + Index Expression // Index/key expression + typeInfo TypeInfo // Element type, embeds directly } func (ie *IndexExpression) expressionNode() {} func (ie *IndexExpression) String() string { return fmt.Sprintf("%s[%s]", ie.Left.String(), ie.Index.String()) } -func (ie *IndexExpression) GetType() *TypeInfo { return ie.typeInfo } -func (ie *IndexExpression) SetType(t *TypeInfo) { ie.typeInfo = t } +func (ie *IndexExpression) TypeInfo() TypeInfo { return ie.typeInfo } -// DotExpression represents table.key access +// DotExpression represents dot-based member access (table.key). +// Stores inferred member type based on container type and field analysis. type DotExpression struct { - Left Expression - Key string - typeInfo *TypeInfo + Left Expression // Container expression + Key string // Member name + typeInfo TypeInfo // Member type, embeds directly } func (de *DotExpression) expressionNode() {} func (de *DotExpression) String() string { return fmt.Sprintf("%s.%s", de.Left.String(), de.Key) } -func (de *DotExpression) GetType() *TypeInfo { return de.typeInfo } -func (de *DotExpression) SetType(t *TypeInfo) { de.typeInfo = t } +func (de *DotExpression) TypeInfo() TypeInfo { return de.typeInfo } -// TablePair represents a key-value pair in a table +// TablePair represents key-value pairs in table literals and struct constructors. +// Key is nil for array-style elements, non-nil for object-style elements. type TablePair struct { - Key Expression // nil for array-style elements - Value Expression + Key Expression // Key expression (nil for array elements) + Value Expression // Value expression } func (tp *TablePair) String() string { @@ -523,10 +518,10 @@ func (tp *TablePair) String() string { return fmt.Sprintf("%s = %s", tp.Key.String(), tp.Value.String()) } -// TableLiteral represents table literals {} +// TableLiteral represents table/array/object literals with key-value pairs. +// Always has table type, provides methods to check if it's array-style. type TableLiteral struct { - Pairs []TablePair - typeInfo *TypeInfo + Pairs []TablePair // Key-value pairs (key nil for array elements) } func (tl *TableLiteral) expressionNode() {} @@ -537,10 +532,9 @@ func (tl *TableLiteral) String() string { } return fmt.Sprintf("{%s}", joinStrings(pairs, ", ")) } -func (tl *TableLiteral) GetType() *TypeInfo { return tl.typeInfo } -func (tl *TableLiteral) SetType(t *TypeInfo) { tl.typeInfo = t } +func (tl *TableLiteral) TypeInfo() TypeInfo { return TableType } -// IsArray returns true if this table contains only array-style elements +// IsArray returns true if this table contains only array-style elements (no explicit keys) func (tl *TableLiteral) IsArray() bool { for _, pair := range tl.Pairs { if pair.Key != nil { @@ -550,7 +544,31 @@ func (tl *TableLiteral) IsArray() bool { return true } -// joinStrings joins string slice with separator +// Helper function to convert TypeInfo to string representation +func typeToString(t TypeInfo) string { + switch t.Type { + case TypeNumber: + return "number" + case TypeString: + return "string" + case TypeBool: + return "bool" + case TypeNil: + return "nil" + case TypeTable: + return "table" + case TypeFunction: + return "function" + case TypeAny: + return "any" + case TypeStruct: + return fmt.Sprintf("struct<%d>", t.StructID) + default: + return "unknown" + } +} + +// joinStrings efficiently joins string slice with separator func joinStrings(strs []string, sep string) string { if len(strs) == 0 { return "" diff --git a/parser/parser.go b/parser/parser.go index b29d0d6..15dd89c 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -19,7 +19,7 @@ func (pe ParseError) Error() string { pe.Line, pe.Column, pe.Message, pe.Token.Literal) } -// Parser implements a recursive descent Pratt parser +// Parser implements a recursive descent Pratt parser with optimized AST generation type Parser struct { lexer *Lexer @@ -32,11 +32,13 @@ type Parser struct { errors []ParseError // Scope tracking - scopes []map[string]bool // stack of scopes, each tracking declared variables - scopeTypes []string // track what type each scope is: "global", "function", "loop" + scopes []map[string]bool + scopeTypes []string - // Struct tracking - structs map[string]*StructStatement // track defined structs + // Struct tracking with ID mapping + structs map[string]*StructStatement + structIDs map[uint16]*StructStatement + nextID uint16 } // NewParser creates a new parser instance @@ -44,9 +46,11 @@ func NewParser(lexer *Lexer) *Parser { p := &Parser{ lexer: lexer, errors: []ParseError{}, - scopes: []map[string]bool{make(map[string]bool)}, // start with global scope - scopeTypes: []string{"global"}, // start with global scope type - structs: make(map[string]*StructStatement), // track struct definitions + scopes: []map[string]bool{make(map[string]bool)}, + scopeTypes: []string{"global"}, + structs: make(map[string]*StructStatement), + structIDs: make(map[uint16]*StructStatement), + nextID: 1, // 0 reserved for non-struct types } p.prefixParseFns = make(map[TokenType]func() Expression) @@ -78,15 +82,31 @@ func NewParser(lexer *Lexer) *Parser { p.registerInfix(DOT, p.parseDotExpression) p.registerInfix(LBRACKET, p.parseIndexExpression) p.registerInfix(LPAREN, p.parseCallExpression) - p.registerInfix(LBRACE, p.parseStructConstructor) // struct constructor + p.registerInfix(LBRACE, p.parseStructConstructor) - // Read two tokens, so curToken and peekToken are both set p.nextToken() p.nextToken() return p } +// Struct management +func (p *Parser) registerStruct(stmt *StructStatement) { + stmt.ID = p.nextID + p.nextID++ + p.structs[stmt.Name] = stmt + p.structIDs[stmt.ID] = stmt +} + +func (p *Parser) getStructByName(name string) *StructStatement { + return p.structs[name] +} + +func (p *Parser) isStructDefined(name string) bool { + _, exists := p.structs[name] + return exists +} + // Scope management func (p *Parser) enterScope(scopeType string) { p.scopes = append(p.scopes, make(map[string]bool)) @@ -100,30 +120,6 @@ func (p *Parser) exitScope() { } } -func (p *Parser) enterFunctionScope() { - p.enterScope("function") -} - -func (p *Parser) exitFunctionScope() { - p.exitScope() -} - -func (p *Parser) enterLoopScope() { - p.enterScope("loop") -} - -func (p *Parser) exitLoopScope() { - p.exitScope() -} - -func (p *Parser) enterBlockScope() { - // Blocks don't create new variable scopes -} - -func (p *Parser) exitBlockScope() { - // No-op -} - func (p *Parser) currentVariableScope() map[string]bool { if len(p.scopeTypes) > 1 && p.scopeTypes[len(p.scopeTypes)-1] == "loop" { return p.scopes[len(p.scopes)-2] @@ -148,45 +144,56 @@ func (p *Parser) declareLoopVariable(name string) { p.scopes[len(p.scopes)-1][name] = true } -// parseTypeHint parses optional type hint after colon -func (p *Parser) parseTypeHint() *TypeInfo { +// parseTypeHint parses optional type hint after colon, returns by value +func (p *Parser) parseTypeHint() TypeInfo { if !p.peekTokenIs(COLON) { - return nil + return UnknownType } p.nextToken() // consume ':' if !p.expectPeekIdent() { p.addError("expected type name after ':'") - return nil + return UnknownType } typeName := p.curToken.Literal - if !ValidTypeName(typeName) && !p.isStructDefined(typeName) { + + // Check built-in types + switch typeName { + case "number": + return TypeInfo{Type: TypeNumber, Inferred: false} + case "string": + return TypeInfo{Type: TypeString, Inferred: false} + case "bool": + return TypeInfo{Type: TypeBool, Inferred: false} + case "nil": + return TypeInfo{Type: TypeNil, Inferred: false} + case "table": + return TypeInfo{Type: TypeTable, Inferred: false} + case "function": + return TypeInfo{Type: TypeFunction, Inferred: false} + case "any": + return TypeInfo{Type: TypeAny, Inferred: false} + default: + // Check if it's a struct type + if structDef := p.getStructByName(typeName); structDef != nil { + return TypeInfo{Type: TypeStruct, StructID: structDef.ID, Inferred: false} + } p.addError(fmt.Sprintf("invalid type name '%s'", typeName)) - return nil + return UnknownType } - - return &TypeInfo{Type: typeName, Inferred: false} } -// isStructDefined checks if a struct name is defined -func (p *Parser) isStructDefined(name string) bool { - _, exists := p.structs[name] - return exists -} - -// registerPrefix registers a prefix parse function +// registerPrefix/registerInfix func (p *Parser) registerPrefix(tokenType TokenType, fn func() Expression) { p.prefixParseFns[tokenType] = fn } -// registerInfix registers an infix parse function func (p *Parser) registerInfix(tokenType TokenType, fn func(Expression) Expression) { p.infixParseFns[tokenType] = fn } -// nextToken advances to the next token func (p *Parser) nextToken() { p.curToken = p.peekToken p.peekToken = p.lexer.NextToken() @@ -265,7 +272,7 @@ func (p *Parser) parseStructStatement() *StructStatement { if p.peekTokenIs(RBRACE) { p.nextToken() - p.structs[stmt.Name] = stmt + p.registerStruct(stmt) return stmt } @@ -284,9 +291,9 @@ func (p *Parser) parseStructStatement() *StructStatement { field := StructField{Name: p.curToken.Literal} - // Parse optional type hint + // Parse required type hint field.TypeHint = p.parseTypeHint() - if field.TypeHint == nil { + if field.TypeHint.Type == TypeUnknown { p.addError("struct fields require type annotation") return nil } @@ -314,7 +321,7 @@ func (p *Parser) parseStructStatement() *StructStatement { return nil } - p.structs[stmt.Name] = stmt + p.registerStruct(stmt) return stmt } @@ -338,12 +345,19 @@ func (p *Parser) parseFunctionStatement() Statement { methodName := p.curToken.Literal + // Get struct ID + structDef := p.getStructByName(funcName) + if structDef == nil { + p.addError(fmt.Sprintf("method defined on undefined struct '%s'", funcName)) + return nil + } + if !p.expectPeek(LPAREN) { p.addError("expected '(' after method name") return nil } - // Parse the function literal starting from parameters + // Parse the function literal funcLit := &FunctionLiteral{} funcLit.Parameters, funcLit.Variadic = p.parseFunctionParameters() @@ -357,12 +371,12 @@ func (p *Parser) parseFunctionStatement() Statement { p.nextToken() - p.enterFunctionScope() + p.enterScope("function") for _, param := range funcLit.Parameters { p.declareVariable(param.Name) } funcLit.Body = p.parseBlockStatements(END) - p.exitFunctionScope() + p.exitScope() if !p.curTokenIs(END) { p.addError("expected 'end' to close function") @@ -370,14 +384,13 @@ func (p *Parser) parseFunctionStatement() Statement { } return &MethodDefinition{ - StructName: funcName, + StructID: structDef.ID, MethodName: methodName, Function: funcLit, } } - // Regular function - this should be handled as expression statement - // Reset to handle as function literal + // Regular function - handle as function literal expression statement funcLit := p.parseFunctionLiteral() if funcLit == nil { return nil @@ -386,7 +399,7 @@ func (p *Parser) parseFunctionStatement() Statement { return &ExpressionStatement{Expression: funcLit} } -// parseIdentifierStatement handles both assignments and expression statements starting with identifiers +// parseIdentifierStatement handles assignments and expression statements func (p *Parser) parseIdentifierStatement() Statement { // Parse the left-hand side expression first expr := p.ParseExpression(LOWEST) @@ -395,28 +408,28 @@ func (p *Parser) parseIdentifierStatement() Statement { } // Check for type hint (only valid on simple identifiers) - var typeHint *TypeInfo + var typeHint TypeInfo = UnknownType if _, ok := expr.(*Identifier); ok { typeHint = p.parseTypeHint() } // Check if this is an assignment if p.peekTokenIs(ASSIGN) { - // Convert to assignment statement - stmt := &AssignStatement{ - Name: expr, + // Create unified assignment + assignment := &Assignment{ + Target: expr, TypeHint: typeHint, } // Validate assignment target and check if it's a declaration - switch name := expr.(type) { + switch target := expr.(type) { case *Identifier: - stmt.IsDeclaration = !p.isVariableDeclared(name.Value) - if stmt.IsDeclaration { - p.declareVariable(name.Value) + assignment.IsDeclaration = !p.isVariableDeclared(target.Value) + if assignment.IsDeclaration { + p.declareVariable(target.Value) } case *DotExpression, *IndexExpression: - stmt.IsDeclaration = false + assignment.IsDeclaration = false default: p.addError("invalid assignment target") return nil @@ -428,29 +441,19 @@ func (p *Parser) parseIdentifierStatement() Statement { p.nextToken() - stmt.Value = p.ParseExpression(LOWEST) - if stmt.Value == nil { + assignment.Value = p.ParseExpression(LOWEST) + if assignment.Value == nil { p.addError("expected expression after assignment operator") return nil } - return stmt + return assignment } else { // This is an expression statement return &ExpressionStatement{Expression: expr} } } -// parseExpressionStatement parses expressions used as statements -func (p *Parser) parseExpressionStatement() *ExpressionStatement { - stmt := &ExpressionStatement{} - stmt.Expression = p.ParseExpression(LOWEST) - if stmt.Expression == nil { - return nil - } - return stmt -} - // parseEchoStatement parses echo statements func (p *Parser) parseEchoStatement() *EchoStatement { stmt := &EchoStatement{} @@ -466,9 +469,8 @@ func (p *Parser) parseEchoStatement() *EchoStatement { return stmt } -// parseBreakStatement parses break statements +// Simple statement parsers func (p *Parser) parseBreakStatement() *BreakStatement { - // Check if break is followed by an identifier (invalid) if p.peekTokenIs(IDENT) { p.addError("unexpected identifier") return nil @@ -476,7 +478,6 @@ func (p *Parser) parseBreakStatement() *BreakStatement { return &BreakStatement{} } -// parseExitStatement parses exit statements func (p *Parser) parseExitStatement() *ExitStatement { stmt := &ExitStatement{} @@ -492,7 +493,6 @@ func (p *Parser) parseExitStatement() *ExitStatement { return stmt } -// parseReturnStatement parses return statements func (p *Parser) parseReturnStatement() *ReturnStatement { stmt := &ReturnStatement{} @@ -508,7 +508,6 @@ func (p *Parser) parseReturnStatement() *ReturnStatement { return stmt } -// canStartExpression checks if a token type can start an expression func (p *Parser) canStartExpression(tokenType TokenType) bool { switch tokenType { case IDENT, NUMBER, STRING, TRUE, FALSE, NIL, LPAREN, LBRACE, MINUS, NOT, FN: @@ -518,7 +517,7 @@ func (p *Parser) canStartExpression(tokenType TokenType) bool { } } -// parseWhileStatement parses while loops +// Loop statement parsers func (p *Parser) parseWhileStatement() *WhileStatement { stmt := &WhileStatement{} @@ -537,9 +536,7 @@ func (p *Parser) parseWhileStatement() *WhileStatement { p.nextToken() - p.enterBlockScope() stmt.Body = p.parseBlockStatements(END) - p.exitBlockScope() if !p.curTokenIs(END) { p.addError("expected 'end' to close while loop") @@ -549,7 +546,6 @@ func (p *Parser) parseWhileStatement() *WhileStatement { return stmt } -// parseForStatement parses for loops func (p *Parser) parseForStatement() Statement { p.nextToken() @@ -570,7 +566,6 @@ func (p *Parser) parseForStatement() Statement { } } -// parseNumericForStatement parses numeric for loops func (p *Parser) parseNumericForStatement(variable *Identifier) *ForStatement { stmt := &ForStatement{Variable: variable} @@ -617,10 +612,10 @@ func (p *Parser) parseNumericForStatement(variable *Identifier) *ForStatement { p.nextToken() - p.enterLoopScope() + p.enterScope("loop") p.declareLoopVariable(variable.Value) stmt.Body = p.parseBlockStatements(END) - p.exitLoopScope() + p.exitScope() if !p.curTokenIs(END) { p.addError("expected 'end' to close for loop") @@ -630,7 +625,6 @@ func (p *Parser) parseNumericForStatement(variable *Identifier) *ForStatement { return stmt } -// parseForInStatement parses for-in loops func (p *Parser) parseForInStatement(firstVar *Identifier) *ForInStatement { stmt := &ForInStatement{} @@ -669,13 +663,13 @@ func (p *Parser) parseForInStatement(firstVar *Identifier) *ForInStatement { p.nextToken() - p.enterLoopScope() + p.enterScope("loop") if stmt.Key != nil { p.declareLoopVariable(stmt.Key.Value) } p.declareLoopVariable(stmt.Value.Value) stmt.Body = p.parseBlockStatements(END) - p.exitLoopScope() + p.exitScope() if !p.curTokenIs(END) { p.addError("expected 'end' to close for loop") @@ -708,9 +702,7 @@ func (p *Parser) parseIfStatement() *IfStatement { return nil } - p.enterBlockScope() stmt.Body = p.parseBlockStatements(ELSEIF, ELSE, END) - p.exitBlockScope() for p.curTokenIs(ELSEIF) { elseif := ElseIfClause{} @@ -729,19 +721,13 @@ func (p *Parser) parseIfStatement() *IfStatement { p.nextToken() - p.enterBlockScope() elseif.Body = p.parseBlockStatements(ELSEIF, ELSE, END) - p.exitBlockScope() - stmt.ElseIfs = append(stmt.ElseIfs, elseif) } if p.curTokenIs(ELSE) { p.nextToken() - - p.enterBlockScope() stmt.Else = p.parseBlockStatements(END) - p.exitBlockScope() } if !p.curTokenIs(END) { @@ -754,7 +740,7 @@ func (p *Parser) parseIfStatement() *IfStatement { // parseBlockStatements parses statements until terminators func (p *Parser) parseBlockStatements(terminators ...TokenType) []Statement { - statements := []Statement{} + statements := make([]Statement, 0, 8) // Pre-allocate for performance for !p.curTokenIs(EOF) && !p.isTerminator(terminators...) { stmt := p.parseStatement() @@ -767,7 +753,6 @@ func (p *Parser) parseBlockStatements(terminators ...TokenType) []Statement { return statements } -// isTerminator checks if current token is a terminator func (p *Parser) isTerminator(terminators ...TokenType) bool { for _, terminator := range terminators { if p.curTokenIs(terminator) { @@ -919,7 +904,6 @@ func (p *Parser) parseGroupedExpression() Expression { // parseParenthesizedAssignment parses assignment expressions in parentheses func (p *Parser) parseParenthesizedAssignment() Expression { - // We're at identifier, peek is ASSIGN target := p.parseIdentifier() if !p.expectPeek(ASSIGN) { @@ -939,9 +923,10 @@ func (p *Parser) parseParenthesizedAssignment() Expression { } // Create assignment expression - assignExpr := &AssignExpression{ - Name: target, - Value: value, + assignExpr := &Assignment{ + Target: target, + Value: value, + IsExpression: true, } // Handle variable declaration for assignment expressions @@ -952,8 +937,6 @@ func (p *Parser) parseParenthesizedAssignment() Expression { } } - // Assignment expression evaluates to the assigned value - assignExpr.SetType(value.GetType()) return assignExpr } @@ -977,12 +960,12 @@ func (p *Parser) parseFunctionLiteral() Expression { p.nextToken() - p.enterFunctionScope() + p.enterScope("function") for _, param := range fn.Parameters { p.declareVariable(param.Name) } fn.Body = p.parseBlockStatements(END) - p.exitFunctionScope() + p.exitScope() if !p.curTokenIs(END) { p.addError("expected 'end' to close function") @@ -1038,7 +1021,7 @@ func (p *Parser) parseFunctionParameters() ([]FunctionParameter, bool) { func (p *Parser) parseTableLiteral() Expression { table := &TableLiteral{} - table.Pairs = []TablePair{} + table.Pairs = make([]TablePair, 0, 4) // Pre-allocate if p.peekTokenIs(RBRACE) { p.nextToken() @@ -1104,22 +1087,24 @@ func (p *Parser) parseTableLiteral() Expression { return table } -// parseStructConstructor handles struct constructor calls like my_type{...} +// parseStructConstructor handles struct constructor calls func (p *Parser) parseStructConstructor(left Expression) Expression { - // left should be an identifier representing the struct name ident, ok := left.(*Identifier) if !ok { - // Not an identifier, fall back to table literal parsing return p.parseTableLiteralFromBrace() } structName := ident.Value + structDef := p.getStructByName(structName) + if structDef == nil { + // Not a struct, parse as table literal + return p.parseTableLiteralFromBrace() + } - // Always try to parse as struct constructor if we have an identifier - // Type checking will catch undefined structs later - constructor := &StructConstructorExpression{ - StructName: structName, - Fields: []TablePair{}, + constructor := &StructConstructor{ + StructID: structDef.ID, + Fields: make([]TablePair, 0, 4), + typeInfo: TypeInfo{Type: TypeStruct, StructID: structDef.ID, Inferred: true}, } if p.peekTokenIs(RBRACE) { @@ -1187,9 +1172,8 @@ func (p *Parser) parseStructConstructor(left Expression) Expression { } func (p *Parser) parseTableLiteralFromBrace() Expression { - // We're already at the opening brace, so parse as table literal table := &TableLiteral{} - table.Pairs = []TablePair{} + table.Pairs = make([]TablePair, 0, 4) if p.peekTokenIs(RBRACE) { p.nextToken() @@ -1428,15 +1412,9 @@ func (p *Parser) curPrecedence() Precedence { return LOWEST } -// Errors returns all parsing errors -func (p *Parser) Errors() []ParseError { - return p.errors -} - -func (p *Parser) HasErrors() bool { - return len(p.errors) > 0 -} - +// Error reporting +func (p *Parser) Errors() []ParseError { return p.errors } +func (p *Parser) HasErrors() bool { return len(p.errors) > 0 } func (p *Parser) ErrorStrings() []string { result := make([]string, len(p.errors)) for i, err := range p.errors { diff --git a/parser/tests/assignments_test.go b/parser/tests/assignments_test.go index 3055b17..03e2079 100644 --- a/parser/tests/assignments_test.go +++ b/parser/tests/assignments_test.go @@ -31,15 +31,15 @@ func TestAssignStatements(t *testing.T) { t.Fatalf("expected 1 statement, got %d", len(program.Statements)) } - stmt, ok := program.Statements[0].(*parser.AssignStatement) + stmt, ok := program.Statements[0].(*parser.Assignment) if !ok { - t.Fatalf("expected AssignStatement, got %T", program.Statements[0]) + t.Fatalf("expected Assignment, got %T", program.Statements[0]) } - // Check that Name is an Identifier - ident, ok := stmt.Name.(*parser.Identifier) + // Check that Target is an Identifier + ident, ok := stmt.Target.(*parser.Identifier) if !ok { - t.Fatalf("expected Identifier for Name, got %T", stmt.Name) + t.Fatalf("expected Identifier for Target, got %T", stmt.Target) } if ident.Value != tt.expectedIdentifier { @@ -90,9 +90,9 @@ func TestMemberAccessAssignment(t *testing.T) { t.Fatalf("expected 1 statement, got %d", len(program.Statements)) } - stmt, ok := program.Statements[0].(*parser.AssignStatement) + stmt, ok := program.Statements[0].(*parser.Assignment) if !ok { - t.Fatalf("expected AssignStatement, got %T", program.Statements[0]) + t.Fatalf("expected Assignment, got %T", program.Statements[0]) } if stmt.String() != tt.expected { @@ -158,15 +158,15 @@ func TestTableAssignments(t *testing.T) { t.Fatalf("expected 1 statement, got %d", len(program.Statements)) } - stmt, ok := program.Statements[0].(*parser.AssignStatement) + stmt, ok := program.Statements[0].(*parser.Assignment) if !ok { - t.Fatalf("expected AssignStatement, got %T", program.Statements[0]) + t.Fatalf("expected Assignment, got %T", program.Statements[0]) } - // Check that Name is an Identifier - ident, ok := stmt.Name.(*parser.Identifier) + // Check that Target is an Identifier + ident, ok := stmt.Target.(*parser.Identifier) if !ok { - t.Fatalf("expected Identifier for Name, got %T", stmt.Name) + t.Fatalf("expected Identifier for Target, got %T", stmt.Target) } if ident.Value != tt.identifier { diff --git a/parser/tests/breakexit_test.go b/parser/tests/breakexit_test.go index f31f594..10cad50 100644 --- a/parser/tests/breakexit_test.go +++ b/parser/tests/breakexit_test.go @@ -247,9 +247,9 @@ exit "success"` } // First: assignment - _, ok := program.Statements[0].(*parser.AssignStatement) + _, ok := program.Statements[0].(*parser.Assignment) if !ok { - t.Fatalf("statement 0: expected AssignStatement, got %T", program.Statements[0]) + t.Fatalf("statement 0: expected Assignment, got %T", program.Statements[0]) } // Second: if statement with exit in body @@ -264,7 +264,7 @@ exit "success"` } // Third: assignment - _, ok = program.Statements[2].(*parser.AssignStatement) + _, ok = program.Statements[2].(*parser.Assignment) if !ok { t.Fatalf("statement 2: expected AssignStatement, got %T", program.Statements[2]) } diff --git a/parser/tests/conditionals_test.go b/parser/tests/conditionals_test.go index 6fdcfb8..e057b93 100644 --- a/parser/tests/conditionals_test.go +++ b/parser/tests/conditionals_test.go @@ -32,15 +32,15 @@ end` t.Fatalf("expected 1 body statement, got %d", len(stmt.Body)) } - bodyStmt, ok := stmt.Body[0].(*parser.AssignStatement) + bodyStmt, ok := stmt.Body[0].(*parser.Assignment) if !ok { t.Fatalf("expected AssignStatement in body, got %T", stmt.Body[0]) } - // Check that Name is an Identifier - ident, ok := bodyStmt.Name.(*parser.Identifier) + // Check that Target is an Identifier + ident, ok := bodyStmt.Target.(*parser.Identifier) if !ok { - t.Fatalf("expected Identifier for Name, got %T", bodyStmt.Name) + t.Fatalf("expected Identifier for Target, got %T", bodyStmt.Target) } if ident.Value != "x" { @@ -79,15 +79,15 @@ end` t.Fatalf("expected 1 else statement, got %d", len(stmt.Else)) } - elseStmt, ok := stmt.Else[0].(*parser.AssignStatement) + elseStmt, ok := stmt.Else[0].(*parser.Assignment) if !ok { t.Fatalf("expected AssignStatement in else, got %T", stmt.Else[0]) } // Check that Name is an Identifier - ident, ok := elseStmt.Name.(*parser.Identifier) + ident, ok := elseStmt.Target.(*parser.Identifier) if !ok { - t.Fatalf("expected Identifier for Name, got %T", elseStmt.Name) + t.Fatalf("expected Identifier for Name, got %T", elseStmt.Target) } if ident.Value != "x" { @@ -169,25 +169,25 @@ end` } // First assignment: arr[1] = "updated" - assign1, ok := stmt.Body[0].(*parser.AssignStatement) + assign1, ok := stmt.Body[0].(*parser.Assignment) if !ok { t.Fatalf("expected AssignStatement, got %T", stmt.Body[0]) } - _, ok = assign1.Name.(*parser.IndexExpression) + _, ok = assign1.Target.(*parser.IndexExpression) if !ok { - t.Fatalf("expected IndexExpression for assignment target, got %T", assign1.Name) + t.Fatalf("expected IndexExpression for assignment target, got %T", assign1.Target) } // Second assignment: obj.nested.count = obj.nested.count + 1 - assign2, ok := stmt.Body[1].(*parser.AssignStatement) + assign2, ok := stmt.Body[1].(*parser.Assignment) if !ok { t.Fatalf("expected AssignStatement, got %T", stmt.Body[1]) } - _, ok = assign2.Name.(*parser.DotExpression) + _, ok = assign2.Target.(*parser.DotExpression) if !ok { - t.Fatalf("expected DotExpression for assignment target, got %T", assign2.Name) + t.Fatalf("expected DotExpression for assignment target, got %T", assign2.Target) } } @@ -214,7 +214,7 @@ end` } // Test body has expression assignment - bodyStmt := stmt.Body[0].(*parser.AssignStatement) + bodyStmt := stmt.Body[0].(*parser.Assignment) bodyInfix, ok := bodyStmt.Value.(*parser.InfixExpression) if !ok { t.Fatalf("expected InfixExpression value, got %T", bodyStmt.Value) diff --git a/parser/tests/expressions_test.go b/parser/tests/expressions_test.go index 688738e..d270c45 100644 --- a/parser/tests/expressions_test.go +++ b/parser/tests/expressions_test.go @@ -352,15 +352,15 @@ func TestAssignmentExpressions(t *testing.T) { expr := p.ParseExpression(parser.LOWEST) checkParserErrors(t, p) - assignExpr, ok := expr.(*parser.AssignExpression) + assignExpr, ok := expr.(*parser.Assignment) if !ok { t.Fatalf("expected AssignExpression, got %T", expr) } // Test target name - ident, ok := assignExpr.Name.(*parser.Identifier) + ident, ok := assignExpr.Target.(*parser.Identifier) if !ok { - t.Fatalf("expected Identifier for assignment target, got %T", assignExpr.Name) + t.Fatalf("expected Identifier for assignment target, got %T", assignExpr.Target) } if ident.Value != tt.targetName { @@ -413,12 +413,12 @@ func TestAssignmentExpressionWithComplexExpressions(t *testing.T) { expr := p.ParseExpression(parser.LOWEST) checkParserErrors(t, p) - assignExpr, ok := expr.(*parser.AssignExpression) + assignExpr, ok := expr.(*parser.Assignment) if !ok { t.Fatalf("expected AssignExpression, got %T", expr) } - if assignExpr.Name == nil { + if assignExpr.Target == nil { t.Error("expected non-nil assignment target") } diff --git a/parser/tests/functions_test.go b/parser/tests/functions_test.go index 5dd4c46..f08e383 100644 --- a/parser/tests/functions_test.go +++ b/parser/tests/functions_test.go @@ -160,7 +160,7 @@ func TestFunctionAssignments(t *testing.T) { t.Fatalf("expected 1 statement, got %d", len(program.Statements)) } - stmt, ok := program.Statements[0].(*parser.AssignStatement) + stmt, ok := program.Statements[0].(*parser.Assignment) if !ok { t.Fatalf("expected AssignStatement, got %T", program.Statements[0]) } @@ -296,7 +296,7 @@ end` } // First statement: assignment of inner function - assign, ok := fn.Body[0].(*parser.AssignStatement) + assign, ok := fn.Body[0].(*parser.Assignment) if !ok { t.Fatalf("expected AssignStatement, got %T", fn.Body[0]) } @@ -342,7 +342,7 @@ end` } // First: function assignment - assign, ok := forStmt.Body[0].(*parser.AssignStatement) + assign, ok := forStmt.Body[0].(*parser.Assignment) if !ok { t.Fatalf("expected AssignStatement, got %T", forStmt.Body[0]) } @@ -551,7 +551,7 @@ echo adder` } // First: table with functions - mathAssign, ok := program.Statements[0].(*parser.AssignStatement) + mathAssign, ok := program.Statements[0].(*parser.Assignment) if !ok { t.Fatalf("statement 0: expected AssignStatement, got %T", program.Statements[0]) } @@ -574,7 +574,7 @@ echo adder` } // Second: result assignment (function call would be handled by interpreter) - _, ok = program.Statements[1].(*parser.AssignStatement) + _, ok = program.Statements[1].(*parser.Assignment) if !ok { t.Fatalf("statement 1: expected AssignStatement, got %T", program.Statements[1]) } @@ -586,7 +586,7 @@ echo adder` } // Fourth: calculator function assignment - calcAssign, ok := program.Statements[3].(*parser.AssignStatement) + calcAssign, ok := program.Statements[3].(*parser.Assignment) if !ok { t.Fatalf("statement 3: expected AssignStatement, got %T", program.Statements[3]) } @@ -601,7 +601,7 @@ echo adder` } // Fifth: adder assignment - _, ok = program.Statements[4].(*parser.AssignStatement) + _, ok = program.Statements[4].(*parser.Assignment) if !ok { t.Fatalf("statement 4: expected AssignStatement, got %T", program.Statements[4]) } @@ -645,7 +645,7 @@ end` } // Check if body has function assignment - ifAssign, ok := ifStmt.Body[0].(*parser.AssignStatement) + ifAssign, ok := ifStmt.Body[0].(*parser.Assignment) if !ok { t.Fatalf("if body: expected AssignStatement, got %T", ifStmt.Body[0]) } @@ -656,7 +656,7 @@ end` } // Check else body has function assignment - elseAssign, ok := ifStmt.Else[0].(*parser.AssignStatement) + elseAssign, ok := ifStmt.Else[0].(*parser.Assignment) if !ok { t.Fatalf("else body: expected AssignStatement, got %T", ifStmt.Else[0]) } @@ -683,7 +683,7 @@ end` } // Verify both branches assign functions - nestedIfAssign, ok := nestedIf.Body[0].(*parser.AssignStatement) + nestedIfAssign, ok := nestedIf.Body[0].(*parser.Assignment) if !ok { t.Fatalf("nested if body: expected AssignStatement, got %T", nestedIf.Body[0]) } @@ -693,7 +693,7 @@ end` t.Fatalf("nested if body: expected FunctionLiteral, got %T", nestedIfAssign.Value) } - nestedElseAssign, ok := nestedIf.Else[0].(*parser.AssignStatement) + nestedElseAssign, ok := nestedIf.Else[0].(*parser.Assignment) if !ok { t.Fatalf("nested else body: expected AssignStatement, got %T", nestedIf.Else[0]) } diff --git a/parser/tests/loops_test.go b/parser/tests/loops_test.go index 9d3f847..5d6fc08 100644 --- a/parser/tests/loops_test.go +++ b/parser/tests/loops_test.go @@ -327,13 +327,13 @@ end` } // First: table assignment - _, ok := program.Statements[0].(*parser.AssignStatement) + _, ok := program.Statements[0].(*parser.Assignment) if !ok { t.Fatalf("statement 0: expected AssignStatement, got %T", program.Statements[0]) } // Second: variable assignment - _, ok = program.Statements[1].(*parser.AssignStatement) + _, ok = program.Statements[1].(*parser.Assignment) if !ok { t.Fatalf("statement 1: expected AssignStatement, got %T", program.Statements[1]) } @@ -590,7 +590,7 @@ end` } // Second body statement should be assignment - _, ok = outerWhile.Body[1].(*parser.AssignStatement) + _, ok = outerWhile.Body[1].(*parser.Assignment) if !ok { t.Fatalf("expected AssignStatement, got %T", outerWhile.Body[1]) } @@ -634,14 +634,14 @@ end` } // First assignment: data[index] = ... - assign1, ok := stmt.Body[0].(*parser.AssignStatement) + assign1, ok := stmt.Body[0].(*parser.Assignment) if !ok { t.Fatalf("expected AssignStatement, got %T", stmt.Body[0]) } - _, ok = assign1.Name.(*parser.IndexExpression) + _, ok = assign1.Target.(*parser.IndexExpression) if !ok { - t.Fatalf("expected IndexExpression for assignment target, got %T", assign1.Name) + t.Fatalf("expected IndexExpression for assignment target, got %T", assign1.Target) } } @@ -755,7 +755,7 @@ end` // First three: assignments for i := 0; i < 3; i++ { - _, ok := program.Statements[i].(*parser.AssignStatement) + _, ok := program.Statements[i].(*parser.Assignment) if !ok { t.Fatalf("statement %d: expected AssignStatement, got %T", i, program.Statements[i]) } @@ -838,7 +838,7 @@ end` } // Fourth: assignment - _, ok = whileStmt.Body[3].(*parser.AssignStatement) + _, ok = whileStmt.Body[3].(*parser.Assignment) if !ok { t.Fatalf("body[3]: expected AssignStatement, got %T", whileStmt.Body[3]) } diff --git a/parser/tests/parser_test.go b/parser/tests/parser_test.go index fd5a355..55add71 100644 --- a/parser/tests/parser_test.go +++ b/parser/tests/parser_test.go @@ -22,14 +22,14 @@ z = true + false` expectedIdentifiers := []string{"x", "y", "z"} for i, expectedIdent := range expectedIdentifiers { - stmt, ok := program.Statements[i].(*parser.AssignStatement) + stmt, ok := program.Statements[i].(*parser.Assignment) if !ok { t.Fatalf("statement %d: expected AssignStatement, got %T", i, program.Statements[i]) } - ident, ok := stmt.Name.(*parser.Identifier) + ident, ok := stmt.Target.(*parser.Identifier) if !ok { - t.Fatalf("statement %d: expected Identifier for Name, got %T", i, stmt.Name) + t.Fatalf("statement %d: expected Identifier for Name, got %T", i, stmt.Target) } if ident.Value != expectedIdent { @@ -58,13 +58,13 @@ arr = {a = 1, b = 2}` } // First statement: assignment - stmt1, ok := program.Statements[0].(*parser.AssignStatement) + stmt1, ok := program.Statements[0].(*parser.Assignment) if !ok { t.Fatalf("statement 0: expected AssignStatement, got %T", program.Statements[0]) } - ident1, ok := stmt1.Name.(*parser.Identifier) + ident1, ok := stmt1.Target.(*parser.Identifier) if !ok { - t.Fatalf("expected Identifier for Name, got %T", stmt1.Name) + t.Fatalf("expected Identifier for Name, got %T", stmt1.Target) } if ident1.Value != "x" { t.Errorf("expected identifier 'x', got %s", ident1.Value) @@ -80,7 +80,7 @@ arr = {a = 1, b = 2}` } // Third statement: table assignment - stmt3, ok := program.Statements[2].(*parser.AssignStatement) + stmt3, ok := program.Statements[2].(*parser.Assignment) if !ok { t.Fatalf("statement 2: expected AssignStatement, got %T", program.Statements[2]) } @@ -110,33 +110,33 @@ echo table[table.key]` } // Second statement: dot assignment - stmt2, ok := program.Statements[1].(*parser.AssignStatement) + stmt2, ok := program.Statements[1].(*parser.Assignment) if !ok { t.Fatalf("statement 1: expected AssignStatement, got %T", program.Statements[1]) } - _, ok = stmt2.Name.(*parser.DotExpression) + _, ok = stmt2.Target.(*parser.DotExpression) if !ok { - t.Fatalf("expected DotExpression for assignment target, got %T", stmt2.Name) + t.Fatalf("expected DotExpression for assignment target, got %T", stmt2.Target) } // Third statement: bracket assignment - stmt3, ok := program.Statements[2].(*parser.AssignStatement) + stmt3, ok := program.Statements[2].(*parser.Assignment) if !ok { t.Fatalf("statement 2: expected AssignStatement, got %T", program.Statements[2]) } - _, ok = stmt3.Name.(*parser.IndexExpression) + _, ok = stmt3.Target.(*parser.IndexExpression) if !ok { - t.Fatalf("expected IndexExpression for assignment target, got %T", stmt3.Name) + t.Fatalf("expected IndexExpression for assignment target, got %T", stmt3.Target) } // Fourth statement: chained dot assignment - stmt4, ok := program.Statements[3].(*parser.AssignStatement) + stmt4, ok := program.Statements[3].(*parser.Assignment) if !ok { t.Fatalf("statement 3: expected AssignStatement, got %T", program.Statements[3]) } - _, ok = stmt4.Name.(*parser.DotExpression) + _, ok = stmt4.Target.(*parser.DotExpression) if !ok { - t.Fatalf("expected DotExpression for assignment target, got %T", stmt4.Name) + t.Fatalf("expected DotExpression for assignment target, got %T", stmt4.Target) } // Fifth statement: echo with nested access @@ -232,7 +232,7 @@ end` } // First statement: complex expression assignment - stmt1, ok := program.Statements[0].(*parser.AssignStatement) + stmt1, ok := program.Statements[0].(*parser.Assignment) if !ok { t.Fatalf("expected AssignStatement, got %T", program.Statements[0]) } @@ -253,7 +253,7 @@ end` t.Fatalf("expected 1 body statement, got %d", len(stmt2.Body)) } - bodyStmt, ok := stmt2.Body[0].(*parser.AssignStatement) + bodyStmt, ok := stmt2.Body[0].(*parser.Assignment) if !ok { t.Fatalf("expected AssignStatement in body, got %T", stmt2.Body[0]) } @@ -286,7 +286,7 @@ echo {result = x}` } // First: assignment - _, ok := program.Statements[0].(*parser.AssignStatement) + _, ok := program.Statements[0].(*parser.Assignment) if !ok { t.Fatalf("statement 0: expected AssignStatement, got %T", program.Statements[0]) } diff --git a/parser/tests/scope_test.go b/parser/tests/scope_test.go index 397c957..dc585d4 100644 --- a/parser/tests/scope_test.go +++ b/parser/tests/scope_test.go @@ -63,7 +63,7 @@ x = 15`, assignmentCount := 0 for _, stmt := range program.Statements { - if assign, ok := stmt.(*parser.AssignStatement); ok { + if assign, ok := stmt.(*parser.Assignment); ok { if assignmentCount >= len(tt.assignments) { t.Fatalf("more assignments than expected") } @@ -71,9 +71,9 @@ x = 15`, expected := tt.assignments[assignmentCount] // Check variable name - ident, ok := assign.Name.(*parser.Identifier) + ident, ok := assign.Target.(*parser.Identifier) if !ok { - t.Fatalf("expected Identifier, got %T", assign.Name) + t.Fatalf("expected Identifier, got %T", assign.Target) } if ident.Value != expected.variable { @@ -135,9 +135,9 @@ z = 30` for i, expected := range expectedAssignments { assign := assignments[i] - ident, ok := assign.Name.(*parser.Identifier) + ident, ok := assign.Target.(*parser.Identifier) if !ok { - t.Fatalf("assignment %d: expected Identifier, got %T", i, assign.Name) + t.Fatalf("assignment %d: expected Identifier, got %T", i, assign.Target) } if ident.Value != expected.variable { @@ -191,9 +191,9 @@ end` for i, expected := range expectedAssignments { assign := assignments[i] - ident, ok := assign.Name.(*parser.Identifier) + ident, ok := assign.Target.(*parser.Identifier) if !ok { - t.Fatalf("assignment %d: expected Identifier, got %T", i, assign.Name) + t.Fatalf("assignment %d: expected Identifier, got %T", i, assign.Target) } if ident.Value != expected.variable { @@ -243,9 +243,9 @@ c = 20` for i, expected := range expectedAssignments { assign := assignments[i] - ident, ok := assign.Name.(*parser.Identifier) + ident, ok := assign.Target.(*parser.Identifier) if !ok { - t.Fatalf("assignment %d: expected Identifier, got %T", i, assign.Name) + t.Fatalf("assignment %d: expected Identifier, got %T", i, assign.Target) } if ident.Value != expected.variable { @@ -344,9 +344,9 @@ count = 0`, for i, expected := range tt.assignments { assign := assignments[i] - ident, ok := assign.Name.(*parser.Identifier) + ident, ok := assign.Target.(*parser.Identifier) if !ok { - t.Fatalf("assignment %d: expected Identifier, got %T", i, assign.Name) + t.Fatalf("assignment %d: expected Identifier, got %T", i, assign.Target) } if ident.Value != expected.variable { @@ -389,7 +389,7 @@ arr[1] = 99` assignmentCount := 0 for _, stmt := range program.Statements { - if assign, ok := stmt.(*parser.AssignStatement); ok { + if assign, ok := stmt.(*parser.Assignment); ok { if assignmentCount >= len(expectedAssignments) { t.Fatalf("more assignments than expected") } @@ -398,7 +398,7 @@ arr[1] = 99` if expected.isMemberAccess { // Should not be an identifier - if _, ok := assign.Name.(*parser.Identifier); ok { + if _, ok := assign.Target.(*parser.Identifier); ok { t.Errorf("assignment %d: expected member access, got Identifier", assignmentCount) } @@ -408,9 +408,9 @@ arr[1] = 99` } } else { // Should be an identifier - ident, ok := assign.Name.(*parser.Identifier) + ident, ok := assign.Target.(*parser.Identifier) if !ok { - t.Errorf("assignment %d: expected Identifier, got %T", assignmentCount, assign.Name) + t.Errorf("assignment %d: expected Identifier, got %T", assignmentCount, assign.Target) } else if ident.Value != expected.variable { t.Errorf("assignment %d: expected variable %s, got %s", assignmentCount, expected.variable, ident.Value) @@ -487,9 +487,9 @@ local_var = "global_local"` for i, expected := range expectedAssignments { assign := assignments[i] - ident, ok := assign.Name.(*parser.Identifier) + ident, ok := assign.Target.(*parser.Identifier) if !ok { - t.Fatalf("assignment %d: expected Identifier, got %T", i, assign.Name) + t.Fatalf("assignment %d: expected Identifier, got %T", i, assign.Target) } if ident.Value != expected.variable { @@ -536,8 +536,8 @@ y = 20.00` } // Helper function to extract all assignments from a program recursively -func extractAssignments(program *parser.Program) []*parser.AssignStatement { - var assignments []*parser.AssignStatement +func extractAssignments(program *parser.Program) []*parser.Assignment { + var assignments []*parser.Assignment for _, stmt := range program.Statements { assignments = append(assignments, extractAssignmentsFromStatement(stmt)...) @@ -546,11 +546,11 @@ func extractAssignments(program *parser.Program) []*parser.AssignStatement { return assignments } -func extractAssignmentsFromStatement(stmt parser.Statement) []*parser.AssignStatement { - var assignments []*parser.AssignStatement +func extractAssignmentsFromStatement(stmt parser.Statement) []*parser.Assignment { + var assignments []*parser.Assignment switch s := stmt.(type) { - case *parser.AssignStatement: + case *parser.Assignment: assignments = append(assignments, s) // Check if the value is a function literal with assignments in body diff --git a/parser/tests/structs_test.go b/parser/tests/structs_test.go index 36a3394..67a4f08 100644 --- a/parser/tests/structs_test.go +++ b/parser/tests/structs_test.go @@ -38,22 +38,22 @@ func TestBasicStructDefinition(t *testing.T) { if stmt.Fields[0].Name != "name" { t.Errorf("expected field name 'name', got %s", stmt.Fields[0].Name) } - if stmt.Fields[0].TypeHint == nil { + if stmt.Fields[0].TypeHint.Type == parser.TypeUnknown { t.Fatal("expected type hint for name field") } - if stmt.Fields[0].TypeHint.Type != "string" { - t.Errorf("expected type 'string', got %s", stmt.Fields[0].TypeHint.Type) + if stmt.Fields[0].TypeHint.Type != parser.TypeString { + t.Errorf("expected type string, got %v", stmt.Fields[0].TypeHint.Type) } // Test second field if stmt.Fields[1].Name != "age" { t.Errorf("expected field name 'age', got %s", stmt.Fields[1].Name) } - if stmt.Fields[1].TypeHint == nil { + if stmt.Fields[1].TypeHint.Type == parser.TypeUnknown { t.Fatal("expected type hint for age field") } - if stmt.Fields[1].TypeHint.Type != "number" { - t.Errorf("expected type 'number', got %s", stmt.Fields[1].TypeHint.Type) + if stmt.Fields[1].TypeHint.Type != parser.TypeNumber { + t.Errorf("expected type number, got %v", stmt.Fields[1].TypeHint.Type) } } @@ -107,7 +107,7 @@ func TestComplexStructDefinition(t *testing.T) { t.Fatalf("expected StructStatement, got %T", program.Statements[0]) } - expectedTypes := []string{"number", "string", "bool", "table", "function", "any"} + expectedTypes := []parser.Type{parser.TypeNumber, parser.TypeString, parser.TypeBool, parser.TypeTable, parser.TypeFunction, parser.TypeAny} expectedNames := []string{"id", "name", "active", "data", "callback", "optional"} if len(stmt.Fields) != len(expectedTypes) { @@ -118,11 +118,11 @@ func TestComplexStructDefinition(t *testing.T) { if field.Name != expectedNames[i] { t.Errorf("field %d: expected name '%s', got '%s'", i, expectedNames[i], field.Name) } - if field.TypeHint == nil { + if field.TypeHint.Type == parser.TypeUnknown { t.Fatalf("field %d: expected type hint", i) } if field.TypeHint.Type != expectedTypes[i] { - t.Errorf("field %d: expected type '%s', got '%s'", i, expectedTypes[i], field.TypeHint.Type) + t.Errorf("field %d: expected type %v, got %v", i, expectedTypes[i], field.TypeHint.Type) } } } @@ -164,17 +164,17 @@ end` if !ok { t.Fatalf("expected MethodDefinition, got %T", program.Statements[1]) } - if method1.StructName != "Person" { - t.Errorf("expected struct name 'Person', got %s", method1.StructName) + if method1.StructID != structStmt.ID { + t.Errorf("expected struct ID %d, got %d", structStmt.ID, method1.StructID) } if method1.MethodName != "getName" { t.Errorf("expected method name 'getName', got %s", method1.MethodName) } - if method1.Function.ReturnType == nil { + if method1.Function.ReturnType.Type == parser.TypeUnknown { t.Fatal("expected return type for getName method") } - if method1.Function.ReturnType.Type != "string" { - t.Errorf("expected return type 'string', got %s", method1.Function.ReturnType.Type) + if method1.Function.ReturnType.Type != parser.TypeString { + t.Errorf("expected return type string, got %v", method1.Function.ReturnType.Type) } if len(method1.Function.Parameters) != 0 { t.Errorf("expected 0 parameters, got %d", len(method1.Function.Parameters)) @@ -185,14 +185,14 @@ end` if !ok { t.Fatalf("expected MethodDefinition, got %T", program.Statements[2]) } - if method2.StructName != "Person" { - t.Errorf("expected struct name 'Person', got %s", method2.StructName) + if method2.StructID != structStmt.ID { + t.Errorf("expected struct ID %d, got %d", structStmt.ID, method2.StructID) } if method2.MethodName != "setAge" { t.Errorf("expected method name 'setAge', got %s", method2.MethodName) } - if method2.Function.ReturnType != nil { - t.Errorf("expected no return type for setAge method, got %s", method2.Function.ReturnType.Type) + if method2.Function.ReturnType.Type != parser.TypeUnknown { + t.Errorf("expected no return type for setAge method, got %v", method2.Function.ReturnType.Type) } if len(method2.Function.Parameters) != 1 { t.Fatalf("expected 1 parameter, got %d", len(method2.Function.Parameters)) @@ -200,11 +200,11 @@ end` if method2.Function.Parameters[0].Name != "newAge" { t.Errorf("expected parameter name 'newAge', got %s", method2.Function.Parameters[0].Name) } - if method2.Function.Parameters[0].TypeHint == nil { + if method2.Function.Parameters[0].TypeHint.Type == parser.TypeUnknown { t.Fatal("expected type hint for newAge parameter") } - if method2.Function.Parameters[0].TypeHint.Type != "number" { - t.Errorf("expected parameter type 'number', got %s", method2.Function.Parameters[0].TypeHint.Type) + if method2.Function.Parameters[0].TypeHint.Type != parser.TypeNumber { + t.Errorf("expected parameter type number, got %v", method2.Function.Parameters[0].TypeHint.Type) } } @@ -226,19 +226,21 @@ empty = Person{}` t.Fatalf("expected 3 statements, got %d", len(program.Statements)) } + structStmt := program.Statements[0].(*parser.StructStatement) + // Second statement: constructor with fields - assign1, ok := program.Statements[1].(*parser.AssignStatement) + assign1, ok := program.Statements[1].(*parser.Assignment) if !ok { - t.Fatalf("expected AssignStatement, got %T", program.Statements[1]) + t.Fatalf("expected Assignment, got %T", program.Statements[1]) } - constructor1, ok := assign1.Value.(*parser.StructConstructorExpression) + constructor1, ok := assign1.Value.(*parser.StructConstructor) if !ok { - t.Fatalf("expected StructConstructorExpression, got %T", assign1.Value) + t.Fatalf("expected StructConstructor, got %T", assign1.Value) } - if constructor1.StructName != "Person" { - t.Errorf("expected struct name 'Person', got %s", constructor1.StructName) + if constructor1.StructID != structStmt.ID { + t.Errorf("expected struct ID %d, got %d", structStmt.ID, constructor1.StructID) } if len(constructor1.Fields) != 2 { @@ -266,18 +268,18 @@ empty = Person{}` testNumberLiteral(t, constructor1.Fields[1].Value, 30) // Third statement: empty constructor - assign2, ok := program.Statements[2].(*parser.AssignStatement) + assign2, ok := program.Statements[2].(*parser.Assignment) if !ok { - t.Fatalf("expected AssignStatement, got %T", program.Statements[2]) + t.Fatalf("expected Assignment, got %T", program.Statements[2]) } - constructor2, ok := assign2.Value.(*parser.StructConstructorExpression) + constructor2, ok := assign2.Value.(*parser.StructConstructor) if !ok { - t.Fatalf("expected StructConstructorExpression, got %T", assign2.Value) + t.Fatalf("expected StructConstructor, got %T", assign2.Value) } - if constructor2.StructName != "Person" { - t.Errorf("expected struct name 'Person', got %s", constructor2.StructName) + if constructor2.StructID != structStmt.ID { + t.Errorf("expected struct ID %d, got %d", structStmt.ID, constructor2.StructID) } if len(constructor2.Fields) != 0 { @@ -310,6 +312,8 @@ person = Person{ t.Fatalf("expected 3 statements, got %d", len(program.Statements)) } + addressStruct := program.Statements[0].(*parser.StructStatement) + // Check Person struct has Address field type personStruct, ok := program.Statements[1].(*parser.StructStatement) if !ok { @@ -320,29 +324,32 @@ person = Person{ if addressField.Name != "address" { t.Errorf("expected field name 'address', got %s", addressField.Name) } - if addressField.TypeHint.Type != "Address" { - t.Errorf("expected field type 'Address', got %s", addressField.TypeHint.Type) + if addressField.TypeHint.Type != parser.TypeStruct { + t.Errorf("expected field type struct, got %v", addressField.TypeHint.Type) + } + if addressField.TypeHint.StructID != addressStruct.ID { + t.Errorf("expected struct ID %d, got %d", addressStruct.ID, addressField.TypeHint.StructID) } // Check nested constructor - assign, ok := program.Statements[2].(*parser.AssignStatement) + assign, ok := program.Statements[2].(*parser.Assignment) if !ok { - t.Fatalf("expected AssignStatement, got %T", program.Statements[2]) + t.Fatalf("expected Assignment, got %T", program.Statements[2]) } - personConstructor, ok := assign.Value.(*parser.StructConstructorExpression) + personConstructor, ok := assign.Value.(*parser.StructConstructor) if !ok { - t.Fatalf("expected StructConstructorExpression, got %T", assign.Value) + t.Fatalf("expected StructConstructor, got %T", assign.Value) } // Check the nested Address constructor - addressConstructor, ok := personConstructor.Fields[1].Value.(*parser.StructConstructorExpression) + addressConstructor, ok := personConstructor.Fields[1].Value.(*parser.StructConstructor) if !ok { - t.Fatalf("expected nested StructConstructorExpression, got %T", personConstructor.Fields[1].Value) + t.Fatalf("expected nested StructConstructor, got %T", personConstructor.Fields[1].Value) } - if addressConstructor.StructName != "Address" { - t.Errorf("expected nested struct name 'Address', got %s", addressConstructor.StructName) + if addressConstructor.StructID != addressStruct.ID { + t.Errorf("expected nested struct ID %d, got %d", addressStruct.ID, addressConstructor.StructID) } if len(addressConstructor.Fields) != 2 { @@ -397,8 +404,8 @@ end` if !ok { t.Fatalf("expected MethodDefinition, got %T", program.Statements[1]) } - if methodStmt.StructName != "Point" { - t.Errorf("expected struct name 'Point', got %s", methodStmt.StructName) + if methodStmt.StructID != structStmt.ID { + t.Errorf("expected struct ID %d, got %d", structStmt.ID, methodStmt.StructID) } if methodStmt.MethodName != "distance" { t.Errorf("expected method name 'distance', got %s", methodStmt.MethodName) @@ -406,16 +413,16 @@ end` // Verify constructors for i := 2; i <= 3; i++ { - assign, ok := program.Statements[i].(*parser.AssignStatement) + assign, ok := program.Statements[i].(*parser.Assignment) if !ok { - t.Fatalf("statement %d: expected AssignStatement, got %T", i, program.Statements[i]) + t.Fatalf("statement %d: expected Assignment, got %T", i, program.Statements[i]) } - constructor, ok := assign.Value.(*parser.StructConstructorExpression) + constructor, ok := assign.Value.(*parser.StructConstructor) if !ok { - t.Fatalf("statement %d: expected StructConstructorExpression, got %T", i, assign.Value) + t.Fatalf("statement %d: expected StructConstructor, got %T", i, assign.Value) } - if constructor.StructName != "Point" { - t.Errorf("statement %d: expected struct name 'Point', got %s", i, constructor.StructName) + if constructor.StructID != structStmt.ID { + t.Errorf("statement %d: expected struct ID %d, got %d", i, structStmt.ID, constructor.StructID) } } @@ -446,16 +453,16 @@ end` } // Check struct constructor in loop - loopAssign, ok := forStmt.Body[0].(*parser.AssignStatement) + loopAssign, ok := forStmt.Body[0].(*parser.Assignment) if !ok { - t.Fatalf("expected AssignStatement in loop, got %T", forStmt.Body[0]) + t.Fatalf("expected Assignment in loop, got %T", forStmt.Body[0]) } - loopConstructor, ok := loopAssign.Value.(*parser.StructConstructorExpression) + loopConstructor, ok := loopAssign.Value.(*parser.StructConstructor) if !ok { - t.Fatalf("expected StructConstructorExpression in loop, got %T", loopAssign.Value) + t.Fatalf("expected StructConstructor in loop, got %T", loopAssign.Value) } - if loopConstructor.StructName != "Point" { - t.Errorf("expected struct name 'Point' in loop, got %s", loopConstructor.StructName) + if loopConstructor.StructID != structStmt.ID { + t.Errorf("expected struct ID %d in loop, got %d", structStmt.ID, loopConstructor.StructID) } } @@ -552,13 +559,13 @@ func TestSingleLineStruct(t *testing.T) { t.Fatalf("expected 2 fields, got %d", len(stmt.Fields)) } - if stmt.Fields[0].Name != "name" || stmt.Fields[0].TypeHint.Type != "string" { - t.Errorf("expected first field 'name: string', got '%s: %s'", + if stmt.Fields[0].Name != "name" || stmt.Fields[0].TypeHint.Type != parser.TypeString { + t.Errorf("expected first field 'name: string', got '%s: %v'", stmt.Fields[0].Name, stmt.Fields[0].TypeHint.Type) } - if stmt.Fields[1].Name != "age" || stmt.Fields[1].TypeHint.Type != "number" { - t.Errorf("expected second field 'age: number', got '%s: %s'", + if stmt.Fields[1].Name != "age" || stmt.Fields[1].TypeHint.Type != parser.TypeNumber { + t.Errorf("expected second field 'age: number', got '%s: %v'", stmt.Fields[1].Name, stmt.Fields[1].TypeHint.Type) } } @@ -600,8 +607,8 @@ end` method := program.Statements[1].(*parser.MethodDefinition) str := method.String() - if !containsSubstring(str, "fn Person.getName") { - t.Errorf("expected method string to contain 'fn Person.getName', got: %s", str) + if !containsSubstring(str, "fn .getName") { + t.Errorf("expected method string to contain 'fn .getName', got: %s", str) } if !containsSubstring(str, ": string") { t.Errorf("expected method string to contain return type, got: %s", str) @@ -621,11 +628,11 @@ person = Person{name = "John", age = 30}` program := p.ParseProgram() checkParserErrors(t, p) - assign := program.Statements[1].(*parser.AssignStatement) - constructor := assign.Value.(*parser.StructConstructorExpression) + assign := program.Statements[1].(*parser.Assignment) + constructor := assign.Value.(*parser.StructConstructor) str := constructor.String() - expected := `Person{name = "John", age = 30.00}` + expected := `{name = "John", age = 30.00}` if str != expected { t.Errorf("expected constructor string:\n%s\ngot:\n%s", expected, str) } diff --git a/parser/tests/types_test.go b/parser/tests/types_test.go index d949450..f10b4fe 100644 --- a/parser/tests/types_test.go +++ b/parser/tests/types_test.go @@ -10,18 +10,18 @@ func TestVariableTypeHints(t *testing.T) { tests := []struct { input string variable string - typeHint string + typeHint parser.Type hasHint bool desc string }{ - {"x = 42", "x", "", false, "no type hint"}, - {"x: number = 42", "x", "number", true, "number type hint"}, - {"name: string = \"hello\"", "name", "string", true, "string type hint"}, - {"flag: bool = true", "flag", "bool", true, "bool type hint"}, - {"data: table = {}", "data", "table", true, "table type hint"}, - {"fn_var: function = fn() end", "fn_var", "function", true, "function type hint"}, - {"value: any = nil", "value", "any", true, "any type hint"}, - {"ptr: nil = nil", "ptr", "nil", true, "nil type hint"}, + {"x = 42", "x", parser.TypeUnknown, false, "no type hint"}, + {"x: number = 42", "x", parser.TypeNumber, true, "number type hint"}, + {"name: string = \"hello\"", "name", parser.TypeString, true, "string type hint"}, + {"flag: bool = true", "flag", parser.TypeBool, true, "bool type hint"}, + {"data: table = {}", "data", parser.TypeTable, true, "table type hint"}, + {"fn_var: function = fn() end", "fn_var", parser.TypeFunction, true, "function type hint"}, + {"value: any = nil", "value", parser.TypeAny, true, "any type hint"}, + {"ptr: nil = nil", "ptr", parser.TypeNil, true, "nil type hint"}, } for _, tt := range tests { @@ -35,15 +35,15 @@ func TestVariableTypeHints(t *testing.T) { t.Fatalf("expected 1 statement, got %d", len(program.Statements)) } - stmt, ok := program.Statements[0].(*parser.AssignStatement) + stmt, ok := program.Statements[0].(*parser.Assignment) if !ok { - t.Fatalf("expected AssignStatement, got %T", program.Statements[0]) + t.Fatalf("expected Assignment, got %T", program.Statements[0]) } // Check variable name - ident, ok := stmt.Name.(*parser.Identifier) + ident, ok := stmt.Target.(*parser.Identifier) if !ok { - t.Fatalf("expected Identifier for Name, got %T", stmt.Name) + t.Fatalf("expected Identifier for Target, got %T", stmt.Target) } if ident.Value != tt.variable { @@ -52,19 +52,19 @@ func TestVariableTypeHints(t *testing.T) { // Check type hint if tt.hasHint { - if stmt.TypeHint == nil { - t.Error("expected type hint but got nil") + if stmt.TypeHint.Type == parser.TypeUnknown { + t.Error("expected type hint but got TypeUnknown") } else { if stmt.TypeHint.Type != tt.typeHint { - t.Errorf("expected type hint %s, got %s", tt.typeHint, stmt.TypeHint.Type) + t.Errorf("expected type hint %v, got %v", tt.typeHint, stmt.TypeHint.Type) } if stmt.TypeHint.Inferred { t.Error("expected type hint to not be inferred") } } } else { - if stmt.TypeHint != nil { - t.Errorf("expected no type hint but got %s", stmt.TypeHint.Type) + if stmt.TypeHint.Type != parser.TypeUnknown { + t.Errorf("expected no type hint but got %v", stmt.TypeHint.Type) } } }) @@ -73,61 +73,82 @@ func TestVariableTypeHints(t *testing.T) { func TestFunctionParameterTypeHints(t *testing.T) { tests := []struct { - input string - params []struct{ name, typeHint string } - returnType string + input string + params []struct { + name string + typeHint parser.Type + } + returnType parser.Type hasReturn bool desc string }{ { "fn(a, b) end", - []struct{ name, typeHint string }{ - {"a", ""}, - {"b", ""}, + []struct { + name string + typeHint parser.Type + }{ + {"a", parser.TypeUnknown}, + {"b", parser.TypeUnknown}, }, - "", false, + parser.TypeUnknown, false, "no type hints", }, { "fn(a: number, b: string) end", - []struct{ name, typeHint string }{ - {"a", "number"}, - {"b", "string"}, + []struct { + name string + typeHint parser.Type + }{ + {"a", parser.TypeNumber}, + {"b", parser.TypeString}, }, - "", false, + parser.TypeUnknown, false, "parameter type hints only", }, { "fn(x: number): string end", - []struct{ name, typeHint string }{ - {"x", "number"}, + []struct { + name string + typeHint parser.Type + }{ + {"x", parser.TypeNumber}, }, - "string", true, + parser.TypeString, true, "parameter and return type hints", }, { "fn(): bool end", - []struct{ name, typeHint string }{}, - "bool", true, + []struct { + name string + typeHint parser.Type + }{}, + parser.TypeBool, true, "return type hint only", }, { "fn(a: number, b, c: string): table end", - []struct{ name, typeHint string }{ - {"a", "number"}, - {"b", ""}, - {"c", "string"}, + []struct { + name string + typeHint parser.Type + }{ + {"a", parser.TypeNumber}, + {"b", parser.TypeUnknown}, + {"c", parser.TypeString}, }, - "table", true, + parser.TypeTable, true, "mixed parameter types with return", }, { "fn(callback: function, data: any): nil end", - []struct{ name, typeHint string }{ - {"callback", "function"}, - {"data", "any"}, + []struct { + name string + typeHint parser.Type + }{ + {"callback", parser.TypeFunction}, + {"data", parser.TypeAny}, }, - "nil", true, + parser.TypeNil, true, "function and any types", }, } @@ -155,29 +176,29 @@ func TestFunctionParameterTypeHints(t *testing.T) { t.Errorf("parameter %d: expected name %s, got %s", i, expected.name, param.Name) } - if expected.typeHint == "" { - if param.TypeHint != nil { - t.Errorf("parameter %d: expected no type hint but got %s", i, param.TypeHint.Type) + if expected.typeHint == parser.TypeUnknown { + if param.TypeHint.Type != parser.TypeUnknown { + t.Errorf("parameter %d: expected no type hint but got %v", i, param.TypeHint.Type) } } else { - if param.TypeHint == nil { - t.Errorf("parameter %d: expected type hint %s but got nil", i, expected.typeHint) + if param.TypeHint.Type == parser.TypeUnknown { + t.Errorf("parameter %d: expected type hint %v but got TypeUnknown", i, expected.typeHint) } else if param.TypeHint.Type != expected.typeHint { - t.Errorf("parameter %d: expected type hint %s, got %s", i, expected.typeHint, param.TypeHint.Type) + t.Errorf("parameter %d: expected type hint %v, got %v", i, expected.typeHint, param.TypeHint.Type) } } } // Check return type if tt.hasReturn { - if fn.ReturnType == nil { - t.Error("expected return type hint but got nil") + if fn.ReturnType.Type == parser.TypeUnknown { + t.Error("expected return type hint but got TypeUnknown") } else if fn.ReturnType.Type != tt.returnType { - t.Errorf("expected return type %s, got %s", tt.returnType, fn.ReturnType.Type) + t.Errorf("expected return type %v, got %v", tt.returnType, fn.ReturnType.Type) } } else { - if fn.ReturnType != nil { - t.Errorf("expected no return type but got %s", fn.ReturnType.Type) + if fn.ReturnType.Type != parser.TypeUnknown { + t.Errorf("expected no return type but got %v", fn.ReturnType.Type) } } }) @@ -279,13 +300,13 @@ func TestMemberAccessWithoutTypeHints(t *testing.T) { t.Fatalf("expected 1 statement, got %d", len(program.Statements)) } - stmt, ok := program.Statements[0].(*parser.AssignStatement) + stmt, ok := program.Statements[0].(*parser.Assignment) if !ok { - t.Fatalf("expected AssignStatement, got %T", program.Statements[0]) + t.Fatalf("expected Assignment, got %T", program.Statements[0]) } // Member access should never have type hints - if stmt.TypeHint != nil { + if stmt.TypeHint.Type != parser.TypeUnknown { t.Error("member access assignment should not have type hints") } @@ -333,12 +354,12 @@ func TestTypeInferenceErrors(t *testing.T) { }{ { "x: number = \"hello\"", - "cannot assign string to variable of type number", + "type mismatch in assignment", "type mismatch in assignment", }, { "x = 42\ny: string = x", - "cannot assign number to variable of type string", + "type mismatch in assignment", "type mismatch with inferred type", }, } @@ -359,7 +380,7 @@ func TestTypeInferenceErrors(t *testing.T) { found := false for _, err := range typeErrors { - if err.Message == tt.expectedError { + if containsSubstring(err.Message, tt.expectedError) { found = true break } @@ -370,7 +391,7 @@ func TestTypeInferenceErrors(t *testing.T) { for i, err := range typeErrors { errorMsgs[i] = err.Message } - t.Errorf("expected error %q, got %v", tt.expectedError, errorMsgs) + t.Errorf("expected error containing %q, got %v", tt.expectedError, errorMsgs) } }) } @@ -439,22 +460,22 @@ server: table = { } // Check first statement: config table with typed assignments - configStmt, ok := program.Statements[0].(*parser.AssignStatement) + configStmt, ok := program.Statements[0].(*parser.Assignment) if !ok { - t.Fatalf("expected AssignStatement, got %T", program.Statements[0]) + t.Fatalf("expected Assignment, got %T", program.Statements[0]) } - if configStmt.TypeHint == nil || configStmt.TypeHint.Type != "table" { + if configStmt.TypeHint.Type == parser.TypeUnknown || configStmt.TypeHint.Type != parser.TypeTable { t.Error("expected table type hint for config") } // Check second statement: handler function with typed parameters - handlerStmt, ok := program.Statements[1].(*parser.AssignStatement) + handlerStmt, ok := program.Statements[1].(*parser.Assignment) if !ok { - t.Fatalf("expected AssignStatement, got %T", program.Statements[1]) + t.Fatalf("expected Assignment, got %T", program.Statements[1]) } - if handlerStmt.TypeHint == nil || handlerStmt.TypeHint.Type != "function" { + if handlerStmt.TypeHint.Type == parser.TypeUnknown || handlerStmt.TypeHint.Type != parser.TypeFunction { t.Error("expected function type hint for handler") } @@ -468,34 +489,32 @@ server: table = { } // Check parameter types - if fn.Parameters[0].TypeHint == nil || fn.Parameters[0].TypeHint.Type != "table" { + if fn.Parameters[0].TypeHint.Type == parser.TypeUnknown || fn.Parameters[0].TypeHint.Type != parser.TypeTable { t.Error("expected table type for request parameter") } - if fn.Parameters[1].TypeHint == nil || fn.Parameters[1].TypeHint.Type != "function" { + if fn.Parameters[1].TypeHint.Type == parser.TypeUnknown || fn.Parameters[1].TypeHint.Type != parser.TypeFunction { t.Error("expected function type for callback parameter") } // Check return type - if fn.ReturnType == nil || fn.ReturnType.Type != "nil" { + if fn.ReturnType.Type == parser.TypeUnknown || fn.ReturnType.Type != parser.TypeNil { t.Error("expected nil return type for handler") } // Check third statement: server table - serverStmt, ok := program.Statements[2].(*parser.AssignStatement) + serverStmt, ok := program.Statements[2].(*parser.Assignment) if !ok { - t.Fatalf("expected AssignStatement, got %T", program.Statements[2]) + t.Fatalf("expected Assignment, got %T", program.Statements[2]) } - if serverStmt.TypeHint == nil || serverStmt.TypeHint.Type != "table" { + if serverStmt.TypeHint.Type == parser.TypeUnknown || serverStmt.TypeHint.Type != parser.TypeTable { t.Error("expected table type hint for server") } } -func TestTypeInfoGettersSetters(t *testing.T) { - // Test that all expression types properly implement GetType/SetType - typeInfo := &parser.TypeInfo{Type: "test", Inferred: true} - +func TestTypeInfoInterface(t *testing.T) { + // Test that all expression types properly implement TypeInfo() expressions := []parser.Expression{ &parser.Identifier{Value: "x"}, &parser.NumberLiteral{Value: 42}, @@ -513,20 +532,43 @@ func TestTypeInfoGettersSetters(t *testing.T) { for i, expr := range expressions { t.Run(string(rune('0'+i)), func(t *testing.T) { - // Initially should have no type - if expr.GetType() != nil { - t.Error("expected nil type initially") - } + // Should have default type initially + typeInfo := expr.TypeInfo() - // Set type - expr.SetType(typeInfo) - - // Get type should return what we set - retrieved := expr.GetType() - if retrieved == nil { - t.Error("expected non-nil type after setting") - } else if retrieved.Type != "test" || !retrieved.Inferred { - t.Errorf("expected {Type: test, Inferred: true}, got %+v", retrieved) + // Basic literals should have their expected types + switch e := expr.(type) { + case *parser.NumberLiteral: + if typeInfo.Type != parser.TypeNumber { + t.Errorf("expected number type, got %v", typeInfo.Type) + } + case *parser.StringLiteral: + if typeInfo.Type != parser.TypeString { + t.Errorf("expected string type, got %v", typeInfo.Type) + } + case *parser.BooleanLiteral: + if typeInfo.Type != parser.TypeBool { + t.Errorf("expected bool type, got %v", typeInfo.Type) + } + case *parser.NilLiteral: + if typeInfo.Type != parser.TypeNil { + t.Errorf("expected nil type, got %v", typeInfo.Type) + } + case *parser.TableLiteral: + if typeInfo.Type != parser.TypeTable { + t.Errorf("expected table type, got %v", typeInfo.Type) + } + case *parser.FunctionLiteral: + if typeInfo.Type != parser.TypeFunction { + t.Errorf("expected function type, got %v", typeInfo.Type) + } + case *parser.Identifier: + // Identifiers default to any type + if typeInfo.Type != parser.TypeAny { + t.Errorf("expected any type for untyped identifier, got %v", typeInfo.Type) + } + default: + // Other expressions may have unknown type initially + _ = e } }) } diff --git a/parser/types.go b/parser/types.go index 7bce16c..1dc4ea6 100644 --- a/parser/types.go +++ b/parser/types.go @@ -1,21 +1,45 @@ package parser -import ( - "fmt" -) +import "fmt" + +// Type represents built-in and user-defined types using compact enum representation. +// Uses single byte instead of string pointers to minimize memory usage. +type Type uint8 -// Type constants for built-in types const ( - TypeNumber = "number" - TypeString = "string" - TypeBool = "bool" - TypeNil = "nil" - TypeTable = "table" - TypeFunction = "function" - TypeAny = "any" + TypeUnknown Type = iota + TypeNumber + TypeString + TypeBool + TypeNil + TypeTable + TypeFunction + TypeAny + TypeStruct // struct types use StructID field for identification ) -// TypeError represents a type checking error +// TypeInfo represents type information with zero-allocation design. +// Embeds directly in AST nodes instead of using pointers to reduce heap pressure. +// Common types are pre-allocated as globals to eliminate most allocations. +type TypeInfo struct { + Type Type // Built-in type or TypeStruct for user types + Inferred bool // True if type was inferred, false if explicitly declared + StructID uint16 // Index into global struct table for struct types (0 for non-structs) +} + +// Pre-allocated common types - eliminates heap allocations for built-in types +var ( + UnknownType = TypeInfo{Type: TypeUnknown, Inferred: true} + NumberType = TypeInfo{Type: TypeNumber, Inferred: true} + StringType = TypeInfo{Type: TypeString, Inferred: true} + BoolType = TypeInfo{Type: TypeBool, Inferred: true} + NilType = TypeInfo{Type: TypeNil, Inferred: true} + TableType = TypeInfo{Type: TypeTable, Inferred: true} + FunctionType = TypeInfo{Type: TypeFunction, Inferred: true} + AnyType = TypeInfo{Type: TypeAny, Inferred: true} +) + +// TypeError represents a type checking error with location information type TypeError struct { Message string Line int @@ -27,10 +51,10 @@ func (te TypeError) Error() string { return fmt.Sprintf("Type error at line %d, column %d: %s", te.Line, te.Column, te.Message) } -// Symbol represents a variable in the symbol table +// Symbol represents a variable in the symbol table with optimized type storage type Symbol struct { Name string - Type *TypeInfo + Type TypeInfo // Embed directly instead of pointer Declared bool Line int Column int @@ -63,52 +87,56 @@ func (s *Scope) Lookup(name string) *Symbol { return nil } -// TypeInferrer performs type inference and checking +// TypeInferrer performs type inference and checking with optimized allocations type TypeInferrer struct { currentScope *Scope globalScope *Scope errors []TypeError - // Pre-allocated type objects for performance - numberType *TypeInfo - stringType *TypeInfo - boolType *TypeInfo - nilType *TypeInfo - tableType *TypeInfo - anyType *TypeInfo - - // Struct definitions - structs map[string]*StructStatement + // Struct definitions with ID mapping + structs map[string]*StructStatement + structIDs map[uint16]*StructStatement + nextID uint16 } // NewTypeInferrer creates a new type inference engine func NewTypeInferrer() *TypeInferrer { globalScope := NewScope(nil) - ti := &TypeInferrer{ + return &TypeInferrer{ currentScope: globalScope, globalScope: globalScope, errors: []TypeError{}, structs: make(map[string]*StructStatement), - - // Pre-allocate common types to reduce allocations - numberType: &TypeInfo{Type: TypeNumber, Inferred: true}, - stringType: &TypeInfo{Type: TypeString, Inferred: true}, - boolType: &TypeInfo{Type: TypeBool, Inferred: true}, - nilType: &TypeInfo{Type: TypeNil, Inferred: true}, - tableType: &TypeInfo{Type: TypeTable, Inferred: true}, - anyType: &TypeInfo{Type: TypeAny, Inferred: true}, + structIDs: make(map[uint16]*StructStatement), + nextID: 1, // 0 reserved for non-struct types } +} - return ti +// RegisterStruct assigns ID to struct and tracks it +func (ti *TypeInferrer) RegisterStruct(stmt *StructStatement) { + stmt.ID = ti.nextID + ti.nextID++ + ti.structs[stmt.Name] = stmt + ti.structIDs[stmt.ID] = stmt +} + +// GetStructByID returns struct definition by ID +func (ti *TypeInferrer) GetStructByID(id uint16) *StructStatement { + return ti.structIDs[id] +} + +// CreateStructType creates TypeInfo for a struct +func (ti *TypeInferrer) CreateStructType(structID uint16) TypeInfo { + return TypeInfo{Type: TypeStruct, StructID: structID, Inferred: true} } // InferTypes performs type inference on the entire program func (ti *TypeInferrer) InferTypes(program *Program) []TypeError { - // First pass: collect struct definitions + // First pass: collect and register struct definitions for _, stmt := range program.Statements { if structStmt, ok := stmt.(*StructStatement); ok { - ti.structs[structStmt.Name] = structStmt + ti.RegisterStruct(structStmt) } } @@ -119,42 +147,6 @@ func (ti *TypeInferrer) InferTypes(program *Program) []TypeError { return ti.errors } -// enterScope creates a new scope -func (ti *TypeInferrer) enterScope() { - ti.currentScope = NewScope(ti.currentScope) -} - -// exitScope returns to the parent scope -func (ti *TypeInferrer) exitScope() { - if ti.currentScope.parent != nil { - ti.currentScope = ti.currentScope.parent - } -} - -// addError adds a type error -func (ti *TypeInferrer) addError(message string, node Node) { - ti.errors = append(ti.errors, TypeError{ - Message: message, - Line: 0, // Would need to track position in AST nodes - Column: 0, - Node: node, - }) -} - -// getStructType returns TypeInfo for a struct -func (ti *TypeInferrer) getStructType(name string) *TypeInfo { - if _, exists := ti.structs[name]; exists { - return &TypeInfo{Type: name, Inferred: true} - } - return nil -} - -// isStructType checks if a type is a struct type -func (ti *TypeInferrer) isStructType(t *TypeInfo) bool { - _, exists := ti.structs[t.Type] - return exists -} - // inferStatement infers types for statements func (ti *TypeInferrer) inferStatement(stmt Statement) { switch s := stmt.(type) { @@ -162,8 +154,10 @@ func (ti *TypeInferrer) inferStatement(stmt Statement) { ti.inferStructStatement(s) case *MethodDefinition: ti.inferMethodDefinition(s) - case *AssignStatement: - ti.inferAssignStatement(s) + case *Assignment: + ti.inferAssignment(s) + case *ExpressionStatement: + ti.inferExpression(s.Expression) case *EchoStatement: ti.inferExpression(s.Value) case *IfStatement: @@ -182,46 +176,38 @@ func (ti *TypeInferrer) inferStatement(stmt Statement) { if s.Value != nil { ti.inferExpression(s.Value) } - case *ExpressionStatement: - ti.inferExpression(s.Expression) + case *BreakStatement: + // No-op } } -// inferStructStatement handles struct definitions func (ti *TypeInferrer) inferStructStatement(stmt *StructStatement) { - // Validate field types for _, field := range stmt.Fields { - if field.TypeHint != nil { - if !ValidTypeName(field.TypeHint.Type) && !ti.isStructType(field.TypeHint) { - ti.addError(fmt.Sprintf("invalid field type '%s' in struct '%s'", - field.TypeHint.Type, stmt.Name), stmt) - } + if !ti.isValidType(field.TypeHint) { + ti.addError(fmt.Sprintf("invalid field type in struct '%s'", stmt.Name), stmt) } } } -// inferMethodDefinition handles method definitions func (ti *TypeInferrer) inferMethodDefinition(stmt *MethodDefinition) { - // Check if struct exists - if _, exists := ti.structs[stmt.StructName]; !exists { - ti.addError(fmt.Sprintf("method defined on undefined struct '%s'", stmt.StructName), stmt) + structDef := ti.GetStructByID(stmt.StructID) + if structDef == nil { + ti.addError("method defined on undefined struct", stmt) return } - // Infer the function body ti.enterScope() - - // Add self parameter implicitly + // Add self parameter ti.currentScope.Define(&Symbol{ Name: "self", - Type: ti.getStructType(stmt.StructName), + Type: ti.CreateStructType(stmt.StructID), Declared: true, }) - // Add explicit parameters + // Add function parameters for _, param := range stmt.Function.Parameters { - paramType := ti.anyType - if param.TypeHint != nil { + paramType := AnyType + if param.TypeHint.Type != TypeUnknown { paramType = param.TypeHint } ti.currentScope.Define(&Symbol{ @@ -235,66 +221,47 @@ func (ti *TypeInferrer) inferMethodDefinition(stmt *MethodDefinition) { for _, bodyStmt := range stmt.Function.Body { ti.inferStatement(bodyStmt) } - ti.exitScope() } -// inferAssignStatement handles variable assignments with type checking -func (ti *TypeInferrer) inferAssignStatement(stmt *AssignStatement) { - // Infer the type of the value expression +func (ti *TypeInferrer) inferAssignment(stmt *Assignment) { valueType := ti.inferExpression(stmt.Value) - if ident, ok := stmt.Name.(*Identifier); ok { - // Simple variable assignment - symbol := ti.currentScope.Lookup(ident.Value) - + if ident, ok := stmt.Target.(*Identifier); ok { if stmt.IsDeclaration { - // New variable declaration varType := valueType - - // If there's a type hint, validate it - if stmt.TypeHint != nil { + if stmt.TypeHint.Type != TypeUnknown { if !ti.isTypeCompatible(valueType, stmt.TypeHint) { - ti.addError(fmt.Sprintf("cannot assign %s to variable of type %s", - valueType.Type, stmt.TypeHint.Type), stmt) + ti.addError("type mismatch in assignment", stmt) } varType = stmt.TypeHint - varType.Inferred = false } - // Define the new symbol ti.currentScope.Define(&Symbol{ Name: ident.Value, Type: varType, Declared: true, }) - - ident.SetType(varType) + ident.typeInfo = varType } else { - // Assignment to existing variable + symbol := ti.currentScope.Lookup(ident.Value) if symbol == nil { ti.addError(fmt.Sprintf("undefined variable '%s'", ident.Value), stmt) return } - - // Check type compatibility if !ti.isTypeCompatible(valueType, symbol.Type) { - ti.addError(fmt.Sprintf("cannot assign %s to variable of type %s", - valueType.Type, symbol.Type.Type), stmt) + ti.addError("type mismatch in assignment", stmt) } - - ident.SetType(symbol.Type) + ident.typeInfo = symbol.Type } } else { // Member access assignment (table.key or table[index]) - ti.inferExpression(stmt.Name) + ti.inferExpression(stmt.Target) } } -// inferIfStatement handles if statements func (ti *TypeInferrer) inferIfStatement(stmt *IfStatement) { - condType := ti.inferExpression(stmt.Condition) - ti.validateBooleanContext(condType, stmt.Condition) + ti.inferExpression(stmt.Condition) ti.enterScope() for _, s := range stmt.Body { @@ -303,9 +270,7 @@ func (ti *TypeInferrer) inferIfStatement(stmt *IfStatement) { ti.exitScope() for _, elseif := range stmt.ElseIfs { - condType := ti.inferExpression(elseif.Condition) - ti.validateBooleanContext(condType, elseif.Condition) - + ti.inferExpression(elseif.Condition) ti.enterScope() for _, s := range elseif.Body { ti.inferStatement(s) @@ -322,10 +287,8 @@ func (ti *TypeInferrer) inferIfStatement(stmt *IfStatement) { } } -// inferWhileStatement handles while loops func (ti *TypeInferrer) inferWhileStatement(stmt *WhileStatement) { - condType := ti.inferExpression(stmt.Condition) - ti.validateBooleanContext(condType, stmt.Condition) + ti.inferExpression(stmt.Condition) ti.enterScope() for _, s := range stmt.Body { @@ -334,33 +297,21 @@ func (ti *TypeInferrer) inferWhileStatement(stmt *WhileStatement) { ti.exitScope() } -// inferForStatement handles numeric for loops func (ti *TypeInferrer) inferForStatement(stmt *ForStatement) { - startType := ti.inferExpression(stmt.Start) - endType := ti.inferExpression(stmt.End) - - if !ti.isNumericType(startType) { - ti.addError("for loop start value must be numeric", stmt.Start) - } - if !ti.isNumericType(endType) { - ti.addError("for loop end value must be numeric", stmt.End) - } - + ti.inferExpression(stmt.Start) + ti.inferExpression(stmt.End) if stmt.Step != nil { - stepType := ti.inferExpression(stmt.Step) - if !ti.isNumericType(stepType) { - ti.addError("for loop step value must be numeric", stmt.Step) - } + ti.inferExpression(stmt.Step) } ti.enterScope() // Define loop variable as number ti.currentScope.Define(&Symbol{ Name: stmt.Variable.Value, - Type: ti.numberType, + Type: NumberType, Declared: true, }) - stmt.Variable.SetType(ti.numberType) + stmt.Variable.typeInfo = NumberType for _, s := range stmt.Body { ti.inferStatement(s) @@ -368,33 +319,26 @@ func (ti *TypeInferrer) inferForStatement(stmt *ForStatement) { ti.exitScope() } -// inferForInStatement handles for-in loops func (ti *TypeInferrer) inferForInStatement(stmt *ForInStatement) { - iterableType := ti.inferExpression(stmt.Iterable) - - // For now, assume iterable is a table or struct - if !ti.isTableType(iterableType) && !ti.isStructType(iterableType) { - ti.addError("for-in requires an iterable (table or struct)", stmt.Iterable) - } + ti.inferExpression(stmt.Iterable) ti.enterScope() - - // Define loop variables (key and value are any for now) + // Define loop variables if stmt.Key != nil { ti.currentScope.Define(&Symbol{ Name: stmt.Key.Value, - Type: ti.anyType, + Type: AnyType, Declared: true, }) - stmt.Key.SetType(ti.anyType) + stmt.Key.typeInfo = AnyType } ti.currentScope.Define(&Symbol{ Name: stmt.Value.Value, - Type: ti.anyType, + Type: AnyType, Declared: true, }) - stmt.Value.SetType(ti.anyType) + stmt.Value.typeInfo = AnyType for _, s := range stmt.Body { ti.inferStatement(s) @@ -403,29 +347,25 @@ func (ti *TypeInferrer) inferForInStatement(stmt *ForInStatement) { } // inferExpression infers the type of an expression -func (ti *TypeInferrer) inferExpression(expr Expression) *TypeInfo { +func (ti *TypeInferrer) inferExpression(expr Expression) TypeInfo { if expr == nil { - return ti.nilType + return NilType } switch e := expr.(type) { case *Identifier: return ti.inferIdentifier(e) case *NumberLiteral: - e.SetType(ti.numberType) - return ti.numberType + return NumberType case *StringLiteral: - e.SetType(ti.stringType) - return ti.stringType + return StringType case *BooleanLiteral: - e.SetType(ti.boolType) - return ti.boolType + return BoolType case *NilLiteral: - e.SetType(ti.nilType) - return ti.nilType + return NilType case *TableLiteral: return ti.inferTableLiteral(e) - case *StructConstructorExpression: + case *StructConstructor: return ti.inferStructConstructor(e) case *FunctionLiteral: return ti.inferFunctionLiteral(e) @@ -439,20 +379,40 @@ func (ti *TypeInferrer) inferExpression(expr Expression) *TypeInfo { return ti.inferIndexExpression(e) case *DotExpression: return ti.inferDotExpression(e) - case *AssignExpression: - return ti.inferAssignExpression(e) + case *Assignment: + return ti.inferAssignmentExpression(e) default: ti.addError("unknown expression type", expr) - return ti.anyType + return AnyType } } -// inferStructConstructor handles struct constructor expressions -func (ti *TypeInferrer) inferStructConstructor(expr *StructConstructorExpression) *TypeInfo { - structDef, exists := ti.structs[expr.StructName] - if !exists { - ti.addError(fmt.Sprintf("undefined struct '%s'", expr.StructName), expr) - return ti.anyType +func (ti *TypeInferrer) inferIdentifier(ident *Identifier) TypeInfo { + symbol := ti.currentScope.Lookup(ident.Value) + if symbol == nil { + ti.addError(fmt.Sprintf("undefined variable '%s'", ident.Value), ident) + return AnyType + } + ident.typeInfo = symbol.Type + return symbol.Type +} + +func (ti *TypeInferrer) inferTableLiteral(table *TableLiteral) TypeInfo { + // Infer types of all values + for _, pair := range table.Pairs { + if pair.Key != nil { + ti.inferExpression(pair.Key) + } + ti.inferExpression(pair.Value) + } + return TableType +} + +func (ti *TypeInferrer) inferStructConstructor(expr *StructConstructor) TypeInfo { + structDef := ti.GetStructByID(expr.StructID) + if structDef == nil { + ti.addError("undefined struct in constructor", expr) + return AnyType } // Validate field assignments @@ -465,25 +425,23 @@ func (ti *TypeInferrer) inferStructConstructor(expr *StructConstructorExpression fieldName = str.Value } - // Check if field exists in struct - fieldExists := false - var fieldType *TypeInfo + // Find field in struct definition + var fieldType TypeInfo + found := false for _, field := range structDef.Fields { if field.Name == fieldName { - fieldExists = true fieldType = field.TypeHint + found = true break } } - if !fieldExists { - ti.addError(fmt.Sprintf("struct '%s' has no field '%s'", expr.StructName, fieldName), expr) + if !found { + ti.addError(fmt.Sprintf("struct has no field '%s'", fieldName), expr) } else { - // Check type compatibility valueType := ti.inferExpression(pair.Value) if !ti.isTypeCompatible(valueType, fieldType) { - ti.addError(fmt.Sprintf("cannot assign %s to field '%s' of type %s", - valueType.Type, fieldName, fieldType.Type), expr) + ti.addError("field type mismatch in struct constructor", expr) } } } else { @@ -492,66 +450,18 @@ func (ti *TypeInferrer) inferStructConstructor(expr *StructConstructorExpression } } - structType := ti.getStructType(expr.StructName) - expr.SetType(structType) + structType := ti.CreateStructType(expr.StructID) + expr.typeInfo = structType return structType } -// inferAssignExpression handles assignment expressions -func (ti *TypeInferrer) inferAssignExpression(expr *AssignExpression) *TypeInfo { - valueType := ti.inferExpression(expr.Value) - - if ident, ok := expr.Name.(*Identifier); ok { - if expr.IsDeclaration { - ti.currentScope.Define(&Symbol{ - Name: ident.Value, - Type: valueType, - Declared: true, - }) - } - ident.SetType(valueType) - } else { - ti.inferExpression(expr.Name) - } - - expr.SetType(valueType) - return valueType -} - -// inferIdentifier looks up identifier type in symbol table -func (ti *TypeInferrer) inferIdentifier(ident *Identifier) *TypeInfo { - symbol := ti.currentScope.Lookup(ident.Value) - if symbol == nil { - ti.addError(fmt.Sprintf("undefined variable '%s'", ident.Value), ident) - return ti.anyType - } - - ident.SetType(symbol.Type) - return symbol.Type -} - -// inferTableLiteral infers table type -func (ti *TypeInferrer) inferTableLiteral(table *TableLiteral) *TypeInfo { - // Infer types of all values - for _, pair := range table.Pairs { - if pair.Key != nil { - ti.inferExpression(pair.Key) - } - ti.inferExpression(pair.Value) - } - - table.SetType(ti.tableType) - return ti.tableType -} - -// inferFunctionLiteral infers function type -func (ti *TypeInferrer) inferFunctionLiteral(fn *FunctionLiteral) *TypeInfo { +func (ti *TypeInferrer) inferFunctionLiteral(fn *FunctionLiteral) TypeInfo { ti.enterScope() // Define parameters in function scope for _, param := range fn.Parameters { - paramType := ti.anyType - if param.TypeHint != nil { + paramType := AnyType + if param.TypeHint.Type != TypeUnknown { paramType = param.TypeHint } @@ -568,104 +478,88 @@ func (ti *TypeInferrer) inferFunctionLiteral(fn *FunctionLiteral) *TypeInfo { } ti.exitScope() - - // For now, all functions have type "function" - funcType := &TypeInfo{Type: TypeFunction, Inferred: true} - fn.SetType(funcType) - return funcType + return FunctionType } -// inferCallExpression infers function call return type -func (ti *TypeInferrer) inferCallExpression(call *CallExpression) *TypeInfo { - funcType := ti.inferExpression(call.Function) - - if !ti.isFunctionType(funcType) { - ti.addError("cannot call non-function", call.Function) - return ti.anyType - } +func (ti *TypeInferrer) inferCallExpression(call *CallExpression) TypeInfo { + ti.inferExpression(call.Function) // Infer argument types for _, arg := range call.Arguments { ti.inferExpression(arg) } - // For now, assume function calls return any - call.SetType(ti.anyType) - return ti.anyType + call.typeInfo = AnyType + return AnyType } -// inferPrefixExpression infers prefix operation type -func (ti *TypeInferrer) inferPrefixExpression(prefix *PrefixExpression) *TypeInfo { +func (ti *TypeInferrer) inferPrefixExpression(prefix *PrefixExpression) TypeInfo { rightType := ti.inferExpression(prefix.Right) - var resultType *TypeInfo + var resultType TypeInfo switch prefix.Operator { case "-": if !ti.isNumericType(rightType) { ti.addError("unary minus requires numeric operand", prefix) } - resultType = ti.numberType + resultType = NumberType case "not": - resultType = ti.boolType + resultType = BoolType default: ti.addError(fmt.Sprintf("unknown prefix operator '%s'", prefix.Operator), prefix) - resultType = ti.anyType + resultType = AnyType } - prefix.SetType(resultType) + prefix.typeInfo = resultType return resultType } -// inferInfixExpression infers binary operation type -func (ti *TypeInferrer) inferInfixExpression(infix *InfixExpression) *TypeInfo { +func (ti *TypeInferrer) inferInfixExpression(infix *InfixExpression) TypeInfo { leftType := ti.inferExpression(infix.Left) rightType := ti.inferExpression(infix.Right) - var resultType *TypeInfo + var resultType TypeInfo switch infix.Operator { case "+", "-", "*", "/": if !ti.isNumericType(leftType) || !ti.isNumericType(rightType) { ti.addError(fmt.Sprintf("arithmetic operator '%s' requires numeric operands", infix.Operator), infix) } - resultType = ti.numberType + resultType = NumberType case "==", "!=": // Equality works with any types - resultType = ti.boolType + resultType = BoolType case "<", ">", "<=", ">=": if !ti.isComparableTypes(leftType, rightType) { ti.addError(fmt.Sprintf("comparison operator '%s' requires compatible operands", infix.Operator), infix) } - resultType = ti.boolType + resultType = BoolType case "and", "or": - ti.validateBooleanContext(leftType, infix.Left) - ti.validateBooleanContext(rightType, infix.Right) - resultType = ti.boolType + resultType = BoolType default: ti.addError(fmt.Sprintf("unknown infix operator '%s'", infix.Operator), infix) - resultType = ti.anyType + resultType = AnyType } - infix.SetType(resultType) + infix.typeInfo = resultType return resultType } -// inferIndexExpression infers table[index] type -func (ti *TypeInferrer) inferIndexExpression(index *IndexExpression) *TypeInfo { +func (ti *TypeInferrer) inferIndexExpression(index *IndexExpression) TypeInfo { leftType := ti.inferExpression(index.Left) ti.inferExpression(index.Index) // If indexing a struct, try to infer field type if ti.isStructType(leftType) { if strLit, ok := index.Index.(*StringLiteral); ok { - if structDef, exists := ti.structs[leftType.Type]; exists { + if structDef := ti.GetStructByID(leftType.StructID); structDef != nil { for _, field := range structDef.Fields { if field.Name == strLit.Value { - index.SetType(field.TypeHint) + index.typeInfo = field.TypeHint return field.TypeHint } } @@ -674,20 +568,19 @@ func (ti *TypeInferrer) inferIndexExpression(index *IndexExpression) *TypeInfo { } // For now, assume table/struct access returns any - index.SetType(ti.anyType) - return ti.anyType + index.typeInfo = AnyType + return AnyType } -// inferDotExpression infers table.key type -func (ti *TypeInferrer) inferDotExpression(dot *DotExpression) *TypeInfo { +func (ti *TypeInferrer) inferDotExpression(dot *DotExpression) TypeInfo { leftType := ti.inferExpression(dot.Left) // If accessing a struct field, try to infer field type if ti.isStructType(leftType) { - if structDef, exists := ti.structs[leftType.Type]; exists { + if structDef := ti.GetStructByID(leftType.StructID); structDef != nil { for _, field := range structDef.Fields { if field.Name == dot.Key { - dot.SetType(field.TypeHint) + dot.typeInfo = field.TypeHint return field.TypeHint } } @@ -695,59 +588,107 @@ func (ti *TypeInferrer) inferDotExpression(dot *DotExpression) *TypeInfo { } // For now, assume member access returns any - dot.SetType(ti.anyType) - return ti.anyType + dot.typeInfo = AnyType + return AnyType } -// Type checking helper methods +func (ti *TypeInferrer) inferAssignmentExpression(expr *Assignment) TypeInfo { + valueType := ti.inferExpression(expr.Value) -func (ti *TypeInferrer) isTypeCompatible(valueType, targetType *TypeInfo) bool { + if ident, ok := expr.Target.(*Identifier); ok { + if expr.IsDeclaration { + varType := valueType + if expr.TypeHint.Type != TypeUnknown { + if !ti.isTypeCompatible(valueType, expr.TypeHint) { + ti.addError("type mismatch in assignment", expr) + } + varType = expr.TypeHint + } + + ti.currentScope.Define(&Symbol{ + Name: ident.Value, + Type: varType, + Declared: true, + }) + ident.typeInfo = varType + } else { + symbol := ti.currentScope.Lookup(ident.Value) + if symbol != nil { + ident.typeInfo = symbol.Type + } + } + } else { + ti.inferExpression(expr.Target) + } + + return valueType +} + +// Helper methods +func (ti *TypeInferrer) enterScope() { + ti.currentScope = NewScope(ti.currentScope) +} + +func (ti *TypeInferrer) exitScope() { + if ti.currentScope.parent != nil { + ti.currentScope = ti.currentScope.parent + } +} + +func (ti *TypeInferrer) addError(message string, node Node) { + ti.errors = append(ti.errors, TypeError{ + Message: message, + Node: node, + }) +} + +func (ti *TypeInferrer) isValidType(t TypeInfo) bool { + if t.Type == TypeStruct { + return ti.GetStructByID(t.StructID) != nil + } + return t.Type <= TypeStruct +} + +func (ti *TypeInferrer) isTypeCompatible(valueType, targetType TypeInfo) bool { if targetType.Type == TypeAny || valueType.Type == TypeAny { return true } + if valueType.Type == TypeStruct && targetType.Type == TypeStruct { + return valueType.StructID == targetType.StructID + } return valueType.Type == targetType.Type } -func (ti *TypeInferrer) isNumericType(t *TypeInfo) bool { +func (ti *TypeInferrer) isNumericType(t TypeInfo) bool { return t.Type == TypeNumber } -func (ti *TypeInferrer) isBooleanType(t *TypeInfo) bool { +func (ti *TypeInferrer) isBooleanType(t TypeInfo) bool { return t.Type == TypeBool } -func (ti *TypeInferrer) isTableType(t *TypeInfo) bool { +func (ti *TypeInferrer) isTableType(t TypeInfo) bool { return t.Type == TypeTable } -func (ti *TypeInferrer) isFunctionType(t *TypeInfo) bool { +func (ti *TypeInferrer) isFunctionType(t TypeInfo) bool { return t.Type == TypeFunction } -func (ti *TypeInferrer) isComparableTypes(left, right *TypeInfo) bool { +func (ti *TypeInferrer) isStructType(t TypeInfo) bool { + return t.Type == TypeStruct +} + +func (ti *TypeInferrer) isComparableTypes(left, right TypeInfo) bool { if left.Type == TypeAny || right.Type == TypeAny { return true } return left.Type == right.Type && (left.Type == TypeNumber || left.Type == TypeString) } -func (ti *TypeInferrer) validateBooleanContext(t *TypeInfo, expr Expression) { - // In many languages, non-boolean values can be used in boolean context - // For strictness, we could require boolean type here - // For now, allow any type (truthy/falsy semantics) -} - -// Errors returns all type checking errors -func (ti *TypeInferrer) Errors() []TypeError { - return ti.errors -} - -// HasErrors returns true if there are any type errors -func (ti *TypeInferrer) HasErrors() bool { - return len(ti.errors) > 0 -} - -// ErrorStrings returns error messages as strings +// Error reporting +func (ti *TypeInferrer) Errors() []TypeError { return ti.errors } +func (ti *TypeInferrer) HasErrors() bool { return len(ti.errors) > 0 } func (ti *TypeInferrer) ErrorStrings() []string { result := make([]string, len(ti.errors)) for i, err := range ti.errors { @@ -756,21 +697,26 @@ func (ti *TypeInferrer) ErrorStrings() []string { return result } -// ValidTypeName checks if a string is a valid type name -func ValidTypeName(name string) bool { - validTypes := []string{TypeNumber, TypeString, TypeBool, TypeNil, TypeTable, TypeFunction, TypeAny} - for _, validType := range validTypes { - if name == validType { - return true - } +// Type string conversion +func TypeToString(t TypeInfo) string { + switch t.Type { + case TypeNumber: + return "number" + case TypeString: + return "string" + case TypeBool: + return "bool" + case TypeNil: + return "nil" + case TypeTable: + return "table" + case TypeFunction: + return "function" + case TypeAny: + return "any" + case TypeStruct: + return fmt.Sprintf("struct<%d>", t.StructID) + default: + return "unknown" } - return false -} - -// ParseTypeName converts a string to a TypeInfo (for parsing type hints) -func ParseTypeName(name string) *TypeInfo { - if ValidTypeName(name) { - return &TypeInfo{Type: name, Inferred: false} - } - return nil }