optimize session set, move lua libs to global

This commit is contained in:
Sky Johnson 2025-05-28 18:28:24 -05:00
parent 6264407d02
commit 39d14d0025
7 changed files with 807 additions and 777 deletions

View File

@ -46,18 +46,19 @@ type ModuleInfo struct {
Code string // Module source code Code string // Module source code
Bytecode atomic.Pointer[[]byte] // Cached bytecode Bytecode atomic.Pointer[[]byte] // Cached bytecode
Once sync.Once // For one-time compilation Once sync.Once // For one-time compilation
DefinesGlobal bool // Whether module defines globals directly
} }
var ( var (
sandbox = ModuleInfo{Name: "sandbox", Code: sandboxLuaCode} sandbox = ModuleInfo{Name: "sandbox", Code: sandboxLuaCode}
modules = []ModuleInfo{ modules = []ModuleInfo{
{Name: "json", Code: jsonLuaCode}, {Name: "json", Code: jsonLuaCode, DefinesGlobal: true},
{Name: "sqlite", Code: sqliteLuaCode}, {Name: "sqlite", Code: sqliteLuaCode},
{Name: "fs", Code: fsLuaCode}, {Name: "fs", Code: fsLuaCode, DefinesGlobal: true},
{Name: "util", Code: utilLuaCode}, {Name: "util", Code: utilLuaCode, DefinesGlobal: true},
{Name: "string", Code: stringLuaCode}, {Name: "string", Code: stringLuaCode},
{Name: "table", Code: tableLuaCode}, {Name: "table", Code: tableLuaCode},
{Name: "crypto", Code: cryptoLuaCode}, {Name: "crypto", Code: cryptoLuaCode, DefinesGlobal: true},
{Name: "time", Code: timeLuaCode}, {Name: "time", Code: timeLuaCode},
{Name: "math", Code: mathLuaCode}, {Name: "math", Code: mathLuaCode},
} }
@ -104,11 +105,18 @@ func loadModule(state *luajit.State, m *ModuleInfo, verbose bool) error {
return err return err
} }
if m.DefinesGlobal {
// Module defines its own globals, just run it
if err := state.RunBytecode(); err != nil {
return err
}
} else {
// Module returns a table, capture and set as global
if err := state.RunBytecodeWithResults(1); err != nil { if err := state.RunBytecodeWithResults(1); err != nil {
return err return err
} }
state.SetGlobal(m.Name) state.SetGlobal(m.Name)
}
} else { } else {
// Fallback to interpreting the source // Fallback to interpreting the source
if verbose { if verbose {

View File

@ -2,8 +2,6 @@
crypto.lua - Cryptographic functions powered by Go crypto.lua - Cryptographic functions powered by Go
]]-- ]]--
local crypto = {}
-- ====================================================================== -- ======================================================================
-- HASHING FUNCTIONS -- HASHING FUNCTIONS
-- ====================================================================== -- ======================================================================
@ -11,9 +9,9 @@ local crypto = {}
-- Generate hash digest using various algorithms -- Generate hash digest using various algorithms
-- Algorithms: md5, sha1, sha256, sha512 -- Algorithms: md5, sha1, sha256, sha512
-- Formats: hex (default), binary -- Formats: hex (default), binary
function crypto.hash(data, algorithm, format) function hash(data, algorithm, format)
if type(data) ~= "string" then if type(data) ~= "string" then
error("crypto.hash: data must be a string", 2) error("hash: data must be a string", 2)
end end
algorithm = algorithm or "sha256" algorithm = algorithm or "sha256"
@ -22,21 +20,20 @@ function crypto.hash(data, algorithm, format)
return __crypto_hash(data, algorithm, format) return __crypto_hash(data, algorithm, format)
end end
-- Convenience functions for common hash algorithms function md5(data, format)
function crypto.md5(data, format) return hash(data, "md5", format)
return crypto.hash(data, "md5", format)
end end
function crypto.sha1(data, format) function sha1(data, format)
return crypto.hash(data, "sha1", format) return hash(data, "sha1", format)
end end
function crypto.sha256(data, format) function sha256(data, format)
return crypto.hash(data, "sha256", format) return hash(data, "sha256", format)
end end
function crypto.sha512(data, format) function sha512(data, format)
return crypto.hash(data, "sha512", format) return hash(data, "sha512", format)
end end
-- ====================================================================== -- ======================================================================
@ -46,13 +43,13 @@ end
-- Generate HMAC using various algorithms -- Generate HMAC using various algorithms
-- Algorithms: md5, sha1, sha256, sha512 -- Algorithms: md5, sha1, sha256, sha512
-- Formats: hex (default), binary -- Formats: hex (default), binary
function crypto.hmac(data, key, algorithm, format) function hmac(data, key, algorithm, format)
if type(data) ~= "string" then if type(data) ~= "string" then
error("crypto.hmac: data must be a string", 2) error("hmac: data must be a string", 2)
end end
if type(key) ~= "string" then if type(key) ~= "string" then
error("crypto.hmac: key must be a string", 2) error("hmac: key must be a string", 2)
end end
algorithm = algorithm or "sha256" algorithm = algorithm or "sha256"
@ -61,21 +58,20 @@ function crypto.hmac(data, key, algorithm, format)
return __crypto_hmac(data, key, algorithm, format) return __crypto_hmac(data, key, algorithm, format)
end end
-- Convenience functions for common HMAC algorithms function hmac_md5(data, key, format)
function crypto.hmac_md5(data, key, format) return hmac(data, key, "md5", format)
return crypto.hmac(data, key, "md5", format)
end end
function crypto.hmac_sha1(data, key, format) function hmac_sha1(data, key, format)
return crypto.hmac(data, key, "sha1", format) return hmac(data, key, "sha1", format)
end end
function crypto.hmac_sha256(data, key, format) function hmac_sha256(data, key, format)
return crypto.hmac(data, key, "sha256", format) return hmac(data, key, "sha256", format)
end end
function crypto.hmac_sha512(data, key, format) function hmac_sha512(data, key, format)
return crypto.hmac(data, key, "sha512", format) return hmac(data, key, "sha512", format)
end end
-- ====================================================================== -- ======================================================================
@ -84,9 +80,9 @@ end
-- Generate random bytes -- Generate random bytes
-- Formats: binary (default), hex -- Formats: binary (default), hex
function crypto.random_bytes(length, secure, format) function random_bytes(length, secure, format)
if type(length) ~= "number" or length <= 0 then if type(length) ~= "number" or length <= 0 then
error("crypto.random_bytes: length must be positive", 2) error("random_bytes: length must be positive", 2)
end end
secure = secure ~= false -- Default to secure secure = secure ~= false -- Default to secure
@ -96,13 +92,13 @@ function crypto.random_bytes(length, secure, format)
end end
-- Generate random integer in range [min, max] -- Generate random integer in range [min, max]
function crypto.random_int(min, max, secure) function random_int(min, max, secure)
if type(min) ~= "number" or type(max) ~= "number" then if type(min) ~= "number" or type(max) ~= "number" then
error("crypto.random_int: min and max must be numbers", 2) error("random_int: min and max must be numbers", 2)
end end
if max <= min then if max <= min then
error("crypto.random_int: max must be greater than min", 2) error("random_int: max must be greater than min", 2)
end end
secure = secure ~= false -- Default to secure secure = secure ~= false -- Default to secure
@ -111,9 +107,9 @@ function crypto.random_int(min, max, secure)
end end
-- Generate random string of specified length -- Generate random string of specified length
function crypto.random_string(length, charset, secure) function random_string(length, charset, secure)
if type(length) ~= "number" or length <= 0 then if type(length) ~= "number" or length <= 0 then
error("crypto.random_string: length must be positive", 2) error("random_string: length must be positive", 2)
end end
secure = secure ~= false -- Default to secure secure = secure ~= false -- Default to secure
@ -122,14 +118,14 @@ function crypto.random_string(length, charset, secure)
charset = charset or "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" charset = charset or "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
if type(charset) ~= "string" or #charset == 0 then if type(charset) ~= "string" or #charset == 0 then
error("crypto.random_string: charset must be non-empty", 2) error("random_string: charset must be non-empty", 2)
end end
local result = "" local result = ""
local charset_length = #charset local charset_length = #charset
for i = 1, length do for i = 1, length do
local index = crypto.random_int(1, charset_length, secure) local index = random_int(1, charset_length, secure)
result = result .. charset:sub(index, index) result = result .. charset:sub(index, index)
end end
@ -141,8 +137,6 @@ end
-- ====================================================================== -- ======================================================================
-- Generate random UUID (v4) -- Generate random UUID (v4)
function crypto.uuid() function uuid()
return __crypto_uuid() return __crypto_uuid()
end end
return crypto

View File

@ -1,49 +1,47 @@
local fs = {} function fs_read(path)
fs.read = function(path)
if type(path) ~= "string" then if type(path) ~= "string" then
error("fs.read: path must be a string", 2) error("fs_read: path must be a string", 2)
end end
return __fs_read_file(path) return __fs_read_file(path)
end end
fs.write = function(path, content) function fs_write(path, content)
if type(path) ~= "string" then if type(path) ~= "string" then
error("fs.write: path must be a string", 2) error("fs_write: path must be a string", 2)
end end
if type(content) ~= "string" then if type(content) ~= "string" then
error("fs.write: content must be a string", 2) error("fs_write: content must be a string", 2)
end end
return __fs_write_file(path, content) return __fs_write_file(path, content)
end end
fs.append = function(path, content) function fs_append(path, content)
if type(path) ~= "string" then if type(path) ~= "string" then
error("fs.append: path must be a string", 2) error("fs_append: path must be a string", 2)
end end
if type(content) ~= "string" then if type(content) ~= "string" then
error("fs.append: content must be a string", 2) error("fs_append: content must be a string", 2)
end end
return __fs_append_file(path, content) return __fs_append_file(path, content)
end end
fs.exists = function(path) function fs_exists(path)
if type(path) ~= "string" then if type(path) ~= "string" then
error("fs.exists: path must be a string", 2) error("fs_exists: path must be a string", 2)
end end
return __fs_exists(path) return __fs_exists(path)
end end
fs.remove = function(path) function fs_remove(path)
if type(path) ~= "string" then if type(path) ~= "string" then
error("fs.remove: path must be a string", 2) error("fs_remove: path must be a string", 2)
end end
return __fs_remove_file(path) return __fs_remove_file(path)
end end
fs.info = function(path) function fs_info(path)
if type(path) ~= "string" then if type(path) ~= "string" then
error("fs.info: path must be a string", 2) error("fs_info: path must be a string", 2)
end end
local info = __fs_get_info(path) local info = __fs_get_info(path)
@ -56,58 +54,58 @@ fs.info = function(path)
end end
-- Directory Operations -- Directory Operations
fs.mkdir = function(path, mode) function fs_mkdir(path, mode)
if type(path) ~= "string" then if type(path) ~= "string" then
error("fs.mkdir: path must be a string", 2) error("fs_mkdir: path must be a string", 2)
end end
mode = mode or 0755 mode = mode or 0755
return __fs_make_dir(path, mode) return __fs_make_dir(path, mode)
end end
fs.ls = function(path) function fs_ls(path)
if type(path) ~= "string" then if type(path) ~= "string" then
error("fs.ls: path must be a string", 2) error("fs_ls: path must be a string", 2)
end end
return __fs_list_dir(path) return __fs_list_dir(path)
end end
fs.rmdir = function(path, recursive) function fs_rmdir(path, recursive)
if type(path) ~= "string" then if type(path) ~= "string" then
error("fs.rmdir: path must be a string", 2) error("fs_rmdir: path must be a string", 2)
end end
recursive = recursive or false recursive = recursive or false
return __fs_remove_dir(path, recursive) return __fs_remove_dir(path, recursive)
end end
-- Path Operations -- Path Operations
fs.join_paths = function(...) function fs_join_paths(...)
return __fs_join_paths(...) return __fs_join_paths(...)
end end
fs.dir_name = function(path) function fs_dir_name(path)
if type(path) ~= "string" then if type(path) ~= "string" then
error("fs.dir_name: path must be a string", 2) error("fs_dir_name: path must be a string", 2)
end end
return __fs_dir_name(path) return __fs_dir_name(path)
end end
fs.base_name = function(path) function fs_base_name(path)
if type(path) ~= "string" then if type(path) ~= "string" then
error("fs.base_name: path must be a string", 2) error("fs_base_name: path must be a string", 2)
end end
return __fs_base_name(path) return __fs_base_name(path)
end end
fs.extension = function(path) function fs_extension(path)
if type(path) ~= "string" then if type(path) ~= "string" then
error("fs.extension: path must be a string", 2) error("fs_extension: path must be a string", 2)
end end
return __fs_extension(path) return __fs_extension(path)
end end
-- Utility Functions -- Utility Functions
fs.read_json = function(path) function fs_read_json(path)
local content = fs.read_file(path) local content = fs_read(path)
if not content then if not content then
return nil, "Could not read file" return nil, "Could not read file"
end end
@ -120,9 +118,9 @@ fs.read_json = function(path)
return result return result
end end
fs.write_json = function(path, data, pretty) function fs_write_json(path, data, pretty)
if type(data) ~= "table" then if type(data) ~= "table" then
error("fs.write_json: data must be a table", 2) error("fs_write_json: data must be a table", 2)
end end
local content local content
@ -132,7 +130,5 @@ fs.write_json = function(path, data, pretty)
content = json.encode(data) content = json.encode(data)
end end
return fs.write_file(path, content) return fs_write(path, content)
end end
return fs

View File

@ -1,18 +1,23 @@
-- json.lua: High-performance JSON module for Moonshark -- json.lua: High-performance JSON module for Moonshark
local json = {}
function json.go_encode(value) -- Pre-computed escape sequences to avoid recreating table
local escape_chars = {
['"'] = '\\"', ['\\'] = '\\\\',
['\n'] = '\\n', ['\r'] = '\\r', ['\t'] = '\\t'
}
function json_go_encode(value)
return __json_marshal(value) return __json_marshal(value)
end end
function json.go_decode(str) function json_go_decode(str)
if type(str) ~= "string" then if type(str) ~= "string" then
error("json.decode: expected string, got " .. type(str), 2) error("json_decode: expected string, got " .. type(str), 2)
end end
return __json_unmarshal(str) return __json_unmarshal(str)
end end
function json.encode(data) function json_encode(data)
local t = type(data) local t = type(data)
if t == "nil" then return "null" end if t == "nil" then return "null" end
@ -20,48 +25,34 @@ function json.encode(data)
if t == "number" then return tostring(data) end if t == "number" then return tostring(data) end
if t == "string" then if t == "string" then
local escape_chars = {
['"'] = '\\"', ['\\'] = '\\\\',
['\n'] = '\\n', ['\r'] = '\\r', ['\t'] = '\\t'
}
return '"' .. data:gsub('[\\"\n\r\t]', escape_chars) .. '"' return '"' .. data:gsub('[\\"\n\r\t]', escape_chars) .. '"'
end end
if t == "table" then if t == "table" then
local isArray = true local isArray = true
local count = 0 local count = 0
local max_index = 0
-- Check if it's an array in one pass
for k, _ in pairs(data) do for k, _ in pairs(data) do
count = count + 1 count = count + 1
if type(k) == "number" and k > 0 and math.floor(k) == k then if type(k) ~= "number" or k ~= count or k < 1 then
max_index = math.max(max_index, k)
else
isArray = false isArray = false
break break
end end
end end
local result = {}
if isArray then if isArray then
for i, v in ipairs(data) do local result = {}
result[i] = json.encode(v) for i = 1, count do
result[i] = json_encode(data[i])
end end
return "[" .. table.concat(result, ",") .. "]" return "[" .. table.concat(result, ",") .. "]"
else else
local size = 0 local result = {}
for k, v in pairs(data) do
if type(k) == "string" and type(v) ~= "function" and type(v) ~= "userdata" then
size = size + 1
end
end
result = {}
local index = 1 local index = 1
for k, v in pairs(data) do for k, v in pairs(data) do
if type(k) == "string" and type(v) ~= "function" and type(v) ~= "userdata" then if type(k) == "string" and type(v) ~= "function" and type(v) ~= "userdata" then
result[index] = json.encode(k) .. ":" .. json.encode(v) result[index] = json_encode(k) .. ":" .. json_encode(v)
index = index + 1 index = index + 1
end end
end end
@ -72,7 +63,7 @@ function json.encode(data)
return "null" -- Unsupported type return "null" -- Unsupported type
end end
function json.decode(data) function json_decode(data)
local pos = 1 local pos = 1
local len = #data local len = #data
@ -372,15 +363,15 @@ function json.decode(data)
return result return result
end end
function json.is_valid(str) function json_is_valid(str)
if type(str) ~= "string" then return false end if type(str) ~= "string" then return false end
local status, _ = pcall(json.decode, str) local status, _ = pcall(json_decode, str)
return status return status
end end
function json.pretty_print(value) function json_pretty_print(value)
if type(value) == "string" then if type(value) == "string" then
value = json.decode(value) value = json_decode(value)
end end
local function stringify(val, indent, visited) local function stringify(val, indent, visited)
@ -429,5 +420,3 @@ function json.pretty_print(value)
return stringify(value) return stringify(value)
end end
return json

View File

@ -60,75 +60,72 @@ function __ensure_response()
end end
-- ====================================================================== -- ======================================================================
-- HTTP MODULE -- HTTP FUNCTIONS
-- ====================================================================== -- ======================================================================
local http = {
-- Set HTTP status code -- Set HTTP status code
set_status = function(code) function http_set_status(code)
if type(code) ~= "number" then if type(code) ~= "number" then
error("http.set_status: status code must be a number", 2) error("http_set_status: status code must be a number", 2)
end end
local resp = __ensure_response() local resp = __ensure_response()
resp.status = code resp.status = code
end, end
-- Set HTTP header -- Set HTTP header
set_header = function(name, value) function http_set_header(name, value)
if type(name) ~= "string" or type(value) ~= "string" then if type(name) ~= "string" or type(value) ~= "string" then
error("http.set_header: name and value must be strings", 2) error("http_set_header: name and value must be strings", 2)
end end
local resp = __ensure_response() local resp = __ensure_response()
resp.headers = resp.headers or {} resp.headers = resp.headers or {}
resp.headers[name] = value resp.headers[name] = value
end, end
-- Set content type; set_header helper -- Set content type; http_set_header helper
set_content_type = function(content_type) function http_set_content_type(content_type)
http.set_header("Content-Type", content_type) http_set_header("Content-Type", content_type)
end, end
-- Set metadata (arbitrary data to be returned with response) -- Set metadata (arbitrary data to be returned with response)
set_metadata = function(key, value) function http_set_metadata(key, value)
if type(key) ~= "string" then if type(key) ~= "string" then
error("http.set_metadata: key must be a string", 2) error("http_set_metadata: key must be a string", 2)
end end
local resp = __ensure_response() local resp = __ensure_response()
resp.metadata = resp.metadata or {} resp.metadata = resp.metadata or {}
resp.metadata[key] = value resp.metadata[key] = value
end, end
-- HTTP client submodule -- Generic HTTP request function
client = { function http_request(method, url, body, options)
-- Generic request function
request = function(method, url, body, options)
if type(method) ~= "string" then if type(method) ~= "string" then
error("http.client.request: method must be a string", 2) error("http_request: method must be a string", 2)
end end
if type(url) ~= "string" then if type(url) ~= "string" then
error("http.client.request: url must be a string", 2) error("http_request: url must be a string", 2)
end end
-- Call native implementation -- Call native implementation
local result = __http_request(method, url, body, options) local result = __http_request(method, url, body, options)
return result return result
end, end
-- Shorthand function to directly get JSON -- Shorthand function to directly get JSON
get_json = function(url, options) function http_get_json(url, options)
options = options or {} options = options or {}
local response = http.client.get(url, options) local response = http_get(url, options)
if response.ok and response.json then if response.ok and response.json then
return response.json return response.json
end end
return nil, response return nil, response
end, end
-- Utility to build a URL with query parameters -- Utility to build a URL with query parameters
build_url = function(base_url, params) function http_build_url(base_url, params)
if not params or type(params) ~= "table" then if not params or type(params) ~= "table" then
return base_url return base_url
end end
@ -137,10 +134,10 @@ local http = {
for k, v in pairs(params) do for k, v in pairs(params) do
if type(v) == "table" then if type(v) == "table" then
for _, item in ipairs(v) do for _, item in ipairs(v) do
table.insert(query, util.url_encode(k) .. "=" .. util.url_encode(tostring(item))) table.insert(query, url_encode(k) .. "=" .. url_encode(tostring(item)))
end end
else else
table.insert(query, util.url_encode(k) .. "=" .. util.url_encode(tostring(v))) table.insert(query, url_encode(k) .. "=" .. url_encode(tostring(v)))
end end
end end
@ -154,32 +151,30 @@ local http = {
return base_url return base_url
end end
}
}
local function make_method(method, needs_body) local function make_method(method, needs_body)
return function(url, body_or_options, options) return function(url, body_or_options, options)
if needs_body then if needs_body then
options = options or {} options = options or {}
return http.client.request(method, url, body_or_options, options) return http_request(method, url, body_or_options, options)
else else
body_or_options = body_or_options or {} body_or_options = body_or_options or {}
return http.client.request(method, url, nil, body_or_options) return http_request(method, url, nil, body_or_options)
end end
end end
end end
http.client.get = make_method("GET", false) http_get = make_method("GET", false)
http.client.delete = make_method("DELETE", false) http_delete = make_method("DELETE", false)
http.client.head = make_method("HEAD", false) http_head = make_method("HEAD", false)
http.client.options = make_method("OPTIONS", false) http_options = make_method("OPTIONS", false)
http.client.post = make_method("POST", true) http_post = make_method("POST", true)
http.client.put = make_method("PUT", true) http_put = make_method("PUT", true)
http.client.patch = make_method("PATCH", true) http_patch = make_method("PATCH", true)
http.redirect = function(url, status) function http_redirect(url, status)
if type(url) ~= "string" then if type(url) ~= "string" then
error("http.redirect: url must be a string", 2) error("http_redirect: url must be a string", 2)
end end
status = status or 302 -- Default to temporary redirect status = status or 302 -- Default to temporary redirect
@ -194,14 +189,13 @@ http.redirect = function(url, status)
end end
-- ====================================================================== -- ======================================================================
-- COOKIE MODULE -- COOKIE FUNCTIONS
-- ====================================================================== -- ======================================================================
local cookie = {
-- Set a cookie -- Set a cookie
set = function(name, value, options) function cookie_set(name, value, options)
if type(name) ~= "string" then if type(name) ~= "string" then
error("cookie.set: name must be a string", 2) error("cookie_set: name must be a string", 2)
end end
local resp = __ensure_response() local resp = __ensure_response()
@ -237,7 +231,7 @@ local cookie = {
local valid_values = {none = true, lax = true, strict = true} local valid_values = {none = true, lax = true, strict = true}
if not valid_values[same_site] then if not valid_values[same_site] then
error("cookie.set: same_site must be one of 'None', 'Lax', or 'Strict'", 2) error("cookie_set: same_site must be one of 'None', 'Lax', or 'Strict'", 2)
end end
-- If SameSite=None, the cookie must be secure -- If SameSite=None, the cookie must be secure
@ -252,12 +246,12 @@ local cookie = {
table.insert(resp.cookies, cookie) table.insert(resp.cookies, cookie)
return true return true
end, end
-- Get a cookie value -- Get a cookie value
get = function(name) function cookie_get(name)
if type(name) ~= "string" then if type(name) ~= "string" then
error("cookie.get: name must be a string", 2) error("cookie_get: name must be a string", 2)
end end
local env = getfenv(2) local env = getfenv(2)
@ -271,26 +265,24 @@ local cookie = {
end end
return nil return nil
end, end
-- Remove a cookie -- Remove a cookie
remove = function(name, path, domain) function cookie_remove(name, path, domain)
if type(name) ~= "string" then if type(name) ~= "string" then
error("cookie.remove: name must be a string", 2) error("cookie_remove: name must be a string", 2)
end end
return cookie.set(name, "", {expires = 0, path = path or "/", domain = domain}) return cookie_set(name, "", {expires = 0, path = path or "/", domain = domain})
end end
}
-- ====================================================================== -- ======================================================================
-- SESSION MODULE -- SESSION FUNCTIONS
-- ====================================================================== -- ======================================================================
local session = { function session_get(key)
get = function(key)
if type(key) ~= "string" then if type(key) ~= "string" then
error("session.get: key must be a string", 2) error("session_get: key must be a string", 2)
end end
local env = getfenv(2) local env = getfenv(2)
@ -300,22 +292,27 @@ local session = {
end end
return nil return nil
end, end
set = function(key, value) function session_set(key, value)
if type(key) ~= "string" then if type(key) ~= "string" then
error("session.set: key must be a string", 2) error("session_set: key must be a string", 2)
end end
if type(value) == nil then if type(value) == nil then
error("session.set: value cannot be nil", 2) error("session_set: value cannot be nil", 2)
end end
local resp = __ensure_response() local resp = __ensure_response()
resp.session = resp.session or {} resp.session = resp.session or {}
resp.session[key] = value resp.session[key] = value
end,
id = function() local env = getfenv(2)
if env.ctx and env.ctx.session and env.ctx.session.data then
env.ctx.session.data[key] = value
end
end
function session_id()
local env = getfenv(2) local env = getfenv(2)
if env.ctx and env.ctx.session then if env.ctx and env.ctx.session then
@ -323,9 +320,9 @@ local session = {
end end
return nil return nil
end, end
get_all = function() function session_get_all()
local env = getfenv(2) local env = getfenv(2)
if env.ctx and env.ctx.session then if env.ctx and env.ctx.session then
@ -333,11 +330,11 @@ local session = {
end end
return nil return nil
end, end
delete = function(key) function session_delete(key)
if type(key) ~= "string" then if type(key) ~= "string" then
error("session.delete: key must be a string", 2) error("session_delete: key must be a string", 2)
end end
local resp = __ensure_response() local resp = __ensure_response()
@ -348,9 +345,9 @@ local session = {
if env.ctx and env.ctx.session and env.ctx.session.data then if env.ctx and env.ctx.session and env.ctx.session.data then
env.ctx.session.data[key] = nil env.ctx.session.data[key] = nil
end end
end, end
clear = function() function session_clear()
local env = getfenv(2) local env = getfenv(2)
if env.ctx and env.ctx.session and env.ctx.session.data then if env.ctx and env.ctx.session and env.ctx.session.data then
for k, _ in pairs(env.ctx.session.data) do for k, _ in pairs(env.ctx.session.data) do
@ -362,29 +359,27 @@ local session = {
resp.session = {} resp.session = {}
resp.session["__clear_all"] = true resp.session["__clear_all"] = true
end end
}
-- ====================================================================== -- ======================================================================
-- CSRF MODULE -- CSRF FUNCTIONS
-- ====================================================================== -- ======================================================================
local csrf = { function csrf_generate()
generate = function() local token = generate_token(32)
local token = util.generate_token(32) session_set("_csrf_token", token)
session.set("_csrf_token", token)
return token return token
end, end
field = function() function csrf_field()
local token = session.get("_csrf_token") local token = session_get("_csrf_token")
if not token then if not token then
token = csrf.generate() token = csrf_generate()
end end
return string.format('<input type="hidden" name="_csrf_token" value="%s" />', return string.format('<input type="hidden" name="_csrf_token" value="%s" />',
util.html_special_chars(token)) html_special_chars(token))
end, end
validate = function() function csrf_validate()
local env = getfenv(2) local env = getfenv(2)
local token = false local token = false
if env.ctx and env.ctx.session and env.ctx.session.data then if env.ctx and env.ctx.session and env.ctx.session.data then
@ -392,7 +387,7 @@ local csrf = {
end end
if not token then if not token then
http.set_status(403) http_set_status(403)
__http_response.body = "CSRF validation failed" __http_response.body = "CSRF validation failed"
exit() exit()
end end
@ -408,14 +403,13 @@ local csrf = {
end end
if not request_token or request_token ~= token then if not request_token or request_token ~= token then
http.set_status(403) http_set_status(403)
__http_response.body = "CSRF validation failed" __http_response.body = "CSRF validation failed"
exit() exit()
end end
return true return true
end end
}
-- ====================================================================== -- ======================================================================
-- TEMPLATE RENDER FUNCTIONS -- TEMPLATE RENDER FUNCTIONS
@ -502,7 +496,7 @@ _G.render = function(template_str, env)
setfenv(fn, runtime_env) setfenv(fn, runtime_env)
local output_buffer = {} local output_buffer = {}
fn(tostring, util.html_special_chars, output_buffer, 0) fn(tostring, html_special_chars, output_buffer, 0)
return table.concat(output_buffer) return table.concat(output_buffer)
end end
@ -536,7 +530,7 @@ _G.parse = function(template_str, env)
local value = env[name] local value = env[name]
local str = tostring(value or "") local str = tostring(value or "")
if escaped then if escaped then
str = util.html_special_chars(str) str = html_special_chars(str)
end end
table.insert(output, str) table.insert(output, str)
@ -576,7 +570,7 @@ _G.iparse = function(template_str, values)
local value = values[value_index] local value = values[value_index]
local str = tostring(value or "") local str = tostring(value or "")
if escaped then if escaped then
str = util.html_special_chars(str) str = html_special_chars(str)
end end
table.insert(output, str) table.insert(output, str)
@ -588,11 +582,9 @@ _G.iparse = function(template_str, values)
end end
-- ====================================================================== -- ======================================================================
-- PASSWORD MODULE -- PASSWORD FUNCTIONS
-- ====================================================================== -- ======================================================================
local password = {}
-- Hash a password using Argon2id -- Hash a password using Argon2id
-- Options: -- Options:
-- memory: Amount of memory to use in KB (default: 128MB) -- memory: Amount of memory to use in KB (default: 128MB)
@ -600,85 +592,72 @@ local password = {}
-- parallelism: Number of threads (default: 4) -- parallelism: Number of threads (default: 4)
-- salt_length: Length of salt in bytes (default: 16) -- salt_length: Length of salt in bytes (default: 16)
-- key_length: Length of the derived key in bytes (default: 32) -- key_length: Length of the derived key in bytes (default: 32)
function password.hash(plain_password, options) function password_hash(plain_password, options)
if type(plain_password) ~= "string" then if type(plain_password) ~= "string" then
error("password.hash: expected string password", 2) error("password_hash: expected string password", 2)
end end
return __password_hash(plain_password, options) return __password_hash(plain_password, options)
end end
-- Verify a password against a hash -- Verify a password against a hash
function password.verify(plain_password, hash_string) function password_verify(plain_password, hash_string)
if type(plain_password) ~= "string" then if type(plain_password) ~= "string" then
error("password.verify: expected string password", 2) error("password_verify: expected string password", 2)
end end
if type(hash_string) ~= "string" then if type(hash_string) ~= "string" then
error("password.verify: expected string hash", 2) error("password_verify: expected string hash", 2)
end end
return __password_verify(plain_password, hash_string) return __password_verify(plain_password, hash_string)
end end
-- ====================================================================== -- ======================================================================
-- SEND MODULE -- SEND FUNCTIONS
-- ====================================================================== -- ======================================================================
local send = {} function send_html(content)
http_set_content_type("text/html")
function send.html(content)
http.set_content_type("text/html")
return content return content
end end
function send.json(content) function send_json(content)
http.set_content_type("application/json") http_set_content_type("application/json")
return content return content
end end
function send.text(content) function send_text(content)
http.set_content_type("text/plain") http_set_content_type("text/plain")
return content return content
end end
function send.xml(content) function send_xml(content)
http.set_content_type("application/xml") http_set_content_type("application/xml")
return content return content
end end
function send.javascript(content) function send_javascript(content)
http.set_content_type("application/javascript") http_set_content_type("application/javascript")
return content return content
end end
function send.css(content) function send_css(content)
http.set_content_type("text/css") http_set_content_type("text/css")
return content return content
end end
function send.svg(content) function send_svg(content)
http.set_content_type("image/svg+xml") http_set_content_type("image/svg+xml")
return content return content
end end
function send.csv(content) function send_csv(content)
http.set_content_type("text/csv") http_set_content_type("text/csv")
return content return content
end end
function send.binary(content, mime_type) function send_binary(content, mime_type)
http.set_content_type(mime_type or "application/octet-stream") http_set_content_type(mime_type or "application/octet-stream")
return content return content
end end
-- ======================================================================
-- REGISTER MODULES GLOBALLY
-- ======================================================================
_G.http = http
_G.session = session
_G.csrf = csrf
_G.cookie = cookie
_G.password = password
_G.send = send

View File

@ -1,16 +1,13 @@
--[[ --[[
util.lua - Utility functions for the Lua sandbox util.lua - Utility functions for the Lua sandbox
Enhanced with web development utilities
]]-- ]]--
local util = {}
-- ====================================================================== -- ======================================================================
-- CORE UTILITY FUNCTIONS -- CORE UTILITY FUNCTIONS
-- ====================================================================== -- ======================================================================
-- Generate a random token -- Generate a random token
function util.generate_token(length) function generate_token(length)
return __generate_token(length or 32) return __generate_token(length or 32)
end end
@ -18,20 +15,8 @@ end
-- HTML ENTITY FUNCTIONS -- HTML ENTITY FUNCTIONS
-- ====================================================================== -- ======================================================================
-- HTML entity mapping for common characters
local html_entities = {
["&"] = "&amp;",
["<"] = "&lt;",
[">"] = "&gt;",
['"'] = "&quot;",
["'"] = "&#39;",
["/"] = "&#x2F;",
["`"] = "&#x60;",
["="] = "&#x3D;"
}
-- Convert special characters to HTML entities (like htmlspecialchars) -- Convert special characters to HTML entities (like htmlspecialchars)
function util.html_special_chars(str) function html_special_chars(str)
if type(str) ~= "string" then if type(str) ~= "string" then
return str return str
end end
@ -40,7 +25,7 @@ function util.html_special_chars(str)
end end
-- Convert all applicable characters to HTML entities (like htmlentities) -- Convert all applicable characters to HTML entities (like htmlentities)
function util.html_entities(str) function html_entities(str)
if type(str) ~= "string" then if type(str) ~= "string" then
return str return str
end end
@ -49,7 +34,7 @@ function util.html_entities(str)
end end
-- Convert HTML entities back to characters (simple version) -- Convert HTML entities back to characters (simple version)
function util.html_entity_decode(str) function html_entity_decode(str)
if type(str) ~= "string" then if type(str) ~= "string" then
return str return str
end end
@ -64,7 +49,7 @@ function util.html_entity_decode(str)
end end
-- Convert newlines to <br> tags -- Convert newlines to <br> tags
function util.nl2br(str) function nl2br(str)
if type(str) ~= "string" then if type(str) ~= "string" then
return str return str
end end
@ -77,7 +62,7 @@ end
-- ====================================================================== -- ======================================================================
-- URL encode a string -- URL encode a string
function util.url_encode(str) function url_encode(str)
if type(str) ~= "string" then if type(str) ~= "string" then
return str return str
end end
@ -91,7 +76,7 @@ function util.url_encode(str)
end end
-- URL decode a string -- URL decode a string
function util.url_decode(str) function url_decode(str)
if type(str) ~= "string" then if type(str) ~= "string" then
return str return str
end end
@ -108,7 +93,7 @@ end
-- ====================================================================== -- ======================================================================
-- Email validation -- Email validation
function util.is_email(str) function is_email(str)
if type(str) ~= "string" then if type(str) ~= "string" then
return false return false
end end
@ -119,7 +104,7 @@ function util.is_email(str)
end end
-- URL validation -- URL validation
function util.is_url(str) function is_url(str)
if type(str) ~= "string" then if type(str) ~= "string" then
return false return false
end end
@ -130,7 +115,7 @@ function util.is_url(str)
end end
-- IP address validation (IPv4) -- IP address validation (IPv4)
function util.is_ipv4(str) function is_ipv4(str)
if type(str) ~= "string" then if type(str) ~= "string" then
return false return false
end end
@ -147,7 +132,7 @@ function util.is_ipv4(str)
end end
-- Integer validation -- Integer validation
function util.is_int(str) function is_int(str)
if type(str) == "number" then if type(str) == "number" then
return math.floor(str) == str return math.floor(str) == str
elseif type(str) ~= "string" then elseif type(str) ~= "string" then
@ -158,7 +143,7 @@ function util.is_int(str)
end end
-- Float validation -- Float validation
function util.is_float(str) function is_float(str)
if type(str) == "number" then if type(str) == "number" then
return true return true
elseif type(str) ~= "string" then elseif type(str) ~= "string" then
@ -169,7 +154,7 @@ function util.is_float(str)
end end
-- Boolean validation -- Boolean validation
function util.is_bool(value) function is_bool(value)
if type(value) == "boolean" then if type(value) == "boolean" then
return true return true
elseif type(value) ~= "string" and type(value) ~= "number" then elseif type(value) ~= "string" and type(value) ~= "number" then
@ -183,7 +168,7 @@ function util.is_bool(value)
end end
-- Convert to boolean -- Convert to boolean
function util.to_bool(value) function to_bool(value)
if type(value) == "boolean" then if type(value) == "boolean" then
return value return value
elseif type(value) ~= "string" and type(value) ~= "number" then elseif type(value) ~= "string" and type(value) ~= "number" then
@ -195,16 +180,16 @@ function util.to_bool(value)
end end
-- Sanitize string (simple version) -- Sanitize string (simple version)
function util.sanitize_string(str) function sanitize_string(str)
if type(str) ~= "string" then if type(str) ~= "string" then
return "" return ""
end end
return util.html_special_chars(str) return html_special_chars(str)
end end
-- Sanitize to integer -- Sanitize to integer
function util.sanitize_int(value) function sanitize_int(value)
if type(value) ~= "string" and type(value) ~= "number" then if type(value) ~= "string" and type(value) ~= "number" then
return 0 return 0
end end
@ -215,7 +200,7 @@ function util.sanitize_int(value)
end end
-- Sanitize to float -- Sanitize to float
function util.sanitize_float(value) function sanitize_float(value)
if type(value) ~= "string" and type(value) ~= "number" then if type(value) ~= "string" and type(value) ~= "number" then
return 0 return 0
end end
@ -226,7 +211,7 @@ function util.sanitize_float(value)
end end
-- Sanitize URL -- Sanitize URL
function util.sanitize_url(str) function sanitize_url(str)
if type(str) ~= "string" then if type(str) ~= "string" then
return "" return ""
end end
@ -235,12 +220,12 @@ function util.sanitize_url(str)
str = str:gsub("[\000-\031]", "") str = str:gsub("[\000-\031]", "")
-- Make sure it's a valid URL -- Make sure it's a valid URL
if util.is_url(str) then if is_url(str) then
return str return str
end end
-- Try to prepend http:// if it's missing -- Try to prepend http:// if it's missing
if not str:match("^https?://") and util.is_url("http://" .. str) then if not str:match("^https?://") and is_url("http://" .. str) then
return "http://" .. str return "http://" .. str
end end
@ -248,7 +233,7 @@ function util.sanitize_url(str)
end end
-- Sanitize email -- Sanitize email
function util.sanitize_email(str) function sanitize_email(str)
if type(str) ~= "string" then if type(str) ~= "string" then
return "" return ""
end end
@ -257,7 +242,7 @@ function util.sanitize_email(str)
str = str:gsub("[^%a%d%!%#%$%%%&%'%*%+%-%/%=%?%^%_%`%{%|%}%~%@%.%[%]]", "") str = str:gsub("[^%a%d%!%#%$%%%&%'%*%+%-%/%=%?%^%_%`%{%|%}%~%@%.%[%]]", "")
-- Return only if it's a valid email -- Return only if it's a valid email
if util.is_email(str) then if is_email(str) then
return str return str
end end
@ -269,13 +254,13 @@ end
-- ====================================================================== -- ======================================================================
-- Basic XSS prevention -- Basic XSS prevention
function util.xss_clean(str) function xss_clean(str)
if type(str) ~= "string" then if type(str) ~= "string" then
return str return str
end end
-- Convert problematic characters to entities -- Convert problematic characters to entities
local result = util.html_special_chars(str) local result = html_special_chars(str)
-- Remove JavaScript event handlers -- Remove JavaScript event handlers
result = result:gsub("on%w+%s*=", "") result = result:gsub("on%w+%s*=", "")
@ -290,7 +275,7 @@ function util.xss_clean(str)
end end
-- Base64 encode -- Base64 encode
function util.base64_encode(str) function base64_encode(str)
if type(str) ~= "string" then if type(str) ~= "string" then
return str return str
end end
@ -299,12 +284,10 @@ function util.base64_encode(str)
end end
-- Base64 decode -- Base64 decode
function util.base64_decode(str) function base64_decode(str)
if type(str) ~= "string" then if type(str) ~= "string" then
return str return str
end end
return __base64_decode(str) return __base64_decode(str)
end end
return util

View File

@ -81,6 +81,9 @@ func (s *Session) GetAll() map[string]any {
// Set stores a value in the session // Set stores a value in the session
func (s *Session) Set(key string, value any) { func (s *Session) Set(key string, value any) {
if existing, ok := s.Data[key]; ok && deepEqual(existing, value) {
return // No change
}
s.Data[key] = value s.Data[key] = value
s.UpdatedAt = time.Now() s.UpdatedAt = time.Now()
s.dirty = true s.dirty = true
@ -346,3 +349,81 @@ func validate(v any) error {
} }
return nil return nil
} }
// deepEqual efficiently compares two values for deep equality
func deepEqual(a, b any) bool {
if a == b {
return true
}
if a == nil || b == nil {
return false
}
switch va := a.(type) {
case string:
if vb, ok := b.(string); ok {
return va == vb
}
case int:
if vb, ok := b.(int); ok {
return va == vb
}
if vb, ok := b.(int64); ok {
return int64(va) == vb
}
case int64:
if vb, ok := b.(int64); ok {
return va == vb
}
if vb, ok := b.(int); ok {
return va == int64(vb)
}
case float64:
if vb, ok := b.(float64); ok {
return va == vb
}
case bool:
if vb, ok := b.(bool); ok {
return va == vb
}
case []byte:
if vb, ok := b.([]byte); ok {
if len(va) != len(vb) {
return false
}
for i, v := range va {
if v != vb[i] {
return false
}
}
return true
}
case map[string]any:
if vb, ok := b.(map[string]any); ok {
if len(va) != len(vb) {
return false
}
for k, v := range va {
if bv, exists := vb[k]; !exists || !deepEqual(v, bv) {
return false
}
}
return true
}
case []any:
if vb, ok := b.([]any); ok {
if len(va) != len(vb) {
return false
}
for i, v := range va {
if !deepEqual(v, vb[i]) {
return false
}
}
return true
}
}
return false
}