structs
This commit is contained in:
parent
fc439b6d5a
commit
5ae2a6ef23
@ -4,7 +4,7 @@ import "fmt"
|
|||||||
|
|
||||||
// TypeInfo represents type information for expressions
|
// TypeInfo represents type information for expressions
|
||||||
type TypeInfo struct {
|
type TypeInfo struct {
|
||||||
Type string // "number", "string", "bool", "table", "function", "nil", "any"
|
Type string // "number", "string", "bool", "table", "function", "nil", "any", struct name
|
||||||
Inferred bool // true if type was inferred, false if explicitly declared
|
Inferred bool // true if type was inferred, false if explicitly declared
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -41,6 +41,67 @@ func (p *Program) String() string {
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// StructField represents a field in a struct definition
|
||||||
|
type StructField struct {
|
||||||
|
Name string
|
||||||
|
TypeHint *TypeInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sf *StructField) String() string {
|
||||||
|
if sf.TypeHint != nil {
|
||||||
|
return fmt.Sprintf("%s: %s", sf.Name, sf.TypeHint.Type)
|
||||||
|
}
|
||||||
|
return sf.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
// StructStatement represents struct definitions
|
||||||
|
type StructStatement struct {
|
||||||
|
Name string
|
||||||
|
Fields []StructField
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ss *StructStatement) statementNode() {}
|
||||||
|
func (ss *StructStatement) String() string {
|
||||||
|
var fields string
|
||||||
|
for i, field := range ss.Fields {
|
||||||
|
if i > 0 {
|
||||||
|
fields += ",\n\t"
|
||||||
|
}
|
||||||
|
fields += field.String()
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("struct %s {\n\t%s\n}", ss.Name, fields)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MethodDefinition represents method definitions on structs
|
||||||
|
type MethodDefinition struct {
|
||||||
|
StructName string
|
||||||
|
MethodName string
|
||||||
|
Function *FunctionLiteral
|
||||||
|
}
|
||||||
|
|
||||||
|
func (md *MethodDefinition) statementNode() {}
|
||||||
|
func (md *MethodDefinition) String() string {
|
||||||
|
return fmt.Sprintf("fn %s.%s%s", md.StructName, md.MethodName, md.Function.String()[2:]) // skip "fn" from function string
|
||||||
|
}
|
||||||
|
|
||||||
|
// StructConstructorExpression represents struct constructor calls like my_type{...}
|
||||||
|
type StructConstructorExpression struct {
|
||||||
|
StructName string
|
||||||
|
Fields []TablePair // reuse TablePair for field assignments
|
||||||
|
typeInfo *TypeInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sce *StructConstructorExpression) expressionNode() {}
|
||||||
|
func (sce *StructConstructorExpression) String() string {
|
||||||
|
var pairs []string
|
||||||
|
for _, pair := range sce.Fields {
|
||||||
|
pairs = append(pairs, pair.String())
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s{%s}", sce.StructName, joinStrings(pairs, ", "))
|
||||||
|
}
|
||||||
|
func (sce *StructConstructorExpression) GetType() *TypeInfo { return sce.typeInfo }
|
||||||
|
func (sce *StructConstructorExpression) SetType(t *TypeInfo) { sce.typeInfo = t }
|
||||||
|
|
||||||
// AssignStatement represents variable assignment with optional type hint
|
// AssignStatement represents variable assignment with optional type hint
|
||||||
type AssignStatement struct {
|
type AssignStatement struct {
|
||||||
Name Expression // Changed from *Identifier to Expression for member access
|
Name Expression // Changed from *Identifier to Expression for member access
|
||||||
|
313
parser/parser.go
313
parser/parser.go
@ -34,6 +34,9 @@ type Parser struct {
|
|||||||
// Scope tracking
|
// Scope tracking
|
||||||
scopes []map[string]bool // stack of scopes, each tracking declared variables
|
scopes []map[string]bool // stack of scopes, each tracking declared variables
|
||||||
scopeTypes []string // track what type each scope is: "global", "function", "loop"
|
scopeTypes []string // track what type each scope is: "global", "function", "loop"
|
||||||
|
|
||||||
|
// Struct tracking
|
||||||
|
structs map[string]*StructStatement // track defined structs
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewParser creates a new parser instance
|
// NewParser creates a new parser instance
|
||||||
@ -43,6 +46,7 @@ func NewParser(lexer *Lexer) *Parser {
|
|||||||
errors: []ParseError{},
|
errors: []ParseError{},
|
||||||
scopes: []map[string]bool{make(map[string]bool)}, // start with global scope
|
scopes: []map[string]bool{make(map[string]bool)}, // start with global scope
|
||||||
scopeTypes: []string{"global"}, // start with global scope type
|
scopeTypes: []string{"global"}, // start with global scope type
|
||||||
|
structs: make(map[string]*StructStatement), // track struct definitions
|
||||||
}
|
}
|
||||||
|
|
||||||
p.prefixParseFns = make(map[TokenType]func() Expression)
|
p.prefixParseFns = make(map[TokenType]func() Expression)
|
||||||
@ -74,6 +78,7 @@ func NewParser(lexer *Lexer) *Parser {
|
|||||||
p.registerInfix(DOT, p.parseDotExpression)
|
p.registerInfix(DOT, p.parseDotExpression)
|
||||||
p.registerInfix(LBRACKET, p.parseIndexExpression)
|
p.registerInfix(LBRACKET, p.parseIndexExpression)
|
||||||
p.registerInfix(LPAREN, p.parseCallExpression)
|
p.registerInfix(LPAREN, p.parseCallExpression)
|
||||||
|
p.registerInfix(LBRACE, p.parseStructConstructor) // struct constructor
|
||||||
|
|
||||||
// Read two tokens, so curToken and peekToken are both set
|
// Read two tokens, so curToken and peekToken are both set
|
||||||
p.nextToken()
|
p.nextToken()
|
||||||
@ -157,7 +162,7 @@ func (p *Parser) parseTypeHint() *TypeInfo {
|
|||||||
}
|
}
|
||||||
|
|
||||||
typeName := p.curToken.Literal
|
typeName := p.curToken.Literal
|
||||||
if !ValidTypeName(typeName) {
|
if !ValidTypeName(typeName) && !p.isStructDefined(typeName) {
|
||||||
p.addError(fmt.Sprintf("invalid type name '%s'", typeName))
|
p.addError(fmt.Sprintf("invalid type name '%s'", typeName))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -165,6 +170,12 @@ func (p *Parser) parseTypeHint() *TypeInfo {
|
|||||||
return &TypeInfo{Type: typeName, Inferred: false}
|
return &TypeInfo{Type: typeName, Inferred: false}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isStructDefined checks if a struct name is defined
|
||||||
|
func (p *Parser) isStructDefined(name string) bool {
|
||||||
|
_, exists := p.structs[name]
|
||||||
|
return exists
|
||||||
|
}
|
||||||
|
|
||||||
// registerPrefix registers a prefix parse function
|
// registerPrefix registers a prefix parse function
|
||||||
func (p *Parser) registerPrefix(tokenType TokenType, fn func() Expression) {
|
func (p *Parser) registerPrefix(tokenType TokenType, fn func() Expression) {
|
||||||
p.prefixParseFns[tokenType] = fn
|
p.prefixParseFns[tokenType] = fn
|
||||||
@ -200,6 +211,10 @@ func (p *Parser) ParseProgram() *Program {
|
|||||||
// parseStatement parses a statement
|
// parseStatement parses a statement
|
||||||
func (p *Parser) parseStatement() Statement {
|
func (p *Parser) parseStatement() Statement {
|
||||||
switch p.curToken.Type {
|
switch p.curToken.Type {
|
||||||
|
case STRUCT:
|
||||||
|
return p.parseStructStatement()
|
||||||
|
case FN:
|
||||||
|
return p.parseFunctionStatement()
|
||||||
case IDENT:
|
case IDENT:
|
||||||
return p.parseIdentifierStatement()
|
return p.parseIdentifierStatement()
|
||||||
case IF:
|
case IF:
|
||||||
@ -230,6 +245,147 @@ func (p *Parser) parseStatement() Statement {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// parseStructStatement parses struct definitions
|
||||||
|
func (p *Parser) parseStructStatement() *StructStatement {
|
||||||
|
stmt := &StructStatement{}
|
||||||
|
|
||||||
|
if !p.expectPeek(IDENT) {
|
||||||
|
p.addError("expected struct name")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
stmt.Name = p.curToken.Literal
|
||||||
|
|
||||||
|
if !p.expectPeek(LBRACE) {
|
||||||
|
p.addError("expected '{' after struct name")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
stmt.Fields = []StructField{}
|
||||||
|
|
||||||
|
if p.peekTokenIs(RBRACE) {
|
||||||
|
p.nextToken()
|
||||||
|
p.structs[stmt.Name] = stmt
|
||||||
|
return stmt
|
||||||
|
}
|
||||||
|
|
||||||
|
p.nextToken()
|
||||||
|
|
||||||
|
for {
|
||||||
|
if p.curTokenIs(EOF) {
|
||||||
|
p.addError("unexpected end of input, expected }")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if !p.curTokenIs(IDENT) {
|
||||||
|
p.addError("expected field name")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
field := StructField{Name: p.curToken.Literal}
|
||||||
|
|
||||||
|
// Parse optional type hint
|
||||||
|
field.TypeHint = p.parseTypeHint()
|
||||||
|
if field.TypeHint == nil {
|
||||||
|
p.addError("struct fields require type annotation")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
stmt.Fields = append(stmt.Fields, field)
|
||||||
|
|
||||||
|
if !p.peekTokenIs(COMMA) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
p.nextToken() // consume comma
|
||||||
|
p.nextToken() // move to next field
|
||||||
|
|
||||||
|
if p.curTokenIs(RBRACE) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.curTokenIs(EOF) {
|
||||||
|
p.addError("expected next token to be }")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !p.expectPeek(RBRACE) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
p.structs[stmt.Name] = stmt
|
||||||
|
return stmt
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseFunctionStatement handles both regular functions and methods
|
||||||
|
func (p *Parser) parseFunctionStatement() Statement {
|
||||||
|
if !p.expectPeek(IDENT) {
|
||||||
|
p.addError("expected function name")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
funcName := p.curToken.Literal
|
||||||
|
|
||||||
|
// Check if this is a method definition (struct.method)
|
||||||
|
if p.peekTokenIs(DOT) {
|
||||||
|
p.nextToken() // consume '.'
|
||||||
|
|
||||||
|
if !p.expectPeek(IDENT) {
|
||||||
|
p.addError("expected method name after '.'")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
methodName := p.curToken.Literal
|
||||||
|
|
||||||
|
if !p.expectPeek(LPAREN) {
|
||||||
|
p.addError("expected '(' after method name")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the function literal starting from parameters
|
||||||
|
funcLit := &FunctionLiteral{}
|
||||||
|
funcLit.Parameters, funcLit.Variadic = p.parseFunctionParameters()
|
||||||
|
|
||||||
|
if !p.expectPeek(RPAREN) {
|
||||||
|
p.addError("expected ')' after function parameters")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for return type hint
|
||||||
|
funcLit.ReturnType = p.parseTypeHint()
|
||||||
|
|
||||||
|
p.nextToken()
|
||||||
|
|
||||||
|
p.enterFunctionScope()
|
||||||
|
for _, param := range funcLit.Parameters {
|
||||||
|
p.declareVariable(param.Name)
|
||||||
|
}
|
||||||
|
funcLit.Body = p.parseBlockStatements(END)
|
||||||
|
p.exitFunctionScope()
|
||||||
|
|
||||||
|
if !p.curTokenIs(END) {
|
||||||
|
p.addError("expected 'end' to close function")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &MethodDefinition{
|
||||||
|
StructName: funcName,
|
||||||
|
MethodName: methodName,
|
||||||
|
Function: funcLit,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Regular function - this should be handled as expression statement
|
||||||
|
// Reset to handle as function literal
|
||||||
|
funcLit := p.parseFunctionLiteral()
|
||||||
|
if funcLit == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &ExpressionStatement{Expression: funcLit}
|
||||||
|
}
|
||||||
|
|
||||||
// parseIdentifierStatement handles both assignments and expression statements starting with identifiers
|
// parseIdentifierStatement handles both assignments and expression statements starting with identifiers
|
||||||
func (p *Parser) parseIdentifierStatement() Statement {
|
func (p *Parser) parseIdentifierStatement() Statement {
|
||||||
// Parse the left-hand side expression first
|
// Parse the left-hand side expression first
|
||||||
@ -948,6 +1104,157 @@ func (p *Parser) parseTableLiteral() Expression {
|
|||||||
return table
|
return table
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// parseStructConstructor handles struct constructor calls like my_type{...}
|
||||||
|
func (p *Parser) parseStructConstructor(left Expression) Expression {
|
||||||
|
// left should be an identifier representing the struct name
|
||||||
|
ident, ok := left.(*Identifier)
|
||||||
|
if !ok {
|
||||||
|
// Not an identifier, fall back to table literal parsing
|
||||||
|
return p.parseTableLiteralFromBrace()
|
||||||
|
}
|
||||||
|
|
||||||
|
structName := ident.Value
|
||||||
|
|
||||||
|
// Always try to parse as struct constructor if we have an identifier
|
||||||
|
// Type checking will catch undefined structs later
|
||||||
|
constructor := &StructConstructorExpression{
|
||||||
|
StructName: structName,
|
||||||
|
Fields: []TablePair{},
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.peekTokenIs(RBRACE) {
|
||||||
|
p.nextToken()
|
||||||
|
return constructor
|
||||||
|
}
|
||||||
|
|
||||||
|
p.nextToken()
|
||||||
|
|
||||||
|
for {
|
||||||
|
if p.curTokenIs(EOF) {
|
||||||
|
p.addError("unexpected end of input, expected }")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
pair := TablePair{}
|
||||||
|
|
||||||
|
if (p.curTokenIs(IDENT) || p.curTokenIs(STRING)) && p.peekTokenIs(ASSIGN) {
|
||||||
|
if p.curTokenIs(IDENT) {
|
||||||
|
pair.Key = &Identifier{Value: p.curToken.Literal}
|
||||||
|
} else {
|
||||||
|
pair.Key = &StringLiteral{Value: p.curToken.Literal}
|
||||||
|
}
|
||||||
|
p.nextToken()
|
||||||
|
p.nextToken()
|
||||||
|
|
||||||
|
if p.curTokenIs(EOF) {
|
||||||
|
p.addError("expected expression after assignment operator")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
pair.Value = p.ParseExpression(LOWEST)
|
||||||
|
} else {
|
||||||
|
pair.Value = p.ParseExpression(LOWEST)
|
||||||
|
}
|
||||||
|
|
||||||
|
if pair.Value == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
constructor.Fields = append(constructor.Fields, pair)
|
||||||
|
|
||||||
|
if !p.peekTokenIs(COMMA) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
p.nextToken()
|
||||||
|
p.nextToken()
|
||||||
|
|
||||||
|
if p.curTokenIs(RBRACE) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.curTokenIs(EOF) {
|
||||||
|
p.addError("expected next token to be }")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !p.expectPeek(RBRACE) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return constructor
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Parser) parseTableLiteralFromBrace() Expression {
|
||||||
|
// We're already at the opening brace, so parse as table literal
|
||||||
|
table := &TableLiteral{}
|
||||||
|
table.Pairs = []TablePair{}
|
||||||
|
|
||||||
|
if p.peekTokenIs(RBRACE) {
|
||||||
|
p.nextToken()
|
||||||
|
return table
|
||||||
|
}
|
||||||
|
|
||||||
|
p.nextToken()
|
||||||
|
|
||||||
|
for {
|
||||||
|
if p.curTokenIs(EOF) {
|
||||||
|
p.addError("unexpected end of input, expected }")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
pair := TablePair{}
|
||||||
|
|
||||||
|
if (p.curTokenIs(IDENT) || p.curTokenIs(STRING)) && p.peekTokenIs(ASSIGN) {
|
||||||
|
if p.curTokenIs(IDENT) {
|
||||||
|
pair.Key = &Identifier{Value: p.curToken.Literal}
|
||||||
|
} else {
|
||||||
|
pair.Key = &StringLiteral{Value: p.curToken.Literal}
|
||||||
|
}
|
||||||
|
p.nextToken()
|
||||||
|
p.nextToken()
|
||||||
|
|
||||||
|
if p.curTokenIs(EOF) {
|
||||||
|
p.addError("expected expression after assignment operator")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
pair.Value = p.ParseExpression(LOWEST)
|
||||||
|
} else {
|
||||||
|
pair.Value = p.ParseExpression(LOWEST)
|
||||||
|
}
|
||||||
|
|
||||||
|
if pair.Value == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
table.Pairs = append(table.Pairs, pair)
|
||||||
|
|
||||||
|
if !p.peekTokenIs(COMMA) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
p.nextToken()
|
||||||
|
p.nextToken()
|
||||||
|
|
||||||
|
if p.curTokenIs(RBRACE) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.curTokenIs(EOF) {
|
||||||
|
p.addError("expected next token to be }")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !p.expectPeek(RBRACE) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return table
|
||||||
|
}
|
||||||
|
|
||||||
func (p *Parser) parseInfixExpression(left Expression) Expression {
|
func (p *Parser) parseInfixExpression(left Expression) Expression {
|
||||||
expression := &InfixExpression{
|
expression := &InfixExpression{
|
||||||
Left: left,
|
Left: left,
|
||||||
@ -1057,7 +1364,7 @@ func (p *Parser) expectPeekIdent() bool {
|
|||||||
|
|
||||||
func (p *Parser) isKeyword(t TokenType) bool {
|
func (p *Parser) isKeyword(t TokenType) bool {
|
||||||
switch t {
|
switch t {
|
||||||
case TRUE, FALSE, NIL, AND, OR, NOT, IF, THEN, ELSEIF, ELSE, END, ECHO, FOR, WHILE, IN, DO, BREAK, EXIT, FN, RETURN:
|
case TRUE, FALSE, NIL, AND, OR, NOT, IF, THEN, ELSEIF, ELSE, END, ECHO, FOR, WHILE, IN, DO, BREAK, EXIT, FN, RETURN, STRUCT:
|
||||||
return true
|
return true
|
||||||
default:
|
default:
|
||||||
return false
|
return false
|
||||||
@ -1227,6 +1534,8 @@ func tokenTypeString(t TokenType) string {
|
|||||||
return "fn"
|
return "fn"
|
||||||
case RETURN:
|
case RETURN:
|
||||||
return "return"
|
return "return"
|
||||||
|
case STRUCT:
|
||||||
|
return "struct"
|
||||||
case EOF:
|
case EOF:
|
||||||
return "end of file"
|
return "end of file"
|
||||||
case ILLEGAL:
|
case ILLEGAL:
|
||||||
|
632
parser/tests/structs_test.go
Normal file
632
parser/tests/structs_test.go
Normal file
@ -0,0 +1,632 @@
|
|||||||
|
package parser_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.sharkk.net/Sharkk/Mako/parser"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBasicStructDefinition(t *testing.T) {
|
||||||
|
input := `struct Person {
|
||||||
|
name: string,
|
||||||
|
age: number
|
||||||
|
}`
|
||||||
|
|
||||||
|
l := parser.NewLexer(input)
|
||||||
|
p := parser.NewParser(l)
|
||||||
|
program := p.ParseProgram()
|
||||||
|
checkParserErrors(t, p)
|
||||||
|
|
||||||
|
if len(program.Statements) != 1 {
|
||||||
|
t.Fatalf("expected 1 statement, got %d", len(program.Statements))
|
||||||
|
}
|
||||||
|
|
||||||
|
stmt, ok := program.Statements[0].(*parser.StructStatement)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected StructStatement, got %T", program.Statements[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
if stmt.Name != "Person" {
|
||||||
|
t.Errorf("expected struct name 'Person', got %s", stmt.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(stmt.Fields) != 2 {
|
||||||
|
t.Fatalf("expected 2 fields, got %d", len(stmt.Fields))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test first field
|
||||||
|
if stmt.Fields[0].Name != "name" {
|
||||||
|
t.Errorf("expected field name 'name', got %s", stmt.Fields[0].Name)
|
||||||
|
}
|
||||||
|
if stmt.Fields[0].TypeHint == nil {
|
||||||
|
t.Fatal("expected type hint for name field")
|
||||||
|
}
|
||||||
|
if stmt.Fields[0].TypeHint.Type != "string" {
|
||||||
|
t.Errorf("expected type 'string', got %s", stmt.Fields[0].TypeHint.Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test second field
|
||||||
|
if stmt.Fields[1].Name != "age" {
|
||||||
|
t.Errorf("expected field name 'age', got %s", stmt.Fields[1].Name)
|
||||||
|
}
|
||||||
|
if stmt.Fields[1].TypeHint == nil {
|
||||||
|
t.Fatal("expected type hint for age field")
|
||||||
|
}
|
||||||
|
if stmt.Fields[1].TypeHint.Type != "number" {
|
||||||
|
t.Errorf("expected type 'number', got %s", stmt.Fields[1].TypeHint.Type)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEmptyStructDefinition(t *testing.T) {
|
||||||
|
input := `struct Empty {}`
|
||||||
|
|
||||||
|
l := parser.NewLexer(input)
|
||||||
|
p := parser.NewParser(l)
|
||||||
|
program := p.ParseProgram()
|
||||||
|
checkParserErrors(t, p)
|
||||||
|
|
||||||
|
if len(program.Statements) != 1 {
|
||||||
|
t.Fatalf("expected 1 statement, got %d", len(program.Statements))
|
||||||
|
}
|
||||||
|
|
||||||
|
stmt, ok := program.Statements[0].(*parser.StructStatement)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected StructStatement, got %T", program.Statements[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
if stmt.Name != "Empty" {
|
||||||
|
t.Errorf("expected struct name 'Empty', got %s", stmt.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(stmt.Fields) != 0 {
|
||||||
|
t.Errorf("expected 0 fields, got %d", len(stmt.Fields))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComplexStructDefinition(t *testing.T) {
|
||||||
|
input := `struct Complex {
|
||||||
|
id: number,
|
||||||
|
name: string,
|
||||||
|
active: bool,
|
||||||
|
data: table,
|
||||||
|
callback: function,
|
||||||
|
optional: any
|
||||||
|
}`
|
||||||
|
|
||||||
|
l := parser.NewLexer(input)
|
||||||
|
p := parser.NewParser(l)
|
||||||
|
program := p.ParseProgram()
|
||||||
|
checkParserErrors(t, p)
|
||||||
|
|
||||||
|
if len(program.Statements) != 1 {
|
||||||
|
t.Fatalf("expected 1 statement, got %d", len(program.Statements))
|
||||||
|
}
|
||||||
|
|
||||||
|
stmt, ok := program.Statements[0].(*parser.StructStatement)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected StructStatement, got %T", program.Statements[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedTypes := []string{"number", "string", "bool", "table", "function", "any"}
|
||||||
|
expectedNames := []string{"id", "name", "active", "data", "callback", "optional"}
|
||||||
|
|
||||||
|
if len(stmt.Fields) != len(expectedTypes) {
|
||||||
|
t.Fatalf("expected %d fields, got %d", len(expectedTypes), len(stmt.Fields))
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, field := range stmt.Fields {
|
||||||
|
if field.Name != expectedNames[i] {
|
||||||
|
t.Errorf("field %d: expected name '%s', got '%s'", i, expectedNames[i], field.Name)
|
||||||
|
}
|
||||||
|
if field.TypeHint == nil {
|
||||||
|
t.Fatalf("field %d: expected type hint", i)
|
||||||
|
}
|
||||||
|
if field.TypeHint.Type != expectedTypes[i] {
|
||||||
|
t.Errorf("field %d: expected type '%s', got '%s'", i, expectedTypes[i], field.TypeHint.Type)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMethodDefinition(t *testing.T) {
|
||||||
|
input := `struct Person {
|
||||||
|
name: string,
|
||||||
|
age: number
|
||||||
|
}
|
||||||
|
|
||||||
|
fn Person.getName(): string
|
||||||
|
return self.name
|
||||||
|
end
|
||||||
|
|
||||||
|
fn Person.setAge(newAge: number)
|
||||||
|
self.age = newAge
|
||||||
|
end`
|
||||||
|
|
||||||
|
l := parser.NewLexer(input)
|
||||||
|
p := parser.NewParser(l)
|
||||||
|
program := p.ParseProgram()
|
||||||
|
checkParserErrors(t, p)
|
||||||
|
|
||||||
|
if len(program.Statements) != 3 {
|
||||||
|
t.Fatalf("expected 3 statements, got %d", len(program.Statements))
|
||||||
|
}
|
||||||
|
|
||||||
|
// First statement: struct definition
|
||||||
|
structStmt, ok := program.Statements[0].(*parser.StructStatement)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected StructStatement, got %T", program.Statements[0])
|
||||||
|
}
|
||||||
|
if structStmt.Name != "Person" {
|
||||||
|
t.Errorf("expected struct name 'Person', got %s", structStmt.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second statement: getter method
|
||||||
|
method1, ok := program.Statements[1].(*parser.MethodDefinition)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected MethodDefinition, got %T", program.Statements[1])
|
||||||
|
}
|
||||||
|
if method1.StructName != "Person" {
|
||||||
|
t.Errorf("expected struct name 'Person', got %s", method1.StructName)
|
||||||
|
}
|
||||||
|
if method1.MethodName != "getName" {
|
||||||
|
t.Errorf("expected method name 'getName', got %s", method1.MethodName)
|
||||||
|
}
|
||||||
|
if method1.Function.ReturnType == nil {
|
||||||
|
t.Fatal("expected return type for getName method")
|
||||||
|
}
|
||||||
|
if method1.Function.ReturnType.Type != "string" {
|
||||||
|
t.Errorf("expected return type 'string', got %s", method1.Function.ReturnType.Type)
|
||||||
|
}
|
||||||
|
if len(method1.Function.Parameters) != 0 {
|
||||||
|
t.Errorf("expected 0 parameters, got %d", len(method1.Function.Parameters))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Third statement: setter method
|
||||||
|
method2, ok := program.Statements[2].(*parser.MethodDefinition)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected MethodDefinition, got %T", program.Statements[2])
|
||||||
|
}
|
||||||
|
if method2.StructName != "Person" {
|
||||||
|
t.Errorf("expected struct name 'Person', got %s", method2.StructName)
|
||||||
|
}
|
||||||
|
if method2.MethodName != "setAge" {
|
||||||
|
t.Errorf("expected method name 'setAge', got %s", method2.MethodName)
|
||||||
|
}
|
||||||
|
if method2.Function.ReturnType != nil {
|
||||||
|
t.Errorf("expected no return type for setAge method, got %s", method2.Function.ReturnType.Type)
|
||||||
|
}
|
||||||
|
if len(method2.Function.Parameters) != 1 {
|
||||||
|
t.Fatalf("expected 1 parameter, got %d", len(method2.Function.Parameters))
|
||||||
|
}
|
||||||
|
if method2.Function.Parameters[0].Name != "newAge" {
|
||||||
|
t.Errorf("expected parameter name 'newAge', got %s", method2.Function.Parameters[0].Name)
|
||||||
|
}
|
||||||
|
if method2.Function.Parameters[0].TypeHint == nil {
|
||||||
|
t.Fatal("expected type hint for newAge parameter")
|
||||||
|
}
|
||||||
|
if method2.Function.Parameters[0].TypeHint.Type != "number" {
|
||||||
|
t.Errorf("expected parameter type 'number', got %s", method2.Function.Parameters[0].TypeHint.Type)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStructConstructor(t *testing.T) {
|
||||||
|
input := `struct Person {
|
||||||
|
name: string,
|
||||||
|
age: number
|
||||||
|
}
|
||||||
|
|
||||||
|
person = Person{name = "John", age = 30}
|
||||||
|
empty = Person{}`
|
||||||
|
|
||||||
|
l := parser.NewLexer(input)
|
||||||
|
p := parser.NewParser(l)
|
||||||
|
program := p.ParseProgram()
|
||||||
|
checkParserErrors(t, p)
|
||||||
|
|
||||||
|
if len(program.Statements) != 3 {
|
||||||
|
t.Fatalf("expected 3 statements, got %d", len(program.Statements))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second statement: constructor with fields
|
||||||
|
assign1, ok := program.Statements[1].(*parser.AssignStatement)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected AssignStatement, got %T", program.Statements[1])
|
||||||
|
}
|
||||||
|
|
||||||
|
constructor1, ok := assign1.Value.(*parser.StructConstructorExpression)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected StructConstructorExpression, got %T", assign1.Value)
|
||||||
|
}
|
||||||
|
|
||||||
|
if constructor1.StructName != "Person" {
|
||||||
|
t.Errorf("expected struct name 'Person', got %s", constructor1.StructName)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(constructor1.Fields) != 2 {
|
||||||
|
t.Fatalf("expected 2 fields, got %d", len(constructor1.Fields))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check name field
|
||||||
|
nameKey, ok := constructor1.Fields[0].Key.(*parser.Identifier)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected Identifier for name key, got %T", constructor1.Fields[0].Key)
|
||||||
|
}
|
||||||
|
if nameKey.Value != "name" {
|
||||||
|
t.Errorf("expected key 'name', got %s", nameKey.Value)
|
||||||
|
}
|
||||||
|
testStringLiteral(t, constructor1.Fields[0].Value, "John")
|
||||||
|
|
||||||
|
// Check age field
|
||||||
|
ageKey, ok := constructor1.Fields[1].Key.(*parser.Identifier)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected Identifier for age key, got %T", constructor1.Fields[1].Key)
|
||||||
|
}
|
||||||
|
if ageKey.Value != "age" {
|
||||||
|
t.Errorf("expected key 'age', got %s", ageKey.Value)
|
||||||
|
}
|
||||||
|
testNumberLiteral(t, constructor1.Fields[1].Value, 30)
|
||||||
|
|
||||||
|
// Third statement: empty constructor
|
||||||
|
assign2, ok := program.Statements[2].(*parser.AssignStatement)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected AssignStatement, got %T", program.Statements[2])
|
||||||
|
}
|
||||||
|
|
||||||
|
constructor2, ok := assign2.Value.(*parser.StructConstructorExpression)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected StructConstructorExpression, got %T", assign2.Value)
|
||||||
|
}
|
||||||
|
|
||||||
|
if constructor2.StructName != "Person" {
|
||||||
|
t.Errorf("expected struct name 'Person', got %s", constructor2.StructName)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(constructor2.Fields) != 0 {
|
||||||
|
t.Errorf("expected 0 fields, got %d", len(constructor2.Fields))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNestedStructTypes(t *testing.T) {
|
||||||
|
input := `struct Address {
|
||||||
|
street: string,
|
||||||
|
city: string
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Person {
|
||||||
|
name: string,
|
||||||
|
address: Address
|
||||||
|
}
|
||||||
|
|
||||||
|
person = Person{
|
||||||
|
name = "John",
|
||||||
|
address = Address{street = "Main St", city = "NYC"}
|
||||||
|
}`
|
||||||
|
|
||||||
|
l := parser.NewLexer(input)
|
||||||
|
p := parser.NewParser(l)
|
||||||
|
program := p.ParseProgram()
|
||||||
|
checkParserErrors(t, p)
|
||||||
|
|
||||||
|
if len(program.Statements) != 3 {
|
||||||
|
t.Fatalf("expected 3 statements, got %d", len(program.Statements))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check Person struct has Address field type
|
||||||
|
personStruct, ok := program.Statements[1].(*parser.StructStatement)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected StructStatement, got %T", program.Statements[1])
|
||||||
|
}
|
||||||
|
|
||||||
|
addressField := personStruct.Fields[1]
|
||||||
|
if addressField.Name != "address" {
|
||||||
|
t.Errorf("expected field name 'address', got %s", addressField.Name)
|
||||||
|
}
|
||||||
|
if addressField.TypeHint.Type != "Address" {
|
||||||
|
t.Errorf("expected field type 'Address', got %s", addressField.TypeHint.Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check nested constructor
|
||||||
|
assign, ok := program.Statements[2].(*parser.AssignStatement)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected AssignStatement, got %T", program.Statements[2])
|
||||||
|
}
|
||||||
|
|
||||||
|
personConstructor, ok := assign.Value.(*parser.StructConstructorExpression)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected StructConstructorExpression, got %T", assign.Value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check the nested Address constructor
|
||||||
|
addressConstructor, ok := personConstructor.Fields[1].Value.(*parser.StructConstructorExpression)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected nested StructConstructorExpression, got %T", personConstructor.Fields[1].Value)
|
||||||
|
}
|
||||||
|
|
||||||
|
if addressConstructor.StructName != "Address" {
|
||||||
|
t.Errorf("expected nested struct name 'Address', got %s", addressConstructor.StructName)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(addressConstructor.Fields) != 2 {
|
||||||
|
t.Errorf("expected 2 fields in nested constructor, got %d", len(addressConstructor.Fields))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStructIntegrationWithProgram(t *testing.T) {
|
||||||
|
input := `struct Point {
|
||||||
|
x: number,
|
||||||
|
y: number
|
||||||
|
}
|
||||||
|
|
||||||
|
fn Point.distance(other: Point): number
|
||||||
|
dx = self.x - other.x
|
||||||
|
dy = self.y - other.y
|
||||||
|
return (dx * dx + dy * dy)
|
||||||
|
end
|
||||||
|
|
||||||
|
p1 = Point{x = 0, y = 0}
|
||||||
|
p2 = Point{x = 3, y = 4}
|
||||||
|
|
||||||
|
if p1.distance(p2) then
|
||||||
|
echo "Distance calculated"
|
||||||
|
end
|
||||||
|
|
||||||
|
for i = 1, 10 do
|
||||||
|
point = Point{x = i, y = i * 2}
|
||||||
|
echo point.x
|
||||||
|
end`
|
||||||
|
|
||||||
|
l := parser.NewLexer(input)
|
||||||
|
p := parser.NewParser(l)
|
||||||
|
program := p.ParseProgram()
|
||||||
|
checkParserErrors(t, p)
|
||||||
|
|
||||||
|
if len(program.Statements) != 6 {
|
||||||
|
t.Fatalf("expected 6 statements, got %d", len(program.Statements))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify struct definition
|
||||||
|
structStmt, ok := program.Statements[0].(*parser.StructStatement)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected StructStatement, got %T", program.Statements[0])
|
||||||
|
}
|
||||||
|
if structStmt.Name != "Point" {
|
||||||
|
t.Errorf("expected struct name 'Point', got %s", structStmt.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify method definition
|
||||||
|
methodStmt, ok := program.Statements[1].(*parser.MethodDefinition)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected MethodDefinition, got %T", program.Statements[1])
|
||||||
|
}
|
||||||
|
if methodStmt.StructName != "Point" {
|
||||||
|
t.Errorf("expected struct name 'Point', got %s", methodStmt.StructName)
|
||||||
|
}
|
||||||
|
if methodStmt.MethodName != "distance" {
|
||||||
|
t.Errorf("expected method name 'distance', got %s", methodStmt.MethodName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify constructors
|
||||||
|
for i := 2; i <= 3; i++ {
|
||||||
|
assign, ok := program.Statements[i].(*parser.AssignStatement)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("statement %d: expected AssignStatement, got %T", i, program.Statements[i])
|
||||||
|
}
|
||||||
|
constructor, ok := assign.Value.(*parser.StructConstructorExpression)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("statement %d: expected StructConstructorExpression, got %T", i, assign.Value)
|
||||||
|
}
|
||||||
|
if constructor.StructName != "Point" {
|
||||||
|
t.Errorf("statement %d: expected struct name 'Point', got %s", i, constructor.StructName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify if statement with method call
|
||||||
|
ifStmt, ok := program.Statements[4].(*parser.IfStatement)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected IfStatement, got %T", program.Statements[4])
|
||||||
|
}
|
||||||
|
callExpr, ok := ifStmt.Condition.(*parser.CallExpression)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected CallExpression in if condition, got %T", ifStmt.Condition)
|
||||||
|
}
|
||||||
|
dotExpr, ok := callExpr.Function.(*parser.DotExpression)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected DotExpression for method call, got %T", callExpr.Function)
|
||||||
|
}
|
||||||
|
if dotExpr.Key != "distance" {
|
||||||
|
t.Errorf("expected method name 'distance', got %s", dotExpr.Key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify for loop with struct creation
|
||||||
|
forStmt, ok := program.Statements[5].(*parser.ForStatement)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected ForStatement, got %T", program.Statements[5])
|
||||||
|
}
|
||||||
|
if len(forStmt.Body) != 2 {
|
||||||
|
t.Errorf("expected 2 statements in for body, got %d", len(forStmt.Body))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check struct constructor in loop
|
||||||
|
loopAssign, ok := forStmt.Body[0].(*parser.AssignStatement)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected AssignStatement in loop, got %T", forStmt.Body[0])
|
||||||
|
}
|
||||||
|
loopConstructor, ok := loopAssign.Value.(*parser.StructConstructorExpression)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected StructConstructorExpression in loop, got %T", loopAssign.Value)
|
||||||
|
}
|
||||||
|
if loopConstructor.StructName != "Point" {
|
||||||
|
t.Errorf("expected struct name 'Point' in loop, got %s", loopConstructor.StructName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStructErrorCases(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expectedErrorSubstring string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "missing field type",
|
||||||
|
input: `struct Person {
|
||||||
|
name
|
||||||
|
}`,
|
||||||
|
expectedErrorSubstring: "struct fields require type annotation",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing struct name",
|
||||||
|
input: `struct {
|
||||||
|
name: string
|
||||||
|
}`,
|
||||||
|
expectedErrorSubstring: "expected struct name",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing opening brace",
|
||||||
|
input: `struct Person
|
||||||
|
name: string
|
||||||
|
}`,
|
||||||
|
expectedErrorSubstring: "expected '{' after struct name",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing closing brace",
|
||||||
|
input: `struct Person {
|
||||||
|
name: string`,
|
||||||
|
expectedErrorSubstring: "expected next token to be }, got end of file instead",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid field type",
|
||||||
|
input: `struct Person {
|
||||||
|
name: invalidtype
|
||||||
|
}`,
|
||||||
|
expectedErrorSubstring: "invalid type name 'invalidtype'",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
l := parser.NewLexer(tt.input)
|
||||||
|
p := parser.NewParser(l)
|
||||||
|
_ = p.ParseProgram()
|
||||||
|
|
||||||
|
if !p.HasErrors() {
|
||||||
|
t.Fatalf("expected parser errors, but got none")
|
||||||
|
}
|
||||||
|
|
||||||
|
errors := p.ErrorStrings()
|
||||||
|
found := false
|
||||||
|
for _, err := range errors {
|
||||||
|
if containsSubstring(err, tt.expectedErrorSubstring) {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !found {
|
||||||
|
t.Errorf("expected error containing '%s', got errors: %v", tt.expectedErrorSubstring, errors)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSingleLineStruct(t *testing.T) {
|
||||||
|
input := `struct Person { name: string, age: number }`
|
||||||
|
|
||||||
|
l := parser.NewLexer(input)
|
||||||
|
p := parser.NewParser(l)
|
||||||
|
program := p.ParseProgram()
|
||||||
|
checkParserErrors(t, p)
|
||||||
|
|
||||||
|
if len(program.Statements) != 1 {
|
||||||
|
t.Fatalf("expected 1 statement, got %d", len(program.Statements))
|
||||||
|
}
|
||||||
|
|
||||||
|
stmt, ok := program.Statements[0].(*parser.StructStatement)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected StructStatement, got %T", program.Statements[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
if stmt.Name != "Person" {
|
||||||
|
t.Errorf("expected struct name 'Person', got %s", stmt.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(stmt.Fields) != 2 {
|
||||||
|
t.Fatalf("expected 2 fields, got %d", len(stmt.Fields))
|
||||||
|
}
|
||||||
|
|
||||||
|
if stmt.Fields[0].Name != "name" || stmt.Fields[0].TypeHint.Type != "string" {
|
||||||
|
t.Errorf("expected first field 'name: string', got '%s: %s'",
|
||||||
|
stmt.Fields[0].Name, stmt.Fields[0].TypeHint.Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
if stmt.Fields[1].Name != "age" || stmt.Fields[1].TypeHint.Type != "number" {
|
||||||
|
t.Errorf("expected second field 'age: number', got '%s: %s'",
|
||||||
|
stmt.Fields[1].Name, stmt.Fields[1].TypeHint.Type)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStructString(t *testing.T) {
|
||||||
|
input := `struct Person {
|
||||||
|
name: string,
|
||||||
|
age: number
|
||||||
|
}`
|
||||||
|
|
||||||
|
l := parser.NewLexer(input)
|
||||||
|
p := parser.NewParser(l)
|
||||||
|
program := p.ParseProgram()
|
||||||
|
checkParserErrors(t, p)
|
||||||
|
|
||||||
|
stmt := program.Statements[0].(*parser.StructStatement)
|
||||||
|
str := stmt.String()
|
||||||
|
|
||||||
|
expected := "struct Person {\n\tname: string,\n\tage: number\n}"
|
||||||
|
if str != expected {
|
||||||
|
t.Errorf("expected string representation:\n%s\ngot:\n%s", expected, str)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMethodString(t *testing.T) {
|
||||||
|
input := `struct Person {
|
||||||
|
name: string
|
||||||
|
}
|
||||||
|
|
||||||
|
fn Person.getName(): string
|
||||||
|
return self.name
|
||||||
|
end`
|
||||||
|
|
||||||
|
l := parser.NewLexer(input)
|
||||||
|
p := parser.NewParser(l)
|
||||||
|
program := p.ParseProgram()
|
||||||
|
checkParserErrors(t, p)
|
||||||
|
|
||||||
|
method := program.Statements[1].(*parser.MethodDefinition)
|
||||||
|
str := method.String()
|
||||||
|
|
||||||
|
if !containsSubstring(str, "fn Person.getName") {
|
||||||
|
t.Errorf("expected method string to contain 'fn Person.getName', got: %s", str)
|
||||||
|
}
|
||||||
|
if !containsSubstring(str, ": string") {
|
||||||
|
t.Errorf("expected method string to contain return type, got: %s", str)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConstructorString(t *testing.T) {
|
||||||
|
input := `struct Person {
|
||||||
|
name: string,
|
||||||
|
age: number
|
||||||
|
}
|
||||||
|
|
||||||
|
person = Person{name = "John", age = 30}`
|
||||||
|
|
||||||
|
l := parser.NewLexer(input)
|
||||||
|
p := parser.NewParser(l)
|
||||||
|
program := p.ParseProgram()
|
||||||
|
checkParserErrors(t, p)
|
||||||
|
|
||||||
|
assign := program.Statements[1].(*parser.AssignStatement)
|
||||||
|
constructor := assign.Value.(*parser.StructConstructorExpression)
|
||||||
|
str := constructor.String()
|
||||||
|
|
||||||
|
expected := `Person{name = "John", age = 30.00}`
|
||||||
|
if str != expected {
|
||||||
|
t.Errorf("expected constructor string:\n%s\ngot:\n%s", expected, str)
|
||||||
|
}
|
||||||
|
}
|
@ -59,6 +59,7 @@ const (
|
|||||||
EXIT
|
EXIT
|
||||||
FN
|
FN
|
||||||
RETURN
|
RETURN
|
||||||
|
STRUCT
|
||||||
|
|
||||||
// Special
|
// Special
|
||||||
EOF
|
EOF
|
||||||
@ -107,6 +108,7 @@ var precedences = map[TokenType]Precedence{
|
|||||||
DOT: MEMBER,
|
DOT: MEMBER,
|
||||||
LBRACKET: MEMBER,
|
LBRACKET: MEMBER,
|
||||||
LPAREN: CALL,
|
LPAREN: CALL,
|
||||||
|
LBRACE: CALL,
|
||||||
}
|
}
|
||||||
|
|
||||||
// lookupIdent checks if an identifier is a keyword
|
// lookupIdent checks if an identifier is a keyword
|
||||||
@ -132,6 +134,7 @@ func lookupIdent(ident string) TokenType {
|
|||||||
"exit": EXIT,
|
"exit": EXIT,
|
||||||
"fn": FN,
|
"fn": FN,
|
||||||
"return": RETURN,
|
"return": RETURN,
|
||||||
|
"struct": STRUCT,
|
||||||
}
|
}
|
||||||
|
|
||||||
if tok, ok := keywords[ident]; ok {
|
if tok, ok := keywords[ident]; ok {
|
||||||
|
197
parser/types.go
197
parser/types.go
@ -76,6 +76,9 @@ type TypeInferrer struct {
|
|||||||
nilType *TypeInfo
|
nilType *TypeInfo
|
||||||
tableType *TypeInfo
|
tableType *TypeInfo
|
||||||
anyType *TypeInfo
|
anyType *TypeInfo
|
||||||
|
|
||||||
|
// Struct definitions
|
||||||
|
structs map[string]*StructStatement
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewTypeInferrer creates a new type inference engine
|
// NewTypeInferrer creates a new type inference engine
|
||||||
@ -86,6 +89,7 @@ func NewTypeInferrer() *TypeInferrer {
|
|||||||
currentScope: globalScope,
|
currentScope: globalScope,
|
||||||
globalScope: globalScope,
|
globalScope: globalScope,
|
||||||
errors: []TypeError{},
|
errors: []TypeError{},
|
||||||
|
structs: make(map[string]*StructStatement),
|
||||||
|
|
||||||
// Pre-allocate common types to reduce allocations
|
// Pre-allocate common types to reduce allocations
|
||||||
numberType: &TypeInfo{Type: TypeNumber, Inferred: true},
|
numberType: &TypeInfo{Type: TypeNumber, Inferred: true},
|
||||||
@ -101,6 +105,14 @@ func NewTypeInferrer() *TypeInferrer {
|
|||||||
|
|
||||||
// InferTypes performs type inference on the entire program
|
// InferTypes performs type inference on the entire program
|
||||||
func (ti *TypeInferrer) InferTypes(program *Program) []TypeError {
|
func (ti *TypeInferrer) InferTypes(program *Program) []TypeError {
|
||||||
|
// First pass: collect struct definitions
|
||||||
|
for _, stmt := range program.Statements {
|
||||||
|
if structStmt, ok := stmt.(*StructStatement); ok {
|
||||||
|
ti.structs[structStmt.Name] = structStmt
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second pass: infer types
|
||||||
for _, stmt := range program.Statements {
|
for _, stmt := range program.Statements {
|
||||||
ti.inferStatement(stmt)
|
ti.inferStatement(stmt)
|
||||||
}
|
}
|
||||||
@ -129,9 +141,27 @@ func (ti *TypeInferrer) addError(message string, node Node) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getStructType returns TypeInfo for a struct
|
||||||
|
func (ti *TypeInferrer) getStructType(name string) *TypeInfo {
|
||||||
|
if _, exists := ti.structs[name]; exists {
|
||||||
|
return &TypeInfo{Type: name, Inferred: true}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// isStructType checks if a type is a struct type
|
||||||
|
func (ti *TypeInferrer) isStructType(t *TypeInfo) bool {
|
||||||
|
_, exists := ti.structs[t.Type]
|
||||||
|
return exists
|
||||||
|
}
|
||||||
|
|
||||||
// inferStatement infers types for statements
|
// inferStatement infers types for statements
|
||||||
func (ti *TypeInferrer) inferStatement(stmt Statement) {
|
func (ti *TypeInferrer) inferStatement(stmt Statement) {
|
||||||
switch s := stmt.(type) {
|
switch s := stmt.(type) {
|
||||||
|
case *StructStatement:
|
||||||
|
ti.inferStructStatement(s)
|
||||||
|
case *MethodDefinition:
|
||||||
|
ti.inferMethodDefinition(s)
|
||||||
case *AssignStatement:
|
case *AssignStatement:
|
||||||
ti.inferAssignStatement(s)
|
ti.inferAssignStatement(s)
|
||||||
case *EchoStatement:
|
case *EchoStatement:
|
||||||
@ -152,9 +182,63 @@ func (ti *TypeInferrer) inferStatement(stmt Statement) {
|
|||||||
if s.Value != nil {
|
if s.Value != nil {
|
||||||
ti.inferExpression(s.Value)
|
ti.inferExpression(s.Value)
|
||||||
}
|
}
|
||||||
|
case *ExpressionStatement:
|
||||||
|
ti.inferExpression(s.Expression)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// inferStructStatement handles struct definitions
|
||||||
|
func (ti *TypeInferrer) inferStructStatement(stmt *StructStatement) {
|
||||||
|
// Validate field types
|
||||||
|
for _, field := range stmt.Fields {
|
||||||
|
if field.TypeHint != nil {
|
||||||
|
if !ValidTypeName(field.TypeHint.Type) && !ti.isStructType(field.TypeHint) {
|
||||||
|
ti.addError(fmt.Sprintf("invalid field type '%s' in struct '%s'",
|
||||||
|
field.TypeHint.Type, stmt.Name), stmt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// inferMethodDefinition handles method definitions
|
||||||
|
func (ti *TypeInferrer) inferMethodDefinition(stmt *MethodDefinition) {
|
||||||
|
// Check if struct exists
|
||||||
|
if _, exists := ti.structs[stmt.StructName]; !exists {
|
||||||
|
ti.addError(fmt.Sprintf("method defined on undefined struct '%s'", stmt.StructName), stmt)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Infer the function body
|
||||||
|
ti.enterScope()
|
||||||
|
|
||||||
|
// Add self parameter implicitly
|
||||||
|
ti.currentScope.Define(&Symbol{
|
||||||
|
Name: "self",
|
||||||
|
Type: ti.getStructType(stmt.StructName),
|
||||||
|
Declared: true,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Add explicit parameters
|
||||||
|
for _, param := range stmt.Function.Parameters {
|
||||||
|
paramType := ti.anyType
|
||||||
|
if param.TypeHint != nil {
|
||||||
|
paramType = param.TypeHint
|
||||||
|
}
|
||||||
|
ti.currentScope.Define(&Symbol{
|
||||||
|
Name: param.Name,
|
||||||
|
Type: paramType,
|
||||||
|
Declared: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Infer function body
|
||||||
|
for _, bodyStmt := range stmt.Function.Body {
|
||||||
|
ti.inferStatement(bodyStmt)
|
||||||
|
}
|
||||||
|
|
||||||
|
ti.exitScope()
|
||||||
|
}
|
||||||
|
|
||||||
// inferAssignStatement handles variable assignments with type checking
|
// inferAssignStatement handles variable assignments with type checking
|
||||||
func (ti *TypeInferrer) inferAssignStatement(stmt *AssignStatement) {
|
func (ti *TypeInferrer) inferAssignStatement(stmt *AssignStatement) {
|
||||||
// Infer the type of the value expression
|
// Infer the type of the value expression
|
||||||
@ -288,9 +372,9 @@ func (ti *TypeInferrer) inferForStatement(stmt *ForStatement) {
|
|||||||
func (ti *TypeInferrer) inferForInStatement(stmt *ForInStatement) {
|
func (ti *TypeInferrer) inferForInStatement(stmt *ForInStatement) {
|
||||||
iterableType := ti.inferExpression(stmt.Iterable)
|
iterableType := ti.inferExpression(stmt.Iterable)
|
||||||
|
|
||||||
// For now, assume iterable is a table
|
// For now, assume iterable is a table or struct
|
||||||
if !ti.isTableType(iterableType) {
|
if !ti.isTableType(iterableType) && !ti.isStructType(iterableType) {
|
||||||
ti.addError("for-in requires an iterable (table)", stmt.Iterable)
|
ti.addError("for-in requires an iterable (table or struct)", stmt.Iterable)
|
||||||
}
|
}
|
||||||
|
|
||||||
ti.enterScope()
|
ti.enterScope()
|
||||||
@ -341,6 +425,8 @@ func (ti *TypeInferrer) inferExpression(expr Expression) *TypeInfo {
|
|||||||
return ti.nilType
|
return ti.nilType
|
||||||
case *TableLiteral:
|
case *TableLiteral:
|
||||||
return ti.inferTableLiteral(e)
|
return ti.inferTableLiteral(e)
|
||||||
|
case *StructConstructorExpression:
|
||||||
|
return ti.inferStructConstructor(e)
|
||||||
case *FunctionLiteral:
|
case *FunctionLiteral:
|
||||||
return ti.inferFunctionLiteral(e)
|
return ti.inferFunctionLiteral(e)
|
||||||
case *CallExpression:
|
case *CallExpression:
|
||||||
@ -353,12 +439,85 @@ func (ti *TypeInferrer) inferExpression(expr Expression) *TypeInfo {
|
|||||||
return ti.inferIndexExpression(e)
|
return ti.inferIndexExpression(e)
|
||||||
case *DotExpression:
|
case *DotExpression:
|
||||||
return ti.inferDotExpression(e)
|
return ti.inferDotExpression(e)
|
||||||
|
case *AssignExpression:
|
||||||
|
return ti.inferAssignExpression(e)
|
||||||
default:
|
default:
|
||||||
ti.addError("unknown expression type", expr)
|
ti.addError("unknown expression type", expr)
|
||||||
return ti.anyType
|
return ti.anyType
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// inferStructConstructor handles struct constructor expressions
|
||||||
|
func (ti *TypeInferrer) inferStructConstructor(expr *StructConstructorExpression) *TypeInfo {
|
||||||
|
structDef, exists := ti.structs[expr.StructName]
|
||||||
|
if !exists {
|
||||||
|
ti.addError(fmt.Sprintf("undefined struct '%s'", expr.StructName), expr)
|
||||||
|
return ti.anyType
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate field assignments
|
||||||
|
for _, pair := range expr.Fields {
|
||||||
|
if pair.Key != nil {
|
||||||
|
fieldName := ""
|
||||||
|
if ident, ok := pair.Key.(*Identifier); ok {
|
||||||
|
fieldName = ident.Value
|
||||||
|
} else if str, ok := pair.Key.(*StringLiteral); ok {
|
||||||
|
fieldName = str.Value
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if field exists in struct
|
||||||
|
fieldExists := false
|
||||||
|
var fieldType *TypeInfo
|
||||||
|
for _, field := range structDef.Fields {
|
||||||
|
if field.Name == fieldName {
|
||||||
|
fieldExists = true
|
||||||
|
fieldType = field.TypeHint
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !fieldExists {
|
||||||
|
ti.addError(fmt.Sprintf("struct '%s' has no field '%s'", expr.StructName, fieldName), expr)
|
||||||
|
} else {
|
||||||
|
// Check type compatibility
|
||||||
|
valueType := ti.inferExpression(pair.Value)
|
||||||
|
if !ti.isTypeCompatible(valueType, fieldType) {
|
||||||
|
ti.addError(fmt.Sprintf("cannot assign %s to field '%s' of type %s",
|
||||||
|
valueType.Type, fieldName, fieldType.Type), expr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Array-style assignment not valid for structs
|
||||||
|
ti.addError("struct constructors require named field assignments", expr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
structType := ti.getStructType(expr.StructName)
|
||||||
|
expr.SetType(structType)
|
||||||
|
return structType
|
||||||
|
}
|
||||||
|
|
||||||
|
// inferAssignExpression handles assignment expressions
|
||||||
|
func (ti *TypeInferrer) inferAssignExpression(expr *AssignExpression) *TypeInfo {
|
||||||
|
valueType := ti.inferExpression(expr.Value)
|
||||||
|
|
||||||
|
if ident, ok := expr.Name.(*Identifier); ok {
|
||||||
|
if expr.IsDeclaration {
|
||||||
|
ti.currentScope.Define(&Symbol{
|
||||||
|
Name: ident.Value,
|
||||||
|
Type: valueType,
|
||||||
|
Declared: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
ident.SetType(valueType)
|
||||||
|
} else {
|
||||||
|
ti.inferExpression(expr.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
expr.SetType(valueType)
|
||||||
|
return valueType
|
||||||
|
}
|
||||||
|
|
||||||
// inferIdentifier looks up identifier type in symbol table
|
// inferIdentifier looks up identifier type in symbol table
|
||||||
func (ti *TypeInferrer) inferIdentifier(ident *Identifier) *TypeInfo {
|
func (ti *TypeInferrer) inferIdentifier(ident *Identifier) *TypeInfo {
|
||||||
symbol := ti.currentScope.Lookup(ident.Value)
|
symbol := ti.currentScope.Lookup(ident.Value)
|
||||||
@ -497,17 +656,43 @@ func (ti *TypeInferrer) inferInfixExpression(infix *InfixExpression) *TypeInfo {
|
|||||||
|
|
||||||
// inferIndexExpression infers table[index] type
|
// inferIndexExpression infers table[index] type
|
||||||
func (ti *TypeInferrer) inferIndexExpression(index *IndexExpression) *TypeInfo {
|
func (ti *TypeInferrer) inferIndexExpression(index *IndexExpression) *TypeInfo {
|
||||||
ti.inferExpression(index.Left)
|
leftType := ti.inferExpression(index.Left)
|
||||||
ti.inferExpression(index.Index)
|
ti.inferExpression(index.Index)
|
||||||
|
|
||||||
// For now, assume table access returns any
|
// If indexing a struct, try to infer field type
|
||||||
|
if ti.isStructType(leftType) {
|
||||||
|
if strLit, ok := index.Index.(*StringLiteral); ok {
|
||||||
|
if structDef, exists := ti.structs[leftType.Type]; exists {
|
||||||
|
for _, field := range structDef.Fields {
|
||||||
|
if field.Name == strLit.Value {
|
||||||
|
index.SetType(field.TypeHint)
|
||||||
|
return field.TypeHint
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// For now, assume table/struct access returns any
|
||||||
index.SetType(ti.anyType)
|
index.SetType(ti.anyType)
|
||||||
return ti.anyType
|
return ti.anyType
|
||||||
}
|
}
|
||||||
|
|
||||||
// inferDotExpression infers table.key type
|
// inferDotExpression infers table.key type
|
||||||
func (ti *TypeInferrer) inferDotExpression(dot *DotExpression) *TypeInfo {
|
func (ti *TypeInferrer) inferDotExpression(dot *DotExpression) *TypeInfo {
|
||||||
ti.inferExpression(dot.Left)
|
leftType := ti.inferExpression(dot.Left)
|
||||||
|
|
||||||
|
// If accessing a struct field, try to infer field type
|
||||||
|
if ti.isStructType(leftType) {
|
||||||
|
if structDef, exists := ti.structs[leftType.Type]; exists {
|
||||||
|
for _, field := range structDef.Fields {
|
||||||
|
if field.Name == dot.Key {
|
||||||
|
dot.SetType(field.TypeHint)
|
||||||
|
return field.TypeHint
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// For now, assume member access returns any
|
// For now, assume member access returns any
|
||||||
dot.SetType(ti.anyType)
|
dot.SetType(ti.anyType)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user