lua state optimizations
This commit is contained in:
parent
0cc1f37cfe
commit
cef93357c3
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -25,3 +25,5 @@ go.work
|
||||||
config.lua
|
config.lua
|
||||||
routes/
|
routes/
|
||||||
static/
|
static/
|
||||||
|
|
||||||
|
luajit
|
||||||
|
|
346
core/workers/init_test.go
Normal file
346
core/workers/init_test.go
Normal file
|
@ -0,0 +1,346 @@
|
||||||
|
package workers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestModuleRegistration(t *testing.T) {
|
||||||
|
// Define an init function that registers a simple "math" module
|
||||||
|
mathInit := func(state *luajit.State) error {
|
||||||
|
// Register the "add" function
|
||||||
|
err := state.RegisterGoFunction("add", func(s *luajit.State) int {
|
||||||
|
a := s.ToNumber(1)
|
||||||
|
b := s.ToNumber(2)
|
||||||
|
s.PushNumber(a + b)
|
||||||
|
return 1 // Return one result
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register a whole module
|
||||||
|
mathFuncs := map[string]luajit.GoFunction{
|
||||||
|
"multiply": func(s *luajit.State) int {
|
||||||
|
a := s.ToNumber(1)
|
||||||
|
b := s.ToNumber(2)
|
||||||
|
s.PushNumber(a * b)
|
||||||
|
return 1
|
||||||
|
},
|
||||||
|
"subtract": func(s *luajit.State) int {
|
||||||
|
a := s.ToNumber(1)
|
||||||
|
b := s.ToNumber(2)
|
||||||
|
s.PushNumber(a - b)
|
||||||
|
return 1
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return RegisterModule(state, "math2", mathFuncs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a pool with our init function
|
||||||
|
pool, err := NewPoolWithInit(2, mathInit)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create pool: %v", err)
|
||||||
|
}
|
||||||
|
defer pool.Shutdown()
|
||||||
|
|
||||||
|
// Test the add function
|
||||||
|
bytecode1 := createTestBytecode(t, "return add(5, 7)")
|
||||||
|
result1, err := pool.Submit(bytecode1, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to call add function: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
num1, ok := result1.(float64)
|
||||||
|
if !ok || num1 != 12 {
|
||||||
|
t.Errorf("Expected add(5, 7) = 12, got %v", result1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test the math2 module
|
||||||
|
bytecode2 := createTestBytecode(t, "return math2.multiply(6, 8)")
|
||||||
|
result2, err := pool.Submit(bytecode2, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to call math2.multiply: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
num2, ok := result2.(float64)
|
||||||
|
if !ok || num2 != 48 {
|
||||||
|
t.Errorf("Expected math2.multiply(6, 8) = 48, got %v", result2)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test multiple operations
|
||||||
|
bytecode3 := createTestBytecode(t, `
|
||||||
|
local a = add(10, 20)
|
||||||
|
local b = math2.subtract(a, 5)
|
||||||
|
return math2.multiply(b, 2)
|
||||||
|
`)
|
||||||
|
|
||||||
|
result3, err := pool.Submit(bytecode3, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to execute combined operations: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
num3, ok := result3.(float64)
|
||||||
|
if !ok || num3 != 50 {
|
||||||
|
t.Errorf("Expected ((10 + 20) - 5) * 2 = 50, got %v", result3)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModuleInitFunc(t *testing.T) {
|
||||||
|
// Define math module functions
|
||||||
|
mathModule := func() map[string]luajit.GoFunction {
|
||||||
|
return map[string]luajit.GoFunction{
|
||||||
|
"add": func(s *luajit.State) int {
|
||||||
|
a := s.ToNumber(1)
|
||||||
|
b := s.ToNumber(2)
|
||||||
|
s.PushNumber(a + b)
|
||||||
|
return 1
|
||||||
|
},
|
||||||
|
"multiply": func(s *luajit.State) int {
|
||||||
|
a := s.ToNumber(1)
|
||||||
|
b := s.ToNumber(2)
|
||||||
|
s.PushNumber(a * b)
|
||||||
|
return 1
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Define string module functions
|
||||||
|
strModule := func() map[string]luajit.GoFunction {
|
||||||
|
return map[string]luajit.GoFunction{
|
||||||
|
"concat": func(s *luajit.State) int {
|
||||||
|
a := s.ToString(1)
|
||||||
|
b := s.ToString(2)
|
||||||
|
s.PushString(a + b)
|
||||||
|
return 1
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create module map
|
||||||
|
modules := map[string]ModuleFunc{
|
||||||
|
"math2": mathModule,
|
||||||
|
"str": strModule,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create pool with module init
|
||||||
|
pool, err := NewPoolWithInit(2, ModuleInitFunc(modules))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create pool: %v", err)
|
||||||
|
}
|
||||||
|
defer pool.Shutdown()
|
||||||
|
|
||||||
|
// Test math module
|
||||||
|
bytecode1 := createTestBytecode(t, "return math2.add(5, 7)")
|
||||||
|
result1, err := pool.Submit(bytecode1, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to call math2.add: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
num1, ok := result1.(float64)
|
||||||
|
if !ok || num1 != 12 {
|
||||||
|
t.Errorf("Expected math2.add(5, 7) = 12, got %v", result1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test string module
|
||||||
|
bytecode2 := createTestBytecode(t, "return str.concat('hello', 'world')")
|
||||||
|
result2, err := pool.Submit(bytecode2, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to call str.concat: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
str2, ok := result2.(string)
|
||||||
|
if !ok || str2 != "helloworld" {
|
||||||
|
t.Errorf("Expected str.concat('hello', 'world') = 'helloworld', got %v", result2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCombineInitFuncs(t *testing.T) {
|
||||||
|
// First init function adds a function to get a constant value
|
||||||
|
init1 := func(state *luajit.State) error {
|
||||||
|
return state.RegisterGoFunction("getAnswer", func(s *luajit.State) int {
|
||||||
|
s.PushNumber(42)
|
||||||
|
return 1
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second init function registers a function that multiplies a number by 2
|
||||||
|
init2 := func(state *luajit.State) error {
|
||||||
|
return state.RegisterGoFunction("double", func(s *luajit.State) int {
|
||||||
|
n := s.ToNumber(1)
|
||||||
|
s.PushNumber(n * 2)
|
||||||
|
return 1
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Combine the init functions
|
||||||
|
combinedInit := CombineInitFuncs(init1, init2)
|
||||||
|
|
||||||
|
// Create a pool with the combined init function
|
||||||
|
pool, err := NewPoolWithInit(1, combinedInit)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create pool: %v", err)
|
||||||
|
}
|
||||||
|
defer pool.Shutdown()
|
||||||
|
|
||||||
|
// Test using both functions together in a single script
|
||||||
|
bytecode := createTestBytecode(t, "return double(getAnswer())")
|
||||||
|
result, err := pool.Submit(bytecode, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to execute: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
num, ok := result.(float64)
|
||||||
|
if !ok || num != 84 {
|
||||||
|
t.Errorf("Expected double(getAnswer()) = 84, got %v", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSandboxIsolation(t *testing.T) {
|
||||||
|
// Create a pool
|
||||||
|
pool, err := NewPool(2)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create pool: %v", err)
|
||||||
|
}
|
||||||
|
defer pool.Shutdown()
|
||||||
|
|
||||||
|
// Create a script that tries to modify a global variable
|
||||||
|
bytecode1 := createTestBytecode(t, `
|
||||||
|
-- Set a "global" variable
|
||||||
|
my_global = "test value"
|
||||||
|
return true
|
||||||
|
`)
|
||||||
|
|
||||||
|
_, err = pool.Submit(bytecode1, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to execute first script: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now try to access that variable from another script
|
||||||
|
bytecode2 := createTestBytecode(t, `
|
||||||
|
-- Try to access the previously set global
|
||||||
|
return my_global ~= nil
|
||||||
|
`)
|
||||||
|
|
||||||
|
result, err := pool.Submit(bytecode2, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to execute second script: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The variable should not be accessible (sandbox isolation)
|
||||||
|
if result.(bool) {
|
||||||
|
t.Errorf("Expected sandbox isolation, but global variable was accessible")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContextInSandbox(t *testing.T) {
|
||||||
|
// Create a pool
|
||||||
|
pool, err := NewPool(2)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create pool: %v", err)
|
||||||
|
}
|
||||||
|
defer pool.Shutdown()
|
||||||
|
|
||||||
|
// Create a context with test data
|
||||||
|
ctx := NewContext()
|
||||||
|
ctx.Set("name", "test")
|
||||||
|
ctx.Set("value", 42.5)
|
||||||
|
ctx.Set("items", []float64{1, 2, 3})
|
||||||
|
|
||||||
|
bytecode := createTestBytecode(t, `
|
||||||
|
-- Access and manipulate context values
|
||||||
|
local sum = 0
|
||||||
|
for i, v in ipairs(ctx.items) do
|
||||||
|
sum = sum + v
|
||||||
|
end
|
||||||
|
|
||||||
|
return {
|
||||||
|
name_length = string.len(ctx.name),
|
||||||
|
value_doubled = ctx.value * 2,
|
||||||
|
items_sum = sum
|
||||||
|
}
|
||||||
|
`)
|
||||||
|
|
||||||
|
result, err := pool.Submit(bytecode, ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to execute script with context: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resultMap, ok := result.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("Expected map result, got %T", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check context values were correctly accessible
|
||||||
|
if resultMap["name_length"].(float64) != 4 {
|
||||||
|
t.Errorf("Expected name_length = 4, got %v", resultMap["name_length"])
|
||||||
|
}
|
||||||
|
|
||||||
|
if resultMap["value_doubled"].(float64) != 85 {
|
||||||
|
t.Errorf("Expected value_doubled = 85, got %v", resultMap["value_doubled"])
|
||||||
|
}
|
||||||
|
|
||||||
|
if resultMap["items_sum"].(float64) != 6 {
|
||||||
|
t.Errorf("Expected items_sum = 6, got %v", resultMap["items_sum"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStandardLibsInSandbox(t *testing.T) {
|
||||||
|
// Create a pool
|
||||||
|
pool, err := NewPool(2)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create pool: %v", err)
|
||||||
|
}
|
||||||
|
defer pool.Shutdown()
|
||||||
|
|
||||||
|
// Test access to standard libraries
|
||||||
|
bytecode := createTestBytecode(t, `
|
||||||
|
local results = {}
|
||||||
|
|
||||||
|
-- Test string library
|
||||||
|
results.string_upper = string.upper("test")
|
||||||
|
|
||||||
|
-- Test math library
|
||||||
|
results.math_sqrt = math.sqrt(16)
|
||||||
|
|
||||||
|
-- Test table library
|
||||||
|
local tbl = {10, 20, 30}
|
||||||
|
table.insert(tbl, 40)
|
||||||
|
results.table_length = #tbl
|
||||||
|
|
||||||
|
-- Test os library (limited functions)
|
||||||
|
results.has_os_time = type(os.time) == "function"
|
||||||
|
|
||||||
|
return results
|
||||||
|
`)
|
||||||
|
|
||||||
|
result, err := pool.Submit(bytecode, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to execute script: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resultMap, ok := result.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("Expected map result, got %T", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check standard library functions worked
|
||||||
|
if resultMap["string_upper"] != "TEST" {
|
||||||
|
t.Errorf("Expected string_upper = 'TEST', got %v", resultMap["string_upper"])
|
||||||
|
}
|
||||||
|
|
||||||
|
if resultMap["math_sqrt"].(float64) != 4 {
|
||||||
|
t.Errorf("Expected math_sqrt = 4, got %v", resultMap["math_sqrt"])
|
||||||
|
}
|
||||||
|
|
||||||
|
if resultMap["table_length"].(float64) != 4 {
|
||||||
|
t.Errorf("Expected table_length = 4, got %v", resultMap["table_length"])
|
||||||
|
}
|
||||||
|
|
||||||
|
if resultMap["has_os_time"] != true {
|
||||||
|
t.Errorf("Expected has_os_time = true, got %v", resultMap["has_os_time"])
|
||||||
|
}
|
||||||
|
}
|
59
core/workers/modules.go
Normal file
59
core/workers/modules.go
Normal file
|
@ -0,0 +1,59 @@
|
||||||
|
package workers
|
||||||
|
|
||||||
|
import (
|
||||||
|
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ModuleFunc is a function that returns a map of module functions
|
||||||
|
type ModuleFunc func() map[string]luajit.GoFunction
|
||||||
|
|
||||||
|
// ModuleInitFunc creates a state initializer that registers multiple modules
|
||||||
|
func ModuleInitFunc(modules map[string]ModuleFunc) StateInitFunc {
|
||||||
|
return func(state *luajit.State) error {
|
||||||
|
for name, moduleFunc := range modules {
|
||||||
|
if err := RegisterModule(state, name, moduleFunc()); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterModule registers a map of functions as a Lua module
|
||||||
|
func RegisterModule(state *luajit.State, name string, funcs map[string]luajit.GoFunction) error {
|
||||||
|
// Create a new table for the module
|
||||||
|
state.NewTable()
|
||||||
|
|
||||||
|
// Add each function to the module table
|
||||||
|
for fname, f := range funcs {
|
||||||
|
// Push function name
|
||||||
|
state.PushString(fname)
|
||||||
|
|
||||||
|
// Push function
|
||||||
|
if err := state.PushGoFunction(f); err != nil {
|
||||||
|
state.Pop(2) // Pop table and function name
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set table[fname] = f
|
||||||
|
state.SetTable(-3)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register the module globally
|
||||||
|
state.SetGlobal(name)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CombineInitFuncs combines multiple state initializer functions into one
|
||||||
|
func CombineInitFuncs(funcs ...StateInitFunc) StateInitFunc {
|
||||||
|
return func(state *luajit.State) error {
|
||||||
|
for _, f := range funcs {
|
||||||
|
if f != nil {
|
||||||
|
if err := f(state); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
|
@ -4,8 +4,14 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
|
||||||
|
luajit "git.sharkk.net/Sky/LuaJIT-to-Go"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// StateInitFunc is a function that initializes a Lua state
|
||||||
|
// It can be used to register custom functions and modules
|
||||||
|
type StateInitFunc func(*luajit.State) error
|
||||||
|
|
||||||
// Pool manages a pool of Lua worker goroutines
|
// Pool manages a pool of Lua worker goroutines
|
||||||
type Pool struct {
|
type Pool struct {
|
||||||
workers uint32 // Number of workers
|
workers uint32 // Number of workers
|
||||||
|
@ -13,10 +19,17 @@ type Pool struct {
|
||||||
wg sync.WaitGroup // WaitGroup to track active workers
|
wg sync.WaitGroup // WaitGroup to track active workers
|
||||||
quit chan struct{} // Channel to signal shutdown
|
quit chan struct{} // Channel to signal shutdown
|
||||||
isRunning atomic.Bool // Flag to track if pool is running
|
isRunning atomic.Bool // Flag to track if pool is running
|
||||||
|
stateInit StateInitFunc // Optional function to initialize Lua state
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewPool creates a new worker pool with the specified number of workers
|
// NewPool creates a new worker pool with the specified number of workers
|
||||||
func NewPool(numWorkers int) (*Pool, error) {
|
func NewPool(numWorkers int) (*Pool, error) {
|
||||||
|
return NewPoolWithInit(numWorkers, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPoolWithInit creates a new worker pool with the specified number of workers
|
||||||
|
// and a function to initialize each worker's Lua state
|
||||||
|
func NewPoolWithInit(numWorkers int, initFunc StateInitFunc) (*Pool, error) {
|
||||||
if numWorkers <= 0 {
|
if numWorkers <= 0 {
|
||||||
return nil, ErrNoWorkers
|
return nil, ErrNoWorkers
|
||||||
}
|
}
|
||||||
|
@ -25,6 +38,7 @@ func NewPool(numWorkers int) (*Pool, error) {
|
||||||
workers: uint32(numWorkers),
|
workers: uint32(numWorkers),
|
||||||
jobs: make(chan job, numWorkers), // Buffer equal to worker count
|
jobs: make(chan job, numWorkers), // Buffer equal to worker count
|
||||||
quit: make(chan struct{}),
|
quit: make(chan struct{}),
|
||||||
|
stateInit: initFunc,
|
||||||
}
|
}
|
||||||
p.isRunning.Store(true)
|
p.isRunning.Store(true)
|
||||||
|
|
||||||
|
@ -41,6 +55,12 @@ func NewPool(numWorkers int) (*Pool, error) {
|
||||||
return p, nil
|
return p, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegisterGlobal is no longer needed with the sandbox approach
|
||||||
|
// but kept as a no-op for backward compatibility
|
||||||
|
func (p *Pool) RegisterGlobal(name string) {
|
||||||
|
// No-op in sandbox mode
|
||||||
|
}
|
||||||
|
|
||||||
// SubmitWithContext sends a job to the worker pool with context
|
// SubmitWithContext sends a job to the worker pool with context
|
||||||
func (p *Pool) SubmitWithContext(ctx context.Context, bytecode []byte, execCtx *Context) (any, error) {
|
func (p *Pool) SubmitWithContext(ctx context.Context, bytecode []byte, execCtx *Context) (any, error) {
|
||||||
if !p.isRunning.Load() {
|
if !p.isRunning.Load() {
|
||||||
|
|
144
core/workers/sandbox.go
Normal file
144
core/workers/sandbox.go
Normal file
|
@ -0,0 +1,144 @@
|
||||||
|
package workers
|
||||||
|
|
||||||
|
// setupSandbox initializes the sandbox environment creation function
|
||||||
|
func (w *worker) setupSandbox() error {
|
||||||
|
// This is the Lua script that creates our sandbox function
|
||||||
|
setupScript := `
|
||||||
|
-- Create a function to run code in a sandbox environment
|
||||||
|
function __create_sandbox()
|
||||||
|
-- Create new environment table
|
||||||
|
local env = {}
|
||||||
|
|
||||||
|
-- Add standard library modules (can be restricted as needed)
|
||||||
|
env.string = string
|
||||||
|
env.table = table
|
||||||
|
env.math = math
|
||||||
|
env.os = {
|
||||||
|
time = os.time,
|
||||||
|
date = os.date,
|
||||||
|
difftime = os.difftime,
|
||||||
|
clock = os.clock
|
||||||
|
}
|
||||||
|
env.tonumber = tonumber
|
||||||
|
env.tostring = tostring
|
||||||
|
env.type = type
|
||||||
|
env.pairs = pairs
|
||||||
|
env.ipairs = ipairs
|
||||||
|
env.next = next
|
||||||
|
env.select = select
|
||||||
|
env.unpack = unpack
|
||||||
|
env.pcall = pcall
|
||||||
|
env.xpcall = xpcall
|
||||||
|
env.error = error
|
||||||
|
env.assert = assert
|
||||||
|
|
||||||
|
-- Allow access to package.loaded for modules
|
||||||
|
env.require = function(name)
|
||||||
|
return package.loaded[name]
|
||||||
|
end
|
||||||
|
|
||||||
|
-- Create metatable to restrict access to _G
|
||||||
|
local mt = {
|
||||||
|
__index = function(t, k)
|
||||||
|
-- First check in env table
|
||||||
|
local v = rawget(env, k)
|
||||||
|
if v ~= nil then return v end
|
||||||
|
|
||||||
|
-- If not found, check for registered modules/functions
|
||||||
|
local moduleValue = _G[k]
|
||||||
|
if type(moduleValue) == "table" or
|
||||||
|
type(moduleValue) == "function" then
|
||||||
|
return moduleValue
|
||||||
|
end
|
||||||
|
|
||||||
|
return nil
|
||||||
|
end,
|
||||||
|
__newindex = function(t, k, v)
|
||||||
|
rawset(env, k, v)
|
||||||
|
end
|
||||||
|
}
|
||||||
|
|
||||||
|
setmetatable(env, mt)
|
||||||
|
return env
|
||||||
|
end
|
||||||
|
|
||||||
|
-- Create function to execute code with a sandbox
|
||||||
|
function __run_sandboxed(f, ctx)
|
||||||
|
local env = __create_sandbox()
|
||||||
|
|
||||||
|
-- Add context to the environment if provided
|
||||||
|
if ctx then
|
||||||
|
env.ctx = ctx
|
||||||
|
end
|
||||||
|
|
||||||
|
-- Set the environment and run the function
|
||||||
|
setfenv(f, env)
|
||||||
|
return f()
|
||||||
|
end
|
||||||
|
`
|
||||||
|
|
||||||
|
return w.state.DoString(setupScript)
|
||||||
|
}
|
||||||
|
|
||||||
|
// executeJobSandboxed runs a script in a sandbox environment
|
||||||
|
func (w *worker) executeJobSandboxed(j job) JobResult {
|
||||||
|
// No need to reset the state for each execution, since we're using a sandbox
|
||||||
|
|
||||||
|
// Re-run init function to register functions and modules if needed
|
||||||
|
if w.pool.stateInit != nil {
|
||||||
|
if err := w.pool.stateInit(w.state); err != nil {
|
||||||
|
return JobResult{nil, err}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set up context if provided
|
||||||
|
if j.Context != nil {
|
||||||
|
// Push context table
|
||||||
|
w.state.NewTable()
|
||||||
|
|
||||||
|
// Add values to context table
|
||||||
|
for key, value := range j.Context.Values {
|
||||||
|
// Push key
|
||||||
|
w.state.PushString(key)
|
||||||
|
|
||||||
|
// Push value
|
||||||
|
if err := w.state.PushValue(value); err != nil {
|
||||||
|
return JobResult{nil, err}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set table[key] = value
|
||||||
|
w.state.SetTable(-3)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Push nil if no context
|
||||||
|
w.state.PushNil()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load bytecode
|
||||||
|
if err := w.state.LoadBytecode(j.Bytecode, "script"); err != nil {
|
||||||
|
w.state.Pop(1) // Pop context
|
||||||
|
return JobResult{nil, err}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the sandbox runner function
|
||||||
|
w.state.GetGlobal("__run_sandboxed")
|
||||||
|
|
||||||
|
// Push loaded function and context as arguments
|
||||||
|
w.state.PushCopy(-2) // Copy the loaded function
|
||||||
|
w.state.PushCopy(-4) // Copy the context table or nil
|
||||||
|
|
||||||
|
// Remove the original function and context
|
||||||
|
w.state.Remove(-5) // Remove original context
|
||||||
|
w.state.Remove(-4) // Remove original function
|
||||||
|
|
||||||
|
// Call the sandbox runner with 2 args (function and context), expecting 1 result
|
||||||
|
if err := w.state.Call(2, 1); err != nil {
|
||||||
|
return JobResult{nil, err}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get result
|
||||||
|
value, err := w.state.ToValue(-1)
|
||||||
|
w.state.Pop(1) // Pop result
|
||||||
|
|
||||||
|
return JobResult{value, err}
|
||||||
|
}
|
|
@ -11,6 +11,7 @@ import (
|
||||||
var (
|
var (
|
||||||
ErrPoolClosed = errors.New("worker pool is closed")
|
ErrPoolClosed = errors.New("worker pool is closed")
|
||||||
ErrNoWorkers = errors.New("no workers available")
|
ErrNoWorkers = errors.New("no workers available")
|
||||||
|
ErrInitFailed = errors.New("worker initialization failed")
|
||||||
)
|
)
|
||||||
|
|
||||||
// worker represents a single Lua execution worker
|
// worker represents a single Lua execution worker
|
||||||
|
@ -33,13 +34,22 @@ func (w *worker) run() {
|
||||||
}
|
}
|
||||||
defer w.state.Close()
|
defer w.state.Close()
|
||||||
|
|
||||||
// Set up reset function for clearing state between requests
|
// Set up sandbox environment
|
||||||
if err := w.setupResetFunction(); err != nil {
|
if err := w.setupSandbox(); err != nil {
|
||||||
// Worker failed to initialize reset function, decrement counter
|
// Worker failed to initialize sandbox, decrement counter
|
||||||
atomic.AddUint32(&w.pool.workers, ^uint32(0))
|
atomic.AddUint32(&w.pool.workers, ^uint32(0))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Run init function if provided
|
||||||
|
if w.pool.stateInit != nil {
|
||||||
|
if err := w.pool.stateInit(w.state); err != nil {
|
||||||
|
// Worker failed to initialize with custom init function
|
||||||
|
atomic.AddUint32(&w.pool.workers, ^uint32(0))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Main worker loop
|
// Main worker loop
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
|
@ -50,7 +60,7 @@ func (w *worker) run() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute job
|
// Execute job
|
||||||
result := w.executeJob(job)
|
result := w.executeJobSandboxed(job)
|
||||||
job.Result <- result
|
job.Result <- result
|
||||||
|
|
||||||
case <-w.pool.quit:
|
case <-w.pool.quit:
|
||||||
|
@ -59,102 +69,3 @@ func (w *worker) run() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// setupResetFunction initializes the reset function for clearing globals
|
|
||||||
func (w *worker) setupResetFunction() error {
|
|
||||||
resetScript := `
|
|
||||||
-- Create reset function to efficiently clear globals after each request
|
|
||||||
function __reset_globals()
|
|
||||||
-- Only keep builtin globals, remove all user-defined globals
|
|
||||||
local preserve = {
|
|
||||||
["_G"] = true, ["_VERSION"] = true, ["__reset_globals"] = true,
|
|
||||||
["assert"] = true, ["collectgarbage"] = true, ["coroutine"] = true,
|
|
||||||
["debug"] = true, ["dofile"] = true, ["error"] = true,
|
|
||||||
["getmetatable"] = true, ["io"] = true, ["ipairs"] = true,
|
|
||||||
["load"] = true, ["loadfile"] = true, ["loadstring"] = true,
|
|
||||||
["math"] = true, ["next"] = true, ["os"] = true,
|
|
||||||
["package"] = true, ["pairs"] = true, ["pcall"] = true,
|
|
||||||
["print"] = true, ["rawequal"] = true, ["rawget"] = true,
|
|
||||||
["rawset"] = true, ["require"] = true, ["select"] = true,
|
|
||||||
["setmetatable"] = true, ["string"] = true, ["table"] = true,
|
|
||||||
["tonumber"] = true, ["tostring"] = true, ["type"] = true,
|
|
||||||
["unpack"] = true, ["xpcall"] = true
|
|
||||||
}
|
|
||||||
|
|
||||||
-- Clear all non-standard globals
|
|
||||||
for name in pairs(_G) do
|
|
||||||
if not preserve[name] then
|
|
||||||
_G[name] = nil
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
-- Run garbage collection to release memory
|
|
||||||
collectgarbage('collect')
|
|
||||||
end
|
|
||||||
`
|
|
||||||
|
|
||||||
return w.state.DoString(resetScript)
|
|
||||||
}
|
|
||||||
|
|
||||||
// resetState prepares the Lua state for a new job
|
|
||||||
func (w *worker) resetState() {
|
|
||||||
w.state.DoString("__reset_globals()")
|
|
||||||
}
|
|
||||||
|
|
||||||
// setContext sets job context as global tables in Lua state
|
|
||||||
func (w *worker) setContext(ctx *Context) error {
|
|
||||||
if ctx == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create context table
|
|
||||||
w.state.NewTable()
|
|
||||||
|
|
||||||
// Add values to context table
|
|
||||||
for key, value := range ctx.Values {
|
|
||||||
// Push key
|
|
||||||
w.state.PushString(key)
|
|
||||||
|
|
||||||
// Push value
|
|
||||||
if err := w.state.PushValue(value); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set table[key] = value
|
|
||||||
w.state.SetTable(-3)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set the table as global 'ctx'
|
|
||||||
w.state.SetGlobal("ctx")
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// executeJob executes a Lua job in the worker's state
|
|
||||||
func (w *worker) executeJob(j job) JobResult {
|
|
||||||
// Reset state before execution
|
|
||||||
w.resetState()
|
|
||||||
|
|
||||||
// Set context
|
|
||||||
if j.Context != nil {
|
|
||||||
if err := w.setContext(j.Context); err != nil {
|
|
||||||
return JobResult{nil, err}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load bytecode
|
|
||||||
if err := w.state.LoadBytecode(j.Bytecode, "script"); err != nil {
|
|
||||||
return JobResult{nil, err}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Execute script with one result
|
|
||||||
if err := w.state.RunBytecodeWithResults(1); err != nil {
|
|
||||||
return JobResult{nil, err}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get result
|
|
||||||
value, err := w.state.ToValue(-1)
|
|
||||||
w.state.Pop(1) // Pop result
|
|
||||||
|
|
||||||
return JobResult{value, err}
|
|
||||||
}
|
|
||||||
|
|
2
luajit
2
luajit
|
@ -1 +1 @@
|
||||||
Subproject commit 13686b3e66b388a31d459fe95d1aa3bfa05aeb27
|
Subproject commit 7ea0dbcb7b2ddcd8758e66b034c300ee55178b29
|
|
@ -47,7 +47,7 @@ func initRouters(routesDir, staticDir string, log *logger.Logger) (*routers.LuaR
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
// Initialize logger
|
// Initialize logger
|
||||||
log := logger.New(logger.LevelDebug, true)
|
log := logger.New(logger.LevelInfo, true)
|
||||||
|
|
||||||
log.Info("Starting Moonshark server")
|
log.Info("Starting Moonshark server")
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user