structs
This commit is contained in:
parent
fc439b6d5a
commit
5ae2a6ef23
@ -4,7 +4,7 @@ import "fmt"
|
||||
|
||||
// TypeInfo represents type information for expressions
|
||||
type TypeInfo struct {
|
||||
Type string // "number", "string", "bool", "table", "function", "nil", "any"
|
||||
Type string // "number", "string", "bool", "table", "function", "nil", "any", struct name
|
||||
Inferred bool // true if type was inferred, false if explicitly declared
|
||||
}
|
||||
|
||||
@ -41,6 +41,67 @@ func (p *Program) String() string {
|
||||
return result
|
||||
}
|
||||
|
||||
// StructField represents a field in a struct definition
|
||||
type StructField struct {
|
||||
Name string
|
||||
TypeHint *TypeInfo
|
||||
}
|
||||
|
||||
func (sf *StructField) String() string {
|
||||
if sf.TypeHint != nil {
|
||||
return fmt.Sprintf("%s: %s", sf.Name, sf.TypeHint.Type)
|
||||
}
|
||||
return sf.Name
|
||||
}
|
||||
|
||||
// StructStatement represents struct definitions
|
||||
type StructStatement struct {
|
||||
Name string
|
||||
Fields []StructField
|
||||
}
|
||||
|
||||
func (ss *StructStatement) statementNode() {}
|
||||
func (ss *StructStatement) String() string {
|
||||
var fields string
|
||||
for i, field := range ss.Fields {
|
||||
if i > 0 {
|
||||
fields += ",\n\t"
|
||||
}
|
||||
fields += field.String()
|
||||
}
|
||||
return fmt.Sprintf("struct %s {\n\t%s\n}", ss.Name, fields)
|
||||
}
|
||||
|
||||
// MethodDefinition represents method definitions on structs
|
||||
type MethodDefinition struct {
|
||||
StructName string
|
||||
MethodName string
|
||||
Function *FunctionLiteral
|
||||
}
|
||||
|
||||
func (md *MethodDefinition) statementNode() {}
|
||||
func (md *MethodDefinition) String() string {
|
||||
return fmt.Sprintf("fn %s.%s%s", md.StructName, md.MethodName, md.Function.String()[2:]) // skip "fn" from function string
|
||||
}
|
||||
|
||||
// StructConstructorExpression represents struct constructor calls like my_type{...}
|
||||
type StructConstructorExpression struct {
|
||||
StructName string
|
||||
Fields []TablePair // reuse TablePair for field assignments
|
||||
typeInfo *TypeInfo
|
||||
}
|
||||
|
||||
func (sce *StructConstructorExpression) expressionNode() {}
|
||||
func (sce *StructConstructorExpression) String() string {
|
||||
var pairs []string
|
||||
for _, pair := range sce.Fields {
|
||||
pairs = append(pairs, pair.String())
|
||||
}
|
||||
return fmt.Sprintf("%s{%s}", sce.StructName, joinStrings(pairs, ", "))
|
||||
}
|
||||
func (sce *StructConstructorExpression) GetType() *TypeInfo { return sce.typeInfo }
|
||||
func (sce *StructConstructorExpression) SetType(t *TypeInfo) { sce.typeInfo = t }
|
||||
|
||||
// AssignStatement represents variable assignment with optional type hint
|
||||
type AssignStatement struct {
|
||||
Name Expression // Changed from *Identifier to Expression for member access
|
||||
|
313
parser/parser.go
313
parser/parser.go
@ -34,6 +34,9 @@ type Parser struct {
|
||||
// Scope tracking
|
||||
scopes []map[string]bool // stack of scopes, each tracking declared variables
|
||||
scopeTypes []string // track what type each scope is: "global", "function", "loop"
|
||||
|
||||
// Struct tracking
|
||||
structs map[string]*StructStatement // track defined structs
|
||||
}
|
||||
|
||||
// NewParser creates a new parser instance
|
||||
@ -43,6 +46,7 @@ func NewParser(lexer *Lexer) *Parser {
|
||||
errors: []ParseError{},
|
||||
scopes: []map[string]bool{make(map[string]bool)}, // start with global scope
|
||||
scopeTypes: []string{"global"}, // start with global scope type
|
||||
structs: make(map[string]*StructStatement), // track struct definitions
|
||||
}
|
||||
|
||||
p.prefixParseFns = make(map[TokenType]func() Expression)
|
||||
@ -74,6 +78,7 @@ func NewParser(lexer *Lexer) *Parser {
|
||||
p.registerInfix(DOT, p.parseDotExpression)
|
||||
p.registerInfix(LBRACKET, p.parseIndexExpression)
|
||||
p.registerInfix(LPAREN, p.parseCallExpression)
|
||||
p.registerInfix(LBRACE, p.parseStructConstructor) // struct constructor
|
||||
|
||||
// Read two tokens, so curToken and peekToken are both set
|
||||
p.nextToken()
|
||||
@ -157,7 +162,7 @@ func (p *Parser) parseTypeHint() *TypeInfo {
|
||||
}
|
||||
|
||||
typeName := p.curToken.Literal
|
||||
if !ValidTypeName(typeName) {
|
||||
if !ValidTypeName(typeName) && !p.isStructDefined(typeName) {
|
||||
p.addError(fmt.Sprintf("invalid type name '%s'", typeName))
|
||||
return nil
|
||||
}
|
||||
@ -165,6 +170,12 @@ func (p *Parser) parseTypeHint() *TypeInfo {
|
||||
return &TypeInfo{Type: typeName, Inferred: false}
|
||||
}
|
||||
|
||||
// isStructDefined checks if a struct name is defined
|
||||
func (p *Parser) isStructDefined(name string) bool {
|
||||
_, exists := p.structs[name]
|
||||
return exists
|
||||
}
|
||||
|
||||
// registerPrefix registers a prefix parse function
|
||||
func (p *Parser) registerPrefix(tokenType TokenType, fn func() Expression) {
|
||||
p.prefixParseFns[tokenType] = fn
|
||||
@ -200,6 +211,10 @@ func (p *Parser) ParseProgram() *Program {
|
||||
// parseStatement parses a statement
|
||||
func (p *Parser) parseStatement() Statement {
|
||||
switch p.curToken.Type {
|
||||
case STRUCT:
|
||||
return p.parseStructStatement()
|
||||
case FN:
|
||||
return p.parseFunctionStatement()
|
||||
case IDENT:
|
||||
return p.parseIdentifierStatement()
|
||||
case IF:
|
||||
@ -230,6 +245,147 @@ func (p *Parser) parseStatement() Statement {
|
||||
}
|
||||
}
|
||||
|
||||
// parseStructStatement parses struct definitions
|
||||
func (p *Parser) parseStructStatement() *StructStatement {
|
||||
stmt := &StructStatement{}
|
||||
|
||||
if !p.expectPeek(IDENT) {
|
||||
p.addError("expected struct name")
|
||||
return nil
|
||||
}
|
||||
|
||||
stmt.Name = p.curToken.Literal
|
||||
|
||||
if !p.expectPeek(LBRACE) {
|
||||
p.addError("expected '{' after struct name")
|
||||
return nil
|
||||
}
|
||||
|
||||
stmt.Fields = []StructField{}
|
||||
|
||||
if p.peekTokenIs(RBRACE) {
|
||||
p.nextToken()
|
||||
p.structs[stmt.Name] = stmt
|
||||
return stmt
|
||||
}
|
||||
|
||||
p.nextToken()
|
||||
|
||||
for {
|
||||
if p.curTokenIs(EOF) {
|
||||
p.addError("unexpected end of input, expected }")
|
||||
return nil
|
||||
}
|
||||
|
||||
if !p.curTokenIs(IDENT) {
|
||||
p.addError("expected field name")
|
||||
return nil
|
||||
}
|
||||
|
||||
field := StructField{Name: p.curToken.Literal}
|
||||
|
||||
// Parse optional type hint
|
||||
field.TypeHint = p.parseTypeHint()
|
||||
if field.TypeHint == nil {
|
||||
p.addError("struct fields require type annotation")
|
||||
return nil
|
||||
}
|
||||
|
||||
stmt.Fields = append(stmt.Fields, field)
|
||||
|
||||
if !p.peekTokenIs(COMMA) {
|
||||
break
|
||||
}
|
||||
|
||||
p.nextToken() // consume comma
|
||||
p.nextToken() // move to next field
|
||||
|
||||
if p.curTokenIs(RBRACE) {
|
||||
break
|
||||
}
|
||||
|
||||
if p.curTokenIs(EOF) {
|
||||
p.addError("expected next token to be }")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if !p.expectPeek(RBRACE) {
|
||||
return nil
|
||||
}
|
||||
|
||||
p.structs[stmt.Name] = stmt
|
||||
return stmt
|
||||
}
|
||||
|
||||
// parseFunctionStatement handles both regular functions and methods
|
||||
func (p *Parser) parseFunctionStatement() Statement {
|
||||
if !p.expectPeek(IDENT) {
|
||||
p.addError("expected function name")
|
||||
return nil
|
||||
}
|
||||
|
||||
funcName := p.curToken.Literal
|
||||
|
||||
// Check if this is a method definition (struct.method)
|
||||
if p.peekTokenIs(DOT) {
|
||||
p.nextToken() // consume '.'
|
||||
|
||||
if !p.expectPeek(IDENT) {
|
||||
p.addError("expected method name after '.'")
|
||||
return nil
|
||||
}
|
||||
|
||||
methodName := p.curToken.Literal
|
||||
|
||||
if !p.expectPeek(LPAREN) {
|
||||
p.addError("expected '(' after method name")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Parse the function literal starting from parameters
|
||||
funcLit := &FunctionLiteral{}
|
||||
funcLit.Parameters, funcLit.Variadic = p.parseFunctionParameters()
|
||||
|
||||
if !p.expectPeek(RPAREN) {
|
||||
p.addError("expected ')' after function parameters")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check for return type hint
|
||||
funcLit.ReturnType = p.parseTypeHint()
|
||||
|
||||
p.nextToken()
|
||||
|
||||
p.enterFunctionScope()
|
||||
for _, param := range funcLit.Parameters {
|
||||
p.declareVariable(param.Name)
|
||||
}
|
||||
funcLit.Body = p.parseBlockStatements(END)
|
||||
p.exitFunctionScope()
|
||||
|
||||
if !p.curTokenIs(END) {
|
||||
p.addError("expected 'end' to close function")
|
||||
return nil
|
||||
}
|
||||
|
||||
return &MethodDefinition{
|
||||
StructName: funcName,
|
||||
MethodName: methodName,
|
||||
Function: funcLit,
|
||||
}
|
||||
}
|
||||
|
||||
// Regular function - this should be handled as expression statement
|
||||
// Reset to handle as function literal
|
||||
funcLit := p.parseFunctionLiteral()
|
||||
if funcLit == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &ExpressionStatement{Expression: funcLit}
|
||||
}
|
||||
|
||||
// parseIdentifierStatement handles both assignments and expression statements starting with identifiers
|
||||
func (p *Parser) parseIdentifierStatement() Statement {
|
||||
// Parse the left-hand side expression first
|
||||
@ -948,6 +1104,157 @@ func (p *Parser) parseTableLiteral() Expression {
|
||||
return table
|
||||
}
|
||||
|
||||
// parseStructConstructor handles struct constructor calls like my_type{...}
|
||||
func (p *Parser) parseStructConstructor(left Expression) Expression {
|
||||
// left should be an identifier representing the struct name
|
||||
ident, ok := left.(*Identifier)
|
||||
if !ok {
|
||||
// Not an identifier, fall back to table literal parsing
|
||||
return p.parseTableLiteralFromBrace()
|
||||
}
|
||||
|
||||
structName := ident.Value
|
||||
|
||||
// Always try to parse as struct constructor if we have an identifier
|
||||
// Type checking will catch undefined structs later
|
||||
constructor := &StructConstructorExpression{
|
||||
StructName: structName,
|
||||
Fields: []TablePair{},
|
||||
}
|
||||
|
||||
if p.peekTokenIs(RBRACE) {
|
||||
p.nextToken()
|
||||
return constructor
|
||||
}
|
||||
|
||||
p.nextToken()
|
||||
|
||||
for {
|
||||
if p.curTokenIs(EOF) {
|
||||
p.addError("unexpected end of input, expected }")
|
||||
return nil
|
||||
}
|
||||
|
||||
pair := TablePair{}
|
||||
|
||||
if (p.curTokenIs(IDENT) || p.curTokenIs(STRING)) && p.peekTokenIs(ASSIGN) {
|
||||
if p.curTokenIs(IDENT) {
|
||||
pair.Key = &Identifier{Value: p.curToken.Literal}
|
||||
} else {
|
||||
pair.Key = &StringLiteral{Value: p.curToken.Literal}
|
||||
}
|
||||
p.nextToken()
|
||||
p.nextToken()
|
||||
|
||||
if p.curTokenIs(EOF) {
|
||||
p.addError("expected expression after assignment operator")
|
||||
return nil
|
||||
}
|
||||
|
||||
pair.Value = p.ParseExpression(LOWEST)
|
||||
} else {
|
||||
pair.Value = p.ParseExpression(LOWEST)
|
||||
}
|
||||
|
||||
if pair.Value == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
constructor.Fields = append(constructor.Fields, pair)
|
||||
|
||||
if !p.peekTokenIs(COMMA) {
|
||||
break
|
||||
}
|
||||
|
||||
p.nextToken()
|
||||
p.nextToken()
|
||||
|
||||
if p.curTokenIs(RBRACE) {
|
||||
break
|
||||
}
|
||||
|
||||
if p.curTokenIs(EOF) {
|
||||
p.addError("expected next token to be }")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if !p.expectPeek(RBRACE) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return constructor
|
||||
}
|
||||
|
||||
func (p *Parser) parseTableLiteralFromBrace() Expression {
|
||||
// We're already at the opening brace, so parse as table literal
|
||||
table := &TableLiteral{}
|
||||
table.Pairs = []TablePair{}
|
||||
|
||||
if p.peekTokenIs(RBRACE) {
|
||||
p.nextToken()
|
||||
return table
|
||||
}
|
||||
|
||||
p.nextToken()
|
||||
|
||||
for {
|
||||
if p.curTokenIs(EOF) {
|
||||
p.addError("unexpected end of input, expected }")
|
||||
return nil
|
||||
}
|
||||
|
||||
pair := TablePair{}
|
||||
|
||||
if (p.curTokenIs(IDENT) || p.curTokenIs(STRING)) && p.peekTokenIs(ASSIGN) {
|
||||
if p.curTokenIs(IDENT) {
|
||||
pair.Key = &Identifier{Value: p.curToken.Literal}
|
||||
} else {
|
||||
pair.Key = &StringLiteral{Value: p.curToken.Literal}
|
||||
}
|
||||
p.nextToken()
|
||||
p.nextToken()
|
||||
|
||||
if p.curTokenIs(EOF) {
|
||||
p.addError("expected expression after assignment operator")
|
||||
return nil
|
||||
}
|
||||
|
||||
pair.Value = p.ParseExpression(LOWEST)
|
||||
} else {
|
||||
pair.Value = p.ParseExpression(LOWEST)
|
||||
}
|
||||
|
||||
if pair.Value == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
table.Pairs = append(table.Pairs, pair)
|
||||
|
||||
if !p.peekTokenIs(COMMA) {
|
||||
break
|
||||
}
|
||||
|
||||
p.nextToken()
|
||||
p.nextToken()
|
||||
|
||||
if p.curTokenIs(RBRACE) {
|
||||
break
|
||||
}
|
||||
|
||||
if p.curTokenIs(EOF) {
|
||||
p.addError("expected next token to be }")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if !p.expectPeek(RBRACE) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return table
|
||||
}
|
||||
|
||||
func (p *Parser) parseInfixExpression(left Expression) Expression {
|
||||
expression := &InfixExpression{
|
||||
Left: left,
|
||||
@ -1057,7 +1364,7 @@ func (p *Parser) expectPeekIdent() bool {
|
||||
|
||||
func (p *Parser) isKeyword(t TokenType) bool {
|
||||
switch t {
|
||||
case TRUE, FALSE, NIL, AND, OR, NOT, IF, THEN, ELSEIF, ELSE, END, ECHO, FOR, WHILE, IN, DO, BREAK, EXIT, FN, RETURN:
|
||||
case TRUE, FALSE, NIL, AND, OR, NOT, IF, THEN, ELSEIF, ELSE, END, ECHO, FOR, WHILE, IN, DO, BREAK, EXIT, FN, RETURN, STRUCT:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
@ -1227,6 +1534,8 @@ func tokenTypeString(t TokenType) string {
|
||||
return "fn"
|
||||
case RETURN:
|
||||
return "return"
|
||||
case STRUCT:
|
||||
return "struct"
|
||||
case EOF:
|
||||
return "end of file"
|
||||
case ILLEGAL:
|
||||
|
632
parser/tests/structs_test.go
Normal file
632
parser/tests/structs_test.go
Normal file
@ -0,0 +1,632 @@
|
||||
package parser_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"git.sharkk.net/Sharkk/Mako/parser"
|
||||
)
|
||||
|
||||
func TestBasicStructDefinition(t *testing.T) {
|
||||
input := `struct Person {
|
||||
name: string,
|
||||
age: number
|
||||
}`
|
||||
|
||||
l := parser.NewLexer(input)
|
||||
p := parser.NewParser(l)
|
||||
program := p.ParseProgram()
|
||||
checkParserErrors(t, p)
|
||||
|
||||
if len(program.Statements) != 1 {
|
||||
t.Fatalf("expected 1 statement, got %d", len(program.Statements))
|
||||
}
|
||||
|
||||
stmt, ok := program.Statements[0].(*parser.StructStatement)
|
||||
if !ok {
|
||||
t.Fatalf("expected StructStatement, got %T", program.Statements[0])
|
||||
}
|
||||
|
||||
if stmt.Name != "Person" {
|
||||
t.Errorf("expected struct name 'Person', got %s", stmt.Name)
|
||||
}
|
||||
|
||||
if len(stmt.Fields) != 2 {
|
||||
t.Fatalf("expected 2 fields, got %d", len(stmt.Fields))
|
||||
}
|
||||
|
||||
// Test first field
|
||||
if stmt.Fields[0].Name != "name" {
|
||||
t.Errorf("expected field name 'name', got %s", stmt.Fields[0].Name)
|
||||
}
|
||||
if stmt.Fields[0].TypeHint == nil {
|
||||
t.Fatal("expected type hint for name field")
|
||||
}
|
||||
if stmt.Fields[0].TypeHint.Type != "string" {
|
||||
t.Errorf("expected type 'string', got %s", stmt.Fields[0].TypeHint.Type)
|
||||
}
|
||||
|
||||
// Test second field
|
||||
if stmt.Fields[1].Name != "age" {
|
||||
t.Errorf("expected field name 'age', got %s", stmt.Fields[1].Name)
|
||||
}
|
||||
if stmt.Fields[1].TypeHint == nil {
|
||||
t.Fatal("expected type hint for age field")
|
||||
}
|
||||
if stmt.Fields[1].TypeHint.Type != "number" {
|
||||
t.Errorf("expected type 'number', got %s", stmt.Fields[1].TypeHint.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmptyStructDefinition(t *testing.T) {
|
||||
input := `struct Empty {}`
|
||||
|
||||
l := parser.NewLexer(input)
|
||||
p := parser.NewParser(l)
|
||||
program := p.ParseProgram()
|
||||
checkParserErrors(t, p)
|
||||
|
||||
if len(program.Statements) != 1 {
|
||||
t.Fatalf("expected 1 statement, got %d", len(program.Statements))
|
||||
}
|
||||
|
||||
stmt, ok := program.Statements[0].(*parser.StructStatement)
|
||||
if !ok {
|
||||
t.Fatalf("expected StructStatement, got %T", program.Statements[0])
|
||||
}
|
||||
|
||||
if stmt.Name != "Empty" {
|
||||
t.Errorf("expected struct name 'Empty', got %s", stmt.Name)
|
||||
}
|
||||
|
||||
if len(stmt.Fields) != 0 {
|
||||
t.Errorf("expected 0 fields, got %d", len(stmt.Fields))
|
||||
}
|
||||
}
|
||||
|
||||
func TestComplexStructDefinition(t *testing.T) {
|
||||
input := `struct Complex {
|
||||
id: number,
|
||||
name: string,
|
||||
active: bool,
|
||||
data: table,
|
||||
callback: function,
|
||||
optional: any
|
||||
}`
|
||||
|
||||
l := parser.NewLexer(input)
|
||||
p := parser.NewParser(l)
|
||||
program := p.ParseProgram()
|
||||
checkParserErrors(t, p)
|
||||
|
||||
if len(program.Statements) != 1 {
|
||||
t.Fatalf("expected 1 statement, got %d", len(program.Statements))
|
||||
}
|
||||
|
||||
stmt, ok := program.Statements[0].(*parser.StructStatement)
|
||||
if !ok {
|
||||
t.Fatalf("expected StructStatement, got %T", program.Statements[0])
|
||||
}
|
||||
|
||||
expectedTypes := []string{"number", "string", "bool", "table", "function", "any"}
|
||||
expectedNames := []string{"id", "name", "active", "data", "callback", "optional"}
|
||||
|
||||
if len(stmt.Fields) != len(expectedTypes) {
|
||||
t.Fatalf("expected %d fields, got %d", len(expectedTypes), len(stmt.Fields))
|
||||
}
|
||||
|
||||
for i, field := range stmt.Fields {
|
||||
if field.Name != expectedNames[i] {
|
||||
t.Errorf("field %d: expected name '%s', got '%s'", i, expectedNames[i], field.Name)
|
||||
}
|
||||
if field.TypeHint == nil {
|
||||
t.Fatalf("field %d: expected type hint", i)
|
||||
}
|
||||
if field.TypeHint.Type != expectedTypes[i] {
|
||||
t.Errorf("field %d: expected type '%s', got '%s'", i, expectedTypes[i], field.TypeHint.Type)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMethodDefinition(t *testing.T) {
|
||||
input := `struct Person {
|
||||
name: string,
|
||||
age: number
|
||||
}
|
||||
|
||||
fn Person.getName(): string
|
||||
return self.name
|
||||
end
|
||||
|
||||
fn Person.setAge(newAge: number)
|
||||
self.age = newAge
|
||||
end`
|
||||
|
||||
l := parser.NewLexer(input)
|
||||
p := parser.NewParser(l)
|
||||
program := p.ParseProgram()
|
||||
checkParserErrors(t, p)
|
||||
|
||||
if len(program.Statements) != 3 {
|
||||
t.Fatalf("expected 3 statements, got %d", len(program.Statements))
|
||||
}
|
||||
|
||||
// First statement: struct definition
|
||||
structStmt, ok := program.Statements[0].(*parser.StructStatement)
|
||||
if !ok {
|
||||
t.Fatalf("expected StructStatement, got %T", program.Statements[0])
|
||||
}
|
||||
if structStmt.Name != "Person" {
|
||||
t.Errorf("expected struct name 'Person', got %s", structStmt.Name)
|
||||
}
|
||||
|
||||
// Second statement: getter method
|
||||
method1, ok := program.Statements[1].(*parser.MethodDefinition)
|
||||
if !ok {
|
||||
t.Fatalf("expected MethodDefinition, got %T", program.Statements[1])
|
||||
}
|
||||
if method1.StructName != "Person" {
|
||||
t.Errorf("expected struct name 'Person', got %s", method1.StructName)
|
||||
}
|
||||
if method1.MethodName != "getName" {
|
||||
t.Errorf("expected method name 'getName', got %s", method1.MethodName)
|
||||
}
|
||||
if method1.Function.ReturnType == nil {
|
||||
t.Fatal("expected return type for getName method")
|
||||
}
|
||||
if method1.Function.ReturnType.Type != "string" {
|
||||
t.Errorf("expected return type 'string', got %s", method1.Function.ReturnType.Type)
|
||||
}
|
||||
if len(method1.Function.Parameters) != 0 {
|
||||
t.Errorf("expected 0 parameters, got %d", len(method1.Function.Parameters))
|
||||
}
|
||||
|
||||
// Third statement: setter method
|
||||
method2, ok := program.Statements[2].(*parser.MethodDefinition)
|
||||
if !ok {
|
||||
t.Fatalf("expected MethodDefinition, got %T", program.Statements[2])
|
||||
}
|
||||
if method2.StructName != "Person" {
|
||||
t.Errorf("expected struct name 'Person', got %s", method2.StructName)
|
||||
}
|
||||
if method2.MethodName != "setAge" {
|
||||
t.Errorf("expected method name 'setAge', got %s", method2.MethodName)
|
||||
}
|
||||
if method2.Function.ReturnType != nil {
|
||||
t.Errorf("expected no return type for setAge method, got %s", method2.Function.ReturnType.Type)
|
||||
}
|
||||
if len(method2.Function.Parameters) != 1 {
|
||||
t.Fatalf("expected 1 parameter, got %d", len(method2.Function.Parameters))
|
||||
}
|
||||
if method2.Function.Parameters[0].Name != "newAge" {
|
||||
t.Errorf("expected parameter name 'newAge', got %s", method2.Function.Parameters[0].Name)
|
||||
}
|
||||
if method2.Function.Parameters[0].TypeHint == nil {
|
||||
t.Fatal("expected type hint for newAge parameter")
|
||||
}
|
||||
if method2.Function.Parameters[0].TypeHint.Type != "number" {
|
||||
t.Errorf("expected parameter type 'number', got %s", method2.Function.Parameters[0].TypeHint.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStructConstructor(t *testing.T) {
|
||||
input := `struct Person {
|
||||
name: string,
|
||||
age: number
|
||||
}
|
||||
|
||||
person = Person{name = "John", age = 30}
|
||||
empty = Person{}`
|
||||
|
||||
l := parser.NewLexer(input)
|
||||
p := parser.NewParser(l)
|
||||
program := p.ParseProgram()
|
||||
checkParserErrors(t, p)
|
||||
|
||||
if len(program.Statements) != 3 {
|
||||
t.Fatalf("expected 3 statements, got %d", len(program.Statements))
|
||||
}
|
||||
|
||||
// Second statement: constructor with fields
|
||||
assign1, ok := program.Statements[1].(*parser.AssignStatement)
|
||||
if !ok {
|
||||
t.Fatalf("expected AssignStatement, got %T", program.Statements[1])
|
||||
}
|
||||
|
||||
constructor1, ok := assign1.Value.(*parser.StructConstructorExpression)
|
||||
if !ok {
|
||||
t.Fatalf("expected StructConstructorExpression, got %T", assign1.Value)
|
||||
}
|
||||
|
||||
if constructor1.StructName != "Person" {
|
||||
t.Errorf("expected struct name 'Person', got %s", constructor1.StructName)
|
||||
}
|
||||
|
||||
if len(constructor1.Fields) != 2 {
|
||||
t.Fatalf("expected 2 fields, got %d", len(constructor1.Fields))
|
||||
}
|
||||
|
||||
// Check name field
|
||||
nameKey, ok := constructor1.Fields[0].Key.(*parser.Identifier)
|
||||
if !ok {
|
||||
t.Fatalf("expected Identifier for name key, got %T", constructor1.Fields[0].Key)
|
||||
}
|
||||
if nameKey.Value != "name" {
|
||||
t.Errorf("expected key 'name', got %s", nameKey.Value)
|
||||
}
|
||||
testStringLiteral(t, constructor1.Fields[0].Value, "John")
|
||||
|
||||
// Check age field
|
||||
ageKey, ok := constructor1.Fields[1].Key.(*parser.Identifier)
|
||||
if !ok {
|
||||
t.Fatalf("expected Identifier for age key, got %T", constructor1.Fields[1].Key)
|
||||
}
|
||||
if ageKey.Value != "age" {
|
||||
t.Errorf("expected key 'age', got %s", ageKey.Value)
|
||||
}
|
||||
testNumberLiteral(t, constructor1.Fields[1].Value, 30)
|
||||
|
||||
// Third statement: empty constructor
|
||||
assign2, ok := program.Statements[2].(*parser.AssignStatement)
|
||||
if !ok {
|
||||
t.Fatalf("expected AssignStatement, got %T", program.Statements[2])
|
||||
}
|
||||
|
||||
constructor2, ok := assign2.Value.(*parser.StructConstructorExpression)
|
||||
if !ok {
|
||||
t.Fatalf("expected StructConstructorExpression, got %T", assign2.Value)
|
||||
}
|
||||
|
||||
if constructor2.StructName != "Person" {
|
||||
t.Errorf("expected struct name 'Person', got %s", constructor2.StructName)
|
||||
}
|
||||
|
||||
if len(constructor2.Fields) != 0 {
|
||||
t.Errorf("expected 0 fields, got %d", len(constructor2.Fields))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNestedStructTypes(t *testing.T) {
|
||||
input := `struct Address {
|
||||
street: string,
|
||||
city: string
|
||||
}
|
||||
|
||||
struct Person {
|
||||
name: string,
|
||||
address: Address
|
||||
}
|
||||
|
||||
person = Person{
|
||||
name = "John",
|
||||
address = Address{street = "Main St", city = "NYC"}
|
||||
}`
|
||||
|
||||
l := parser.NewLexer(input)
|
||||
p := parser.NewParser(l)
|
||||
program := p.ParseProgram()
|
||||
checkParserErrors(t, p)
|
||||
|
||||
if len(program.Statements) != 3 {
|
||||
t.Fatalf("expected 3 statements, got %d", len(program.Statements))
|
||||
}
|
||||
|
||||
// Check Person struct has Address field type
|
||||
personStruct, ok := program.Statements[1].(*parser.StructStatement)
|
||||
if !ok {
|
||||
t.Fatalf("expected StructStatement, got %T", program.Statements[1])
|
||||
}
|
||||
|
||||
addressField := personStruct.Fields[1]
|
||||
if addressField.Name != "address" {
|
||||
t.Errorf("expected field name 'address', got %s", addressField.Name)
|
||||
}
|
||||
if addressField.TypeHint.Type != "Address" {
|
||||
t.Errorf("expected field type 'Address', got %s", addressField.TypeHint.Type)
|
||||
}
|
||||
|
||||
// Check nested constructor
|
||||
assign, ok := program.Statements[2].(*parser.AssignStatement)
|
||||
if !ok {
|
||||
t.Fatalf("expected AssignStatement, got %T", program.Statements[2])
|
||||
}
|
||||
|
||||
personConstructor, ok := assign.Value.(*parser.StructConstructorExpression)
|
||||
if !ok {
|
||||
t.Fatalf("expected StructConstructorExpression, got %T", assign.Value)
|
||||
}
|
||||
|
||||
// Check the nested Address constructor
|
||||
addressConstructor, ok := personConstructor.Fields[1].Value.(*parser.StructConstructorExpression)
|
||||
if !ok {
|
||||
t.Fatalf("expected nested StructConstructorExpression, got %T", personConstructor.Fields[1].Value)
|
||||
}
|
||||
|
||||
if addressConstructor.StructName != "Address" {
|
||||
t.Errorf("expected nested struct name 'Address', got %s", addressConstructor.StructName)
|
||||
}
|
||||
|
||||
if len(addressConstructor.Fields) != 2 {
|
||||
t.Errorf("expected 2 fields in nested constructor, got %d", len(addressConstructor.Fields))
|
||||
}
|
||||
}
|
||||
|
||||
func TestStructIntegrationWithProgram(t *testing.T) {
|
||||
input := `struct Point {
|
||||
x: number,
|
||||
y: number
|
||||
}
|
||||
|
||||
fn Point.distance(other: Point): number
|
||||
dx = self.x - other.x
|
||||
dy = self.y - other.y
|
||||
return (dx * dx + dy * dy)
|
||||
end
|
||||
|
||||
p1 = Point{x = 0, y = 0}
|
||||
p2 = Point{x = 3, y = 4}
|
||||
|
||||
if p1.distance(p2) then
|
||||
echo "Distance calculated"
|
||||
end
|
||||
|
||||
for i = 1, 10 do
|
||||
point = Point{x = i, y = i * 2}
|
||||
echo point.x
|
||||
end`
|
||||
|
||||
l := parser.NewLexer(input)
|
||||
p := parser.NewParser(l)
|
||||
program := p.ParseProgram()
|
||||
checkParserErrors(t, p)
|
||||
|
||||
if len(program.Statements) != 6 {
|
||||
t.Fatalf("expected 6 statements, got %d", len(program.Statements))
|
||||
}
|
||||
|
||||
// Verify struct definition
|
||||
structStmt, ok := program.Statements[0].(*parser.StructStatement)
|
||||
if !ok {
|
||||
t.Fatalf("expected StructStatement, got %T", program.Statements[0])
|
||||
}
|
||||
if structStmt.Name != "Point" {
|
||||
t.Errorf("expected struct name 'Point', got %s", structStmt.Name)
|
||||
}
|
||||
|
||||
// Verify method definition
|
||||
methodStmt, ok := program.Statements[1].(*parser.MethodDefinition)
|
||||
if !ok {
|
||||
t.Fatalf("expected MethodDefinition, got %T", program.Statements[1])
|
||||
}
|
||||
if methodStmt.StructName != "Point" {
|
||||
t.Errorf("expected struct name 'Point', got %s", methodStmt.StructName)
|
||||
}
|
||||
if methodStmt.MethodName != "distance" {
|
||||
t.Errorf("expected method name 'distance', got %s", methodStmt.MethodName)
|
||||
}
|
||||
|
||||
// Verify constructors
|
||||
for i := 2; i <= 3; i++ {
|
||||
assign, ok := program.Statements[i].(*parser.AssignStatement)
|
||||
if !ok {
|
||||
t.Fatalf("statement %d: expected AssignStatement, got %T", i, program.Statements[i])
|
||||
}
|
||||
constructor, ok := assign.Value.(*parser.StructConstructorExpression)
|
||||
if !ok {
|
||||
t.Fatalf("statement %d: expected StructConstructorExpression, got %T", i, assign.Value)
|
||||
}
|
||||
if constructor.StructName != "Point" {
|
||||
t.Errorf("statement %d: expected struct name 'Point', got %s", i, constructor.StructName)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify if statement with method call
|
||||
ifStmt, ok := program.Statements[4].(*parser.IfStatement)
|
||||
if !ok {
|
||||
t.Fatalf("expected IfStatement, got %T", program.Statements[4])
|
||||
}
|
||||
callExpr, ok := ifStmt.Condition.(*parser.CallExpression)
|
||||
if !ok {
|
||||
t.Fatalf("expected CallExpression in if condition, got %T", ifStmt.Condition)
|
||||
}
|
||||
dotExpr, ok := callExpr.Function.(*parser.DotExpression)
|
||||
if !ok {
|
||||
t.Fatalf("expected DotExpression for method call, got %T", callExpr.Function)
|
||||
}
|
||||
if dotExpr.Key != "distance" {
|
||||
t.Errorf("expected method name 'distance', got %s", dotExpr.Key)
|
||||
}
|
||||
|
||||
// Verify for loop with struct creation
|
||||
forStmt, ok := program.Statements[5].(*parser.ForStatement)
|
||||
if !ok {
|
||||
t.Fatalf("expected ForStatement, got %T", program.Statements[5])
|
||||
}
|
||||
if len(forStmt.Body) != 2 {
|
||||
t.Errorf("expected 2 statements in for body, got %d", len(forStmt.Body))
|
||||
}
|
||||
|
||||
// Check struct constructor in loop
|
||||
loopAssign, ok := forStmt.Body[0].(*parser.AssignStatement)
|
||||
if !ok {
|
||||
t.Fatalf("expected AssignStatement in loop, got %T", forStmt.Body[0])
|
||||
}
|
||||
loopConstructor, ok := loopAssign.Value.(*parser.StructConstructorExpression)
|
||||
if !ok {
|
||||
t.Fatalf("expected StructConstructorExpression in loop, got %T", loopAssign.Value)
|
||||
}
|
||||
if loopConstructor.StructName != "Point" {
|
||||
t.Errorf("expected struct name 'Point' in loop, got %s", loopConstructor.StructName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStructErrorCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expectedErrorSubstring string
|
||||
}{
|
||||
{
|
||||
name: "missing field type",
|
||||
input: `struct Person {
|
||||
name
|
||||
}`,
|
||||
expectedErrorSubstring: "struct fields require type annotation",
|
||||
},
|
||||
{
|
||||
name: "missing struct name",
|
||||
input: `struct {
|
||||
name: string
|
||||
}`,
|
||||
expectedErrorSubstring: "expected struct name",
|
||||
},
|
||||
{
|
||||
name: "missing opening brace",
|
||||
input: `struct Person
|
||||
name: string
|
||||
}`,
|
||||
expectedErrorSubstring: "expected '{' after struct name",
|
||||
},
|
||||
{
|
||||
name: "missing closing brace",
|
||||
input: `struct Person {
|
||||
name: string`,
|
||||
expectedErrorSubstring: "expected next token to be }, got end of file instead",
|
||||
},
|
||||
{
|
||||
name: "invalid field type",
|
||||
input: `struct Person {
|
||||
name: invalidtype
|
||||
}`,
|
||||
expectedErrorSubstring: "invalid type name 'invalidtype'",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
l := parser.NewLexer(tt.input)
|
||||
p := parser.NewParser(l)
|
||||
_ = p.ParseProgram()
|
||||
|
||||
if !p.HasErrors() {
|
||||
t.Fatalf("expected parser errors, but got none")
|
||||
}
|
||||
|
||||
errors := p.ErrorStrings()
|
||||
found := false
|
||||
for _, err := range errors {
|
||||
if containsSubstring(err, tt.expectedErrorSubstring) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
t.Errorf("expected error containing '%s', got errors: %v", tt.expectedErrorSubstring, errors)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSingleLineStruct(t *testing.T) {
|
||||
input := `struct Person { name: string, age: number }`
|
||||
|
||||
l := parser.NewLexer(input)
|
||||
p := parser.NewParser(l)
|
||||
program := p.ParseProgram()
|
||||
checkParserErrors(t, p)
|
||||
|
||||
if len(program.Statements) != 1 {
|
||||
t.Fatalf("expected 1 statement, got %d", len(program.Statements))
|
||||
}
|
||||
|
||||
stmt, ok := program.Statements[0].(*parser.StructStatement)
|
||||
if !ok {
|
||||
t.Fatalf("expected StructStatement, got %T", program.Statements[0])
|
||||
}
|
||||
|
||||
if stmt.Name != "Person" {
|
||||
t.Errorf("expected struct name 'Person', got %s", stmt.Name)
|
||||
}
|
||||
|
||||
if len(stmt.Fields) != 2 {
|
||||
t.Fatalf("expected 2 fields, got %d", len(stmt.Fields))
|
||||
}
|
||||
|
||||
if stmt.Fields[0].Name != "name" || stmt.Fields[0].TypeHint.Type != "string" {
|
||||
t.Errorf("expected first field 'name: string', got '%s: %s'",
|
||||
stmt.Fields[0].Name, stmt.Fields[0].TypeHint.Type)
|
||||
}
|
||||
|
||||
if stmt.Fields[1].Name != "age" || stmt.Fields[1].TypeHint.Type != "number" {
|
||||
t.Errorf("expected second field 'age: number', got '%s: %s'",
|
||||
stmt.Fields[1].Name, stmt.Fields[1].TypeHint.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStructString(t *testing.T) {
|
||||
input := `struct Person {
|
||||
name: string,
|
||||
age: number
|
||||
}`
|
||||
|
||||
l := parser.NewLexer(input)
|
||||
p := parser.NewParser(l)
|
||||
program := p.ParseProgram()
|
||||
checkParserErrors(t, p)
|
||||
|
||||
stmt := program.Statements[0].(*parser.StructStatement)
|
||||
str := stmt.String()
|
||||
|
||||
expected := "struct Person {\n\tname: string,\n\tage: number\n}"
|
||||
if str != expected {
|
||||
t.Errorf("expected string representation:\n%s\ngot:\n%s", expected, str)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMethodString(t *testing.T) {
|
||||
input := `struct Person {
|
||||
name: string
|
||||
}
|
||||
|
||||
fn Person.getName(): string
|
||||
return self.name
|
||||
end`
|
||||
|
||||
l := parser.NewLexer(input)
|
||||
p := parser.NewParser(l)
|
||||
program := p.ParseProgram()
|
||||
checkParserErrors(t, p)
|
||||
|
||||
method := program.Statements[1].(*parser.MethodDefinition)
|
||||
str := method.String()
|
||||
|
||||
if !containsSubstring(str, "fn Person.getName") {
|
||||
t.Errorf("expected method string to contain 'fn Person.getName', got: %s", str)
|
||||
}
|
||||
if !containsSubstring(str, ": string") {
|
||||
t.Errorf("expected method string to contain return type, got: %s", str)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConstructorString(t *testing.T) {
|
||||
input := `struct Person {
|
||||
name: string,
|
||||
age: number
|
||||
}
|
||||
|
||||
person = Person{name = "John", age = 30}`
|
||||
|
||||
l := parser.NewLexer(input)
|
||||
p := parser.NewParser(l)
|
||||
program := p.ParseProgram()
|
||||
checkParserErrors(t, p)
|
||||
|
||||
assign := program.Statements[1].(*parser.AssignStatement)
|
||||
constructor := assign.Value.(*parser.StructConstructorExpression)
|
||||
str := constructor.String()
|
||||
|
||||
expected := `Person{name = "John", age = 30.00}`
|
||||
if str != expected {
|
||||
t.Errorf("expected constructor string:\n%s\ngot:\n%s", expected, str)
|
||||
}
|
||||
}
|
@ -59,6 +59,7 @@ const (
|
||||
EXIT
|
||||
FN
|
||||
RETURN
|
||||
STRUCT
|
||||
|
||||
// Special
|
||||
EOF
|
||||
@ -107,6 +108,7 @@ var precedences = map[TokenType]Precedence{
|
||||
DOT: MEMBER,
|
||||
LBRACKET: MEMBER,
|
||||
LPAREN: CALL,
|
||||
LBRACE: CALL,
|
||||
}
|
||||
|
||||
// lookupIdent checks if an identifier is a keyword
|
||||
@ -132,6 +134,7 @@ func lookupIdent(ident string) TokenType {
|
||||
"exit": EXIT,
|
||||
"fn": FN,
|
||||
"return": RETURN,
|
||||
"struct": STRUCT,
|
||||
}
|
||||
|
||||
if tok, ok := keywords[ident]; ok {
|
||||
|
197
parser/types.go
197
parser/types.go
@ -76,6 +76,9 @@ type TypeInferrer struct {
|
||||
nilType *TypeInfo
|
||||
tableType *TypeInfo
|
||||
anyType *TypeInfo
|
||||
|
||||
// Struct definitions
|
||||
structs map[string]*StructStatement
|
||||
}
|
||||
|
||||
// NewTypeInferrer creates a new type inference engine
|
||||
@ -86,6 +89,7 @@ func NewTypeInferrer() *TypeInferrer {
|
||||
currentScope: globalScope,
|
||||
globalScope: globalScope,
|
||||
errors: []TypeError{},
|
||||
structs: make(map[string]*StructStatement),
|
||||
|
||||
// Pre-allocate common types to reduce allocations
|
||||
numberType: &TypeInfo{Type: TypeNumber, Inferred: true},
|
||||
@ -101,6 +105,14 @@ func NewTypeInferrer() *TypeInferrer {
|
||||
|
||||
// InferTypes performs type inference on the entire program
|
||||
func (ti *TypeInferrer) InferTypes(program *Program) []TypeError {
|
||||
// First pass: collect struct definitions
|
||||
for _, stmt := range program.Statements {
|
||||
if structStmt, ok := stmt.(*StructStatement); ok {
|
||||
ti.structs[structStmt.Name] = structStmt
|
||||
}
|
||||
}
|
||||
|
||||
// Second pass: infer types
|
||||
for _, stmt := range program.Statements {
|
||||
ti.inferStatement(stmt)
|
||||
}
|
||||
@ -129,9 +141,27 @@ func (ti *TypeInferrer) addError(message string, node Node) {
|
||||
})
|
||||
}
|
||||
|
||||
// getStructType returns TypeInfo for a struct
|
||||
func (ti *TypeInferrer) getStructType(name string) *TypeInfo {
|
||||
if _, exists := ti.structs[name]; exists {
|
||||
return &TypeInfo{Type: name, Inferred: true}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// isStructType checks if a type is a struct type
|
||||
func (ti *TypeInferrer) isStructType(t *TypeInfo) bool {
|
||||
_, exists := ti.structs[t.Type]
|
||||
return exists
|
||||
}
|
||||
|
||||
// inferStatement infers types for statements
|
||||
func (ti *TypeInferrer) inferStatement(stmt Statement) {
|
||||
switch s := stmt.(type) {
|
||||
case *StructStatement:
|
||||
ti.inferStructStatement(s)
|
||||
case *MethodDefinition:
|
||||
ti.inferMethodDefinition(s)
|
||||
case *AssignStatement:
|
||||
ti.inferAssignStatement(s)
|
||||
case *EchoStatement:
|
||||
@ -152,9 +182,63 @@ func (ti *TypeInferrer) inferStatement(stmt Statement) {
|
||||
if s.Value != nil {
|
||||
ti.inferExpression(s.Value)
|
||||
}
|
||||
case *ExpressionStatement:
|
||||
ti.inferExpression(s.Expression)
|
||||
}
|
||||
}
|
||||
|
||||
// inferStructStatement handles struct definitions
|
||||
func (ti *TypeInferrer) inferStructStatement(stmt *StructStatement) {
|
||||
// Validate field types
|
||||
for _, field := range stmt.Fields {
|
||||
if field.TypeHint != nil {
|
||||
if !ValidTypeName(field.TypeHint.Type) && !ti.isStructType(field.TypeHint) {
|
||||
ti.addError(fmt.Sprintf("invalid field type '%s' in struct '%s'",
|
||||
field.TypeHint.Type, stmt.Name), stmt)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// inferMethodDefinition handles method definitions
|
||||
func (ti *TypeInferrer) inferMethodDefinition(stmt *MethodDefinition) {
|
||||
// Check if struct exists
|
||||
if _, exists := ti.structs[stmt.StructName]; !exists {
|
||||
ti.addError(fmt.Sprintf("method defined on undefined struct '%s'", stmt.StructName), stmt)
|
||||
return
|
||||
}
|
||||
|
||||
// Infer the function body
|
||||
ti.enterScope()
|
||||
|
||||
// Add self parameter implicitly
|
||||
ti.currentScope.Define(&Symbol{
|
||||
Name: "self",
|
||||
Type: ti.getStructType(stmt.StructName),
|
||||
Declared: true,
|
||||
})
|
||||
|
||||
// Add explicit parameters
|
||||
for _, param := range stmt.Function.Parameters {
|
||||
paramType := ti.anyType
|
||||
if param.TypeHint != nil {
|
||||
paramType = param.TypeHint
|
||||
}
|
||||
ti.currentScope.Define(&Symbol{
|
||||
Name: param.Name,
|
||||
Type: paramType,
|
||||
Declared: true,
|
||||
})
|
||||
}
|
||||
|
||||
// Infer function body
|
||||
for _, bodyStmt := range stmt.Function.Body {
|
||||
ti.inferStatement(bodyStmt)
|
||||
}
|
||||
|
||||
ti.exitScope()
|
||||
}
|
||||
|
||||
// inferAssignStatement handles variable assignments with type checking
|
||||
func (ti *TypeInferrer) inferAssignStatement(stmt *AssignStatement) {
|
||||
// Infer the type of the value expression
|
||||
@ -288,9 +372,9 @@ func (ti *TypeInferrer) inferForStatement(stmt *ForStatement) {
|
||||
func (ti *TypeInferrer) inferForInStatement(stmt *ForInStatement) {
|
||||
iterableType := ti.inferExpression(stmt.Iterable)
|
||||
|
||||
// For now, assume iterable is a table
|
||||
if !ti.isTableType(iterableType) {
|
||||
ti.addError("for-in requires an iterable (table)", stmt.Iterable)
|
||||
// For now, assume iterable is a table or struct
|
||||
if !ti.isTableType(iterableType) && !ti.isStructType(iterableType) {
|
||||
ti.addError("for-in requires an iterable (table or struct)", stmt.Iterable)
|
||||
}
|
||||
|
||||
ti.enterScope()
|
||||
@ -341,6 +425,8 @@ func (ti *TypeInferrer) inferExpression(expr Expression) *TypeInfo {
|
||||
return ti.nilType
|
||||
case *TableLiteral:
|
||||
return ti.inferTableLiteral(e)
|
||||
case *StructConstructorExpression:
|
||||
return ti.inferStructConstructor(e)
|
||||
case *FunctionLiteral:
|
||||
return ti.inferFunctionLiteral(e)
|
||||
case *CallExpression:
|
||||
@ -353,12 +439,85 @@ func (ti *TypeInferrer) inferExpression(expr Expression) *TypeInfo {
|
||||
return ti.inferIndexExpression(e)
|
||||
case *DotExpression:
|
||||
return ti.inferDotExpression(e)
|
||||
case *AssignExpression:
|
||||
return ti.inferAssignExpression(e)
|
||||
default:
|
||||
ti.addError("unknown expression type", expr)
|
||||
return ti.anyType
|
||||
}
|
||||
}
|
||||
|
||||
// inferStructConstructor handles struct constructor expressions
|
||||
func (ti *TypeInferrer) inferStructConstructor(expr *StructConstructorExpression) *TypeInfo {
|
||||
structDef, exists := ti.structs[expr.StructName]
|
||||
if !exists {
|
||||
ti.addError(fmt.Sprintf("undefined struct '%s'", expr.StructName), expr)
|
||||
return ti.anyType
|
||||
}
|
||||
|
||||
// Validate field assignments
|
||||
for _, pair := range expr.Fields {
|
||||
if pair.Key != nil {
|
||||
fieldName := ""
|
||||
if ident, ok := pair.Key.(*Identifier); ok {
|
||||
fieldName = ident.Value
|
||||
} else if str, ok := pair.Key.(*StringLiteral); ok {
|
||||
fieldName = str.Value
|
||||
}
|
||||
|
||||
// Check if field exists in struct
|
||||
fieldExists := false
|
||||
var fieldType *TypeInfo
|
||||
for _, field := range structDef.Fields {
|
||||
if field.Name == fieldName {
|
||||
fieldExists = true
|
||||
fieldType = field.TypeHint
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !fieldExists {
|
||||
ti.addError(fmt.Sprintf("struct '%s' has no field '%s'", expr.StructName, fieldName), expr)
|
||||
} else {
|
||||
// Check type compatibility
|
||||
valueType := ti.inferExpression(pair.Value)
|
||||
if !ti.isTypeCompatible(valueType, fieldType) {
|
||||
ti.addError(fmt.Sprintf("cannot assign %s to field '%s' of type %s",
|
||||
valueType.Type, fieldName, fieldType.Type), expr)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Array-style assignment not valid for structs
|
||||
ti.addError("struct constructors require named field assignments", expr)
|
||||
}
|
||||
}
|
||||
|
||||
structType := ti.getStructType(expr.StructName)
|
||||
expr.SetType(structType)
|
||||
return structType
|
||||
}
|
||||
|
||||
// inferAssignExpression handles assignment expressions
|
||||
func (ti *TypeInferrer) inferAssignExpression(expr *AssignExpression) *TypeInfo {
|
||||
valueType := ti.inferExpression(expr.Value)
|
||||
|
||||
if ident, ok := expr.Name.(*Identifier); ok {
|
||||
if expr.IsDeclaration {
|
||||
ti.currentScope.Define(&Symbol{
|
||||
Name: ident.Value,
|
||||
Type: valueType,
|
||||
Declared: true,
|
||||
})
|
||||
}
|
||||
ident.SetType(valueType)
|
||||
} else {
|
||||
ti.inferExpression(expr.Name)
|
||||
}
|
||||
|
||||
expr.SetType(valueType)
|
||||
return valueType
|
||||
}
|
||||
|
||||
// inferIdentifier looks up identifier type in symbol table
|
||||
func (ti *TypeInferrer) inferIdentifier(ident *Identifier) *TypeInfo {
|
||||
symbol := ti.currentScope.Lookup(ident.Value)
|
||||
@ -497,17 +656,43 @@ func (ti *TypeInferrer) inferInfixExpression(infix *InfixExpression) *TypeInfo {
|
||||
|
||||
// inferIndexExpression infers table[index] type
|
||||
func (ti *TypeInferrer) inferIndexExpression(index *IndexExpression) *TypeInfo {
|
||||
ti.inferExpression(index.Left)
|
||||
leftType := ti.inferExpression(index.Left)
|
||||
ti.inferExpression(index.Index)
|
||||
|
||||
// For now, assume table access returns any
|
||||
// If indexing a struct, try to infer field type
|
||||
if ti.isStructType(leftType) {
|
||||
if strLit, ok := index.Index.(*StringLiteral); ok {
|
||||
if structDef, exists := ti.structs[leftType.Type]; exists {
|
||||
for _, field := range structDef.Fields {
|
||||
if field.Name == strLit.Value {
|
||||
index.SetType(field.TypeHint)
|
||||
return field.TypeHint
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// For now, assume table/struct access returns any
|
||||
index.SetType(ti.anyType)
|
||||
return ti.anyType
|
||||
}
|
||||
|
||||
// inferDotExpression infers table.key type
|
||||
func (ti *TypeInferrer) inferDotExpression(dot *DotExpression) *TypeInfo {
|
||||
ti.inferExpression(dot.Left)
|
||||
leftType := ti.inferExpression(dot.Left)
|
||||
|
||||
// If accessing a struct field, try to infer field type
|
||||
if ti.isStructType(leftType) {
|
||||
if structDef, exists := ti.structs[leftType.Type]; exists {
|
||||
for _, field := range structDef.Fields {
|
||||
if field.Name == dot.Key {
|
||||
dot.SetType(field.TypeHint)
|
||||
return field.TypeHint
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// For now, assume member access returns any
|
||||
dot.SetType(ti.anyType)
|
||||
|
Loading…
x
Reference in New Issue
Block a user