functions 2
This commit is contained in:
parent
306f34cb73
commit
29707b4b02
@ -135,6 +135,10 @@ func (c *compiler) compileStatement(stmt parser.Statement) {
|
||||
}
|
||||
c.emit(types.OpReturn, 0)
|
||||
|
||||
case *parser.FunctionStatement:
|
||||
// Use the dedicated function for function statements
|
||||
c.compileFunctionDeclaration(s)
|
||||
|
||||
// BlockStatement now should only be used for keyword blocks like if-then-else-end
|
||||
case *parser.BlockStatement:
|
||||
for _, blockStmt := range s.Statements {
|
||||
@ -207,14 +211,14 @@ func (c *compiler) compileExpression(expr parser.Expression) {
|
||||
c.compileFunctionLiteral(e)
|
||||
|
||||
case *parser.CallExpression:
|
||||
// Compile the arguments first
|
||||
// Compile the function expression first
|
||||
c.compileExpression(e.Function)
|
||||
|
||||
// Then compile the arguments
|
||||
for _, arg := range e.Arguments {
|
||||
c.compileExpression(arg)
|
||||
}
|
||||
|
||||
// Compile the function expression (which might be an Identifier or more complex)
|
||||
c.compileExpression(e.Function)
|
||||
|
||||
// Emit the call instruction with the number of arguments
|
||||
c.emit(types.OpCall, len(e.Arguments))
|
||||
|
||||
@ -434,7 +438,8 @@ func (c *compiler) compileFunctionLiteral(fn *parser.FunctionLiteral) {
|
||||
|
||||
// Ensure the function always returns a value
|
||||
// If the last instruction is not a return, add one
|
||||
if len(c.instructions) == 0 || c.instructions[len(c.instructions)-1].Opcode != types.OpReturn {
|
||||
if len(fnCompiler.instructions) == 0 ||
|
||||
(len(fnCompiler.instructions) > 0 && fnCompiler.instructions[len(fnCompiler.instructions)-1].Opcode != types.OpReturn) {
|
||||
nullIndex := c.addConstant(nil)
|
||||
c.emit(types.OpConstant, nullIndex)
|
||||
c.emit(types.OpReturn, 0)
|
||||
@ -446,13 +451,13 @@ func (c *compiler) compileFunctionLiteral(fn *parser.FunctionLiteral) {
|
||||
// Restore the parent compiler
|
||||
c.currentFunction = parentCompiler
|
||||
|
||||
// Create a Function object and add it to the constants
|
||||
// Extract upvalue information for closure creation
|
||||
upvalueIndexes := make([]int, len(fnCompiler.upvalues))
|
||||
for i, upvalue := range fnCompiler.upvalues {
|
||||
upvalueIndexes[i] = upvalue.index
|
||||
}
|
||||
|
||||
// Create a Function object and add it to the constants
|
||||
function := types.NewFunction(
|
||||
fnCompiler.instructions,
|
||||
fnCompiler.numParams,
|
||||
@ -486,3 +491,75 @@ func (c *compiler) emit(op types.Opcode, operand int) {
|
||||
c.instructions = append(c.instructions, instruction)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *compiler) compileFunctionDeclaration(fn *parser.FunctionStatement) {
|
||||
// Save the current compiler state
|
||||
parentCompiler := c.currentFunction
|
||||
|
||||
// Create a new function compiler
|
||||
fnCompiler := &functionCompiler{
|
||||
constants: []any{},
|
||||
instructions: []types.Instruction{},
|
||||
numParams: len(fn.Parameters),
|
||||
upvalues: []upvalueInfo{},
|
||||
}
|
||||
|
||||
c.currentFunction = fnCompiler
|
||||
|
||||
// Enter a new scope for the function body
|
||||
c.enterScope()
|
||||
|
||||
// Declare parameters as local variables
|
||||
for _, param := range fn.Parameters {
|
||||
c.declareVariable(param.Value)
|
||||
paramIndex := c.addConstant(param.Value)
|
||||
c.emit(types.OpSetLocal, paramIndex)
|
||||
}
|
||||
|
||||
// Compile the function body
|
||||
for _, stmt := range fn.Body.Statements {
|
||||
c.compileStatement(stmt)
|
||||
}
|
||||
|
||||
// Ensure the function always returns a value
|
||||
// If the last instruction is not a return, add one
|
||||
if len(fnCompiler.instructions) == 0 || fnCompiler.instructions[len(fnCompiler.instructions)-1].Opcode != types.OpReturn {
|
||||
nullIndex := c.addConstant(nil)
|
||||
c.emit(types.OpConstant, nullIndex)
|
||||
c.emit(types.OpReturn, 0)
|
||||
}
|
||||
|
||||
// Exit the function scope
|
||||
c.exitScope()
|
||||
|
||||
// Restore the parent compiler
|
||||
c.currentFunction = parentCompiler
|
||||
|
||||
// Extract upvalue information for closure creation
|
||||
upvalueIndexes := make([]int, len(fnCompiler.upvalues))
|
||||
for i, upvalue := range fnCompiler.upvalues {
|
||||
upvalueIndexes[i] = upvalue.index
|
||||
}
|
||||
|
||||
// Create a Function object and add it to the constants
|
||||
function := types.NewFunction(
|
||||
fnCompiler.instructions,
|
||||
fnCompiler.numParams,
|
||||
fnCompiler.constants,
|
||||
upvalueIndexes,
|
||||
)
|
||||
|
||||
functionIndex := c.addConstant(function)
|
||||
c.emit(types.OpFunction, functionIndex)
|
||||
|
||||
// Store the function in a global variable
|
||||
nameIndex := c.addConstant(fn.Name.Value)
|
||||
|
||||
// Use SetGlobal for top-level variables to persist between REPL lines
|
||||
if len(c.scopes) <= 1 {
|
||||
c.emit(types.OpSetGlobal, nameIndex)
|
||||
} else {
|
||||
c.declareVariable(fn.Name.Value)
|
||||
c.emit(types.OpSetLocal, nameIndex)
|
||||
}
|
||||
}
|
||||
|
@ -191,3 +191,13 @@ type ReturnStatement struct {
|
||||
|
||||
func (rs *ReturnStatement) statementNode() {}
|
||||
func (rs *ReturnStatement) TokenLiteral() string { return rs.Token.Value }
|
||||
|
||||
type FunctionStatement struct {
|
||||
Token lexer.Token // The 'function' token
|
||||
Name *Identifier
|
||||
Parameters []*Identifier
|
||||
Body *BlockStatement
|
||||
}
|
||||
|
||||
func (fs *FunctionStatement) statementNode() {}
|
||||
func (fs *FunctionStatement) TokenLiteral() string { return fs.Token.Value }
|
||||
|
199
parser/parser.go
199
parser/parser.go
@ -77,7 +77,8 @@ func New(l *lexer.Lexer) *Parser {
|
||||
p.registerPrefix(lexer.TokenFalse, p.parseBooleanLiteral)
|
||||
p.registerPrefix(lexer.TokenNot, p.parsePrefixExpression)
|
||||
p.registerPrefix(lexer.TokenNil, p.parseNilLiteral)
|
||||
p.registerPrefix(lexer.TokenFunction, p.parseFunctionLiteral) // Add function literal parsing
|
||||
p.registerPrefix(lexer.TokenFunction, p.parseFunctionLiteral)
|
||||
p.registerPrefix(lexer.TokenRightParen, p.parseUnexpectedToken)
|
||||
|
||||
// Initialize infix parse functions
|
||||
p.infixParseFns = make(map[lexer.TokenType]infixParseFn)
|
||||
@ -86,7 +87,7 @@ func New(l *lexer.Lexer) *Parser {
|
||||
p.registerInfix(lexer.TokenStar, p.parseInfixExpression)
|
||||
p.registerInfix(lexer.TokenSlash, p.parseInfixExpression)
|
||||
p.registerInfix(lexer.TokenLeftBracket, p.parseIndexExpression)
|
||||
p.registerInfix(lexer.TokenLeftParen, p.parseCallExpression) // Add function call parsing
|
||||
p.registerInfix(lexer.TokenLeftParen, p.parseCallExpression)
|
||||
p.registerInfix(lexer.TokenAnd, p.parseInfixExpression)
|
||||
p.registerInfix(lexer.TokenOr, p.parseInfixExpression)
|
||||
|
||||
@ -184,6 +185,13 @@ func (p *Parser) parseStatement() Statement {
|
||||
return p.parseEchoStatement()
|
||||
case lexer.TokenReturn:
|
||||
return p.parseReturnStatement()
|
||||
case lexer.TokenFunction:
|
||||
// If the next token is an identifier, it's a function declaration
|
||||
if p.peekTokenIs(lexer.TokenIdentifier) {
|
||||
return p.parseFunctionStatement()
|
||||
}
|
||||
// Otherwise, it's a function expression
|
||||
return p.parseExpressionStatement()
|
||||
default:
|
||||
return p.parseExpressionStatement()
|
||||
}
|
||||
@ -601,24 +609,74 @@ func (p *Parser) parseNilLiteral() Expression {
|
||||
return &NilLiteral{Token: p.curToken}
|
||||
}
|
||||
|
||||
// New methods for function literals and call expressions
|
||||
|
||||
func (p *Parser) parseFunctionLiteral() Expression {
|
||||
lit := &FunctionLiteral{Token: p.curToken}
|
||||
|
||||
// Check for opening paren for parameters
|
||||
// Check if next token is a left paren
|
||||
if !p.expectPeek(lexer.TokenLeftParen) {
|
||||
return nil
|
||||
}
|
||||
|
||||
lit.Parameters = p.parseFunctionParameters()
|
||||
// Parse the parameters
|
||||
lit.Parameters = []*Identifier{}
|
||||
|
||||
// Expect a block for the function body
|
||||
if !p.expectPeek(lexer.TokenLeftBrace) {
|
||||
// Check for empty parameter list
|
||||
if p.peekTokenIs(lexer.TokenRightParen) {
|
||||
p.nextToken() // Skip to the right paren
|
||||
} else {
|
||||
p.nextToken() // Skip the left paren
|
||||
|
||||
// Parse first parameter
|
||||
if !p.curTokenIs(lexer.TokenIdentifier) {
|
||||
p.errors = append(p.errors, fmt.Sprintf("line %d: expected parameter name, got %s",
|
||||
p.curToken.Line, p.curToken.Value))
|
||||
return nil
|
||||
}
|
||||
|
||||
ident := &Identifier{Token: p.curToken, Value: p.curToken.Value}
|
||||
lit.Parameters = append(lit.Parameters, ident)
|
||||
|
||||
// Parse additional parameters
|
||||
for p.peekTokenIs(lexer.TokenComma) {
|
||||
p.nextToken() // Skip current parameter
|
||||
p.nextToken() // Skip comma
|
||||
|
||||
if !p.curTokenIs(lexer.TokenIdentifier) {
|
||||
p.errors = append(p.errors, fmt.Sprintf("line %d: expected parameter name after comma",
|
||||
p.curToken.Line))
|
||||
return nil
|
||||
}
|
||||
|
||||
ident := &Identifier{Token: p.curToken, Value: p.curToken.Value}
|
||||
lit.Parameters = append(lit.Parameters, ident)
|
||||
}
|
||||
|
||||
// After parsing parameters, expect closing parenthesis
|
||||
if !p.expectPeek(lexer.TokenRightParen) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Parse function body
|
||||
bodyStmts := []Statement{}
|
||||
for p.nextToken(); !p.curTokenIs(lexer.TokenEnd) && !p.curTokenIs(lexer.TokenEOF); p.nextToken() {
|
||||
stmt := p.parseStatement()
|
||||
if stmt != nil {
|
||||
bodyStmts = append(bodyStmts, stmt)
|
||||
}
|
||||
}
|
||||
|
||||
// Expect 'end' token
|
||||
if !p.curTokenIs(lexer.TokenEnd) {
|
||||
p.errors = append(p.errors, fmt.Sprintf("line %d: expected 'end' to close function",
|
||||
p.curToken.Line))
|
||||
return nil
|
||||
}
|
||||
|
||||
lit.Body = p.parseBlockStatement()
|
||||
lit.Body = &BlockStatement{
|
||||
Token: p.curToken,
|
||||
Statements: bodyStmts,
|
||||
}
|
||||
|
||||
return lit
|
||||
}
|
||||
@ -628,13 +686,24 @@ func (p *Parser) parseFunctionParameters() []*Identifier {
|
||||
|
||||
// Empty parameter list
|
||||
if p.peekTokenIs(lexer.TokenRightParen) {
|
||||
p.nextToken()
|
||||
p.nextToken() // Skip to right paren
|
||||
p.nextToken() // Skip right paren
|
||||
return identifiers
|
||||
}
|
||||
|
||||
p.nextToken() // Skip '('
|
||||
p.nextToken() // Skip left paren
|
||||
|
||||
// First parameter
|
||||
if !p.curTokenIs(lexer.TokenIdentifier) {
|
||||
// Expected identifier for parameter but didn't get one
|
||||
p.errors = append(p.errors, fmt.Sprintf("line %d: expected parameter name, got %d",
|
||||
p.curToken.Line, p.curToken.Type))
|
||||
if p.expectPeek(lexer.TokenRightParen) {
|
||||
return identifiers
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
ident := &Identifier{Token: p.curToken, Value: p.curToken.Value}
|
||||
identifiers = append(identifiers, ident)
|
||||
|
||||
@ -642,34 +711,26 @@ func (p *Parser) parseFunctionParameters() []*Identifier {
|
||||
for p.peekTokenIs(lexer.TokenComma) {
|
||||
p.nextToken() // Skip current identifier
|
||||
p.nextToken() // Skip comma
|
||||
|
||||
if !p.curTokenIs(lexer.TokenIdentifier) {
|
||||
p.errors = append(p.errors, fmt.Sprintf("line %d: expected parameter name after comma",
|
||||
p.curToken.Line))
|
||||
break
|
||||
}
|
||||
|
||||
ident := &Identifier{Token: p.curToken, Value: p.curToken.Value}
|
||||
identifiers = append(identifiers, ident)
|
||||
}
|
||||
|
||||
if !p.expectPeek(lexer.TokenRightParen) {
|
||||
p.errors = append(p.errors, fmt.Sprintf("line %d: expected ')' to close parameter list",
|
||||
p.curToken.Line))
|
||||
return nil
|
||||
}
|
||||
|
||||
return identifiers
|
||||
}
|
||||
|
||||
func (p *Parser) parseBlockStatement() *BlockStatement {
|
||||
block := &BlockStatement{Token: p.curToken}
|
||||
block.Statements = []Statement{}
|
||||
|
||||
p.nextToken() // Skip '{'
|
||||
|
||||
for !p.curTokenIs(lexer.TokenRightBrace) && !p.curTokenIs(lexer.TokenEOF) {
|
||||
stmt := p.parseStatement()
|
||||
if stmt != nil {
|
||||
block.Statements = append(block.Statements, stmt)
|
||||
}
|
||||
p.nextToken()
|
||||
}
|
||||
|
||||
return block
|
||||
}
|
||||
|
||||
func (p *Parser) parseCallExpression(function Expression) Expression {
|
||||
exp := &CallExpression{
|
||||
Token: p.curToken,
|
||||
@ -701,3 +762,85 @@ func (p *Parser) parseCallExpression(function Expression) Expression {
|
||||
|
||||
return exp
|
||||
}
|
||||
|
||||
func (p *Parser) parseFunctionStatement() *FunctionStatement {
|
||||
stmt := &FunctionStatement{Token: p.curToken}
|
||||
|
||||
p.nextToken() // Skip 'function'
|
||||
|
||||
// Parse function name
|
||||
if !p.curTokenIs(lexer.TokenIdentifier) {
|
||||
p.errors = append(p.errors, fmt.Sprintf("line %d: expected function name", p.curToken.Line))
|
||||
return nil
|
||||
}
|
||||
|
||||
stmt.Name = &Identifier{Token: p.curToken, Value: p.curToken.Value}
|
||||
|
||||
// Check if next token is a left paren
|
||||
if !p.expectPeek(lexer.TokenLeftParen) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Parse parameters
|
||||
stmt.Parameters = []*Identifier{}
|
||||
|
||||
// Check for empty parameter list
|
||||
if p.peekTokenIs(lexer.TokenRightParen) {
|
||||
p.nextToken() // Skip to the right paren
|
||||
} else {
|
||||
p.nextToken() // Skip the left paren
|
||||
|
||||
// Parse first parameter
|
||||
if !p.curTokenIs(lexer.TokenIdentifier) {
|
||||
p.errors = append(p.errors, fmt.Sprintf("line %d: expected parameter name, got %s",
|
||||
p.curToken.Line, p.curToken.Value))
|
||||
return nil
|
||||
}
|
||||
|
||||
ident := &Identifier{Token: p.curToken, Value: p.curToken.Value}
|
||||
stmt.Parameters = append(stmt.Parameters, ident)
|
||||
|
||||
// Parse additional parameters
|
||||
for p.peekTokenIs(lexer.TokenComma) {
|
||||
p.nextToken() // Skip current parameter
|
||||
p.nextToken() // Skip comma
|
||||
|
||||
if !p.curTokenIs(lexer.TokenIdentifier) {
|
||||
p.errors = append(p.errors, fmt.Sprintf("line %d: expected parameter name after comma",
|
||||
p.curToken.Line))
|
||||
return nil
|
||||
}
|
||||
|
||||
ident := &Identifier{Token: p.curToken, Value: p.curToken.Value}
|
||||
stmt.Parameters = append(stmt.Parameters, ident)
|
||||
}
|
||||
|
||||
// After parsing parameters, expect closing parenthesis
|
||||
if !p.expectPeek(lexer.TokenRightParen) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Parse function body
|
||||
bodyStmts := []Statement{}
|
||||
for p.nextToken(); !p.curTokenIs(lexer.TokenEnd) && !p.curTokenIs(lexer.TokenEOF); p.nextToken() {
|
||||
stmt := p.parseStatement()
|
||||
if stmt != nil {
|
||||
bodyStmts = append(bodyStmts, stmt)
|
||||
}
|
||||
}
|
||||
|
||||
// Expect 'end' token
|
||||
if !p.curTokenIs(lexer.TokenEnd) {
|
||||
p.errors = append(p.errors, fmt.Sprintf("line %d: expected 'end' to close function",
|
||||
p.curToken.Line))
|
||||
return nil
|
||||
}
|
||||
|
||||
stmt.Body = &BlockStatement{
|
||||
Token: p.curToken,
|
||||
Statements: bodyStmts,
|
||||
}
|
||||
|
||||
return stmt
|
||||
}
|
||||
|
@ -1,5 +1,12 @@
|
||||
function add(a, b)
|
||||
return a+b
|
||||
// Assignment style
|
||||
add = function(a, b)
|
||||
return a + b
|
||||
end
|
||||
|
||||
echo add(1, 2)
|
||||
// Direct declaration style
|
||||
function multiply(a, b)
|
||||
return a * b
|
||||
end
|
||||
|
||||
echo add(1, 2)
|
||||
echo multiply(2, 2)
|
47
vm/vm.go
47
vm/vm.go
@ -241,23 +241,37 @@ func (vm *VM) runCode(instructions []types.Instruction, basePointer int) types.V
|
||||
case types.OpJump:
|
||||
ip = instruction.Operand - 1 // -1 because loop will increment
|
||||
|
||||
// Function instructions
|
||||
// Function instructions
|
||||
case types.OpFunction:
|
||||
constIndex := instruction.Operand
|
||||
function := vm.constants[constIndex].(*types.Function)
|
||||
// Use the helper function to create a proper function value
|
||||
vm.push(types.NewFunctionValue(function))
|
||||
|
||||
case types.OpCall:
|
||||
numArgs := instruction.Operand
|
||||
fnVal := vm.stack[vm.sp-numArgs-1]
|
||||
|
||||
if fnVal.Type != types.TypeFunction {
|
||||
fmt.Println("Error: attempt to call non-function value")
|
||||
// The function is at position sp-numArgs-1
|
||||
if vm.sp <= numArgs {
|
||||
fmt.Println("Error: stack underflow during function call")
|
||||
vm.push(types.NewNull())
|
||||
continue
|
||||
}
|
||||
|
||||
function := fnVal.Data.(*types.Function)
|
||||
fnVal := vm.stack[vm.sp-numArgs-1]
|
||||
|
||||
if fnVal.Type != types.TypeFunction {
|
||||
fmt.Printf("Error: attempt to call non-function value\n")
|
||||
vm.push(types.NewNull())
|
||||
continue
|
||||
}
|
||||
|
||||
function, ok := fnVal.Data.(*types.Function)
|
||||
if !ok {
|
||||
fmt.Printf("Error: function data is invalid\n")
|
||||
vm.push(types.NewNull())
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if we have the correct number of arguments
|
||||
if numArgs != function.NumParams {
|
||||
@ -286,18 +300,25 @@ func (vm *VM) runCode(instructions []types.Instruction, basePointer int) types.V
|
||||
}
|
||||
vm.frames[vm.fp] = frame
|
||||
|
||||
// Set up the call
|
||||
// We keep the function value and args on the stack
|
||||
// They become local variables in the function's scope
|
||||
// Save the current constants
|
||||
oldConstants := vm.constants
|
||||
// Switch to function's constants
|
||||
vm.constants = function.Constants
|
||||
|
||||
// Run the function code
|
||||
returnValue := vm.runCode(function.Instructions, frame.BasePointer)
|
||||
|
||||
// Restore the old constants
|
||||
vm.constants = oldConstants
|
||||
|
||||
// Restore state
|
||||
vm.fp--
|
||||
|
||||
// Return value from function is already on stack from OpReturn
|
||||
return returnValue
|
||||
// Replace the function with the return value
|
||||
vm.stack[frame.BasePointer] = returnValue
|
||||
|
||||
// Adjust the stack pointer to remove the arguments
|
||||
vm.sp = frame.BasePointer + 1
|
||||
|
||||
case types.OpReturn:
|
||||
returnValue := vm.pop()
|
||||
@ -306,11 +327,11 @@ func (vm *VM) runCode(instructions []types.Instruction, basePointer int) types.V
|
||||
if vm.fp >= 0 {
|
||||
frame := vm.frames[vm.fp]
|
||||
|
||||
// Restore the stack
|
||||
vm.sp = frame.BasePointer + 1 // Keep the function
|
||||
// Restore the stack to just below the function
|
||||
vm.sp = frame.BasePointer
|
||||
|
||||
// Push the return value
|
||||
vm.stack[vm.sp-1] = returnValue
|
||||
vm.push(returnValue)
|
||||
|
||||
// Return to the caller
|
||||
return returnValue
|
||||
|
Loading…
x
Reference in New Issue
Block a user