functions 2

This commit is contained in:
Sky Johnson 2025-05-06 23:20:10 -05:00
parent 306f34cb73
commit 29707b4b02
5 changed files with 308 additions and 50 deletions

View File

@ -135,6 +135,10 @@ func (c *compiler) compileStatement(stmt parser.Statement) {
}
c.emit(types.OpReturn, 0)
case *parser.FunctionStatement:
// Use the dedicated function for function statements
c.compileFunctionDeclaration(s)
// BlockStatement now should only be used for keyword blocks like if-then-else-end
case *parser.BlockStatement:
for _, blockStmt := range s.Statements {
@ -207,14 +211,14 @@ func (c *compiler) compileExpression(expr parser.Expression) {
c.compileFunctionLiteral(e)
case *parser.CallExpression:
// Compile the arguments first
// Compile the function expression first
c.compileExpression(e.Function)
// Then compile the arguments
for _, arg := range e.Arguments {
c.compileExpression(arg)
}
// Compile the function expression (which might be an Identifier or more complex)
c.compileExpression(e.Function)
// Emit the call instruction with the number of arguments
c.emit(types.OpCall, len(e.Arguments))
@ -434,7 +438,8 @@ func (c *compiler) compileFunctionLiteral(fn *parser.FunctionLiteral) {
// Ensure the function always returns a value
// If the last instruction is not a return, add one
if len(c.instructions) == 0 || c.instructions[len(c.instructions)-1].Opcode != types.OpReturn {
if len(fnCompiler.instructions) == 0 ||
(len(fnCompiler.instructions) > 0 && fnCompiler.instructions[len(fnCompiler.instructions)-1].Opcode != types.OpReturn) {
nullIndex := c.addConstant(nil)
c.emit(types.OpConstant, nullIndex)
c.emit(types.OpReturn, 0)
@ -446,13 +451,13 @@ func (c *compiler) compileFunctionLiteral(fn *parser.FunctionLiteral) {
// Restore the parent compiler
c.currentFunction = parentCompiler
// Create a Function object and add it to the constants
// Extract upvalue information for closure creation
upvalueIndexes := make([]int, len(fnCompiler.upvalues))
for i, upvalue := range fnCompiler.upvalues {
upvalueIndexes[i] = upvalue.index
}
// Create a Function object and add it to the constants
function := types.NewFunction(
fnCompiler.instructions,
fnCompiler.numParams,
@ -486,3 +491,75 @@ func (c *compiler) emit(op types.Opcode, operand int) {
c.instructions = append(c.instructions, instruction)
}
}
func (c *compiler) compileFunctionDeclaration(fn *parser.FunctionStatement) {
// Save the current compiler state
parentCompiler := c.currentFunction
// Create a new function compiler
fnCompiler := &functionCompiler{
constants: []any{},
instructions: []types.Instruction{},
numParams: len(fn.Parameters),
upvalues: []upvalueInfo{},
}
c.currentFunction = fnCompiler
// Enter a new scope for the function body
c.enterScope()
// Declare parameters as local variables
for _, param := range fn.Parameters {
c.declareVariable(param.Value)
paramIndex := c.addConstant(param.Value)
c.emit(types.OpSetLocal, paramIndex)
}
// Compile the function body
for _, stmt := range fn.Body.Statements {
c.compileStatement(stmt)
}
// Ensure the function always returns a value
// If the last instruction is not a return, add one
if len(fnCompiler.instructions) == 0 || fnCompiler.instructions[len(fnCompiler.instructions)-1].Opcode != types.OpReturn {
nullIndex := c.addConstant(nil)
c.emit(types.OpConstant, nullIndex)
c.emit(types.OpReturn, 0)
}
// Exit the function scope
c.exitScope()
// Restore the parent compiler
c.currentFunction = parentCompiler
// Extract upvalue information for closure creation
upvalueIndexes := make([]int, len(fnCompiler.upvalues))
for i, upvalue := range fnCompiler.upvalues {
upvalueIndexes[i] = upvalue.index
}
// Create a Function object and add it to the constants
function := types.NewFunction(
fnCompiler.instructions,
fnCompiler.numParams,
fnCompiler.constants,
upvalueIndexes,
)
functionIndex := c.addConstant(function)
c.emit(types.OpFunction, functionIndex)
// Store the function in a global variable
nameIndex := c.addConstant(fn.Name.Value)
// Use SetGlobal for top-level variables to persist between REPL lines
if len(c.scopes) <= 1 {
c.emit(types.OpSetGlobal, nameIndex)
} else {
c.declareVariable(fn.Name.Value)
c.emit(types.OpSetLocal, nameIndex)
}
}

View File

@ -191,3 +191,13 @@ type ReturnStatement struct {
func (rs *ReturnStatement) statementNode() {}
func (rs *ReturnStatement) TokenLiteral() string { return rs.Token.Value }
type FunctionStatement struct {
Token lexer.Token // The 'function' token
Name *Identifier
Parameters []*Identifier
Body *BlockStatement
}
func (fs *FunctionStatement) statementNode() {}
func (fs *FunctionStatement) TokenLiteral() string { return fs.Token.Value }

View File

@ -77,7 +77,8 @@ func New(l *lexer.Lexer) *Parser {
p.registerPrefix(lexer.TokenFalse, p.parseBooleanLiteral)
p.registerPrefix(lexer.TokenNot, p.parsePrefixExpression)
p.registerPrefix(lexer.TokenNil, p.parseNilLiteral)
p.registerPrefix(lexer.TokenFunction, p.parseFunctionLiteral) // Add function literal parsing
p.registerPrefix(lexer.TokenFunction, p.parseFunctionLiteral)
p.registerPrefix(lexer.TokenRightParen, p.parseUnexpectedToken)
// Initialize infix parse functions
p.infixParseFns = make(map[lexer.TokenType]infixParseFn)
@ -86,7 +87,7 @@ func New(l *lexer.Lexer) *Parser {
p.registerInfix(lexer.TokenStar, p.parseInfixExpression)
p.registerInfix(lexer.TokenSlash, p.parseInfixExpression)
p.registerInfix(lexer.TokenLeftBracket, p.parseIndexExpression)
p.registerInfix(lexer.TokenLeftParen, p.parseCallExpression) // Add function call parsing
p.registerInfix(lexer.TokenLeftParen, p.parseCallExpression)
p.registerInfix(lexer.TokenAnd, p.parseInfixExpression)
p.registerInfix(lexer.TokenOr, p.parseInfixExpression)
@ -184,6 +185,13 @@ func (p *Parser) parseStatement() Statement {
return p.parseEchoStatement()
case lexer.TokenReturn:
return p.parseReturnStatement()
case lexer.TokenFunction:
// If the next token is an identifier, it's a function declaration
if p.peekTokenIs(lexer.TokenIdentifier) {
return p.parseFunctionStatement()
}
// Otherwise, it's a function expression
return p.parseExpressionStatement()
default:
return p.parseExpressionStatement()
}
@ -601,24 +609,74 @@ func (p *Parser) parseNilLiteral() Expression {
return &NilLiteral{Token: p.curToken}
}
// New methods for function literals and call expressions
func (p *Parser) parseFunctionLiteral() Expression {
lit := &FunctionLiteral{Token: p.curToken}
// Check for opening paren for parameters
// Check if next token is a left paren
if !p.expectPeek(lexer.TokenLeftParen) {
return nil
}
lit.Parameters = p.parseFunctionParameters()
// Parse the parameters
lit.Parameters = []*Identifier{}
// Expect a block for the function body
if !p.expectPeek(lexer.TokenLeftBrace) {
// Check for empty parameter list
if p.peekTokenIs(lexer.TokenRightParen) {
p.nextToken() // Skip to the right paren
} else {
p.nextToken() // Skip the left paren
// Parse first parameter
if !p.curTokenIs(lexer.TokenIdentifier) {
p.errors = append(p.errors, fmt.Sprintf("line %d: expected parameter name, got %s",
p.curToken.Line, p.curToken.Value))
return nil
}
ident := &Identifier{Token: p.curToken, Value: p.curToken.Value}
lit.Parameters = append(lit.Parameters, ident)
// Parse additional parameters
for p.peekTokenIs(lexer.TokenComma) {
p.nextToken() // Skip current parameter
p.nextToken() // Skip comma
if !p.curTokenIs(lexer.TokenIdentifier) {
p.errors = append(p.errors, fmt.Sprintf("line %d: expected parameter name after comma",
p.curToken.Line))
return nil
}
ident := &Identifier{Token: p.curToken, Value: p.curToken.Value}
lit.Parameters = append(lit.Parameters, ident)
}
// After parsing parameters, expect closing parenthesis
if !p.expectPeek(lexer.TokenRightParen) {
return nil
}
}
// Parse function body
bodyStmts := []Statement{}
for p.nextToken(); !p.curTokenIs(lexer.TokenEnd) && !p.curTokenIs(lexer.TokenEOF); p.nextToken() {
stmt := p.parseStatement()
if stmt != nil {
bodyStmts = append(bodyStmts, stmt)
}
}
// Expect 'end' token
if !p.curTokenIs(lexer.TokenEnd) {
p.errors = append(p.errors, fmt.Sprintf("line %d: expected 'end' to close function",
p.curToken.Line))
return nil
}
lit.Body = p.parseBlockStatement()
lit.Body = &BlockStatement{
Token: p.curToken,
Statements: bodyStmts,
}
return lit
}
@ -628,13 +686,24 @@ func (p *Parser) parseFunctionParameters() []*Identifier {
// Empty parameter list
if p.peekTokenIs(lexer.TokenRightParen) {
p.nextToken()
p.nextToken() // Skip to right paren
p.nextToken() // Skip right paren
return identifiers
}
p.nextToken() // Skip '('
p.nextToken() // Skip left paren
// First parameter
if !p.curTokenIs(lexer.TokenIdentifier) {
// Expected identifier for parameter but didn't get one
p.errors = append(p.errors, fmt.Sprintf("line %d: expected parameter name, got %d",
p.curToken.Line, p.curToken.Type))
if p.expectPeek(lexer.TokenRightParen) {
return identifiers
}
return nil
}
ident := &Identifier{Token: p.curToken, Value: p.curToken.Value}
identifiers = append(identifiers, ident)
@ -642,34 +711,26 @@ func (p *Parser) parseFunctionParameters() []*Identifier {
for p.peekTokenIs(lexer.TokenComma) {
p.nextToken() // Skip current identifier
p.nextToken() // Skip comma
if !p.curTokenIs(lexer.TokenIdentifier) {
p.errors = append(p.errors, fmt.Sprintf("line %d: expected parameter name after comma",
p.curToken.Line))
break
}
ident := &Identifier{Token: p.curToken, Value: p.curToken.Value}
identifiers = append(identifiers, ident)
}
if !p.expectPeek(lexer.TokenRightParen) {
p.errors = append(p.errors, fmt.Sprintf("line %d: expected ')' to close parameter list",
p.curToken.Line))
return nil
}
return identifiers
}
func (p *Parser) parseBlockStatement() *BlockStatement {
block := &BlockStatement{Token: p.curToken}
block.Statements = []Statement{}
p.nextToken() // Skip '{'
for !p.curTokenIs(lexer.TokenRightBrace) && !p.curTokenIs(lexer.TokenEOF) {
stmt := p.parseStatement()
if stmt != nil {
block.Statements = append(block.Statements, stmt)
}
p.nextToken()
}
return block
}
func (p *Parser) parseCallExpression(function Expression) Expression {
exp := &CallExpression{
Token: p.curToken,
@ -701,3 +762,85 @@ func (p *Parser) parseCallExpression(function Expression) Expression {
return exp
}
func (p *Parser) parseFunctionStatement() *FunctionStatement {
stmt := &FunctionStatement{Token: p.curToken}
p.nextToken() // Skip 'function'
// Parse function name
if !p.curTokenIs(lexer.TokenIdentifier) {
p.errors = append(p.errors, fmt.Sprintf("line %d: expected function name", p.curToken.Line))
return nil
}
stmt.Name = &Identifier{Token: p.curToken, Value: p.curToken.Value}
// Check if next token is a left paren
if !p.expectPeek(lexer.TokenLeftParen) {
return nil
}
// Parse parameters
stmt.Parameters = []*Identifier{}
// Check for empty parameter list
if p.peekTokenIs(lexer.TokenRightParen) {
p.nextToken() // Skip to the right paren
} else {
p.nextToken() // Skip the left paren
// Parse first parameter
if !p.curTokenIs(lexer.TokenIdentifier) {
p.errors = append(p.errors, fmt.Sprintf("line %d: expected parameter name, got %s",
p.curToken.Line, p.curToken.Value))
return nil
}
ident := &Identifier{Token: p.curToken, Value: p.curToken.Value}
stmt.Parameters = append(stmt.Parameters, ident)
// Parse additional parameters
for p.peekTokenIs(lexer.TokenComma) {
p.nextToken() // Skip current parameter
p.nextToken() // Skip comma
if !p.curTokenIs(lexer.TokenIdentifier) {
p.errors = append(p.errors, fmt.Sprintf("line %d: expected parameter name after comma",
p.curToken.Line))
return nil
}
ident := &Identifier{Token: p.curToken, Value: p.curToken.Value}
stmt.Parameters = append(stmt.Parameters, ident)
}
// After parsing parameters, expect closing parenthesis
if !p.expectPeek(lexer.TokenRightParen) {
return nil
}
}
// Parse function body
bodyStmts := []Statement{}
for p.nextToken(); !p.curTokenIs(lexer.TokenEnd) && !p.curTokenIs(lexer.TokenEOF); p.nextToken() {
stmt := p.parseStatement()
if stmt != nil {
bodyStmts = append(bodyStmts, stmt)
}
}
// Expect 'end' token
if !p.curTokenIs(lexer.TokenEnd) {
p.errors = append(p.errors, fmt.Sprintf("line %d: expected 'end' to close function",
p.curToken.Line))
return nil
}
stmt.Body = &BlockStatement{
Token: p.curToken,
Statements: bodyStmts,
}
return stmt
}

View File

@ -1,5 +1,12 @@
function add(a, b)
return a+b
// Assignment style
add = function(a, b)
return a + b
end
echo add(1, 2)
// Direct declaration style
function multiply(a, b)
return a * b
end
echo add(1, 2)
echo multiply(2, 2)

View File

@ -241,23 +241,37 @@ func (vm *VM) runCode(instructions []types.Instruction, basePointer int) types.V
case types.OpJump:
ip = instruction.Operand - 1 // -1 because loop will increment
// Function instructions
// Function instructions
case types.OpFunction:
constIndex := instruction.Operand
function := vm.constants[constIndex].(*types.Function)
// Use the helper function to create a proper function value
vm.push(types.NewFunctionValue(function))
case types.OpCall:
numArgs := instruction.Operand
fnVal := vm.stack[vm.sp-numArgs-1]
if fnVal.Type != types.TypeFunction {
fmt.Println("Error: attempt to call non-function value")
// The function is at position sp-numArgs-1
if vm.sp <= numArgs {
fmt.Println("Error: stack underflow during function call")
vm.push(types.NewNull())
continue
}
function := fnVal.Data.(*types.Function)
fnVal := vm.stack[vm.sp-numArgs-1]
if fnVal.Type != types.TypeFunction {
fmt.Printf("Error: attempt to call non-function value\n")
vm.push(types.NewNull())
continue
}
function, ok := fnVal.Data.(*types.Function)
if !ok {
fmt.Printf("Error: function data is invalid\n")
vm.push(types.NewNull())
continue
}
// Check if we have the correct number of arguments
if numArgs != function.NumParams {
@ -286,18 +300,25 @@ func (vm *VM) runCode(instructions []types.Instruction, basePointer int) types.V
}
vm.frames[vm.fp] = frame
// Set up the call
// We keep the function value and args on the stack
// They become local variables in the function's scope
// Save the current constants
oldConstants := vm.constants
// Switch to function's constants
vm.constants = function.Constants
// Run the function code
returnValue := vm.runCode(function.Instructions, frame.BasePointer)
// Restore the old constants
vm.constants = oldConstants
// Restore state
vm.fp--
// Return value from function is already on stack from OpReturn
return returnValue
// Replace the function with the return value
vm.stack[frame.BasePointer] = returnValue
// Adjust the stack pointer to remove the arguments
vm.sp = frame.BasePointer + 1
case types.OpReturn:
returnValue := vm.pop()
@ -306,11 +327,11 @@ func (vm *VM) runCode(instructions []types.Instruction, basePointer int) types.V
if vm.fp >= 0 {
frame := vm.frames[vm.fp]
// Restore the stack
vm.sp = frame.BasePointer + 1 // Keep the function
// Restore the stack to just below the function
vm.sp = frame.BasePointer
// Push the return value
vm.stack[vm.sp-1] = returnValue
vm.push(returnValue)
// Return to the caller
return returnValue