require
This commit is contained in:
parent
95a4187d3f
commit
55f27c6f68
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -26,3 +26,4 @@ go.work
|
|||
config.lua
|
||||
routes/
|
||||
static/
|
||||
libs/
|
||||
|
|
|
@ -14,6 +14,7 @@ const (
|
|||
colorYellow = "\033[33m" // 4xx - Client Errors
|
||||
colorRed = "\033[31m" // 5xx - Server Errors
|
||||
colorReset = "\033[0m" // Reset color
|
||||
colorGray = "\033[90m"
|
||||
)
|
||||
|
||||
// LogRequest logs an HTTP request with custom formatting
|
||||
|
@ -21,10 +22,9 @@ func LogRequest(log *logger.Logger, statusCode int, r *http.Request, duration ti
|
|||
statusColor := getStatusColor(statusCode)
|
||||
|
||||
// Use the logger's raw message writer to bypass the standard format
|
||||
log.LogRaw("%s [ %s%d%s] %s %s (%v)",
|
||||
time.Now().Format(log.TimeFormat()),
|
||||
statusColor, statusCode, colorReset,
|
||||
r.Method, r.URL.Path, duration)
|
||||
log.LogRaw("%s%s%s %s%d %s%s %s %s(%v)%s",
|
||||
colorGray, time.Now().Format(log.TimeFormat()), colorReset,
|
||||
statusColor, statusCode, r.Method, colorReset, r.URL.Path, colorGray, duration, colorReset)
|
||||
}
|
||||
|
||||
// getStatusColor returns the ANSI color code for a status code
|
||||
|
|
|
@ -97,9 +97,9 @@ func (s *Server) handleRequest(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
// Try Lua routes first
|
||||
params := &routers.Params{}
|
||||
if bytecode, found := s.luaRouter.GetBytecode(r.Method, r.URL.Path, params); found {
|
||||
if bytecode, scriptPath, found := s.luaRouter.GetBytecode(r.Method, r.URL.Path, params); found {
|
||||
s.logger.Debug("Found Lua route match for %s %s with %d params", r.Method, r.URL.Path, params.Count)
|
||||
s.handleLuaRoute(w, r, bytecode, params)
|
||||
s.handleLuaRoute(w, r, bytecode, scriptPath, params)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -114,7 +114,7 @@ func (s *Server) handleRequest(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// handleLuaRoute executes a Lua route
|
||||
func (s *Server) handleLuaRoute(w http.ResponseWriter, r *http.Request, bytecode []byte, params *routers.Params) {
|
||||
func (s *Server) handleLuaRoute(w http.ResponseWriter, r *http.Request, bytecode []byte, scriptPath string, params *routers.Params) {
|
||||
ctx := runner.NewContext()
|
||||
|
||||
// Log bytecode size
|
||||
|
@ -148,7 +148,7 @@ func (s *Server) handleLuaRoute(w http.ResponseWriter, r *http.Request, bytecode
|
|||
}
|
||||
|
||||
// Execute Lua script
|
||||
result, err := s.luaRunner.Run(bytecode, ctx)
|
||||
result, err := s.luaRunner.Run(bytecode, ctx, scriptPath)
|
||||
if err != nil {
|
||||
s.logger.Error("Error executing Lua route: %v", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
|
|
|
@ -18,6 +18,7 @@ const (
|
|||
colorPurple = "\033[35m"
|
||||
colorCyan = "\033[36m"
|
||||
colorWhite = "\033[37m"
|
||||
colorGray = "\033[90m"
|
||||
)
|
||||
|
||||
// Log levels
|
||||
|
@ -26,8 +27,8 @@ const (
|
|||
LevelInfo
|
||||
LevelWarning
|
||||
LevelError
|
||||
LevelFatal
|
||||
LevelServer
|
||||
LevelFatal
|
||||
)
|
||||
|
||||
// Level names and colors
|
||||
|
@ -35,12 +36,12 @@ var levelProps = map[int]struct {
|
|||
tag string
|
||||
color string
|
||||
}{
|
||||
LevelDebug: {" DBG", colorCyan},
|
||||
LevelInfo: {"INFO", colorBlue},
|
||||
LevelWarning: {"WARN", colorYellow},
|
||||
LevelError: {" ERR", colorRed},
|
||||
LevelFatal: {"FATL", colorPurple},
|
||||
LevelServer: {"SRVR", colorGreen},
|
||||
LevelDebug: {"DEBUG", colorCyan},
|
||||
LevelInfo: {" INFO", colorBlue},
|
||||
LevelWarning: {" WARN", colorYellow},
|
||||
LevelError: {"ERROR", colorRed},
|
||||
LevelServer: {" SYS", colorGreen},
|
||||
LevelFatal: {"FATAL", colorPurple},
|
||||
}
|
||||
|
||||
// Time format for log messages
|
||||
|
@ -110,10 +111,10 @@ func (l *Logger) writeMessage(level int, message string, rawMode bool) {
|
|||
props := levelProps[level]
|
||||
|
||||
if l.useColors {
|
||||
logLine = fmt.Sprintf("%s %s[%s]%s %s\n",
|
||||
now, props.color, props.tag, colorReset, message)
|
||||
logLine = fmt.Sprintf("%s%s%s %s%s%s %s\n",
|
||||
colorGray, now, colorReset, props.color, props.tag, colorReset, message)
|
||||
} else {
|
||||
logLine = fmt.Sprintf("%s [%s] %s\n",
|
||||
logLine = fmt.Sprintf("%s %s %s\n",
|
||||
now, props.tag, message)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -251,12 +251,12 @@ func (r *LuaRouter) compileHandler(n *node) error {
|
|||
}
|
||||
|
||||
// GetBytecode returns the compiled bytecode for a matched route
|
||||
func (r *LuaRouter) GetBytecode(method, path string, params *Params) ([]byte, bool) {
|
||||
func (r *LuaRouter) GetBytecode(method, path string, params *Params) ([]byte, string, bool) {
|
||||
node, found := r.Match(method, path, params)
|
||||
if !found {
|
||||
return nil, false
|
||||
return nil, "", false
|
||||
}
|
||||
return node.bytecode, true
|
||||
return node.bytecode, node.handler, true
|
||||
}
|
||||
|
||||
// Refresh rebuilds the router by rescanning the routes directory
|
||||
|
|
|
@ -173,7 +173,7 @@ func TestGetBytecode(t *testing.T) {
|
|||
}
|
||||
|
||||
var params Params
|
||||
bytecode, found := router.GetBytecode("GET", "/api/users/123", ¶ms)
|
||||
bytecode, _, found := router.GetBytecode("GET", "/api/users/123", ¶ms)
|
||||
|
||||
if !found {
|
||||
t.Fatalf("Route not found")
|
||||
|
@ -212,7 +212,7 @@ func TestRefresh(t *testing.T) {
|
|||
|
||||
// Before refresh, route should not be found
|
||||
var params Params
|
||||
_, found := router.GetBytecode("GET", "/new", ¶ms)
|
||||
_, _, found := router.GetBytecode("GET", "/new", ¶ms)
|
||||
if found {
|
||||
t.Errorf("New route should not be found before refresh")
|
||||
}
|
||||
|
@ -224,7 +224,7 @@ func TestRefresh(t *testing.T) {
|
|||
}
|
||||
|
||||
// After refresh, route should be found
|
||||
bytecode, found := router.GetBytecode("GET", "/new", ¶ms)
|
||||
bytecode, _, found := router.GetBytecode("GET", "/new", ¶ms)
|
||||
if !found {
|
||||
t.Errorf("New route should be found after refresh")
|
||||
}
|
||||
|
|
|
@ -10,5 +10,6 @@ type JobResult struct {
|
|||
type job struct {
|
||||
Bytecode []byte // Compiled LuaJIT bytecode
|
||||
Context *Context // Execution context
|
||||
ScriptPath string // Path to the original script (for require resolution)
|
||||
Result chan<- JobResult // Channel to send result back
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package runner
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
|
@ -27,6 +28,10 @@ type LuaRunner struct {
|
|||
wg sync.WaitGroup // WaitGroup for clean shutdown
|
||||
initFunc StateInitFunc // Optional function to initialize Lua state
|
||||
bufferSize int // Size of the job queue buffer
|
||||
requireCache *RequireCache // Cache for required modules
|
||||
requireCfg RequireConfig // Configuration for require paths
|
||||
scriptDir string // Base directory for scripts
|
||||
libDirs []string // Additional library directories
|
||||
}
|
||||
|
||||
// NewRunner creates a new LuaRunner
|
||||
|
@ -34,6 +39,10 @@ func NewRunner(options ...RunnerOption) (*LuaRunner, error) {
|
|||
// Default configuration
|
||||
runner := &LuaRunner{
|
||||
bufferSize: 10, // Default buffer size
|
||||
requireCache: NewRequireCache(),
|
||||
requireCfg: RequireConfig{
|
||||
LibDirs: []string{},
|
||||
},
|
||||
}
|
||||
|
||||
// Apply options
|
||||
|
@ -52,6 +61,12 @@ func NewRunner(options ...RunnerOption) (*LuaRunner, error) {
|
|||
runner.jobQueue = make(chan job, runner.bufferSize)
|
||||
runner.isRunning.Store(true)
|
||||
|
||||
// Set up require functionality
|
||||
if err := SetupRequire(state, runner.requireCache, runner.requireCfg); err != nil {
|
||||
state.Close()
|
||||
return nil, ErrInitFailed
|
||||
}
|
||||
|
||||
// Set up sandbox
|
||||
if err := runner.setupSandbox(); err != nil {
|
||||
state.Close()
|
||||
|
@ -92,6 +107,22 @@ func WithInitFunc(initFunc StateInitFunc) RunnerOption {
|
|||
}
|
||||
}
|
||||
|
||||
// WithScriptDir sets the base directory for scripts
|
||||
func WithScriptDir(dir string) RunnerOption {
|
||||
return func(r *LuaRunner) {
|
||||
r.scriptDir = dir
|
||||
r.requireCfg.ScriptDir = dir
|
||||
}
|
||||
}
|
||||
|
||||
// WithLibDirs sets additional library directories
|
||||
func WithLibDirs(dirs ...string) RunnerOption {
|
||||
return func(r *LuaRunner) {
|
||||
r.libDirs = dirs
|
||||
r.requireCfg.LibDirs = dirs
|
||||
}
|
||||
}
|
||||
|
||||
// setupSandbox initializes the sandbox environment
|
||||
func (r *LuaRunner) setupSandbox() error {
|
||||
// This is the Lua script that creates our sandbox function
|
||||
|
@ -124,10 +155,13 @@ func (r *LuaRunner) setupSandbox() error {
|
|||
env.error = error
|
||||
env.assert = assert
|
||||
|
||||
-- Allow access to package.loaded for modules
|
||||
env.require = function(name)
|
||||
return package.loaded[name]
|
||||
end
|
||||
-- Set up the standard library package table
|
||||
env.package = {
|
||||
loaded = {} -- Table to store loaded modules
|
||||
}
|
||||
|
||||
-- Set up secure require function
|
||||
env = __setup_secure_require(env)
|
||||
|
||||
-- Create metatable to restrict access to _G
|
||||
local mt = {
|
||||
|
@ -192,6 +226,16 @@ func (r *LuaRunner) eventLoop() {
|
|||
|
||||
// executeJob runs a script in the sandbox environment
|
||||
func (r *LuaRunner) executeJob(j job) JobResult {
|
||||
// If the job has a script path, update the require context
|
||||
if j.ScriptPath != "" {
|
||||
// Update the script directory for require
|
||||
scriptDir := filepath.Dir(j.ScriptPath)
|
||||
r.requireCfg.ScriptDir = scriptDir
|
||||
|
||||
// Update in the require cache config
|
||||
SetupRequire(r.state, r.requireCache, r.requireCfg)
|
||||
}
|
||||
|
||||
// Re-run init function if needed
|
||||
if r.initFunc != nil {
|
||||
if err := r.initFunc(r.state); err != nil {
|
||||
|
@ -223,7 +267,7 @@ func (r *LuaRunner) executeJob(j job) JobResult {
|
|||
}
|
||||
|
||||
// Load bytecode
|
||||
if err := r.state.LoadBytecode(j.Bytecode, "script"); err != nil {
|
||||
if err := r.state.LoadBytecode(j.Bytecode, j.ScriptPath); err != nil {
|
||||
r.state.Pop(1) // Pop context
|
||||
return JobResult{nil, err}
|
||||
}
|
||||
|
@ -252,7 +296,7 @@ func (r *LuaRunner) executeJob(j job) JobResult {
|
|||
}
|
||||
|
||||
// RunWithContext executes a Lua script with context and timeout
|
||||
func (r *LuaRunner) RunWithContext(ctx context.Context, bytecode []byte, execCtx *Context) (any, error) {
|
||||
func (r *LuaRunner) RunWithContext(ctx context.Context, bytecode []byte, execCtx *Context, scriptPath string) (any, error) {
|
||||
r.mu.RLock()
|
||||
if !r.isRunning.Load() {
|
||||
r.mu.RUnlock()
|
||||
|
@ -264,6 +308,7 @@ func (r *LuaRunner) RunWithContext(ctx context.Context, bytecode []byte, execCtx
|
|||
j := job{
|
||||
Bytecode: bytecode,
|
||||
Context: execCtx,
|
||||
ScriptPath: scriptPath,
|
||||
Result: resultChan,
|
||||
}
|
||||
|
||||
|
@ -285,8 +330,8 @@ func (r *LuaRunner) RunWithContext(ctx context.Context, bytecode []byte, execCtx
|
|||
}
|
||||
|
||||
// Run executes a Lua script
|
||||
func (r *LuaRunner) Run(bytecode []byte, execCtx *Context) (any, error) {
|
||||
return r.RunWithContext(context.Background(), bytecode, execCtx)
|
||||
func (r *LuaRunner) Run(bytecode []byte, execCtx *Context, scriptPath string) (any, error) {
|
||||
return r.RunWithContext(context.Background(), bytecode, execCtx, scriptPath)
|
||||
}
|
||||
|
||||
// Close gracefully shuts down the LuaRunner
|
||||
|
@ -306,3 +351,8 @@ func (r *LuaRunner) Close() error {
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ClearRequireCache clears the cache of loaded modules
|
||||
func (r *LuaRunner) ClearRequireCache() {
|
||||
r.requireCache = NewRequireCache()
|
||||
}
|
||||
|
|
|
@ -33,7 +33,7 @@ func TestRunnerBasic(t *testing.T) {
|
|||
|
||||
bytecode := createTestBytecode(t, "return 42")
|
||||
|
||||
result, err := runner.Run(bytecode, nil)
|
||||
result, err := runner.Run(bytecode, nil, "")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to run script: %v", err)
|
||||
}
|
||||
|
@ -70,7 +70,7 @@ func TestRunnerWithContext(t *testing.T) {
|
|||
execCtx.Set("enabled", true)
|
||||
execCtx.Set("table", []float64{10, 20, 30})
|
||||
|
||||
result, err := runner.Run(bytecode, execCtx)
|
||||
result, err := runner.Run(bytecode, execCtx, "")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to run job: %v", err)
|
||||
}
|
||||
|
@ -124,7 +124,7 @@ func TestRunnerWithTimeout(t *testing.T) {
|
|||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel()
|
||||
|
||||
result, err := runner.RunWithContext(ctx, bytecode, nil)
|
||||
result, err := runner.RunWithContext(ctx, bytecode, nil, "")
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error with sufficient timeout: %v", err)
|
||||
}
|
||||
|
@ -136,7 +136,7 @@ func TestRunnerWithTimeout(t *testing.T) {
|
|||
ctx, cancel = context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
_, err = runner.RunWithContext(ctx, bytecode, nil)
|
||||
_, err = runner.RunWithContext(ctx, bytecode, nil, "")
|
||||
if err == nil {
|
||||
t.Errorf("Expected timeout error, got nil")
|
||||
}
|
||||
|
@ -156,7 +156,7 @@ func TestSandboxIsolation(t *testing.T) {
|
|||
return true
|
||||
`)
|
||||
|
||||
_, err = runner.Run(bytecode1, nil)
|
||||
_, err = runner.Run(bytecode1, nil, "")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute first script: %v", err)
|
||||
}
|
||||
|
@ -167,7 +167,7 @@ func TestSandboxIsolation(t *testing.T) {
|
|||
return my_global ~= nil
|
||||
`)
|
||||
|
||||
result, err := runner.Run(bytecode2, nil)
|
||||
result, err := runner.Run(bytecode2, nil, "")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute second script: %v", err)
|
||||
}
|
||||
|
@ -221,7 +221,7 @@ func TestRunnerWithInit(t *testing.T) {
|
|||
|
||||
// Test the add function
|
||||
bytecode1 := createTestBytecode(t, "return add(5, 7)")
|
||||
result1, err := runner.Run(bytecode1, nil)
|
||||
result1, err := runner.Run(bytecode1, nil, "")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to call add function: %v", err)
|
||||
}
|
||||
|
@ -233,7 +233,7 @@ func TestRunnerWithInit(t *testing.T) {
|
|||
|
||||
// Test the math2 module
|
||||
bytecode2 := createTestBytecode(t, "return math2.multiply(6, 8)")
|
||||
result2, err := runner.Run(bytecode2, nil)
|
||||
result2, err := runner.Run(bytecode2, nil, "")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to call math2.multiply: %v", err)
|
||||
}
|
||||
|
@ -264,7 +264,7 @@ func TestConcurrentExecution(t *testing.T) {
|
|||
execCtx := NewContext()
|
||||
execCtx.Set("n", float64(i))
|
||||
|
||||
result, err := runner.Run(bytecode, execCtx)
|
||||
result, err := runner.Run(bytecode, execCtx, "")
|
||||
if err != nil {
|
||||
t.Errorf("Job %d failed: %v", i, err)
|
||||
results <- -1
|
||||
|
@ -305,7 +305,7 @@ func TestRunnerClose(t *testing.T) {
|
|||
|
||||
// Submit a job to verify runner works
|
||||
bytecode := createTestBytecode(t, "return 42")
|
||||
_, err = runner.Run(bytecode, nil)
|
||||
_, err = runner.Run(bytecode, nil, "")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to run job: %v", err)
|
||||
}
|
||||
|
@ -316,7 +316,7 @@ func TestRunnerClose(t *testing.T) {
|
|||
}
|
||||
|
||||
// Run after close should fail
|
||||
_, err = runner.Run(bytecode, nil)
|
||||
_, err = runner.Run(bytecode, nil, "")
|
||||
if err != ErrRunnerClosed {
|
||||
t.Errorf("Expected ErrRunnerClosed, got %v", err)
|
||||
}
|
||||
|
@ -335,7 +335,7 @@ func TestErrorHandling(t *testing.T) {
|
|||
defer runner.Close()
|
||||
|
||||
// Test invalid bytecode
|
||||
_, err = runner.Run([]byte("not valid bytecode"), nil)
|
||||
_, err = runner.Run([]byte("not valid bytecode"), nil, "")
|
||||
if err == nil {
|
||||
t.Errorf("Expected error for invalid bytecode, got nil")
|
||||
}
|
||||
|
@ -346,14 +346,14 @@ func TestErrorHandling(t *testing.T) {
|
|||
return true
|
||||
`)
|
||||
|
||||
_, err = runner.Run(bytecode, nil)
|
||||
_, err = runner.Run(bytecode, nil, "")
|
||||
if err == nil {
|
||||
t.Errorf("Expected error from Lua error() call, got nil")
|
||||
}
|
||||
|
||||
// Test with nil context
|
||||
bytecode = createTestBytecode(t, "return ctx == nil")
|
||||
result, err := runner.Run(bytecode, nil)
|
||||
result, err := runner.Run(bytecode, nil, "")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error with nil context: %v", err)
|
||||
}
|
||||
|
@ -366,7 +366,7 @@ func TestErrorHandling(t *testing.T) {
|
|||
execCtx.Set("param", complex128(1+2i)) // Unsupported type
|
||||
|
||||
bytecode = createTestBytecode(t, "return ctx.param")
|
||||
_, err = runner.Run(bytecode, execCtx)
|
||||
_, err = runner.Run(bytecode, execCtx, "")
|
||||
if err == nil {
|
||||
t.Errorf("Expected error for unsupported context value type, got nil")
|
||||
}
|
||||
|
|
212
core/runner/require.go
Normal file
212
core/runner/require.go
Normal file
|
@ -0,0 +1,212 @@
|
|||
package runner
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
)
|
||||
|
||||
// Common errors
|
||||
var (
|
||||
ErrModuleNotFound = errors.New("module not found")
|
||||
ErrPathTraversal = errors.New("path traversal not allowed")
|
||||
)
|
||||
|
||||
// RequireConfig holds configuration for Lua's require function
|
||||
type RequireConfig struct {
|
||||
ScriptDir string // Base directory for script being executed
|
||||
LibDirs []string // Additional library directories
|
||||
}
|
||||
|
||||
// RequireCache is a thread-safe cache for loaded Lua modules
|
||||
type RequireCache struct {
|
||||
modules sync.Map // Maps full file paths to compiled bytecode
|
||||
}
|
||||
|
||||
// NewRequireCache creates a new, empty require cache
|
||||
func NewRequireCache() *RequireCache {
|
||||
return &RequireCache{
|
||||
modules: sync.Map{},
|
||||
}
|
||||
}
|
||||
|
||||
// SetupRequire configures the Lua state with a secure require function
|
||||
func SetupRequire(state *luajit.State, cache *RequireCache, config RequireConfig) error {
|
||||
// Register the loader function
|
||||
err := state.RegisterGoFunction("__go_load_module", func(s *luajit.State) int {
|
||||
// Get module name
|
||||
modName := s.ToString(1)
|
||||
if modName == "" {
|
||||
s.PushString("module name required")
|
||||
return -1 // Return error
|
||||
}
|
||||
|
||||
// Try to load the module
|
||||
bytecode, err := findAndCompileModule(s, cache, config, modName)
|
||||
if err != nil {
|
||||
if err == ErrModuleNotFound {
|
||||
s.PushString("module '" + modName + "' not found")
|
||||
} else {
|
||||
s.PushString("error loading module: " + err.Error())
|
||||
}
|
||||
return -1 // Return error
|
||||
}
|
||||
|
||||
// Load the bytecode
|
||||
if err := s.LoadBytecode(bytecode, modName); err != nil {
|
||||
s.PushString("error loading bytecode: " + err.Error())
|
||||
return -1 // Return error
|
||||
}
|
||||
|
||||
// Return the loaded function
|
||||
return 1
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Set up the secure require implementation
|
||||
setupScript := `
|
||||
-- Create a secure require function for sandboxed environments
|
||||
function __setup_secure_require(env)
|
||||
-- Replace env.require with our secure version
|
||||
env.require = function(modname)
|
||||
-- Check if already loaded in package.loaded
|
||||
if package.loaded[modname] then
|
||||
return package.loaded[modname]
|
||||
end
|
||||
|
||||
-- Try to load the module using our Go loader
|
||||
local loader = __go_load_module
|
||||
|
||||
-- Load the module
|
||||
local f, err = loader(modname)
|
||||
if not f then
|
||||
error(err or "failed to load module: " .. modname)
|
||||
end
|
||||
|
||||
-- Set the environment for the module
|
||||
setfenv(f, env)
|
||||
|
||||
-- Execute the module
|
||||
local result = f()
|
||||
|
||||
-- If module didn't return a value, use true
|
||||
if result == nil then
|
||||
result = true
|
||||
end
|
||||
|
||||
-- Cache the result
|
||||
package.loaded[modname] = result
|
||||
|
||||
return result
|
||||
end
|
||||
|
||||
return env
|
||||
end
|
||||
`
|
||||
|
||||
return state.DoString(setupScript)
|
||||
}
|
||||
|
||||
// findAndCompileModule finds a module in allowed directories and compiles it to bytecode
|
||||
func findAndCompileModule(
|
||||
state *luajit.State,
|
||||
cache *RequireCache,
|
||||
config RequireConfig,
|
||||
modName string,
|
||||
) ([]byte, error) {
|
||||
// Convert module name to relative path
|
||||
modPath := strings.ReplaceAll(modName, ".", string(filepath.Separator))
|
||||
|
||||
// List of paths to check
|
||||
paths := []string{}
|
||||
|
||||
// 1. Check adjacent to script directory first
|
||||
if config.ScriptDir != "" {
|
||||
paths = append(paths, filepath.Join(config.ScriptDir, modPath+".lua"))
|
||||
}
|
||||
|
||||
// 2. Check in lib directories
|
||||
for _, libDir := range config.LibDirs {
|
||||
if libDir != "" {
|
||||
paths = append(paths, filepath.Join(libDir, modPath+".lua"))
|
||||
}
|
||||
}
|
||||
|
||||
// Try each path
|
||||
for _, path := range paths {
|
||||
// Clean the path to handle .. and such (security)
|
||||
cleanPath := filepath.Clean(path)
|
||||
|
||||
// Check for path traversal (extra safety)
|
||||
if !isSubPath(config.ScriptDir, cleanPath) {
|
||||
isValidLib := false
|
||||
for _, libDir := range config.LibDirs {
|
||||
if isSubPath(libDir, cleanPath) {
|
||||
isValidLib = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !isValidLib {
|
||||
continue // Skip paths outside allowed directories
|
||||
}
|
||||
}
|
||||
|
||||
// Check if already in cache
|
||||
if bytecode, ok := cache.modules.Load(cleanPath); ok {
|
||||
return bytecode.([]byte), nil
|
||||
}
|
||||
|
||||
// Check if file exists
|
||||
_, err := os.Stat(cleanPath)
|
||||
if os.IsNotExist(err) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Read and compile the file
|
||||
content, err := os.ReadFile(cleanPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Compile to bytecode
|
||||
bytecode, err := state.CompileBytecode(string(content), cleanPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Store in cache
|
||||
cache.modules.Store(cleanPath, bytecode)
|
||||
|
||||
return bytecode, nil
|
||||
}
|
||||
|
||||
return nil, ErrModuleNotFound
|
||||
}
|
||||
|
||||
// isSubPath checks if path is contained within base directory
|
||||
func isSubPath(baseDir, path string) bool {
|
||||
if baseDir == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// Clean and normalize paths
|
||||
baseDir = filepath.Clean(baseDir)
|
||||
path = filepath.Clean(path)
|
||||
|
||||
// Get relative path
|
||||
rel, err := filepath.Rel(baseDir, path)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if path goes outside baseDir
|
||||
return !strings.HasPrefix(rel, ".."+string(filepath.Separator)) && rel != ".."
|
||||
}
|
280
core/runner/require_test.go
Normal file
280
core/runner/require_test.go
Normal file
|
@ -0,0 +1,280 @@
|
|||
package runner_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||
"git.sharkk.net/Sky/Moonshark/core/runner"
|
||||
)
|
||||
|
||||
func TestRequireFunctionality(t *testing.T) {
|
||||
// Create temporary directories for test
|
||||
tempDir, err := os.MkdirTemp("", "luarunner-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
// Create script directory and lib directory
|
||||
scriptDir := filepath.Join(tempDir, "scripts")
|
||||
libDir := filepath.Join(tempDir, "libs")
|
||||
|
||||
if err := os.Mkdir(scriptDir, 0755); err != nil {
|
||||
t.Fatalf("Failed to create script directory: %v", err)
|
||||
}
|
||||
if err := os.Mkdir(libDir, 0755); err != nil {
|
||||
t.Fatalf("Failed to create lib directory: %v", err)
|
||||
}
|
||||
|
||||
// Create a module in the lib directory
|
||||
libModule := `
|
||||
local lib = {}
|
||||
|
||||
function lib.add(a, b)
|
||||
return a + b
|
||||
end
|
||||
|
||||
function lib.mul(a, b)
|
||||
return a * b
|
||||
end
|
||||
|
||||
return lib
|
||||
`
|
||||
if err := os.WriteFile(filepath.Join(libDir, "mathlib.lua"), []byte(libModule), 0644); err != nil {
|
||||
t.Fatalf("Failed to write lib module: %v", err)
|
||||
}
|
||||
|
||||
// Create a helper module in the script directory
|
||||
helperModule := `
|
||||
local helper = {}
|
||||
|
||||
function helper.square(x)
|
||||
return x * x
|
||||
end
|
||||
|
||||
function helper.cube(x)
|
||||
return x * x * x
|
||||
end
|
||||
|
||||
return helper
|
||||
`
|
||||
if err := os.WriteFile(filepath.Join(scriptDir, "helper.lua"), []byte(helperModule), 0644); err != nil {
|
||||
t.Fatalf("Failed to write helper module: %v", err)
|
||||
}
|
||||
|
||||
// Create main script that requires both modules
|
||||
mainScript := `
|
||||
-- Require from the same directory
|
||||
local helper = require("helper")
|
||||
|
||||
-- Require from the lib directory
|
||||
local mathlib = require("mathlib")
|
||||
|
||||
-- Use both modules
|
||||
local result = {
|
||||
add = mathlib.add(10, 5),
|
||||
mul = mathlib.mul(10, 5),
|
||||
square = helper.square(5),
|
||||
cube = helper.cube(3)
|
||||
}
|
||||
|
||||
return result
|
||||
`
|
||||
mainScriptPath := filepath.Join(scriptDir, "main.lua")
|
||||
if err := os.WriteFile(mainScriptPath, []byte(mainScript), 0644); err != nil {
|
||||
t.Fatalf("Failed to write main script: %v", err)
|
||||
}
|
||||
|
||||
// Create LuaRunner
|
||||
luaRunner, err := runner.NewRunner(
|
||||
runner.WithScriptDir(scriptDir),
|
||||
runner.WithLibDirs(libDir),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create LuaRunner: %v", err)
|
||||
}
|
||||
defer luaRunner.Close()
|
||||
|
||||
// Compile the main script
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
bytecode, err := state.CompileBytecode(mainScript, "main.lua")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to compile script: %v", err)
|
||||
}
|
||||
|
||||
// Run the script
|
||||
result, err := luaRunner.Run(bytecode, nil, mainScriptPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to run script: %v", err)
|
||||
}
|
||||
|
||||
// Check result
|
||||
resultMap, ok := result.(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("Expected map result, got %T", result)
|
||||
}
|
||||
|
||||
// Validate results
|
||||
expectedResults := map[string]float64{
|
||||
"add": 15, // 10 + 5
|
||||
"mul": 50, // 10 * 5
|
||||
"square": 25, // 5^2
|
||||
"cube": 27, // 3^3
|
||||
}
|
||||
|
||||
for key, expected := range expectedResults {
|
||||
if val, ok := resultMap[key]; !ok {
|
||||
t.Errorf("Missing result key: %s", key)
|
||||
} else if val != expected {
|
||||
t.Errorf("For %s: expected %.1f, got %v", key, expected, val)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireSecurityBoundaries(t *testing.T) {
|
||||
// Create temporary directories for test
|
||||
tempDir, err := os.MkdirTemp("", "luarunner-security-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
// Create script directory and lib directory
|
||||
scriptDir := filepath.Join(tempDir, "scripts")
|
||||
libDir := filepath.Join(tempDir, "libs")
|
||||
secretDir := filepath.Join(tempDir, "secret")
|
||||
|
||||
if err := os.Mkdir(scriptDir, 0755); err != nil {
|
||||
t.Fatalf("Failed to create script directory: %v", err)
|
||||
}
|
||||
if err := os.Mkdir(libDir, 0755); err != nil {
|
||||
t.Fatalf("Failed to create lib directory: %v", err)
|
||||
}
|
||||
if err := os.Mkdir(secretDir, 0755); err != nil {
|
||||
t.Fatalf("Failed to create secret directory: %v", err)
|
||||
}
|
||||
|
||||
// Create a "secret" module that should not be accessible
|
||||
secretModule := `
|
||||
local secret = "TOP SECRET"
|
||||
return secret
|
||||
`
|
||||
if err := os.WriteFile(filepath.Join(secretDir, "secret.lua"), []byte(secretModule), 0644); err != nil {
|
||||
t.Fatalf("Failed to write secret module: %v", err)
|
||||
}
|
||||
|
||||
// Create a normal module in lib
|
||||
normalModule := `return "normal module"`
|
||||
if err := os.WriteFile(filepath.Join(libDir, "normal.lua"), []byte(normalModule), 0644); err != nil {
|
||||
t.Fatalf("Failed to write normal module: %v", err)
|
||||
}
|
||||
|
||||
// Test attempting to access file outside allowed paths
|
||||
pathTraversalTests := []struct {
|
||||
name string
|
||||
script string
|
||||
}{
|
||||
{
|
||||
name: "Direct path traversal",
|
||||
script: `
|
||||
-- Try path traversal
|
||||
local secret = require("../secret/secret")
|
||||
return secret
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "Double dot traversal",
|
||||
script: `
|
||||
local secret = require("..secret.secret")
|
||||
return secret
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "Absolute path",
|
||||
script: fmt.Sprintf(`
|
||||
local secret = require("%s")
|
||||
return secret
|
||||
`, filepath.Join(secretDir, "secret")),
|
||||
},
|
||||
}
|
||||
|
||||
// Create and configure runner
|
||||
luaRunner, err := runner.NewRunner(
|
||||
runner.WithScriptDir(scriptDir),
|
||||
runner.WithLibDirs(libDir),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create LuaRunner: %v", err)
|
||||
}
|
||||
defer luaRunner.Close()
|
||||
|
||||
// Test each attempt at path traversal
|
||||
for _, tt := range pathTraversalTests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Write the script
|
||||
scriptPath := filepath.Join(scriptDir, tt.name+".lua")
|
||||
if err := os.WriteFile(scriptPath, []byte(tt.script), 0644); err != nil {
|
||||
t.Fatalf("Failed to write test script: %v", err)
|
||||
}
|
||||
|
||||
// Compile
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
bytecode, err := state.CompileBytecode(tt.script, tt.name+".lua")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to compile script: %v", err)
|
||||
}
|
||||
|
||||
// Run and expect error
|
||||
_, err = luaRunner.Run(bytecode, nil, scriptPath)
|
||||
if err == nil {
|
||||
t.Error("Expected error for path traversal, got nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Test that we can still require valid modules
|
||||
normalScript := `
|
||||
local normal = require("normal")
|
||||
return normal
|
||||
`
|
||||
scriptPath := filepath.Join(scriptDir, "normal_test.lua")
|
||||
if err := os.WriteFile(scriptPath, []byte(normalScript), 0644); err != nil {
|
||||
t.Fatalf("Failed to write normal test script: %v", err)
|
||||
}
|
||||
|
||||
// Compile
|
||||
state := luajit.New()
|
||||
if state == nil {
|
||||
t.Fatal("Failed to create Lua state")
|
||||
}
|
||||
defer state.Close()
|
||||
|
||||
bytecode, err := state.CompileBytecode(normalScript, "normal_test.lua")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to compile script: %v", err)
|
||||
}
|
||||
|
||||
// Run and expect success
|
||||
result, err := luaRunner.Run(bytecode, nil, scriptPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to run normal script: %v", err)
|
||||
}
|
||||
|
||||
// Check result
|
||||
if result != "normal module" {
|
||||
t.Errorf("Expected 'normal module', got %v", result)
|
||||
}
|
||||
}
|
|
@ -66,6 +66,8 @@ func main() {
|
|||
log.SetLevel(logger.LevelWarning)
|
||||
case "error":
|
||||
log.SetLevel(logger.LevelError)
|
||||
case "server":
|
||||
log.SetLevel(logger.LevelServer)
|
||||
case "fatal":
|
||||
log.SetLevel(logger.LevelFatal)
|
||||
default:
|
||||
|
@ -108,6 +110,7 @@ func main() {
|
|||
// Initialize Lua runner (replacing worker pool)
|
||||
runner, err := runner.NewRunner(
|
||||
runner.WithBufferSize(bufferSize),
|
||||
runner.WithLibDirs("./libs"),
|
||||
)
|
||||
if err != nil {
|
||||
log.Fatal("Failed to initialize Lua runner: %v", err)
|
||||
|
|
Loading…
Reference in New Issue
Block a user