This commit is contained in:
Sky Johnson 2025-04-10 09:48:58 -05:00
parent ba9a3db0a0
commit ab6135e98a
4 changed files with 54 additions and 88 deletions

View File

@ -24,7 +24,7 @@ func precompileSandboxCode() {
// Create temporary state for compilation // Create temporary state for compilation
tempState := luajit.New() tempState := luajit.New()
if tempState == nil { if tempState == nil {
logger.Error("Failed to create temp Lua state for bytecode compilation") logger.ErrorCont("Failed to create temp Lua state for bytecode compilation")
return return
} }
defer tempState.Close() defer tempState.Close()
@ -32,7 +32,7 @@ func precompileSandboxCode() {
code, err := tempState.CompileBytecode(sandboxLuaCode, "sandbox.lua") code, err := tempState.CompileBytecode(sandboxLuaCode, "sandbox.lua")
if err != nil { if err != nil {
logger.Error("Failed to compile sandbox code: %v", err) logger.ErrorCont("Failed to compile sandbox code: %v", err)
return return
} }
@ -40,22 +40,20 @@ func precompileSandboxCode() {
copy(bytecode, code) copy(bytecode, code)
sandboxBytecode.Store(&bytecode) sandboxBytecode.Store(&bytecode)
logger.Debug("Successfully precompiled sandbox.lua to bytecode (%d bytes)", len(code)) logger.ServerCont("Successfully precompiled sandbox.lua to bytecode (%d bytes)", len(code))
} }
// loadSandboxIntoState loads the sandbox code into a Lua state // loadSandboxIntoState loads the sandbox code into a Lua state
func loadSandboxIntoState(state *luajit.State) error { func loadSandboxIntoState(state *luajit.State) error {
// Initialize bytecode once
bytecodeOnce.Do(precompileSandboxCode) bytecodeOnce.Do(precompileSandboxCode)
// Use precompiled bytecode if available
bytecode := sandboxBytecode.Load() bytecode := sandboxBytecode.Load()
if bytecode != nil && len(*bytecode) > 0 { if bytecode != nil && len(*bytecode) > 0 {
logger.Debug("Loading sandbox.lua from precompiled bytecode") logger.ServerCont("Loading sandbox.lua from precompiled bytecode") // piggyback off Sandbox.go's Setup()
return state.LoadAndRunBytecode(*bytecode, "sandbox.lua") return state.LoadAndRunBytecode(*bytecode, "sandbox.lua")
} }
// Fallback to direct execution // Fallback to direct execution
logger.Warning("Using non-precompiled sandbox.lua (bytecode compilation failed)") logger.WarningCont("Using non-precompiled sandbox.lua (bytecode compilation failed)")
return state.DoString(sandboxLuaCode) return state.DoString(sandboxLuaCode)
} }

View File

@ -147,17 +147,12 @@ func (r *Runner) createState(index int) (*State, error) {
r.debugLog("Creating Lua state %d", index) r.debugLog("Creating Lua state %d", index)
} }
// Create a new state
L := luajit.New() L := luajit.New()
if L == nil { if L == nil {
return nil, errors.New("failed to create Lua state") return nil, errors.New("failed to create Lua state")
} }
// Create sandbox
sb := NewSandbox() sb := NewSandbox()
if r.debug {
sb.EnableDebug()
}
// Set up sandbox // Set up sandbox
if err := sb.Setup(L); err != nil { if err := sb.Setup(L); err != nil {

View File

@ -40,67 +40,50 @@ func NewSandbox() *Sandbox {
} }
} }
// EnableDebug turns on debug logging
func (s *Sandbox) EnableDebug() {
s.debug = true
}
// debugLog logs a message if debug mode is enabled
func (s *Sandbox) debugLog(format string, args ...interface{}) {
if s.debug {
logger.Debug("Sandbox "+format, args...)
}
}
// AddModule adds a module to the sandbox environment // AddModule adds a module to the sandbox environment
func (s *Sandbox) AddModule(name string, module any) { func (s *Sandbox) AddModule(name string, module any) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
s.modules[name] = module s.modules[name] = module
s.debugLog("Added module: %s", name) logger.Debug("Added module: %s", name)
} }
// Setup initializes the sandbox in a Lua state // Setup initializes the sandbox in a Lua state
func (s *Sandbox) Setup(state *luajit.State) error { func (s *Sandbox) Setup(state *luajit.State) error {
s.debugLog("Setting up sandbox...") logger.Server("Setting up sandbox...")
// Load the sandbox code
if err := loadSandboxIntoState(state); err != nil { if err := loadSandboxIntoState(state); err != nil {
s.debugLog("Failed to load sandbox: %v", err) logger.ErrorCont("Failed to load sandbox: %v", err)
return err return err
} }
// Register core functions
if err := s.registerCoreFunctions(state); err != nil { if err := s.registerCoreFunctions(state); err != nil {
s.debugLog("Failed to register core functions: %v", err) logger.ErrorCont("Failed to register core functions: %v", err)
return err return err
} }
// Register custom modules in the global environment
s.mu.RLock() s.mu.RLock()
for name, module := range s.modules { for name, module := range s.modules {
s.debugLog("Registering module: %s", name) logger.DebugCont("Registering module: %s", name)
if err := state.PushValue(module); err != nil { if err := state.PushValue(module); err != nil {
s.mu.RUnlock() s.mu.RUnlock()
s.debugLog("Failed to register module %s: %v", name, err) logger.ErrorCont("Failed to register module %s: %v", name, err)
return err return err
} }
state.SetGlobal(name) state.SetGlobal(name)
} }
s.mu.RUnlock() s.mu.RUnlock()
s.debugLog("Sandbox setup complete") logger.ServerCont("Sandbox setup complete")
return nil return nil
} }
// registerCoreFunctions registers all built-in functions in the Lua state // registerCoreFunctions registers all built-in functions in the Lua state
func (s *Sandbox) registerCoreFunctions(state *luajit.State) error { func (s *Sandbox) registerCoreFunctions(state *luajit.State) error {
// Register HTTP functions
if err := state.RegisterGoFunction("__http_request", httpRequest); err != nil { if err := state.RegisterGoFunction("__http_request", httpRequest); err != nil {
return err return err
} }
// Register utility functions
if err := state.RegisterGoFunction("__generate_token", generateToken); err != nil { if err := state.RegisterGoFunction("__generate_token", generateToken); err != nil {
return err return err
} }
@ -112,43 +95,38 @@ func (s *Sandbox) registerCoreFunctions(state *luajit.State) error {
// Execute runs a Lua script in the sandbox with the given context // Execute runs a Lua script in the sandbox with the given context
func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx *Context) (*Response, error) { func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx *Context) (*Response, error) {
// Create a response object
response := NewResponse()
// Get the execution function first // Get the execution function first
state.GetGlobal("__execute_script") state.GetGlobal("__execute_script")
if !state.IsFunction(-1) { if !state.IsFunction(-1) {
state.Pop(1) state.Pop(1)
ReleaseResponse(response)
return nil, ErrSandboxNotInitialized return nil, ErrSandboxNotInitialized
} }
// Load bytecode // Load bytecode
if err := state.LoadBytecode(bytecode, "script"); err != nil { if err := state.LoadBytecode(bytecode, "script"); err != nil {
state.Pop(1) // Pop the __execute_script function state.Pop(1) // Pop the __execute_script function
ReleaseResponse(response)
return nil, fmt.Errorf("failed to load script: %w", err) return nil, fmt.Errorf("failed to load script: %w", err)
} }
// Push context values // Push context values
if err := state.PushTable(ctx.Values); err != nil { if err := state.PushTable(ctx.Values); err != nil {
state.Pop(2) // Pop bytecode and __execute_script state.Pop(2) // Pop bytecode and __execute_script
ReleaseResponse(response)
return nil, err return nil, err
} }
// Execute with 2 args, 1 result // Execute with 2 args, 1 result
if err := state.Call(2, 1); err != nil { if err := state.Call(2, 1); err != nil {
ReleaseResponse(response)
return nil, fmt.Errorf("script execution failed: %w", err) return nil, fmt.Errorf("script execution failed: %w", err)
} }
// Set response body from result // Get result value
body, err := state.ToValue(-1) body, err := state.ToValue(-1)
state.Pop(1)
response := NewResponse()
if err == nil { if err == nil {
response.Body = body response.Body = body
} }
state.Pop(1)
extractHTTPResponseData(state, response) extractHTTPResponseData(state, response)

View File

@ -57,7 +57,6 @@ end
-- HTTP MODULE -- HTTP MODULE
-- ====================================================================== -- ======================================================================
-- HTTP module implementation
local http = { local http = {
-- Set HTTP status code -- Set HTTP status code
set_status = function(code) set_status = function(code)
@ -112,45 +111,6 @@ local http = {
return result return result
end, end,
-- Simple GET request
get = function(url, options)
return http.client.request("GET", url, nil, options)
end,
-- Simple POST request with automatic content-type
post = function(url, body, options)
options = options or {}
return http.client.request("POST", url, body, options)
end,
-- Simple PUT request with automatic content-type
put = function(url, body, options)
options = options or {}
return http.client.request("PUT", url, body, options)
end,
-- Simple DELETE request
delete = function(url, options)
return http.client.request("DELETE", url, nil, options)
end,
-- Simple PATCH request
patch = function(url, body, options)
options = options or {}
return http.client.request("PATCH", url, body, options)
end,
-- Simple HEAD request
head = function(url, options)
options = options or {}
return http.client.request("HEAD", url, nil, options)
end,
-- Simple OPTIONS request
options = function(url, options)
return http.client.request("OPTIONS", url, nil, options)
end,
-- Shorthand function to directly get JSON -- Shorthand function to directly get JSON
get_json = function(url, options) get_json = function(url, options)
options = options or {} options = options or {}
@ -191,11 +151,30 @@ local http = {
} }
} }
local function make_method(method, needs_body)
return function(url, body_or_options, options)
if needs_body then
options = options or {}
return http.client.request(method, url, body_or_options, options)
else
body_or_options = body_or_options or {}
return http.client.request(method, url, nil, body_or_options)
end
end
end
http.client.get = make_method("GET", false)
http.client.delete = make_method("DELETE", false)
http.client.head = make_method("HEAD", false)
http.client.options = make_method("OPTIONS", false)
http.client.post = make_method("POST", true)
http.client.put = make_method("PUT", true)
http.client.patch = make_method("PATCH", true)
-- ====================================================================== -- ======================================================================
-- COOKIE MODULE -- COOKIE MODULE
-- ====================================================================== -- ======================================================================
-- Cookie module implementation
local cookie = { local cookie = {
-- Set a cookie -- Set a cookie
set = function(name, value, options) set = function(name, value, options)
@ -231,8 +210,25 @@ local cookie = {
cookie.secure = (opts.secure ~= false) cookie.secure = (opts.secure ~= false)
cookie.http_only = (opts.http_only ~= false) cookie.http_only = (opts.http_only ~= false)
table.insert(resp.cookies, cookie) if opts.same_site then
local valid_values = {none = true, lax = true, strict = true}
local same_site = string.lower(opts.same_site)
if not valid_values[same_site] then
error("cookie.set: same_site must be one of 'None', 'Lax', or 'Strict'", 2)
end
-- If SameSite=None, the cookie must be secure
if same_site == "none" and not cookie.secure then
cookie.secure = true
end
cookie.same_site = opts.same_site
else
cookie.same_site = "Lax"
end
table.insert(resp.cookies, cookie)
return true return true
end, end,
@ -353,7 +349,6 @@ local util = {
-- REGISTER MODULES GLOBALLY -- REGISTER MODULES GLOBALLY
-- ====================================================================== -- ======================================================================
-- Install modules in global scope
_G.http = http _G.http = http
_G.cookie = cookie _G.cookie = cookie
_G.util = util _G.util = util