Compare commits

..

No commits in common. "53cdb95b6e3ad7f0a22e76f21a93f1e49ac604dd" and "5ae2a6ef23876037778ed7b39a1bfc7d81a4cb2a" have entirely different histories.

15 changed files with 894 additions and 1409 deletions

View File

@ -1,234 +0,0 @@
package compiler
// Opcode represents a single bytecode instruction
type Opcode uint8
const (
// Stack Operations
OpLoadConst Opcode = iota // Load constant onto stack [idx]
OpLoadLocal // Load local variable [slot]
OpStoreLocal // Store top of stack to local [slot]
OpLoadGlobal // Load global variable [idx]
OpStoreGlobal // Store top of stack to global [idx]
OpPop // Pop top value from stack
OpDup // Duplicate top value on stack
// Arithmetic Operations
OpAdd // a + b
OpSub // a - b
OpMul // a * b
OpDiv // a / b
OpNeg // -a
OpMod // a % b
// Comparison Operations
OpEq // a == b
OpNeq // a != b
OpLt // a < b
OpLte // a <= b
OpGt // a > b
OpGte // a >= b
// Logical Operations
OpNot // not a
OpAnd // a and b
OpOr // a or b
// Control Flow
OpJump // Unconditional jump [offset]
OpJumpIfTrue // Jump if top of stack is true [offset]
OpJumpIfFalse // Jump if top of stack is false [offset]
OpCall // Call function [argCount]
OpReturn // Return from function
OpReturnNil // Return nil from function
// Table Operations
OpNewTable // Create new empty table
OpGetIndex // table[key] -> value
OpSetIndex // table[key] = value
OpGetField // table.field -> value [fieldIdx]
OpSetField // table.field = value [fieldIdx]
OpTableInsert // Insert value into table at next index
// Struct Operations
OpNewStruct // Create new struct instance [structId]
OpGetProperty // struct.field -> value [fieldIdx]
OpSetProperty // struct.field = value [fieldIdx]
OpCallMethod // Call method on struct [methodIdx, argCount]
// Function Operations
OpClosure // Create closure from function [funcIdx, upvalueCount]
OpGetUpvalue // Get upvalue [idx]
OpSetUpvalue // Set upvalue [idx]
OpCloseUpvalue // Close upvalue (move to heap)
// Array Operations
OpNewArray // Create new array with size [size]
OpArrayAppend // Append value to array
// Type Operations
OpGetType // Get type of value on stack
OpCast // Cast value to type [typeId]
// I/O Operations
OpEcho // Echo value to output
OpExit // Exit with code
// Special Operations
OpNoop // No operation
OpBreak // Break from loop
OpContinue // Continue loop iteration
// Debug Operations
OpDebugPrint // Debug print stack top
OpDebugStack // Debug print entire stack
)
// Instruction represents a single bytecode instruction with operands
type Instruction struct {
Op Opcode
Operands []uint16 // Variable length operands
}
// Chunk represents a compiled chunk of bytecode
type Chunk struct {
Code []uint8 // Raw bytecode stream
Constants []Value // Constant pool
Lines []int // Line numbers for debugging
Functions []Function // Function definitions
Structs []Struct // Struct definitions
}
// Value represents a runtime value in the VM
type Value struct {
Type ValueType
Data any // Actual value data
}
// ValueType represents the type of a runtime value
type ValueType uint8
const (
ValueNil ValueType = iota
ValueBool
ValueNumber
ValueString
ValueTable
ValueFunction
ValueStruct
ValueArray
ValueUpvalue
)
// Function represents a compiled function
type Function struct {
Name string // Function name (empty for anonymous)
Arity int // Number of parameters
Variadic bool // Whether function accepts variable args
LocalCount int // Number of local variable slots
UpvalCount int // Number of upvalues
Chunk Chunk // Function bytecode
Defaults []Value // Default parameter values
}
// Struct represents a compiled struct definition
type Struct struct {
Name string // Struct name
Fields []StructField // Field definitions
Methods map[string]uint16 // Method name -> function index
ID uint16 // Unique struct identifier
}
// StructField represents a field in a struct
type StructField struct {
Name string // Field name
Type ValueType // Field type
Offset uint16 // Offset in struct layout
}
// Table represents a key-value table/map
type Table struct {
Array map[int]Value // Array part (integer keys)
Hash map[string]Value // Hash part (string keys)
Meta *Table // Metatable for operations
}
// Array represents a dynamic array
type Array struct {
Elements []Value // Array elements
Count int // Current element count
Capacity int // Current capacity
}
// StructInstance represents an instance of a struct
type StructInstance struct {
StructID uint16 // Reference to struct definition
Fields map[string]Value // Field values
}
// Upvalue represents a captured variable
type Upvalue struct {
Location *Value // Pointer to actual value location
Closed Value // Closed-over value (when moved to heap)
IsClosed bool // Whether upvalue has been closed
}
// Instruction encoding helpers
// EncodeInstruction encodes an instruction into bytecode
func EncodeInstruction(op Opcode, operands ...uint16) []uint8 {
bytes := []uint8{uint8(op)}
for _, operand := range operands {
bytes = append(bytes, uint8(operand&0xFF), uint8(operand>>8))
}
return bytes
}
// DecodeInstruction decodes bytecode into instruction
func DecodeInstruction(code []uint8, offset int) (Opcode, []uint16, int) {
if offset >= len(code) {
return OpNoop, nil, offset
}
op := Opcode(code[offset])
operands := []uint16{}
nextOffset := offset + 1
// Decode operands based on instruction type
operandCount := GetOperandCount(op)
for range operandCount {
if nextOffset+1 >= len(code) {
break
}
operand := uint16(code[nextOffset]) | (uint16(code[nextOffset+1]) << 8)
operands = append(operands, operand)
nextOffset += 2
}
return op, operands, nextOffset
}
// GetOperandCount returns the number of operands for an instruction
func GetOperandCount(op Opcode) int {
switch op {
case OpLoadConst, OpLoadLocal, OpStoreLocal, OpLoadGlobal, OpStoreGlobal:
return 1
case OpJump, OpJumpIfTrue, OpJumpIfFalse:
return 1
case OpCall, OpNewStruct, OpGetField, OpSetField, OpGetProperty, OpSetProperty:
return 1
case OpCallMethod:
return 2
case OpClosure:
return 2
case OpNewArray, OpCast:
return 1
default:
return 0
}
}
// Instruction size calculation
func InstructionSize(op Opcode) int {
return 1 + (GetOperandCount(op) * 2) // 1 byte opcode + 2 bytes per operand
}

View File

@ -1,290 +0,0 @@
package compiler
import "fmt"
// Constants for compiler limits
const (
MaxLocals = 256 // Maximum local variables per function
MaxUpvalues = 256 // Maximum upvalues per function
MaxConstants = 65536 // Maximum constants per chunk
)
// CompilerState holds state during compilation
type CompilerState struct {
Chunk *Chunk // Current chunk being compiled
Constants map[string]int // Constant pool index mapping
Functions []Function // Compiled functions
Structs []Struct // Compiled structs
Locals []Local // Local variable stack
Upvalues []UpvalueRef // Upvalue definitions
ScopeDepth int // Current scope nesting level
FunctionType FunctionType // Type of function being compiled
BreakJumps []int // Break jump addresses for loops
ContinueJumps []int // Continue jump addresses for loops
LoopStart int // Start of current loop for continue
LoopDepth int // Current loop nesting depth
}
// Local represents a local variable during compilation
type Local struct {
Name string // Variable name
Depth int // Scope depth where declared
IsCaptured bool // Whether variable is captured by closure
Slot int // Stack slot index
}
// UpvalueRef represents an upvalue reference during compilation
type UpvalueRef struct {
Index uint8 // Index in enclosing function's locals or upvalues
IsLocal bool // True if captures local, false if captures upvalue
}
// FunctionType represents the type of function being compiled
type FunctionType uint8
const (
FunctionTypeScript FunctionType = iota // Top-level script
FunctionTypeFunction // Regular function
FunctionTypeMethod // Struct method
)
// CompileError represents a compilation error with location information
type CompileError struct {
Message string
Line int
Column int
}
func (ce CompileError) Error() string {
return fmt.Sprintf("Compile error at line %d, column %d: %s", ce.Line, ce.Column, ce.Message)
}
// NewCompilerState creates a new compiler state for compilation
func NewCompilerState(functionType FunctionType) *CompilerState {
return &CompilerState{
Chunk: NewChunk(),
Constants: make(map[string]int),
Functions: make([]Function, 0),
Structs: make([]Struct, 0),
Locals: make([]Local, 0, MaxLocals),
Upvalues: make([]UpvalueRef, 0, MaxUpvalues),
ScopeDepth: 0,
FunctionType: functionType,
BreakJumps: make([]int, 0),
ContinueJumps: make([]int, 0),
LoopStart: -1,
LoopDepth: 0,
}
}
// NewChunk creates a new bytecode chunk
func NewChunk() *Chunk {
return &Chunk{
Code: make([]uint8, 0, 256),
Constants: make([]Value, 0, 64),
Lines: make([]int, 0, 256),
Functions: make([]Function, 0),
Structs: make([]Struct, 0),
}
}
// Scope management methods
func (cs *CompilerState) BeginScope() {
cs.ScopeDepth++
}
func (cs *CompilerState) EndScope() {
cs.ScopeDepth--
// Remove locals that go out of scope
for len(cs.Locals) > 0 && cs.Locals[len(cs.Locals)-1].Depth > cs.ScopeDepth {
local := cs.Locals[len(cs.Locals)-1]
if local.IsCaptured {
// Emit close upvalue instruction
cs.EmitByte(uint8(OpCloseUpvalue))
} else {
// Emit pop instruction
cs.EmitByte(uint8(OpPop))
}
cs.Locals = cs.Locals[:len(cs.Locals)-1]
}
}
// Local variable management
func (cs *CompilerState) AddLocal(name string) error {
if len(cs.Locals) >= MaxLocals {
return CompileError{
Message: "too many local variables in function",
}
}
local := Local{
Name: name,
Depth: -1, // Mark as uninitialized
IsCaptured: false,
Slot: len(cs.Locals),
}
cs.Locals = append(cs.Locals, local)
return nil
}
func (cs *CompilerState) MarkInitialized() {
if len(cs.Locals) > 0 {
cs.Locals[len(cs.Locals)-1].Depth = cs.ScopeDepth
}
}
func (cs *CompilerState) ResolveLocal(name string) int {
for i := len(cs.Locals) - 1; i >= 0; i-- {
local := &cs.Locals[i]
if local.Name == name {
if local.Depth == -1 {
// Variable used before initialization
return -2
}
return i
}
}
return -1
}
// Upvalue management
func (cs *CompilerState) AddUpvalue(index uint8, isLocal bool) int {
upvalueCount := len(cs.Upvalues)
// Check if upvalue already exists
for i := range upvalueCount {
upvalue := &cs.Upvalues[i]
if upvalue.Index == index && upvalue.IsLocal == isLocal {
return i
}
}
if upvalueCount >= MaxUpvalues {
return -1 // Too many upvalues
}
cs.Upvalues = append(cs.Upvalues, UpvalueRef{
Index: index,
IsLocal: isLocal,
})
return upvalueCount
}
// Constant pool management
func (cs *CompilerState) AddConstant(value Value) int {
// Check if constant already exists to avoid duplicates
key := cs.valueKey(value)
if index, exists := cs.Constants[key]; exists {
return index
}
if len(cs.Chunk.Constants) >= MaxConstants {
return -1 // Too many constants
}
index := len(cs.Chunk.Constants)
cs.Chunk.Constants = append(cs.Chunk.Constants, value)
cs.Constants[key] = index
return index
}
// Generate unique key for value in constant pool
func (cs *CompilerState) valueKey(value Value) string {
switch value.Type {
case ValueNil:
return "nil"
case ValueBool:
if value.Data.(bool) {
return "bool:true"
}
return "bool:false"
case ValueNumber:
return fmt.Sprintf("number:%g", value.Data.(float64))
case ValueString:
return fmt.Sprintf("string:%s", value.Data.(string))
default:
// For complex types, use memory address as fallback
return fmt.Sprintf("%T:%p", value.Data, value.Data)
}
}
// Bytecode emission methods
func (cs *CompilerState) EmitByte(byte uint8) {
cs.Chunk.Code = append(cs.Chunk.Code, byte)
cs.Chunk.Lines = append(cs.Chunk.Lines, 0) // Line will be set by caller
}
func (cs *CompilerState) EmitBytes(bytes ...uint8) {
for _, b := range bytes {
cs.EmitByte(b)
}
}
func (cs *CompilerState) EmitInstruction(op Opcode, operands ...uint16) {
bytes := EncodeInstruction(op, operands...)
cs.EmitBytes(bytes...)
}
func (cs *CompilerState) EmitJump(op Opcode) int {
cs.EmitByte(uint8(op))
cs.EmitByte(0xFF) // Placeholder
cs.EmitByte(0xFF) // Placeholder
return len(cs.Chunk.Code) - 2 // Return offset of jump address
}
func (cs *CompilerState) PatchJump(offset int) {
// Calculate jump distance
jump := len(cs.Chunk.Code) - offset - 2
if jump > 65535 {
// Jump too large - would need long jump instruction
return
}
cs.Chunk.Code[offset] = uint8(jump & 0xFF)
cs.Chunk.Code[offset+1] = uint8((jump >> 8) & 0xFF)
}
// Loop management
func (cs *CompilerState) EnterLoop() {
cs.LoopStart = len(cs.Chunk.Code)
cs.LoopDepth++
}
func (cs *CompilerState) ExitLoop() {
cs.LoopDepth--
if cs.LoopDepth == 0 {
cs.LoopStart = -1
}
// Patch break jumps
for _, jumpOffset := range cs.BreakJumps {
cs.PatchJump(jumpOffset)
}
cs.BreakJumps = cs.BreakJumps[:0]
// Patch continue jumps
for _, jumpOffset := range cs.ContinueJumps {
jump := cs.LoopStart - jumpOffset - 2
if jump < 65535 {
cs.Chunk.Code[jumpOffset] = uint8(jump & 0xFF)
cs.Chunk.Code[jumpOffset+1] = uint8((jump >> 8) & 0xFF)
}
}
cs.ContinueJumps = cs.ContinueJumps[:0]
}
func (cs *CompilerState) EmitBreak() {
jumpOffset := cs.EmitJump(OpJump)
cs.BreakJumps = append(cs.BreakJumps, jumpOffset)
}
func (cs *CompilerState) EmitContinue() {
if cs.LoopStart != -1 {
jumpOffset := cs.EmitJump(OpJump)
cs.ContinueJumps = append(cs.ContinueJumps, jumpOffset)
}
}

View File

@ -2,28 +2,32 @@ package parser
import "fmt" import "fmt"
// Note: Type definitions moved to types.go for proper separation of concerns // TypeInfo represents type information for expressions
type TypeInfo struct {
Type string // "number", "string", "bool", "table", "function", "nil", "any", struct name
Inferred bool // true if type was inferred, false if explicitly declared
}
// Node represents any node in the AST // Node represents any node in the AST
type Node interface { type Node interface {
String() string String() string
} }
// Statement represents statement nodes that can appear at the top level or in blocks // Statement represents statement nodes
type Statement interface { type Statement interface {
Node Node
statementNode() statementNode()
} }
// Expression represents expression nodes that produce values and have types // Expression represents expression nodes
type Expression interface { type Expression interface {
Node Node
expressionNode() expressionNode()
TypeInfo() TypeInfo // Returns type by value, not pointer GetType() *TypeInfo
SetType(*TypeInfo)
} }
// Program represents the root of the AST containing all top-level statements. // Program represents the root of the AST
// Tracks exit code for script termination and owns the statement list.
type Program struct { type Program struct {
Statements []Statement Statements []Statement
ExitCode int ExitCode int
@ -37,23 +41,23 @@ func (p *Program) String() string {
return result return result
} }
// StructField represents a field definition within a struct. // StructField represents a field in a struct definition
// Contains field name and required type annotation for compile-time checking.
type StructField struct { type StructField struct {
Name string Name string
TypeHint TypeInfo // Required for struct fields, embeds directly TypeHint *TypeInfo
} }
func (sf *StructField) String() string { func (sf *StructField) String() string {
return fmt.Sprintf("%s: %s", sf.Name, typeToString(sf.TypeHint)) if sf.TypeHint != nil {
return fmt.Sprintf("%s: %s", sf.Name, sf.TypeHint.Type)
}
return sf.Name
} }
// StructStatement represents struct type definitions with named fields. // StructStatement represents struct definitions
// Defines new types that can be instantiated and used for type checking.
type StructStatement struct { type StructStatement struct {
Name string Name string
Fields []StructField Fields []StructField
ID uint16 // Unique identifier for fast lookup
} }
func (ss *StructStatement) statementNode() {} func (ss *StructStatement) statementNode() {}
@ -68,72 +72,77 @@ func (ss *StructStatement) String() string {
return fmt.Sprintf("struct %s {\n\t%s\n}", ss.Name, fields) return fmt.Sprintf("struct %s {\n\t%s\n}", ss.Name, fields)
} }
// MethodDefinition represents method definitions attached to struct types. // MethodDefinition represents method definitions on structs
// Links a function implementation to a specific struct via struct ID.
type MethodDefinition struct { type MethodDefinition struct {
StructID uint16 // Index into struct table for fast lookup StructName string
MethodName string MethodName string
Function *FunctionLiteral Function *FunctionLiteral
} }
func (md *MethodDefinition) statementNode() {} func (md *MethodDefinition) statementNode() {}
func (md *MethodDefinition) String() string { func (md *MethodDefinition) String() string {
return fmt.Sprintf("fn <struct>.%s%s", md.MethodName, md.Function.String()[2:]) return fmt.Sprintf("fn %s.%s%s", md.StructName, md.MethodName, md.Function.String()[2:]) // skip "fn" from function string
} }
// StructConstructor represents struct instantiation with field initialization. // StructConstructorExpression represents struct constructor calls like my_type{...}
// Uses struct ID for fast type resolution and validation during parsing. type StructConstructorExpression struct {
type StructConstructor struct { StructName string
StructID uint16 // Index into struct table Fields []TablePair // reuse TablePair for field assignments
Fields []TablePair // Reuses table pair structure for field assignments typeInfo *TypeInfo
typeInfo TypeInfo // Cached type info for this constructor
} }
func (sc *StructConstructor) expressionNode() {} func (sce *StructConstructorExpression) expressionNode() {}
func (sc *StructConstructor) String() string { func (sce *StructConstructorExpression) String() string {
var pairs []string var pairs []string
for _, pair := range sc.Fields { for _, pair := range sce.Fields {
pairs = append(pairs, pair.String()) pairs = append(pairs, pair.String())
} }
return fmt.Sprintf("<struct>{%s}", joinStrings(pairs, ", ")) return fmt.Sprintf("%s{%s}", sce.StructName, joinStrings(pairs, ", "))
} }
func (sc *StructConstructor) TypeInfo() TypeInfo { return sc.typeInfo } func (sce *StructConstructorExpression) GetType() *TypeInfo { return sce.typeInfo }
func (sce *StructConstructorExpression) SetType(t *TypeInfo) { sce.typeInfo = t }
// Assignment represents both variable assignment statements and assignment expressions. // AssignStatement represents variable assignment with optional type hint
// Unified design reduces AST node count and simplifies type checking logic. type AssignStatement struct {
type Assignment struct { Name Expression // Changed from *Identifier to Expression for member access
Target Expression // Target (identifier, dot, or index expression) TypeHint *TypeInfo // optional type hint
Value Expression // Value being assigned Value Expression
TypeHint TypeInfo // Optional explicit type hint, embeds directly IsDeclaration bool // true if this is the first assignment in current scope
IsDeclaration bool // True if declaring new variable in current scope
IsExpression bool // True if used as expression (wrapped in parentheses)
} }
func (a *Assignment) statementNode() {} func (as *AssignStatement) statementNode() {}
func (a *Assignment) expressionNode() {} func (as *AssignStatement) String() string {
func (a *Assignment) String() string {
prefix := "" prefix := ""
if a.IsDeclaration { if as.IsDeclaration {
prefix = "local " prefix = "local "
} }
var nameStr string var nameStr string
if a.TypeHint.Type != TypeUnknown { if as.TypeHint != nil {
nameStr = fmt.Sprintf("%s: %s", a.Target.String(), typeToString(a.TypeHint)) nameStr = fmt.Sprintf("%s: %s", as.Name.String(), as.TypeHint.Type)
} else { } else {
nameStr = a.Target.String() nameStr = as.Name.String()
} }
result := fmt.Sprintf("%s%s = %s", prefix, nameStr, a.Value.String()) return fmt.Sprintf("%s%s = %s", prefix, nameStr, as.Value.String())
if a.IsExpression {
return "(" + result + ")"
} }
return result
}
func (a *Assignment) TypeInfo() TypeInfo { return a.Value.TypeInfo() }
// ExpressionStatement wraps expressions used as statements. // AssignExpression represents assignment as an expression (only in parentheses)
// Allows function calls and other expressions at statement level. type AssignExpression struct {
Name Expression // Target (identifier, dot, or index expression)
Value Expression // Value to assign
IsDeclaration bool // true if this declares a new variable
typeInfo *TypeInfo // type of the expression (same as assigned value)
}
func (ae *AssignExpression) expressionNode() {}
func (ae *AssignExpression) String() string {
return fmt.Sprintf("(%s = %s)", ae.Name.String(), ae.Value.String())
}
func (ae *AssignExpression) GetType() *TypeInfo { return ae.typeInfo }
func (ae *AssignExpression) SetType(t *TypeInfo) { ae.typeInfo = t }
// ExpressionStatement represents expressions used as statements
type ExpressionStatement struct { type ExpressionStatement struct {
Expression Expression Expression Expression
} }
@ -143,8 +152,7 @@ func (es *ExpressionStatement) String() string {
return es.Expression.String() return es.Expression.String()
} }
// EchoStatement represents output statements for displaying values. // EchoStatement represents echo output statements
// Simple debugging and output mechanism built into the language.
type EchoStatement struct { type EchoStatement struct {
Value Expression Value Expression
} }
@ -154,17 +162,17 @@ func (es *EchoStatement) String() string {
return fmt.Sprintf("echo %s", es.Value.String()) return fmt.Sprintf("echo %s", es.Value.String())
} }
// BreakStatement represents loop exit statements. // BreakStatement represents break statements to exit loops
// Simple marker node with no additional data needed.
type BreakStatement struct{} type BreakStatement struct{}
func (bs *BreakStatement) statementNode() {} func (bs *BreakStatement) statementNode() {}
func (bs *BreakStatement) String() string { return "break" } func (bs *BreakStatement) String() string {
return "break"
}
// ExitStatement represents script termination with optional exit code. // ExitStatement represents exit statements to quit the script
// Value expression is nil for plain "exit", non-nil for "exit <code>".
type ExitStatement struct { type ExitStatement struct {
Value Expression // Optional exit code expression Value Expression // optional, can be nil
} }
func (es *ExitStatement) statementNode() {} func (es *ExitStatement) statementNode() {}
@ -175,10 +183,9 @@ func (es *ExitStatement) String() string {
return fmt.Sprintf("exit %s", es.Value.String()) return fmt.Sprintf("exit %s", es.Value.String())
} }
// ReturnStatement represents function return with optional value. // ReturnStatement represents return statements
// Value expression is nil for plain "return", non-nil for "return <value>".
type ReturnStatement struct { type ReturnStatement struct {
Value Expression // Optional return value expression Value Expression // optional, can be nil
} }
func (rs *ReturnStatement) statementNode() {} func (rs *ReturnStatement) statementNode() {}
@ -189,8 +196,7 @@ func (rs *ReturnStatement) String() string {
return fmt.Sprintf("return %s", rs.Value.String()) return fmt.Sprintf("return %s", rs.Value.String())
} }
// ElseIfClause represents conditional branches in if statements. // ElseIfClause represents an elseif condition
// Contains condition expression and body statements for this branch.
type ElseIfClause struct { type ElseIfClause struct {
Condition Expression Condition Expression
Body []Statement Body []Statement
@ -204,28 +210,30 @@ func (eic *ElseIfClause) String() string {
return fmt.Sprintf("elseif %s then\n%s", eic.Condition.String(), body) return fmt.Sprintf("elseif %s then\n%s", eic.Condition.String(), body)
} }
// IfStatement represents conditional execution with optional elseif and else branches. // IfStatement represents conditional statements
// Supports multiple elseif clauses and an optional final else clause.
type IfStatement struct { type IfStatement struct {
Condition Expression // Main condition Condition Expression
Body []Statement // Statements to execute if condition is true Body []Statement
ElseIfs []ElseIfClause // Optional elseif branches ElseIfs []ElseIfClause
Else []Statement // Optional else branch Else []Statement
} }
func (is *IfStatement) statementNode() {} func (is *IfStatement) statementNode() {}
func (is *IfStatement) String() string { func (is *IfStatement) String() string {
var result string var result string
// If clause
result += fmt.Sprintf("if %s then\n", is.Condition.String()) result += fmt.Sprintf("if %s then\n", is.Condition.String())
for _, stmt := range is.Body { for _, stmt := range is.Body {
result += "\t" + stmt.String() + "\n" result += "\t" + stmt.String() + "\n"
} }
// ElseIf clauses
for _, elseif := range is.ElseIfs { for _, elseif := range is.ElseIfs {
result += elseif.String() result += elseif.String()
} }
// Else clause
if len(is.Else) > 0 { if len(is.Else) > 0 {
result += "else\n" result += "else\n"
for _, stmt := range is.Else { for _, stmt := range is.Else {
@ -237,8 +245,7 @@ func (is *IfStatement) String() string {
return result return result
} }
// WhileStatement represents condition-based loops that execute while condition is true. // WhileStatement represents while loops: while condition do ... end
// Contains condition expression and body statements to repeat.
type WhileStatement struct { type WhileStatement struct {
Condition Expression Condition Expression
Body []Statement Body []Statement
@ -257,14 +264,13 @@ func (ws *WhileStatement) String() string {
return result return result
} }
// ForStatement represents numeric for loops with start, end, and optional step. // ForStatement represents numeric for loops: for i = start, end, step do ... end
// Variable is automatically scoped to the loop body.
type ForStatement struct { type ForStatement struct {
Variable *Identifier // Loop variable (automatically number type) Variable *Identifier
Start Expression // Starting value expression Start Expression
End Expression // Ending value expression End Expression
Step Expression // Optional step expression (nil means step of 1) Step Expression // optional, nil means step of 1
Body []Statement // Loop body statements Body []Statement
} }
func (fs *ForStatement) statementNode() {} func (fs *ForStatement) statementNode() {}
@ -286,13 +292,12 @@ func (fs *ForStatement) String() string {
return result return result
} }
// ForInStatement represents iterator-based loops over tables, arrays, or other iterables. // ForInStatement represents iterator for loops: for k, v in expr do ... end
// Supports both single variable (for v in iter) and key-value (for k,v in iter) forms.
type ForInStatement struct { type ForInStatement struct {
Key *Identifier // Optional key variable (nil for single variable iteration) Key *Identifier // optional, nil for single variable iteration
Value *Identifier // Value variable (required) Value *Identifier
Iterable Expression // Expression to iterate over Iterable Expression
Body []Statement // Loop body statements Body []Statement
} }
func (fis *ForInStatement) statementNode() {} func (fis *ForInStatement) statementNode() {}
@ -314,60 +319,56 @@ func (fis *ForInStatement) String() string {
return result return result
} }
// FunctionParameter represents a parameter in function definitions. // FunctionParameter represents a function parameter with optional type hint
// Contains parameter name and optional type hint for type checking.
type FunctionParameter struct { type FunctionParameter struct {
Name string Name string
TypeHint TypeInfo // Optional type constraint, embeds directly TypeHint *TypeInfo
} }
func (fp *FunctionParameter) String() string { func (fp *FunctionParameter) String() string {
if fp.TypeHint.Type != TypeUnknown { if fp.TypeHint != nil {
return fmt.Sprintf("%s: %s", fp.Name, typeToString(fp.TypeHint)) return fmt.Sprintf("%s: %s", fp.Name, fp.TypeHint.Type)
} }
return fp.Name return fp.Name
} }
// Identifier represents variable references and names. // Identifier represents identifiers
// Stores resolved type information for efficient type checking.
type Identifier struct { type Identifier struct {
Value string Value string
typeInfo TypeInfo // Resolved type, embeds directly typeInfo *TypeInfo
} }
func (i *Identifier) expressionNode() {} func (i *Identifier) expressionNode() {}
func (i *Identifier) String() string { return i.Value } func (i *Identifier) String() string { return i.Value }
func (i *Identifier) TypeInfo() TypeInfo { func (i *Identifier) GetType() *TypeInfo { return i.typeInfo }
if i.typeInfo.Type == TypeUnknown { func (i *Identifier) SetType(t *TypeInfo) { i.typeInfo = t }
return AnyType
}
return i.typeInfo
}
// NumberLiteral represents numeric constants including integers, floats, hex, and binary. // NumberLiteral represents numeric literals
// Always has number type, so no additional type storage needed.
type NumberLiteral struct { type NumberLiteral struct {
Value float64 // All numbers stored as float64 for simplicity Value float64
typeInfo *TypeInfo
} }
func (nl *NumberLiteral) expressionNode() {} func (nl *NumberLiteral) expressionNode() {}
func (nl *NumberLiteral) String() string { return fmt.Sprintf("%.2f", nl.Value) } func (nl *NumberLiteral) String() string { return fmt.Sprintf("%.2f", nl.Value) }
func (nl *NumberLiteral) TypeInfo() TypeInfo { return NumberType } func (nl *NumberLiteral) GetType() *TypeInfo { return nl.typeInfo }
func (nl *NumberLiteral) SetType(t *TypeInfo) { nl.typeInfo = t }
// StringLiteral represents string constants and multiline strings. // StringLiteral represents string literals
// Always has string type, so no additional type storage needed.
type StringLiteral struct { type StringLiteral struct {
Value string // String content without quotes Value string
typeInfo *TypeInfo
} }
func (sl *StringLiteral) expressionNode() {} func (sl *StringLiteral) expressionNode() {}
func (sl *StringLiteral) String() string { return fmt.Sprintf(`"%s"`, sl.Value) } func (sl *StringLiteral) String() string { return fmt.Sprintf(`"%s"`, sl.Value) }
func (sl *StringLiteral) TypeInfo() TypeInfo { return StringType } func (sl *StringLiteral) GetType() *TypeInfo { return sl.typeInfo }
func (sl *StringLiteral) SetType(t *TypeInfo) { sl.typeInfo = t }
// BooleanLiteral represents true and false constants. // BooleanLiteral represents boolean literals
// Always has bool type, so no additional type storage needed.
type BooleanLiteral struct { type BooleanLiteral struct {
Value bool Value bool
typeInfo *TypeInfo
} }
func (bl *BooleanLiteral) expressionNode() {} func (bl *BooleanLiteral) expressionNode() {}
@ -377,23 +378,26 @@ func (bl *BooleanLiteral) String() string {
} }
return "false" return "false"
} }
func (bl *BooleanLiteral) TypeInfo() TypeInfo { return BoolType } func (bl *BooleanLiteral) GetType() *TypeInfo { return bl.typeInfo }
func (bl *BooleanLiteral) SetType(t *TypeInfo) { bl.typeInfo = t }
// NilLiteral represents the nil constant value. // NilLiteral represents nil literal
// Always has nil type, so no additional type storage needed. type NilLiteral struct {
type NilLiteral struct{} typeInfo *TypeInfo
}
func (nl *NilLiteral) expressionNode() {} func (nl *NilLiteral) expressionNode() {}
func (nl *NilLiteral) String() string { return "nil" } func (nl *NilLiteral) String() string { return "nil" }
func (nl *NilLiteral) TypeInfo() TypeInfo { return NilType } func (nl *NilLiteral) GetType() *TypeInfo { return nl.typeInfo }
func (nl *NilLiteral) SetType(t *TypeInfo) { nl.typeInfo = t }
// FunctionLiteral represents function definitions with parameters, body, and optional return type. // FunctionLiteral represents function literals with typed parameters
// Always has function type, stores additional return type information separately.
type FunctionLiteral struct { type FunctionLiteral struct {
Parameters []FunctionParameter // Function parameters with optional types Parameters []FunctionParameter
Body []Statement // Function body statements Variadic bool
ReturnType TypeInfo // Optional return type hint, embeds directly ReturnType *TypeInfo // optional return type hint
Variadic bool // True if function accepts variable arguments Body []Statement
typeInfo *TypeInfo
} }
func (fl *FunctionLiteral) expressionNode() {} func (fl *FunctionLiteral) expressionNode() {}
@ -413,8 +417,8 @@ func (fl *FunctionLiteral) String() string {
} }
result := fmt.Sprintf("fn(%s)", params) result := fmt.Sprintf("fn(%s)", params)
if fl.ReturnType.Type != TypeUnknown { if fl.ReturnType != nil {
result += ": " + typeToString(fl.ReturnType) result += ": " + fl.ReturnType.Type
} }
result += "\n" result += "\n"
@ -424,14 +428,14 @@ func (fl *FunctionLiteral) String() string {
result += "end" result += "end"
return result return result
} }
func (fl *FunctionLiteral) TypeInfo() TypeInfo { return FunctionType } func (fl *FunctionLiteral) GetType() *TypeInfo { return fl.typeInfo }
func (fl *FunctionLiteral) SetType(t *TypeInfo) { fl.typeInfo = t }
// CallExpression represents function calls with arguments. // CallExpression represents function calls: func(arg1, arg2, ...)
// Stores inferred return type from function signature analysis.
type CallExpression struct { type CallExpression struct {
Function Expression // Function expression to call Function Expression
Arguments []Expression // Argument expressions Arguments []Expression
typeInfo TypeInfo // Inferred return type, embeds directly typeInfo *TypeInfo
} }
func (ce *CallExpression) expressionNode() {} func (ce *CallExpression) expressionNode() {}
@ -442,73 +446,74 @@ func (ce *CallExpression) String() string {
} }
return fmt.Sprintf("%s(%s)", ce.Function.String(), joinStrings(args, ", ")) return fmt.Sprintf("%s(%s)", ce.Function.String(), joinStrings(args, ", "))
} }
func (ce *CallExpression) TypeInfo() TypeInfo { return ce.typeInfo } func (ce *CallExpression) GetType() *TypeInfo { return ce.typeInfo }
func (ce *CallExpression) SetType(t *TypeInfo) { ce.typeInfo = t }
// PrefixExpression represents unary operations like negation and logical not. // PrefixExpression represents prefix operations like -x, not x
// Stores result type based on operator and operand type analysis.
type PrefixExpression struct { type PrefixExpression struct {
Operator string // Operator symbol ("-", "not") Operator string
Right Expression // Operand expression Right Expression
typeInfo TypeInfo // Result type, embeds directly typeInfo *TypeInfo
} }
func (pe *PrefixExpression) expressionNode() {} func (pe *PrefixExpression) expressionNode() {}
func (pe *PrefixExpression) String() string { func (pe *PrefixExpression) String() string {
// Add space for word operators
if pe.Operator == "not" { if pe.Operator == "not" {
return fmt.Sprintf("(%s %s)", pe.Operator, pe.Right.String()) return fmt.Sprintf("(%s %s)", pe.Operator, pe.Right.String())
} }
return fmt.Sprintf("(%s%s)", pe.Operator, pe.Right.String()) return fmt.Sprintf("(%s%s)", pe.Operator, pe.Right.String())
} }
func (pe *PrefixExpression) TypeInfo() TypeInfo { return pe.typeInfo } func (pe *PrefixExpression) GetType() *TypeInfo { return pe.typeInfo }
func (pe *PrefixExpression) SetType(t *TypeInfo) { pe.typeInfo = t }
// InfixExpression represents binary operations between two expressions. // InfixExpression represents binary operations
// Stores result type based on operator and operand type compatibility.
type InfixExpression struct { type InfixExpression struct {
Left Expression // Left operand Left Expression
Right Expression // Right operand Operator string
Operator string // Operator symbol ("+", "-", "==", "and", etc.) Right Expression
typeInfo TypeInfo // Result type, embeds directly typeInfo *TypeInfo
} }
func (ie *InfixExpression) expressionNode() {} func (ie *InfixExpression) expressionNode() {}
func (ie *InfixExpression) String() string { func (ie *InfixExpression) String() string {
return fmt.Sprintf("(%s %s %s)", ie.Left.String(), ie.Operator, ie.Right.String()) return fmt.Sprintf("(%s %s %s)", ie.Left.String(), ie.Operator, ie.Right.String())
} }
func (ie *InfixExpression) TypeInfo() TypeInfo { return ie.typeInfo } func (ie *InfixExpression) GetType() *TypeInfo { return ie.typeInfo }
func (ie *InfixExpression) SetType(t *TypeInfo) { ie.typeInfo = t }
// IndexExpression represents bracket-based member access (table[key]). // IndexExpression represents table[key] access
// Stores inferred element type based on container type analysis.
type IndexExpression struct { type IndexExpression struct {
Left Expression // Container expression Left Expression
Index Expression // Index/key expression Index Expression
typeInfo TypeInfo // Element type, embeds directly typeInfo *TypeInfo
} }
func (ie *IndexExpression) expressionNode() {} func (ie *IndexExpression) expressionNode() {}
func (ie *IndexExpression) String() string { func (ie *IndexExpression) String() string {
return fmt.Sprintf("%s[%s]", ie.Left.String(), ie.Index.String()) return fmt.Sprintf("%s[%s]", ie.Left.String(), ie.Index.String())
} }
func (ie *IndexExpression) TypeInfo() TypeInfo { return ie.typeInfo } func (ie *IndexExpression) GetType() *TypeInfo { return ie.typeInfo }
func (ie *IndexExpression) SetType(t *TypeInfo) { ie.typeInfo = t }
// DotExpression represents dot-based member access (table.key). // DotExpression represents table.key access
// Stores inferred member type based on container type and field analysis.
type DotExpression struct { type DotExpression struct {
Left Expression // Container expression Left Expression
Key string // Member name Key string
typeInfo TypeInfo // Member type, embeds directly typeInfo *TypeInfo
} }
func (de *DotExpression) expressionNode() {} func (de *DotExpression) expressionNode() {}
func (de *DotExpression) String() string { func (de *DotExpression) String() string {
return fmt.Sprintf("%s.%s", de.Left.String(), de.Key) return fmt.Sprintf("%s.%s", de.Left.String(), de.Key)
} }
func (de *DotExpression) TypeInfo() TypeInfo { return de.typeInfo } func (de *DotExpression) GetType() *TypeInfo { return de.typeInfo }
func (de *DotExpression) SetType(t *TypeInfo) { de.typeInfo = t }
// TablePair represents key-value pairs in table literals and struct constructors. // TablePair represents a key-value pair in a table
// Key is nil for array-style elements, non-nil for object-style elements.
type TablePair struct { type TablePair struct {
Key Expression // Key expression (nil for array elements) Key Expression // nil for array-style elements
Value Expression // Value expression Value Expression
} }
func (tp *TablePair) String() string { func (tp *TablePair) String() string {
@ -518,10 +523,10 @@ func (tp *TablePair) String() string {
return fmt.Sprintf("%s = %s", tp.Key.String(), tp.Value.String()) return fmt.Sprintf("%s = %s", tp.Key.String(), tp.Value.String())
} }
// TableLiteral represents table/array/object literals with key-value pairs. // TableLiteral represents table literals {}
// Always has table type, provides methods to check if it's array-style.
type TableLiteral struct { type TableLiteral struct {
Pairs []TablePair // Key-value pairs (key nil for array elements) Pairs []TablePair
typeInfo *TypeInfo
} }
func (tl *TableLiteral) expressionNode() {} func (tl *TableLiteral) expressionNode() {}
@ -532,9 +537,10 @@ func (tl *TableLiteral) String() string {
} }
return fmt.Sprintf("{%s}", joinStrings(pairs, ", ")) return fmt.Sprintf("{%s}", joinStrings(pairs, ", "))
} }
func (tl *TableLiteral) TypeInfo() TypeInfo { return TableType } func (tl *TableLiteral) GetType() *TypeInfo { return tl.typeInfo }
func (tl *TableLiteral) SetType(t *TypeInfo) { tl.typeInfo = t }
// IsArray returns true if this table contains only array-style elements (no explicit keys) // IsArray returns true if this table contains only array-style elements
func (tl *TableLiteral) IsArray() bool { func (tl *TableLiteral) IsArray() bool {
for _, pair := range tl.Pairs { for _, pair := range tl.Pairs {
if pair.Key != nil { if pair.Key != nil {
@ -544,31 +550,7 @@ func (tl *TableLiteral) IsArray() bool {
return true return true
} }
// Helper function to convert TypeInfo to string representation // joinStrings joins string slice with separator
func typeToString(t TypeInfo) string {
switch t.Type {
case TypeNumber:
return "number"
case TypeString:
return "string"
case TypeBool:
return "bool"
case TypeNil:
return "nil"
case TypeTable:
return "table"
case TypeFunction:
return "function"
case TypeAny:
return "any"
case TypeStruct:
return fmt.Sprintf("struct<%d>", t.StructID)
default:
return "unknown"
}
}
// joinStrings efficiently joins string slice with separator
func joinStrings(strs []string, sep string) string { func joinStrings(strs []string, sep string) string {
if len(strs) == 0 { if len(strs) == 0 {
return "" return ""

View File

@ -19,7 +19,7 @@ func (pe ParseError) Error() string {
pe.Line, pe.Column, pe.Message, pe.Token.Literal) pe.Line, pe.Column, pe.Message, pe.Token.Literal)
} }
// Parser implements a recursive descent Pratt parser with optimized AST generation // Parser implements a recursive descent Pratt parser
type Parser struct { type Parser struct {
lexer *Lexer lexer *Lexer
@ -32,13 +32,11 @@ type Parser struct {
errors []ParseError errors []ParseError
// Scope tracking // Scope tracking
scopes []map[string]bool scopes []map[string]bool // stack of scopes, each tracking declared variables
scopeTypes []string scopeTypes []string // track what type each scope is: "global", "function", "loop"
// Struct tracking with ID mapping // Struct tracking
structs map[string]*StructStatement structs map[string]*StructStatement // track defined structs
structIDs map[uint16]*StructStatement
nextID uint16
} }
// NewParser creates a new parser instance // NewParser creates a new parser instance
@ -46,11 +44,9 @@ func NewParser(lexer *Lexer) *Parser {
p := &Parser{ p := &Parser{
lexer: lexer, lexer: lexer,
errors: []ParseError{}, errors: []ParseError{},
scopes: []map[string]bool{make(map[string]bool)}, scopes: []map[string]bool{make(map[string]bool)}, // start with global scope
scopeTypes: []string{"global"}, scopeTypes: []string{"global"}, // start with global scope type
structs: make(map[string]*StructStatement), structs: make(map[string]*StructStatement), // track struct definitions
structIDs: make(map[uint16]*StructStatement),
nextID: 1, // 0 reserved for non-struct types
} }
p.prefixParseFns = make(map[TokenType]func() Expression) p.prefixParseFns = make(map[TokenType]func() Expression)
@ -82,31 +78,15 @@ func NewParser(lexer *Lexer) *Parser {
p.registerInfix(DOT, p.parseDotExpression) p.registerInfix(DOT, p.parseDotExpression)
p.registerInfix(LBRACKET, p.parseIndexExpression) p.registerInfix(LBRACKET, p.parseIndexExpression)
p.registerInfix(LPAREN, p.parseCallExpression) p.registerInfix(LPAREN, p.parseCallExpression)
p.registerInfix(LBRACE, p.parseStructConstructor) p.registerInfix(LBRACE, p.parseStructConstructor) // struct constructor
// Read two tokens, so curToken and peekToken are both set
p.nextToken() p.nextToken()
p.nextToken() p.nextToken()
return p return p
} }
// Struct management
func (p *Parser) registerStruct(stmt *StructStatement) {
stmt.ID = p.nextID
p.nextID++
p.structs[stmt.Name] = stmt
p.structIDs[stmt.ID] = stmt
}
func (p *Parser) getStructByName(name string) *StructStatement {
return p.structs[name]
}
func (p *Parser) isStructDefined(name string) bool {
_, exists := p.structs[name]
return exists
}
// Scope management // Scope management
func (p *Parser) enterScope(scopeType string) { func (p *Parser) enterScope(scopeType string) {
p.scopes = append(p.scopes, make(map[string]bool)) p.scopes = append(p.scopes, make(map[string]bool))
@ -120,6 +100,30 @@ func (p *Parser) exitScope() {
} }
} }
func (p *Parser) enterFunctionScope() {
p.enterScope("function")
}
func (p *Parser) exitFunctionScope() {
p.exitScope()
}
func (p *Parser) enterLoopScope() {
p.enterScope("loop")
}
func (p *Parser) exitLoopScope() {
p.exitScope()
}
func (p *Parser) enterBlockScope() {
// Blocks don't create new variable scopes
}
func (p *Parser) exitBlockScope() {
// No-op
}
func (p *Parser) currentVariableScope() map[string]bool { func (p *Parser) currentVariableScope() map[string]bool {
if len(p.scopeTypes) > 1 && p.scopeTypes[len(p.scopeTypes)-1] == "loop" { if len(p.scopeTypes) > 1 && p.scopeTypes[len(p.scopeTypes)-1] == "loop" {
return p.scopes[len(p.scopes)-2] return p.scopes[len(p.scopes)-2]
@ -144,56 +148,45 @@ func (p *Parser) declareLoopVariable(name string) {
p.scopes[len(p.scopes)-1][name] = true p.scopes[len(p.scopes)-1][name] = true
} }
// parseTypeHint parses optional type hint after colon, returns by value // parseTypeHint parses optional type hint after colon
func (p *Parser) parseTypeHint() TypeInfo { func (p *Parser) parseTypeHint() *TypeInfo {
if !p.peekTokenIs(COLON) { if !p.peekTokenIs(COLON) {
return UnknownType return nil
} }
p.nextToken() // consume ':' p.nextToken() // consume ':'
if !p.expectPeekIdent() { if !p.expectPeekIdent() {
p.addError("expected type name after ':'") p.addError("expected type name after ':'")
return UnknownType return nil
} }
typeName := p.curToken.Literal typeName := p.curToken.Literal
if !ValidTypeName(typeName) && !p.isStructDefined(typeName) {
// Check built-in types
switch typeName {
case "number":
return TypeInfo{Type: TypeNumber, Inferred: false}
case "string":
return TypeInfo{Type: TypeString, Inferred: false}
case "bool":
return TypeInfo{Type: TypeBool, Inferred: false}
case "nil":
return TypeInfo{Type: TypeNil, Inferred: false}
case "table":
return TypeInfo{Type: TypeTable, Inferred: false}
case "function":
return TypeInfo{Type: TypeFunction, Inferred: false}
case "any":
return TypeInfo{Type: TypeAny, Inferred: false}
default:
// Check if it's a struct type
if structDef := p.getStructByName(typeName); structDef != nil {
return TypeInfo{Type: TypeStruct, StructID: structDef.ID, Inferred: false}
}
p.addError(fmt.Sprintf("invalid type name '%s'", typeName)) p.addError(fmt.Sprintf("invalid type name '%s'", typeName))
return UnknownType return nil
}
} }
// registerPrefix/registerInfix return &TypeInfo{Type: typeName, Inferred: false}
}
// isStructDefined checks if a struct name is defined
func (p *Parser) isStructDefined(name string) bool {
_, exists := p.structs[name]
return exists
}
// registerPrefix registers a prefix parse function
func (p *Parser) registerPrefix(tokenType TokenType, fn func() Expression) { func (p *Parser) registerPrefix(tokenType TokenType, fn func() Expression) {
p.prefixParseFns[tokenType] = fn p.prefixParseFns[tokenType] = fn
} }
// registerInfix registers an infix parse function
func (p *Parser) registerInfix(tokenType TokenType, fn func(Expression) Expression) { func (p *Parser) registerInfix(tokenType TokenType, fn func(Expression) Expression) {
p.infixParseFns[tokenType] = fn p.infixParseFns[tokenType] = fn
} }
// nextToken advances to the next token
func (p *Parser) nextToken() { func (p *Parser) nextToken() {
p.curToken = p.peekToken p.curToken = p.peekToken
p.peekToken = p.lexer.NextToken() p.peekToken = p.lexer.NextToken()
@ -272,7 +265,7 @@ func (p *Parser) parseStructStatement() *StructStatement {
if p.peekTokenIs(RBRACE) { if p.peekTokenIs(RBRACE) {
p.nextToken() p.nextToken()
p.registerStruct(stmt) p.structs[stmt.Name] = stmt
return stmt return stmt
} }
@ -291,9 +284,9 @@ func (p *Parser) parseStructStatement() *StructStatement {
field := StructField{Name: p.curToken.Literal} field := StructField{Name: p.curToken.Literal}
// Parse required type hint // Parse optional type hint
field.TypeHint = p.parseTypeHint() field.TypeHint = p.parseTypeHint()
if field.TypeHint.Type == TypeUnknown { if field.TypeHint == nil {
p.addError("struct fields require type annotation") p.addError("struct fields require type annotation")
return nil return nil
} }
@ -321,7 +314,7 @@ func (p *Parser) parseStructStatement() *StructStatement {
return nil return nil
} }
p.registerStruct(stmt) p.structs[stmt.Name] = stmt
return stmt return stmt
} }
@ -345,19 +338,12 @@ func (p *Parser) parseFunctionStatement() Statement {
methodName := p.curToken.Literal methodName := p.curToken.Literal
// Get struct ID
structDef := p.getStructByName(funcName)
if structDef == nil {
p.addError(fmt.Sprintf("method defined on undefined struct '%s'", funcName))
return nil
}
if !p.expectPeek(LPAREN) { if !p.expectPeek(LPAREN) {
p.addError("expected '(' after method name") p.addError("expected '(' after method name")
return nil return nil
} }
// Parse the function literal // Parse the function literal starting from parameters
funcLit := &FunctionLiteral{} funcLit := &FunctionLiteral{}
funcLit.Parameters, funcLit.Variadic = p.parseFunctionParameters() funcLit.Parameters, funcLit.Variadic = p.parseFunctionParameters()
@ -371,12 +357,12 @@ func (p *Parser) parseFunctionStatement() Statement {
p.nextToken() p.nextToken()
p.enterScope("function") p.enterFunctionScope()
for _, param := range funcLit.Parameters { for _, param := range funcLit.Parameters {
p.declareVariable(param.Name) p.declareVariable(param.Name)
} }
funcLit.Body = p.parseBlockStatements(END) funcLit.Body = p.parseBlockStatements(END)
p.exitScope() p.exitFunctionScope()
if !p.curTokenIs(END) { if !p.curTokenIs(END) {
p.addError("expected 'end' to close function") p.addError("expected 'end' to close function")
@ -384,13 +370,14 @@ func (p *Parser) parseFunctionStatement() Statement {
} }
return &MethodDefinition{ return &MethodDefinition{
StructID: structDef.ID, StructName: funcName,
MethodName: methodName, MethodName: methodName,
Function: funcLit, Function: funcLit,
} }
} }
// Regular function - handle as function literal expression statement // Regular function - this should be handled as expression statement
// Reset to handle as function literal
funcLit := p.parseFunctionLiteral() funcLit := p.parseFunctionLiteral()
if funcLit == nil { if funcLit == nil {
return nil return nil
@ -399,7 +386,7 @@ func (p *Parser) parseFunctionStatement() Statement {
return &ExpressionStatement{Expression: funcLit} return &ExpressionStatement{Expression: funcLit}
} }
// parseIdentifierStatement handles assignments and expression statements // parseIdentifierStatement handles both assignments and expression statements starting with identifiers
func (p *Parser) parseIdentifierStatement() Statement { func (p *Parser) parseIdentifierStatement() Statement {
// Parse the left-hand side expression first // Parse the left-hand side expression first
expr := p.ParseExpression(LOWEST) expr := p.ParseExpression(LOWEST)
@ -408,28 +395,28 @@ func (p *Parser) parseIdentifierStatement() Statement {
} }
// Check for type hint (only valid on simple identifiers) // Check for type hint (only valid on simple identifiers)
var typeHint TypeInfo = UnknownType var typeHint *TypeInfo
if _, ok := expr.(*Identifier); ok { if _, ok := expr.(*Identifier); ok {
typeHint = p.parseTypeHint() typeHint = p.parseTypeHint()
} }
// Check if this is an assignment // Check if this is an assignment
if p.peekTokenIs(ASSIGN) { if p.peekTokenIs(ASSIGN) {
// Create unified assignment // Convert to assignment statement
assignment := &Assignment{ stmt := &AssignStatement{
Target: expr, Name: expr,
TypeHint: typeHint, TypeHint: typeHint,
} }
// Validate assignment target and check if it's a declaration // Validate assignment target and check if it's a declaration
switch target := expr.(type) { switch name := expr.(type) {
case *Identifier: case *Identifier:
assignment.IsDeclaration = !p.isVariableDeclared(target.Value) stmt.IsDeclaration = !p.isVariableDeclared(name.Value)
if assignment.IsDeclaration { if stmt.IsDeclaration {
p.declareVariable(target.Value) p.declareVariable(name.Value)
} }
case *DotExpression, *IndexExpression: case *DotExpression, *IndexExpression:
assignment.IsDeclaration = false stmt.IsDeclaration = false
default: default:
p.addError("invalid assignment target") p.addError("invalid assignment target")
return nil return nil
@ -441,19 +428,29 @@ func (p *Parser) parseIdentifierStatement() Statement {
p.nextToken() p.nextToken()
assignment.Value = p.ParseExpression(LOWEST) stmt.Value = p.ParseExpression(LOWEST)
if assignment.Value == nil { if stmt.Value == nil {
p.addError("expected expression after assignment operator") p.addError("expected expression after assignment operator")
return nil return nil
} }
return assignment return stmt
} else { } else {
// This is an expression statement // This is an expression statement
return &ExpressionStatement{Expression: expr} return &ExpressionStatement{Expression: expr}
} }
} }
// parseExpressionStatement parses expressions used as statements
func (p *Parser) parseExpressionStatement() *ExpressionStatement {
stmt := &ExpressionStatement{}
stmt.Expression = p.ParseExpression(LOWEST)
if stmt.Expression == nil {
return nil
}
return stmt
}
// parseEchoStatement parses echo statements // parseEchoStatement parses echo statements
func (p *Parser) parseEchoStatement() *EchoStatement { func (p *Parser) parseEchoStatement() *EchoStatement {
stmt := &EchoStatement{} stmt := &EchoStatement{}
@ -469,8 +466,9 @@ func (p *Parser) parseEchoStatement() *EchoStatement {
return stmt return stmt
} }
// Simple statement parsers // parseBreakStatement parses break statements
func (p *Parser) parseBreakStatement() *BreakStatement { func (p *Parser) parseBreakStatement() *BreakStatement {
// Check if break is followed by an identifier (invalid)
if p.peekTokenIs(IDENT) { if p.peekTokenIs(IDENT) {
p.addError("unexpected identifier") p.addError("unexpected identifier")
return nil return nil
@ -478,6 +476,7 @@ func (p *Parser) parseBreakStatement() *BreakStatement {
return &BreakStatement{} return &BreakStatement{}
} }
// parseExitStatement parses exit statements
func (p *Parser) parseExitStatement() *ExitStatement { func (p *Parser) parseExitStatement() *ExitStatement {
stmt := &ExitStatement{} stmt := &ExitStatement{}
@ -493,6 +492,7 @@ func (p *Parser) parseExitStatement() *ExitStatement {
return stmt return stmt
} }
// parseReturnStatement parses return statements
func (p *Parser) parseReturnStatement() *ReturnStatement { func (p *Parser) parseReturnStatement() *ReturnStatement {
stmt := &ReturnStatement{} stmt := &ReturnStatement{}
@ -508,6 +508,7 @@ func (p *Parser) parseReturnStatement() *ReturnStatement {
return stmt return stmt
} }
// canStartExpression checks if a token type can start an expression
func (p *Parser) canStartExpression(tokenType TokenType) bool { func (p *Parser) canStartExpression(tokenType TokenType) bool {
switch tokenType { switch tokenType {
case IDENT, NUMBER, STRING, TRUE, FALSE, NIL, LPAREN, LBRACE, MINUS, NOT, FN: case IDENT, NUMBER, STRING, TRUE, FALSE, NIL, LPAREN, LBRACE, MINUS, NOT, FN:
@ -517,7 +518,7 @@ func (p *Parser) canStartExpression(tokenType TokenType) bool {
} }
} }
// Loop statement parsers // parseWhileStatement parses while loops
func (p *Parser) parseWhileStatement() *WhileStatement { func (p *Parser) parseWhileStatement() *WhileStatement {
stmt := &WhileStatement{} stmt := &WhileStatement{}
@ -536,7 +537,9 @@ func (p *Parser) parseWhileStatement() *WhileStatement {
p.nextToken() p.nextToken()
p.enterBlockScope()
stmt.Body = p.parseBlockStatements(END) stmt.Body = p.parseBlockStatements(END)
p.exitBlockScope()
if !p.curTokenIs(END) { if !p.curTokenIs(END) {
p.addError("expected 'end' to close while loop") p.addError("expected 'end' to close while loop")
@ -546,6 +549,7 @@ func (p *Parser) parseWhileStatement() *WhileStatement {
return stmt return stmt
} }
// parseForStatement parses for loops
func (p *Parser) parseForStatement() Statement { func (p *Parser) parseForStatement() Statement {
p.nextToken() p.nextToken()
@ -566,6 +570,7 @@ func (p *Parser) parseForStatement() Statement {
} }
} }
// parseNumericForStatement parses numeric for loops
func (p *Parser) parseNumericForStatement(variable *Identifier) *ForStatement { func (p *Parser) parseNumericForStatement(variable *Identifier) *ForStatement {
stmt := &ForStatement{Variable: variable} stmt := &ForStatement{Variable: variable}
@ -612,10 +617,10 @@ func (p *Parser) parseNumericForStatement(variable *Identifier) *ForStatement {
p.nextToken() p.nextToken()
p.enterScope("loop") p.enterLoopScope()
p.declareLoopVariable(variable.Value) p.declareLoopVariable(variable.Value)
stmt.Body = p.parseBlockStatements(END) stmt.Body = p.parseBlockStatements(END)
p.exitScope() p.exitLoopScope()
if !p.curTokenIs(END) { if !p.curTokenIs(END) {
p.addError("expected 'end' to close for loop") p.addError("expected 'end' to close for loop")
@ -625,6 +630,7 @@ func (p *Parser) parseNumericForStatement(variable *Identifier) *ForStatement {
return stmt return stmt
} }
// parseForInStatement parses for-in loops
func (p *Parser) parseForInStatement(firstVar *Identifier) *ForInStatement { func (p *Parser) parseForInStatement(firstVar *Identifier) *ForInStatement {
stmt := &ForInStatement{} stmt := &ForInStatement{}
@ -663,13 +669,13 @@ func (p *Parser) parseForInStatement(firstVar *Identifier) *ForInStatement {
p.nextToken() p.nextToken()
p.enterScope("loop") p.enterLoopScope()
if stmt.Key != nil { if stmt.Key != nil {
p.declareLoopVariable(stmt.Key.Value) p.declareLoopVariable(stmt.Key.Value)
} }
p.declareLoopVariable(stmt.Value.Value) p.declareLoopVariable(stmt.Value.Value)
stmt.Body = p.parseBlockStatements(END) stmt.Body = p.parseBlockStatements(END)
p.exitScope() p.exitLoopScope()
if !p.curTokenIs(END) { if !p.curTokenIs(END) {
p.addError("expected 'end' to close for loop") p.addError("expected 'end' to close for loop")
@ -702,7 +708,9 @@ func (p *Parser) parseIfStatement() *IfStatement {
return nil return nil
} }
p.enterBlockScope()
stmt.Body = p.parseBlockStatements(ELSEIF, ELSE, END) stmt.Body = p.parseBlockStatements(ELSEIF, ELSE, END)
p.exitBlockScope()
for p.curTokenIs(ELSEIF) { for p.curTokenIs(ELSEIF) {
elseif := ElseIfClause{} elseif := ElseIfClause{}
@ -721,13 +729,19 @@ func (p *Parser) parseIfStatement() *IfStatement {
p.nextToken() p.nextToken()
p.enterBlockScope()
elseif.Body = p.parseBlockStatements(ELSEIF, ELSE, END) elseif.Body = p.parseBlockStatements(ELSEIF, ELSE, END)
p.exitBlockScope()
stmt.ElseIfs = append(stmt.ElseIfs, elseif) stmt.ElseIfs = append(stmt.ElseIfs, elseif)
} }
if p.curTokenIs(ELSE) { if p.curTokenIs(ELSE) {
p.nextToken() p.nextToken()
p.enterBlockScope()
stmt.Else = p.parseBlockStatements(END) stmt.Else = p.parseBlockStatements(END)
p.exitBlockScope()
} }
if !p.curTokenIs(END) { if !p.curTokenIs(END) {
@ -740,7 +754,7 @@ func (p *Parser) parseIfStatement() *IfStatement {
// parseBlockStatements parses statements until terminators // parseBlockStatements parses statements until terminators
func (p *Parser) parseBlockStatements(terminators ...TokenType) []Statement { func (p *Parser) parseBlockStatements(terminators ...TokenType) []Statement {
statements := make([]Statement, 0, 8) // Pre-allocate for performance statements := []Statement{}
for !p.curTokenIs(EOF) && !p.isTerminator(terminators...) { for !p.curTokenIs(EOF) && !p.isTerminator(terminators...) {
stmt := p.parseStatement() stmt := p.parseStatement()
@ -753,6 +767,7 @@ func (p *Parser) parseBlockStatements(terminators ...TokenType) []Statement {
return statements return statements
} }
// isTerminator checks if current token is a terminator
func (p *Parser) isTerminator(terminators ...TokenType) bool { func (p *Parser) isTerminator(terminators ...TokenType) bool {
for _, terminator := range terminators { for _, terminator := range terminators {
if p.curTokenIs(terminator) { if p.curTokenIs(terminator) {
@ -904,6 +919,7 @@ func (p *Parser) parseGroupedExpression() Expression {
// parseParenthesizedAssignment parses assignment expressions in parentheses // parseParenthesizedAssignment parses assignment expressions in parentheses
func (p *Parser) parseParenthesizedAssignment() Expression { func (p *Parser) parseParenthesizedAssignment() Expression {
// We're at identifier, peek is ASSIGN
target := p.parseIdentifier() target := p.parseIdentifier()
if !p.expectPeek(ASSIGN) { if !p.expectPeek(ASSIGN) {
@ -923,10 +939,9 @@ func (p *Parser) parseParenthesizedAssignment() Expression {
} }
// Create assignment expression // Create assignment expression
assignExpr := &Assignment{ assignExpr := &AssignExpression{
Target: target, Name: target,
Value: value, Value: value,
IsExpression: true,
} }
// Handle variable declaration for assignment expressions // Handle variable declaration for assignment expressions
@ -937,6 +952,8 @@ func (p *Parser) parseParenthesizedAssignment() Expression {
} }
} }
// Assignment expression evaluates to the assigned value
assignExpr.SetType(value.GetType())
return assignExpr return assignExpr
} }
@ -960,12 +977,12 @@ func (p *Parser) parseFunctionLiteral() Expression {
p.nextToken() p.nextToken()
p.enterScope("function") p.enterFunctionScope()
for _, param := range fn.Parameters { for _, param := range fn.Parameters {
p.declareVariable(param.Name) p.declareVariable(param.Name)
} }
fn.Body = p.parseBlockStatements(END) fn.Body = p.parseBlockStatements(END)
p.exitScope() p.exitFunctionScope()
if !p.curTokenIs(END) { if !p.curTokenIs(END) {
p.addError("expected 'end' to close function") p.addError("expected 'end' to close function")
@ -1021,7 +1038,7 @@ func (p *Parser) parseFunctionParameters() ([]FunctionParameter, bool) {
func (p *Parser) parseTableLiteral() Expression { func (p *Parser) parseTableLiteral() Expression {
table := &TableLiteral{} table := &TableLiteral{}
table.Pairs = make([]TablePair, 0, 4) // Pre-allocate table.Pairs = []TablePair{}
if p.peekTokenIs(RBRACE) { if p.peekTokenIs(RBRACE) {
p.nextToken() p.nextToken()
@ -1087,24 +1104,22 @@ func (p *Parser) parseTableLiteral() Expression {
return table return table
} }
// parseStructConstructor handles struct constructor calls // parseStructConstructor handles struct constructor calls like my_type{...}
func (p *Parser) parseStructConstructor(left Expression) Expression { func (p *Parser) parseStructConstructor(left Expression) Expression {
// left should be an identifier representing the struct name
ident, ok := left.(*Identifier) ident, ok := left.(*Identifier)
if !ok { if !ok {
// Not an identifier, fall back to table literal parsing
return p.parseTableLiteralFromBrace() return p.parseTableLiteralFromBrace()
} }
structName := ident.Value structName := ident.Value
structDef := p.getStructByName(structName)
if structDef == nil {
// Not a struct, parse as table literal
return p.parseTableLiteralFromBrace()
}
constructor := &StructConstructor{ // Always try to parse as struct constructor if we have an identifier
StructID: structDef.ID, // Type checking will catch undefined structs later
Fields: make([]TablePair, 0, 4), constructor := &StructConstructorExpression{
typeInfo: TypeInfo{Type: TypeStruct, StructID: structDef.ID, Inferred: true}, StructName: structName,
Fields: []TablePair{},
} }
if p.peekTokenIs(RBRACE) { if p.peekTokenIs(RBRACE) {
@ -1172,8 +1187,9 @@ func (p *Parser) parseStructConstructor(left Expression) Expression {
} }
func (p *Parser) parseTableLiteralFromBrace() Expression { func (p *Parser) parseTableLiteralFromBrace() Expression {
// We're already at the opening brace, so parse as table literal
table := &TableLiteral{} table := &TableLiteral{}
table.Pairs = make([]TablePair, 0, 4) table.Pairs = []TablePair{}
if p.peekTokenIs(RBRACE) { if p.peekTokenIs(RBRACE) {
p.nextToken() p.nextToken()
@ -1412,9 +1428,15 @@ func (p *Parser) curPrecedence() Precedence {
return LOWEST return LOWEST
} }
// Error reporting // Errors returns all parsing errors
func (p *Parser) Errors() []ParseError { return p.errors } func (p *Parser) Errors() []ParseError {
func (p *Parser) HasErrors() bool { return len(p.errors) > 0 } return p.errors
}
func (p *Parser) HasErrors() bool {
return len(p.errors) > 0
}
func (p *Parser) ErrorStrings() []string { func (p *Parser) ErrorStrings() []string {
result := make([]string, len(p.errors)) result := make([]string, len(p.errors))
for i, err := range p.errors { for i, err := range p.errors {

View File

@ -31,15 +31,15 @@ func TestAssignStatements(t *testing.T) {
t.Fatalf("expected 1 statement, got %d", len(program.Statements)) t.Fatalf("expected 1 statement, got %d", len(program.Statements))
} }
stmt, ok := program.Statements[0].(*parser.Assignment) stmt, ok := program.Statements[0].(*parser.AssignStatement)
if !ok { if !ok {
t.Fatalf("expected Assignment, got %T", program.Statements[0]) t.Fatalf("expected AssignStatement, got %T", program.Statements[0])
} }
// Check that Target is an Identifier // Check that Name is an Identifier
ident, ok := stmt.Target.(*parser.Identifier) ident, ok := stmt.Name.(*parser.Identifier)
if !ok { if !ok {
t.Fatalf("expected Identifier for Target, got %T", stmt.Target) t.Fatalf("expected Identifier for Name, got %T", stmt.Name)
} }
if ident.Value != tt.expectedIdentifier { if ident.Value != tt.expectedIdentifier {
@ -90,9 +90,9 @@ func TestMemberAccessAssignment(t *testing.T) {
t.Fatalf("expected 1 statement, got %d", len(program.Statements)) t.Fatalf("expected 1 statement, got %d", len(program.Statements))
} }
stmt, ok := program.Statements[0].(*parser.Assignment) stmt, ok := program.Statements[0].(*parser.AssignStatement)
if !ok { if !ok {
t.Fatalf("expected Assignment, got %T", program.Statements[0]) t.Fatalf("expected AssignStatement, got %T", program.Statements[0])
} }
if stmt.String() != tt.expected { if stmt.String() != tt.expected {
@ -158,15 +158,15 @@ func TestTableAssignments(t *testing.T) {
t.Fatalf("expected 1 statement, got %d", len(program.Statements)) t.Fatalf("expected 1 statement, got %d", len(program.Statements))
} }
stmt, ok := program.Statements[0].(*parser.Assignment) stmt, ok := program.Statements[0].(*parser.AssignStatement)
if !ok { if !ok {
t.Fatalf("expected Assignment, got %T", program.Statements[0]) t.Fatalf("expected AssignStatement, got %T", program.Statements[0])
} }
// Check that Target is an Identifier // Check that Name is an Identifier
ident, ok := stmt.Target.(*parser.Identifier) ident, ok := stmt.Name.(*parser.Identifier)
if !ok { if !ok {
t.Fatalf("expected Identifier for Target, got %T", stmt.Target) t.Fatalf("expected Identifier for Name, got %T", stmt.Name)
} }
if ident.Value != tt.identifier { if ident.Value != tt.identifier {

View File

@ -247,9 +247,9 @@ exit "success"`
} }
// First: assignment // First: assignment
_, ok := program.Statements[0].(*parser.Assignment) _, ok := program.Statements[0].(*parser.AssignStatement)
if !ok { if !ok {
t.Fatalf("statement 0: expected Assignment, got %T", program.Statements[0]) t.Fatalf("statement 0: expected AssignStatement, got %T", program.Statements[0])
} }
// Second: if statement with exit in body // Second: if statement with exit in body
@ -264,7 +264,7 @@ exit "success"`
} }
// Third: assignment // Third: assignment
_, ok = program.Statements[2].(*parser.Assignment) _, ok = program.Statements[2].(*parser.AssignStatement)
if !ok { if !ok {
t.Fatalf("statement 2: expected AssignStatement, got %T", program.Statements[2]) t.Fatalf("statement 2: expected AssignStatement, got %T", program.Statements[2])
} }

View File

@ -32,15 +32,15 @@ end`
t.Fatalf("expected 1 body statement, got %d", len(stmt.Body)) t.Fatalf("expected 1 body statement, got %d", len(stmt.Body))
} }
bodyStmt, ok := stmt.Body[0].(*parser.Assignment) bodyStmt, ok := stmt.Body[0].(*parser.AssignStatement)
if !ok { if !ok {
t.Fatalf("expected AssignStatement in body, got %T", stmt.Body[0]) t.Fatalf("expected AssignStatement in body, got %T", stmt.Body[0])
} }
// Check that Target is an Identifier // Check that Name is an Identifier
ident, ok := bodyStmt.Target.(*parser.Identifier) ident, ok := bodyStmt.Name.(*parser.Identifier)
if !ok { if !ok {
t.Fatalf("expected Identifier for Target, got %T", bodyStmt.Target) t.Fatalf("expected Identifier for Name, got %T", bodyStmt.Name)
} }
if ident.Value != "x" { if ident.Value != "x" {
@ -79,15 +79,15 @@ end`
t.Fatalf("expected 1 else statement, got %d", len(stmt.Else)) t.Fatalf("expected 1 else statement, got %d", len(stmt.Else))
} }
elseStmt, ok := stmt.Else[0].(*parser.Assignment) elseStmt, ok := stmt.Else[0].(*parser.AssignStatement)
if !ok { if !ok {
t.Fatalf("expected AssignStatement in else, got %T", stmt.Else[0]) t.Fatalf("expected AssignStatement in else, got %T", stmt.Else[0])
} }
// Check that Name is an Identifier // Check that Name is an Identifier
ident, ok := elseStmt.Target.(*parser.Identifier) ident, ok := elseStmt.Name.(*parser.Identifier)
if !ok { if !ok {
t.Fatalf("expected Identifier for Name, got %T", elseStmt.Target) t.Fatalf("expected Identifier for Name, got %T", elseStmt.Name)
} }
if ident.Value != "x" { if ident.Value != "x" {
@ -169,25 +169,25 @@ end`
} }
// First assignment: arr[1] = "updated" // First assignment: arr[1] = "updated"
assign1, ok := stmt.Body[0].(*parser.Assignment) assign1, ok := stmt.Body[0].(*parser.AssignStatement)
if !ok { if !ok {
t.Fatalf("expected AssignStatement, got %T", stmt.Body[0]) t.Fatalf("expected AssignStatement, got %T", stmt.Body[0])
} }
_, ok = assign1.Target.(*parser.IndexExpression) _, ok = assign1.Name.(*parser.IndexExpression)
if !ok { if !ok {
t.Fatalf("expected IndexExpression for assignment target, got %T", assign1.Target) t.Fatalf("expected IndexExpression for assignment target, got %T", assign1.Name)
} }
// Second assignment: obj.nested.count = obj.nested.count + 1 // Second assignment: obj.nested.count = obj.nested.count + 1
assign2, ok := stmt.Body[1].(*parser.Assignment) assign2, ok := stmt.Body[1].(*parser.AssignStatement)
if !ok { if !ok {
t.Fatalf("expected AssignStatement, got %T", stmt.Body[1]) t.Fatalf("expected AssignStatement, got %T", stmt.Body[1])
} }
_, ok = assign2.Target.(*parser.DotExpression) _, ok = assign2.Name.(*parser.DotExpression)
if !ok { if !ok {
t.Fatalf("expected DotExpression for assignment target, got %T", assign2.Target) t.Fatalf("expected DotExpression for assignment target, got %T", assign2.Name)
} }
} }
@ -214,7 +214,7 @@ end`
} }
// Test body has expression assignment // Test body has expression assignment
bodyStmt := stmt.Body[0].(*parser.Assignment) bodyStmt := stmt.Body[0].(*parser.AssignStatement)
bodyInfix, ok := bodyStmt.Value.(*parser.InfixExpression) bodyInfix, ok := bodyStmt.Value.(*parser.InfixExpression)
if !ok { if !ok {
t.Fatalf("expected InfixExpression value, got %T", bodyStmt.Value) t.Fatalf("expected InfixExpression value, got %T", bodyStmt.Value)

View File

@ -352,15 +352,15 @@ func TestAssignmentExpressions(t *testing.T) {
expr := p.ParseExpression(parser.LOWEST) expr := p.ParseExpression(parser.LOWEST)
checkParserErrors(t, p) checkParserErrors(t, p)
assignExpr, ok := expr.(*parser.Assignment) assignExpr, ok := expr.(*parser.AssignExpression)
if !ok { if !ok {
t.Fatalf("expected AssignExpression, got %T", expr) t.Fatalf("expected AssignExpression, got %T", expr)
} }
// Test target name // Test target name
ident, ok := assignExpr.Target.(*parser.Identifier) ident, ok := assignExpr.Name.(*parser.Identifier)
if !ok { if !ok {
t.Fatalf("expected Identifier for assignment target, got %T", assignExpr.Target) t.Fatalf("expected Identifier for assignment target, got %T", assignExpr.Name)
} }
if ident.Value != tt.targetName { if ident.Value != tt.targetName {
@ -413,12 +413,12 @@ func TestAssignmentExpressionWithComplexExpressions(t *testing.T) {
expr := p.ParseExpression(parser.LOWEST) expr := p.ParseExpression(parser.LOWEST)
checkParserErrors(t, p) checkParserErrors(t, p)
assignExpr, ok := expr.(*parser.Assignment) assignExpr, ok := expr.(*parser.AssignExpression)
if !ok { if !ok {
t.Fatalf("expected AssignExpression, got %T", expr) t.Fatalf("expected AssignExpression, got %T", expr)
} }
if assignExpr.Target == nil { if assignExpr.Name == nil {
t.Error("expected non-nil assignment target") t.Error("expected non-nil assignment target")
} }

View File

@ -160,7 +160,7 @@ func TestFunctionAssignments(t *testing.T) {
t.Fatalf("expected 1 statement, got %d", len(program.Statements)) t.Fatalf("expected 1 statement, got %d", len(program.Statements))
} }
stmt, ok := program.Statements[0].(*parser.Assignment) stmt, ok := program.Statements[0].(*parser.AssignStatement)
if !ok { if !ok {
t.Fatalf("expected AssignStatement, got %T", program.Statements[0]) t.Fatalf("expected AssignStatement, got %T", program.Statements[0])
} }
@ -296,7 +296,7 @@ end`
} }
// First statement: assignment of inner function // First statement: assignment of inner function
assign, ok := fn.Body[0].(*parser.Assignment) assign, ok := fn.Body[0].(*parser.AssignStatement)
if !ok { if !ok {
t.Fatalf("expected AssignStatement, got %T", fn.Body[0]) t.Fatalf("expected AssignStatement, got %T", fn.Body[0])
} }
@ -342,7 +342,7 @@ end`
} }
// First: function assignment // First: function assignment
assign, ok := forStmt.Body[0].(*parser.Assignment) assign, ok := forStmt.Body[0].(*parser.AssignStatement)
if !ok { if !ok {
t.Fatalf("expected AssignStatement, got %T", forStmt.Body[0]) t.Fatalf("expected AssignStatement, got %T", forStmt.Body[0])
} }
@ -551,7 +551,7 @@ echo adder`
} }
// First: table with functions // First: table with functions
mathAssign, ok := program.Statements[0].(*parser.Assignment) mathAssign, ok := program.Statements[0].(*parser.AssignStatement)
if !ok { if !ok {
t.Fatalf("statement 0: expected AssignStatement, got %T", program.Statements[0]) t.Fatalf("statement 0: expected AssignStatement, got %T", program.Statements[0])
} }
@ -574,7 +574,7 @@ echo adder`
} }
// Second: result assignment (function call would be handled by interpreter) // Second: result assignment (function call would be handled by interpreter)
_, ok = program.Statements[1].(*parser.Assignment) _, ok = program.Statements[1].(*parser.AssignStatement)
if !ok { if !ok {
t.Fatalf("statement 1: expected AssignStatement, got %T", program.Statements[1]) t.Fatalf("statement 1: expected AssignStatement, got %T", program.Statements[1])
} }
@ -586,7 +586,7 @@ echo adder`
} }
// Fourth: calculator function assignment // Fourth: calculator function assignment
calcAssign, ok := program.Statements[3].(*parser.Assignment) calcAssign, ok := program.Statements[3].(*parser.AssignStatement)
if !ok { if !ok {
t.Fatalf("statement 3: expected AssignStatement, got %T", program.Statements[3]) t.Fatalf("statement 3: expected AssignStatement, got %T", program.Statements[3])
} }
@ -601,7 +601,7 @@ echo adder`
} }
// Fifth: adder assignment // Fifth: adder assignment
_, ok = program.Statements[4].(*parser.Assignment) _, ok = program.Statements[4].(*parser.AssignStatement)
if !ok { if !ok {
t.Fatalf("statement 4: expected AssignStatement, got %T", program.Statements[4]) t.Fatalf("statement 4: expected AssignStatement, got %T", program.Statements[4])
} }
@ -645,7 +645,7 @@ end`
} }
// Check if body has function assignment // Check if body has function assignment
ifAssign, ok := ifStmt.Body[0].(*parser.Assignment) ifAssign, ok := ifStmt.Body[0].(*parser.AssignStatement)
if !ok { if !ok {
t.Fatalf("if body: expected AssignStatement, got %T", ifStmt.Body[0]) t.Fatalf("if body: expected AssignStatement, got %T", ifStmt.Body[0])
} }
@ -656,7 +656,7 @@ end`
} }
// Check else body has function assignment // Check else body has function assignment
elseAssign, ok := ifStmt.Else[0].(*parser.Assignment) elseAssign, ok := ifStmt.Else[0].(*parser.AssignStatement)
if !ok { if !ok {
t.Fatalf("else body: expected AssignStatement, got %T", ifStmt.Else[0]) t.Fatalf("else body: expected AssignStatement, got %T", ifStmt.Else[0])
} }
@ -683,7 +683,7 @@ end`
} }
// Verify both branches assign functions // Verify both branches assign functions
nestedIfAssign, ok := nestedIf.Body[0].(*parser.Assignment) nestedIfAssign, ok := nestedIf.Body[0].(*parser.AssignStatement)
if !ok { if !ok {
t.Fatalf("nested if body: expected AssignStatement, got %T", nestedIf.Body[0]) t.Fatalf("nested if body: expected AssignStatement, got %T", nestedIf.Body[0])
} }
@ -693,7 +693,7 @@ end`
t.Fatalf("nested if body: expected FunctionLiteral, got %T", nestedIfAssign.Value) t.Fatalf("nested if body: expected FunctionLiteral, got %T", nestedIfAssign.Value)
} }
nestedElseAssign, ok := nestedIf.Else[0].(*parser.Assignment) nestedElseAssign, ok := nestedIf.Else[0].(*parser.AssignStatement)
if !ok { if !ok {
t.Fatalf("nested else body: expected AssignStatement, got %T", nestedIf.Else[0]) t.Fatalf("nested else body: expected AssignStatement, got %T", nestedIf.Else[0])
} }

View File

@ -327,13 +327,13 @@ end`
} }
// First: table assignment // First: table assignment
_, ok := program.Statements[0].(*parser.Assignment) _, ok := program.Statements[0].(*parser.AssignStatement)
if !ok { if !ok {
t.Fatalf("statement 0: expected AssignStatement, got %T", program.Statements[0]) t.Fatalf("statement 0: expected AssignStatement, got %T", program.Statements[0])
} }
// Second: variable assignment // Second: variable assignment
_, ok = program.Statements[1].(*parser.Assignment) _, ok = program.Statements[1].(*parser.AssignStatement)
if !ok { if !ok {
t.Fatalf("statement 1: expected AssignStatement, got %T", program.Statements[1]) t.Fatalf("statement 1: expected AssignStatement, got %T", program.Statements[1])
} }
@ -590,7 +590,7 @@ end`
} }
// Second body statement should be assignment // Second body statement should be assignment
_, ok = outerWhile.Body[1].(*parser.Assignment) _, ok = outerWhile.Body[1].(*parser.AssignStatement)
if !ok { if !ok {
t.Fatalf("expected AssignStatement, got %T", outerWhile.Body[1]) t.Fatalf("expected AssignStatement, got %T", outerWhile.Body[1])
} }
@ -634,14 +634,14 @@ end`
} }
// First assignment: data[index] = ... // First assignment: data[index] = ...
assign1, ok := stmt.Body[0].(*parser.Assignment) assign1, ok := stmt.Body[0].(*parser.AssignStatement)
if !ok { if !ok {
t.Fatalf("expected AssignStatement, got %T", stmt.Body[0]) t.Fatalf("expected AssignStatement, got %T", stmt.Body[0])
} }
_, ok = assign1.Target.(*parser.IndexExpression) _, ok = assign1.Name.(*parser.IndexExpression)
if !ok { if !ok {
t.Fatalf("expected IndexExpression for assignment target, got %T", assign1.Target) t.Fatalf("expected IndexExpression for assignment target, got %T", assign1.Name)
} }
} }
@ -755,7 +755,7 @@ end`
// First three: assignments // First three: assignments
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
_, ok := program.Statements[i].(*parser.Assignment) _, ok := program.Statements[i].(*parser.AssignStatement)
if !ok { if !ok {
t.Fatalf("statement %d: expected AssignStatement, got %T", i, program.Statements[i]) t.Fatalf("statement %d: expected AssignStatement, got %T", i, program.Statements[i])
} }
@ -838,7 +838,7 @@ end`
} }
// Fourth: assignment // Fourth: assignment
_, ok = whileStmt.Body[3].(*parser.Assignment) _, ok = whileStmt.Body[3].(*parser.AssignStatement)
if !ok { if !ok {
t.Fatalf("body[3]: expected AssignStatement, got %T", whileStmt.Body[3]) t.Fatalf("body[3]: expected AssignStatement, got %T", whileStmt.Body[3])
} }

View File

@ -22,14 +22,14 @@ z = true + false`
expectedIdentifiers := []string{"x", "y", "z"} expectedIdentifiers := []string{"x", "y", "z"}
for i, expectedIdent := range expectedIdentifiers { for i, expectedIdent := range expectedIdentifiers {
stmt, ok := program.Statements[i].(*parser.Assignment) stmt, ok := program.Statements[i].(*parser.AssignStatement)
if !ok { if !ok {
t.Fatalf("statement %d: expected AssignStatement, got %T", i, program.Statements[i]) t.Fatalf("statement %d: expected AssignStatement, got %T", i, program.Statements[i])
} }
ident, ok := stmt.Target.(*parser.Identifier) ident, ok := stmt.Name.(*parser.Identifier)
if !ok { if !ok {
t.Fatalf("statement %d: expected Identifier for Name, got %T", i, stmt.Target) t.Fatalf("statement %d: expected Identifier for Name, got %T", i, stmt.Name)
} }
if ident.Value != expectedIdent { if ident.Value != expectedIdent {
@ -58,13 +58,13 @@ arr = {a = 1, b = 2}`
} }
// First statement: assignment // First statement: assignment
stmt1, ok := program.Statements[0].(*parser.Assignment) stmt1, ok := program.Statements[0].(*parser.AssignStatement)
if !ok { if !ok {
t.Fatalf("statement 0: expected AssignStatement, got %T", program.Statements[0]) t.Fatalf("statement 0: expected AssignStatement, got %T", program.Statements[0])
} }
ident1, ok := stmt1.Target.(*parser.Identifier) ident1, ok := stmt1.Name.(*parser.Identifier)
if !ok { if !ok {
t.Fatalf("expected Identifier for Name, got %T", stmt1.Target) t.Fatalf("expected Identifier for Name, got %T", stmt1.Name)
} }
if ident1.Value != "x" { if ident1.Value != "x" {
t.Errorf("expected identifier 'x', got %s", ident1.Value) t.Errorf("expected identifier 'x', got %s", ident1.Value)
@ -80,7 +80,7 @@ arr = {a = 1, b = 2}`
} }
// Third statement: table assignment // Third statement: table assignment
stmt3, ok := program.Statements[2].(*parser.Assignment) stmt3, ok := program.Statements[2].(*parser.AssignStatement)
if !ok { if !ok {
t.Fatalf("statement 2: expected AssignStatement, got %T", program.Statements[2]) t.Fatalf("statement 2: expected AssignStatement, got %T", program.Statements[2])
} }
@ -110,33 +110,33 @@ echo table[table.key]`
} }
// Second statement: dot assignment // Second statement: dot assignment
stmt2, ok := program.Statements[1].(*parser.Assignment) stmt2, ok := program.Statements[1].(*parser.AssignStatement)
if !ok { if !ok {
t.Fatalf("statement 1: expected AssignStatement, got %T", program.Statements[1]) t.Fatalf("statement 1: expected AssignStatement, got %T", program.Statements[1])
} }
_, ok = stmt2.Target.(*parser.DotExpression) _, ok = stmt2.Name.(*parser.DotExpression)
if !ok { if !ok {
t.Fatalf("expected DotExpression for assignment target, got %T", stmt2.Target) t.Fatalf("expected DotExpression for assignment target, got %T", stmt2.Name)
} }
// Third statement: bracket assignment // Third statement: bracket assignment
stmt3, ok := program.Statements[2].(*parser.Assignment) stmt3, ok := program.Statements[2].(*parser.AssignStatement)
if !ok { if !ok {
t.Fatalf("statement 2: expected AssignStatement, got %T", program.Statements[2]) t.Fatalf("statement 2: expected AssignStatement, got %T", program.Statements[2])
} }
_, ok = stmt3.Target.(*parser.IndexExpression) _, ok = stmt3.Name.(*parser.IndexExpression)
if !ok { if !ok {
t.Fatalf("expected IndexExpression for assignment target, got %T", stmt3.Target) t.Fatalf("expected IndexExpression for assignment target, got %T", stmt3.Name)
} }
// Fourth statement: chained dot assignment // Fourth statement: chained dot assignment
stmt4, ok := program.Statements[3].(*parser.Assignment) stmt4, ok := program.Statements[3].(*parser.AssignStatement)
if !ok { if !ok {
t.Fatalf("statement 3: expected AssignStatement, got %T", program.Statements[3]) t.Fatalf("statement 3: expected AssignStatement, got %T", program.Statements[3])
} }
_, ok = stmt4.Target.(*parser.DotExpression) _, ok = stmt4.Name.(*parser.DotExpression)
if !ok { if !ok {
t.Fatalf("expected DotExpression for assignment target, got %T", stmt4.Target) t.Fatalf("expected DotExpression for assignment target, got %T", stmt4.Name)
} }
// Fifth statement: echo with nested access // Fifth statement: echo with nested access
@ -232,7 +232,7 @@ end`
} }
// First statement: complex expression assignment // First statement: complex expression assignment
stmt1, ok := program.Statements[0].(*parser.Assignment) stmt1, ok := program.Statements[0].(*parser.AssignStatement)
if !ok { if !ok {
t.Fatalf("expected AssignStatement, got %T", program.Statements[0]) t.Fatalf("expected AssignStatement, got %T", program.Statements[0])
} }
@ -253,7 +253,7 @@ end`
t.Fatalf("expected 1 body statement, got %d", len(stmt2.Body)) t.Fatalf("expected 1 body statement, got %d", len(stmt2.Body))
} }
bodyStmt, ok := stmt2.Body[0].(*parser.Assignment) bodyStmt, ok := stmt2.Body[0].(*parser.AssignStatement)
if !ok { if !ok {
t.Fatalf("expected AssignStatement in body, got %T", stmt2.Body[0]) t.Fatalf("expected AssignStatement in body, got %T", stmt2.Body[0])
} }
@ -286,7 +286,7 @@ echo {result = x}`
} }
// First: assignment // First: assignment
_, ok := program.Statements[0].(*parser.Assignment) _, ok := program.Statements[0].(*parser.AssignStatement)
if !ok { if !ok {
t.Fatalf("statement 0: expected AssignStatement, got %T", program.Statements[0]) t.Fatalf("statement 0: expected AssignStatement, got %T", program.Statements[0])
} }

View File

@ -63,7 +63,7 @@ x = 15`,
assignmentCount := 0 assignmentCount := 0
for _, stmt := range program.Statements { for _, stmt := range program.Statements {
if assign, ok := stmt.(*parser.Assignment); ok { if assign, ok := stmt.(*parser.AssignStatement); ok {
if assignmentCount >= len(tt.assignments) { if assignmentCount >= len(tt.assignments) {
t.Fatalf("more assignments than expected") t.Fatalf("more assignments than expected")
} }
@ -71,9 +71,9 @@ x = 15`,
expected := tt.assignments[assignmentCount] expected := tt.assignments[assignmentCount]
// Check variable name // Check variable name
ident, ok := assign.Target.(*parser.Identifier) ident, ok := assign.Name.(*parser.Identifier)
if !ok { if !ok {
t.Fatalf("expected Identifier, got %T", assign.Target) t.Fatalf("expected Identifier, got %T", assign.Name)
} }
if ident.Value != expected.variable { if ident.Value != expected.variable {
@ -135,9 +135,9 @@ z = 30`
for i, expected := range expectedAssignments { for i, expected := range expectedAssignments {
assign := assignments[i] assign := assignments[i]
ident, ok := assign.Target.(*parser.Identifier) ident, ok := assign.Name.(*parser.Identifier)
if !ok { if !ok {
t.Fatalf("assignment %d: expected Identifier, got %T", i, assign.Target) t.Fatalf("assignment %d: expected Identifier, got %T", i, assign.Name)
} }
if ident.Value != expected.variable { if ident.Value != expected.variable {
@ -191,9 +191,9 @@ end`
for i, expected := range expectedAssignments { for i, expected := range expectedAssignments {
assign := assignments[i] assign := assignments[i]
ident, ok := assign.Target.(*parser.Identifier) ident, ok := assign.Name.(*parser.Identifier)
if !ok { if !ok {
t.Fatalf("assignment %d: expected Identifier, got %T", i, assign.Target) t.Fatalf("assignment %d: expected Identifier, got %T", i, assign.Name)
} }
if ident.Value != expected.variable { if ident.Value != expected.variable {
@ -243,9 +243,9 @@ c = 20`
for i, expected := range expectedAssignments { for i, expected := range expectedAssignments {
assign := assignments[i] assign := assignments[i]
ident, ok := assign.Target.(*parser.Identifier) ident, ok := assign.Name.(*parser.Identifier)
if !ok { if !ok {
t.Fatalf("assignment %d: expected Identifier, got %T", i, assign.Target) t.Fatalf("assignment %d: expected Identifier, got %T", i, assign.Name)
} }
if ident.Value != expected.variable { if ident.Value != expected.variable {
@ -344,9 +344,9 @@ count = 0`,
for i, expected := range tt.assignments { for i, expected := range tt.assignments {
assign := assignments[i] assign := assignments[i]
ident, ok := assign.Target.(*parser.Identifier) ident, ok := assign.Name.(*parser.Identifier)
if !ok { if !ok {
t.Fatalf("assignment %d: expected Identifier, got %T", i, assign.Target) t.Fatalf("assignment %d: expected Identifier, got %T", i, assign.Name)
} }
if ident.Value != expected.variable { if ident.Value != expected.variable {
@ -389,7 +389,7 @@ arr[1] = 99`
assignmentCount := 0 assignmentCount := 0
for _, stmt := range program.Statements { for _, stmt := range program.Statements {
if assign, ok := stmt.(*parser.Assignment); ok { if assign, ok := stmt.(*parser.AssignStatement); ok {
if assignmentCount >= len(expectedAssignments) { if assignmentCount >= len(expectedAssignments) {
t.Fatalf("more assignments than expected") t.Fatalf("more assignments than expected")
} }
@ -398,7 +398,7 @@ arr[1] = 99`
if expected.isMemberAccess { if expected.isMemberAccess {
// Should not be an identifier // Should not be an identifier
if _, ok := assign.Target.(*parser.Identifier); ok { if _, ok := assign.Name.(*parser.Identifier); ok {
t.Errorf("assignment %d: expected member access, got Identifier", assignmentCount) t.Errorf("assignment %d: expected member access, got Identifier", assignmentCount)
} }
@ -408,9 +408,9 @@ arr[1] = 99`
} }
} else { } else {
// Should be an identifier // Should be an identifier
ident, ok := assign.Target.(*parser.Identifier) ident, ok := assign.Name.(*parser.Identifier)
if !ok { if !ok {
t.Errorf("assignment %d: expected Identifier, got %T", assignmentCount, assign.Target) t.Errorf("assignment %d: expected Identifier, got %T", assignmentCount, assign.Name)
} else if ident.Value != expected.variable { } else if ident.Value != expected.variable {
t.Errorf("assignment %d: expected variable %s, got %s", t.Errorf("assignment %d: expected variable %s, got %s",
assignmentCount, expected.variable, ident.Value) assignmentCount, expected.variable, ident.Value)
@ -487,9 +487,9 @@ local_var = "global_local"`
for i, expected := range expectedAssignments { for i, expected := range expectedAssignments {
assign := assignments[i] assign := assignments[i]
ident, ok := assign.Target.(*parser.Identifier) ident, ok := assign.Name.(*parser.Identifier)
if !ok { if !ok {
t.Fatalf("assignment %d: expected Identifier, got %T", i, assign.Target) t.Fatalf("assignment %d: expected Identifier, got %T", i, assign.Name)
} }
if ident.Value != expected.variable { if ident.Value != expected.variable {
@ -536,8 +536,8 @@ y = 20.00`
} }
// Helper function to extract all assignments from a program recursively // Helper function to extract all assignments from a program recursively
func extractAssignments(program *parser.Program) []*parser.Assignment { func extractAssignments(program *parser.Program) []*parser.AssignStatement {
var assignments []*parser.Assignment var assignments []*parser.AssignStatement
for _, stmt := range program.Statements { for _, stmt := range program.Statements {
assignments = append(assignments, extractAssignmentsFromStatement(stmt)...) assignments = append(assignments, extractAssignmentsFromStatement(stmt)...)
@ -546,11 +546,11 @@ func extractAssignments(program *parser.Program) []*parser.Assignment {
return assignments return assignments
} }
func extractAssignmentsFromStatement(stmt parser.Statement) []*parser.Assignment { func extractAssignmentsFromStatement(stmt parser.Statement) []*parser.AssignStatement {
var assignments []*parser.Assignment var assignments []*parser.AssignStatement
switch s := stmt.(type) { switch s := stmt.(type) {
case *parser.Assignment: case *parser.AssignStatement:
assignments = append(assignments, s) assignments = append(assignments, s)
// Check if the value is a function literal with assignments in body // Check if the value is a function literal with assignments in body

View File

@ -38,22 +38,22 @@ func TestBasicStructDefinition(t *testing.T) {
if stmt.Fields[0].Name != "name" { if stmt.Fields[0].Name != "name" {
t.Errorf("expected field name 'name', got %s", stmt.Fields[0].Name) t.Errorf("expected field name 'name', got %s", stmt.Fields[0].Name)
} }
if stmt.Fields[0].TypeHint.Type == parser.TypeUnknown { if stmt.Fields[0].TypeHint == nil {
t.Fatal("expected type hint for name field") t.Fatal("expected type hint for name field")
} }
if stmt.Fields[0].TypeHint.Type != parser.TypeString { if stmt.Fields[0].TypeHint.Type != "string" {
t.Errorf("expected type string, got %v", stmt.Fields[0].TypeHint.Type) t.Errorf("expected type 'string', got %s", stmt.Fields[0].TypeHint.Type)
} }
// Test second field // Test second field
if stmt.Fields[1].Name != "age" { if stmt.Fields[1].Name != "age" {
t.Errorf("expected field name 'age', got %s", stmt.Fields[1].Name) t.Errorf("expected field name 'age', got %s", stmt.Fields[1].Name)
} }
if stmt.Fields[1].TypeHint.Type == parser.TypeUnknown { if stmt.Fields[1].TypeHint == nil {
t.Fatal("expected type hint for age field") t.Fatal("expected type hint for age field")
} }
if stmt.Fields[1].TypeHint.Type != parser.TypeNumber { if stmt.Fields[1].TypeHint.Type != "number" {
t.Errorf("expected type number, got %v", stmt.Fields[1].TypeHint.Type) t.Errorf("expected type 'number', got %s", stmt.Fields[1].TypeHint.Type)
} }
} }
@ -107,7 +107,7 @@ func TestComplexStructDefinition(t *testing.T) {
t.Fatalf("expected StructStatement, got %T", program.Statements[0]) t.Fatalf("expected StructStatement, got %T", program.Statements[0])
} }
expectedTypes := []parser.Type{parser.TypeNumber, parser.TypeString, parser.TypeBool, parser.TypeTable, parser.TypeFunction, parser.TypeAny} expectedTypes := []string{"number", "string", "bool", "table", "function", "any"}
expectedNames := []string{"id", "name", "active", "data", "callback", "optional"} expectedNames := []string{"id", "name", "active", "data", "callback", "optional"}
if len(stmt.Fields) != len(expectedTypes) { if len(stmt.Fields) != len(expectedTypes) {
@ -118,11 +118,11 @@ func TestComplexStructDefinition(t *testing.T) {
if field.Name != expectedNames[i] { if field.Name != expectedNames[i] {
t.Errorf("field %d: expected name '%s', got '%s'", i, expectedNames[i], field.Name) t.Errorf("field %d: expected name '%s', got '%s'", i, expectedNames[i], field.Name)
} }
if field.TypeHint.Type == parser.TypeUnknown { if field.TypeHint == nil {
t.Fatalf("field %d: expected type hint", i) t.Fatalf("field %d: expected type hint", i)
} }
if field.TypeHint.Type != expectedTypes[i] { if field.TypeHint.Type != expectedTypes[i] {
t.Errorf("field %d: expected type %v, got %v", i, expectedTypes[i], field.TypeHint.Type) t.Errorf("field %d: expected type '%s', got '%s'", i, expectedTypes[i], field.TypeHint.Type)
} }
} }
} }
@ -164,17 +164,17 @@ end`
if !ok { if !ok {
t.Fatalf("expected MethodDefinition, got %T", program.Statements[1]) t.Fatalf("expected MethodDefinition, got %T", program.Statements[1])
} }
if method1.StructID != structStmt.ID { if method1.StructName != "Person" {
t.Errorf("expected struct ID %d, got %d", structStmt.ID, method1.StructID) t.Errorf("expected struct name 'Person', got %s", method1.StructName)
} }
if method1.MethodName != "getName" { if method1.MethodName != "getName" {
t.Errorf("expected method name 'getName', got %s", method1.MethodName) t.Errorf("expected method name 'getName', got %s", method1.MethodName)
} }
if method1.Function.ReturnType.Type == parser.TypeUnknown { if method1.Function.ReturnType == nil {
t.Fatal("expected return type for getName method") t.Fatal("expected return type for getName method")
} }
if method1.Function.ReturnType.Type != parser.TypeString { if method1.Function.ReturnType.Type != "string" {
t.Errorf("expected return type string, got %v", method1.Function.ReturnType.Type) t.Errorf("expected return type 'string', got %s", method1.Function.ReturnType.Type)
} }
if len(method1.Function.Parameters) != 0 { if len(method1.Function.Parameters) != 0 {
t.Errorf("expected 0 parameters, got %d", len(method1.Function.Parameters)) t.Errorf("expected 0 parameters, got %d", len(method1.Function.Parameters))
@ -185,14 +185,14 @@ end`
if !ok { if !ok {
t.Fatalf("expected MethodDefinition, got %T", program.Statements[2]) t.Fatalf("expected MethodDefinition, got %T", program.Statements[2])
} }
if method2.StructID != structStmt.ID { if method2.StructName != "Person" {
t.Errorf("expected struct ID %d, got %d", structStmt.ID, method2.StructID) t.Errorf("expected struct name 'Person', got %s", method2.StructName)
} }
if method2.MethodName != "setAge" { if method2.MethodName != "setAge" {
t.Errorf("expected method name 'setAge', got %s", method2.MethodName) t.Errorf("expected method name 'setAge', got %s", method2.MethodName)
} }
if method2.Function.ReturnType.Type != parser.TypeUnknown { if method2.Function.ReturnType != nil {
t.Errorf("expected no return type for setAge method, got %v", method2.Function.ReturnType.Type) t.Errorf("expected no return type for setAge method, got %s", method2.Function.ReturnType.Type)
} }
if len(method2.Function.Parameters) != 1 { if len(method2.Function.Parameters) != 1 {
t.Fatalf("expected 1 parameter, got %d", len(method2.Function.Parameters)) t.Fatalf("expected 1 parameter, got %d", len(method2.Function.Parameters))
@ -200,11 +200,11 @@ end`
if method2.Function.Parameters[0].Name != "newAge" { if method2.Function.Parameters[0].Name != "newAge" {
t.Errorf("expected parameter name 'newAge', got %s", method2.Function.Parameters[0].Name) t.Errorf("expected parameter name 'newAge', got %s", method2.Function.Parameters[0].Name)
} }
if method2.Function.Parameters[0].TypeHint.Type == parser.TypeUnknown { if method2.Function.Parameters[0].TypeHint == nil {
t.Fatal("expected type hint for newAge parameter") t.Fatal("expected type hint for newAge parameter")
} }
if method2.Function.Parameters[0].TypeHint.Type != parser.TypeNumber { if method2.Function.Parameters[0].TypeHint.Type != "number" {
t.Errorf("expected parameter type number, got %v", method2.Function.Parameters[0].TypeHint.Type) t.Errorf("expected parameter type 'number', got %s", method2.Function.Parameters[0].TypeHint.Type)
} }
} }
@ -226,21 +226,19 @@ empty = Person{}`
t.Fatalf("expected 3 statements, got %d", len(program.Statements)) t.Fatalf("expected 3 statements, got %d", len(program.Statements))
} }
structStmt := program.Statements[0].(*parser.StructStatement)
// Second statement: constructor with fields // Second statement: constructor with fields
assign1, ok := program.Statements[1].(*parser.Assignment) assign1, ok := program.Statements[1].(*parser.AssignStatement)
if !ok { if !ok {
t.Fatalf("expected Assignment, got %T", program.Statements[1]) t.Fatalf("expected AssignStatement, got %T", program.Statements[1])
} }
constructor1, ok := assign1.Value.(*parser.StructConstructor) constructor1, ok := assign1.Value.(*parser.StructConstructorExpression)
if !ok { if !ok {
t.Fatalf("expected StructConstructor, got %T", assign1.Value) t.Fatalf("expected StructConstructorExpression, got %T", assign1.Value)
} }
if constructor1.StructID != structStmt.ID { if constructor1.StructName != "Person" {
t.Errorf("expected struct ID %d, got %d", structStmt.ID, constructor1.StructID) t.Errorf("expected struct name 'Person', got %s", constructor1.StructName)
} }
if len(constructor1.Fields) != 2 { if len(constructor1.Fields) != 2 {
@ -268,18 +266,18 @@ empty = Person{}`
testNumberLiteral(t, constructor1.Fields[1].Value, 30) testNumberLiteral(t, constructor1.Fields[1].Value, 30)
// Third statement: empty constructor // Third statement: empty constructor
assign2, ok := program.Statements[2].(*parser.Assignment) assign2, ok := program.Statements[2].(*parser.AssignStatement)
if !ok { if !ok {
t.Fatalf("expected Assignment, got %T", program.Statements[2]) t.Fatalf("expected AssignStatement, got %T", program.Statements[2])
} }
constructor2, ok := assign2.Value.(*parser.StructConstructor) constructor2, ok := assign2.Value.(*parser.StructConstructorExpression)
if !ok { if !ok {
t.Fatalf("expected StructConstructor, got %T", assign2.Value) t.Fatalf("expected StructConstructorExpression, got %T", assign2.Value)
} }
if constructor2.StructID != structStmt.ID { if constructor2.StructName != "Person" {
t.Errorf("expected struct ID %d, got %d", structStmt.ID, constructor2.StructID) t.Errorf("expected struct name 'Person', got %s", constructor2.StructName)
} }
if len(constructor2.Fields) != 0 { if len(constructor2.Fields) != 0 {
@ -312,8 +310,6 @@ person = Person{
t.Fatalf("expected 3 statements, got %d", len(program.Statements)) t.Fatalf("expected 3 statements, got %d", len(program.Statements))
} }
addressStruct := program.Statements[0].(*parser.StructStatement)
// Check Person struct has Address field type // Check Person struct has Address field type
personStruct, ok := program.Statements[1].(*parser.StructStatement) personStruct, ok := program.Statements[1].(*parser.StructStatement)
if !ok { if !ok {
@ -324,32 +320,29 @@ person = Person{
if addressField.Name != "address" { if addressField.Name != "address" {
t.Errorf("expected field name 'address', got %s", addressField.Name) t.Errorf("expected field name 'address', got %s", addressField.Name)
} }
if addressField.TypeHint.Type != parser.TypeStruct { if addressField.TypeHint.Type != "Address" {
t.Errorf("expected field type struct, got %v", addressField.TypeHint.Type) t.Errorf("expected field type 'Address', got %s", addressField.TypeHint.Type)
}
if addressField.TypeHint.StructID != addressStruct.ID {
t.Errorf("expected struct ID %d, got %d", addressStruct.ID, addressField.TypeHint.StructID)
} }
// Check nested constructor // Check nested constructor
assign, ok := program.Statements[2].(*parser.Assignment) assign, ok := program.Statements[2].(*parser.AssignStatement)
if !ok { if !ok {
t.Fatalf("expected Assignment, got %T", program.Statements[2]) t.Fatalf("expected AssignStatement, got %T", program.Statements[2])
} }
personConstructor, ok := assign.Value.(*parser.StructConstructor) personConstructor, ok := assign.Value.(*parser.StructConstructorExpression)
if !ok { if !ok {
t.Fatalf("expected StructConstructor, got %T", assign.Value) t.Fatalf("expected StructConstructorExpression, got %T", assign.Value)
} }
// Check the nested Address constructor // Check the nested Address constructor
addressConstructor, ok := personConstructor.Fields[1].Value.(*parser.StructConstructor) addressConstructor, ok := personConstructor.Fields[1].Value.(*parser.StructConstructorExpression)
if !ok { if !ok {
t.Fatalf("expected nested StructConstructor, got %T", personConstructor.Fields[1].Value) t.Fatalf("expected nested StructConstructorExpression, got %T", personConstructor.Fields[1].Value)
} }
if addressConstructor.StructID != addressStruct.ID { if addressConstructor.StructName != "Address" {
t.Errorf("expected nested struct ID %d, got %d", addressStruct.ID, addressConstructor.StructID) t.Errorf("expected nested struct name 'Address', got %s", addressConstructor.StructName)
} }
if len(addressConstructor.Fields) != 2 { if len(addressConstructor.Fields) != 2 {
@ -404,8 +397,8 @@ end`
if !ok { if !ok {
t.Fatalf("expected MethodDefinition, got %T", program.Statements[1]) t.Fatalf("expected MethodDefinition, got %T", program.Statements[1])
} }
if methodStmt.StructID != structStmt.ID { if methodStmt.StructName != "Point" {
t.Errorf("expected struct ID %d, got %d", structStmt.ID, methodStmt.StructID) t.Errorf("expected struct name 'Point', got %s", methodStmt.StructName)
} }
if methodStmt.MethodName != "distance" { if methodStmt.MethodName != "distance" {
t.Errorf("expected method name 'distance', got %s", methodStmt.MethodName) t.Errorf("expected method name 'distance', got %s", methodStmt.MethodName)
@ -413,16 +406,16 @@ end`
// Verify constructors // Verify constructors
for i := 2; i <= 3; i++ { for i := 2; i <= 3; i++ {
assign, ok := program.Statements[i].(*parser.Assignment) assign, ok := program.Statements[i].(*parser.AssignStatement)
if !ok { if !ok {
t.Fatalf("statement %d: expected Assignment, got %T", i, program.Statements[i]) t.Fatalf("statement %d: expected AssignStatement, got %T", i, program.Statements[i])
} }
constructor, ok := assign.Value.(*parser.StructConstructor) constructor, ok := assign.Value.(*parser.StructConstructorExpression)
if !ok { if !ok {
t.Fatalf("statement %d: expected StructConstructor, got %T", i, assign.Value) t.Fatalf("statement %d: expected StructConstructorExpression, got %T", i, assign.Value)
} }
if constructor.StructID != structStmt.ID { if constructor.StructName != "Point" {
t.Errorf("statement %d: expected struct ID %d, got %d", i, structStmt.ID, constructor.StructID) t.Errorf("statement %d: expected struct name 'Point', got %s", i, constructor.StructName)
} }
} }
@ -453,16 +446,16 @@ end`
} }
// Check struct constructor in loop // Check struct constructor in loop
loopAssign, ok := forStmt.Body[0].(*parser.Assignment) loopAssign, ok := forStmt.Body[0].(*parser.AssignStatement)
if !ok { if !ok {
t.Fatalf("expected Assignment in loop, got %T", forStmt.Body[0]) t.Fatalf("expected AssignStatement in loop, got %T", forStmt.Body[0])
} }
loopConstructor, ok := loopAssign.Value.(*parser.StructConstructor) loopConstructor, ok := loopAssign.Value.(*parser.StructConstructorExpression)
if !ok { if !ok {
t.Fatalf("expected StructConstructor in loop, got %T", loopAssign.Value) t.Fatalf("expected StructConstructorExpression in loop, got %T", loopAssign.Value)
} }
if loopConstructor.StructID != structStmt.ID { if loopConstructor.StructName != "Point" {
t.Errorf("expected struct ID %d in loop, got %d", structStmt.ID, loopConstructor.StructID) t.Errorf("expected struct name 'Point' in loop, got %s", loopConstructor.StructName)
} }
} }
@ -559,13 +552,13 @@ func TestSingleLineStruct(t *testing.T) {
t.Fatalf("expected 2 fields, got %d", len(stmt.Fields)) t.Fatalf("expected 2 fields, got %d", len(stmt.Fields))
} }
if stmt.Fields[0].Name != "name" || stmt.Fields[0].TypeHint.Type != parser.TypeString { if stmt.Fields[0].Name != "name" || stmt.Fields[0].TypeHint.Type != "string" {
t.Errorf("expected first field 'name: string', got '%s: %v'", t.Errorf("expected first field 'name: string', got '%s: %s'",
stmt.Fields[0].Name, stmt.Fields[0].TypeHint.Type) stmt.Fields[0].Name, stmt.Fields[0].TypeHint.Type)
} }
if stmt.Fields[1].Name != "age" || stmt.Fields[1].TypeHint.Type != parser.TypeNumber { if stmt.Fields[1].Name != "age" || stmt.Fields[1].TypeHint.Type != "number" {
t.Errorf("expected second field 'age: number', got '%s: %v'", t.Errorf("expected second field 'age: number', got '%s: %s'",
stmt.Fields[1].Name, stmt.Fields[1].TypeHint.Type) stmt.Fields[1].Name, stmt.Fields[1].TypeHint.Type)
} }
} }
@ -607,8 +600,8 @@ end`
method := program.Statements[1].(*parser.MethodDefinition) method := program.Statements[1].(*parser.MethodDefinition)
str := method.String() str := method.String()
if !containsSubstring(str, "fn <struct>.getName") { if !containsSubstring(str, "fn Person.getName") {
t.Errorf("expected method string to contain 'fn <struct>.getName', got: %s", str) t.Errorf("expected method string to contain 'fn Person.getName', got: %s", str)
} }
if !containsSubstring(str, ": string") { if !containsSubstring(str, ": string") {
t.Errorf("expected method string to contain return type, got: %s", str) t.Errorf("expected method string to contain return type, got: %s", str)
@ -628,11 +621,11 @@ person = Person{name = "John", age = 30}`
program := p.ParseProgram() program := p.ParseProgram()
checkParserErrors(t, p) checkParserErrors(t, p)
assign := program.Statements[1].(*parser.Assignment) assign := program.Statements[1].(*parser.AssignStatement)
constructor := assign.Value.(*parser.StructConstructor) constructor := assign.Value.(*parser.StructConstructorExpression)
str := constructor.String() str := constructor.String()
expected := `<struct>{name = "John", age = 30.00}` expected := `Person{name = "John", age = 30.00}`
if str != expected { if str != expected {
t.Errorf("expected constructor string:\n%s\ngot:\n%s", expected, str) t.Errorf("expected constructor string:\n%s\ngot:\n%s", expected, str)
} }

View File

@ -10,18 +10,18 @@ func TestVariableTypeHints(t *testing.T) {
tests := []struct { tests := []struct {
input string input string
variable string variable string
typeHint parser.Type typeHint string
hasHint bool hasHint bool
desc string desc string
}{ }{
{"x = 42", "x", parser.TypeUnknown, false, "no type hint"}, {"x = 42", "x", "", false, "no type hint"},
{"x: number = 42", "x", parser.TypeNumber, true, "number type hint"}, {"x: number = 42", "x", "number", true, "number type hint"},
{"name: string = \"hello\"", "name", parser.TypeString, true, "string type hint"}, {"name: string = \"hello\"", "name", "string", true, "string type hint"},
{"flag: bool = true", "flag", parser.TypeBool, true, "bool type hint"}, {"flag: bool = true", "flag", "bool", true, "bool type hint"},
{"data: table = {}", "data", parser.TypeTable, true, "table type hint"}, {"data: table = {}", "data", "table", true, "table type hint"},
{"fn_var: function = fn() end", "fn_var", parser.TypeFunction, true, "function type hint"}, {"fn_var: function = fn() end", "fn_var", "function", true, "function type hint"},
{"value: any = nil", "value", parser.TypeAny, true, "any type hint"}, {"value: any = nil", "value", "any", true, "any type hint"},
{"ptr: nil = nil", "ptr", parser.TypeNil, true, "nil type hint"}, {"ptr: nil = nil", "ptr", "nil", true, "nil type hint"},
} }
for _, tt := range tests { for _, tt := range tests {
@ -35,15 +35,15 @@ func TestVariableTypeHints(t *testing.T) {
t.Fatalf("expected 1 statement, got %d", len(program.Statements)) t.Fatalf("expected 1 statement, got %d", len(program.Statements))
} }
stmt, ok := program.Statements[0].(*parser.Assignment) stmt, ok := program.Statements[0].(*parser.AssignStatement)
if !ok { if !ok {
t.Fatalf("expected Assignment, got %T", program.Statements[0]) t.Fatalf("expected AssignStatement, got %T", program.Statements[0])
} }
// Check variable name // Check variable name
ident, ok := stmt.Target.(*parser.Identifier) ident, ok := stmt.Name.(*parser.Identifier)
if !ok { if !ok {
t.Fatalf("expected Identifier for Target, got %T", stmt.Target) t.Fatalf("expected Identifier for Name, got %T", stmt.Name)
} }
if ident.Value != tt.variable { if ident.Value != tt.variable {
@ -52,19 +52,19 @@ func TestVariableTypeHints(t *testing.T) {
// Check type hint // Check type hint
if tt.hasHint { if tt.hasHint {
if stmt.TypeHint.Type == parser.TypeUnknown { if stmt.TypeHint == nil {
t.Error("expected type hint but got TypeUnknown") t.Error("expected type hint but got nil")
} else { } else {
if stmt.TypeHint.Type != tt.typeHint { if stmt.TypeHint.Type != tt.typeHint {
t.Errorf("expected type hint %v, got %v", tt.typeHint, stmt.TypeHint.Type) t.Errorf("expected type hint %s, got %s", tt.typeHint, stmt.TypeHint.Type)
} }
if stmt.TypeHint.Inferred { if stmt.TypeHint.Inferred {
t.Error("expected type hint to not be inferred") t.Error("expected type hint to not be inferred")
} }
} }
} else { } else {
if stmt.TypeHint.Type != parser.TypeUnknown { if stmt.TypeHint != nil {
t.Errorf("expected no type hint but got %v", stmt.TypeHint.Type) t.Errorf("expected no type hint but got %s", stmt.TypeHint.Type)
} }
} }
}) })
@ -74,81 +74,60 @@ func TestVariableTypeHints(t *testing.T) {
func TestFunctionParameterTypeHints(t *testing.T) { func TestFunctionParameterTypeHints(t *testing.T) {
tests := []struct { tests := []struct {
input string input string
params []struct { params []struct{ name, typeHint string }
name string returnType string
typeHint parser.Type
}
returnType parser.Type
hasReturn bool hasReturn bool
desc string desc string
}{ }{
{ {
"fn(a, b) end", "fn(a, b) end",
[]struct { []struct{ name, typeHint string }{
name string {"a", ""},
typeHint parser.Type {"b", ""},
}{
{"a", parser.TypeUnknown},
{"b", parser.TypeUnknown},
}, },
parser.TypeUnknown, false, "", false,
"no type hints", "no type hints",
}, },
{ {
"fn(a: number, b: string) end", "fn(a: number, b: string) end",
[]struct { []struct{ name, typeHint string }{
name string {"a", "number"},
typeHint parser.Type {"b", "string"},
}{
{"a", parser.TypeNumber},
{"b", parser.TypeString},
}, },
parser.TypeUnknown, false, "", false,
"parameter type hints only", "parameter type hints only",
}, },
{ {
"fn(x: number): string end", "fn(x: number): string end",
[]struct { []struct{ name, typeHint string }{
name string {"x", "number"},
typeHint parser.Type
}{
{"x", parser.TypeNumber},
}, },
parser.TypeString, true, "string", true,
"parameter and return type hints", "parameter and return type hints",
}, },
{ {
"fn(): bool end", "fn(): bool end",
[]struct { []struct{ name, typeHint string }{},
name string "bool", true,
typeHint parser.Type
}{},
parser.TypeBool, true,
"return type hint only", "return type hint only",
}, },
{ {
"fn(a: number, b, c: string): table end", "fn(a: number, b, c: string): table end",
[]struct { []struct{ name, typeHint string }{
name string {"a", "number"},
typeHint parser.Type {"b", ""},
}{ {"c", "string"},
{"a", parser.TypeNumber},
{"b", parser.TypeUnknown},
{"c", parser.TypeString},
}, },
parser.TypeTable, true, "table", true,
"mixed parameter types with return", "mixed parameter types with return",
}, },
{ {
"fn(callback: function, data: any): nil end", "fn(callback: function, data: any): nil end",
[]struct { []struct{ name, typeHint string }{
name string {"callback", "function"},
typeHint parser.Type {"data", "any"},
}{
{"callback", parser.TypeFunction},
{"data", parser.TypeAny},
}, },
parser.TypeNil, true, "nil", true,
"function and any types", "function and any types",
}, },
} }
@ -176,29 +155,29 @@ func TestFunctionParameterTypeHints(t *testing.T) {
t.Errorf("parameter %d: expected name %s, got %s", i, expected.name, param.Name) t.Errorf("parameter %d: expected name %s, got %s", i, expected.name, param.Name)
} }
if expected.typeHint == parser.TypeUnknown { if expected.typeHint == "" {
if param.TypeHint.Type != parser.TypeUnknown { if param.TypeHint != nil {
t.Errorf("parameter %d: expected no type hint but got %v", i, param.TypeHint.Type) t.Errorf("parameter %d: expected no type hint but got %s", i, param.TypeHint.Type)
} }
} else { } else {
if param.TypeHint.Type == parser.TypeUnknown { if param.TypeHint == nil {
t.Errorf("parameter %d: expected type hint %v but got TypeUnknown", i, expected.typeHint) t.Errorf("parameter %d: expected type hint %s but got nil", i, expected.typeHint)
} else if param.TypeHint.Type != expected.typeHint { } else if param.TypeHint.Type != expected.typeHint {
t.Errorf("parameter %d: expected type hint %v, got %v", i, expected.typeHint, param.TypeHint.Type) t.Errorf("parameter %d: expected type hint %s, got %s", i, expected.typeHint, param.TypeHint.Type)
} }
} }
} }
// Check return type // Check return type
if tt.hasReturn { if tt.hasReturn {
if fn.ReturnType.Type == parser.TypeUnknown { if fn.ReturnType == nil {
t.Error("expected return type hint but got TypeUnknown") t.Error("expected return type hint but got nil")
} else if fn.ReturnType.Type != tt.returnType { } else if fn.ReturnType.Type != tt.returnType {
t.Errorf("expected return type %v, got %v", tt.returnType, fn.ReturnType.Type) t.Errorf("expected return type %s, got %s", tt.returnType, fn.ReturnType.Type)
} }
} else { } else {
if fn.ReturnType.Type != parser.TypeUnknown { if fn.ReturnType != nil {
t.Errorf("expected no return type but got %v", fn.ReturnType.Type) t.Errorf("expected no return type but got %s", fn.ReturnType.Type)
} }
} }
}) })
@ -300,13 +279,13 @@ func TestMemberAccessWithoutTypeHints(t *testing.T) {
t.Fatalf("expected 1 statement, got %d", len(program.Statements)) t.Fatalf("expected 1 statement, got %d", len(program.Statements))
} }
stmt, ok := program.Statements[0].(*parser.Assignment) stmt, ok := program.Statements[0].(*parser.AssignStatement)
if !ok { if !ok {
t.Fatalf("expected Assignment, got %T", program.Statements[0]) t.Fatalf("expected AssignStatement, got %T", program.Statements[0])
} }
// Member access should never have type hints // Member access should never have type hints
if stmt.TypeHint.Type != parser.TypeUnknown { if stmt.TypeHint != nil {
t.Error("member access assignment should not have type hints") t.Error("member access assignment should not have type hints")
} }
@ -354,12 +333,12 @@ func TestTypeInferenceErrors(t *testing.T) {
}{ }{
{ {
"x: number = \"hello\"", "x: number = \"hello\"",
"type mismatch in assignment", "cannot assign string to variable of type number",
"type mismatch in assignment", "type mismatch in assignment",
}, },
{ {
"x = 42\ny: string = x", "x = 42\ny: string = x",
"type mismatch in assignment", "cannot assign number to variable of type string",
"type mismatch with inferred type", "type mismatch with inferred type",
}, },
} }
@ -380,7 +359,7 @@ func TestTypeInferenceErrors(t *testing.T) {
found := false found := false
for _, err := range typeErrors { for _, err := range typeErrors {
if containsSubstring(err.Message, tt.expectedError) { if err.Message == tt.expectedError {
found = true found = true
break break
} }
@ -391,7 +370,7 @@ func TestTypeInferenceErrors(t *testing.T) {
for i, err := range typeErrors { for i, err := range typeErrors {
errorMsgs[i] = err.Message errorMsgs[i] = err.Message
} }
t.Errorf("expected error containing %q, got %v", tt.expectedError, errorMsgs) t.Errorf("expected error %q, got %v", tt.expectedError, errorMsgs)
} }
}) })
} }
@ -460,22 +439,22 @@ server: table = {
} }
// Check first statement: config table with typed assignments // Check first statement: config table with typed assignments
configStmt, ok := program.Statements[0].(*parser.Assignment) configStmt, ok := program.Statements[0].(*parser.AssignStatement)
if !ok { if !ok {
t.Fatalf("expected Assignment, got %T", program.Statements[0]) t.Fatalf("expected AssignStatement, got %T", program.Statements[0])
} }
if configStmt.TypeHint.Type == parser.TypeUnknown || configStmt.TypeHint.Type != parser.TypeTable { if configStmt.TypeHint == nil || configStmt.TypeHint.Type != "table" {
t.Error("expected table type hint for config") t.Error("expected table type hint for config")
} }
// Check second statement: handler function with typed parameters // Check second statement: handler function with typed parameters
handlerStmt, ok := program.Statements[1].(*parser.Assignment) handlerStmt, ok := program.Statements[1].(*parser.AssignStatement)
if !ok { if !ok {
t.Fatalf("expected Assignment, got %T", program.Statements[1]) t.Fatalf("expected AssignStatement, got %T", program.Statements[1])
} }
if handlerStmt.TypeHint.Type == parser.TypeUnknown || handlerStmt.TypeHint.Type != parser.TypeFunction { if handlerStmt.TypeHint == nil || handlerStmt.TypeHint.Type != "function" {
t.Error("expected function type hint for handler") t.Error("expected function type hint for handler")
} }
@ -489,32 +468,34 @@ server: table = {
} }
// Check parameter types // Check parameter types
if fn.Parameters[0].TypeHint.Type == parser.TypeUnknown || fn.Parameters[0].TypeHint.Type != parser.TypeTable { if fn.Parameters[0].TypeHint == nil || fn.Parameters[0].TypeHint.Type != "table" {
t.Error("expected table type for request parameter") t.Error("expected table type for request parameter")
} }
if fn.Parameters[1].TypeHint.Type == parser.TypeUnknown || fn.Parameters[1].TypeHint.Type != parser.TypeFunction { if fn.Parameters[1].TypeHint == nil || fn.Parameters[1].TypeHint.Type != "function" {
t.Error("expected function type for callback parameter") t.Error("expected function type for callback parameter")
} }
// Check return type // Check return type
if fn.ReturnType.Type == parser.TypeUnknown || fn.ReturnType.Type != parser.TypeNil { if fn.ReturnType == nil || fn.ReturnType.Type != "nil" {
t.Error("expected nil return type for handler") t.Error("expected nil return type for handler")
} }
// Check third statement: server table // Check third statement: server table
serverStmt, ok := program.Statements[2].(*parser.Assignment) serverStmt, ok := program.Statements[2].(*parser.AssignStatement)
if !ok { if !ok {
t.Fatalf("expected Assignment, got %T", program.Statements[2]) t.Fatalf("expected AssignStatement, got %T", program.Statements[2])
} }
if serverStmt.TypeHint.Type == parser.TypeUnknown || serverStmt.TypeHint.Type != parser.TypeTable { if serverStmt.TypeHint == nil || serverStmt.TypeHint.Type != "table" {
t.Error("expected table type hint for server") t.Error("expected table type hint for server")
} }
} }
func TestTypeInfoInterface(t *testing.T) { func TestTypeInfoGettersSetters(t *testing.T) {
// Test that all expression types properly implement TypeInfo() // Test that all expression types properly implement GetType/SetType
typeInfo := &parser.TypeInfo{Type: "test", Inferred: true}
expressions := []parser.Expression{ expressions := []parser.Expression{
&parser.Identifier{Value: "x"}, &parser.Identifier{Value: "x"},
&parser.NumberLiteral{Value: 42}, &parser.NumberLiteral{Value: 42},
@ -532,43 +513,20 @@ func TestTypeInfoInterface(t *testing.T) {
for i, expr := range expressions { for i, expr := range expressions {
t.Run(string(rune('0'+i)), func(t *testing.T) { t.Run(string(rune('0'+i)), func(t *testing.T) {
// Should have default type initially // Initially should have no type
typeInfo := expr.TypeInfo() if expr.GetType() != nil {
t.Error("expected nil type initially")
}
// Basic literals should have their expected types // Set type
switch e := expr.(type) { expr.SetType(typeInfo)
case *parser.NumberLiteral:
if typeInfo.Type != parser.TypeNumber { // Get type should return what we set
t.Errorf("expected number type, got %v", typeInfo.Type) retrieved := expr.GetType()
} if retrieved == nil {
case *parser.StringLiteral: t.Error("expected non-nil type after setting")
if typeInfo.Type != parser.TypeString { } else if retrieved.Type != "test" || !retrieved.Inferred {
t.Errorf("expected string type, got %v", typeInfo.Type) t.Errorf("expected {Type: test, Inferred: true}, got %+v", retrieved)
}
case *parser.BooleanLiteral:
if typeInfo.Type != parser.TypeBool {
t.Errorf("expected bool type, got %v", typeInfo.Type)
}
case *parser.NilLiteral:
if typeInfo.Type != parser.TypeNil {
t.Errorf("expected nil type, got %v", typeInfo.Type)
}
case *parser.TableLiteral:
if typeInfo.Type != parser.TypeTable {
t.Errorf("expected table type, got %v", typeInfo.Type)
}
case *parser.FunctionLiteral:
if typeInfo.Type != parser.TypeFunction {
t.Errorf("expected function type, got %v", typeInfo.Type)
}
case *parser.Identifier:
// Identifiers default to any type
if typeInfo.Type != parser.TypeAny {
t.Errorf("expected any type for untyped identifier, got %v", typeInfo.Type)
}
default:
// Other expressions may have unknown type initially
_ = e
} }
}) })
} }

View File

@ -1,45 +1,21 @@
package parser package parser
import "fmt" import (
"fmt"
// Type represents built-in and user-defined types using compact enum representation. )
// Uses single byte instead of string pointers to minimize memory usage.
type Type uint8
// Type constants for built-in types
const ( const (
TypeUnknown Type = iota TypeNumber = "number"
TypeNumber TypeString = "string"
TypeString TypeBool = "bool"
TypeBool TypeNil = "nil"
TypeNil TypeTable = "table"
TypeTable TypeFunction = "function"
TypeFunction TypeAny = "any"
TypeAny
TypeStruct // struct types use StructID field for identification
) )
// TypeInfo represents type information with zero-allocation design. // TypeError represents a type checking error
// Embeds directly in AST nodes instead of using pointers to reduce heap pressure.
// Common types are pre-allocated as globals to eliminate most allocations.
type TypeInfo struct {
Type Type // Built-in type or TypeStruct for user types
Inferred bool // True if type was inferred, false if explicitly declared
StructID uint16 // Index into global struct table for struct types (0 for non-structs)
}
// Pre-allocated common types - eliminates heap allocations for built-in types
var (
UnknownType = TypeInfo{Type: TypeUnknown, Inferred: true}
NumberType = TypeInfo{Type: TypeNumber, Inferred: true}
StringType = TypeInfo{Type: TypeString, Inferred: true}
BoolType = TypeInfo{Type: TypeBool, Inferred: true}
NilType = TypeInfo{Type: TypeNil, Inferred: true}
TableType = TypeInfo{Type: TypeTable, Inferred: true}
FunctionType = TypeInfo{Type: TypeFunction, Inferred: true}
AnyType = TypeInfo{Type: TypeAny, Inferred: true}
)
// TypeError represents a type checking error with location information
type TypeError struct { type TypeError struct {
Message string Message string
Line int Line int
@ -51,10 +27,10 @@ func (te TypeError) Error() string {
return fmt.Sprintf("Type error at line %d, column %d: %s", te.Line, te.Column, te.Message) return fmt.Sprintf("Type error at line %d, column %d: %s", te.Line, te.Column, te.Message)
} }
// Symbol represents a variable in the symbol table with optimized type storage // Symbol represents a variable in the symbol table
type Symbol struct { type Symbol struct {
Name string Name string
Type TypeInfo // Embed directly instead of pointer Type *TypeInfo
Declared bool Declared bool
Line int Line int
Column int Column int
@ -87,56 +63,52 @@ func (s *Scope) Lookup(name string) *Symbol {
return nil return nil
} }
// TypeInferrer performs type inference and checking with optimized allocations // TypeInferrer performs type inference and checking
type TypeInferrer struct { type TypeInferrer struct {
currentScope *Scope currentScope *Scope
globalScope *Scope globalScope *Scope
errors []TypeError errors []TypeError
// Struct definitions with ID mapping // Pre-allocated type objects for performance
numberType *TypeInfo
stringType *TypeInfo
boolType *TypeInfo
nilType *TypeInfo
tableType *TypeInfo
anyType *TypeInfo
// Struct definitions
structs map[string]*StructStatement structs map[string]*StructStatement
structIDs map[uint16]*StructStatement
nextID uint16
} }
// NewTypeInferrer creates a new type inference engine // NewTypeInferrer creates a new type inference engine
func NewTypeInferrer() *TypeInferrer { func NewTypeInferrer() *TypeInferrer {
globalScope := NewScope(nil) globalScope := NewScope(nil)
return &TypeInferrer{ ti := &TypeInferrer{
currentScope: globalScope, currentScope: globalScope,
globalScope: globalScope, globalScope: globalScope,
errors: []TypeError{}, errors: []TypeError{},
structs: make(map[string]*StructStatement), structs: make(map[string]*StructStatement),
structIDs: make(map[uint16]*StructStatement),
nextID: 1, // 0 reserved for non-struct types // Pre-allocate common types to reduce allocations
} numberType: &TypeInfo{Type: TypeNumber, Inferred: true},
stringType: &TypeInfo{Type: TypeString, Inferred: true},
boolType: &TypeInfo{Type: TypeBool, Inferred: true},
nilType: &TypeInfo{Type: TypeNil, Inferred: true},
tableType: &TypeInfo{Type: TypeTable, Inferred: true},
anyType: &TypeInfo{Type: TypeAny, Inferred: true},
} }
// RegisterStruct assigns ID to struct and tracks it return ti
func (ti *TypeInferrer) RegisterStruct(stmt *StructStatement) {
stmt.ID = ti.nextID
ti.nextID++
ti.structs[stmt.Name] = stmt
ti.structIDs[stmt.ID] = stmt
}
// GetStructByID returns struct definition by ID
func (ti *TypeInferrer) GetStructByID(id uint16) *StructStatement {
return ti.structIDs[id]
}
// CreateStructType creates TypeInfo for a struct
func (ti *TypeInferrer) CreateStructType(structID uint16) TypeInfo {
return TypeInfo{Type: TypeStruct, StructID: structID, Inferred: true}
} }
// InferTypes performs type inference on the entire program // InferTypes performs type inference on the entire program
func (ti *TypeInferrer) InferTypes(program *Program) []TypeError { func (ti *TypeInferrer) InferTypes(program *Program) []TypeError {
// First pass: collect and register struct definitions // First pass: collect struct definitions
for _, stmt := range program.Statements { for _, stmt := range program.Statements {
if structStmt, ok := stmt.(*StructStatement); ok { if structStmt, ok := stmt.(*StructStatement); ok {
ti.RegisterStruct(structStmt) ti.structs[structStmt.Name] = structStmt
} }
} }
@ -147,6 +119,42 @@ func (ti *TypeInferrer) InferTypes(program *Program) []TypeError {
return ti.errors return ti.errors
} }
// enterScope creates a new scope
func (ti *TypeInferrer) enterScope() {
ti.currentScope = NewScope(ti.currentScope)
}
// exitScope returns to the parent scope
func (ti *TypeInferrer) exitScope() {
if ti.currentScope.parent != nil {
ti.currentScope = ti.currentScope.parent
}
}
// addError adds a type error
func (ti *TypeInferrer) addError(message string, node Node) {
ti.errors = append(ti.errors, TypeError{
Message: message,
Line: 0, // Would need to track position in AST nodes
Column: 0,
Node: node,
})
}
// getStructType returns TypeInfo for a struct
func (ti *TypeInferrer) getStructType(name string) *TypeInfo {
if _, exists := ti.structs[name]; exists {
return &TypeInfo{Type: name, Inferred: true}
}
return nil
}
// isStructType checks if a type is a struct type
func (ti *TypeInferrer) isStructType(t *TypeInfo) bool {
_, exists := ti.structs[t.Type]
return exists
}
// inferStatement infers types for statements // inferStatement infers types for statements
func (ti *TypeInferrer) inferStatement(stmt Statement) { func (ti *TypeInferrer) inferStatement(stmt Statement) {
switch s := stmt.(type) { switch s := stmt.(type) {
@ -154,10 +162,8 @@ func (ti *TypeInferrer) inferStatement(stmt Statement) {
ti.inferStructStatement(s) ti.inferStructStatement(s)
case *MethodDefinition: case *MethodDefinition:
ti.inferMethodDefinition(s) ti.inferMethodDefinition(s)
case *Assignment: case *AssignStatement:
ti.inferAssignment(s) ti.inferAssignStatement(s)
case *ExpressionStatement:
ti.inferExpression(s.Expression)
case *EchoStatement: case *EchoStatement:
ti.inferExpression(s.Value) ti.inferExpression(s.Value)
case *IfStatement: case *IfStatement:
@ -176,38 +182,46 @@ func (ti *TypeInferrer) inferStatement(stmt Statement) {
if s.Value != nil { if s.Value != nil {
ti.inferExpression(s.Value) ti.inferExpression(s.Value)
} }
case *BreakStatement: case *ExpressionStatement:
// No-op ti.inferExpression(s.Expression)
} }
} }
// inferStructStatement handles struct definitions
func (ti *TypeInferrer) inferStructStatement(stmt *StructStatement) { func (ti *TypeInferrer) inferStructStatement(stmt *StructStatement) {
// Validate field types
for _, field := range stmt.Fields { for _, field := range stmt.Fields {
if !ti.isValidType(field.TypeHint) { if field.TypeHint != nil {
ti.addError(fmt.Sprintf("invalid field type in struct '%s'", stmt.Name), stmt) if !ValidTypeName(field.TypeHint.Type) && !ti.isStructType(field.TypeHint) {
ti.addError(fmt.Sprintf("invalid field type '%s' in struct '%s'",
field.TypeHint.Type, stmt.Name), stmt)
}
} }
} }
} }
// inferMethodDefinition handles method definitions
func (ti *TypeInferrer) inferMethodDefinition(stmt *MethodDefinition) { func (ti *TypeInferrer) inferMethodDefinition(stmt *MethodDefinition) {
structDef := ti.GetStructByID(stmt.StructID) // Check if struct exists
if structDef == nil { if _, exists := ti.structs[stmt.StructName]; !exists {
ti.addError("method defined on undefined struct", stmt) ti.addError(fmt.Sprintf("method defined on undefined struct '%s'", stmt.StructName), stmt)
return return
} }
// Infer the function body
ti.enterScope() ti.enterScope()
// Add self parameter
// Add self parameter implicitly
ti.currentScope.Define(&Symbol{ ti.currentScope.Define(&Symbol{
Name: "self", Name: "self",
Type: ti.CreateStructType(stmt.StructID), Type: ti.getStructType(stmt.StructName),
Declared: true, Declared: true,
}) })
// Add function parameters // Add explicit parameters
for _, param := range stmt.Function.Parameters { for _, param := range stmt.Function.Parameters {
paramType := AnyType paramType := ti.anyType
if param.TypeHint.Type != TypeUnknown { if param.TypeHint != nil {
paramType = param.TypeHint paramType = param.TypeHint
} }
ti.currentScope.Define(&Symbol{ ti.currentScope.Define(&Symbol{
@ -221,47 +235,66 @@ func (ti *TypeInferrer) inferMethodDefinition(stmt *MethodDefinition) {
for _, bodyStmt := range stmt.Function.Body { for _, bodyStmt := range stmt.Function.Body {
ti.inferStatement(bodyStmt) ti.inferStatement(bodyStmt)
} }
ti.exitScope() ti.exitScope()
} }
func (ti *TypeInferrer) inferAssignment(stmt *Assignment) { // inferAssignStatement handles variable assignments with type checking
func (ti *TypeInferrer) inferAssignStatement(stmt *AssignStatement) {
// Infer the type of the value expression
valueType := ti.inferExpression(stmt.Value) valueType := ti.inferExpression(stmt.Value)
if ident, ok := stmt.Target.(*Identifier); ok { if ident, ok := stmt.Name.(*Identifier); ok {
// Simple variable assignment
symbol := ti.currentScope.Lookup(ident.Value)
if stmt.IsDeclaration { if stmt.IsDeclaration {
// New variable declaration
varType := valueType varType := valueType
if stmt.TypeHint.Type != TypeUnknown {
// If there's a type hint, validate it
if stmt.TypeHint != nil {
if !ti.isTypeCompatible(valueType, stmt.TypeHint) { if !ti.isTypeCompatible(valueType, stmt.TypeHint) {
ti.addError("type mismatch in assignment", stmt) ti.addError(fmt.Sprintf("cannot assign %s to variable of type %s",
valueType.Type, stmt.TypeHint.Type), stmt)
} }
varType = stmt.TypeHint varType = stmt.TypeHint
varType.Inferred = false
} }
// Define the new symbol
ti.currentScope.Define(&Symbol{ ti.currentScope.Define(&Symbol{
Name: ident.Value, Name: ident.Value,
Type: varType, Type: varType,
Declared: true, Declared: true,
}) })
ident.typeInfo = varType
ident.SetType(varType)
} else { } else {
symbol := ti.currentScope.Lookup(ident.Value) // Assignment to existing variable
if symbol == nil { if symbol == nil {
ti.addError(fmt.Sprintf("undefined variable '%s'", ident.Value), stmt) ti.addError(fmt.Sprintf("undefined variable '%s'", ident.Value), stmt)
return return
} }
// Check type compatibility
if !ti.isTypeCompatible(valueType, symbol.Type) { if !ti.isTypeCompatible(valueType, symbol.Type) {
ti.addError("type mismatch in assignment", stmt) ti.addError(fmt.Sprintf("cannot assign %s to variable of type %s",
valueType.Type, symbol.Type.Type), stmt)
} }
ident.typeInfo = symbol.Type
ident.SetType(symbol.Type)
} }
} else { } else {
// Member access assignment (table.key or table[index]) // Member access assignment (table.key or table[index])
ti.inferExpression(stmt.Target) ti.inferExpression(stmt.Name)
} }
} }
// inferIfStatement handles if statements
func (ti *TypeInferrer) inferIfStatement(stmt *IfStatement) { func (ti *TypeInferrer) inferIfStatement(stmt *IfStatement) {
ti.inferExpression(stmt.Condition) condType := ti.inferExpression(stmt.Condition)
ti.validateBooleanContext(condType, stmt.Condition)
ti.enterScope() ti.enterScope()
for _, s := range stmt.Body { for _, s := range stmt.Body {
@ -270,7 +303,9 @@ func (ti *TypeInferrer) inferIfStatement(stmt *IfStatement) {
ti.exitScope() ti.exitScope()
for _, elseif := range stmt.ElseIfs { for _, elseif := range stmt.ElseIfs {
ti.inferExpression(elseif.Condition) condType := ti.inferExpression(elseif.Condition)
ti.validateBooleanContext(condType, elseif.Condition)
ti.enterScope() ti.enterScope()
for _, s := range elseif.Body { for _, s := range elseif.Body {
ti.inferStatement(s) ti.inferStatement(s)
@ -287,8 +322,10 @@ func (ti *TypeInferrer) inferIfStatement(stmt *IfStatement) {
} }
} }
// inferWhileStatement handles while loops
func (ti *TypeInferrer) inferWhileStatement(stmt *WhileStatement) { func (ti *TypeInferrer) inferWhileStatement(stmt *WhileStatement) {
ti.inferExpression(stmt.Condition) condType := ti.inferExpression(stmt.Condition)
ti.validateBooleanContext(condType, stmt.Condition)
ti.enterScope() ti.enterScope()
for _, s := range stmt.Body { for _, s := range stmt.Body {
@ -297,21 +334,33 @@ func (ti *TypeInferrer) inferWhileStatement(stmt *WhileStatement) {
ti.exitScope() ti.exitScope()
} }
// inferForStatement handles numeric for loops
func (ti *TypeInferrer) inferForStatement(stmt *ForStatement) { func (ti *TypeInferrer) inferForStatement(stmt *ForStatement) {
ti.inferExpression(stmt.Start) startType := ti.inferExpression(stmt.Start)
ti.inferExpression(stmt.End) endType := ti.inferExpression(stmt.End)
if !ti.isNumericType(startType) {
ti.addError("for loop start value must be numeric", stmt.Start)
}
if !ti.isNumericType(endType) {
ti.addError("for loop end value must be numeric", stmt.End)
}
if stmt.Step != nil { if stmt.Step != nil {
ti.inferExpression(stmt.Step) stepType := ti.inferExpression(stmt.Step)
if !ti.isNumericType(stepType) {
ti.addError("for loop step value must be numeric", stmt.Step)
}
} }
ti.enterScope() ti.enterScope()
// Define loop variable as number // Define loop variable as number
ti.currentScope.Define(&Symbol{ ti.currentScope.Define(&Symbol{
Name: stmt.Variable.Value, Name: stmt.Variable.Value,
Type: NumberType, Type: ti.numberType,
Declared: true, Declared: true,
}) })
stmt.Variable.typeInfo = NumberType stmt.Variable.SetType(ti.numberType)
for _, s := range stmt.Body { for _, s := range stmt.Body {
ti.inferStatement(s) ti.inferStatement(s)
@ -319,26 +368,33 @@ func (ti *TypeInferrer) inferForStatement(stmt *ForStatement) {
ti.exitScope() ti.exitScope()
} }
// inferForInStatement handles for-in loops
func (ti *TypeInferrer) inferForInStatement(stmt *ForInStatement) { func (ti *TypeInferrer) inferForInStatement(stmt *ForInStatement) {
ti.inferExpression(stmt.Iterable) iterableType := ti.inferExpression(stmt.Iterable)
// For now, assume iterable is a table or struct
if !ti.isTableType(iterableType) && !ti.isStructType(iterableType) {
ti.addError("for-in requires an iterable (table or struct)", stmt.Iterable)
}
ti.enterScope() ti.enterScope()
// Define loop variables
// Define loop variables (key and value are any for now)
if stmt.Key != nil { if stmt.Key != nil {
ti.currentScope.Define(&Symbol{ ti.currentScope.Define(&Symbol{
Name: stmt.Key.Value, Name: stmt.Key.Value,
Type: AnyType, Type: ti.anyType,
Declared: true, Declared: true,
}) })
stmt.Key.typeInfo = AnyType stmt.Key.SetType(ti.anyType)
} }
ti.currentScope.Define(&Symbol{ ti.currentScope.Define(&Symbol{
Name: stmt.Value.Value, Name: stmt.Value.Value,
Type: AnyType, Type: ti.anyType,
Declared: true, Declared: true,
}) })
stmt.Value.typeInfo = AnyType stmt.Value.SetType(ti.anyType)
for _, s := range stmt.Body { for _, s := range stmt.Body {
ti.inferStatement(s) ti.inferStatement(s)
@ -347,25 +403,29 @@ func (ti *TypeInferrer) inferForInStatement(stmt *ForInStatement) {
} }
// inferExpression infers the type of an expression // inferExpression infers the type of an expression
func (ti *TypeInferrer) inferExpression(expr Expression) TypeInfo { func (ti *TypeInferrer) inferExpression(expr Expression) *TypeInfo {
if expr == nil { if expr == nil {
return NilType return ti.nilType
} }
switch e := expr.(type) { switch e := expr.(type) {
case *Identifier: case *Identifier:
return ti.inferIdentifier(e) return ti.inferIdentifier(e)
case *NumberLiteral: case *NumberLiteral:
return NumberType e.SetType(ti.numberType)
return ti.numberType
case *StringLiteral: case *StringLiteral:
return StringType e.SetType(ti.stringType)
return ti.stringType
case *BooleanLiteral: case *BooleanLiteral:
return BoolType e.SetType(ti.boolType)
return ti.boolType
case *NilLiteral: case *NilLiteral:
return NilType e.SetType(ti.nilType)
return ti.nilType
case *TableLiteral: case *TableLiteral:
return ti.inferTableLiteral(e) return ti.inferTableLiteral(e)
case *StructConstructor: case *StructConstructorExpression:
return ti.inferStructConstructor(e) return ti.inferStructConstructor(e)
case *FunctionLiteral: case *FunctionLiteral:
return ti.inferFunctionLiteral(e) return ti.inferFunctionLiteral(e)
@ -379,40 +439,20 @@ func (ti *TypeInferrer) inferExpression(expr Expression) TypeInfo {
return ti.inferIndexExpression(e) return ti.inferIndexExpression(e)
case *DotExpression: case *DotExpression:
return ti.inferDotExpression(e) return ti.inferDotExpression(e)
case *Assignment: case *AssignExpression:
return ti.inferAssignmentExpression(e) return ti.inferAssignExpression(e)
default: default:
ti.addError("unknown expression type", expr) ti.addError("unknown expression type", expr)
return AnyType return ti.anyType
} }
} }
func (ti *TypeInferrer) inferIdentifier(ident *Identifier) TypeInfo { // inferStructConstructor handles struct constructor expressions
symbol := ti.currentScope.Lookup(ident.Value) func (ti *TypeInferrer) inferStructConstructor(expr *StructConstructorExpression) *TypeInfo {
if symbol == nil { structDef, exists := ti.structs[expr.StructName]
ti.addError(fmt.Sprintf("undefined variable '%s'", ident.Value), ident) if !exists {
return AnyType ti.addError(fmt.Sprintf("undefined struct '%s'", expr.StructName), expr)
} return ti.anyType
ident.typeInfo = symbol.Type
return symbol.Type
}
func (ti *TypeInferrer) inferTableLiteral(table *TableLiteral) TypeInfo {
// Infer types of all values
for _, pair := range table.Pairs {
if pair.Key != nil {
ti.inferExpression(pair.Key)
}
ti.inferExpression(pair.Value)
}
return TableType
}
func (ti *TypeInferrer) inferStructConstructor(expr *StructConstructor) TypeInfo {
structDef := ti.GetStructByID(expr.StructID)
if structDef == nil {
ti.addError("undefined struct in constructor", expr)
return AnyType
} }
// Validate field assignments // Validate field assignments
@ -425,23 +465,25 @@ func (ti *TypeInferrer) inferStructConstructor(expr *StructConstructor) TypeInfo
fieldName = str.Value fieldName = str.Value
} }
// Find field in struct definition // Check if field exists in struct
var fieldType TypeInfo fieldExists := false
found := false var fieldType *TypeInfo
for _, field := range structDef.Fields { for _, field := range structDef.Fields {
if field.Name == fieldName { if field.Name == fieldName {
fieldExists = true
fieldType = field.TypeHint fieldType = field.TypeHint
found = true
break break
} }
} }
if !found { if !fieldExists {
ti.addError(fmt.Sprintf("struct has no field '%s'", fieldName), expr) ti.addError(fmt.Sprintf("struct '%s' has no field '%s'", expr.StructName, fieldName), expr)
} else { } else {
// Check type compatibility
valueType := ti.inferExpression(pair.Value) valueType := ti.inferExpression(pair.Value)
if !ti.isTypeCompatible(valueType, fieldType) { if !ti.isTypeCompatible(valueType, fieldType) {
ti.addError("field type mismatch in struct constructor", expr) ti.addError(fmt.Sprintf("cannot assign %s to field '%s' of type %s",
valueType.Type, fieldName, fieldType.Type), expr)
} }
} }
} else { } else {
@ -450,18 +492,66 @@ func (ti *TypeInferrer) inferStructConstructor(expr *StructConstructor) TypeInfo
} }
} }
structType := ti.CreateStructType(expr.StructID) structType := ti.getStructType(expr.StructName)
expr.typeInfo = structType expr.SetType(structType)
return structType return structType
} }
func (ti *TypeInferrer) inferFunctionLiteral(fn *FunctionLiteral) TypeInfo { // inferAssignExpression handles assignment expressions
func (ti *TypeInferrer) inferAssignExpression(expr *AssignExpression) *TypeInfo {
valueType := ti.inferExpression(expr.Value)
if ident, ok := expr.Name.(*Identifier); ok {
if expr.IsDeclaration {
ti.currentScope.Define(&Symbol{
Name: ident.Value,
Type: valueType,
Declared: true,
})
}
ident.SetType(valueType)
} else {
ti.inferExpression(expr.Name)
}
expr.SetType(valueType)
return valueType
}
// inferIdentifier looks up identifier type in symbol table
func (ti *TypeInferrer) inferIdentifier(ident *Identifier) *TypeInfo {
symbol := ti.currentScope.Lookup(ident.Value)
if symbol == nil {
ti.addError(fmt.Sprintf("undefined variable '%s'", ident.Value), ident)
return ti.anyType
}
ident.SetType(symbol.Type)
return symbol.Type
}
// inferTableLiteral infers table type
func (ti *TypeInferrer) inferTableLiteral(table *TableLiteral) *TypeInfo {
// Infer types of all values
for _, pair := range table.Pairs {
if pair.Key != nil {
ti.inferExpression(pair.Key)
}
ti.inferExpression(pair.Value)
}
table.SetType(ti.tableType)
return ti.tableType
}
// inferFunctionLiteral infers function type
func (ti *TypeInferrer) inferFunctionLiteral(fn *FunctionLiteral) *TypeInfo {
ti.enterScope() ti.enterScope()
// Define parameters in function scope // Define parameters in function scope
for _, param := range fn.Parameters { for _, param := range fn.Parameters {
paramType := AnyType paramType := ti.anyType
if param.TypeHint.Type != TypeUnknown { if param.TypeHint != nil {
paramType = param.TypeHint paramType = param.TypeHint
} }
@ -478,88 +568,104 @@ func (ti *TypeInferrer) inferFunctionLiteral(fn *FunctionLiteral) TypeInfo {
} }
ti.exitScope() ti.exitScope()
return FunctionType
// For now, all functions have type "function"
funcType := &TypeInfo{Type: TypeFunction, Inferred: true}
fn.SetType(funcType)
return funcType
} }
func (ti *TypeInferrer) inferCallExpression(call *CallExpression) TypeInfo { // inferCallExpression infers function call return type
ti.inferExpression(call.Function) func (ti *TypeInferrer) inferCallExpression(call *CallExpression) *TypeInfo {
funcType := ti.inferExpression(call.Function)
if !ti.isFunctionType(funcType) {
ti.addError("cannot call non-function", call.Function)
return ti.anyType
}
// Infer argument types // Infer argument types
for _, arg := range call.Arguments { for _, arg := range call.Arguments {
ti.inferExpression(arg) ti.inferExpression(arg)
} }
call.typeInfo = AnyType // For now, assume function calls return any
return AnyType call.SetType(ti.anyType)
return ti.anyType
} }
func (ti *TypeInferrer) inferPrefixExpression(prefix *PrefixExpression) TypeInfo { // inferPrefixExpression infers prefix operation type
func (ti *TypeInferrer) inferPrefixExpression(prefix *PrefixExpression) *TypeInfo {
rightType := ti.inferExpression(prefix.Right) rightType := ti.inferExpression(prefix.Right)
var resultType TypeInfo var resultType *TypeInfo
switch prefix.Operator { switch prefix.Operator {
case "-": case "-":
if !ti.isNumericType(rightType) { if !ti.isNumericType(rightType) {
ti.addError("unary minus requires numeric operand", prefix) ti.addError("unary minus requires numeric operand", prefix)
} }
resultType = NumberType resultType = ti.numberType
case "not": case "not":
resultType = BoolType resultType = ti.boolType
default: default:
ti.addError(fmt.Sprintf("unknown prefix operator '%s'", prefix.Operator), prefix) ti.addError(fmt.Sprintf("unknown prefix operator '%s'", prefix.Operator), prefix)
resultType = AnyType resultType = ti.anyType
} }
prefix.typeInfo = resultType prefix.SetType(resultType)
return resultType return resultType
} }
func (ti *TypeInferrer) inferInfixExpression(infix *InfixExpression) TypeInfo { // inferInfixExpression infers binary operation type
func (ti *TypeInferrer) inferInfixExpression(infix *InfixExpression) *TypeInfo {
leftType := ti.inferExpression(infix.Left) leftType := ti.inferExpression(infix.Left)
rightType := ti.inferExpression(infix.Right) rightType := ti.inferExpression(infix.Right)
var resultType TypeInfo var resultType *TypeInfo
switch infix.Operator { switch infix.Operator {
case "+", "-", "*", "/": case "+", "-", "*", "/":
if !ti.isNumericType(leftType) || !ti.isNumericType(rightType) { if !ti.isNumericType(leftType) || !ti.isNumericType(rightType) {
ti.addError(fmt.Sprintf("arithmetic operator '%s' requires numeric operands", infix.Operator), infix) ti.addError(fmt.Sprintf("arithmetic operator '%s' requires numeric operands", infix.Operator), infix)
} }
resultType = NumberType resultType = ti.numberType
case "==", "!=": case "==", "!=":
// Equality works with any types // Equality works with any types
resultType = BoolType resultType = ti.boolType
case "<", ">", "<=", ">=": case "<", ">", "<=", ">=":
if !ti.isComparableTypes(leftType, rightType) { if !ti.isComparableTypes(leftType, rightType) {
ti.addError(fmt.Sprintf("comparison operator '%s' requires compatible operands", infix.Operator), infix) ti.addError(fmt.Sprintf("comparison operator '%s' requires compatible operands", infix.Operator), infix)
} }
resultType = BoolType resultType = ti.boolType
case "and", "or": case "and", "or":
resultType = BoolType ti.validateBooleanContext(leftType, infix.Left)
ti.validateBooleanContext(rightType, infix.Right)
resultType = ti.boolType
default: default:
ti.addError(fmt.Sprintf("unknown infix operator '%s'", infix.Operator), infix) ti.addError(fmt.Sprintf("unknown infix operator '%s'", infix.Operator), infix)
resultType = AnyType resultType = ti.anyType
} }
infix.typeInfo = resultType infix.SetType(resultType)
return resultType return resultType
} }
func (ti *TypeInferrer) inferIndexExpression(index *IndexExpression) TypeInfo { // inferIndexExpression infers table[index] type
func (ti *TypeInferrer) inferIndexExpression(index *IndexExpression) *TypeInfo {
leftType := ti.inferExpression(index.Left) leftType := ti.inferExpression(index.Left)
ti.inferExpression(index.Index) ti.inferExpression(index.Index)
// If indexing a struct, try to infer field type // If indexing a struct, try to infer field type
if ti.isStructType(leftType) { if ti.isStructType(leftType) {
if strLit, ok := index.Index.(*StringLiteral); ok { if strLit, ok := index.Index.(*StringLiteral); ok {
if structDef := ti.GetStructByID(leftType.StructID); structDef != nil { if structDef, exists := ti.structs[leftType.Type]; exists {
for _, field := range structDef.Fields { for _, field := range structDef.Fields {
if field.Name == strLit.Value { if field.Name == strLit.Value {
index.typeInfo = field.TypeHint index.SetType(field.TypeHint)
return field.TypeHint return field.TypeHint
} }
} }
@ -568,19 +674,20 @@ func (ti *TypeInferrer) inferIndexExpression(index *IndexExpression) TypeInfo {
} }
// For now, assume table/struct access returns any // For now, assume table/struct access returns any
index.typeInfo = AnyType index.SetType(ti.anyType)
return AnyType return ti.anyType
} }
func (ti *TypeInferrer) inferDotExpression(dot *DotExpression) TypeInfo { // inferDotExpression infers table.key type
func (ti *TypeInferrer) inferDotExpression(dot *DotExpression) *TypeInfo {
leftType := ti.inferExpression(dot.Left) leftType := ti.inferExpression(dot.Left)
// If accessing a struct field, try to infer field type // If accessing a struct field, try to infer field type
if ti.isStructType(leftType) { if ti.isStructType(leftType) {
if structDef := ti.GetStructByID(leftType.StructID); structDef != nil { if structDef, exists := ti.structs[leftType.Type]; exists {
for _, field := range structDef.Fields { for _, field := range structDef.Fields {
if field.Name == dot.Key { if field.Name == dot.Key {
dot.typeInfo = field.TypeHint dot.SetType(field.TypeHint)
return field.TypeHint return field.TypeHint
} }
} }
@ -588,107 +695,59 @@ func (ti *TypeInferrer) inferDotExpression(dot *DotExpression) TypeInfo {
} }
// For now, assume member access returns any // For now, assume member access returns any
dot.typeInfo = AnyType dot.SetType(ti.anyType)
return AnyType return ti.anyType
} }
func (ti *TypeInferrer) inferAssignmentExpression(expr *Assignment) TypeInfo { // Type checking helper methods
valueType := ti.inferExpression(expr.Value)
if ident, ok := expr.Target.(*Identifier); ok { func (ti *TypeInferrer) isTypeCompatible(valueType, targetType *TypeInfo) bool {
if expr.IsDeclaration {
varType := valueType
if expr.TypeHint.Type != TypeUnknown {
if !ti.isTypeCompatible(valueType, expr.TypeHint) {
ti.addError("type mismatch in assignment", expr)
}
varType = expr.TypeHint
}
ti.currentScope.Define(&Symbol{
Name: ident.Value,
Type: varType,
Declared: true,
})
ident.typeInfo = varType
} else {
symbol := ti.currentScope.Lookup(ident.Value)
if symbol != nil {
ident.typeInfo = symbol.Type
}
}
} else {
ti.inferExpression(expr.Target)
}
return valueType
}
// Helper methods
func (ti *TypeInferrer) enterScope() {
ti.currentScope = NewScope(ti.currentScope)
}
func (ti *TypeInferrer) exitScope() {
if ti.currentScope.parent != nil {
ti.currentScope = ti.currentScope.parent
}
}
func (ti *TypeInferrer) addError(message string, node Node) {
ti.errors = append(ti.errors, TypeError{
Message: message,
Node: node,
})
}
func (ti *TypeInferrer) isValidType(t TypeInfo) bool {
if t.Type == TypeStruct {
return ti.GetStructByID(t.StructID) != nil
}
return t.Type <= TypeStruct
}
func (ti *TypeInferrer) isTypeCompatible(valueType, targetType TypeInfo) bool {
if targetType.Type == TypeAny || valueType.Type == TypeAny { if targetType.Type == TypeAny || valueType.Type == TypeAny {
return true return true
} }
if valueType.Type == TypeStruct && targetType.Type == TypeStruct {
return valueType.StructID == targetType.StructID
}
return valueType.Type == targetType.Type return valueType.Type == targetType.Type
} }
func (ti *TypeInferrer) isNumericType(t TypeInfo) bool { func (ti *TypeInferrer) isNumericType(t *TypeInfo) bool {
return t.Type == TypeNumber return t.Type == TypeNumber
} }
func (ti *TypeInferrer) isBooleanType(t TypeInfo) bool { func (ti *TypeInferrer) isBooleanType(t *TypeInfo) bool {
return t.Type == TypeBool return t.Type == TypeBool
} }
func (ti *TypeInferrer) isTableType(t TypeInfo) bool { func (ti *TypeInferrer) isTableType(t *TypeInfo) bool {
return t.Type == TypeTable return t.Type == TypeTable
} }
func (ti *TypeInferrer) isFunctionType(t TypeInfo) bool { func (ti *TypeInferrer) isFunctionType(t *TypeInfo) bool {
return t.Type == TypeFunction return t.Type == TypeFunction
} }
func (ti *TypeInferrer) isStructType(t TypeInfo) bool { func (ti *TypeInferrer) isComparableTypes(left, right *TypeInfo) bool {
return t.Type == TypeStruct
}
func (ti *TypeInferrer) isComparableTypes(left, right TypeInfo) bool {
if left.Type == TypeAny || right.Type == TypeAny { if left.Type == TypeAny || right.Type == TypeAny {
return true return true
} }
return left.Type == right.Type && (left.Type == TypeNumber || left.Type == TypeString) return left.Type == right.Type && (left.Type == TypeNumber || left.Type == TypeString)
} }
// Error reporting func (ti *TypeInferrer) validateBooleanContext(t *TypeInfo, expr Expression) {
func (ti *TypeInferrer) Errors() []TypeError { return ti.errors } // In many languages, non-boolean values can be used in boolean context
func (ti *TypeInferrer) HasErrors() bool { return len(ti.errors) > 0 } // For strictness, we could require boolean type here
// For now, allow any type (truthy/falsy semantics)
}
// Errors returns all type checking errors
func (ti *TypeInferrer) Errors() []TypeError {
return ti.errors
}
// HasErrors returns true if there are any type errors
func (ti *TypeInferrer) HasErrors() bool {
return len(ti.errors) > 0
}
// ErrorStrings returns error messages as strings
func (ti *TypeInferrer) ErrorStrings() []string { func (ti *TypeInferrer) ErrorStrings() []string {
result := make([]string, len(ti.errors)) result := make([]string, len(ti.errors))
for i, err := range ti.errors { for i, err := range ti.errors {
@ -697,26 +756,21 @@ func (ti *TypeInferrer) ErrorStrings() []string {
return result return result
} }
// Type string conversion // ValidTypeName checks if a string is a valid type name
func TypeToString(t TypeInfo) string { func ValidTypeName(name string) bool {
switch t.Type { validTypes := []string{TypeNumber, TypeString, TypeBool, TypeNil, TypeTable, TypeFunction, TypeAny}
case TypeNumber: for _, validType := range validTypes {
return "number" if name == validType {
case TypeString: return true
return "string"
case TypeBool:
return "bool"
case TypeNil:
return "nil"
case TypeTable:
return "table"
case TypeFunction:
return "function"
case TypeAny:
return "any"
case TypeStruct:
return fmt.Sprintf("struct<%d>", t.StructID)
default:
return "unknown"
} }
} }
return false
}
// ParseTypeName converts a string to a TypeInfo (for parsing type hints)
func ParseTypeName(name string) *TypeInfo {
if ValidTypeName(name) {
return &TypeInfo{Type: name, Inferred: false}
}
return nil
}