Mako/parser/ast.go

632 lines
16 KiB
Go

package parser
import "fmt"
// Position represents source location information
type Position struct {
Line int
Column int
}
// Node represents any node in the AST
type Node interface {
String() string
Pos() Position
}
// Statement represents statement nodes that can appear at the top level or in blocks
type Statement interface {
Node
statementNode()
}
// Expression represents expression nodes that produce values and have types
type Expression interface {
Node
expressionNode()
TypeInfo() TypeInfo
}
// Program represents the root of the AST containing all top-level statements.
type Program struct {
Statements []Statement
ExitCode int
Position Position
}
func (p *Program) String() string {
var result string
for _, stmt := range p.Statements {
result += stmt.String() + "\n"
}
return result
}
func (p *Program) Pos() Position { return p.Position }
// StructField represents a field definition within a struct.
type StructField struct {
Name string
TypeHint TypeInfo
Position Position
}
func (sf *StructField) String() string {
return fmt.Sprintf("%s: %s", sf.Name, typeToString(sf.TypeHint))
}
func (sf *StructField) Pos() Position { return sf.Position }
// StructStatement represents struct type definitions with named fields.
type StructStatement struct {
Name string
Fields []StructField
ID uint16
Position Position
}
func (ss *StructStatement) statementNode() {}
func (ss *StructStatement) String() string {
var fields string
for i, field := range ss.Fields {
if i > 0 {
fields += ",\n\t"
}
fields += field.String()
}
return fmt.Sprintf("struct %s {\n\t%s\n}", ss.Name, fields)
}
func (ss *StructStatement) Pos() Position { return ss.Position }
// MethodDefinition represents method definitions attached to struct types.
type MethodDefinition struct {
StructID uint16
MethodName string
Function *FunctionLiteral
Position Position
}
func (md *MethodDefinition) statementNode() {}
func (md *MethodDefinition) String() string {
return fmt.Sprintf("fn <struct>.%s%s", md.MethodName, md.Function.String()[2:])
}
func (md *MethodDefinition) Pos() Position { return md.Position }
// StructConstructor represents struct instantiation with field initialization.
type StructConstructor struct {
StructID uint16
Fields []TablePair
typeInfo TypeInfo
Position Position
}
func (sc *StructConstructor) expressionNode() {}
func (sc *StructConstructor) String() string {
var pairs []string
for _, pair := range sc.Fields {
pairs = append(pairs, pair.String())
}
return fmt.Sprintf("<struct>{%s}", joinStrings(pairs, ", "))
}
func (sc *StructConstructor) TypeInfo() TypeInfo { return sc.typeInfo }
func (sc *StructConstructor) Pos() Position { return sc.Position }
// Assignment represents both variable assignment statements and assignment expressions.
type Assignment struct {
Target Expression
Value Expression
TypeHint TypeInfo
IsDeclaration bool
IsExpression bool
Position Position
}
func (a *Assignment) statementNode() {}
func (a *Assignment) expressionNode() {}
func (a *Assignment) String() string {
prefix := ""
if a.IsDeclaration {
prefix = "local "
}
var nameStr string
if a.TypeHint.Type != TypeUnknown {
nameStr = fmt.Sprintf("%s: %s", a.Target.String(), typeToString(a.TypeHint))
} else {
nameStr = a.Target.String()
}
result := fmt.Sprintf("%s%s = %s", prefix, nameStr, a.Value.String())
if a.IsExpression {
return "(" + result + ")"
}
return result
}
func (a *Assignment) TypeInfo() TypeInfo { return a.Value.TypeInfo() }
func (a *Assignment) Pos() Position { return a.Position }
// ExpressionStatement wraps expressions used as statements.
type ExpressionStatement struct {
Expression Expression
Position Position
}
func (es *ExpressionStatement) statementNode() {}
func (es *ExpressionStatement) String() string {
return es.Expression.String()
}
func (es *ExpressionStatement) Pos() Position { return es.Position }
// EchoStatement represents output statements for displaying values.
type EchoStatement struct {
Value Expression
Position Position
}
func (es *EchoStatement) statementNode() {}
func (es *EchoStatement) String() string {
return fmt.Sprintf("echo %s", es.Value.String())
}
func (es *EchoStatement) Pos() Position { return es.Position }
// BreakStatement represents loop exit statements.
type BreakStatement struct {
Position Position
}
func (bs *BreakStatement) statementNode() {}
func (bs *BreakStatement) String() string { return "break" }
func (bs *BreakStatement) Pos() Position { return bs.Position }
// ExitStatement represents script termination with optional exit code.
type ExitStatement struct {
Value Expression
Position Position
}
func (es *ExitStatement) statementNode() {}
func (es *ExitStatement) String() string {
if es.Value == nil {
return "exit"
}
return fmt.Sprintf("exit %s", es.Value.String())
}
func (es *ExitStatement) Pos() Position { return es.Position }
// ReturnStatement represents function return with optional value.
type ReturnStatement struct {
Value Expression
Position Position
}
func (rs *ReturnStatement) statementNode() {}
func (rs *ReturnStatement) String() string {
if rs.Value == nil {
return "return"
}
return fmt.Sprintf("return %s", rs.Value.String())
}
func (rs *ReturnStatement) Pos() Position { return rs.Position }
// ElseIfClause represents conditional branches in if statements.
type ElseIfClause struct {
Condition Expression
Body []Statement
Position Position
}
func (eic *ElseIfClause) String() string {
var body string
for _, stmt := range eic.Body {
body += "\t" + stmt.String() + "\n"
}
return fmt.Sprintf("elseif %s then\n%s", eic.Condition.String(), body)
}
func (eic *ElseIfClause) Pos() Position { return eic.Position }
// IfStatement represents conditional execution with optional elseif and else branches.
type IfStatement struct {
Condition Expression
Body []Statement
ElseIfs []ElseIfClause
Else []Statement
Position Position
}
func (is *IfStatement) statementNode() {}
func (is *IfStatement) String() string {
var result string
result += fmt.Sprintf("if %s then\n", is.Condition.String())
for _, stmt := range is.Body {
result += "\t" + stmt.String() + "\n"
}
for _, elseif := range is.ElseIfs {
result += elseif.String()
}
if len(is.Else) > 0 {
result += "else\n"
for _, stmt := range is.Else {
result += "\t" + stmt.String() + "\n"
}
}
result += "end"
return result
}
func (is *IfStatement) Pos() Position { return is.Position }
// WhileStatement represents condition-based loops that execute while condition is true.
type WhileStatement struct {
Condition Expression
Body []Statement
Position Position
}
func (ws *WhileStatement) statementNode() {}
func (ws *WhileStatement) String() string {
var result string
result += fmt.Sprintf("while %s do\n", ws.Condition.String())
for _, stmt := range ws.Body {
result += "\t" + stmt.String() + "\n"
}
result += "end"
return result
}
func (ws *WhileStatement) Pos() Position { return ws.Position }
// ForStatement represents numeric for loops with start, end, and optional step.
type ForStatement struct {
Variable *Identifier
Start Expression
End Expression
Step Expression
Body []Statement
Position Position
}
func (fs *ForStatement) statementNode() {}
func (fs *ForStatement) String() string {
var result string
if fs.Step != nil {
result += fmt.Sprintf("for %s = %s, %s, %s do\n",
fs.Variable.String(), fs.Start.String(), fs.End.String(), fs.Step.String())
} else {
result += fmt.Sprintf("for %s = %s, %s do\n",
fs.Variable.String(), fs.Start.String(), fs.End.String())
}
for _, stmt := range fs.Body {
result += "\t" + stmt.String() + "\n"
}
result += "end"
return result
}
func (fs *ForStatement) Pos() Position { return fs.Position }
// ForInStatement represents iterator-based loops over tables, arrays, or other iterables.
type ForInStatement struct {
Key *Identifier
Value *Identifier
Iterable Expression
Body []Statement
Position Position
}
func (fis *ForInStatement) statementNode() {}
func (fis *ForInStatement) String() string {
var result string
if fis.Key != nil {
result += fmt.Sprintf("for %s, %s in %s do\n",
fis.Key.String(), fis.Value.String(), fis.Iterable.String())
} else {
result += fmt.Sprintf("for %s in %s do\n",
fis.Value.String(), fis.Iterable.String())
}
for _, stmt := range fis.Body {
result += "\t" + stmt.String() + "\n"
}
result += "end"
return result
}
func (fis *ForInStatement) Pos() Position { return fis.Position }
// FunctionParameter represents a parameter in function definitions.
type FunctionParameter struct {
Name string
TypeHint TypeInfo
Position Position
}
func (fp *FunctionParameter) String() string {
if fp.TypeHint.Type != TypeUnknown {
return fmt.Sprintf("%s: %s", fp.Name, typeToString(fp.TypeHint))
}
return fp.Name
}
func (fp *FunctionParameter) Pos() Position { return fp.Position }
// Identifier represents variable references and names.
type Identifier struct {
Value string
ScopeDepth int // 0 = global, 1+ = local depth
SlotIndex int // register/stack slot (-1 = unresolved)
typeInfo TypeInfo
Position Position
}
func (i *Identifier) expressionNode() {}
func (i *Identifier) String() string { return i.Value }
func (i *Identifier) TypeInfo() TypeInfo {
if i.typeInfo.Type == TypeUnknown {
return AnyType
}
return i.typeInfo
}
func (i *Identifier) Pos() Position { return i.Position }
func (i *Identifier) IsResolved() bool { return i.SlotIndex >= 0 }
// NumberLiteral represents numeric constants including integers, floats, hex, and binary.
type NumberLiteral struct {
Value float64
Position Position
}
func (nl *NumberLiteral) expressionNode() {}
func (nl *NumberLiteral) String() string { return fmt.Sprintf("%.2f", nl.Value) }
func (nl *NumberLiteral) TypeInfo() TypeInfo { return NumberType }
func (nl *NumberLiteral) Pos() Position { return nl.Position }
// StringLiteral represents string constants and multiline strings.
type StringLiteral struct {
Value string
Position Position
}
func (sl *StringLiteral) expressionNode() {}
func (sl *StringLiteral) String() string { return fmt.Sprintf(`"%s"`, sl.Value) }
func (sl *StringLiteral) TypeInfo() TypeInfo { return StringType }
func (sl *StringLiteral) Pos() Position { return sl.Position }
// BooleanLiteral represents true and false constants.
type BooleanLiteral struct {
Value bool
Position Position
}
func (bl *BooleanLiteral) expressionNode() {}
func (bl *BooleanLiteral) String() string {
if bl.Value {
return "true"
}
return "false"
}
func (bl *BooleanLiteral) TypeInfo() TypeInfo { return BoolType }
func (bl *BooleanLiteral) Pos() Position { return bl.Position }
// NilLiteral represents the nil constant value.
type NilLiteral struct {
Position Position
}
func (nl *NilLiteral) expressionNode() {}
func (nl *NilLiteral) String() string { return "nil" }
func (nl *NilLiteral) TypeInfo() TypeInfo { return NilType }
func (nl *NilLiteral) Pos() Position { return nl.Position }
// FunctionLiteral represents function definitions with parameters, body, and optional return type.
type FunctionLiteral struct {
Parameters []FunctionParameter
Body []Statement
ReturnType TypeInfo
Variadic bool
LocalCount int // Pre-computed local variable count
UpvalueCount int // Number of captured variables
MaxStackDepth int // Maximum expression evaluation depth
Position Position
}
func (fl *FunctionLiteral) expressionNode() {}
func (fl *FunctionLiteral) String() string {
var params string
for i, param := range fl.Parameters {
if i > 0 {
params += ", "
}
params += param.String()
}
if fl.Variadic {
if len(fl.Parameters) > 0 {
params += ", "
}
params += "..."
}
result := fmt.Sprintf("fn(%s)", params)
if fl.ReturnType.Type != TypeUnknown {
result += ": " + typeToString(fl.ReturnType)
}
result += "\n"
for _, stmt := range fl.Body {
result += "\t" + stmt.String() + "\n"
}
result += "end"
return result
}
func (fl *FunctionLiteral) TypeInfo() TypeInfo { return FunctionType }
func (fl *FunctionLiteral) Pos() Position { return fl.Position }
// CallExpression represents function calls with arguments.
type CallExpression struct {
Function Expression
Arguments []Expression
typeInfo TypeInfo
Position Position
}
func (ce *CallExpression) expressionNode() {}
func (ce *CallExpression) String() string {
var args []string
for _, arg := range ce.Arguments {
args = append(args, arg.String())
}
return fmt.Sprintf("%s(%s)", ce.Function.String(), joinStrings(args, ", "))
}
func (ce *CallExpression) TypeInfo() TypeInfo { return ce.typeInfo }
func (ce *CallExpression) Pos() Position { return ce.Position }
// PrefixExpression represents unary operations like negation and logical not.
type PrefixExpression struct {
Operator string
Right Expression
typeInfo TypeInfo
Position Position
}
func (pe *PrefixExpression) expressionNode() {}
func (pe *PrefixExpression) String() string {
if pe.Operator == "not" {
return fmt.Sprintf("(%s %s)", pe.Operator, pe.Right.String())
}
return fmt.Sprintf("(%s%s)", pe.Operator, pe.Right.String())
}
func (pe *PrefixExpression) TypeInfo() TypeInfo { return pe.typeInfo }
func (pe *PrefixExpression) Pos() Position { return pe.Position }
// InfixExpression represents binary operations between two expressions.
type InfixExpression struct {
Left Expression
Right Expression
Operator string
typeInfo TypeInfo
Position Position
}
func (ie *InfixExpression) expressionNode() {}
func (ie *InfixExpression) String() string {
return fmt.Sprintf("(%s %s %s)", ie.Left.String(), ie.Operator, ie.Right.String())
}
func (ie *InfixExpression) TypeInfo() TypeInfo { return ie.typeInfo }
func (ie *InfixExpression) Pos() Position { return ie.Position }
// IndexExpression represents bracket-based member access (table[key]).
type IndexExpression struct {
Left Expression
Index Expression
typeInfo TypeInfo
Position Position
}
func (ie *IndexExpression) expressionNode() {}
func (ie *IndexExpression) String() string {
return fmt.Sprintf("%s[%s]", ie.Left.String(), ie.Index.String())
}
func (ie *IndexExpression) TypeInfo() TypeInfo { return ie.typeInfo }
func (ie *IndexExpression) Pos() Position { return ie.Position }
// DotExpression represents dot-based member access (table.key).
type DotExpression struct {
Left Expression
Key string
typeInfo TypeInfo
Position Position
}
func (de *DotExpression) expressionNode() {}
func (de *DotExpression) String() string {
return fmt.Sprintf("%s.%s", de.Left.String(), de.Key)
}
func (de *DotExpression) TypeInfo() TypeInfo { return de.typeInfo }
func (de *DotExpression) Pos() Position { return de.Position }
// TablePair represents key-value pairs in table literals and struct constructors.
type TablePair struct {
Key Expression
Value Expression
Position Position
}
func (tp *TablePair) String() string {
if tp.Key == nil {
return tp.Value.String()
}
return fmt.Sprintf("%s = %s", tp.Key.String(), tp.Value.String())
}
func (tp *TablePair) Pos() Position { return tp.Position }
// TableLiteral represents table/array/object literals with key-value pairs.
type TableLiteral struct {
Pairs []TablePair
Position Position
}
func (tl *TableLiteral) expressionNode() {}
func (tl *TableLiteral) String() string {
var pairs []string
for _, pair := range tl.Pairs {
pairs = append(pairs, pair.String())
}
return fmt.Sprintf("{%s}", joinStrings(pairs, ", "))
}
func (tl *TableLiteral) TypeInfo() TypeInfo { return TableType }
func (tl *TableLiteral) Pos() Position { return tl.Position }
// IsArray returns true if this table contains only array-style elements (no explicit keys)
func (tl *TableLiteral) IsArray() bool {
for _, pair := range tl.Pairs {
if pair.Key != nil {
return false
}
}
return true
}
// Helper function to convert TypeInfo to string representation
func typeToString(t TypeInfo) string {
switch t.Type {
case TypeNumber:
return "number"
case TypeString:
return "string"
case TypeBool:
return "bool"
case TypeNil:
return "nil"
case TypeTable:
return "table"
case TypeFunction:
return "function"
case TypeAny:
return "any"
case TypeStruct:
return fmt.Sprintf("struct<%d>", t.StructID)
default:
return "unknown"
}
}
// joinStrings efficiently joins string slice with separator
func joinStrings(strs []string, sep string) string {
if len(strs) == 0 {
return ""
}
if len(strs) == 1 {
return strs[0]
}
var result string
for i, s := range strs {
if i > 0 {
result += sep
}
result += s
}
return result
}