Mako/parser/types.go
2025-06-11 23:08:34 -05:00

726 lines
18 KiB
Go

package parser
import "fmt"
// Type represents built-in and user-defined types using compact enum representation.
// Uses single byte instead of string pointers to minimize memory usage.
type Type uint8
const (
TypeUnknown Type = iota
TypeNumber
TypeString
TypeBool
TypeNil
TypeTable
TypeFunction
TypeAny
TypeStruct // struct types use StructID field for identification
)
// TypeInfo represents type information with zero-allocation design.
// Embeds directly in AST nodes instead of using pointers to reduce heap pressure.
// Common types are pre-allocated as globals to eliminate most allocations.
type TypeInfo struct {
Type Type // Built-in type or TypeStruct for user types
Inferred bool // True if type was inferred, false if explicitly declared
StructID uint16 // Index into global struct table for struct types (0 for non-structs)
}
// Pre-allocated common types - eliminates heap allocations for built-in types
var (
UnknownType = TypeInfo{Type: TypeUnknown, Inferred: true}
NumberType = TypeInfo{Type: TypeNumber, Inferred: true}
StringType = TypeInfo{Type: TypeString, Inferred: true}
BoolType = TypeInfo{Type: TypeBool, Inferred: true}
NilType = TypeInfo{Type: TypeNil, Inferred: true}
TableType = TypeInfo{Type: TypeTable, Inferred: true}
FunctionType = TypeInfo{Type: TypeFunction, Inferred: true}
AnyType = TypeInfo{Type: TypeAny, Inferred: true}
)
// TypeError represents a type checking error with location information
type TypeError struct {
Message string
Line int
Column int
Node Node
}
func (te TypeError) Error() string {
return fmt.Sprintf("Type error at line %d, column %d: %s", te.Line, te.Column, te.Message)
}
// Symbol represents a variable in the symbol table with optimized type storage
type Symbol struct {
Name string
Type TypeInfo // Embed directly instead of pointer
Declared bool
Line int
Column int
}
// Scope represents a scope in the symbol table
type Scope struct {
symbols map[string]*Symbol
parent *Scope
}
func NewScope(parent *Scope) *Scope {
return &Scope{
symbols: make(map[string]*Symbol),
parent: parent,
}
}
func (s *Scope) Define(symbol *Symbol) {
s.symbols[symbol.Name] = symbol
}
func (s *Scope) Lookup(name string) *Symbol {
if symbol, ok := s.symbols[name]; ok {
return symbol
}
if s.parent != nil {
return s.parent.Lookup(name)
}
return nil
}
// TypeInferrer performs type inference and checking with optimized allocations
type TypeInferrer struct {
currentScope *Scope
globalScope *Scope
errors []TypeError
// Struct definitions with ID mapping
structs map[string]*StructStatement
structIDs map[uint16]*StructStatement
nextID uint16
}
// NewTypeInferrer creates a new type inference engine
func NewTypeInferrer() *TypeInferrer {
globalScope := NewScope(nil)
return &TypeInferrer{
currentScope: globalScope,
globalScope: globalScope,
errors: []TypeError{},
structs: make(map[string]*StructStatement),
structIDs: make(map[uint16]*StructStatement),
nextID: 1, // 0 reserved for non-struct types
}
}
// RegisterStruct assigns ID to struct and tracks it
func (ti *TypeInferrer) RegisterStruct(stmt *StructStatement) {
stmt.ID = ti.nextID
ti.nextID++
ti.structs[stmt.Name] = stmt
ti.structIDs[stmt.ID] = stmt
}
// GetStructByID returns struct definition by ID
func (ti *TypeInferrer) GetStructByID(id uint16) *StructStatement {
return ti.structIDs[id]
}
// CreateStructType creates TypeInfo for a struct
func (ti *TypeInferrer) CreateStructType(structID uint16) TypeInfo {
return TypeInfo{Type: TypeStruct, StructID: structID, Inferred: true}
}
// InferTypes performs type inference on the entire program
func (ti *TypeInferrer) InferTypes(program *Program) []TypeError {
// First pass: collect and register struct definitions
for _, stmt := range program.Statements {
if structStmt, ok := stmt.(*StructStatement); ok {
ti.RegisterStruct(structStmt)
}
}
// Second pass: infer types
for _, stmt := range program.Statements {
ti.inferStatement(stmt)
}
return ti.errors
}
// inferStatement infers types for statements
func (ti *TypeInferrer) inferStatement(stmt Statement) {
switch s := stmt.(type) {
case *StructStatement:
ti.inferStructStatement(s)
case *MethodDefinition:
ti.inferMethodDefinition(s)
case *Assignment:
ti.inferAssignment(s)
case *ExpressionStatement:
ti.inferExpression(s.Expression)
case *EchoStatement:
ti.inferExpression(s.Value)
case *IfStatement:
ti.inferIfStatement(s)
case *WhileStatement:
ti.inferWhileStatement(s)
case *ForStatement:
ti.inferForStatement(s)
case *ForInStatement:
ti.inferForInStatement(s)
case *ReturnStatement:
if s.Value != nil {
ti.inferExpression(s.Value)
}
case *ExitStatement:
if s.Value != nil {
ti.inferExpression(s.Value)
}
case *BreakStatement:
// No-op
}
}
func (ti *TypeInferrer) inferStructStatement(stmt *StructStatement) {
for _, field := range stmt.Fields {
if !ti.isValidType(field.TypeHint) {
ti.addError(fmt.Sprintf("invalid field type in struct '%s'", stmt.Name), stmt)
}
}
}
func (ti *TypeInferrer) inferMethodDefinition(stmt *MethodDefinition) {
structDef := ti.GetStructByID(stmt.StructID)
if structDef == nil {
ti.addError("method defined on undefined struct", stmt)
return
}
ti.enterScope()
// Add self parameter
ti.currentScope.Define(&Symbol{
Name: "self",
Type: ti.CreateStructType(stmt.StructID),
Declared: true,
})
// Add function parameters
for _, param := range stmt.Function.Parameters {
paramType := AnyType
if param.TypeHint.Type != TypeUnknown {
paramType = param.TypeHint
}
ti.currentScope.Define(&Symbol{
Name: param.Name,
Type: paramType,
Declared: true,
})
}
// Infer function body
for _, bodyStmt := range stmt.Function.Body {
ti.inferStatement(bodyStmt)
}
ti.exitScope()
}
func (ti *TypeInferrer) inferAssignment(stmt *Assignment) {
valueType := ti.inferExpression(stmt.Value)
if ident, ok := stmt.Target.(*Identifier); ok {
if stmt.IsDeclaration {
varType := valueType
if stmt.TypeHint.Type != TypeUnknown {
if !ti.isTypeCompatible(valueType, stmt.TypeHint) {
ti.addError("type mismatch in assignment", stmt)
}
varType = stmt.TypeHint
}
ti.currentScope.Define(&Symbol{
Name: ident.Value,
Type: varType,
Declared: true,
})
ident.typeInfo = varType
} else {
symbol := ti.currentScope.Lookup(ident.Value)
if symbol == nil {
ti.addError(fmt.Sprintf("undefined variable '%s'", ident.Value), stmt)
return
}
if !ti.isTypeCompatible(valueType, symbol.Type) {
ti.addError("type mismatch in assignment", stmt)
}
ident.typeInfo = symbol.Type
}
} else {
// Member access assignment (table.key or table[index])
ti.inferExpression(stmt.Target)
}
}
func (ti *TypeInferrer) inferIfStatement(stmt *IfStatement) {
ti.inferExpression(stmt.Condition)
ti.enterScope()
for _, s := range stmt.Body {
ti.inferStatement(s)
}
ti.exitScope()
for _, elseif := range stmt.ElseIfs {
ti.inferExpression(elseif.Condition)
ti.enterScope()
for _, s := range elseif.Body {
ti.inferStatement(s)
}
ti.exitScope()
}
if len(stmt.Else) > 0 {
ti.enterScope()
for _, s := range stmt.Else {
ti.inferStatement(s)
}
ti.exitScope()
}
}
func (ti *TypeInferrer) inferWhileStatement(stmt *WhileStatement) {
ti.inferExpression(stmt.Condition)
ti.enterScope()
for _, s := range stmt.Body {
ti.inferStatement(s)
}
ti.exitScope()
}
func (ti *TypeInferrer) inferForStatement(stmt *ForStatement) {
ti.inferExpression(stmt.Start)
ti.inferExpression(stmt.End)
if stmt.Step != nil {
ti.inferExpression(stmt.Step)
}
ti.enterScope()
// Define loop variable as number
ti.currentScope.Define(&Symbol{
Name: stmt.Variable.Value,
Type: NumberType,
Declared: true,
})
stmt.Variable.typeInfo = NumberType
for _, s := range stmt.Body {
ti.inferStatement(s)
}
ti.exitScope()
}
func (ti *TypeInferrer) inferForInStatement(stmt *ForInStatement) {
ti.inferExpression(stmt.Iterable)
ti.enterScope()
// Define loop variables
if stmt.Key != nil {
ti.currentScope.Define(&Symbol{
Name: stmt.Key.Value,
Type: AnyType,
Declared: true,
})
stmt.Key.typeInfo = AnyType
}
ti.currentScope.Define(&Symbol{
Name: stmt.Value.Value,
Type: AnyType,
Declared: true,
})
stmt.Value.typeInfo = AnyType
for _, s := range stmt.Body {
ti.inferStatement(s)
}
ti.exitScope()
}
// inferExpression infers the type of an expression
func (ti *TypeInferrer) inferExpression(expr Expression) TypeInfo {
if expr == nil {
return NilType
}
switch e := expr.(type) {
case *Identifier:
return ti.inferIdentifier(e)
case *NumberLiteral:
return NumberType
case *StringLiteral:
return StringType
case *BooleanLiteral:
return BoolType
case *NilLiteral:
return NilType
case *TableLiteral:
return ti.inferTableLiteral(e)
case *StructConstructor:
return ti.inferStructConstructor(e)
case *FunctionLiteral:
return ti.inferFunctionLiteral(e)
case *CallExpression:
return ti.inferCallExpression(e)
case *PrefixExpression:
return ti.inferPrefixExpression(e)
case *InfixExpression:
return ti.inferInfixExpression(e)
case *IndexExpression:
return ti.inferIndexExpression(e)
case *DotExpression:
return ti.inferDotExpression(e)
case *Assignment:
return ti.inferAssignmentExpression(e)
default:
ti.addError("unknown expression type", expr)
return AnyType
}
}
func (ti *TypeInferrer) inferIdentifier(ident *Identifier) TypeInfo {
symbol := ti.currentScope.Lookup(ident.Value)
if symbol == nil {
ti.addError(fmt.Sprintf("undefined variable '%s'", ident.Value), ident)
return AnyType
}
ident.typeInfo = symbol.Type
return symbol.Type
}
func (ti *TypeInferrer) inferTableLiteral(table *TableLiteral) TypeInfo {
// Infer types of all values
for _, pair := range table.Pairs {
if pair.Key != nil {
ti.inferExpression(pair.Key)
}
ti.inferExpression(pair.Value)
}
return TableType
}
func (ti *TypeInferrer) inferStructConstructor(expr *StructConstructor) TypeInfo {
structDef := ti.GetStructByID(expr.StructID)
if structDef == nil {
ti.addError("undefined struct in constructor", expr)
return AnyType
}
// Validate field assignments
for _, pair := range expr.Fields {
if pair.Key != nil {
fieldName := ""
if ident, ok := pair.Key.(*Identifier); ok {
fieldName = ident.Value
} else if str, ok := pair.Key.(*StringLiteral); ok {
fieldName = str.Value
}
// Find field in struct definition
var fieldType TypeInfo
found := false
for _, field := range structDef.Fields {
if field.Name == fieldName {
fieldType = field.TypeHint
found = true
break
}
}
if !found {
ti.addError(fmt.Sprintf("struct has no field '%s'", fieldName), expr)
} else {
valueType := ti.inferExpression(pair.Value)
if !ti.isTypeCompatible(valueType, fieldType) {
ti.addError("field type mismatch in struct constructor", expr)
}
}
} else {
// Array-style assignment not valid for structs
ti.addError("struct constructors require named field assignments", expr)
}
}
structType := ti.CreateStructType(expr.StructID)
expr.typeInfo = structType
return structType
}
func (ti *TypeInferrer) inferFunctionLiteral(fn *FunctionLiteral) TypeInfo {
ti.enterScope()
// Define parameters in function scope
for _, param := range fn.Parameters {
paramType := AnyType
if param.TypeHint.Type != TypeUnknown {
paramType = param.TypeHint
}
ti.currentScope.Define(&Symbol{
Name: param.Name,
Type: paramType,
Declared: true,
})
}
// Infer body
for _, stmt := range fn.Body {
ti.inferStatement(stmt)
}
ti.exitScope()
return FunctionType
}
func (ti *TypeInferrer) inferCallExpression(call *CallExpression) TypeInfo {
ti.inferExpression(call.Function)
// Infer argument types
for _, arg := range call.Arguments {
ti.inferExpression(arg)
}
call.typeInfo = AnyType
return AnyType
}
func (ti *TypeInferrer) inferPrefixExpression(prefix *PrefixExpression) TypeInfo {
rightType := ti.inferExpression(prefix.Right)
var resultType TypeInfo
switch prefix.Operator {
case "-":
if !ti.isNumericType(rightType) {
ti.addError("unary minus requires numeric operand", prefix)
}
resultType = NumberType
case "not":
resultType = BoolType
default:
ti.addError(fmt.Sprintf("unknown prefix operator '%s'", prefix.Operator), prefix)
resultType = AnyType
}
prefix.typeInfo = resultType
return resultType
}
func (ti *TypeInferrer) inferInfixExpression(infix *InfixExpression) TypeInfo {
leftType := ti.inferExpression(infix.Left)
rightType := ti.inferExpression(infix.Right)
var resultType TypeInfo
switch infix.Operator {
case "+", "-", "*", "/", "%":
if !ti.isNumericType(leftType) || !ti.isNumericType(rightType) {
ti.addError(fmt.Sprintf("arithmetic operator '%s' requires numeric operands", infix.Operator), infix)
}
resultType = NumberType
case "==", "!=":
// Equality works with any types
resultType = BoolType
case "<", ">", "<=", ">=":
if !ti.isComparableTypes(leftType, rightType) {
ti.addError(fmt.Sprintf("comparison operator '%s' requires compatible operands", infix.Operator), infix)
}
resultType = BoolType
case "and", "or":
resultType = BoolType
default:
ti.addError(fmt.Sprintf("unknown infix operator '%s'", infix.Operator), infix)
resultType = AnyType
}
infix.typeInfo = resultType
return resultType
}
func (ti *TypeInferrer) inferIndexExpression(index *IndexExpression) TypeInfo {
leftType := ti.inferExpression(index.Left)
ti.inferExpression(index.Index)
// If indexing a struct, try to infer field type
if ti.isStructType(leftType) {
if strLit, ok := index.Index.(*StringLiteral); ok {
if structDef := ti.GetStructByID(leftType.StructID); structDef != nil {
for _, field := range structDef.Fields {
if field.Name == strLit.Value {
index.typeInfo = field.TypeHint
return field.TypeHint
}
}
}
}
}
// For now, assume table/struct access returns any
index.typeInfo = AnyType
return AnyType
}
func (ti *TypeInferrer) inferDotExpression(dot *DotExpression) TypeInfo {
leftType := ti.inferExpression(dot.Left)
// If accessing a struct field, try to infer field type
if ti.isStructType(leftType) {
if structDef := ti.GetStructByID(leftType.StructID); structDef != nil {
for _, field := range structDef.Fields {
if field.Name == dot.Key {
dot.typeInfo = field.TypeHint
return field.TypeHint
}
}
}
}
// For now, assume member access returns any
dot.typeInfo = AnyType
return AnyType
}
func (ti *TypeInferrer) inferAssignmentExpression(expr *Assignment) TypeInfo {
valueType := ti.inferExpression(expr.Value)
if ident, ok := expr.Target.(*Identifier); ok {
if expr.IsDeclaration {
varType := valueType
if expr.TypeHint.Type != TypeUnknown {
if !ti.isTypeCompatible(valueType, expr.TypeHint) {
ti.addError("type mismatch in assignment", expr)
}
varType = expr.TypeHint
}
ti.currentScope.Define(&Symbol{
Name: ident.Value,
Type: varType,
Declared: true,
})
ident.typeInfo = varType
} else {
symbol := ti.currentScope.Lookup(ident.Value)
if symbol != nil {
ident.typeInfo = symbol.Type
}
}
} else {
ti.inferExpression(expr.Target)
}
return valueType
}
// Helper methods
func (ti *TypeInferrer) enterScope() {
ti.currentScope = NewScope(ti.currentScope)
}
func (ti *TypeInferrer) exitScope() {
if ti.currentScope.parent != nil {
ti.currentScope = ti.currentScope.parent
}
}
func (ti *TypeInferrer) addError(message string, node Node) {
pos := node.Pos()
ti.errors = append(ti.errors, TypeError{
Message: message,
Line: pos.Line,
Column: pos.Column,
Node: node,
})
}
func (ti *TypeInferrer) isValidType(t TypeInfo) bool {
if t.Type == TypeStruct {
return ti.GetStructByID(t.StructID) != nil
}
return t.Type <= TypeStruct
}
func (ti *TypeInferrer) isTypeCompatible(valueType, targetType TypeInfo) bool {
if targetType.Type == TypeAny || valueType.Type == TypeAny {
return true
}
if valueType.Type == TypeStruct && targetType.Type == TypeStruct {
return valueType.StructID == targetType.StructID
}
return valueType.Type == targetType.Type
}
func (ti *TypeInferrer) isNumericType(t TypeInfo) bool {
return t.Type == TypeNumber
}
func (ti *TypeInferrer) isBooleanType(t TypeInfo) bool {
return t.Type == TypeBool
}
func (ti *TypeInferrer) isTableType(t TypeInfo) bool {
return t.Type == TypeTable
}
func (ti *TypeInferrer) isFunctionType(t TypeInfo) bool {
return t.Type == TypeFunction
}
func (ti *TypeInferrer) isStructType(t TypeInfo) bool {
return t.Type == TypeStruct
}
func (ti *TypeInferrer) isComparableTypes(left, right TypeInfo) bool {
if left.Type == TypeAny || right.Type == TypeAny {
return true
}
return left.Type == right.Type && (left.Type == TypeNumber || left.Type == TypeString)
}
// Error reporting
func (ti *TypeInferrer) Errors() []TypeError { return ti.errors }
func (ti *TypeInferrer) HasErrors() bool { return len(ti.errors) > 0 }
func (ti *TypeInferrer) ErrorStrings() []string {
result := make([]string, len(ti.errors))
for i, err := range ti.errors {
result[i] = err.Error()
}
return result
}
// Type string conversion
func TypeToString(t TypeInfo) string {
switch t.Type {
case TypeNumber:
return "number"
case TypeString:
return "string"
case TypeBool:
return "bool"
case TypeNil:
return "nil"
case TypeTable:
return "table"
case TypeFunction:
return "function"
case TypeAny:
return "any"
case TypeStruct:
return fmt.Sprintf("struct<%d>", t.StructID)
default:
return "unknown"
}
}