diff --git a/.gitignore b/.gitignore index f380bca..97a3337 100644 --- a/.gitignore +++ b/.gitignore @@ -25,3 +25,5 @@ go.work config.lua routes/ static/ + +luajit diff --git a/core/workers/init_test.go b/core/workers/init_test.go new file mode 100644 index 0000000..ea0557d --- /dev/null +++ b/core/workers/init_test.go @@ -0,0 +1,346 @@ +package workers + +import ( + "testing" + + luajit "git.sharkk.net/Sky/LuaJIT-to-Go" +) + +func TestModuleRegistration(t *testing.T) { + // Define an init function that registers a simple "math" module + mathInit := func(state *luajit.State) error { + // Register the "add" function + err := state.RegisterGoFunction("add", func(s *luajit.State) int { + a := s.ToNumber(1) + b := s.ToNumber(2) + s.PushNumber(a + b) + return 1 // Return one result + }) + + if err != nil { + return err + } + + // Register a whole module + mathFuncs := map[string]luajit.GoFunction{ + "multiply": func(s *luajit.State) int { + a := s.ToNumber(1) + b := s.ToNumber(2) + s.PushNumber(a * b) + return 1 + }, + "subtract": func(s *luajit.State) int { + a := s.ToNumber(1) + b := s.ToNumber(2) + s.PushNumber(a - b) + return 1 + }, + } + + return RegisterModule(state, "math2", mathFuncs) + } + + // Create a pool with our init function + pool, err := NewPoolWithInit(2, mathInit) + if err != nil { + t.Fatalf("Failed to create pool: %v", err) + } + defer pool.Shutdown() + + // Test the add function + bytecode1 := createTestBytecode(t, "return add(5, 7)") + result1, err := pool.Submit(bytecode1, nil) + if err != nil { + t.Fatalf("Failed to call add function: %v", err) + } + + num1, ok := result1.(float64) + if !ok || num1 != 12 { + t.Errorf("Expected add(5, 7) = 12, got %v", result1) + } + + // Test the math2 module + bytecode2 := createTestBytecode(t, "return math2.multiply(6, 8)") + result2, err := pool.Submit(bytecode2, nil) + if err != nil { + t.Fatalf("Failed to call math2.multiply: %v", err) + } + + num2, ok := result2.(float64) + if !ok || num2 != 48 { + t.Errorf("Expected math2.multiply(6, 8) = 48, got %v", result2) + } + + // Test multiple operations + bytecode3 := createTestBytecode(t, ` + local a = add(10, 20) + local b = math2.subtract(a, 5) + return math2.multiply(b, 2) + `) + + result3, err := pool.Submit(bytecode3, nil) + if err != nil { + t.Fatalf("Failed to execute combined operations: %v", err) + } + + num3, ok := result3.(float64) + if !ok || num3 != 50 { + t.Errorf("Expected ((10 + 20) - 5) * 2 = 50, got %v", result3) + } +} + +func TestModuleInitFunc(t *testing.T) { + // Define math module functions + mathModule := func() map[string]luajit.GoFunction { + return map[string]luajit.GoFunction{ + "add": func(s *luajit.State) int { + a := s.ToNumber(1) + b := s.ToNumber(2) + s.PushNumber(a + b) + return 1 + }, + "multiply": func(s *luajit.State) int { + a := s.ToNumber(1) + b := s.ToNumber(2) + s.PushNumber(a * b) + return 1 + }, + } + } + + // Define string module functions + strModule := func() map[string]luajit.GoFunction { + return map[string]luajit.GoFunction{ + "concat": func(s *luajit.State) int { + a := s.ToString(1) + b := s.ToString(2) + s.PushString(a + b) + return 1 + }, + } + } + + // Create module map + modules := map[string]ModuleFunc{ + "math2": mathModule, + "str": strModule, + } + + // Create pool with module init + pool, err := NewPoolWithInit(2, ModuleInitFunc(modules)) + if err != nil { + t.Fatalf("Failed to create pool: %v", err) + } + defer pool.Shutdown() + + // Test math module + bytecode1 := createTestBytecode(t, "return math2.add(5, 7)") + result1, err := pool.Submit(bytecode1, nil) + if err != nil { + t.Fatalf("Failed to call math2.add: %v", err) + } + + num1, ok := result1.(float64) + if !ok || num1 != 12 { + t.Errorf("Expected math2.add(5, 7) = 12, got %v", result1) + } + + // Test string module + bytecode2 := createTestBytecode(t, "return str.concat('hello', 'world')") + result2, err := pool.Submit(bytecode2, nil) + if err != nil { + t.Fatalf("Failed to call str.concat: %v", err) + } + + str2, ok := result2.(string) + if !ok || str2 != "helloworld" { + t.Errorf("Expected str.concat('hello', 'world') = 'helloworld', got %v", result2) + } +} + +func TestCombineInitFuncs(t *testing.T) { + // First init function adds a function to get a constant value + init1 := func(state *luajit.State) error { + return state.RegisterGoFunction("getAnswer", func(s *luajit.State) int { + s.PushNumber(42) + return 1 + }) + } + + // Second init function registers a function that multiplies a number by 2 + init2 := func(state *luajit.State) error { + return state.RegisterGoFunction("double", func(s *luajit.State) int { + n := s.ToNumber(1) + s.PushNumber(n * 2) + return 1 + }) + } + + // Combine the init functions + combinedInit := CombineInitFuncs(init1, init2) + + // Create a pool with the combined init function + pool, err := NewPoolWithInit(1, combinedInit) + if err != nil { + t.Fatalf("Failed to create pool: %v", err) + } + defer pool.Shutdown() + + // Test using both functions together in a single script + bytecode := createTestBytecode(t, "return double(getAnswer())") + result, err := pool.Submit(bytecode, nil) + if err != nil { + t.Fatalf("Failed to execute: %v", err) + } + + num, ok := result.(float64) + if !ok || num != 84 { + t.Errorf("Expected double(getAnswer()) = 84, got %v", result) + } +} + +func TestSandboxIsolation(t *testing.T) { + // Create a pool + pool, err := NewPool(2) + if err != nil { + t.Fatalf("Failed to create pool: %v", err) + } + defer pool.Shutdown() + + // Create a script that tries to modify a global variable + bytecode1 := createTestBytecode(t, ` + -- Set a "global" variable + my_global = "test value" + return true + `) + + _, err = pool.Submit(bytecode1, nil) + if err != nil { + t.Fatalf("Failed to execute first script: %v", err) + } + + // Now try to access that variable from another script + bytecode2 := createTestBytecode(t, ` + -- Try to access the previously set global + return my_global ~= nil + `) + + result, err := pool.Submit(bytecode2, nil) + if err != nil { + t.Fatalf("Failed to execute second script: %v", err) + } + + // The variable should not be accessible (sandbox isolation) + if result.(bool) { + t.Errorf("Expected sandbox isolation, but global variable was accessible") + } +} + +func TestContextInSandbox(t *testing.T) { + // Create a pool + pool, err := NewPool(2) + if err != nil { + t.Fatalf("Failed to create pool: %v", err) + } + defer pool.Shutdown() + + // Create a context with test data + ctx := NewContext() + ctx.Set("name", "test") + ctx.Set("value", 42.5) + ctx.Set("items", []float64{1, 2, 3}) + + bytecode := createTestBytecode(t, ` + -- Access and manipulate context values + local sum = 0 + for i, v in ipairs(ctx.items) do + sum = sum + v + end + + return { + name_length = string.len(ctx.name), + value_doubled = ctx.value * 2, + items_sum = sum + } + `) + + result, err := pool.Submit(bytecode, ctx) + if err != nil { + t.Fatalf("Failed to execute script with context: %v", err) + } + + resultMap, ok := result.(map[string]any) + if !ok { + t.Fatalf("Expected map result, got %T", result) + } + + // Check context values were correctly accessible + if resultMap["name_length"].(float64) != 4 { + t.Errorf("Expected name_length = 4, got %v", resultMap["name_length"]) + } + + if resultMap["value_doubled"].(float64) != 85 { + t.Errorf("Expected value_doubled = 85, got %v", resultMap["value_doubled"]) + } + + if resultMap["items_sum"].(float64) != 6 { + t.Errorf("Expected items_sum = 6, got %v", resultMap["items_sum"]) + } +} + +func TestStandardLibsInSandbox(t *testing.T) { + // Create a pool + pool, err := NewPool(2) + if err != nil { + t.Fatalf("Failed to create pool: %v", err) + } + defer pool.Shutdown() + + // Test access to standard libraries + bytecode := createTestBytecode(t, ` + local results = {} + + -- Test string library + results.string_upper = string.upper("test") + + -- Test math library + results.math_sqrt = math.sqrt(16) + + -- Test table library + local tbl = {10, 20, 30} + table.insert(tbl, 40) + results.table_length = #tbl + + -- Test os library (limited functions) + results.has_os_time = type(os.time) == "function" + + return results + `) + + result, err := pool.Submit(bytecode, nil) + if err != nil { + t.Fatalf("Failed to execute script: %v", err) + } + + resultMap, ok := result.(map[string]any) + if !ok { + t.Fatalf("Expected map result, got %T", result) + } + + // Check standard library functions worked + if resultMap["string_upper"] != "TEST" { + t.Errorf("Expected string_upper = 'TEST', got %v", resultMap["string_upper"]) + } + + if resultMap["math_sqrt"].(float64) != 4 { + t.Errorf("Expected math_sqrt = 4, got %v", resultMap["math_sqrt"]) + } + + if resultMap["table_length"].(float64) != 4 { + t.Errorf("Expected table_length = 4, got %v", resultMap["table_length"]) + } + + if resultMap["has_os_time"] != true { + t.Errorf("Expected has_os_time = true, got %v", resultMap["has_os_time"]) + } +} diff --git a/core/workers/modules.go b/core/workers/modules.go new file mode 100644 index 0000000..b1194ce --- /dev/null +++ b/core/workers/modules.go @@ -0,0 +1,59 @@ +package workers + +import ( + luajit "git.sharkk.net/Sky/LuaJIT-to-Go" +) + +// ModuleFunc is a function that returns a map of module functions +type ModuleFunc func() map[string]luajit.GoFunction + +// ModuleInitFunc creates a state initializer that registers multiple modules +func ModuleInitFunc(modules map[string]ModuleFunc) StateInitFunc { + return func(state *luajit.State) error { + for name, moduleFunc := range modules { + if err := RegisterModule(state, name, moduleFunc()); err != nil { + return err + } + } + return nil + } +} + +// RegisterModule registers a map of functions as a Lua module +func RegisterModule(state *luajit.State, name string, funcs map[string]luajit.GoFunction) error { + // Create a new table for the module + state.NewTable() + + // Add each function to the module table + for fname, f := range funcs { + // Push function name + state.PushString(fname) + + // Push function + if err := state.PushGoFunction(f); err != nil { + state.Pop(2) // Pop table and function name + return err + } + + // Set table[fname] = f + state.SetTable(-3) + } + + // Register the module globally + state.SetGlobal(name) + return nil +} + +// CombineInitFuncs combines multiple state initializer functions into one +func CombineInitFuncs(funcs ...StateInitFunc) StateInitFunc { + return func(state *luajit.State) error { + for _, f := range funcs { + if f != nil { + if err := f(state); err != nil { + return err + } + } + } + return nil + } +} diff --git a/core/workers/pool.go b/core/workers/pool.go index 2859caf..e6b85f5 100644 --- a/core/workers/pool.go +++ b/core/workers/pool.go @@ -4,8 +4,14 @@ import ( "context" "sync" "sync/atomic" + + luajit "git.sharkk.net/Sky/LuaJIT-to-Go" ) +// StateInitFunc is a function that initializes a Lua state +// It can be used to register custom functions and modules +type StateInitFunc func(*luajit.State) error + // Pool manages a pool of Lua worker goroutines type Pool struct { workers uint32 // Number of workers @@ -13,18 +19,26 @@ type Pool struct { wg sync.WaitGroup // WaitGroup to track active workers quit chan struct{} // Channel to signal shutdown isRunning atomic.Bool // Flag to track if pool is running + stateInit StateInitFunc // Optional function to initialize Lua state } // NewPool creates a new worker pool with the specified number of workers func NewPool(numWorkers int) (*Pool, error) { + return NewPoolWithInit(numWorkers, nil) +} + +// NewPoolWithInit creates a new worker pool with the specified number of workers +// and a function to initialize each worker's Lua state +func NewPoolWithInit(numWorkers int, initFunc StateInitFunc) (*Pool, error) { if numWorkers <= 0 { return nil, ErrNoWorkers } p := &Pool{ - workers: uint32(numWorkers), - jobs: make(chan job, numWorkers), // Buffer equal to worker count - quit: make(chan struct{}), + workers: uint32(numWorkers), + jobs: make(chan job, numWorkers), // Buffer equal to worker count + quit: make(chan struct{}), + stateInit: initFunc, } p.isRunning.Store(true) @@ -41,6 +55,12 @@ func NewPool(numWorkers int) (*Pool, error) { return p, nil } +// RegisterGlobal is no longer needed with the sandbox approach +// but kept as a no-op for backward compatibility +func (p *Pool) RegisterGlobal(name string) { + // No-op in sandbox mode +} + // SubmitWithContext sends a job to the worker pool with context func (p *Pool) SubmitWithContext(ctx context.Context, bytecode []byte, execCtx *Context) (any, error) { if !p.isRunning.Load() { diff --git a/core/workers/sandbox.go b/core/workers/sandbox.go new file mode 100644 index 0000000..ddf9930 --- /dev/null +++ b/core/workers/sandbox.go @@ -0,0 +1,144 @@ +package workers + +// setupSandbox initializes the sandbox environment creation function +func (w *worker) setupSandbox() error { + // This is the Lua script that creates our sandbox function + setupScript := ` + -- Create a function to run code in a sandbox environment + function __create_sandbox() + -- Create new environment table + local env = {} + + -- Add standard library modules (can be restricted as needed) + env.string = string + env.table = table + env.math = math + env.os = { + time = os.time, + date = os.date, + difftime = os.difftime, + clock = os.clock + } + env.tonumber = tonumber + env.tostring = tostring + env.type = type + env.pairs = pairs + env.ipairs = ipairs + env.next = next + env.select = select + env.unpack = unpack + env.pcall = pcall + env.xpcall = xpcall + env.error = error + env.assert = assert + + -- Allow access to package.loaded for modules + env.require = function(name) + return package.loaded[name] + end + + -- Create metatable to restrict access to _G + local mt = { + __index = function(t, k) + -- First check in env table + local v = rawget(env, k) + if v ~= nil then return v end + + -- If not found, check for registered modules/functions + local moduleValue = _G[k] + if type(moduleValue) == "table" or + type(moduleValue) == "function" then + return moduleValue + end + + return nil + end, + __newindex = function(t, k, v) + rawset(env, k, v) + end + } + + setmetatable(env, mt) + return env + end + + -- Create function to execute code with a sandbox + function __run_sandboxed(f, ctx) + local env = __create_sandbox() + + -- Add context to the environment if provided + if ctx then + env.ctx = ctx + end + + -- Set the environment and run the function + setfenv(f, env) + return f() + end + ` + + return w.state.DoString(setupScript) +} + +// executeJobSandboxed runs a script in a sandbox environment +func (w *worker) executeJobSandboxed(j job) JobResult { + // No need to reset the state for each execution, since we're using a sandbox + + // Re-run init function to register functions and modules if needed + if w.pool.stateInit != nil { + if err := w.pool.stateInit(w.state); err != nil { + return JobResult{nil, err} + } + } + + // Set up context if provided + if j.Context != nil { + // Push context table + w.state.NewTable() + + // Add values to context table + for key, value := range j.Context.Values { + // Push key + w.state.PushString(key) + + // Push value + if err := w.state.PushValue(value); err != nil { + return JobResult{nil, err} + } + + // Set table[key] = value + w.state.SetTable(-3) + } + } else { + // Push nil if no context + w.state.PushNil() + } + + // Load bytecode + if err := w.state.LoadBytecode(j.Bytecode, "script"); err != nil { + w.state.Pop(1) // Pop context + return JobResult{nil, err} + } + + // Get the sandbox runner function + w.state.GetGlobal("__run_sandboxed") + + // Push loaded function and context as arguments + w.state.PushCopy(-2) // Copy the loaded function + w.state.PushCopy(-4) // Copy the context table or nil + + // Remove the original function and context + w.state.Remove(-5) // Remove original context + w.state.Remove(-4) // Remove original function + + // Call the sandbox runner with 2 args (function and context), expecting 1 result + if err := w.state.Call(2, 1); err != nil { + return JobResult{nil, err} + } + + // Get result + value, err := w.state.ToValue(-1) + w.state.Pop(1) // Pop result + + return JobResult{value, err} +} diff --git a/core/workers/worker.go b/core/workers/worker.go index 958cfed..e83ce0c 100644 --- a/core/workers/worker.go +++ b/core/workers/worker.go @@ -11,6 +11,7 @@ import ( var ( ErrPoolClosed = errors.New("worker pool is closed") ErrNoWorkers = errors.New("no workers available") + ErrInitFailed = errors.New("worker initialization failed") ) // worker represents a single Lua execution worker @@ -33,13 +34,22 @@ func (w *worker) run() { } defer w.state.Close() - // Set up reset function for clearing state between requests - if err := w.setupResetFunction(); err != nil { - // Worker failed to initialize reset function, decrement counter + // Set up sandbox environment + if err := w.setupSandbox(); err != nil { + // Worker failed to initialize sandbox, decrement counter atomic.AddUint32(&w.pool.workers, ^uint32(0)) return } + // Run init function if provided + if w.pool.stateInit != nil { + if err := w.pool.stateInit(w.state); err != nil { + // Worker failed to initialize with custom init function + atomic.AddUint32(&w.pool.workers, ^uint32(0)) + return + } + } + // Main worker loop for { select { @@ -50,7 +60,7 @@ func (w *worker) run() { } // Execute job - result := w.executeJob(job) + result := w.executeJobSandboxed(job) job.Result <- result case <-w.pool.quit: @@ -59,102 +69,3 @@ func (w *worker) run() { } } } - -// setupResetFunction initializes the reset function for clearing globals -func (w *worker) setupResetFunction() error { - resetScript := ` - -- Create reset function to efficiently clear globals after each request - function __reset_globals() - -- Only keep builtin globals, remove all user-defined globals - local preserve = { - ["_G"] = true, ["_VERSION"] = true, ["__reset_globals"] = true, - ["assert"] = true, ["collectgarbage"] = true, ["coroutine"] = true, - ["debug"] = true, ["dofile"] = true, ["error"] = true, - ["getmetatable"] = true, ["io"] = true, ["ipairs"] = true, - ["load"] = true, ["loadfile"] = true, ["loadstring"] = true, - ["math"] = true, ["next"] = true, ["os"] = true, - ["package"] = true, ["pairs"] = true, ["pcall"] = true, - ["print"] = true, ["rawequal"] = true, ["rawget"] = true, - ["rawset"] = true, ["require"] = true, ["select"] = true, - ["setmetatable"] = true, ["string"] = true, ["table"] = true, - ["tonumber"] = true, ["tostring"] = true, ["type"] = true, - ["unpack"] = true, ["xpcall"] = true - } - - -- Clear all non-standard globals - for name in pairs(_G) do - if not preserve[name] then - _G[name] = nil - end - end - - -- Run garbage collection to release memory - collectgarbage('collect') - end - ` - - return w.state.DoString(resetScript) -} - -// resetState prepares the Lua state for a new job -func (w *worker) resetState() { - w.state.DoString("__reset_globals()") -} - -// setContext sets job context as global tables in Lua state -func (w *worker) setContext(ctx *Context) error { - if ctx == nil { - return nil - } - - // Create context table - w.state.NewTable() - - // Add values to context table - for key, value := range ctx.Values { - // Push key - w.state.PushString(key) - - // Push value - if err := w.state.PushValue(value); err != nil { - return err - } - - // Set table[key] = value - w.state.SetTable(-3) - } - - // Set the table as global 'ctx' - w.state.SetGlobal("ctx") - - return nil -} - -// executeJob executes a Lua job in the worker's state -func (w *worker) executeJob(j job) JobResult { - // Reset state before execution - w.resetState() - - // Set context - if j.Context != nil { - if err := w.setContext(j.Context); err != nil { - return JobResult{nil, err} - } - } - - // Load bytecode - if err := w.state.LoadBytecode(j.Bytecode, "script"); err != nil { - return JobResult{nil, err} - } - - // Execute script with one result - if err := w.state.RunBytecodeWithResults(1); err != nil { - return JobResult{nil, err} - } - - // Get result - value, err := w.state.ToValue(-1) - w.state.Pop(1) // Pop result - - return JobResult{value, err} -} diff --git a/luajit b/luajit index 13686b3..7ea0dbc 160000 --- a/luajit +++ b/luajit @@ -1 +1 @@ -Subproject commit 13686b3e66b388a31d459fe95d1aa3bfa05aeb27 +Subproject commit 7ea0dbcb7b2ddcd8758e66b034c300ee55178b29 diff --git a/moonshark.go b/moonshark.go index 30856f7..1e17c49 100644 --- a/moonshark.go +++ b/moonshark.go @@ -47,7 +47,7 @@ func initRouters(routesDir, staticDir string, log *logger.Logger) (*routers.LuaR func main() { // Initialize logger - log := logger.New(logger.LevelDebug, true) + log := logger.New(logger.LevelInfo, true) log.Info("Starting Moonshark server")