mem leak fix

This commit is contained in:
Sky Johnson 2025-03-19 20:24:47 -05:00
parent 55f27c6f68
commit 03a03af96c
2 changed files with 24 additions and 15 deletions

View File

@ -3,7 +3,6 @@ package runner
import ( import (
"context" "context"
"errors" "errors"
"path/filepath"
"sync" "sync"
"sync/atomic" "sync/atomic"
@ -29,7 +28,7 @@ type LuaRunner struct {
initFunc StateInitFunc // Optional function to initialize Lua state initFunc StateInitFunc // Optional function to initialize Lua state
bufferSize int // Size of the job queue buffer bufferSize int // Size of the job queue buffer
requireCache *RequireCache // Cache for required modules requireCache *RequireCache // Cache for required modules
requireCfg RequireConfig // Configuration for require paths requireCfg *RequireConfig // Configuration for require paths
scriptDir string // Base directory for scripts scriptDir string // Base directory for scripts
libDirs []string // Additional library directories libDirs []string // Additional library directories
} }
@ -40,7 +39,7 @@ func NewRunner(options ...RunnerOption) (*LuaRunner, error) {
runner := &LuaRunner{ runner := &LuaRunner{
bufferSize: 10, // Default buffer size bufferSize: 10, // Default buffer size
requireCache: NewRequireCache(), requireCache: NewRequireCache(),
requireCfg: RequireConfig{ requireCfg: &RequireConfig{
LibDirs: []string{}, LibDirs: []string{},
}, },
} }
@ -61,7 +60,13 @@ func NewRunner(options ...RunnerOption) (*LuaRunner, error) {
runner.jobQueue = make(chan job, runner.bufferSize) runner.jobQueue = make(chan job, runner.bufferSize)
runner.isRunning.Store(true) runner.isRunning.Store(true)
// Set up require functionality // Create a shared config pointer that will be updated per request
runner.requireCfg = &RequireConfig{
ScriptDir: runner.scriptDir,
LibDirs: runner.libDirs,
}
// Set up require functionality ONCE
if err := SetupRequire(state, runner.requireCache, runner.requireCfg); err != nil { if err := SetupRequire(state, runner.requireCache, runner.requireCfg); err != nil {
state.Close() state.Close()
return nil, ErrInitFailed return nil, ErrInitFailed
@ -226,14 +231,11 @@ func (r *LuaRunner) eventLoop() {
// executeJob runs a script in the sandbox environment // executeJob runs a script in the sandbox environment
func (r *LuaRunner) executeJob(j job) JobResult { func (r *LuaRunner) executeJob(j job) JobResult {
// If the job has a script path, update the require context // If the job has a script path, update paths without re-registering
if j.ScriptPath != "" { if j.ScriptPath != "" {
// Update the script directory for require r.mu.Lock()
scriptDir := filepath.Dir(j.ScriptPath) UpdateRequirePaths(r.requireCfg, j.ScriptPath)
r.requireCfg.ScriptDir = scriptDir r.mu.Unlock()
// Update in the require cache config
SetupRequire(r.state, r.requireCache, r.requireCfg)
} }
// Re-run init function if needed // Re-run init function if needed

View File

@ -35,18 +35,18 @@ func NewRequireCache() *RequireCache {
} }
// SetupRequire configures the Lua state with a secure require function // SetupRequire configures the Lua state with a secure require function
func SetupRequire(state *luajit.State, cache *RequireCache, config RequireConfig) error { func SetupRequire(state *luajit.State, cache *RequireCache, config *RequireConfig) error {
// Register the loader function // Register the loader function
err := state.RegisterGoFunction("__go_load_module", func(s *luajit.State) int { err := state.RegisterGoFunction("__go_load_module", func(s *luajit.State) int {
// Get module name // Get module name
modName := s.ToString(1) modName := s.ToString(1)
if modName == "" { if modName == "" {
s.PushString("module name required") s.PushString("module name required")
return -1 // Return error return -1
} }
// Try to load the module // Use the pointer to the shared config
bytecode, err := findAndCompileModule(s, cache, config, modName) bytecode, err := findAndCompileModule(s, cache, *config, modName)
if err != nil { if err != nil {
if err == ErrModuleNotFound { if err == ErrModuleNotFound {
s.PushString("module '" + modName + "' not found") s.PushString("module '" + modName + "' not found")
@ -114,6 +114,13 @@ func SetupRequire(state *luajit.State, cache *RequireCache, config RequireConfig
return state.DoString(setupScript) return state.DoString(setupScript)
} }
// UpdateRequirePaths updates the require paths in the config without further allocations or re-registering the loader.
func UpdateRequirePaths(config *RequireConfig, scriptPath string) {
if scriptPath != "" {
config.ScriptDir = filepath.Dir(scriptPath)
}
}
// findAndCompileModule finds a module in allowed directories and compiles it to bytecode // findAndCompileModule finds a module in allowed directories and compiles it to bytecode
func findAndCompileModule( func findAndCompileModule(
state *luajit.State, state *luajit.State,