AST/types optimization
This commit is contained in:
parent
5ae2a6ef23
commit
30e4b11a96
358
parser/ast.go
358
parser/ast.go
@ -2,32 +2,28 @@ package parser
|
||||
|
||||
import "fmt"
|
||||
|
||||
// TypeInfo represents type information for expressions
|
||||
type TypeInfo struct {
|
||||
Type string // "number", "string", "bool", "table", "function", "nil", "any", struct name
|
||||
Inferred bool // true if type was inferred, false if explicitly declared
|
||||
}
|
||||
// Note: Type definitions moved to types.go for proper separation of concerns
|
||||
|
||||
// Node represents any node in the AST
|
||||
type Node interface {
|
||||
String() string
|
||||
}
|
||||
|
||||
// Statement represents statement nodes
|
||||
// Statement represents statement nodes that can appear at the top level or in blocks
|
||||
type Statement interface {
|
||||
Node
|
||||
statementNode()
|
||||
}
|
||||
|
||||
// Expression represents expression nodes
|
||||
// Expression represents expression nodes that produce values and have types
|
||||
type Expression interface {
|
||||
Node
|
||||
expressionNode()
|
||||
GetType() *TypeInfo
|
||||
SetType(*TypeInfo)
|
||||
TypeInfo() TypeInfo // Returns type by value, not pointer
|
||||
}
|
||||
|
||||
// Program represents the root of the AST
|
||||
// Program represents the root of the AST containing all top-level statements.
|
||||
// Tracks exit code for script termination and owns the statement list.
|
||||
type Program struct {
|
||||
Statements []Statement
|
||||
ExitCode int
|
||||
@ -41,23 +37,23 @@ func (p *Program) String() string {
|
||||
return result
|
||||
}
|
||||
|
||||
// StructField represents a field in a struct definition
|
||||
// StructField represents a field definition within a struct.
|
||||
// Contains field name and required type annotation for compile-time checking.
|
||||
type StructField struct {
|
||||
Name string
|
||||
TypeHint *TypeInfo
|
||||
TypeHint TypeInfo // Required for struct fields, embeds directly
|
||||
}
|
||||
|
||||
func (sf *StructField) String() string {
|
||||
if sf.TypeHint != nil {
|
||||
return fmt.Sprintf("%s: %s", sf.Name, sf.TypeHint.Type)
|
||||
}
|
||||
return sf.Name
|
||||
return fmt.Sprintf("%s: %s", sf.Name, typeToString(sf.TypeHint))
|
||||
}
|
||||
|
||||
// StructStatement represents struct definitions
|
||||
// StructStatement represents struct type definitions with named fields.
|
||||
// Defines new types that can be instantiated and used for type checking.
|
||||
type StructStatement struct {
|
||||
Name string
|
||||
Fields []StructField
|
||||
ID uint16 // Unique identifier for fast lookup
|
||||
}
|
||||
|
||||
func (ss *StructStatement) statementNode() {}
|
||||
@ -72,77 +68,72 @@ func (ss *StructStatement) String() string {
|
||||
return fmt.Sprintf("struct %s {\n\t%s\n}", ss.Name, fields)
|
||||
}
|
||||
|
||||
// MethodDefinition represents method definitions on structs
|
||||
// MethodDefinition represents method definitions attached to struct types.
|
||||
// Links a function implementation to a specific struct via struct ID.
|
||||
type MethodDefinition struct {
|
||||
StructName string
|
||||
StructID uint16 // Index into struct table for fast lookup
|
||||
MethodName string
|
||||
Function *FunctionLiteral
|
||||
}
|
||||
|
||||
func (md *MethodDefinition) statementNode() {}
|
||||
func (md *MethodDefinition) String() string {
|
||||
return fmt.Sprintf("fn %s.%s%s", md.StructName, md.MethodName, md.Function.String()[2:]) // skip "fn" from function string
|
||||
return fmt.Sprintf("fn <struct>.%s%s", md.MethodName, md.Function.String()[2:])
|
||||
}
|
||||
|
||||
// StructConstructorExpression represents struct constructor calls like my_type{...}
|
||||
type StructConstructorExpression struct {
|
||||
StructName string
|
||||
Fields []TablePair // reuse TablePair for field assignments
|
||||
typeInfo *TypeInfo
|
||||
// StructConstructor represents struct instantiation with field initialization.
|
||||
// Uses struct ID for fast type resolution and validation during parsing.
|
||||
type StructConstructor struct {
|
||||
StructID uint16 // Index into struct table
|
||||
Fields []TablePair // Reuses table pair structure for field assignments
|
||||
typeInfo TypeInfo // Cached type info for this constructor
|
||||
}
|
||||
|
||||
func (sce *StructConstructorExpression) expressionNode() {}
|
||||
func (sce *StructConstructorExpression) String() string {
|
||||
func (sc *StructConstructor) expressionNode() {}
|
||||
func (sc *StructConstructor) String() string {
|
||||
var pairs []string
|
||||
for _, pair := range sce.Fields {
|
||||
for _, pair := range sc.Fields {
|
||||
pairs = append(pairs, pair.String())
|
||||
}
|
||||
return fmt.Sprintf("%s{%s}", sce.StructName, joinStrings(pairs, ", "))
|
||||
return fmt.Sprintf("<struct>{%s}", joinStrings(pairs, ", "))
|
||||
}
|
||||
func (sce *StructConstructorExpression) GetType() *TypeInfo { return sce.typeInfo }
|
||||
func (sce *StructConstructorExpression) SetType(t *TypeInfo) { sce.typeInfo = t }
|
||||
func (sc *StructConstructor) TypeInfo() TypeInfo { return sc.typeInfo }
|
||||
|
||||
// AssignStatement represents variable assignment with optional type hint
|
||||
type AssignStatement struct {
|
||||
Name Expression // Changed from *Identifier to Expression for member access
|
||||
TypeHint *TypeInfo // optional type hint
|
||||
Value Expression
|
||||
IsDeclaration bool // true if this is the first assignment in current scope
|
||||
// Assignment represents both variable assignment statements and assignment expressions.
|
||||
// Unified design reduces AST node count and simplifies type checking logic.
|
||||
type Assignment struct {
|
||||
Target Expression // Target (identifier, dot, or index expression)
|
||||
Value Expression // Value being assigned
|
||||
TypeHint TypeInfo // Optional explicit type hint, embeds directly
|
||||
IsDeclaration bool // True if declaring new variable in current scope
|
||||
IsExpression bool // True if used as expression (wrapped in parentheses)
|
||||
}
|
||||
|
||||
func (as *AssignStatement) statementNode() {}
|
||||
func (as *AssignStatement) String() string {
|
||||
func (a *Assignment) statementNode() {}
|
||||
func (a *Assignment) expressionNode() {}
|
||||
func (a *Assignment) String() string {
|
||||
prefix := ""
|
||||
if as.IsDeclaration {
|
||||
if a.IsDeclaration {
|
||||
prefix = "local "
|
||||
}
|
||||
|
||||
var nameStr string
|
||||
if as.TypeHint != nil {
|
||||
nameStr = fmt.Sprintf("%s: %s", as.Name.String(), as.TypeHint.Type)
|
||||
if a.TypeHint.Type != TypeUnknown {
|
||||
nameStr = fmt.Sprintf("%s: %s", a.Target.String(), typeToString(a.TypeHint))
|
||||
} else {
|
||||
nameStr = as.Name.String()
|
||||
nameStr = a.Target.String()
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s%s = %s", prefix, nameStr, as.Value.String())
|
||||
result := fmt.Sprintf("%s%s = %s", prefix, nameStr, a.Value.String())
|
||||
if a.IsExpression {
|
||||
return "(" + result + ")"
|
||||
}
|
||||
return result
|
||||
}
|
||||
func (a *Assignment) TypeInfo() TypeInfo { return a.Value.TypeInfo() }
|
||||
|
||||
// AssignExpression represents assignment as an expression (only in parentheses)
|
||||
type AssignExpression struct {
|
||||
Name Expression // Target (identifier, dot, or index expression)
|
||||
Value Expression // Value to assign
|
||||
IsDeclaration bool // true if this declares a new variable
|
||||
typeInfo *TypeInfo // type of the expression (same as assigned value)
|
||||
}
|
||||
|
||||
func (ae *AssignExpression) expressionNode() {}
|
||||
func (ae *AssignExpression) String() string {
|
||||
return fmt.Sprintf("(%s = %s)", ae.Name.String(), ae.Value.String())
|
||||
}
|
||||
func (ae *AssignExpression) GetType() *TypeInfo { return ae.typeInfo }
|
||||
func (ae *AssignExpression) SetType(t *TypeInfo) { ae.typeInfo = t }
|
||||
|
||||
// ExpressionStatement represents expressions used as statements
|
||||
// ExpressionStatement wraps expressions used as statements.
|
||||
// Allows function calls and other expressions at statement level.
|
||||
type ExpressionStatement struct {
|
||||
Expression Expression
|
||||
}
|
||||
@ -152,7 +143,8 @@ func (es *ExpressionStatement) String() string {
|
||||
return es.Expression.String()
|
||||
}
|
||||
|
||||
// EchoStatement represents echo output statements
|
||||
// EchoStatement represents output statements for displaying values.
|
||||
// Simple debugging and output mechanism built into the language.
|
||||
type EchoStatement struct {
|
||||
Value Expression
|
||||
}
|
||||
@ -162,17 +154,17 @@ func (es *EchoStatement) String() string {
|
||||
return fmt.Sprintf("echo %s", es.Value.String())
|
||||
}
|
||||
|
||||
// BreakStatement represents break statements to exit loops
|
||||
// BreakStatement represents loop exit statements.
|
||||
// Simple marker node with no additional data needed.
|
||||
type BreakStatement struct{}
|
||||
|
||||
func (bs *BreakStatement) statementNode() {}
|
||||
func (bs *BreakStatement) String() string {
|
||||
return "break"
|
||||
}
|
||||
func (bs *BreakStatement) String() string { return "break" }
|
||||
|
||||
// ExitStatement represents exit statements to quit the script
|
||||
// ExitStatement represents script termination with optional exit code.
|
||||
// Value expression is nil for plain "exit", non-nil for "exit <code>".
|
||||
type ExitStatement struct {
|
||||
Value Expression // optional, can be nil
|
||||
Value Expression // Optional exit code expression
|
||||
}
|
||||
|
||||
func (es *ExitStatement) statementNode() {}
|
||||
@ -183,9 +175,10 @@ func (es *ExitStatement) String() string {
|
||||
return fmt.Sprintf("exit %s", es.Value.String())
|
||||
}
|
||||
|
||||
// ReturnStatement represents return statements
|
||||
// ReturnStatement represents function return with optional value.
|
||||
// Value expression is nil for plain "return", non-nil for "return <value>".
|
||||
type ReturnStatement struct {
|
||||
Value Expression // optional, can be nil
|
||||
Value Expression // Optional return value expression
|
||||
}
|
||||
|
||||
func (rs *ReturnStatement) statementNode() {}
|
||||
@ -196,7 +189,8 @@ func (rs *ReturnStatement) String() string {
|
||||
return fmt.Sprintf("return %s", rs.Value.String())
|
||||
}
|
||||
|
||||
// ElseIfClause represents an elseif condition
|
||||
// ElseIfClause represents conditional branches in if statements.
|
||||
// Contains condition expression and body statements for this branch.
|
||||
type ElseIfClause struct {
|
||||
Condition Expression
|
||||
Body []Statement
|
||||
@ -210,30 +204,28 @@ func (eic *ElseIfClause) String() string {
|
||||
return fmt.Sprintf("elseif %s then\n%s", eic.Condition.String(), body)
|
||||
}
|
||||
|
||||
// IfStatement represents conditional statements
|
||||
// IfStatement represents conditional execution with optional elseif and else branches.
|
||||
// Supports multiple elseif clauses and an optional final else clause.
|
||||
type IfStatement struct {
|
||||
Condition Expression
|
||||
Body []Statement
|
||||
ElseIfs []ElseIfClause
|
||||
Else []Statement
|
||||
Condition Expression // Main condition
|
||||
Body []Statement // Statements to execute if condition is true
|
||||
ElseIfs []ElseIfClause // Optional elseif branches
|
||||
Else []Statement // Optional else branch
|
||||
}
|
||||
|
||||
func (is *IfStatement) statementNode() {}
|
||||
func (is *IfStatement) String() string {
|
||||
var result string
|
||||
|
||||
// If clause
|
||||
result += fmt.Sprintf("if %s then\n", is.Condition.String())
|
||||
for _, stmt := range is.Body {
|
||||
result += "\t" + stmt.String() + "\n"
|
||||
}
|
||||
|
||||
// ElseIf clauses
|
||||
for _, elseif := range is.ElseIfs {
|
||||
result += elseif.String()
|
||||
}
|
||||
|
||||
// Else clause
|
||||
if len(is.Else) > 0 {
|
||||
result += "else\n"
|
||||
for _, stmt := range is.Else {
|
||||
@ -245,7 +237,8 @@ func (is *IfStatement) String() string {
|
||||
return result
|
||||
}
|
||||
|
||||
// WhileStatement represents while loops: while condition do ... end
|
||||
// WhileStatement represents condition-based loops that execute while condition is true.
|
||||
// Contains condition expression and body statements to repeat.
|
||||
type WhileStatement struct {
|
||||
Condition Expression
|
||||
Body []Statement
|
||||
@ -264,13 +257,14 @@ func (ws *WhileStatement) String() string {
|
||||
return result
|
||||
}
|
||||
|
||||
// ForStatement represents numeric for loops: for i = start, end, step do ... end
|
||||
// ForStatement represents numeric for loops with start, end, and optional step.
|
||||
// Variable is automatically scoped to the loop body.
|
||||
type ForStatement struct {
|
||||
Variable *Identifier
|
||||
Start Expression
|
||||
End Expression
|
||||
Step Expression // optional, nil means step of 1
|
||||
Body []Statement
|
||||
Variable *Identifier // Loop variable (automatically number type)
|
||||
Start Expression // Starting value expression
|
||||
End Expression // Ending value expression
|
||||
Step Expression // Optional step expression (nil means step of 1)
|
||||
Body []Statement // Loop body statements
|
||||
}
|
||||
|
||||
func (fs *ForStatement) statementNode() {}
|
||||
@ -292,12 +286,13 @@ func (fs *ForStatement) String() string {
|
||||
return result
|
||||
}
|
||||
|
||||
// ForInStatement represents iterator for loops: for k, v in expr do ... end
|
||||
// ForInStatement represents iterator-based loops over tables, arrays, or other iterables.
|
||||
// Supports both single variable (for v in iter) and key-value (for k,v in iter) forms.
|
||||
type ForInStatement struct {
|
||||
Key *Identifier // optional, nil for single variable iteration
|
||||
Value *Identifier
|
||||
Iterable Expression
|
||||
Body []Statement
|
||||
Key *Identifier // Optional key variable (nil for single variable iteration)
|
||||
Value *Identifier // Value variable (required)
|
||||
Iterable Expression // Expression to iterate over
|
||||
Body []Statement // Loop body statements
|
||||
}
|
||||
|
||||
func (fis *ForInStatement) statementNode() {}
|
||||
@ -319,56 +314,60 @@ func (fis *ForInStatement) String() string {
|
||||
return result
|
||||
}
|
||||
|
||||
// FunctionParameter represents a function parameter with optional type hint
|
||||
// FunctionParameter represents a parameter in function definitions.
|
||||
// Contains parameter name and optional type hint for type checking.
|
||||
type FunctionParameter struct {
|
||||
Name string
|
||||
TypeHint *TypeInfo
|
||||
TypeHint TypeInfo // Optional type constraint, embeds directly
|
||||
}
|
||||
|
||||
func (fp *FunctionParameter) String() string {
|
||||
if fp.TypeHint != nil {
|
||||
return fmt.Sprintf("%s: %s", fp.Name, fp.TypeHint.Type)
|
||||
if fp.TypeHint.Type != TypeUnknown {
|
||||
return fmt.Sprintf("%s: %s", fp.Name, typeToString(fp.TypeHint))
|
||||
}
|
||||
return fp.Name
|
||||
}
|
||||
|
||||
// Identifier represents identifiers
|
||||
// Identifier represents variable references and names.
|
||||
// Stores resolved type information for efficient type checking.
|
||||
type Identifier struct {
|
||||
Value string
|
||||
typeInfo *TypeInfo
|
||||
typeInfo TypeInfo // Resolved type, embeds directly
|
||||
}
|
||||
|
||||
func (i *Identifier) expressionNode() {}
|
||||
func (i *Identifier) String() string { return i.Value }
|
||||
func (i *Identifier) GetType() *TypeInfo { return i.typeInfo }
|
||||
func (i *Identifier) SetType(t *TypeInfo) { i.typeInfo = t }
|
||||
func (i *Identifier) TypeInfo() TypeInfo {
|
||||
if i.typeInfo.Type == TypeUnknown {
|
||||
return AnyType
|
||||
}
|
||||
return i.typeInfo
|
||||
}
|
||||
|
||||
// NumberLiteral represents numeric literals
|
||||
// NumberLiteral represents numeric constants including integers, floats, hex, and binary.
|
||||
// Always has number type, so no additional type storage needed.
|
||||
type NumberLiteral struct {
|
||||
Value float64
|
||||
typeInfo *TypeInfo
|
||||
Value float64 // All numbers stored as float64 for simplicity
|
||||
}
|
||||
|
||||
func (nl *NumberLiteral) expressionNode() {}
|
||||
func (nl *NumberLiteral) String() string { return fmt.Sprintf("%.2f", nl.Value) }
|
||||
func (nl *NumberLiteral) GetType() *TypeInfo { return nl.typeInfo }
|
||||
func (nl *NumberLiteral) SetType(t *TypeInfo) { nl.typeInfo = t }
|
||||
func (nl *NumberLiteral) TypeInfo() TypeInfo { return NumberType }
|
||||
|
||||
// StringLiteral represents string literals
|
||||
// StringLiteral represents string constants and multiline strings.
|
||||
// Always has string type, so no additional type storage needed.
|
||||
type StringLiteral struct {
|
||||
Value string
|
||||
typeInfo *TypeInfo
|
||||
Value string // String content without quotes
|
||||
}
|
||||
|
||||
func (sl *StringLiteral) expressionNode() {}
|
||||
func (sl *StringLiteral) String() string { return fmt.Sprintf(`"%s"`, sl.Value) }
|
||||
func (sl *StringLiteral) GetType() *TypeInfo { return sl.typeInfo }
|
||||
func (sl *StringLiteral) SetType(t *TypeInfo) { sl.typeInfo = t }
|
||||
func (sl *StringLiteral) TypeInfo() TypeInfo { return StringType }
|
||||
|
||||
// BooleanLiteral represents boolean literals
|
||||
// BooleanLiteral represents true and false constants.
|
||||
// Always has bool type, so no additional type storage needed.
|
||||
type BooleanLiteral struct {
|
||||
Value bool
|
||||
typeInfo *TypeInfo
|
||||
}
|
||||
|
||||
func (bl *BooleanLiteral) expressionNode() {}
|
||||
@ -378,26 +377,23 @@ func (bl *BooleanLiteral) String() string {
|
||||
}
|
||||
return "false"
|
||||
}
|
||||
func (bl *BooleanLiteral) GetType() *TypeInfo { return bl.typeInfo }
|
||||
func (bl *BooleanLiteral) SetType(t *TypeInfo) { bl.typeInfo = t }
|
||||
func (bl *BooleanLiteral) TypeInfo() TypeInfo { return BoolType }
|
||||
|
||||
// NilLiteral represents nil literal
|
||||
type NilLiteral struct {
|
||||
typeInfo *TypeInfo
|
||||
}
|
||||
// NilLiteral represents the nil constant value.
|
||||
// Always has nil type, so no additional type storage needed.
|
||||
type NilLiteral struct{}
|
||||
|
||||
func (nl *NilLiteral) expressionNode() {}
|
||||
func (nl *NilLiteral) String() string { return "nil" }
|
||||
func (nl *NilLiteral) GetType() *TypeInfo { return nl.typeInfo }
|
||||
func (nl *NilLiteral) SetType(t *TypeInfo) { nl.typeInfo = t }
|
||||
func (nl *NilLiteral) TypeInfo() TypeInfo { return NilType }
|
||||
|
||||
// FunctionLiteral represents function literals with typed parameters
|
||||
// FunctionLiteral represents function definitions with parameters, body, and optional return type.
|
||||
// Always has function type, stores additional return type information separately.
|
||||
type FunctionLiteral struct {
|
||||
Parameters []FunctionParameter
|
||||
Variadic bool
|
||||
ReturnType *TypeInfo // optional return type hint
|
||||
Body []Statement
|
||||
typeInfo *TypeInfo
|
||||
Parameters []FunctionParameter // Function parameters with optional types
|
||||
Body []Statement // Function body statements
|
||||
ReturnType TypeInfo // Optional return type hint, embeds directly
|
||||
Variadic bool // True if function accepts variable arguments
|
||||
}
|
||||
|
||||
func (fl *FunctionLiteral) expressionNode() {}
|
||||
@ -417,8 +413,8 @@ func (fl *FunctionLiteral) String() string {
|
||||
}
|
||||
|
||||
result := fmt.Sprintf("fn(%s)", params)
|
||||
if fl.ReturnType != nil {
|
||||
result += ": " + fl.ReturnType.Type
|
||||
if fl.ReturnType.Type != TypeUnknown {
|
||||
result += ": " + typeToString(fl.ReturnType)
|
||||
}
|
||||
result += "\n"
|
||||
|
||||
@ -428,14 +424,14 @@ func (fl *FunctionLiteral) String() string {
|
||||
result += "end"
|
||||
return result
|
||||
}
|
||||
func (fl *FunctionLiteral) GetType() *TypeInfo { return fl.typeInfo }
|
||||
func (fl *FunctionLiteral) SetType(t *TypeInfo) { fl.typeInfo = t }
|
||||
func (fl *FunctionLiteral) TypeInfo() TypeInfo { return FunctionType }
|
||||
|
||||
// CallExpression represents function calls: func(arg1, arg2, ...)
|
||||
// CallExpression represents function calls with arguments.
|
||||
// Stores inferred return type from function signature analysis.
|
||||
type CallExpression struct {
|
||||
Function Expression
|
||||
Arguments []Expression
|
||||
typeInfo *TypeInfo
|
||||
Function Expression // Function expression to call
|
||||
Arguments []Expression // Argument expressions
|
||||
typeInfo TypeInfo // Inferred return type, embeds directly
|
||||
}
|
||||
|
||||
func (ce *CallExpression) expressionNode() {}
|
||||
@ -446,74 +442,73 @@ func (ce *CallExpression) String() string {
|
||||
}
|
||||
return fmt.Sprintf("%s(%s)", ce.Function.String(), joinStrings(args, ", "))
|
||||
}
|
||||
func (ce *CallExpression) GetType() *TypeInfo { return ce.typeInfo }
|
||||
func (ce *CallExpression) SetType(t *TypeInfo) { ce.typeInfo = t }
|
||||
func (ce *CallExpression) TypeInfo() TypeInfo { return ce.typeInfo }
|
||||
|
||||
// PrefixExpression represents prefix operations like -x, not x
|
||||
// PrefixExpression represents unary operations like negation and logical not.
|
||||
// Stores result type based on operator and operand type analysis.
|
||||
type PrefixExpression struct {
|
||||
Operator string
|
||||
Right Expression
|
||||
typeInfo *TypeInfo
|
||||
Operator string // Operator symbol ("-", "not")
|
||||
Right Expression // Operand expression
|
||||
typeInfo TypeInfo // Result type, embeds directly
|
||||
}
|
||||
|
||||
func (pe *PrefixExpression) expressionNode() {}
|
||||
func (pe *PrefixExpression) String() string {
|
||||
// Add space for word operators
|
||||
if pe.Operator == "not" {
|
||||
return fmt.Sprintf("(%s %s)", pe.Operator, pe.Right.String())
|
||||
}
|
||||
return fmt.Sprintf("(%s%s)", pe.Operator, pe.Right.String())
|
||||
}
|
||||
func (pe *PrefixExpression) GetType() *TypeInfo { return pe.typeInfo }
|
||||
func (pe *PrefixExpression) SetType(t *TypeInfo) { pe.typeInfo = t }
|
||||
func (pe *PrefixExpression) TypeInfo() TypeInfo { return pe.typeInfo }
|
||||
|
||||
// InfixExpression represents binary operations
|
||||
// InfixExpression represents binary operations between two expressions.
|
||||
// Stores result type based on operator and operand type compatibility.
|
||||
type InfixExpression struct {
|
||||
Left Expression
|
||||
Operator string
|
||||
Right Expression
|
||||
typeInfo *TypeInfo
|
||||
Left Expression // Left operand
|
||||
Right Expression // Right operand
|
||||
Operator string // Operator symbol ("+", "-", "==", "and", etc.)
|
||||
typeInfo TypeInfo // Result type, embeds directly
|
||||
}
|
||||
|
||||
func (ie *InfixExpression) expressionNode() {}
|
||||
func (ie *InfixExpression) String() string {
|
||||
return fmt.Sprintf("(%s %s %s)", ie.Left.String(), ie.Operator, ie.Right.String())
|
||||
}
|
||||
func (ie *InfixExpression) GetType() *TypeInfo { return ie.typeInfo }
|
||||
func (ie *InfixExpression) SetType(t *TypeInfo) { ie.typeInfo = t }
|
||||
func (ie *InfixExpression) TypeInfo() TypeInfo { return ie.typeInfo }
|
||||
|
||||
// IndexExpression represents table[key] access
|
||||
// IndexExpression represents bracket-based member access (table[key]).
|
||||
// Stores inferred element type based on container type analysis.
|
||||
type IndexExpression struct {
|
||||
Left Expression
|
||||
Index Expression
|
||||
typeInfo *TypeInfo
|
||||
Left Expression // Container expression
|
||||
Index Expression // Index/key expression
|
||||
typeInfo TypeInfo // Element type, embeds directly
|
||||
}
|
||||
|
||||
func (ie *IndexExpression) expressionNode() {}
|
||||
func (ie *IndexExpression) String() string {
|
||||
return fmt.Sprintf("%s[%s]", ie.Left.String(), ie.Index.String())
|
||||
}
|
||||
func (ie *IndexExpression) GetType() *TypeInfo { return ie.typeInfo }
|
||||
func (ie *IndexExpression) SetType(t *TypeInfo) { ie.typeInfo = t }
|
||||
func (ie *IndexExpression) TypeInfo() TypeInfo { return ie.typeInfo }
|
||||
|
||||
// DotExpression represents table.key access
|
||||
// DotExpression represents dot-based member access (table.key).
|
||||
// Stores inferred member type based on container type and field analysis.
|
||||
type DotExpression struct {
|
||||
Left Expression
|
||||
Key string
|
||||
typeInfo *TypeInfo
|
||||
Left Expression // Container expression
|
||||
Key string // Member name
|
||||
typeInfo TypeInfo // Member type, embeds directly
|
||||
}
|
||||
|
||||
func (de *DotExpression) expressionNode() {}
|
||||
func (de *DotExpression) String() string {
|
||||
return fmt.Sprintf("%s.%s", de.Left.String(), de.Key)
|
||||
}
|
||||
func (de *DotExpression) GetType() *TypeInfo { return de.typeInfo }
|
||||
func (de *DotExpression) SetType(t *TypeInfo) { de.typeInfo = t }
|
||||
func (de *DotExpression) TypeInfo() TypeInfo { return de.typeInfo }
|
||||
|
||||
// TablePair represents a key-value pair in a table
|
||||
// TablePair represents key-value pairs in table literals and struct constructors.
|
||||
// Key is nil for array-style elements, non-nil for object-style elements.
|
||||
type TablePair struct {
|
||||
Key Expression // nil for array-style elements
|
||||
Value Expression
|
||||
Key Expression // Key expression (nil for array elements)
|
||||
Value Expression // Value expression
|
||||
}
|
||||
|
||||
func (tp *TablePair) String() string {
|
||||
@ -523,10 +518,10 @@ func (tp *TablePair) String() string {
|
||||
return fmt.Sprintf("%s = %s", tp.Key.String(), tp.Value.String())
|
||||
}
|
||||
|
||||
// TableLiteral represents table literals {}
|
||||
// TableLiteral represents table/array/object literals with key-value pairs.
|
||||
// Always has table type, provides methods to check if it's array-style.
|
||||
type TableLiteral struct {
|
||||
Pairs []TablePair
|
||||
typeInfo *TypeInfo
|
||||
Pairs []TablePair // Key-value pairs (key nil for array elements)
|
||||
}
|
||||
|
||||
func (tl *TableLiteral) expressionNode() {}
|
||||
@ -537,10 +532,9 @@ func (tl *TableLiteral) String() string {
|
||||
}
|
||||
return fmt.Sprintf("{%s}", joinStrings(pairs, ", "))
|
||||
}
|
||||
func (tl *TableLiteral) GetType() *TypeInfo { return tl.typeInfo }
|
||||
func (tl *TableLiteral) SetType(t *TypeInfo) { tl.typeInfo = t }
|
||||
func (tl *TableLiteral) TypeInfo() TypeInfo { return TableType }
|
||||
|
||||
// IsArray returns true if this table contains only array-style elements
|
||||
// IsArray returns true if this table contains only array-style elements (no explicit keys)
|
||||
func (tl *TableLiteral) IsArray() bool {
|
||||
for _, pair := range tl.Pairs {
|
||||
if pair.Key != nil {
|
||||
@ -550,7 +544,31 @@ func (tl *TableLiteral) IsArray() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// joinStrings joins string slice with separator
|
||||
// Helper function to convert TypeInfo to string representation
|
||||
func typeToString(t TypeInfo) string {
|
||||
switch t.Type {
|
||||
case TypeNumber:
|
||||
return "number"
|
||||
case TypeString:
|
||||
return "string"
|
||||
case TypeBool:
|
||||
return "bool"
|
||||
case TypeNil:
|
||||
return "nil"
|
||||
case TypeTable:
|
||||
return "table"
|
||||
case TypeFunction:
|
||||
return "function"
|
||||
case TypeAny:
|
||||
return "any"
|
||||
case TypeStruct:
|
||||
return fmt.Sprintf("struct<%d>", t.StructID)
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// joinStrings efficiently joins string slice with separator
|
||||
func joinStrings(strs []string, sep string) string {
|
||||
if len(strs) == 0 {
|
||||
return ""
|
||||
|
252
parser/parser.go
252
parser/parser.go
@ -19,7 +19,7 @@ func (pe ParseError) Error() string {
|
||||
pe.Line, pe.Column, pe.Message, pe.Token.Literal)
|
||||
}
|
||||
|
||||
// Parser implements a recursive descent Pratt parser
|
||||
// Parser implements a recursive descent Pratt parser with optimized AST generation
|
||||
type Parser struct {
|
||||
lexer *Lexer
|
||||
|
||||
@ -32,11 +32,13 @@ type Parser struct {
|
||||
errors []ParseError
|
||||
|
||||
// Scope tracking
|
||||
scopes []map[string]bool // stack of scopes, each tracking declared variables
|
||||
scopeTypes []string // track what type each scope is: "global", "function", "loop"
|
||||
scopes []map[string]bool
|
||||
scopeTypes []string
|
||||
|
||||
// Struct tracking
|
||||
structs map[string]*StructStatement // track defined structs
|
||||
// Struct tracking with ID mapping
|
||||
structs map[string]*StructStatement
|
||||
structIDs map[uint16]*StructStatement
|
||||
nextID uint16
|
||||
}
|
||||
|
||||
// NewParser creates a new parser instance
|
||||
@ -44,9 +46,11 @@ func NewParser(lexer *Lexer) *Parser {
|
||||
p := &Parser{
|
||||
lexer: lexer,
|
||||
errors: []ParseError{},
|
||||
scopes: []map[string]bool{make(map[string]bool)}, // start with global scope
|
||||
scopeTypes: []string{"global"}, // start with global scope type
|
||||
structs: make(map[string]*StructStatement), // track struct definitions
|
||||
scopes: []map[string]bool{make(map[string]bool)},
|
||||
scopeTypes: []string{"global"},
|
||||
structs: make(map[string]*StructStatement),
|
||||
structIDs: make(map[uint16]*StructStatement),
|
||||
nextID: 1, // 0 reserved for non-struct types
|
||||
}
|
||||
|
||||
p.prefixParseFns = make(map[TokenType]func() Expression)
|
||||
@ -78,15 +82,31 @@ func NewParser(lexer *Lexer) *Parser {
|
||||
p.registerInfix(DOT, p.parseDotExpression)
|
||||
p.registerInfix(LBRACKET, p.parseIndexExpression)
|
||||
p.registerInfix(LPAREN, p.parseCallExpression)
|
||||
p.registerInfix(LBRACE, p.parseStructConstructor) // struct constructor
|
||||
p.registerInfix(LBRACE, p.parseStructConstructor)
|
||||
|
||||
// Read two tokens, so curToken and peekToken are both set
|
||||
p.nextToken()
|
||||
p.nextToken()
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
// Struct management
|
||||
func (p *Parser) registerStruct(stmt *StructStatement) {
|
||||
stmt.ID = p.nextID
|
||||
p.nextID++
|
||||
p.structs[stmt.Name] = stmt
|
||||
p.structIDs[stmt.ID] = stmt
|
||||
}
|
||||
|
||||
func (p *Parser) getStructByName(name string) *StructStatement {
|
||||
return p.structs[name]
|
||||
}
|
||||
|
||||
func (p *Parser) isStructDefined(name string) bool {
|
||||
_, exists := p.structs[name]
|
||||
return exists
|
||||
}
|
||||
|
||||
// Scope management
|
||||
func (p *Parser) enterScope(scopeType string) {
|
||||
p.scopes = append(p.scopes, make(map[string]bool))
|
||||
@ -100,30 +120,6 @@ func (p *Parser) exitScope() {
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Parser) enterFunctionScope() {
|
||||
p.enterScope("function")
|
||||
}
|
||||
|
||||
func (p *Parser) exitFunctionScope() {
|
||||
p.exitScope()
|
||||
}
|
||||
|
||||
func (p *Parser) enterLoopScope() {
|
||||
p.enterScope("loop")
|
||||
}
|
||||
|
||||
func (p *Parser) exitLoopScope() {
|
||||
p.exitScope()
|
||||
}
|
||||
|
||||
func (p *Parser) enterBlockScope() {
|
||||
// Blocks don't create new variable scopes
|
||||
}
|
||||
|
||||
func (p *Parser) exitBlockScope() {
|
||||
// No-op
|
||||
}
|
||||
|
||||
func (p *Parser) currentVariableScope() map[string]bool {
|
||||
if len(p.scopeTypes) > 1 && p.scopeTypes[len(p.scopeTypes)-1] == "loop" {
|
||||
return p.scopes[len(p.scopes)-2]
|
||||
@ -148,45 +144,56 @@ func (p *Parser) declareLoopVariable(name string) {
|
||||
p.scopes[len(p.scopes)-1][name] = true
|
||||
}
|
||||
|
||||
// parseTypeHint parses optional type hint after colon
|
||||
func (p *Parser) parseTypeHint() *TypeInfo {
|
||||
// parseTypeHint parses optional type hint after colon, returns by value
|
||||
func (p *Parser) parseTypeHint() TypeInfo {
|
||||
if !p.peekTokenIs(COLON) {
|
||||
return nil
|
||||
return UnknownType
|
||||
}
|
||||
|
||||
p.nextToken() // consume ':'
|
||||
|
||||
if !p.expectPeekIdent() {
|
||||
p.addError("expected type name after ':'")
|
||||
return nil
|
||||
return UnknownType
|
||||
}
|
||||
|
||||
typeName := p.curToken.Literal
|
||||
if !ValidTypeName(typeName) && !p.isStructDefined(typeName) {
|
||||
p.addError(fmt.Sprintf("invalid type name '%s'", typeName))
|
||||
return nil
|
||||
|
||||
// Check built-in types
|
||||
switch typeName {
|
||||
case "number":
|
||||
return TypeInfo{Type: TypeNumber, Inferred: false}
|
||||
case "string":
|
||||
return TypeInfo{Type: TypeString, Inferred: false}
|
||||
case "bool":
|
||||
return TypeInfo{Type: TypeBool, Inferred: false}
|
||||
case "nil":
|
||||
return TypeInfo{Type: TypeNil, Inferred: false}
|
||||
case "table":
|
||||
return TypeInfo{Type: TypeTable, Inferred: false}
|
||||
case "function":
|
||||
return TypeInfo{Type: TypeFunction, Inferred: false}
|
||||
case "any":
|
||||
return TypeInfo{Type: TypeAny, Inferred: false}
|
||||
default:
|
||||
// Check if it's a struct type
|
||||
if structDef := p.getStructByName(typeName); structDef != nil {
|
||||
return TypeInfo{Type: TypeStruct, StructID: structDef.ID, Inferred: false}
|
||||
}
|
||||
p.addError(fmt.Sprintf("invalid type name '%s'", typeName))
|
||||
return UnknownType
|
||||
}
|
||||
|
||||
return &TypeInfo{Type: typeName, Inferred: false}
|
||||
}
|
||||
|
||||
// isStructDefined checks if a struct name is defined
|
||||
func (p *Parser) isStructDefined(name string) bool {
|
||||
_, exists := p.structs[name]
|
||||
return exists
|
||||
}
|
||||
|
||||
// registerPrefix registers a prefix parse function
|
||||
// registerPrefix/registerInfix
|
||||
func (p *Parser) registerPrefix(tokenType TokenType, fn func() Expression) {
|
||||
p.prefixParseFns[tokenType] = fn
|
||||
}
|
||||
|
||||
// registerInfix registers an infix parse function
|
||||
func (p *Parser) registerInfix(tokenType TokenType, fn func(Expression) Expression) {
|
||||
p.infixParseFns[tokenType] = fn
|
||||
}
|
||||
|
||||
// nextToken advances to the next token
|
||||
func (p *Parser) nextToken() {
|
||||
p.curToken = p.peekToken
|
||||
p.peekToken = p.lexer.NextToken()
|
||||
@ -265,7 +272,7 @@ func (p *Parser) parseStructStatement() *StructStatement {
|
||||
|
||||
if p.peekTokenIs(RBRACE) {
|
||||
p.nextToken()
|
||||
p.structs[stmt.Name] = stmt
|
||||
p.registerStruct(stmt)
|
||||
return stmt
|
||||
}
|
||||
|
||||
@ -284,9 +291,9 @@ func (p *Parser) parseStructStatement() *StructStatement {
|
||||
|
||||
field := StructField{Name: p.curToken.Literal}
|
||||
|
||||
// Parse optional type hint
|
||||
// Parse required type hint
|
||||
field.TypeHint = p.parseTypeHint()
|
||||
if field.TypeHint == nil {
|
||||
if field.TypeHint.Type == TypeUnknown {
|
||||
p.addError("struct fields require type annotation")
|
||||
return nil
|
||||
}
|
||||
@ -314,7 +321,7 @@ func (p *Parser) parseStructStatement() *StructStatement {
|
||||
return nil
|
||||
}
|
||||
|
||||
p.structs[stmt.Name] = stmt
|
||||
p.registerStruct(stmt)
|
||||
return stmt
|
||||
}
|
||||
|
||||
@ -338,12 +345,19 @@ func (p *Parser) parseFunctionStatement() Statement {
|
||||
|
||||
methodName := p.curToken.Literal
|
||||
|
||||
// Get struct ID
|
||||
structDef := p.getStructByName(funcName)
|
||||
if structDef == nil {
|
||||
p.addError(fmt.Sprintf("method defined on undefined struct '%s'", funcName))
|
||||
return nil
|
||||
}
|
||||
|
||||
if !p.expectPeek(LPAREN) {
|
||||
p.addError("expected '(' after method name")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Parse the function literal starting from parameters
|
||||
// Parse the function literal
|
||||
funcLit := &FunctionLiteral{}
|
||||
funcLit.Parameters, funcLit.Variadic = p.parseFunctionParameters()
|
||||
|
||||
@ -357,12 +371,12 @@ func (p *Parser) parseFunctionStatement() Statement {
|
||||
|
||||
p.nextToken()
|
||||
|
||||
p.enterFunctionScope()
|
||||
p.enterScope("function")
|
||||
for _, param := range funcLit.Parameters {
|
||||
p.declareVariable(param.Name)
|
||||
}
|
||||
funcLit.Body = p.parseBlockStatements(END)
|
||||
p.exitFunctionScope()
|
||||
p.exitScope()
|
||||
|
||||
if !p.curTokenIs(END) {
|
||||
p.addError("expected 'end' to close function")
|
||||
@ -370,14 +384,13 @@ func (p *Parser) parseFunctionStatement() Statement {
|
||||
}
|
||||
|
||||
return &MethodDefinition{
|
||||
StructName: funcName,
|
||||
StructID: structDef.ID,
|
||||
MethodName: methodName,
|
||||
Function: funcLit,
|
||||
}
|
||||
}
|
||||
|
||||
// Regular function - this should be handled as expression statement
|
||||
// Reset to handle as function literal
|
||||
// Regular function - handle as function literal expression statement
|
||||
funcLit := p.parseFunctionLiteral()
|
||||
if funcLit == nil {
|
||||
return nil
|
||||
@ -386,7 +399,7 @@ func (p *Parser) parseFunctionStatement() Statement {
|
||||
return &ExpressionStatement{Expression: funcLit}
|
||||
}
|
||||
|
||||
// parseIdentifierStatement handles both assignments and expression statements starting with identifiers
|
||||
// parseIdentifierStatement handles assignments and expression statements
|
||||
func (p *Parser) parseIdentifierStatement() Statement {
|
||||
// Parse the left-hand side expression first
|
||||
expr := p.ParseExpression(LOWEST)
|
||||
@ -395,28 +408,28 @@ func (p *Parser) parseIdentifierStatement() Statement {
|
||||
}
|
||||
|
||||
// Check for type hint (only valid on simple identifiers)
|
||||
var typeHint *TypeInfo
|
||||
var typeHint TypeInfo = UnknownType
|
||||
if _, ok := expr.(*Identifier); ok {
|
||||
typeHint = p.parseTypeHint()
|
||||
}
|
||||
|
||||
// Check if this is an assignment
|
||||
if p.peekTokenIs(ASSIGN) {
|
||||
// Convert to assignment statement
|
||||
stmt := &AssignStatement{
|
||||
Name: expr,
|
||||
// Create unified assignment
|
||||
assignment := &Assignment{
|
||||
Target: expr,
|
||||
TypeHint: typeHint,
|
||||
}
|
||||
|
||||
// Validate assignment target and check if it's a declaration
|
||||
switch name := expr.(type) {
|
||||
switch target := expr.(type) {
|
||||
case *Identifier:
|
||||
stmt.IsDeclaration = !p.isVariableDeclared(name.Value)
|
||||
if stmt.IsDeclaration {
|
||||
p.declareVariable(name.Value)
|
||||
assignment.IsDeclaration = !p.isVariableDeclared(target.Value)
|
||||
if assignment.IsDeclaration {
|
||||
p.declareVariable(target.Value)
|
||||
}
|
||||
case *DotExpression, *IndexExpression:
|
||||
stmt.IsDeclaration = false
|
||||
assignment.IsDeclaration = false
|
||||
default:
|
||||
p.addError("invalid assignment target")
|
||||
return nil
|
||||
@ -428,29 +441,19 @@ func (p *Parser) parseIdentifierStatement() Statement {
|
||||
|
||||
p.nextToken()
|
||||
|
||||
stmt.Value = p.ParseExpression(LOWEST)
|
||||
if stmt.Value == nil {
|
||||
assignment.Value = p.ParseExpression(LOWEST)
|
||||
if assignment.Value == nil {
|
||||
p.addError("expected expression after assignment operator")
|
||||
return nil
|
||||
}
|
||||
|
||||
return stmt
|
||||
return assignment
|
||||
} else {
|
||||
// This is an expression statement
|
||||
return &ExpressionStatement{Expression: expr}
|
||||
}
|
||||
}
|
||||
|
||||
// parseExpressionStatement parses expressions used as statements
|
||||
func (p *Parser) parseExpressionStatement() *ExpressionStatement {
|
||||
stmt := &ExpressionStatement{}
|
||||
stmt.Expression = p.ParseExpression(LOWEST)
|
||||
if stmt.Expression == nil {
|
||||
return nil
|
||||
}
|
||||
return stmt
|
||||
}
|
||||
|
||||
// parseEchoStatement parses echo statements
|
||||
func (p *Parser) parseEchoStatement() *EchoStatement {
|
||||
stmt := &EchoStatement{}
|
||||
@ -466,9 +469,8 @@ func (p *Parser) parseEchoStatement() *EchoStatement {
|
||||
return stmt
|
||||
}
|
||||
|
||||
// parseBreakStatement parses break statements
|
||||
// Simple statement parsers
|
||||
func (p *Parser) parseBreakStatement() *BreakStatement {
|
||||
// Check if break is followed by an identifier (invalid)
|
||||
if p.peekTokenIs(IDENT) {
|
||||
p.addError("unexpected identifier")
|
||||
return nil
|
||||
@ -476,7 +478,6 @@ func (p *Parser) parseBreakStatement() *BreakStatement {
|
||||
return &BreakStatement{}
|
||||
}
|
||||
|
||||
// parseExitStatement parses exit statements
|
||||
func (p *Parser) parseExitStatement() *ExitStatement {
|
||||
stmt := &ExitStatement{}
|
||||
|
||||
@ -492,7 +493,6 @@ func (p *Parser) parseExitStatement() *ExitStatement {
|
||||
return stmt
|
||||
}
|
||||
|
||||
// parseReturnStatement parses return statements
|
||||
func (p *Parser) parseReturnStatement() *ReturnStatement {
|
||||
stmt := &ReturnStatement{}
|
||||
|
||||
@ -508,7 +508,6 @@ func (p *Parser) parseReturnStatement() *ReturnStatement {
|
||||
return stmt
|
||||
}
|
||||
|
||||
// canStartExpression checks if a token type can start an expression
|
||||
func (p *Parser) canStartExpression(tokenType TokenType) bool {
|
||||
switch tokenType {
|
||||
case IDENT, NUMBER, STRING, TRUE, FALSE, NIL, LPAREN, LBRACE, MINUS, NOT, FN:
|
||||
@ -518,7 +517,7 @@ func (p *Parser) canStartExpression(tokenType TokenType) bool {
|
||||
}
|
||||
}
|
||||
|
||||
// parseWhileStatement parses while loops
|
||||
// Loop statement parsers
|
||||
func (p *Parser) parseWhileStatement() *WhileStatement {
|
||||
stmt := &WhileStatement{}
|
||||
|
||||
@ -537,9 +536,7 @@ func (p *Parser) parseWhileStatement() *WhileStatement {
|
||||
|
||||
p.nextToken()
|
||||
|
||||
p.enterBlockScope()
|
||||
stmt.Body = p.parseBlockStatements(END)
|
||||
p.exitBlockScope()
|
||||
|
||||
if !p.curTokenIs(END) {
|
||||
p.addError("expected 'end' to close while loop")
|
||||
@ -549,7 +546,6 @@ func (p *Parser) parseWhileStatement() *WhileStatement {
|
||||
return stmt
|
||||
}
|
||||
|
||||
// parseForStatement parses for loops
|
||||
func (p *Parser) parseForStatement() Statement {
|
||||
p.nextToken()
|
||||
|
||||
@ -570,7 +566,6 @@ func (p *Parser) parseForStatement() Statement {
|
||||
}
|
||||
}
|
||||
|
||||
// parseNumericForStatement parses numeric for loops
|
||||
func (p *Parser) parseNumericForStatement(variable *Identifier) *ForStatement {
|
||||
stmt := &ForStatement{Variable: variable}
|
||||
|
||||
@ -617,10 +612,10 @@ func (p *Parser) parseNumericForStatement(variable *Identifier) *ForStatement {
|
||||
|
||||
p.nextToken()
|
||||
|
||||
p.enterLoopScope()
|
||||
p.enterScope("loop")
|
||||
p.declareLoopVariable(variable.Value)
|
||||
stmt.Body = p.parseBlockStatements(END)
|
||||
p.exitLoopScope()
|
||||
p.exitScope()
|
||||
|
||||
if !p.curTokenIs(END) {
|
||||
p.addError("expected 'end' to close for loop")
|
||||
@ -630,7 +625,6 @@ func (p *Parser) parseNumericForStatement(variable *Identifier) *ForStatement {
|
||||
return stmt
|
||||
}
|
||||
|
||||
// parseForInStatement parses for-in loops
|
||||
func (p *Parser) parseForInStatement(firstVar *Identifier) *ForInStatement {
|
||||
stmt := &ForInStatement{}
|
||||
|
||||
@ -669,13 +663,13 @@ func (p *Parser) parseForInStatement(firstVar *Identifier) *ForInStatement {
|
||||
|
||||
p.nextToken()
|
||||
|
||||
p.enterLoopScope()
|
||||
p.enterScope("loop")
|
||||
if stmt.Key != nil {
|
||||
p.declareLoopVariable(stmt.Key.Value)
|
||||
}
|
||||
p.declareLoopVariable(stmt.Value.Value)
|
||||
stmt.Body = p.parseBlockStatements(END)
|
||||
p.exitLoopScope()
|
||||
p.exitScope()
|
||||
|
||||
if !p.curTokenIs(END) {
|
||||
p.addError("expected 'end' to close for loop")
|
||||
@ -708,9 +702,7 @@ func (p *Parser) parseIfStatement() *IfStatement {
|
||||
return nil
|
||||
}
|
||||
|
||||
p.enterBlockScope()
|
||||
stmt.Body = p.parseBlockStatements(ELSEIF, ELSE, END)
|
||||
p.exitBlockScope()
|
||||
|
||||
for p.curTokenIs(ELSEIF) {
|
||||
elseif := ElseIfClause{}
|
||||
@ -729,19 +721,13 @@ func (p *Parser) parseIfStatement() *IfStatement {
|
||||
|
||||
p.nextToken()
|
||||
|
||||
p.enterBlockScope()
|
||||
elseif.Body = p.parseBlockStatements(ELSEIF, ELSE, END)
|
||||
p.exitBlockScope()
|
||||
|
||||
stmt.ElseIfs = append(stmt.ElseIfs, elseif)
|
||||
}
|
||||
|
||||
if p.curTokenIs(ELSE) {
|
||||
p.nextToken()
|
||||
|
||||
p.enterBlockScope()
|
||||
stmt.Else = p.parseBlockStatements(END)
|
||||
p.exitBlockScope()
|
||||
}
|
||||
|
||||
if !p.curTokenIs(END) {
|
||||
@ -754,7 +740,7 @@ func (p *Parser) parseIfStatement() *IfStatement {
|
||||
|
||||
// parseBlockStatements parses statements until terminators
|
||||
func (p *Parser) parseBlockStatements(terminators ...TokenType) []Statement {
|
||||
statements := []Statement{}
|
||||
statements := make([]Statement, 0, 8) // Pre-allocate for performance
|
||||
|
||||
for !p.curTokenIs(EOF) && !p.isTerminator(terminators...) {
|
||||
stmt := p.parseStatement()
|
||||
@ -767,7 +753,6 @@ func (p *Parser) parseBlockStatements(terminators ...TokenType) []Statement {
|
||||
return statements
|
||||
}
|
||||
|
||||
// isTerminator checks if current token is a terminator
|
||||
func (p *Parser) isTerminator(terminators ...TokenType) bool {
|
||||
for _, terminator := range terminators {
|
||||
if p.curTokenIs(terminator) {
|
||||
@ -919,7 +904,6 @@ func (p *Parser) parseGroupedExpression() Expression {
|
||||
|
||||
// parseParenthesizedAssignment parses assignment expressions in parentheses
|
||||
func (p *Parser) parseParenthesizedAssignment() Expression {
|
||||
// We're at identifier, peek is ASSIGN
|
||||
target := p.parseIdentifier()
|
||||
|
||||
if !p.expectPeek(ASSIGN) {
|
||||
@ -939,9 +923,10 @@ func (p *Parser) parseParenthesizedAssignment() Expression {
|
||||
}
|
||||
|
||||
// Create assignment expression
|
||||
assignExpr := &AssignExpression{
|
||||
Name: target,
|
||||
assignExpr := &Assignment{
|
||||
Target: target,
|
||||
Value: value,
|
||||
IsExpression: true,
|
||||
}
|
||||
|
||||
// Handle variable declaration for assignment expressions
|
||||
@ -952,8 +937,6 @@ func (p *Parser) parseParenthesizedAssignment() Expression {
|
||||
}
|
||||
}
|
||||
|
||||
// Assignment expression evaluates to the assigned value
|
||||
assignExpr.SetType(value.GetType())
|
||||
return assignExpr
|
||||
}
|
||||
|
||||
@ -977,12 +960,12 @@ func (p *Parser) parseFunctionLiteral() Expression {
|
||||
|
||||
p.nextToken()
|
||||
|
||||
p.enterFunctionScope()
|
||||
p.enterScope("function")
|
||||
for _, param := range fn.Parameters {
|
||||
p.declareVariable(param.Name)
|
||||
}
|
||||
fn.Body = p.parseBlockStatements(END)
|
||||
p.exitFunctionScope()
|
||||
p.exitScope()
|
||||
|
||||
if !p.curTokenIs(END) {
|
||||
p.addError("expected 'end' to close function")
|
||||
@ -1038,7 +1021,7 @@ func (p *Parser) parseFunctionParameters() ([]FunctionParameter, bool) {
|
||||
|
||||
func (p *Parser) parseTableLiteral() Expression {
|
||||
table := &TableLiteral{}
|
||||
table.Pairs = []TablePair{}
|
||||
table.Pairs = make([]TablePair, 0, 4) // Pre-allocate
|
||||
|
||||
if p.peekTokenIs(RBRACE) {
|
||||
p.nextToken()
|
||||
@ -1104,22 +1087,24 @@ func (p *Parser) parseTableLiteral() Expression {
|
||||
return table
|
||||
}
|
||||
|
||||
// parseStructConstructor handles struct constructor calls like my_type{...}
|
||||
// parseStructConstructor handles struct constructor calls
|
||||
func (p *Parser) parseStructConstructor(left Expression) Expression {
|
||||
// left should be an identifier representing the struct name
|
||||
ident, ok := left.(*Identifier)
|
||||
if !ok {
|
||||
// Not an identifier, fall back to table literal parsing
|
||||
return p.parseTableLiteralFromBrace()
|
||||
}
|
||||
|
||||
structName := ident.Value
|
||||
structDef := p.getStructByName(structName)
|
||||
if structDef == nil {
|
||||
// Not a struct, parse as table literal
|
||||
return p.parseTableLiteralFromBrace()
|
||||
}
|
||||
|
||||
// Always try to parse as struct constructor if we have an identifier
|
||||
// Type checking will catch undefined structs later
|
||||
constructor := &StructConstructorExpression{
|
||||
StructName: structName,
|
||||
Fields: []TablePair{},
|
||||
constructor := &StructConstructor{
|
||||
StructID: structDef.ID,
|
||||
Fields: make([]TablePair, 0, 4),
|
||||
typeInfo: TypeInfo{Type: TypeStruct, StructID: structDef.ID, Inferred: true},
|
||||
}
|
||||
|
||||
if p.peekTokenIs(RBRACE) {
|
||||
@ -1187,9 +1172,8 @@ func (p *Parser) parseStructConstructor(left Expression) Expression {
|
||||
}
|
||||
|
||||
func (p *Parser) parseTableLiteralFromBrace() Expression {
|
||||
// We're already at the opening brace, so parse as table literal
|
||||
table := &TableLiteral{}
|
||||
table.Pairs = []TablePair{}
|
||||
table.Pairs = make([]TablePair, 0, 4)
|
||||
|
||||
if p.peekTokenIs(RBRACE) {
|
||||
p.nextToken()
|
||||
@ -1428,15 +1412,9 @@ func (p *Parser) curPrecedence() Precedence {
|
||||
return LOWEST
|
||||
}
|
||||
|
||||
// Errors returns all parsing errors
|
||||
func (p *Parser) Errors() []ParseError {
|
||||
return p.errors
|
||||
}
|
||||
|
||||
func (p *Parser) HasErrors() bool {
|
||||
return len(p.errors) > 0
|
||||
}
|
||||
|
||||
// Error reporting
|
||||
func (p *Parser) Errors() []ParseError { return p.errors }
|
||||
func (p *Parser) HasErrors() bool { return len(p.errors) > 0 }
|
||||
func (p *Parser) ErrorStrings() []string {
|
||||
result := make([]string, len(p.errors))
|
||||
for i, err := range p.errors {
|
||||
|
@ -31,15 +31,15 @@ func TestAssignStatements(t *testing.T) {
|
||||
t.Fatalf("expected 1 statement, got %d", len(program.Statements))
|
||||
}
|
||||
|
||||
stmt, ok := program.Statements[0].(*parser.AssignStatement)
|
||||
stmt, ok := program.Statements[0].(*parser.Assignment)
|
||||
if !ok {
|
||||
t.Fatalf("expected AssignStatement, got %T", program.Statements[0])
|
||||
t.Fatalf("expected Assignment, got %T", program.Statements[0])
|
||||
}
|
||||
|
||||
// Check that Name is an Identifier
|
||||
ident, ok := stmt.Name.(*parser.Identifier)
|
||||
// Check that Target is an Identifier
|
||||
ident, ok := stmt.Target.(*parser.Identifier)
|
||||
if !ok {
|
||||
t.Fatalf("expected Identifier for Name, got %T", stmt.Name)
|
||||
t.Fatalf("expected Identifier for Target, got %T", stmt.Target)
|
||||
}
|
||||
|
||||
if ident.Value != tt.expectedIdentifier {
|
||||
@ -90,9 +90,9 @@ func TestMemberAccessAssignment(t *testing.T) {
|
||||
t.Fatalf("expected 1 statement, got %d", len(program.Statements))
|
||||
}
|
||||
|
||||
stmt, ok := program.Statements[0].(*parser.AssignStatement)
|
||||
stmt, ok := program.Statements[0].(*parser.Assignment)
|
||||
if !ok {
|
||||
t.Fatalf("expected AssignStatement, got %T", program.Statements[0])
|
||||
t.Fatalf("expected Assignment, got %T", program.Statements[0])
|
||||
}
|
||||
|
||||
if stmt.String() != tt.expected {
|
||||
@ -158,15 +158,15 @@ func TestTableAssignments(t *testing.T) {
|
||||
t.Fatalf("expected 1 statement, got %d", len(program.Statements))
|
||||
}
|
||||
|
||||
stmt, ok := program.Statements[0].(*parser.AssignStatement)
|
||||
stmt, ok := program.Statements[0].(*parser.Assignment)
|
||||
if !ok {
|
||||
t.Fatalf("expected AssignStatement, got %T", program.Statements[0])
|
||||
t.Fatalf("expected Assignment, got %T", program.Statements[0])
|
||||
}
|
||||
|
||||
// Check that Name is an Identifier
|
||||
ident, ok := stmt.Name.(*parser.Identifier)
|
||||
// Check that Target is an Identifier
|
||||
ident, ok := stmt.Target.(*parser.Identifier)
|
||||
if !ok {
|
||||
t.Fatalf("expected Identifier for Name, got %T", stmt.Name)
|
||||
t.Fatalf("expected Identifier for Target, got %T", stmt.Target)
|
||||
}
|
||||
|
||||
if ident.Value != tt.identifier {
|
||||
|
@ -247,9 +247,9 @@ exit "success"`
|
||||
}
|
||||
|
||||
// First: assignment
|
||||
_, ok := program.Statements[0].(*parser.AssignStatement)
|
||||
_, ok := program.Statements[0].(*parser.Assignment)
|
||||
if !ok {
|
||||
t.Fatalf("statement 0: expected AssignStatement, got %T", program.Statements[0])
|
||||
t.Fatalf("statement 0: expected Assignment, got %T", program.Statements[0])
|
||||
}
|
||||
|
||||
// Second: if statement with exit in body
|
||||
@ -264,7 +264,7 @@ exit "success"`
|
||||
}
|
||||
|
||||
// Third: assignment
|
||||
_, ok = program.Statements[2].(*parser.AssignStatement)
|
||||
_, ok = program.Statements[2].(*parser.Assignment)
|
||||
if !ok {
|
||||
t.Fatalf("statement 2: expected AssignStatement, got %T", program.Statements[2])
|
||||
}
|
||||
|
@ -32,15 +32,15 @@ end`
|
||||
t.Fatalf("expected 1 body statement, got %d", len(stmt.Body))
|
||||
}
|
||||
|
||||
bodyStmt, ok := stmt.Body[0].(*parser.AssignStatement)
|
||||
bodyStmt, ok := stmt.Body[0].(*parser.Assignment)
|
||||
if !ok {
|
||||
t.Fatalf("expected AssignStatement in body, got %T", stmt.Body[0])
|
||||
}
|
||||
|
||||
// Check that Name is an Identifier
|
||||
ident, ok := bodyStmt.Name.(*parser.Identifier)
|
||||
// Check that Target is an Identifier
|
||||
ident, ok := bodyStmt.Target.(*parser.Identifier)
|
||||
if !ok {
|
||||
t.Fatalf("expected Identifier for Name, got %T", bodyStmt.Name)
|
||||
t.Fatalf("expected Identifier for Target, got %T", bodyStmt.Target)
|
||||
}
|
||||
|
||||
if ident.Value != "x" {
|
||||
@ -79,15 +79,15 @@ end`
|
||||
t.Fatalf("expected 1 else statement, got %d", len(stmt.Else))
|
||||
}
|
||||
|
||||
elseStmt, ok := stmt.Else[0].(*parser.AssignStatement)
|
||||
elseStmt, ok := stmt.Else[0].(*parser.Assignment)
|
||||
if !ok {
|
||||
t.Fatalf("expected AssignStatement in else, got %T", stmt.Else[0])
|
||||
}
|
||||
|
||||
// Check that Name is an Identifier
|
||||
ident, ok := elseStmt.Name.(*parser.Identifier)
|
||||
ident, ok := elseStmt.Target.(*parser.Identifier)
|
||||
if !ok {
|
||||
t.Fatalf("expected Identifier for Name, got %T", elseStmt.Name)
|
||||
t.Fatalf("expected Identifier for Name, got %T", elseStmt.Target)
|
||||
}
|
||||
|
||||
if ident.Value != "x" {
|
||||
@ -169,25 +169,25 @@ end`
|
||||
}
|
||||
|
||||
// First assignment: arr[1] = "updated"
|
||||
assign1, ok := stmt.Body[0].(*parser.AssignStatement)
|
||||
assign1, ok := stmt.Body[0].(*parser.Assignment)
|
||||
if !ok {
|
||||
t.Fatalf("expected AssignStatement, got %T", stmt.Body[0])
|
||||
}
|
||||
|
||||
_, ok = assign1.Name.(*parser.IndexExpression)
|
||||
_, ok = assign1.Target.(*parser.IndexExpression)
|
||||
if !ok {
|
||||
t.Fatalf("expected IndexExpression for assignment target, got %T", assign1.Name)
|
||||
t.Fatalf("expected IndexExpression for assignment target, got %T", assign1.Target)
|
||||
}
|
||||
|
||||
// Second assignment: obj.nested.count = obj.nested.count + 1
|
||||
assign2, ok := stmt.Body[1].(*parser.AssignStatement)
|
||||
assign2, ok := stmt.Body[1].(*parser.Assignment)
|
||||
if !ok {
|
||||
t.Fatalf("expected AssignStatement, got %T", stmt.Body[1])
|
||||
}
|
||||
|
||||
_, ok = assign2.Name.(*parser.DotExpression)
|
||||
_, ok = assign2.Target.(*parser.DotExpression)
|
||||
if !ok {
|
||||
t.Fatalf("expected DotExpression for assignment target, got %T", assign2.Name)
|
||||
t.Fatalf("expected DotExpression for assignment target, got %T", assign2.Target)
|
||||
}
|
||||
}
|
||||
|
||||
@ -214,7 +214,7 @@ end`
|
||||
}
|
||||
|
||||
// Test body has expression assignment
|
||||
bodyStmt := stmt.Body[0].(*parser.AssignStatement)
|
||||
bodyStmt := stmt.Body[0].(*parser.Assignment)
|
||||
bodyInfix, ok := bodyStmt.Value.(*parser.InfixExpression)
|
||||
if !ok {
|
||||
t.Fatalf("expected InfixExpression value, got %T", bodyStmt.Value)
|
||||
|
@ -352,15 +352,15 @@ func TestAssignmentExpressions(t *testing.T) {
|
||||
expr := p.ParseExpression(parser.LOWEST)
|
||||
checkParserErrors(t, p)
|
||||
|
||||
assignExpr, ok := expr.(*parser.AssignExpression)
|
||||
assignExpr, ok := expr.(*parser.Assignment)
|
||||
if !ok {
|
||||
t.Fatalf("expected AssignExpression, got %T", expr)
|
||||
}
|
||||
|
||||
// Test target name
|
||||
ident, ok := assignExpr.Name.(*parser.Identifier)
|
||||
ident, ok := assignExpr.Target.(*parser.Identifier)
|
||||
if !ok {
|
||||
t.Fatalf("expected Identifier for assignment target, got %T", assignExpr.Name)
|
||||
t.Fatalf("expected Identifier for assignment target, got %T", assignExpr.Target)
|
||||
}
|
||||
|
||||
if ident.Value != tt.targetName {
|
||||
@ -413,12 +413,12 @@ func TestAssignmentExpressionWithComplexExpressions(t *testing.T) {
|
||||
expr := p.ParseExpression(parser.LOWEST)
|
||||
checkParserErrors(t, p)
|
||||
|
||||
assignExpr, ok := expr.(*parser.AssignExpression)
|
||||
assignExpr, ok := expr.(*parser.Assignment)
|
||||
if !ok {
|
||||
t.Fatalf("expected AssignExpression, got %T", expr)
|
||||
}
|
||||
|
||||
if assignExpr.Name == nil {
|
||||
if assignExpr.Target == nil {
|
||||
t.Error("expected non-nil assignment target")
|
||||
}
|
||||
|
||||
|
@ -160,7 +160,7 @@ func TestFunctionAssignments(t *testing.T) {
|
||||
t.Fatalf("expected 1 statement, got %d", len(program.Statements))
|
||||
}
|
||||
|
||||
stmt, ok := program.Statements[0].(*parser.AssignStatement)
|
||||
stmt, ok := program.Statements[0].(*parser.Assignment)
|
||||
if !ok {
|
||||
t.Fatalf("expected AssignStatement, got %T", program.Statements[0])
|
||||
}
|
||||
@ -296,7 +296,7 @@ end`
|
||||
}
|
||||
|
||||
// First statement: assignment of inner function
|
||||
assign, ok := fn.Body[0].(*parser.AssignStatement)
|
||||
assign, ok := fn.Body[0].(*parser.Assignment)
|
||||
if !ok {
|
||||
t.Fatalf("expected AssignStatement, got %T", fn.Body[0])
|
||||
}
|
||||
@ -342,7 +342,7 @@ end`
|
||||
}
|
||||
|
||||
// First: function assignment
|
||||
assign, ok := forStmt.Body[0].(*parser.AssignStatement)
|
||||
assign, ok := forStmt.Body[0].(*parser.Assignment)
|
||||
if !ok {
|
||||
t.Fatalf("expected AssignStatement, got %T", forStmt.Body[0])
|
||||
}
|
||||
@ -551,7 +551,7 @@ echo adder`
|
||||
}
|
||||
|
||||
// First: table with functions
|
||||
mathAssign, ok := program.Statements[0].(*parser.AssignStatement)
|
||||
mathAssign, ok := program.Statements[0].(*parser.Assignment)
|
||||
if !ok {
|
||||
t.Fatalf("statement 0: expected AssignStatement, got %T", program.Statements[0])
|
||||
}
|
||||
@ -574,7 +574,7 @@ echo adder`
|
||||
}
|
||||
|
||||
// Second: result assignment (function call would be handled by interpreter)
|
||||
_, ok = program.Statements[1].(*parser.AssignStatement)
|
||||
_, ok = program.Statements[1].(*parser.Assignment)
|
||||
if !ok {
|
||||
t.Fatalf("statement 1: expected AssignStatement, got %T", program.Statements[1])
|
||||
}
|
||||
@ -586,7 +586,7 @@ echo adder`
|
||||
}
|
||||
|
||||
// Fourth: calculator function assignment
|
||||
calcAssign, ok := program.Statements[3].(*parser.AssignStatement)
|
||||
calcAssign, ok := program.Statements[3].(*parser.Assignment)
|
||||
if !ok {
|
||||
t.Fatalf("statement 3: expected AssignStatement, got %T", program.Statements[3])
|
||||
}
|
||||
@ -601,7 +601,7 @@ echo adder`
|
||||
}
|
||||
|
||||
// Fifth: adder assignment
|
||||
_, ok = program.Statements[4].(*parser.AssignStatement)
|
||||
_, ok = program.Statements[4].(*parser.Assignment)
|
||||
if !ok {
|
||||
t.Fatalf("statement 4: expected AssignStatement, got %T", program.Statements[4])
|
||||
}
|
||||
@ -645,7 +645,7 @@ end`
|
||||
}
|
||||
|
||||
// Check if body has function assignment
|
||||
ifAssign, ok := ifStmt.Body[0].(*parser.AssignStatement)
|
||||
ifAssign, ok := ifStmt.Body[0].(*parser.Assignment)
|
||||
if !ok {
|
||||
t.Fatalf("if body: expected AssignStatement, got %T", ifStmt.Body[0])
|
||||
}
|
||||
@ -656,7 +656,7 @@ end`
|
||||
}
|
||||
|
||||
// Check else body has function assignment
|
||||
elseAssign, ok := ifStmt.Else[0].(*parser.AssignStatement)
|
||||
elseAssign, ok := ifStmt.Else[0].(*parser.Assignment)
|
||||
if !ok {
|
||||
t.Fatalf("else body: expected AssignStatement, got %T", ifStmt.Else[0])
|
||||
}
|
||||
@ -683,7 +683,7 @@ end`
|
||||
}
|
||||
|
||||
// Verify both branches assign functions
|
||||
nestedIfAssign, ok := nestedIf.Body[0].(*parser.AssignStatement)
|
||||
nestedIfAssign, ok := nestedIf.Body[0].(*parser.Assignment)
|
||||
if !ok {
|
||||
t.Fatalf("nested if body: expected AssignStatement, got %T", nestedIf.Body[0])
|
||||
}
|
||||
@ -693,7 +693,7 @@ end`
|
||||
t.Fatalf("nested if body: expected FunctionLiteral, got %T", nestedIfAssign.Value)
|
||||
}
|
||||
|
||||
nestedElseAssign, ok := nestedIf.Else[0].(*parser.AssignStatement)
|
||||
nestedElseAssign, ok := nestedIf.Else[0].(*parser.Assignment)
|
||||
if !ok {
|
||||
t.Fatalf("nested else body: expected AssignStatement, got %T", nestedIf.Else[0])
|
||||
}
|
||||
|
@ -327,13 +327,13 @@ end`
|
||||
}
|
||||
|
||||
// First: table assignment
|
||||
_, ok := program.Statements[0].(*parser.AssignStatement)
|
||||
_, ok := program.Statements[0].(*parser.Assignment)
|
||||
if !ok {
|
||||
t.Fatalf("statement 0: expected AssignStatement, got %T", program.Statements[0])
|
||||
}
|
||||
|
||||
// Second: variable assignment
|
||||
_, ok = program.Statements[1].(*parser.AssignStatement)
|
||||
_, ok = program.Statements[1].(*parser.Assignment)
|
||||
if !ok {
|
||||
t.Fatalf("statement 1: expected AssignStatement, got %T", program.Statements[1])
|
||||
}
|
||||
@ -590,7 +590,7 @@ end`
|
||||
}
|
||||
|
||||
// Second body statement should be assignment
|
||||
_, ok = outerWhile.Body[1].(*parser.AssignStatement)
|
||||
_, ok = outerWhile.Body[1].(*parser.Assignment)
|
||||
if !ok {
|
||||
t.Fatalf("expected AssignStatement, got %T", outerWhile.Body[1])
|
||||
}
|
||||
@ -634,14 +634,14 @@ end`
|
||||
}
|
||||
|
||||
// First assignment: data[index] = ...
|
||||
assign1, ok := stmt.Body[0].(*parser.AssignStatement)
|
||||
assign1, ok := stmt.Body[0].(*parser.Assignment)
|
||||
if !ok {
|
||||
t.Fatalf("expected AssignStatement, got %T", stmt.Body[0])
|
||||
}
|
||||
|
||||
_, ok = assign1.Name.(*parser.IndexExpression)
|
||||
_, ok = assign1.Target.(*parser.IndexExpression)
|
||||
if !ok {
|
||||
t.Fatalf("expected IndexExpression for assignment target, got %T", assign1.Name)
|
||||
t.Fatalf("expected IndexExpression for assignment target, got %T", assign1.Target)
|
||||
}
|
||||
}
|
||||
|
||||
@ -755,7 +755,7 @@ end`
|
||||
|
||||
// First three: assignments
|
||||
for i := 0; i < 3; i++ {
|
||||
_, ok := program.Statements[i].(*parser.AssignStatement)
|
||||
_, ok := program.Statements[i].(*parser.Assignment)
|
||||
if !ok {
|
||||
t.Fatalf("statement %d: expected AssignStatement, got %T", i, program.Statements[i])
|
||||
}
|
||||
@ -838,7 +838,7 @@ end`
|
||||
}
|
||||
|
||||
// Fourth: assignment
|
||||
_, ok = whileStmt.Body[3].(*parser.AssignStatement)
|
||||
_, ok = whileStmt.Body[3].(*parser.Assignment)
|
||||
if !ok {
|
||||
t.Fatalf("body[3]: expected AssignStatement, got %T", whileStmt.Body[3])
|
||||
}
|
||||
|
@ -22,14 +22,14 @@ z = true + false`
|
||||
|
||||
expectedIdentifiers := []string{"x", "y", "z"}
|
||||
for i, expectedIdent := range expectedIdentifiers {
|
||||
stmt, ok := program.Statements[i].(*parser.AssignStatement)
|
||||
stmt, ok := program.Statements[i].(*parser.Assignment)
|
||||
if !ok {
|
||||
t.Fatalf("statement %d: expected AssignStatement, got %T", i, program.Statements[i])
|
||||
}
|
||||
|
||||
ident, ok := stmt.Name.(*parser.Identifier)
|
||||
ident, ok := stmt.Target.(*parser.Identifier)
|
||||
if !ok {
|
||||
t.Fatalf("statement %d: expected Identifier for Name, got %T", i, stmt.Name)
|
||||
t.Fatalf("statement %d: expected Identifier for Name, got %T", i, stmt.Target)
|
||||
}
|
||||
|
||||
if ident.Value != expectedIdent {
|
||||
@ -58,13 +58,13 @@ arr = {a = 1, b = 2}`
|
||||
}
|
||||
|
||||
// First statement: assignment
|
||||
stmt1, ok := program.Statements[0].(*parser.AssignStatement)
|
||||
stmt1, ok := program.Statements[0].(*parser.Assignment)
|
||||
if !ok {
|
||||
t.Fatalf("statement 0: expected AssignStatement, got %T", program.Statements[0])
|
||||
}
|
||||
ident1, ok := stmt1.Name.(*parser.Identifier)
|
||||
ident1, ok := stmt1.Target.(*parser.Identifier)
|
||||
if !ok {
|
||||
t.Fatalf("expected Identifier for Name, got %T", stmt1.Name)
|
||||
t.Fatalf("expected Identifier for Name, got %T", stmt1.Target)
|
||||
}
|
||||
if ident1.Value != "x" {
|
||||
t.Errorf("expected identifier 'x', got %s", ident1.Value)
|
||||
@ -80,7 +80,7 @@ arr = {a = 1, b = 2}`
|
||||
}
|
||||
|
||||
// Third statement: table assignment
|
||||
stmt3, ok := program.Statements[2].(*parser.AssignStatement)
|
||||
stmt3, ok := program.Statements[2].(*parser.Assignment)
|
||||
if !ok {
|
||||
t.Fatalf("statement 2: expected AssignStatement, got %T", program.Statements[2])
|
||||
}
|
||||
@ -110,33 +110,33 @@ echo table[table.key]`
|
||||
}
|
||||
|
||||
// Second statement: dot assignment
|
||||
stmt2, ok := program.Statements[1].(*parser.AssignStatement)
|
||||
stmt2, ok := program.Statements[1].(*parser.Assignment)
|
||||
if !ok {
|
||||
t.Fatalf("statement 1: expected AssignStatement, got %T", program.Statements[1])
|
||||
}
|
||||
_, ok = stmt2.Name.(*parser.DotExpression)
|
||||
_, ok = stmt2.Target.(*parser.DotExpression)
|
||||
if !ok {
|
||||
t.Fatalf("expected DotExpression for assignment target, got %T", stmt2.Name)
|
||||
t.Fatalf("expected DotExpression for assignment target, got %T", stmt2.Target)
|
||||
}
|
||||
|
||||
// Third statement: bracket assignment
|
||||
stmt3, ok := program.Statements[2].(*parser.AssignStatement)
|
||||
stmt3, ok := program.Statements[2].(*parser.Assignment)
|
||||
if !ok {
|
||||
t.Fatalf("statement 2: expected AssignStatement, got %T", program.Statements[2])
|
||||
}
|
||||
_, ok = stmt3.Name.(*parser.IndexExpression)
|
||||
_, ok = stmt3.Target.(*parser.IndexExpression)
|
||||
if !ok {
|
||||
t.Fatalf("expected IndexExpression for assignment target, got %T", stmt3.Name)
|
||||
t.Fatalf("expected IndexExpression for assignment target, got %T", stmt3.Target)
|
||||
}
|
||||
|
||||
// Fourth statement: chained dot assignment
|
||||
stmt4, ok := program.Statements[3].(*parser.AssignStatement)
|
||||
stmt4, ok := program.Statements[3].(*parser.Assignment)
|
||||
if !ok {
|
||||
t.Fatalf("statement 3: expected AssignStatement, got %T", program.Statements[3])
|
||||
}
|
||||
_, ok = stmt4.Name.(*parser.DotExpression)
|
||||
_, ok = stmt4.Target.(*parser.DotExpression)
|
||||
if !ok {
|
||||
t.Fatalf("expected DotExpression for assignment target, got %T", stmt4.Name)
|
||||
t.Fatalf("expected DotExpression for assignment target, got %T", stmt4.Target)
|
||||
}
|
||||
|
||||
// Fifth statement: echo with nested access
|
||||
@ -232,7 +232,7 @@ end`
|
||||
}
|
||||
|
||||
// First statement: complex expression assignment
|
||||
stmt1, ok := program.Statements[0].(*parser.AssignStatement)
|
||||
stmt1, ok := program.Statements[0].(*parser.Assignment)
|
||||
if !ok {
|
||||
t.Fatalf("expected AssignStatement, got %T", program.Statements[0])
|
||||
}
|
||||
@ -253,7 +253,7 @@ end`
|
||||
t.Fatalf("expected 1 body statement, got %d", len(stmt2.Body))
|
||||
}
|
||||
|
||||
bodyStmt, ok := stmt2.Body[0].(*parser.AssignStatement)
|
||||
bodyStmt, ok := stmt2.Body[0].(*parser.Assignment)
|
||||
if !ok {
|
||||
t.Fatalf("expected AssignStatement in body, got %T", stmt2.Body[0])
|
||||
}
|
||||
@ -286,7 +286,7 @@ echo {result = x}`
|
||||
}
|
||||
|
||||
// First: assignment
|
||||
_, ok := program.Statements[0].(*parser.AssignStatement)
|
||||
_, ok := program.Statements[0].(*parser.Assignment)
|
||||
if !ok {
|
||||
t.Fatalf("statement 0: expected AssignStatement, got %T", program.Statements[0])
|
||||
}
|
||||
|
@ -63,7 +63,7 @@ x = 15`,
|
||||
|
||||
assignmentCount := 0
|
||||
for _, stmt := range program.Statements {
|
||||
if assign, ok := stmt.(*parser.AssignStatement); ok {
|
||||
if assign, ok := stmt.(*parser.Assignment); ok {
|
||||
if assignmentCount >= len(tt.assignments) {
|
||||
t.Fatalf("more assignments than expected")
|
||||
}
|
||||
@ -71,9 +71,9 @@ x = 15`,
|
||||
expected := tt.assignments[assignmentCount]
|
||||
|
||||
// Check variable name
|
||||
ident, ok := assign.Name.(*parser.Identifier)
|
||||
ident, ok := assign.Target.(*parser.Identifier)
|
||||
if !ok {
|
||||
t.Fatalf("expected Identifier, got %T", assign.Name)
|
||||
t.Fatalf("expected Identifier, got %T", assign.Target)
|
||||
}
|
||||
|
||||
if ident.Value != expected.variable {
|
||||
@ -135,9 +135,9 @@ z = 30`
|
||||
|
||||
for i, expected := range expectedAssignments {
|
||||
assign := assignments[i]
|
||||
ident, ok := assign.Name.(*parser.Identifier)
|
||||
ident, ok := assign.Target.(*parser.Identifier)
|
||||
if !ok {
|
||||
t.Fatalf("assignment %d: expected Identifier, got %T", i, assign.Name)
|
||||
t.Fatalf("assignment %d: expected Identifier, got %T", i, assign.Target)
|
||||
}
|
||||
|
||||
if ident.Value != expected.variable {
|
||||
@ -191,9 +191,9 @@ end`
|
||||
|
||||
for i, expected := range expectedAssignments {
|
||||
assign := assignments[i]
|
||||
ident, ok := assign.Name.(*parser.Identifier)
|
||||
ident, ok := assign.Target.(*parser.Identifier)
|
||||
if !ok {
|
||||
t.Fatalf("assignment %d: expected Identifier, got %T", i, assign.Name)
|
||||
t.Fatalf("assignment %d: expected Identifier, got %T", i, assign.Target)
|
||||
}
|
||||
|
||||
if ident.Value != expected.variable {
|
||||
@ -243,9 +243,9 @@ c = 20`
|
||||
|
||||
for i, expected := range expectedAssignments {
|
||||
assign := assignments[i]
|
||||
ident, ok := assign.Name.(*parser.Identifier)
|
||||
ident, ok := assign.Target.(*parser.Identifier)
|
||||
if !ok {
|
||||
t.Fatalf("assignment %d: expected Identifier, got %T", i, assign.Name)
|
||||
t.Fatalf("assignment %d: expected Identifier, got %T", i, assign.Target)
|
||||
}
|
||||
|
||||
if ident.Value != expected.variable {
|
||||
@ -344,9 +344,9 @@ count = 0`,
|
||||
|
||||
for i, expected := range tt.assignments {
|
||||
assign := assignments[i]
|
||||
ident, ok := assign.Name.(*parser.Identifier)
|
||||
ident, ok := assign.Target.(*parser.Identifier)
|
||||
if !ok {
|
||||
t.Fatalf("assignment %d: expected Identifier, got %T", i, assign.Name)
|
||||
t.Fatalf("assignment %d: expected Identifier, got %T", i, assign.Target)
|
||||
}
|
||||
|
||||
if ident.Value != expected.variable {
|
||||
@ -389,7 +389,7 @@ arr[1] = 99`
|
||||
|
||||
assignmentCount := 0
|
||||
for _, stmt := range program.Statements {
|
||||
if assign, ok := stmt.(*parser.AssignStatement); ok {
|
||||
if assign, ok := stmt.(*parser.Assignment); ok {
|
||||
if assignmentCount >= len(expectedAssignments) {
|
||||
t.Fatalf("more assignments than expected")
|
||||
}
|
||||
@ -398,7 +398,7 @@ arr[1] = 99`
|
||||
|
||||
if expected.isMemberAccess {
|
||||
// Should not be an identifier
|
||||
if _, ok := assign.Name.(*parser.Identifier); ok {
|
||||
if _, ok := assign.Target.(*parser.Identifier); ok {
|
||||
t.Errorf("assignment %d: expected member access, got Identifier", assignmentCount)
|
||||
}
|
||||
|
||||
@ -408,9 +408,9 @@ arr[1] = 99`
|
||||
}
|
||||
} else {
|
||||
// Should be an identifier
|
||||
ident, ok := assign.Name.(*parser.Identifier)
|
||||
ident, ok := assign.Target.(*parser.Identifier)
|
||||
if !ok {
|
||||
t.Errorf("assignment %d: expected Identifier, got %T", assignmentCount, assign.Name)
|
||||
t.Errorf("assignment %d: expected Identifier, got %T", assignmentCount, assign.Target)
|
||||
} else if ident.Value != expected.variable {
|
||||
t.Errorf("assignment %d: expected variable %s, got %s",
|
||||
assignmentCount, expected.variable, ident.Value)
|
||||
@ -487,9 +487,9 @@ local_var = "global_local"`
|
||||
|
||||
for i, expected := range expectedAssignments {
|
||||
assign := assignments[i]
|
||||
ident, ok := assign.Name.(*parser.Identifier)
|
||||
ident, ok := assign.Target.(*parser.Identifier)
|
||||
if !ok {
|
||||
t.Fatalf("assignment %d: expected Identifier, got %T", i, assign.Name)
|
||||
t.Fatalf("assignment %d: expected Identifier, got %T", i, assign.Target)
|
||||
}
|
||||
|
||||
if ident.Value != expected.variable {
|
||||
@ -536,8 +536,8 @@ y = 20.00`
|
||||
}
|
||||
|
||||
// Helper function to extract all assignments from a program recursively
|
||||
func extractAssignments(program *parser.Program) []*parser.AssignStatement {
|
||||
var assignments []*parser.AssignStatement
|
||||
func extractAssignments(program *parser.Program) []*parser.Assignment {
|
||||
var assignments []*parser.Assignment
|
||||
|
||||
for _, stmt := range program.Statements {
|
||||
assignments = append(assignments, extractAssignmentsFromStatement(stmt)...)
|
||||
@ -546,11 +546,11 @@ func extractAssignments(program *parser.Program) []*parser.AssignStatement {
|
||||
return assignments
|
||||
}
|
||||
|
||||
func extractAssignmentsFromStatement(stmt parser.Statement) []*parser.AssignStatement {
|
||||
var assignments []*parser.AssignStatement
|
||||
func extractAssignmentsFromStatement(stmt parser.Statement) []*parser.Assignment {
|
||||
var assignments []*parser.Assignment
|
||||
|
||||
switch s := stmt.(type) {
|
||||
case *parser.AssignStatement:
|
||||
case *parser.Assignment:
|
||||
assignments = append(assignments, s)
|
||||
|
||||
// Check if the value is a function literal with assignments in body
|
||||
|
@ -38,22 +38,22 @@ func TestBasicStructDefinition(t *testing.T) {
|
||||
if stmt.Fields[0].Name != "name" {
|
||||
t.Errorf("expected field name 'name', got %s", stmt.Fields[0].Name)
|
||||
}
|
||||
if stmt.Fields[0].TypeHint == nil {
|
||||
if stmt.Fields[0].TypeHint.Type == parser.TypeUnknown {
|
||||
t.Fatal("expected type hint for name field")
|
||||
}
|
||||
if stmt.Fields[0].TypeHint.Type != "string" {
|
||||
t.Errorf("expected type 'string', got %s", stmt.Fields[0].TypeHint.Type)
|
||||
if stmt.Fields[0].TypeHint.Type != parser.TypeString {
|
||||
t.Errorf("expected type string, got %v", stmt.Fields[0].TypeHint.Type)
|
||||
}
|
||||
|
||||
// Test second field
|
||||
if stmt.Fields[1].Name != "age" {
|
||||
t.Errorf("expected field name 'age', got %s", stmt.Fields[1].Name)
|
||||
}
|
||||
if stmt.Fields[1].TypeHint == nil {
|
||||
if stmt.Fields[1].TypeHint.Type == parser.TypeUnknown {
|
||||
t.Fatal("expected type hint for age field")
|
||||
}
|
||||
if stmt.Fields[1].TypeHint.Type != "number" {
|
||||
t.Errorf("expected type 'number', got %s", stmt.Fields[1].TypeHint.Type)
|
||||
if stmt.Fields[1].TypeHint.Type != parser.TypeNumber {
|
||||
t.Errorf("expected type number, got %v", stmt.Fields[1].TypeHint.Type)
|
||||
}
|
||||
}
|
||||
|
||||
@ -107,7 +107,7 @@ func TestComplexStructDefinition(t *testing.T) {
|
||||
t.Fatalf("expected StructStatement, got %T", program.Statements[0])
|
||||
}
|
||||
|
||||
expectedTypes := []string{"number", "string", "bool", "table", "function", "any"}
|
||||
expectedTypes := []parser.Type{parser.TypeNumber, parser.TypeString, parser.TypeBool, parser.TypeTable, parser.TypeFunction, parser.TypeAny}
|
||||
expectedNames := []string{"id", "name", "active", "data", "callback", "optional"}
|
||||
|
||||
if len(stmt.Fields) != len(expectedTypes) {
|
||||
@ -118,11 +118,11 @@ func TestComplexStructDefinition(t *testing.T) {
|
||||
if field.Name != expectedNames[i] {
|
||||
t.Errorf("field %d: expected name '%s', got '%s'", i, expectedNames[i], field.Name)
|
||||
}
|
||||
if field.TypeHint == nil {
|
||||
if field.TypeHint.Type == parser.TypeUnknown {
|
||||
t.Fatalf("field %d: expected type hint", i)
|
||||
}
|
||||
if field.TypeHint.Type != expectedTypes[i] {
|
||||
t.Errorf("field %d: expected type '%s', got '%s'", i, expectedTypes[i], field.TypeHint.Type)
|
||||
t.Errorf("field %d: expected type %v, got %v", i, expectedTypes[i], field.TypeHint.Type)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -164,17 +164,17 @@ end`
|
||||
if !ok {
|
||||
t.Fatalf("expected MethodDefinition, got %T", program.Statements[1])
|
||||
}
|
||||
if method1.StructName != "Person" {
|
||||
t.Errorf("expected struct name 'Person', got %s", method1.StructName)
|
||||
if method1.StructID != structStmt.ID {
|
||||
t.Errorf("expected struct ID %d, got %d", structStmt.ID, method1.StructID)
|
||||
}
|
||||
if method1.MethodName != "getName" {
|
||||
t.Errorf("expected method name 'getName', got %s", method1.MethodName)
|
||||
}
|
||||
if method1.Function.ReturnType == nil {
|
||||
if method1.Function.ReturnType.Type == parser.TypeUnknown {
|
||||
t.Fatal("expected return type for getName method")
|
||||
}
|
||||
if method1.Function.ReturnType.Type != "string" {
|
||||
t.Errorf("expected return type 'string', got %s", method1.Function.ReturnType.Type)
|
||||
if method1.Function.ReturnType.Type != parser.TypeString {
|
||||
t.Errorf("expected return type string, got %v", method1.Function.ReturnType.Type)
|
||||
}
|
||||
if len(method1.Function.Parameters) != 0 {
|
||||
t.Errorf("expected 0 parameters, got %d", len(method1.Function.Parameters))
|
||||
@ -185,14 +185,14 @@ end`
|
||||
if !ok {
|
||||
t.Fatalf("expected MethodDefinition, got %T", program.Statements[2])
|
||||
}
|
||||
if method2.StructName != "Person" {
|
||||
t.Errorf("expected struct name 'Person', got %s", method2.StructName)
|
||||
if method2.StructID != structStmt.ID {
|
||||
t.Errorf("expected struct ID %d, got %d", structStmt.ID, method2.StructID)
|
||||
}
|
||||
if method2.MethodName != "setAge" {
|
||||
t.Errorf("expected method name 'setAge', got %s", method2.MethodName)
|
||||
}
|
||||
if method2.Function.ReturnType != nil {
|
||||
t.Errorf("expected no return type for setAge method, got %s", method2.Function.ReturnType.Type)
|
||||
if method2.Function.ReturnType.Type != parser.TypeUnknown {
|
||||
t.Errorf("expected no return type for setAge method, got %v", method2.Function.ReturnType.Type)
|
||||
}
|
||||
if len(method2.Function.Parameters) != 1 {
|
||||
t.Fatalf("expected 1 parameter, got %d", len(method2.Function.Parameters))
|
||||
@ -200,11 +200,11 @@ end`
|
||||
if method2.Function.Parameters[0].Name != "newAge" {
|
||||
t.Errorf("expected parameter name 'newAge', got %s", method2.Function.Parameters[0].Name)
|
||||
}
|
||||
if method2.Function.Parameters[0].TypeHint == nil {
|
||||
if method2.Function.Parameters[0].TypeHint.Type == parser.TypeUnknown {
|
||||
t.Fatal("expected type hint for newAge parameter")
|
||||
}
|
||||
if method2.Function.Parameters[0].TypeHint.Type != "number" {
|
||||
t.Errorf("expected parameter type 'number', got %s", method2.Function.Parameters[0].TypeHint.Type)
|
||||
if method2.Function.Parameters[0].TypeHint.Type != parser.TypeNumber {
|
||||
t.Errorf("expected parameter type number, got %v", method2.Function.Parameters[0].TypeHint.Type)
|
||||
}
|
||||
}
|
||||
|
||||
@ -226,19 +226,21 @@ empty = Person{}`
|
||||
t.Fatalf("expected 3 statements, got %d", len(program.Statements))
|
||||
}
|
||||
|
||||
structStmt := program.Statements[0].(*parser.StructStatement)
|
||||
|
||||
// Second statement: constructor with fields
|
||||
assign1, ok := program.Statements[1].(*parser.AssignStatement)
|
||||
assign1, ok := program.Statements[1].(*parser.Assignment)
|
||||
if !ok {
|
||||
t.Fatalf("expected AssignStatement, got %T", program.Statements[1])
|
||||
t.Fatalf("expected Assignment, got %T", program.Statements[1])
|
||||
}
|
||||
|
||||
constructor1, ok := assign1.Value.(*parser.StructConstructorExpression)
|
||||
constructor1, ok := assign1.Value.(*parser.StructConstructor)
|
||||
if !ok {
|
||||
t.Fatalf("expected StructConstructorExpression, got %T", assign1.Value)
|
||||
t.Fatalf("expected StructConstructor, got %T", assign1.Value)
|
||||
}
|
||||
|
||||
if constructor1.StructName != "Person" {
|
||||
t.Errorf("expected struct name 'Person', got %s", constructor1.StructName)
|
||||
if constructor1.StructID != structStmt.ID {
|
||||
t.Errorf("expected struct ID %d, got %d", structStmt.ID, constructor1.StructID)
|
||||
}
|
||||
|
||||
if len(constructor1.Fields) != 2 {
|
||||
@ -266,18 +268,18 @@ empty = Person{}`
|
||||
testNumberLiteral(t, constructor1.Fields[1].Value, 30)
|
||||
|
||||
// Third statement: empty constructor
|
||||
assign2, ok := program.Statements[2].(*parser.AssignStatement)
|
||||
assign2, ok := program.Statements[2].(*parser.Assignment)
|
||||
if !ok {
|
||||
t.Fatalf("expected AssignStatement, got %T", program.Statements[2])
|
||||
t.Fatalf("expected Assignment, got %T", program.Statements[2])
|
||||
}
|
||||
|
||||
constructor2, ok := assign2.Value.(*parser.StructConstructorExpression)
|
||||
constructor2, ok := assign2.Value.(*parser.StructConstructor)
|
||||
if !ok {
|
||||
t.Fatalf("expected StructConstructorExpression, got %T", assign2.Value)
|
||||
t.Fatalf("expected StructConstructor, got %T", assign2.Value)
|
||||
}
|
||||
|
||||
if constructor2.StructName != "Person" {
|
||||
t.Errorf("expected struct name 'Person', got %s", constructor2.StructName)
|
||||
if constructor2.StructID != structStmt.ID {
|
||||
t.Errorf("expected struct ID %d, got %d", structStmt.ID, constructor2.StructID)
|
||||
}
|
||||
|
||||
if len(constructor2.Fields) != 0 {
|
||||
@ -310,6 +312,8 @@ person = Person{
|
||||
t.Fatalf("expected 3 statements, got %d", len(program.Statements))
|
||||
}
|
||||
|
||||
addressStruct := program.Statements[0].(*parser.StructStatement)
|
||||
|
||||
// Check Person struct has Address field type
|
||||
personStruct, ok := program.Statements[1].(*parser.StructStatement)
|
||||
if !ok {
|
||||
@ -320,29 +324,32 @@ person = Person{
|
||||
if addressField.Name != "address" {
|
||||
t.Errorf("expected field name 'address', got %s", addressField.Name)
|
||||
}
|
||||
if addressField.TypeHint.Type != "Address" {
|
||||
t.Errorf("expected field type 'Address', got %s", addressField.TypeHint.Type)
|
||||
if addressField.TypeHint.Type != parser.TypeStruct {
|
||||
t.Errorf("expected field type struct, got %v", addressField.TypeHint.Type)
|
||||
}
|
||||
if addressField.TypeHint.StructID != addressStruct.ID {
|
||||
t.Errorf("expected struct ID %d, got %d", addressStruct.ID, addressField.TypeHint.StructID)
|
||||
}
|
||||
|
||||
// Check nested constructor
|
||||
assign, ok := program.Statements[2].(*parser.AssignStatement)
|
||||
assign, ok := program.Statements[2].(*parser.Assignment)
|
||||
if !ok {
|
||||
t.Fatalf("expected AssignStatement, got %T", program.Statements[2])
|
||||
t.Fatalf("expected Assignment, got %T", program.Statements[2])
|
||||
}
|
||||
|
||||
personConstructor, ok := assign.Value.(*parser.StructConstructorExpression)
|
||||
personConstructor, ok := assign.Value.(*parser.StructConstructor)
|
||||
if !ok {
|
||||
t.Fatalf("expected StructConstructorExpression, got %T", assign.Value)
|
||||
t.Fatalf("expected StructConstructor, got %T", assign.Value)
|
||||
}
|
||||
|
||||
// Check the nested Address constructor
|
||||
addressConstructor, ok := personConstructor.Fields[1].Value.(*parser.StructConstructorExpression)
|
||||
addressConstructor, ok := personConstructor.Fields[1].Value.(*parser.StructConstructor)
|
||||
if !ok {
|
||||
t.Fatalf("expected nested StructConstructorExpression, got %T", personConstructor.Fields[1].Value)
|
||||
t.Fatalf("expected nested StructConstructor, got %T", personConstructor.Fields[1].Value)
|
||||
}
|
||||
|
||||
if addressConstructor.StructName != "Address" {
|
||||
t.Errorf("expected nested struct name 'Address', got %s", addressConstructor.StructName)
|
||||
if addressConstructor.StructID != addressStruct.ID {
|
||||
t.Errorf("expected nested struct ID %d, got %d", addressStruct.ID, addressConstructor.StructID)
|
||||
}
|
||||
|
||||
if len(addressConstructor.Fields) != 2 {
|
||||
@ -397,8 +404,8 @@ end`
|
||||
if !ok {
|
||||
t.Fatalf("expected MethodDefinition, got %T", program.Statements[1])
|
||||
}
|
||||
if methodStmt.StructName != "Point" {
|
||||
t.Errorf("expected struct name 'Point', got %s", methodStmt.StructName)
|
||||
if methodStmt.StructID != structStmt.ID {
|
||||
t.Errorf("expected struct ID %d, got %d", structStmt.ID, methodStmt.StructID)
|
||||
}
|
||||
if methodStmt.MethodName != "distance" {
|
||||
t.Errorf("expected method name 'distance', got %s", methodStmt.MethodName)
|
||||
@ -406,16 +413,16 @@ end`
|
||||
|
||||
// Verify constructors
|
||||
for i := 2; i <= 3; i++ {
|
||||
assign, ok := program.Statements[i].(*parser.AssignStatement)
|
||||
assign, ok := program.Statements[i].(*parser.Assignment)
|
||||
if !ok {
|
||||
t.Fatalf("statement %d: expected AssignStatement, got %T", i, program.Statements[i])
|
||||
t.Fatalf("statement %d: expected Assignment, got %T", i, program.Statements[i])
|
||||
}
|
||||
constructor, ok := assign.Value.(*parser.StructConstructorExpression)
|
||||
constructor, ok := assign.Value.(*parser.StructConstructor)
|
||||
if !ok {
|
||||
t.Fatalf("statement %d: expected StructConstructorExpression, got %T", i, assign.Value)
|
||||
t.Fatalf("statement %d: expected StructConstructor, got %T", i, assign.Value)
|
||||
}
|
||||
if constructor.StructName != "Point" {
|
||||
t.Errorf("statement %d: expected struct name 'Point', got %s", i, constructor.StructName)
|
||||
if constructor.StructID != structStmt.ID {
|
||||
t.Errorf("statement %d: expected struct ID %d, got %d", i, structStmt.ID, constructor.StructID)
|
||||
}
|
||||
}
|
||||
|
||||
@ -446,16 +453,16 @@ end`
|
||||
}
|
||||
|
||||
// Check struct constructor in loop
|
||||
loopAssign, ok := forStmt.Body[0].(*parser.AssignStatement)
|
||||
loopAssign, ok := forStmt.Body[0].(*parser.Assignment)
|
||||
if !ok {
|
||||
t.Fatalf("expected AssignStatement in loop, got %T", forStmt.Body[0])
|
||||
t.Fatalf("expected Assignment in loop, got %T", forStmt.Body[0])
|
||||
}
|
||||
loopConstructor, ok := loopAssign.Value.(*parser.StructConstructorExpression)
|
||||
loopConstructor, ok := loopAssign.Value.(*parser.StructConstructor)
|
||||
if !ok {
|
||||
t.Fatalf("expected StructConstructorExpression in loop, got %T", loopAssign.Value)
|
||||
t.Fatalf("expected StructConstructor in loop, got %T", loopAssign.Value)
|
||||
}
|
||||
if loopConstructor.StructName != "Point" {
|
||||
t.Errorf("expected struct name 'Point' in loop, got %s", loopConstructor.StructName)
|
||||
if loopConstructor.StructID != structStmt.ID {
|
||||
t.Errorf("expected struct ID %d in loop, got %d", structStmt.ID, loopConstructor.StructID)
|
||||
}
|
||||
}
|
||||
|
||||
@ -552,13 +559,13 @@ func TestSingleLineStruct(t *testing.T) {
|
||||
t.Fatalf("expected 2 fields, got %d", len(stmt.Fields))
|
||||
}
|
||||
|
||||
if stmt.Fields[0].Name != "name" || stmt.Fields[0].TypeHint.Type != "string" {
|
||||
t.Errorf("expected first field 'name: string', got '%s: %s'",
|
||||
if stmt.Fields[0].Name != "name" || stmt.Fields[0].TypeHint.Type != parser.TypeString {
|
||||
t.Errorf("expected first field 'name: string', got '%s: %v'",
|
||||
stmt.Fields[0].Name, stmt.Fields[0].TypeHint.Type)
|
||||
}
|
||||
|
||||
if stmt.Fields[1].Name != "age" || stmt.Fields[1].TypeHint.Type != "number" {
|
||||
t.Errorf("expected second field 'age: number', got '%s: %s'",
|
||||
if stmt.Fields[1].Name != "age" || stmt.Fields[1].TypeHint.Type != parser.TypeNumber {
|
||||
t.Errorf("expected second field 'age: number', got '%s: %v'",
|
||||
stmt.Fields[1].Name, stmt.Fields[1].TypeHint.Type)
|
||||
}
|
||||
}
|
||||
@ -600,8 +607,8 @@ end`
|
||||
method := program.Statements[1].(*parser.MethodDefinition)
|
||||
str := method.String()
|
||||
|
||||
if !containsSubstring(str, "fn Person.getName") {
|
||||
t.Errorf("expected method string to contain 'fn Person.getName', got: %s", str)
|
||||
if !containsSubstring(str, "fn <struct>.getName") {
|
||||
t.Errorf("expected method string to contain 'fn <struct>.getName', got: %s", str)
|
||||
}
|
||||
if !containsSubstring(str, ": string") {
|
||||
t.Errorf("expected method string to contain return type, got: %s", str)
|
||||
@ -621,11 +628,11 @@ person = Person{name = "John", age = 30}`
|
||||
program := p.ParseProgram()
|
||||
checkParserErrors(t, p)
|
||||
|
||||
assign := program.Statements[1].(*parser.AssignStatement)
|
||||
constructor := assign.Value.(*parser.StructConstructorExpression)
|
||||
assign := program.Statements[1].(*parser.Assignment)
|
||||
constructor := assign.Value.(*parser.StructConstructor)
|
||||
str := constructor.String()
|
||||
|
||||
expected := `Person{name = "John", age = 30.00}`
|
||||
expected := `<struct>{name = "John", age = 30.00}`
|
||||
if str != expected {
|
||||
t.Errorf("expected constructor string:\n%s\ngot:\n%s", expected, str)
|
||||
}
|
||||
|
@ -10,18 +10,18 @@ func TestVariableTypeHints(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
variable string
|
||||
typeHint string
|
||||
typeHint parser.Type
|
||||
hasHint bool
|
||||
desc string
|
||||
}{
|
||||
{"x = 42", "x", "", false, "no type hint"},
|
||||
{"x: number = 42", "x", "number", true, "number type hint"},
|
||||
{"name: string = \"hello\"", "name", "string", true, "string type hint"},
|
||||
{"flag: bool = true", "flag", "bool", true, "bool type hint"},
|
||||
{"data: table = {}", "data", "table", true, "table type hint"},
|
||||
{"fn_var: function = fn() end", "fn_var", "function", true, "function type hint"},
|
||||
{"value: any = nil", "value", "any", true, "any type hint"},
|
||||
{"ptr: nil = nil", "ptr", "nil", true, "nil type hint"},
|
||||
{"x = 42", "x", parser.TypeUnknown, false, "no type hint"},
|
||||
{"x: number = 42", "x", parser.TypeNumber, true, "number type hint"},
|
||||
{"name: string = \"hello\"", "name", parser.TypeString, true, "string type hint"},
|
||||
{"flag: bool = true", "flag", parser.TypeBool, true, "bool type hint"},
|
||||
{"data: table = {}", "data", parser.TypeTable, true, "table type hint"},
|
||||
{"fn_var: function = fn() end", "fn_var", parser.TypeFunction, true, "function type hint"},
|
||||
{"value: any = nil", "value", parser.TypeAny, true, "any type hint"},
|
||||
{"ptr: nil = nil", "ptr", parser.TypeNil, true, "nil type hint"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@ -35,15 +35,15 @@ func TestVariableTypeHints(t *testing.T) {
|
||||
t.Fatalf("expected 1 statement, got %d", len(program.Statements))
|
||||
}
|
||||
|
||||
stmt, ok := program.Statements[0].(*parser.AssignStatement)
|
||||
stmt, ok := program.Statements[0].(*parser.Assignment)
|
||||
if !ok {
|
||||
t.Fatalf("expected AssignStatement, got %T", program.Statements[0])
|
||||
t.Fatalf("expected Assignment, got %T", program.Statements[0])
|
||||
}
|
||||
|
||||
// Check variable name
|
||||
ident, ok := stmt.Name.(*parser.Identifier)
|
||||
ident, ok := stmt.Target.(*parser.Identifier)
|
||||
if !ok {
|
||||
t.Fatalf("expected Identifier for Name, got %T", stmt.Name)
|
||||
t.Fatalf("expected Identifier for Target, got %T", stmt.Target)
|
||||
}
|
||||
|
||||
if ident.Value != tt.variable {
|
||||
@ -52,19 +52,19 @@ func TestVariableTypeHints(t *testing.T) {
|
||||
|
||||
// Check type hint
|
||||
if tt.hasHint {
|
||||
if stmt.TypeHint == nil {
|
||||
t.Error("expected type hint but got nil")
|
||||
if stmt.TypeHint.Type == parser.TypeUnknown {
|
||||
t.Error("expected type hint but got TypeUnknown")
|
||||
} else {
|
||||
if stmt.TypeHint.Type != tt.typeHint {
|
||||
t.Errorf("expected type hint %s, got %s", tt.typeHint, stmt.TypeHint.Type)
|
||||
t.Errorf("expected type hint %v, got %v", tt.typeHint, stmt.TypeHint.Type)
|
||||
}
|
||||
if stmt.TypeHint.Inferred {
|
||||
t.Error("expected type hint to not be inferred")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if stmt.TypeHint != nil {
|
||||
t.Errorf("expected no type hint but got %s", stmt.TypeHint.Type)
|
||||
if stmt.TypeHint.Type != parser.TypeUnknown {
|
||||
t.Errorf("expected no type hint but got %v", stmt.TypeHint.Type)
|
||||
}
|
||||
}
|
||||
})
|
||||
@ -74,60 +74,81 @@ func TestVariableTypeHints(t *testing.T) {
|
||||
func TestFunctionParameterTypeHints(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
params []struct{ name, typeHint string }
|
||||
returnType string
|
||||
params []struct {
|
||||
name string
|
||||
typeHint parser.Type
|
||||
}
|
||||
returnType parser.Type
|
||||
hasReturn bool
|
||||
desc string
|
||||
}{
|
||||
{
|
||||
"fn(a, b) end",
|
||||
[]struct{ name, typeHint string }{
|
||||
{"a", ""},
|
||||
{"b", ""},
|
||||
[]struct {
|
||||
name string
|
||||
typeHint parser.Type
|
||||
}{
|
||||
{"a", parser.TypeUnknown},
|
||||
{"b", parser.TypeUnknown},
|
||||
},
|
||||
"", false,
|
||||
parser.TypeUnknown, false,
|
||||
"no type hints",
|
||||
},
|
||||
{
|
||||
"fn(a: number, b: string) end",
|
||||
[]struct{ name, typeHint string }{
|
||||
{"a", "number"},
|
||||
{"b", "string"},
|
||||
[]struct {
|
||||
name string
|
||||
typeHint parser.Type
|
||||
}{
|
||||
{"a", parser.TypeNumber},
|
||||
{"b", parser.TypeString},
|
||||
},
|
||||
"", false,
|
||||
parser.TypeUnknown, false,
|
||||
"parameter type hints only",
|
||||
},
|
||||
{
|
||||
"fn(x: number): string end",
|
||||
[]struct{ name, typeHint string }{
|
||||
{"x", "number"},
|
||||
[]struct {
|
||||
name string
|
||||
typeHint parser.Type
|
||||
}{
|
||||
{"x", parser.TypeNumber},
|
||||
},
|
||||
"string", true,
|
||||
parser.TypeString, true,
|
||||
"parameter and return type hints",
|
||||
},
|
||||
{
|
||||
"fn(): bool end",
|
||||
[]struct{ name, typeHint string }{},
|
||||
"bool", true,
|
||||
[]struct {
|
||||
name string
|
||||
typeHint parser.Type
|
||||
}{},
|
||||
parser.TypeBool, true,
|
||||
"return type hint only",
|
||||
},
|
||||
{
|
||||
"fn(a: number, b, c: string): table end",
|
||||
[]struct{ name, typeHint string }{
|
||||
{"a", "number"},
|
||||
{"b", ""},
|
||||
{"c", "string"},
|
||||
[]struct {
|
||||
name string
|
||||
typeHint parser.Type
|
||||
}{
|
||||
{"a", parser.TypeNumber},
|
||||
{"b", parser.TypeUnknown},
|
||||
{"c", parser.TypeString},
|
||||
},
|
||||
"table", true,
|
||||
parser.TypeTable, true,
|
||||
"mixed parameter types with return",
|
||||
},
|
||||
{
|
||||
"fn(callback: function, data: any): nil end",
|
||||
[]struct{ name, typeHint string }{
|
||||
{"callback", "function"},
|
||||
{"data", "any"},
|
||||
[]struct {
|
||||
name string
|
||||
typeHint parser.Type
|
||||
}{
|
||||
{"callback", parser.TypeFunction},
|
||||
{"data", parser.TypeAny},
|
||||
},
|
||||
"nil", true,
|
||||
parser.TypeNil, true,
|
||||
"function and any types",
|
||||
},
|
||||
}
|
||||
@ -155,29 +176,29 @@ func TestFunctionParameterTypeHints(t *testing.T) {
|
||||
t.Errorf("parameter %d: expected name %s, got %s", i, expected.name, param.Name)
|
||||
}
|
||||
|
||||
if expected.typeHint == "" {
|
||||
if param.TypeHint != nil {
|
||||
t.Errorf("parameter %d: expected no type hint but got %s", i, param.TypeHint.Type)
|
||||
if expected.typeHint == parser.TypeUnknown {
|
||||
if param.TypeHint.Type != parser.TypeUnknown {
|
||||
t.Errorf("parameter %d: expected no type hint but got %v", i, param.TypeHint.Type)
|
||||
}
|
||||
} else {
|
||||
if param.TypeHint == nil {
|
||||
t.Errorf("parameter %d: expected type hint %s but got nil", i, expected.typeHint)
|
||||
if param.TypeHint.Type == parser.TypeUnknown {
|
||||
t.Errorf("parameter %d: expected type hint %v but got TypeUnknown", i, expected.typeHint)
|
||||
} else if param.TypeHint.Type != expected.typeHint {
|
||||
t.Errorf("parameter %d: expected type hint %s, got %s", i, expected.typeHint, param.TypeHint.Type)
|
||||
t.Errorf("parameter %d: expected type hint %v, got %v", i, expected.typeHint, param.TypeHint.Type)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check return type
|
||||
if tt.hasReturn {
|
||||
if fn.ReturnType == nil {
|
||||
t.Error("expected return type hint but got nil")
|
||||
if fn.ReturnType.Type == parser.TypeUnknown {
|
||||
t.Error("expected return type hint but got TypeUnknown")
|
||||
} else if fn.ReturnType.Type != tt.returnType {
|
||||
t.Errorf("expected return type %s, got %s", tt.returnType, fn.ReturnType.Type)
|
||||
t.Errorf("expected return type %v, got %v", tt.returnType, fn.ReturnType.Type)
|
||||
}
|
||||
} else {
|
||||
if fn.ReturnType != nil {
|
||||
t.Errorf("expected no return type but got %s", fn.ReturnType.Type)
|
||||
if fn.ReturnType.Type != parser.TypeUnknown {
|
||||
t.Errorf("expected no return type but got %v", fn.ReturnType.Type)
|
||||
}
|
||||
}
|
||||
})
|
||||
@ -279,13 +300,13 @@ func TestMemberAccessWithoutTypeHints(t *testing.T) {
|
||||
t.Fatalf("expected 1 statement, got %d", len(program.Statements))
|
||||
}
|
||||
|
||||
stmt, ok := program.Statements[0].(*parser.AssignStatement)
|
||||
stmt, ok := program.Statements[0].(*parser.Assignment)
|
||||
if !ok {
|
||||
t.Fatalf("expected AssignStatement, got %T", program.Statements[0])
|
||||
t.Fatalf("expected Assignment, got %T", program.Statements[0])
|
||||
}
|
||||
|
||||
// Member access should never have type hints
|
||||
if stmt.TypeHint != nil {
|
||||
if stmt.TypeHint.Type != parser.TypeUnknown {
|
||||
t.Error("member access assignment should not have type hints")
|
||||
}
|
||||
|
||||
@ -333,12 +354,12 @@ func TestTypeInferenceErrors(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
"x: number = \"hello\"",
|
||||
"cannot assign string to variable of type number",
|
||||
"type mismatch in assignment",
|
||||
"type mismatch in assignment",
|
||||
},
|
||||
{
|
||||
"x = 42\ny: string = x",
|
||||
"cannot assign number to variable of type string",
|
||||
"type mismatch in assignment",
|
||||
"type mismatch with inferred type",
|
||||
},
|
||||
}
|
||||
@ -359,7 +380,7 @@ func TestTypeInferenceErrors(t *testing.T) {
|
||||
|
||||
found := false
|
||||
for _, err := range typeErrors {
|
||||
if err.Message == tt.expectedError {
|
||||
if containsSubstring(err.Message, tt.expectedError) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
@ -370,7 +391,7 @@ func TestTypeInferenceErrors(t *testing.T) {
|
||||
for i, err := range typeErrors {
|
||||
errorMsgs[i] = err.Message
|
||||
}
|
||||
t.Errorf("expected error %q, got %v", tt.expectedError, errorMsgs)
|
||||
t.Errorf("expected error containing %q, got %v", tt.expectedError, errorMsgs)
|
||||
}
|
||||
})
|
||||
}
|
||||
@ -439,22 +460,22 @@ server: table = {
|
||||
}
|
||||
|
||||
// Check first statement: config table with typed assignments
|
||||
configStmt, ok := program.Statements[0].(*parser.AssignStatement)
|
||||
configStmt, ok := program.Statements[0].(*parser.Assignment)
|
||||
if !ok {
|
||||
t.Fatalf("expected AssignStatement, got %T", program.Statements[0])
|
||||
t.Fatalf("expected Assignment, got %T", program.Statements[0])
|
||||
}
|
||||
|
||||
if configStmt.TypeHint == nil || configStmt.TypeHint.Type != "table" {
|
||||
if configStmt.TypeHint.Type == parser.TypeUnknown || configStmt.TypeHint.Type != parser.TypeTable {
|
||||
t.Error("expected table type hint for config")
|
||||
}
|
||||
|
||||
// Check second statement: handler function with typed parameters
|
||||
handlerStmt, ok := program.Statements[1].(*parser.AssignStatement)
|
||||
handlerStmt, ok := program.Statements[1].(*parser.Assignment)
|
||||
if !ok {
|
||||
t.Fatalf("expected AssignStatement, got %T", program.Statements[1])
|
||||
t.Fatalf("expected Assignment, got %T", program.Statements[1])
|
||||
}
|
||||
|
||||
if handlerStmt.TypeHint == nil || handlerStmt.TypeHint.Type != "function" {
|
||||
if handlerStmt.TypeHint.Type == parser.TypeUnknown || handlerStmt.TypeHint.Type != parser.TypeFunction {
|
||||
t.Error("expected function type hint for handler")
|
||||
}
|
||||
|
||||
@ -468,34 +489,32 @@ server: table = {
|
||||
}
|
||||
|
||||
// Check parameter types
|
||||
if fn.Parameters[0].TypeHint == nil || fn.Parameters[0].TypeHint.Type != "table" {
|
||||
if fn.Parameters[0].TypeHint.Type == parser.TypeUnknown || fn.Parameters[0].TypeHint.Type != parser.TypeTable {
|
||||
t.Error("expected table type for request parameter")
|
||||
}
|
||||
|
||||
if fn.Parameters[1].TypeHint == nil || fn.Parameters[1].TypeHint.Type != "function" {
|
||||
if fn.Parameters[1].TypeHint.Type == parser.TypeUnknown || fn.Parameters[1].TypeHint.Type != parser.TypeFunction {
|
||||
t.Error("expected function type for callback parameter")
|
||||
}
|
||||
|
||||
// Check return type
|
||||
if fn.ReturnType == nil || fn.ReturnType.Type != "nil" {
|
||||
if fn.ReturnType.Type == parser.TypeUnknown || fn.ReturnType.Type != parser.TypeNil {
|
||||
t.Error("expected nil return type for handler")
|
||||
}
|
||||
|
||||
// Check third statement: server table
|
||||
serverStmt, ok := program.Statements[2].(*parser.AssignStatement)
|
||||
serverStmt, ok := program.Statements[2].(*parser.Assignment)
|
||||
if !ok {
|
||||
t.Fatalf("expected AssignStatement, got %T", program.Statements[2])
|
||||
t.Fatalf("expected Assignment, got %T", program.Statements[2])
|
||||
}
|
||||
|
||||
if serverStmt.TypeHint == nil || serverStmt.TypeHint.Type != "table" {
|
||||
if serverStmt.TypeHint.Type == parser.TypeUnknown || serverStmt.TypeHint.Type != parser.TypeTable {
|
||||
t.Error("expected table type hint for server")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTypeInfoGettersSetters(t *testing.T) {
|
||||
// Test that all expression types properly implement GetType/SetType
|
||||
typeInfo := &parser.TypeInfo{Type: "test", Inferred: true}
|
||||
|
||||
func TestTypeInfoInterface(t *testing.T) {
|
||||
// Test that all expression types properly implement TypeInfo()
|
||||
expressions := []parser.Expression{
|
||||
&parser.Identifier{Value: "x"},
|
||||
&parser.NumberLiteral{Value: 42},
|
||||
@ -513,20 +532,43 @@ func TestTypeInfoGettersSetters(t *testing.T) {
|
||||
|
||||
for i, expr := range expressions {
|
||||
t.Run(string(rune('0'+i)), func(t *testing.T) {
|
||||
// Initially should have no type
|
||||
if expr.GetType() != nil {
|
||||
t.Error("expected nil type initially")
|
||||
// Should have default type initially
|
||||
typeInfo := expr.TypeInfo()
|
||||
|
||||
// Basic literals should have their expected types
|
||||
switch e := expr.(type) {
|
||||
case *parser.NumberLiteral:
|
||||
if typeInfo.Type != parser.TypeNumber {
|
||||
t.Errorf("expected number type, got %v", typeInfo.Type)
|
||||
}
|
||||
|
||||
// Set type
|
||||
expr.SetType(typeInfo)
|
||||
|
||||
// Get type should return what we set
|
||||
retrieved := expr.GetType()
|
||||
if retrieved == nil {
|
||||
t.Error("expected non-nil type after setting")
|
||||
} else if retrieved.Type != "test" || !retrieved.Inferred {
|
||||
t.Errorf("expected {Type: test, Inferred: true}, got %+v", retrieved)
|
||||
case *parser.StringLiteral:
|
||||
if typeInfo.Type != parser.TypeString {
|
||||
t.Errorf("expected string type, got %v", typeInfo.Type)
|
||||
}
|
||||
case *parser.BooleanLiteral:
|
||||
if typeInfo.Type != parser.TypeBool {
|
||||
t.Errorf("expected bool type, got %v", typeInfo.Type)
|
||||
}
|
||||
case *parser.NilLiteral:
|
||||
if typeInfo.Type != parser.TypeNil {
|
||||
t.Errorf("expected nil type, got %v", typeInfo.Type)
|
||||
}
|
||||
case *parser.TableLiteral:
|
||||
if typeInfo.Type != parser.TypeTable {
|
||||
t.Errorf("expected table type, got %v", typeInfo.Type)
|
||||
}
|
||||
case *parser.FunctionLiteral:
|
||||
if typeInfo.Type != parser.TypeFunction {
|
||||
t.Errorf("expected function type, got %v", typeInfo.Type)
|
||||
}
|
||||
case *parser.Identifier:
|
||||
// Identifiers default to any type
|
||||
if typeInfo.Type != parser.TypeAny {
|
||||
t.Errorf("expected any type for untyped identifier, got %v", typeInfo.Type)
|
||||
}
|
||||
default:
|
||||
// Other expressions may have unknown type initially
|
||||
_ = e
|
||||
}
|
||||
})
|
||||
}
|
||||
|
600
parser/types.go
600
parser/types.go
@ -1,21 +1,45 @@
|
||||
package parser
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
import "fmt"
|
||||
|
||||
// Type represents built-in and user-defined types using compact enum representation.
|
||||
// Uses single byte instead of string pointers to minimize memory usage.
|
||||
type Type uint8
|
||||
|
||||
// Type constants for built-in types
|
||||
const (
|
||||
TypeNumber = "number"
|
||||
TypeString = "string"
|
||||
TypeBool = "bool"
|
||||
TypeNil = "nil"
|
||||
TypeTable = "table"
|
||||
TypeFunction = "function"
|
||||
TypeAny = "any"
|
||||
TypeUnknown Type = iota
|
||||
TypeNumber
|
||||
TypeString
|
||||
TypeBool
|
||||
TypeNil
|
||||
TypeTable
|
||||
TypeFunction
|
||||
TypeAny
|
||||
TypeStruct // struct types use StructID field for identification
|
||||
)
|
||||
|
||||
// TypeError represents a type checking error
|
||||
// TypeInfo represents type information with zero-allocation design.
|
||||
// Embeds directly in AST nodes instead of using pointers to reduce heap pressure.
|
||||
// Common types are pre-allocated as globals to eliminate most allocations.
|
||||
type TypeInfo struct {
|
||||
Type Type // Built-in type or TypeStruct for user types
|
||||
Inferred bool // True if type was inferred, false if explicitly declared
|
||||
StructID uint16 // Index into global struct table for struct types (0 for non-structs)
|
||||
}
|
||||
|
||||
// Pre-allocated common types - eliminates heap allocations for built-in types
|
||||
var (
|
||||
UnknownType = TypeInfo{Type: TypeUnknown, Inferred: true}
|
||||
NumberType = TypeInfo{Type: TypeNumber, Inferred: true}
|
||||
StringType = TypeInfo{Type: TypeString, Inferred: true}
|
||||
BoolType = TypeInfo{Type: TypeBool, Inferred: true}
|
||||
NilType = TypeInfo{Type: TypeNil, Inferred: true}
|
||||
TableType = TypeInfo{Type: TypeTable, Inferred: true}
|
||||
FunctionType = TypeInfo{Type: TypeFunction, Inferred: true}
|
||||
AnyType = TypeInfo{Type: TypeAny, Inferred: true}
|
||||
)
|
||||
|
||||
// TypeError represents a type checking error with location information
|
||||
type TypeError struct {
|
||||
Message string
|
||||
Line int
|
||||
@ -27,10 +51,10 @@ func (te TypeError) Error() string {
|
||||
return fmt.Sprintf("Type error at line %d, column %d: %s", te.Line, te.Column, te.Message)
|
||||
}
|
||||
|
||||
// Symbol represents a variable in the symbol table
|
||||
// Symbol represents a variable in the symbol table with optimized type storage
|
||||
type Symbol struct {
|
||||
Name string
|
||||
Type *TypeInfo
|
||||
Type TypeInfo // Embed directly instead of pointer
|
||||
Declared bool
|
||||
Line int
|
||||
Column int
|
||||
@ -63,52 +87,56 @@ func (s *Scope) Lookup(name string) *Symbol {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TypeInferrer performs type inference and checking
|
||||
// TypeInferrer performs type inference and checking with optimized allocations
|
||||
type TypeInferrer struct {
|
||||
currentScope *Scope
|
||||
globalScope *Scope
|
||||
errors []TypeError
|
||||
|
||||
// Pre-allocated type objects for performance
|
||||
numberType *TypeInfo
|
||||
stringType *TypeInfo
|
||||
boolType *TypeInfo
|
||||
nilType *TypeInfo
|
||||
tableType *TypeInfo
|
||||
anyType *TypeInfo
|
||||
|
||||
// Struct definitions
|
||||
// Struct definitions with ID mapping
|
||||
structs map[string]*StructStatement
|
||||
structIDs map[uint16]*StructStatement
|
||||
nextID uint16
|
||||
}
|
||||
|
||||
// NewTypeInferrer creates a new type inference engine
|
||||
func NewTypeInferrer() *TypeInferrer {
|
||||
globalScope := NewScope(nil)
|
||||
|
||||
ti := &TypeInferrer{
|
||||
return &TypeInferrer{
|
||||
currentScope: globalScope,
|
||||
globalScope: globalScope,
|
||||
errors: []TypeError{},
|
||||
structs: make(map[string]*StructStatement),
|
||||
|
||||
// Pre-allocate common types to reduce allocations
|
||||
numberType: &TypeInfo{Type: TypeNumber, Inferred: true},
|
||||
stringType: &TypeInfo{Type: TypeString, Inferred: true},
|
||||
boolType: &TypeInfo{Type: TypeBool, Inferred: true},
|
||||
nilType: &TypeInfo{Type: TypeNil, Inferred: true},
|
||||
tableType: &TypeInfo{Type: TypeTable, Inferred: true},
|
||||
anyType: &TypeInfo{Type: TypeAny, Inferred: true},
|
||||
structIDs: make(map[uint16]*StructStatement),
|
||||
nextID: 1, // 0 reserved for non-struct types
|
||||
}
|
||||
}
|
||||
|
||||
return ti
|
||||
// RegisterStruct assigns ID to struct and tracks it
|
||||
func (ti *TypeInferrer) RegisterStruct(stmt *StructStatement) {
|
||||
stmt.ID = ti.nextID
|
||||
ti.nextID++
|
||||
ti.structs[stmt.Name] = stmt
|
||||
ti.structIDs[stmt.ID] = stmt
|
||||
}
|
||||
|
||||
// GetStructByID returns struct definition by ID
|
||||
func (ti *TypeInferrer) GetStructByID(id uint16) *StructStatement {
|
||||
return ti.structIDs[id]
|
||||
}
|
||||
|
||||
// CreateStructType creates TypeInfo for a struct
|
||||
func (ti *TypeInferrer) CreateStructType(structID uint16) TypeInfo {
|
||||
return TypeInfo{Type: TypeStruct, StructID: structID, Inferred: true}
|
||||
}
|
||||
|
||||
// InferTypes performs type inference on the entire program
|
||||
func (ti *TypeInferrer) InferTypes(program *Program) []TypeError {
|
||||
// First pass: collect struct definitions
|
||||
// First pass: collect and register struct definitions
|
||||
for _, stmt := range program.Statements {
|
||||
if structStmt, ok := stmt.(*StructStatement); ok {
|
||||
ti.structs[structStmt.Name] = structStmt
|
||||
ti.RegisterStruct(structStmt)
|
||||
}
|
||||
}
|
||||
|
||||
@ -119,42 +147,6 @@ func (ti *TypeInferrer) InferTypes(program *Program) []TypeError {
|
||||
return ti.errors
|
||||
}
|
||||
|
||||
// enterScope creates a new scope
|
||||
func (ti *TypeInferrer) enterScope() {
|
||||
ti.currentScope = NewScope(ti.currentScope)
|
||||
}
|
||||
|
||||
// exitScope returns to the parent scope
|
||||
func (ti *TypeInferrer) exitScope() {
|
||||
if ti.currentScope.parent != nil {
|
||||
ti.currentScope = ti.currentScope.parent
|
||||
}
|
||||
}
|
||||
|
||||
// addError adds a type error
|
||||
func (ti *TypeInferrer) addError(message string, node Node) {
|
||||
ti.errors = append(ti.errors, TypeError{
|
||||
Message: message,
|
||||
Line: 0, // Would need to track position in AST nodes
|
||||
Column: 0,
|
||||
Node: node,
|
||||
})
|
||||
}
|
||||
|
||||
// getStructType returns TypeInfo for a struct
|
||||
func (ti *TypeInferrer) getStructType(name string) *TypeInfo {
|
||||
if _, exists := ti.structs[name]; exists {
|
||||
return &TypeInfo{Type: name, Inferred: true}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// isStructType checks if a type is a struct type
|
||||
func (ti *TypeInferrer) isStructType(t *TypeInfo) bool {
|
||||
_, exists := ti.structs[t.Type]
|
||||
return exists
|
||||
}
|
||||
|
||||
// inferStatement infers types for statements
|
||||
func (ti *TypeInferrer) inferStatement(stmt Statement) {
|
||||
switch s := stmt.(type) {
|
||||
@ -162,8 +154,10 @@ func (ti *TypeInferrer) inferStatement(stmt Statement) {
|
||||
ti.inferStructStatement(s)
|
||||
case *MethodDefinition:
|
||||
ti.inferMethodDefinition(s)
|
||||
case *AssignStatement:
|
||||
ti.inferAssignStatement(s)
|
||||
case *Assignment:
|
||||
ti.inferAssignment(s)
|
||||
case *ExpressionStatement:
|
||||
ti.inferExpression(s.Expression)
|
||||
case *EchoStatement:
|
||||
ti.inferExpression(s.Value)
|
||||
case *IfStatement:
|
||||
@ -182,46 +176,38 @@ func (ti *TypeInferrer) inferStatement(stmt Statement) {
|
||||
if s.Value != nil {
|
||||
ti.inferExpression(s.Value)
|
||||
}
|
||||
case *ExpressionStatement:
|
||||
ti.inferExpression(s.Expression)
|
||||
case *BreakStatement:
|
||||
// No-op
|
||||
}
|
||||
}
|
||||
|
||||
// inferStructStatement handles struct definitions
|
||||
func (ti *TypeInferrer) inferStructStatement(stmt *StructStatement) {
|
||||
// Validate field types
|
||||
for _, field := range stmt.Fields {
|
||||
if field.TypeHint != nil {
|
||||
if !ValidTypeName(field.TypeHint.Type) && !ti.isStructType(field.TypeHint) {
|
||||
ti.addError(fmt.Sprintf("invalid field type '%s' in struct '%s'",
|
||||
field.TypeHint.Type, stmt.Name), stmt)
|
||||
}
|
||||
if !ti.isValidType(field.TypeHint) {
|
||||
ti.addError(fmt.Sprintf("invalid field type in struct '%s'", stmt.Name), stmt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// inferMethodDefinition handles method definitions
|
||||
func (ti *TypeInferrer) inferMethodDefinition(stmt *MethodDefinition) {
|
||||
// Check if struct exists
|
||||
if _, exists := ti.structs[stmt.StructName]; !exists {
|
||||
ti.addError(fmt.Sprintf("method defined on undefined struct '%s'", stmt.StructName), stmt)
|
||||
structDef := ti.GetStructByID(stmt.StructID)
|
||||
if structDef == nil {
|
||||
ti.addError("method defined on undefined struct", stmt)
|
||||
return
|
||||
}
|
||||
|
||||
// Infer the function body
|
||||
ti.enterScope()
|
||||
|
||||
// Add self parameter implicitly
|
||||
// Add self parameter
|
||||
ti.currentScope.Define(&Symbol{
|
||||
Name: "self",
|
||||
Type: ti.getStructType(stmt.StructName),
|
||||
Type: ti.CreateStructType(stmt.StructID),
|
||||
Declared: true,
|
||||
})
|
||||
|
||||
// Add explicit parameters
|
||||
// Add function parameters
|
||||
for _, param := range stmt.Function.Parameters {
|
||||
paramType := ti.anyType
|
||||
if param.TypeHint != nil {
|
||||
paramType := AnyType
|
||||
if param.TypeHint.Type != TypeUnknown {
|
||||
paramType = param.TypeHint
|
||||
}
|
||||
ti.currentScope.Define(&Symbol{
|
||||
@ -235,66 +221,47 @@ func (ti *TypeInferrer) inferMethodDefinition(stmt *MethodDefinition) {
|
||||
for _, bodyStmt := range stmt.Function.Body {
|
||||
ti.inferStatement(bodyStmt)
|
||||
}
|
||||
|
||||
ti.exitScope()
|
||||
}
|
||||
|
||||
// inferAssignStatement handles variable assignments with type checking
|
||||
func (ti *TypeInferrer) inferAssignStatement(stmt *AssignStatement) {
|
||||
// Infer the type of the value expression
|
||||
func (ti *TypeInferrer) inferAssignment(stmt *Assignment) {
|
||||
valueType := ti.inferExpression(stmt.Value)
|
||||
|
||||
if ident, ok := stmt.Name.(*Identifier); ok {
|
||||
// Simple variable assignment
|
||||
symbol := ti.currentScope.Lookup(ident.Value)
|
||||
|
||||
if ident, ok := stmt.Target.(*Identifier); ok {
|
||||
if stmt.IsDeclaration {
|
||||
// New variable declaration
|
||||
varType := valueType
|
||||
|
||||
// If there's a type hint, validate it
|
||||
if stmt.TypeHint != nil {
|
||||
if stmt.TypeHint.Type != TypeUnknown {
|
||||
if !ti.isTypeCompatible(valueType, stmt.TypeHint) {
|
||||
ti.addError(fmt.Sprintf("cannot assign %s to variable of type %s",
|
||||
valueType.Type, stmt.TypeHint.Type), stmt)
|
||||
ti.addError("type mismatch in assignment", stmt)
|
||||
}
|
||||
varType = stmt.TypeHint
|
||||
varType.Inferred = false
|
||||
}
|
||||
|
||||
// Define the new symbol
|
||||
ti.currentScope.Define(&Symbol{
|
||||
Name: ident.Value,
|
||||
Type: varType,
|
||||
Declared: true,
|
||||
})
|
||||
|
||||
ident.SetType(varType)
|
||||
ident.typeInfo = varType
|
||||
} else {
|
||||
// Assignment to existing variable
|
||||
symbol := ti.currentScope.Lookup(ident.Value)
|
||||
if symbol == nil {
|
||||
ti.addError(fmt.Sprintf("undefined variable '%s'", ident.Value), stmt)
|
||||
return
|
||||
}
|
||||
|
||||
// Check type compatibility
|
||||
if !ti.isTypeCompatible(valueType, symbol.Type) {
|
||||
ti.addError(fmt.Sprintf("cannot assign %s to variable of type %s",
|
||||
valueType.Type, symbol.Type.Type), stmt)
|
||||
ti.addError("type mismatch in assignment", stmt)
|
||||
}
|
||||
|
||||
ident.SetType(symbol.Type)
|
||||
ident.typeInfo = symbol.Type
|
||||
}
|
||||
} else {
|
||||
// Member access assignment (table.key or table[index])
|
||||
ti.inferExpression(stmt.Name)
|
||||
ti.inferExpression(stmt.Target)
|
||||
}
|
||||
}
|
||||
|
||||
// inferIfStatement handles if statements
|
||||
func (ti *TypeInferrer) inferIfStatement(stmt *IfStatement) {
|
||||
condType := ti.inferExpression(stmt.Condition)
|
||||
ti.validateBooleanContext(condType, stmt.Condition)
|
||||
ti.inferExpression(stmt.Condition)
|
||||
|
||||
ti.enterScope()
|
||||
for _, s := range stmt.Body {
|
||||
@ -303,9 +270,7 @@ func (ti *TypeInferrer) inferIfStatement(stmt *IfStatement) {
|
||||
ti.exitScope()
|
||||
|
||||
for _, elseif := range stmt.ElseIfs {
|
||||
condType := ti.inferExpression(elseif.Condition)
|
||||
ti.validateBooleanContext(condType, elseif.Condition)
|
||||
|
||||
ti.inferExpression(elseif.Condition)
|
||||
ti.enterScope()
|
||||
for _, s := range elseif.Body {
|
||||
ti.inferStatement(s)
|
||||
@ -322,10 +287,8 @@ func (ti *TypeInferrer) inferIfStatement(stmt *IfStatement) {
|
||||
}
|
||||
}
|
||||
|
||||
// inferWhileStatement handles while loops
|
||||
func (ti *TypeInferrer) inferWhileStatement(stmt *WhileStatement) {
|
||||
condType := ti.inferExpression(stmt.Condition)
|
||||
ti.validateBooleanContext(condType, stmt.Condition)
|
||||
ti.inferExpression(stmt.Condition)
|
||||
|
||||
ti.enterScope()
|
||||
for _, s := range stmt.Body {
|
||||
@ -334,33 +297,21 @@ func (ti *TypeInferrer) inferWhileStatement(stmt *WhileStatement) {
|
||||
ti.exitScope()
|
||||
}
|
||||
|
||||
// inferForStatement handles numeric for loops
|
||||
func (ti *TypeInferrer) inferForStatement(stmt *ForStatement) {
|
||||
startType := ti.inferExpression(stmt.Start)
|
||||
endType := ti.inferExpression(stmt.End)
|
||||
|
||||
if !ti.isNumericType(startType) {
|
||||
ti.addError("for loop start value must be numeric", stmt.Start)
|
||||
}
|
||||
if !ti.isNumericType(endType) {
|
||||
ti.addError("for loop end value must be numeric", stmt.End)
|
||||
}
|
||||
|
||||
ti.inferExpression(stmt.Start)
|
||||
ti.inferExpression(stmt.End)
|
||||
if stmt.Step != nil {
|
||||
stepType := ti.inferExpression(stmt.Step)
|
||||
if !ti.isNumericType(stepType) {
|
||||
ti.addError("for loop step value must be numeric", stmt.Step)
|
||||
}
|
||||
ti.inferExpression(stmt.Step)
|
||||
}
|
||||
|
||||
ti.enterScope()
|
||||
// Define loop variable as number
|
||||
ti.currentScope.Define(&Symbol{
|
||||
Name: stmt.Variable.Value,
|
||||
Type: ti.numberType,
|
||||
Type: NumberType,
|
||||
Declared: true,
|
||||
})
|
||||
stmt.Variable.SetType(ti.numberType)
|
||||
stmt.Variable.typeInfo = NumberType
|
||||
|
||||
for _, s := range stmt.Body {
|
||||
ti.inferStatement(s)
|
||||
@ -368,33 +319,26 @@ func (ti *TypeInferrer) inferForStatement(stmt *ForStatement) {
|
||||
ti.exitScope()
|
||||
}
|
||||
|
||||
// inferForInStatement handles for-in loops
|
||||
func (ti *TypeInferrer) inferForInStatement(stmt *ForInStatement) {
|
||||
iterableType := ti.inferExpression(stmt.Iterable)
|
||||
|
||||
// For now, assume iterable is a table or struct
|
||||
if !ti.isTableType(iterableType) && !ti.isStructType(iterableType) {
|
||||
ti.addError("for-in requires an iterable (table or struct)", stmt.Iterable)
|
||||
}
|
||||
ti.inferExpression(stmt.Iterable)
|
||||
|
||||
ti.enterScope()
|
||||
|
||||
// Define loop variables (key and value are any for now)
|
||||
// Define loop variables
|
||||
if stmt.Key != nil {
|
||||
ti.currentScope.Define(&Symbol{
|
||||
Name: stmt.Key.Value,
|
||||
Type: ti.anyType,
|
||||
Type: AnyType,
|
||||
Declared: true,
|
||||
})
|
||||
stmt.Key.SetType(ti.anyType)
|
||||
stmt.Key.typeInfo = AnyType
|
||||
}
|
||||
|
||||
ti.currentScope.Define(&Symbol{
|
||||
Name: stmt.Value.Value,
|
||||
Type: ti.anyType,
|
||||
Type: AnyType,
|
||||
Declared: true,
|
||||
})
|
||||
stmt.Value.SetType(ti.anyType)
|
||||
stmt.Value.typeInfo = AnyType
|
||||
|
||||
for _, s := range stmt.Body {
|
||||
ti.inferStatement(s)
|
||||
@ -403,29 +347,25 @@ func (ti *TypeInferrer) inferForInStatement(stmt *ForInStatement) {
|
||||
}
|
||||
|
||||
// inferExpression infers the type of an expression
|
||||
func (ti *TypeInferrer) inferExpression(expr Expression) *TypeInfo {
|
||||
func (ti *TypeInferrer) inferExpression(expr Expression) TypeInfo {
|
||||
if expr == nil {
|
||||
return ti.nilType
|
||||
return NilType
|
||||
}
|
||||
|
||||
switch e := expr.(type) {
|
||||
case *Identifier:
|
||||
return ti.inferIdentifier(e)
|
||||
case *NumberLiteral:
|
||||
e.SetType(ti.numberType)
|
||||
return ti.numberType
|
||||
return NumberType
|
||||
case *StringLiteral:
|
||||
e.SetType(ti.stringType)
|
||||
return ti.stringType
|
||||
return StringType
|
||||
case *BooleanLiteral:
|
||||
e.SetType(ti.boolType)
|
||||
return ti.boolType
|
||||
return BoolType
|
||||
case *NilLiteral:
|
||||
e.SetType(ti.nilType)
|
||||
return ti.nilType
|
||||
return NilType
|
||||
case *TableLiteral:
|
||||
return ti.inferTableLiteral(e)
|
||||
case *StructConstructorExpression:
|
||||
case *StructConstructor:
|
||||
return ti.inferStructConstructor(e)
|
||||
case *FunctionLiteral:
|
||||
return ti.inferFunctionLiteral(e)
|
||||
@ -439,20 +379,40 @@ func (ti *TypeInferrer) inferExpression(expr Expression) *TypeInfo {
|
||||
return ti.inferIndexExpression(e)
|
||||
case *DotExpression:
|
||||
return ti.inferDotExpression(e)
|
||||
case *AssignExpression:
|
||||
return ti.inferAssignExpression(e)
|
||||
case *Assignment:
|
||||
return ti.inferAssignmentExpression(e)
|
||||
default:
|
||||
ti.addError("unknown expression type", expr)
|
||||
return ti.anyType
|
||||
return AnyType
|
||||
}
|
||||
}
|
||||
|
||||
// inferStructConstructor handles struct constructor expressions
|
||||
func (ti *TypeInferrer) inferStructConstructor(expr *StructConstructorExpression) *TypeInfo {
|
||||
structDef, exists := ti.structs[expr.StructName]
|
||||
if !exists {
|
||||
ti.addError(fmt.Sprintf("undefined struct '%s'", expr.StructName), expr)
|
||||
return ti.anyType
|
||||
func (ti *TypeInferrer) inferIdentifier(ident *Identifier) TypeInfo {
|
||||
symbol := ti.currentScope.Lookup(ident.Value)
|
||||
if symbol == nil {
|
||||
ti.addError(fmt.Sprintf("undefined variable '%s'", ident.Value), ident)
|
||||
return AnyType
|
||||
}
|
||||
ident.typeInfo = symbol.Type
|
||||
return symbol.Type
|
||||
}
|
||||
|
||||
func (ti *TypeInferrer) inferTableLiteral(table *TableLiteral) TypeInfo {
|
||||
// Infer types of all values
|
||||
for _, pair := range table.Pairs {
|
||||
if pair.Key != nil {
|
||||
ti.inferExpression(pair.Key)
|
||||
}
|
||||
ti.inferExpression(pair.Value)
|
||||
}
|
||||
return TableType
|
||||
}
|
||||
|
||||
func (ti *TypeInferrer) inferStructConstructor(expr *StructConstructor) TypeInfo {
|
||||
structDef := ti.GetStructByID(expr.StructID)
|
||||
if structDef == nil {
|
||||
ti.addError("undefined struct in constructor", expr)
|
||||
return AnyType
|
||||
}
|
||||
|
||||
// Validate field assignments
|
||||
@ -465,25 +425,23 @@ func (ti *TypeInferrer) inferStructConstructor(expr *StructConstructorExpression
|
||||
fieldName = str.Value
|
||||
}
|
||||
|
||||
// Check if field exists in struct
|
||||
fieldExists := false
|
||||
var fieldType *TypeInfo
|
||||
// Find field in struct definition
|
||||
var fieldType TypeInfo
|
||||
found := false
|
||||
for _, field := range structDef.Fields {
|
||||
if field.Name == fieldName {
|
||||
fieldExists = true
|
||||
fieldType = field.TypeHint
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !fieldExists {
|
||||
ti.addError(fmt.Sprintf("struct '%s' has no field '%s'", expr.StructName, fieldName), expr)
|
||||
if !found {
|
||||
ti.addError(fmt.Sprintf("struct has no field '%s'", fieldName), expr)
|
||||
} else {
|
||||
// Check type compatibility
|
||||
valueType := ti.inferExpression(pair.Value)
|
||||
if !ti.isTypeCompatible(valueType, fieldType) {
|
||||
ti.addError(fmt.Sprintf("cannot assign %s to field '%s' of type %s",
|
||||
valueType.Type, fieldName, fieldType.Type), expr)
|
||||
ti.addError("field type mismatch in struct constructor", expr)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@ -492,66 +450,18 @@ func (ti *TypeInferrer) inferStructConstructor(expr *StructConstructorExpression
|
||||
}
|
||||
}
|
||||
|
||||
structType := ti.getStructType(expr.StructName)
|
||||
expr.SetType(structType)
|
||||
structType := ti.CreateStructType(expr.StructID)
|
||||
expr.typeInfo = structType
|
||||
return structType
|
||||
}
|
||||
|
||||
// inferAssignExpression handles assignment expressions
|
||||
func (ti *TypeInferrer) inferAssignExpression(expr *AssignExpression) *TypeInfo {
|
||||
valueType := ti.inferExpression(expr.Value)
|
||||
|
||||
if ident, ok := expr.Name.(*Identifier); ok {
|
||||
if expr.IsDeclaration {
|
||||
ti.currentScope.Define(&Symbol{
|
||||
Name: ident.Value,
|
||||
Type: valueType,
|
||||
Declared: true,
|
||||
})
|
||||
}
|
||||
ident.SetType(valueType)
|
||||
} else {
|
||||
ti.inferExpression(expr.Name)
|
||||
}
|
||||
|
||||
expr.SetType(valueType)
|
||||
return valueType
|
||||
}
|
||||
|
||||
// inferIdentifier looks up identifier type in symbol table
|
||||
func (ti *TypeInferrer) inferIdentifier(ident *Identifier) *TypeInfo {
|
||||
symbol := ti.currentScope.Lookup(ident.Value)
|
||||
if symbol == nil {
|
||||
ti.addError(fmt.Sprintf("undefined variable '%s'", ident.Value), ident)
|
||||
return ti.anyType
|
||||
}
|
||||
|
||||
ident.SetType(symbol.Type)
|
||||
return symbol.Type
|
||||
}
|
||||
|
||||
// inferTableLiteral infers table type
|
||||
func (ti *TypeInferrer) inferTableLiteral(table *TableLiteral) *TypeInfo {
|
||||
// Infer types of all values
|
||||
for _, pair := range table.Pairs {
|
||||
if pair.Key != nil {
|
||||
ti.inferExpression(pair.Key)
|
||||
}
|
||||
ti.inferExpression(pair.Value)
|
||||
}
|
||||
|
||||
table.SetType(ti.tableType)
|
||||
return ti.tableType
|
||||
}
|
||||
|
||||
// inferFunctionLiteral infers function type
|
||||
func (ti *TypeInferrer) inferFunctionLiteral(fn *FunctionLiteral) *TypeInfo {
|
||||
func (ti *TypeInferrer) inferFunctionLiteral(fn *FunctionLiteral) TypeInfo {
|
||||
ti.enterScope()
|
||||
|
||||
// Define parameters in function scope
|
||||
for _, param := range fn.Parameters {
|
||||
paramType := ti.anyType
|
||||
if param.TypeHint != nil {
|
||||
paramType := AnyType
|
||||
if param.TypeHint.Type != TypeUnknown {
|
||||
paramType = param.TypeHint
|
||||
}
|
||||
|
||||
@ -568,104 +478,88 @@ func (ti *TypeInferrer) inferFunctionLiteral(fn *FunctionLiteral) *TypeInfo {
|
||||
}
|
||||
|
||||
ti.exitScope()
|
||||
|
||||
// For now, all functions have type "function"
|
||||
funcType := &TypeInfo{Type: TypeFunction, Inferred: true}
|
||||
fn.SetType(funcType)
|
||||
return funcType
|
||||
return FunctionType
|
||||
}
|
||||
|
||||
// inferCallExpression infers function call return type
|
||||
func (ti *TypeInferrer) inferCallExpression(call *CallExpression) *TypeInfo {
|
||||
funcType := ti.inferExpression(call.Function)
|
||||
|
||||
if !ti.isFunctionType(funcType) {
|
||||
ti.addError("cannot call non-function", call.Function)
|
||||
return ti.anyType
|
||||
}
|
||||
func (ti *TypeInferrer) inferCallExpression(call *CallExpression) TypeInfo {
|
||||
ti.inferExpression(call.Function)
|
||||
|
||||
// Infer argument types
|
||||
for _, arg := range call.Arguments {
|
||||
ti.inferExpression(arg)
|
||||
}
|
||||
|
||||
// For now, assume function calls return any
|
||||
call.SetType(ti.anyType)
|
||||
return ti.anyType
|
||||
call.typeInfo = AnyType
|
||||
return AnyType
|
||||
}
|
||||
|
||||
// inferPrefixExpression infers prefix operation type
|
||||
func (ti *TypeInferrer) inferPrefixExpression(prefix *PrefixExpression) *TypeInfo {
|
||||
func (ti *TypeInferrer) inferPrefixExpression(prefix *PrefixExpression) TypeInfo {
|
||||
rightType := ti.inferExpression(prefix.Right)
|
||||
|
||||
var resultType *TypeInfo
|
||||
var resultType TypeInfo
|
||||
switch prefix.Operator {
|
||||
case "-":
|
||||
if !ti.isNumericType(rightType) {
|
||||
ti.addError("unary minus requires numeric operand", prefix)
|
||||
}
|
||||
resultType = ti.numberType
|
||||
resultType = NumberType
|
||||
case "not":
|
||||
resultType = ti.boolType
|
||||
resultType = BoolType
|
||||
default:
|
||||
ti.addError(fmt.Sprintf("unknown prefix operator '%s'", prefix.Operator), prefix)
|
||||
resultType = ti.anyType
|
||||
resultType = AnyType
|
||||
}
|
||||
|
||||
prefix.SetType(resultType)
|
||||
prefix.typeInfo = resultType
|
||||
return resultType
|
||||
}
|
||||
|
||||
// inferInfixExpression infers binary operation type
|
||||
func (ti *TypeInferrer) inferInfixExpression(infix *InfixExpression) *TypeInfo {
|
||||
func (ti *TypeInferrer) inferInfixExpression(infix *InfixExpression) TypeInfo {
|
||||
leftType := ti.inferExpression(infix.Left)
|
||||
rightType := ti.inferExpression(infix.Right)
|
||||
|
||||
var resultType *TypeInfo
|
||||
var resultType TypeInfo
|
||||
|
||||
switch infix.Operator {
|
||||
case "+", "-", "*", "/":
|
||||
if !ti.isNumericType(leftType) || !ti.isNumericType(rightType) {
|
||||
ti.addError(fmt.Sprintf("arithmetic operator '%s' requires numeric operands", infix.Operator), infix)
|
||||
}
|
||||
resultType = ti.numberType
|
||||
resultType = NumberType
|
||||
|
||||
case "==", "!=":
|
||||
// Equality works with any types
|
||||
resultType = ti.boolType
|
||||
resultType = BoolType
|
||||
|
||||
case "<", ">", "<=", ">=":
|
||||
if !ti.isComparableTypes(leftType, rightType) {
|
||||
ti.addError(fmt.Sprintf("comparison operator '%s' requires compatible operands", infix.Operator), infix)
|
||||
}
|
||||
resultType = ti.boolType
|
||||
resultType = BoolType
|
||||
|
||||
case "and", "or":
|
||||
ti.validateBooleanContext(leftType, infix.Left)
|
||||
ti.validateBooleanContext(rightType, infix.Right)
|
||||
resultType = ti.boolType
|
||||
resultType = BoolType
|
||||
|
||||
default:
|
||||
ti.addError(fmt.Sprintf("unknown infix operator '%s'", infix.Operator), infix)
|
||||
resultType = ti.anyType
|
||||
resultType = AnyType
|
||||
}
|
||||
|
||||
infix.SetType(resultType)
|
||||
infix.typeInfo = resultType
|
||||
return resultType
|
||||
}
|
||||
|
||||
// inferIndexExpression infers table[index] type
|
||||
func (ti *TypeInferrer) inferIndexExpression(index *IndexExpression) *TypeInfo {
|
||||
func (ti *TypeInferrer) inferIndexExpression(index *IndexExpression) TypeInfo {
|
||||
leftType := ti.inferExpression(index.Left)
|
||||
ti.inferExpression(index.Index)
|
||||
|
||||
// If indexing a struct, try to infer field type
|
||||
if ti.isStructType(leftType) {
|
||||
if strLit, ok := index.Index.(*StringLiteral); ok {
|
||||
if structDef, exists := ti.structs[leftType.Type]; exists {
|
||||
if structDef := ti.GetStructByID(leftType.StructID); structDef != nil {
|
||||
for _, field := range structDef.Fields {
|
||||
if field.Name == strLit.Value {
|
||||
index.SetType(field.TypeHint)
|
||||
index.typeInfo = field.TypeHint
|
||||
return field.TypeHint
|
||||
}
|
||||
}
|
||||
@ -674,20 +568,19 @@ func (ti *TypeInferrer) inferIndexExpression(index *IndexExpression) *TypeInfo {
|
||||
}
|
||||
|
||||
// For now, assume table/struct access returns any
|
||||
index.SetType(ti.anyType)
|
||||
return ti.anyType
|
||||
index.typeInfo = AnyType
|
||||
return AnyType
|
||||
}
|
||||
|
||||
// inferDotExpression infers table.key type
|
||||
func (ti *TypeInferrer) inferDotExpression(dot *DotExpression) *TypeInfo {
|
||||
func (ti *TypeInferrer) inferDotExpression(dot *DotExpression) TypeInfo {
|
||||
leftType := ti.inferExpression(dot.Left)
|
||||
|
||||
// If accessing a struct field, try to infer field type
|
||||
if ti.isStructType(leftType) {
|
||||
if structDef, exists := ti.structs[leftType.Type]; exists {
|
||||
if structDef := ti.GetStructByID(leftType.StructID); structDef != nil {
|
||||
for _, field := range structDef.Fields {
|
||||
if field.Name == dot.Key {
|
||||
dot.SetType(field.TypeHint)
|
||||
dot.typeInfo = field.TypeHint
|
||||
return field.TypeHint
|
||||
}
|
||||
}
|
||||
@ -695,59 +588,107 @@ func (ti *TypeInferrer) inferDotExpression(dot *DotExpression) *TypeInfo {
|
||||
}
|
||||
|
||||
// For now, assume member access returns any
|
||||
dot.SetType(ti.anyType)
|
||||
return ti.anyType
|
||||
dot.typeInfo = AnyType
|
||||
return AnyType
|
||||
}
|
||||
|
||||
// Type checking helper methods
|
||||
func (ti *TypeInferrer) inferAssignmentExpression(expr *Assignment) TypeInfo {
|
||||
valueType := ti.inferExpression(expr.Value)
|
||||
|
||||
func (ti *TypeInferrer) isTypeCompatible(valueType, targetType *TypeInfo) bool {
|
||||
if ident, ok := expr.Target.(*Identifier); ok {
|
||||
if expr.IsDeclaration {
|
||||
varType := valueType
|
||||
if expr.TypeHint.Type != TypeUnknown {
|
||||
if !ti.isTypeCompatible(valueType, expr.TypeHint) {
|
||||
ti.addError("type mismatch in assignment", expr)
|
||||
}
|
||||
varType = expr.TypeHint
|
||||
}
|
||||
|
||||
ti.currentScope.Define(&Symbol{
|
||||
Name: ident.Value,
|
||||
Type: varType,
|
||||
Declared: true,
|
||||
})
|
||||
ident.typeInfo = varType
|
||||
} else {
|
||||
symbol := ti.currentScope.Lookup(ident.Value)
|
||||
if symbol != nil {
|
||||
ident.typeInfo = symbol.Type
|
||||
}
|
||||
}
|
||||
} else {
|
||||
ti.inferExpression(expr.Target)
|
||||
}
|
||||
|
||||
return valueType
|
||||
}
|
||||
|
||||
// Helper methods
|
||||
func (ti *TypeInferrer) enterScope() {
|
||||
ti.currentScope = NewScope(ti.currentScope)
|
||||
}
|
||||
|
||||
func (ti *TypeInferrer) exitScope() {
|
||||
if ti.currentScope.parent != nil {
|
||||
ti.currentScope = ti.currentScope.parent
|
||||
}
|
||||
}
|
||||
|
||||
func (ti *TypeInferrer) addError(message string, node Node) {
|
||||
ti.errors = append(ti.errors, TypeError{
|
||||
Message: message,
|
||||
Node: node,
|
||||
})
|
||||
}
|
||||
|
||||
func (ti *TypeInferrer) isValidType(t TypeInfo) bool {
|
||||
if t.Type == TypeStruct {
|
||||
return ti.GetStructByID(t.StructID) != nil
|
||||
}
|
||||
return t.Type <= TypeStruct
|
||||
}
|
||||
|
||||
func (ti *TypeInferrer) isTypeCompatible(valueType, targetType TypeInfo) bool {
|
||||
if targetType.Type == TypeAny || valueType.Type == TypeAny {
|
||||
return true
|
||||
}
|
||||
if valueType.Type == TypeStruct && targetType.Type == TypeStruct {
|
||||
return valueType.StructID == targetType.StructID
|
||||
}
|
||||
return valueType.Type == targetType.Type
|
||||
}
|
||||
|
||||
func (ti *TypeInferrer) isNumericType(t *TypeInfo) bool {
|
||||
func (ti *TypeInferrer) isNumericType(t TypeInfo) bool {
|
||||
return t.Type == TypeNumber
|
||||
}
|
||||
|
||||
func (ti *TypeInferrer) isBooleanType(t *TypeInfo) bool {
|
||||
func (ti *TypeInferrer) isBooleanType(t TypeInfo) bool {
|
||||
return t.Type == TypeBool
|
||||
}
|
||||
|
||||
func (ti *TypeInferrer) isTableType(t *TypeInfo) bool {
|
||||
func (ti *TypeInferrer) isTableType(t TypeInfo) bool {
|
||||
return t.Type == TypeTable
|
||||
}
|
||||
|
||||
func (ti *TypeInferrer) isFunctionType(t *TypeInfo) bool {
|
||||
func (ti *TypeInferrer) isFunctionType(t TypeInfo) bool {
|
||||
return t.Type == TypeFunction
|
||||
}
|
||||
|
||||
func (ti *TypeInferrer) isComparableTypes(left, right *TypeInfo) bool {
|
||||
func (ti *TypeInferrer) isStructType(t TypeInfo) bool {
|
||||
return t.Type == TypeStruct
|
||||
}
|
||||
|
||||
func (ti *TypeInferrer) isComparableTypes(left, right TypeInfo) bool {
|
||||
if left.Type == TypeAny || right.Type == TypeAny {
|
||||
return true
|
||||
}
|
||||
return left.Type == right.Type && (left.Type == TypeNumber || left.Type == TypeString)
|
||||
}
|
||||
|
||||
func (ti *TypeInferrer) validateBooleanContext(t *TypeInfo, expr Expression) {
|
||||
// In many languages, non-boolean values can be used in boolean context
|
||||
// For strictness, we could require boolean type here
|
||||
// For now, allow any type (truthy/falsy semantics)
|
||||
}
|
||||
|
||||
// Errors returns all type checking errors
|
||||
func (ti *TypeInferrer) Errors() []TypeError {
|
||||
return ti.errors
|
||||
}
|
||||
|
||||
// HasErrors returns true if there are any type errors
|
||||
func (ti *TypeInferrer) HasErrors() bool {
|
||||
return len(ti.errors) > 0
|
||||
}
|
||||
|
||||
// ErrorStrings returns error messages as strings
|
||||
// Error reporting
|
||||
func (ti *TypeInferrer) Errors() []TypeError { return ti.errors }
|
||||
func (ti *TypeInferrer) HasErrors() bool { return len(ti.errors) > 0 }
|
||||
func (ti *TypeInferrer) ErrorStrings() []string {
|
||||
result := make([]string, len(ti.errors))
|
||||
for i, err := range ti.errors {
|
||||
@ -756,21 +697,26 @@ func (ti *TypeInferrer) ErrorStrings() []string {
|
||||
return result
|
||||
}
|
||||
|
||||
// ValidTypeName checks if a string is a valid type name
|
||||
func ValidTypeName(name string) bool {
|
||||
validTypes := []string{TypeNumber, TypeString, TypeBool, TypeNil, TypeTable, TypeFunction, TypeAny}
|
||||
for _, validType := range validTypes {
|
||||
if name == validType {
|
||||
return true
|
||||
// Type string conversion
|
||||
func TypeToString(t TypeInfo) string {
|
||||
switch t.Type {
|
||||
case TypeNumber:
|
||||
return "number"
|
||||
case TypeString:
|
||||
return "string"
|
||||
case TypeBool:
|
||||
return "bool"
|
||||
case TypeNil:
|
||||
return "nil"
|
||||
case TypeTable:
|
||||
return "table"
|
||||
case TypeFunction:
|
||||
return "function"
|
||||
case TypeAny:
|
||||
return "any"
|
||||
case TypeStruct:
|
||||
return fmt.Sprintf("struct<%d>", t.StructID)
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ParseTypeName converts a string to a TypeInfo (for parsing type hints)
|
||||
func ParseTypeName(name string) *TypeInfo {
|
||||
if ValidTypeName(name) {
|
||||
return &TypeInfo{Type: name, Inferred: false}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user