type hint/inference
This commit is contained in:
parent
c691c90c69
commit
e98e0643e2
125
parser/ast.go
125
parser/ast.go
@ -2,6 +2,12 @@ package parser
|
||||
|
||||
import "fmt"
|
||||
|
||||
// TypeInfo represents type information for expressions
|
||||
type TypeInfo struct {
|
||||
Type string // "number", "string", "bool", "table", "function", "nil", "any"
|
||||
Inferred bool // true if type was inferred, false if explicitly declared
|
||||
}
|
||||
|
||||
// Node represents any node in the AST
|
||||
type Node interface {
|
||||
String() string
|
||||
@ -17,6 +23,8 @@ type Statement interface {
|
||||
type Expression interface {
|
||||
Node
|
||||
expressionNode()
|
||||
GetType() *TypeInfo
|
||||
SetType(*TypeInfo)
|
||||
}
|
||||
|
||||
// Program represents the root of the AST
|
||||
@ -33,9 +41,10 @@ func (p *Program) String() string {
|
||||
return result
|
||||
}
|
||||
|
||||
// AssignStatement represents variable assignment
|
||||
// AssignStatement represents variable assignment with optional type hint
|
||||
type AssignStatement struct {
|
||||
Name Expression // Changed from *Identifier to Expression for member access
|
||||
TypeHint *TypeInfo // optional type hint
|
||||
Value Expression
|
||||
IsDeclaration bool // true if this is the first assignment in current scope
|
||||
}
|
||||
@ -46,7 +55,15 @@ func (as *AssignStatement) String() string {
|
||||
if as.IsDeclaration {
|
||||
prefix = "local "
|
||||
}
|
||||
return fmt.Sprintf("%s%s = %s", prefix, as.Name.String(), as.Value.String())
|
||||
|
||||
var nameStr string
|
||||
if as.TypeHint != nil {
|
||||
nameStr = fmt.Sprintf("%s: %s", as.Name.String(), as.TypeHint.Type)
|
||||
} else {
|
||||
nameStr = as.Name.String()
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s%s = %s", prefix, nameStr, as.Value.String())
|
||||
}
|
||||
|
||||
// EchoStatement represents echo output statements
|
||||
@ -216,33 +233,56 @@ func (fis *ForInStatement) String() string {
|
||||
return result
|
||||
}
|
||||
|
||||
// Identifier represents identifiers
|
||||
type Identifier struct {
|
||||
Value string
|
||||
// FunctionParameter represents a function parameter with optional type hint
|
||||
type FunctionParameter struct {
|
||||
Name string
|
||||
TypeHint *TypeInfo
|
||||
}
|
||||
|
||||
func (i *Identifier) expressionNode() {}
|
||||
func (i *Identifier) String() string { return i.Value }
|
||||
func (fp *FunctionParameter) String() string {
|
||||
if fp.TypeHint != nil {
|
||||
return fmt.Sprintf("%s: %s", fp.Name, fp.TypeHint.Type)
|
||||
}
|
||||
return fp.Name
|
||||
}
|
||||
|
||||
// Identifier represents identifiers
|
||||
type Identifier struct {
|
||||
Value string
|
||||
typeInfo *TypeInfo
|
||||
}
|
||||
|
||||
func (i *Identifier) expressionNode() {}
|
||||
func (i *Identifier) String() string { return i.Value }
|
||||
func (i *Identifier) GetType() *TypeInfo { return i.typeInfo }
|
||||
func (i *Identifier) SetType(t *TypeInfo) { i.typeInfo = t }
|
||||
|
||||
// NumberLiteral represents numeric literals
|
||||
type NumberLiteral struct {
|
||||
Value float64
|
||||
Value float64
|
||||
typeInfo *TypeInfo
|
||||
}
|
||||
|
||||
func (nl *NumberLiteral) expressionNode() {}
|
||||
func (nl *NumberLiteral) String() string { return fmt.Sprintf("%.2f", nl.Value) }
|
||||
func (nl *NumberLiteral) expressionNode() {}
|
||||
func (nl *NumberLiteral) String() string { return fmt.Sprintf("%.2f", nl.Value) }
|
||||
func (nl *NumberLiteral) GetType() *TypeInfo { return nl.typeInfo }
|
||||
func (nl *NumberLiteral) SetType(t *TypeInfo) { nl.typeInfo = t }
|
||||
|
||||
// StringLiteral represents string literals
|
||||
type StringLiteral struct {
|
||||
Value string
|
||||
Value string
|
||||
typeInfo *TypeInfo
|
||||
}
|
||||
|
||||
func (sl *StringLiteral) expressionNode() {}
|
||||
func (sl *StringLiteral) String() string { return fmt.Sprintf(`"%s"`, sl.Value) }
|
||||
func (sl *StringLiteral) expressionNode() {}
|
||||
func (sl *StringLiteral) String() string { return fmt.Sprintf(`"%s"`, sl.Value) }
|
||||
func (sl *StringLiteral) GetType() *TypeInfo { return sl.typeInfo }
|
||||
func (sl *StringLiteral) SetType(t *TypeInfo) { sl.typeInfo = t }
|
||||
|
||||
// BooleanLiteral represents boolean literals
|
||||
type BooleanLiteral struct {
|
||||
Value bool
|
||||
Value bool
|
||||
typeInfo *TypeInfo
|
||||
}
|
||||
|
||||
func (bl *BooleanLiteral) expressionNode() {}
|
||||
@ -252,18 +292,26 @@ func (bl *BooleanLiteral) String() string {
|
||||
}
|
||||
return "false"
|
||||
}
|
||||
func (bl *BooleanLiteral) GetType() *TypeInfo { return bl.typeInfo }
|
||||
func (bl *BooleanLiteral) SetType(t *TypeInfo) { bl.typeInfo = t }
|
||||
|
||||
// NilLiteral represents nil literal
|
||||
type NilLiteral struct{}
|
||||
type NilLiteral struct {
|
||||
typeInfo *TypeInfo
|
||||
}
|
||||
|
||||
func (nl *NilLiteral) expressionNode() {}
|
||||
func (nl *NilLiteral) String() string { return "nil" }
|
||||
func (nl *NilLiteral) expressionNode() {}
|
||||
func (nl *NilLiteral) String() string { return "nil" }
|
||||
func (nl *NilLiteral) GetType() *TypeInfo { return nl.typeInfo }
|
||||
func (nl *NilLiteral) SetType(t *TypeInfo) { nl.typeInfo = t }
|
||||
|
||||
// FunctionLiteral represents function literals: fn(a, b, ...) ... end
|
||||
// FunctionLiteral represents function literals with typed parameters
|
||||
type FunctionLiteral struct {
|
||||
Parameters []string
|
||||
Parameters []FunctionParameter
|
||||
Variadic bool
|
||||
ReturnType *TypeInfo // optional return type hint
|
||||
Body []Statement
|
||||
typeInfo *TypeInfo
|
||||
}
|
||||
|
||||
func (fl *FunctionLiteral) expressionNode() {}
|
||||
@ -273,7 +321,7 @@ func (fl *FunctionLiteral) String() string {
|
||||
if i > 0 {
|
||||
params += ", "
|
||||
}
|
||||
params += param
|
||||
params += param.String()
|
||||
}
|
||||
if fl.Variadic {
|
||||
if len(fl.Parameters) > 0 {
|
||||
@ -282,18 +330,26 @@ func (fl *FunctionLiteral) String() string {
|
||||
params += "..."
|
||||
}
|
||||
|
||||
result := fmt.Sprintf("fn(%s)\n", params)
|
||||
result := fmt.Sprintf("fn(%s)", params)
|
||||
if fl.ReturnType != nil {
|
||||
result += ": " + fl.ReturnType.Type
|
||||
}
|
||||
result += "\n"
|
||||
|
||||
for _, stmt := range fl.Body {
|
||||
result += "\t" + stmt.String() + "\n"
|
||||
}
|
||||
result += "end"
|
||||
return result
|
||||
}
|
||||
func (fl *FunctionLiteral) GetType() *TypeInfo { return fl.typeInfo }
|
||||
func (fl *FunctionLiteral) SetType(t *TypeInfo) { fl.typeInfo = t }
|
||||
|
||||
// CallExpression represents function calls: func(arg1, arg2, ...)
|
||||
type CallExpression struct {
|
||||
Function Expression
|
||||
Arguments []Expression
|
||||
typeInfo *TypeInfo
|
||||
}
|
||||
|
||||
func (ce *CallExpression) expressionNode() {}
|
||||
@ -304,11 +360,14 @@ func (ce *CallExpression) String() string {
|
||||
}
|
||||
return fmt.Sprintf("%s(%s)", ce.Function.String(), joinStrings(args, ", "))
|
||||
}
|
||||
func (ce *CallExpression) GetType() *TypeInfo { return ce.typeInfo }
|
||||
func (ce *CallExpression) SetType(t *TypeInfo) { ce.typeInfo = t }
|
||||
|
||||
// PrefixExpression represents prefix operations like -x, not x
|
||||
type PrefixExpression struct {
|
||||
Operator string
|
||||
Right Expression
|
||||
typeInfo *TypeInfo
|
||||
}
|
||||
|
||||
func (pe *PrefixExpression) expressionNode() {}
|
||||
@ -319,40 +378,51 @@ func (pe *PrefixExpression) String() string {
|
||||
}
|
||||
return fmt.Sprintf("(%s%s)", pe.Operator, pe.Right.String())
|
||||
}
|
||||
func (pe *PrefixExpression) GetType() *TypeInfo { return pe.typeInfo }
|
||||
func (pe *PrefixExpression) SetType(t *TypeInfo) { pe.typeInfo = t }
|
||||
|
||||
// InfixExpression represents binary operations
|
||||
type InfixExpression struct {
|
||||
Left Expression
|
||||
Operator string
|
||||
Right Expression
|
||||
typeInfo *TypeInfo
|
||||
}
|
||||
|
||||
func (ie *InfixExpression) expressionNode() {}
|
||||
func (ie *InfixExpression) String() string {
|
||||
return fmt.Sprintf("(%s %s %s)", ie.Left.String(), ie.Operator, ie.Right.String())
|
||||
}
|
||||
func (ie *InfixExpression) GetType() *TypeInfo { return ie.typeInfo }
|
||||
func (ie *InfixExpression) SetType(t *TypeInfo) { ie.typeInfo = t }
|
||||
|
||||
// IndexExpression represents table[key] access
|
||||
type IndexExpression struct {
|
||||
Left Expression
|
||||
Index Expression
|
||||
Left Expression
|
||||
Index Expression
|
||||
typeInfo *TypeInfo
|
||||
}
|
||||
|
||||
func (ie *IndexExpression) expressionNode() {}
|
||||
func (ie *IndexExpression) String() string {
|
||||
return fmt.Sprintf("%s[%s]", ie.Left.String(), ie.Index.String())
|
||||
}
|
||||
func (ie *IndexExpression) GetType() *TypeInfo { return ie.typeInfo }
|
||||
func (ie *IndexExpression) SetType(t *TypeInfo) { ie.typeInfo = t }
|
||||
|
||||
// DotExpression represents table.key access
|
||||
type DotExpression struct {
|
||||
Left Expression
|
||||
Key string
|
||||
Left Expression
|
||||
Key string
|
||||
typeInfo *TypeInfo
|
||||
}
|
||||
|
||||
func (de *DotExpression) expressionNode() {}
|
||||
func (de *DotExpression) String() string {
|
||||
return fmt.Sprintf("%s.%s", de.Left.String(), de.Key)
|
||||
}
|
||||
func (de *DotExpression) GetType() *TypeInfo { return de.typeInfo }
|
||||
func (de *DotExpression) SetType(t *TypeInfo) { de.typeInfo = t }
|
||||
|
||||
// TablePair represents a key-value pair in a table
|
||||
type TablePair struct {
|
||||
@ -369,7 +439,8 @@ func (tp *TablePair) String() string {
|
||||
|
||||
// TableLiteral represents table literals {}
|
||||
type TableLiteral struct {
|
||||
Pairs []TablePair
|
||||
Pairs []TablePair
|
||||
typeInfo *TypeInfo
|
||||
}
|
||||
|
||||
func (tl *TableLiteral) expressionNode() {}
|
||||
@ -380,6 +451,8 @@ func (tl *TableLiteral) String() string {
|
||||
}
|
||||
return fmt.Sprintf("{%s}", joinStrings(pairs, ", "))
|
||||
}
|
||||
func (tl *TableLiteral) GetType() *TypeInfo { return tl.typeInfo }
|
||||
func (tl *TableLiteral) SetType(t *TypeInfo) { tl.typeInfo = t }
|
||||
|
||||
// IsArray returns true if this table contains only array-style elements
|
||||
func (tl *TableLiteral) IsArray() bool {
|
||||
|
@ -257,6 +257,8 @@ func (l *Lexer) NextToken() Token {
|
||||
tok = Token{Type: STAR, Literal: string(l.ch), Line: l.line, Column: l.column}
|
||||
case '/':
|
||||
tok = Token{Type: SLASH, Literal: string(l.ch), Line: l.line, Column: l.column}
|
||||
case ':':
|
||||
tok = Token{Type: COLON, Literal: string(l.ch), Line: l.line, Column: l.column}
|
||||
case '.':
|
||||
// Check for ellipsis (...)
|
||||
if l.peekChar() == '.' && l.peekCharAt(2) == '.' {
|
||||
|
191
parser/parser.go
191
parser/parser.go
@ -89,14 +89,13 @@ func (p *Parser) enterScope(scopeType string) {
|
||||
}
|
||||
|
||||
func (p *Parser) exitScope() {
|
||||
if len(p.scopes) > 1 { // never remove global scope
|
||||
if len(p.scopes) > 1 {
|
||||
p.scopes = p.scopes[:len(p.scopes)-1]
|
||||
p.scopeTypes = p.scopeTypes[:len(p.scopeTypes)-1]
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Parser) enterFunctionScope() {
|
||||
// Functions create new variable scopes
|
||||
p.enterScope("function")
|
||||
}
|
||||
|
||||
@ -105,35 +104,29 @@ func (p *Parser) exitFunctionScope() {
|
||||
}
|
||||
|
||||
func (p *Parser) enterLoopScope() {
|
||||
// Create temporary scope for loop variables only
|
||||
p.enterScope("loop")
|
||||
}
|
||||
|
||||
func (p *Parser) exitLoopScope() {
|
||||
// Remove temporary loop scope
|
||||
p.exitScope()
|
||||
}
|
||||
|
||||
func (p *Parser) enterBlockScope() {
|
||||
// Blocks don't create new variable scopes, just control flow scopes
|
||||
// We don't need to track these for variable declarations
|
||||
// Blocks don't create new variable scopes
|
||||
}
|
||||
|
||||
func (p *Parser) exitBlockScope() {
|
||||
// No-op since blocks don't create variable scopes
|
||||
// No-op
|
||||
}
|
||||
|
||||
func (p *Parser) currentVariableScope() map[string]bool {
|
||||
// If we're in a loop scope, declare variables in the parent scope
|
||||
if len(p.scopeTypes) > 1 && p.scopeTypes[len(p.scopeTypes)-1] == "loop" {
|
||||
return p.scopes[len(p.scopes)-2]
|
||||
}
|
||||
// Otherwise use the current scope
|
||||
return p.scopes[len(p.scopes)-1]
|
||||
}
|
||||
|
||||
func (p *Parser) isVariableDeclared(name string) bool {
|
||||
// Check all scopes from current up to global
|
||||
for i := len(p.scopes) - 1; i >= 0; i-- {
|
||||
if p.scopes[i][name] {
|
||||
return true
|
||||
@ -147,10 +140,31 @@ func (p *Parser) declareVariable(name string) {
|
||||
}
|
||||
|
||||
func (p *Parser) declareLoopVariable(name string) {
|
||||
// Loop variables go in the current loop scope
|
||||
p.scopes[len(p.scopes)-1][name] = true
|
||||
}
|
||||
|
||||
// parseTypeHint parses optional type hint after colon
|
||||
func (p *Parser) parseTypeHint() *TypeInfo {
|
||||
if !p.peekTokenIs(COLON) {
|
||||
return nil
|
||||
}
|
||||
|
||||
p.nextToken() // consume ':'
|
||||
|
||||
if !p.expectPeekIdent() {
|
||||
p.addError("expected type name after ':'")
|
||||
return nil
|
||||
}
|
||||
|
||||
typeName := p.curToken.Literal
|
||||
if !ValidTypeName(typeName) {
|
||||
p.addError(fmt.Sprintf("invalid type name '%s'", typeName))
|
||||
return nil
|
||||
}
|
||||
|
||||
return &TypeInfo{Type: typeName, Inferred: false}
|
||||
}
|
||||
|
||||
// registerPrefix registers a prefix parse function
|
||||
func (p *Parser) registerPrefix(tokenType TokenType, fn func() Expression) {
|
||||
p.prefixParseFns[tokenType] = fn
|
||||
@ -187,7 +201,6 @@ func (p *Parser) ParseProgram() *Program {
|
||||
func (p *Parser) parseStatement() Statement {
|
||||
switch p.curToken.Type {
|
||||
case IDENT:
|
||||
// Try to parse as assignment (handles both simple and member access)
|
||||
return p.parseAssignStatement()
|
||||
case IF:
|
||||
return p.parseIfStatement()
|
||||
@ -217,17 +230,22 @@ func (p *Parser) parseStatement() Statement {
|
||||
}
|
||||
}
|
||||
|
||||
// parseAssignStatement parses variable assignment
|
||||
// parseAssignStatement parses variable assignment with optional type hint
|
||||
func (p *Parser) parseAssignStatement() *AssignStatement {
|
||||
stmt := &AssignStatement{}
|
||||
|
||||
// Parse left-hand side expression (can be identifier or member access)
|
||||
// Parse left-hand side expression
|
||||
stmt.Name = p.ParseExpression(LOWEST)
|
||||
if stmt.Name == nil {
|
||||
p.addError("expected expression for assignment left-hand side")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check for type hint on simple identifiers
|
||||
if _, ok := stmt.Name.(*Identifier); ok {
|
||||
stmt.TypeHint = p.parseTypeHint()
|
||||
}
|
||||
|
||||
// Check if next token is assignment operator
|
||||
if !p.peekTokenIs(ASSIGN) {
|
||||
p.addError("unexpected identifier, expected assignment or declaration")
|
||||
@ -237,13 +255,11 @@ func (p *Parser) parseAssignStatement() *AssignStatement {
|
||||
// Validate assignment target and check if it's a declaration
|
||||
switch name := stmt.Name.(type) {
|
||||
case *Identifier:
|
||||
// Simple variable assignment - check if it's a declaration
|
||||
stmt.IsDeclaration = !p.isVariableDeclared(name.Value)
|
||||
if stmt.IsDeclaration {
|
||||
p.declareVariable(name.Value)
|
||||
}
|
||||
case *DotExpression, *IndexExpression:
|
||||
// Member access - never a declaration
|
||||
stmt.IsDeclaration = false
|
||||
default:
|
||||
p.addError("invalid assignment target")
|
||||
@ -289,10 +305,8 @@ func (p *Parser) parseBreakStatement() *BreakStatement {
|
||||
func (p *Parser) parseExitStatement() *ExitStatement {
|
||||
stmt := &ExitStatement{}
|
||||
|
||||
// Check if there's an optional expression after 'exit'
|
||||
// Only parse expression if next token can start an expression
|
||||
if p.canStartExpression(p.peekToken.Type) {
|
||||
p.nextToken() // move past 'exit'
|
||||
p.nextToken()
|
||||
stmt.Value = p.ParseExpression(LOWEST)
|
||||
if stmt.Value == nil {
|
||||
p.addError("expected expression after 'exit'")
|
||||
@ -307,9 +321,8 @@ func (p *Parser) parseExitStatement() *ExitStatement {
|
||||
func (p *Parser) parseReturnStatement() *ReturnStatement {
|
||||
stmt := &ReturnStatement{}
|
||||
|
||||
// Check if there's an optional expression after 'return'
|
||||
if p.canStartExpression(p.peekToken.Type) {
|
||||
p.nextToken() // move past 'return'
|
||||
p.nextToken()
|
||||
stmt.Value = p.ParseExpression(LOWEST)
|
||||
if stmt.Value == nil {
|
||||
p.addError("expected expression after 'return'")
|
||||
@ -330,11 +343,11 @@ func (p *Parser) canStartExpression(tokenType TokenType) bool {
|
||||
}
|
||||
}
|
||||
|
||||
// parseWhileStatement parses while loops: while condition do ... end
|
||||
// parseWhileStatement parses while loops
|
||||
func (p *Parser) parseWhileStatement() *WhileStatement {
|
||||
stmt := &WhileStatement{}
|
||||
|
||||
p.nextToken() // move past 'while'
|
||||
p.nextToken()
|
||||
|
||||
stmt.Condition = p.ParseExpression(LOWEST)
|
||||
if stmt.Condition == nil {
|
||||
@ -347,9 +360,8 @@ func (p *Parser) parseWhileStatement() *WhileStatement {
|
||||
return nil
|
||||
}
|
||||
|
||||
p.nextToken() // move past 'do'
|
||||
p.nextToken()
|
||||
|
||||
// Parse loop body (no new variable scope)
|
||||
p.enterBlockScope()
|
||||
stmt.Body = p.parseBlockStatements(END)
|
||||
p.exitBlockScope()
|
||||
@ -362,9 +374,9 @@ func (p *Parser) parseWhileStatement() *WhileStatement {
|
||||
return stmt
|
||||
}
|
||||
|
||||
// parseForStatement parses for loops (both numeric and for-in)
|
||||
// parseForStatement parses for loops
|
||||
func (p *Parser) parseForStatement() Statement {
|
||||
p.nextToken() // move past 'for'
|
||||
p.nextToken()
|
||||
|
||||
if !p.curTokenIs(IDENT) {
|
||||
p.addError("expected identifier after 'for'")
|
||||
@ -373,12 +385,9 @@ func (p *Parser) parseForStatement() Statement {
|
||||
|
||||
firstVar := &Identifier{Value: p.curToken.Literal}
|
||||
|
||||
// Look ahead to determine which type of for loop
|
||||
if p.peekTokenIs(ASSIGN) {
|
||||
// Numeric for loop: for i = start, end, step do
|
||||
return p.parseNumericForStatement(firstVar)
|
||||
} else if p.peekTokenIs(COMMA) || p.peekTokenIs(IN) {
|
||||
// For-in loop: for k, v in expr do or for v in expr do
|
||||
return p.parseForInStatement(firstVar)
|
||||
} else {
|
||||
p.addError("expected '=', ',' or 'in' after for loop variable")
|
||||
@ -386,7 +395,7 @@ func (p *Parser) parseForStatement() Statement {
|
||||
}
|
||||
}
|
||||
|
||||
// parseNumericForStatement parses numeric for loops: for i = start, end, step do
|
||||
// parseNumericForStatement parses numeric for loops
|
||||
func (p *Parser) parseNumericForStatement(variable *Identifier) *ForStatement {
|
||||
stmt := &ForStatement{Variable: variable}
|
||||
|
||||
@ -394,9 +403,8 @@ func (p *Parser) parseNumericForStatement(variable *Identifier) *ForStatement {
|
||||
return nil
|
||||
}
|
||||
|
||||
p.nextToken() // move past '='
|
||||
p.nextToken()
|
||||
|
||||
// Parse start expression
|
||||
stmt.Start = p.ParseExpression(LOWEST)
|
||||
if stmt.Start == nil {
|
||||
p.addError("expected start expression in for loop")
|
||||
@ -408,19 +416,17 @@ func (p *Parser) parseNumericForStatement(variable *Identifier) *ForStatement {
|
||||
return nil
|
||||
}
|
||||
|
||||
p.nextToken() // move past ','
|
||||
p.nextToken()
|
||||
|
||||
// Parse end expression
|
||||
stmt.End = p.ParseExpression(LOWEST)
|
||||
if stmt.End == nil {
|
||||
p.addError("expected end expression in for loop")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Optional step expression
|
||||
if p.peekTokenIs(COMMA) {
|
||||
p.nextToken() // move to ','
|
||||
p.nextToken() // move past ','
|
||||
p.nextToken()
|
||||
p.nextToken()
|
||||
|
||||
stmt.Step = p.ParseExpression(LOWEST)
|
||||
if stmt.Step == nil {
|
||||
@ -434,13 +440,12 @@ func (p *Parser) parseNumericForStatement(variable *Identifier) *ForStatement {
|
||||
return nil
|
||||
}
|
||||
|
||||
p.nextToken() // move past 'do'
|
||||
p.nextToken()
|
||||
|
||||
// Create temporary scope for loop variable, assignments in body go to parent scope
|
||||
p.enterLoopScope()
|
||||
p.declareLoopVariable(variable.Value) // loop variable in temporary scope
|
||||
p.declareLoopVariable(variable.Value)
|
||||
stmt.Body = p.parseBlockStatements(END)
|
||||
p.exitLoopScope() // discard temporary scope with loop variable
|
||||
p.exitLoopScope()
|
||||
|
||||
if !p.curTokenIs(END) {
|
||||
p.addError("expected 'end' to close for loop")
|
||||
@ -450,15 +455,14 @@ func (p *Parser) parseNumericForStatement(variable *Identifier) *ForStatement {
|
||||
return stmt
|
||||
}
|
||||
|
||||
// parseForInStatement parses for-in loops: for k, v in expr do or for v in expr do
|
||||
// parseForInStatement parses for-in loops
|
||||
func (p *Parser) parseForInStatement(firstVar *Identifier) *ForInStatement {
|
||||
stmt := &ForInStatement{}
|
||||
|
||||
if p.peekTokenIs(COMMA) {
|
||||
// Two variables: for k, v in expr do
|
||||
stmt.Key = firstVar
|
||||
p.nextToken() // move to ','
|
||||
p.nextToken() // move past ','
|
||||
p.nextToken()
|
||||
p.nextToken()
|
||||
|
||||
if !p.curTokenIs(IDENT) {
|
||||
p.addError("expected identifier after ',' in for loop")
|
||||
@ -467,7 +471,6 @@ func (p *Parser) parseForInStatement(firstVar *Identifier) *ForInStatement {
|
||||
|
||||
stmt.Value = &Identifier{Value: p.curToken.Literal}
|
||||
} else {
|
||||
// Single variable: for v in expr do
|
||||
stmt.Value = firstVar
|
||||
}
|
||||
|
||||
@ -476,9 +479,8 @@ func (p *Parser) parseForInStatement(firstVar *Identifier) *ForInStatement {
|
||||
return nil
|
||||
}
|
||||
|
||||
p.nextToken() // move past 'in'
|
||||
p.nextToken()
|
||||
|
||||
// Parse iterable expression
|
||||
stmt.Iterable = p.ParseExpression(LOWEST)
|
||||
if stmt.Iterable == nil {
|
||||
p.addError("expected expression after 'in' in for loop")
|
||||
@ -490,16 +492,15 @@ func (p *Parser) parseForInStatement(firstVar *Identifier) *ForInStatement {
|
||||
return nil
|
||||
}
|
||||
|
||||
p.nextToken() // move past 'do'
|
||||
p.nextToken()
|
||||
|
||||
// Create temporary scope for loop variables, assignments in body go to parent scope
|
||||
p.enterLoopScope()
|
||||
if stmt.Key != nil {
|
||||
p.declareLoopVariable(stmt.Key.Value) // loop variable in temporary scope
|
||||
p.declareLoopVariable(stmt.Key.Value)
|
||||
}
|
||||
p.declareLoopVariable(stmt.Value.Value) // loop variable in temporary scope
|
||||
p.declareLoopVariable(stmt.Value.Value)
|
||||
stmt.Body = p.parseBlockStatements(END)
|
||||
p.exitLoopScope() // discard temporary scope with loop variables
|
||||
p.exitLoopScope()
|
||||
|
||||
if !p.curTokenIs(END) {
|
||||
p.addError("expected 'end' to close for loop")
|
||||
@ -509,11 +510,11 @@ func (p *Parser) parseForInStatement(firstVar *Identifier) *ForInStatement {
|
||||
return stmt
|
||||
}
|
||||
|
||||
// parseIfStatement parses if/elseif/else/end statements
|
||||
// parseIfStatement parses if statements
|
||||
func (p *Parser) parseIfStatement() *IfStatement {
|
||||
stmt := &IfStatement{}
|
||||
|
||||
p.nextToken() // move past 'if'
|
||||
p.nextToken()
|
||||
|
||||
stmt.Condition = p.ParseExpression(LOWEST)
|
||||
if stmt.Condition == nil {
|
||||
@ -521,29 +522,25 @@ func (p *Parser) parseIfStatement() *IfStatement {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Optional 'then' keyword
|
||||
if p.peekTokenIs(THEN) {
|
||||
p.nextToken()
|
||||
}
|
||||
|
||||
p.nextToken() // move past condition (and optional 'then')
|
||||
p.nextToken()
|
||||
|
||||
// Check if we immediately hit END (empty body should be an error)
|
||||
if p.curTokenIs(END) {
|
||||
p.addError("expected 'end' to close if statement")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Parse if body (no new variable scope)
|
||||
p.enterBlockScope()
|
||||
stmt.Body = p.parseBlockStatements(ELSEIF, ELSE, END)
|
||||
p.exitBlockScope()
|
||||
|
||||
// Parse elseif clauses
|
||||
for p.curTokenIs(ELSEIF) {
|
||||
elseif := ElseIfClause{}
|
||||
|
||||
p.nextToken() // move past 'elseif'
|
||||
p.nextToken()
|
||||
|
||||
elseif.Condition = p.ParseExpression(LOWEST)
|
||||
if elseif.Condition == nil {
|
||||
@ -551,14 +548,12 @@ func (p *Parser) parseIfStatement() *IfStatement {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Optional 'then' keyword
|
||||
if p.peekTokenIs(THEN) {
|
||||
p.nextToken()
|
||||
}
|
||||
|
||||
p.nextToken() // move past condition (and optional 'then')
|
||||
p.nextToken()
|
||||
|
||||
// Parse elseif body (no new variable scope)
|
||||
p.enterBlockScope()
|
||||
elseif.Body = p.parseBlockStatements(ELSEIF, ELSE, END)
|
||||
p.exitBlockScope()
|
||||
@ -566,11 +561,9 @@ func (p *Parser) parseIfStatement() *IfStatement {
|
||||
stmt.ElseIfs = append(stmt.ElseIfs, elseif)
|
||||
}
|
||||
|
||||
// Parse else clause
|
||||
if p.curTokenIs(ELSE) {
|
||||
p.nextToken() // move past 'else'
|
||||
p.nextToken()
|
||||
|
||||
// Parse else body (no new variable scope)
|
||||
p.enterBlockScope()
|
||||
stmt.Else = p.parseBlockStatements(END)
|
||||
p.exitBlockScope()
|
||||
@ -584,7 +577,7 @@ func (p *Parser) parseIfStatement() *IfStatement {
|
||||
return stmt
|
||||
}
|
||||
|
||||
// parseBlockStatements parses statements until one of the terminator tokens
|
||||
// parseBlockStatements parses statements until terminators
|
||||
func (p *Parser) parseBlockStatements(terminators ...TokenType) []Statement {
|
||||
statements := []Statement{}
|
||||
|
||||
@ -599,7 +592,7 @@ func (p *Parser) parseBlockStatements(terminators ...TokenType) []Statement {
|
||||
return statements
|
||||
}
|
||||
|
||||
// isTerminator checks if current token is one of the terminators
|
||||
// isTerminator checks if current token is a terminator
|
||||
func (p *Parser) isTerminator(terminators ...TokenType) bool {
|
||||
for _, terminator := range terminators {
|
||||
if p.curTokenIs(terminator) {
|
||||
@ -650,9 +643,7 @@ func (p *Parser) parseNumberLiteral() Expression {
|
||||
var value float64
|
||||
var err error
|
||||
|
||||
// Check for hexadecimal (0x/0X prefix)
|
||||
if strings.HasPrefix(literal, "0x") || strings.HasPrefix(literal, "0X") {
|
||||
// Validate hex format
|
||||
if len(literal) <= 2 {
|
||||
p.addError(fmt.Sprintf("could not parse '%s' as hexadecimal number", literal))
|
||||
return nil
|
||||
@ -664,7 +655,6 @@ func (p *Parser) parseNumberLiteral() Expression {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
// Parse as hex and convert to float64
|
||||
intVal, parseErr := strconv.ParseInt(literal, 0, 64)
|
||||
if parseErr != nil {
|
||||
p.addError(fmt.Sprintf("could not parse '%s' as hexadecimal number", literal))
|
||||
@ -672,7 +662,6 @@ func (p *Parser) parseNumberLiteral() Expression {
|
||||
}
|
||||
value = float64(intVal)
|
||||
} else if strings.HasPrefix(literal, "0b") || strings.HasPrefix(literal, "0B") {
|
||||
// Validate binary format
|
||||
if len(literal) <= 2 {
|
||||
p.addError(fmt.Sprintf("could not parse '%s' as binary number", literal))
|
||||
return nil
|
||||
@ -684,8 +673,7 @@ func (p *Parser) parseNumberLiteral() Expression {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
// Parse binary manually since Go doesn't support 0b in ParseInt with base 0
|
||||
binaryStr := literal[2:] // remove "0b" prefix
|
||||
binaryStr := literal[2:]
|
||||
intVal, parseErr := strconv.ParseInt(binaryStr, 2, 64)
|
||||
if parseErr != nil {
|
||||
p.addError(fmt.Sprintf("could not parse '%s' as binary number", literal))
|
||||
@ -693,7 +681,6 @@ func (p *Parser) parseNumberLiteral() Expression {
|
||||
}
|
||||
value = float64(intVal)
|
||||
} else {
|
||||
// Parse as regular decimal (handles scientific notation automatically)
|
||||
value, err = strconv.ParseFloat(literal, 64)
|
||||
if err != nil {
|
||||
p.addError(fmt.Sprintf("could not parse '%s' as number", literal))
|
||||
@ -763,12 +750,14 @@ func (p *Parser) parseFunctionLiteral() Expression {
|
||||
return nil
|
||||
}
|
||||
|
||||
p.nextToken() // move past ')'
|
||||
// Check for return type hint
|
||||
fn.ReturnType = p.parseTypeHint()
|
||||
|
||||
p.nextToken()
|
||||
|
||||
// Enter new function scope and declare parameters
|
||||
p.enterFunctionScope()
|
||||
for _, param := range fn.Parameters {
|
||||
p.declareVariable(param)
|
||||
p.declareVariable(param.Name)
|
||||
}
|
||||
fn.Body = p.parseBlockStatements(END)
|
||||
p.exitFunctionScope()
|
||||
@ -781,8 +770,8 @@ func (p *Parser) parseFunctionLiteral() Expression {
|
||||
return fn
|
||||
}
|
||||
|
||||
func (p *Parser) parseFunctionParameters() ([]string, bool) {
|
||||
var params []string
|
||||
func (p *Parser) parseFunctionParameters() ([]FunctionParameter, bool) {
|
||||
var params []FunctionParameter
|
||||
var variadic bool
|
||||
|
||||
if p.peekTokenIs(RPAREN) {
|
||||
@ -802,16 +791,20 @@ func (p *Parser) parseFunctionParameters() ([]string, bool) {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
params = append(params, p.curToken.Literal)
|
||||
param := FunctionParameter{Name: p.curToken.Literal}
|
||||
|
||||
// Check for type hint
|
||||
param.TypeHint = p.parseTypeHint()
|
||||
|
||||
params = append(params, param)
|
||||
|
||||
if !p.peekTokenIs(COMMA) {
|
||||
break
|
||||
}
|
||||
|
||||
p.nextToken() // move to ','
|
||||
p.nextToken() // move past ','
|
||||
p.nextToken()
|
||||
p.nextToken()
|
||||
|
||||
// Check for ellipsis after comma
|
||||
if p.curTokenIs(ELLIPSIS) {
|
||||
variadic = true
|
||||
break
|
||||
@ -833,7 +826,6 @@ func (p *Parser) parseTableLiteral() Expression {
|
||||
p.nextToken()
|
||||
|
||||
for {
|
||||
// Check for EOF
|
||||
if p.curTokenIs(EOF) {
|
||||
p.addError("unexpected end of input, expected }")
|
||||
return nil
|
||||
@ -841,17 +833,15 @@ func (p *Parser) parseTableLiteral() Expression {
|
||||
|
||||
pair := TablePair{}
|
||||
|
||||
// Check if this is a key=value pair (identifier or string key)
|
||||
if (p.curTokenIs(IDENT) || p.curTokenIs(STRING)) && p.peekTokenIs(ASSIGN) {
|
||||
if p.curTokenIs(IDENT) {
|
||||
pair.Key = &Identifier{Value: p.curToken.Literal}
|
||||
} else {
|
||||
pair.Key = &StringLiteral{Value: p.curToken.Literal}
|
||||
}
|
||||
p.nextToken() // move to =
|
||||
p.nextToken() // move past =
|
||||
p.nextToken()
|
||||
p.nextToken()
|
||||
|
||||
// Check for EOF after =
|
||||
if p.curTokenIs(EOF) {
|
||||
p.addError("expected expression after assignment operator")
|
||||
return nil
|
||||
@ -859,7 +849,6 @@ func (p *Parser) parseTableLiteral() Expression {
|
||||
|
||||
pair.Value = p.ParseExpression(LOWEST)
|
||||
} else {
|
||||
// Array-style element
|
||||
pair.Value = p.ParseExpression(LOWEST)
|
||||
}
|
||||
|
||||
@ -873,15 +862,13 @@ func (p *Parser) parseTableLiteral() Expression {
|
||||
break
|
||||
}
|
||||
|
||||
p.nextToken() // consume comma
|
||||
p.nextToken() // move to next element
|
||||
p.nextToken()
|
||||
p.nextToken()
|
||||
|
||||
// Allow trailing comma
|
||||
if p.curTokenIs(RBRACE) {
|
||||
break
|
||||
}
|
||||
|
||||
// Check for EOF after comma
|
||||
if p.curTokenIs(EOF) {
|
||||
p.addError("expected next token to be }")
|
||||
return nil
|
||||
@ -956,7 +943,7 @@ func (p *Parser) parseExpressionList(end TokenType) []Expression {
|
||||
}
|
||||
|
||||
func (p *Parser) parseIndexExpression(left Expression) Expression {
|
||||
p.nextToken() // move past '['
|
||||
p.nextToken()
|
||||
|
||||
index := p.ParseExpression(LOWEST)
|
||||
if index == nil {
|
||||
@ -993,7 +980,6 @@ func (p *Parser) expectPeek(t TokenType) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// expectPeekIdent accepts IDENT or keyword tokens as identifiers
|
||||
func (p *Parser) expectPeekIdent() bool {
|
||||
if p.peekTokenIs(IDENT) || p.isKeyword(p.peekToken.Type) {
|
||||
p.nextToken()
|
||||
@ -1003,7 +989,6 @@ func (p *Parser) expectPeekIdent() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// isKeyword checks if a token type is a keyword that can be used as identifier
|
||||
func (p *Parser) isKeyword(t TokenType) bool {
|
||||
switch t {
|
||||
case TRUE, FALSE, NIL, AND, OR, NOT, IF, THEN, ELSEIF, ELSE, END, ECHO, FOR, WHILE, IN, DO, BREAK, EXIT, FN, RETURN:
|
||||
@ -1013,7 +998,7 @@ func (p *Parser) isKeyword(t TokenType) bool {
|
||||
}
|
||||
}
|
||||
|
||||
// Error handling methods
|
||||
// Error handling
|
||||
func (p *Parser) addError(message string) {
|
||||
p.errors = append(p.errors, ParseError{
|
||||
Message: message,
|
||||
@ -1075,12 +1060,10 @@ func (p *Parser) Errors() []ParseError {
|
||||
return p.errors
|
||||
}
|
||||
|
||||
// HasErrors returns true if there are any parsing errors
|
||||
func (p *Parser) HasErrors() bool {
|
||||
return len(p.errors) > 0
|
||||
}
|
||||
|
||||
// ErrorStrings returns error messages as strings for backward compatibility
|
||||
func (p *Parser) ErrorStrings() []string {
|
||||
result := make([]string, len(p.errors))
|
||||
for i, err := range p.errors {
|
||||
@ -1089,7 +1072,7 @@ func (p *Parser) ErrorStrings() []string {
|
||||
return result
|
||||
}
|
||||
|
||||
// tokenTypeString returns a human-readable string for token types
|
||||
// tokenTypeString returns human-readable string for token types
|
||||
func tokenTypeString(t TokenType) string {
|
||||
switch t {
|
||||
case IDENT:
|
||||
@ -1114,6 +1097,8 @@ func tokenTypeString(t TokenType) string {
|
||||
return "/"
|
||||
case DOT:
|
||||
return "."
|
||||
case COLON:
|
||||
return ":"
|
||||
case EQ:
|
||||
return "=="
|
||||
case NOT_EQ:
|
||||
|
@ -81,8 +81,8 @@ func TestFunctionParameters(t *testing.T) {
|
||||
}
|
||||
|
||||
for i, expected := range tt.params {
|
||||
if fn.Parameters[i] != expected {
|
||||
t.Errorf("parameter %d: expected %s, got %s", i, expected, fn.Parameters[i])
|
||||
if fn.Parameters[i].Name != expected {
|
||||
t.Errorf("parameter %d: expected %s, got %s", i, expected, fn.Parameters[i].Name)
|
||||
}
|
||||
}
|
||||
|
||||
|
533
parser/tests/types_test.go
Normal file
533
parser/tests/types_test.go
Normal file
@ -0,0 +1,533 @@
|
||||
package parser_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"git.sharkk.net/Sharkk/Mako/parser"
|
||||
)
|
||||
|
||||
func TestVariableTypeHints(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
variable string
|
||||
typeHint string
|
||||
hasHint bool
|
||||
desc string
|
||||
}{
|
||||
{"x = 42", "x", "", false, "no type hint"},
|
||||
{"x: number = 42", "x", "number", true, "number type hint"},
|
||||
{"name: string = \"hello\"", "name", "string", true, "string type hint"},
|
||||
{"flag: bool = true", "flag", "bool", true, "bool type hint"},
|
||||
{"data: table = {}", "data", "table", true, "table type hint"},
|
||||
{"fn_var: function = fn() end", "fn_var", "function", true, "function type hint"},
|
||||
{"value: any = nil", "value", "any", true, "any type hint"},
|
||||
{"ptr: nil = nil", "ptr", "nil", true, "nil type hint"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
l := parser.NewLexer(tt.input)
|
||||
p := parser.NewParser(l)
|
||||
program := p.ParseProgram()
|
||||
checkParserErrors(t, p)
|
||||
|
||||
if len(program.Statements) != 1 {
|
||||
t.Fatalf("expected 1 statement, got %d", len(program.Statements))
|
||||
}
|
||||
|
||||
stmt, ok := program.Statements[0].(*parser.AssignStatement)
|
||||
if !ok {
|
||||
t.Fatalf("expected AssignStatement, got %T", program.Statements[0])
|
||||
}
|
||||
|
||||
// Check variable name
|
||||
ident, ok := stmt.Name.(*parser.Identifier)
|
||||
if !ok {
|
||||
t.Fatalf("expected Identifier for Name, got %T", stmt.Name)
|
||||
}
|
||||
|
||||
if ident.Value != tt.variable {
|
||||
t.Errorf("expected variable %s, got %s", tt.variable, ident.Value)
|
||||
}
|
||||
|
||||
// Check type hint
|
||||
if tt.hasHint {
|
||||
if stmt.TypeHint == nil {
|
||||
t.Error("expected type hint but got nil")
|
||||
} else {
|
||||
if stmt.TypeHint.Type != tt.typeHint {
|
||||
t.Errorf("expected type hint %s, got %s", tt.typeHint, stmt.TypeHint.Type)
|
||||
}
|
||||
if stmt.TypeHint.Inferred {
|
||||
t.Error("expected type hint to not be inferred")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if stmt.TypeHint != nil {
|
||||
t.Errorf("expected no type hint but got %s", stmt.TypeHint.Type)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFunctionParameterTypeHints(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
params []struct{ name, typeHint string }
|
||||
returnType string
|
||||
hasReturn bool
|
||||
desc string
|
||||
}{
|
||||
{
|
||||
"fn(a, b) end",
|
||||
[]struct{ name, typeHint string }{
|
||||
{"a", ""},
|
||||
{"b", ""},
|
||||
},
|
||||
"", false,
|
||||
"no type hints",
|
||||
},
|
||||
{
|
||||
"fn(a: number, b: string) end",
|
||||
[]struct{ name, typeHint string }{
|
||||
{"a", "number"},
|
||||
{"b", "string"},
|
||||
},
|
||||
"", false,
|
||||
"parameter type hints only",
|
||||
},
|
||||
{
|
||||
"fn(x: number): string end",
|
||||
[]struct{ name, typeHint string }{
|
||||
{"x", "number"},
|
||||
},
|
||||
"string", true,
|
||||
"parameter and return type hints",
|
||||
},
|
||||
{
|
||||
"fn(): bool end",
|
||||
[]struct{ name, typeHint string }{},
|
||||
"bool", true,
|
||||
"return type hint only",
|
||||
},
|
||||
{
|
||||
"fn(a: number, b, c: string): table end",
|
||||
[]struct{ name, typeHint string }{
|
||||
{"a", "number"},
|
||||
{"b", ""},
|
||||
{"c", "string"},
|
||||
},
|
||||
"table", true,
|
||||
"mixed parameter types with return",
|
||||
},
|
||||
{
|
||||
"fn(callback: function, data: any): nil end",
|
||||
[]struct{ name, typeHint string }{
|
||||
{"callback", "function"},
|
||||
{"data", "any"},
|
||||
},
|
||||
"nil", true,
|
||||
"function and any types",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
l := parser.NewLexer(tt.input)
|
||||
p := parser.NewParser(l)
|
||||
expr := p.ParseExpression(parser.LOWEST)
|
||||
checkParserErrors(t, p)
|
||||
|
||||
fn, ok := expr.(*parser.FunctionLiteral)
|
||||
if !ok {
|
||||
t.Fatalf("expected FunctionLiteral, got %T", expr)
|
||||
}
|
||||
|
||||
// Check parameters
|
||||
if len(fn.Parameters) != len(tt.params) {
|
||||
t.Fatalf("expected %d parameters, got %d", len(tt.params), len(fn.Parameters))
|
||||
}
|
||||
|
||||
for i, expected := range tt.params {
|
||||
param := fn.Parameters[i]
|
||||
if param.Name != expected.name {
|
||||
t.Errorf("parameter %d: expected name %s, got %s", i, expected.name, param.Name)
|
||||
}
|
||||
|
||||
if expected.typeHint == "" {
|
||||
if param.TypeHint != nil {
|
||||
t.Errorf("parameter %d: expected no type hint but got %s", i, param.TypeHint.Type)
|
||||
}
|
||||
} else {
|
||||
if param.TypeHint == nil {
|
||||
t.Errorf("parameter %d: expected type hint %s but got nil", i, expected.typeHint)
|
||||
} else if param.TypeHint.Type != expected.typeHint {
|
||||
t.Errorf("parameter %d: expected type hint %s, got %s", i, expected.typeHint, param.TypeHint.Type)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check return type
|
||||
if tt.hasReturn {
|
||||
if fn.ReturnType == nil {
|
||||
t.Error("expected return type hint but got nil")
|
||||
} else if fn.ReturnType.Type != tt.returnType {
|
||||
t.Errorf("expected return type %s, got %s", tt.returnType, fn.ReturnType.Type)
|
||||
}
|
||||
} else {
|
||||
if fn.ReturnType != nil {
|
||||
t.Errorf("expected no return type but got %s", fn.ReturnType.Type)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTypeHintStringRepresentation(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
desc string
|
||||
}{
|
||||
{"x: number = 42", "local x: number = 42.00", "typed variable assignment"},
|
||||
{"fn(a: number, b: string): bool end", "fn(a: number, b: string): bool\nend", "typed function"},
|
||||
{"callback: function = fn(x: number): string return \"\" end", "local callback: function = fn(x: number): string\n\treturn \"\"\nend", "typed function assignment"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
l := parser.NewLexer(tt.input)
|
||||
p := parser.NewParser(l)
|
||||
|
||||
var result string
|
||||
if tt.input[0] == 'f' { // function literal
|
||||
expr := p.ParseExpression(parser.LOWEST)
|
||||
checkParserErrors(t, p)
|
||||
result = expr.String()
|
||||
} else { // assignment statement
|
||||
program := p.ParseProgram()
|
||||
checkParserErrors(t, p)
|
||||
result = program.Statements[0].String()
|
||||
}
|
||||
|
||||
if result != tt.expected {
|
||||
t.Errorf("expected %q, got %q", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTypeHintValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expectedError string
|
||||
desc string
|
||||
}{
|
||||
{"x: invalid = 42", "invalid type name 'invalid'", "invalid type name"},
|
||||
{"x: = 42", "expected type name after ':'", "missing type name"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
l := parser.NewLexer(tt.input)
|
||||
p := parser.NewParser(l)
|
||||
p.ParseProgram()
|
||||
|
||||
if !p.HasErrors() {
|
||||
t.Fatal("expected parsing errors")
|
||||
}
|
||||
|
||||
errors := p.Errors()
|
||||
found := false
|
||||
for _, err := range errors {
|
||||
if err.Message == tt.expectedError {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
errorMsgs := make([]string, len(errors))
|
||||
for i, err := range errors {
|
||||
errorMsgs[i] = err.Message
|
||||
}
|
||||
t.Errorf("expected error %q, got %v", tt.expectedError, errorMsgs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemberAccessWithoutTypeHints(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
desc string
|
||||
}{
|
||||
{"table.key = 42", "dot notation assignment"},
|
||||
{"arr[1] = \"hello\"", "bracket notation assignment"},
|
||||
{"obj.nested.deep = true", "chained dot assignment"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
l := parser.NewLexer(tt.input)
|
||||
p := parser.NewParser(l)
|
||||
program := p.ParseProgram()
|
||||
checkParserErrors(t, p)
|
||||
|
||||
if len(program.Statements) != 1 {
|
||||
t.Fatalf("expected 1 statement, got %d", len(program.Statements))
|
||||
}
|
||||
|
||||
stmt, ok := program.Statements[0].(*parser.AssignStatement)
|
||||
if !ok {
|
||||
t.Fatalf("expected AssignStatement, got %T", program.Statements[0])
|
||||
}
|
||||
|
||||
// Member access should never have type hints
|
||||
if stmt.TypeHint != nil {
|
||||
t.Error("member access assignment should not have type hints")
|
||||
}
|
||||
|
||||
// Should not be a declaration
|
||||
if stmt.IsDeclaration {
|
||||
t.Error("member access assignment should not be a declaration")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTypeInferenceEngine(t *testing.T) {
|
||||
input := `x: number = 42
|
||||
name: string = "hello"
|
||||
fn_var: function = fn(a: number, b: string): bool
|
||||
result: bool = a > 0
|
||||
return result
|
||||
end
|
||||
echo name`
|
||||
|
||||
l := parser.NewLexer(input)
|
||||
p := parser.NewParser(l)
|
||||
program := p.ParseProgram()
|
||||
checkParserErrors(t, p)
|
||||
|
||||
// Run type inference
|
||||
inferrer := parser.NewTypeInferrer()
|
||||
typeErrors := inferrer.InferTypes(program)
|
||||
|
||||
// Should have no type errors for valid code
|
||||
if len(typeErrors) > 0 {
|
||||
errorMsgs := make([]string, len(typeErrors))
|
||||
for i, err := range typeErrors {
|
||||
errorMsgs[i] = err.Error()
|
||||
}
|
||||
t.Errorf("unexpected type errors: %v", errorMsgs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTypeInferenceErrors(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expectedError string
|
||||
desc string
|
||||
}{
|
||||
{
|
||||
"x: number = \"hello\"",
|
||||
"cannot assign string to variable of type number",
|
||||
"type mismatch in assignment",
|
||||
},
|
||||
{
|
||||
"x = 42\ny: string = x",
|
||||
"cannot assign number to variable of type string",
|
||||
"type mismatch with inferred type",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
l := parser.NewLexer(tt.input)
|
||||
p := parser.NewParser(l)
|
||||
program := p.ParseProgram()
|
||||
checkParserErrors(t, p)
|
||||
|
||||
inferrer := parser.NewTypeInferrer()
|
||||
typeErrors := inferrer.InferTypes(program)
|
||||
|
||||
if len(typeErrors) == 0 {
|
||||
t.Fatal("expected type errors")
|
||||
}
|
||||
|
||||
found := false
|
||||
for _, err := range typeErrors {
|
||||
if err.Message == tt.expectedError {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
errorMsgs := make([]string, len(typeErrors))
|
||||
for i, err := range typeErrors {
|
||||
errorMsgs[i] = err.Message
|
||||
}
|
||||
t.Errorf("expected error %q, got %v", tt.expectedError, errorMsgs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVariadicFunctionTypeHints(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
variadic bool
|
||||
desc string
|
||||
}{
|
||||
{"fn(a: number, ...) end", true, "variadic after typed param"},
|
||||
{"fn(...) end", true, "variadic only"},
|
||||
{"fn(a: number, b: string) end", false, "no variadic"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
l := parser.NewLexer(tt.input)
|
||||
p := parser.NewParser(l)
|
||||
expr := p.ParseExpression(parser.LOWEST)
|
||||
checkParserErrors(t, p)
|
||||
|
||||
fn, ok := expr.(*parser.FunctionLiteral)
|
||||
if !ok {
|
||||
t.Fatalf("expected FunctionLiteral, got %T", expr)
|
||||
}
|
||||
|
||||
if fn.Variadic != tt.variadic {
|
||||
t.Errorf("expected variadic = %t, got %t", tt.variadic, fn.Variadic)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestComplexTypeHintProgram(t *testing.T) {
|
||||
input := `config: table = {
|
||||
host = "localhost",
|
||||
port = 8080,
|
||||
enabled = true
|
||||
}
|
||||
|
||||
handler: function = fn(request: table, callback: function): nil
|
||||
status: number = 200
|
||||
if request.method == "GET" then
|
||||
response: table = {status = status, body = "OK"}
|
||||
result = callback(response)
|
||||
end
|
||||
end
|
||||
|
||||
server: table = {
|
||||
config = config,
|
||||
handler = handler,
|
||||
start = fn(self: table): bool
|
||||
return true
|
||||
end
|
||||
}`
|
||||
|
||||
l := parser.NewLexer(input)
|
||||
p := parser.NewParser(l)
|
||||
program := p.ParseProgram()
|
||||
checkParserErrors(t, p)
|
||||
|
||||
if len(program.Statements) != 3 {
|
||||
t.Fatalf("expected 3 statements, got %d", len(program.Statements))
|
||||
}
|
||||
|
||||
// Check first statement: config table with typed assignments
|
||||
configStmt, ok := program.Statements[0].(*parser.AssignStatement)
|
||||
if !ok {
|
||||
t.Fatalf("expected AssignStatement, got %T", program.Statements[0])
|
||||
}
|
||||
|
||||
if configStmt.TypeHint == nil || configStmt.TypeHint.Type != "table" {
|
||||
t.Error("expected table type hint for config")
|
||||
}
|
||||
|
||||
// Check second statement: handler function with typed parameters
|
||||
handlerStmt, ok := program.Statements[1].(*parser.AssignStatement)
|
||||
if !ok {
|
||||
t.Fatalf("expected AssignStatement, got %T", program.Statements[1])
|
||||
}
|
||||
|
||||
if handlerStmt.TypeHint == nil || handlerStmt.TypeHint.Type != "function" {
|
||||
t.Error("expected function type hint for handler")
|
||||
}
|
||||
|
||||
fn, ok := handlerStmt.Value.(*parser.FunctionLiteral)
|
||||
if !ok {
|
||||
t.Fatalf("expected FunctionLiteral value, got %T", handlerStmt.Value)
|
||||
}
|
||||
|
||||
if len(fn.Parameters) != 2 {
|
||||
t.Fatalf("expected 2 parameters, got %d", len(fn.Parameters))
|
||||
}
|
||||
|
||||
// Check parameter types
|
||||
if fn.Parameters[0].TypeHint == nil || fn.Parameters[0].TypeHint.Type != "table" {
|
||||
t.Error("expected table type for request parameter")
|
||||
}
|
||||
|
||||
if fn.Parameters[1].TypeHint == nil || fn.Parameters[1].TypeHint.Type != "function" {
|
||||
t.Error("expected function type for callback parameter")
|
||||
}
|
||||
|
||||
// Check return type
|
||||
if fn.ReturnType == nil || fn.ReturnType.Type != "nil" {
|
||||
t.Error("expected nil return type for handler")
|
||||
}
|
||||
|
||||
// Check third statement: server table
|
||||
serverStmt, ok := program.Statements[2].(*parser.AssignStatement)
|
||||
if !ok {
|
||||
t.Fatalf("expected AssignStatement, got %T", program.Statements[2])
|
||||
}
|
||||
|
||||
if serverStmt.TypeHint == nil || serverStmt.TypeHint.Type != "table" {
|
||||
t.Error("expected table type hint for server")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTypeInfoGettersSetters(t *testing.T) {
|
||||
// Test that all expression types properly implement GetType/SetType
|
||||
typeInfo := &parser.TypeInfo{Type: "test", Inferred: true}
|
||||
|
||||
expressions := []parser.Expression{
|
||||
&parser.Identifier{Value: "x"},
|
||||
&parser.NumberLiteral{Value: 42},
|
||||
&parser.StringLiteral{Value: "hello"},
|
||||
&parser.BooleanLiteral{Value: true},
|
||||
&parser.NilLiteral{},
|
||||
&parser.TableLiteral{},
|
||||
&parser.FunctionLiteral{},
|
||||
&parser.CallExpression{Function: &parser.Identifier{Value: "fn"}},
|
||||
&parser.PrefixExpression{Operator: "-", Right: &parser.NumberLiteral{Value: 1}},
|
||||
&parser.InfixExpression{Left: &parser.NumberLiteral{Value: 1}, Operator: "+", Right: &parser.NumberLiteral{Value: 2}},
|
||||
&parser.IndexExpression{Left: &parser.Identifier{Value: "arr"}, Index: &parser.NumberLiteral{Value: 0}},
|
||||
&parser.DotExpression{Left: &parser.Identifier{Value: "obj"}, Key: "prop"},
|
||||
}
|
||||
|
||||
for i, expr := range expressions {
|
||||
t.Run(string(rune('0'+i)), func(t *testing.T) {
|
||||
// Initially should have no type
|
||||
if expr.GetType() != nil {
|
||||
t.Error("expected nil type initially")
|
||||
}
|
||||
|
||||
// Set type
|
||||
expr.SetType(typeInfo)
|
||||
|
||||
// Get type should return what we set
|
||||
retrieved := expr.GetType()
|
||||
if retrieved == nil {
|
||||
t.Error("expected non-nil type after setting")
|
||||
} else if retrieved.Type != "test" || !retrieved.Inferred {
|
||||
t.Errorf("expected {Type: test, Inferred: true}, got %+v", retrieved)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -42,6 +42,7 @@ const (
|
||||
RBRACKET // ]
|
||||
COMMA // ,
|
||||
ELLIPSIS // ...
|
||||
COLON // :
|
||||
|
||||
// Keywords
|
||||
IF
|
||||
|
591
parser/types.go
Normal file
591
parser/types.go
Normal file
@ -0,0 +1,591 @@
|
||||
package parser
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Type constants for built-in types
|
||||
const (
|
||||
TypeNumber = "number"
|
||||
TypeString = "string"
|
||||
TypeBool = "bool"
|
||||
TypeNil = "nil"
|
||||
TypeTable = "table"
|
||||
TypeFunction = "function"
|
||||
TypeAny = "any"
|
||||
)
|
||||
|
||||
// TypeError represents a type checking error
|
||||
type TypeError struct {
|
||||
Message string
|
||||
Line int
|
||||
Column int
|
||||
Node Node
|
||||
}
|
||||
|
||||
func (te TypeError) Error() string {
|
||||
return fmt.Sprintf("Type error at line %d, column %d: %s", te.Line, te.Column, te.Message)
|
||||
}
|
||||
|
||||
// Symbol represents a variable in the symbol table
|
||||
type Symbol struct {
|
||||
Name string
|
||||
Type *TypeInfo
|
||||
Declared bool
|
||||
Line int
|
||||
Column int
|
||||
}
|
||||
|
||||
// Scope represents a scope in the symbol table
|
||||
type Scope struct {
|
||||
symbols map[string]*Symbol
|
||||
parent *Scope
|
||||
}
|
||||
|
||||
func NewScope(parent *Scope) *Scope {
|
||||
return &Scope{
|
||||
symbols: make(map[string]*Symbol),
|
||||
parent: parent,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Scope) Define(symbol *Symbol) {
|
||||
s.symbols[symbol.Name] = symbol
|
||||
}
|
||||
|
||||
func (s *Scope) Lookup(name string) *Symbol {
|
||||
if symbol, ok := s.symbols[name]; ok {
|
||||
return symbol
|
||||
}
|
||||
if s.parent != nil {
|
||||
return s.parent.Lookup(name)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TypeInferrer performs type inference and checking
|
||||
type TypeInferrer struct {
|
||||
currentScope *Scope
|
||||
globalScope *Scope
|
||||
errors []TypeError
|
||||
|
||||
// Pre-allocated type objects for performance
|
||||
numberType *TypeInfo
|
||||
stringType *TypeInfo
|
||||
boolType *TypeInfo
|
||||
nilType *TypeInfo
|
||||
tableType *TypeInfo
|
||||
anyType *TypeInfo
|
||||
}
|
||||
|
||||
// NewTypeInferrer creates a new type inference engine
|
||||
func NewTypeInferrer() *TypeInferrer {
|
||||
globalScope := NewScope(nil)
|
||||
|
||||
ti := &TypeInferrer{
|
||||
currentScope: globalScope,
|
||||
globalScope: globalScope,
|
||||
errors: []TypeError{},
|
||||
|
||||
// Pre-allocate common types to reduce allocations
|
||||
numberType: &TypeInfo{Type: TypeNumber, Inferred: true},
|
||||
stringType: &TypeInfo{Type: TypeString, Inferred: true},
|
||||
boolType: &TypeInfo{Type: TypeBool, Inferred: true},
|
||||
nilType: &TypeInfo{Type: TypeNil, Inferred: true},
|
||||
tableType: &TypeInfo{Type: TypeTable, Inferred: true},
|
||||
anyType: &TypeInfo{Type: TypeAny, Inferred: true},
|
||||
}
|
||||
|
||||
return ti
|
||||
}
|
||||
|
||||
// InferTypes performs type inference on the entire program
|
||||
func (ti *TypeInferrer) InferTypes(program *Program) []TypeError {
|
||||
for _, stmt := range program.Statements {
|
||||
ti.inferStatement(stmt)
|
||||
}
|
||||
return ti.errors
|
||||
}
|
||||
|
||||
// enterScope creates a new scope
|
||||
func (ti *TypeInferrer) enterScope() {
|
||||
ti.currentScope = NewScope(ti.currentScope)
|
||||
}
|
||||
|
||||
// exitScope returns to the parent scope
|
||||
func (ti *TypeInferrer) exitScope() {
|
||||
if ti.currentScope.parent != nil {
|
||||
ti.currentScope = ti.currentScope.parent
|
||||
}
|
||||
}
|
||||
|
||||
// addError adds a type error
|
||||
func (ti *TypeInferrer) addError(message string, node Node) {
|
||||
ti.errors = append(ti.errors, TypeError{
|
||||
Message: message,
|
||||
Line: 0, // Would need to track position in AST nodes
|
||||
Column: 0,
|
||||
Node: node,
|
||||
})
|
||||
}
|
||||
|
||||
// inferStatement infers types for statements
|
||||
func (ti *TypeInferrer) inferStatement(stmt Statement) {
|
||||
switch s := stmt.(type) {
|
||||
case *AssignStatement:
|
||||
ti.inferAssignStatement(s)
|
||||
case *EchoStatement:
|
||||
ti.inferExpression(s.Value)
|
||||
case *IfStatement:
|
||||
ti.inferIfStatement(s)
|
||||
case *WhileStatement:
|
||||
ti.inferWhileStatement(s)
|
||||
case *ForStatement:
|
||||
ti.inferForStatement(s)
|
||||
case *ForInStatement:
|
||||
ti.inferForInStatement(s)
|
||||
case *ReturnStatement:
|
||||
if s.Value != nil {
|
||||
ti.inferExpression(s.Value)
|
||||
}
|
||||
case *ExitStatement:
|
||||
if s.Value != nil {
|
||||
ti.inferExpression(s.Value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// inferAssignStatement handles variable assignments with type checking
|
||||
func (ti *TypeInferrer) inferAssignStatement(stmt *AssignStatement) {
|
||||
// Infer the type of the value expression
|
||||
valueType := ti.inferExpression(stmt.Value)
|
||||
|
||||
if ident, ok := stmt.Name.(*Identifier); ok {
|
||||
// Simple variable assignment
|
||||
symbol := ti.currentScope.Lookup(ident.Value)
|
||||
|
||||
if stmt.IsDeclaration {
|
||||
// New variable declaration
|
||||
varType := valueType
|
||||
|
||||
// If there's a type hint, validate it
|
||||
if stmt.TypeHint != nil {
|
||||
if !ti.isTypeCompatible(valueType, stmt.TypeHint) {
|
||||
ti.addError(fmt.Sprintf("cannot assign %s to variable of type %s",
|
||||
valueType.Type, stmt.TypeHint.Type), stmt)
|
||||
}
|
||||
varType = stmt.TypeHint
|
||||
varType.Inferred = false
|
||||
}
|
||||
|
||||
// Define the new symbol
|
||||
ti.currentScope.Define(&Symbol{
|
||||
Name: ident.Value,
|
||||
Type: varType,
|
||||
Declared: true,
|
||||
})
|
||||
|
||||
ident.SetType(varType)
|
||||
} else {
|
||||
// Assignment to existing variable
|
||||
if symbol == nil {
|
||||
ti.addError(fmt.Sprintf("undefined variable '%s'", ident.Value), stmt)
|
||||
return
|
||||
}
|
||||
|
||||
// Check type compatibility
|
||||
if !ti.isTypeCompatible(valueType, symbol.Type) {
|
||||
ti.addError(fmt.Sprintf("cannot assign %s to variable of type %s",
|
||||
valueType.Type, symbol.Type.Type), stmt)
|
||||
}
|
||||
|
||||
ident.SetType(symbol.Type)
|
||||
}
|
||||
} else {
|
||||
// Member access assignment (table.key or table[index])
|
||||
ti.inferExpression(stmt.Name)
|
||||
}
|
||||
}
|
||||
|
||||
// inferIfStatement handles if statements
|
||||
func (ti *TypeInferrer) inferIfStatement(stmt *IfStatement) {
|
||||
condType := ti.inferExpression(stmt.Condition)
|
||||
ti.validateBooleanContext(condType, stmt.Condition)
|
||||
|
||||
ti.enterScope()
|
||||
for _, s := range stmt.Body {
|
||||
ti.inferStatement(s)
|
||||
}
|
||||
ti.exitScope()
|
||||
|
||||
for _, elseif := range stmt.ElseIfs {
|
||||
condType := ti.inferExpression(elseif.Condition)
|
||||
ti.validateBooleanContext(condType, elseif.Condition)
|
||||
|
||||
ti.enterScope()
|
||||
for _, s := range elseif.Body {
|
||||
ti.inferStatement(s)
|
||||
}
|
||||
ti.exitScope()
|
||||
}
|
||||
|
||||
if len(stmt.Else) > 0 {
|
||||
ti.enterScope()
|
||||
for _, s := range stmt.Else {
|
||||
ti.inferStatement(s)
|
||||
}
|
||||
ti.exitScope()
|
||||
}
|
||||
}
|
||||
|
||||
// inferWhileStatement handles while loops
|
||||
func (ti *TypeInferrer) inferWhileStatement(stmt *WhileStatement) {
|
||||
condType := ti.inferExpression(stmt.Condition)
|
||||
ti.validateBooleanContext(condType, stmt.Condition)
|
||||
|
||||
ti.enterScope()
|
||||
for _, s := range stmt.Body {
|
||||
ti.inferStatement(s)
|
||||
}
|
||||
ti.exitScope()
|
||||
}
|
||||
|
||||
// inferForStatement handles numeric for loops
|
||||
func (ti *TypeInferrer) inferForStatement(stmt *ForStatement) {
|
||||
startType := ti.inferExpression(stmt.Start)
|
||||
endType := ti.inferExpression(stmt.End)
|
||||
|
||||
if !ti.isNumericType(startType) {
|
||||
ti.addError("for loop start value must be numeric", stmt.Start)
|
||||
}
|
||||
if !ti.isNumericType(endType) {
|
||||
ti.addError("for loop end value must be numeric", stmt.End)
|
||||
}
|
||||
|
||||
if stmt.Step != nil {
|
||||
stepType := ti.inferExpression(stmt.Step)
|
||||
if !ti.isNumericType(stepType) {
|
||||
ti.addError("for loop step value must be numeric", stmt.Step)
|
||||
}
|
||||
}
|
||||
|
||||
ti.enterScope()
|
||||
// Define loop variable as number
|
||||
ti.currentScope.Define(&Symbol{
|
||||
Name: stmt.Variable.Value,
|
||||
Type: ti.numberType,
|
||||
Declared: true,
|
||||
})
|
||||
stmt.Variable.SetType(ti.numberType)
|
||||
|
||||
for _, s := range stmt.Body {
|
||||
ti.inferStatement(s)
|
||||
}
|
||||
ti.exitScope()
|
||||
}
|
||||
|
||||
// inferForInStatement handles for-in loops
|
||||
func (ti *TypeInferrer) inferForInStatement(stmt *ForInStatement) {
|
||||
iterableType := ti.inferExpression(stmt.Iterable)
|
||||
|
||||
// For now, assume iterable is a table
|
||||
if !ti.isTableType(iterableType) {
|
||||
ti.addError("for-in requires an iterable (table)", stmt.Iterable)
|
||||
}
|
||||
|
||||
ti.enterScope()
|
||||
|
||||
// Define loop variables (key and value are any for now)
|
||||
if stmt.Key != nil {
|
||||
ti.currentScope.Define(&Symbol{
|
||||
Name: stmt.Key.Value,
|
||||
Type: ti.anyType,
|
||||
Declared: true,
|
||||
})
|
||||
stmt.Key.SetType(ti.anyType)
|
||||
}
|
||||
|
||||
ti.currentScope.Define(&Symbol{
|
||||
Name: stmt.Value.Value,
|
||||
Type: ti.anyType,
|
||||
Declared: true,
|
||||
})
|
||||
stmt.Value.SetType(ti.anyType)
|
||||
|
||||
for _, s := range stmt.Body {
|
||||
ti.inferStatement(s)
|
||||
}
|
||||
ti.exitScope()
|
||||
}
|
||||
|
||||
// inferExpression infers the type of an expression
|
||||
func (ti *TypeInferrer) inferExpression(expr Expression) *TypeInfo {
|
||||
if expr == nil {
|
||||
return ti.nilType
|
||||
}
|
||||
|
||||
switch e := expr.(type) {
|
||||
case *Identifier:
|
||||
return ti.inferIdentifier(e)
|
||||
case *NumberLiteral:
|
||||
e.SetType(ti.numberType)
|
||||
return ti.numberType
|
||||
case *StringLiteral:
|
||||
e.SetType(ti.stringType)
|
||||
return ti.stringType
|
||||
case *BooleanLiteral:
|
||||
e.SetType(ti.boolType)
|
||||
return ti.boolType
|
||||
case *NilLiteral:
|
||||
e.SetType(ti.nilType)
|
||||
return ti.nilType
|
||||
case *TableLiteral:
|
||||
return ti.inferTableLiteral(e)
|
||||
case *FunctionLiteral:
|
||||
return ti.inferFunctionLiteral(e)
|
||||
case *CallExpression:
|
||||
return ti.inferCallExpression(e)
|
||||
case *PrefixExpression:
|
||||
return ti.inferPrefixExpression(e)
|
||||
case *InfixExpression:
|
||||
return ti.inferInfixExpression(e)
|
||||
case *IndexExpression:
|
||||
return ti.inferIndexExpression(e)
|
||||
case *DotExpression:
|
||||
return ti.inferDotExpression(e)
|
||||
default:
|
||||
ti.addError("unknown expression type", expr)
|
||||
return ti.anyType
|
||||
}
|
||||
}
|
||||
|
||||
// inferIdentifier looks up identifier type in symbol table
|
||||
func (ti *TypeInferrer) inferIdentifier(ident *Identifier) *TypeInfo {
|
||||
symbol := ti.currentScope.Lookup(ident.Value)
|
||||
if symbol == nil {
|
||||
ti.addError(fmt.Sprintf("undefined variable '%s'", ident.Value), ident)
|
||||
return ti.anyType
|
||||
}
|
||||
|
||||
ident.SetType(symbol.Type)
|
||||
return symbol.Type
|
||||
}
|
||||
|
||||
// inferTableLiteral infers table type
|
||||
func (ti *TypeInferrer) inferTableLiteral(table *TableLiteral) *TypeInfo {
|
||||
// Infer types of all values
|
||||
for _, pair := range table.Pairs {
|
||||
if pair.Key != nil {
|
||||
ti.inferExpression(pair.Key)
|
||||
}
|
||||
ti.inferExpression(pair.Value)
|
||||
}
|
||||
|
||||
table.SetType(ti.tableType)
|
||||
return ti.tableType
|
||||
}
|
||||
|
||||
// inferFunctionLiteral infers function type
|
||||
func (ti *TypeInferrer) inferFunctionLiteral(fn *FunctionLiteral) *TypeInfo {
|
||||
ti.enterScope()
|
||||
|
||||
// Define parameters in function scope
|
||||
for _, param := range fn.Parameters {
|
||||
paramType := ti.anyType
|
||||
if param.TypeHint != nil {
|
||||
paramType = param.TypeHint
|
||||
}
|
||||
|
||||
ti.currentScope.Define(&Symbol{
|
||||
Name: param.Name,
|
||||
Type: paramType,
|
||||
Declared: true,
|
||||
})
|
||||
}
|
||||
|
||||
// Infer body
|
||||
for _, stmt := range fn.Body {
|
||||
ti.inferStatement(stmt)
|
||||
}
|
||||
|
||||
ti.exitScope()
|
||||
|
||||
// For now, all functions have type "function"
|
||||
funcType := &TypeInfo{Type: TypeFunction, Inferred: true}
|
||||
fn.SetType(funcType)
|
||||
return funcType
|
||||
}
|
||||
|
||||
// inferCallExpression infers function call return type
|
||||
func (ti *TypeInferrer) inferCallExpression(call *CallExpression) *TypeInfo {
|
||||
funcType := ti.inferExpression(call.Function)
|
||||
|
||||
if !ti.isFunctionType(funcType) {
|
||||
ti.addError("cannot call non-function", call.Function)
|
||||
return ti.anyType
|
||||
}
|
||||
|
||||
// Infer argument types
|
||||
for _, arg := range call.Arguments {
|
||||
ti.inferExpression(arg)
|
||||
}
|
||||
|
||||
// For now, assume function calls return any
|
||||
call.SetType(ti.anyType)
|
||||
return ti.anyType
|
||||
}
|
||||
|
||||
// inferPrefixExpression infers prefix operation type
|
||||
func (ti *TypeInferrer) inferPrefixExpression(prefix *PrefixExpression) *TypeInfo {
|
||||
rightType := ti.inferExpression(prefix.Right)
|
||||
|
||||
var resultType *TypeInfo
|
||||
switch prefix.Operator {
|
||||
case "-":
|
||||
if !ti.isNumericType(rightType) {
|
||||
ti.addError("unary minus requires numeric operand", prefix)
|
||||
}
|
||||
resultType = ti.numberType
|
||||
case "not":
|
||||
resultType = ti.boolType
|
||||
default:
|
||||
ti.addError(fmt.Sprintf("unknown prefix operator '%s'", prefix.Operator), prefix)
|
||||
resultType = ti.anyType
|
||||
}
|
||||
|
||||
prefix.SetType(resultType)
|
||||
return resultType
|
||||
}
|
||||
|
||||
// inferInfixExpression infers binary operation type
|
||||
func (ti *TypeInferrer) inferInfixExpression(infix *InfixExpression) *TypeInfo {
|
||||
leftType := ti.inferExpression(infix.Left)
|
||||
rightType := ti.inferExpression(infix.Right)
|
||||
|
||||
var resultType *TypeInfo
|
||||
|
||||
switch infix.Operator {
|
||||
case "+", "-", "*", "/":
|
||||
if !ti.isNumericType(leftType) || !ti.isNumericType(rightType) {
|
||||
ti.addError(fmt.Sprintf("arithmetic operator '%s' requires numeric operands", infix.Operator), infix)
|
||||
}
|
||||
resultType = ti.numberType
|
||||
|
||||
case "==", "!=":
|
||||
// Equality works with any types
|
||||
resultType = ti.boolType
|
||||
|
||||
case "<", ">", "<=", ">=":
|
||||
if !ti.isComparableTypes(leftType, rightType) {
|
||||
ti.addError(fmt.Sprintf("comparison operator '%s' requires compatible operands", infix.Operator), infix)
|
||||
}
|
||||
resultType = ti.boolType
|
||||
|
||||
case "and", "or":
|
||||
ti.validateBooleanContext(leftType, infix.Left)
|
||||
ti.validateBooleanContext(rightType, infix.Right)
|
||||
resultType = ti.boolType
|
||||
|
||||
default:
|
||||
ti.addError(fmt.Sprintf("unknown infix operator '%s'", infix.Operator), infix)
|
||||
resultType = ti.anyType
|
||||
}
|
||||
|
||||
infix.SetType(resultType)
|
||||
return resultType
|
||||
}
|
||||
|
||||
// inferIndexExpression infers table[index] type
|
||||
func (ti *TypeInferrer) inferIndexExpression(index *IndexExpression) *TypeInfo {
|
||||
ti.inferExpression(index.Left)
|
||||
ti.inferExpression(index.Index)
|
||||
|
||||
// For now, assume table access returns any
|
||||
index.SetType(ti.anyType)
|
||||
return ti.anyType
|
||||
}
|
||||
|
||||
// inferDotExpression infers table.key type
|
||||
func (ti *TypeInferrer) inferDotExpression(dot *DotExpression) *TypeInfo {
|
||||
ti.inferExpression(dot.Left)
|
||||
|
||||
// For now, assume member access returns any
|
||||
dot.SetType(ti.anyType)
|
||||
return ti.anyType
|
||||
}
|
||||
|
||||
// Type checking helper methods
|
||||
|
||||
func (ti *TypeInferrer) isTypeCompatible(valueType, targetType *TypeInfo) bool {
|
||||
if targetType.Type == TypeAny || valueType.Type == TypeAny {
|
||||
return true
|
||||
}
|
||||
return valueType.Type == targetType.Type
|
||||
}
|
||||
|
||||
func (ti *TypeInferrer) isNumericType(t *TypeInfo) bool {
|
||||
return t.Type == TypeNumber
|
||||
}
|
||||
|
||||
func (ti *TypeInferrer) isBooleanType(t *TypeInfo) bool {
|
||||
return t.Type == TypeBool
|
||||
}
|
||||
|
||||
func (ti *TypeInferrer) isTableType(t *TypeInfo) bool {
|
||||
return t.Type == TypeTable
|
||||
}
|
||||
|
||||
func (ti *TypeInferrer) isFunctionType(t *TypeInfo) bool {
|
||||
return t.Type == TypeFunction
|
||||
}
|
||||
|
||||
func (ti *TypeInferrer) isComparableTypes(left, right *TypeInfo) bool {
|
||||
if left.Type == TypeAny || right.Type == TypeAny {
|
||||
return true
|
||||
}
|
||||
return left.Type == right.Type && (left.Type == TypeNumber || left.Type == TypeString)
|
||||
}
|
||||
|
||||
func (ti *TypeInferrer) validateBooleanContext(t *TypeInfo, expr Expression) {
|
||||
// In many languages, non-boolean values can be used in boolean context
|
||||
// For strictness, we could require boolean type here
|
||||
// For now, allow any type (truthy/falsy semantics)
|
||||
}
|
||||
|
||||
// Errors returns all type checking errors
|
||||
func (ti *TypeInferrer) Errors() []TypeError {
|
||||
return ti.errors
|
||||
}
|
||||
|
||||
// HasErrors returns true if there are any type errors
|
||||
func (ti *TypeInferrer) HasErrors() bool {
|
||||
return len(ti.errors) > 0
|
||||
}
|
||||
|
||||
// ErrorStrings returns error messages as strings
|
||||
func (ti *TypeInferrer) ErrorStrings() []string {
|
||||
result := make([]string, len(ti.errors))
|
||||
for i, err := range ti.errors {
|
||||
result[i] = err.Error()
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ValidTypeName checks if a string is a valid type name
|
||||
func ValidTypeName(name string) bool {
|
||||
validTypes := []string{TypeNumber, TypeString, TypeBool, TypeNil, TypeTable, TypeFunction, TypeAny}
|
||||
for _, validType := range validTypes {
|
||||
if name == validType {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ParseTypeName converts a string to a TypeInfo (for parsing type hints)
|
||||
func ParseTypeName(name string) *TypeInfo {
|
||||
if ValidTypeName(name) {
|
||||
return &TypeInfo{Type: name, Inferred: false}
|
||||
}
|
||||
return nil
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user