Mako/parser/types.go
2025-06-11 10:20:38 -05:00

592 lines
14 KiB
Go

package parser
import (
"fmt"
)
// Type constants for built-in types
const (
TypeNumber = "number"
TypeString = "string"
TypeBool = "bool"
TypeNil = "nil"
TypeTable = "table"
TypeFunction = "function"
TypeAny = "any"
)
// TypeError represents a type checking error
type TypeError struct {
Message string
Line int
Column int
Node Node
}
func (te TypeError) Error() string {
return fmt.Sprintf("Type error at line %d, column %d: %s", te.Line, te.Column, te.Message)
}
// Symbol represents a variable in the symbol table
type Symbol struct {
Name string
Type *TypeInfo
Declared bool
Line int
Column int
}
// Scope represents a scope in the symbol table
type Scope struct {
symbols map[string]*Symbol
parent *Scope
}
func NewScope(parent *Scope) *Scope {
return &Scope{
symbols: make(map[string]*Symbol),
parent: parent,
}
}
func (s *Scope) Define(symbol *Symbol) {
s.symbols[symbol.Name] = symbol
}
func (s *Scope) Lookup(name string) *Symbol {
if symbol, ok := s.symbols[name]; ok {
return symbol
}
if s.parent != nil {
return s.parent.Lookup(name)
}
return nil
}
// TypeInferrer performs type inference and checking
type TypeInferrer struct {
currentScope *Scope
globalScope *Scope
errors []TypeError
// Pre-allocated type objects for performance
numberType *TypeInfo
stringType *TypeInfo
boolType *TypeInfo
nilType *TypeInfo
tableType *TypeInfo
anyType *TypeInfo
}
// NewTypeInferrer creates a new type inference engine
func NewTypeInferrer() *TypeInferrer {
globalScope := NewScope(nil)
ti := &TypeInferrer{
currentScope: globalScope,
globalScope: globalScope,
errors: []TypeError{},
// Pre-allocate common types to reduce allocations
numberType: &TypeInfo{Type: TypeNumber, Inferred: true},
stringType: &TypeInfo{Type: TypeString, Inferred: true},
boolType: &TypeInfo{Type: TypeBool, Inferred: true},
nilType: &TypeInfo{Type: TypeNil, Inferred: true},
tableType: &TypeInfo{Type: TypeTable, Inferred: true},
anyType: &TypeInfo{Type: TypeAny, Inferred: true},
}
return ti
}
// InferTypes performs type inference on the entire program
func (ti *TypeInferrer) InferTypes(program *Program) []TypeError {
for _, stmt := range program.Statements {
ti.inferStatement(stmt)
}
return ti.errors
}
// enterScope creates a new scope
func (ti *TypeInferrer) enterScope() {
ti.currentScope = NewScope(ti.currentScope)
}
// exitScope returns to the parent scope
func (ti *TypeInferrer) exitScope() {
if ti.currentScope.parent != nil {
ti.currentScope = ti.currentScope.parent
}
}
// addError adds a type error
func (ti *TypeInferrer) addError(message string, node Node) {
ti.errors = append(ti.errors, TypeError{
Message: message,
Line: 0, // Would need to track position in AST nodes
Column: 0,
Node: node,
})
}
// inferStatement infers types for statements
func (ti *TypeInferrer) inferStatement(stmt Statement) {
switch s := stmt.(type) {
case *AssignStatement:
ti.inferAssignStatement(s)
case *EchoStatement:
ti.inferExpression(s.Value)
case *IfStatement:
ti.inferIfStatement(s)
case *WhileStatement:
ti.inferWhileStatement(s)
case *ForStatement:
ti.inferForStatement(s)
case *ForInStatement:
ti.inferForInStatement(s)
case *ReturnStatement:
if s.Value != nil {
ti.inferExpression(s.Value)
}
case *ExitStatement:
if s.Value != nil {
ti.inferExpression(s.Value)
}
}
}
// inferAssignStatement handles variable assignments with type checking
func (ti *TypeInferrer) inferAssignStatement(stmt *AssignStatement) {
// Infer the type of the value expression
valueType := ti.inferExpression(stmt.Value)
if ident, ok := stmt.Name.(*Identifier); ok {
// Simple variable assignment
symbol := ti.currentScope.Lookup(ident.Value)
if stmt.IsDeclaration {
// New variable declaration
varType := valueType
// If there's a type hint, validate it
if stmt.TypeHint != nil {
if !ti.isTypeCompatible(valueType, stmt.TypeHint) {
ti.addError(fmt.Sprintf("cannot assign %s to variable of type %s",
valueType.Type, stmt.TypeHint.Type), stmt)
}
varType = stmt.TypeHint
varType.Inferred = false
}
// Define the new symbol
ti.currentScope.Define(&Symbol{
Name: ident.Value,
Type: varType,
Declared: true,
})
ident.SetType(varType)
} else {
// Assignment to existing variable
if symbol == nil {
ti.addError(fmt.Sprintf("undefined variable '%s'", ident.Value), stmt)
return
}
// Check type compatibility
if !ti.isTypeCompatible(valueType, symbol.Type) {
ti.addError(fmt.Sprintf("cannot assign %s to variable of type %s",
valueType.Type, symbol.Type.Type), stmt)
}
ident.SetType(symbol.Type)
}
} else {
// Member access assignment (table.key or table[index])
ti.inferExpression(stmt.Name)
}
}
// inferIfStatement handles if statements
func (ti *TypeInferrer) inferIfStatement(stmt *IfStatement) {
condType := ti.inferExpression(stmt.Condition)
ti.validateBooleanContext(condType, stmt.Condition)
ti.enterScope()
for _, s := range stmt.Body {
ti.inferStatement(s)
}
ti.exitScope()
for _, elseif := range stmt.ElseIfs {
condType := ti.inferExpression(elseif.Condition)
ti.validateBooleanContext(condType, elseif.Condition)
ti.enterScope()
for _, s := range elseif.Body {
ti.inferStatement(s)
}
ti.exitScope()
}
if len(stmt.Else) > 0 {
ti.enterScope()
for _, s := range stmt.Else {
ti.inferStatement(s)
}
ti.exitScope()
}
}
// inferWhileStatement handles while loops
func (ti *TypeInferrer) inferWhileStatement(stmt *WhileStatement) {
condType := ti.inferExpression(stmt.Condition)
ti.validateBooleanContext(condType, stmt.Condition)
ti.enterScope()
for _, s := range stmt.Body {
ti.inferStatement(s)
}
ti.exitScope()
}
// inferForStatement handles numeric for loops
func (ti *TypeInferrer) inferForStatement(stmt *ForStatement) {
startType := ti.inferExpression(stmt.Start)
endType := ti.inferExpression(stmt.End)
if !ti.isNumericType(startType) {
ti.addError("for loop start value must be numeric", stmt.Start)
}
if !ti.isNumericType(endType) {
ti.addError("for loop end value must be numeric", stmt.End)
}
if stmt.Step != nil {
stepType := ti.inferExpression(stmt.Step)
if !ti.isNumericType(stepType) {
ti.addError("for loop step value must be numeric", stmt.Step)
}
}
ti.enterScope()
// Define loop variable as number
ti.currentScope.Define(&Symbol{
Name: stmt.Variable.Value,
Type: ti.numberType,
Declared: true,
})
stmt.Variable.SetType(ti.numberType)
for _, s := range stmt.Body {
ti.inferStatement(s)
}
ti.exitScope()
}
// inferForInStatement handles for-in loops
func (ti *TypeInferrer) inferForInStatement(stmt *ForInStatement) {
iterableType := ti.inferExpression(stmt.Iterable)
// For now, assume iterable is a table
if !ti.isTableType(iterableType) {
ti.addError("for-in requires an iterable (table)", stmt.Iterable)
}
ti.enterScope()
// Define loop variables (key and value are any for now)
if stmt.Key != nil {
ti.currentScope.Define(&Symbol{
Name: stmt.Key.Value,
Type: ti.anyType,
Declared: true,
})
stmt.Key.SetType(ti.anyType)
}
ti.currentScope.Define(&Symbol{
Name: stmt.Value.Value,
Type: ti.anyType,
Declared: true,
})
stmt.Value.SetType(ti.anyType)
for _, s := range stmt.Body {
ti.inferStatement(s)
}
ti.exitScope()
}
// inferExpression infers the type of an expression
func (ti *TypeInferrer) inferExpression(expr Expression) *TypeInfo {
if expr == nil {
return ti.nilType
}
switch e := expr.(type) {
case *Identifier:
return ti.inferIdentifier(e)
case *NumberLiteral:
e.SetType(ti.numberType)
return ti.numberType
case *StringLiteral:
e.SetType(ti.stringType)
return ti.stringType
case *BooleanLiteral:
e.SetType(ti.boolType)
return ti.boolType
case *NilLiteral:
e.SetType(ti.nilType)
return ti.nilType
case *TableLiteral:
return ti.inferTableLiteral(e)
case *FunctionLiteral:
return ti.inferFunctionLiteral(e)
case *CallExpression:
return ti.inferCallExpression(e)
case *PrefixExpression:
return ti.inferPrefixExpression(e)
case *InfixExpression:
return ti.inferInfixExpression(e)
case *IndexExpression:
return ti.inferIndexExpression(e)
case *DotExpression:
return ti.inferDotExpression(e)
default:
ti.addError("unknown expression type", expr)
return ti.anyType
}
}
// inferIdentifier looks up identifier type in symbol table
func (ti *TypeInferrer) inferIdentifier(ident *Identifier) *TypeInfo {
symbol := ti.currentScope.Lookup(ident.Value)
if symbol == nil {
ti.addError(fmt.Sprintf("undefined variable '%s'", ident.Value), ident)
return ti.anyType
}
ident.SetType(symbol.Type)
return symbol.Type
}
// inferTableLiteral infers table type
func (ti *TypeInferrer) inferTableLiteral(table *TableLiteral) *TypeInfo {
// Infer types of all values
for _, pair := range table.Pairs {
if pair.Key != nil {
ti.inferExpression(pair.Key)
}
ti.inferExpression(pair.Value)
}
table.SetType(ti.tableType)
return ti.tableType
}
// inferFunctionLiteral infers function type
func (ti *TypeInferrer) inferFunctionLiteral(fn *FunctionLiteral) *TypeInfo {
ti.enterScope()
// Define parameters in function scope
for _, param := range fn.Parameters {
paramType := ti.anyType
if param.TypeHint != nil {
paramType = param.TypeHint
}
ti.currentScope.Define(&Symbol{
Name: param.Name,
Type: paramType,
Declared: true,
})
}
// Infer body
for _, stmt := range fn.Body {
ti.inferStatement(stmt)
}
ti.exitScope()
// For now, all functions have type "function"
funcType := &TypeInfo{Type: TypeFunction, Inferred: true}
fn.SetType(funcType)
return funcType
}
// inferCallExpression infers function call return type
func (ti *TypeInferrer) inferCallExpression(call *CallExpression) *TypeInfo {
funcType := ti.inferExpression(call.Function)
if !ti.isFunctionType(funcType) {
ti.addError("cannot call non-function", call.Function)
return ti.anyType
}
// Infer argument types
for _, arg := range call.Arguments {
ti.inferExpression(arg)
}
// For now, assume function calls return any
call.SetType(ti.anyType)
return ti.anyType
}
// inferPrefixExpression infers prefix operation type
func (ti *TypeInferrer) inferPrefixExpression(prefix *PrefixExpression) *TypeInfo {
rightType := ti.inferExpression(prefix.Right)
var resultType *TypeInfo
switch prefix.Operator {
case "-":
if !ti.isNumericType(rightType) {
ti.addError("unary minus requires numeric operand", prefix)
}
resultType = ti.numberType
case "not":
resultType = ti.boolType
default:
ti.addError(fmt.Sprintf("unknown prefix operator '%s'", prefix.Operator), prefix)
resultType = ti.anyType
}
prefix.SetType(resultType)
return resultType
}
// inferInfixExpression infers binary operation type
func (ti *TypeInferrer) inferInfixExpression(infix *InfixExpression) *TypeInfo {
leftType := ti.inferExpression(infix.Left)
rightType := ti.inferExpression(infix.Right)
var resultType *TypeInfo
switch infix.Operator {
case "+", "-", "*", "/":
if !ti.isNumericType(leftType) || !ti.isNumericType(rightType) {
ti.addError(fmt.Sprintf("arithmetic operator '%s' requires numeric operands", infix.Operator), infix)
}
resultType = ti.numberType
case "==", "!=":
// Equality works with any types
resultType = ti.boolType
case "<", ">", "<=", ">=":
if !ti.isComparableTypes(leftType, rightType) {
ti.addError(fmt.Sprintf("comparison operator '%s' requires compatible operands", infix.Operator), infix)
}
resultType = ti.boolType
case "and", "or":
ti.validateBooleanContext(leftType, infix.Left)
ti.validateBooleanContext(rightType, infix.Right)
resultType = ti.boolType
default:
ti.addError(fmt.Sprintf("unknown infix operator '%s'", infix.Operator), infix)
resultType = ti.anyType
}
infix.SetType(resultType)
return resultType
}
// inferIndexExpression infers table[index] type
func (ti *TypeInferrer) inferIndexExpression(index *IndexExpression) *TypeInfo {
ti.inferExpression(index.Left)
ti.inferExpression(index.Index)
// For now, assume table access returns any
index.SetType(ti.anyType)
return ti.anyType
}
// inferDotExpression infers table.key type
func (ti *TypeInferrer) inferDotExpression(dot *DotExpression) *TypeInfo {
ti.inferExpression(dot.Left)
// For now, assume member access returns any
dot.SetType(ti.anyType)
return ti.anyType
}
// Type checking helper methods
func (ti *TypeInferrer) isTypeCompatible(valueType, targetType *TypeInfo) bool {
if targetType.Type == TypeAny || valueType.Type == TypeAny {
return true
}
return valueType.Type == targetType.Type
}
func (ti *TypeInferrer) isNumericType(t *TypeInfo) bool {
return t.Type == TypeNumber
}
func (ti *TypeInferrer) isBooleanType(t *TypeInfo) bool {
return t.Type == TypeBool
}
func (ti *TypeInferrer) isTableType(t *TypeInfo) bool {
return t.Type == TypeTable
}
func (ti *TypeInferrer) isFunctionType(t *TypeInfo) bool {
return t.Type == TypeFunction
}
func (ti *TypeInferrer) isComparableTypes(left, right *TypeInfo) bool {
if left.Type == TypeAny || right.Type == TypeAny {
return true
}
return left.Type == right.Type && (left.Type == TypeNumber || left.Type == TypeString)
}
func (ti *TypeInferrer) validateBooleanContext(t *TypeInfo, expr Expression) {
// In many languages, non-boolean values can be used in boolean context
// For strictness, we could require boolean type here
// For now, allow any type (truthy/falsy semantics)
}
// Errors returns all type checking errors
func (ti *TypeInferrer) Errors() []TypeError {
return ti.errors
}
// HasErrors returns true if there are any type errors
func (ti *TypeInferrer) HasErrors() bool {
return len(ti.errors) > 0
}
// ErrorStrings returns error messages as strings
func (ti *TypeInferrer) ErrorStrings() []string {
result := make([]string, len(ti.errors))
for i, err := range ti.errors {
result[i] = err.Error()
}
return result
}
// ValidTypeName checks if a string is a valid type name
func ValidTypeName(name string) bool {
validTypes := []string{TypeNumber, TypeString, TypeBool, TypeNil, TypeTable, TypeFunction, TypeAny}
for _, validType := range validTypes {
if name == validType {
return true
}
}
return false
}
// ParseTypeName converts a string to a TypeInfo (for parsing type hints)
func ParseTypeName(name string) *TypeInfo {
if ValidTypeName(name) {
return &TypeInfo{Type: name, Inferred: false}
}
return nil
}