AST/types optimization

This commit is contained in:
Sky Johnson 2025-06-11 17:01:01 -05:00
parent 5ae2a6ef23
commit 30e4b11a96
13 changed files with 885 additions and 894 deletions

View File

@ -2,32 +2,28 @@ package parser
import "fmt"
// TypeInfo represents type information for expressions
type TypeInfo struct {
Type string // "number", "string", "bool", "table", "function", "nil", "any", struct name
Inferred bool // true if type was inferred, false if explicitly declared
}
// Note: Type definitions moved to types.go for proper separation of concerns
// Node represents any node in the AST
type Node interface {
String() string
}
// Statement represents statement nodes
// Statement represents statement nodes that can appear at the top level or in blocks
type Statement interface {
Node
statementNode()
}
// Expression represents expression nodes
// Expression represents expression nodes that produce values and have types
type Expression interface {
Node
expressionNode()
GetType() *TypeInfo
SetType(*TypeInfo)
TypeInfo() TypeInfo // Returns type by value, not pointer
}
// Program represents the root of the AST
// Program represents the root of the AST containing all top-level statements.
// Tracks exit code for script termination and owns the statement list.
type Program struct {
Statements []Statement
ExitCode int
@ -41,23 +37,23 @@ func (p *Program) String() string {
return result
}
// StructField represents a field in a struct definition
// StructField represents a field definition within a struct.
// Contains field name and required type annotation for compile-time checking.
type StructField struct {
Name string
TypeHint *TypeInfo
TypeHint TypeInfo // Required for struct fields, embeds directly
}
func (sf *StructField) String() string {
if sf.TypeHint != nil {
return fmt.Sprintf("%s: %s", sf.Name, sf.TypeHint.Type)
}
return sf.Name
return fmt.Sprintf("%s: %s", sf.Name, typeToString(sf.TypeHint))
}
// StructStatement represents struct definitions
// StructStatement represents struct type definitions with named fields.
// Defines new types that can be instantiated and used for type checking.
type StructStatement struct {
Name string
Fields []StructField
ID uint16 // Unique identifier for fast lookup
}
func (ss *StructStatement) statementNode() {}
@ -72,77 +68,72 @@ func (ss *StructStatement) String() string {
return fmt.Sprintf("struct %s {\n\t%s\n}", ss.Name, fields)
}
// MethodDefinition represents method definitions on structs
// MethodDefinition represents method definitions attached to struct types.
// Links a function implementation to a specific struct via struct ID.
type MethodDefinition struct {
StructName string
StructID uint16 // Index into struct table for fast lookup
MethodName string
Function *FunctionLiteral
}
func (md *MethodDefinition) statementNode() {}
func (md *MethodDefinition) String() string {
return fmt.Sprintf("fn %s.%s%s", md.StructName, md.MethodName, md.Function.String()[2:]) // skip "fn" from function string
return fmt.Sprintf("fn <struct>.%s%s", md.MethodName, md.Function.String()[2:])
}
// StructConstructorExpression represents struct constructor calls like my_type{...}
type StructConstructorExpression struct {
StructName string
Fields []TablePair // reuse TablePair for field assignments
typeInfo *TypeInfo
// StructConstructor represents struct instantiation with field initialization.
// Uses struct ID for fast type resolution and validation during parsing.
type StructConstructor struct {
StructID uint16 // Index into struct table
Fields []TablePair // Reuses table pair structure for field assignments
typeInfo TypeInfo // Cached type info for this constructor
}
func (sce *StructConstructorExpression) expressionNode() {}
func (sce *StructConstructorExpression) String() string {
func (sc *StructConstructor) expressionNode() {}
func (sc *StructConstructor) String() string {
var pairs []string
for _, pair := range sce.Fields {
for _, pair := range sc.Fields {
pairs = append(pairs, pair.String())
}
return fmt.Sprintf("%s{%s}", sce.StructName, joinStrings(pairs, ", "))
return fmt.Sprintf("<struct>{%s}", joinStrings(pairs, ", "))
}
func (sce *StructConstructorExpression) GetType() *TypeInfo { return sce.typeInfo }
func (sce *StructConstructorExpression) SetType(t *TypeInfo) { sce.typeInfo = t }
func (sc *StructConstructor) TypeInfo() TypeInfo { return sc.typeInfo }
// AssignStatement represents variable assignment with optional type hint
type AssignStatement struct {
Name Expression // Changed from *Identifier to Expression for member access
TypeHint *TypeInfo // optional type hint
Value Expression
IsDeclaration bool // true if this is the first assignment in current scope
// Assignment represents both variable assignment statements and assignment expressions.
// Unified design reduces AST node count and simplifies type checking logic.
type Assignment struct {
Target Expression // Target (identifier, dot, or index expression)
Value Expression // Value being assigned
TypeHint TypeInfo // Optional explicit type hint, embeds directly
IsDeclaration bool // True if declaring new variable in current scope
IsExpression bool // True if used as expression (wrapped in parentheses)
}
func (as *AssignStatement) statementNode() {}
func (as *AssignStatement) String() string {
func (a *Assignment) statementNode() {}
func (a *Assignment) expressionNode() {}
func (a *Assignment) String() string {
prefix := ""
if as.IsDeclaration {
if a.IsDeclaration {
prefix = "local "
}
var nameStr string
if as.TypeHint != nil {
nameStr = fmt.Sprintf("%s: %s", as.Name.String(), as.TypeHint.Type)
if a.TypeHint.Type != TypeUnknown {
nameStr = fmt.Sprintf("%s: %s", a.Target.String(), typeToString(a.TypeHint))
} else {
nameStr = as.Name.String()
nameStr = a.Target.String()
}
return fmt.Sprintf("%s%s = %s", prefix, nameStr, as.Value.String())
result := fmt.Sprintf("%s%s = %s", prefix, nameStr, a.Value.String())
if a.IsExpression {
return "(" + result + ")"
}
// AssignExpression represents assignment as an expression (only in parentheses)
type AssignExpression struct {
Name Expression // Target (identifier, dot, or index expression)
Value Expression // Value to assign
IsDeclaration bool // true if this declares a new variable
typeInfo *TypeInfo // type of the expression (same as assigned value)
return result
}
func (a *Assignment) TypeInfo() TypeInfo { return a.Value.TypeInfo() }
func (ae *AssignExpression) expressionNode() {}
func (ae *AssignExpression) String() string {
return fmt.Sprintf("(%s = %s)", ae.Name.String(), ae.Value.String())
}
func (ae *AssignExpression) GetType() *TypeInfo { return ae.typeInfo }
func (ae *AssignExpression) SetType(t *TypeInfo) { ae.typeInfo = t }
// ExpressionStatement represents expressions used as statements
// ExpressionStatement wraps expressions used as statements.
// Allows function calls and other expressions at statement level.
type ExpressionStatement struct {
Expression Expression
}
@ -152,7 +143,8 @@ func (es *ExpressionStatement) String() string {
return es.Expression.String()
}
// EchoStatement represents echo output statements
// EchoStatement represents output statements for displaying values.
// Simple debugging and output mechanism built into the language.
type EchoStatement struct {
Value Expression
}
@ -162,17 +154,17 @@ func (es *EchoStatement) String() string {
return fmt.Sprintf("echo %s", es.Value.String())
}
// BreakStatement represents break statements to exit loops
// BreakStatement represents loop exit statements.
// Simple marker node with no additional data needed.
type BreakStatement struct{}
func (bs *BreakStatement) statementNode() {}
func (bs *BreakStatement) String() string {
return "break"
}
func (bs *BreakStatement) String() string { return "break" }
// ExitStatement represents exit statements to quit the script
// ExitStatement represents script termination with optional exit code.
// Value expression is nil for plain "exit", non-nil for "exit <code>".
type ExitStatement struct {
Value Expression // optional, can be nil
Value Expression // Optional exit code expression
}
func (es *ExitStatement) statementNode() {}
@ -183,9 +175,10 @@ func (es *ExitStatement) String() string {
return fmt.Sprintf("exit %s", es.Value.String())
}
// ReturnStatement represents return statements
// ReturnStatement represents function return with optional value.
// Value expression is nil for plain "return", non-nil for "return <value>".
type ReturnStatement struct {
Value Expression // optional, can be nil
Value Expression // Optional return value expression
}
func (rs *ReturnStatement) statementNode() {}
@ -196,7 +189,8 @@ func (rs *ReturnStatement) String() string {
return fmt.Sprintf("return %s", rs.Value.String())
}
// ElseIfClause represents an elseif condition
// ElseIfClause represents conditional branches in if statements.
// Contains condition expression and body statements for this branch.
type ElseIfClause struct {
Condition Expression
Body []Statement
@ -210,30 +204,28 @@ func (eic *ElseIfClause) String() string {
return fmt.Sprintf("elseif %s then\n%s", eic.Condition.String(), body)
}
// IfStatement represents conditional statements
// IfStatement represents conditional execution with optional elseif and else branches.
// Supports multiple elseif clauses and an optional final else clause.
type IfStatement struct {
Condition Expression
Body []Statement
ElseIfs []ElseIfClause
Else []Statement
Condition Expression // Main condition
Body []Statement // Statements to execute if condition is true
ElseIfs []ElseIfClause // Optional elseif branches
Else []Statement // Optional else branch
}
func (is *IfStatement) statementNode() {}
func (is *IfStatement) String() string {
var result string
// If clause
result += fmt.Sprintf("if %s then\n", is.Condition.String())
for _, stmt := range is.Body {
result += "\t" + stmt.String() + "\n"
}
// ElseIf clauses
for _, elseif := range is.ElseIfs {
result += elseif.String()
}
// Else clause
if len(is.Else) > 0 {
result += "else\n"
for _, stmt := range is.Else {
@ -245,7 +237,8 @@ func (is *IfStatement) String() string {
return result
}
// WhileStatement represents while loops: while condition do ... end
// WhileStatement represents condition-based loops that execute while condition is true.
// Contains condition expression and body statements to repeat.
type WhileStatement struct {
Condition Expression
Body []Statement
@ -264,13 +257,14 @@ func (ws *WhileStatement) String() string {
return result
}
// ForStatement represents numeric for loops: for i = start, end, step do ... end
// ForStatement represents numeric for loops with start, end, and optional step.
// Variable is automatically scoped to the loop body.
type ForStatement struct {
Variable *Identifier
Start Expression
End Expression
Step Expression // optional, nil means step of 1
Body []Statement
Variable *Identifier // Loop variable (automatically number type)
Start Expression // Starting value expression
End Expression // Ending value expression
Step Expression // Optional step expression (nil means step of 1)
Body []Statement // Loop body statements
}
func (fs *ForStatement) statementNode() {}
@ -292,12 +286,13 @@ func (fs *ForStatement) String() string {
return result
}
// ForInStatement represents iterator for loops: for k, v in expr do ... end
// ForInStatement represents iterator-based loops over tables, arrays, or other iterables.
// Supports both single variable (for v in iter) and key-value (for k,v in iter) forms.
type ForInStatement struct {
Key *Identifier // optional, nil for single variable iteration
Value *Identifier
Iterable Expression
Body []Statement
Key *Identifier // Optional key variable (nil for single variable iteration)
Value *Identifier // Value variable (required)
Iterable Expression // Expression to iterate over
Body []Statement // Loop body statements
}
func (fis *ForInStatement) statementNode() {}
@ -319,56 +314,60 @@ func (fis *ForInStatement) String() string {
return result
}
// FunctionParameter represents a function parameter with optional type hint
// FunctionParameter represents a parameter in function definitions.
// Contains parameter name and optional type hint for type checking.
type FunctionParameter struct {
Name string
TypeHint *TypeInfo
TypeHint TypeInfo // Optional type constraint, embeds directly
}
func (fp *FunctionParameter) String() string {
if fp.TypeHint != nil {
return fmt.Sprintf("%s: %s", fp.Name, fp.TypeHint.Type)
if fp.TypeHint.Type != TypeUnknown {
return fmt.Sprintf("%s: %s", fp.Name, typeToString(fp.TypeHint))
}
return fp.Name
}
// Identifier represents identifiers
// Identifier represents variable references and names.
// Stores resolved type information for efficient type checking.
type Identifier struct {
Value string
typeInfo *TypeInfo
typeInfo TypeInfo // Resolved type, embeds directly
}
func (i *Identifier) expressionNode() {}
func (i *Identifier) String() string { return i.Value }
func (i *Identifier) GetType() *TypeInfo { return i.typeInfo }
func (i *Identifier) SetType(t *TypeInfo) { i.typeInfo = t }
func (i *Identifier) TypeInfo() TypeInfo {
if i.typeInfo.Type == TypeUnknown {
return AnyType
}
return i.typeInfo
}
// NumberLiteral represents numeric literals
// NumberLiteral represents numeric constants including integers, floats, hex, and binary.
// Always has number type, so no additional type storage needed.
type NumberLiteral struct {
Value float64
typeInfo *TypeInfo
Value float64 // All numbers stored as float64 for simplicity
}
func (nl *NumberLiteral) expressionNode() {}
func (nl *NumberLiteral) String() string { return fmt.Sprintf("%.2f", nl.Value) }
func (nl *NumberLiteral) GetType() *TypeInfo { return nl.typeInfo }
func (nl *NumberLiteral) SetType(t *TypeInfo) { nl.typeInfo = t }
func (nl *NumberLiteral) TypeInfo() TypeInfo { return NumberType }
// StringLiteral represents string literals
// StringLiteral represents string constants and multiline strings.
// Always has string type, so no additional type storage needed.
type StringLiteral struct {
Value string
typeInfo *TypeInfo
Value string // String content without quotes
}
func (sl *StringLiteral) expressionNode() {}
func (sl *StringLiteral) String() string { return fmt.Sprintf(`"%s"`, sl.Value) }
func (sl *StringLiteral) GetType() *TypeInfo { return sl.typeInfo }
func (sl *StringLiteral) SetType(t *TypeInfo) { sl.typeInfo = t }
func (sl *StringLiteral) TypeInfo() TypeInfo { return StringType }
// BooleanLiteral represents boolean literals
// BooleanLiteral represents true and false constants.
// Always has bool type, so no additional type storage needed.
type BooleanLiteral struct {
Value bool
typeInfo *TypeInfo
}
func (bl *BooleanLiteral) expressionNode() {}
@ -378,26 +377,23 @@ func (bl *BooleanLiteral) String() string {
}
return "false"
}
func (bl *BooleanLiteral) GetType() *TypeInfo { return bl.typeInfo }
func (bl *BooleanLiteral) SetType(t *TypeInfo) { bl.typeInfo = t }
func (bl *BooleanLiteral) TypeInfo() TypeInfo { return BoolType }
// NilLiteral represents nil literal
type NilLiteral struct {
typeInfo *TypeInfo
}
// NilLiteral represents the nil constant value.
// Always has nil type, so no additional type storage needed.
type NilLiteral struct{}
func (nl *NilLiteral) expressionNode() {}
func (nl *NilLiteral) String() string { return "nil" }
func (nl *NilLiteral) GetType() *TypeInfo { return nl.typeInfo }
func (nl *NilLiteral) SetType(t *TypeInfo) { nl.typeInfo = t }
func (nl *NilLiteral) TypeInfo() TypeInfo { return NilType }
// FunctionLiteral represents function literals with typed parameters
// FunctionLiteral represents function definitions with parameters, body, and optional return type.
// Always has function type, stores additional return type information separately.
type FunctionLiteral struct {
Parameters []FunctionParameter
Variadic bool
ReturnType *TypeInfo // optional return type hint
Body []Statement
typeInfo *TypeInfo
Parameters []FunctionParameter // Function parameters with optional types
Body []Statement // Function body statements
ReturnType TypeInfo // Optional return type hint, embeds directly
Variadic bool // True if function accepts variable arguments
}
func (fl *FunctionLiteral) expressionNode() {}
@ -417,8 +413,8 @@ func (fl *FunctionLiteral) String() string {
}
result := fmt.Sprintf("fn(%s)", params)
if fl.ReturnType != nil {
result += ": " + fl.ReturnType.Type
if fl.ReturnType.Type != TypeUnknown {
result += ": " + typeToString(fl.ReturnType)
}
result += "\n"
@ -428,14 +424,14 @@ func (fl *FunctionLiteral) String() string {
result += "end"
return result
}
func (fl *FunctionLiteral) GetType() *TypeInfo { return fl.typeInfo }
func (fl *FunctionLiteral) SetType(t *TypeInfo) { fl.typeInfo = t }
func (fl *FunctionLiteral) TypeInfo() TypeInfo { return FunctionType }
// CallExpression represents function calls: func(arg1, arg2, ...)
// CallExpression represents function calls with arguments.
// Stores inferred return type from function signature analysis.
type CallExpression struct {
Function Expression
Arguments []Expression
typeInfo *TypeInfo
Function Expression // Function expression to call
Arguments []Expression // Argument expressions
typeInfo TypeInfo // Inferred return type, embeds directly
}
func (ce *CallExpression) expressionNode() {}
@ -446,74 +442,73 @@ func (ce *CallExpression) String() string {
}
return fmt.Sprintf("%s(%s)", ce.Function.String(), joinStrings(args, ", "))
}
func (ce *CallExpression) GetType() *TypeInfo { return ce.typeInfo }
func (ce *CallExpression) SetType(t *TypeInfo) { ce.typeInfo = t }
func (ce *CallExpression) TypeInfo() TypeInfo { return ce.typeInfo }
// PrefixExpression represents prefix operations like -x, not x
// PrefixExpression represents unary operations like negation and logical not.
// Stores result type based on operator and operand type analysis.
type PrefixExpression struct {
Operator string
Right Expression
typeInfo *TypeInfo
Operator string // Operator symbol ("-", "not")
Right Expression // Operand expression
typeInfo TypeInfo // Result type, embeds directly
}
func (pe *PrefixExpression) expressionNode() {}
func (pe *PrefixExpression) String() string {
// Add space for word operators
if pe.Operator == "not" {
return fmt.Sprintf("(%s %s)", pe.Operator, pe.Right.String())
}
return fmt.Sprintf("(%s%s)", pe.Operator, pe.Right.String())
}
func (pe *PrefixExpression) GetType() *TypeInfo { return pe.typeInfo }
func (pe *PrefixExpression) SetType(t *TypeInfo) { pe.typeInfo = t }
func (pe *PrefixExpression) TypeInfo() TypeInfo { return pe.typeInfo }
// InfixExpression represents binary operations
// InfixExpression represents binary operations between two expressions.
// Stores result type based on operator and operand type compatibility.
type InfixExpression struct {
Left Expression
Operator string
Right Expression
typeInfo *TypeInfo
Left Expression // Left operand
Right Expression // Right operand
Operator string // Operator symbol ("+", "-", "==", "and", etc.)
typeInfo TypeInfo // Result type, embeds directly
}
func (ie *InfixExpression) expressionNode() {}
func (ie *InfixExpression) String() string {
return fmt.Sprintf("(%s %s %s)", ie.Left.String(), ie.Operator, ie.Right.String())
}
func (ie *InfixExpression) GetType() *TypeInfo { return ie.typeInfo }
func (ie *InfixExpression) SetType(t *TypeInfo) { ie.typeInfo = t }
func (ie *InfixExpression) TypeInfo() TypeInfo { return ie.typeInfo }
// IndexExpression represents table[key] access
// IndexExpression represents bracket-based member access (table[key]).
// Stores inferred element type based on container type analysis.
type IndexExpression struct {
Left Expression
Index Expression
typeInfo *TypeInfo
Left Expression // Container expression
Index Expression // Index/key expression
typeInfo TypeInfo // Element type, embeds directly
}
func (ie *IndexExpression) expressionNode() {}
func (ie *IndexExpression) String() string {
return fmt.Sprintf("%s[%s]", ie.Left.String(), ie.Index.String())
}
func (ie *IndexExpression) GetType() *TypeInfo { return ie.typeInfo }
func (ie *IndexExpression) SetType(t *TypeInfo) { ie.typeInfo = t }
func (ie *IndexExpression) TypeInfo() TypeInfo { return ie.typeInfo }
// DotExpression represents table.key access
// DotExpression represents dot-based member access (table.key).
// Stores inferred member type based on container type and field analysis.
type DotExpression struct {
Left Expression
Key string
typeInfo *TypeInfo
Left Expression // Container expression
Key string // Member name
typeInfo TypeInfo // Member type, embeds directly
}
func (de *DotExpression) expressionNode() {}
func (de *DotExpression) String() string {
return fmt.Sprintf("%s.%s", de.Left.String(), de.Key)
}
func (de *DotExpression) GetType() *TypeInfo { return de.typeInfo }
func (de *DotExpression) SetType(t *TypeInfo) { de.typeInfo = t }
func (de *DotExpression) TypeInfo() TypeInfo { return de.typeInfo }
// TablePair represents a key-value pair in a table
// TablePair represents key-value pairs in table literals and struct constructors.
// Key is nil for array-style elements, non-nil for object-style elements.
type TablePair struct {
Key Expression // nil for array-style elements
Value Expression
Key Expression // Key expression (nil for array elements)
Value Expression // Value expression
}
func (tp *TablePair) String() string {
@ -523,10 +518,10 @@ func (tp *TablePair) String() string {
return fmt.Sprintf("%s = %s", tp.Key.String(), tp.Value.String())
}
// TableLiteral represents table literals {}
// TableLiteral represents table/array/object literals with key-value pairs.
// Always has table type, provides methods to check if it's array-style.
type TableLiteral struct {
Pairs []TablePair
typeInfo *TypeInfo
Pairs []TablePair // Key-value pairs (key nil for array elements)
}
func (tl *TableLiteral) expressionNode() {}
@ -537,10 +532,9 @@ func (tl *TableLiteral) String() string {
}
return fmt.Sprintf("{%s}", joinStrings(pairs, ", "))
}
func (tl *TableLiteral) GetType() *TypeInfo { return tl.typeInfo }
func (tl *TableLiteral) SetType(t *TypeInfo) { tl.typeInfo = t }
func (tl *TableLiteral) TypeInfo() TypeInfo { return TableType }
// IsArray returns true if this table contains only array-style elements
// IsArray returns true if this table contains only array-style elements (no explicit keys)
func (tl *TableLiteral) IsArray() bool {
for _, pair := range tl.Pairs {
if pair.Key != nil {
@ -550,7 +544,31 @@ func (tl *TableLiteral) IsArray() bool {
return true
}
// joinStrings joins string slice with separator
// Helper function to convert TypeInfo to string representation
func typeToString(t TypeInfo) string {
switch t.Type {
case TypeNumber:
return "number"
case TypeString:
return "string"
case TypeBool:
return "bool"
case TypeNil:
return "nil"
case TypeTable:
return "table"
case TypeFunction:
return "function"
case TypeAny:
return "any"
case TypeStruct:
return fmt.Sprintf("struct<%d>", t.StructID)
default:
return "unknown"
}
}
// joinStrings efficiently joins string slice with separator
func joinStrings(strs []string, sep string) string {
if len(strs) == 0 {
return ""

View File

@ -19,7 +19,7 @@ func (pe ParseError) Error() string {
pe.Line, pe.Column, pe.Message, pe.Token.Literal)
}
// Parser implements a recursive descent Pratt parser
// Parser implements a recursive descent Pratt parser with optimized AST generation
type Parser struct {
lexer *Lexer
@ -32,11 +32,13 @@ type Parser struct {
errors []ParseError
// Scope tracking
scopes []map[string]bool // stack of scopes, each tracking declared variables
scopeTypes []string // track what type each scope is: "global", "function", "loop"
scopes []map[string]bool
scopeTypes []string
// Struct tracking
structs map[string]*StructStatement // track defined structs
// Struct tracking with ID mapping
structs map[string]*StructStatement
structIDs map[uint16]*StructStatement
nextID uint16
}
// NewParser creates a new parser instance
@ -44,9 +46,11 @@ func NewParser(lexer *Lexer) *Parser {
p := &Parser{
lexer: lexer,
errors: []ParseError{},
scopes: []map[string]bool{make(map[string]bool)}, // start with global scope
scopeTypes: []string{"global"}, // start with global scope type
structs: make(map[string]*StructStatement), // track struct definitions
scopes: []map[string]bool{make(map[string]bool)},
scopeTypes: []string{"global"},
structs: make(map[string]*StructStatement),
structIDs: make(map[uint16]*StructStatement),
nextID: 1, // 0 reserved for non-struct types
}
p.prefixParseFns = make(map[TokenType]func() Expression)
@ -78,15 +82,31 @@ func NewParser(lexer *Lexer) *Parser {
p.registerInfix(DOT, p.parseDotExpression)
p.registerInfix(LBRACKET, p.parseIndexExpression)
p.registerInfix(LPAREN, p.parseCallExpression)
p.registerInfix(LBRACE, p.parseStructConstructor) // struct constructor
p.registerInfix(LBRACE, p.parseStructConstructor)
// Read two tokens, so curToken and peekToken are both set
p.nextToken()
p.nextToken()
return p
}
// Struct management
func (p *Parser) registerStruct(stmt *StructStatement) {
stmt.ID = p.nextID
p.nextID++
p.structs[stmt.Name] = stmt
p.structIDs[stmt.ID] = stmt
}
func (p *Parser) getStructByName(name string) *StructStatement {
return p.structs[name]
}
func (p *Parser) isStructDefined(name string) bool {
_, exists := p.structs[name]
return exists
}
// Scope management
func (p *Parser) enterScope(scopeType string) {
p.scopes = append(p.scopes, make(map[string]bool))
@ -100,30 +120,6 @@ func (p *Parser) exitScope() {
}
}
func (p *Parser) enterFunctionScope() {
p.enterScope("function")
}
func (p *Parser) exitFunctionScope() {
p.exitScope()
}
func (p *Parser) enterLoopScope() {
p.enterScope("loop")
}
func (p *Parser) exitLoopScope() {
p.exitScope()
}
func (p *Parser) enterBlockScope() {
// Blocks don't create new variable scopes
}
func (p *Parser) exitBlockScope() {
// No-op
}
func (p *Parser) currentVariableScope() map[string]bool {
if len(p.scopeTypes) > 1 && p.scopeTypes[len(p.scopeTypes)-1] == "loop" {
return p.scopes[len(p.scopes)-2]
@ -148,45 +144,56 @@ func (p *Parser) declareLoopVariable(name string) {
p.scopes[len(p.scopes)-1][name] = true
}
// parseTypeHint parses optional type hint after colon
func (p *Parser) parseTypeHint() *TypeInfo {
// parseTypeHint parses optional type hint after colon, returns by value
func (p *Parser) parseTypeHint() TypeInfo {
if !p.peekTokenIs(COLON) {
return nil
return UnknownType
}
p.nextToken() // consume ':'
if !p.expectPeekIdent() {
p.addError("expected type name after ':'")
return nil
return UnknownType
}
typeName := p.curToken.Literal
if !ValidTypeName(typeName) && !p.isStructDefined(typeName) {
// Check built-in types
switch typeName {
case "number":
return TypeInfo{Type: TypeNumber, Inferred: false}
case "string":
return TypeInfo{Type: TypeString, Inferred: false}
case "bool":
return TypeInfo{Type: TypeBool, Inferred: false}
case "nil":
return TypeInfo{Type: TypeNil, Inferred: false}
case "table":
return TypeInfo{Type: TypeTable, Inferred: false}
case "function":
return TypeInfo{Type: TypeFunction, Inferred: false}
case "any":
return TypeInfo{Type: TypeAny, Inferred: false}
default:
// Check if it's a struct type
if structDef := p.getStructByName(typeName); structDef != nil {
return TypeInfo{Type: TypeStruct, StructID: structDef.ID, Inferred: false}
}
p.addError(fmt.Sprintf("invalid type name '%s'", typeName))
return nil
return UnknownType
}
}
return &TypeInfo{Type: typeName, Inferred: false}
}
// isStructDefined checks if a struct name is defined
func (p *Parser) isStructDefined(name string) bool {
_, exists := p.structs[name]
return exists
}
// registerPrefix registers a prefix parse function
// registerPrefix/registerInfix
func (p *Parser) registerPrefix(tokenType TokenType, fn func() Expression) {
p.prefixParseFns[tokenType] = fn
}
// registerInfix registers an infix parse function
func (p *Parser) registerInfix(tokenType TokenType, fn func(Expression) Expression) {
p.infixParseFns[tokenType] = fn
}
// nextToken advances to the next token
func (p *Parser) nextToken() {
p.curToken = p.peekToken
p.peekToken = p.lexer.NextToken()
@ -265,7 +272,7 @@ func (p *Parser) parseStructStatement() *StructStatement {
if p.peekTokenIs(RBRACE) {
p.nextToken()
p.structs[stmt.Name] = stmt
p.registerStruct(stmt)
return stmt
}
@ -284,9 +291,9 @@ func (p *Parser) parseStructStatement() *StructStatement {
field := StructField{Name: p.curToken.Literal}
// Parse optional type hint
// Parse required type hint
field.TypeHint = p.parseTypeHint()
if field.TypeHint == nil {
if field.TypeHint.Type == TypeUnknown {
p.addError("struct fields require type annotation")
return nil
}
@ -314,7 +321,7 @@ func (p *Parser) parseStructStatement() *StructStatement {
return nil
}
p.structs[stmt.Name] = stmt
p.registerStruct(stmt)
return stmt
}
@ -338,12 +345,19 @@ func (p *Parser) parseFunctionStatement() Statement {
methodName := p.curToken.Literal
// Get struct ID
structDef := p.getStructByName(funcName)
if structDef == nil {
p.addError(fmt.Sprintf("method defined on undefined struct '%s'", funcName))
return nil
}
if !p.expectPeek(LPAREN) {
p.addError("expected '(' after method name")
return nil
}
// Parse the function literal starting from parameters
// Parse the function literal
funcLit := &FunctionLiteral{}
funcLit.Parameters, funcLit.Variadic = p.parseFunctionParameters()
@ -357,12 +371,12 @@ func (p *Parser) parseFunctionStatement() Statement {
p.nextToken()
p.enterFunctionScope()
p.enterScope("function")
for _, param := range funcLit.Parameters {
p.declareVariable(param.Name)
}
funcLit.Body = p.parseBlockStatements(END)
p.exitFunctionScope()
p.exitScope()
if !p.curTokenIs(END) {
p.addError("expected 'end' to close function")
@ -370,14 +384,13 @@ func (p *Parser) parseFunctionStatement() Statement {
}
return &MethodDefinition{
StructName: funcName,
StructID: structDef.ID,
MethodName: methodName,
Function: funcLit,
}
}
// Regular function - this should be handled as expression statement
// Reset to handle as function literal
// Regular function - handle as function literal expression statement
funcLit := p.parseFunctionLiteral()
if funcLit == nil {
return nil
@ -386,7 +399,7 @@ func (p *Parser) parseFunctionStatement() Statement {
return &ExpressionStatement{Expression: funcLit}
}
// parseIdentifierStatement handles both assignments and expression statements starting with identifiers
// parseIdentifierStatement handles assignments and expression statements
func (p *Parser) parseIdentifierStatement() Statement {
// Parse the left-hand side expression first
expr := p.ParseExpression(LOWEST)
@ -395,28 +408,28 @@ func (p *Parser) parseIdentifierStatement() Statement {
}
// Check for type hint (only valid on simple identifiers)
var typeHint *TypeInfo
var typeHint TypeInfo = UnknownType
if _, ok := expr.(*Identifier); ok {
typeHint = p.parseTypeHint()
}
// Check if this is an assignment
if p.peekTokenIs(ASSIGN) {
// Convert to assignment statement
stmt := &AssignStatement{
Name: expr,
// Create unified assignment
assignment := &Assignment{
Target: expr,
TypeHint: typeHint,
}
// Validate assignment target and check if it's a declaration
switch name := expr.(type) {
switch target := expr.(type) {
case *Identifier:
stmt.IsDeclaration = !p.isVariableDeclared(name.Value)
if stmt.IsDeclaration {
p.declareVariable(name.Value)
assignment.IsDeclaration = !p.isVariableDeclared(target.Value)
if assignment.IsDeclaration {
p.declareVariable(target.Value)
}
case *DotExpression, *IndexExpression:
stmt.IsDeclaration = false
assignment.IsDeclaration = false
default:
p.addError("invalid assignment target")
return nil
@ -428,29 +441,19 @@ func (p *Parser) parseIdentifierStatement() Statement {
p.nextToken()
stmt.Value = p.ParseExpression(LOWEST)
if stmt.Value == nil {
assignment.Value = p.ParseExpression(LOWEST)
if assignment.Value == nil {
p.addError("expected expression after assignment operator")
return nil
}
return stmt
return assignment
} else {
// This is an expression statement
return &ExpressionStatement{Expression: expr}
}
}
// parseExpressionStatement parses expressions used as statements
func (p *Parser) parseExpressionStatement() *ExpressionStatement {
stmt := &ExpressionStatement{}
stmt.Expression = p.ParseExpression(LOWEST)
if stmt.Expression == nil {
return nil
}
return stmt
}
// parseEchoStatement parses echo statements
func (p *Parser) parseEchoStatement() *EchoStatement {
stmt := &EchoStatement{}
@ -466,9 +469,8 @@ func (p *Parser) parseEchoStatement() *EchoStatement {
return stmt
}
// parseBreakStatement parses break statements
// Simple statement parsers
func (p *Parser) parseBreakStatement() *BreakStatement {
// Check if break is followed by an identifier (invalid)
if p.peekTokenIs(IDENT) {
p.addError("unexpected identifier")
return nil
@ -476,7 +478,6 @@ func (p *Parser) parseBreakStatement() *BreakStatement {
return &BreakStatement{}
}
// parseExitStatement parses exit statements
func (p *Parser) parseExitStatement() *ExitStatement {
stmt := &ExitStatement{}
@ -492,7 +493,6 @@ func (p *Parser) parseExitStatement() *ExitStatement {
return stmt
}
// parseReturnStatement parses return statements
func (p *Parser) parseReturnStatement() *ReturnStatement {
stmt := &ReturnStatement{}
@ -508,7 +508,6 @@ func (p *Parser) parseReturnStatement() *ReturnStatement {
return stmt
}
// canStartExpression checks if a token type can start an expression
func (p *Parser) canStartExpression(tokenType TokenType) bool {
switch tokenType {
case IDENT, NUMBER, STRING, TRUE, FALSE, NIL, LPAREN, LBRACE, MINUS, NOT, FN:
@ -518,7 +517,7 @@ func (p *Parser) canStartExpression(tokenType TokenType) bool {
}
}
// parseWhileStatement parses while loops
// Loop statement parsers
func (p *Parser) parseWhileStatement() *WhileStatement {
stmt := &WhileStatement{}
@ -537,9 +536,7 @@ func (p *Parser) parseWhileStatement() *WhileStatement {
p.nextToken()
p.enterBlockScope()
stmt.Body = p.parseBlockStatements(END)
p.exitBlockScope()
if !p.curTokenIs(END) {
p.addError("expected 'end' to close while loop")
@ -549,7 +546,6 @@ func (p *Parser) parseWhileStatement() *WhileStatement {
return stmt
}
// parseForStatement parses for loops
func (p *Parser) parseForStatement() Statement {
p.nextToken()
@ -570,7 +566,6 @@ func (p *Parser) parseForStatement() Statement {
}
}
// parseNumericForStatement parses numeric for loops
func (p *Parser) parseNumericForStatement(variable *Identifier) *ForStatement {
stmt := &ForStatement{Variable: variable}
@ -617,10 +612,10 @@ func (p *Parser) parseNumericForStatement(variable *Identifier) *ForStatement {
p.nextToken()
p.enterLoopScope()
p.enterScope("loop")
p.declareLoopVariable(variable.Value)
stmt.Body = p.parseBlockStatements(END)
p.exitLoopScope()
p.exitScope()
if !p.curTokenIs(END) {
p.addError("expected 'end' to close for loop")
@ -630,7 +625,6 @@ func (p *Parser) parseNumericForStatement(variable *Identifier) *ForStatement {
return stmt
}
// parseForInStatement parses for-in loops
func (p *Parser) parseForInStatement(firstVar *Identifier) *ForInStatement {
stmt := &ForInStatement{}
@ -669,13 +663,13 @@ func (p *Parser) parseForInStatement(firstVar *Identifier) *ForInStatement {
p.nextToken()
p.enterLoopScope()
p.enterScope("loop")
if stmt.Key != nil {
p.declareLoopVariable(stmt.Key.Value)
}
p.declareLoopVariable(stmt.Value.Value)
stmt.Body = p.parseBlockStatements(END)
p.exitLoopScope()
p.exitScope()
if !p.curTokenIs(END) {
p.addError("expected 'end' to close for loop")
@ -708,9 +702,7 @@ func (p *Parser) parseIfStatement() *IfStatement {
return nil
}
p.enterBlockScope()
stmt.Body = p.parseBlockStatements(ELSEIF, ELSE, END)
p.exitBlockScope()
for p.curTokenIs(ELSEIF) {
elseif := ElseIfClause{}
@ -729,19 +721,13 @@ func (p *Parser) parseIfStatement() *IfStatement {
p.nextToken()
p.enterBlockScope()
elseif.Body = p.parseBlockStatements(ELSEIF, ELSE, END)
p.exitBlockScope()
stmt.ElseIfs = append(stmt.ElseIfs, elseif)
}
if p.curTokenIs(ELSE) {
p.nextToken()
p.enterBlockScope()
stmt.Else = p.parseBlockStatements(END)
p.exitBlockScope()
}
if !p.curTokenIs(END) {
@ -754,7 +740,7 @@ func (p *Parser) parseIfStatement() *IfStatement {
// parseBlockStatements parses statements until terminators
func (p *Parser) parseBlockStatements(terminators ...TokenType) []Statement {
statements := []Statement{}
statements := make([]Statement, 0, 8) // Pre-allocate for performance
for !p.curTokenIs(EOF) && !p.isTerminator(terminators...) {
stmt := p.parseStatement()
@ -767,7 +753,6 @@ func (p *Parser) parseBlockStatements(terminators ...TokenType) []Statement {
return statements
}
// isTerminator checks if current token is a terminator
func (p *Parser) isTerminator(terminators ...TokenType) bool {
for _, terminator := range terminators {
if p.curTokenIs(terminator) {
@ -919,7 +904,6 @@ func (p *Parser) parseGroupedExpression() Expression {
// parseParenthesizedAssignment parses assignment expressions in parentheses
func (p *Parser) parseParenthesizedAssignment() Expression {
// We're at identifier, peek is ASSIGN
target := p.parseIdentifier()
if !p.expectPeek(ASSIGN) {
@ -939,9 +923,10 @@ func (p *Parser) parseParenthesizedAssignment() Expression {
}
// Create assignment expression
assignExpr := &AssignExpression{
Name: target,
assignExpr := &Assignment{
Target: target,
Value: value,
IsExpression: true,
}
// Handle variable declaration for assignment expressions
@ -952,8 +937,6 @@ func (p *Parser) parseParenthesizedAssignment() Expression {
}
}
// Assignment expression evaluates to the assigned value
assignExpr.SetType(value.GetType())
return assignExpr
}
@ -977,12 +960,12 @@ func (p *Parser) parseFunctionLiteral() Expression {
p.nextToken()
p.enterFunctionScope()
p.enterScope("function")
for _, param := range fn.Parameters {
p.declareVariable(param.Name)
}
fn.Body = p.parseBlockStatements(END)
p.exitFunctionScope()
p.exitScope()
if !p.curTokenIs(END) {
p.addError("expected 'end' to close function")
@ -1038,7 +1021,7 @@ func (p *Parser) parseFunctionParameters() ([]FunctionParameter, bool) {
func (p *Parser) parseTableLiteral() Expression {
table := &TableLiteral{}
table.Pairs = []TablePair{}
table.Pairs = make([]TablePair, 0, 4) // Pre-allocate
if p.peekTokenIs(RBRACE) {
p.nextToken()
@ -1104,22 +1087,24 @@ func (p *Parser) parseTableLiteral() Expression {
return table
}
// parseStructConstructor handles struct constructor calls like my_type{...}
// parseStructConstructor handles struct constructor calls
func (p *Parser) parseStructConstructor(left Expression) Expression {
// left should be an identifier representing the struct name
ident, ok := left.(*Identifier)
if !ok {
// Not an identifier, fall back to table literal parsing
return p.parseTableLiteralFromBrace()
}
structName := ident.Value
structDef := p.getStructByName(structName)
if structDef == nil {
// Not a struct, parse as table literal
return p.parseTableLiteralFromBrace()
}
// Always try to parse as struct constructor if we have an identifier
// Type checking will catch undefined structs later
constructor := &StructConstructorExpression{
StructName: structName,
Fields: []TablePair{},
constructor := &StructConstructor{
StructID: structDef.ID,
Fields: make([]TablePair, 0, 4),
typeInfo: TypeInfo{Type: TypeStruct, StructID: structDef.ID, Inferred: true},
}
if p.peekTokenIs(RBRACE) {
@ -1187,9 +1172,8 @@ func (p *Parser) parseStructConstructor(left Expression) Expression {
}
func (p *Parser) parseTableLiteralFromBrace() Expression {
// We're already at the opening brace, so parse as table literal
table := &TableLiteral{}
table.Pairs = []TablePair{}
table.Pairs = make([]TablePair, 0, 4)
if p.peekTokenIs(RBRACE) {
p.nextToken()
@ -1428,15 +1412,9 @@ func (p *Parser) curPrecedence() Precedence {
return LOWEST
}
// Errors returns all parsing errors
func (p *Parser) Errors() []ParseError {
return p.errors
}
func (p *Parser) HasErrors() bool {
return len(p.errors) > 0
}
// Error reporting
func (p *Parser) Errors() []ParseError { return p.errors }
func (p *Parser) HasErrors() bool { return len(p.errors) > 0 }
func (p *Parser) ErrorStrings() []string {
result := make([]string, len(p.errors))
for i, err := range p.errors {

View File

@ -31,15 +31,15 @@ func TestAssignStatements(t *testing.T) {
t.Fatalf("expected 1 statement, got %d", len(program.Statements))
}
stmt, ok := program.Statements[0].(*parser.AssignStatement)
stmt, ok := program.Statements[0].(*parser.Assignment)
if !ok {
t.Fatalf("expected AssignStatement, got %T", program.Statements[0])
t.Fatalf("expected Assignment, got %T", program.Statements[0])
}
// Check that Name is an Identifier
ident, ok := stmt.Name.(*parser.Identifier)
// Check that Target is an Identifier
ident, ok := stmt.Target.(*parser.Identifier)
if !ok {
t.Fatalf("expected Identifier for Name, got %T", stmt.Name)
t.Fatalf("expected Identifier for Target, got %T", stmt.Target)
}
if ident.Value != tt.expectedIdentifier {
@ -90,9 +90,9 @@ func TestMemberAccessAssignment(t *testing.T) {
t.Fatalf("expected 1 statement, got %d", len(program.Statements))
}
stmt, ok := program.Statements[0].(*parser.AssignStatement)
stmt, ok := program.Statements[0].(*parser.Assignment)
if !ok {
t.Fatalf("expected AssignStatement, got %T", program.Statements[0])
t.Fatalf("expected Assignment, got %T", program.Statements[0])
}
if stmt.String() != tt.expected {
@ -158,15 +158,15 @@ func TestTableAssignments(t *testing.T) {
t.Fatalf("expected 1 statement, got %d", len(program.Statements))
}
stmt, ok := program.Statements[0].(*parser.AssignStatement)
stmt, ok := program.Statements[0].(*parser.Assignment)
if !ok {
t.Fatalf("expected AssignStatement, got %T", program.Statements[0])
t.Fatalf("expected Assignment, got %T", program.Statements[0])
}
// Check that Name is an Identifier
ident, ok := stmt.Name.(*parser.Identifier)
// Check that Target is an Identifier
ident, ok := stmt.Target.(*parser.Identifier)
if !ok {
t.Fatalf("expected Identifier for Name, got %T", stmt.Name)
t.Fatalf("expected Identifier for Target, got %T", stmt.Target)
}
if ident.Value != tt.identifier {

View File

@ -247,9 +247,9 @@ exit "success"`
}
// First: assignment
_, ok := program.Statements[0].(*parser.AssignStatement)
_, ok := program.Statements[0].(*parser.Assignment)
if !ok {
t.Fatalf("statement 0: expected AssignStatement, got %T", program.Statements[0])
t.Fatalf("statement 0: expected Assignment, got %T", program.Statements[0])
}
// Second: if statement with exit in body
@ -264,7 +264,7 @@ exit "success"`
}
// Third: assignment
_, ok = program.Statements[2].(*parser.AssignStatement)
_, ok = program.Statements[2].(*parser.Assignment)
if !ok {
t.Fatalf("statement 2: expected AssignStatement, got %T", program.Statements[2])
}

View File

@ -32,15 +32,15 @@ end`
t.Fatalf("expected 1 body statement, got %d", len(stmt.Body))
}
bodyStmt, ok := stmt.Body[0].(*parser.AssignStatement)
bodyStmt, ok := stmt.Body[0].(*parser.Assignment)
if !ok {
t.Fatalf("expected AssignStatement in body, got %T", stmt.Body[0])
}
// Check that Name is an Identifier
ident, ok := bodyStmt.Name.(*parser.Identifier)
// Check that Target is an Identifier
ident, ok := bodyStmt.Target.(*parser.Identifier)
if !ok {
t.Fatalf("expected Identifier for Name, got %T", bodyStmt.Name)
t.Fatalf("expected Identifier for Target, got %T", bodyStmt.Target)
}
if ident.Value != "x" {
@ -79,15 +79,15 @@ end`
t.Fatalf("expected 1 else statement, got %d", len(stmt.Else))
}
elseStmt, ok := stmt.Else[0].(*parser.AssignStatement)
elseStmt, ok := stmt.Else[0].(*parser.Assignment)
if !ok {
t.Fatalf("expected AssignStatement in else, got %T", stmt.Else[0])
}
// Check that Name is an Identifier
ident, ok := elseStmt.Name.(*parser.Identifier)
ident, ok := elseStmt.Target.(*parser.Identifier)
if !ok {
t.Fatalf("expected Identifier for Name, got %T", elseStmt.Name)
t.Fatalf("expected Identifier for Name, got %T", elseStmt.Target)
}
if ident.Value != "x" {
@ -169,25 +169,25 @@ end`
}
// First assignment: arr[1] = "updated"
assign1, ok := stmt.Body[0].(*parser.AssignStatement)
assign1, ok := stmt.Body[0].(*parser.Assignment)
if !ok {
t.Fatalf("expected AssignStatement, got %T", stmt.Body[0])
}
_, ok = assign1.Name.(*parser.IndexExpression)
_, ok = assign1.Target.(*parser.IndexExpression)
if !ok {
t.Fatalf("expected IndexExpression for assignment target, got %T", assign1.Name)
t.Fatalf("expected IndexExpression for assignment target, got %T", assign1.Target)
}
// Second assignment: obj.nested.count = obj.nested.count + 1
assign2, ok := stmt.Body[1].(*parser.AssignStatement)
assign2, ok := stmt.Body[1].(*parser.Assignment)
if !ok {
t.Fatalf("expected AssignStatement, got %T", stmt.Body[1])
}
_, ok = assign2.Name.(*parser.DotExpression)
_, ok = assign2.Target.(*parser.DotExpression)
if !ok {
t.Fatalf("expected DotExpression for assignment target, got %T", assign2.Name)
t.Fatalf("expected DotExpression for assignment target, got %T", assign2.Target)
}
}
@ -214,7 +214,7 @@ end`
}
// Test body has expression assignment
bodyStmt := stmt.Body[0].(*parser.AssignStatement)
bodyStmt := stmt.Body[0].(*parser.Assignment)
bodyInfix, ok := bodyStmt.Value.(*parser.InfixExpression)
if !ok {
t.Fatalf("expected InfixExpression value, got %T", bodyStmt.Value)

View File

@ -352,15 +352,15 @@ func TestAssignmentExpressions(t *testing.T) {
expr := p.ParseExpression(parser.LOWEST)
checkParserErrors(t, p)
assignExpr, ok := expr.(*parser.AssignExpression)
assignExpr, ok := expr.(*parser.Assignment)
if !ok {
t.Fatalf("expected AssignExpression, got %T", expr)
}
// Test target name
ident, ok := assignExpr.Name.(*parser.Identifier)
ident, ok := assignExpr.Target.(*parser.Identifier)
if !ok {
t.Fatalf("expected Identifier for assignment target, got %T", assignExpr.Name)
t.Fatalf("expected Identifier for assignment target, got %T", assignExpr.Target)
}
if ident.Value != tt.targetName {
@ -413,12 +413,12 @@ func TestAssignmentExpressionWithComplexExpressions(t *testing.T) {
expr := p.ParseExpression(parser.LOWEST)
checkParserErrors(t, p)
assignExpr, ok := expr.(*parser.AssignExpression)
assignExpr, ok := expr.(*parser.Assignment)
if !ok {
t.Fatalf("expected AssignExpression, got %T", expr)
}
if assignExpr.Name == nil {
if assignExpr.Target == nil {
t.Error("expected non-nil assignment target")
}

View File

@ -160,7 +160,7 @@ func TestFunctionAssignments(t *testing.T) {
t.Fatalf("expected 1 statement, got %d", len(program.Statements))
}
stmt, ok := program.Statements[0].(*parser.AssignStatement)
stmt, ok := program.Statements[0].(*parser.Assignment)
if !ok {
t.Fatalf("expected AssignStatement, got %T", program.Statements[0])
}
@ -296,7 +296,7 @@ end`
}
// First statement: assignment of inner function
assign, ok := fn.Body[0].(*parser.AssignStatement)
assign, ok := fn.Body[0].(*parser.Assignment)
if !ok {
t.Fatalf("expected AssignStatement, got %T", fn.Body[0])
}
@ -342,7 +342,7 @@ end`
}
// First: function assignment
assign, ok := forStmt.Body[0].(*parser.AssignStatement)
assign, ok := forStmt.Body[0].(*parser.Assignment)
if !ok {
t.Fatalf("expected AssignStatement, got %T", forStmt.Body[0])
}
@ -551,7 +551,7 @@ echo adder`
}
// First: table with functions
mathAssign, ok := program.Statements[0].(*parser.AssignStatement)
mathAssign, ok := program.Statements[0].(*parser.Assignment)
if !ok {
t.Fatalf("statement 0: expected AssignStatement, got %T", program.Statements[0])
}
@ -574,7 +574,7 @@ echo adder`
}
// Second: result assignment (function call would be handled by interpreter)
_, ok = program.Statements[1].(*parser.AssignStatement)
_, ok = program.Statements[1].(*parser.Assignment)
if !ok {
t.Fatalf("statement 1: expected AssignStatement, got %T", program.Statements[1])
}
@ -586,7 +586,7 @@ echo adder`
}
// Fourth: calculator function assignment
calcAssign, ok := program.Statements[3].(*parser.AssignStatement)
calcAssign, ok := program.Statements[3].(*parser.Assignment)
if !ok {
t.Fatalf("statement 3: expected AssignStatement, got %T", program.Statements[3])
}
@ -601,7 +601,7 @@ echo adder`
}
// Fifth: adder assignment
_, ok = program.Statements[4].(*parser.AssignStatement)
_, ok = program.Statements[4].(*parser.Assignment)
if !ok {
t.Fatalf("statement 4: expected AssignStatement, got %T", program.Statements[4])
}
@ -645,7 +645,7 @@ end`
}
// Check if body has function assignment
ifAssign, ok := ifStmt.Body[0].(*parser.AssignStatement)
ifAssign, ok := ifStmt.Body[0].(*parser.Assignment)
if !ok {
t.Fatalf("if body: expected AssignStatement, got %T", ifStmt.Body[0])
}
@ -656,7 +656,7 @@ end`
}
// Check else body has function assignment
elseAssign, ok := ifStmt.Else[0].(*parser.AssignStatement)
elseAssign, ok := ifStmt.Else[0].(*parser.Assignment)
if !ok {
t.Fatalf("else body: expected AssignStatement, got %T", ifStmt.Else[0])
}
@ -683,7 +683,7 @@ end`
}
// Verify both branches assign functions
nestedIfAssign, ok := nestedIf.Body[0].(*parser.AssignStatement)
nestedIfAssign, ok := nestedIf.Body[0].(*parser.Assignment)
if !ok {
t.Fatalf("nested if body: expected AssignStatement, got %T", nestedIf.Body[0])
}
@ -693,7 +693,7 @@ end`
t.Fatalf("nested if body: expected FunctionLiteral, got %T", nestedIfAssign.Value)
}
nestedElseAssign, ok := nestedIf.Else[0].(*parser.AssignStatement)
nestedElseAssign, ok := nestedIf.Else[0].(*parser.Assignment)
if !ok {
t.Fatalf("nested else body: expected AssignStatement, got %T", nestedIf.Else[0])
}

View File

@ -327,13 +327,13 @@ end`
}
// First: table assignment
_, ok := program.Statements[0].(*parser.AssignStatement)
_, ok := program.Statements[0].(*parser.Assignment)
if !ok {
t.Fatalf("statement 0: expected AssignStatement, got %T", program.Statements[0])
}
// Second: variable assignment
_, ok = program.Statements[1].(*parser.AssignStatement)
_, ok = program.Statements[1].(*parser.Assignment)
if !ok {
t.Fatalf("statement 1: expected AssignStatement, got %T", program.Statements[1])
}
@ -590,7 +590,7 @@ end`
}
// Second body statement should be assignment
_, ok = outerWhile.Body[1].(*parser.AssignStatement)
_, ok = outerWhile.Body[1].(*parser.Assignment)
if !ok {
t.Fatalf("expected AssignStatement, got %T", outerWhile.Body[1])
}
@ -634,14 +634,14 @@ end`
}
// First assignment: data[index] = ...
assign1, ok := stmt.Body[0].(*parser.AssignStatement)
assign1, ok := stmt.Body[0].(*parser.Assignment)
if !ok {
t.Fatalf("expected AssignStatement, got %T", stmt.Body[0])
}
_, ok = assign1.Name.(*parser.IndexExpression)
_, ok = assign1.Target.(*parser.IndexExpression)
if !ok {
t.Fatalf("expected IndexExpression for assignment target, got %T", assign1.Name)
t.Fatalf("expected IndexExpression for assignment target, got %T", assign1.Target)
}
}
@ -755,7 +755,7 @@ end`
// First three: assignments
for i := 0; i < 3; i++ {
_, ok := program.Statements[i].(*parser.AssignStatement)
_, ok := program.Statements[i].(*parser.Assignment)
if !ok {
t.Fatalf("statement %d: expected AssignStatement, got %T", i, program.Statements[i])
}
@ -838,7 +838,7 @@ end`
}
// Fourth: assignment
_, ok = whileStmt.Body[3].(*parser.AssignStatement)
_, ok = whileStmt.Body[3].(*parser.Assignment)
if !ok {
t.Fatalf("body[3]: expected AssignStatement, got %T", whileStmt.Body[3])
}

View File

@ -22,14 +22,14 @@ z = true + false`
expectedIdentifiers := []string{"x", "y", "z"}
for i, expectedIdent := range expectedIdentifiers {
stmt, ok := program.Statements[i].(*parser.AssignStatement)
stmt, ok := program.Statements[i].(*parser.Assignment)
if !ok {
t.Fatalf("statement %d: expected AssignStatement, got %T", i, program.Statements[i])
}
ident, ok := stmt.Name.(*parser.Identifier)
ident, ok := stmt.Target.(*parser.Identifier)
if !ok {
t.Fatalf("statement %d: expected Identifier for Name, got %T", i, stmt.Name)
t.Fatalf("statement %d: expected Identifier for Name, got %T", i, stmt.Target)
}
if ident.Value != expectedIdent {
@ -58,13 +58,13 @@ arr = {a = 1, b = 2}`
}
// First statement: assignment
stmt1, ok := program.Statements[0].(*parser.AssignStatement)
stmt1, ok := program.Statements[0].(*parser.Assignment)
if !ok {
t.Fatalf("statement 0: expected AssignStatement, got %T", program.Statements[0])
}
ident1, ok := stmt1.Name.(*parser.Identifier)
ident1, ok := stmt1.Target.(*parser.Identifier)
if !ok {
t.Fatalf("expected Identifier for Name, got %T", stmt1.Name)
t.Fatalf("expected Identifier for Name, got %T", stmt1.Target)
}
if ident1.Value != "x" {
t.Errorf("expected identifier 'x', got %s", ident1.Value)
@ -80,7 +80,7 @@ arr = {a = 1, b = 2}`
}
// Third statement: table assignment
stmt3, ok := program.Statements[2].(*parser.AssignStatement)
stmt3, ok := program.Statements[2].(*parser.Assignment)
if !ok {
t.Fatalf("statement 2: expected AssignStatement, got %T", program.Statements[2])
}
@ -110,33 +110,33 @@ echo table[table.key]`
}
// Second statement: dot assignment
stmt2, ok := program.Statements[1].(*parser.AssignStatement)
stmt2, ok := program.Statements[1].(*parser.Assignment)
if !ok {
t.Fatalf("statement 1: expected AssignStatement, got %T", program.Statements[1])
}
_, ok = stmt2.Name.(*parser.DotExpression)
_, ok = stmt2.Target.(*parser.DotExpression)
if !ok {
t.Fatalf("expected DotExpression for assignment target, got %T", stmt2.Name)
t.Fatalf("expected DotExpression for assignment target, got %T", stmt2.Target)
}
// Third statement: bracket assignment
stmt3, ok := program.Statements[2].(*parser.AssignStatement)
stmt3, ok := program.Statements[2].(*parser.Assignment)
if !ok {
t.Fatalf("statement 2: expected AssignStatement, got %T", program.Statements[2])
}
_, ok = stmt3.Name.(*parser.IndexExpression)
_, ok = stmt3.Target.(*parser.IndexExpression)
if !ok {
t.Fatalf("expected IndexExpression for assignment target, got %T", stmt3.Name)
t.Fatalf("expected IndexExpression for assignment target, got %T", stmt3.Target)
}
// Fourth statement: chained dot assignment
stmt4, ok := program.Statements[3].(*parser.AssignStatement)
stmt4, ok := program.Statements[3].(*parser.Assignment)
if !ok {
t.Fatalf("statement 3: expected AssignStatement, got %T", program.Statements[3])
}
_, ok = stmt4.Name.(*parser.DotExpression)
_, ok = stmt4.Target.(*parser.DotExpression)
if !ok {
t.Fatalf("expected DotExpression for assignment target, got %T", stmt4.Name)
t.Fatalf("expected DotExpression for assignment target, got %T", stmt4.Target)
}
// Fifth statement: echo with nested access
@ -232,7 +232,7 @@ end`
}
// First statement: complex expression assignment
stmt1, ok := program.Statements[0].(*parser.AssignStatement)
stmt1, ok := program.Statements[0].(*parser.Assignment)
if !ok {
t.Fatalf("expected AssignStatement, got %T", program.Statements[0])
}
@ -253,7 +253,7 @@ end`
t.Fatalf("expected 1 body statement, got %d", len(stmt2.Body))
}
bodyStmt, ok := stmt2.Body[0].(*parser.AssignStatement)
bodyStmt, ok := stmt2.Body[0].(*parser.Assignment)
if !ok {
t.Fatalf("expected AssignStatement in body, got %T", stmt2.Body[0])
}
@ -286,7 +286,7 @@ echo {result = x}`
}
// First: assignment
_, ok := program.Statements[0].(*parser.AssignStatement)
_, ok := program.Statements[0].(*parser.Assignment)
if !ok {
t.Fatalf("statement 0: expected AssignStatement, got %T", program.Statements[0])
}

View File

@ -63,7 +63,7 @@ x = 15`,
assignmentCount := 0
for _, stmt := range program.Statements {
if assign, ok := stmt.(*parser.AssignStatement); ok {
if assign, ok := stmt.(*parser.Assignment); ok {
if assignmentCount >= len(tt.assignments) {
t.Fatalf("more assignments than expected")
}
@ -71,9 +71,9 @@ x = 15`,
expected := tt.assignments[assignmentCount]
// Check variable name
ident, ok := assign.Name.(*parser.Identifier)
ident, ok := assign.Target.(*parser.Identifier)
if !ok {
t.Fatalf("expected Identifier, got %T", assign.Name)
t.Fatalf("expected Identifier, got %T", assign.Target)
}
if ident.Value != expected.variable {
@ -135,9 +135,9 @@ z = 30`
for i, expected := range expectedAssignments {
assign := assignments[i]
ident, ok := assign.Name.(*parser.Identifier)
ident, ok := assign.Target.(*parser.Identifier)
if !ok {
t.Fatalf("assignment %d: expected Identifier, got %T", i, assign.Name)
t.Fatalf("assignment %d: expected Identifier, got %T", i, assign.Target)
}
if ident.Value != expected.variable {
@ -191,9 +191,9 @@ end`
for i, expected := range expectedAssignments {
assign := assignments[i]
ident, ok := assign.Name.(*parser.Identifier)
ident, ok := assign.Target.(*parser.Identifier)
if !ok {
t.Fatalf("assignment %d: expected Identifier, got %T", i, assign.Name)
t.Fatalf("assignment %d: expected Identifier, got %T", i, assign.Target)
}
if ident.Value != expected.variable {
@ -243,9 +243,9 @@ c = 20`
for i, expected := range expectedAssignments {
assign := assignments[i]
ident, ok := assign.Name.(*parser.Identifier)
ident, ok := assign.Target.(*parser.Identifier)
if !ok {
t.Fatalf("assignment %d: expected Identifier, got %T", i, assign.Name)
t.Fatalf("assignment %d: expected Identifier, got %T", i, assign.Target)
}
if ident.Value != expected.variable {
@ -344,9 +344,9 @@ count = 0`,
for i, expected := range tt.assignments {
assign := assignments[i]
ident, ok := assign.Name.(*parser.Identifier)
ident, ok := assign.Target.(*parser.Identifier)
if !ok {
t.Fatalf("assignment %d: expected Identifier, got %T", i, assign.Name)
t.Fatalf("assignment %d: expected Identifier, got %T", i, assign.Target)
}
if ident.Value != expected.variable {
@ -389,7 +389,7 @@ arr[1] = 99`
assignmentCount := 0
for _, stmt := range program.Statements {
if assign, ok := stmt.(*parser.AssignStatement); ok {
if assign, ok := stmt.(*parser.Assignment); ok {
if assignmentCount >= len(expectedAssignments) {
t.Fatalf("more assignments than expected")
}
@ -398,7 +398,7 @@ arr[1] = 99`
if expected.isMemberAccess {
// Should not be an identifier
if _, ok := assign.Name.(*parser.Identifier); ok {
if _, ok := assign.Target.(*parser.Identifier); ok {
t.Errorf("assignment %d: expected member access, got Identifier", assignmentCount)
}
@ -408,9 +408,9 @@ arr[1] = 99`
}
} else {
// Should be an identifier
ident, ok := assign.Name.(*parser.Identifier)
ident, ok := assign.Target.(*parser.Identifier)
if !ok {
t.Errorf("assignment %d: expected Identifier, got %T", assignmentCount, assign.Name)
t.Errorf("assignment %d: expected Identifier, got %T", assignmentCount, assign.Target)
} else if ident.Value != expected.variable {
t.Errorf("assignment %d: expected variable %s, got %s",
assignmentCount, expected.variable, ident.Value)
@ -487,9 +487,9 @@ local_var = "global_local"`
for i, expected := range expectedAssignments {
assign := assignments[i]
ident, ok := assign.Name.(*parser.Identifier)
ident, ok := assign.Target.(*parser.Identifier)
if !ok {
t.Fatalf("assignment %d: expected Identifier, got %T", i, assign.Name)
t.Fatalf("assignment %d: expected Identifier, got %T", i, assign.Target)
}
if ident.Value != expected.variable {
@ -536,8 +536,8 @@ y = 20.00`
}
// Helper function to extract all assignments from a program recursively
func extractAssignments(program *parser.Program) []*parser.AssignStatement {
var assignments []*parser.AssignStatement
func extractAssignments(program *parser.Program) []*parser.Assignment {
var assignments []*parser.Assignment
for _, stmt := range program.Statements {
assignments = append(assignments, extractAssignmentsFromStatement(stmt)...)
@ -546,11 +546,11 @@ func extractAssignments(program *parser.Program) []*parser.AssignStatement {
return assignments
}
func extractAssignmentsFromStatement(stmt parser.Statement) []*parser.AssignStatement {
var assignments []*parser.AssignStatement
func extractAssignmentsFromStatement(stmt parser.Statement) []*parser.Assignment {
var assignments []*parser.Assignment
switch s := stmt.(type) {
case *parser.AssignStatement:
case *parser.Assignment:
assignments = append(assignments, s)
// Check if the value is a function literal with assignments in body

View File

@ -38,22 +38,22 @@ func TestBasicStructDefinition(t *testing.T) {
if stmt.Fields[0].Name != "name" {
t.Errorf("expected field name 'name', got %s", stmt.Fields[0].Name)
}
if stmt.Fields[0].TypeHint == nil {
if stmt.Fields[0].TypeHint.Type == parser.TypeUnknown {
t.Fatal("expected type hint for name field")
}
if stmt.Fields[0].TypeHint.Type != "string" {
t.Errorf("expected type 'string', got %s", stmt.Fields[0].TypeHint.Type)
if stmt.Fields[0].TypeHint.Type != parser.TypeString {
t.Errorf("expected type string, got %v", stmt.Fields[0].TypeHint.Type)
}
// Test second field
if stmt.Fields[1].Name != "age" {
t.Errorf("expected field name 'age', got %s", stmt.Fields[1].Name)
}
if stmt.Fields[1].TypeHint == nil {
if stmt.Fields[1].TypeHint.Type == parser.TypeUnknown {
t.Fatal("expected type hint for age field")
}
if stmt.Fields[1].TypeHint.Type != "number" {
t.Errorf("expected type 'number', got %s", stmt.Fields[1].TypeHint.Type)
if stmt.Fields[1].TypeHint.Type != parser.TypeNumber {
t.Errorf("expected type number, got %v", stmt.Fields[1].TypeHint.Type)
}
}
@ -107,7 +107,7 @@ func TestComplexStructDefinition(t *testing.T) {
t.Fatalf("expected StructStatement, got %T", program.Statements[0])
}
expectedTypes := []string{"number", "string", "bool", "table", "function", "any"}
expectedTypes := []parser.Type{parser.TypeNumber, parser.TypeString, parser.TypeBool, parser.TypeTable, parser.TypeFunction, parser.TypeAny}
expectedNames := []string{"id", "name", "active", "data", "callback", "optional"}
if len(stmt.Fields) != len(expectedTypes) {
@ -118,11 +118,11 @@ func TestComplexStructDefinition(t *testing.T) {
if field.Name != expectedNames[i] {
t.Errorf("field %d: expected name '%s', got '%s'", i, expectedNames[i], field.Name)
}
if field.TypeHint == nil {
if field.TypeHint.Type == parser.TypeUnknown {
t.Fatalf("field %d: expected type hint", i)
}
if field.TypeHint.Type != expectedTypes[i] {
t.Errorf("field %d: expected type '%s', got '%s'", i, expectedTypes[i], field.TypeHint.Type)
t.Errorf("field %d: expected type %v, got %v", i, expectedTypes[i], field.TypeHint.Type)
}
}
}
@ -164,17 +164,17 @@ end`
if !ok {
t.Fatalf("expected MethodDefinition, got %T", program.Statements[1])
}
if method1.StructName != "Person" {
t.Errorf("expected struct name 'Person', got %s", method1.StructName)
if method1.StructID != structStmt.ID {
t.Errorf("expected struct ID %d, got %d", structStmt.ID, method1.StructID)
}
if method1.MethodName != "getName" {
t.Errorf("expected method name 'getName', got %s", method1.MethodName)
}
if method1.Function.ReturnType == nil {
if method1.Function.ReturnType.Type == parser.TypeUnknown {
t.Fatal("expected return type for getName method")
}
if method1.Function.ReturnType.Type != "string" {
t.Errorf("expected return type 'string', got %s", method1.Function.ReturnType.Type)
if method1.Function.ReturnType.Type != parser.TypeString {
t.Errorf("expected return type string, got %v", method1.Function.ReturnType.Type)
}
if len(method1.Function.Parameters) != 0 {
t.Errorf("expected 0 parameters, got %d", len(method1.Function.Parameters))
@ -185,14 +185,14 @@ end`
if !ok {
t.Fatalf("expected MethodDefinition, got %T", program.Statements[2])
}
if method2.StructName != "Person" {
t.Errorf("expected struct name 'Person', got %s", method2.StructName)
if method2.StructID != structStmt.ID {
t.Errorf("expected struct ID %d, got %d", structStmt.ID, method2.StructID)
}
if method2.MethodName != "setAge" {
t.Errorf("expected method name 'setAge', got %s", method2.MethodName)
}
if method2.Function.ReturnType != nil {
t.Errorf("expected no return type for setAge method, got %s", method2.Function.ReturnType.Type)
if method2.Function.ReturnType.Type != parser.TypeUnknown {
t.Errorf("expected no return type for setAge method, got %v", method2.Function.ReturnType.Type)
}
if len(method2.Function.Parameters) != 1 {
t.Fatalf("expected 1 parameter, got %d", len(method2.Function.Parameters))
@ -200,11 +200,11 @@ end`
if method2.Function.Parameters[0].Name != "newAge" {
t.Errorf("expected parameter name 'newAge', got %s", method2.Function.Parameters[0].Name)
}
if method2.Function.Parameters[0].TypeHint == nil {
if method2.Function.Parameters[0].TypeHint.Type == parser.TypeUnknown {
t.Fatal("expected type hint for newAge parameter")
}
if method2.Function.Parameters[0].TypeHint.Type != "number" {
t.Errorf("expected parameter type 'number', got %s", method2.Function.Parameters[0].TypeHint.Type)
if method2.Function.Parameters[0].TypeHint.Type != parser.TypeNumber {
t.Errorf("expected parameter type number, got %v", method2.Function.Parameters[0].TypeHint.Type)
}
}
@ -226,19 +226,21 @@ empty = Person{}`
t.Fatalf("expected 3 statements, got %d", len(program.Statements))
}
structStmt := program.Statements[0].(*parser.StructStatement)
// Second statement: constructor with fields
assign1, ok := program.Statements[1].(*parser.AssignStatement)
assign1, ok := program.Statements[1].(*parser.Assignment)
if !ok {
t.Fatalf("expected AssignStatement, got %T", program.Statements[1])
t.Fatalf("expected Assignment, got %T", program.Statements[1])
}
constructor1, ok := assign1.Value.(*parser.StructConstructorExpression)
constructor1, ok := assign1.Value.(*parser.StructConstructor)
if !ok {
t.Fatalf("expected StructConstructorExpression, got %T", assign1.Value)
t.Fatalf("expected StructConstructor, got %T", assign1.Value)
}
if constructor1.StructName != "Person" {
t.Errorf("expected struct name 'Person', got %s", constructor1.StructName)
if constructor1.StructID != structStmt.ID {
t.Errorf("expected struct ID %d, got %d", structStmt.ID, constructor1.StructID)
}
if len(constructor1.Fields) != 2 {
@ -266,18 +268,18 @@ empty = Person{}`
testNumberLiteral(t, constructor1.Fields[1].Value, 30)
// Third statement: empty constructor
assign2, ok := program.Statements[2].(*parser.AssignStatement)
assign2, ok := program.Statements[2].(*parser.Assignment)
if !ok {
t.Fatalf("expected AssignStatement, got %T", program.Statements[2])
t.Fatalf("expected Assignment, got %T", program.Statements[2])
}
constructor2, ok := assign2.Value.(*parser.StructConstructorExpression)
constructor2, ok := assign2.Value.(*parser.StructConstructor)
if !ok {
t.Fatalf("expected StructConstructorExpression, got %T", assign2.Value)
t.Fatalf("expected StructConstructor, got %T", assign2.Value)
}
if constructor2.StructName != "Person" {
t.Errorf("expected struct name 'Person', got %s", constructor2.StructName)
if constructor2.StructID != structStmt.ID {
t.Errorf("expected struct ID %d, got %d", structStmt.ID, constructor2.StructID)
}
if len(constructor2.Fields) != 0 {
@ -310,6 +312,8 @@ person = Person{
t.Fatalf("expected 3 statements, got %d", len(program.Statements))
}
addressStruct := program.Statements[0].(*parser.StructStatement)
// Check Person struct has Address field type
personStruct, ok := program.Statements[1].(*parser.StructStatement)
if !ok {
@ -320,29 +324,32 @@ person = Person{
if addressField.Name != "address" {
t.Errorf("expected field name 'address', got %s", addressField.Name)
}
if addressField.TypeHint.Type != "Address" {
t.Errorf("expected field type 'Address', got %s", addressField.TypeHint.Type)
if addressField.TypeHint.Type != parser.TypeStruct {
t.Errorf("expected field type struct, got %v", addressField.TypeHint.Type)
}
if addressField.TypeHint.StructID != addressStruct.ID {
t.Errorf("expected struct ID %d, got %d", addressStruct.ID, addressField.TypeHint.StructID)
}
// Check nested constructor
assign, ok := program.Statements[2].(*parser.AssignStatement)
assign, ok := program.Statements[2].(*parser.Assignment)
if !ok {
t.Fatalf("expected AssignStatement, got %T", program.Statements[2])
t.Fatalf("expected Assignment, got %T", program.Statements[2])
}
personConstructor, ok := assign.Value.(*parser.StructConstructorExpression)
personConstructor, ok := assign.Value.(*parser.StructConstructor)
if !ok {
t.Fatalf("expected StructConstructorExpression, got %T", assign.Value)
t.Fatalf("expected StructConstructor, got %T", assign.Value)
}
// Check the nested Address constructor
addressConstructor, ok := personConstructor.Fields[1].Value.(*parser.StructConstructorExpression)
addressConstructor, ok := personConstructor.Fields[1].Value.(*parser.StructConstructor)
if !ok {
t.Fatalf("expected nested StructConstructorExpression, got %T", personConstructor.Fields[1].Value)
t.Fatalf("expected nested StructConstructor, got %T", personConstructor.Fields[1].Value)
}
if addressConstructor.StructName != "Address" {
t.Errorf("expected nested struct name 'Address', got %s", addressConstructor.StructName)
if addressConstructor.StructID != addressStruct.ID {
t.Errorf("expected nested struct ID %d, got %d", addressStruct.ID, addressConstructor.StructID)
}
if len(addressConstructor.Fields) != 2 {
@ -397,8 +404,8 @@ end`
if !ok {
t.Fatalf("expected MethodDefinition, got %T", program.Statements[1])
}
if methodStmt.StructName != "Point" {
t.Errorf("expected struct name 'Point', got %s", methodStmt.StructName)
if methodStmt.StructID != structStmt.ID {
t.Errorf("expected struct ID %d, got %d", structStmt.ID, methodStmt.StructID)
}
if methodStmt.MethodName != "distance" {
t.Errorf("expected method name 'distance', got %s", methodStmt.MethodName)
@ -406,16 +413,16 @@ end`
// Verify constructors
for i := 2; i <= 3; i++ {
assign, ok := program.Statements[i].(*parser.AssignStatement)
assign, ok := program.Statements[i].(*parser.Assignment)
if !ok {
t.Fatalf("statement %d: expected AssignStatement, got %T", i, program.Statements[i])
t.Fatalf("statement %d: expected Assignment, got %T", i, program.Statements[i])
}
constructor, ok := assign.Value.(*parser.StructConstructorExpression)
constructor, ok := assign.Value.(*parser.StructConstructor)
if !ok {
t.Fatalf("statement %d: expected StructConstructorExpression, got %T", i, assign.Value)
t.Fatalf("statement %d: expected StructConstructor, got %T", i, assign.Value)
}
if constructor.StructName != "Point" {
t.Errorf("statement %d: expected struct name 'Point', got %s", i, constructor.StructName)
if constructor.StructID != structStmt.ID {
t.Errorf("statement %d: expected struct ID %d, got %d", i, structStmt.ID, constructor.StructID)
}
}
@ -446,16 +453,16 @@ end`
}
// Check struct constructor in loop
loopAssign, ok := forStmt.Body[0].(*parser.AssignStatement)
loopAssign, ok := forStmt.Body[0].(*parser.Assignment)
if !ok {
t.Fatalf("expected AssignStatement in loop, got %T", forStmt.Body[0])
t.Fatalf("expected Assignment in loop, got %T", forStmt.Body[0])
}
loopConstructor, ok := loopAssign.Value.(*parser.StructConstructorExpression)
loopConstructor, ok := loopAssign.Value.(*parser.StructConstructor)
if !ok {
t.Fatalf("expected StructConstructorExpression in loop, got %T", loopAssign.Value)
t.Fatalf("expected StructConstructor in loop, got %T", loopAssign.Value)
}
if loopConstructor.StructName != "Point" {
t.Errorf("expected struct name 'Point' in loop, got %s", loopConstructor.StructName)
if loopConstructor.StructID != structStmt.ID {
t.Errorf("expected struct ID %d in loop, got %d", structStmt.ID, loopConstructor.StructID)
}
}
@ -552,13 +559,13 @@ func TestSingleLineStruct(t *testing.T) {
t.Fatalf("expected 2 fields, got %d", len(stmt.Fields))
}
if stmt.Fields[0].Name != "name" || stmt.Fields[0].TypeHint.Type != "string" {
t.Errorf("expected first field 'name: string', got '%s: %s'",
if stmt.Fields[0].Name != "name" || stmt.Fields[0].TypeHint.Type != parser.TypeString {
t.Errorf("expected first field 'name: string', got '%s: %v'",
stmt.Fields[0].Name, stmt.Fields[0].TypeHint.Type)
}
if stmt.Fields[1].Name != "age" || stmt.Fields[1].TypeHint.Type != "number" {
t.Errorf("expected second field 'age: number', got '%s: %s'",
if stmt.Fields[1].Name != "age" || stmt.Fields[1].TypeHint.Type != parser.TypeNumber {
t.Errorf("expected second field 'age: number', got '%s: %v'",
stmt.Fields[1].Name, stmt.Fields[1].TypeHint.Type)
}
}
@ -600,8 +607,8 @@ end`
method := program.Statements[1].(*parser.MethodDefinition)
str := method.String()
if !containsSubstring(str, "fn Person.getName") {
t.Errorf("expected method string to contain 'fn Person.getName', got: %s", str)
if !containsSubstring(str, "fn <struct>.getName") {
t.Errorf("expected method string to contain 'fn <struct>.getName', got: %s", str)
}
if !containsSubstring(str, ": string") {
t.Errorf("expected method string to contain return type, got: %s", str)
@ -621,11 +628,11 @@ person = Person{name = "John", age = 30}`
program := p.ParseProgram()
checkParserErrors(t, p)
assign := program.Statements[1].(*parser.AssignStatement)
constructor := assign.Value.(*parser.StructConstructorExpression)
assign := program.Statements[1].(*parser.Assignment)
constructor := assign.Value.(*parser.StructConstructor)
str := constructor.String()
expected := `Person{name = "John", age = 30.00}`
expected := `<struct>{name = "John", age = 30.00}`
if str != expected {
t.Errorf("expected constructor string:\n%s\ngot:\n%s", expected, str)
}

View File

@ -10,18 +10,18 @@ func TestVariableTypeHints(t *testing.T) {
tests := []struct {
input string
variable string
typeHint string
typeHint parser.Type
hasHint bool
desc string
}{
{"x = 42", "x", "", false, "no type hint"},
{"x: number = 42", "x", "number", true, "number type hint"},
{"name: string = \"hello\"", "name", "string", true, "string type hint"},
{"flag: bool = true", "flag", "bool", true, "bool type hint"},
{"data: table = {}", "data", "table", true, "table type hint"},
{"fn_var: function = fn() end", "fn_var", "function", true, "function type hint"},
{"value: any = nil", "value", "any", true, "any type hint"},
{"ptr: nil = nil", "ptr", "nil", true, "nil type hint"},
{"x = 42", "x", parser.TypeUnknown, false, "no type hint"},
{"x: number = 42", "x", parser.TypeNumber, true, "number type hint"},
{"name: string = \"hello\"", "name", parser.TypeString, true, "string type hint"},
{"flag: bool = true", "flag", parser.TypeBool, true, "bool type hint"},
{"data: table = {}", "data", parser.TypeTable, true, "table type hint"},
{"fn_var: function = fn() end", "fn_var", parser.TypeFunction, true, "function type hint"},
{"value: any = nil", "value", parser.TypeAny, true, "any type hint"},
{"ptr: nil = nil", "ptr", parser.TypeNil, true, "nil type hint"},
}
for _, tt := range tests {
@ -35,15 +35,15 @@ func TestVariableTypeHints(t *testing.T) {
t.Fatalf("expected 1 statement, got %d", len(program.Statements))
}
stmt, ok := program.Statements[0].(*parser.AssignStatement)
stmt, ok := program.Statements[0].(*parser.Assignment)
if !ok {
t.Fatalf("expected AssignStatement, got %T", program.Statements[0])
t.Fatalf("expected Assignment, got %T", program.Statements[0])
}
// Check variable name
ident, ok := stmt.Name.(*parser.Identifier)
ident, ok := stmt.Target.(*parser.Identifier)
if !ok {
t.Fatalf("expected Identifier for Name, got %T", stmt.Name)
t.Fatalf("expected Identifier for Target, got %T", stmt.Target)
}
if ident.Value != tt.variable {
@ -52,19 +52,19 @@ func TestVariableTypeHints(t *testing.T) {
// Check type hint
if tt.hasHint {
if stmt.TypeHint == nil {
t.Error("expected type hint but got nil")
if stmt.TypeHint.Type == parser.TypeUnknown {
t.Error("expected type hint but got TypeUnknown")
} else {
if stmt.TypeHint.Type != tt.typeHint {
t.Errorf("expected type hint %s, got %s", tt.typeHint, stmt.TypeHint.Type)
t.Errorf("expected type hint %v, got %v", tt.typeHint, stmt.TypeHint.Type)
}
if stmt.TypeHint.Inferred {
t.Error("expected type hint to not be inferred")
}
}
} else {
if stmt.TypeHint != nil {
t.Errorf("expected no type hint but got %s", stmt.TypeHint.Type)
if stmt.TypeHint.Type != parser.TypeUnknown {
t.Errorf("expected no type hint but got %v", stmt.TypeHint.Type)
}
}
})
@ -74,60 +74,81 @@ func TestVariableTypeHints(t *testing.T) {
func TestFunctionParameterTypeHints(t *testing.T) {
tests := []struct {
input string
params []struct{ name, typeHint string }
returnType string
params []struct {
name string
typeHint parser.Type
}
returnType parser.Type
hasReturn bool
desc string
}{
{
"fn(a, b) end",
[]struct{ name, typeHint string }{
{"a", ""},
{"b", ""},
[]struct {
name string
typeHint parser.Type
}{
{"a", parser.TypeUnknown},
{"b", parser.TypeUnknown},
},
"", false,
parser.TypeUnknown, false,
"no type hints",
},
{
"fn(a: number, b: string) end",
[]struct{ name, typeHint string }{
{"a", "number"},
{"b", "string"},
[]struct {
name string
typeHint parser.Type
}{
{"a", parser.TypeNumber},
{"b", parser.TypeString},
},
"", false,
parser.TypeUnknown, false,
"parameter type hints only",
},
{
"fn(x: number): string end",
[]struct{ name, typeHint string }{
{"x", "number"},
[]struct {
name string
typeHint parser.Type
}{
{"x", parser.TypeNumber},
},
"string", true,
parser.TypeString, true,
"parameter and return type hints",
},
{
"fn(): bool end",
[]struct{ name, typeHint string }{},
"bool", true,
[]struct {
name string
typeHint parser.Type
}{},
parser.TypeBool, true,
"return type hint only",
},
{
"fn(a: number, b, c: string): table end",
[]struct{ name, typeHint string }{
{"a", "number"},
{"b", ""},
{"c", "string"},
[]struct {
name string
typeHint parser.Type
}{
{"a", parser.TypeNumber},
{"b", parser.TypeUnknown},
{"c", parser.TypeString},
},
"table", true,
parser.TypeTable, true,
"mixed parameter types with return",
},
{
"fn(callback: function, data: any): nil end",
[]struct{ name, typeHint string }{
{"callback", "function"},
{"data", "any"},
[]struct {
name string
typeHint parser.Type
}{
{"callback", parser.TypeFunction},
{"data", parser.TypeAny},
},
"nil", true,
parser.TypeNil, true,
"function and any types",
},
}
@ -155,29 +176,29 @@ func TestFunctionParameterTypeHints(t *testing.T) {
t.Errorf("parameter %d: expected name %s, got %s", i, expected.name, param.Name)
}
if expected.typeHint == "" {
if param.TypeHint != nil {
t.Errorf("parameter %d: expected no type hint but got %s", i, param.TypeHint.Type)
if expected.typeHint == parser.TypeUnknown {
if param.TypeHint.Type != parser.TypeUnknown {
t.Errorf("parameter %d: expected no type hint but got %v", i, param.TypeHint.Type)
}
} else {
if param.TypeHint == nil {
t.Errorf("parameter %d: expected type hint %s but got nil", i, expected.typeHint)
if param.TypeHint.Type == parser.TypeUnknown {
t.Errorf("parameter %d: expected type hint %v but got TypeUnknown", i, expected.typeHint)
} else if param.TypeHint.Type != expected.typeHint {
t.Errorf("parameter %d: expected type hint %s, got %s", i, expected.typeHint, param.TypeHint.Type)
t.Errorf("parameter %d: expected type hint %v, got %v", i, expected.typeHint, param.TypeHint.Type)
}
}
}
// Check return type
if tt.hasReturn {
if fn.ReturnType == nil {
t.Error("expected return type hint but got nil")
if fn.ReturnType.Type == parser.TypeUnknown {
t.Error("expected return type hint but got TypeUnknown")
} else if fn.ReturnType.Type != tt.returnType {
t.Errorf("expected return type %s, got %s", tt.returnType, fn.ReturnType.Type)
t.Errorf("expected return type %v, got %v", tt.returnType, fn.ReturnType.Type)
}
} else {
if fn.ReturnType != nil {
t.Errorf("expected no return type but got %s", fn.ReturnType.Type)
if fn.ReturnType.Type != parser.TypeUnknown {
t.Errorf("expected no return type but got %v", fn.ReturnType.Type)
}
}
})
@ -279,13 +300,13 @@ func TestMemberAccessWithoutTypeHints(t *testing.T) {
t.Fatalf("expected 1 statement, got %d", len(program.Statements))
}
stmt, ok := program.Statements[0].(*parser.AssignStatement)
stmt, ok := program.Statements[0].(*parser.Assignment)
if !ok {
t.Fatalf("expected AssignStatement, got %T", program.Statements[0])
t.Fatalf("expected Assignment, got %T", program.Statements[0])
}
// Member access should never have type hints
if stmt.TypeHint != nil {
if stmt.TypeHint.Type != parser.TypeUnknown {
t.Error("member access assignment should not have type hints")
}
@ -333,12 +354,12 @@ func TestTypeInferenceErrors(t *testing.T) {
}{
{
"x: number = \"hello\"",
"cannot assign string to variable of type number",
"type mismatch in assignment",
"type mismatch in assignment",
},
{
"x = 42\ny: string = x",
"cannot assign number to variable of type string",
"type mismatch in assignment",
"type mismatch with inferred type",
},
}
@ -359,7 +380,7 @@ func TestTypeInferenceErrors(t *testing.T) {
found := false
for _, err := range typeErrors {
if err.Message == tt.expectedError {
if containsSubstring(err.Message, tt.expectedError) {
found = true
break
}
@ -370,7 +391,7 @@ func TestTypeInferenceErrors(t *testing.T) {
for i, err := range typeErrors {
errorMsgs[i] = err.Message
}
t.Errorf("expected error %q, got %v", tt.expectedError, errorMsgs)
t.Errorf("expected error containing %q, got %v", tt.expectedError, errorMsgs)
}
})
}
@ -439,22 +460,22 @@ server: table = {
}
// Check first statement: config table with typed assignments
configStmt, ok := program.Statements[0].(*parser.AssignStatement)
configStmt, ok := program.Statements[0].(*parser.Assignment)
if !ok {
t.Fatalf("expected AssignStatement, got %T", program.Statements[0])
t.Fatalf("expected Assignment, got %T", program.Statements[0])
}
if configStmt.TypeHint == nil || configStmt.TypeHint.Type != "table" {
if configStmt.TypeHint.Type == parser.TypeUnknown || configStmt.TypeHint.Type != parser.TypeTable {
t.Error("expected table type hint for config")
}
// Check second statement: handler function with typed parameters
handlerStmt, ok := program.Statements[1].(*parser.AssignStatement)
handlerStmt, ok := program.Statements[1].(*parser.Assignment)
if !ok {
t.Fatalf("expected AssignStatement, got %T", program.Statements[1])
t.Fatalf("expected Assignment, got %T", program.Statements[1])
}
if handlerStmt.TypeHint == nil || handlerStmt.TypeHint.Type != "function" {
if handlerStmt.TypeHint.Type == parser.TypeUnknown || handlerStmt.TypeHint.Type != parser.TypeFunction {
t.Error("expected function type hint for handler")
}
@ -468,34 +489,32 @@ server: table = {
}
// Check parameter types
if fn.Parameters[0].TypeHint == nil || fn.Parameters[0].TypeHint.Type != "table" {
if fn.Parameters[0].TypeHint.Type == parser.TypeUnknown || fn.Parameters[0].TypeHint.Type != parser.TypeTable {
t.Error("expected table type for request parameter")
}
if fn.Parameters[1].TypeHint == nil || fn.Parameters[1].TypeHint.Type != "function" {
if fn.Parameters[1].TypeHint.Type == parser.TypeUnknown || fn.Parameters[1].TypeHint.Type != parser.TypeFunction {
t.Error("expected function type for callback parameter")
}
// Check return type
if fn.ReturnType == nil || fn.ReturnType.Type != "nil" {
if fn.ReturnType.Type == parser.TypeUnknown || fn.ReturnType.Type != parser.TypeNil {
t.Error("expected nil return type for handler")
}
// Check third statement: server table
serverStmt, ok := program.Statements[2].(*parser.AssignStatement)
serverStmt, ok := program.Statements[2].(*parser.Assignment)
if !ok {
t.Fatalf("expected AssignStatement, got %T", program.Statements[2])
t.Fatalf("expected Assignment, got %T", program.Statements[2])
}
if serverStmt.TypeHint == nil || serverStmt.TypeHint.Type != "table" {
if serverStmt.TypeHint.Type == parser.TypeUnknown || serverStmt.TypeHint.Type != parser.TypeTable {
t.Error("expected table type hint for server")
}
}
func TestTypeInfoGettersSetters(t *testing.T) {
// Test that all expression types properly implement GetType/SetType
typeInfo := &parser.TypeInfo{Type: "test", Inferred: true}
func TestTypeInfoInterface(t *testing.T) {
// Test that all expression types properly implement TypeInfo()
expressions := []parser.Expression{
&parser.Identifier{Value: "x"},
&parser.NumberLiteral{Value: 42},
@ -513,20 +532,43 @@ func TestTypeInfoGettersSetters(t *testing.T) {
for i, expr := range expressions {
t.Run(string(rune('0'+i)), func(t *testing.T) {
// Initially should have no type
if expr.GetType() != nil {
t.Error("expected nil type initially")
// Should have default type initially
typeInfo := expr.TypeInfo()
// Basic literals should have their expected types
switch e := expr.(type) {
case *parser.NumberLiteral:
if typeInfo.Type != parser.TypeNumber {
t.Errorf("expected number type, got %v", typeInfo.Type)
}
// Set type
expr.SetType(typeInfo)
// Get type should return what we set
retrieved := expr.GetType()
if retrieved == nil {
t.Error("expected non-nil type after setting")
} else if retrieved.Type != "test" || !retrieved.Inferred {
t.Errorf("expected {Type: test, Inferred: true}, got %+v", retrieved)
case *parser.StringLiteral:
if typeInfo.Type != parser.TypeString {
t.Errorf("expected string type, got %v", typeInfo.Type)
}
case *parser.BooleanLiteral:
if typeInfo.Type != parser.TypeBool {
t.Errorf("expected bool type, got %v", typeInfo.Type)
}
case *parser.NilLiteral:
if typeInfo.Type != parser.TypeNil {
t.Errorf("expected nil type, got %v", typeInfo.Type)
}
case *parser.TableLiteral:
if typeInfo.Type != parser.TypeTable {
t.Errorf("expected table type, got %v", typeInfo.Type)
}
case *parser.FunctionLiteral:
if typeInfo.Type != parser.TypeFunction {
t.Errorf("expected function type, got %v", typeInfo.Type)
}
case *parser.Identifier:
// Identifiers default to any type
if typeInfo.Type != parser.TypeAny {
t.Errorf("expected any type for untyped identifier, got %v", typeInfo.Type)
}
default:
// Other expressions may have unknown type initially
_ = e
}
})
}

View File

@ -1,21 +1,45 @@
package parser
import (
"fmt"
)
import "fmt"
// Type represents built-in and user-defined types using compact enum representation.
// Uses single byte instead of string pointers to minimize memory usage.
type Type uint8
// Type constants for built-in types
const (
TypeNumber = "number"
TypeString = "string"
TypeBool = "bool"
TypeNil = "nil"
TypeTable = "table"
TypeFunction = "function"
TypeAny = "any"
TypeUnknown Type = iota
TypeNumber
TypeString
TypeBool
TypeNil
TypeTable
TypeFunction
TypeAny
TypeStruct // struct types use StructID field for identification
)
// TypeError represents a type checking error
// TypeInfo represents type information with zero-allocation design.
// Embeds directly in AST nodes instead of using pointers to reduce heap pressure.
// Common types are pre-allocated as globals to eliminate most allocations.
type TypeInfo struct {
Type Type // Built-in type or TypeStruct for user types
Inferred bool // True if type was inferred, false if explicitly declared
StructID uint16 // Index into global struct table for struct types (0 for non-structs)
}
// Pre-allocated common types - eliminates heap allocations for built-in types
var (
UnknownType = TypeInfo{Type: TypeUnknown, Inferred: true}
NumberType = TypeInfo{Type: TypeNumber, Inferred: true}
StringType = TypeInfo{Type: TypeString, Inferred: true}
BoolType = TypeInfo{Type: TypeBool, Inferred: true}
NilType = TypeInfo{Type: TypeNil, Inferred: true}
TableType = TypeInfo{Type: TypeTable, Inferred: true}
FunctionType = TypeInfo{Type: TypeFunction, Inferred: true}
AnyType = TypeInfo{Type: TypeAny, Inferred: true}
)
// TypeError represents a type checking error with location information
type TypeError struct {
Message string
Line int
@ -27,10 +51,10 @@ func (te TypeError) Error() string {
return fmt.Sprintf("Type error at line %d, column %d: %s", te.Line, te.Column, te.Message)
}
// Symbol represents a variable in the symbol table
// Symbol represents a variable in the symbol table with optimized type storage
type Symbol struct {
Name string
Type *TypeInfo
Type TypeInfo // Embed directly instead of pointer
Declared bool
Line int
Column int
@ -63,52 +87,56 @@ func (s *Scope) Lookup(name string) *Symbol {
return nil
}
// TypeInferrer performs type inference and checking
// TypeInferrer performs type inference and checking with optimized allocations
type TypeInferrer struct {
currentScope *Scope
globalScope *Scope
errors []TypeError
// Pre-allocated type objects for performance
numberType *TypeInfo
stringType *TypeInfo
boolType *TypeInfo
nilType *TypeInfo
tableType *TypeInfo
anyType *TypeInfo
// Struct definitions
// Struct definitions with ID mapping
structs map[string]*StructStatement
structIDs map[uint16]*StructStatement
nextID uint16
}
// NewTypeInferrer creates a new type inference engine
func NewTypeInferrer() *TypeInferrer {
globalScope := NewScope(nil)
ti := &TypeInferrer{
return &TypeInferrer{
currentScope: globalScope,
globalScope: globalScope,
errors: []TypeError{},
structs: make(map[string]*StructStatement),
// Pre-allocate common types to reduce allocations
numberType: &TypeInfo{Type: TypeNumber, Inferred: true},
stringType: &TypeInfo{Type: TypeString, Inferred: true},
boolType: &TypeInfo{Type: TypeBool, Inferred: true},
nilType: &TypeInfo{Type: TypeNil, Inferred: true},
tableType: &TypeInfo{Type: TypeTable, Inferred: true},
anyType: &TypeInfo{Type: TypeAny, Inferred: true},
structIDs: make(map[uint16]*StructStatement),
nextID: 1, // 0 reserved for non-struct types
}
}
return ti
// RegisterStruct assigns ID to struct and tracks it
func (ti *TypeInferrer) RegisterStruct(stmt *StructStatement) {
stmt.ID = ti.nextID
ti.nextID++
ti.structs[stmt.Name] = stmt
ti.structIDs[stmt.ID] = stmt
}
// GetStructByID returns struct definition by ID
func (ti *TypeInferrer) GetStructByID(id uint16) *StructStatement {
return ti.structIDs[id]
}
// CreateStructType creates TypeInfo for a struct
func (ti *TypeInferrer) CreateStructType(structID uint16) TypeInfo {
return TypeInfo{Type: TypeStruct, StructID: structID, Inferred: true}
}
// InferTypes performs type inference on the entire program
func (ti *TypeInferrer) InferTypes(program *Program) []TypeError {
// First pass: collect struct definitions
// First pass: collect and register struct definitions
for _, stmt := range program.Statements {
if structStmt, ok := stmt.(*StructStatement); ok {
ti.structs[structStmt.Name] = structStmt
ti.RegisterStruct(structStmt)
}
}
@ -119,42 +147,6 @@ func (ti *TypeInferrer) InferTypes(program *Program) []TypeError {
return ti.errors
}
// enterScope creates a new scope
func (ti *TypeInferrer) enterScope() {
ti.currentScope = NewScope(ti.currentScope)
}
// exitScope returns to the parent scope
func (ti *TypeInferrer) exitScope() {
if ti.currentScope.parent != nil {
ti.currentScope = ti.currentScope.parent
}
}
// addError adds a type error
func (ti *TypeInferrer) addError(message string, node Node) {
ti.errors = append(ti.errors, TypeError{
Message: message,
Line: 0, // Would need to track position in AST nodes
Column: 0,
Node: node,
})
}
// getStructType returns TypeInfo for a struct
func (ti *TypeInferrer) getStructType(name string) *TypeInfo {
if _, exists := ti.structs[name]; exists {
return &TypeInfo{Type: name, Inferred: true}
}
return nil
}
// isStructType checks if a type is a struct type
func (ti *TypeInferrer) isStructType(t *TypeInfo) bool {
_, exists := ti.structs[t.Type]
return exists
}
// inferStatement infers types for statements
func (ti *TypeInferrer) inferStatement(stmt Statement) {
switch s := stmt.(type) {
@ -162,8 +154,10 @@ func (ti *TypeInferrer) inferStatement(stmt Statement) {
ti.inferStructStatement(s)
case *MethodDefinition:
ti.inferMethodDefinition(s)
case *AssignStatement:
ti.inferAssignStatement(s)
case *Assignment:
ti.inferAssignment(s)
case *ExpressionStatement:
ti.inferExpression(s.Expression)
case *EchoStatement:
ti.inferExpression(s.Value)
case *IfStatement:
@ -182,46 +176,38 @@ func (ti *TypeInferrer) inferStatement(stmt Statement) {
if s.Value != nil {
ti.inferExpression(s.Value)
}
case *ExpressionStatement:
ti.inferExpression(s.Expression)
case *BreakStatement:
// No-op
}
}
// inferStructStatement handles struct definitions
func (ti *TypeInferrer) inferStructStatement(stmt *StructStatement) {
// Validate field types
for _, field := range stmt.Fields {
if field.TypeHint != nil {
if !ValidTypeName(field.TypeHint.Type) && !ti.isStructType(field.TypeHint) {
ti.addError(fmt.Sprintf("invalid field type '%s' in struct '%s'",
field.TypeHint.Type, stmt.Name), stmt)
}
if !ti.isValidType(field.TypeHint) {
ti.addError(fmt.Sprintf("invalid field type in struct '%s'", stmt.Name), stmt)
}
}
}
// inferMethodDefinition handles method definitions
func (ti *TypeInferrer) inferMethodDefinition(stmt *MethodDefinition) {
// Check if struct exists
if _, exists := ti.structs[stmt.StructName]; !exists {
ti.addError(fmt.Sprintf("method defined on undefined struct '%s'", stmt.StructName), stmt)
structDef := ti.GetStructByID(stmt.StructID)
if structDef == nil {
ti.addError("method defined on undefined struct", stmt)
return
}
// Infer the function body
ti.enterScope()
// Add self parameter implicitly
// Add self parameter
ti.currentScope.Define(&Symbol{
Name: "self",
Type: ti.getStructType(stmt.StructName),
Type: ti.CreateStructType(stmt.StructID),
Declared: true,
})
// Add explicit parameters
// Add function parameters
for _, param := range stmt.Function.Parameters {
paramType := ti.anyType
if param.TypeHint != nil {
paramType := AnyType
if param.TypeHint.Type != TypeUnknown {
paramType = param.TypeHint
}
ti.currentScope.Define(&Symbol{
@ -235,66 +221,47 @@ func (ti *TypeInferrer) inferMethodDefinition(stmt *MethodDefinition) {
for _, bodyStmt := range stmt.Function.Body {
ti.inferStatement(bodyStmt)
}
ti.exitScope()
}
// inferAssignStatement handles variable assignments with type checking
func (ti *TypeInferrer) inferAssignStatement(stmt *AssignStatement) {
// Infer the type of the value expression
func (ti *TypeInferrer) inferAssignment(stmt *Assignment) {
valueType := ti.inferExpression(stmt.Value)
if ident, ok := stmt.Name.(*Identifier); ok {
// Simple variable assignment
symbol := ti.currentScope.Lookup(ident.Value)
if ident, ok := stmt.Target.(*Identifier); ok {
if stmt.IsDeclaration {
// New variable declaration
varType := valueType
// If there's a type hint, validate it
if stmt.TypeHint != nil {
if stmt.TypeHint.Type != TypeUnknown {
if !ti.isTypeCompatible(valueType, stmt.TypeHint) {
ti.addError(fmt.Sprintf("cannot assign %s to variable of type %s",
valueType.Type, stmt.TypeHint.Type), stmt)
ti.addError("type mismatch in assignment", stmt)
}
varType = stmt.TypeHint
varType.Inferred = false
}
// Define the new symbol
ti.currentScope.Define(&Symbol{
Name: ident.Value,
Type: varType,
Declared: true,
})
ident.SetType(varType)
ident.typeInfo = varType
} else {
// Assignment to existing variable
symbol := ti.currentScope.Lookup(ident.Value)
if symbol == nil {
ti.addError(fmt.Sprintf("undefined variable '%s'", ident.Value), stmt)
return
}
// Check type compatibility
if !ti.isTypeCompatible(valueType, symbol.Type) {
ti.addError(fmt.Sprintf("cannot assign %s to variable of type %s",
valueType.Type, symbol.Type.Type), stmt)
ti.addError("type mismatch in assignment", stmt)
}
ident.SetType(symbol.Type)
ident.typeInfo = symbol.Type
}
} else {
// Member access assignment (table.key or table[index])
ti.inferExpression(stmt.Name)
ti.inferExpression(stmt.Target)
}
}
// inferIfStatement handles if statements
func (ti *TypeInferrer) inferIfStatement(stmt *IfStatement) {
condType := ti.inferExpression(stmt.Condition)
ti.validateBooleanContext(condType, stmt.Condition)
ti.inferExpression(stmt.Condition)
ti.enterScope()
for _, s := range stmt.Body {
@ -303,9 +270,7 @@ func (ti *TypeInferrer) inferIfStatement(stmt *IfStatement) {
ti.exitScope()
for _, elseif := range stmt.ElseIfs {
condType := ti.inferExpression(elseif.Condition)
ti.validateBooleanContext(condType, elseif.Condition)
ti.inferExpression(elseif.Condition)
ti.enterScope()
for _, s := range elseif.Body {
ti.inferStatement(s)
@ -322,10 +287,8 @@ func (ti *TypeInferrer) inferIfStatement(stmt *IfStatement) {
}
}
// inferWhileStatement handles while loops
func (ti *TypeInferrer) inferWhileStatement(stmt *WhileStatement) {
condType := ti.inferExpression(stmt.Condition)
ti.validateBooleanContext(condType, stmt.Condition)
ti.inferExpression(stmt.Condition)
ti.enterScope()
for _, s := range stmt.Body {
@ -334,33 +297,21 @@ func (ti *TypeInferrer) inferWhileStatement(stmt *WhileStatement) {
ti.exitScope()
}
// inferForStatement handles numeric for loops
func (ti *TypeInferrer) inferForStatement(stmt *ForStatement) {
startType := ti.inferExpression(stmt.Start)
endType := ti.inferExpression(stmt.End)
if !ti.isNumericType(startType) {
ti.addError("for loop start value must be numeric", stmt.Start)
}
if !ti.isNumericType(endType) {
ti.addError("for loop end value must be numeric", stmt.End)
}
ti.inferExpression(stmt.Start)
ti.inferExpression(stmt.End)
if stmt.Step != nil {
stepType := ti.inferExpression(stmt.Step)
if !ti.isNumericType(stepType) {
ti.addError("for loop step value must be numeric", stmt.Step)
}
ti.inferExpression(stmt.Step)
}
ti.enterScope()
// Define loop variable as number
ti.currentScope.Define(&Symbol{
Name: stmt.Variable.Value,
Type: ti.numberType,
Type: NumberType,
Declared: true,
})
stmt.Variable.SetType(ti.numberType)
stmt.Variable.typeInfo = NumberType
for _, s := range stmt.Body {
ti.inferStatement(s)
@ -368,33 +319,26 @@ func (ti *TypeInferrer) inferForStatement(stmt *ForStatement) {
ti.exitScope()
}
// inferForInStatement handles for-in loops
func (ti *TypeInferrer) inferForInStatement(stmt *ForInStatement) {
iterableType := ti.inferExpression(stmt.Iterable)
// For now, assume iterable is a table or struct
if !ti.isTableType(iterableType) && !ti.isStructType(iterableType) {
ti.addError("for-in requires an iterable (table or struct)", stmt.Iterable)
}
ti.inferExpression(stmt.Iterable)
ti.enterScope()
// Define loop variables (key and value are any for now)
// Define loop variables
if stmt.Key != nil {
ti.currentScope.Define(&Symbol{
Name: stmt.Key.Value,
Type: ti.anyType,
Type: AnyType,
Declared: true,
})
stmt.Key.SetType(ti.anyType)
stmt.Key.typeInfo = AnyType
}
ti.currentScope.Define(&Symbol{
Name: stmt.Value.Value,
Type: ti.anyType,
Type: AnyType,
Declared: true,
})
stmt.Value.SetType(ti.anyType)
stmt.Value.typeInfo = AnyType
for _, s := range stmt.Body {
ti.inferStatement(s)
@ -403,29 +347,25 @@ func (ti *TypeInferrer) inferForInStatement(stmt *ForInStatement) {
}
// inferExpression infers the type of an expression
func (ti *TypeInferrer) inferExpression(expr Expression) *TypeInfo {
func (ti *TypeInferrer) inferExpression(expr Expression) TypeInfo {
if expr == nil {
return ti.nilType
return NilType
}
switch e := expr.(type) {
case *Identifier:
return ti.inferIdentifier(e)
case *NumberLiteral:
e.SetType(ti.numberType)
return ti.numberType
return NumberType
case *StringLiteral:
e.SetType(ti.stringType)
return ti.stringType
return StringType
case *BooleanLiteral:
e.SetType(ti.boolType)
return ti.boolType
return BoolType
case *NilLiteral:
e.SetType(ti.nilType)
return ti.nilType
return NilType
case *TableLiteral:
return ti.inferTableLiteral(e)
case *StructConstructorExpression:
case *StructConstructor:
return ti.inferStructConstructor(e)
case *FunctionLiteral:
return ti.inferFunctionLiteral(e)
@ -439,20 +379,40 @@ func (ti *TypeInferrer) inferExpression(expr Expression) *TypeInfo {
return ti.inferIndexExpression(e)
case *DotExpression:
return ti.inferDotExpression(e)
case *AssignExpression:
return ti.inferAssignExpression(e)
case *Assignment:
return ti.inferAssignmentExpression(e)
default:
ti.addError("unknown expression type", expr)
return ti.anyType
return AnyType
}
}
// inferStructConstructor handles struct constructor expressions
func (ti *TypeInferrer) inferStructConstructor(expr *StructConstructorExpression) *TypeInfo {
structDef, exists := ti.structs[expr.StructName]
if !exists {
ti.addError(fmt.Sprintf("undefined struct '%s'", expr.StructName), expr)
return ti.anyType
func (ti *TypeInferrer) inferIdentifier(ident *Identifier) TypeInfo {
symbol := ti.currentScope.Lookup(ident.Value)
if symbol == nil {
ti.addError(fmt.Sprintf("undefined variable '%s'", ident.Value), ident)
return AnyType
}
ident.typeInfo = symbol.Type
return symbol.Type
}
func (ti *TypeInferrer) inferTableLiteral(table *TableLiteral) TypeInfo {
// Infer types of all values
for _, pair := range table.Pairs {
if pair.Key != nil {
ti.inferExpression(pair.Key)
}
ti.inferExpression(pair.Value)
}
return TableType
}
func (ti *TypeInferrer) inferStructConstructor(expr *StructConstructor) TypeInfo {
structDef := ti.GetStructByID(expr.StructID)
if structDef == nil {
ti.addError("undefined struct in constructor", expr)
return AnyType
}
// Validate field assignments
@ -465,25 +425,23 @@ func (ti *TypeInferrer) inferStructConstructor(expr *StructConstructorExpression
fieldName = str.Value
}
// Check if field exists in struct
fieldExists := false
var fieldType *TypeInfo
// Find field in struct definition
var fieldType TypeInfo
found := false
for _, field := range structDef.Fields {
if field.Name == fieldName {
fieldExists = true
fieldType = field.TypeHint
found = true
break
}
}
if !fieldExists {
ti.addError(fmt.Sprintf("struct '%s' has no field '%s'", expr.StructName, fieldName), expr)
if !found {
ti.addError(fmt.Sprintf("struct has no field '%s'", fieldName), expr)
} else {
// Check type compatibility
valueType := ti.inferExpression(pair.Value)
if !ti.isTypeCompatible(valueType, fieldType) {
ti.addError(fmt.Sprintf("cannot assign %s to field '%s' of type %s",
valueType.Type, fieldName, fieldType.Type), expr)
ti.addError("field type mismatch in struct constructor", expr)
}
}
} else {
@ -492,66 +450,18 @@ func (ti *TypeInferrer) inferStructConstructor(expr *StructConstructorExpression
}
}
structType := ti.getStructType(expr.StructName)
expr.SetType(structType)
structType := ti.CreateStructType(expr.StructID)
expr.typeInfo = structType
return structType
}
// inferAssignExpression handles assignment expressions
func (ti *TypeInferrer) inferAssignExpression(expr *AssignExpression) *TypeInfo {
valueType := ti.inferExpression(expr.Value)
if ident, ok := expr.Name.(*Identifier); ok {
if expr.IsDeclaration {
ti.currentScope.Define(&Symbol{
Name: ident.Value,
Type: valueType,
Declared: true,
})
}
ident.SetType(valueType)
} else {
ti.inferExpression(expr.Name)
}
expr.SetType(valueType)
return valueType
}
// inferIdentifier looks up identifier type in symbol table
func (ti *TypeInferrer) inferIdentifier(ident *Identifier) *TypeInfo {
symbol := ti.currentScope.Lookup(ident.Value)
if symbol == nil {
ti.addError(fmt.Sprintf("undefined variable '%s'", ident.Value), ident)
return ti.anyType
}
ident.SetType(symbol.Type)
return symbol.Type
}
// inferTableLiteral infers table type
func (ti *TypeInferrer) inferTableLiteral(table *TableLiteral) *TypeInfo {
// Infer types of all values
for _, pair := range table.Pairs {
if pair.Key != nil {
ti.inferExpression(pair.Key)
}
ti.inferExpression(pair.Value)
}
table.SetType(ti.tableType)
return ti.tableType
}
// inferFunctionLiteral infers function type
func (ti *TypeInferrer) inferFunctionLiteral(fn *FunctionLiteral) *TypeInfo {
func (ti *TypeInferrer) inferFunctionLiteral(fn *FunctionLiteral) TypeInfo {
ti.enterScope()
// Define parameters in function scope
for _, param := range fn.Parameters {
paramType := ti.anyType
if param.TypeHint != nil {
paramType := AnyType
if param.TypeHint.Type != TypeUnknown {
paramType = param.TypeHint
}
@ -568,104 +478,88 @@ func (ti *TypeInferrer) inferFunctionLiteral(fn *FunctionLiteral) *TypeInfo {
}
ti.exitScope()
// For now, all functions have type "function"
funcType := &TypeInfo{Type: TypeFunction, Inferred: true}
fn.SetType(funcType)
return funcType
return FunctionType
}
// inferCallExpression infers function call return type
func (ti *TypeInferrer) inferCallExpression(call *CallExpression) *TypeInfo {
funcType := ti.inferExpression(call.Function)
if !ti.isFunctionType(funcType) {
ti.addError("cannot call non-function", call.Function)
return ti.anyType
}
func (ti *TypeInferrer) inferCallExpression(call *CallExpression) TypeInfo {
ti.inferExpression(call.Function)
// Infer argument types
for _, arg := range call.Arguments {
ti.inferExpression(arg)
}
// For now, assume function calls return any
call.SetType(ti.anyType)
return ti.anyType
call.typeInfo = AnyType
return AnyType
}
// inferPrefixExpression infers prefix operation type
func (ti *TypeInferrer) inferPrefixExpression(prefix *PrefixExpression) *TypeInfo {
func (ti *TypeInferrer) inferPrefixExpression(prefix *PrefixExpression) TypeInfo {
rightType := ti.inferExpression(prefix.Right)
var resultType *TypeInfo
var resultType TypeInfo
switch prefix.Operator {
case "-":
if !ti.isNumericType(rightType) {
ti.addError("unary minus requires numeric operand", prefix)
}
resultType = ti.numberType
resultType = NumberType
case "not":
resultType = ti.boolType
resultType = BoolType
default:
ti.addError(fmt.Sprintf("unknown prefix operator '%s'", prefix.Operator), prefix)
resultType = ti.anyType
resultType = AnyType
}
prefix.SetType(resultType)
prefix.typeInfo = resultType
return resultType
}
// inferInfixExpression infers binary operation type
func (ti *TypeInferrer) inferInfixExpression(infix *InfixExpression) *TypeInfo {
func (ti *TypeInferrer) inferInfixExpression(infix *InfixExpression) TypeInfo {
leftType := ti.inferExpression(infix.Left)
rightType := ti.inferExpression(infix.Right)
var resultType *TypeInfo
var resultType TypeInfo
switch infix.Operator {
case "+", "-", "*", "/":
if !ti.isNumericType(leftType) || !ti.isNumericType(rightType) {
ti.addError(fmt.Sprintf("arithmetic operator '%s' requires numeric operands", infix.Operator), infix)
}
resultType = ti.numberType
resultType = NumberType
case "==", "!=":
// Equality works with any types
resultType = ti.boolType
resultType = BoolType
case "<", ">", "<=", ">=":
if !ti.isComparableTypes(leftType, rightType) {
ti.addError(fmt.Sprintf("comparison operator '%s' requires compatible operands", infix.Operator), infix)
}
resultType = ti.boolType
resultType = BoolType
case "and", "or":
ti.validateBooleanContext(leftType, infix.Left)
ti.validateBooleanContext(rightType, infix.Right)
resultType = ti.boolType
resultType = BoolType
default:
ti.addError(fmt.Sprintf("unknown infix operator '%s'", infix.Operator), infix)
resultType = ti.anyType
resultType = AnyType
}
infix.SetType(resultType)
infix.typeInfo = resultType
return resultType
}
// inferIndexExpression infers table[index] type
func (ti *TypeInferrer) inferIndexExpression(index *IndexExpression) *TypeInfo {
func (ti *TypeInferrer) inferIndexExpression(index *IndexExpression) TypeInfo {
leftType := ti.inferExpression(index.Left)
ti.inferExpression(index.Index)
// If indexing a struct, try to infer field type
if ti.isStructType(leftType) {
if strLit, ok := index.Index.(*StringLiteral); ok {
if structDef, exists := ti.structs[leftType.Type]; exists {
if structDef := ti.GetStructByID(leftType.StructID); structDef != nil {
for _, field := range structDef.Fields {
if field.Name == strLit.Value {
index.SetType(field.TypeHint)
index.typeInfo = field.TypeHint
return field.TypeHint
}
}
@ -674,20 +568,19 @@ func (ti *TypeInferrer) inferIndexExpression(index *IndexExpression) *TypeInfo {
}
// For now, assume table/struct access returns any
index.SetType(ti.anyType)
return ti.anyType
index.typeInfo = AnyType
return AnyType
}
// inferDotExpression infers table.key type
func (ti *TypeInferrer) inferDotExpression(dot *DotExpression) *TypeInfo {
func (ti *TypeInferrer) inferDotExpression(dot *DotExpression) TypeInfo {
leftType := ti.inferExpression(dot.Left)
// If accessing a struct field, try to infer field type
if ti.isStructType(leftType) {
if structDef, exists := ti.structs[leftType.Type]; exists {
if structDef := ti.GetStructByID(leftType.StructID); structDef != nil {
for _, field := range structDef.Fields {
if field.Name == dot.Key {
dot.SetType(field.TypeHint)
dot.typeInfo = field.TypeHint
return field.TypeHint
}
}
@ -695,59 +588,107 @@ func (ti *TypeInferrer) inferDotExpression(dot *DotExpression) *TypeInfo {
}
// For now, assume member access returns any
dot.SetType(ti.anyType)
return ti.anyType
dot.typeInfo = AnyType
return AnyType
}
// Type checking helper methods
func (ti *TypeInferrer) inferAssignmentExpression(expr *Assignment) TypeInfo {
valueType := ti.inferExpression(expr.Value)
func (ti *TypeInferrer) isTypeCompatible(valueType, targetType *TypeInfo) bool {
if ident, ok := expr.Target.(*Identifier); ok {
if expr.IsDeclaration {
varType := valueType
if expr.TypeHint.Type != TypeUnknown {
if !ti.isTypeCompatible(valueType, expr.TypeHint) {
ti.addError("type mismatch in assignment", expr)
}
varType = expr.TypeHint
}
ti.currentScope.Define(&Symbol{
Name: ident.Value,
Type: varType,
Declared: true,
})
ident.typeInfo = varType
} else {
symbol := ti.currentScope.Lookup(ident.Value)
if symbol != nil {
ident.typeInfo = symbol.Type
}
}
} else {
ti.inferExpression(expr.Target)
}
return valueType
}
// Helper methods
func (ti *TypeInferrer) enterScope() {
ti.currentScope = NewScope(ti.currentScope)
}
func (ti *TypeInferrer) exitScope() {
if ti.currentScope.parent != nil {
ti.currentScope = ti.currentScope.parent
}
}
func (ti *TypeInferrer) addError(message string, node Node) {
ti.errors = append(ti.errors, TypeError{
Message: message,
Node: node,
})
}
func (ti *TypeInferrer) isValidType(t TypeInfo) bool {
if t.Type == TypeStruct {
return ti.GetStructByID(t.StructID) != nil
}
return t.Type <= TypeStruct
}
func (ti *TypeInferrer) isTypeCompatible(valueType, targetType TypeInfo) bool {
if targetType.Type == TypeAny || valueType.Type == TypeAny {
return true
}
if valueType.Type == TypeStruct && targetType.Type == TypeStruct {
return valueType.StructID == targetType.StructID
}
return valueType.Type == targetType.Type
}
func (ti *TypeInferrer) isNumericType(t *TypeInfo) bool {
func (ti *TypeInferrer) isNumericType(t TypeInfo) bool {
return t.Type == TypeNumber
}
func (ti *TypeInferrer) isBooleanType(t *TypeInfo) bool {
func (ti *TypeInferrer) isBooleanType(t TypeInfo) bool {
return t.Type == TypeBool
}
func (ti *TypeInferrer) isTableType(t *TypeInfo) bool {
func (ti *TypeInferrer) isTableType(t TypeInfo) bool {
return t.Type == TypeTable
}
func (ti *TypeInferrer) isFunctionType(t *TypeInfo) bool {
func (ti *TypeInferrer) isFunctionType(t TypeInfo) bool {
return t.Type == TypeFunction
}
func (ti *TypeInferrer) isComparableTypes(left, right *TypeInfo) bool {
func (ti *TypeInferrer) isStructType(t TypeInfo) bool {
return t.Type == TypeStruct
}
func (ti *TypeInferrer) isComparableTypes(left, right TypeInfo) bool {
if left.Type == TypeAny || right.Type == TypeAny {
return true
}
return left.Type == right.Type && (left.Type == TypeNumber || left.Type == TypeString)
}
func (ti *TypeInferrer) validateBooleanContext(t *TypeInfo, expr Expression) {
// In many languages, non-boolean values can be used in boolean context
// For strictness, we could require boolean type here
// For now, allow any type (truthy/falsy semantics)
}
// Errors returns all type checking errors
func (ti *TypeInferrer) Errors() []TypeError {
return ti.errors
}
// HasErrors returns true if there are any type errors
func (ti *TypeInferrer) HasErrors() bool {
return len(ti.errors) > 0
}
// ErrorStrings returns error messages as strings
// Error reporting
func (ti *TypeInferrer) Errors() []TypeError { return ti.errors }
func (ti *TypeInferrer) HasErrors() bool { return len(ti.errors) > 0 }
func (ti *TypeInferrer) ErrorStrings() []string {
result := make([]string, len(ti.errors))
for i, err := range ti.errors {
@ -756,21 +697,26 @@ func (ti *TypeInferrer) ErrorStrings() []string {
return result
}
// ValidTypeName checks if a string is a valid type name
func ValidTypeName(name string) bool {
validTypes := []string{TypeNumber, TypeString, TypeBool, TypeNil, TypeTable, TypeFunction, TypeAny}
for _, validType := range validTypes {
if name == validType {
return true
// Type string conversion
func TypeToString(t TypeInfo) string {
switch t.Type {
case TypeNumber:
return "number"
case TypeString:
return "string"
case TypeBool:
return "bool"
case TypeNil:
return "nil"
case TypeTable:
return "table"
case TypeFunction:
return "function"
case TypeAny:
return "any"
case TypeStruct:
return fmt.Sprintf("struct<%d>", t.StructID)
default:
return "unknown"
}
}
return false
}
// ParseTypeName converts a string to a TypeInfo (for parsing type hints)
func ParseTypeName(name string) *TypeInfo {
if ValidTypeName(name) {
return &TypeInfo{Type: name, Inferred: false}
}
return nil
}