type hint/inference

This commit is contained in:
Sky Johnson 2025-06-11 10:20:38 -05:00
parent c691c90c69
commit e98e0643e2
7 changed files with 1316 additions and 131 deletions

View File

@ -2,6 +2,12 @@ package parser
import "fmt"
// TypeInfo represents type information for expressions
type TypeInfo struct {
Type string // "number", "string", "bool", "table", "function", "nil", "any"
Inferred bool // true if type was inferred, false if explicitly declared
}
// Node represents any node in the AST
type Node interface {
String() string
@ -17,6 +23,8 @@ type Statement interface {
type Expression interface {
Node
expressionNode()
GetType() *TypeInfo
SetType(*TypeInfo)
}
// Program represents the root of the AST
@ -33,9 +41,10 @@ func (p *Program) String() string {
return result
}
// AssignStatement represents variable assignment
// AssignStatement represents variable assignment with optional type hint
type AssignStatement struct {
Name Expression // Changed from *Identifier to Expression for member access
TypeHint *TypeInfo // optional type hint
Value Expression
IsDeclaration bool // true if this is the first assignment in current scope
}
@ -46,7 +55,15 @@ func (as *AssignStatement) String() string {
if as.IsDeclaration {
prefix = "local "
}
return fmt.Sprintf("%s%s = %s", prefix, as.Name.String(), as.Value.String())
var nameStr string
if as.TypeHint != nil {
nameStr = fmt.Sprintf("%s: %s", as.Name.String(), as.TypeHint.Type)
} else {
nameStr = as.Name.String()
}
return fmt.Sprintf("%s%s = %s", prefix, nameStr, as.Value.String())
}
// EchoStatement represents echo output statements
@ -216,33 +233,56 @@ func (fis *ForInStatement) String() string {
return result
}
// Identifier represents identifiers
type Identifier struct {
Value string
// FunctionParameter represents a function parameter with optional type hint
type FunctionParameter struct {
Name string
TypeHint *TypeInfo
}
func (i *Identifier) expressionNode() {}
func (i *Identifier) String() string { return i.Value }
func (fp *FunctionParameter) String() string {
if fp.TypeHint != nil {
return fmt.Sprintf("%s: %s", fp.Name, fp.TypeHint.Type)
}
return fp.Name
}
// Identifier represents identifiers
type Identifier struct {
Value string
typeInfo *TypeInfo
}
func (i *Identifier) expressionNode() {}
func (i *Identifier) String() string { return i.Value }
func (i *Identifier) GetType() *TypeInfo { return i.typeInfo }
func (i *Identifier) SetType(t *TypeInfo) { i.typeInfo = t }
// NumberLiteral represents numeric literals
type NumberLiteral struct {
Value float64
Value float64
typeInfo *TypeInfo
}
func (nl *NumberLiteral) expressionNode() {}
func (nl *NumberLiteral) String() string { return fmt.Sprintf("%.2f", nl.Value) }
func (nl *NumberLiteral) expressionNode() {}
func (nl *NumberLiteral) String() string { return fmt.Sprintf("%.2f", nl.Value) }
func (nl *NumberLiteral) GetType() *TypeInfo { return nl.typeInfo }
func (nl *NumberLiteral) SetType(t *TypeInfo) { nl.typeInfo = t }
// StringLiteral represents string literals
type StringLiteral struct {
Value string
Value string
typeInfo *TypeInfo
}
func (sl *StringLiteral) expressionNode() {}
func (sl *StringLiteral) String() string { return fmt.Sprintf(`"%s"`, sl.Value) }
func (sl *StringLiteral) expressionNode() {}
func (sl *StringLiteral) String() string { return fmt.Sprintf(`"%s"`, sl.Value) }
func (sl *StringLiteral) GetType() *TypeInfo { return sl.typeInfo }
func (sl *StringLiteral) SetType(t *TypeInfo) { sl.typeInfo = t }
// BooleanLiteral represents boolean literals
type BooleanLiteral struct {
Value bool
Value bool
typeInfo *TypeInfo
}
func (bl *BooleanLiteral) expressionNode() {}
@ -252,18 +292,26 @@ func (bl *BooleanLiteral) String() string {
}
return "false"
}
func (bl *BooleanLiteral) GetType() *TypeInfo { return bl.typeInfo }
func (bl *BooleanLiteral) SetType(t *TypeInfo) { bl.typeInfo = t }
// NilLiteral represents nil literal
type NilLiteral struct{}
type NilLiteral struct {
typeInfo *TypeInfo
}
func (nl *NilLiteral) expressionNode() {}
func (nl *NilLiteral) String() string { return "nil" }
func (nl *NilLiteral) expressionNode() {}
func (nl *NilLiteral) String() string { return "nil" }
func (nl *NilLiteral) GetType() *TypeInfo { return nl.typeInfo }
func (nl *NilLiteral) SetType(t *TypeInfo) { nl.typeInfo = t }
// FunctionLiteral represents function literals: fn(a, b, ...) ... end
// FunctionLiteral represents function literals with typed parameters
type FunctionLiteral struct {
Parameters []string
Parameters []FunctionParameter
Variadic bool
ReturnType *TypeInfo // optional return type hint
Body []Statement
typeInfo *TypeInfo
}
func (fl *FunctionLiteral) expressionNode() {}
@ -273,7 +321,7 @@ func (fl *FunctionLiteral) String() string {
if i > 0 {
params += ", "
}
params += param
params += param.String()
}
if fl.Variadic {
if len(fl.Parameters) > 0 {
@ -282,18 +330,26 @@ func (fl *FunctionLiteral) String() string {
params += "..."
}
result := fmt.Sprintf("fn(%s)\n", params)
result := fmt.Sprintf("fn(%s)", params)
if fl.ReturnType != nil {
result += ": " + fl.ReturnType.Type
}
result += "\n"
for _, stmt := range fl.Body {
result += "\t" + stmt.String() + "\n"
}
result += "end"
return result
}
func (fl *FunctionLiteral) GetType() *TypeInfo { return fl.typeInfo }
func (fl *FunctionLiteral) SetType(t *TypeInfo) { fl.typeInfo = t }
// CallExpression represents function calls: func(arg1, arg2, ...)
type CallExpression struct {
Function Expression
Arguments []Expression
typeInfo *TypeInfo
}
func (ce *CallExpression) expressionNode() {}
@ -304,11 +360,14 @@ func (ce *CallExpression) String() string {
}
return fmt.Sprintf("%s(%s)", ce.Function.String(), joinStrings(args, ", "))
}
func (ce *CallExpression) GetType() *TypeInfo { return ce.typeInfo }
func (ce *CallExpression) SetType(t *TypeInfo) { ce.typeInfo = t }
// PrefixExpression represents prefix operations like -x, not x
type PrefixExpression struct {
Operator string
Right Expression
typeInfo *TypeInfo
}
func (pe *PrefixExpression) expressionNode() {}
@ -319,40 +378,51 @@ func (pe *PrefixExpression) String() string {
}
return fmt.Sprintf("(%s%s)", pe.Operator, pe.Right.String())
}
func (pe *PrefixExpression) GetType() *TypeInfo { return pe.typeInfo }
func (pe *PrefixExpression) SetType(t *TypeInfo) { pe.typeInfo = t }
// InfixExpression represents binary operations
type InfixExpression struct {
Left Expression
Operator string
Right Expression
typeInfo *TypeInfo
}
func (ie *InfixExpression) expressionNode() {}
func (ie *InfixExpression) String() string {
return fmt.Sprintf("(%s %s %s)", ie.Left.String(), ie.Operator, ie.Right.String())
}
func (ie *InfixExpression) GetType() *TypeInfo { return ie.typeInfo }
func (ie *InfixExpression) SetType(t *TypeInfo) { ie.typeInfo = t }
// IndexExpression represents table[key] access
type IndexExpression struct {
Left Expression
Index Expression
Left Expression
Index Expression
typeInfo *TypeInfo
}
func (ie *IndexExpression) expressionNode() {}
func (ie *IndexExpression) String() string {
return fmt.Sprintf("%s[%s]", ie.Left.String(), ie.Index.String())
}
func (ie *IndexExpression) GetType() *TypeInfo { return ie.typeInfo }
func (ie *IndexExpression) SetType(t *TypeInfo) { ie.typeInfo = t }
// DotExpression represents table.key access
type DotExpression struct {
Left Expression
Key string
Left Expression
Key string
typeInfo *TypeInfo
}
func (de *DotExpression) expressionNode() {}
func (de *DotExpression) String() string {
return fmt.Sprintf("%s.%s", de.Left.String(), de.Key)
}
func (de *DotExpression) GetType() *TypeInfo { return de.typeInfo }
func (de *DotExpression) SetType(t *TypeInfo) { de.typeInfo = t }
// TablePair represents a key-value pair in a table
type TablePair struct {
@ -369,7 +439,8 @@ func (tp *TablePair) String() string {
// TableLiteral represents table literals {}
type TableLiteral struct {
Pairs []TablePair
Pairs []TablePair
typeInfo *TypeInfo
}
func (tl *TableLiteral) expressionNode() {}
@ -380,6 +451,8 @@ func (tl *TableLiteral) String() string {
}
return fmt.Sprintf("{%s}", joinStrings(pairs, ", "))
}
func (tl *TableLiteral) GetType() *TypeInfo { return tl.typeInfo }
func (tl *TableLiteral) SetType(t *TypeInfo) { tl.typeInfo = t }
// IsArray returns true if this table contains only array-style elements
func (tl *TableLiteral) IsArray() bool {

View File

@ -257,6 +257,8 @@ func (l *Lexer) NextToken() Token {
tok = Token{Type: STAR, Literal: string(l.ch), Line: l.line, Column: l.column}
case '/':
tok = Token{Type: SLASH, Literal: string(l.ch), Line: l.line, Column: l.column}
case ':':
tok = Token{Type: COLON, Literal: string(l.ch), Line: l.line, Column: l.column}
case '.':
// Check for ellipsis (...)
if l.peekChar() == '.' && l.peekCharAt(2) == '.' {

View File

@ -89,14 +89,13 @@ func (p *Parser) enterScope(scopeType string) {
}
func (p *Parser) exitScope() {
if len(p.scopes) > 1 { // never remove global scope
if len(p.scopes) > 1 {
p.scopes = p.scopes[:len(p.scopes)-1]
p.scopeTypes = p.scopeTypes[:len(p.scopeTypes)-1]
}
}
func (p *Parser) enterFunctionScope() {
// Functions create new variable scopes
p.enterScope("function")
}
@ -105,35 +104,29 @@ func (p *Parser) exitFunctionScope() {
}
func (p *Parser) enterLoopScope() {
// Create temporary scope for loop variables only
p.enterScope("loop")
}
func (p *Parser) exitLoopScope() {
// Remove temporary loop scope
p.exitScope()
}
func (p *Parser) enterBlockScope() {
// Blocks don't create new variable scopes, just control flow scopes
// We don't need to track these for variable declarations
// Blocks don't create new variable scopes
}
func (p *Parser) exitBlockScope() {
// No-op since blocks don't create variable scopes
// No-op
}
func (p *Parser) currentVariableScope() map[string]bool {
// If we're in a loop scope, declare variables in the parent scope
if len(p.scopeTypes) > 1 && p.scopeTypes[len(p.scopeTypes)-1] == "loop" {
return p.scopes[len(p.scopes)-2]
}
// Otherwise use the current scope
return p.scopes[len(p.scopes)-1]
}
func (p *Parser) isVariableDeclared(name string) bool {
// Check all scopes from current up to global
for i := len(p.scopes) - 1; i >= 0; i-- {
if p.scopes[i][name] {
return true
@ -147,10 +140,31 @@ func (p *Parser) declareVariable(name string) {
}
func (p *Parser) declareLoopVariable(name string) {
// Loop variables go in the current loop scope
p.scopes[len(p.scopes)-1][name] = true
}
// parseTypeHint parses optional type hint after colon
func (p *Parser) parseTypeHint() *TypeInfo {
if !p.peekTokenIs(COLON) {
return nil
}
p.nextToken() // consume ':'
if !p.expectPeekIdent() {
p.addError("expected type name after ':'")
return nil
}
typeName := p.curToken.Literal
if !ValidTypeName(typeName) {
p.addError(fmt.Sprintf("invalid type name '%s'", typeName))
return nil
}
return &TypeInfo{Type: typeName, Inferred: false}
}
// registerPrefix registers a prefix parse function
func (p *Parser) registerPrefix(tokenType TokenType, fn func() Expression) {
p.prefixParseFns[tokenType] = fn
@ -187,7 +201,6 @@ func (p *Parser) ParseProgram() *Program {
func (p *Parser) parseStatement() Statement {
switch p.curToken.Type {
case IDENT:
// Try to parse as assignment (handles both simple and member access)
return p.parseAssignStatement()
case IF:
return p.parseIfStatement()
@ -217,17 +230,22 @@ func (p *Parser) parseStatement() Statement {
}
}
// parseAssignStatement parses variable assignment
// parseAssignStatement parses variable assignment with optional type hint
func (p *Parser) parseAssignStatement() *AssignStatement {
stmt := &AssignStatement{}
// Parse left-hand side expression (can be identifier or member access)
// Parse left-hand side expression
stmt.Name = p.ParseExpression(LOWEST)
if stmt.Name == nil {
p.addError("expected expression for assignment left-hand side")
return nil
}
// Check for type hint on simple identifiers
if _, ok := stmt.Name.(*Identifier); ok {
stmt.TypeHint = p.parseTypeHint()
}
// Check if next token is assignment operator
if !p.peekTokenIs(ASSIGN) {
p.addError("unexpected identifier, expected assignment or declaration")
@ -237,13 +255,11 @@ func (p *Parser) parseAssignStatement() *AssignStatement {
// Validate assignment target and check if it's a declaration
switch name := stmt.Name.(type) {
case *Identifier:
// Simple variable assignment - check if it's a declaration
stmt.IsDeclaration = !p.isVariableDeclared(name.Value)
if stmt.IsDeclaration {
p.declareVariable(name.Value)
}
case *DotExpression, *IndexExpression:
// Member access - never a declaration
stmt.IsDeclaration = false
default:
p.addError("invalid assignment target")
@ -289,10 +305,8 @@ func (p *Parser) parseBreakStatement() *BreakStatement {
func (p *Parser) parseExitStatement() *ExitStatement {
stmt := &ExitStatement{}
// Check if there's an optional expression after 'exit'
// Only parse expression if next token can start an expression
if p.canStartExpression(p.peekToken.Type) {
p.nextToken() // move past 'exit'
p.nextToken()
stmt.Value = p.ParseExpression(LOWEST)
if stmt.Value == nil {
p.addError("expected expression after 'exit'")
@ -307,9 +321,8 @@ func (p *Parser) parseExitStatement() *ExitStatement {
func (p *Parser) parseReturnStatement() *ReturnStatement {
stmt := &ReturnStatement{}
// Check if there's an optional expression after 'return'
if p.canStartExpression(p.peekToken.Type) {
p.nextToken() // move past 'return'
p.nextToken()
stmt.Value = p.ParseExpression(LOWEST)
if stmt.Value == nil {
p.addError("expected expression after 'return'")
@ -330,11 +343,11 @@ func (p *Parser) canStartExpression(tokenType TokenType) bool {
}
}
// parseWhileStatement parses while loops: while condition do ... end
// parseWhileStatement parses while loops
func (p *Parser) parseWhileStatement() *WhileStatement {
stmt := &WhileStatement{}
p.nextToken() // move past 'while'
p.nextToken()
stmt.Condition = p.ParseExpression(LOWEST)
if stmt.Condition == nil {
@ -347,9 +360,8 @@ func (p *Parser) parseWhileStatement() *WhileStatement {
return nil
}
p.nextToken() // move past 'do'
p.nextToken()
// Parse loop body (no new variable scope)
p.enterBlockScope()
stmt.Body = p.parseBlockStatements(END)
p.exitBlockScope()
@ -362,9 +374,9 @@ func (p *Parser) parseWhileStatement() *WhileStatement {
return stmt
}
// parseForStatement parses for loops (both numeric and for-in)
// parseForStatement parses for loops
func (p *Parser) parseForStatement() Statement {
p.nextToken() // move past 'for'
p.nextToken()
if !p.curTokenIs(IDENT) {
p.addError("expected identifier after 'for'")
@ -373,12 +385,9 @@ func (p *Parser) parseForStatement() Statement {
firstVar := &Identifier{Value: p.curToken.Literal}
// Look ahead to determine which type of for loop
if p.peekTokenIs(ASSIGN) {
// Numeric for loop: for i = start, end, step do
return p.parseNumericForStatement(firstVar)
} else if p.peekTokenIs(COMMA) || p.peekTokenIs(IN) {
// For-in loop: for k, v in expr do or for v in expr do
return p.parseForInStatement(firstVar)
} else {
p.addError("expected '=', ',' or 'in' after for loop variable")
@ -386,7 +395,7 @@ func (p *Parser) parseForStatement() Statement {
}
}
// parseNumericForStatement parses numeric for loops: for i = start, end, step do
// parseNumericForStatement parses numeric for loops
func (p *Parser) parseNumericForStatement(variable *Identifier) *ForStatement {
stmt := &ForStatement{Variable: variable}
@ -394,9 +403,8 @@ func (p *Parser) parseNumericForStatement(variable *Identifier) *ForStatement {
return nil
}
p.nextToken() // move past '='
p.nextToken()
// Parse start expression
stmt.Start = p.ParseExpression(LOWEST)
if stmt.Start == nil {
p.addError("expected start expression in for loop")
@ -408,19 +416,17 @@ func (p *Parser) parseNumericForStatement(variable *Identifier) *ForStatement {
return nil
}
p.nextToken() // move past ','
p.nextToken()
// Parse end expression
stmt.End = p.ParseExpression(LOWEST)
if stmt.End == nil {
p.addError("expected end expression in for loop")
return nil
}
// Optional step expression
if p.peekTokenIs(COMMA) {
p.nextToken() // move to ','
p.nextToken() // move past ','
p.nextToken()
p.nextToken()
stmt.Step = p.ParseExpression(LOWEST)
if stmt.Step == nil {
@ -434,13 +440,12 @@ func (p *Parser) parseNumericForStatement(variable *Identifier) *ForStatement {
return nil
}
p.nextToken() // move past 'do'
p.nextToken()
// Create temporary scope for loop variable, assignments in body go to parent scope
p.enterLoopScope()
p.declareLoopVariable(variable.Value) // loop variable in temporary scope
p.declareLoopVariable(variable.Value)
stmt.Body = p.parseBlockStatements(END)
p.exitLoopScope() // discard temporary scope with loop variable
p.exitLoopScope()
if !p.curTokenIs(END) {
p.addError("expected 'end' to close for loop")
@ -450,15 +455,14 @@ func (p *Parser) parseNumericForStatement(variable *Identifier) *ForStatement {
return stmt
}
// parseForInStatement parses for-in loops: for k, v in expr do or for v in expr do
// parseForInStatement parses for-in loops
func (p *Parser) parseForInStatement(firstVar *Identifier) *ForInStatement {
stmt := &ForInStatement{}
if p.peekTokenIs(COMMA) {
// Two variables: for k, v in expr do
stmt.Key = firstVar
p.nextToken() // move to ','
p.nextToken() // move past ','
p.nextToken()
p.nextToken()
if !p.curTokenIs(IDENT) {
p.addError("expected identifier after ',' in for loop")
@ -467,7 +471,6 @@ func (p *Parser) parseForInStatement(firstVar *Identifier) *ForInStatement {
stmt.Value = &Identifier{Value: p.curToken.Literal}
} else {
// Single variable: for v in expr do
stmt.Value = firstVar
}
@ -476,9 +479,8 @@ func (p *Parser) parseForInStatement(firstVar *Identifier) *ForInStatement {
return nil
}
p.nextToken() // move past 'in'
p.nextToken()
// Parse iterable expression
stmt.Iterable = p.ParseExpression(LOWEST)
if stmt.Iterable == nil {
p.addError("expected expression after 'in' in for loop")
@ -490,16 +492,15 @@ func (p *Parser) parseForInStatement(firstVar *Identifier) *ForInStatement {
return nil
}
p.nextToken() // move past 'do'
p.nextToken()
// Create temporary scope for loop variables, assignments in body go to parent scope
p.enterLoopScope()
if stmt.Key != nil {
p.declareLoopVariable(stmt.Key.Value) // loop variable in temporary scope
p.declareLoopVariable(stmt.Key.Value)
}
p.declareLoopVariable(stmt.Value.Value) // loop variable in temporary scope
p.declareLoopVariable(stmt.Value.Value)
stmt.Body = p.parseBlockStatements(END)
p.exitLoopScope() // discard temporary scope with loop variables
p.exitLoopScope()
if !p.curTokenIs(END) {
p.addError("expected 'end' to close for loop")
@ -509,11 +510,11 @@ func (p *Parser) parseForInStatement(firstVar *Identifier) *ForInStatement {
return stmt
}
// parseIfStatement parses if/elseif/else/end statements
// parseIfStatement parses if statements
func (p *Parser) parseIfStatement() *IfStatement {
stmt := &IfStatement{}
p.nextToken() // move past 'if'
p.nextToken()
stmt.Condition = p.ParseExpression(LOWEST)
if stmt.Condition == nil {
@ -521,29 +522,25 @@ func (p *Parser) parseIfStatement() *IfStatement {
return nil
}
// Optional 'then' keyword
if p.peekTokenIs(THEN) {
p.nextToken()
}
p.nextToken() // move past condition (and optional 'then')
p.nextToken()
// Check if we immediately hit END (empty body should be an error)
if p.curTokenIs(END) {
p.addError("expected 'end' to close if statement")
return nil
}
// Parse if body (no new variable scope)
p.enterBlockScope()
stmt.Body = p.parseBlockStatements(ELSEIF, ELSE, END)
p.exitBlockScope()
// Parse elseif clauses
for p.curTokenIs(ELSEIF) {
elseif := ElseIfClause{}
p.nextToken() // move past 'elseif'
p.nextToken()
elseif.Condition = p.ParseExpression(LOWEST)
if elseif.Condition == nil {
@ -551,14 +548,12 @@ func (p *Parser) parseIfStatement() *IfStatement {
return nil
}
// Optional 'then' keyword
if p.peekTokenIs(THEN) {
p.nextToken()
}
p.nextToken() // move past condition (and optional 'then')
p.nextToken()
// Parse elseif body (no new variable scope)
p.enterBlockScope()
elseif.Body = p.parseBlockStatements(ELSEIF, ELSE, END)
p.exitBlockScope()
@ -566,11 +561,9 @@ func (p *Parser) parseIfStatement() *IfStatement {
stmt.ElseIfs = append(stmt.ElseIfs, elseif)
}
// Parse else clause
if p.curTokenIs(ELSE) {
p.nextToken() // move past 'else'
p.nextToken()
// Parse else body (no new variable scope)
p.enterBlockScope()
stmt.Else = p.parseBlockStatements(END)
p.exitBlockScope()
@ -584,7 +577,7 @@ func (p *Parser) parseIfStatement() *IfStatement {
return stmt
}
// parseBlockStatements parses statements until one of the terminator tokens
// parseBlockStatements parses statements until terminators
func (p *Parser) parseBlockStatements(terminators ...TokenType) []Statement {
statements := []Statement{}
@ -599,7 +592,7 @@ func (p *Parser) parseBlockStatements(terminators ...TokenType) []Statement {
return statements
}
// isTerminator checks if current token is one of the terminators
// isTerminator checks if current token is a terminator
func (p *Parser) isTerminator(terminators ...TokenType) bool {
for _, terminator := range terminators {
if p.curTokenIs(terminator) {
@ -650,9 +643,7 @@ func (p *Parser) parseNumberLiteral() Expression {
var value float64
var err error
// Check for hexadecimal (0x/0X prefix)
if strings.HasPrefix(literal, "0x") || strings.HasPrefix(literal, "0X") {
// Validate hex format
if len(literal) <= 2 {
p.addError(fmt.Sprintf("could not parse '%s' as hexadecimal number", literal))
return nil
@ -664,7 +655,6 @@ func (p *Parser) parseNumberLiteral() Expression {
return nil
}
}
// Parse as hex and convert to float64
intVal, parseErr := strconv.ParseInt(literal, 0, 64)
if parseErr != nil {
p.addError(fmt.Sprintf("could not parse '%s' as hexadecimal number", literal))
@ -672,7 +662,6 @@ func (p *Parser) parseNumberLiteral() Expression {
}
value = float64(intVal)
} else if strings.HasPrefix(literal, "0b") || strings.HasPrefix(literal, "0B") {
// Validate binary format
if len(literal) <= 2 {
p.addError(fmt.Sprintf("could not parse '%s' as binary number", literal))
return nil
@ -684,8 +673,7 @@ func (p *Parser) parseNumberLiteral() Expression {
return nil
}
}
// Parse binary manually since Go doesn't support 0b in ParseInt with base 0
binaryStr := literal[2:] // remove "0b" prefix
binaryStr := literal[2:]
intVal, parseErr := strconv.ParseInt(binaryStr, 2, 64)
if parseErr != nil {
p.addError(fmt.Sprintf("could not parse '%s' as binary number", literal))
@ -693,7 +681,6 @@ func (p *Parser) parseNumberLiteral() Expression {
}
value = float64(intVal)
} else {
// Parse as regular decimal (handles scientific notation automatically)
value, err = strconv.ParseFloat(literal, 64)
if err != nil {
p.addError(fmt.Sprintf("could not parse '%s' as number", literal))
@ -763,12 +750,14 @@ func (p *Parser) parseFunctionLiteral() Expression {
return nil
}
p.nextToken() // move past ')'
// Check for return type hint
fn.ReturnType = p.parseTypeHint()
p.nextToken()
// Enter new function scope and declare parameters
p.enterFunctionScope()
for _, param := range fn.Parameters {
p.declareVariable(param)
p.declareVariable(param.Name)
}
fn.Body = p.parseBlockStatements(END)
p.exitFunctionScope()
@ -781,8 +770,8 @@ func (p *Parser) parseFunctionLiteral() Expression {
return fn
}
func (p *Parser) parseFunctionParameters() ([]string, bool) {
var params []string
func (p *Parser) parseFunctionParameters() ([]FunctionParameter, bool) {
var params []FunctionParameter
var variadic bool
if p.peekTokenIs(RPAREN) {
@ -802,16 +791,20 @@ func (p *Parser) parseFunctionParameters() ([]string, bool) {
return nil, false
}
params = append(params, p.curToken.Literal)
param := FunctionParameter{Name: p.curToken.Literal}
// Check for type hint
param.TypeHint = p.parseTypeHint()
params = append(params, param)
if !p.peekTokenIs(COMMA) {
break
}
p.nextToken() // move to ','
p.nextToken() // move past ','
p.nextToken()
p.nextToken()
// Check for ellipsis after comma
if p.curTokenIs(ELLIPSIS) {
variadic = true
break
@ -833,7 +826,6 @@ func (p *Parser) parseTableLiteral() Expression {
p.nextToken()
for {
// Check for EOF
if p.curTokenIs(EOF) {
p.addError("unexpected end of input, expected }")
return nil
@ -841,17 +833,15 @@ func (p *Parser) parseTableLiteral() Expression {
pair := TablePair{}
// Check if this is a key=value pair (identifier or string key)
if (p.curTokenIs(IDENT) || p.curTokenIs(STRING)) && p.peekTokenIs(ASSIGN) {
if p.curTokenIs(IDENT) {
pair.Key = &Identifier{Value: p.curToken.Literal}
} else {
pair.Key = &StringLiteral{Value: p.curToken.Literal}
}
p.nextToken() // move to =
p.nextToken() // move past =
p.nextToken()
p.nextToken()
// Check for EOF after =
if p.curTokenIs(EOF) {
p.addError("expected expression after assignment operator")
return nil
@ -859,7 +849,6 @@ func (p *Parser) parseTableLiteral() Expression {
pair.Value = p.ParseExpression(LOWEST)
} else {
// Array-style element
pair.Value = p.ParseExpression(LOWEST)
}
@ -873,15 +862,13 @@ func (p *Parser) parseTableLiteral() Expression {
break
}
p.nextToken() // consume comma
p.nextToken() // move to next element
p.nextToken()
p.nextToken()
// Allow trailing comma
if p.curTokenIs(RBRACE) {
break
}
// Check for EOF after comma
if p.curTokenIs(EOF) {
p.addError("expected next token to be }")
return nil
@ -956,7 +943,7 @@ func (p *Parser) parseExpressionList(end TokenType) []Expression {
}
func (p *Parser) parseIndexExpression(left Expression) Expression {
p.nextToken() // move past '['
p.nextToken()
index := p.ParseExpression(LOWEST)
if index == nil {
@ -993,7 +980,6 @@ func (p *Parser) expectPeek(t TokenType) bool {
return false
}
// expectPeekIdent accepts IDENT or keyword tokens as identifiers
func (p *Parser) expectPeekIdent() bool {
if p.peekTokenIs(IDENT) || p.isKeyword(p.peekToken.Type) {
p.nextToken()
@ -1003,7 +989,6 @@ func (p *Parser) expectPeekIdent() bool {
return false
}
// isKeyword checks if a token type is a keyword that can be used as identifier
func (p *Parser) isKeyword(t TokenType) bool {
switch t {
case TRUE, FALSE, NIL, AND, OR, NOT, IF, THEN, ELSEIF, ELSE, END, ECHO, FOR, WHILE, IN, DO, BREAK, EXIT, FN, RETURN:
@ -1013,7 +998,7 @@ func (p *Parser) isKeyword(t TokenType) bool {
}
}
// Error handling methods
// Error handling
func (p *Parser) addError(message string) {
p.errors = append(p.errors, ParseError{
Message: message,
@ -1075,12 +1060,10 @@ func (p *Parser) Errors() []ParseError {
return p.errors
}
// HasErrors returns true if there are any parsing errors
func (p *Parser) HasErrors() bool {
return len(p.errors) > 0
}
// ErrorStrings returns error messages as strings for backward compatibility
func (p *Parser) ErrorStrings() []string {
result := make([]string, len(p.errors))
for i, err := range p.errors {
@ -1089,7 +1072,7 @@ func (p *Parser) ErrorStrings() []string {
return result
}
// tokenTypeString returns a human-readable string for token types
// tokenTypeString returns human-readable string for token types
func tokenTypeString(t TokenType) string {
switch t {
case IDENT:
@ -1114,6 +1097,8 @@ func tokenTypeString(t TokenType) string {
return "/"
case DOT:
return "."
case COLON:
return ":"
case EQ:
return "=="
case NOT_EQ:

View File

@ -81,8 +81,8 @@ func TestFunctionParameters(t *testing.T) {
}
for i, expected := range tt.params {
if fn.Parameters[i] != expected {
t.Errorf("parameter %d: expected %s, got %s", i, expected, fn.Parameters[i])
if fn.Parameters[i].Name != expected {
t.Errorf("parameter %d: expected %s, got %s", i, expected, fn.Parameters[i].Name)
}
}

533
parser/tests/types_test.go Normal file
View File

@ -0,0 +1,533 @@
package parser_test
import (
"testing"
"git.sharkk.net/Sharkk/Mako/parser"
)
func TestVariableTypeHints(t *testing.T) {
tests := []struct {
input string
variable string
typeHint string
hasHint bool
desc string
}{
{"x = 42", "x", "", false, "no type hint"},
{"x: number = 42", "x", "number", true, "number type hint"},
{"name: string = \"hello\"", "name", "string", true, "string type hint"},
{"flag: bool = true", "flag", "bool", true, "bool type hint"},
{"data: table = {}", "data", "table", true, "table type hint"},
{"fn_var: function = fn() end", "fn_var", "function", true, "function type hint"},
{"value: any = nil", "value", "any", true, "any type hint"},
{"ptr: nil = nil", "ptr", "nil", true, "nil type hint"},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
l := parser.NewLexer(tt.input)
p := parser.NewParser(l)
program := p.ParseProgram()
checkParserErrors(t, p)
if len(program.Statements) != 1 {
t.Fatalf("expected 1 statement, got %d", len(program.Statements))
}
stmt, ok := program.Statements[0].(*parser.AssignStatement)
if !ok {
t.Fatalf("expected AssignStatement, got %T", program.Statements[0])
}
// Check variable name
ident, ok := stmt.Name.(*parser.Identifier)
if !ok {
t.Fatalf("expected Identifier for Name, got %T", stmt.Name)
}
if ident.Value != tt.variable {
t.Errorf("expected variable %s, got %s", tt.variable, ident.Value)
}
// Check type hint
if tt.hasHint {
if stmt.TypeHint == nil {
t.Error("expected type hint but got nil")
} else {
if stmt.TypeHint.Type != tt.typeHint {
t.Errorf("expected type hint %s, got %s", tt.typeHint, stmt.TypeHint.Type)
}
if stmt.TypeHint.Inferred {
t.Error("expected type hint to not be inferred")
}
}
} else {
if stmt.TypeHint != nil {
t.Errorf("expected no type hint but got %s", stmt.TypeHint.Type)
}
}
})
}
}
func TestFunctionParameterTypeHints(t *testing.T) {
tests := []struct {
input string
params []struct{ name, typeHint string }
returnType string
hasReturn bool
desc string
}{
{
"fn(a, b) end",
[]struct{ name, typeHint string }{
{"a", ""},
{"b", ""},
},
"", false,
"no type hints",
},
{
"fn(a: number, b: string) end",
[]struct{ name, typeHint string }{
{"a", "number"},
{"b", "string"},
},
"", false,
"parameter type hints only",
},
{
"fn(x: number): string end",
[]struct{ name, typeHint string }{
{"x", "number"},
},
"string", true,
"parameter and return type hints",
},
{
"fn(): bool end",
[]struct{ name, typeHint string }{},
"bool", true,
"return type hint only",
},
{
"fn(a: number, b, c: string): table end",
[]struct{ name, typeHint string }{
{"a", "number"},
{"b", ""},
{"c", "string"},
},
"table", true,
"mixed parameter types with return",
},
{
"fn(callback: function, data: any): nil end",
[]struct{ name, typeHint string }{
{"callback", "function"},
{"data", "any"},
},
"nil", true,
"function and any types",
},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
l := parser.NewLexer(tt.input)
p := parser.NewParser(l)
expr := p.ParseExpression(parser.LOWEST)
checkParserErrors(t, p)
fn, ok := expr.(*parser.FunctionLiteral)
if !ok {
t.Fatalf("expected FunctionLiteral, got %T", expr)
}
// Check parameters
if len(fn.Parameters) != len(tt.params) {
t.Fatalf("expected %d parameters, got %d", len(tt.params), len(fn.Parameters))
}
for i, expected := range tt.params {
param := fn.Parameters[i]
if param.Name != expected.name {
t.Errorf("parameter %d: expected name %s, got %s", i, expected.name, param.Name)
}
if expected.typeHint == "" {
if param.TypeHint != nil {
t.Errorf("parameter %d: expected no type hint but got %s", i, param.TypeHint.Type)
}
} else {
if param.TypeHint == nil {
t.Errorf("parameter %d: expected type hint %s but got nil", i, expected.typeHint)
} else if param.TypeHint.Type != expected.typeHint {
t.Errorf("parameter %d: expected type hint %s, got %s", i, expected.typeHint, param.TypeHint.Type)
}
}
}
// Check return type
if tt.hasReturn {
if fn.ReturnType == nil {
t.Error("expected return type hint but got nil")
} else if fn.ReturnType.Type != tt.returnType {
t.Errorf("expected return type %s, got %s", tt.returnType, fn.ReturnType.Type)
}
} else {
if fn.ReturnType != nil {
t.Errorf("expected no return type but got %s", fn.ReturnType.Type)
}
}
})
}
}
func TestTypeHintStringRepresentation(t *testing.T) {
tests := []struct {
input string
expected string
desc string
}{
{"x: number = 42", "local x: number = 42.00", "typed variable assignment"},
{"fn(a: number, b: string): bool end", "fn(a: number, b: string): bool\nend", "typed function"},
{"callback: function = fn(x: number): string return \"\" end", "local callback: function = fn(x: number): string\n\treturn \"\"\nend", "typed function assignment"},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
l := parser.NewLexer(tt.input)
p := parser.NewParser(l)
var result string
if tt.input[0] == 'f' { // function literal
expr := p.ParseExpression(parser.LOWEST)
checkParserErrors(t, p)
result = expr.String()
} else { // assignment statement
program := p.ParseProgram()
checkParserErrors(t, p)
result = program.Statements[0].String()
}
if result != tt.expected {
t.Errorf("expected %q, got %q", tt.expected, result)
}
})
}
}
func TestTypeHintValidation(t *testing.T) {
tests := []struct {
input string
expectedError string
desc string
}{
{"x: invalid = 42", "invalid type name 'invalid'", "invalid type name"},
{"x: = 42", "expected type name after ':'", "missing type name"},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
l := parser.NewLexer(tt.input)
p := parser.NewParser(l)
p.ParseProgram()
if !p.HasErrors() {
t.Fatal("expected parsing errors")
}
errors := p.Errors()
found := false
for _, err := range errors {
if err.Message == tt.expectedError {
found = true
break
}
}
if !found {
errorMsgs := make([]string, len(errors))
for i, err := range errors {
errorMsgs[i] = err.Message
}
t.Errorf("expected error %q, got %v", tt.expectedError, errorMsgs)
}
})
}
}
func TestMemberAccessWithoutTypeHints(t *testing.T) {
tests := []struct {
input string
desc string
}{
{"table.key = 42", "dot notation assignment"},
{"arr[1] = \"hello\"", "bracket notation assignment"},
{"obj.nested.deep = true", "chained dot assignment"},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
l := parser.NewLexer(tt.input)
p := parser.NewParser(l)
program := p.ParseProgram()
checkParserErrors(t, p)
if len(program.Statements) != 1 {
t.Fatalf("expected 1 statement, got %d", len(program.Statements))
}
stmt, ok := program.Statements[0].(*parser.AssignStatement)
if !ok {
t.Fatalf("expected AssignStatement, got %T", program.Statements[0])
}
// Member access should never have type hints
if stmt.TypeHint != nil {
t.Error("member access assignment should not have type hints")
}
// Should not be a declaration
if stmt.IsDeclaration {
t.Error("member access assignment should not be a declaration")
}
})
}
}
func TestTypeInferenceEngine(t *testing.T) {
input := `x: number = 42
name: string = "hello"
fn_var: function = fn(a: number, b: string): bool
result: bool = a > 0
return result
end
echo name`
l := parser.NewLexer(input)
p := parser.NewParser(l)
program := p.ParseProgram()
checkParserErrors(t, p)
// Run type inference
inferrer := parser.NewTypeInferrer()
typeErrors := inferrer.InferTypes(program)
// Should have no type errors for valid code
if len(typeErrors) > 0 {
errorMsgs := make([]string, len(typeErrors))
for i, err := range typeErrors {
errorMsgs[i] = err.Error()
}
t.Errorf("unexpected type errors: %v", errorMsgs)
}
}
func TestTypeInferenceErrors(t *testing.T) {
tests := []struct {
input string
expectedError string
desc string
}{
{
"x: number = \"hello\"",
"cannot assign string to variable of type number",
"type mismatch in assignment",
},
{
"x = 42\ny: string = x",
"cannot assign number to variable of type string",
"type mismatch with inferred type",
},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
l := parser.NewLexer(tt.input)
p := parser.NewParser(l)
program := p.ParseProgram()
checkParserErrors(t, p)
inferrer := parser.NewTypeInferrer()
typeErrors := inferrer.InferTypes(program)
if len(typeErrors) == 0 {
t.Fatal("expected type errors")
}
found := false
for _, err := range typeErrors {
if err.Message == tt.expectedError {
found = true
break
}
}
if !found {
errorMsgs := make([]string, len(typeErrors))
for i, err := range typeErrors {
errorMsgs[i] = err.Message
}
t.Errorf("expected error %q, got %v", tt.expectedError, errorMsgs)
}
})
}
}
func TestVariadicFunctionTypeHints(t *testing.T) {
tests := []struct {
input string
variadic bool
desc string
}{
{"fn(a: number, ...) end", true, "variadic after typed param"},
{"fn(...) end", true, "variadic only"},
{"fn(a: number, b: string) end", false, "no variadic"},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
l := parser.NewLexer(tt.input)
p := parser.NewParser(l)
expr := p.ParseExpression(parser.LOWEST)
checkParserErrors(t, p)
fn, ok := expr.(*parser.FunctionLiteral)
if !ok {
t.Fatalf("expected FunctionLiteral, got %T", expr)
}
if fn.Variadic != tt.variadic {
t.Errorf("expected variadic = %t, got %t", tt.variadic, fn.Variadic)
}
})
}
}
func TestComplexTypeHintProgram(t *testing.T) {
input := `config: table = {
host = "localhost",
port = 8080,
enabled = true
}
handler: function = fn(request: table, callback: function): nil
status: number = 200
if request.method == "GET" then
response: table = {status = status, body = "OK"}
result = callback(response)
end
end
server: table = {
config = config,
handler = handler,
start = fn(self: table): bool
return true
end
}`
l := parser.NewLexer(input)
p := parser.NewParser(l)
program := p.ParseProgram()
checkParserErrors(t, p)
if len(program.Statements) != 3 {
t.Fatalf("expected 3 statements, got %d", len(program.Statements))
}
// Check first statement: config table with typed assignments
configStmt, ok := program.Statements[0].(*parser.AssignStatement)
if !ok {
t.Fatalf("expected AssignStatement, got %T", program.Statements[0])
}
if configStmt.TypeHint == nil || configStmt.TypeHint.Type != "table" {
t.Error("expected table type hint for config")
}
// Check second statement: handler function with typed parameters
handlerStmt, ok := program.Statements[1].(*parser.AssignStatement)
if !ok {
t.Fatalf("expected AssignStatement, got %T", program.Statements[1])
}
if handlerStmt.TypeHint == nil || handlerStmt.TypeHint.Type != "function" {
t.Error("expected function type hint for handler")
}
fn, ok := handlerStmt.Value.(*parser.FunctionLiteral)
if !ok {
t.Fatalf("expected FunctionLiteral value, got %T", handlerStmt.Value)
}
if len(fn.Parameters) != 2 {
t.Fatalf("expected 2 parameters, got %d", len(fn.Parameters))
}
// Check parameter types
if fn.Parameters[0].TypeHint == nil || fn.Parameters[0].TypeHint.Type != "table" {
t.Error("expected table type for request parameter")
}
if fn.Parameters[1].TypeHint == nil || fn.Parameters[1].TypeHint.Type != "function" {
t.Error("expected function type for callback parameter")
}
// Check return type
if fn.ReturnType == nil || fn.ReturnType.Type != "nil" {
t.Error("expected nil return type for handler")
}
// Check third statement: server table
serverStmt, ok := program.Statements[2].(*parser.AssignStatement)
if !ok {
t.Fatalf("expected AssignStatement, got %T", program.Statements[2])
}
if serverStmt.TypeHint == nil || serverStmt.TypeHint.Type != "table" {
t.Error("expected table type hint for server")
}
}
func TestTypeInfoGettersSetters(t *testing.T) {
// Test that all expression types properly implement GetType/SetType
typeInfo := &parser.TypeInfo{Type: "test", Inferred: true}
expressions := []parser.Expression{
&parser.Identifier{Value: "x"},
&parser.NumberLiteral{Value: 42},
&parser.StringLiteral{Value: "hello"},
&parser.BooleanLiteral{Value: true},
&parser.NilLiteral{},
&parser.TableLiteral{},
&parser.FunctionLiteral{},
&parser.CallExpression{Function: &parser.Identifier{Value: "fn"}},
&parser.PrefixExpression{Operator: "-", Right: &parser.NumberLiteral{Value: 1}},
&parser.InfixExpression{Left: &parser.NumberLiteral{Value: 1}, Operator: "+", Right: &parser.NumberLiteral{Value: 2}},
&parser.IndexExpression{Left: &parser.Identifier{Value: "arr"}, Index: &parser.NumberLiteral{Value: 0}},
&parser.DotExpression{Left: &parser.Identifier{Value: "obj"}, Key: "prop"},
}
for i, expr := range expressions {
t.Run(string(rune('0'+i)), func(t *testing.T) {
// Initially should have no type
if expr.GetType() != nil {
t.Error("expected nil type initially")
}
// Set type
expr.SetType(typeInfo)
// Get type should return what we set
retrieved := expr.GetType()
if retrieved == nil {
t.Error("expected non-nil type after setting")
} else if retrieved.Type != "test" || !retrieved.Inferred {
t.Errorf("expected {Type: test, Inferred: true}, got %+v", retrieved)
}
})
}
}

View File

@ -42,6 +42,7 @@ const (
RBRACKET // ]
COMMA // ,
ELLIPSIS // ...
COLON // :
// Keywords
IF

591
parser/types.go Normal file
View File

@ -0,0 +1,591 @@
package parser
import (
"fmt"
)
// Type constants for built-in types
const (
TypeNumber = "number"
TypeString = "string"
TypeBool = "bool"
TypeNil = "nil"
TypeTable = "table"
TypeFunction = "function"
TypeAny = "any"
)
// TypeError represents a type checking error
type TypeError struct {
Message string
Line int
Column int
Node Node
}
func (te TypeError) Error() string {
return fmt.Sprintf("Type error at line %d, column %d: %s", te.Line, te.Column, te.Message)
}
// Symbol represents a variable in the symbol table
type Symbol struct {
Name string
Type *TypeInfo
Declared bool
Line int
Column int
}
// Scope represents a scope in the symbol table
type Scope struct {
symbols map[string]*Symbol
parent *Scope
}
func NewScope(parent *Scope) *Scope {
return &Scope{
symbols: make(map[string]*Symbol),
parent: parent,
}
}
func (s *Scope) Define(symbol *Symbol) {
s.symbols[symbol.Name] = symbol
}
func (s *Scope) Lookup(name string) *Symbol {
if symbol, ok := s.symbols[name]; ok {
return symbol
}
if s.parent != nil {
return s.parent.Lookup(name)
}
return nil
}
// TypeInferrer performs type inference and checking
type TypeInferrer struct {
currentScope *Scope
globalScope *Scope
errors []TypeError
// Pre-allocated type objects for performance
numberType *TypeInfo
stringType *TypeInfo
boolType *TypeInfo
nilType *TypeInfo
tableType *TypeInfo
anyType *TypeInfo
}
// NewTypeInferrer creates a new type inference engine
func NewTypeInferrer() *TypeInferrer {
globalScope := NewScope(nil)
ti := &TypeInferrer{
currentScope: globalScope,
globalScope: globalScope,
errors: []TypeError{},
// Pre-allocate common types to reduce allocations
numberType: &TypeInfo{Type: TypeNumber, Inferred: true},
stringType: &TypeInfo{Type: TypeString, Inferred: true},
boolType: &TypeInfo{Type: TypeBool, Inferred: true},
nilType: &TypeInfo{Type: TypeNil, Inferred: true},
tableType: &TypeInfo{Type: TypeTable, Inferred: true},
anyType: &TypeInfo{Type: TypeAny, Inferred: true},
}
return ti
}
// InferTypes performs type inference on the entire program
func (ti *TypeInferrer) InferTypes(program *Program) []TypeError {
for _, stmt := range program.Statements {
ti.inferStatement(stmt)
}
return ti.errors
}
// enterScope creates a new scope
func (ti *TypeInferrer) enterScope() {
ti.currentScope = NewScope(ti.currentScope)
}
// exitScope returns to the parent scope
func (ti *TypeInferrer) exitScope() {
if ti.currentScope.parent != nil {
ti.currentScope = ti.currentScope.parent
}
}
// addError adds a type error
func (ti *TypeInferrer) addError(message string, node Node) {
ti.errors = append(ti.errors, TypeError{
Message: message,
Line: 0, // Would need to track position in AST nodes
Column: 0,
Node: node,
})
}
// inferStatement infers types for statements
func (ti *TypeInferrer) inferStatement(stmt Statement) {
switch s := stmt.(type) {
case *AssignStatement:
ti.inferAssignStatement(s)
case *EchoStatement:
ti.inferExpression(s.Value)
case *IfStatement:
ti.inferIfStatement(s)
case *WhileStatement:
ti.inferWhileStatement(s)
case *ForStatement:
ti.inferForStatement(s)
case *ForInStatement:
ti.inferForInStatement(s)
case *ReturnStatement:
if s.Value != nil {
ti.inferExpression(s.Value)
}
case *ExitStatement:
if s.Value != nil {
ti.inferExpression(s.Value)
}
}
}
// inferAssignStatement handles variable assignments with type checking
func (ti *TypeInferrer) inferAssignStatement(stmt *AssignStatement) {
// Infer the type of the value expression
valueType := ti.inferExpression(stmt.Value)
if ident, ok := stmt.Name.(*Identifier); ok {
// Simple variable assignment
symbol := ti.currentScope.Lookup(ident.Value)
if stmt.IsDeclaration {
// New variable declaration
varType := valueType
// If there's a type hint, validate it
if stmt.TypeHint != nil {
if !ti.isTypeCompatible(valueType, stmt.TypeHint) {
ti.addError(fmt.Sprintf("cannot assign %s to variable of type %s",
valueType.Type, stmt.TypeHint.Type), stmt)
}
varType = stmt.TypeHint
varType.Inferred = false
}
// Define the new symbol
ti.currentScope.Define(&Symbol{
Name: ident.Value,
Type: varType,
Declared: true,
})
ident.SetType(varType)
} else {
// Assignment to existing variable
if symbol == nil {
ti.addError(fmt.Sprintf("undefined variable '%s'", ident.Value), stmt)
return
}
// Check type compatibility
if !ti.isTypeCompatible(valueType, symbol.Type) {
ti.addError(fmt.Sprintf("cannot assign %s to variable of type %s",
valueType.Type, symbol.Type.Type), stmt)
}
ident.SetType(symbol.Type)
}
} else {
// Member access assignment (table.key or table[index])
ti.inferExpression(stmt.Name)
}
}
// inferIfStatement handles if statements
func (ti *TypeInferrer) inferIfStatement(stmt *IfStatement) {
condType := ti.inferExpression(stmt.Condition)
ti.validateBooleanContext(condType, stmt.Condition)
ti.enterScope()
for _, s := range stmt.Body {
ti.inferStatement(s)
}
ti.exitScope()
for _, elseif := range stmt.ElseIfs {
condType := ti.inferExpression(elseif.Condition)
ti.validateBooleanContext(condType, elseif.Condition)
ti.enterScope()
for _, s := range elseif.Body {
ti.inferStatement(s)
}
ti.exitScope()
}
if len(stmt.Else) > 0 {
ti.enterScope()
for _, s := range stmt.Else {
ti.inferStatement(s)
}
ti.exitScope()
}
}
// inferWhileStatement handles while loops
func (ti *TypeInferrer) inferWhileStatement(stmt *WhileStatement) {
condType := ti.inferExpression(stmt.Condition)
ti.validateBooleanContext(condType, stmt.Condition)
ti.enterScope()
for _, s := range stmt.Body {
ti.inferStatement(s)
}
ti.exitScope()
}
// inferForStatement handles numeric for loops
func (ti *TypeInferrer) inferForStatement(stmt *ForStatement) {
startType := ti.inferExpression(stmt.Start)
endType := ti.inferExpression(stmt.End)
if !ti.isNumericType(startType) {
ti.addError("for loop start value must be numeric", stmt.Start)
}
if !ti.isNumericType(endType) {
ti.addError("for loop end value must be numeric", stmt.End)
}
if stmt.Step != nil {
stepType := ti.inferExpression(stmt.Step)
if !ti.isNumericType(stepType) {
ti.addError("for loop step value must be numeric", stmt.Step)
}
}
ti.enterScope()
// Define loop variable as number
ti.currentScope.Define(&Symbol{
Name: stmt.Variable.Value,
Type: ti.numberType,
Declared: true,
})
stmt.Variable.SetType(ti.numberType)
for _, s := range stmt.Body {
ti.inferStatement(s)
}
ti.exitScope()
}
// inferForInStatement handles for-in loops
func (ti *TypeInferrer) inferForInStatement(stmt *ForInStatement) {
iterableType := ti.inferExpression(stmt.Iterable)
// For now, assume iterable is a table
if !ti.isTableType(iterableType) {
ti.addError("for-in requires an iterable (table)", stmt.Iterable)
}
ti.enterScope()
// Define loop variables (key and value are any for now)
if stmt.Key != nil {
ti.currentScope.Define(&Symbol{
Name: stmt.Key.Value,
Type: ti.anyType,
Declared: true,
})
stmt.Key.SetType(ti.anyType)
}
ti.currentScope.Define(&Symbol{
Name: stmt.Value.Value,
Type: ti.anyType,
Declared: true,
})
stmt.Value.SetType(ti.anyType)
for _, s := range stmt.Body {
ti.inferStatement(s)
}
ti.exitScope()
}
// inferExpression infers the type of an expression
func (ti *TypeInferrer) inferExpression(expr Expression) *TypeInfo {
if expr == nil {
return ti.nilType
}
switch e := expr.(type) {
case *Identifier:
return ti.inferIdentifier(e)
case *NumberLiteral:
e.SetType(ti.numberType)
return ti.numberType
case *StringLiteral:
e.SetType(ti.stringType)
return ti.stringType
case *BooleanLiteral:
e.SetType(ti.boolType)
return ti.boolType
case *NilLiteral:
e.SetType(ti.nilType)
return ti.nilType
case *TableLiteral:
return ti.inferTableLiteral(e)
case *FunctionLiteral:
return ti.inferFunctionLiteral(e)
case *CallExpression:
return ti.inferCallExpression(e)
case *PrefixExpression:
return ti.inferPrefixExpression(e)
case *InfixExpression:
return ti.inferInfixExpression(e)
case *IndexExpression:
return ti.inferIndexExpression(e)
case *DotExpression:
return ti.inferDotExpression(e)
default:
ti.addError("unknown expression type", expr)
return ti.anyType
}
}
// inferIdentifier looks up identifier type in symbol table
func (ti *TypeInferrer) inferIdentifier(ident *Identifier) *TypeInfo {
symbol := ti.currentScope.Lookup(ident.Value)
if symbol == nil {
ti.addError(fmt.Sprintf("undefined variable '%s'", ident.Value), ident)
return ti.anyType
}
ident.SetType(symbol.Type)
return symbol.Type
}
// inferTableLiteral infers table type
func (ti *TypeInferrer) inferTableLiteral(table *TableLiteral) *TypeInfo {
// Infer types of all values
for _, pair := range table.Pairs {
if pair.Key != nil {
ti.inferExpression(pair.Key)
}
ti.inferExpression(pair.Value)
}
table.SetType(ti.tableType)
return ti.tableType
}
// inferFunctionLiteral infers function type
func (ti *TypeInferrer) inferFunctionLiteral(fn *FunctionLiteral) *TypeInfo {
ti.enterScope()
// Define parameters in function scope
for _, param := range fn.Parameters {
paramType := ti.anyType
if param.TypeHint != nil {
paramType = param.TypeHint
}
ti.currentScope.Define(&Symbol{
Name: param.Name,
Type: paramType,
Declared: true,
})
}
// Infer body
for _, stmt := range fn.Body {
ti.inferStatement(stmt)
}
ti.exitScope()
// For now, all functions have type "function"
funcType := &TypeInfo{Type: TypeFunction, Inferred: true}
fn.SetType(funcType)
return funcType
}
// inferCallExpression infers function call return type
func (ti *TypeInferrer) inferCallExpression(call *CallExpression) *TypeInfo {
funcType := ti.inferExpression(call.Function)
if !ti.isFunctionType(funcType) {
ti.addError("cannot call non-function", call.Function)
return ti.anyType
}
// Infer argument types
for _, arg := range call.Arguments {
ti.inferExpression(arg)
}
// For now, assume function calls return any
call.SetType(ti.anyType)
return ti.anyType
}
// inferPrefixExpression infers prefix operation type
func (ti *TypeInferrer) inferPrefixExpression(prefix *PrefixExpression) *TypeInfo {
rightType := ti.inferExpression(prefix.Right)
var resultType *TypeInfo
switch prefix.Operator {
case "-":
if !ti.isNumericType(rightType) {
ti.addError("unary minus requires numeric operand", prefix)
}
resultType = ti.numberType
case "not":
resultType = ti.boolType
default:
ti.addError(fmt.Sprintf("unknown prefix operator '%s'", prefix.Operator), prefix)
resultType = ti.anyType
}
prefix.SetType(resultType)
return resultType
}
// inferInfixExpression infers binary operation type
func (ti *TypeInferrer) inferInfixExpression(infix *InfixExpression) *TypeInfo {
leftType := ti.inferExpression(infix.Left)
rightType := ti.inferExpression(infix.Right)
var resultType *TypeInfo
switch infix.Operator {
case "+", "-", "*", "/":
if !ti.isNumericType(leftType) || !ti.isNumericType(rightType) {
ti.addError(fmt.Sprintf("arithmetic operator '%s' requires numeric operands", infix.Operator), infix)
}
resultType = ti.numberType
case "==", "!=":
// Equality works with any types
resultType = ti.boolType
case "<", ">", "<=", ">=":
if !ti.isComparableTypes(leftType, rightType) {
ti.addError(fmt.Sprintf("comparison operator '%s' requires compatible operands", infix.Operator), infix)
}
resultType = ti.boolType
case "and", "or":
ti.validateBooleanContext(leftType, infix.Left)
ti.validateBooleanContext(rightType, infix.Right)
resultType = ti.boolType
default:
ti.addError(fmt.Sprintf("unknown infix operator '%s'", infix.Operator), infix)
resultType = ti.anyType
}
infix.SetType(resultType)
return resultType
}
// inferIndexExpression infers table[index] type
func (ti *TypeInferrer) inferIndexExpression(index *IndexExpression) *TypeInfo {
ti.inferExpression(index.Left)
ti.inferExpression(index.Index)
// For now, assume table access returns any
index.SetType(ti.anyType)
return ti.anyType
}
// inferDotExpression infers table.key type
func (ti *TypeInferrer) inferDotExpression(dot *DotExpression) *TypeInfo {
ti.inferExpression(dot.Left)
// For now, assume member access returns any
dot.SetType(ti.anyType)
return ti.anyType
}
// Type checking helper methods
func (ti *TypeInferrer) isTypeCompatible(valueType, targetType *TypeInfo) bool {
if targetType.Type == TypeAny || valueType.Type == TypeAny {
return true
}
return valueType.Type == targetType.Type
}
func (ti *TypeInferrer) isNumericType(t *TypeInfo) bool {
return t.Type == TypeNumber
}
func (ti *TypeInferrer) isBooleanType(t *TypeInfo) bool {
return t.Type == TypeBool
}
func (ti *TypeInferrer) isTableType(t *TypeInfo) bool {
return t.Type == TypeTable
}
func (ti *TypeInferrer) isFunctionType(t *TypeInfo) bool {
return t.Type == TypeFunction
}
func (ti *TypeInferrer) isComparableTypes(left, right *TypeInfo) bool {
if left.Type == TypeAny || right.Type == TypeAny {
return true
}
return left.Type == right.Type && (left.Type == TypeNumber || left.Type == TypeString)
}
func (ti *TypeInferrer) validateBooleanContext(t *TypeInfo, expr Expression) {
// In many languages, non-boolean values can be used in boolean context
// For strictness, we could require boolean type here
// For now, allow any type (truthy/falsy semantics)
}
// Errors returns all type checking errors
func (ti *TypeInferrer) Errors() []TypeError {
return ti.errors
}
// HasErrors returns true if there are any type errors
func (ti *TypeInferrer) HasErrors() bool {
return len(ti.errors) > 0
}
// ErrorStrings returns error messages as strings
func (ti *TypeInferrer) ErrorStrings() []string {
result := make([]string, len(ti.errors))
for i, err := range ti.errors {
result[i] = err.Error()
}
return result
}
// ValidTypeName checks if a string is a valid type name
func ValidTypeName(name string) bool {
validTypes := []string{TypeNumber, TypeString, TypeBool, TypeNil, TypeTable, TypeFunction, TypeAny}
for _, validType := range validTypes {
if name == validType {
return true
}
}
return false
}
// ParseTypeName converts a string to a TypeInfo (for parsing type hints)
func ParseTypeName(name string) *TypeInfo {
if ValidTypeName(name) {
return &TypeInfo{Type: name, Inferred: false}
}
return nil
}