package parser import ( "fmt" ) // Type constants for built-in types const ( TypeNumber = "number" TypeString = "string" TypeBool = "bool" TypeNil = "nil" TypeTable = "table" TypeFunction = "function" TypeAny = "any" ) // TypeError represents a type checking error type TypeError struct { Message string Line int Column int Node Node } func (te TypeError) Error() string { return fmt.Sprintf("Type error at line %d, column %d: %s", te.Line, te.Column, te.Message) } // Symbol represents a variable in the symbol table type Symbol struct { Name string Type *TypeInfo Declared bool Line int Column int } // Scope represents a scope in the symbol table type Scope struct { symbols map[string]*Symbol parent *Scope } func NewScope(parent *Scope) *Scope { return &Scope{ symbols: make(map[string]*Symbol), parent: parent, } } func (s *Scope) Define(symbol *Symbol) { s.symbols[symbol.Name] = symbol } func (s *Scope) Lookup(name string) *Symbol { if symbol, ok := s.symbols[name]; ok { return symbol } if s.parent != nil { return s.parent.Lookup(name) } return nil } // TypeInferrer performs type inference and checking type TypeInferrer struct { currentScope *Scope globalScope *Scope errors []TypeError // Pre-allocated type objects for performance numberType *TypeInfo stringType *TypeInfo boolType *TypeInfo nilType *TypeInfo tableType *TypeInfo anyType *TypeInfo } // NewTypeInferrer creates a new type inference engine func NewTypeInferrer() *TypeInferrer { globalScope := NewScope(nil) ti := &TypeInferrer{ currentScope: globalScope, globalScope: globalScope, errors: []TypeError{}, // Pre-allocate common types to reduce allocations numberType: &TypeInfo{Type: TypeNumber, Inferred: true}, stringType: &TypeInfo{Type: TypeString, Inferred: true}, boolType: &TypeInfo{Type: TypeBool, Inferred: true}, nilType: &TypeInfo{Type: TypeNil, Inferred: true}, tableType: &TypeInfo{Type: TypeTable, Inferred: true}, anyType: &TypeInfo{Type: TypeAny, Inferred: true}, } return ti } // InferTypes performs type inference on the entire program func (ti *TypeInferrer) InferTypes(program *Program) []TypeError { for _, stmt := range program.Statements { ti.inferStatement(stmt) } return ti.errors } // enterScope creates a new scope func (ti *TypeInferrer) enterScope() { ti.currentScope = NewScope(ti.currentScope) } // exitScope returns to the parent scope func (ti *TypeInferrer) exitScope() { if ti.currentScope.parent != nil { ti.currentScope = ti.currentScope.parent } } // addError adds a type error func (ti *TypeInferrer) addError(message string, node Node) { ti.errors = append(ti.errors, TypeError{ Message: message, Line: 0, // Would need to track position in AST nodes Column: 0, Node: node, }) } // inferStatement infers types for statements func (ti *TypeInferrer) inferStatement(stmt Statement) { switch s := stmt.(type) { case *AssignStatement: ti.inferAssignStatement(s) case *EchoStatement: ti.inferExpression(s.Value) case *IfStatement: ti.inferIfStatement(s) case *WhileStatement: ti.inferWhileStatement(s) case *ForStatement: ti.inferForStatement(s) case *ForInStatement: ti.inferForInStatement(s) case *ReturnStatement: if s.Value != nil { ti.inferExpression(s.Value) } case *ExitStatement: if s.Value != nil { ti.inferExpression(s.Value) } } } // inferAssignStatement handles variable assignments with type checking func (ti *TypeInferrer) inferAssignStatement(stmt *AssignStatement) { // Infer the type of the value expression valueType := ti.inferExpression(stmt.Value) if ident, ok := stmt.Name.(*Identifier); ok { // Simple variable assignment symbol := ti.currentScope.Lookup(ident.Value) if stmt.IsDeclaration { // New variable declaration varType := valueType // If there's a type hint, validate it if stmt.TypeHint != nil { if !ti.isTypeCompatible(valueType, stmt.TypeHint) { ti.addError(fmt.Sprintf("cannot assign %s to variable of type %s", valueType.Type, stmt.TypeHint.Type), stmt) } varType = stmt.TypeHint varType.Inferred = false } // Define the new symbol ti.currentScope.Define(&Symbol{ Name: ident.Value, Type: varType, Declared: true, }) ident.SetType(varType) } else { // Assignment to existing variable if symbol == nil { ti.addError(fmt.Sprintf("undefined variable '%s'", ident.Value), stmt) return } // Check type compatibility if !ti.isTypeCompatible(valueType, symbol.Type) { ti.addError(fmt.Sprintf("cannot assign %s to variable of type %s", valueType.Type, symbol.Type.Type), stmt) } ident.SetType(symbol.Type) } } else { // Member access assignment (table.key or table[index]) ti.inferExpression(stmt.Name) } } // inferIfStatement handles if statements func (ti *TypeInferrer) inferIfStatement(stmt *IfStatement) { condType := ti.inferExpression(stmt.Condition) ti.validateBooleanContext(condType, stmt.Condition) ti.enterScope() for _, s := range stmt.Body { ti.inferStatement(s) } ti.exitScope() for _, elseif := range stmt.ElseIfs { condType := ti.inferExpression(elseif.Condition) ti.validateBooleanContext(condType, elseif.Condition) ti.enterScope() for _, s := range elseif.Body { ti.inferStatement(s) } ti.exitScope() } if len(stmt.Else) > 0 { ti.enterScope() for _, s := range stmt.Else { ti.inferStatement(s) } ti.exitScope() } } // inferWhileStatement handles while loops func (ti *TypeInferrer) inferWhileStatement(stmt *WhileStatement) { condType := ti.inferExpression(stmt.Condition) ti.validateBooleanContext(condType, stmt.Condition) ti.enterScope() for _, s := range stmt.Body { ti.inferStatement(s) } ti.exitScope() } // inferForStatement handles numeric for loops func (ti *TypeInferrer) inferForStatement(stmt *ForStatement) { startType := ti.inferExpression(stmt.Start) endType := ti.inferExpression(stmt.End) if !ti.isNumericType(startType) { ti.addError("for loop start value must be numeric", stmt.Start) } if !ti.isNumericType(endType) { ti.addError("for loop end value must be numeric", stmt.End) } if stmt.Step != nil { stepType := ti.inferExpression(stmt.Step) if !ti.isNumericType(stepType) { ti.addError("for loop step value must be numeric", stmt.Step) } } ti.enterScope() // Define loop variable as number ti.currentScope.Define(&Symbol{ Name: stmt.Variable.Value, Type: ti.numberType, Declared: true, }) stmt.Variable.SetType(ti.numberType) for _, s := range stmt.Body { ti.inferStatement(s) } ti.exitScope() } // inferForInStatement handles for-in loops func (ti *TypeInferrer) inferForInStatement(stmt *ForInStatement) { iterableType := ti.inferExpression(stmt.Iterable) // For now, assume iterable is a table if !ti.isTableType(iterableType) { ti.addError("for-in requires an iterable (table)", stmt.Iterable) } ti.enterScope() // Define loop variables (key and value are any for now) if stmt.Key != nil { ti.currentScope.Define(&Symbol{ Name: stmt.Key.Value, Type: ti.anyType, Declared: true, }) stmt.Key.SetType(ti.anyType) } ti.currentScope.Define(&Symbol{ Name: stmt.Value.Value, Type: ti.anyType, Declared: true, }) stmt.Value.SetType(ti.anyType) for _, s := range stmt.Body { ti.inferStatement(s) } ti.exitScope() } // inferExpression infers the type of an expression func (ti *TypeInferrer) inferExpression(expr Expression) *TypeInfo { if expr == nil { return ti.nilType } switch e := expr.(type) { case *Identifier: return ti.inferIdentifier(e) case *NumberLiteral: e.SetType(ti.numberType) return ti.numberType case *StringLiteral: e.SetType(ti.stringType) return ti.stringType case *BooleanLiteral: e.SetType(ti.boolType) return ti.boolType case *NilLiteral: e.SetType(ti.nilType) return ti.nilType case *TableLiteral: return ti.inferTableLiteral(e) case *FunctionLiteral: return ti.inferFunctionLiteral(e) case *CallExpression: return ti.inferCallExpression(e) case *PrefixExpression: return ti.inferPrefixExpression(e) case *InfixExpression: return ti.inferInfixExpression(e) case *IndexExpression: return ti.inferIndexExpression(e) case *DotExpression: return ti.inferDotExpression(e) default: ti.addError("unknown expression type", expr) return ti.anyType } } // inferIdentifier looks up identifier type in symbol table func (ti *TypeInferrer) inferIdentifier(ident *Identifier) *TypeInfo { symbol := ti.currentScope.Lookup(ident.Value) if symbol == nil { ti.addError(fmt.Sprintf("undefined variable '%s'", ident.Value), ident) return ti.anyType } ident.SetType(symbol.Type) return symbol.Type } // inferTableLiteral infers table type func (ti *TypeInferrer) inferTableLiteral(table *TableLiteral) *TypeInfo { // Infer types of all values for _, pair := range table.Pairs { if pair.Key != nil { ti.inferExpression(pair.Key) } ti.inferExpression(pair.Value) } table.SetType(ti.tableType) return ti.tableType } // inferFunctionLiteral infers function type func (ti *TypeInferrer) inferFunctionLiteral(fn *FunctionLiteral) *TypeInfo { ti.enterScope() // Define parameters in function scope for _, param := range fn.Parameters { paramType := ti.anyType if param.TypeHint != nil { paramType = param.TypeHint } ti.currentScope.Define(&Symbol{ Name: param.Name, Type: paramType, Declared: true, }) } // Infer body for _, stmt := range fn.Body { ti.inferStatement(stmt) } ti.exitScope() // For now, all functions have type "function" funcType := &TypeInfo{Type: TypeFunction, Inferred: true} fn.SetType(funcType) return funcType } // inferCallExpression infers function call return type func (ti *TypeInferrer) inferCallExpression(call *CallExpression) *TypeInfo { funcType := ti.inferExpression(call.Function) if !ti.isFunctionType(funcType) { ti.addError("cannot call non-function", call.Function) return ti.anyType } // Infer argument types for _, arg := range call.Arguments { ti.inferExpression(arg) } // For now, assume function calls return any call.SetType(ti.anyType) return ti.anyType } // inferPrefixExpression infers prefix operation type func (ti *TypeInferrer) inferPrefixExpression(prefix *PrefixExpression) *TypeInfo { rightType := ti.inferExpression(prefix.Right) var resultType *TypeInfo switch prefix.Operator { case "-": if !ti.isNumericType(rightType) { ti.addError("unary minus requires numeric operand", prefix) } resultType = ti.numberType case "not": resultType = ti.boolType default: ti.addError(fmt.Sprintf("unknown prefix operator '%s'", prefix.Operator), prefix) resultType = ti.anyType } prefix.SetType(resultType) return resultType } // inferInfixExpression infers binary operation type func (ti *TypeInferrer) inferInfixExpression(infix *InfixExpression) *TypeInfo { leftType := ti.inferExpression(infix.Left) rightType := ti.inferExpression(infix.Right) var resultType *TypeInfo switch infix.Operator { case "+", "-", "*", "/": if !ti.isNumericType(leftType) || !ti.isNumericType(rightType) { ti.addError(fmt.Sprintf("arithmetic operator '%s' requires numeric operands", infix.Operator), infix) } resultType = ti.numberType case "==", "!=": // Equality works with any types resultType = ti.boolType case "<", ">", "<=", ">=": if !ti.isComparableTypes(leftType, rightType) { ti.addError(fmt.Sprintf("comparison operator '%s' requires compatible operands", infix.Operator), infix) } resultType = ti.boolType case "and", "or": ti.validateBooleanContext(leftType, infix.Left) ti.validateBooleanContext(rightType, infix.Right) resultType = ti.boolType default: ti.addError(fmt.Sprintf("unknown infix operator '%s'", infix.Operator), infix) resultType = ti.anyType } infix.SetType(resultType) return resultType } // inferIndexExpression infers table[index] type func (ti *TypeInferrer) inferIndexExpression(index *IndexExpression) *TypeInfo { ti.inferExpression(index.Left) ti.inferExpression(index.Index) // For now, assume table access returns any index.SetType(ti.anyType) return ti.anyType } // inferDotExpression infers table.key type func (ti *TypeInferrer) inferDotExpression(dot *DotExpression) *TypeInfo { ti.inferExpression(dot.Left) // For now, assume member access returns any dot.SetType(ti.anyType) return ti.anyType } // Type checking helper methods func (ti *TypeInferrer) isTypeCompatible(valueType, targetType *TypeInfo) bool { if targetType.Type == TypeAny || valueType.Type == TypeAny { return true } return valueType.Type == targetType.Type } func (ti *TypeInferrer) isNumericType(t *TypeInfo) bool { return t.Type == TypeNumber } func (ti *TypeInferrer) isBooleanType(t *TypeInfo) bool { return t.Type == TypeBool } func (ti *TypeInferrer) isTableType(t *TypeInfo) bool { return t.Type == TypeTable } func (ti *TypeInferrer) isFunctionType(t *TypeInfo) bool { return t.Type == TypeFunction } func (ti *TypeInferrer) isComparableTypes(left, right *TypeInfo) bool { if left.Type == TypeAny || right.Type == TypeAny { return true } return left.Type == right.Type && (left.Type == TypeNumber || left.Type == TypeString) } func (ti *TypeInferrer) validateBooleanContext(t *TypeInfo, expr Expression) { // In many languages, non-boolean values can be used in boolean context // For strictness, we could require boolean type here // For now, allow any type (truthy/falsy semantics) } // Errors returns all type checking errors func (ti *TypeInferrer) Errors() []TypeError { return ti.errors } // HasErrors returns true if there are any type errors func (ti *TypeInferrer) HasErrors() bool { return len(ti.errors) > 0 } // ErrorStrings returns error messages as strings func (ti *TypeInferrer) ErrorStrings() []string { result := make([]string, len(ti.errors)) for i, err := range ti.errors { result[i] = err.Error() } return result } // ValidTypeName checks if a string is a valid type name func ValidTypeName(name string) bool { validTypes := []string{TypeNumber, TypeString, TypeBool, TypeNil, TypeTable, TypeFunction, TypeAny} for _, validType := range validTypes { if name == validType { return true } } return false } // ParseTypeName converts a string to a TypeInfo (for parsing type hints) func ParseTypeName(name string) *TypeInfo { if ValidTypeName(name) { return &TypeInfo{Type: name, Inferred: false} } return nil }