From 29707b4b02704383635b0abe4ca3214c1df317fb Mon Sep 17 00:00:00 2001 From: Sky Johnson Date: Tue, 6 May 2025 23:20:10 -0500 Subject: [PATCH] functions 2 --- compiler/compiler.go | 89 +++++++++++++++++-- parser/ast.go | 10 +++ parser/parser.go | 199 +++++++++++++++++++++++++++++++++++++------ tests/funcs.mako | 13 ++- vm/vm.go | 47 +++++++--- 5 files changed, 308 insertions(+), 50 deletions(-) diff --git a/compiler/compiler.go b/compiler/compiler.go index fbf152a..9213b6d 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -135,6 +135,10 @@ func (c *compiler) compileStatement(stmt parser.Statement) { } c.emit(types.OpReturn, 0) + case *parser.FunctionStatement: + // Use the dedicated function for function statements + c.compileFunctionDeclaration(s) + // BlockStatement now should only be used for keyword blocks like if-then-else-end case *parser.BlockStatement: for _, blockStmt := range s.Statements { @@ -207,14 +211,14 @@ func (c *compiler) compileExpression(expr parser.Expression) { c.compileFunctionLiteral(e) case *parser.CallExpression: - // Compile the arguments first + // Compile the function expression first + c.compileExpression(e.Function) + + // Then compile the arguments for _, arg := range e.Arguments { c.compileExpression(arg) } - // Compile the function expression (which might be an Identifier or more complex) - c.compileExpression(e.Function) - // Emit the call instruction with the number of arguments c.emit(types.OpCall, len(e.Arguments)) @@ -434,7 +438,8 @@ func (c *compiler) compileFunctionLiteral(fn *parser.FunctionLiteral) { // Ensure the function always returns a value // If the last instruction is not a return, add one - if len(c.instructions) == 0 || c.instructions[len(c.instructions)-1].Opcode != types.OpReturn { + if len(fnCompiler.instructions) == 0 || + (len(fnCompiler.instructions) > 0 && fnCompiler.instructions[len(fnCompiler.instructions)-1].Opcode != types.OpReturn) { nullIndex := c.addConstant(nil) c.emit(types.OpConstant, nullIndex) c.emit(types.OpReturn, 0) @@ -446,13 +451,13 @@ func (c *compiler) compileFunctionLiteral(fn *parser.FunctionLiteral) { // Restore the parent compiler c.currentFunction = parentCompiler - // Create a Function object and add it to the constants // Extract upvalue information for closure creation upvalueIndexes := make([]int, len(fnCompiler.upvalues)) for i, upvalue := range fnCompiler.upvalues { upvalueIndexes[i] = upvalue.index } + // Create a Function object and add it to the constants function := types.NewFunction( fnCompiler.instructions, fnCompiler.numParams, @@ -486,3 +491,75 @@ func (c *compiler) emit(op types.Opcode, operand int) { c.instructions = append(c.instructions, instruction) } } + +func (c *compiler) compileFunctionDeclaration(fn *parser.FunctionStatement) { + // Save the current compiler state + parentCompiler := c.currentFunction + + // Create a new function compiler + fnCompiler := &functionCompiler{ + constants: []any{}, + instructions: []types.Instruction{}, + numParams: len(fn.Parameters), + upvalues: []upvalueInfo{}, + } + + c.currentFunction = fnCompiler + + // Enter a new scope for the function body + c.enterScope() + + // Declare parameters as local variables + for _, param := range fn.Parameters { + c.declareVariable(param.Value) + paramIndex := c.addConstant(param.Value) + c.emit(types.OpSetLocal, paramIndex) + } + + // Compile the function body + for _, stmt := range fn.Body.Statements { + c.compileStatement(stmt) + } + + // Ensure the function always returns a value + // If the last instruction is not a return, add one + if len(fnCompiler.instructions) == 0 || fnCompiler.instructions[len(fnCompiler.instructions)-1].Opcode != types.OpReturn { + nullIndex := c.addConstant(nil) + c.emit(types.OpConstant, nullIndex) + c.emit(types.OpReturn, 0) + } + + // Exit the function scope + c.exitScope() + + // Restore the parent compiler + c.currentFunction = parentCompiler + + // Extract upvalue information for closure creation + upvalueIndexes := make([]int, len(fnCompiler.upvalues)) + for i, upvalue := range fnCompiler.upvalues { + upvalueIndexes[i] = upvalue.index + } + + // Create a Function object and add it to the constants + function := types.NewFunction( + fnCompiler.instructions, + fnCompiler.numParams, + fnCompiler.constants, + upvalueIndexes, + ) + + functionIndex := c.addConstant(function) + c.emit(types.OpFunction, functionIndex) + + // Store the function in a global variable + nameIndex := c.addConstant(fn.Name.Value) + + // Use SetGlobal for top-level variables to persist between REPL lines + if len(c.scopes) <= 1 { + c.emit(types.OpSetGlobal, nameIndex) + } else { + c.declareVariable(fn.Name.Value) + c.emit(types.OpSetLocal, nameIndex) + } +} diff --git a/parser/ast.go b/parser/ast.go index 8470f79..815e55f 100644 --- a/parser/ast.go +++ b/parser/ast.go @@ -191,3 +191,13 @@ type ReturnStatement struct { func (rs *ReturnStatement) statementNode() {} func (rs *ReturnStatement) TokenLiteral() string { return rs.Token.Value } + +type FunctionStatement struct { + Token lexer.Token // The 'function' token + Name *Identifier + Parameters []*Identifier + Body *BlockStatement +} + +func (fs *FunctionStatement) statementNode() {} +func (fs *FunctionStatement) TokenLiteral() string { return fs.Token.Value } diff --git a/parser/parser.go b/parser/parser.go index 719c954..266f464 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -77,7 +77,8 @@ func New(l *lexer.Lexer) *Parser { p.registerPrefix(lexer.TokenFalse, p.parseBooleanLiteral) p.registerPrefix(lexer.TokenNot, p.parsePrefixExpression) p.registerPrefix(lexer.TokenNil, p.parseNilLiteral) - p.registerPrefix(lexer.TokenFunction, p.parseFunctionLiteral) // Add function literal parsing + p.registerPrefix(lexer.TokenFunction, p.parseFunctionLiteral) + p.registerPrefix(lexer.TokenRightParen, p.parseUnexpectedToken) // Initialize infix parse functions p.infixParseFns = make(map[lexer.TokenType]infixParseFn) @@ -86,7 +87,7 @@ func New(l *lexer.Lexer) *Parser { p.registerInfix(lexer.TokenStar, p.parseInfixExpression) p.registerInfix(lexer.TokenSlash, p.parseInfixExpression) p.registerInfix(lexer.TokenLeftBracket, p.parseIndexExpression) - p.registerInfix(lexer.TokenLeftParen, p.parseCallExpression) // Add function call parsing + p.registerInfix(lexer.TokenLeftParen, p.parseCallExpression) p.registerInfix(lexer.TokenAnd, p.parseInfixExpression) p.registerInfix(lexer.TokenOr, p.parseInfixExpression) @@ -184,6 +185,13 @@ func (p *Parser) parseStatement() Statement { return p.parseEchoStatement() case lexer.TokenReturn: return p.parseReturnStatement() + case lexer.TokenFunction: + // If the next token is an identifier, it's a function declaration + if p.peekTokenIs(lexer.TokenIdentifier) { + return p.parseFunctionStatement() + } + // Otherwise, it's a function expression + return p.parseExpressionStatement() default: return p.parseExpressionStatement() } @@ -601,24 +609,74 @@ func (p *Parser) parseNilLiteral() Expression { return &NilLiteral{Token: p.curToken} } -// New methods for function literals and call expressions - func (p *Parser) parseFunctionLiteral() Expression { lit := &FunctionLiteral{Token: p.curToken} - // Check for opening paren for parameters + // Check if next token is a left paren if !p.expectPeek(lexer.TokenLeftParen) { return nil } - lit.Parameters = p.parseFunctionParameters() + // Parse the parameters + lit.Parameters = []*Identifier{} - // Expect a block for the function body - if !p.expectPeek(lexer.TokenLeftBrace) { + // Check for empty parameter list + if p.peekTokenIs(lexer.TokenRightParen) { + p.nextToken() // Skip to the right paren + } else { + p.nextToken() // Skip the left paren + + // Parse first parameter + if !p.curTokenIs(lexer.TokenIdentifier) { + p.errors = append(p.errors, fmt.Sprintf("line %d: expected parameter name, got %s", + p.curToken.Line, p.curToken.Value)) + return nil + } + + ident := &Identifier{Token: p.curToken, Value: p.curToken.Value} + lit.Parameters = append(lit.Parameters, ident) + + // Parse additional parameters + for p.peekTokenIs(lexer.TokenComma) { + p.nextToken() // Skip current parameter + p.nextToken() // Skip comma + + if !p.curTokenIs(lexer.TokenIdentifier) { + p.errors = append(p.errors, fmt.Sprintf("line %d: expected parameter name after comma", + p.curToken.Line)) + return nil + } + + ident := &Identifier{Token: p.curToken, Value: p.curToken.Value} + lit.Parameters = append(lit.Parameters, ident) + } + + // After parsing parameters, expect closing parenthesis + if !p.expectPeek(lexer.TokenRightParen) { + return nil + } + } + + // Parse function body + bodyStmts := []Statement{} + for p.nextToken(); !p.curTokenIs(lexer.TokenEnd) && !p.curTokenIs(lexer.TokenEOF); p.nextToken() { + stmt := p.parseStatement() + if stmt != nil { + bodyStmts = append(bodyStmts, stmt) + } + } + + // Expect 'end' token + if !p.curTokenIs(lexer.TokenEnd) { + p.errors = append(p.errors, fmt.Sprintf("line %d: expected 'end' to close function", + p.curToken.Line)) return nil } - lit.Body = p.parseBlockStatement() + lit.Body = &BlockStatement{ + Token: p.curToken, + Statements: bodyStmts, + } return lit } @@ -628,13 +686,24 @@ func (p *Parser) parseFunctionParameters() []*Identifier { // Empty parameter list if p.peekTokenIs(lexer.TokenRightParen) { - p.nextToken() + p.nextToken() // Skip to right paren + p.nextToken() // Skip right paren return identifiers } - p.nextToken() // Skip '(' + p.nextToken() // Skip left paren // First parameter + if !p.curTokenIs(lexer.TokenIdentifier) { + // Expected identifier for parameter but didn't get one + p.errors = append(p.errors, fmt.Sprintf("line %d: expected parameter name, got %d", + p.curToken.Line, p.curToken.Type)) + if p.expectPeek(lexer.TokenRightParen) { + return identifiers + } + return nil + } + ident := &Identifier{Token: p.curToken, Value: p.curToken.Value} identifiers = append(identifiers, ident) @@ -642,34 +711,26 @@ func (p *Parser) parseFunctionParameters() []*Identifier { for p.peekTokenIs(lexer.TokenComma) { p.nextToken() // Skip current identifier p.nextToken() // Skip comma + + if !p.curTokenIs(lexer.TokenIdentifier) { + p.errors = append(p.errors, fmt.Sprintf("line %d: expected parameter name after comma", + p.curToken.Line)) + break + } + ident := &Identifier{Token: p.curToken, Value: p.curToken.Value} identifiers = append(identifiers, ident) } if !p.expectPeek(lexer.TokenRightParen) { + p.errors = append(p.errors, fmt.Sprintf("line %d: expected ')' to close parameter list", + p.curToken.Line)) return nil } return identifiers } -func (p *Parser) parseBlockStatement() *BlockStatement { - block := &BlockStatement{Token: p.curToken} - block.Statements = []Statement{} - - p.nextToken() // Skip '{' - - for !p.curTokenIs(lexer.TokenRightBrace) && !p.curTokenIs(lexer.TokenEOF) { - stmt := p.parseStatement() - if stmt != nil { - block.Statements = append(block.Statements, stmt) - } - p.nextToken() - } - - return block -} - func (p *Parser) parseCallExpression(function Expression) Expression { exp := &CallExpression{ Token: p.curToken, @@ -701,3 +762,85 @@ func (p *Parser) parseCallExpression(function Expression) Expression { return exp } + +func (p *Parser) parseFunctionStatement() *FunctionStatement { + stmt := &FunctionStatement{Token: p.curToken} + + p.nextToken() // Skip 'function' + + // Parse function name + if !p.curTokenIs(lexer.TokenIdentifier) { + p.errors = append(p.errors, fmt.Sprintf("line %d: expected function name", p.curToken.Line)) + return nil + } + + stmt.Name = &Identifier{Token: p.curToken, Value: p.curToken.Value} + + // Check if next token is a left paren + if !p.expectPeek(lexer.TokenLeftParen) { + return nil + } + + // Parse parameters + stmt.Parameters = []*Identifier{} + + // Check for empty parameter list + if p.peekTokenIs(lexer.TokenRightParen) { + p.nextToken() // Skip to the right paren + } else { + p.nextToken() // Skip the left paren + + // Parse first parameter + if !p.curTokenIs(lexer.TokenIdentifier) { + p.errors = append(p.errors, fmt.Sprintf("line %d: expected parameter name, got %s", + p.curToken.Line, p.curToken.Value)) + return nil + } + + ident := &Identifier{Token: p.curToken, Value: p.curToken.Value} + stmt.Parameters = append(stmt.Parameters, ident) + + // Parse additional parameters + for p.peekTokenIs(lexer.TokenComma) { + p.nextToken() // Skip current parameter + p.nextToken() // Skip comma + + if !p.curTokenIs(lexer.TokenIdentifier) { + p.errors = append(p.errors, fmt.Sprintf("line %d: expected parameter name after comma", + p.curToken.Line)) + return nil + } + + ident := &Identifier{Token: p.curToken, Value: p.curToken.Value} + stmt.Parameters = append(stmt.Parameters, ident) + } + + // After parsing parameters, expect closing parenthesis + if !p.expectPeek(lexer.TokenRightParen) { + return nil + } + } + + // Parse function body + bodyStmts := []Statement{} + for p.nextToken(); !p.curTokenIs(lexer.TokenEnd) && !p.curTokenIs(lexer.TokenEOF); p.nextToken() { + stmt := p.parseStatement() + if stmt != nil { + bodyStmts = append(bodyStmts, stmt) + } + } + + // Expect 'end' token + if !p.curTokenIs(lexer.TokenEnd) { + p.errors = append(p.errors, fmt.Sprintf("line %d: expected 'end' to close function", + p.curToken.Line)) + return nil + } + + stmt.Body = &BlockStatement{ + Token: p.curToken, + Statements: bodyStmts, + } + + return stmt +} diff --git a/tests/funcs.mako b/tests/funcs.mako index c44383c..2cc77ba 100644 --- a/tests/funcs.mako +++ b/tests/funcs.mako @@ -1,5 +1,12 @@ -function add(a, b) - return a+b +// Assignment style +add = function(a, b) + return a + b end -echo add(1, 2) \ No newline at end of file +// Direct declaration style +function multiply(a, b) + return a * b +end + +echo add(1, 2) +echo multiply(2, 2) \ No newline at end of file diff --git a/vm/vm.go b/vm/vm.go index 5e1a686..6a25214 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -241,23 +241,37 @@ func (vm *VM) runCode(instructions []types.Instruction, basePointer int) types.V case types.OpJump: ip = instruction.Operand - 1 // -1 because loop will increment - // Function instructions + // Function instructions case types.OpFunction: constIndex := instruction.Operand function := vm.constants[constIndex].(*types.Function) + // Use the helper function to create a proper function value vm.push(types.NewFunctionValue(function)) case types.OpCall: numArgs := instruction.Operand - fnVal := vm.stack[vm.sp-numArgs-1] - if fnVal.Type != types.TypeFunction { - fmt.Println("Error: attempt to call non-function value") + // The function is at position sp-numArgs-1 + if vm.sp <= numArgs { + fmt.Println("Error: stack underflow during function call") vm.push(types.NewNull()) continue } - function := fnVal.Data.(*types.Function) + fnVal := vm.stack[vm.sp-numArgs-1] + + if fnVal.Type != types.TypeFunction { + fmt.Printf("Error: attempt to call non-function value\n") + vm.push(types.NewNull()) + continue + } + + function, ok := fnVal.Data.(*types.Function) + if !ok { + fmt.Printf("Error: function data is invalid\n") + vm.push(types.NewNull()) + continue + } // Check if we have the correct number of arguments if numArgs != function.NumParams { @@ -286,18 +300,25 @@ func (vm *VM) runCode(instructions []types.Instruction, basePointer int) types.V } vm.frames[vm.fp] = frame - // Set up the call - // We keep the function value and args on the stack - // They become local variables in the function's scope + // Save the current constants + oldConstants := vm.constants + // Switch to function's constants + vm.constants = function.Constants // Run the function code returnValue := vm.runCode(function.Instructions, frame.BasePointer) + // Restore the old constants + vm.constants = oldConstants + // Restore state vm.fp-- - // Return value from function is already on stack from OpReturn - return returnValue + // Replace the function with the return value + vm.stack[frame.BasePointer] = returnValue + + // Adjust the stack pointer to remove the arguments + vm.sp = frame.BasePointer + 1 case types.OpReturn: returnValue := vm.pop() @@ -306,11 +327,11 @@ func (vm *VM) runCode(instructions []types.Instruction, basePointer int) types.V if vm.fp >= 0 { frame := vm.frames[vm.fp] - // Restore the stack - vm.sp = frame.BasePointer + 1 // Keep the function + // Restore the stack to just below the function + vm.sp = frame.BasePointer // Push the return value - vm.stack[vm.sp-1] = returnValue + vm.push(returnValue) // Return to the caller return returnValue