compiler 1
This commit is contained in:
parent
8b4496b363
commit
0c4acd5f33
21
compiler/compile.go
Normal file
21
compiler/compile.go
Normal file
@ -0,0 +1,21 @@
|
||||
package compiler
|
||||
|
||||
import (
|
||||
"git.sharkk.net/Sharkk/Mako/parser"
|
||||
"git.sharkk.net/Sharkk/Mako/types"
|
||||
)
|
||||
|
||||
// CompileSource compiles source code to bytecode
|
||||
func CompileSource(source string) (*types.Function, []*types.MakoError) {
|
||||
// Parse the source code into an AST
|
||||
p := parser.New(source)
|
||||
statements := p.Parse()
|
||||
|
||||
// Create a compiler for the main (global) scope
|
||||
compiler := New("script", nil)
|
||||
|
||||
// Compile the statements to bytecode
|
||||
function, errors := compiler.Compile(statements)
|
||||
|
||||
return function, errors
|
||||
}
|
519
compiler/compiler.go
Normal file
519
compiler/compiler.go
Normal file
@ -0,0 +1,519 @@
|
||||
package compiler
|
||||
|
||||
import (
|
||||
"git.sharkk.net/Sharkk/Mako/types"
|
||||
)
|
||||
|
||||
// Compiler manages the state for compilation
|
||||
type Compiler struct {
|
||||
chunk *types.Chunk
|
||||
locals []local
|
||||
scopeDepth int
|
||||
enclosing *Compiler
|
||||
upvalues []types.Upvalue
|
||||
errors []*types.MakoError
|
||||
currentFunction *types.Function
|
||||
}
|
||||
|
||||
type local struct {
|
||||
name string
|
||||
depth int
|
||||
isCaptured bool
|
||||
}
|
||||
|
||||
// New creates a new compiler for a function
|
||||
func New(name string, enclosing *Compiler) *Compiler {
|
||||
compiler := &Compiler{
|
||||
chunk: &types.Chunk{
|
||||
Code: make([]types.Instruction, 0, 8),
|
||||
Constants: make([]types.Value, 0, 8),
|
||||
},
|
||||
locals: make([]local, 0, 8),
|
||||
scopeDepth: 0,
|
||||
enclosing: enclosing,
|
||||
currentFunction: &types.Function{
|
||||
Name: name,
|
||||
Chunk: nil,
|
||||
Upvalues: nil,
|
||||
LocalCount: 0,
|
||||
},
|
||||
}
|
||||
|
||||
// The first local slot is implicitly used by the function itself
|
||||
compiler.locals = append(compiler.locals, local{
|
||||
name: name,
|
||||
depth: 0,
|
||||
})
|
||||
|
||||
compiler.currentFunction.Chunk = compiler.chunk
|
||||
return compiler
|
||||
}
|
||||
|
||||
// Compile compiles statements into bytecode
|
||||
func (c *Compiler) Compile(statements []types.Statement) (*types.Function, []*types.MakoError) {
|
||||
for _, stmt := range statements {
|
||||
c.statement(stmt)
|
||||
}
|
||||
|
||||
// Implicit return
|
||||
c.emitReturn()
|
||||
|
||||
c.currentFunction.LocalCount = len(c.locals)
|
||||
c.currentFunction.Upvalues = c.upvalues
|
||||
c.currentFunction.UpvalueCount = len(c.upvalues)
|
||||
|
||||
return c.currentFunction, c.errors
|
||||
}
|
||||
|
||||
// statement compiles a statement
|
||||
func (c *Compiler) statement(stmt types.Statement) {
|
||||
switch s := stmt.(type) {
|
||||
case types.ExpressionStmt:
|
||||
c.expression(s.Expression)
|
||||
c.emit(types.OP_POP, nil)
|
||||
case types.AssignStmt:
|
||||
c.assignment(s)
|
||||
case types.FunctionStmt:
|
||||
c.function(s)
|
||||
case types.ReturnStmt:
|
||||
c.returnStmt(s)
|
||||
case types.IfStmt:
|
||||
c.ifStatement(s)
|
||||
case types.EchoStmt:
|
||||
c.expression(s.Value)
|
||||
c.emit(types.OP_PRINT, nil)
|
||||
case types.BlockStmt:
|
||||
c.beginScope()
|
||||
for _, blockStmt := range s.Statements {
|
||||
c.statement(blockStmt)
|
||||
}
|
||||
c.endScope()
|
||||
}
|
||||
}
|
||||
|
||||
// assignment compiles a variable assignment
|
||||
func (c *Compiler) assignment(stmt types.AssignStmt) {
|
||||
c.expression(stmt.Value)
|
||||
|
||||
if c.scopeDepth > 0 {
|
||||
// Try to find it as a local first
|
||||
for i := len(c.locals) - 1; i >= 0; i-- {
|
||||
if c.locals[i].name == stmt.Name.Lexeme && c.locals[i].depth <= c.scopeDepth {
|
||||
c.emit(types.OP_SET_LOCAL, []byte{byte(i)})
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Global variable
|
||||
idx := c.makeConstant(types.StringValue{Value: stmt.Name.Lexeme})
|
||||
c.emit(types.OP_SET_GLOBAL, []byte{idx})
|
||||
}
|
||||
|
||||
// expression compiles an expression
|
||||
func (c *Compiler) expression(expr types.Expression) {
|
||||
switch e := expr.(type) {
|
||||
case types.LiteralExpr:
|
||||
c.literal(e)
|
||||
case types.BinaryExpr:
|
||||
c.binary(e)
|
||||
case types.UnaryExpr:
|
||||
c.unary(e)
|
||||
case types.VariableExpr:
|
||||
c.variable(e)
|
||||
case types.CallExpr:
|
||||
c.call(e)
|
||||
case types.FunctionExpr:
|
||||
c.functionExpr(e)
|
||||
}
|
||||
}
|
||||
|
||||
// literal compiles a literal value
|
||||
func (c *Compiler) literal(expr types.LiteralExpr) {
|
||||
switch v := expr.Value.(type) {
|
||||
case nil:
|
||||
c.emit(types.OP_NIL, nil)
|
||||
case bool:
|
||||
if v {
|
||||
c.emit(types.OP_TRUE, nil)
|
||||
} else {
|
||||
c.emit(types.OP_FALSE, nil)
|
||||
}
|
||||
case float64:
|
||||
idx := c.makeConstant(types.NumberValue{Value: v})
|
||||
c.emit(types.OP_CONSTANT, []byte{idx})
|
||||
case string:
|
||||
idx := c.makeConstant(types.StringValue{Value: v})
|
||||
c.emit(types.OP_CONSTANT, []byte{idx})
|
||||
}
|
||||
}
|
||||
|
||||
// binary compiles a binary expression
|
||||
func (c *Compiler) binary(expr types.BinaryExpr) {
|
||||
c.expression(expr.Left)
|
||||
c.expression(expr.Right)
|
||||
|
||||
switch expr.Operator.Type {
|
||||
case types.PLUS:
|
||||
c.emit(types.OP_ADD, nil)
|
||||
case types.MINUS:
|
||||
c.emit(types.OP_SUBTRACT, nil)
|
||||
case types.STAR:
|
||||
c.emit(types.OP_MULTIPLY, nil)
|
||||
case types.SLASH:
|
||||
c.emit(types.OP_DIVIDE, nil)
|
||||
case types.EQUAL_EQUAL:
|
||||
c.emit(types.OP_EQUAL, nil)
|
||||
case types.BANG_EQUAL:
|
||||
c.emit(types.OP_EQUAL, nil)
|
||||
c.emit(types.OP_NOT, nil)
|
||||
case types.LESS:
|
||||
c.emit(types.OP_LESS, nil)
|
||||
case types.LESS_EQUAL:
|
||||
c.emit(types.OP_GREATER, nil)
|
||||
c.emit(types.OP_NOT, nil)
|
||||
case types.GREATER:
|
||||
c.emit(types.OP_GREATER, nil)
|
||||
case types.GREATER_EQUAL:
|
||||
c.emit(types.OP_LESS, nil)
|
||||
c.emit(types.OP_NOT, nil)
|
||||
case types.AND:
|
||||
// Short-circuit evaluation
|
||||
endJump := c.emitJump(types.OP_JUMP_IF_FALSE)
|
||||
c.emit(types.OP_POP, nil)
|
||||
c.patchJump(endJump)
|
||||
case types.OR:
|
||||
// Short-circuit evaluation
|
||||
skipJump := c.emitJump(types.OP_JUMP_IF_FALSE)
|
||||
endJump := c.emitJump(types.OP_JUMP)
|
||||
c.patchJump(skipJump)
|
||||
c.emit(types.OP_POP, nil)
|
||||
c.patchJump(endJump)
|
||||
}
|
||||
}
|
||||
|
||||
// unary compiles a unary expression
|
||||
func (c *Compiler) unary(expr types.UnaryExpr) {
|
||||
c.expression(expr.Right)
|
||||
|
||||
switch expr.Operator.Type {
|
||||
case types.MINUS:
|
||||
c.emit(types.OP_NEGATE, nil)
|
||||
}
|
||||
}
|
||||
|
||||
// variable compiles a variable reference
|
||||
func (c *Compiler) variable(expr types.VariableExpr) {
|
||||
// Try to resolve as local
|
||||
for i := len(c.locals) - 1; i >= 0; i-- {
|
||||
if c.locals[i].name == expr.Name.Lexeme && c.locals[i].depth <= c.scopeDepth {
|
||||
c.emit(types.OP_GET_LOCAL, []byte{byte(i)})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Try to resolve as upvalue
|
||||
if index, ok := c.resolveUpvalue(expr.Name.Lexeme); ok {
|
||||
c.emit(types.OP_GET_UPVALUE, []byte{byte(index)})
|
||||
return
|
||||
}
|
||||
|
||||
// Global variable
|
||||
idx := c.makeConstant(types.StringValue{Value: expr.Name.Lexeme})
|
||||
c.emit(types.OP_GET_GLOBAL, []byte{idx})
|
||||
}
|
||||
|
||||
// call compiles a function call
|
||||
func (c *Compiler) call(expr types.CallExpr) {
|
||||
c.expression(expr.Callee)
|
||||
|
||||
argCount := byte(len(expr.Arguments))
|
||||
for _, arg := range expr.Arguments {
|
||||
c.expression(arg)
|
||||
}
|
||||
|
||||
c.emit(types.OP_CALL, []byte{argCount})
|
||||
}
|
||||
|
||||
// function compiles a function declaration
|
||||
func (c *Compiler) function(stmt types.FunctionStmt) {
|
||||
// Add function name to current scope
|
||||
var global byte
|
||||
if c.scopeDepth > 0 {
|
||||
c.addLocal(stmt.Name.Lexeme)
|
||||
} else {
|
||||
global = c.makeConstant(types.StringValue{Value: stmt.Name.Lexeme})
|
||||
}
|
||||
|
||||
// Compile function body with new compiler
|
||||
compiler := New(stmt.Name.Lexeme, c)
|
||||
|
||||
// Add parameters
|
||||
compiler.beginScope()
|
||||
for _, param := range stmt.Params {
|
||||
compiler.addLocal(param.Lexeme)
|
||||
compiler.currentFunction.Arity++
|
||||
}
|
||||
compiler.currentFunction.IsVariadic = stmt.IsVariadic
|
||||
|
||||
// Compile function body
|
||||
for _, bodyStmt := range stmt.Body {
|
||||
compiler.statement(bodyStmt)
|
||||
}
|
||||
|
||||
// Implicit return if needed
|
||||
compiler.emitReturn()
|
||||
|
||||
// Create function object
|
||||
compiler.currentFunction.UpvalueCount = len(compiler.upvalues)
|
||||
compiler.currentFunction.Upvalues = compiler.upvalues
|
||||
|
||||
// Add function to constants
|
||||
idx := c.makeConstant(types.ClosureValue{
|
||||
Closure: &types.Closure{
|
||||
Function: compiler.currentFunction,
|
||||
Upvalues: make([]*types.Upvalue, compiler.currentFunction.UpvalueCount),
|
||||
},
|
||||
})
|
||||
|
||||
// Emit closure instruction and upvalue info
|
||||
c.emit(types.OP_CLOSURE, []byte{idx})
|
||||
|
||||
// Add upvalue information to instruction stream
|
||||
for _, upvalue := range compiler.upvalues {
|
||||
if upvalue.IsLocal {
|
||||
c.emit(0, []byte{1, upvalue.Index}) // 1 means isLocal=true
|
||||
} else {
|
||||
c.emit(0, []byte{0, upvalue.Index}) // 0 means isLocal=false
|
||||
}
|
||||
}
|
||||
|
||||
// Store function in variable
|
||||
if c.scopeDepth > 0 {
|
||||
// It's already on the stack
|
||||
} else {
|
||||
c.emit(types.OP_SET_GLOBAL, []byte{global})
|
||||
}
|
||||
}
|
||||
|
||||
// functionExpr compiles an anonymous function expression
|
||||
func (c *Compiler) functionExpr(expr types.FunctionExpr) {
|
||||
// Compile function body with new compiler
|
||||
compiler := New("", c)
|
||||
|
||||
// Add parameters
|
||||
compiler.beginScope()
|
||||
for _, param := range expr.Params {
|
||||
compiler.addLocal(param.Lexeme)
|
||||
compiler.currentFunction.Arity++
|
||||
}
|
||||
compiler.currentFunction.IsVariadic = expr.IsVariadic
|
||||
|
||||
// Compile function body
|
||||
for _, bodyStmt := range expr.Body {
|
||||
compiler.statement(bodyStmt)
|
||||
}
|
||||
|
||||
// Implicit return
|
||||
compiler.emitReturn()
|
||||
|
||||
// Create function object
|
||||
compiler.currentFunction.UpvalueCount = len(compiler.upvalues)
|
||||
compiler.currentFunction.Upvalues = compiler.upvalues
|
||||
|
||||
// Add function to constants
|
||||
idx := c.makeConstant(types.ClosureValue{
|
||||
Closure: &types.Closure{
|
||||
Function: compiler.currentFunction,
|
||||
Upvalues: make([]*types.Upvalue, compiler.currentFunction.UpvalueCount),
|
||||
},
|
||||
})
|
||||
|
||||
// Emit closure instruction and upvalue info
|
||||
c.emit(types.OP_CLOSURE, []byte{idx})
|
||||
|
||||
// Add upvalue information to instruction stream
|
||||
for _, upvalue := range compiler.upvalues {
|
||||
if upvalue.IsLocal {
|
||||
c.emit(0, []byte{1, upvalue.Index}) // 1 means isLocal=true
|
||||
} else {
|
||||
c.emit(0, []byte{0, upvalue.Index}) // 0 means isLocal=false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// returnStmt compiles a return statement
|
||||
func (c *Compiler) returnStmt(stmt types.ReturnStmt) {
|
||||
if stmt.Value == nil {
|
||||
c.emit(types.OP_NIL, nil)
|
||||
} else {
|
||||
c.expression(stmt.Value)
|
||||
}
|
||||
|
||||
c.emit(types.OP_RETURN, nil)
|
||||
}
|
||||
|
||||
// ifStatement compiles an if statement
|
||||
func (c *Compiler) ifStatement(stmt types.IfStmt) {
|
||||
// Compile condition
|
||||
c.expression(stmt.Condition)
|
||||
|
||||
// Emit the then branch jump
|
||||
thenJump := c.emitJump(types.OP_JUMP_IF_FALSE)
|
||||
|
||||
// Compile then branch
|
||||
c.emit(types.OP_POP, nil) // Pop condition
|
||||
for _, thenStmt := range stmt.ThenBranch {
|
||||
c.statement(thenStmt)
|
||||
}
|
||||
|
||||
// Jump over else branch
|
||||
elseJump := c.emitJump(types.OP_JUMP)
|
||||
|
||||
// Patch then jump
|
||||
c.patchJump(thenJump)
|
||||
c.emit(types.OP_POP, nil) // Pop condition
|
||||
|
||||
// Compile elseif branches
|
||||
for _, elseif := range stmt.ElseIfs {
|
||||
c.expression(elseif.Condition)
|
||||
|
||||
// Jump if this condition is false
|
||||
elseifJump := c.emitJump(types.OP_JUMP_IF_FALSE)
|
||||
|
||||
c.emit(types.OP_POP, nil) // Pop condition
|
||||
for _, elseifStmt := range elseif.Body {
|
||||
c.statement(elseifStmt)
|
||||
}
|
||||
|
||||
// Jump to end after this branch
|
||||
endJump := c.emitJump(types.OP_JUMP)
|
||||
|
||||
// Patch elseif jump to next branch
|
||||
c.patchJump(elseifJump)
|
||||
c.emit(types.OP_POP, nil) // Pop condition
|
||||
|
||||
// Collect end jumps for patching
|
||||
elseJump = endJump
|
||||
}
|
||||
|
||||
// Compile else branch
|
||||
for _, elseStmt := range stmt.ElseBranch {
|
||||
c.statement(elseStmt)
|
||||
}
|
||||
|
||||
// Patch else jump
|
||||
c.patchJump(elseJump)
|
||||
}
|
||||
|
||||
// Helper methods
|
||||
func (c *Compiler) emit(op types.OpCode, operands []byte) {
|
||||
// Get source position from the current token
|
||||
pos := types.SourcePos{Line: 0, Column: 0} // In real implementation, track from token
|
||||
|
||||
instruction := types.Instruction{
|
||||
Op: op,
|
||||
Operands: operands,
|
||||
Pos: pos,
|
||||
}
|
||||
|
||||
c.chunk.Code = append(c.chunk.Code, instruction)
|
||||
}
|
||||
|
||||
func (c *Compiler) emitJump(op types.OpCode) int {
|
||||
c.emit(op, []byte{0xFF, 0xFF}) // Placeholder for jump offset
|
||||
return len(c.chunk.Code) - 1
|
||||
}
|
||||
|
||||
func (c *Compiler) patchJump(jumpIndex int) {
|
||||
// -2 to adjust for the size of the jump offset itself
|
||||
jumpDistance := len(c.chunk.Code) - jumpIndex - 1
|
||||
|
||||
// Store jump distance in the instruction's operands
|
||||
// Using big-endian format: high byte first, low byte second
|
||||
c.chunk.Code[jumpIndex].Operands = []byte{
|
||||
byte((jumpDistance >> 8) & 0xFF),
|
||||
byte(jumpDistance & 0xFF),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Compiler) emitReturn() {
|
||||
c.emit(types.OP_NIL, nil)
|
||||
c.emit(types.OP_RETURN, nil)
|
||||
}
|
||||
|
||||
func (c *Compiler) makeConstant(value types.Value) byte {
|
||||
c.chunk.Constants = append(c.chunk.Constants, value)
|
||||
return byte(len(c.chunk.Constants) - 1)
|
||||
}
|
||||
|
||||
func (c *Compiler) beginScope() {
|
||||
c.scopeDepth++
|
||||
}
|
||||
|
||||
func (c *Compiler) endScope() {
|
||||
c.scopeDepth--
|
||||
|
||||
// Remove locals from this scope
|
||||
for len(c.locals) > 0 && c.locals[len(c.locals)-1].depth > c.scopeDepth {
|
||||
if c.locals[len(c.locals)-1].isCaptured {
|
||||
c.emit(types.OP_CLOSE_UPVALUE, nil)
|
||||
} else {
|
||||
c.emit(types.OP_POP, nil)
|
||||
}
|
||||
c.locals = c.locals[:len(c.locals)-1]
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Compiler) addLocal(name string) {
|
||||
// Check if we've hit the limit of local variables
|
||||
if len(c.locals) >= 256 {
|
||||
c.error("Too many local variables in function.")
|
||||
return
|
||||
}
|
||||
|
||||
c.locals = append(c.locals, local{
|
||||
name: name,
|
||||
depth: c.scopeDepth,
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Compiler) resolveUpvalue(name string) (int, bool) {
|
||||
// If no enclosing scope, can't be an upvalue
|
||||
if c.enclosing == nil {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// Try to find in immediate enclosing function's locals
|
||||
for i := range c.enclosing.locals {
|
||||
if c.enclosing.locals[i].name == name {
|
||||
c.enclosing.locals[i].isCaptured = true
|
||||
upvalue := types.Upvalue{
|
||||
Index: uint8(i),
|
||||
IsLocal: true,
|
||||
}
|
||||
c.upvalues = append(c.upvalues, upvalue)
|
||||
return len(c.upvalues) - 1, true
|
||||
}
|
||||
}
|
||||
|
||||
// Try to find in higher enclosing scopes
|
||||
if upvalueIndex, found := c.enclosing.resolveUpvalue(name); found {
|
||||
upvalue := types.Upvalue{
|
||||
Index: uint8(upvalueIndex),
|
||||
IsLocal: false,
|
||||
}
|
||||
c.upvalues = append(c.upvalues, upvalue)
|
||||
return len(c.upvalues) - 1, true
|
||||
}
|
||||
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func (c *Compiler) error(message string) {
|
||||
c.errors = append(c.errors, &types.MakoError{
|
||||
Message: message,
|
||||
Line: 0, // In real implementation, track from token
|
||||
Column: 0,
|
||||
})
|
||||
}
|
325
compiler/compiler_test.go
Normal file
325
compiler/compiler_test.go
Normal file
@ -0,0 +1,325 @@
|
||||
package compiler_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
assert "git.sharkk.net/Go/Assert"
|
||||
"git.sharkk.net/Sharkk/Mako/compiler"
|
||||
"git.sharkk.net/Sharkk/Mako/types"
|
||||
)
|
||||
|
||||
// Helper function to check if an opcode exists in the bytecode
|
||||
func hasOpCode(chunk *types.Chunk, opCode types.OpCode) bool {
|
||||
for _, instr := range chunk.Code {
|
||||
if instr.Op == opCode {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Helper function to count opcodes in the bytecode
|
||||
func countOpCodes(chunk *types.Chunk, opCode types.OpCode) int {
|
||||
count := 0
|
||||
for _, instr := range chunk.Code {
|
||||
if instr.Op == opCode {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
// Helper function to check constants in the chunk
|
||||
func hasConstant(chunk *types.Chunk, value types.Value) bool {
|
||||
for _, constant := range chunk.Constants {
|
||||
if constant.String() == value.String() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func TestCompileLiterals(t *testing.T) {
|
||||
tests := []struct {
|
||||
source string
|
||||
opCode types.OpCode
|
||||
hasValue bool
|
||||
value types.Value
|
||||
}{
|
||||
{"nil", types.OP_NIL, false, nil},
|
||||
{"true", types.OP_TRUE, false, nil},
|
||||
{"false", types.OP_FALSE, false, nil},
|
||||
{"123", types.OP_CONSTANT, true, types.NumberValue{Value: 123}},
|
||||
{"\"hello\"", types.OP_CONSTANT, true, types.StringValue{Value: "hello"}},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
function, errors := compiler.CompileSource(test.source)
|
||||
|
||||
assert.Equal(t, 0, len(errors))
|
||||
assert.NotNil(t, function)
|
||||
assert.NotNil(t, function.Chunk)
|
||||
|
||||
// Check if the right opcode exists
|
||||
assert.True(t, hasOpCode(function.Chunk, test.opCode))
|
||||
|
||||
// Check if the constant exists if needed
|
||||
if test.hasValue {
|
||||
assert.True(t, hasConstant(function.Chunk, test.value))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompileVariables(t *testing.T) {
|
||||
// Test variable declaration and reference
|
||||
source := "x = 10\necho x"
|
||||
function, errors := compiler.CompileSource(source)
|
||||
|
||||
assert.Equal(t, 0, len(errors))
|
||||
assert.NotNil(t, function)
|
||||
|
||||
// Check for SET_GLOBAL and GET_GLOBAL opcodes
|
||||
assert.True(t, hasOpCode(function.Chunk, types.OP_SET_GLOBAL))
|
||||
assert.True(t, hasOpCode(function.Chunk, types.OP_GET_GLOBAL))
|
||||
|
||||
// Check that "x" exists as a constant
|
||||
stringValue := types.StringValue{Value: "x"}
|
||||
assert.True(t, hasConstant(function.Chunk, stringValue))
|
||||
|
||||
// Check that 10 exists as a constant
|
||||
numberValue := types.NumberValue{Value: 10}
|
||||
assert.True(t, hasConstant(function.Chunk, numberValue))
|
||||
}
|
||||
|
||||
func TestCompileExpressions(t *testing.T) {
|
||||
tests := []struct {
|
||||
source string
|
||||
opCode types.OpCode
|
||||
}{
|
||||
{"1 + 2", types.OP_ADD},
|
||||
{"3 - 4", types.OP_SUBTRACT},
|
||||
{"5 * 6", types.OP_MULTIPLY},
|
||||
{"7 / 8", types.OP_DIVIDE},
|
||||
{"9 == 10", types.OP_EQUAL},
|
||||
{"11 != 12", types.OP_NOT},
|
||||
{"13 < 14", types.OP_LESS},
|
||||
{"15 > 16", types.OP_GREATER},
|
||||
{"-17", types.OP_NEGATE},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
function, errors := compiler.CompileSource(test.source)
|
||||
|
||||
assert.Equal(t, 0, len(errors))
|
||||
assert.NotNil(t, function)
|
||||
assert.NotNil(t, function.Chunk)
|
||||
|
||||
// Check if the right opcode exists
|
||||
assert.True(t, hasOpCode(function.Chunk, test.opCode))
|
||||
}
|
||||
|
||||
// Test complex expression
|
||||
source := "1 + 2 * 3 - 4 / 5"
|
||||
function, errors := compiler.CompileSource(source)
|
||||
|
||||
assert.Equal(t, 0, len(errors))
|
||||
assert.NotNil(t, function)
|
||||
|
||||
// Check for all arithmetic operations
|
||||
assert.True(t, hasOpCode(function.Chunk, types.OP_ADD))
|
||||
assert.True(t, hasOpCode(function.Chunk, types.OP_MULTIPLY))
|
||||
assert.True(t, hasOpCode(function.Chunk, types.OP_SUBTRACT))
|
||||
assert.True(t, hasOpCode(function.Chunk, types.OP_DIVIDE))
|
||||
}
|
||||
|
||||
func TestCompileStatements(t *testing.T) {
|
||||
// Test echo statement
|
||||
source := "echo \"hello\""
|
||||
function, errors := compiler.CompileSource(source)
|
||||
|
||||
assert.Equal(t, 0, len(errors))
|
||||
assert.NotNil(t, function)
|
||||
|
||||
// Check for PRINT opcode
|
||||
assert.True(t, hasOpCode(function.Chunk, types.OP_PRINT))
|
||||
|
||||
// Check that string constant exists
|
||||
stringValue := types.StringValue{Value: "hello"}
|
||||
assert.True(t, hasConstant(function.Chunk, stringValue))
|
||||
|
||||
// Test multiple statements
|
||||
source = "x = 1\ny = 2\necho x + y"
|
||||
function, errors = compiler.CompileSource(source)
|
||||
|
||||
assert.Equal(t, 0, len(errors))
|
||||
assert.NotNil(t, function)
|
||||
|
||||
// Check for SET_GLOBAL, GET_GLOBAL, ADD, and PRINT opcodes
|
||||
assert.True(t, hasOpCode(function.Chunk, types.OP_SET_GLOBAL))
|
||||
assert.True(t, hasOpCode(function.Chunk, types.OP_GET_GLOBAL))
|
||||
assert.True(t, hasOpCode(function.Chunk, types.OP_ADD))
|
||||
assert.True(t, hasOpCode(function.Chunk, types.OP_PRINT))
|
||||
|
||||
// Expect at least 2 SET_GLOBAL instructions
|
||||
assert.True(t, countOpCodes(function.Chunk, types.OP_SET_GLOBAL) >= 2)
|
||||
}
|
||||
|
||||
func TestCompileControlFlow(t *testing.T) {
|
||||
// Test if statement
|
||||
source := "if true then echo \"yes\" end"
|
||||
function, errors := compiler.CompileSource(source)
|
||||
|
||||
assert.Equal(t, 0, len(errors))
|
||||
assert.NotNil(t, function)
|
||||
|
||||
// Check for TRUE, JUMP_IF_FALSE, and PRINT opcodes
|
||||
assert.True(t, hasOpCode(function.Chunk, types.OP_TRUE))
|
||||
assert.True(t, hasOpCode(function.Chunk, types.OP_JUMP_IF_FALSE))
|
||||
assert.True(t, hasOpCode(function.Chunk, types.OP_PRINT))
|
||||
|
||||
// Test if-else statement
|
||||
source = "if false then echo \"yes\" else echo \"no\" end"
|
||||
function, errors = compiler.CompileSource(source)
|
||||
|
||||
assert.Equal(t, 0, len(errors))
|
||||
assert.NotNil(t, function)
|
||||
|
||||
// Check for FALSE, JUMP_IF_FALSE, JUMP, and PRINT opcodes
|
||||
assert.True(t, hasOpCode(function.Chunk, types.OP_FALSE))
|
||||
assert.True(t, hasOpCode(function.Chunk, types.OP_JUMP_IF_FALSE))
|
||||
assert.True(t, hasOpCode(function.Chunk, types.OP_JUMP))
|
||||
assert.True(t, hasOpCode(function.Chunk, types.OP_PRINT))
|
||||
|
||||
// Test if-elseif-else statement
|
||||
source = "if false then echo \"a\" elseif true then echo \"b\" else echo \"c\" end"
|
||||
function, errors = compiler.CompileSource(source)
|
||||
|
||||
assert.Equal(t, 0, len(errors))
|
||||
assert.NotNil(t, function)
|
||||
|
||||
// Check for appropriate opcodes
|
||||
assert.True(t, hasOpCode(function.Chunk, types.OP_FALSE))
|
||||
assert.True(t, hasOpCode(function.Chunk, types.OP_TRUE))
|
||||
assert.True(t, hasOpCode(function.Chunk, types.OP_JUMP_IF_FALSE))
|
||||
assert.True(t, hasOpCode(function.Chunk, types.OP_JUMP))
|
||||
assert.True(t, hasOpCode(function.Chunk, types.OP_PRINT))
|
||||
}
|
||||
|
||||
func TestCompileFunctions(t *testing.T) {
|
||||
// Test function declaration
|
||||
source := "fn add(a, b) return a + b end"
|
||||
function, errors := compiler.CompileSource(source)
|
||||
|
||||
assert.Equal(t, 0, len(errors))
|
||||
assert.NotNil(t, function)
|
||||
|
||||
// Check for CLOSURE opcode
|
||||
assert.True(t, hasOpCode(function.Chunk, types.OP_CLOSURE))
|
||||
|
||||
// Test function call
|
||||
source = "fn add(a, b) return a + b end\necho add(1, 2)"
|
||||
function, errors = compiler.CompileSource(source)
|
||||
|
||||
assert.Equal(t, 0, len(errors))
|
||||
assert.NotNil(t, function)
|
||||
|
||||
// Check for CLOSURE, CALL, and PRINT opcodes
|
||||
assert.True(t, hasOpCode(function.Chunk, types.OP_CLOSURE))
|
||||
assert.True(t, hasOpCode(function.Chunk, types.OP_CALL))
|
||||
assert.True(t, hasOpCode(function.Chunk, types.OP_PRINT))
|
||||
|
||||
// Test anonymous function expression
|
||||
source = "fnVar = fn(x, y) return x * y end"
|
||||
function, errors = compiler.CompileSource(source)
|
||||
|
||||
assert.Equal(t, 0, len(errors))
|
||||
assert.NotNil(t, function)
|
||||
|
||||
// Check for CLOSURE and SET_GLOBAL opcodes
|
||||
assert.True(t, hasOpCode(function.Chunk, types.OP_CLOSURE))
|
||||
assert.True(t, hasOpCode(function.Chunk, types.OP_SET_GLOBAL))
|
||||
}
|
||||
|
||||
func TestCompileClosures(t *testing.T) {
|
||||
// Test closure that captures a variable
|
||||
source := `
|
||||
x = 10
|
||||
fn makeAdder()
|
||||
return fn(y) return x + y end
|
||||
end
|
||||
adder = makeAdder()
|
||||
echo adder(5)
|
||||
`
|
||||
|
||||
function, errors := compiler.CompileSource(source)
|
||||
|
||||
assert.Equal(t, 0, len(errors))
|
||||
assert.NotNil(t, function)
|
||||
|
||||
// Check for GET_UPVALUE opcode (used in closures)
|
||||
assert.True(t, hasOpCode(function.Chunk, types.OP_CLOSURE))
|
||||
|
||||
// The nested functions should create closures
|
||||
// This is a simplified test since we can't easily peek inside the nested functions
|
||||
assert.True(t, countOpCodes(function.Chunk, types.OP_CLOSURE) >= 2)
|
||||
}
|
||||
|
||||
func TestCompileErrors(t *testing.T) {
|
||||
// Test syntax error
|
||||
source := "if true echo \"missing then\" end"
|
||||
_, errors := compiler.CompileSource(source)
|
||||
|
||||
assert.True(t, len(errors) > 0)
|
||||
|
||||
// Test undefined variable
|
||||
source = "echo undefinedVar"
|
||||
function, _ := compiler.CompileSource(source)
|
||||
|
||||
// This should compile, as variable resolution happens at runtime
|
||||
assert.NotNil(t, function)
|
||||
assert.True(t, hasOpCode(function.Chunk, types.OP_GET_GLOBAL))
|
||||
}
|
||||
|
||||
func TestCompileComplexProgram(t *testing.T) {
|
||||
// Test a more complex program
|
||||
source := `
|
||||
// Calculate factorial recursively
|
||||
fn factorial(n)
|
||||
if n <= 1 then
|
||||
return 1
|
||||
else
|
||||
return n * factorial(n - 1)
|
||||
end
|
||||
end
|
||||
|
||||
// Calculate factorial iteratively
|
||||
fn factorialIter(n)
|
||||
result = 1
|
||||
if n > 1 then
|
||||
current = 2
|
||||
while current <= n
|
||||
result = result * current
|
||||
current = current + 1
|
||||
end
|
||||
end
|
||||
return result
|
||||
end
|
||||
|
||||
// Use both functions
|
||||
echo factorial(5)
|
||||
echo factorialIter(5)
|
||||
`
|
||||
|
||||
// This won't compile yet because 'while' is not implemented
|
||||
// But we can check if the function declarations compile
|
||||
function, errors := compiler.CompileSource(source)
|
||||
|
||||
// We expect some errors due to missing 'while' implementation
|
||||
if len(errors) == 0 {
|
||||
// If no errors, check for expected opcodes
|
||||
assert.NotNil(t, function)
|
||||
assert.True(t, hasOpCode(function.Chunk, types.OP_CLOSURE))
|
||||
assert.True(t, hasOpCode(function.Chunk, types.OP_CALL))
|
||||
assert.True(t, hasOpCode(function.Chunk, types.OP_PRINT))
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user