type hint/inference
This commit is contained in:
parent
c691c90c69
commit
e98e0643e2
125
parser/ast.go
125
parser/ast.go
@ -2,6 +2,12 @@ package parser
|
|||||||
|
|
||||||
import "fmt"
|
import "fmt"
|
||||||
|
|
||||||
|
// TypeInfo represents type information for expressions
|
||||||
|
type TypeInfo struct {
|
||||||
|
Type string // "number", "string", "bool", "table", "function", "nil", "any"
|
||||||
|
Inferred bool // true if type was inferred, false if explicitly declared
|
||||||
|
}
|
||||||
|
|
||||||
// Node represents any node in the AST
|
// Node represents any node in the AST
|
||||||
type Node interface {
|
type Node interface {
|
||||||
String() string
|
String() string
|
||||||
@ -17,6 +23,8 @@ type Statement interface {
|
|||||||
type Expression interface {
|
type Expression interface {
|
||||||
Node
|
Node
|
||||||
expressionNode()
|
expressionNode()
|
||||||
|
GetType() *TypeInfo
|
||||||
|
SetType(*TypeInfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Program represents the root of the AST
|
// Program represents the root of the AST
|
||||||
@ -33,9 +41,10 @@ func (p *Program) String() string {
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
// AssignStatement represents variable assignment
|
// AssignStatement represents variable assignment with optional type hint
|
||||||
type AssignStatement struct {
|
type AssignStatement struct {
|
||||||
Name Expression // Changed from *Identifier to Expression for member access
|
Name Expression // Changed from *Identifier to Expression for member access
|
||||||
|
TypeHint *TypeInfo // optional type hint
|
||||||
Value Expression
|
Value Expression
|
||||||
IsDeclaration bool // true if this is the first assignment in current scope
|
IsDeclaration bool // true if this is the first assignment in current scope
|
||||||
}
|
}
|
||||||
@ -46,7 +55,15 @@ func (as *AssignStatement) String() string {
|
|||||||
if as.IsDeclaration {
|
if as.IsDeclaration {
|
||||||
prefix = "local "
|
prefix = "local "
|
||||||
}
|
}
|
||||||
return fmt.Sprintf("%s%s = %s", prefix, as.Name.String(), as.Value.String())
|
|
||||||
|
var nameStr string
|
||||||
|
if as.TypeHint != nil {
|
||||||
|
nameStr = fmt.Sprintf("%s: %s", as.Name.String(), as.TypeHint.Type)
|
||||||
|
} else {
|
||||||
|
nameStr = as.Name.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf("%s%s = %s", prefix, nameStr, as.Value.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
// EchoStatement represents echo output statements
|
// EchoStatement represents echo output statements
|
||||||
@ -216,33 +233,56 @@ func (fis *ForInStatement) String() string {
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
// Identifier represents identifiers
|
// FunctionParameter represents a function parameter with optional type hint
|
||||||
type Identifier struct {
|
type FunctionParameter struct {
|
||||||
Value string
|
Name string
|
||||||
|
TypeHint *TypeInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *Identifier) expressionNode() {}
|
func (fp *FunctionParameter) String() string {
|
||||||
func (i *Identifier) String() string { return i.Value }
|
if fp.TypeHint != nil {
|
||||||
|
return fmt.Sprintf("%s: %s", fp.Name, fp.TypeHint.Type)
|
||||||
|
}
|
||||||
|
return fp.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
// Identifier represents identifiers
|
||||||
|
type Identifier struct {
|
||||||
|
Value string
|
||||||
|
typeInfo *TypeInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *Identifier) expressionNode() {}
|
||||||
|
func (i *Identifier) String() string { return i.Value }
|
||||||
|
func (i *Identifier) GetType() *TypeInfo { return i.typeInfo }
|
||||||
|
func (i *Identifier) SetType(t *TypeInfo) { i.typeInfo = t }
|
||||||
|
|
||||||
// NumberLiteral represents numeric literals
|
// NumberLiteral represents numeric literals
|
||||||
type NumberLiteral struct {
|
type NumberLiteral struct {
|
||||||
Value float64
|
Value float64
|
||||||
|
typeInfo *TypeInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
func (nl *NumberLiteral) expressionNode() {}
|
func (nl *NumberLiteral) expressionNode() {}
|
||||||
func (nl *NumberLiteral) String() string { return fmt.Sprintf("%.2f", nl.Value) }
|
func (nl *NumberLiteral) String() string { return fmt.Sprintf("%.2f", nl.Value) }
|
||||||
|
func (nl *NumberLiteral) GetType() *TypeInfo { return nl.typeInfo }
|
||||||
|
func (nl *NumberLiteral) SetType(t *TypeInfo) { nl.typeInfo = t }
|
||||||
|
|
||||||
// StringLiteral represents string literals
|
// StringLiteral represents string literals
|
||||||
type StringLiteral struct {
|
type StringLiteral struct {
|
||||||
Value string
|
Value string
|
||||||
|
typeInfo *TypeInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sl *StringLiteral) expressionNode() {}
|
func (sl *StringLiteral) expressionNode() {}
|
||||||
func (sl *StringLiteral) String() string { return fmt.Sprintf(`"%s"`, sl.Value) }
|
func (sl *StringLiteral) String() string { return fmt.Sprintf(`"%s"`, sl.Value) }
|
||||||
|
func (sl *StringLiteral) GetType() *TypeInfo { return sl.typeInfo }
|
||||||
|
func (sl *StringLiteral) SetType(t *TypeInfo) { sl.typeInfo = t }
|
||||||
|
|
||||||
// BooleanLiteral represents boolean literals
|
// BooleanLiteral represents boolean literals
|
||||||
type BooleanLiteral struct {
|
type BooleanLiteral struct {
|
||||||
Value bool
|
Value bool
|
||||||
|
typeInfo *TypeInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bl *BooleanLiteral) expressionNode() {}
|
func (bl *BooleanLiteral) expressionNode() {}
|
||||||
@ -252,18 +292,26 @@ func (bl *BooleanLiteral) String() string {
|
|||||||
}
|
}
|
||||||
return "false"
|
return "false"
|
||||||
}
|
}
|
||||||
|
func (bl *BooleanLiteral) GetType() *TypeInfo { return bl.typeInfo }
|
||||||
|
func (bl *BooleanLiteral) SetType(t *TypeInfo) { bl.typeInfo = t }
|
||||||
|
|
||||||
// NilLiteral represents nil literal
|
// NilLiteral represents nil literal
|
||||||
type NilLiteral struct{}
|
type NilLiteral struct {
|
||||||
|
typeInfo *TypeInfo
|
||||||
|
}
|
||||||
|
|
||||||
func (nl *NilLiteral) expressionNode() {}
|
func (nl *NilLiteral) expressionNode() {}
|
||||||
func (nl *NilLiteral) String() string { return "nil" }
|
func (nl *NilLiteral) String() string { return "nil" }
|
||||||
|
func (nl *NilLiteral) GetType() *TypeInfo { return nl.typeInfo }
|
||||||
|
func (nl *NilLiteral) SetType(t *TypeInfo) { nl.typeInfo = t }
|
||||||
|
|
||||||
// FunctionLiteral represents function literals: fn(a, b, ...) ... end
|
// FunctionLiteral represents function literals with typed parameters
|
||||||
type FunctionLiteral struct {
|
type FunctionLiteral struct {
|
||||||
Parameters []string
|
Parameters []FunctionParameter
|
||||||
Variadic bool
|
Variadic bool
|
||||||
|
ReturnType *TypeInfo // optional return type hint
|
||||||
Body []Statement
|
Body []Statement
|
||||||
|
typeInfo *TypeInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fl *FunctionLiteral) expressionNode() {}
|
func (fl *FunctionLiteral) expressionNode() {}
|
||||||
@ -273,7 +321,7 @@ func (fl *FunctionLiteral) String() string {
|
|||||||
if i > 0 {
|
if i > 0 {
|
||||||
params += ", "
|
params += ", "
|
||||||
}
|
}
|
||||||
params += param
|
params += param.String()
|
||||||
}
|
}
|
||||||
if fl.Variadic {
|
if fl.Variadic {
|
||||||
if len(fl.Parameters) > 0 {
|
if len(fl.Parameters) > 0 {
|
||||||
@ -282,18 +330,26 @@ func (fl *FunctionLiteral) String() string {
|
|||||||
params += "..."
|
params += "..."
|
||||||
}
|
}
|
||||||
|
|
||||||
result := fmt.Sprintf("fn(%s)\n", params)
|
result := fmt.Sprintf("fn(%s)", params)
|
||||||
|
if fl.ReturnType != nil {
|
||||||
|
result += ": " + fl.ReturnType.Type
|
||||||
|
}
|
||||||
|
result += "\n"
|
||||||
|
|
||||||
for _, stmt := range fl.Body {
|
for _, stmt := range fl.Body {
|
||||||
result += "\t" + stmt.String() + "\n"
|
result += "\t" + stmt.String() + "\n"
|
||||||
}
|
}
|
||||||
result += "end"
|
result += "end"
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
func (fl *FunctionLiteral) GetType() *TypeInfo { return fl.typeInfo }
|
||||||
|
func (fl *FunctionLiteral) SetType(t *TypeInfo) { fl.typeInfo = t }
|
||||||
|
|
||||||
// CallExpression represents function calls: func(arg1, arg2, ...)
|
// CallExpression represents function calls: func(arg1, arg2, ...)
|
||||||
type CallExpression struct {
|
type CallExpression struct {
|
||||||
Function Expression
|
Function Expression
|
||||||
Arguments []Expression
|
Arguments []Expression
|
||||||
|
typeInfo *TypeInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ce *CallExpression) expressionNode() {}
|
func (ce *CallExpression) expressionNode() {}
|
||||||
@ -304,11 +360,14 @@ func (ce *CallExpression) String() string {
|
|||||||
}
|
}
|
||||||
return fmt.Sprintf("%s(%s)", ce.Function.String(), joinStrings(args, ", "))
|
return fmt.Sprintf("%s(%s)", ce.Function.String(), joinStrings(args, ", "))
|
||||||
}
|
}
|
||||||
|
func (ce *CallExpression) GetType() *TypeInfo { return ce.typeInfo }
|
||||||
|
func (ce *CallExpression) SetType(t *TypeInfo) { ce.typeInfo = t }
|
||||||
|
|
||||||
// PrefixExpression represents prefix operations like -x, not x
|
// PrefixExpression represents prefix operations like -x, not x
|
||||||
type PrefixExpression struct {
|
type PrefixExpression struct {
|
||||||
Operator string
|
Operator string
|
||||||
Right Expression
|
Right Expression
|
||||||
|
typeInfo *TypeInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pe *PrefixExpression) expressionNode() {}
|
func (pe *PrefixExpression) expressionNode() {}
|
||||||
@ -319,40 +378,51 @@ func (pe *PrefixExpression) String() string {
|
|||||||
}
|
}
|
||||||
return fmt.Sprintf("(%s%s)", pe.Operator, pe.Right.String())
|
return fmt.Sprintf("(%s%s)", pe.Operator, pe.Right.String())
|
||||||
}
|
}
|
||||||
|
func (pe *PrefixExpression) GetType() *TypeInfo { return pe.typeInfo }
|
||||||
|
func (pe *PrefixExpression) SetType(t *TypeInfo) { pe.typeInfo = t }
|
||||||
|
|
||||||
// InfixExpression represents binary operations
|
// InfixExpression represents binary operations
|
||||||
type InfixExpression struct {
|
type InfixExpression struct {
|
||||||
Left Expression
|
Left Expression
|
||||||
Operator string
|
Operator string
|
||||||
Right Expression
|
Right Expression
|
||||||
|
typeInfo *TypeInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ie *InfixExpression) expressionNode() {}
|
func (ie *InfixExpression) expressionNode() {}
|
||||||
func (ie *InfixExpression) String() string {
|
func (ie *InfixExpression) String() string {
|
||||||
return fmt.Sprintf("(%s %s %s)", ie.Left.String(), ie.Operator, ie.Right.String())
|
return fmt.Sprintf("(%s %s %s)", ie.Left.String(), ie.Operator, ie.Right.String())
|
||||||
}
|
}
|
||||||
|
func (ie *InfixExpression) GetType() *TypeInfo { return ie.typeInfo }
|
||||||
|
func (ie *InfixExpression) SetType(t *TypeInfo) { ie.typeInfo = t }
|
||||||
|
|
||||||
// IndexExpression represents table[key] access
|
// IndexExpression represents table[key] access
|
||||||
type IndexExpression struct {
|
type IndexExpression struct {
|
||||||
Left Expression
|
Left Expression
|
||||||
Index Expression
|
Index Expression
|
||||||
|
typeInfo *TypeInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ie *IndexExpression) expressionNode() {}
|
func (ie *IndexExpression) expressionNode() {}
|
||||||
func (ie *IndexExpression) String() string {
|
func (ie *IndexExpression) String() string {
|
||||||
return fmt.Sprintf("%s[%s]", ie.Left.String(), ie.Index.String())
|
return fmt.Sprintf("%s[%s]", ie.Left.String(), ie.Index.String())
|
||||||
}
|
}
|
||||||
|
func (ie *IndexExpression) GetType() *TypeInfo { return ie.typeInfo }
|
||||||
|
func (ie *IndexExpression) SetType(t *TypeInfo) { ie.typeInfo = t }
|
||||||
|
|
||||||
// DotExpression represents table.key access
|
// DotExpression represents table.key access
|
||||||
type DotExpression struct {
|
type DotExpression struct {
|
||||||
Left Expression
|
Left Expression
|
||||||
Key string
|
Key string
|
||||||
|
typeInfo *TypeInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
func (de *DotExpression) expressionNode() {}
|
func (de *DotExpression) expressionNode() {}
|
||||||
func (de *DotExpression) String() string {
|
func (de *DotExpression) String() string {
|
||||||
return fmt.Sprintf("%s.%s", de.Left.String(), de.Key)
|
return fmt.Sprintf("%s.%s", de.Left.String(), de.Key)
|
||||||
}
|
}
|
||||||
|
func (de *DotExpression) GetType() *TypeInfo { return de.typeInfo }
|
||||||
|
func (de *DotExpression) SetType(t *TypeInfo) { de.typeInfo = t }
|
||||||
|
|
||||||
// TablePair represents a key-value pair in a table
|
// TablePair represents a key-value pair in a table
|
||||||
type TablePair struct {
|
type TablePair struct {
|
||||||
@ -369,7 +439,8 @@ func (tp *TablePair) String() string {
|
|||||||
|
|
||||||
// TableLiteral represents table literals {}
|
// TableLiteral represents table literals {}
|
||||||
type TableLiteral struct {
|
type TableLiteral struct {
|
||||||
Pairs []TablePair
|
Pairs []TablePair
|
||||||
|
typeInfo *TypeInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tl *TableLiteral) expressionNode() {}
|
func (tl *TableLiteral) expressionNode() {}
|
||||||
@ -380,6 +451,8 @@ func (tl *TableLiteral) String() string {
|
|||||||
}
|
}
|
||||||
return fmt.Sprintf("{%s}", joinStrings(pairs, ", "))
|
return fmt.Sprintf("{%s}", joinStrings(pairs, ", "))
|
||||||
}
|
}
|
||||||
|
func (tl *TableLiteral) GetType() *TypeInfo { return tl.typeInfo }
|
||||||
|
func (tl *TableLiteral) SetType(t *TypeInfo) { tl.typeInfo = t }
|
||||||
|
|
||||||
// IsArray returns true if this table contains only array-style elements
|
// IsArray returns true if this table contains only array-style elements
|
||||||
func (tl *TableLiteral) IsArray() bool {
|
func (tl *TableLiteral) IsArray() bool {
|
||||||
|
@ -257,6 +257,8 @@ func (l *Lexer) NextToken() Token {
|
|||||||
tok = Token{Type: STAR, Literal: string(l.ch), Line: l.line, Column: l.column}
|
tok = Token{Type: STAR, Literal: string(l.ch), Line: l.line, Column: l.column}
|
||||||
case '/':
|
case '/':
|
||||||
tok = Token{Type: SLASH, Literal: string(l.ch), Line: l.line, Column: l.column}
|
tok = Token{Type: SLASH, Literal: string(l.ch), Line: l.line, Column: l.column}
|
||||||
|
case ':':
|
||||||
|
tok = Token{Type: COLON, Literal: string(l.ch), Line: l.line, Column: l.column}
|
||||||
case '.':
|
case '.':
|
||||||
// Check for ellipsis (...)
|
// Check for ellipsis (...)
|
||||||
if l.peekChar() == '.' && l.peekCharAt(2) == '.' {
|
if l.peekChar() == '.' && l.peekCharAt(2) == '.' {
|
||||||
|
191
parser/parser.go
191
parser/parser.go
@ -89,14 +89,13 @@ func (p *Parser) enterScope(scopeType string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *Parser) exitScope() {
|
func (p *Parser) exitScope() {
|
||||||
if len(p.scopes) > 1 { // never remove global scope
|
if len(p.scopes) > 1 {
|
||||||
p.scopes = p.scopes[:len(p.scopes)-1]
|
p.scopes = p.scopes[:len(p.scopes)-1]
|
||||||
p.scopeTypes = p.scopeTypes[:len(p.scopeTypes)-1]
|
p.scopeTypes = p.scopeTypes[:len(p.scopeTypes)-1]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Parser) enterFunctionScope() {
|
func (p *Parser) enterFunctionScope() {
|
||||||
// Functions create new variable scopes
|
|
||||||
p.enterScope("function")
|
p.enterScope("function")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -105,35 +104,29 @@ func (p *Parser) exitFunctionScope() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *Parser) enterLoopScope() {
|
func (p *Parser) enterLoopScope() {
|
||||||
// Create temporary scope for loop variables only
|
|
||||||
p.enterScope("loop")
|
p.enterScope("loop")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Parser) exitLoopScope() {
|
func (p *Parser) exitLoopScope() {
|
||||||
// Remove temporary loop scope
|
|
||||||
p.exitScope()
|
p.exitScope()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Parser) enterBlockScope() {
|
func (p *Parser) enterBlockScope() {
|
||||||
// Blocks don't create new variable scopes, just control flow scopes
|
// Blocks don't create new variable scopes
|
||||||
// We don't need to track these for variable declarations
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Parser) exitBlockScope() {
|
func (p *Parser) exitBlockScope() {
|
||||||
// No-op since blocks don't create variable scopes
|
// No-op
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Parser) currentVariableScope() map[string]bool {
|
func (p *Parser) currentVariableScope() map[string]bool {
|
||||||
// If we're in a loop scope, declare variables in the parent scope
|
|
||||||
if len(p.scopeTypes) > 1 && p.scopeTypes[len(p.scopeTypes)-1] == "loop" {
|
if len(p.scopeTypes) > 1 && p.scopeTypes[len(p.scopeTypes)-1] == "loop" {
|
||||||
return p.scopes[len(p.scopes)-2]
|
return p.scopes[len(p.scopes)-2]
|
||||||
}
|
}
|
||||||
// Otherwise use the current scope
|
|
||||||
return p.scopes[len(p.scopes)-1]
|
return p.scopes[len(p.scopes)-1]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Parser) isVariableDeclared(name string) bool {
|
func (p *Parser) isVariableDeclared(name string) bool {
|
||||||
// Check all scopes from current up to global
|
|
||||||
for i := len(p.scopes) - 1; i >= 0; i-- {
|
for i := len(p.scopes) - 1; i >= 0; i-- {
|
||||||
if p.scopes[i][name] {
|
if p.scopes[i][name] {
|
||||||
return true
|
return true
|
||||||
@ -147,10 +140,31 @@ func (p *Parser) declareVariable(name string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *Parser) declareLoopVariable(name string) {
|
func (p *Parser) declareLoopVariable(name string) {
|
||||||
// Loop variables go in the current loop scope
|
|
||||||
p.scopes[len(p.scopes)-1][name] = true
|
p.scopes[len(p.scopes)-1][name] = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// parseTypeHint parses optional type hint after colon
|
||||||
|
func (p *Parser) parseTypeHint() *TypeInfo {
|
||||||
|
if !p.peekTokenIs(COLON) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
p.nextToken() // consume ':'
|
||||||
|
|
||||||
|
if !p.expectPeekIdent() {
|
||||||
|
p.addError("expected type name after ':'")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
typeName := p.curToken.Literal
|
||||||
|
if !ValidTypeName(typeName) {
|
||||||
|
p.addError(fmt.Sprintf("invalid type name '%s'", typeName))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &TypeInfo{Type: typeName, Inferred: false}
|
||||||
|
}
|
||||||
|
|
||||||
// registerPrefix registers a prefix parse function
|
// registerPrefix registers a prefix parse function
|
||||||
func (p *Parser) registerPrefix(tokenType TokenType, fn func() Expression) {
|
func (p *Parser) registerPrefix(tokenType TokenType, fn func() Expression) {
|
||||||
p.prefixParseFns[tokenType] = fn
|
p.prefixParseFns[tokenType] = fn
|
||||||
@ -187,7 +201,6 @@ func (p *Parser) ParseProgram() *Program {
|
|||||||
func (p *Parser) parseStatement() Statement {
|
func (p *Parser) parseStatement() Statement {
|
||||||
switch p.curToken.Type {
|
switch p.curToken.Type {
|
||||||
case IDENT:
|
case IDENT:
|
||||||
// Try to parse as assignment (handles both simple and member access)
|
|
||||||
return p.parseAssignStatement()
|
return p.parseAssignStatement()
|
||||||
case IF:
|
case IF:
|
||||||
return p.parseIfStatement()
|
return p.parseIfStatement()
|
||||||
@ -217,17 +230,22 @@ func (p *Parser) parseStatement() Statement {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseAssignStatement parses variable assignment
|
// parseAssignStatement parses variable assignment with optional type hint
|
||||||
func (p *Parser) parseAssignStatement() *AssignStatement {
|
func (p *Parser) parseAssignStatement() *AssignStatement {
|
||||||
stmt := &AssignStatement{}
|
stmt := &AssignStatement{}
|
||||||
|
|
||||||
// Parse left-hand side expression (can be identifier or member access)
|
// Parse left-hand side expression
|
||||||
stmt.Name = p.ParseExpression(LOWEST)
|
stmt.Name = p.ParseExpression(LOWEST)
|
||||||
if stmt.Name == nil {
|
if stmt.Name == nil {
|
||||||
p.addError("expected expression for assignment left-hand side")
|
p.addError("expected expression for assignment left-hand side")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check for type hint on simple identifiers
|
||||||
|
if _, ok := stmt.Name.(*Identifier); ok {
|
||||||
|
stmt.TypeHint = p.parseTypeHint()
|
||||||
|
}
|
||||||
|
|
||||||
// Check if next token is assignment operator
|
// Check if next token is assignment operator
|
||||||
if !p.peekTokenIs(ASSIGN) {
|
if !p.peekTokenIs(ASSIGN) {
|
||||||
p.addError("unexpected identifier, expected assignment or declaration")
|
p.addError("unexpected identifier, expected assignment or declaration")
|
||||||
@ -237,13 +255,11 @@ func (p *Parser) parseAssignStatement() *AssignStatement {
|
|||||||
// Validate assignment target and check if it's a declaration
|
// Validate assignment target and check if it's a declaration
|
||||||
switch name := stmt.Name.(type) {
|
switch name := stmt.Name.(type) {
|
||||||
case *Identifier:
|
case *Identifier:
|
||||||
// Simple variable assignment - check if it's a declaration
|
|
||||||
stmt.IsDeclaration = !p.isVariableDeclared(name.Value)
|
stmt.IsDeclaration = !p.isVariableDeclared(name.Value)
|
||||||
if stmt.IsDeclaration {
|
if stmt.IsDeclaration {
|
||||||
p.declareVariable(name.Value)
|
p.declareVariable(name.Value)
|
||||||
}
|
}
|
||||||
case *DotExpression, *IndexExpression:
|
case *DotExpression, *IndexExpression:
|
||||||
// Member access - never a declaration
|
|
||||||
stmt.IsDeclaration = false
|
stmt.IsDeclaration = false
|
||||||
default:
|
default:
|
||||||
p.addError("invalid assignment target")
|
p.addError("invalid assignment target")
|
||||||
@ -289,10 +305,8 @@ func (p *Parser) parseBreakStatement() *BreakStatement {
|
|||||||
func (p *Parser) parseExitStatement() *ExitStatement {
|
func (p *Parser) parseExitStatement() *ExitStatement {
|
||||||
stmt := &ExitStatement{}
|
stmt := &ExitStatement{}
|
||||||
|
|
||||||
// Check if there's an optional expression after 'exit'
|
|
||||||
// Only parse expression if next token can start an expression
|
|
||||||
if p.canStartExpression(p.peekToken.Type) {
|
if p.canStartExpression(p.peekToken.Type) {
|
||||||
p.nextToken() // move past 'exit'
|
p.nextToken()
|
||||||
stmt.Value = p.ParseExpression(LOWEST)
|
stmt.Value = p.ParseExpression(LOWEST)
|
||||||
if stmt.Value == nil {
|
if stmt.Value == nil {
|
||||||
p.addError("expected expression after 'exit'")
|
p.addError("expected expression after 'exit'")
|
||||||
@ -307,9 +321,8 @@ func (p *Parser) parseExitStatement() *ExitStatement {
|
|||||||
func (p *Parser) parseReturnStatement() *ReturnStatement {
|
func (p *Parser) parseReturnStatement() *ReturnStatement {
|
||||||
stmt := &ReturnStatement{}
|
stmt := &ReturnStatement{}
|
||||||
|
|
||||||
// Check if there's an optional expression after 'return'
|
|
||||||
if p.canStartExpression(p.peekToken.Type) {
|
if p.canStartExpression(p.peekToken.Type) {
|
||||||
p.nextToken() // move past 'return'
|
p.nextToken()
|
||||||
stmt.Value = p.ParseExpression(LOWEST)
|
stmt.Value = p.ParseExpression(LOWEST)
|
||||||
if stmt.Value == nil {
|
if stmt.Value == nil {
|
||||||
p.addError("expected expression after 'return'")
|
p.addError("expected expression after 'return'")
|
||||||
@ -330,11 +343,11 @@ func (p *Parser) canStartExpression(tokenType TokenType) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseWhileStatement parses while loops: while condition do ... end
|
// parseWhileStatement parses while loops
|
||||||
func (p *Parser) parseWhileStatement() *WhileStatement {
|
func (p *Parser) parseWhileStatement() *WhileStatement {
|
||||||
stmt := &WhileStatement{}
|
stmt := &WhileStatement{}
|
||||||
|
|
||||||
p.nextToken() // move past 'while'
|
p.nextToken()
|
||||||
|
|
||||||
stmt.Condition = p.ParseExpression(LOWEST)
|
stmt.Condition = p.ParseExpression(LOWEST)
|
||||||
if stmt.Condition == nil {
|
if stmt.Condition == nil {
|
||||||
@ -347,9 +360,8 @@ func (p *Parser) parseWhileStatement() *WhileStatement {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
p.nextToken() // move past 'do'
|
p.nextToken()
|
||||||
|
|
||||||
// Parse loop body (no new variable scope)
|
|
||||||
p.enterBlockScope()
|
p.enterBlockScope()
|
||||||
stmt.Body = p.parseBlockStatements(END)
|
stmt.Body = p.parseBlockStatements(END)
|
||||||
p.exitBlockScope()
|
p.exitBlockScope()
|
||||||
@ -362,9 +374,9 @@ func (p *Parser) parseWhileStatement() *WhileStatement {
|
|||||||
return stmt
|
return stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseForStatement parses for loops (both numeric and for-in)
|
// parseForStatement parses for loops
|
||||||
func (p *Parser) parseForStatement() Statement {
|
func (p *Parser) parseForStatement() Statement {
|
||||||
p.nextToken() // move past 'for'
|
p.nextToken()
|
||||||
|
|
||||||
if !p.curTokenIs(IDENT) {
|
if !p.curTokenIs(IDENT) {
|
||||||
p.addError("expected identifier after 'for'")
|
p.addError("expected identifier after 'for'")
|
||||||
@ -373,12 +385,9 @@ func (p *Parser) parseForStatement() Statement {
|
|||||||
|
|
||||||
firstVar := &Identifier{Value: p.curToken.Literal}
|
firstVar := &Identifier{Value: p.curToken.Literal}
|
||||||
|
|
||||||
// Look ahead to determine which type of for loop
|
|
||||||
if p.peekTokenIs(ASSIGN) {
|
if p.peekTokenIs(ASSIGN) {
|
||||||
// Numeric for loop: for i = start, end, step do
|
|
||||||
return p.parseNumericForStatement(firstVar)
|
return p.parseNumericForStatement(firstVar)
|
||||||
} else if p.peekTokenIs(COMMA) || p.peekTokenIs(IN) {
|
} else if p.peekTokenIs(COMMA) || p.peekTokenIs(IN) {
|
||||||
// For-in loop: for k, v in expr do or for v in expr do
|
|
||||||
return p.parseForInStatement(firstVar)
|
return p.parseForInStatement(firstVar)
|
||||||
} else {
|
} else {
|
||||||
p.addError("expected '=', ',' or 'in' after for loop variable")
|
p.addError("expected '=', ',' or 'in' after for loop variable")
|
||||||
@ -386,7 +395,7 @@ func (p *Parser) parseForStatement() Statement {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseNumericForStatement parses numeric for loops: for i = start, end, step do
|
// parseNumericForStatement parses numeric for loops
|
||||||
func (p *Parser) parseNumericForStatement(variable *Identifier) *ForStatement {
|
func (p *Parser) parseNumericForStatement(variable *Identifier) *ForStatement {
|
||||||
stmt := &ForStatement{Variable: variable}
|
stmt := &ForStatement{Variable: variable}
|
||||||
|
|
||||||
@ -394,9 +403,8 @@ func (p *Parser) parseNumericForStatement(variable *Identifier) *ForStatement {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
p.nextToken() // move past '='
|
p.nextToken()
|
||||||
|
|
||||||
// Parse start expression
|
|
||||||
stmt.Start = p.ParseExpression(LOWEST)
|
stmt.Start = p.ParseExpression(LOWEST)
|
||||||
if stmt.Start == nil {
|
if stmt.Start == nil {
|
||||||
p.addError("expected start expression in for loop")
|
p.addError("expected start expression in for loop")
|
||||||
@ -408,19 +416,17 @@ func (p *Parser) parseNumericForStatement(variable *Identifier) *ForStatement {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
p.nextToken() // move past ','
|
p.nextToken()
|
||||||
|
|
||||||
// Parse end expression
|
|
||||||
stmt.End = p.ParseExpression(LOWEST)
|
stmt.End = p.ParseExpression(LOWEST)
|
||||||
if stmt.End == nil {
|
if stmt.End == nil {
|
||||||
p.addError("expected end expression in for loop")
|
p.addError("expected end expression in for loop")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Optional step expression
|
|
||||||
if p.peekTokenIs(COMMA) {
|
if p.peekTokenIs(COMMA) {
|
||||||
p.nextToken() // move to ','
|
p.nextToken()
|
||||||
p.nextToken() // move past ','
|
p.nextToken()
|
||||||
|
|
||||||
stmt.Step = p.ParseExpression(LOWEST)
|
stmt.Step = p.ParseExpression(LOWEST)
|
||||||
if stmt.Step == nil {
|
if stmt.Step == nil {
|
||||||
@ -434,13 +440,12 @@ func (p *Parser) parseNumericForStatement(variable *Identifier) *ForStatement {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
p.nextToken() // move past 'do'
|
p.nextToken()
|
||||||
|
|
||||||
// Create temporary scope for loop variable, assignments in body go to parent scope
|
|
||||||
p.enterLoopScope()
|
p.enterLoopScope()
|
||||||
p.declareLoopVariable(variable.Value) // loop variable in temporary scope
|
p.declareLoopVariable(variable.Value)
|
||||||
stmt.Body = p.parseBlockStatements(END)
|
stmt.Body = p.parseBlockStatements(END)
|
||||||
p.exitLoopScope() // discard temporary scope with loop variable
|
p.exitLoopScope()
|
||||||
|
|
||||||
if !p.curTokenIs(END) {
|
if !p.curTokenIs(END) {
|
||||||
p.addError("expected 'end' to close for loop")
|
p.addError("expected 'end' to close for loop")
|
||||||
@ -450,15 +455,14 @@ func (p *Parser) parseNumericForStatement(variable *Identifier) *ForStatement {
|
|||||||
return stmt
|
return stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseForInStatement parses for-in loops: for k, v in expr do or for v in expr do
|
// parseForInStatement parses for-in loops
|
||||||
func (p *Parser) parseForInStatement(firstVar *Identifier) *ForInStatement {
|
func (p *Parser) parseForInStatement(firstVar *Identifier) *ForInStatement {
|
||||||
stmt := &ForInStatement{}
|
stmt := &ForInStatement{}
|
||||||
|
|
||||||
if p.peekTokenIs(COMMA) {
|
if p.peekTokenIs(COMMA) {
|
||||||
// Two variables: for k, v in expr do
|
|
||||||
stmt.Key = firstVar
|
stmt.Key = firstVar
|
||||||
p.nextToken() // move to ','
|
p.nextToken()
|
||||||
p.nextToken() // move past ','
|
p.nextToken()
|
||||||
|
|
||||||
if !p.curTokenIs(IDENT) {
|
if !p.curTokenIs(IDENT) {
|
||||||
p.addError("expected identifier after ',' in for loop")
|
p.addError("expected identifier after ',' in for loop")
|
||||||
@ -467,7 +471,6 @@ func (p *Parser) parseForInStatement(firstVar *Identifier) *ForInStatement {
|
|||||||
|
|
||||||
stmt.Value = &Identifier{Value: p.curToken.Literal}
|
stmt.Value = &Identifier{Value: p.curToken.Literal}
|
||||||
} else {
|
} else {
|
||||||
// Single variable: for v in expr do
|
|
||||||
stmt.Value = firstVar
|
stmt.Value = firstVar
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -476,9 +479,8 @@ func (p *Parser) parseForInStatement(firstVar *Identifier) *ForInStatement {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
p.nextToken() // move past 'in'
|
p.nextToken()
|
||||||
|
|
||||||
// Parse iterable expression
|
|
||||||
stmt.Iterable = p.ParseExpression(LOWEST)
|
stmt.Iterable = p.ParseExpression(LOWEST)
|
||||||
if stmt.Iterable == nil {
|
if stmt.Iterable == nil {
|
||||||
p.addError("expected expression after 'in' in for loop")
|
p.addError("expected expression after 'in' in for loop")
|
||||||
@ -490,16 +492,15 @@ func (p *Parser) parseForInStatement(firstVar *Identifier) *ForInStatement {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
p.nextToken() // move past 'do'
|
p.nextToken()
|
||||||
|
|
||||||
// Create temporary scope for loop variables, assignments in body go to parent scope
|
|
||||||
p.enterLoopScope()
|
p.enterLoopScope()
|
||||||
if stmt.Key != nil {
|
if stmt.Key != nil {
|
||||||
p.declareLoopVariable(stmt.Key.Value) // loop variable in temporary scope
|
p.declareLoopVariable(stmt.Key.Value)
|
||||||
}
|
}
|
||||||
p.declareLoopVariable(stmt.Value.Value) // loop variable in temporary scope
|
p.declareLoopVariable(stmt.Value.Value)
|
||||||
stmt.Body = p.parseBlockStatements(END)
|
stmt.Body = p.parseBlockStatements(END)
|
||||||
p.exitLoopScope() // discard temporary scope with loop variables
|
p.exitLoopScope()
|
||||||
|
|
||||||
if !p.curTokenIs(END) {
|
if !p.curTokenIs(END) {
|
||||||
p.addError("expected 'end' to close for loop")
|
p.addError("expected 'end' to close for loop")
|
||||||
@ -509,11 +510,11 @@ func (p *Parser) parseForInStatement(firstVar *Identifier) *ForInStatement {
|
|||||||
return stmt
|
return stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseIfStatement parses if/elseif/else/end statements
|
// parseIfStatement parses if statements
|
||||||
func (p *Parser) parseIfStatement() *IfStatement {
|
func (p *Parser) parseIfStatement() *IfStatement {
|
||||||
stmt := &IfStatement{}
|
stmt := &IfStatement{}
|
||||||
|
|
||||||
p.nextToken() // move past 'if'
|
p.nextToken()
|
||||||
|
|
||||||
stmt.Condition = p.ParseExpression(LOWEST)
|
stmt.Condition = p.ParseExpression(LOWEST)
|
||||||
if stmt.Condition == nil {
|
if stmt.Condition == nil {
|
||||||
@ -521,29 +522,25 @@ func (p *Parser) parseIfStatement() *IfStatement {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Optional 'then' keyword
|
|
||||||
if p.peekTokenIs(THEN) {
|
if p.peekTokenIs(THEN) {
|
||||||
p.nextToken()
|
p.nextToken()
|
||||||
}
|
}
|
||||||
|
|
||||||
p.nextToken() // move past condition (and optional 'then')
|
p.nextToken()
|
||||||
|
|
||||||
// Check if we immediately hit END (empty body should be an error)
|
|
||||||
if p.curTokenIs(END) {
|
if p.curTokenIs(END) {
|
||||||
p.addError("expected 'end' to close if statement")
|
p.addError("expected 'end' to close if statement")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse if body (no new variable scope)
|
|
||||||
p.enterBlockScope()
|
p.enterBlockScope()
|
||||||
stmt.Body = p.parseBlockStatements(ELSEIF, ELSE, END)
|
stmt.Body = p.parseBlockStatements(ELSEIF, ELSE, END)
|
||||||
p.exitBlockScope()
|
p.exitBlockScope()
|
||||||
|
|
||||||
// Parse elseif clauses
|
|
||||||
for p.curTokenIs(ELSEIF) {
|
for p.curTokenIs(ELSEIF) {
|
||||||
elseif := ElseIfClause{}
|
elseif := ElseIfClause{}
|
||||||
|
|
||||||
p.nextToken() // move past 'elseif'
|
p.nextToken()
|
||||||
|
|
||||||
elseif.Condition = p.ParseExpression(LOWEST)
|
elseif.Condition = p.ParseExpression(LOWEST)
|
||||||
if elseif.Condition == nil {
|
if elseif.Condition == nil {
|
||||||
@ -551,14 +548,12 @@ func (p *Parser) parseIfStatement() *IfStatement {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Optional 'then' keyword
|
|
||||||
if p.peekTokenIs(THEN) {
|
if p.peekTokenIs(THEN) {
|
||||||
p.nextToken()
|
p.nextToken()
|
||||||
}
|
}
|
||||||
|
|
||||||
p.nextToken() // move past condition (and optional 'then')
|
p.nextToken()
|
||||||
|
|
||||||
// Parse elseif body (no new variable scope)
|
|
||||||
p.enterBlockScope()
|
p.enterBlockScope()
|
||||||
elseif.Body = p.parseBlockStatements(ELSEIF, ELSE, END)
|
elseif.Body = p.parseBlockStatements(ELSEIF, ELSE, END)
|
||||||
p.exitBlockScope()
|
p.exitBlockScope()
|
||||||
@ -566,11 +561,9 @@ func (p *Parser) parseIfStatement() *IfStatement {
|
|||||||
stmt.ElseIfs = append(stmt.ElseIfs, elseif)
|
stmt.ElseIfs = append(stmt.ElseIfs, elseif)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse else clause
|
|
||||||
if p.curTokenIs(ELSE) {
|
if p.curTokenIs(ELSE) {
|
||||||
p.nextToken() // move past 'else'
|
p.nextToken()
|
||||||
|
|
||||||
// Parse else body (no new variable scope)
|
|
||||||
p.enterBlockScope()
|
p.enterBlockScope()
|
||||||
stmt.Else = p.parseBlockStatements(END)
|
stmt.Else = p.parseBlockStatements(END)
|
||||||
p.exitBlockScope()
|
p.exitBlockScope()
|
||||||
@ -584,7 +577,7 @@ func (p *Parser) parseIfStatement() *IfStatement {
|
|||||||
return stmt
|
return stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseBlockStatements parses statements until one of the terminator tokens
|
// parseBlockStatements parses statements until terminators
|
||||||
func (p *Parser) parseBlockStatements(terminators ...TokenType) []Statement {
|
func (p *Parser) parseBlockStatements(terminators ...TokenType) []Statement {
|
||||||
statements := []Statement{}
|
statements := []Statement{}
|
||||||
|
|
||||||
@ -599,7 +592,7 @@ func (p *Parser) parseBlockStatements(terminators ...TokenType) []Statement {
|
|||||||
return statements
|
return statements
|
||||||
}
|
}
|
||||||
|
|
||||||
// isTerminator checks if current token is one of the terminators
|
// isTerminator checks if current token is a terminator
|
||||||
func (p *Parser) isTerminator(terminators ...TokenType) bool {
|
func (p *Parser) isTerminator(terminators ...TokenType) bool {
|
||||||
for _, terminator := range terminators {
|
for _, terminator := range terminators {
|
||||||
if p.curTokenIs(terminator) {
|
if p.curTokenIs(terminator) {
|
||||||
@ -650,9 +643,7 @@ func (p *Parser) parseNumberLiteral() Expression {
|
|||||||
var value float64
|
var value float64
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
// Check for hexadecimal (0x/0X prefix)
|
|
||||||
if strings.HasPrefix(literal, "0x") || strings.HasPrefix(literal, "0X") {
|
if strings.HasPrefix(literal, "0x") || strings.HasPrefix(literal, "0X") {
|
||||||
// Validate hex format
|
|
||||||
if len(literal) <= 2 {
|
if len(literal) <= 2 {
|
||||||
p.addError(fmt.Sprintf("could not parse '%s' as hexadecimal number", literal))
|
p.addError(fmt.Sprintf("could not parse '%s' as hexadecimal number", literal))
|
||||||
return nil
|
return nil
|
||||||
@ -664,7 +655,6 @@ func (p *Parser) parseNumberLiteral() Expression {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Parse as hex and convert to float64
|
|
||||||
intVal, parseErr := strconv.ParseInt(literal, 0, 64)
|
intVal, parseErr := strconv.ParseInt(literal, 0, 64)
|
||||||
if parseErr != nil {
|
if parseErr != nil {
|
||||||
p.addError(fmt.Sprintf("could not parse '%s' as hexadecimal number", literal))
|
p.addError(fmt.Sprintf("could not parse '%s' as hexadecimal number", literal))
|
||||||
@ -672,7 +662,6 @@ func (p *Parser) parseNumberLiteral() Expression {
|
|||||||
}
|
}
|
||||||
value = float64(intVal)
|
value = float64(intVal)
|
||||||
} else if strings.HasPrefix(literal, "0b") || strings.HasPrefix(literal, "0B") {
|
} else if strings.HasPrefix(literal, "0b") || strings.HasPrefix(literal, "0B") {
|
||||||
// Validate binary format
|
|
||||||
if len(literal) <= 2 {
|
if len(literal) <= 2 {
|
||||||
p.addError(fmt.Sprintf("could not parse '%s' as binary number", literal))
|
p.addError(fmt.Sprintf("could not parse '%s' as binary number", literal))
|
||||||
return nil
|
return nil
|
||||||
@ -684,8 +673,7 @@ func (p *Parser) parseNumberLiteral() Expression {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Parse binary manually since Go doesn't support 0b in ParseInt with base 0
|
binaryStr := literal[2:]
|
||||||
binaryStr := literal[2:] // remove "0b" prefix
|
|
||||||
intVal, parseErr := strconv.ParseInt(binaryStr, 2, 64)
|
intVal, parseErr := strconv.ParseInt(binaryStr, 2, 64)
|
||||||
if parseErr != nil {
|
if parseErr != nil {
|
||||||
p.addError(fmt.Sprintf("could not parse '%s' as binary number", literal))
|
p.addError(fmt.Sprintf("could not parse '%s' as binary number", literal))
|
||||||
@ -693,7 +681,6 @@ func (p *Parser) parseNumberLiteral() Expression {
|
|||||||
}
|
}
|
||||||
value = float64(intVal)
|
value = float64(intVal)
|
||||||
} else {
|
} else {
|
||||||
// Parse as regular decimal (handles scientific notation automatically)
|
|
||||||
value, err = strconv.ParseFloat(literal, 64)
|
value, err = strconv.ParseFloat(literal, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
p.addError(fmt.Sprintf("could not parse '%s' as number", literal))
|
p.addError(fmt.Sprintf("could not parse '%s' as number", literal))
|
||||||
@ -763,12 +750,14 @@ func (p *Parser) parseFunctionLiteral() Expression {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
p.nextToken() // move past ')'
|
// Check for return type hint
|
||||||
|
fn.ReturnType = p.parseTypeHint()
|
||||||
|
|
||||||
|
p.nextToken()
|
||||||
|
|
||||||
// Enter new function scope and declare parameters
|
|
||||||
p.enterFunctionScope()
|
p.enterFunctionScope()
|
||||||
for _, param := range fn.Parameters {
|
for _, param := range fn.Parameters {
|
||||||
p.declareVariable(param)
|
p.declareVariable(param.Name)
|
||||||
}
|
}
|
||||||
fn.Body = p.parseBlockStatements(END)
|
fn.Body = p.parseBlockStatements(END)
|
||||||
p.exitFunctionScope()
|
p.exitFunctionScope()
|
||||||
@ -781,8 +770,8 @@ func (p *Parser) parseFunctionLiteral() Expression {
|
|||||||
return fn
|
return fn
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Parser) parseFunctionParameters() ([]string, bool) {
|
func (p *Parser) parseFunctionParameters() ([]FunctionParameter, bool) {
|
||||||
var params []string
|
var params []FunctionParameter
|
||||||
var variadic bool
|
var variadic bool
|
||||||
|
|
||||||
if p.peekTokenIs(RPAREN) {
|
if p.peekTokenIs(RPAREN) {
|
||||||
@ -802,16 +791,20 @@ func (p *Parser) parseFunctionParameters() ([]string, bool) {
|
|||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
params = append(params, p.curToken.Literal)
|
param := FunctionParameter{Name: p.curToken.Literal}
|
||||||
|
|
||||||
|
// Check for type hint
|
||||||
|
param.TypeHint = p.parseTypeHint()
|
||||||
|
|
||||||
|
params = append(params, param)
|
||||||
|
|
||||||
if !p.peekTokenIs(COMMA) {
|
if !p.peekTokenIs(COMMA) {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
p.nextToken() // move to ','
|
p.nextToken()
|
||||||
p.nextToken() // move past ','
|
p.nextToken()
|
||||||
|
|
||||||
// Check for ellipsis after comma
|
|
||||||
if p.curTokenIs(ELLIPSIS) {
|
if p.curTokenIs(ELLIPSIS) {
|
||||||
variadic = true
|
variadic = true
|
||||||
break
|
break
|
||||||
@ -833,7 +826,6 @@ func (p *Parser) parseTableLiteral() Expression {
|
|||||||
p.nextToken()
|
p.nextToken()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
// Check for EOF
|
|
||||||
if p.curTokenIs(EOF) {
|
if p.curTokenIs(EOF) {
|
||||||
p.addError("unexpected end of input, expected }")
|
p.addError("unexpected end of input, expected }")
|
||||||
return nil
|
return nil
|
||||||
@ -841,17 +833,15 @@ func (p *Parser) parseTableLiteral() Expression {
|
|||||||
|
|
||||||
pair := TablePair{}
|
pair := TablePair{}
|
||||||
|
|
||||||
// Check if this is a key=value pair (identifier or string key)
|
|
||||||
if (p.curTokenIs(IDENT) || p.curTokenIs(STRING)) && p.peekTokenIs(ASSIGN) {
|
if (p.curTokenIs(IDENT) || p.curTokenIs(STRING)) && p.peekTokenIs(ASSIGN) {
|
||||||
if p.curTokenIs(IDENT) {
|
if p.curTokenIs(IDENT) {
|
||||||
pair.Key = &Identifier{Value: p.curToken.Literal}
|
pair.Key = &Identifier{Value: p.curToken.Literal}
|
||||||
} else {
|
} else {
|
||||||
pair.Key = &StringLiteral{Value: p.curToken.Literal}
|
pair.Key = &StringLiteral{Value: p.curToken.Literal}
|
||||||
}
|
}
|
||||||
p.nextToken() // move to =
|
p.nextToken()
|
||||||
p.nextToken() // move past =
|
p.nextToken()
|
||||||
|
|
||||||
// Check for EOF after =
|
|
||||||
if p.curTokenIs(EOF) {
|
if p.curTokenIs(EOF) {
|
||||||
p.addError("expected expression after assignment operator")
|
p.addError("expected expression after assignment operator")
|
||||||
return nil
|
return nil
|
||||||
@ -859,7 +849,6 @@ func (p *Parser) parseTableLiteral() Expression {
|
|||||||
|
|
||||||
pair.Value = p.ParseExpression(LOWEST)
|
pair.Value = p.ParseExpression(LOWEST)
|
||||||
} else {
|
} else {
|
||||||
// Array-style element
|
|
||||||
pair.Value = p.ParseExpression(LOWEST)
|
pair.Value = p.ParseExpression(LOWEST)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -873,15 +862,13 @@ func (p *Parser) parseTableLiteral() Expression {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
p.nextToken() // consume comma
|
p.nextToken()
|
||||||
p.nextToken() // move to next element
|
p.nextToken()
|
||||||
|
|
||||||
// Allow trailing comma
|
|
||||||
if p.curTokenIs(RBRACE) {
|
if p.curTokenIs(RBRACE) {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for EOF after comma
|
|
||||||
if p.curTokenIs(EOF) {
|
if p.curTokenIs(EOF) {
|
||||||
p.addError("expected next token to be }")
|
p.addError("expected next token to be }")
|
||||||
return nil
|
return nil
|
||||||
@ -956,7 +943,7 @@ func (p *Parser) parseExpressionList(end TokenType) []Expression {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *Parser) parseIndexExpression(left Expression) Expression {
|
func (p *Parser) parseIndexExpression(left Expression) Expression {
|
||||||
p.nextToken() // move past '['
|
p.nextToken()
|
||||||
|
|
||||||
index := p.ParseExpression(LOWEST)
|
index := p.ParseExpression(LOWEST)
|
||||||
if index == nil {
|
if index == nil {
|
||||||
@ -993,7 +980,6 @@ func (p *Parser) expectPeek(t TokenType) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// expectPeekIdent accepts IDENT or keyword tokens as identifiers
|
|
||||||
func (p *Parser) expectPeekIdent() bool {
|
func (p *Parser) expectPeekIdent() bool {
|
||||||
if p.peekTokenIs(IDENT) || p.isKeyword(p.peekToken.Type) {
|
if p.peekTokenIs(IDENT) || p.isKeyword(p.peekToken.Type) {
|
||||||
p.nextToken()
|
p.nextToken()
|
||||||
@ -1003,7 +989,6 @@ func (p *Parser) expectPeekIdent() bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// isKeyword checks if a token type is a keyword that can be used as identifier
|
|
||||||
func (p *Parser) isKeyword(t TokenType) bool {
|
func (p *Parser) isKeyword(t TokenType) bool {
|
||||||
switch t {
|
switch t {
|
||||||
case TRUE, FALSE, NIL, AND, OR, NOT, IF, THEN, ELSEIF, ELSE, END, ECHO, FOR, WHILE, IN, DO, BREAK, EXIT, FN, RETURN:
|
case TRUE, FALSE, NIL, AND, OR, NOT, IF, THEN, ELSEIF, ELSE, END, ECHO, FOR, WHILE, IN, DO, BREAK, EXIT, FN, RETURN:
|
||||||
@ -1013,7 +998,7 @@ func (p *Parser) isKeyword(t TokenType) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Error handling methods
|
// Error handling
|
||||||
func (p *Parser) addError(message string) {
|
func (p *Parser) addError(message string) {
|
||||||
p.errors = append(p.errors, ParseError{
|
p.errors = append(p.errors, ParseError{
|
||||||
Message: message,
|
Message: message,
|
||||||
@ -1075,12 +1060,10 @@ func (p *Parser) Errors() []ParseError {
|
|||||||
return p.errors
|
return p.errors
|
||||||
}
|
}
|
||||||
|
|
||||||
// HasErrors returns true if there are any parsing errors
|
|
||||||
func (p *Parser) HasErrors() bool {
|
func (p *Parser) HasErrors() bool {
|
||||||
return len(p.errors) > 0
|
return len(p.errors) > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
// ErrorStrings returns error messages as strings for backward compatibility
|
|
||||||
func (p *Parser) ErrorStrings() []string {
|
func (p *Parser) ErrorStrings() []string {
|
||||||
result := make([]string, len(p.errors))
|
result := make([]string, len(p.errors))
|
||||||
for i, err := range p.errors {
|
for i, err := range p.errors {
|
||||||
@ -1089,7 +1072,7 @@ func (p *Parser) ErrorStrings() []string {
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
// tokenTypeString returns a human-readable string for token types
|
// tokenTypeString returns human-readable string for token types
|
||||||
func tokenTypeString(t TokenType) string {
|
func tokenTypeString(t TokenType) string {
|
||||||
switch t {
|
switch t {
|
||||||
case IDENT:
|
case IDENT:
|
||||||
@ -1114,6 +1097,8 @@ func tokenTypeString(t TokenType) string {
|
|||||||
return "/"
|
return "/"
|
||||||
case DOT:
|
case DOT:
|
||||||
return "."
|
return "."
|
||||||
|
case COLON:
|
||||||
|
return ":"
|
||||||
case EQ:
|
case EQ:
|
||||||
return "=="
|
return "=="
|
||||||
case NOT_EQ:
|
case NOT_EQ:
|
||||||
|
@ -81,8 +81,8 @@ func TestFunctionParameters(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for i, expected := range tt.params {
|
for i, expected := range tt.params {
|
||||||
if fn.Parameters[i] != expected {
|
if fn.Parameters[i].Name != expected {
|
||||||
t.Errorf("parameter %d: expected %s, got %s", i, expected, fn.Parameters[i])
|
t.Errorf("parameter %d: expected %s, got %s", i, expected, fn.Parameters[i].Name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
533
parser/tests/types_test.go
Normal file
533
parser/tests/types_test.go
Normal file
@ -0,0 +1,533 @@
|
|||||||
|
package parser_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.sharkk.net/Sharkk/Mako/parser"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestVariableTypeHints(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
variable string
|
||||||
|
typeHint string
|
||||||
|
hasHint bool
|
||||||
|
desc string
|
||||||
|
}{
|
||||||
|
{"x = 42", "x", "", false, "no type hint"},
|
||||||
|
{"x: number = 42", "x", "number", true, "number type hint"},
|
||||||
|
{"name: string = \"hello\"", "name", "string", true, "string type hint"},
|
||||||
|
{"flag: bool = true", "flag", "bool", true, "bool type hint"},
|
||||||
|
{"data: table = {}", "data", "table", true, "table type hint"},
|
||||||
|
{"fn_var: function = fn() end", "fn_var", "function", true, "function type hint"},
|
||||||
|
{"value: any = nil", "value", "any", true, "any type hint"},
|
||||||
|
{"ptr: nil = nil", "ptr", "nil", true, "nil type hint"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.desc, func(t *testing.T) {
|
||||||
|
l := parser.NewLexer(tt.input)
|
||||||
|
p := parser.NewParser(l)
|
||||||
|
program := p.ParseProgram()
|
||||||
|
checkParserErrors(t, p)
|
||||||
|
|
||||||
|
if len(program.Statements) != 1 {
|
||||||
|
t.Fatalf("expected 1 statement, got %d", len(program.Statements))
|
||||||
|
}
|
||||||
|
|
||||||
|
stmt, ok := program.Statements[0].(*parser.AssignStatement)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected AssignStatement, got %T", program.Statements[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check variable name
|
||||||
|
ident, ok := stmt.Name.(*parser.Identifier)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected Identifier for Name, got %T", stmt.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if ident.Value != tt.variable {
|
||||||
|
t.Errorf("expected variable %s, got %s", tt.variable, ident.Value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check type hint
|
||||||
|
if tt.hasHint {
|
||||||
|
if stmt.TypeHint == nil {
|
||||||
|
t.Error("expected type hint but got nil")
|
||||||
|
} else {
|
||||||
|
if stmt.TypeHint.Type != tt.typeHint {
|
||||||
|
t.Errorf("expected type hint %s, got %s", tt.typeHint, stmt.TypeHint.Type)
|
||||||
|
}
|
||||||
|
if stmt.TypeHint.Inferred {
|
||||||
|
t.Error("expected type hint to not be inferred")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if stmt.TypeHint != nil {
|
||||||
|
t.Errorf("expected no type hint but got %s", stmt.TypeHint.Type)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFunctionParameterTypeHints(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
params []struct{ name, typeHint string }
|
||||||
|
returnType string
|
||||||
|
hasReturn bool
|
||||||
|
desc string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"fn(a, b) end",
|
||||||
|
[]struct{ name, typeHint string }{
|
||||||
|
{"a", ""},
|
||||||
|
{"b", ""},
|
||||||
|
},
|
||||||
|
"", false,
|
||||||
|
"no type hints",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"fn(a: number, b: string) end",
|
||||||
|
[]struct{ name, typeHint string }{
|
||||||
|
{"a", "number"},
|
||||||
|
{"b", "string"},
|
||||||
|
},
|
||||||
|
"", false,
|
||||||
|
"parameter type hints only",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"fn(x: number): string end",
|
||||||
|
[]struct{ name, typeHint string }{
|
||||||
|
{"x", "number"},
|
||||||
|
},
|
||||||
|
"string", true,
|
||||||
|
"parameter and return type hints",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"fn(): bool end",
|
||||||
|
[]struct{ name, typeHint string }{},
|
||||||
|
"bool", true,
|
||||||
|
"return type hint only",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"fn(a: number, b, c: string): table end",
|
||||||
|
[]struct{ name, typeHint string }{
|
||||||
|
{"a", "number"},
|
||||||
|
{"b", ""},
|
||||||
|
{"c", "string"},
|
||||||
|
},
|
||||||
|
"table", true,
|
||||||
|
"mixed parameter types with return",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"fn(callback: function, data: any): nil end",
|
||||||
|
[]struct{ name, typeHint string }{
|
||||||
|
{"callback", "function"},
|
||||||
|
{"data", "any"},
|
||||||
|
},
|
||||||
|
"nil", true,
|
||||||
|
"function and any types",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.desc, func(t *testing.T) {
|
||||||
|
l := parser.NewLexer(tt.input)
|
||||||
|
p := parser.NewParser(l)
|
||||||
|
expr := p.ParseExpression(parser.LOWEST)
|
||||||
|
checkParserErrors(t, p)
|
||||||
|
|
||||||
|
fn, ok := expr.(*parser.FunctionLiteral)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected FunctionLiteral, got %T", expr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check parameters
|
||||||
|
if len(fn.Parameters) != len(tt.params) {
|
||||||
|
t.Fatalf("expected %d parameters, got %d", len(tt.params), len(fn.Parameters))
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, expected := range tt.params {
|
||||||
|
param := fn.Parameters[i]
|
||||||
|
if param.Name != expected.name {
|
||||||
|
t.Errorf("parameter %d: expected name %s, got %s", i, expected.name, param.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if expected.typeHint == "" {
|
||||||
|
if param.TypeHint != nil {
|
||||||
|
t.Errorf("parameter %d: expected no type hint but got %s", i, param.TypeHint.Type)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if param.TypeHint == nil {
|
||||||
|
t.Errorf("parameter %d: expected type hint %s but got nil", i, expected.typeHint)
|
||||||
|
} else if param.TypeHint.Type != expected.typeHint {
|
||||||
|
t.Errorf("parameter %d: expected type hint %s, got %s", i, expected.typeHint, param.TypeHint.Type)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check return type
|
||||||
|
if tt.hasReturn {
|
||||||
|
if fn.ReturnType == nil {
|
||||||
|
t.Error("expected return type hint but got nil")
|
||||||
|
} else if fn.ReturnType.Type != tt.returnType {
|
||||||
|
t.Errorf("expected return type %s, got %s", tt.returnType, fn.ReturnType.Type)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if fn.ReturnType != nil {
|
||||||
|
t.Errorf("expected no return type but got %s", fn.ReturnType.Type)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTypeHintStringRepresentation(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
desc string
|
||||||
|
}{
|
||||||
|
{"x: number = 42", "local x: number = 42.00", "typed variable assignment"},
|
||||||
|
{"fn(a: number, b: string): bool end", "fn(a: number, b: string): bool\nend", "typed function"},
|
||||||
|
{"callback: function = fn(x: number): string return \"\" end", "local callback: function = fn(x: number): string\n\treturn \"\"\nend", "typed function assignment"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.desc, func(t *testing.T) {
|
||||||
|
l := parser.NewLexer(tt.input)
|
||||||
|
p := parser.NewParser(l)
|
||||||
|
|
||||||
|
var result string
|
||||||
|
if tt.input[0] == 'f' { // function literal
|
||||||
|
expr := p.ParseExpression(parser.LOWEST)
|
||||||
|
checkParserErrors(t, p)
|
||||||
|
result = expr.String()
|
||||||
|
} else { // assignment statement
|
||||||
|
program := p.ParseProgram()
|
||||||
|
checkParserErrors(t, p)
|
||||||
|
result = program.Statements[0].String()
|
||||||
|
}
|
||||||
|
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("expected %q, got %q", tt.expected, result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTypeHintValidation(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
expectedError string
|
||||||
|
desc string
|
||||||
|
}{
|
||||||
|
{"x: invalid = 42", "invalid type name 'invalid'", "invalid type name"},
|
||||||
|
{"x: = 42", "expected type name after ':'", "missing type name"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.desc, func(t *testing.T) {
|
||||||
|
l := parser.NewLexer(tt.input)
|
||||||
|
p := parser.NewParser(l)
|
||||||
|
p.ParseProgram()
|
||||||
|
|
||||||
|
if !p.HasErrors() {
|
||||||
|
t.Fatal("expected parsing errors")
|
||||||
|
}
|
||||||
|
|
||||||
|
errors := p.Errors()
|
||||||
|
found := false
|
||||||
|
for _, err := range errors {
|
||||||
|
if err.Message == tt.expectedError {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !found {
|
||||||
|
errorMsgs := make([]string, len(errors))
|
||||||
|
for i, err := range errors {
|
||||||
|
errorMsgs[i] = err.Message
|
||||||
|
}
|
||||||
|
t.Errorf("expected error %q, got %v", tt.expectedError, errorMsgs)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemberAccessWithoutTypeHints(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
desc string
|
||||||
|
}{
|
||||||
|
{"table.key = 42", "dot notation assignment"},
|
||||||
|
{"arr[1] = \"hello\"", "bracket notation assignment"},
|
||||||
|
{"obj.nested.deep = true", "chained dot assignment"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.desc, func(t *testing.T) {
|
||||||
|
l := parser.NewLexer(tt.input)
|
||||||
|
p := parser.NewParser(l)
|
||||||
|
program := p.ParseProgram()
|
||||||
|
checkParserErrors(t, p)
|
||||||
|
|
||||||
|
if len(program.Statements) != 1 {
|
||||||
|
t.Fatalf("expected 1 statement, got %d", len(program.Statements))
|
||||||
|
}
|
||||||
|
|
||||||
|
stmt, ok := program.Statements[0].(*parser.AssignStatement)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected AssignStatement, got %T", program.Statements[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Member access should never have type hints
|
||||||
|
if stmt.TypeHint != nil {
|
||||||
|
t.Error("member access assignment should not have type hints")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should not be a declaration
|
||||||
|
if stmt.IsDeclaration {
|
||||||
|
t.Error("member access assignment should not be a declaration")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTypeInferenceEngine(t *testing.T) {
|
||||||
|
input := `x: number = 42
|
||||||
|
name: string = "hello"
|
||||||
|
fn_var: function = fn(a: number, b: string): bool
|
||||||
|
result: bool = a > 0
|
||||||
|
return result
|
||||||
|
end
|
||||||
|
echo name`
|
||||||
|
|
||||||
|
l := parser.NewLexer(input)
|
||||||
|
p := parser.NewParser(l)
|
||||||
|
program := p.ParseProgram()
|
||||||
|
checkParserErrors(t, p)
|
||||||
|
|
||||||
|
// Run type inference
|
||||||
|
inferrer := parser.NewTypeInferrer()
|
||||||
|
typeErrors := inferrer.InferTypes(program)
|
||||||
|
|
||||||
|
// Should have no type errors for valid code
|
||||||
|
if len(typeErrors) > 0 {
|
||||||
|
errorMsgs := make([]string, len(typeErrors))
|
||||||
|
for i, err := range typeErrors {
|
||||||
|
errorMsgs[i] = err.Error()
|
||||||
|
}
|
||||||
|
t.Errorf("unexpected type errors: %v", errorMsgs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTypeInferenceErrors(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
expectedError string
|
||||||
|
desc string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"x: number = \"hello\"",
|
||||||
|
"cannot assign string to variable of type number",
|
||||||
|
"type mismatch in assignment",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"x = 42\ny: string = x",
|
||||||
|
"cannot assign number to variable of type string",
|
||||||
|
"type mismatch with inferred type",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.desc, func(t *testing.T) {
|
||||||
|
l := parser.NewLexer(tt.input)
|
||||||
|
p := parser.NewParser(l)
|
||||||
|
program := p.ParseProgram()
|
||||||
|
checkParserErrors(t, p)
|
||||||
|
|
||||||
|
inferrer := parser.NewTypeInferrer()
|
||||||
|
typeErrors := inferrer.InferTypes(program)
|
||||||
|
|
||||||
|
if len(typeErrors) == 0 {
|
||||||
|
t.Fatal("expected type errors")
|
||||||
|
}
|
||||||
|
|
||||||
|
found := false
|
||||||
|
for _, err := range typeErrors {
|
||||||
|
if err.Message == tt.expectedError {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !found {
|
||||||
|
errorMsgs := make([]string, len(typeErrors))
|
||||||
|
for i, err := range typeErrors {
|
||||||
|
errorMsgs[i] = err.Message
|
||||||
|
}
|
||||||
|
t.Errorf("expected error %q, got %v", tt.expectedError, errorMsgs)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVariadicFunctionTypeHints(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
variadic bool
|
||||||
|
desc string
|
||||||
|
}{
|
||||||
|
{"fn(a: number, ...) end", true, "variadic after typed param"},
|
||||||
|
{"fn(...) end", true, "variadic only"},
|
||||||
|
{"fn(a: number, b: string) end", false, "no variadic"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.desc, func(t *testing.T) {
|
||||||
|
l := parser.NewLexer(tt.input)
|
||||||
|
p := parser.NewParser(l)
|
||||||
|
expr := p.ParseExpression(parser.LOWEST)
|
||||||
|
checkParserErrors(t, p)
|
||||||
|
|
||||||
|
fn, ok := expr.(*parser.FunctionLiteral)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected FunctionLiteral, got %T", expr)
|
||||||
|
}
|
||||||
|
|
||||||
|
if fn.Variadic != tt.variadic {
|
||||||
|
t.Errorf("expected variadic = %t, got %t", tt.variadic, fn.Variadic)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComplexTypeHintProgram(t *testing.T) {
|
||||||
|
input := `config: table = {
|
||||||
|
host = "localhost",
|
||||||
|
port = 8080,
|
||||||
|
enabled = true
|
||||||
|
}
|
||||||
|
|
||||||
|
handler: function = fn(request: table, callback: function): nil
|
||||||
|
status: number = 200
|
||||||
|
if request.method == "GET" then
|
||||||
|
response: table = {status = status, body = "OK"}
|
||||||
|
result = callback(response)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
server: table = {
|
||||||
|
config = config,
|
||||||
|
handler = handler,
|
||||||
|
start = fn(self: table): bool
|
||||||
|
return true
|
||||||
|
end
|
||||||
|
}`
|
||||||
|
|
||||||
|
l := parser.NewLexer(input)
|
||||||
|
p := parser.NewParser(l)
|
||||||
|
program := p.ParseProgram()
|
||||||
|
checkParserErrors(t, p)
|
||||||
|
|
||||||
|
if len(program.Statements) != 3 {
|
||||||
|
t.Fatalf("expected 3 statements, got %d", len(program.Statements))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check first statement: config table with typed assignments
|
||||||
|
configStmt, ok := program.Statements[0].(*parser.AssignStatement)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected AssignStatement, got %T", program.Statements[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
if configStmt.TypeHint == nil || configStmt.TypeHint.Type != "table" {
|
||||||
|
t.Error("expected table type hint for config")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check second statement: handler function with typed parameters
|
||||||
|
handlerStmt, ok := program.Statements[1].(*parser.AssignStatement)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected AssignStatement, got %T", program.Statements[1])
|
||||||
|
}
|
||||||
|
|
||||||
|
if handlerStmt.TypeHint == nil || handlerStmt.TypeHint.Type != "function" {
|
||||||
|
t.Error("expected function type hint for handler")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn, ok := handlerStmt.Value.(*parser.FunctionLiteral)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected FunctionLiteral value, got %T", handlerStmt.Value)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(fn.Parameters) != 2 {
|
||||||
|
t.Fatalf("expected 2 parameters, got %d", len(fn.Parameters))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check parameter types
|
||||||
|
if fn.Parameters[0].TypeHint == nil || fn.Parameters[0].TypeHint.Type != "table" {
|
||||||
|
t.Error("expected table type for request parameter")
|
||||||
|
}
|
||||||
|
|
||||||
|
if fn.Parameters[1].TypeHint == nil || fn.Parameters[1].TypeHint.Type != "function" {
|
||||||
|
t.Error("expected function type for callback parameter")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check return type
|
||||||
|
if fn.ReturnType == nil || fn.ReturnType.Type != "nil" {
|
||||||
|
t.Error("expected nil return type for handler")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check third statement: server table
|
||||||
|
serverStmt, ok := program.Statements[2].(*parser.AssignStatement)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected AssignStatement, got %T", program.Statements[2])
|
||||||
|
}
|
||||||
|
|
||||||
|
if serverStmt.TypeHint == nil || serverStmt.TypeHint.Type != "table" {
|
||||||
|
t.Error("expected table type hint for server")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTypeInfoGettersSetters(t *testing.T) {
|
||||||
|
// Test that all expression types properly implement GetType/SetType
|
||||||
|
typeInfo := &parser.TypeInfo{Type: "test", Inferred: true}
|
||||||
|
|
||||||
|
expressions := []parser.Expression{
|
||||||
|
&parser.Identifier{Value: "x"},
|
||||||
|
&parser.NumberLiteral{Value: 42},
|
||||||
|
&parser.StringLiteral{Value: "hello"},
|
||||||
|
&parser.BooleanLiteral{Value: true},
|
||||||
|
&parser.NilLiteral{},
|
||||||
|
&parser.TableLiteral{},
|
||||||
|
&parser.FunctionLiteral{},
|
||||||
|
&parser.CallExpression{Function: &parser.Identifier{Value: "fn"}},
|
||||||
|
&parser.PrefixExpression{Operator: "-", Right: &parser.NumberLiteral{Value: 1}},
|
||||||
|
&parser.InfixExpression{Left: &parser.NumberLiteral{Value: 1}, Operator: "+", Right: &parser.NumberLiteral{Value: 2}},
|
||||||
|
&parser.IndexExpression{Left: &parser.Identifier{Value: "arr"}, Index: &parser.NumberLiteral{Value: 0}},
|
||||||
|
&parser.DotExpression{Left: &parser.Identifier{Value: "obj"}, Key: "prop"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, expr := range expressions {
|
||||||
|
t.Run(string(rune('0'+i)), func(t *testing.T) {
|
||||||
|
// Initially should have no type
|
||||||
|
if expr.GetType() != nil {
|
||||||
|
t.Error("expected nil type initially")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set type
|
||||||
|
expr.SetType(typeInfo)
|
||||||
|
|
||||||
|
// Get type should return what we set
|
||||||
|
retrieved := expr.GetType()
|
||||||
|
if retrieved == nil {
|
||||||
|
t.Error("expected non-nil type after setting")
|
||||||
|
} else if retrieved.Type != "test" || !retrieved.Inferred {
|
||||||
|
t.Errorf("expected {Type: test, Inferred: true}, got %+v", retrieved)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
@ -42,6 +42,7 @@ const (
|
|||||||
RBRACKET // ]
|
RBRACKET // ]
|
||||||
COMMA // ,
|
COMMA // ,
|
||||||
ELLIPSIS // ...
|
ELLIPSIS // ...
|
||||||
|
COLON // :
|
||||||
|
|
||||||
// Keywords
|
// Keywords
|
||||||
IF
|
IF
|
||||||
|
591
parser/types.go
Normal file
591
parser/types.go
Normal file
@ -0,0 +1,591 @@
|
|||||||
|
package parser
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Type constants for built-in types
|
||||||
|
const (
|
||||||
|
TypeNumber = "number"
|
||||||
|
TypeString = "string"
|
||||||
|
TypeBool = "bool"
|
||||||
|
TypeNil = "nil"
|
||||||
|
TypeTable = "table"
|
||||||
|
TypeFunction = "function"
|
||||||
|
TypeAny = "any"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TypeError represents a type checking error
|
||||||
|
type TypeError struct {
|
||||||
|
Message string
|
||||||
|
Line int
|
||||||
|
Column int
|
||||||
|
Node Node
|
||||||
|
}
|
||||||
|
|
||||||
|
func (te TypeError) Error() string {
|
||||||
|
return fmt.Sprintf("Type error at line %d, column %d: %s", te.Line, te.Column, te.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Symbol represents a variable in the symbol table
|
||||||
|
type Symbol struct {
|
||||||
|
Name string
|
||||||
|
Type *TypeInfo
|
||||||
|
Declared bool
|
||||||
|
Line int
|
||||||
|
Column int
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scope represents a scope in the symbol table
|
||||||
|
type Scope struct {
|
||||||
|
symbols map[string]*Symbol
|
||||||
|
parent *Scope
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewScope(parent *Scope) *Scope {
|
||||||
|
return &Scope{
|
||||||
|
symbols: make(map[string]*Symbol),
|
||||||
|
parent: parent,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Scope) Define(symbol *Symbol) {
|
||||||
|
s.symbols[symbol.Name] = symbol
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Scope) Lookup(name string) *Symbol {
|
||||||
|
if symbol, ok := s.symbols[name]; ok {
|
||||||
|
return symbol
|
||||||
|
}
|
||||||
|
if s.parent != nil {
|
||||||
|
return s.parent.Lookup(name)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TypeInferrer performs type inference and checking
|
||||||
|
type TypeInferrer struct {
|
||||||
|
currentScope *Scope
|
||||||
|
globalScope *Scope
|
||||||
|
errors []TypeError
|
||||||
|
|
||||||
|
// Pre-allocated type objects for performance
|
||||||
|
numberType *TypeInfo
|
||||||
|
stringType *TypeInfo
|
||||||
|
boolType *TypeInfo
|
||||||
|
nilType *TypeInfo
|
||||||
|
tableType *TypeInfo
|
||||||
|
anyType *TypeInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTypeInferrer creates a new type inference engine
|
||||||
|
func NewTypeInferrer() *TypeInferrer {
|
||||||
|
globalScope := NewScope(nil)
|
||||||
|
|
||||||
|
ti := &TypeInferrer{
|
||||||
|
currentScope: globalScope,
|
||||||
|
globalScope: globalScope,
|
||||||
|
errors: []TypeError{},
|
||||||
|
|
||||||
|
// Pre-allocate common types to reduce allocations
|
||||||
|
numberType: &TypeInfo{Type: TypeNumber, Inferred: true},
|
||||||
|
stringType: &TypeInfo{Type: TypeString, Inferred: true},
|
||||||
|
boolType: &TypeInfo{Type: TypeBool, Inferred: true},
|
||||||
|
nilType: &TypeInfo{Type: TypeNil, Inferred: true},
|
||||||
|
tableType: &TypeInfo{Type: TypeTable, Inferred: true},
|
||||||
|
anyType: &TypeInfo{Type: TypeAny, Inferred: true},
|
||||||
|
}
|
||||||
|
|
||||||
|
return ti
|
||||||
|
}
|
||||||
|
|
||||||
|
// InferTypes performs type inference on the entire program
|
||||||
|
func (ti *TypeInferrer) InferTypes(program *Program) []TypeError {
|
||||||
|
for _, stmt := range program.Statements {
|
||||||
|
ti.inferStatement(stmt)
|
||||||
|
}
|
||||||
|
return ti.errors
|
||||||
|
}
|
||||||
|
|
||||||
|
// enterScope creates a new scope
|
||||||
|
func (ti *TypeInferrer) enterScope() {
|
||||||
|
ti.currentScope = NewScope(ti.currentScope)
|
||||||
|
}
|
||||||
|
|
||||||
|
// exitScope returns to the parent scope
|
||||||
|
func (ti *TypeInferrer) exitScope() {
|
||||||
|
if ti.currentScope.parent != nil {
|
||||||
|
ti.currentScope = ti.currentScope.parent
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// addError adds a type error
|
||||||
|
func (ti *TypeInferrer) addError(message string, node Node) {
|
||||||
|
ti.errors = append(ti.errors, TypeError{
|
||||||
|
Message: message,
|
||||||
|
Line: 0, // Would need to track position in AST nodes
|
||||||
|
Column: 0,
|
||||||
|
Node: node,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// inferStatement infers types for statements
|
||||||
|
func (ti *TypeInferrer) inferStatement(stmt Statement) {
|
||||||
|
switch s := stmt.(type) {
|
||||||
|
case *AssignStatement:
|
||||||
|
ti.inferAssignStatement(s)
|
||||||
|
case *EchoStatement:
|
||||||
|
ti.inferExpression(s.Value)
|
||||||
|
case *IfStatement:
|
||||||
|
ti.inferIfStatement(s)
|
||||||
|
case *WhileStatement:
|
||||||
|
ti.inferWhileStatement(s)
|
||||||
|
case *ForStatement:
|
||||||
|
ti.inferForStatement(s)
|
||||||
|
case *ForInStatement:
|
||||||
|
ti.inferForInStatement(s)
|
||||||
|
case *ReturnStatement:
|
||||||
|
if s.Value != nil {
|
||||||
|
ti.inferExpression(s.Value)
|
||||||
|
}
|
||||||
|
case *ExitStatement:
|
||||||
|
if s.Value != nil {
|
||||||
|
ti.inferExpression(s.Value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// inferAssignStatement handles variable assignments with type checking
|
||||||
|
func (ti *TypeInferrer) inferAssignStatement(stmt *AssignStatement) {
|
||||||
|
// Infer the type of the value expression
|
||||||
|
valueType := ti.inferExpression(stmt.Value)
|
||||||
|
|
||||||
|
if ident, ok := stmt.Name.(*Identifier); ok {
|
||||||
|
// Simple variable assignment
|
||||||
|
symbol := ti.currentScope.Lookup(ident.Value)
|
||||||
|
|
||||||
|
if stmt.IsDeclaration {
|
||||||
|
// New variable declaration
|
||||||
|
varType := valueType
|
||||||
|
|
||||||
|
// If there's a type hint, validate it
|
||||||
|
if stmt.TypeHint != nil {
|
||||||
|
if !ti.isTypeCompatible(valueType, stmt.TypeHint) {
|
||||||
|
ti.addError(fmt.Sprintf("cannot assign %s to variable of type %s",
|
||||||
|
valueType.Type, stmt.TypeHint.Type), stmt)
|
||||||
|
}
|
||||||
|
varType = stmt.TypeHint
|
||||||
|
varType.Inferred = false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Define the new symbol
|
||||||
|
ti.currentScope.Define(&Symbol{
|
||||||
|
Name: ident.Value,
|
||||||
|
Type: varType,
|
||||||
|
Declared: true,
|
||||||
|
})
|
||||||
|
|
||||||
|
ident.SetType(varType)
|
||||||
|
} else {
|
||||||
|
// Assignment to existing variable
|
||||||
|
if symbol == nil {
|
||||||
|
ti.addError(fmt.Sprintf("undefined variable '%s'", ident.Value), stmt)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check type compatibility
|
||||||
|
if !ti.isTypeCompatible(valueType, symbol.Type) {
|
||||||
|
ti.addError(fmt.Sprintf("cannot assign %s to variable of type %s",
|
||||||
|
valueType.Type, symbol.Type.Type), stmt)
|
||||||
|
}
|
||||||
|
|
||||||
|
ident.SetType(symbol.Type)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Member access assignment (table.key or table[index])
|
||||||
|
ti.inferExpression(stmt.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// inferIfStatement handles if statements
|
||||||
|
func (ti *TypeInferrer) inferIfStatement(stmt *IfStatement) {
|
||||||
|
condType := ti.inferExpression(stmt.Condition)
|
||||||
|
ti.validateBooleanContext(condType, stmt.Condition)
|
||||||
|
|
||||||
|
ti.enterScope()
|
||||||
|
for _, s := range stmt.Body {
|
||||||
|
ti.inferStatement(s)
|
||||||
|
}
|
||||||
|
ti.exitScope()
|
||||||
|
|
||||||
|
for _, elseif := range stmt.ElseIfs {
|
||||||
|
condType := ti.inferExpression(elseif.Condition)
|
||||||
|
ti.validateBooleanContext(condType, elseif.Condition)
|
||||||
|
|
||||||
|
ti.enterScope()
|
||||||
|
for _, s := range elseif.Body {
|
||||||
|
ti.inferStatement(s)
|
||||||
|
}
|
||||||
|
ti.exitScope()
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(stmt.Else) > 0 {
|
||||||
|
ti.enterScope()
|
||||||
|
for _, s := range stmt.Else {
|
||||||
|
ti.inferStatement(s)
|
||||||
|
}
|
||||||
|
ti.exitScope()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// inferWhileStatement handles while loops
|
||||||
|
func (ti *TypeInferrer) inferWhileStatement(stmt *WhileStatement) {
|
||||||
|
condType := ti.inferExpression(stmt.Condition)
|
||||||
|
ti.validateBooleanContext(condType, stmt.Condition)
|
||||||
|
|
||||||
|
ti.enterScope()
|
||||||
|
for _, s := range stmt.Body {
|
||||||
|
ti.inferStatement(s)
|
||||||
|
}
|
||||||
|
ti.exitScope()
|
||||||
|
}
|
||||||
|
|
||||||
|
// inferForStatement handles numeric for loops
|
||||||
|
func (ti *TypeInferrer) inferForStatement(stmt *ForStatement) {
|
||||||
|
startType := ti.inferExpression(stmt.Start)
|
||||||
|
endType := ti.inferExpression(stmt.End)
|
||||||
|
|
||||||
|
if !ti.isNumericType(startType) {
|
||||||
|
ti.addError("for loop start value must be numeric", stmt.Start)
|
||||||
|
}
|
||||||
|
if !ti.isNumericType(endType) {
|
||||||
|
ti.addError("for loop end value must be numeric", stmt.End)
|
||||||
|
}
|
||||||
|
|
||||||
|
if stmt.Step != nil {
|
||||||
|
stepType := ti.inferExpression(stmt.Step)
|
||||||
|
if !ti.isNumericType(stepType) {
|
||||||
|
ti.addError("for loop step value must be numeric", stmt.Step)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ti.enterScope()
|
||||||
|
// Define loop variable as number
|
||||||
|
ti.currentScope.Define(&Symbol{
|
||||||
|
Name: stmt.Variable.Value,
|
||||||
|
Type: ti.numberType,
|
||||||
|
Declared: true,
|
||||||
|
})
|
||||||
|
stmt.Variable.SetType(ti.numberType)
|
||||||
|
|
||||||
|
for _, s := range stmt.Body {
|
||||||
|
ti.inferStatement(s)
|
||||||
|
}
|
||||||
|
ti.exitScope()
|
||||||
|
}
|
||||||
|
|
||||||
|
// inferForInStatement handles for-in loops
|
||||||
|
func (ti *TypeInferrer) inferForInStatement(stmt *ForInStatement) {
|
||||||
|
iterableType := ti.inferExpression(stmt.Iterable)
|
||||||
|
|
||||||
|
// For now, assume iterable is a table
|
||||||
|
if !ti.isTableType(iterableType) {
|
||||||
|
ti.addError("for-in requires an iterable (table)", stmt.Iterable)
|
||||||
|
}
|
||||||
|
|
||||||
|
ti.enterScope()
|
||||||
|
|
||||||
|
// Define loop variables (key and value are any for now)
|
||||||
|
if stmt.Key != nil {
|
||||||
|
ti.currentScope.Define(&Symbol{
|
||||||
|
Name: stmt.Key.Value,
|
||||||
|
Type: ti.anyType,
|
||||||
|
Declared: true,
|
||||||
|
})
|
||||||
|
stmt.Key.SetType(ti.anyType)
|
||||||
|
}
|
||||||
|
|
||||||
|
ti.currentScope.Define(&Symbol{
|
||||||
|
Name: stmt.Value.Value,
|
||||||
|
Type: ti.anyType,
|
||||||
|
Declared: true,
|
||||||
|
})
|
||||||
|
stmt.Value.SetType(ti.anyType)
|
||||||
|
|
||||||
|
for _, s := range stmt.Body {
|
||||||
|
ti.inferStatement(s)
|
||||||
|
}
|
||||||
|
ti.exitScope()
|
||||||
|
}
|
||||||
|
|
||||||
|
// inferExpression infers the type of an expression
|
||||||
|
func (ti *TypeInferrer) inferExpression(expr Expression) *TypeInfo {
|
||||||
|
if expr == nil {
|
||||||
|
return ti.nilType
|
||||||
|
}
|
||||||
|
|
||||||
|
switch e := expr.(type) {
|
||||||
|
case *Identifier:
|
||||||
|
return ti.inferIdentifier(e)
|
||||||
|
case *NumberLiteral:
|
||||||
|
e.SetType(ti.numberType)
|
||||||
|
return ti.numberType
|
||||||
|
case *StringLiteral:
|
||||||
|
e.SetType(ti.stringType)
|
||||||
|
return ti.stringType
|
||||||
|
case *BooleanLiteral:
|
||||||
|
e.SetType(ti.boolType)
|
||||||
|
return ti.boolType
|
||||||
|
case *NilLiteral:
|
||||||
|
e.SetType(ti.nilType)
|
||||||
|
return ti.nilType
|
||||||
|
case *TableLiteral:
|
||||||
|
return ti.inferTableLiteral(e)
|
||||||
|
case *FunctionLiteral:
|
||||||
|
return ti.inferFunctionLiteral(e)
|
||||||
|
case *CallExpression:
|
||||||
|
return ti.inferCallExpression(e)
|
||||||
|
case *PrefixExpression:
|
||||||
|
return ti.inferPrefixExpression(e)
|
||||||
|
case *InfixExpression:
|
||||||
|
return ti.inferInfixExpression(e)
|
||||||
|
case *IndexExpression:
|
||||||
|
return ti.inferIndexExpression(e)
|
||||||
|
case *DotExpression:
|
||||||
|
return ti.inferDotExpression(e)
|
||||||
|
default:
|
||||||
|
ti.addError("unknown expression type", expr)
|
||||||
|
return ti.anyType
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// inferIdentifier looks up identifier type in symbol table
|
||||||
|
func (ti *TypeInferrer) inferIdentifier(ident *Identifier) *TypeInfo {
|
||||||
|
symbol := ti.currentScope.Lookup(ident.Value)
|
||||||
|
if symbol == nil {
|
||||||
|
ti.addError(fmt.Sprintf("undefined variable '%s'", ident.Value), ident)
|
||||||
|
return ti.anyType
|
||||||
|
}
|
||||||
|
|
||||||
|
ident.SetType(symbol.Type)
|
||||||
|
return symbol.Type
|
||||||
|
}
|
||||||
|
|
||||||
|
// inferTableLiteral infers table type
|
||||||
|
func (ti *TypeInferrer) inferTableLiteral(table *TableLiteral) *TypeInfo {
|
||||||
|
// Infer types of all values
|
||||||
|
for _, pair := range table.Pairs {
|
||||||
|
if pair.Key != nil {
|
||||||
|
ti.inferExpression(pair.Key)
|
||||||
|
}
|
||||||
|
ti.inferExpression(pair.Value)
|
||||||
|
}
|
||||||
|
|
||||||
|
table.SetType(ti.tableType)
|
||||||
|
return ti.tableType
|
||||||
|
}
|
||||||
|
|
||||||
|
// inferFunctionLiteral infers function type
|
||||||
|
func (ti *TypeInferrer) inferFunctionLiteral(fn *FunctionLiteral) *TypeInfo {
|
||||||
|
ti.enterScope()
|
||||||
|
|
||||||
|
// Define parameters in function scope
|
||||||
|
for _, param := range fn.Parameters {
|
||||||
|
paramType := ti.anyType
|
||||||
|
if param.TypeHint != nil {
|
||||||
|
paramType = param.TypeHint
|
||||||
|
}
|
||||||
|
|
||||||
|
ti.currentScope.Define(&Symbol{
|
||||||
|
Name: param.Name,
|
||||||
|
Type: paramType,
|
||||||
|
Declared: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Infer body
|
||||||
|
for _, stmt := range fn.Body {
|
||||||
|
ti.inferStatement(stmt)
|
||||||
|
}
|
||||||
|
|
||||||
|
ti.exitScope()
|
||||||
|
|
||||||
|
// For now, all functions have type "function"
|
||||||
|
funcType := &TypeInfo{Type: TypeFunction, Inferred: true}
|
||||||
|
fn.SetType(funcType)
|
||||||
|
return funcType
|
||||||
|
}
|
||||||
|
|
||||||
|
// inferCallExpression infers function call return type
|
||||||
|
func (ti *TypeInferrer) inferCallExpression(call *CallExpression) *TypeInfo {
|
||||||
|
funcType := ti.inferExpression(call.Function)
|
||||||
|
|
||||||
|
if !ti.isFunctionType(funcType) {
|
||||||
|
ti.addError("cannot call non-function", call.Function)
|
||||||
|
return ti.anyType
|
||||||
|
}
|
||||||
|
|
||||||
|
// Infer argument types
|
||||||
|
for _, arg := range call.Arguments {
|
||||||
|
ti.inferExpression(arg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// For now, assume function calls return any
|
||||||
|
call.SetType(ti.anyType)
|
||||||
|
return ti.anyType
|
||||||
|
}
|
||||||
|
|
||||||
|
// inferPrefixExpression infers prefix operation type
|
||||||
|
func (ti *TypeInferrer) inferPrefixExpression(prefix *PrefixExpression) *TypeInfo {
|
||||||
|
rightType := ti.inferExpression(prefix.Right)
|
||||||
|
|
||||||
|
var resultType *TypeInfo
|
||||||
|
switch prefix.Operator {
|
||||||
|
case "-":
|
||||||
|
if !ti.isNumericType(rightType) {
|
||||||
|
ti.addError("unary minus requires numeric operand", prefix)
|
||||||
|
}
|
||||||
|
resultType = ti.numberType
|
||||||
|
case "not":
|
||||||
|
resultType = ti.boolType
|
||||||
|
default:
|
||||||
|
ti.addError(fmt.Sprintf("unknown prefix operator '%s'", prefix.Operator), prefix)
|
||||||
|
resultType = ti.anyType
|
||||||
|
}
|
||||||
|
|
||||||
|
prefix.SetType(resultType)
|
||||||
|
return resultType
|
||||||
|
}
|
||||||
|
|
||||||
|
// inferInfixExpression infers binary operation type
|
||||||
|
func (ti *TypeInferrer) inferInfixExpression(infix *InfixExpression) *TypeInfo {
|
||||||
|
leftType := ti.inferExpression(infix.Left)
|
||||||
|
rightType := ti.inferExpression(infix.Right)
|
||||||
|
|
||||||
|
var resultType *TypeInfo
|
||||||
|
|
||||||
|
switch infix.Operator {
|
||||||
|
case "+", "-", "*", "/":
|
||||||
|
if !ti.isNumericType(leftType) || !ti.isNumericType(rightType) {
|
||||||
|
ti.addError(fmt.Sprintf("arithmetic operator '%s' requires numeric operands", infix.Operator), infix)
|
||||||
|
}
|
||||||
|
resultType = ti.numberType
|
||||||
|
|
||||||
|
case "==", "!=":
|
||||||
|
// Equality works with any types
|
||||||
|
resultType = ti.boolType
|
||||||
|
|
||||||
|
case "<", ">", "<=", ">=":
|
||||||
|
if !ti.isComparableTypes(leftType, rightType) {
|
||||||
|
ti.addError(fmt.Sprintf("comparison operator '%s' requires compatible operands", infix.Operator), infix)
|
||||||
|
}
|
||||||
|
resultType = ti.boolType
|
||||||
|
|
||||||
|
case "and", "or":
|
||||||
|
ti.validateBooleanContext(leftType, infix.Left)
|
||||||
|
ti.validateBooleanContext(rightType, infix.Right)
|
||||||
|
resultType = ti.boolType
|
||||||
|
|
||||||
|
default:
|
||||||
|
ti.addError(fmt.Sprintf("unknown infix operator '%s'", infix.Operator), infix)
|
||||||
|
resultType = ti.anyType
|
||||||
|
}
|
||||||
|
|
||||||
|
infix.SetType(resultType)
|
||||||
|
return resultType
|
||||||
|
}
|
||||||
|
|
||||||
|
// inferIndexExpression infers table[index] type
|
||||||
|
func (ti *TypeInferrer) inferIndexExpression(index *IndexExpression) *TypeInfo {
|
||||||
|
ti.inferExpression(index.Left)
|
||||||
|
ti.inferExpression(index.Index)
|
||||||
|
|
||||||
|
// For now, assume table access returns any
|
||||||
|
index.SetType(ti.anyType)
|
||||||
|
return ti.anyType
|
||||||
|
}
|
||||||
|
|
||||||
|
// inferDotExpression infers table.key type
|
||||||
|
func (ti *TypeInferrer) inferDotExpression(dot *DotExpression) *TypeInfo {
|
||||||
|
ti.inferExpression(dot.Left)
|
||||||
|
|
||||||
|
// For now, assume member access returns any
|
||||||
|
dot.SetType(ti.anyType)
|
||||||
|
return ti.anyType
|
||||||
|
}
|
||||||
|
|
||||||
|
// Type checking helper methods
|
||||||
|
|
||||||
|
func (ti *TypeInferrer) isTypeCompatible(valueType, targetType *TypeInfo) bool {
|
||||||
|
if targetType.Type == TypeAny || valueType.Type == TypeAny {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return valueType.Type == targetType.Type
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ti *TypeInferrer) isNumericType(t *TypeInfo) bool {
|
||||||
|
return t.Type == TypeNumber
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ti *TypeInferrer) isBooleanType(t *TypeInfo) bool {
|
||||||
|
return t.Type == TypeBool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ti *TypeInferrer) isTableType(t *TypeInfo) bool {
|
||||||
|
return t.Type == TypeTable
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ti *TypeInferrer) isFunctionType(t *TypeInfo) bool {
|
||||||
|
return t.Type == TypeFunction
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ti *TypeInferrer) isComparableTypes(left, right *TypeInfo) bool {
|
||||||
|
if left.Type == TypeAny || right.Type == TypeAny {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return left.Type == right.Type && (left.Type == TypeNumber || left.Type == TypeString)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ti *TypeInferrer) validateBooleanContext(t *TypeInfo, expr Expression) {
|
||||||
|
// In many languages, non-boolean values can be used in boolean context
|
||||||
|
// For strictness, we could require boolean type here
|
||||||
|
// For now, allow any type (truthy/falsy semantics)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Errors returns all type checking errors
|
||||||
|
func (ti *TypeInferrer) Errors() []TypeError {
|
||||||
|
return ti.errors
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasErrors returns true if there are any type errors
|
||||||
|
func (ti *TypeInferrer) HasErrors() bool {
|
||||||
|
return len(ti.errors) > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrorStrings returns error messages as strings
|
||||||
|
func (ti *TypeInferrer) ErrorStrings() []string {
|
||||||
|
result := make([]string, len(ti.errors))
|
||||||
|
for i, err := range ti.errors {
|
||||||
|
result[i] = err.Error()
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidTypeName checks if a string is a valid type name
|
||||||
|
func ValidTypeName(name string) bool {
|
||||||
|
validTypes := []string{TypeNumber, TypeString, TypeBool, TypeNil, TypeTable, TypeFunction, TypeAny}
|
||||||
|
for _, validType := range validTypes {
|
||||||
|
if name == validType {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseTypeName converts a string to a TypeInfo (for parsing type hints)
|
||||||
|
func ParseTypeName(name string) *TypeInfo {
|
||||||
|
if ValidTypeName(name) {
|
||||||
|
return &TypeInfo{Type: name, Inferred: false}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user