diff --git a/parser/ast.go b/parser/ast.go index 4486b97..cb25448 100644 --- a/parser/ast.go +++ b/parser/ast.go @@ -2,6 +2,12 @@ package parser import "fmt" +// TypeInfo represents type information for expressions +type TypeInfo struct { + Type string // "number", "string", "bool", "table", "function", "nil", "any" + Inferred bool // true if type was inferred, false if explicitly declared +} + // Node represents any node in the AST type Node interface { String() string @@ -17,6 +23,8 @@ type Statement interface { type Expression interface { Node expressionNode() + GetType() *TypeInfo + SetType(*TypeInfo) } // Program represents the root of the AST @@ -33,9 +41,10 @@ func (p *Program) String() string { return result } -// AssignStatement represents variable assignment +// AssignStatement represents variable assignment with optional type hint type AssignStatement struct { Name Expression // Changed from *Identifier to Expression for member access + TypeHint *TypeInfo // optional type hint Value Expression IsDeclaration bool // true if this is the first assignment in current scope } @@ -46,7 +55,15 @@ func (as *AssignStatement) String() string { if as.IsDeclaration { prefix = "local " } - return fmt.Sprintf("%s%s = %s", prefix, as.Name.String(), as.Value.String()) + + var nameStr string + if as.TypeHint != nil { + nameStr = fmt.Sprintf("%s: %s", as.Name.String(), as.TypeHint.Type) + } else { + nameStr = as.Name.String() + } + + return fmt.Sprintf("%s%s = %s", prefix, nameStr, as.Value.String()) } // EchoStatement represents echo output statements @@ -216,33 +233,56 @@ func (fis *ForInStatement) String() string { return result } -// Identifier represents identifiers -type Identifier struct { - Value string +// FunctionParameter represents a function parameter with optional type hint +type FunctionParameter struct { + Name string + TypeHint *TypeInfo } -func (i *Identifier) expressionNode() {} -func (i *Identifier) String() string { return i.Value } +func (fp *FunctionParameter) String() string { + if fp.TypeHint != nil { + return fmt.Sprintf("%s: %s", fp.Name, fp.TypeHint.Type) + } + return fp.Name +} + +// Identifier represents identifiers +type Identifier struct { + Value string + typeInfo *TypeInfo +} + +func (i *Identifier) expressionNode() {} +func (i *Identifier) String() string { return i.Value } +func (i *Identifier) GetType() *TypeInfo { return i.typeInfo } +func (i *Identifier) SetType(t *TypeInfo) { i.typeInfo = t } // NumberLiteral represents numeric literals type NumberLiteral struct { - Value float64 + Value float64 + typeInfo *TypeInfo } -func (nl *NumberLiteral) expressionNode() {} -func (nl *NumberLiteral) String() string { return fmt.Sprintf("%.2f", nl.Value) } +func (nl *NumberLiteral) expressionNode() {} +func (nl *NumberLiteral) String() string { return fmt.Sprintf("%.2f", nl.Value) } +func (nl *NumberLiteral) GetType() *TypeInfo { return nl.typeInfo } +func (nl *NumberLiteral) SetType(t *TypeInfo) { nl.typeInfo = t } // StringLiteral represents string literals type StringLiteral struct { - Value string + Value string + typeInfo *TypeInfo } -func (sl *StringLiteral) expressionNode() {} -func (sl *StringLiteral) String() string { return fmt.Sprintf(`"%s"`, sl.Value) } +func (sl *StringLiteral) expressionNode() {} +func (sl *StringLiteral) String() string { return fmt.Sprintf(`"%s"`, sl.Value) } +func (sl *StringLiteral) GetType() *TypeInfo { return sl.typeInfo } +func (sl *StringLiteral) SetType(t *TypeInfo) { sl.typeInfo = t } // BooleanLiteral represents boolean literals type BooleanLiteral struct { - Value bool + Value bool + typeInfo *TypeInfo } func (bl *BooleanLiteral) expressionNode() {} @@ -252,18 +292,26 @@ func (bl *BooleanLiteral) String() string { } return "false" } +func (bl *BooleanLiteral) GetType() *TypeInfo { return bl.typeInfo } +func (bl *BooleanLiteral) SetType(t *TypeInfo) { bl.typeInfo = t } // NilLiteral represents nil literal -type NilLiteral struct{} +type NilLiteral struct { + typeInfo *TypeInfo +} -func (nl *NilLiteral) expressionNode() {} -func (nl *NilLiteral) String() string { return "nil" } +func (nl *NilLiteral) expressionNode() {} +func (nl *NilLiteral) String() string { return "nil" } +func (nl *NilLiteral) GetType() *TypeInfo { return nl.typeInfo } +func (nl *NilLiteral) SetType(t *TypeInfo) { nl.typeInfo = t } -// FunctionLiteral represents function literals: fn(a, b, ...) ... end +// FunctionLiteral represents function literals with typed parameters type FunctionLiteral struct { - Parameters []string + Parameters []FunctionParameter Variadic bool + ReturnType *TypeInfo // optional return type hint Body []Statement + typeInfo *TypeInfo } func (fl *FunctionLiteral) expressionNode() {} @@ -273,7 +321,7 @@ func (fl *FunctionLiteral) String() string { if i > 0 { params += ", " } - params += param + params += param.String() } if fl.Variadic { if len(fl.Parameters) > 0 { @@ -282,18 +330,26 @@ func (fl *FunctionLiteral) String() string { params += "..." } - result := fmt.Sprintf("fn(%s)\n", params) + result := fmt.Sprintf("fn(%s)", params) + if fl.ReturnType != nil { + result += ": " + fl.ReturnType.Type + } + result += "\n" + for _, stmt := range fl.Body { result += "\t" + stmt.String() + "\n" } result += "end" return result } +func (fl *FunctionLiteral) GetType() *TypeInfo { return fl.typeInfo } +func (fl *FunctionLiteral) SetType(t *TypeInfo) { fl.typeInfo = t } // CallExpression represents function calls: func(arg1, arg2, ...) type CallExpression struct { Function Expression Arguments []Expression + typeInfo *TypeInfo } func (ce *CallExpression) expressionNode() {} @@ -304,11 +360,14 @@ func (ce *CallExpression) String() string { } return fmt.Sprintf("%s(%s)", ce.Function.String(), joinStrings(args, ", ")) } +func (ce *CallExpression) GetType() *TypeInfo { return ce.typeInfo } +func (ce *CallExpression) SetType(t *TypeInfo) { ce.typeInfo = t } // PrefixExpression represents prefix operations like -x, not x type PrefixExpression struct { Operator string Right Expression + typeInfo *TypeInfo } func (pe *PrefixExpression) expressionNode() {} @@ -319,40 +378,51 @@ func (pe *PrefixExpression) String() string { } return fmt.Sprintf("(%s%s)", pe.Operator, pe.Right.String()) } +func (pe *PrefixExpression) GetType() *TypeInfo { return pe.typeInfo } +func (pe *PrefixExpression) SetType(t *TypeInfo) { pe.typeInfo = t } // InfixExpression represents binary operations type InfixExpression struct { Left Expression Operator string Right Expression + typeInfo *TypeInfo } func (ie *InfixExpression) expressionNode() {} func (ie *InfixExpression) String() string { return fmt.Sprintf("(%s %s %s)", ie.Left.String(), ie.Operator, ie.Right.String()) } +func (ie *InfixExpression) GetType() *TypeInfo { return ie.typeInfo } +func (ie *InfixExpression) SetType(t *TypeInfo) { ie.typeInfo = t } // IndexExpression represents table[key] access type IndexExpression struct { - Left Expression - Index Expression + Left Expression + Index Expression + typeInfo *TypeInfo } func (ie *IndexExpression) expressionNode() {} func (ie *IndexExpression) String() string { return fmt.Sprintf("%s[%s]", ie.Left.String(), ie.Index.String()) } +func (ie *IndexExpression) GetType() *TypeInfo { return ie.typeInfo } +func (ie *IndexExpression) SetType(t *TypeInfo) { ie.typeInfo = t } // DotExpression represents table.key access type DotExpression struct { - Left Expression - Key string + Left Expression + Key string + typeInfo *TypeInfo } func (de *DotExpression) expressionNode() {} func (de *DotExpression) String() string { return fmt.Sprintf("%s.%s", de.Left.String(), de.Key) } +func (de *DotExpression) GetType() *TypeInfo { return de.typeInfo } +func (de *DotExpression) SetType(t *TypeInfo) { de.typeInfo = t } // TablePair represents a key-value pair in a table type TablePair struct { @@ -369,7 +439,8 @@ func (tp *TablePair) String() string { // TableLiteral represents table literals {} type TableLiteral struct { - Pairs []TablePair + Pairs []TablePair + typeInfo *TypeInfo } func (tl *TableLiteral) expressionNode() {} @@ -380,6 +451,8 @@ func (tl *TableLiteral) String() string { } return fmt.Sprintf("{%s}", joinStrings(pairs, ", ")) } +func (tl *TableLiteral) GetType() *TypeInfo { return tl.typeInfo } +func (tl *TableLiteral) SetType(t *TypeInfo) { tl.typeInfo = t } // IsArray returns true if this table contains only array-style elements func (tl *TableLiteral) IsArray() bool { diff --git a/parser/lexer.go b/parser/lexer.go index 01d8e44..4d263b4 100644 --- a/parser/lexer.go +++ b/parser/lexer.go @@ -257,6 +257,8 @@ func (l *Lexer) NextToken() Token { tok = Token{Type: STAR, Literal: string(l.ch), Line: l.line, Column: l.column} case '/': tok = Token{Type: SLASH, Literal: string(l.ch), Line: l.line, Column: l.column} + case ':': + tok = Token{Type: COLON, Literal: string(l.ch), Line: l.line, Column: l.column} case '.': // Check for ellipsis (...) if l.peekChar() == '.' && l.peekCharAt(2) == '.' { diff --git a/parser/parser.go b/parser/parser.go index b6e330a..42c2e52 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -89,14 +89,13 @@ func (p *Parser) enterScope(scopeType string) { } func (p *Parser) exitScope() { - if len(p.scopes) > 1 { // never remove global scope + if len(p.scopes) > 1 { p.scopes = p.scopes[:len(p.scopes)-1] p.scopeTypes = p.scopeTypes[:len(p.scopeTypes)-1] } } func (p *Parser) enterFunctionScope() { - // Functions create new variable scopes p.enterScope("function") } @@ -105,35 +104,29 @@ func (p *Parser) exitFunctionScope() { } func (p *Parser) enterLoopScope() { - // Create temporary scope for loop variables only p.enterScope("loop") } func (p *Parser) exitLoopScope() { - // Remove temporary loop scope p.exitScope() } func (p *Parser) enterBlockScope() { - // Blocks don't create new variable scopes, just control flow scopes - // We don't need to track these for variable declarations + // Blocks don't create new variable scopes } func (p *Parser) exitBlockScope() { - // No-op since blocks don't create variable scopes + // No-op } func (p *Parser) currentVariableScope() map[string]bool { - // If we're in a loop scope, declare variables in the parent scope if len(p.scopeTypes) > 1 && p.scopeTypes[len(p.scopeTypes)-1] == "loop" { return p.scopes[len(p.scopes)-2] } - // Otherwise use the current scope return p.scopes[len(p.scopes)-1] } func (p *Parser) isVariableDeclared(name string) bool { - // Check all scopes from current up to global for i := len(p.scopes) - 1; i >= 0; i-- { if p.scopes[i][name] { return true @@ -147,10 +140,31 @@ func (p *Parser) declareVariable(name string) { } func (p *Parser) declareLoopVariable(name string) { - // Loop variables go in the current loop scope p.scopes[len(p.scopes)-1][name] = true } +// parseTypeHint parses optional type hint after colon +func (p *Parser) parseTypeHint() *TypeInfo { + if !p.peekTokenIs(COLON) { + return nil + } + + p.nextToken() // consume ':' + + if !p.expectPeekIdent() { + p.addError("expected type name after ':'") + return nil + } + + typeName := p.curToken.Literal + if !ValidTypeName(typeName) { + p.addError(fmt.Sprintf("invalid type name '%s'", typeName)) + return nil + } + + return &TypeInfo{Type: typeName, Inferred: false} +} + // registerPrefix registers a prefix parse function func (p *Parser) registerPrefix(tokenType TokenType, fn func() Expression) { p.prefixParseFns[tokenType] = fn @@ -187,7 +201,6 @@ func (p *Parser) ParseProgram() *Program { func (p *Parser) parseStatement() Statement { switch p.curToken.Type { case IDENT: - // Try to parse as assignment (handles both simple and member access) return p.parseAssignStatement() case IF: return p.parseIfStatement() @@ -217,17 +230,22 @@ func (p *Parser) parseStatement() Statement { } } -// parseAssignStatement parses variable assignment +// parseAssignStatement parses variable assignment with optional type hint func (p *Parser) parseAssignStatement() *AssignStatement { stmt := &AssignStatement{} - // Parse left-hand side expression (can be identifier or member access) + // Parse left-hand side expression stmt.Name = p.ParseExpression(LOWEST) if stmt.Name == nil { p.addError("expected expression for assignment left-hand side") return nil } + // Check for type hint on simple identifiers + if _, ok := stmt.Name.(*Identifier); ok { + stmt.TypeHint = p.parseTypeHint() + } + // Check if next token is assignment operator if !p.peekTokenIs(ASSIGN) { p.addError("unexpected identifier, expected assignment or declaration") @@ -237,13 +255,11 @@ func (p *Parser) parseAssignStatement() *AssignStatement { // Validate assignment target and check if it's a declaration switch name := stmt.Name.(type) { case *Identifier: - // Simple variable assignment - check if it's a declaration stmt.IsDeclaration = !p.isVariableDeclared(name.Value) if stmt.IsDeclaration { p.declareVariable(name.Value) } case *DotExpression, *IndexExpression: - // Member access - never a declaration stmt.IsDeclaration = false default: p.addError("invalid assignment target") @@ -289,10 +305,8 @@ func (p *Parser) parseBreakStatement() *BreakStatement { func (p *Parser) parseExitStatement() *ExitStatement { stmt := &ExitStatement{} - // Check if there's an optional expression after 'exit' - // Only parse expression if next token can start an expression if p.canStartExpression(p.peekToken.Type) { - p.nextToken() // move past 'exit' + p.nextToken() stmt.Value = p.ParseExpression(LOWEST) if stmt.Value == nil { p.addError("expected expression after 'exit'") @@ -307,9 +321,8 @@ func (p *Parser) parseExitStatement() *ExitStatement { func (p *Parser) parseReturnStatement() *ReturnStatement { stmt := &ReturnStatement{} - // Check if there's an optional expression after 'return' if p.canStartExpression(p.peekToken.Type) { - p.nextToken() // move past 'return' + p.nextToken() stmt.Value = p.ParseExpression(LOWEST) if stmt.Value == nil { p.addError("expected expression after 'return'") @@ -330,11 +343,11 @@ func (p *Parser) canStartExpression(tokenType TokenType) bool { } } -// parseWhileStatement parses while loops: while condition do ... end +// parseWhileStatement parses while loops func (p *Parser) parseWhileStatement() *WhileStatement { stmt := &WhileStatement{} - p.nextToken() // move past 'while' + p.nextToken() stmt.Condition = p.ParseExpression(LOWEST) if stmt.Condition == nil { @@ -347,9 +360,8 @@ func (p *Parser) parseWhileStatement() *WhileStatement { return nil } - p.nextToken() // move past 'do' + p.nextToken() - // Parse loop body (no new variable scope) p.enterBlockScope() stmt.Body = p.parseBlockStatements(END) p.exitBlockScope() @@ -362,9 +374,9 @@ func (p *Parser) parseWhileStatement() *WhileStatement { return stmt } -// parseForStatement parses for loops (both numeric and for-in) +// parseForStatement parses for loops func (p *Parser) parseForStatement() Statement { - p.nextToken() // move past 'for' + p.nextToken() if !p.curTokenIs(IDENT) { p.addError("expected identifier after 'for'") @@ -373,12 +385,9 @@ func (p *Parser) parseForStatement() Statement { firstVar := &Identifier{Value: p.curToken.Literal} - // Look ahead to determine which type of for loop if p.peekTokenIs(ASSIGN) { - // Numeric for loop: for i = start, end, step do return p.parseNumericForStatement(firstVar) } else if p.peekTokenIs(COMMA) || p.peekTokenIs(IN) { - // For-in loop: for k, v in expr do or for v in expr do return p.parseForInStatement(firstVar) } else { p.addError("expected '=', ',' or 'in' after for loop variable") @@ -386,7 +395,7 @@ func (p *Parser) parseForStatement() Statement { } } -// parseNumericForStatement parses numeric for loops: for i = start, end, step do +// parseNumericForStatement parses numeric for loops func (p *Parser) parseNumericForStatement(variable *Identifier) *ForStatement { stmt := &ForStatement{Variable: variable} @@ -394,9 +403,8 @@ func (p *Parser) parseNumericForStatement(variable *Identifier) *ForStatement { return nil } - p.nextToken() // move past '=' + p.nextToken() - // Parse start expression stmt.Start = p.ParseExpression(LOWEST) if stmt.Start == nil { p.addError("expected start expression in for loop") @@ -408,19 +416,17 @@ func (p *Parser) parseNumericForStatement(variable *Identifier) *ForStatement { return nil } - p.nextToken() // move past ',' + p.nextToken() - // Parse end expression stmt.End = p.ParseExpression(LOWEST) if stmt.End == nil { p.addError("expected end expression in for loop") return nil } - // Optional step expression if p.peekTokenIs(COMMA) { - p.nextToken() // move to ',' - p.nextToken() // move past ',' + p.nextToken() + p.nextToken() stmt.Step = p.ParseExpression(LOWEST) if stmt.Step == nil { @@ -434,13 +440,12 @@ func (p *Parser) parseNumericForStatement(variable *Identifier) *ForStatement { return nil } - p.nextToken() // move past 'do' + p.nextToken() - // Create temporary scope for loop variable, assignments in body go to parent scope p.enterLoopScope() - p.declareLoopVariable(variable.Value) // loop variable in temporary scope + p.declareLoopVariable(variable.Value) stmt.Body = p.parseBlockStatements(END) - p.exitLoopScope() // discard temporary scope with loop variable + p.exitLoopScope() if !p.curTokenIs(END) { p.addError("expected 'end' to close for loop") @@ -450,15 +455,14 @@ func (p *Parser) parseNumericForStatement(variable *Identifier) *ForStatement { return stmt } -// parseForInStatement parses for-in loops: for k, v in expr do or for v in expr do +// parseForInStatement parses for-in loops func (p *Parser) parseForInStatement(firstVar *Identifier) *ForInStatement { stmt := &ForInStatement{} if p.peekTokenIs(COMMA) { - // Two variables: for k, v in expr do stmt.Key = firstVar - p.nextToken() // move to ',' - p.nextToken() // move past ',' + p.nextToken() + p.nextToken() if !p.curTokenIs(IDENT) { p.addError("expected identifier after ',' in for loop") @@ -467,7 +471,6 @@ func (p *Parser) parseForInStatement(firstVar *Identifier) *ForInStatement { stmt.Value = &Identifier{Value: p.curToken.Literal} } else { - // Single variable: for v in expr do stmt.Value = firstVar } @@ -476,9 +479,8 @@ func (p *Parser) parseForInStatement(firstVar *Identifier) *ForInStatement { return nil } - p.nextToken() // move past 'in' + p.nextToken() - // Parse iterable expression stmt.Iterable = p.ParseExpression(LOWEST) if stmt.Iterable == nil { p.addError("expected expression after 'in' in for loop") @@ -490,16 +492,15 @@ func (p *Parser) parseForInStatement(firstVar *Identifier) *ForInStatement { return nil } - p.nextToken() // move past 'do' + p.nextToken() - // Create temporary scope for loop variables, assignments in body go to parent scope p.enterLoopScope() if stmt.Key != nil { - p.declareLoopVariable(stmt.Key.Value) // loop variable in temporary scope + p.declareLoopVariable(stmt.Key.Value) } - p.declareLoopVariable(stmt.Value.Value) // loop variable in temporary scope + p.declareLoopVariable(stmt.Value.Value) stmt.Body = p.parseBlockStatements(END) - p.exitLoopScope() // discard temporary scope with loop variables + p.exitLoopScope() if !p.curTokenIs(END) { p.addError("expected 'end' to close for loop") @@ -509,11 +510,11 @@ func (p *Parser) parseForInStatement(firstVar *Identifier) *ForInStatement { return stmt } -// parseIfStatement parses if/elseif/else/end statements +// parseIfStatement parses if statements func (p *Parser) parseIfStatement() *IfStatement { stmt := &IfStatement{} - p.nextToken() // move past 'if' + p.nextToken() stmt.Condition = p.ParseExpression(LOWEST) if stmt.Condition == nil { @@ -521,29 +522,25 @@ func (p *Parser) parseIfStatement() *IfStatement { return nil } - // Optional 'then' keyword if p.peekTokenIs(THEN) { p.nextToken() } - p.nextToken() // move past condition (and optional 'then') + p.nextToken() - // Check if we immediately hit END (empty body should be an error) if p.curTokenIs(END) { p.addError("expected 'end' to close if statement") return nil } - // Parse if body (no new variable scope) p.enterBlockScope() stmt.Body = p.parseBlockStatements(ELSEIF, ELSE, END) p.exitBlockScope() - // Parse elseif clauses for p.curTokenIs(ELSEIF) { elseif := ElseIfClause{} - p.nextToken() // move past 'elseif' + p.nextToken() elseif.Condition = p.ParseExpression(LOWEST) if elseif.Condition == nil { @@ -551,14 +548,12 @@ func (p *Parser) parseIfStatement() *IfStatement { return nil } - // Optional 'then' keyword if p.peekTokenIs(THEN) { p.nextToken() } - p.nextToken() // move past condition (and optional 'then') + p.nextToken() - // Parse elseif body (no new variable scope) p.enterBlockScope() elseif.Body = p.parseBlockStatements(ELSEIF, ELSE, END) p.exitBlockScope() @@ -566,11 +561,9 @@ func (p *Parser) parseIfStatement() *IfStatement { stmt.ElseIfs = append(stmt.ElseIfs, elseif) } - // Parse else clause if p.curTokenIs(ELSE) { - p.nextToken() // move past 'else' + p.nextToken() - // Parse else body (no new variable scope) p.enterBlockScope() stmt.Else = p.parseBlockStatements(END) p.exitBlockScope() @@ -584,7 +577,7 @@ func (p *Parser) parseIfStatement() *IfStatement { return stmt } -// parseBlockStatements parses statements until one of the terminator tokens +// parseBlockStatements parses statements until terminators func (p *Parser) parseBlockStatements(terminators ...TokenType) []Statement { statements := []Statement{} @@ -599,7 +592,7 @@ func (p *Parser) parseBlockStatements(terminators ...TokenType) []Statement { return statements } -// isTerminator checks if current token is one of the terminators +// isTerminator checks if current token is a terminator func (p *Parser) isTerminator(terminators ...TokenType) bool { for _, terminator := range terminators { if p.curTokenIs(terminator) { @@ -650,9 +643,7 @@ func (p *Parser) parseNumberLiteral() Expression { var value float64 var err error - // Check for hexadecimal (0x/0X prefix) if strings.HasPrefix(literal, "0x") || strings.HasPrefix(literal, "0X") { - // Validate hex format if len(literal) <= 2 { p.addError(fmt.Sprintf("could not parse '%s' as hexadecimal number", literal)) return nil @@ -664,7 +655,6 @@ func (p *Parser) parseNumberLiteral() Expression { return nil } } - // Parse as hex and convert to float64 intVal, parseErr := strconv.ParseInt(literal, 0, 64) if parseErr != nil { p.addError(fmt.Sprintf("could not parse '%s' as hexadecimal number", literal)) @@ -672,7 +662,6 @@ func (p *Parser) parseNumberLiteral() Expression { } value = float64(intVal) } else if strings.HasPrefix(literal, "0b") || strings.HasPrefix(literal, "0B") { - // Validate binary format if len(literal) <= 2 { p.addError(fmt.Sprintf("could not parse '%s' as binary number", literal)) return nil @@ -684,8 +673,7 @@ func (p *Parser) parseNumberLiteral() Expression { return nil } } - // Parse binary manually since Go doesn't support 0b in ParseInt with base 0 - binaryStr := literal[2:] // remove "0b" prefix + binaryStr := literal[2:] intVal, parseErr := strconv.ParseInt(binaryStr, 2, 64) if parseErr != nil { p.addError(fmt.Sprintf("could not parse '%s' as binary number", literal)) @@ -693,7 +681,6 @@ func (p *Parser) parseNumberLiteral() Expression { } value = float64(intVal) } else { - // Parse as regular decimal (handles scientific notation automatically) value, err = strconv.ParseFloat(literal, 64) if err != nil { p.addError(fmt.Sprintf("could not parse '%s' as number", literal)) @@ -763,12 +750,14 @@ func (p *Parser) parseFunctionLiteral() Expression { return nil } - p.nextToken() // move past ')' + // Check for return type hint + fn.ReturnType = p.parseTypeHint() + + p.nextToken() - // Enter new function scope and declare parameters p.enterFunctionScope() for _, param := range fn.Parameters { - p.declareVariable(param) + p.declareVariable(param.Name) } fn.Body = p.parseBlockStatements(END) p.exitFunctionScope() @@ -781,8 +770,8 @@ func (p *Parser) parseFunctionLiteral() Expression { return fn } -func (p *Parser) parseFunctionParameters() ([]string, bool) { - var params []string +func (p *Parser) parseFunctionParameters() ([]FunctionParameter, bool) { + var params []FunctionParameter var variadic bool if p.peekTokenIs(RPAREN) { @@ -802,16 +791,20 @@ func (p *Parser) parseFunctionParameters() ([]string, bool) { return nil, false } - params = append(params, p.curToken.Literal) + param := FunctionParameter{Name: p.curToken.Literal} + + // Check for type hint + param.TypeHint = p.parseTypeHint() + + params = append(params, param) if !p.peekTokenIs(COMMA) { break } - p.nextToken() // move to ',' - p.nextToken() // move past ',' + p.nextToken() + p.nextToken() - // Check for ellipsis after comma if p.curTokenIs(ELLIPSIS) { variadic = true break @@ -833,7 +826,6 @@ func (p *Parser) parseTableLiteral() Expression { p.nextToken() for { - // Check for EOF if p.curTokenIs(EOF) { p.addError("unexpected end of input, expected }") return nil @@ -841,17 +833,15 @@ func (p *Parser) parseTableLiteral() Expression { pair := TablePair{} - // Check if this is a key=value pair (identifier or string key) if (p.curTokenIs(IDENT) || p.curTokenIs(STRING)) && p.peekTokenIs(ASSIGN) { if p.curTokenIs(IDENT) { pair.Key = &Identifier{Value: p.curToken.Literal} } else { pair.Key = &StringLiteral{Value: p.curToken.Literal} } - p.nextToken() // move to = - p.nextToken() // move past = + p.nextToken() + p.nextToken() - // Check for EOF after = if p.curTokenIs(EOF) { p.addError("expected expression after assignment operator") return nil @@ -859,7 +849,6 @@ func (p *Parser) parseTableLiteral() Expression { pair.Value = p.ParseExpression(LOWEST) } else { - // Array-style element pair.Value = p.ParseExpression(LOWEST) } @@ -873,15 +862,13 @@ func (p *Parser) parseTableLiteral() Expression { break } - p.nextToken() // consume comma - p.nextToken() // move to next element + p.nextToken() + p.nextToken() - // Allow trailing comma if p.curTokenIs(RBRACE) { break } - // Check for EOF after comma if p.curTokenIs(EOF) { p.addError("expected next token to be }") return nil @@ -956,7 +943,7 @@ func (p *Parser) parseExpressionList(end TokenType) []Expression { } func (p *Parser) parseIndexExpression(left Expression) Expression { - p.nextToken() // move past '[' + p.nextToken() index := p.ParseExpression(LOWEST) if index == nil { @@ -993,7 +980,6 @@ func (p *Parser) expectPeek(t TokenType) bool { return false } -// expectPeekIdent accepts IDENT or keyword tokens as identifiers func (p *Parser) expectPeekIdent() bool { if p.peekTokenIs(IDENT) || p.isKeyword(p.peekToken.Type) { p.nextToken() @@ -1003,7 +989,6 @@ func (p *Parser) expectPeekIdent() bool { return false } -// isKeyword checks if a token type is a keyword that can be used as identifier func (p *Parser) isKeyword(t TokenType) bool { switch t { case TRUE, FALSE, NIL, AND, OR, NOT, IF, THEN, ELSEIF, ELSE, END, ECHO, FOR, WHILE, IN, DO, BREAK, EXIT, FN, RETURN: @@ -1013,7 +998,7 @@ func (p *Parser) isKeyword(t TokenType) bool { } } -// Error handling methods +// Error handling func (p *Parser) addError(message string) { p.errors = append(p.errors, ParseError{ Message: message, @@ -1075,12 +1060,10 @@ func (p *Parser) Errors() []ParseError { return p.errors } -// HasErrors returns true if there are any parsing errors func (p *Parser) HasErrors() bool { return len(p.errors) > 0 } -// ErrorStrings returns error messages as strings for backward compatibility func (p *Parser) ErrorStrings() []string { result := make([]string, len(p.errors)) for i, err := range p.errors { @@ -1089,7 +1072,7 @@ func (p *Parser) ErrorStrings() []string { return result } -// tokenTypeString returns a human-readable string for token types +// tokenTypeString returns human-readable string for token types func tokenTypeString(t TokenType) string { switch t { case IDENT: @@ -1114,6 +1097,8 @@ func tokenTypeString(t TokenType) string { return "/" case DOT: return "." + case COLON: + return ":" case EQ: return "==" case NOT_EQ: diff --git a/parser/tests/functions_test.go b/parser/tests/functions_test.go index fccad34..5dd4c46 100644 --- a/parser/tests/functions_test.go +++ b/parser/tests/functions_test.go @@ -81,8 +81,8 @@ func TestFunctionParameters(t *testing.T) { } for i, expected := range tt.params { - if fn.Parameters[i] != expected { - t.Errorf("parameter %d: expected %s, got %s", i, expected, fn.Parameters[i]) + if fn.Parameters[i].Name != expected { + t.Errorf("parameter %d: expected %s, got %s", i, expected, fn.Parameters[i].Name) } } diff --git a/parser/tests/types_test.go b/parser/tests/types_test.go new file mode 100644 index 0000000..d949450 --- /dev/null +++ b/parser/tests/types_test.go @@ -0,0 +1,533 @@ +package parser_test + +import ( + "testing" + + "git.sharkk.net/Sharkk/Mako/parser" +) + +func TestVariableTypeHints(t *testing.T) { + tests := []struct { + input string + variable string + typeHint string + hasHint bool + desc string + }{ + {"x = 42", "x", "", false, "no type hint"}, + {"x: number = 42", "x", "number", true, "number type hint"}, + {"name: string = \"hello\"", "name", "string", true, "string type hint"}, + {"flag: bool = true", "flag", "bool", true, "bool type hint"}, + {"data: table = {}", "data", "table", true, "table type hint"}, + {"fn_var: function = fn() end", "fn_var", "function", true, "function type hint"}, + {"value: any = nil", "value", "any", true, "any type hint"}, + {"ptr: nil = nil", "ptr", "nil", true, "nil type hint"}, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + l := parser.NewLexer(tt.input) + p := parser.NewParser(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + if len(program.Statements) != 1 { + t.Fatalf("expected 1 statement, got %d", len(program.Statements)) + } + + stmt, ok := program.Statements[0].(*parser.AssignStatement) + if !ok { + t.Fatalf("expected AssignStatement, got %T", program.Statements[0]) + } + + // Check variable name + ident, ok := stmt.Name.(*parser.Identifier) + if !ok { + t.Fatalf("expected Identifier for Name, got %T", stmt.Name) + } + + if ident.Value != tt.variable { + t.Errorf("expected variable %s, got %s", tt.variable, ident.Value) + } + + // Check type hint + if tt.hasHint { + if stmt.TypeHint == nil { + t.Error("expected type hint but got nil") + } else { + if stmt.TypeHint.Type != tt.typeHint { + t.Errorf("expected type hint %s, got %s", tt.typeHint, stmt.TypeHint.Type) + } + if stmt.TypeHint.Inferred { + t.Error("expected type hint to not be inferred") + } + } + } else { + if stmt.TypeHint != nil { + t.Errorf("expected no type hint but got %s", stmt.TypeHint.Type) + } + } + }) + } +} + +func TestFunctionParameterTypeHints(t *testing.T) { + tests := []struct { + input string + params []struct{ name, typeHint string } + returnType string + hasReturn bool + desc string + }{ + { + "fn(a, b) end", + []struct{ name, typeHint string }{ + {"a", ""}, + {"b", ""}, + }, + "", false, + "no type hints", + }, + { + "fn(a: number, b: string) end", + []struct{ name, typeHint string }{ + {"a", "number"}, + {"b", "string"}, + }, + "", false, + "parameter type hints only", + }, + { + "fn(x: number): string end", + []struct{ name, typeHint string }{ + {"x", "number"}, + }, + "string", true, + "parameter and return type hints", + }, + { + "fn(): bool end", + []struct{ name, typeHint string }{}, + "bool", true, + "return type hint only", + }, + { + "fn(a: number, b, c: string): table end", + []struct{ name, typeHint string }{ + {"a", "number"}, + {"b", ""}, + {"c", "string"}, + }, + "table", true, + "mixed parameter types with return", + }, + { + "fn(callback: function, data: any): nil end", + []struct{ name, typeHint string }{ + {"callback", "function"}, + {"data", "any"}, + }, + "nil", true, + "function and any types", + }, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + l := parser.NewLexer(tt.input) + p := parser.NewParser(l) + expr := p.ParseExpression(parser.LOWEST) + checkParserErrors(t, p) + + fn, ok := expr.(*parser.FunctionLiteral) + if !ok { + t.Fatalf("expected FunctionLiteral, got %T", expr) + } + + // Check parameters + if len(fn.Parameters) != len(tt.params) { + t.Fatalf("expected %d parameters, got %d", len(tt.params), len(fn.Parameters)) + } + + for i, expected := range tt.params { + param := fn.Parameters[i] + if param.Name != expected.name { + t.Errorf("parameter %d: expected name %s, got %s", i, expected.name, param.Name) + } + + if expected.typeHint == "" { + if param.TypeHint != nil { + t.Errorf("parameter %d: expected no type hint but got %s", i, param.TypeHint.Type) + } + } else { + if param.TypeHint == nil { + t.Errorf("parameter %d: expected type hint %s but got nil", i, expected.typeHint) + } else if param.TypeHint.Type != expected.typeHint { + t.Errorf("parameter %d: expected type hint %s, got %s", i, expected.typeHint, param.TypeHint.Type) + } + } + } + + // Check return type + if tt.hasReturn { + if fn.ReturnType == nil { + t.Error("expected return type hint but got nil") + } else if fn.ReturnType.Type != tt.returnType { + t.Errorf("expected return type %s, got %s", tt.returnType, fn.ReturnType.Type) + } + } else { + if fn.ReturnType != nil { + t.Errorf("expected no return type but got %s", fn.ReturnType.Type) + } + } + }) + } +} + +func TestTypeHintStringRepresentation(t *testing.T) { + tests := []struct { + input string + expected string + desc string + }{ + {"x: number = 42", "local x: number = 42.00", "typed variable assignment"}, + {"fn(a: number, b: string): bool end", "fn(a: number, b: string): bool\nend", "typed function"}, + {"callback: function = fn(x: number): string return \"\" end", "local callback: function = fn(x: number): string\n\treturn \"\"\nend", "typed function assignment"}, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + l := parser.NewLexer(tt.input) + p := parser.NewParser(l) + + var result string + if tt.input[0] == 'f' { // function literal + expr := p.ParseExpression(parser.LOWEST) + checkParserErrors(t, p) + result = expr.String() + } else { // assignment statement + program := p.ParseProgram() + checkParserErrors(t, p) + result = program.Statements[0].String() + } + + if result != tt.expected { + t.Errorf("expected %q, got %q", tt.expected, result) + } + }) + } +} + +func TestTypeHintValidation(t *testing.T) { + tests := []struct { + input string + expectedError string + desc string + }{ + {"x: invalid = 42", "invalid type name 'invalid'", "invalid type name"}, + {"x: = 42", "expected type name after ':'", "missing type name"}, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + l := parser.NewLexer(tt.input) + p := parser.NewParser(l) + p.ParseProgram() + + if !p.HasErrors() { + t.Fatal("expected parsing errors") + } + + errors := p.Errors() + found := false + for _, err := range errors { + if err.Message == tt.expectedError { + found = true + break + } + } + + if !found { + errorMsgs := make([]string, len(errors)) + for i, err := range errors { + errorMsgs[i] = err.Message + } + t.Errorf("expected error %q, got %v", tt.expectedError, errorMsgs) + } + }) + } +} + +func TestMemberAccessWithoutTypeHints(t *testing.T) { + tests := []struct { + input string + desc string + }{ + {"table.key = 42", "dot notation assignment"}, + {"arr[1] = \"hello\"", "bracket notation assignment"}, + {"obj.nested.deep = true", "chained dot assignment"}, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + l := parser.NewLexer(tt.input) + p := parser.NewParser(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + if len(program.Statements) != 1 { + t.Fatalf("expected 1 statement, got %d", len(program.Statements)) + } + + stmt, ok := program.Statements[0].(*parser.AssignStatement) + if !ok { + t.Fatalf("expected AssignStatement, got %T", program.Statements[0]) + } + + // Member access should never have type hints + if stmt.TypeHint != nil { + t.Error("member access assignment should not have type hints") + } + + // Should not be a declaration + if stmt.IsDeclaration { + t.Error("member access assignment should not be a declaration") + } + }) + } +} + +func TestTypeInferenceEngine(t *testing.T) { + input := `x: number = 42 +name: string = "hello" +fn_var: function = fn(a: number, b: string): bool + result: bool = a > 0 + return result +end +echo name` + + l := parser.NewLexer(input) + p := parser.NewParser(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + // Run type inference + inferrer := parser.NewTypeInferrer() + typeErrors := inferrer.InferTypes(program) + + // Should have no type errors for valid code + if len(typeErrors) > 0 { + errorMsgs := make([]string, len(typeErrors)) + for i, err := range typeErrors { + errorMsgs[i] = err.Error() + } + t.Errorf("unexpected type errors: %v", errorMsgs) + } +} + +func TestTypeInferenceErrors(t *testing.T) { + tests := []struct { + input string + expectedError string + desc string + }{ + { + "x: number = \"hello\"", + "cannot assign string to variable of type number", + "type mismatch in assignment", + }, + { + "x = 42\ny: string = x", + "cannot assign number to variable of type string", + "type mismatch with inferred type", + }, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + l := parser.NewLexer(tt.input) + p := parser.NewParser(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + inferrer := parser.NewTypeInferrer() + typeErrors := inferrer.InferTypes(program) + + if len(typeErrors) == 0 { + t.Fatal("expected type errors") + } + + found := false + for _, err := range typeErrors { + if err.Message == tt.expectedError { + found = true + break + } + } + + if !found { + errorMsgs := make([]string, len(typeErrors)) + for i, err := range typeErrors { + errorMsgs[i] = err.Message + } + t.Errorf("expected error %q, got %v", tt.expectedError, errorMsgs) + } + }) + } +} + +func TestVariadicFunctionTypeHints(t *testing.T) { + tests := []struct { + input string + variadic bool + desc string + }{ + {"fn(a: number, ...) end", true, "variadic after typed param"}, + {"fn(...) end", true, "variadic only"}, + {"fn(a: number, b: string) end", false, "no variadic"}, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + l := parser.NewLexer(tt.input) + p := parser.NewParser(l) + expr := p.ParseExpression(parser.LOWEST) + checkParserErrors(t, p) + + fn, ok := expr.(*parser.FunctionLiteral) + if !ok { + t.Fatalf("expected FunctionLiteral, got %T", expr) + } + + if fn.Variadic != tt.variadic { + t.Errorf("expected variadic = %t, got %t", tt.variadic, fn.Variadic) + } + }) + } +} + +func TestComplexTypeHintProgram(t *testing.T) { + input := `config: table = { + host = "localhost", + port = 8080, + enabled = true +} + +handler: function = fn(request: table, callback: function): nil + status: number = 200 + if request.method == "GET" then + response: table = {status = status, body = "OK"} + result = callback(response) + end +end + +server: table = { + config = config, + handler = handler, + start = fn(self: table): bool + return true + end +}` + + l := parser.NewLexer(input) + p := parser.NewParser(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + if len(program.Statements) != 3 { + t.Fatalf("expected 3 statements, got %d", len(program.Statements)) + } + + // Check first statement: config table with typed assignments + configStmt, ok := program.Statements[0].(*parser.AssignStatement) + if !ok { + t.Fatalf("expected AssignStatement, got %T", program.Statements[0]) + } + + if configStmt.TypeHint == nil || configStmt.TypeHint.Type != "table" { + t.Error("expected table type hint for config") + } + + // Check second statement: handler function with typed parameters + handlerStmt, ok := program.Statements[1].(*parser.AssignStatement) + if !ok { + t.Fatalf("expected AssignStatement, got %T", program.Statements[1]) + } + + if handlerStmt.TypeHint == nil || handlerStmt.TypeHint.Type != "function" { + t.Error("expected function type hint for handler") + } + + fn, ok := handlerStmt.Value.(*parser.FunctionLiteral) + if !ok { + t.Fatalf("expected FunctionLiteral value, got %T", handlerStmt.Value) + } + + if len(fn.Parameters) != 2 { + t.Fatalf("expected 2 parameters, got %d", len(fn.Parameters)) + } + + // Check parameter types + if fn.Parameters[0].TypeHint == nil || fn.Parameters[0].TypeHint.Type != "table" { + t.Error("expected table type for request parameter") + } + + if fn.Parameters[1].TypeHint == nil || fn.Parameters[1].TypeHint.Type != "function" { + t.Error("expected function type for callback parameter") + } + + // Check return type + if fn.ReturnType == nil || fn.ReturnType.Type != "nil" { + t.Error("expected nil return type for handler") + } + + // Check third statement: server table + serverStmt, ok := program.Statements[2].(*parser.AssignStatement) + if !ok { + t.Fatalf("expected AssignStatement, got %T", program.Statements[2]) + } + + if serverStmt.TypeHint == nil || serverStmt.TypeHint.Type != "table" { + t.Error("expected table type hint for server") + } +} + +func TestTypeInfoGettersSetters(t *testing.T) { + // Test that all expression types properly implement GetType/SetType + typeInfo := &parser.TypeInfo{Type: "test", Inferred: true} + + expressions := []parser.Expression{ + &parser.Identifier{Value: "x"}, + &parser.NumberLiteral{Value: 42}, + &parser.StringLiteral{Value: "hello"}, + &parser.BooleanLiteral{Value: true}, + &parser.NilLiteral{}, + &parser.TableLiteral{}, + &parser.FunctionLiteral{}, + &parser.CallExpression{Function: &parser.Identifier{Value: "fn"}}, + &parser.PrefixExpression{Operator: "-", Right: &parser.NumberLiteral{Value: 1}}, + &parser.InfixExpression{Left: &parser.NumberLiteral{Value: 1}, Operator: "+", Right: &parser.NumberLiteral{Value: 2}}, + &parser.IndexExpression{Left: &parser.Identifier{Value: "arr"}, Index: &parser.NumberLiteral{Value: 0}}, + &parser.DotExpression{Left: &parser.Identifier{Value: "obj"}, Key: "prop"}, + } + + for i, expr := range expressions { + t.Run(string(rune('0'+i)), func(t *testing.T) { + // Initially should have no type + if expr.GetType() != nil { + t.Error("expected nil type initially") + } + + // Set type + expr.SetType(typeInfo) + + // Get type should return what we set + retrieved := expr.GetType() + if retrieved == nil { + t.Error("expected non-nil type after setting") + } else if retrieved.Type != "test" || !retrieved.Inferred { + t.Errorf("expected {Type: test, Inferred: true}, got %+v", retrieved) + } + }) + } +} diff --git a/parser/token.go b/parser/token.go index afb33e1..4a8fb80 100644 --- a/parser/token.go +++ b/parser/token.go @@ -42,6 +42,7 @@ const ( RBRACKET // ] COMMA // , ELLIPSIS // ... + COLON // : // Keywords IF diff --git a/parser/types.go b/parser/types.go new file mode 100644 index 0000000..c674b36 --- /dev/null +++ b/parser/types.go @@ -0,0 +1,591 @@ +package parser + +import ( + "fmt" +) + +// Type constants for built-in types +const ( + TypeNumber = "number" + TypeString = "string" + TypeBool = "bool" + TypeNil = "nil" + TypeTable = "table" + TypeFunction = "function" + TypeAny = "any" +) + +// TypeError represents a type checking error +type TypeError struct { + Message string + Line int + Column int + Node Node +} + +func (te TypeError) Error() string { + return fmt.Sprintf("Type error at line %d, column %d: %s", te.Line, te.Column, te.Message) +} + +// Symbol represents a variable in the symbol table +type Symbol struct { + Name string + Type *TypeInfo + Declared bool + Line int + Column int +} + +// Scope represents a scope in the symbol table +type Scope struct { + symbols map[string]*Symbol + parent *Scope +} + +func NewScope(parent *Scope) *Scope { + return &Scope{ + symbols: make(map[string]*Symbol), + parent: parent, + } +} + +func (s *Scope) Define(symbol *Symbol) { + s.symbols[symbol.Name] = symbol +} + +func (s *Scope) Lookup(name string) *Symbol { + if symbol, ok := s.symbols[name]; ok { + return symbol + } + if s.parent != nil { + return s.parent.Lookup(name) + } + return nil +} + +// TypeInferrer performs type inference and checking +type TypeInferrer struct { + currentScope *Scope + globalScope *Scope + errors []TypeError + + // Pre-allocated type objects for performance + numberType *TypeInfo + stringType *TypeInfo + boolType *TypeInfo + nilType *TypeInfo + tableType *TypeInfo + anyType *TypeInfo +} + +// NewTypeInferrer creates a new type inference engine +func NewTypeInferrer() *TypeInferrer { + globalScope := NewScope(nil) + + ti := &TypeInferrer{ + currentScope: globalScope, + globalScope: globalScope, + errors: []TypeError{}, + + // Pre-allocate common types to reduce allocations + numberType: &TypeInfo{Type: TypeNumber, Inferred: true}, + stringType: &TypeInfo{Type: TypeString, Inferred: true}, + boolType: &TypeInfo{Type: TypeBool, Inferred: true}, + nilType: &TypeInfo{Type: TypeNil, Inferred: true}, + tableType: &TypeInfo{Type: TypeTable, Inferred: true}, + anyType: &TypeInfo{Type: TypeAny, Inferred: true}, + } + + return ti +} + +// InferTypes performs type inference on the entire program +func (ti *TypeInferrer) InferTypes(program *Program) []TypeError { + for _, stmt := range program.Statements { + ti.inferStatement(stmt) + } + return ti.errors +} + +// enterScope creates a new scope +func (ti *TypeInferrer) enterScope() { + ti.currentScope = NewScope(ti.currentScope) +} + +// exitScope returns to the parent scope +func (ti *TypeInferrer) exitScope() { + if ti.currentScope.parent != nil { + ti.currentScope = ti.currentScope.parent + } +} + +// addError adds a type error +func (ti *TypeInferrer) addError(message string, node Node) { + ti.errors = append(ti.errors, TypeError{ + Message: message, + Line: 0, // Would need to track position in AST nodes + Column: 0, + Node: node, + }) +} + +// inferStatement infers types for statements +func (ti *TypeInferrer) inferStatement(stmt Statement) { + switch s := stmt.(type) { + case *AssignStatement: + ti.inferAssignStatement(s) + case *EchoStatement: + ti.inferExpression(s.Value) + case *IfStatement: + ti.inferIfStatement(s) + case *WhileStatement: + ti.inferWhileStatement(s) + case *ForStatement: + ti.inferForStatement(s) + case *ForInStatement: + ti.inferForInStatement(s) + case *ReturnStatement: + if s.Value != nil { + ti.inferExpression(s.Value) + } + case *ExitStatement: + if s.Value != nil { + ti.inferExpression(s.Value) + } + } +} + +// inferAssignStatement handles variable assignments with type checking +func (ti *TypeInferrer) inferAssignStatement(stmt *AssignStatement) { + // Infer the type of the value expression + valueType := ti.inferExpression(stmt.Value) + + if ident, ok := stmt.Name.(*Identifier); ok { + // Simple variable assignment + symbol := ti.currentScope.Lookup(ident.Value) + + if stmt.IsDeclaration { + // New variable declaration + varType := valueType + + // If there's a type hint, validate it + if stmt.TypeHint != nil { + if !ti.isTypeCompatible(valueType, stmt.TypeHint) { + ti.addError(fmt.Sprintf("cannot assign %s to variable of type %s", + valueType.Type, stmt.TypeHint.Type), stmt) + } + varType = stmt.TypeHint + varType.Inferred = false + } + + // Define the new symbol + ti.currentScope.Define(&Symbol{ + Name: ident.Value, + Type: varType, + Declared: true, + }) + + ident.SetType(varType) + } else { + // Assignment to existing variable + if symbol == nil { + ti.addError(fmt.Sprintf("undefined variable '%s'", ident.Value), stmt) + return + } + + // Check type compatibility + if !ti.isTypeCompatible(valueType, symbol.Type) { + ti.addError(fmt.Sprintf("cannot assign %s to variable of type %s", + valueType.Type, symbol.Type.Type), stmt) + } + + ident.SetType(symbol.Type) + } + } else { + // Member access assignment (table.key or table[index]) + ti.inferExpression(stmt.Name) + } +} + +// inferIfStatement handles if statements +func (ti *TypeInferrer) inferIfStatement(stmt *IfStatement) { + condType := ti.inferExpression(stmt.Condition) + ti.validateBooleanContext(condType, stmt.Condition) + + ti.enterScope() + for _, s := range stmt.Body { + ti.inferStatement(s) + } + ti.exitScope() + + for _, elseif := range stmt.ElseIfs { + condType := ti.inferExpression(elseif.Condition) + ti.validateBooleanContext(condType, elseif.Condition) + + ti.enterScope() + for _, s := range elseif.Body { + ti.inferStatement(s) + } + ti.exitScope() + } + + if len(stmt.Else) > 0 { + ti.enterScope() + for _, s := range stmt.Else { + ti.inferStatement(s) + } + ti.exitScope() + } +} + +// inferWhileStatement handles while loops +func (ti *TypeInferrer) inferWhileStatement(stmt *WhileStatement) { + condType := ti.inferExpression(stmt.Condition) + ti.validateBooleanContext(condType, stmt.Condition) + + ti.enterScope() + for _, s := range stmt.Body { + ti.inferStatement(s) + } + ti.exitScope() +} + +// inferForStatement handles numeric for loops +func (ti *TypeInferrer) inferForStatement(stmt *ForStatement) { + startType := ti.inferExpression(stmt.Start) + endType := ti.inferExpression(stmt.End) + + if !ti.isNumericType(startType) { + ti.addError("for loop start value must be numeric", stmt.Start) + } + if !ti.isNumericType(endType) { + ti.addError("for loop end value must be numeric", stmt.End) + } + + if stmt.Step != nil { + stepType := ti.inferExpression(stmt.Step) + if !ti.isNumericType(stepType) { + ti.addError("for loop step value must be numeric", stmt.Step) + } + } + + ti.enterScope() + // Define loop variable as number + ti.currentScope.Define(&Symbol{ + Name: stmt.Variable.Value, + Type: ti.numberType, + Declared: true, + }) + stmt.Variable.SetType(ti.numberType) + + for _, s := range stmt.Body { + ti.inferStatement(s) + } + ti.exitScope() +} + +// inferForInStatement handles for-in loops +func (ti *TypeInferrer) inferForInStatement(stmt *ForInStatement) { + iterableType := ti.inferExpression(stmt.Iterable) + + // For now, assume iterable is a table + if !ti.isTableType(iterableType) { + ti.addError("for-in requires an iterable (table)", stmt.Iterable) + } + + ti.enterScope() + + // Define loop variables (key and value are any for now) + if stmt.Key != nil { + ti.currentScope.Define(&Symbol{ + Name: stmt.Key.Value, + Type: ti.anyType, + Declared: true, + }) + stmt.Key.SetType(ti.anyType) + } + + ti.currentScope.Define(&Symbol{ + Name: stmt.Value.Value, + Type: ti.anyType, + Declared: true, + }) + stmt.Value.SetType(ti.anyType) + + for _, s := range stmt.Body { + ti.inferStatement(s) + } + ti.exitScope() +} + +// inferExpression infers the type of an expression +func (ti *TypeInferrer) inferExpression(expr Expression) *TypeInfo { + if expr == nil { + return ti.nilType + } + + switch e := expr.(type) { + case *Identifier: + return ti.inferIdentifier(e) + case *NumberLiteral: + e.SetType(ti.numberType) + return ti.numberType + case *StringLiteral: + e.SetType(ti.stringType) + return ti.stringType + case *BooleanLiteral: + e.SetType(ti.boolType) + return ti.boolType + case *NilLiteral: + e.SetType(ti.nilType) + return ti.nilType + case *TableLiteral: + return ti.inferTableLiteral(e) + case *FunctionLiteral: + return ti.inferFunctionLiteral(e) + case *CallExpression: + return ti.inferCallExpression(e) + case *PrefixExpression: + return ti.inferPrefixExpression(e) + case *InfixExpression: + return ti.inferInfixExpression(e) + case *IndexExpression: + return ti.inferIndexExpression(e) + case *DotExpression: + return ti.inferDotExpression(e) + default: + ti.addError("unknown expression type", expr) + return ti.anyType + } +} + +// inferIdentifier looks up identifier type in symbol table +func (ti *TypeInferrer) inferIdentifier(ident *Identifier) *TypeInfo { + symbol := ti.currentScope.Lookup(ident.Value) + if symbol == nil { + ti.addError(fmt.Sprintf("undefined variable '%s'", ident.Value), ident) + return ti.anyType + } + + ident.SetType(symbol.Type) + return symbol.Type +} + +// inferTableLiteral infers table type +func (ti *TypeInferrer) inferTableLiteral(table *TableLiteral) *TypeInfo { + // Infer types of all values + for _, pair := range table.Pairs { + if pair.Key != nil { + ti.inferExpression(pair.Key) + } + ti.inferExpression(pair.Value) + } + + table.SetType(ti.tableType) + return ti.tableType +} + +// inferFunctionLiteral infers function type +func (ti *TypeInferrer) inferFunctionLiteral(fn *FunctionLiteral) *TypeInfo { + ti.enterScope() + + // Define parameters in function scope + for _, param := range fn.Parameters { + paramType := ti.anyType + if param.TypeHint != nil { + paramType = param.TypeHint + } + + ti.currentScope.Define(&Symbol{ + Name: param.Name, + Type: paramType, + Declared: true, + }) + } + + // Infer body + for _, stmt := range fn.Body { + ti.inferStatement(stmt) + } + + ti.exitScope() + + // For now, all functions have type "function" + funcType := &TypeInfo{Type: TypeFunction, Inferred: true} + fn.SetType(funcType) + return funcType +} + +// inferCallExpression infers function call return type +func (ti *TypeInferrer) inferCallExpression(call *CallExpression) *TypeInfo { + funcType := ti.inferExpression(call.Function) + + if !ti.isFunctionType(funcType) { + ti.addError("cannot call non-function", call.Function) + return ti.anyType + } + + // Infer argument types + for _, arg := range call.Arguments { + ti.inferExpression(arg) + } + + // For now, assume function calls return any + call.SetType(ti.anyType) + return ti.anyType +} + +// inferPrefixExpression infers prefix operation type +func (ti *TypeInferrer) inferPrefixExpression(prefix *PrefixExpression) *TypeInfo { + rightType := ti.inferExpression(prefix.Right) + + var resultType *TypeInfo + switch prefix.Operator { + case "-": + if !ti.isNumericType(rightType) { + ti.addError("unary minus requires numeric operand", prefix) + } + resultType = ti.numberType + case "not": + resultType = ti.boolType + default: + ti.addError(fmt.Sprintf("unknown prefix operator '%s'", prefix.Operator), prefix) + resultType = ti.anyType + } + + prefix.SetType(resultType) + return resultType +} + +// inferInfixExpression infers binary operation type +func (ti *TypeInferrer) inferInfixExpression(infix *InfixExpression) *TypeInfo { + leftType := ti.inferExpression(infix.Left) + rightType := ti.inferExpression(infix.Right) + + var resultType *TypeInfo + + switch infix.Operator { + case "+", "-", "*", "/": + if !ti.isNumericType(leftType) || !ti.isNumericType(rightType) { + ti.addError(fmt.Sprintf("arithmetic operator '%s' requires numeric operands", infix.Operator), infix) + } + resultType = ti.numberType + + case "==", "!=": + // Equality works with any types + resultType = ti.boolType + + case "<", ">", "<=", ">=": + if !ti.isComparableTypes(leftType, rightType) { + ti.addError(fmt.Sprintf("comparison operator '%s' requires compatible operands", infix.Operator), infix) + } + resultType = ti.boolType + + case "and", "or": + ti.validateBooleanContext(leftType, infix.Left) + ti.validateBooleanContext(rightType, infix.Right) + resultType = ti.boolType + + default: + ti.addError(fmt.Sprintf("unknown infix operator '%s'", infix.Operator), infix) + resultType = ti.anyType + } + + infix.SetType(resultType) + return resultType +} + +// inferIndexExpression infers table[index] type +func (ti *TypeInferrer) inferIndexExpression(index *IndexExpression) *TypeInfo { + ti.inferExpression(index.Left) + ti.inferExpression(index.Index) + + // For now, assume table access returns any + index.SetType(ti.anyType) + return ti.anyType +} + +// inferDotExpression infers table.key type +func (ti *TypeInferrer) inferDotExpression(dot *DotExpression) *TypeInfo { + ti.inferExpression(dot.Left) + + // For now, assume member access returns any + dot.SetType(ti.anyType) + return ti.anyType +} + +// Type checking helper methods + +func (ti *TypeInferrer) isTypeCompatible(valueType, targetType *TypeInfo) bool { + if targetType.Type == TypeAny || valueType.Type == TypeAny { + return true + } + return valueType.Type == targetType.Type +} + +func (ti *TypeInferrer) isNumericType(t *TypeInfo) bool { + return t.Type == TypeNumber +} + +func (ti *TypeInferrer) isBooleanType(t *TypeInfo) bool { + return t.Type == TypeBool +} + +func (ti *TypeInferrer) isTableType(t *TypeInfo) bool { + return t.Type == TypeTable +} + +func (ti *TypeInferrer) isFunctionType(t *TypeInfo) bool { + return t.Type == TypeFunction +} + +func (ti *TypeInferrer) isComparableTypes(left, right *TypeInfo) bool { + if left.Type == TypeAny || right.Type == TypeAny { + return true + } + return left.Type == right.Type && (left.Type == TypeNumber || left.Type == TypeString) +} + +func (ti *TypeInferrer) validateBooleanContext(t *TypeInfo, expr Expression) { + // In many languages, non-boolean values can be used in boolean context + // For strictness, we could require boolean type here + // For now, allow any type (truthy/falsy semantics) +} + +// Errors returns all type checking errors +func (ti *TypeInferrer) Errors() []TypeError { + return ti.errors +} + +// HasErrors returns true if there are any type errors +func (ti *TypeInferrer) HasErrors() bool { + return len(ti.errors) > 0 +} + +// ErrorStrings returns error messages as strings +func (ti *TypeInferrer) ErrorStrings() []string { + result := make([]string, len(ti.errors)) + for i, err := range ti.errors { + result[i] = err.Error() + } + return result +} + +// ValidTypeName checks if a string is a valid type name +func ValidTypeName(name string) bool { + validTypes := []string{TypeNumber, TypeString, TypeBool, TypeNil, TypeTable, TypeFunction, TypeAny} + for _, validType := range validTypes { + if name == validType { + return true + } + } + return false +} + +// ParseTypeName converts a string to a TypeInfo (for parsing type hints) +func ParseTypeName(name string) *TypeInfo { + if ValidTypeName(name) { + return &TypeInfo{Type: name, Inferred: false} + } + return nil +}