Mako/compiler/compiler.go

373 lines
8.9 KiB
Go

package compiler
import (
"fmt"
"git.sharkk.net/Sharkk/Mako/parser"
"git.sharkk.net/Sharkk/Mako/types"
)
// Compile converts AST to bytecode
func Compile(program *parser.Program) *types.Bytecode {
c := &compiler{
constants: []any{},
instructions: []types.Instruction{},
scopes: []scope{},
}
// Start in global scope
c.enterScope()
// Add nil check for program
if program == nil {
c.exitScope()
return &types.Bytecode{
Constants: c.constants,
Instructions: c.instructions,
}
}
// Process each statement safely
for _, stmt := range program.Statements {
// Skip nil statements
if stmt == nil {
continue
}
c.compileStatement(stmt)
}
c.exitScope()
return &types.Bytecode{
Constants: c.constants,
Instructions: c.instructions,
}
}
type scope struct {
variables map[string]bool
}
type compiler struct {
constants []any
instructions []types.Instruction
scopes []scope
}
func (c *compiler) enterScope() {
c.scopes = append(c.scopes, scope{
variables: make(map[string]bool),
})
c.emit(types.OpEnterScope, 0)
}
func (c *compiler) exitScope() {
c.scopes = c.scopes[:len(c.scopes)-1]
c.emit(types.OpExitScope, 0)
}
func (c *compiler) declareVariable(name string) {
if len(c.scopes) > 0 {
c.scopes[len(c.scopes)-1].variables[name] = true
}
}
func (c *compiler) isLocalVariable(name string) bool {
for i := len(c.scopes) - 1; i >= 0; i-- {
if _, ok := c.scopes[i].variables[name]; ok {
return true
}
}
return false
}
func (c *compiler) compileStatement(stmt parser.Statement) {
if stmt == nil {
return
}
switch s := stmt.(type) {
case *parser.VariableStatement:
c.compileExpression(s.Value)
nameIndex := c.addConstant(s.Name.Value)
// Use SetGlobal for top-level variables to persist between REPL lines
if len(c.scopes) <= 1 {
c.emit(types.OpSetGlobal, nameIndex)
} else {
c.declareVariable(s.Name.Value)
c.emit(types.OpSetLocal, nameIndex)
}
case *parser.IndexAssignmentStatement:
c.compileExpression(s.Left)
c.compileExpression(s.Index)
c.compileExpression(s.Value)
c.emit(types.OpSetIndex, 0)
case *parser.EchoStatement:
c.compileExpression(s.Value)
c.emit(types.OpEcho, 0)
// BlockStatement now should only be used for keyword blocks like if-then-else-end
case *parser.BlockStatement:
for _, blockStmt := range s.Statements {
c.compileStatement(blockStmt)
}
case *parser.ExpressionStatement:
c.compileExpression(s.Expression)
// Pop the value since we're not using it
c.emit(types.OpPop, 0)
}
}
func (c *compiler) compileExpression(expr parser.Expression) {
switch e := expr.(type) {
case *parser.StringLiteral:
constIndex := c.addConstant(e.Value)
c.emit(types.OpConstant, constIndex)
case *parser.NumberLiteral:
constIndex := c.addConstant(e.Value)
c.emit(types.OpConstant, constIndex)
case *parser.BooleanLiteral:
constIndex := c.addConstant(e.Value)
c.emit(types.OpConstant, constIndex)
case *parser.Identifier:
nameIndex := c.addConstant(e.Value)
// Check if it's a local variable first
if c.isLocalVariable(e.Value) {
c.emit(types.OpGetLocal, nameIndex)
} else {
// Otherwise treat as global
c.emit(types.OpGetGlobal, nameIndex)
}
case *parser.TableLiteral:
c.emit(types.OpNewTable, 0)
for key, value := range e.Pairs {
c.emit(types.OpDup, 0)
// Special handling for identifier keys in tables
if ident, ok := key.(*parser.Identifier); ok {
// Treat identifiers as string literals in table keys
strIndex := c.addConstant(ident.Value)
c.emit(types.OpConstant, strIndex)
} else {
// For other expressions, compile normally
c.compileExpression(key)
}
c.compileExpression(value)
c.emit(types.OpSetIndex, 0)
c.emit(types.OpPop, 0)
}
case *parser.IndexExpression:
c.compileExpression(e.Left)
c.compileExpression(e.Index)
c.emit(types.OpGetIndex, 0)
// Arithmetic expressions
case *parser.InfixExpression:
switch e.Operator {
case "and":
// Compile left operand
c.compileExpression(e.Left)
// Duplicate to check condition
c.emit(types.OpDup, 0)
// Jump if false (short-circuit)
jumpFalsePos := len(c.instructions)
c.emit(types.OpJumpIfFalse, 0) // Will backpatch
// Pop the duplicate since we'll replace it
c.emit(types.OpPop, 0)
// Compile right operand
c.compileExpression(e.Right)
// Jump target for short-circuit
endPos := len(c.instructions)
c.instructions[jumpFalsePos].Operand = endPos
case "or":
// Compile left operand
c.compileExpression(e.Left)
// Duplicate to check condition
c.emit(types.OpDup, 0)
// Need to check if it's truthy to short-circuit
falseJumpPos := len(c.instructions)
c.emit(types.OpJumpIfFalse, 0) // Jump to right eval if false
// If truthy, jump to end
trueJumpPos := len(c.instructions)
c.emit(types.OpJump, 0) // Jump to end if true
// Position for false case
falsePos := len(c.instructions)
c.instructions[falseJumpPos].Operand = falsePos
// Pop the duplicate since we'll replace it
c.emit(types.OpPop, 0)
// Compile right operand
c.compileExpression(e.Right)
// End position
endPos := len(c.instructions)
c.instructions[trueJumpPos].Operand = endPos
default:
// Original infix expression compilation
c.compileExpression(e.Left)
c.compileExpression(e.Right)
// Generate the appropriate operation
switch e.Operator {
case "+":
c.emit(types.OpAdd, 0)
case "-":
c.emit(types.OpSubtract, 0)
case "*":
c.emit(types.OpMultiply, 0)
case "/":
c.emit(types.OpDivide, 0)
case "==":
c.emit(types.OpEqual, 0)
case "!=":
c.emit(types.OpNotEqual, 0)
case "<":
c.emit(types.OpLessThan, 0)
case ">":
c.emit(types.OpGreaterThan, 0)
case "<=":
c.emit(types.OpLessEqual, 0)
case ">=":
c.emit(types.OpGreaterEqual, 0)
default:
panic(fmt.Sprintf("Unknown infix operator: %s", e.Operator))
}
}
case *parser.PrefixExpression:
// Compile the operand
c.compileExpression(e.Right)
// Generate the appropriate operation
switch e.Operator {
case "-":
c.emit(types.OpNegate, 0)
case "not":
c.emit(types.OpNot, 0)
default:
panic(fmt.Sprintf("Unknown prefix operator: %s", e.Operator))
}
case *parser.GroupedExpression:
// Just compile the inner expression
c.compileExpression(e.Expr)
case *parser.IfExpression:
// Compile condition
c.compileExpression(e.Condition)
// Emit jump-if-false with placeholder
jumpNotTruePos := len(c.instructions)
c.emit(types.OpJumpIfFalse, 0) // Will backpatch
// Compile consequence (then block)
if e.Consequence != nil {
lastStmtIndex := len(e.Consequence.Statements) - 1
for i, stmt := range e.Consequence.Statements {
if i == lastStmtIndex {
// For the last statement, we need to ensure it leaves a value
if exprStmt, ok := stmt.(*parser.ExpressionStatement); ok {
c.compileExpression(exprStmt.Expression)
} else {
c.compileStatement(stmt)
// Push null if not an expression statement
nullIndex := c.addConstant(nil)
c.emit(types.OpConstant, nullIndex)
}
} else {
c.compileStatement(stmt)
}
}
// If no statements, push null
if len(e.Consequence.Statements) == 0 {
nullIndex := c.addConstant(nil)
c.emit(types.OpConstant, nullIndex)
}
} else {
// No consequence block, push null
nullIndex := c.addConstant(nil)
c.emit(types.OpConstant, nullIndex)
}
// Emit jump to skip else part
jumpPos := len(c.instructions)
c.emit(types.OpJump, 0) // Will backpatch
// Backpatch jump-if-false to point to else
afterConsequencePos := len(c.instructions)
c.instructions[jumpNotTruePos].Operand = afterConsequencePos
// Compile alternative (else block)
if e.Alternative != nil {
lastStmtIndex := len(e.Alternative.Statements) - 1
for i, stmt := range e.Alternative.Statements {
if i == lastStmtIndex {
// For the last statement, we need to ensure it leaves a value
if exprStmt, ok := stmt.(*parser.ExpressionStatement); ok {
c.compileExpression(exprStmt.Expression)
} else {
c.compileStatement(stmt)
// Push null if not an expression statement
nullIndex := c.addConstant(nil)
c.emit(types.OpConstant, nullIndex)
}
} else {
c.compileStatement(stmt)
}
}
// If no statements, push null
if len(e.Alternative.Statements) == 0 {
nullIndex := c.addConstant(nil)
c.emit(types.OpConstant, nullIndex)
}
} else {
// No else - push null
nullIndex := c.addConstant(nil)
c.emit(types.OpConstant, nullIndex)
}
// Backpatch jump to point after else
afterAlternativePos := len(c.instructions)
c.instructions[jumpPos].Operand = afterAlternativePos
}
}
func (c *compiler) addConstant(value any) int {
c.constants = append(c.constants, value)
return len(c.constants) - 1
}
func (c *compiler) emit(op types.Opcode, operand int) {
instruction := types.Instruction{
Opcode: op,
Operand: operand,
}
c.instructions = append(c.instructions, instruction)
}