This commit is contained in:
Sky Johnson 2025-06-11 16:28:55 -05:00
parent fc439b6d5a
commit 5ae2a6ef23
5 changed files with 1199 additions and 9 deletions

View File

@ -4,7 +4,7 @@ import "fmt"
// TypeInfo represents type information for expressions
type TypeInfo struct {
Type string // "number", "string", "bool", "table", "function", "nil", "any"
Type string // "number", "string", "bool", "table", "function", "nil", "any", struct name
Inferred bool // true if type was inferred, false if explicitly declared
}
@ -41,6 +41,67 @@ func (p *Program) String() string {
return result
}
// StructField represents a field in a struct definition
type StructField struct {
Name string
TypeHint *TypeInfo
}
func (sf *StructField) String() string {
if sf.TypeHint != nil {
return fmt.Sprintf("%s: %s", sf.Name, sf.TypeHint.Type)
}
return sf.Name
}
// StructStatement represents struct definitions
type StructStatement struct {
Name string
Fields []StructField
}
func (ss *StructStatement) statementNode() {}
func (ss *StructStatement) String() string {
var fields string
for i, field := range ss.Fields {
if i > 0 {
fields += ",\n\t"
}
fields += field.String()
}
return fmt.Sprintf("struct %s {\n\t%s\n}", ss.Name, fields)
}
// MethodDefinition represents method definitions on structs
type MethodDefinition struct {
StructName string
MethodName string
Function *FunctionLiteral
}
func (md *MethodDefinition) statementNode() {}
func (md *MethodDefinition) String() string {
return fmt.Sprintf("fn %s.%s%s", md.StructName, md.MethodName, md.Function.String()[2:]) // skip "fn" from function string
}
// StructConstructorExpression represents struct constructor calls like my_type{...}
type StructConstructorExpression struct {
StructName string
Fields []TablePair // reuse TablePair for field assignments
typeInfo *TypeInfo
}
func (sce *StructConstructorExpression) expressionNode() {}
func (sce *StructConstructorExpression) String() string {
var pairs []string
for _, pair := range sce.Fields {
pairs = append(pairs, pair.String())
}
return fmt.Sprintf("%s{%s}", sce.StructName, joinStrings(pairs, ", "))
}
func (sce *StructConstructorExpression) GetType() *TypeInfo { return sce.typeInfo }
func (sce *StructConstructorExpression) SetType(t *TypeInfo) { sce.typeInfo = t }
// AssignStatement represents variable assignment with optional type hint
type AssignStatement struct {
Name Expression // Changed from *Identifier to Expression for member access

View File

@ -34,6 +34,9 @@ type Parser struct {
// Scope tracking
scopes []map[string]bool // stack of scopes, each tracking declared variables
scopeTypes []string // track what type each scope is: "global", "function", "loop"
// Struct tracking
structs map[string]*StructStatement // track defined structs
}
// NewParser creates a new parser instance
@ -43,6 +46,7 @@ func NewParser(lexer *Lexer) *Parser {
errors: []ParseError{},
scopes: []map[string]bool{make(map[string]bool)}, // start with global scope
scopeTypes: []string{"global"}, // start with global scope type
structs: make(map[string]*StructStatement), // track struct definitions
}
p.prefixParseFns = make(map[TokenType]func() Expression)
@ -74,6 +78,7 @@ func NewParser(lexer *Lexer) *Parser {
p.registerInfix(DOT, p.parseDotExpression)
p.registerInfix(LBRACKET, p.parseIndexExpression)
p.registerInfix(LPAREN, p.parseCallExpression)
p.registerInfix(LBRACE, p.parseStructConstructor) // struct constructor
// Read two tokens, so curToken and peekToken are both set
p.nextToken()
@ -157,7 +162,7 @@ func (p *Parser) parseTypeHint() *TypeInfo {
}
typeName := p.curToken.Literal
if !ValidTypeName(typeName) {
if !ValidTypeName(typeName) && !p.isStructDefined(typeName) {
p.addError(fmt.Sprintf("invalid type name '%s'", typeName))
return nil
}
@ -165,6 +170,12 @@ func (p *Parser) parseTypeHint() *TypeInfo {
return &TypeInfo{Type: typeName, Inferred: false}
}
// isStructDefined checks if a struct name is defined
func (p *Parser) isStructDefined(name string) bool {
_, exists := p.structs[name]
return exists
}
// registerPrefix registers a prefix parse function
func (p *Parser) registerPrefix(tokenType TokenType, fn func() Expression) {
p.prefixParseFns[tokenType] = fn
@ -200,6 +211,10 @@ func (p *Parser) ParseProgram() *Program {
// parseStatement parses a statement
func (p *Parser) parseStatement() Statement {
switch p.curToken.Type {
case STRUCT:
return p.parseStructStatement()
case FN:
return p.parseFunctionStatement()
case IDENT:
return p.parseIdentifierStatement()
case IF:
@ -230,6 +245,147 @@ func (p *Parser) parseStatement() Statement {
}
}
// parseStructStatement parses struct definitions
func (p *Parser) parseStructStatement() *StructStatement {
stmt := &StructStatement{}
if !p.expectPeek(IDENT) {
p.addError("expected struct name")
return nil
}
stmt.Name = p.curToken.Literal
if !p.expectPeek(LBRACE) {
p.addError("expected '{' after struct name")
return nil
}
stmt.Fields = []StructField{}
if p.peekTokenIs(RBRACE) {
p.nextToken()
p.structs[stmt.Name] = stmt
return stmt
}
p.nextToken()
for {
if p.curTokenIs(EOF) {
p.addError("unexpected end of input, expected }")
return nil
}
if !p.curTokenIs(IDENT) {
p.addError("expected field name")
return nil
}
field := StructField{Name: p.curToken.Literal}
// Parse optional type hint
field.TypeHint = p.parseTypeHint()
if field.TypeHint == nil {
p.addError("struct fields require type annotation")
return nil
}
stmt.Fields = append(stmt.Fields, field)
if !p.peekTokenIs(COMMA) {
break
}
p.nextToken() // consume comma
p.nextToken() // move to next field
if p.curTokenIs(RBRACE) {
break
}
if p.curTokenIs(EOF) {
p.addError("expected next token to be }")
return nil
}
}
if !p.expectPeek(RBRACE) {
return nil
}
p.structs[stmt.Name] = stmt
return stmt
}
// parseFunctionStatement handles both regular functions and methods
func (p *Parser) parseFunctionStatement() Statement {
if !p.expectPeek(IDENT) {
p.addError("expected function name")
return nil
}
funcName := p.curToken.Literal
// Check if this is a method definition (struct.method)
if p.peekTokenIs(DOT) {
p.nextToken() // consume '.'
if !p.expectPeek(IDENT) {
p.addError("expected method name after '.'")
return nil
}
methodName := p.curToken.Literal
if !p.expectPeek(LPAREN) {
p.addError("expected '(' after method name")
return nil
}
// Parse the function literal starting from parameters
funcLit := &FunctionLiteral{}
funcLit.Parameters, funcLit.Variadic = p.parseFunctionParameters()
if !p.expectPeek(RPAREN) {
p.addError("expected ')' after function parameters")
return nil
}
// Check for return type hint
funcLit.ReturnType = p.parseTypeHint()
p.nextToken()
p.enterFunctionScope()
for _, param := range funcLit.Parameters {
p.declareVariable(param.Name)
}
funcLit.Body = p.parseBlockStatements(END)
p.exitFunctionScope()
if !p.curTokenIs(END) {
p.addError("expected 'end' to close function")
return nil
}
return &MethodDefinition{
StructName: funcName,
MethodName: methodName,
Function: funcLit,
}
}
// Regular function - this should be handled as expression statement
// Reset to handle as function literal
funcLit := p.parseFunctionLiteral()
if funcLit == nil {
return nil
}
return &ExpressionStatement{Expression: funcLit}
}
// parseIdentifierStatement handles both assignments and expression statements starting with identifiers
func (p *Parser) parseIdentifierStatement() Statement {
// Parse the left-hand side expression first
@ -948,6 +1104,157 @@ func (p *Parser) parseTableLiteral() Expression {
return table
}
// parseStructConstructor handles struct constructor calls like my_type{...}
func (p *Parser) parseStructConstructor(left Expression) Expression {
// left should be an identifier representing the struct name
ident, ok := left.(*Identifier)
if !ok {
// Not an identifier, fall back to table literal parsing
return p.parseTableLiteralFromBrace()
}
structName := ident.Value
// Always try to parse as struct constructor if we have an identifier
// Type checking will catch undefined structs later
constructor := &StructConstructorExpression{
StructName: structName,
Fields: []TablePair{},
}
if p.peekTokenIs(RBRACE) {
p.nextToken()
return constructor
}
p.nextToken()
for {
if p.curTokenIs(EOF) {
p.addError("unexpected end of input, expected }")
return nil
}
pair := TablePair{}
if (p.curTokenIs(IDENT) || p.curTokenIs(STRING)) && p.peekTokenIs(ASSIGN) {
if p.curTokenIs(IDENT) {
pair.Key = &Identifier{Value: p.curToken.Literal}
} else {
pair.Key = &StringLiteral{Value: p.curToken.Literal}
}
p.nextToken()
p.nextToken()
if p.curTokenIs(EOF) {
p.addError("expected expression after assignment operator")
return nil
}
pair.Value = p.ParseExpression(LOWEST)
} else {
pair.Value = p.ParseExpression(LOWEST)
}
if pair.Value == nil {
return nil
}
constructor.Fields = append(constructor.Fields, pair)
if !p.peekTokenIs(COMMA) {
break
}
p.nextToken()
p.nextToken()
if p.curTokenIs(RBRACE) {
break
}
if p.curTokenIs(EOF) {
p.addError("expected next token to be }")
return nil
}
}
if !p.expectPeek(RBRACE) {
return nil
}
return constructor
}
func (p *Parser) parseTableLiteralFromBrace() Expression {
// We're already at the opening brace, so parse as table literal
table := &TableLiteral{}
table.Pairs = []TablePair{}
if p.peekTokenIs(RBRACE) {
p.nextToken()
return table
}
p.nextToken()
for {
if p.curTokenIs(EOF) {
p.addError("unexpected end of input, expected }")
return nil
}
pair := TablePair{}
if (p.curTokenIs(IDENT) || p.curTokenIs(STRING)) && p.peekTokenIs(ASSIGN) {
if p.curTokenIs(IDENT) {
pair.Key = &Identifier{Value: p.curToken.Literal}
} else {
pair.Key = &StringLiteral{Value: p.curToken.Literal}
}
p.nextToken()
p.nextToken()
if p.curTokenIs(EOF) {
p.addError("expected expression after assignment operator")
return nil
}
pair.Value = p.ParseExpression(LOWEST)
} else {
pair.Value = p.ParseExpression(LOWEST)
}
if pair.Value == nil {
return nil
}
table.Pairs = append(table.Pairs, pair)
if !p.peekTokenIs(COMMA) {
break
}
p.nextToken()
p.nextToken()
if p.curTokenIs(RBRACE) {
break
}
if p.curTokenIs(EOF) {
p.addError("expected next token to be }")
return nil
}
}
if !p.expectPeek(RBRACE) {
return nil
}
return table
}
func (p *Parser) parseInfixExpression(left Expression) Expression {
expression := &InfixExpression{
Left: left,
@ -1057,7 +1364,7 @@ func (p *Parser) expectPeekIdent() bool {
func (p *Parser) isKeyword(t TokenType) bool {
switch t {
case TRUE, FALSE, NIL, AND, OR, NOT, IF, THEN, ELSEIF, ELSE, END, ECHO, FOR, WHILE, IN, DO, BREAK, EXIT, FN, RETURN:
case TRUE, FALSE, NIL, AND, OR, NOT, IF, THEN, ELSEIF, ELSE, END, ECHO, FOR, WHILE, IN, DO, BREAK, EXIT, FN, RETURN, STRUCT:
return true
default:
return false
@ -1227,6 +1534,8 @@ func tokenTypeString(t TokenType) string {
return "fn"
case RETURN:
return "return"
case STRUCT:
return "struct"
case EOF:
return "end of file"
case ILLEGAL:

View File

@ -0,0 +1,632 @@
package parser_test
import (
"testing"
"git.sharkk.net/Sharkk/Mako/parser"
)
func TestBasicStructDefinition(t *testing.T) {
input := `struct Person {
name: string,
age: number
}`
l := parser.NewLexer(input)
p := parser.NewParser(l)
program := p.ParseProgram()
checkParserErrors(t, p)
if len(program.Statements) != 1 {
t.Fatalf("expected 1 statement, got %d", len(program.Statements))
}
stmt, ok := program.Statements[0].(*parser.StructStatement)
if !ok {
t.Fatalf("expected StructStatement, got %T", program.Statements[0])
}
if stmt.Name != "Person" {
t.Errorf("expected struct name 'Person', got %s", stmt.Name)
}
if len(stmt.Fields) != 2 {
t.Fatalf("expected 2 fields, got %d", len(stmt.Fields))
}
// Test first field
if stmt.Fields[0].Name != "name" {
t.Errorf("expected field name 'name', got %s", stmt.Fields[0].Name)
}
if stmt.Fields[0].TypeHint == nil {
t.Fatal("expected type hint for name field")
}
if stmt.Fields[0].TypeHint.Type != "string" {
t.Errorf("expected type 'string', got %s", stmt.Fields[0].TypeHint.Type)
}
// Test second field
if stmt.Fields[1].Name != "age" {
t.Errorf("expected field name 'age', got %s", stmt.Fields[1].Name)
}
if stmt.Fields[1].TypeHint == nil {
t.Fatal("expected type hint for age field")
}
if stmt.Fields[1].TypeHint.Type != "number" {
t.Errorf("expected type 'number', got %s", stmt.Fields[1].TypeHint.Type)
}
}
func TestEmptyStructDefinition(t *testing.T) {
input := `struct Empty {}`
l := parser.NewLexer(input)
p := parser.NewParser(l)
program := p.ParseProgram()
checkParserErrors(t, p)
if len(program.Statements) != 1 {
t.Fatalf("expected 1 statement, got %d", len(program.Statements))
}
stmt, ok := program.Statements[0].(*parser.StructStatement)
if !ok {
t.Fatalf("expected StructStatement, got %T", program.Statements[0])
}
if stmt.Name != "Empty" {
t.Errorf("expected struct name 'Empty', got %s", stmt.Name)
}
if len(stmt.Fields) != 0 {
t.Errorf("expected 0 fields, got %d", len(stmt.Fields))
}
}
func TestComplexStructDefinition(t *testing.T) {
input := `struct Complex {
id: number,
name: string,
active: bool,
data: table,
callback: function,
optional: any
}`
l := parser.NewLexer(input)
p := parser.NewParser(l)
program := p.ParseProgram()
checkParserErrors(t, p)
if len(program.Statements) != 1 {
t.Fatalf("expected 1 statement, got %d", len(program.Statements))
}
stmt, ok := program.Statements[0].(*parser.StructStatement)
if !ok {
t.Fatalf("expected StructStatement, got %T", program.Statements[0])
}
expectedTypes := []string{"number", "string", "bool", "table", "function", "any"}
expectedNames := []string{"id", "name", "active", "data", "callback", "optional"}
if len(stmt.Fields) != len(expectedTypes) {
t.Fatalf("expected %d fields, got %d", len(expectedTypes), len(stmt.Fields))
}
for i, field := range stmt.Fields {
if field.Name != expectedNames[i] {
t.Errorf("field %d: expected name '%s', got '%s'", i, expectedNames[i], field.Name)
}
if field.TypeHint == nil {
t.Fatalf("field %d: expected type hint", i)
}
if field.TypeHint.Type != expectedTypes[i] {
t.Errorf("field %d: expected type '%s', got '%s'", i, expectedTypes[i], field.TypeHint.Type)
}
}
}
func TestMethodDefinition(t *testing.T) {
input := `struct Person {
name: string,
age: number
}
fn Person.getName(): string
return self.name
end
fn Person.setAge(newAge: number)
self.age = newAge
end`
l := parser.NewLexer(input)
p := parser.NewParser(l)
program := p.ParseProgram()
checkParserErrors(t, p)
if len(program.Statements) != 3 {
t.Fatalf("expected 3 statements, got %d", len(program.Statements))
}
// First statement: struct definition
structStmt, ok := program.Statements[0].(*parser.StructStatement)
if !ok {
t.Fatalf("expected StructStatement, got %T", program.Statements[0])
}
if structStmt.Name != "Person" {
t.Errorf("expected struct name 'Person', got %s", structStmt.Name)
}
// Second statement: getter method
method1, ok := program.Statements[1].(*parser.MethodDefinition)
if !ok {
t.Fatalf("expected MethodDefinition, got %T", program.Statements[1])
}
if method1.StructName != "Person" {
t.Errorf("expected struct name 'Person', got %s", method1.StructName)
}
if method1.MethodName != "getName" {
t.Errorf("expected method name 'getName', got %s", method1.MethodName)
}
if method1.Function.ReturnType == nil {
t.Fatal("expected return type for getName method")
}
if method1.Function.ReturnType.Type != "string" {
t.Errorf("expected return type 'string', got %s", method1.Function.ReturnType.Type)
}
if len(method1.Function.Parameters) != 0 {
t.Errorf("expected 0 parameters, got %d", len(method1.Function.Parameters))
}
// Third statement: setter method
method2, ok := program.Statements[2].(*parser.MethodDefinition)
if !ok {
t.Fatalf("expected MethodDefinition, got %T", program.Statements[2])
}
if method2.StructName != "Person" {
t.Errorf("expected struct name 'Person', got %s", method2.StructName)
}
if method2.MethodName != "setAge" {
t.Errorf("expected method name 'setAge', got %s", method2.MethodName)
}
if method2.Function.ReturnType != nil {
t.Errorf("expected no return type for setAge method, got %s", method2.Function.ReturnType.Type)
}
if len(method2.Function.Parameters) != 1 {
t.Fatalf("expected 1 parameter, got %d", len(method2.Function.Parameters))
}
if method2.Function.Parameters[0].Name != "newAge" {
t.Errorf("expected parameter name 'newAge', got %s", method2.Function.Parameters[0].Name)
}
if method2.Function.Parameters[0].TypeHint == nil {
t.Fatal("expected type hint for newAge parameter")
}
if method2.Function.Parameters[0].TypeHint.Type != "number" {
t.Errorf("expected parameter type 'number', got %s", method2.Function.Parameters[0].TypeHint.Type)
}
}
func TestStructConstructor(t *testing.T) {
input := `struct Person {
name: string,
age: number
}
person = Person{name = "John", age = 30}
empty = Person{}`
l := parser.NewLexer(input)
p := parser.NewParser(l)
program := p.ParseProgram()
checkParserErrors(t, p)
if len(program.Statements) != 3 {
t.Fatalf("expected 3 statements, got %d", len(program.Statements))
}
// Second statement: constructor with fields
assign1, ok := program.Statements[1].(*parser.AssignStatement)
if !ok {
t.Fatalf("expected AssignStatement, got %T", program.Statements[1])
}
constructor1, ok := assign1.Value.(*parser.StructConstructorExpression)
if !ok {
t.Fatalf("expected StructConstructorExpression, got %T", assign1.Value)
}
if constructor1.StructName != "Person" {
t.Errorf("expected struct name 'Person', got %s", constructor1.StructName)
}
if len(constructor1.Fields) != 2 {
t.Fatalf("expected 2 fields, got %d", len(constructor1.Fields))
}
// Check name field
nameKey, ok := constructor1.Fields[0].Key.(*parser.Identifier)
if !ok {
t.Fatalf("expected Identifier for name key, got %T", constructor1.Fields[0].Key)
}
if nameKey.Value != "name" {
t.Errorf("expected key 'name', got %s", nameKey.Value)
}
testStringLiteral(t, constructor1.Fields[0].Value, "John")
// Check age field
ageKey, ok := constructor1.Fields[1].Key.(*parser.Identifier)
if !ok {
t.Fatalf("expected Identifier for age key, got %T", constructor1.Fields[1].Key)
}
if ageKey.Value != "age" {
t.Errorf("expected key 'age', got %s", ageKey.Value)
}
testNumberLiteral(t, constructor1.Fields[1].Value, 30)
// Third statement: empty constructor
assign2, ok := program.Statements[2].(*parser.AssignStatement)
if !ok {
t.Fatalf("expected AssignStatement, got %T", program.Statements[2])
}
constructor2, ok := assign2.Value.(*parser.StructConstructorExpression)
if !ok {
t.Fatalf("expected StructConstructorExpression, got %T", assign2.Value)
}
if constructor2.StructName != "Person" {
t.Errorf("expected struct name 'Person', got %s", constructor2.StructName)
}
if len(constructor2.Fields) != 0 {
t.Errorf("expected 0 fields, got %d", len(constructor2.Fields))
}
}
func TestNestedStructTypes(t *testing.T) {
input := `struct Address {
street: string,
city: string
}
struct Person {
name: string,
address: Address
}
person = Person{
name = "John",
address = Address{street = "Main St", city = "NYC"}
}`
l := parser.NewLexer(input)
p := parser.NewParser(l)
program := p.ParseProgram()
checkParserErrors(t, p)
if len(program.Statements) != 3 {
t.Fatalf("expected 3 statements, got %d", len(program.Statements))
}
// Check Person struct has Address field type
personStruct, ok := program.Statements[1].(*parser.StructStatement)
if !ok {
t.Fatalf("expected StructStatement, got %T", program.Statements[1])
}
addressField := personStruct.Fields[1]
if addressField.Name != "address" {
t.Errorf("expected field name 'address', got %s", addressField.Name)
}
if addressField.TypeHint.Type != "Address" {
t.Errorf("expected field type 'Address', got %s", addressField.TypeHint.Type)
}
// Check nested constructor
assign, ok := program.Statements[2].(*parser.AssignStatement)
if !ok {
t.Fatalf("expected AssignStatement, got %T", program.Statements[2])
}
personConstructor, ok := assign.Value.(*parser.StructConstructorExpression)
if !ok {
t.Fatalf("expected StructConstructorExpression, got %T", assign.Value)
}
// Check the nested Address constructor
addressConstructor, ok := personConstructor.Fields[1].Value.(*parser.StructConstructorExpression)
if !ok {
t.Fatalf("expected nested StructConstructorExpression, got %T", personConstructor.Fields[1].Value)
}
if addressConstructor.StructName != "Address" {
t.Errorf("expected nested struct name 'Address', got %s", addressConstructor.StructName)
}
if len(addressConstructor.Fields) != 2 {
t.Errorf("expected 2 fields in nested constructor, got %d", len(addressConstructor.Fields))
}
}
func TestStructIntegrationWithProgram(t *testing.T) {
input := `struct Point {
x: number,
y: number
}
fn Point.distance(other: Point): number
dx = self.x - other.x
dy = self.y - other.y
return (dx * dx + dy * dy)
end
p1 = Point{x = 0, y = 0}
p2 = Point{x = 3, y = 4}
if p1.distance(p2) then
echo "Distance calculated"
end
for i = 1, 10 do
point = Point{x = i, y = i * 2}
echo point.x
end`
l := parser.NewLexer(input)
p := parser.NewParser(l)
program := p.ParseProgram()
checkParserErrors(t, p)
if len(program.Statements) != 6 {
t.Fatalf("expected 6 statements, got %d", len(program.Statements))
}
// Verify struct definition
structStmt, ok := program.Statements[0].(*parser.StructStatement)
if !ok {
t.Fatalf("expected StructStatement, got %T", program.Statements[0])
}
if structStmt.Name != "Point" {
t.Errorf("expected struct name 'Point', got %s", structStmt.Name)
}
// Verify method definition
methodStmt, ok := program.Statements[1].(*parser.MethodDefinition)
if !ok {
t.Fatalf("expected MethodDefinition, got %T", program.Statements[1])
}
if methodStmt.StructName != "Point" {
t.Errorf("expected struct name 'Point', got %s", methodStmt.StructName)
}
if methodStmt.MethodName != "distance" {
t.Errorf("expected method name 'distance', got %s", methodStmt.MethodName)
}
// Verify constructors
for i := 2; i <= 3; i++ {
assign, ok := program.Statements[i].(*parser.AssignStatement)
if !ok {
t.Fatalf("statement %d: expected AssignStatement, got %T", i, program.Statements[i])
}
constructor, ok := assign.Value.(*parser.StructConstructorExpression)
if !ok {
t.Fatalf("statement %d: expected StructConstructorExpression, got %T", i, assign.Value)
}
if constructor.StructName != "Point" {
t.Errorf("statement %d: expected struct name 'Point', got %s", i, constructor.StructName)
}
}
// Verify if statement with method call
ifStmt, ok := program.Statements[4].(*parser.IfStatement)
if !ok {
t.Fatalf("expected IfStatement, got %T", program.Statements[4])
}
callExpr, ok := ifStmt.Condition.(*parser.CallExpression)
if !ok {
t.Fatalf("expected CallExpression in if condition, got %T", ifStmt.Condition)
}
dotExpr, ok := callExpr.Function.(*parser.DotExpression)
if !ok {
t.Fatalf("expected DotExpression for method call, got %T", callExpr.Function)
}
if dotExpr.Key != "distance" {
t.Errorf("expected method name 'distance', got %s", dotExpr.Key)
}
// Verify for loop with struct creation
forStmt, ok := program.Statements[5].(*parser.ForStatement)
if !ok {
t.Fatalf("expected ForStatement, got %T", program.Statements[5])
}
if len(forStmt.Body) != 2 {
t.Errorf("expected 2 statements in for body, got %d", len(forStmt.Body))
}
// Check struct constructor in loop
loopAssign, ok := forStmt.Body[0].(*parser.AssignStatement)
if !ok {
t.Fatalf("expected AssignStatement in loop, got %T", forStmt.Body[0])
}
loopConstructor, ok := loopAssign.Value.(*parser.StructConstructorExpression)
if !ok {
t.Fatalf("expected StructConstructorExpression in loop, got %T", loopAssign.Value)
}
if loopConstructor.StructName != "Point" {
t.Errorf("expected struct name 'Point' in loop, got %s", loopConstructor.StructName)
}
}
func TestStructErrorCases(t *testing.T) {
tests := []struct {
name string
input string
expectedErrorSubstring string
}{
{
name: "missing field type",
input: `struct Person {
name
}`,
expectedErrorSubstring: "struct fields require type annotation",
},
{
name: "missing struct name",
input: `struct {
name: string
}`,
expectedErrorSubstring: "expected struct name",
},
{
name: "missing opening brace",
input: `struct Person
name: string
}`,
expectedErrorSubstring: "expected '{' after struct name",
},
{
name: "missing closing brace",
input: `struct Person {
name: string`,
expectedErrorSubstring: "expected next token to be }, got end of file instead",
},
{
name: "invalid field type",
input: `struct Person {
name: invalidtype
}`,
expectedErrorSubstring: "invalid type name 'invalidtype'",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
l := parser.NewLexer(tt.input)
p := parser.NewParser(l)
_ = p.ParseProgram()
if !p.HasErrors() {
t.Fatalf("expected parser errors, but got none")
}
errors := p.ErrorStrings()
found := false
for _, err := range errors {
if containsSubstring(err, tt.expectedErrorSubstring) {
found = true
break
}
}
if !found {
t.Errorf("expected error containing '%s', got errors: %v", tt.expectedErrorSubstring, errors)
}
})
}
}
func TestSingleLineStruct(t *testing.T) {
input := `struct Person { name: string, age: number }`
l := parser.NewLexer(input)
p := parser.NewParser(l)
program := p.ParseProgram()
checkParserErrors(t, p)
if len(program.Statements) != 1 {
t.Fatalf("expected 1 statement, got %d", len(program.Statements))
}
stmt, ok := program.Statements[0].(*parser.StructStatement)
if !ok {
t.Fatalf("expected StructStatement, got %T", program.Statements[0])
}
if stmt.Name != "Person" {
t.Errorf("expected struct name 'Person', got %s", stmt.Name)
}
if len(stmt.Fields) != 2 {
t.Fatalf("expected 2 fields, got %d", len(stmt.Fields))
}
if stmt.Fields[0].Name != "name" || stmt.Fields[0].TypeHint.Type != "string" {
t.Errorf("expected first field 'name: string', got '%s: %s'",
stmt.Fields[0].Name, stmt.Fields[0].TypeHint.Type)
}
if stmt.Fields[1].Name != "age" || stmt.Fields[1].TypeHint.Type != "number" {
t.Errorf("expected second field 'age: number', got '%s: %s'",
stmt.Fields[1].Name, stmt.Fields[1].TypeHint.Type)
}
}
func TestStructString(t *testing.T) {
input := `struct Person {
name: string,
age: number
}`
l := parser.NewLexer(input)
p := parser.NewParser(l)
program := p.ParseProgram()
checkParserErrors(t, p)
stmt := program.Statements[0].(*parser.StructStatement)
str := stmt.String()
expected := "struct Person {\n\tname: string,\n\tage: number\n}"
if str != expected {
t.Errorf("expected string representation:\n%s\ngot:\n%s", expected, str)
}
}
func TestMethodString(t *testing.T) {
input := `struct Person {
name: string
}
fn Person.getName(): string
return self.name
end`
l := parser.NewLexer(input)
p := parser.NewParser(l)
program := p.ParseProgram()
checkParserErrors(t, p)
method := program.Statements[1].(*parser.MethodDefinition)
str := method.String()
if !containsSubstring(str, "fn Person.getName") {
t.Errorf("expected method string to contain 'fn Person.getName', got: %s", str)
}
if !containsSubstring(str, ": string") {
t.Errorf("expected method string to contain return type, got: %s", str)
}
}
func TestConstructorString(t *testing.T) {
input := `struct Person {
name: string,
age: number
}
person = Person{name = "John", age = 30}`
l := parser.NewLexer(input)
p := parser.NewParser(l)
program := p.ParseProgram()
checkParserErrors(t, p)
assign := program.Statements[1].(*parser.AssignStatement)
constructor := assign.Value.(*parser.StructConstructorExpression)
str := constructor.String()
expected := `Person{name = "John", age = 30.00}`
if str != expected {
t.Errorf("expected constructor string:\n%s\ngot:\n%s", expected, str)
}
}

View File

@ -59,6 +59,7 @@ const (
EXIT
FN
RETURN
STRUCT
// Special
EOF
@ -107,6 +108,7 @@ var precedences = map[TokenType]Precedence{
DOT: MEMBER,
LBRACKET: MEMBER,
LPAREN: CALL,
LBRACE: CALL,
}
// lookupIdent checks if an identifier is a keyword
@ -132,6 +134,7 @@ func lookupIdent(ident string) TokenType {
"exit": EXIT,
"fn": FN,
"return": RETURN,
"struct": STRUCT,
}
if tok, ok := keywords[ident]; ok {

View File

@ -76,6 +76,9 @@ type TypeInferrer struct {
nilType *TypeInfo
tableType *TypeInfo
anyType *TypeInfo
// Struct definitions
structs map[string]*StructStatement
}
// NewTypeInferrer creates a new type inference engine
@ -86,6 +89,7 @@ func NewTypeInferrer() *TypeInferrer {
currentScope: globalScope,
globalScope: globalScope,
errors: []TypeError{},
structs: make(map[string]*StructStatement),
// Pre-allocate common types to reduce allocations
numberType: &TypeInfo{Type: TypeNumber, Inferred: true},
@ -101,6 +105,14 @@ func NewTypeInferrer() *TypeInferrer {
// InferTypes performs type inference on the entire program
func (ti *TypeInferrer) InferTypes(program *Program) []TypeError {
// First pass: collect struct definitions
for _, stmt := range program.Statements {
if structStmt, ok := stmt.(*StructStatement); ok {
ti.structs[structStmt.Name] = structStmt
}
}
// Second pass: infer types
for _, stmt := range program.Statements {
ti.inferStatement(stmt)
}
@ -129,9 +141,27 @@ func (ti *TypeInferrer) addError(message string, node Node) {
})
}
// getStructType returns TypeInfo for a struct
func (ti *TypeInferrer) getStructType(name string) *TypeInfo {
if _, exists := ti.structs[name]; exists {
return &TypeInfo{Type: name, Inferred: true}
}
return nil
}
// isStructType checks if a type is a struct type
func (ti *TypeInferrer) isStructType(t *TypeInfo) bool {
_, exists := ti.structs[t.Type]
return exists
}
// inferStatement infers types for statements
func (ti *TypeInferrer) inferStatement(stmt Statement) {
switch s := stmt.(type) {
case *StructStatement:
ti.inferStructStatement(s)
case *MethodDefinition:
ti.inferMethodDefinition(s)
case *AssignStatement:
ti.inferAssignStatement(s)
case *EchoStatement:
@ -152,9 +182,63 @@ func (ti *TypeInferrer) inferStatement(stmt Statement) {
if s.Value != nil {
ti.inferExpression(s.Value)
}
case *ExpressionStatement:
ti.inferExpression(s.Expression)
}
}
// inferStructStatement handles struct definitions
func (ti *TypeInferrer) inferStructStatement(stmt *StructStatement) {
// Validate field types
for _, field := range stmt.Fields {
if field.TypeHint != nil {
if !ValidTypeName(field.TypeHint.Type) && !ti.isStructType(field.TypeHint) {
ti.addError(fmt.Sprintf("invalid field type '%s' in struct '%s'",
field.TypeHint.Type, stmt.Name), stmt)
}
}
}
}
// inferMethodDefinition handles method definitions
func (ti *TypeInferrer) inferMethodDefinition(stmt *MethodDefinition) {
// Check if struct exists
if _, exists := ti.structs[stmt.StructName]; !exists {
ti.addError(fmt.Sprintf("method defined on undefined struct '%s'", stmt.StructName), stmt)
return
}
// Infer the function body
ti.enterScope()
// Add self parameter implicitly
ti.currentScope.Define(&Symbol{
Name: "self",
Type: ti.getStructType(stmt.StructName),
Declared: true,
})
// Add explicit parameters
for _, param := range stmt.Function.Parameters {
paramType := ti.anyType
if param.TypeHint != nil {
paramType = param.TypeHint
}
ti.currentScope.Define(&Symbol{
Name: param.Name,
Type: paramType,
Declared: true,
})
}
// Infer function body
for _, bodyStmt := range stmt.Function.Body {
ti.inferStatement(bodyStmt)
}
ti.exitScope()
}
// inferAssignStatement handles variable assignments with type checking
func (ti *TypeInferrer) inferAssignStatement(stmt *AssignStatement) {
// Infer the type of the value expression
@ -288,9 +372,9 @@ func (ti *TypeInferrer) inferForStatement(stmt *ForStatement) {
func (ti *TypeInferrer) inferForInStatement(stmt *ForInStatement) {
iterableType := ti.inferExpression(stmt.Iterable)
// For now, assume iterable is a table
if !ti.isTableType(iterableType) {
ti.addError("for-in requires an iterable (table)", stmt.Iterable)
// For now, assume iterable is a table or struct
if !ti.isTableType(iterableType) && !ti.isStructType(iterableType) {
ti.addError("for-in requires an iterable (table or struct)", stmt.Iterable)
}
ti.enterScope()
@ -341,6 +425,8 @@ func (ti *TypeInferrer) inferExpression(expr Expression) *TypeInfo {
return ti.nilType
case *TableLiteral:
return ti.inferTableLiteral(e)
case *StructConstructorExpression:
return ti.inferStructConstructor(e)
case *FunctionLiteral:
return ti.inferFunctionLiteral(e)
case *CallExpression:
@ -353,12 +439,85 @@ func (ti *TypeInferrer) inferExpression(expr Expression) *TypeInfo {
return ti.inferIndexExpression(e)
case *DotExpression:
return ti.inferDotExpression(e)
case *AssignExpression:
return ti.inferAssignExpression(e)
default:
ti.addError("unknown expression type", expr)
return ti.anyType
}
}
// inferStructConstructor handles struct constructor expressions
func (ti *TypeInferrer) inferStructConstructor(expr *StructConstructorExpression) *TypeInfo {
structDef, exists := ti.structs[expr.StructName]
if !exists {
ti.addError(fmt.Sprintf("undefined struct '%s'", expr.StructName), expr)
return ti.anyType
}
// Validate field assignments
for _, pair := range expr.Fields {
if pair.Key != nil {
fieldName := ""
if ident, ok := pair.Key.(*Identifier); ok {
fieldName = ident.Value
} else if str, ok := pair.Key.(*StringLiteral); ok {
fieldName = str.Value
}
// Check if field exists in struct
fieldExists := false
var fieldType *TypeInfo
for _, field := range structDef.Fields {
if field.Name == fieldName {
fieldExists = true
fieldType = field.TypeHint
break
}
}
if !fieldExists {
ti.addError(fmt.Sprintf("struct '%s' has no field '%s'", expr.StructName, fieldName), expr)
} else {
// Check type compatibility
valueType := ti.inferExpression(pair.Value)
if !ti.isTypeCompatible(valueType, fieldType) {
ti.addError(fmt.Sprintf("cannot assign %s to field '%s' of type %s",
valueType.Type, fieldName, fieldType.Type), expr)
}
}
} else {
// Array-style assignment not valid for structs
ti.addError("struct constructors require named field assignments", expr)
}
}
structType := ti.getStructType(expr.StructName)
expr.SetType(structType)
return structType
}
// inferAssignExpression handles assignment expressions
func (ti *TypeInferrer) inferAssignExpression(expr *AssignExpression) *TypeInfo {
valueType := ti.inferExpression(expr.Value)
if ident, ok := expr.Name.(*Identifier); ok {
if expr.IsDeclaration {
ti.currentScope.Define(&Symbol{
Name: ident.Value,
Type: valueType,
Declared: true,
})
}
ident.SetType(valueType)
} else {
ti.inferExpression(expr.Name)
}
expr.SetType(valueType)
return valueType
}
// inferIdentifier looks up identifier type in symbol table
func (ti *TypeInferrer) inferIdentifier(ident *Identifier) *TypeInfo {
symbol := ti.currentScope.Lookup(ident.Value)
@ -497,17 +656,43 @@ func (ti *TypeInferrer) inferInfixExpression(infix *InfixExpression) *TypeInfo {
// inferIndexExpression infers table[index] type
func (ti *TypeInferrer) inferIndexExpression(index *IndexExpression) *TypeInfo {
ti.inferExpression(index.Left)
leftType := ti.inferExpression(index.Left)
ti.inferExpression(index.Index)
// For now, assume table access returns any
// If indexing a struct, try to infer field type
if ti.isStructType(leftType) {
if strLit, ok := index.Index.(*StringLiteral); ok {
if structDef, exists := ti.structs[leftType.Type]; exists {
for _, field := range structDef.Fields {
if field.Name == strLit.Value {
index.SetType(field.TypeHint)
return field.TypeHint
}
}
}
}
}
// For now, assume table/struct access returns any
index.SetType(ti.anyType)
return ti.anyType
}
// inferDotExpression infers table.key type
func (ti *TypeInferrer) inferDotExpression(dot *DotExpression) *TypeInfo {
ti.inferExpression(dot.Left)
leftType := ti.inferExpression(dot.Left)
// If accessing a struct field, try to infer field type
if ti.isStructType(leftType) {
if structDef, exists := ti.structs[leftType.Type]; exists {
for _, field := range structDef.Fields {
if field.Name == dot.Key {
dot.SetType(field.TypeHint)
return field.TypeHint
}
}
}
}
// For now, assume member access returns any
dot.SetType(ti.anyType)