290 lines
6.9 KiB
Go
290 lines
6.9 KiB
Go
package compiler
|
|
|
|
import (
|
|
"fmt"
|
|
|
|
"git.sharkk.net/Sharkk/Mako/parser"
|
|
"git.sharkk.net/Sharkk/Mako/types"
|
|
)
|
|
|
|
// Compile converts AST to bytecode
|
|
func Compile(program *parser.Program) *types.Bytecode {
|
|
c := &compiler{
|
|
constants: []any{},
|
|
instructions: []types.Instruction{},
|
|
scopes: []scope{},
|
|
}
|
|
|
|
// Start in global scope
|
|
c.enterScope()
|
|
|
|
for _, stmt := range program.Statements {
|
|
c.compileStatement(stmt)
|
|
}
|
|
|
|
c.exitScope()
|
|
|
|
return &types.Bytecode{
|
|
Constants: c.constants,
|
|
Instructions: c.instructions,
|
|
}
|
|
}
|
|
|
|
type scope struct {
|
|
variables map[string]bool
|
|
}
|
|
|
|
type compiler struct {
|
|
constants []any
|
|
instructions []types.Instruction
|
|
scopes []scope
|
|
}
|
|
|
|
func (c *compiler) enterScope() {
|
|
c.scopes = append(c.scopes, scope{
|
|
variables: make(map[string]bool),
|
|
})
|
|
c.emit(types.OpEnterScope, 0)
|
|
}
|
|
|
|
func (c *compiler) exitScope() {
|
|
c.scopes = c.scopes[:len(c.scopes)-1]
|
|
c.emit(types.OpExitScope, 0)
|
|
}
|
|
|
|
func (c *compiler) declareVariable(name string) {
|
|
if len(c.scopes) > 0 {
|
|
c.scopes[len(c.scopes)-1].variables[name] = true
|
|
}
|
|
}
|
|
|
|
func (c *compiler) isLocalVariable(name string) bool {
|
|
for i := len(c.scopes) - 1; i >= 0; i-- {
|
|
if _, ok := c.scopes[i].variables[name]; ok {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (c *compiler) compileStatement(stmt parser.Statement) {
|
|
switch s := stmt.(type) {
|
|
case *parser.VariableStatement:
|
|
c.compileExpression(s.Value)
|
|
nameIndex := c.addConstant(s.Name.Value)
|
|
|
|
// Use SetGlobal for top-level variables to persist between REPL lines
|
|
if len(c.scopes) <= 1 {
|
|
c.emit(types.OpSetGlobal, nameIndex)
|
|
} else {
|
|
c.declareVariable(s.Name.Value)
|
|
c.emit(types.OpSetLocal, nameIndex)
|
|
}
|
|
|
|
case *parser.IndexAssignmentStatement:
|
|
c.compileExpression(s.Left)
|
|
c.compileExpression(s.Index)
|
|
c.compileExpression(s.Value)
|
|
c.emit(types.OpSetIndex, 0)
|
|
|
|
case *parser.EchoStatement:
|
|
c.compileExpression(s.Value)
|
|
c.emit(types.OpEcho, 0)
|
|
|
|
case *parser.BlockStatement:
|
|
c.enterScope()
|
|
for _, blockStmt := range s.Statements {
|
|
c.compileStatement(blockStmt)
|
|
}
|
|
c.exitScope()
|
|
}
|
|
}
|
|
|
|
func (c *compiler) compileExpression(expr parser.Expression) {
|
|
switch e := expr.(type) {
|
|
case *parser.StringLiteral:
|
|
constIndex := c.addConstant(e.Value)
|
|
c.emit(types.OpConstant, constIndex)
|
|
|
|
case *parser.NumberLiteral:
|
|
constIndex := c.addConstant(e.Value)
|
|
c.emit(types.OpConstant, constIndex)
|
|
|
|
case *parser.BooleanLiteral:
|
|
constIndex := c.addConstant(e.Value)
|
|
c.emit(types.OpConstant, constIndex)
|
|
|
|
case *parser.Identifier:
|
|
nameIndex := c.addConstant(e.Value)
|
|
|
|
// Check if it's a local variable first
|
|
if c.isLocalVariable(e.Value) {
|
|
c.emit(types.OpGetLocal, nameIndex)
|
|
} else {
|
|
// Otherwise treat as global
|
|
c.emit(types.OpGetGlobal, nameIndex)
|
|
}
|
|
|
|
case *parser.TableLiteral:
|
|
c.emit(types.OpNewTable, 0)
|
|
|
|
for key, value := range e.Pairs {
|
|
c.emit(types.OpDup, 0)
|
|
|
|
// Special handling for identifier keys in tables
|
|
if ident, ok := key.(*parser.Identifier); ok {
|
|
// Treat identifiers as string literals in table keys
|
|
strIndex := c.addConstant(ident.Value)
|
|
c.emit(types.OpConstant, strIndex)
|
|
} else {
|
|
// For other expressions, compile normally
|
|
c.compileExpression(key)
|
|
}
|
|
|
|
c.compileExpression(value)
|
|
c.emit(types.OpSetIndex, 0)
|
|
c.emit(types.OpPop, 0)
|
|
}
|
|
|
|
case *parser.IndexExpression:
|
|
c.compileExpression(e.Left)
|
|
c.compileExpression(e.Index)
|
|
c.emit(types.OpGetIndex, 0)
|
|
|
|
// Arithmetic expressions
|
|
case *parser.InfixExpression:
|
|
// Compile left and right expressions
|
|
c.compileExpression(e.Left)
|
|
c.compileExpression(e.Right)
|
|
|
|
// Generate the appropriate operation
|
|
switch e.Operator {
|
|
case "+":
|
|
c.emit(types.OpAdd, 0)
|
|
case "-":
|
|
c.emit(types.OpSubtract, 0)
|
|
case "*":
|
|
c.emit(types.OpMultiply, 0)
|
|
case "/":
|
|
c.emit(types.OpDivide, 0)
|
|
case "==":
|
|
c.emit(types.OpEqual, 0)
|
|
case "!=":
|
|
c.emit(types.OpNotEqual, 0)
|
|
case "<":
|
|
c.emit(types.OpLessThan, 0)
|
|
case ">":
|
|
c.emit(types.OpGreaterThan, 0)
|
|
case "<=":
|
|
c.emit(types.OpLessEqual, 0)
|
|
case ">=":
|
|
c.emit(types.OpGreaterEqual, 0)
|
|
default:
|
|
panic(fmt.Sprintf("Unknown infix operator: %s", e.Operator))
|
|
}
|
|
|
|
case *parser.PrefixExpression:
|
|
// Compile the operand
|
|
c.compileExpression(e.Right)
|
|
|
|
// Generate the appropriate operation
|
|
switch e.Operator {
|
|
case "-":
|
|
c.emit(types.OpNegate, 0)
|
|
default:
|
|
panic(fmt.Sprintf("Unknown prefix operator: %s", e.Operator))
|
|
}
|
|
|
|
case *parser.GroupedExpression:
|
|
// Just compile the inner expression
|
|
c.compileExpression(e.Expr)
|
|
|
|
case *parser.IfExpression:
|
|
// Compile condition
|
|
c.compileExpression(e.Condition)
|
|
|
|
// Emit jump-if-false with placeholder
|
|
jumpNotTruePos := len(c.instructions)
|
|
c.emit(types.OpJumpIfFalse, 0) // Will backpatch
|
|
|
|
// Compile consequence (then block)
|
|
lastStmtIndex := len(e.Consequence.Statements) - 1
|
|
for i, stmt := range e.Consequence.Statements {
|
|
if i == lastStmtIndex {
|
|
// For the last statement, we need to ensure it leaves a value
|
|
if exprStmt, ok := stmt.(*parser.ExpressionStatement); ok {
|
|
c.compileExpression(exprStmt.Expression)
|
|
} else {
|
|
c.compileStatement(stmt)
|
|
// Push null if not an expression statement
|
|
nullIndex := c.addConstant(nil)
|
|
c.emit(types.OpConstant, nullIndex)
|
|
}
|
|
} else {
|
|
c.compileStatement(stmt)
|
|
}
|
|
}
|
|
|
|
// If no statements, push null
|
|
if len(e.Consequence.Statements) == 0 {
|
|
nullIndex := c.addConstant(nil)
|
|
c.emit(types.OpConstant, nullIndex)
|
|
}
|
|
|
|
// Emit jump to skip else part
|
|
jumpPos := len(c.instructions)
|
|
c.emit(types.OpJump, 0) // Will backpatch
|
|
|
|
// Backpatch jump-if-false to point to else
|
|
afterConsequencePos := len(c.instructions)
|
|
c.instructions[jumpNotTruePos].Operand = afterConsequencePos
|
|
|
|
// Compile alternative (else block)
|
|
if e.Alternative != nil {
|
|
lastStmtIndex = len(e.Alternative.Statements) - 1
|
|
for i, stmt := range e.Alternative.Statements {
|
|
if i == lastStmtIndex {
|
|
// For the last statement, we need to ensure it leaves a value
|
|
if exprStmt, ok := stmt.(*parser.ExpressionStatement); ok {
|
|
c.compileExpression(exprStmt.Expression)
|
|
} else {
|
|
c.compileStatement(stmt)
|
|
// Push null if not an expression statement
|
|
nullIndex := c.addConstant(nil)
|
|
c.emit(types.OpConstant, nullIndex)
|
|
}
|
|
} else {
|
|
c.compileStatement(stmt)
|
|
}
|
|
}
|
|
|
|
// If no statements, push null
|
|
if len(e.Alternative.Statements) == 0 {
|
|
nullIndex := c.addConstant(nil)
|
|
c.emit(types.OpConstant, nullIndex)
|
|
}
|
|
} else {
|
|
// No else - push null
|
|
nullIndex := c.addConstant(nil)
|
|
c.emit(types.OpConstant, nullIndex)
|
|
}
|
|
|
|
// Backpatch jump to point after else
|
|
afterAlternativePos := len(c.instructions)
|
|
c.instructions[jumpPos].Operand = afterAlternativePos
|
|
}
|
|
}
|
|
|
|
func (c *compiler) addConstant(value any) int {
|
|
c.constants = append(c.constants, value)
|
|
return len(c.constants) - 1
|
|
}
|
|
|
|
func (c *compiler) emit(op types.Opcode, operand int) {
|
|
instruction := types.Instruction{
|
|
Opcode: op,
|
|
Operand: operand,
|
|
}
|
|
c.instructions = append(c.instructions, instruction)
|
|
}
|