Mako/vm/vm.go
2025-05-06 23:20:10 -05:00

605 lines
15 KiB
Go

package vm
import (
"fmt"
"git.sharkk.net/Sharkk/Mako/types"
)
// Scope represents a lexical scope
type Scope struct {
Variables map[string]types.Value
}
// Frame represents a call frame on the call stack
type Frame struct {
Function *types.Function // The function being executed
IP int // Instruction pointer
BasePointer int // Base stack pointer for this frame
ReturnAddr int // Return address in the caller
Upvalues []*types.Upvalue // Closed-over variables
}
type VM struct {
constants []any
globals map[string]types.Value
scopes []Scope // Stack of local scopes
stack []types.Value
sp int // Stack pointer
frames []Frame // Call frames
fp int // Frame pointer (index of current frame)
upvalues []*types.Upvalue // Upvalues for closures
}
func New() *VM {
return &VM{
globals: make(map[string]types.Value),
scopes: []Scope{}, // Initially no scopes
stack: make([]types.Value, 1024),
sp: 0,
frames: make([]Frame, 64), // Support up to 64 nested calls
fp: -1, // No active frame yet
upvalues: []*types.Upvalue{},
}
}
// Reset resets the VM to its initial state
// Can be called as vm.Reset() to reuse an existing VM instance
func (vm *VM) Reset() *VM {
vm.constants = nil
vm.globals = make(map[string]types.Value)
vm.scopes = []Scope{}
vm.stack = make([]types.Value, 1024)
vm.sp = 0
vm.frames = make([]Frame, 64)
vm.fp = -1
vm.upvalues = []*types.Upvalue{}
return vm
}
// GetGlobal retrieves a global variable by name
func (vm *VM) GetGlobal(name string) (types.Value, bool) {
val, ok := vm.globals[name]
return val, ok
}
// Global returns all global variables for testing purposes
func (vm *VM) Globals() map[string]types.Value {
return vm.globals
}
// CurrentStack returns the current stack values for testing
func (vm *VM) CurrentStack() []types.Value {
return vm.stack[:vm.sp]
}
func (vm *VM) Run(bytecode *types.Bytecode) {
vm.constants = bytecode.Constants
vm.runCode(bytecode.Instructions, 0)
}
func (vm *VM) runCode(instructions []types.Instruction, basePointer int) types.Value {
for ip := 0; ip < len(instructions); ip++ {
instruction := instructions[ip]
switch instruction.Opcode {
case types.OpConstant:
constIndex := instruction.Operand
constant := vm.constants[constIndex]
switch v := constant.(type) {
case string:
vm.push(types.NewString(v))
case float64:
vm.push(types.NewNumber(v))
case bool:
vm.push(types.NewBoolean(v))
case nil:
vm.push(types.NewNull())
case *types.Function:
vm.push(types.NewFunctionValue(v))
}
case types.OpSetLocal:
constIndex := instruction.Operand
name := vm.constants[constIndex].(string)
value := vm.pop()
// Set in current scope if it exists
if len(vm.scopes) > 0 {
vm.scopes[len(vm.scopes)-1].Variables[name] = value
} else {
// No scope, set as global
vm.globals[name] = value
}
case types.OpGetLocal:
constIndex := instruction.Operand
name := vm.constants[constIndex].(string)
// Check local scopes from innermost to outermost
found := false
for i := len(vm.scopes) - 1; i >= 0; i-- {
if val, ok := vm.scopes[i].Variables[name]; ok {
vm.push(val)
found = true
break
}
}
// If not found in locals, check globals
if !found {
if val, ok := vm.globals[name]; ok {
vm.push(val)
} else {
vm.push(types.NewNull())
}
}
case types.OpSetGlobal:
constIndex := instruction.Operand
name := vm.constants[constIndex].(string)
value := vm.pop()
vm.globals[name] = value
case types.OpGetGlobal:
constIndex := instruction.Operand
name := vm.constants[constIndex].(string)
if val, ok := vm.globals[name]; ok {
vm.push(val)
} else {
vm.push(types.NewNull())
}
case types.OpEnterScope:
// Push a new scope
vm.scopes = append(vm.scopes, Scope{
Variables: make(map[string]types.Value),
})
case types.OpExitScope:
// Pop the current scope
if len(vm.scopes) > 0 {
vm.scopes = vm.scopes[:len(vm.scopes)-1]
}
case types.OpNewTable:
vm.push(types.NewTableValue())
case types.OpSetIndex:
value := vm.pop()
key := vm.pop()
tableVal := vm.pop()
if tableVal.Type != types.TypeTable {
fmt.Println("Error: attempt to index non-table value")
vm.push(types.NewNull())
continue
}
table := tableVal.Data.(*types.Table)
table.Set(key, value)
vm.push(tableVal)
case types.OpGetIndex:
key := vm.pop()
tableVal := vm.pop()
if tableVal.Type != types.TypeTable {
fmt.Println("Error: attempt to index non-table value")
vm.push(types.NewNull())
continue
}
table := tableVal.Data.(*types.Table)
value := table.Get(key)
vm.push(value)
case types.OpDup:
if vm.sp > 0 {
vm.push(vm.stack[vm.sp-1])
}
case types.OpPop:
vm.pop()
case types.OpEcho:
value := vm.pop()
switch value.Type {
case types.TypeString:
fmt.Println(value.Data.(string))
case types.TypeNumber:
fmt.Println(value.Data.(float64))
case types.TypeBoolean:
fmt.Println(value.Data.(bool))
case types.TypeNull:
fmt.Println("nil")
case types.TypeTable:
fmt.Println(vm.formatTable(value.Data.(*types.Table)))
case types.TypeFunction:
fmt.Println("<function>")
}
// Jump instructions
case types.OpJumpIfFalse:
condition := vm.pop()
// Consider falsy: false, null, 0
shouldJump := false
if condition.Type == types.TypeBoolean && !condition.Data.(bool) {
shouldJump = true
} else if condition.Type == types.TypeNull {
shouldJump = true
} else if condition.Type == types.TypeNumber && condition.Data.(float64) == 0 {
shouldJump = true
}
if shouldJump {
ip = instruction.Operand - 1 // -1 because loop will increment
}
case types.OpJump:
ip = instruction.Operand - 1 // -1 because loop will increment
// Function instructions
case types.OpFunction:
constIndex := instruction.Operand
function := vm.constants[constIndex].(*types.Function)
// Use the helper function to create a proper function value
vm.push(types.NewFunctionValue(function))
case types.OpCall:
numArgs := instruction.Operand
// The function is at position sp-numArgs-1
if vm.sp <= numArgs {
fmt.Println("Error: stack underflow during function call")
vm.push(types.NewNull())
continue
}
fnVal := vm.stack[vm.sp-numArgs-1]
if fnVal.Type != types.TypeFunction {
fmt.Printf("Error: attempt to call non-function value\n")
vm.push(types.NewNull())
continue
}
function, ok := fnVal.Data.(*types.Function)
if !ok {
fmt.Printf("Error: function data is invalid\n")
vm.push(types.NewNull())
continue
}
// Check if we have the correct number of arguments
if numArgs != function.NumParams {
fmt.Printf("Error: function expects %d arguments, got %d\n",
function.NumParams, numArgs)
vm.push(types.NewNull())
continue
}
// Create a new call frame
frame := Frame{
Function: function,
IP: 0,
BasePointer: vm.sp - numArgs - 1, // Below the function and args
ReturnAddr: ip,
Upvalues: make([]*types.Upvalue, len(function.UpvalueIndexes)),
}
// Save the current frame
vm.fp++
if vm.fp >= len(vm.frames) {
// Grow the frame stack if needed
newFrames := make([]Frame, len(vm.frames)*2)
copy(newFrames, vm.frames)
vm.frames = newFrames
}
vm.frames[vm.fp] = frame
// Save the current constants
oldConstants := vm.constants
// Switch to function's constants
vm.constants = function.Constants
// Run the function code
returnValue := vm.runCode(function.Instructions, frame.BasePointer)
// Restore the old constants
vm.constants = oldConstants
// Restore state
vm.fp--
// Replace the function with the return value
vm.stack[frame.BasePointer] = returnValue
// Adjust the stack pointer to remove the arguments
vm.sp = frame.BasePointer + 1
case types.OpReturn:
returnValue := vm.pop()
// If we're in a function call, return to the caller
if vm.fp >= 0 {
frame := vm.frames[vm.fp]
// Restore the stack to just below the function
vm.sp = frame.BasePointer
// Push the return value
vm.push(returnValue)
// Return to the caller
return returnValue
}
// Top-level return
vm.push(returnValue)
return returnValue
// Arithmetic operations
case types.OpAdd:
right := vm.pop()
left := vm.pop()
if left.Type == types.TypeNumber && right.Type == types.TypeNumber {
result := left.Data.(float64) + right.Data.(float64)
vm.push(types.NewNumber(result))
} else if left.Type == types.TypeString && right.Type == types.TypeString {
// String concatenation
result := left.Data.(string) + right.Data.(string)
vm.push(types.NewString(result))
} else {
fmt.Println("Error: cannot add values of different types")
vm.push(types.NewNull())
}
case types.OpSubtract:
right := vm.pop()
left := vm.pop()
if left.Type == types.TypeNumber && right.Type == types.TypeNumber {
result := left.Data.(float64) - right.Data.(float64)
vm.push(types.NewNumber(result))
} else {
fmt.Println("Error: cannot subtract non-number values")
vm.push(types.NewNull())
}
case types.OpMultiply:
right := vm.pop()
left := vm.pop()
if left.Type == types.TypeNumber && right.Type == types.TypeNumber {
result := left.Data.(float64) * right.Data.(float64)
vm.push(types.NewNumber(result))
} else {
fmt.Println("Error: cannot multiply non-number values")
vm.push(types.NewNull())
}
case types.OpDivide:
right := vm.pop()
left := vm.pop()
if left.Type == types.TypeNumber && right.Type == types.TypeNumber {
// Check for division by zero
if right.Data.(float64) == 0 {
fmt.Println("Error: division by zero")
vm.push(types.NewNull())
} else {
result := left.Data.(float64) / right.Data.(float64)
vm.push(types.NewNumber(result))
}
} else {
fmt.Println("Error: cannot divide non-number values")
vm.push(types.NewNull())
}
case types.OpNegate:
operand := vm.pop()
if operand.Type == types.TypeNumber {
result := -operand.Data.(float64)
vm.push(types.NewNumber(result))
} else {
fmt.Println("Error: cannot negate non-number value")
vm.push(types.NewNull())
}
case types.OpEqual:
if vm.sp < 2 {
fmt.Println("Error: not enough operands for equality comparison")
vm.push(types.NewBoolean(false))
continue
}
right := vm.pop()
left := vm.pop()
vm.push(types.NewBoolean(left.Equal(right)))
case types.OpNotEqual:
if vm.sp < 2 {
fmt.Println("Error: not enough operands for inequality comparison")
vm.push(types.NewBoolean(true))
continue
}
right := vm.pop()
left := vm.pop()
vm.push(types.NewBoolean(!left.Equal(right)))
case types.OpLessThan:
if vm.sp < 2 {
fmt.Println("Error: not enough operands for less-than comparison")
vm.push(types.NewBoolean(false))
continue
}
// Peek at values first before popping
right := vm.stack[vm.sp-1]
left := vm.stack[vm.sp-2]
if left.Type == types.TypeNumber && right.Type == types.TypeNumber {
// Now pop them
vm.pop()
vm.pop()
vm.push(types.NewBoolean(left.Data.(float64) < right.Data.(float64)))
} else {
// Pop the values to maintain stack balance
vm.pop()
vm.pop()
fmt.Println("Error: cannot compare non-number values with <")
vm.push(types.NewBoolean(false))
}
case types.OpGreaterThan:
if vm.sp < 2 {
fmt.Println("Error: not enough operands for greater-than comparison")
vm.push(types.NewBoolean(false))
continue
}
// Peek at values first before popping
right := vm.stack[vm.sp-1]
left := vm.stack[vm.sp-2]
if left.Type == types.TypeNumber && right.Type == types.TypeNumber {
// Now pop them
vm.pop()
vm.pop()
vm.push(types.NewBoolean(left.Data.(float64) > right.Data.(float64)))
} else {
// Pop the values to maintain stack balance
vm.pop()
vm.pop()
fmt.Println("Error: cannot compare non-number values with >")
vm.push(types.NewBoolean(false))
}
case types.OpLessEqual:
if vm.sp < 2 {
fmt.Println("Error: not enough operands for less-equal comparison")
vm.push(types.NewBoolean(false))
continue
}
// Peek at values first before popping
right := vm.stack[vm.sp-1]
left := vm.stack[vm.sp-2]
if left.Type == types.TypeNumber && right.Type == types.TypeNumber {
// Now pop them
vm.pop()
vm.pop()
vm.push(types.NewBoolean(left.Data.(float64) <= right.Data.(float64)))
} else {
// Pop the values to maintain stack balance
vm.pop()
vm.pop()
fmt.Println("Error: cannot compare non-number values with <=")
vm.push(types.NewBoolean(false))
}
case types.OpGreaterEqual:
if vm.sp < 2 {
fmt.Println("Error: not enough operands for greater-equal comparison")
vm.push(types.NewBoolean(false))
continue
}
// Peek at values first before popping
right := vm.stack[vm.sp-1]
left := vm.stack[vm.sp-2]
if left.Type == types.TypeNumber && right.Type == types.TypeNumber {
// Now pop them
vm.pop()
vm.pop()
vm.push(types.NewBoolean(left.Data.(float64) >= right.Data.(float64)))
} else {
// Pop the values to maintain stack balance
vm.pop()
vm.pop()
fmt.Println("Error: cannot compare non-number values with >=")
vm.push(types.NewBoolean(false))
}
case types.OpNot:
operand := vm.pop()
// Consider falsy: false, null, 0
isFalsy := false
if operand.Type == types.TypeBoolean && !operand.Data.(bool) {
isFalsy = true
} else if operand.Type == types.TypeNull {
isFalsy = true
} else if operand.Type == types.TypeNumber && operand.Data.(float64) == 0 {
isFalsy = true
}
vm.push(types.NewBoolean(isFalsy))
}
}
// Return null for the top-level code when no explicit return is found
return types.NewNull()
}
func (vm *VM) push(value types.Value) {
if vm.sp >= len(vm.stack) {
// Grow stack if needed
newStack := make([]types.Value, len(vm.stack)*2)
copy(newStack, vm.stack)
vm.stack = newStack
}
vm.stack[vm.sp] = value
vm.sp++
}
func (vm *VM) pop() types.Value {
if vm.sp <= 0 {
// Return null instead of causing a panic when trying to pop from an empty stack
fmt.Println("Stack underflow error")
return types.NewNull()
}
vm.sp--
return vm.stack[vm.sp]
}
func (vm *VM) formatTable(table *types.Table) string {
result := "{"
for i, entry := range table.Entries {
result += vm.formatValue(entry.Key) + " = " + vm.formatValue(entry.Value)
if i < len(table.Entries)-1 {
result += ", "
}
}
result += "}"
return result
}
func (vm *VM) formatValue(value types.Value) string {
switch value.Type {
case types.TypeString:
return "\"" + value.Data.(string) + "\""
case types.TypeNumber:
return fmt.Sprintf("%v", value.Data.(float64))
case types.TypeBoolean:
return fmt.Sprintf("%v", value.Data.(bool))
case types.TypeNull:
return "nil"
case types.TypeTable:
return vm.formatTable(value.Data.(*types.Table))
case types.TypeFunction:
return "<function>"
default:
return "unknown"
}
}