From 55f27c6f68fefd8df69756b38fe45a554e6c3160 Mon Sep 17 00:00:00 2001 From: Sky Johnson Date: Wed, 19 Mar 2025 16:50:39 -0500 Subject: [PATCH] require --- .gitignore | 1 + core/http/httplogger.go | 8 +- core/http/server.go | 8 +- core/logger/logger.go | 21 +-- core/routers/luarouter.go | 6 +- core/routers/luarouter_test.go | 6 +- core/runner/job.go | 7 +- core/runner/luarunner.go | 88 ++++++++--- core/runner/luarunner_test.go | 30 ++-- core/runner/require.go | 212 +++++++++++++++++++++++++ core/runner/require_test.go | 280 +++++++++++++++++++++++++++++++++ moonshark.go | 3 + 12 files changed, 609 insertions(+), 61 deletions(-) create mode 100644 core/runner/require.go create mode 100644 core/runner/require_test.go diff --git a/.gitignore b/.gitignore index 88778fe..ed0c89d 100644 --- a/.gitignore +++ b/.gitignore @@ -26,3 +26,4 @@ go.work config.lua routes/ static/ +libs/ diff --git a/core/http/httplogger.go b/core/http/httplogger.go index e86fa08..13ecac4 100644 --- a/core/http/httplogger.go +++ b/core/http/httplogger.go @@ -14,6 +14,7 @@ const ( colorYellow = "\033[33m" // 4xx - Client Errors colorRed = "\033[31m" // 5xx - Server Errors colorReset = "\033[0m" // Reset color + colorGray = "\033[90m" ) // LogRequest logs an HTTP request with custom formatting @@ -21,10 +22,9 @@ func LogRequest(log *logger.Logger, statusCode int, r *http.Request, duration ti statusColor := getStatusColor(statusCode) // Use the logger's raw message writer to bypass the standard format - log.LogRaw("%s [ %s%d%s] %s %s (%v)", - time.Now().Format(log.TimeFormat()), - statusColor, statusCode, colorReset, - r.Method, r.URL.Path, duration) + log.LogRaw("%s%s%s %s%d %s%s %s %s(%v)%s", + colorGray, time.Now().Format(log.TimeFormat()), colorReset, + statusColor, statusCode, r.Method, colorReset, r.URL.Path, colorGray, duration, colorReset) } // getStatusColor returns the ANSI color code for a status code diff --git a/core/http/server.go b/core/http/server.go index a546eda..8d8b873 100644 --- a/core/http/server.go +++ b/core/http/server.go @@ -97,9 +97,9 @@ func (s *Server) handleRequest(w http.ResponseWriter, r *http.Request) { // Try Lua routes first params := &routers.Params{} - if bytecode, found := s.luaRouter.GetBytecode(r.Method, r.URL.Path, params); found { + if bytecode, scriptPath, found := s.luaRouter.GetBytecode(r.Method, r.URL.Path, params); found { s.logger.Debug("Found Lua route match for %s %s with %d params", r.Method, r.URL.Path, params.Count) - s.handleLuaRoute(w, r, bytecode, params) + s.handleLuaRoute(w, r, bytecode, scriptPath, params) return } @@ -114,7 +114,7 @@ func (s *Server) handleRequest(w http.ResponseWriter, r *http.Request) { } // handleLuaRoute executes a Lua route -func (s *Server) handleLuaRoute(w http.ResponseWriter, r *http.Request, bytecode []byte, params *routers.Params) { +func (s *Server) handleLuaRoute(w http.ResponseWriter, r *http.Request, bytecode []byte, scriptPath string, params *routers.Params) { ctx := runner.NewContext() // Log bytecode size @@ -148,7 +148,7 @@ func (s *Server) handleLuaRoute(w http.ResponseWriter, r *http.Request, bytecode } // Execute Lua script - result, err := s.luaRunner.Run(bytecode, ctx) + result, err := s.luaRunner.Run(bytecode, ctx, scriptPath) if err != nil { s.logger.Error("Error executing Lua route: %v", err) http.Error(w, "Internal Server Error", http.StatusInternalServerError) diff --git a/core/logger/logger.go b/core/logger/logger.go index 9229074..247907c 100644 --- a/core/logger/logger.go +++ b/core/logger/logger.go @@ -18,6 +18,7 @@ const ( colorPurple = "\033[35m" colorCyan = "\033[36m" colorWhite = "\033[37m" + colorGray = "\033[90m" ) // Log levels @@ -26,8 +27,8 @@ const ( LevelInfo LevelWarning LevelError - LevelFatal LevelServer + LevelFatal ) // Level names and colors @@ -35,12 +36,12 @@ var levelProps = map[int]struct { tag string color string }{ - LevelDebug: {" DBG", colorCyan}, - LevelInfo: {"INFO", colorBlue}, - LevelWarning: {"WARN", colorYellow}, - LevelError: {" ERR", colorRed}, - LevelFatal: {"FATL", colorPurple}, - LevelServer: {"SRVR", colorGreen}, + LevelDebug: {"DEBUG", colorCyan}, + LevelInfo: {" INFO", colorBlue}, + LevelWarning: {" WARN", colorYellow}, + LevelError: {"ERROR", colorRed}, + LevelServer: {" SYS", colorGreen}, + LevelFatal: {"FATAL", colorPurple}, } // Time format for log messages @@ -110,10 +111,10 @@ func (l *Logger) writeMessage(level int, message string, rawMode bool) { props := levelProps[level] if l.useColors { - logLine = fmt.Sprintf("%s %s[%s]%s %s\n", - now, props.color, props.tag, colorReset, message) + logLine = fmt.Sprintf("%s%s%s %s%s%s %s\n", + colorGray, now, colorReset, props.color, props.tag, colorReset, message) } else { - logLine = fmt.Sprintf("%s [%s] %s\n", + logLine = fmt.Sprintf("%s %s %s\n", now, props.tag, message) } } diff --git a/core/routers/luarouter.go b/core/routers/luarouter.go index 4e7595f..e660159 100644 --- a/core/routers/luarouter.go +++ b/core/routers/luarouter.go @@ -251,12 +251,12 @@ func (r *LuaRouter) compileHandler(n *node) error { } // GetBytecode returns the compiled bytecode for a matched route -func (r *LuaRouter) GetBytecode(method, path string, params *Params) ([]byte, bool) { +func (r *LuaRouter) GetBytecode(method, path string, params *Params) ([]byte, string, bool) { node, found := r.Match(method, path, params) if !found { - return nil, false + return nil, "", false } - return node.bytecode, true + return node.bytecode, node.handler, true } // Refresh rebuilds the router by rescanning the routes directory diff --git a/core/routers/luarouter_test.go b/core/routers/luarouter_test.go index 8fc336d..61b2c5f 100644 --- a/core/routers/luarouter_test.go +++ b/core/routers/luarouter_test.go @@ -173,7 +173,7 @@ func TestGetBytecode(t *testing.T) { } var params Params - bytecode, found := router.GetBytecode("GET", "/api/users/123", ¶ms) + bytecode, _, found := router.GetBytecode("GET", "/api/users/123", ¶ms) if !found { t.Fatalf("Route not found") @@ -212,7 +212,7 @@ func TestRefresh(t *testing.T) { // Before refresh, route should not be found var params Params - _, found := router.GetBytecode("GET", "/new", ¶ms) + _, _, found := router.GetBytecode("GET", "/new", ¶ms) if found { t.Errorf("New route should not be found before refresh") } @@ -224,7 +224,7 @@ func TestRefresh(t *testing.T) { } // After refresh, route should be found - bytecode, found := router.GetBytecode("GET", "/new", ¶ms) + bytecode, _, found := router.GetBytecode("GET", "/new", ¶ms) if !found { t.Errorf("New route should be found after refresh") } diff --git a/core/runner/job.go b/core/runner/job.go index 70f6264..2040016 100644 --- a/core/runner/job.go +++ b/core/runner/job.go @@ -8,7 +8,8 @@ type JobResult struct { // job represents a Lua script execution request type job struct { - Bytecode []byte // Compiled LuaJIT bytecode - Context *Context // Execution context - Result chan<- JobResult // Channel to send result back + Bytecode []byte // Compiled LuaJIT bytecode + Context *Context // Execution context + ScriptPath string // Path to the original script (for require resolution) + Result chan<- JobResult // Channel to send result back } diff --git a/core/runner/luarunner.go b/core/runner/luarunner.go index 7dc591c..7154fe0 100644 --- a/core/runner/luarunner.go +++ b/core/runner/luarunner.go @@ -3,6 +3,7 @@ package runner import ( "context" "errors" + "path/filepath" "sync" "sync/atomic" @@ -20,20 +21,28 @@ type StateInitFunc func(*luajit.State) error // LuaRunner runs Lua scripts using a single Lua state type LuaRunner struct { - state *luajit.State // The Lua state - jobQueue chan job // Channel for incoming jobs - isRunning atomic.Bool // Flag indicating if the runner is active - mu sync.RWMutex // Mutex for thread safety - wg sync.WaitGroup // WaitGroup for clean shutdown - initFunc StateInitFunc // Optional function to initialize Lua state - bufferSize int // Size of the job queue buffer + state *luajit.State // The Lua state + jobQueue chan job // Channel for incoming jobs + isRunning atomic.Bool // Flag indicating if the runner is active + mu sync.RWMutex // Mutex for thread safety + wg sync.WaitGroup // WaitGroup for clean shutdown + initFunc StateInitFunc // Optional function to initialize Lua state + bufferSize int // Size of the job queue buffer + requireCache *RequireCache // Cache for required modules + requireCfg RequireConfig // Configuration for require paths + scriptDir string // Base directory for scripts + libDirs []string // Additional library directories } // NewRunner creates a new LuaRunner func NewRunner(options ...RunnerOption) (*LuaRunner, error) { // Default configuration runner := &LuaRunner{ - bufferSize: 10, // Default buffer size + bufferSize: 10, // Default buffer size + requireCache: NewRequireCache(), + requireCfg: RequireConfig{ + LibDirs: []string{}, + }, } // Apply options @@ -52,6 +61,12 @@ func NewRunner(options ...RunnerOption) (*LuaRunner, error) { runner.jobQueue = make(chan job, runner.bufferSize) runner.isRunning.Store(true) + // Set up require functionality + if err := SetupRequire(state, runner.requireCache, runner.requireCfg); err != nil { + state.Close() + return nil, ErrInitFailed + } + // Set up sandbox if err := runner.setupSandbox(); err != nil { state.Close() @@ -92,6 +107,22 @@ func WithInitFunc(initFunc StateInitFunc) RunnerOption { } } +// WithScriptDir sets the base directory for scripts +func WithScriptDir(dir string) RunnerOption { + return func(r *LuaRunner) { + r.scriptDir = dir + r.requireCfg.ScriptDir = dir + } +} + +// WithLibDirs sets additional library directories +func WithLibDirs(dirs ...string) RunnerOption { + return func(r *LuaRunner) { + r.libDirs = dirs + r.requireCfg.LibDirs = dirs + } +} + // setupSandbox initializes the sandbox environment func (r *LuaRunner) setupSandbox() error { // This is the Lua script that creates our sandbox function @@ -124,10 +155,13 @@ func (r *LuaRunner) setupSandbox() error { env.error = error env.assert = assert - -- Allow access to package.loaded for modules - env.require = function(name) - return package.loaded[name] - end + -- Set up the standard library package table + env.package = { + loaded = {} -- Table to store loaded modules + } + + -- Set up secure require function + env = __setup_secure_require(env) -- Create metatable to restrict access to _G local mt = { @@ -192,6 +226,16 @@ func (r *LuaRunner) eventLoop() { // executeJob runs a script in the sandbox environment func (r *LuaRunner) executeJob(j job) JobResult { + // If the job has a script path, update the require context + if j.ScriptPath != "" { + // Update the script directory for require + scriptDir := filepath.Dir(j.ScriptPath) + r.requireCfg.ScriptDir = scriptDir + + // Update in the require cache config + SetupRequire(r.state, r.requireCache, r.requireCfg) + } + // Re-run init function if needed if r.initFunc != nil { if err := r.initFunc(r.state); err != nil { @@ -223,7 +267,7 @@ func (r *LuaRunner) executeJob(j job) JobResult { } // Load bytecode - if err := r.state.LoadBytecode(j.Bytecode, "script"); err != nil { + if err := r.state.LoadBytecode(j.Bytecode, j.ScriptPath); err != nil { r.state.Pop(1) // Pop context return JobResult{nil, err} } @@ -252,7 +296,7 @@ func (r *LuaRunner) executeJob(j job) JobResult { } // RunWithContext executes a Lua script with context and timeout -func (r *LuaRunner) RunWithContext(ctx context.Context, bytecode []byte, execCtx *Context) (any, error) { +func (r *LuaRunner) RunWithContext(ctx context.Context, bytecode []byte, execCtx *Context, scriptPath string) (any, error) { r.mu.RLock() if !r.isRunning.Load() { r.mu.RUnlock() @@ -262,9 +306,10 @@ func (r *LuaRunner) RunWithContext(ctx context.Context, bytecode []byte, execCtx resultChan := make(chan JobResult, 1) j := job{ - Bytecode: bytecode, - Context: execCtx, - Result: resultChan, + Bytecode: bytecode, + Context: execCtx, + ScriptPath: scriptPath, + Result: resultChan, } // Submit job with context @@ -285,8 +330,8 @@ func (r *LuaRunner) RunWithContext(ctx context.Context, bytecode []byte, execCtx } // Run executes a Lua script -func (r *LuaRunner) Run(bytecode []byte, execCtx *Context) (any, error) { - return r.RunWithContext(context.Background(), bytecode, execCtx) +func (r *LuaRunner) Run(bytecode []byte, execCtx *Context, scriptPath string) (any, error) { + return r.RunWithContext(context.Background(), bytecode, execCtx, scriptPath) } // Close gracefully shuts down the LuaRunner @@ -306,3 +351,8 @@ func (r *LuaRunner) Close() error { return nil } + +// ClearRequireCache clears the cache of loaded modules +func (r *LuaRunner) ClearRequireCache() { + r.requireCache = NewRequireCache() +} diff --git a/core/runner/luarunner_test.go b/core/runner/luarunner_test.go index 7debb43..f11d934 100644 --- a/core/runner/luarunner_test.go +++ b/core/runner/luarunner_test.go @@ -33,7 +33,7 @@ func TestRunnerBasic(t *testing.T) { bytecode := createTestBytecode(t, "return 42") - result, err := runner.Run(bytecode, nil) + result, err := runner.Run(bytecode, nil, "") if err != nil { t.Fatalf("Failed to run script: %v", err) } @@ -70,7 +70,7 @@ func TestRunnerWithContext(t *testing.T) { execCtx.Set("enabled", true) execCtx.Set("table", []float64{10, 20, 30}) - result, err := runner.Run(bytecode, execCtx) + result, err := runner.Run(bytecode, execCtx, "") if err != nil { t.Fatalf("Failed to run job: %v", err) } @@ -124,7 +124,7 @@ func TestRunnerWithTimeout(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() - result, err := runner.RunWithContext(ctx, bytecode, nil) + result, err := runner.RunWithContext(ctx, bytecode, nil, "") if err != nil { t.Fatalf("Unexpected error with sufficient timeout: %v", err) } @@ -136,7 +136,7 @@ func TestRunnerWithTimeout(t *testing.T) { ctx, cancel = context.WithTimeout(context.Background(), 50*time.Millisecond) defer cancel() - _, err = runner.RunWithContext(ctx, bytecode, nil) + _, err = runner.RunWithContext(ctx, bytecode, nil, "") if err == nil { t.Errorf("Expected timeout error, got nil") } @@ -156,7 +156,7 @@ func TestSandboxIsolation(t *testing.T) { return true `) - _, err = runner.Run(bytecode1, nil) + _, err = runner.Run(bytecode1, nil, "") if err != nil { t.Fatalf("Failed to execute first script: %v", err) } @@ -167,7 +167,7 @@ func TestSandboxIsolation(t *testing.T) { return my_global ~= nil `) - result, err := runner.Run(bytecode2, nil) + result, err := runner.Run(bytecode2, nil, "") if err != nil { t.Fatalf("Failed to execute second script: %v", err) } @@ -221,7 +221,7 @@ func TestRunnerWithInit(t *testing.T) { // Test the add function bytecode1 := createTestBytecode(t, "return add(5, 7)") - result1, err := runner.Run(bytecode1, nil) + result1, err := runner.Run(bytecode1, nil, "") if err != nil { t.Fatalf("Failed to call add function: %v", err) } @@ -233,7 +233,7 @@ func TestRunnerWithInit(t *testing.T) { // Test the math2 module bytecode2 := createTestBytecode(t, "return math2.multiply(6, 8)") - result2, err := runner.Run(bytecode2, nil) + result2, err := runner.Run(bytecode2, nil, "") if err != nil { t.Fatalf("Failed to call math2.multiply: %v", err) } @@ -264,7 +264,7 @@ func TestConcurrentExecution(t *testing.T) { execCtx := NewContext() execCtx.Set("n", float64(i)) - result, err := runner.Run(bytecode, execCtx) + result, err := runner.Run(bytecode, execCtx, "") if err != nil { t.Errorf("Job %d failed: %v", i, err) results <- -1 @@ -305,7 +305,7 @@ func TestRunnerClose(t *testing.T) { // Submit a job to verify runner works bytecode := createTestBytecode(t, "return 42") - _, err = runner.Run(bytecode, nil) + _, err = runner.Run(bytecode, nil, "") if err != nil { t.Fatalf("Failed to run job: %v", err) } @@ -316,7 +316,7 @@ func TestRunnerClose(t *testing.T) { } // Run after close should fail - _, err = runner.Run(bytecode, nil) + _, err = runner.Run(bytecode, nil, "") if err != ErrRunnerClosed { t.Errorf("Expected ErrRunnerClosed, got %v", err) } @@ -335,7 +335,7 @@ func TestErrorHandling(t *testing.T) { defer runner.Close() // Test invalid bytecode - _, err = runner.Run([]byte("not valid bytecode"), nil) + _, err = runner.Run([]byte("not valid bytecode"), nil, "") if err == nil { t.Errorf("Expected error for invalid bytecode, got nil") } @@ -346,14 +346,14 @@ func TestErrorHandling(t *testing.T) { return true `) - _, err = runner.Run(bytecode, nil) + _, err = runner.Run(bytecode, nil, "") if err == nil { t.Errorf("Expected error from Lua error() call, got nil") } // Test with nil context bytecode = createTestBytecode(t, "return ctx == nil") - result, err := runner.Run(bytecode, nil) + result, err := runner.Run(bytecode, nil, "") if err != nil { t.Errorf("Unexpected error with nil context: %v", err) } @@ -366,7 +366,7 @@ func TestErrorHandling(t *testing.T) { execCtx.Set("param", complex128(1+2i)) // Unsupported type bytecode = createTestBytecode(t, "return ctx.param") - _, err = runner.Run(bytecode, execCtx) + _, err = runner.Run(bytecode, execCtx, "") if err == nil { t.Errorf("Expected error for unsupported context value type, got nil") } diff --git a/core/runner/require.go b/core/runner/require.go new file mode 100644 index 0000000..27ae50e --- /dev/null +++ b/core/runner/require.go @@ -0,0 +1,212 @@ +package runner + +import ( + "errors" + "os" + "path/filepath" + "strings" + "sync" + + luajit "git.sharkk.net/Sky/LuaJIT-to-Go" +) + +// Common errors +var ( + ErrModuleNotFound = errors.New("module not found") + ErrPathTraversal = errors.New("path traversal not allowed") +) + +// RequireConfig holds configuration for Lua's require function +type RequireConfig struct { + ScriptDir string // Base directory for script being executed + LibDirs []string // Additional library directories +} + +// RequireCache is a thread-safe cache for loaded Lua modules +type RequireCache struct { + modules sync.Map // Maps full file paths to compiled bytecode +} + +// NewRequireCache creates a new, empty require cache +func NewRequireCache() *RequireCache { + return &RequireCache{ + modules: sync.Map{}, + } +} + +// SetupRequire configures the Lua state with a secure require function +func SetupRequire(state *luajit.State, cache *RequireCache, config RequireConfig) error { + // Register the loader function + err := state.RegisterGoFunction("__go_load_module", func(s *luajit.State) int { + // Get module name + modName := s.ToString(1) + if modName == "" { + s.PushString("module name required") + return -1 // Return error + } + + // Try to load the module + bytecode, err := findAndCompileModule(s, cache, config, modName) + if err != nil { + if err == ErrModuleNotFound { + s.PushString("module '" + modName + "' not found") + } else { + s.PushString("error loading module: " + err.Error()) + } + return -1 // Return error + } + + // Load the bytecode + if err := s.LoadBytecode(bytecode, modName); err != nil { + s.PushString("error loading bytecode: " + err.Error()) + return -1 // Return error + } + + // Return the loaded function + return 1 + }) + + if err != nil { + return err + } + + // Set up the secure require implementation + setupScript := ` + -- Create a secure require function for sandboxed environments + function __setup_secure_require(env) + -- Replace env.require with our secure version + env.require = function(modname) + -- Check if already loaded in package.loaded + if package.loaded[modname] then + return package.loaded[modname] + end + + -- Try to load the module using our Go loader + local loader = __go_load_module + + -- Load the module + local f, err = loader(modname) + if not f then + error(err or "failed to load module: " .. modname) + end + + -- Set the environment for the module + setfenv(f, env) + + -- Execute the module + local result = f() + + -- If module didn't return a value, use true + if result == nil then + result = true + end + + -- Cache the result + package.loaded[modname] = result + + return result + end + + return env + end + ` + + return state.DoString(setupScript) +} + +// findAndCompileModule finds a module in allowed directories and compiles it to bytecode +func findAndCompileModule( + state *luajit.State, + cache *RequireCache, + config RequireConfig, + modName string, +) ([]byte, error) { + // Convert module name to relative path + modPath := strings.ReplaceAll(modName, ".", string(filepath.Separator)) + + // List of paths to check + paths := []string{} + + // 1. Check adjacent to script directory first + if config.ScriptDir != "" { + paths = append(paths, filepath.Join(config.ScriptDir, modPath+".lua")) + } + + // 2. Check in lib directories + for _, libDir := range config.LibDirs { + if libDir != "" { + paths = append(paths, filepath.Join(libDir, modPath+".lua")) + } + } + + // Try each path + for _, path := range paths { + // Clean the path to handle .. and such (security) + cleanPath := filepath.Clean(path) + + // Check for path traversal (extra safety) + if !isSubPath(config.ScriptDir, cleanPath) { + isValidLib := false + for _, libDir := range config.LibDirs { + if isSubPath(libDir, cleanPath) { + isValidLib = true + break + } + } + + if !isValidLib { + continue // Skip paths outside allowed directories + } + } + + // Check if already in cache + if bytecode, ok := cache.modules.Load(cleanPath); ok { + return bytecode.([]byte), nil + } + + // Check if file exists + _, err := os.Stat(cleanPath) + if os.IsNotExist(err) { + continue + } + + // Read and compile the file + content, err := os.ReadFile(cleanPath) + if err != nil { + return nil, err + } + + // Compile to bytecode + bytecode, err := state.CompileBytecode(string(content), cleanPath) + if err != nil { + return nil, err + } + + // Store in cache + cache.modules.Store(cleanPath, bytecode) + + return bytecode, nil + } + + return nil, ErrModuleNotFound +} + +// isSubPath checks if path is contained within base directory +func isSubPath(baseDir, path string) bool { + if baseDir == "" { + return false + } + + // Clean and normalize paths + baseDir = filepath.Clean(baseDir) + path = filepath.Clean(path) + + // Get relative path + rel, err := filepath.Rel(baseDir, path) + if err != nil { + return false + } + + // Check if path goes outside baseDir + return !strings.HasPrefix(rel, ".."+string(filepath.Separator)) && rel != ".." +} diff --git a/core/runner/require_test.go b/core/runner/require_test.go new file mode 100644 index 0000000..0663c65 --- /dev/null +++ b/core/runner/require_test.go @@ -0,0 +1,280 @@ +package runner_test + +import ( + "fmt" + "os" + "path/filepath" + "testing" + + luajit "git.sharkk.net/Sky/LuaJIT-to-Go" + "git.sharkk.net/Sky/Moonshark/core/runner" +) + +func TestRequireFunctionality(t *testing.T) { + // Create temporary directories for test + tempDir, err := os.MkdirTemp("", "luarunner-test-*") + if err != nil { + t.Fatalf("Failed to create temp directory: %v", err) + } + defer os.RemoveAll(tempDir) + + // Create script directory and lib directory + scriptDir := filepath.Join(tempDir, "scripts") + libDir := filepath.Join(tempDir, "libs") + + if err := os.Mkdir(scriptDir, 0755); err != nil { + t.Fatalf("Failed to create script directory: %v", err) + } + if err := os.Mkdir(libDir, 0755); err != nil { + t.Fatalf("Failed to create lib directory: %v", err) + } + + // Create a module in the lib directory + libModule := ` + local lib = {} + + function lib.add(a, b) + return a + b + end + + function lib.mul(a, b) + return a * b + end + + return lib + ` + if err := os.WriteFile(filepath.Join(libDir, "mathlib.lua"), []byte(libModule), 0644); err != nil { + t.Fatalf("Failed to write lib module: %v", err) + } + + // Create a helper module in the script directory + helperModule := ` + local helper = {} + + function helper.square(x) + return x * x + end + + function helper.cube(x) + return x * x * x + end + + return helper + ` + if err := os.WriteFile(filepath.Join(scriptDir, "helper.lua"), []byte(helperModule), 0644); err != nil { + t.Fatalf("Failed to write helper module: %v", err) + } + + // Create main script that requires both modules + mainScript := ` + -- Require from the same directory + local helper = require("helper") + + -- Require from the lib directory + local mathlib = require("mathlib") + + -- Use both modules + local result = { + add = mathlib.add(10, 5), + mul = mathlib.mul(10, 5), + square = helper.square(5), + cube = helper.cube(3) + } + + return result + ` + mainScriptPath := filepath.Join(scriptDir, "main.lua") + if err := os.WriteFile(mainScriptPath, []byte(mainScript), 0644); err != nil { + t.Fatalf("Failed to write main script: %v", err) + } + + // Create LuaRunner + luaRunner, err := runner.NewRunner( + runner.WithScriptDir(scriptDir), + runner.WithLibDirs(libDir), + ) + if err != nil { + t.Fatalf("Failed to create LuaRunner: %v", err) + } + defer luaRunner.Close() + + // Compile the main script + state := luajit.New() + if state == nil { + t.Fatal("Failed to create Lua state") + } + defer state.Close() + + bytecode, err := state.CompileBytecode(mainScript, "main.lua") + if err != nil { + t.Fatalf("Failed to compile script: %v", err) + } + + // Run the script + result, err := luaRunner.Run(bytecode, nil, mainScriptPath) + if err != nil { + t.Fatalf("Failed to run script: %v", err) + } + + // Check result + resultMap, ok := result.(map[string]any) + if !ok { + t.Fatalf("Expected map result, got %T", result) + } + + // Validate results + expectedResults := map[string]float64{ + "add": 15, // 10 + 5 + "mul": 50, // 10 * 5 + "square": 25, // 5^2 + "cube": 27, // 3^3 + } + + for key, expected := range expectedResults { + if val, ok := resultMap[key]; !ok { + t.Errorf("Missing result key: %s", key) + } else if val != expected { + t.Errorf("For %s: expected %.1f, got %v", key, expected, val) + } + } +} + +func TestRequireSecurityBoundaries(t *testing.T) { + // Create temporary directories for test + tempDir, err := os.MkdirTemp("", "luarunner-security-test-*") + if err != nil { + t.Fatalf("Failed to create temp directory: %v", err) + } + defer os.RemoveAll(tempDir) + + // Create script directory and lib directory + scriptDir := filepath.Join(tempDir, "scripts") + libDir := filepath.Join(tempDir, "libs") + secretDir := filepath.Join(tempDir, "secret") + + if err := os.Mkdir(scriptDir, 0755); err != nil { + t.Fatalf("Failed to create script directory: %v", err) + } + if err := os.Mkdir(libDir, 0755); err != nil { + t.Fatalf("Failed to create lib directory: %v", err) + } + if err := os.Mkdir(secretDir, 0755); err != nil { + t.Fatalf("Failed to create secret directory: %v", err) + } + + // Create a "secret" module that should not be accessible + secretModule := ` + local secret = "TOP SECRET" + return secret + ` + if err := os.WriteFile(filepath.Join(secretDir, "secret.lua"), []byte(secretModule), 0644); err != nil { + t.Fatalf("Failed to write secret module: %v", err) + } + + // Create a normal module in lib + normalModule := `return "normal module"` + if err := os.WriteFile(filepath.Join(libDir, "normal.lua"), []byte(normalModule), 0644); err != nil { + t.Fatalf("Failed to write normal module: %v", err) + } + + // Test attempting to access file outside allowed paths + pathTraversalTests := []struct { + name string + script string + }{ + { + name: "Direct path traversal", + script: ` + -- Try path traversal + local secret = require("../secret/secret") + return secret + `, + }, + { + name: "Double dot traversal", + script: ` + local secret = require("..secret.secret") + return secret + `, + }, + { + name: "Absolute path", + script: fmt.Sprintf(` + local secret = require("%s") + return secret + `, filepath.Join(secretDir, "secret")), + }, + } + + // Create and configure runner + luaRunner, err := runner.NewRunner( + runner.WithScriptDir(scriptDir), + runner.WithLibDirs(libDir), + ) + if err != nil { + t.Fatalf("Failed to create LuaRunner: %v", err) + } + defer luaRunner.Close() + + // Test each attempt at path traversal + for _, tt := range pathTraversalTests { + t.Run(tt.name, func(t *testing.T) { + // Write the script + scriptPath := filepath.Join(scriptDir, tt.name+".lua") + if err := os.WriteFile(scriptPath, []byte(tt.script), 0644); err != nil { + t.Fatalf("Failed to write test script: %v", err) + } + + // Compile + state := luajit.New() + if state == nil { + t.Fatal("Failed to create Lua state") + } + defer state.Close() + + bytecode, err := state.CompileBytecode(tt.script, tt.name+".lua") + if err != nil { + t.Fatalf("Failed to compile script: %v", err) + } + + // Run and expect error + _, err = luaRunner.Run(bytecode, nil, scriptPath) + if err == nil { + t.Error("Expected error for path traversal, got nil") + } + }) + } + + // Test that we can still require valid modules + normalScript := ` + local normal = require("normal") + return normal + ` + scriptPath := filepath.Join(scriptDir, "normal_test.lua") + if err := os.WriteFile(scriptPath, []byte(normalScript), 0644); err != nil { + t.Fatalf("Failed to write normal test script: %v", err) + } + + // Compile + state := luajit.New() + if state == nil { + t.Fatal("Failed to create Lua state") + } + defer state.Close() + + bytecode, err := state.CompileBytecode(normalScript, "normal_test.lua") + if err != nil { + t.Fatalf("Failed to compile script: %v", err) + } + + // Run and expect success + result, err := luaRunner.Run(bytecode, nil, scriptPath) + if err != nil { + t.Fatalf("Failed to run normal script: %v", err) + } + + // Check result + if result != "normal module" { + t.Errorf("Expected 'normal module', got %v", result) + } +} diff --git a/moonshark.go b/moonshark.go index 953084a..60d23c9 100644 --- a/moonshark.go +++ b/moonshark.go @@ -66,6 +66,8 @@ func main() { log.SetLevel(logger.LevelWarning) case "error": log.SetLevel(logger.LevelError) + case "server": + log.SetLevel(logger.LevelServer) case "fatal": log.SetLevel(logger.LevelFatal) default: @@ -108,6 +110,7 @@ func main() { // Initialize Lua runner (replacing worker pool) runner, err := runner.NewRunner( runner.WithBufferSize(bufferSize), + runner.WithLibDirs("./libs"), ) if err != nil { log.Fatal("Failed to initialize Lua runner: %v", err)