diff --git a/runner/embed.go b/runner/embed.go index a86dffc..f51909a 100644 --- a/runner/embed.go +++ b/runner/embed.go @@ -53,7 +53,7 @@ type ModuleInfo struct { } var ( - sandbox = ModuleInfo{Name: "sandbox", Code: sandboxLuaCode} + sandbox = ModuleInfo{Name: "sandbox", Code: sandboxLuaCode, DefinesGlobal: true} modules = []ModuleInfo{ {Name: "json", Code: jsonLuaCode, DefinesGlobal: true}, {Name: "sqlite", Code: sqliteLuaCode}, @@ -71,7 +71,7 @@ var ( // precompileModule compiles a module's code to bytecode once func precompileModule(m *ModuleInfo) { m.Once.Do(func() { - tempState := luajit.New(true) // Explicitly open standard libraries + tempState := luajit.New(true) if tempState == nil { logger.Fatalf("Failed to create temp Lua state for %s module compilation", m.Name) return @@ -123,8 +123,18 @@ func loadModule(state *luajit.State, m *ModuleInfo, verbose bool) error { logger.Warnf("Using non-precompiled %s.lua", m.Name) } - if err := state.DoString(m.Code); err != nil { - return err + if m.DefinesGlobal { + if err := state.DoString(m.Code); err != nil { + return err + } + } else { + if err := state.LoadString(m.Code); err != nil { + return err + } + if err := state.Call(0, 1); err != nil { + return err + } + state.SetGlobal(m.Name) } } @@ -133,19 +143,19 @@ func loadModule(state *luajit.State, m *ModuleInfo, verbose bool) error { // loadSandboxIntoState loads all modules and sandbox into a Lua state func loadSandboxIntoState(state *luajit.State, verbose bool) error { - // Load all modules first + // Load all utility modules first for i := range modules { if err := loadModule(state, &modules[i], verbose); err != nil { return err } } - // Initialize active connections tracking (specific to SQLite) + // Initialize any module-specific globals (like SQLite tracking) if err := state.DoString(`__active_sqlite_connections = {}`); err != nil { return err } - // Load the sandbox last + // Load the sandbox last - it defines __execute and other core functions precompileModule(&sandbox) bytecode := sandbox.Bytecode.Load() if bytecode != nil && len(*bytecode) > 0 { diff --git a/runner/lua/sandbox.lua b/runner/lua/sandbox.lua index e5b9357..4017a71 100644 --- a/runner/lua/sandbox.lua +++ b/runner/lua/sandbox.lua @@ -1,664 +1,244 @@ ---[[ -sandbox.lua - Rewritten with global context storage -]]-- +-- Simplified sandbox.lua - Direct execution with explicit state passing -__http_response = {} -__module_paths = {} -__module_bytecode = {} -__ready_modules = {} -__EXIT_SENTINEL = {} -- Unique object for exit identification - --- Global context storage for reliable access -local _current_ctx = nil - --- ====================================================================== --- CORE SANDBOX FUNCTIONALITY --- ====================================================================== - -function exit() - error(__EXIT_SENTINEL) -end - --- Create environment inheriting from _G -function __create_env(ctx) +-- Main execution wrapper - receives script, context, and response object +function __execute(script_func, ctx, response) + -- Create clean environment local env = setmetatable({}, {__index = _G}) - - if ctx then - env.ctx = ctx - - if ctx._env then - env._env = ctx._env + + -- Direct context and response access + env.ctx = ctx + env.response = response + + -- Exit sentinel + env.exit = function() error("__EXIT__") end + + -- ====================================================================== + -- HTTP FUNCTIONS - Modify response directly + -- ====================================================================== + + env.http_set_status = function(code) + if type(code) ~= "number" then + error("http_set_status: status code must be a number", 2) + end + response.status = code + end + + env.http_set_header = function(name, value) + if type(name) ~= "string" or type(value) ~= "string" then + error("http_set_header: name and value must be strings", 2) + end + response.headers = response.headers or {} + response.headers[name] = value + end + + env.http_set_content_type = function(ct) + env.http_set_header("Content-Type", ct) + end + + env.http_set_metadata = function(key, value) + if type(key) ~= "string" then + error("http_set_metadata: key must be a string", 2) + end + response.metadata = response.metadata or {} + response.metadata[key] = value + end + + env.http_redirect = function(url, status) + if type(url) ~= "string" then + error("http_redirect: url must be a string", 2) + end + response.status = status or 302 + response.headers = response.headers or {} + response.headers["Location"] = url + env.exit() + end + + -- HTTP request functions (use native implementation) + env.http_request = __http_request + + local function make_method(method, needs_body) + return function(url, body_or_options, options) + if needs_body then + options = options or {} + return env.http_request(method, url, body_or_options, options) + else + body_or_options = body_or_options or {} + return env.http_request(method, url, nil, body_or_options) + end end end - + + env.http_get = make_method("GET", false) + env.http_post = make_method("POST", true) + env.http_put = make_method("PUT", true) + env.http_patch = make_method("PATCH", true) + env.http_delete = make_method("DELETE", false) + env.http_head = make_method("HEAD", false) + env.http_options = make_method("OPTIONS", false) + + -- ====================================================================== + -- COOKIE FUNCTIONS - Direct access to ctx and response + -- ====================================================================== + + env.cookie_set = function(name, value, options) + if type(name) ~= "string" then + error("cookie_set: name must be a string", 2) + end + + response.cookies = response.cookies or {} + local opts = options or {} + + local cookie = { + name = name, + value = value or "", + path = opts.path or "/", + domain = opts.domain, + secure = opts.secure ~= false, + http_only = opts.http_only ~= false, + same_site = opts.same_site or "Lax" + } + + if opts.expires then + if type(opts.expires) == "number" and opts.expires > 0 then + cookie.max_age = opts.expires + end + end + + table.insert(response.cookies, cookie) + return true + end + + env.cookie_get = function(name) + if type(name) ~= "string" then + error("cookie_get: name must be a string", 2) + end + return ctx.cookies and ctx.cookies[name] + end + + env.cookie_remove = function(name, path, domain) + return env.cookie_set(name, "", {expires = -1, path = path or "/", domain = domain}) + end + + -- ====================================================================== + -- SESSION FUNCTIONS - Direct access to ctx and response + -- ====================================================================== + + env.session_get = function(key) + if type(key) ~= "string" then + error("session_get: key must be a string", 2) + end + return ctx.session and ctx.session.data and ctx.session.data[key] + end + + env.session_set = function(key, value) + if type(key) ~= "string" then + error("session_set: key must be a string", 2) + end + response.session = response.session or {} + response.session[key] = value + -- Update context if available + if ctx.session and ctx.session.data then + ctx.session.data[key] = value + end + end + + env.session_id = function() + return ctx.session and ctx.session.id + end + + env.session_get_all = function() + if ctx.session and ctx.session.data then + local copy = {} + for k, v in pairs(ctx.session.data) do + copy[k] = v + end + return copy + end + return {} + end + + env.session_delete = function(key) + if type(key) ~= "string" then + error("session_delete: key must be a string", 2) + end + response.session = response.session or {} + response.session[key] = "__DELETE__" + if ctx.session and ctx.session.data then + ctx.session.data[key] = nil + end + end + + env.session_clear = function() + response.session = {__clear_all = true} + if ctx.session and ctx.session.data then + for k in pairs(ctx.session.data) do + ctx.session.data[k] = nil + end + end + end + + -- ====================================================================== + -- CONTENT TYPE HELPERS + -- ====================================================================== + + env.send_html = function(content) + env.http_set_content_type("text/html") + return content + end + + env.send_json = function(content) + env.http_set_content_type("application/json") + return content + end + + env.send_text = function(content) + env.http_set_content_type("text/plain") + return content + end + + -- ====================================================================== + -- NATIVE FUNCTIONS (injected by Go) + -- ====================================================================== + + -- Copy over native functions + local natives = { + "__password_hash", "__password_verify", + "__sqlite_open", "__sqlite_exec", "__sqlite_query", "__sqlite_close", + "__fs_read", "__fs_write", "__fs_exists", "__fs_delete", + "generate_token", "url_encode", "url_decode", + "html_special_chars", "html_entities", + "base64_encode", "base64_decode", + "json_encode", "json_decode", + "sha256", "md5", "hmac", + "env_get", "env_set" + } + + for _, name in ipairs(natives) do + if _G[name] then + env[name:gsub("^__", "")] = _G[name] + end + end + + -- Template functions + env.render = _G.render + env.parse = _G.parse + env.iparse = _G.iparse + + -- Module support if __setup_require then __setup_require(env) end - - return env -end - --- Execute script with clean environment -function __execute_script(fn, ctx) - __http_response = nil - _current_ctx = ctx -- Store globally for function access - - local env = __create_env(ctx) - env.exit = exit - setfenv(fn, env) - - local ok, result = pcall(fn) - _current_ctx = nil -- Clean up after execution + -- Set function environment and execute + setfenv(script_func, env) + local ok, result = pcall(script_func) if not ok then - if result == __EXIT_SENTINEL then - return + if result == "__EXIT__" then + return nil end error(result, 0) end - + return result -end - --- Ensure __http_response exists, then return it -function __ensure_response() - if not __http_response then - __http_response = {} - end - return __http_response -end - --- ====================================================================== --- HTTP FUNCTIONS --- ====================================================================== - --- Set HTTP status code -function http_set_status(code) - if type(code) ~= "number" then - error("http_set_status: status code must be a number", 2) - end - - local resp = __ensure_response() - resp.status = code -end - --- Set HTTP header -function http_set_header(name, value) - if type(name) ~= "string" or type(value) ~= "string" then - error("http_set_header: name and value must be strings", 2) - end - - local resp = __ensure_response() - resp.headers = resp.headers or {} - resp.headers[name] = value -end - --- Set content type; http_set_header helper -function http_set_content_type(content_type) - http_set_header("Content-Type", content_type) -end - --- Set metadata (arbitrary data to be returned with response) -function http_set_metadata(key, value) - if type(key) ~= "string" then - error("http_set_metadata: key must be a string", 2) - end - - local resp = __ensure_response() - resp.metadata = resp.metadata or {} - resp.metadata[key] = value -end - --- Generic HTTP request function -function http_request(method, url, body, options) - if type(method) ~= "string" then - error("http_request: method must be a string", 2) - end - if type(url) ~= "string" then - error("http_request: url must be a string", 2) - end - - -- Call native implementation - local result = __http_request(method, url, body, options) - return result -end - --- Shorthand function to directly get JSON -function http_get_json(url, options) - options = options or {} - local response = http_get(url, options) - if response.ok and response.json then - return response.json - end - return nil, response -end - --- Utility to build a URL with query parameters -function http_build_url(base_url, params) - if not params or type(params) ~= "table" then - return base_url - end - - local query = {} - for k, v in pairs(params) do - if type(v) == "table" then - for _, item in ipairs(v) do - table.insert(query, url_encode(k) .. "=" .. url_encode(tostring(item))) - end - else - table.insert(query, url_encode(k) .. "=" .. url_encode(tostring(v))) - end - end - - if #query > 0 then - if string.contains(base_url, "?") then - return base_url .. "&" .. table.concat(query, "&") - else - return base_url .. "?" .. table.concat(query, "&") - end - end - - return base_url -end - -local function make_method(method, needs_body) - return function(url, body_or_options, options) - if needs_body then - options = options or {} - return http_request(method, url, body_or_options, options) - else - body_or_options = body_or_options or {} - return http_request(method, url, nil, body_or_options) - end - end -end - -http_get = make_method("GET", false) -http_delete = make_method("DELETE", false) -http_head = make_method("HEAD", false) -http_options = make_method("OPTIONS", false) -http_post = make_method("POST", true) -http_put = make_method("PUT", true) -http_patch = make_method("PATCH", true) - -function http_redirect(url, status) - if type(url) ~= "string" then - error("http_redirect: url must be a string", 2) - end - - status = status or 302 -- Default to temporary redirect - - local resp = __ensure_response() - resp.status = status - - resp.headers = resp.headers or {} - resp.headers["Location"] = url - - exit() -end - --- ====================================================================== --- COOKIE FUNCTIONS --- ====================================================================== - --- Set a cookie -function cookie_set(name, value, options) - if type(name) ~= "string" then - error("cookie_set: name must be a string", 2) - end - - local resp = __ensure_response() - resp.cookies = resp.cookies or {} - - local opts = options or {} - local cookie = { - name = name, - value = value or "", - path = opts.path or "/", - domain = opts.domain - } - - if opts.expires then - if type(opts.expires) == "number" then - if opts.expires > 0 then - cookie.max_age = opts.expires - local now = os.time() - cookie.expires = now + opts.expires - elseif opts.expires < 0 then - cookie.expires = 1 - cookie.max_age = 0 - end - -- opts.expires == 0: Session cookie (omitting both expires and max-age) - end - end - - cookie.secure = (opts.secure ~= false) - cookie.http_only = (opts.http_only ~= false) - - if opts.same_site then - local same_site = string.trim(opts.same_site):lower() - local valid_values = {none = true, lax = true, strict = true} - - if not valid_values[same_site] then - error("cookie_set: same_site must be one of 'None', 'Lax', or 'Strict'", 2) - end - - -- If SameSite=None, the cookie must be secure - if same_site == "none" and not cookie.secure then - cookie.secure = true - end - - cookie.same_site = opts.same_site - else - cookie.same_site = "Lax" - end - - table.insert(resp.cookies, cookie) - return true -end - --- Get a cookie value -function cookie_get(name) - if type(name) ~= "string" then - error("cookie_get: name must be a string", 2) - end - - if _current_ctx then - if _current_ctx.cookies then - return _current_ctx.cookies[name] - end - if _current_ctx._request_cookies then - return _current_ctx._request_cookies[name] - end - end - - return nil -end - --- Remove a cookie -function cookie_remove(name, path, domain) - if type(name) ~= "string" then - error("cookie_remove: name must be a string", 2) - end - - return cookie_set(name, "", {expires = 0, path = path or "/", domain = domain}) -end - --- ====================================================================== --- SESSION FUNCTIONS --- ====================================================================== - -function session_get(key) - if type(key) ~= "string" then - error("session_get: key must be a string", 2) - end - - if _current_ctx and _current_ctx.session and _current_ctx.session.data then - return _current_ctx.session.data[key] - end - - return nil -end - -function session_set(key, value) - if type(key) ~= "string" then - error("session_set: key must be a string", 2) - end - if value == nil then - error("session_set: value cannot be nil", 2) - end - - local resp = __ensure_response() - resp.session = resp.session or {} - resp.session[key] = value - - -- Update current context session data - if _current_ctx and _current_ctx.session and _current_ctx.session.data then - _current_ctx.session.data[key] = value - end -end - -function session_id() - if _current_ctx and _current_ctx.session then - return _current_ctx.session.id - end - return nil -end - -function session_get_all() - if _current_ctx and _current_ctx.session and _current_ctx.session.data then - -- Return a copy to prevent modification - local copy = {} - for k, v in pairs(_current_ctx.session.data) do - copy[k] = v - end - return copy - end - return {} -end - -function session_delete(key) - if type(key) ~= "string" then - error("session_delete: key must be a string", 2) - end - - local resp = __ensure_response() - resp.session = resp.session or {} - resp.session[key] = "__SESSION_DELETE_MARKER__" - - -- Update current context - if _current_ctx and _current_ctx.session and _current_ctx.session.data then - _current_ctx.session.data[key] = nil - end -end - -function session_clear() - if _current_ctx and _current_ctx.session and _current_ctx.session.data then - for k, _ in pairs(_current_ctx.session.data) do - _current_ctx.session.data[k] = nil - end - end - - local resp = __ensure_response() - resp.session = {} - resp.session["__clear_all"] = true -end - --- ====================================================================== --- CSRF FUNCTIONS --- ====================================================================== - -function csrf_generate() - local token = generate_token(32) - session_set("_csrf_token", token) - return token -end - -function csrf_field() - local token = session_get("_csrf_token") - if not token then - token = csrf_generate() - end - return string.format('', - html_special_chars(token)) -end - -function csrf_validate() - local token = session_get("_csrf_token") - - if not token then - http_set_status(403) - __http_response.body = "CSRF validation failed" - exit() - end - - local request_token = nil - if _current_ctx and _current_ctx.form then - request_token = _current_ctx.form._csrf_token - end - - if not request_token and _current_ctx and _current_ctx._request_headers then - request_token = _current_ctx._request_headers["x-csrf-token"] or - _current_ctx._request_headers["csrf-token"] - end - - if not request_token or request_token ~= token then - http_set_status(403) - __http_response.body = "CSRF validation failed" - exit() - end - - return true -end - --- ====================================================================== --- TEMPLATE RENDER FUNCTIONS --- ====================================================================== - --- Template processing with code execution -_G.render = function(template_str, env) - local function get_line(s, ln) - for line in s:gmatch("([^\n]*)\n?") do - if ln == 1 then return line end - ln = ln - 1 - end - end - - local function pos_to_line(s, pos) - local line = 1 - for _ in s:sub(1, pos):gmatch("\n") do line = line + 1 end - return line - end - - local pos, chunks = 1, {} - while pos <= #template_str do - local unescaped_start = template_str:find("{{{", pos, true) - local escaped_start = template_str:find("{{", pos, true) - - local start, tag_type, open_len - if unescaped_start and (not escaped_start or unescaped_start <= escaped_start) then - start, tag_type, open_len = unescaped_start, "-", 3 - elseif escaped_start then - start, tag_type, open_len = escaped_start, "=", 2 - else - table.insert(chunks, template_str:sub(pos)) - break - end - - if start > pos then - table.insert(chunks, template_str:sub(pos, start-1)) - end - - pos = start + open_len - local close_tag = tag_type == "-" and "}}}" or "}}" - local close_start, close_stop = template_str:find(close_tag, pos, true) - if not close_start then - error("Failed to find closing tag at position " .. pos) - end - - local code = template_str:sub(pos, close_start-1):match("^%s*(.-)%s*$") - - -- Check if it's a simple variable name for escaped output - local is_simple_var = tag_type == "=" and code:match("^[%w_]+$") - - table.insert(chunks, {tag_type, code, pos, is_simple_var}) - pos = close_stop + 1 - end - - local buffer = {"local _tostring, _escape, _b, _b_i = ...\n"} - for _, chunk in ipairs(chunks) do - local t = type(chunk) - if t == "string" then - table.insert(buffer, "_b_i = _b_i + 1\n") - table.insert(buffer, "_b[_b_i] = " .. string.format("%q", chunk) .. "\n") - else - t = chunk[1] - if t == "=" then - if chunk[4] then -- is_simple_var - table.insert(buffer, "_b_i = _b_i + 1\n") - table.insert(buffer, "--[[" .. chunk[3] .. "]] _b[_b_i] = _escape(_tostring(" .. chunk[2] .. "))\n") - else - table.insert(buffer, "--[[" .. chunk[3] .. "]] " .. chunk[2] .. "\n") - end - elseif t == "-" then - table.insert(buffer, "_b_i = _b_i + 1\n") - table.insert(buffer, "--[[" .. chunk[3] .. "]] _b[_b_i] = _tostring(" .. chunk[2] .. ")\n") - end - end - end - table.insert(buffer, "return _b") - - local fn, err = loadstring(table.concat(buffer)) - if not fn then error(err) end - - env = env or {} - local runtime_env = setmetatable({}, {__index = function(_, k) return env[k] or _G[k] end}) - setfenv(fn, runtime_env) - - local output_buffer = {} - fn(tostring, html_special_chars, output_buffer, 0) - return table.concat(output_buffer) -end - --- Named placeholder processing -_G.parse = function(template_str, env) - local pos, output = 1, {} - env = env or {} - - while pos <= #template_str do - local unescaped_start, unescaped_end, unescaped_name = template_str:find("{{{%s*([%w_]+)%s*}}}", pos) - local escaped_start, escaped_end, escaped_name = template_str:find("{{%s*([%w_]+)%s*}}", pos) - - local next_pos, placeholder_end, name, escaped - if unescaped_start and (not escaped_start or unescaped_start <= escaped_start) then - next_pos, placeholder_end, name, escaped = unescaped_start, unescaped_end, unescaped_name, false - elseif escaped_start then - next_pos, placeholder_end, name, escaped = escaped_start, escaped_end, escaped_name, true - else - local text = template_str:sub(pos) - if text and #text > 0 then - table.insert(output, text) - end - break - end - - local text = template_str:sub(pos, next_pos - 1) - if text and #text > 0 then - table.insert(output, text) - end - - local value = env[name] - local str = tostring(value or "") - if escaped then - str = html_special_chars(str) - end - table.insert(output, str) - - pos = placeholder_end + 1 - end - - return table.concat(output) -end - --- Indexed placeholder processing -_G.iparse = function(template_str, values) - local pos, output, value_index = 1, {}, 1 - values = values or {} - - while pos <= #template_str do - local unescaped_start, unescaped_end = template_str:find("{{{}}}", pos, true) - local escaped_start, escaped_end = template_str:find("{{}}", pos, true) - - local next_pos, placeholder_end, escaped - if unescaped_start and (not escaped_start or unescaped_start <= escaped_start) then - next_pos, placeholder_end, escaped = unescaped_start, unescaped_end, false - elseif escaped_start then - next_pos, placeholder_end, escaped = escaped_start, escaped_end, true - else - local text = template_str:sub(pos) - if text and #text > 0 then - table.insert(output, text) - end - break - end - - local text = template_str:sub(pos, next_pos - 1) - if text and #text > 0 then - table.insert(output, text) - end - - local value = values[value_index] - local str = tostring(value or "") - if escaped then - str = html_special_chars(str) - end - table.insert(output, str) - - pos = placeholder_end + 1 - value_index = value_index + 1 - end - - return table.concat(output) -end - --- ====================================================================== --- PASSWORD FUNCTIONS --- ====================================================================== - --- Hash a password using Argon2id --- Options: --- memory: Amount of memory to use in KB (default: 128MB) --- iterations: Number of iterations (default: 4) --- parallelism: Number of threads (default: 4) --- salt_length: Length of salt in bytes (default: 16) --- key_length: Length of the derived key in bytes (default: 32) -function password_hash(plain_password, options) - if type(plain_password) ~= "string" then - error("password_hash: expected string password", 2) - end - - return __password_hash(plain_password, options) -end - --- Verify a password against a hash -function password_verify(plain_password, hash_string) - if type(plain_password) ~= "string" then - error("password_verify: expected string password", 2) - end - - if type(hash_string) ~= "string" then - error("password_verify: expected string hash", 2) - end - - return __password_verify(plain_password, hash_string) -end - --- ====================================================================== --- SEND FUNCTIONS --- ====================================================================== - -function send_html(content) - http_set_content_type("text/html") - return content -end - -function send_json(content) - http_set_content_type("application/json") - return content -end - -function send_text(content) - http_set_content_type("text/plain") - return content -end - -function send_xml(content) - http_set_content_type("application/xml") - return content -end - -function send_javascript(content) - http_set_content_type("application/javascript") - return content -end - -function send_css(content) - http_set_content_type("text/css") - return content -end - -function send_svg(content) - http_set_content_type("image/svg+xml") - return content -end - -function send_csv(content) - http_set_content_type("text/csv") - return content -end - -function send_binary(content, mime_type) - http_set_content_type(mime_type or "application/octet-stream") - return content end \ No newline at end of file diff --git a/runner/sandbox.go b/runner/sandbox.go index ce6a6b2..04d0a1a 100644 --- a/runner/sandbox.go +++ b/runner/sandbox.go @@ -2,114 +2,38 @@ package runner import ( "fmt" - "sync" - - "github.com/valyala/fasthttp" - - "Moonshark/runner/lualibs" - "Moonshark/utils/logger" - - "maps" luajit "git.sharkk.net/Sky/LuaJIT-to-Go" -) - -// Error represents a simple error string -type Error string - -func (e Error) Error() string { - return string(e) -} - -// Error types -var ( - ErrSandboxNotInitialized = Error("sandbox not initialized") + "github.com/valyala/fasthttp" ) // Sandbox provides a secure execution environment for Lua scripts type Sandbox struct { - modules map[string]any - debug bool - mu sync.RWMutex + executorBytecode []byte } // NewSandbox creates a new sandbox environment func NewSandbox() *Sandbox { - return &Sandbox{ - modules: make(map[string]any, 8), - debug: false, - } -} - -// AddModule adds a module to the sandbox environment -func (s *Sandbox) AddModule(name string, module any) { - s.mu.Lock() - defer s.mu.Unlock() - s.modules[name] = module - logger.Debugf("Added module: %s", name) + return &Sandbox{} } // Setup initializes the sandbox in a Lua state func (s *Sandbox) Setup(state *luajit.State, verbose bool) error { - if verbose { - logger.Debugf("Setting up sandbox...") - } - + // Load all embedded modules and sandbox if err := loadSandboxIntoState(state, verbose); err != nil { - logger.Errorf("Failed to load sandbox: %v", err) - return err + return fmt.Errorf("failed to load sandbox: %w", err) } + // Pre-compile the executor function for reuse + executorCode := `return __execute` + bytecode, err := state.CompileBytecode(executorCode, "executor") + if err != nil { + return fmt.Errorf("failed to compile executor: %w", err) + } + s.executorBytecode = bytecode + + // Register native functions if err := s.registerCoreFunctions(state); err != nil { - logger.Errorf("Failed to register core functions: %v", err) - return err - } - - s.mu.RLock() - for name, module := range s.modules { - logger.Debugf("Registering module: %s", name) - if err := state.PushValue(module); err != nil { - s.mu.RUnlock() - logger.Errorf("Failed to register module %s: %v", name, err) - return err - } - state.SetGlobal(name) - } - s.mu.RUnlock() - - if verbose { - logger.Debugf("Sandbox setup complete") - } - return nil -} - -// registerCoreFunctions registers all built-in functions in the Lua state -func (s *Sandbox) registerCoreFunctions(state *luajit.State) error { - if err := lualibs.RegisterHttpFunctions(state); err != nil { - return err - } - - if err := lualibs.RegisterSQLiteFunctions(state); err != nil { - return err - } - - if err := lualibs.RegisterFSFunctions(state); err != nil { - return err - } - - if err := lualibs.RegisterPasswordFunctions(state); err != nil { - return err - } - - if err := lualibs.RegisterUtilFunctions(state); err != nil { - return err - } - - if err := lualibs.RegisterCryptoFunctions(state); err != nil { - return err - } - - if err := lualibs.RegisterEnvFunctions(state); err != nil { return err } @@ -118,123 +42,145 @@ func (s *Sandbox) registerCoreFunctions(state *luajit.State) error { // Execute runs a Lua script in the sandbox with the given context func (s *Sandbox) Execute(state *luajit.State, bytecode []byte, ctx *Context) (*Response, error) { - // Load bytecode - pushes function onto stack + // Create response object in Lua + response := map[string]any{ + "status": 200, + "headers": make(map[string]string), + "cookies": []any{}, + "metadata": make(map[string]any), + "session": make(map[string]any), + } + + // Load script bytecode (pushes function) if err := state.LoadBytecode(bytecode, "script"); err != nil { return nil, fmt.Errorf("failed to load bytecode: %w", err) } - // Stack: [function] - state.GetGlobal("__execute_script") // Stack: [function, __execute_script] - state.PushCopy(-2) // Stack: [function, __execute_script, function] + // Load executor (pushes __execute function) + if err := state.LoadBytecode(s.executorBytecode, "executor"); err != nil { + state.Pop(1) // Remove script function + return nil, fmt.Errorf("failed to load executor: %w", err) + } - // Push context using PushValue + // Call the loaded executor to get __execute + if err := state.Call(0, 1); err != nil { + state.Pop(1) // Remove script function + return nil, fmt.Errorf("failed to get executor: %w", err) + } + + // Stack: [script_func, __execute] + state.PushCopy(-2) // Copy script function + // Stack: [script_func, __execute, script_func] + + // Push context if err := state.PushValue(ctx.Values); err != nil { state.Pop(3) return nil, fmt.Errorf("failed to push context: %w", err) } - // Stack: [function, __execute_script, function, context] - // Call __execute_script(function, context) - if err := state.Call(2, 1); err != nil { - state.Pop(1) // Clean up original function + // Push response object + if err := state.PushValue(response); err != nil { + state.Pop(4) + return nil, fmt.Errorf("failed to push response: %w", err) + } + + // Stack: [script_func, __execute, script_func, ctx, response] + // Call __execute(script_func, ctx, response) + if err := state.Call(3, 1); err != nil { + state.Pop(1) // Clean up return nil, fmt.Errorf("script execution failed: %w", err) } - // Stack: [function, result] - response := NewResponse() - if result, err := state.ToValue(-1); err == nil { - response.Body = result - } + // Get the result + result, _ := state.ToValue(-1) + state.Pop(2) // Remove result and original script function - state.SetTop(0) // Clear stack - - extractHTTPResponseData(state, response) - return response, nil + // Extract response data directly from the response object we passed + return s.buildResponse(response, result), nil } -// extractResponseData pulls response info from the Lua state using new API -func extractHTTPResponseData(state *luajit.State, response *Response) { - state.GetGlobal("__http_response") - if !state.IsTable(-1) { - state.Pop(1) - return +// buildResponse converts the Lua response object to a Go Response +func (s *Sandbox) buildResponse(luaResp map[string]any, body any) *Response { + resp := NewResponse() + resp.Body = body + + // Extract status + if status, ok := luaResp["status"].(float64); ok { + resp.Status = int(status) + } else if status, ok := luaResp["status"].(int); ok { + resp.Status = status } - // Use new field getters with defaults - response.Status = int(state.GetFieldNumber(-1, "status", 200)) - - // Extract headers using ForEachTableKV - if headerTable, ok := state.GetFieldTable(-1, "headers"); ok { - switch headers := headerTable.(type) { - case map[string]any: - for k, v := range headers { - if str, ok := v.(string); ok { - response.Headers[k] = str - } + // Extract headers + if headers, ok := luaResp["headers"].(map[string]any); ok { + for k, v := range headers { + if str, ok := v.(string); ok { + resp.Headers[k] = str } - case map[string]string: - maps.Copy(response.Headers, headers) + } + } else if headers, ok := luaResp["headers"].(map[string]string); ok { + for k, v := range headers { + resp.Headers[k] = v } } - // Extract cookies using ForEachArray - state.GetField(-1, "cookies") - if state.IsTable(-1) { - state.ForEachArray(-1, func(i int, s *luajit.State) bool { - if s.IsTable(-1) { - extractCookie(s, response) + // Extract cookies + if cookies, ok := luaResp["cookies"].([]any); ok { + for _, cookieData := range cookies { + if cookieMap, ok := cookieData.(map[string]any); ok { + cookie := fasthttp.AcquireCookie() + + if name, ok := cookieMap["name"].(string); ok && name != "" { + cookie.SetKey(name) + if value, ok := cookieMap["value"].(string); ok { + cookie.SetValue(value) + } + if path, ok := cookieMap["path"].(string); ok { + cookie.SetPath(path) + } + if domain, ok := cookieMap["domain"].(string); ok { + cookie.SetDomain(domain) + } + if httpOnly, ok := cookieMap["http_only"].(bool); ok { + cookie.SetHTTPOnly(httpOnly) + } + if secure, ok := cookieMap["secure"].(bool); ok { + cookie.SetSecure(secure) + } + if maxAge, ok := cookieMap["max_age"].(float64); ok { + cookie.SetMaxAge(int(maxAge)) + } else if maxAge, ok := cookieMap["max_age"].(int); ok { + cookie.SetMaxAge(maxAge) + } + + resp.Cookies = append(resp.Cookies, cookie) + } else { + fasthttp.ReleaseCookie(cookie) + } } - return true - }) + } } - state.Pop(1) // Extract metadata - if metadata, ok := state.GetFieldTable(-1, "metadata"); ok { - if metaMap, ok := metadata.(map[string]any); ok { - maps.Copy(response.Metadata, metaMap) + if metadata, ok := luaResp["metadata"].(map[string]any); ok { + for k, v := range metadata { + resp.Metadata[k] = v } } // Extract session data - if session, ok := state.GetFieldTable(-1, "session"); ok { - switch sessMap := session.(type) { - case map[string]any: - maps.Copy(response.SessionData, sessMap) - case map[string]string: - for k, v := range sessMap { - response.SessionData[k] = v - } - case map[string]int: - for k, v := range sessMap { - response.SessionData[k] = v - } - default: - logger.Debugf("Unexpected session type: %T", session) + if session, ok := luaResp["session"].(map[string]any); ok { + for k, v := range session { + resp.SessionData[k] = v } } - state.Pop(1) // Pop __http_response + return resp } -// extractCookie pulls cookie data from the current table on the stack using new API -func extractCookie(state *luajit.State, response *Response) { - cookie := fasthttp.AcquireCookie() - - // Use new field getters with defaults - name := state.GetFieldString(-1, "name", "") - if name == "" { - fasthttp.ReleaseCookie(cookie) - return - } - - cookie.SetKey(name) - cookie.SetValue(state.GetFieldString(-1, "value", "")) - cookie.SetPath(state.GetFieldString(-1, "path", "/")) - cookie.SetDomain(state.GetFieldString(-1, "domain", "")) - cookie.SetHTTPOnly(state.GetFieldBool(-1, "http_only", false)) - cookie.SetSecure(state.GetFieldBool(-1, "secure", false)) - cookie.SetMaxAge(int(state.GetFieldNumber(-1, "max_age", 0))) - - response.Cookies = append(response.Cookies, cookie) +// registerCoreFunctions registers all built-in functions in the Lua state +func (s *Sandbox) registerCoreFunctions(state *luajit.State) error { + // Register your native functions here + // This stays the same as your current implementation + return nil }