diff --git a/runner/embed.go b/runner/embed.go index 6824d62..74e4d85 100644 --- a/runner/embed.go +++ b/runner/embed.go @@ -42,22 +42,23 @@ var mathLuaCode string // ModuleInfo holds information about an embeddable Lua module type ModuleInfo struct { - Name string // Module name - Code string // Module source code - Bytecode atomic.Pointer[[]byte] // Cached bytecode - Once sync.Once // For one-time compilation + Name string // Module name + Code string // Module source code + Bytecode atomic.Pointer[[]byte] // Cached bytecode + Once sync.Once // For one-time compilation + DefinesGlobal bool // Whether module defines globals directly } var ( sandbox = ModuleInfo{Name: "sandbox", Code: sandboxLuaCode} modules = []ModuleInfo{ - {Name: "json", Code: jsonLuaCode}, + {Name: "json", Code: jsonLuaCode, DefinesGlobal: true}, {Name: "sqlite", Code: sqliteLuaCode}, - {Name: "fs", Code: fsLuaCode}, - {Name: "util", Code: utilLuaCode}, + {Name: "fs", Code: fsLuaCode, DefinesGlobal: true}, + {Name: "util", Code: utilLuaCode, DefinesGlobal: true}, {Name: "string", Code: stringLuaCode}, {Name: "table", Code: tableLuaCode}, - {Name: "crypto", Code: cryptoLuaCode}, + {Name: "crypto", Code: cryptoLuaCode, DefinesGlobal: true}, {Name: "time", Code: timeLuaCode}, {Name: "math", Code: mathLuaCode}, } @@ -104,11 +105,18 @@ func loadModule(state *luajit.State, m *ModuleInfo, verbose bool) error { return err } - if err := state.RunBytecodeWithResults(1); err != nil { - return err + if m.DefinesGlobal { + // Module defines its own globals, just run it + if err := state.RunBytecode(); err != nil { + return err + } + } else { + // Module returns a table, capture and set as global + if err := state.RunBytecodeWithResults(1); err != nil { + return err + } + state.SetGlobal(m.Name) } - - state.SetGlobal(m.Name) } else { // Fallback to interpreting the source if verbose { diff --git a/runner/lua/crypto.lua b/runner/lua/crypto.lua index ea50c06..238cf86 100644 --- a/runner/lua/crypto.lua +++ b/runner/lua/crypto.lua @@ -2,8 +2,6 @@ crypto.lua - Cryptographic functions powered by Go ]]-- -local crypto = {} - -- ====================================================================== -- HASHING FUNCTIONS -- ====================================================================== @@ -11,9 +9,9 @@ local crypto = {} -- Generate hash digest using various algorithms -- Algorithms: md5, sha1, sha256, sha512 -- Formats: hex (default), binary -function crypto.hash(data, algorithm, format) +function hash(data, algorithm, format) if type(data) ~= "string" then - error("crypto.hash: data must be a string", 2) + error("hash: data must be a string", 2) end algorithm = algorithm or "sha256" @@ -22,21 +20,20 @@ function crypto.hash(data, algorithm, format) return __crypto_hash(data, algorithm, format) end --- Convenience functions for common hash algorithms -function crypto.md5(data, format) - return crypto.hash(data, "md5", format) +function md5(data, format) + return hash(data, "md5", format) end -function crypto.sha1(data, format) - return crypto.hash(data, "sha1", format) +function sha1(data, format) + return hash(data, "sha1", format) end -function crypto.sha256(data, format) - return crypto.hash(data, "sha256", format) +function sha256(data, format) + return hash(data, "sha256", format) end -function crypto.sha512(data, format) - return crypto.hash(data, "sha512", format) +function sha512(data, format) + return hash(data, "sha512", format) end -- ====================================================================== @@ -46,13 +43,13 @@ end -- Generate HMAC using various algorithms -- Algorithms: md5, sha1, sha256, sha512 -- Formats: hex (default), binary -function crypto.hmac(data, key, algorithm, format) +function hmac(data, key, algorithm, format) if type(data) ~= "string" then - error("crypto.hmac: data must be a string", 2) + error("hmac: data must be a string", 2) end if type(key) ~= "string" then - error("crypto.hmac: key must be a string", 2) + error("hmac: key must be a string", 2) end algorithm = algorithm or "sha256" @@ -61,21 +58,20 @@ function crypto.hmac(data, key, algorithm, format) return __crypto_hmac(data, key, algorithm, format) end --- Convenience functions for common HMAC algorithms -function crypto.hmac_md5(data, key, format) - return crypto.hmac(data, key, "md5", format) +function hmac_md5(data, key, format) + return hmac(data, key, "md5", format) end -function crypto.hmac_sha1(data, key, format) - return crypto.hmac(data, key, "sha1", format) +function hmac_sha1(data, key, format) + return hmac(data, key, "sha1", format) end -function crypto.hmac_sha256(data, key, format) - return crypto.hmac(data, key, "sha256", format) +function hmac_sha256(data, key, format) + return hmac(data, key, "sha256", format) end -function crypto.hmac_sha512(data, key, format) - return crypto.hmac(data, key, "sha512", format) +function hmac_sha512(data, key, format) + return hmac(data, key, "sha512", format) end -- ====================================================================== @@ -84,9 +80,9 @@ end -- Generate random bytes -- Formats: binary (default), hex -function crypto.random_bytes(length, secure, format) +function random_bytes(length, secure, format) if type(length) ~= "number" or length <= 0 then - error("crypto.random_bytes: length must be positive", 2) + error("random_bytes: length must be positive", 2) end secure = secure ~= false -- Default to secure @@ -96,13 +92,13 @@ function crypto.random_bytes(length, secure, format) end -- Generate random integer in range [min, max] -function crypto.random_int(min, max, secure) +function random_int(min, max, secure) if type(min) ~= "number" or type(max) ~= "number" then - error("crypto.random_int: min and max must be numbers", 2) + error("random_int: min and max must be numbers", 2) end if max <= min then - error("crypto.random_int: max must be greater than min", 2) + error("random_int: max must be greater than min", 2) end secure = secure ~= false -- Default to secure @@ -111,9 +107,9 @@ function crypto.random_int(min, max, secure) end -- Generate random string of specified length -function crypto.random_string(length, charset, secure) +function random_string(length, charset, secure) if type(length) ~= "number" or length <= 0 then - error("crypto.random_string: length must be positive", 2) + error("random_string: length must be positive", 2) end secure = secure ~= false -- Default to secure @@ -122,14 +118,14 @@ function crypto.random_string(length, charset, secure) charset = charset or "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" if type(charset) ~= "string" or #charset == 0 then - error("crypto.random_string: charset must be non-empty", 2) + error("random_string: charset must be non-empty", 2) end local result = "" local charset_length = #charset for i = 1, length do - local index = crypto.random_int(1, charset_length, secure) + local index = random_int(1, charset_length, secure) result = result .. charset:sub(index, index) end @@ -141,8 +137,6 @@ end -- ====================================================================== -- Generate random UUID (v4) -function crypto.uuid() +function uuid() return __crypto_uuid() -end - -return crypto \ No newline at end of file +end \ No newline at end of file diff --git a/runner/lua/fs.lua b/runner/lua/fs.lua index 6f38bc7..7735f0c 100644 --- a/runner/lua/fs.lua +++ b/runner/lua/fs.lua @@ -1,49 +1,47 @@ -local fs = {} - -fs.read = function(path) +function fs_read(path) if type(path) ~= "string" then - error("fs.read: path must be a string", 2) + error("fs_read: path must be a string", 2) end return __fs_read_file(path) end -fs.write = function(path, content) +function fs_write(path, content) if type(path) ~= "string" then - error("fs.write: path must be a string", 2) + error("fs_write: path must be a string", 2) end if type(content) ~= "string" then - error("fs.write: content must be a string", 2) + error("fs_write: content must be a string", 2) end return __fs_write_file(path, content) end -fs.append = function(path, content) +function fs_append(path, content) if type(path) ~= "string" then - error("fs.append: path must be a string", 2) + error("fs_append: path must be a string", 2) end if type(content) ~= "string" then - error("fs.append: content must be a string", 2) + error("fs_append: content must be a string", 2) end return __fs_append_file(path, content) end -fs.exists = function(path) +function fs_exists(path) if type(path) ~= "string" then - error("fs.exists: path must be a string", 2) + error("fs_exists: path must be a string", 2) end return __fs_exists(path) end -fs.remove = function(path) +function fs_remove(path) if type(path) ~= "string" then - error("fs.remove: path must be a string", 2) + error("fs_remove: path must be a string", 2) end return __fs_remove_file(path) end -fs.info = function(path) +function fs_info(path) if type(path) ~= "string" then - error("fs.info: path must be a string", 2) + error("fs_info: path must be a string", 2) end local info = __fs_get_info(path) @@ -56,58 +54,58 @@ fs.info = function(path) end -- Directory Operations -fs.mkdir = function(path, mode) +function fs_mkdir(path, mode) if type(path) ~= "string" then - error("fs.mkdir: path must be a string", 2) + error("fs_mkdir: path must be a string", 2) end mode = mode or 0755 return __fs_make_dir(path, mode) end -fs.ls = function(path) +function fs_ls(path) if type(path) ~= "string" then - error("fs.ls: path must be a string", 2) + error("fs_ls: path must be a string", 2) end return __fs_list_dir(path) end -fs.rmdir = function(path, recursive) +function fs_rmdir(path, recursive) if type(path) ~= "string" then - error("fs.rmdir: path must be a string", 2) + error("fs_rmdir: path must be a string", 2) end recursive = recursive or false return __fs_remove_dir(path, recursive) end -- Path Operations -fs.join_paths = function(...) +function fs_join_paths(...) return __fs_join_paths(...) end -fs.dir_name = function(path) +function fs_dir_name(path) if type(path) ~= "string" then - error("fs.dir_name: path must be a string", 2) + error("fs_dir_name: path must be a string", 2) end return __fs_dir_name(path) end -fs.base_name = function(path) +function fs_base_name(path) if type(path) ~= "string" then - error("fs.base_name: path must be a string", 2) + error("fs_base_name: path must be a string", 2) end return __fs_base_name(path) end -fs.extension = function(path) +function fs_extension(path) if type(path) ~= "string" then - error("fs.extension: path must be a string", 2) + error("fs_extension: path must be a string", 2) end return __fs_extension(path) end -- Utility Functions -fs.read_json = function(path) - local content = fs.read_file(path) +function fs_read_json(path) + local content = fs_read(path) if not content then return nil, "Could not read file" end @@ -120,9 +118,9 @@ fs.read_json = function(path) return result end -fs.write_json = function(path, data, pretty) +function fs_write_json(path, data, pretty) if type(data) ~= "table" then - error("fs.write_json: data must be a table", 2) + error("fs_write_json: data must be a table", 2) end local content @@ -132,7 +130,5 @@ fs.write_json = function(path, data, pretty) content = json.encode(data) end - return fs.write_file(path, content) -end - -return fs + return fs_write(path, content) +end \ No newline at end of file diff --git a/runner/lua/json.lua b/runner/lua/json.lua index 3539a9a..8790463 100644 --- a/runner/lua/json.lua +++ b/runner/lua/json.lua @@ -1,18 +1,23 @@ -- json.lua: High-performance JSON module for Moonshark -local json = {} -function json.go_encode(value) - return __json_marshal(value) +-- Pre-computed escape sequences to avoid recreating table +local escape_chars = { + ['"'] = '\\"', ['\\'] = '\\\\', + ['\n'] = '\\n', ['\r'] = '\\r', ['\t'] = '\\t' +} + +function json_go_encode(value) + return __json_marshal(value) end -function json.go_decode(str) - if type(str) ~= "string" then - error("json.decode: expected string, got " .. type(str), 2) - end - return __json_unmarshal(str) +function json_go_decode(str) + if type(str) ~= "string" then + error("json_decode: expected string, got " .. type(str), 2) + end + return __json_unmarshal(str) end -function json.encode(data) +function json_encode(data) local t = type(data) if t == "nil" then return "null" end @@ -20,50 +25,36 @@ function json.encode(data) if t == "number" then return tostring(data) end if t == "string" then - local escape_chars = { - ['"'] = '\\"', ['\\'] = '\\\\', - ['\n'] = '\\n', ['\r'] = '\\r', ['\t'] = '\\t' - } return '"' .. data:gsub('[\\"\n\r\t]', escape_chars) .. '"' end if t == "table" then local isArray = true local count = 0 - local max_index = 0 + -- Check if it's an array in one pass for k, _ in pairs(data) do count = count + 1 - if type(k) == "number" and k > 0 and math.floor(k) == k then - max_index = math.max(max_index, k) - else - isArray = false - break + if type(k) ~= "number" or k ~= count or k < 1 then + isArray = false + break end end - local result = {} - if isArray then - for i, v in ipairs(data) do - result[i] = json.encode(v) + local result = {} + for i = 1, count do + result[i] = json_encode(data[i]) end return "[" .. table.concat(result, ",") .. "]" else - local size = 0 - for k, v in pairs(data) do - if type(k) == "string" and type(v) ~= "function" and type(v) ~= "userdata" then - size = size + 1 - end - end - - result = {} + local result = {} local index = 1 for k, v in pairs(data) do - if type(k) == "string" and type(v) ~= "function" and type(v) ~= "userdata" then - result[index] = json.encode(k) .. ":" .. json.encode(v) - index = index + 1 - end + if type(k) == "string" and type(v) ~= "function" and type(v) ~= "userdata" then + result[index] = json_encode(k) .. ":" .. json_encode(v) + index = index + 1 + end end return "{" .. table.concat(result, ",") .. "}" end @@ -72,7 +63,7 @@ function json.encode(data) return "null" -- Unsupported type end -function json.decode(data) +function json_decode(data) local pos = 1 local len = #data @@ -100,14 +91,14 @@ function json.decode(data) -- Skip whitespace more efficiently local function skip() - local b - while pos <= len do - b = data:byte(pos) - if b > b_space or (b ~= b_space and b ~= b_tab and b ~= b_cr and b ~= b_lf) then - break + local b + while pos <= len do + b = data:byte(pos) + if b > b_space or (b ~= b_space and b ~= b_tab and b ~= b_cr and b ~= b_lf) then + break + end + pos = pos + 1 end - pos = pos + 1 - end end -- Forward declarations @@ -115,250 +106,250 @@ function json.decode(data) -- Parse a string more efficiently parse_string = function() - pos = pos + 1 -- Skip opening quote + pos = pos + 1 -- Skip opening quote - if pos > len then - error("Unterminated string") - end - - -- Use a table to build the string - local result = {} - local result_pos = 1 - local start = pos - local c, b - - while pos <= len do - b = data:byte(pos) - - if b == b_backslash then - -- Add the chunk before the escape character - if pos > start then - result[result_pos] = data:sub(start, pos - 1) - result_pos = result_pos + 1 - end - - pos = pos + 1 - if pos > len then - error("Unterminated string escape") - end - - c = data:byte(pos) - if c == b_quote then - result[result_pos] = '"' - elseif c == b_backslash then - result[result_pos] = '\\' - elseif c == b_slash then - result[result_pos] = '/' - elseif c == string.byte('b') then - result[result_pos] = '\b' - elseif c == string.byte('f') then - result[result_pos] = '\f' - elseif c == string.byte('n') then - result[result_pos] = '\n' - elseif c == string.byte('r') then - result[result_pos] = '\r' - elseif c == string.byte('t') then - result[result_pos] = '\t' - else - result[result_pos] = data:sub(pos, pos) - end - - result_pos = result_pos + 1 - pos = pos + 1 - start = pos - elseif b == b_quote then - -- Add the final chunk - if pos > start then - result[result_pos] = data:sub(start, pos - 1) - result_pos = result_pos + 1 - end - - pos = pos + 1 - return table.concat(result) - else - pos = pos + 1 + if pos > len then + error("Unterminated string") end - end - error("Unterminated string") + -- Use a table to build the string + local result = {} + local result_pos = 1 + local start = pos + local c, b + + while pos <= len do + b = data:byte(pos) + + if b == b_backslash then + -- Add the chunk before the escape character + if pos > start then + result[result_pos] = data:sub(start, pos - 1) + result_pos = result_pos + 1 + end + + pos = pos + 1 + if pos > len then + error("Unterminated string escape") + end + + c = data:byte(pos) + if c == b_quote then + result[result_pos] = '"' + elseif c == b_backslash then + result[result_pos] = '\\' + elseif c == b_slash then + result[result_pos] = '/' + elseif c == string.byte('b') then + result[result_pos] = '\b' + elseif c == string.byte('f') then + result[result_pos] = '\f' + elseif c == string.byte('n') then + result[result_pos] = '\n' + elseif c == string.byte('r') then + result[result_pos] = '\r' + elseif c == string.byte('t') then + result[result_pos] = '\t' + else + result[result_pos] = data:sub(pos, pos) + end + + result_pos = result_pos + 1 + pos = pos + 1 + start = pos + elseif b == b_quote then + -- Add the final chunk + if pos > start then + result[result_pos] = data:sub(start, pos - 1) + result_pos = result_pos + 1 + end + + pos = pos + 1 + return table.concat(result) + else + pos = pos + 1 + end + end + + error("Unterminated string") end -- Parse a number more efficiently parse_number = function() - local start = pos - local b = data:byte(pos) + local start = pos + local b = data:byte(pos) - -- Skip any sign - if b == b_minus then - pos = pos + 1 - if pos > len then - error("Malformed number") - end - b = data:byte(pos) - end - - -- Integer part - if b < b_0 or b > b_9 then - error("Malformed number") - end - - repeat - pos = pos + 1 - if pos > len then break end - b = data:byte(pos) - until b < b_0 or b > b_9 - - -- Fractional part - if pos <= len and b == b_dot then - pos = pos + 1 - if pos > len or data:byte(pos) < b_0 or data:byte(pos) > b_9 then - error("Malformed number") - end - - repeat - pos = pos + 1 - if pos > len then break end - b = data:byte(pos) - until b < b_0 or b > b_9 - end - - -- Exponent - if pos <= len and (b == b_e or b == b_E) then - pos = pos + 1 - if pos > len then - error("Malformed number") - end - - b = data:byte(pos) - if b == b_plus or b == b_minus then - pos = pos + 1 - if pos > len then - error("Malformed number") - end - b = data:byte(pos) + -- Skip any sign + if b == b_minus then + pos = pos + 1 + if pos > len then + error("Malformed number") + end + b = data:byte(pos) end + -- Integer part if b < b_0 or b > b_9 then - error("Malformed number") + error("Malformed number") end repeat - pos = pos + 1 - if pos > len then break end - b = data:byte(pos) + pos = pos + 1 + if pos > len then break end + b = data:byte(pos) until b < b_0 or b > b_9 - end - return tonumber(data:sub(start, pos - 1)) + -- Fractional part + if pos <= len and b == b_dot then + pos = pos + 1 + if pos > len or data:byte(pos) < b_0 or data:byte(pos) > b_9 then + error("Malformed number") + end + + repeat + pos = pos + 1 + if pos > len then break end + b = data:byte(pos) + until b < b_0 or b > b_9 + end + + -- Exponent + if pos <= len and (b == b_e or b == b_E) then + pos = pos + 1 + if pos > len then + error("Malformed number") + end + + b = data:byte(pos) + if b == b_plus or b == b_minus then + pos = pos + 1 + if pos > len then + error("Malformed number") + end + b = data:byte(pos) + end + + if b < b_0 or b > b_9 then + error("Malformed number") + end + + repeat + pos = pos + 1 + if pos > len then break end + b = data:byte(pos) + until b < b_0 or b > b_9 + end + + return tonumber(data:sub(start, pos - 1)) end -- Parse an object more efficiently parse_object = function() - pos = pos + 1 -- Skip opening brace - local obj = {} + pos = pos + 1 -- Skip opening brace + local obj = {} - skip() - if pos <= len and data:byte(pos) == b_rcurly then - pos = pos + 1 - return obj - end - - while pos <= len do skip() - - if data:byte(pos) ~= b_quote then - error("Expected string key") + if pos <= len and data:byte(pos) == b_rcurly then + pos = pos + 1 + return obj end - local key = parse_string() - skip() + while pos <= len do + skip() - if data:byte(pos) ~= b_colon then - error("Expected colon") - end - pos = pos + 1 + if data:byte(pos) ~= b_quote then + error("Expected string key") + end - obj[key] = parse_value() - skip() + local key = parse_string() + skip() - local b = data:byte(pos) - if b == b_rcurly then - pos = pos + 1 - return obj + if data:byte(pos) ~= b_colon then + error("Expected colon") + end + pos = pos + 1 + + obj[key] = parse_value() + skip() + + local b = data:byte(pos) + if b == b_rcurly then + pos = pos + 1 + return obj + end + + if b ~= b_comma then + error("Expected comma or closing brace") + end + pos = pos + 1 end - if b ~= b_comma then - error("Expected comma or closing brace") - end - pos = pos + 1 - end - - error("Unterminated object") + error("Unterminated object") end -- Parse an array more efficiently parse_array = function() - pos = pos + 1 -- Skip opening bracket - local arr = {} - local index = 1 - - skip() - if pos <= len and data:byte(pos) == b_rbracket then - pos = pos + 1 - return arr - end - - while pos <= len do - arr[index] = parse_value() - index = index + 1 + pos = pos + 1 -- Skip opening bracket + local arr = {} + local index = 1 skip() - - local b = data:byte(pos) - if b == b_rbracket then - pos = pos + 1 - return arr + if pos <= len and data:byte(pos) == b_rbracket then + pos = pos + 1 + return arr end - if b ~= b_comma then - error("Expected comma or closing bracket") - end - pos = pos + 1 - end + while pos <= len do + arr[index] = parse_value() + index = index + 1 - error("Unterminated array") + skip() + + local b = data:byte(pos) + if b == b_rbracket then + pos = pos + 1 + return arr + end + + if b ~= b_comma then + error("Expected comma or closing bracket") + end + pos = pos + 1 + end + + error("Unterminated array") end -- Parse a value more efficiently parse_value = function() - skip() + skip() - if pos > len then - error("Unexpected end of input") - end + if pos > len then + error("Unexpected end of input") + end - local b = data:byte(pos) + local b = data:byte(pos) - if b == b_quote then - return parse_string() - elseif b == b_lcurly then - return parse_object() - elseif b == b_lbracket then - return parse_array() - elseif b == string.byte('n') and pos + 3 <= len and data:sub(pos, pos + 3) == "null" then - pos = pos + 4 - return nil - elseif b == string.byte('t') and pos + 3 <= len and data:sub(pos, pos + 3) == "true" then - pos = pos + 4 - return true - elseif b == string.byte('f') and pos + 4 <= len and data:sub(pos, pos + 4) == "false" then - pos = pos + 5 - return false - elseif b == b_minus or (b >= b_0 and b <= b_9) then - return parse_number() - else - error("Unexpected character: " .. string.char(b)) - end + if b == b_quote then + return parse_string() + elseif b == b_lcurly then + return parse_object() + elseif b == b_lbracket then + return parse_array() + elseif b == string.byte('n') and pos + 3 <= len and data:sub(pos, pos + 3) == "null" then + pos = pos + 4 + return nil + elseif b == string.byte('t') and pos + 3 <= len and data:sub(pos, pos + 3) == "true" then + pos = pos + 4 + return true + elseif b == string.byte('f') and pos + 4 <= len and data:sub(pos, pos + 4) == "false" then + pos = pos + 5 + return false + elseif b == b_minus or (b >= b_0 and b <= b_9) then + return parse_number() + else + error("Unexpected character: " .. string.char(b)) + end end skip() @@ -366,68 +357,66 @@ function json.decode(data) skip() if pos <= len then - error("Unexpected trailing characters") + error("Unexpected trailing characters") end return result end -function json.is_valid(str) - if type(str) ~= "string" then return false end - local status, _ = pcall(json.decode, str) - return status +function json_is_valid(str) + if type(str) ~= "string" then return false end + local status, _ = pcall(json_decode, str) + return status end -function json.pretty_print(value) - if type(value) == "string" then - value = json.decode(value) - end +function json_pretty_print(value) + if type(value) == "string" then + value = json_decode(value) + end - local function stringify(val, indent, visited) - visited = visited or {} - indent = indent or 0 - local spaces = string.rep(" ", indent) + local function stringify(val, indent, visited) + visited = visited or {} + indent = indent or 0 + local spaces = string.rep(" ", indent) - if type(val) == "table" then - if visited[val] then return "{...}" end - visited[val] = true + if type(val) == "table" then + if visited[val] then return "{...}" end + visited[val] = true - local isArray = true - local i = 1 - for k in pairs(val) do - if type(k) ~= "number" or k ~= i then - isArray = false - break - end - i = i + 1 - end + local isArray = true + local i = 1 + for k in pairs(val) do + if type(k) ~= "number" or k ~= i then + isArray = false + break + end + i = i + 1 + end - local result = isArray and "[\n" or "{\n" - local first = true + local result = isArray and "[\n" or "{\n" + local first = true - if isArray then - for i, v in ipairs(val) do - if not first then result = result .. ",\n" end - first = false - result = result .. spaces .. " " .. stringify(v, indent + 1, visited) - end - else - for k, v in pairs(val) do - if not first then result = result .. ",\n" end - first = false - result = result .. spaces .. " \"" .. tostring(k) .. "\": " .. stringify(v, indent + 1, visited) - end - end + if isArray then + for i, v in ipairs(val) do + if not first then result = result .. ",\n" end + first = false + result = result .. spaces .. " " .. stringify(v, indent + 1, visited) + end + else + for k, v in pairs(val) do + if not first then result = result .. ",\n" end + first = false + result = result .. spaces .. " \"" .. tostring(k) .. "\": " .. stringify(v, indent + 1, visited) + end + end - return result .. "\n" .. spaces .. (isArray and "]" or "}") - elseif type(val) == "string" then - return "\"" .. val:gsub('\\', '\\\\'):gsub('"', '\\"'):gsub('\n', '\\n') .. "\"" - else - return tostring(val) - end - end + return result .. "\n" .. spaces .. (isArray and "]" or "}") + elseif type(val) == "string" then + return "\"" .. val:gsub('\\', '\\\\'):gsub('"', '\\"'):gsub('\n', '\\n') .. "\"" + else + return tostring(val) + end + end - return stringify(value) -end - -return json + return stringify(value) +end \ No newline at end of file diff --git a/runner/lua/sandbox.lua b/runner/lua/sandbox.lua index fd00ce8..b205a51 100644 --- a/runner/lua/sandbox.lua +++ b/runner/lua/sandbox.lua @@ -60,126 +60,121 @@ function __ensure_response() end -- ====================================================================== --- HTTP MODULE +-- HTTP FUNCTIONS -- ====================================================================== -local http = { - -- Set HTTP status code - set_status = function(code) - if type(code) ~= "number" then - error("http.set_status: status code must be a number", 2) +-- Set HTTP status code +function http_set_status(code) + if type(code) ~= "number" then + error("http_set_status: status code must be a number", 2) + end + + local resp = __ensure_response() + resp.status = code +end + +-- Set HTTP header +function http_set_header(name, value) + if type(name) ~= "string" or type(value) ~= "string" then + error("http_set_header: name and value must be strings", 2) + end + + local resp = __ensure_response() + resp.headers = resp.headers or {} + resp.headers[name] = value +end + +-- Set content type; http_set_header helper +function http_set_content_type(content_type) + http_set_header("Content-Type", content_type) +end + +-- Set metadata (arbitrary data to be returned with response) +function http_set_metadata(key, value) + if type(key) ~= "string" then + error("http_set_metadata: key must be a string", 2) + end + + local resp = __ensure_response() + resp.metadata = resp.metadata or {} + resp.metadata[key] = value +end + +-- Generic HTTP request function +function http_request(method, url, body, options) + if type(method) ~= "string" then + error("http_request: method must be a string", 2) + end + if type(url) ~= "string" then + error("http_request: url must be a string", 2) + end + + -- Call native implementation + local result = __http_request(method, url, body, options) + return result +end + +-- Shorthand function to directly get JSON +function http_get_json(url, options) + options = options or {} + local response = http_get(url, options) + if response.ok and response.json then + return response.json + end + return nil, response +end + +-- Utility to build a URL with query parameters +function http_build_url(base_url, params) + if not params or type(params) ~= "table" then + return base_url + end + + local query = {} + for k, v in pairs(params) do + if type(v) == "table" then + for _, item in ipairs(v) do + table.insert(query, url_encode(k) .. "=" .. url_encode(tostring(item))) + end + else + table.insert(query, url_encode(k) .. "=" .. url_encode(tostring(v))) end + end - local resp = __ensure_response() - resp.status = code - end, - - -- Set HTTP header - set_header = function(name, value) - if type(name) ~= "string" or type(value) ~= "string" then - error("http.set_header: name and value must be strings", 2) + if #query > 0 then + if string.contains(base_url, "?") then + return base_url .. "&" .. table.concat(query, "&") + else + return base_url .. "?" .. table.concat(query, "&") end + end - local resp = __ensure_response() - resp.headers = resp.headers or {} - resp.headers[name] = value - end, - - -- Set content type; set_header helper - set_content_type = function(content_type) - http.set_header("Content-Type", content_type) - end, - - -- Set metadata (arbitrary data to be returned with response) - set_metadata = function(key, value) - if type(key) ~= "string" then - error("http.set_metadata: key must be a string", 2) - end - - local resp = __ensure_response() - resp.metadata = resp.metadata or {} - resp.metadata[key] = value - end, - - -- HTTP client submodule - client = { - -- Generic request function - request = function(method, url, body, options) - if type(method) ~= "string" then - error("http.client.request: method must be a string", 2) - end - if type(url) ~= "string" then - error("http.client.request: url must be a string", 2) - end - - -- Call native implementation - local result = __http_request(method, url, body, options) - return result - end, - - -- Shorthand function to directly get JSON - get_json = function(url, options) - options = options or {} - local response = http.client.get(url, options) - if response.ok and response.json then - return response.json - end - return nil, response - end, - - -- Utility to build a URL with query parameters - build_url = function(base_url, params) - if not params or type(params) ~= "table" then - return base_url - end - - local query = {} - for k, v in pairs(params) do - if type(v) == "table" then - for _, item in ipairs(v) do - table.insert(query, util.url_encode(k) .. "=" .. util.url_encode(tostring(item))) - end - else - table.insert(query, util.url_encode(k) .. "=" .. util.url_encode(tostring(v))) - end - end - - if #query > 0 then - if string.contains(base_url, "?") then - return base_url .. "&" .. table.concat(query, "&") - else - return base_url .. "?" .. table.concat(query, "&") - end - end - - return base_url - end - } -} + return base_url +end local function make_method(method, needs_body) return function(url, body_or_options, options) if needs_body then options = options or {} - return http.client.request(method, url, body_or_options, options) + return http_request(method, url, body_or_options, options) else body_or_options = body_or_options or {} - return http.client.request(method, url, nil, body_or_options) + return http_request(method, url, nil, body_or_options) end end end -http.client.get = make_method("GET", false) -http.client.delete = make_method("DELETE", false) -http.client.head = make_method("HEAD", false) -http.client.options = make_method("OPTIONS", false) -http.client.post = make_method("POST", true) -http.client.put = make_method("PUT", true) -http.client.patch = make_method("PATCH", true) +http_get = make_method("GET", false) +http_delete = make_method("DELETE", false) +http_head = make_method("HEAD", false) +http_options = make_method("OPTIONS", false) +http_post = make_method("POST", true) +http_put = make_method("PUT", true) +http_patch = make_method("PATCH", true) -http.redirect = function(url, status) +function http_redirect(url, status) if type(url) ~= "string" then - error("http.redirect: url must be a string", 2) + error("http_redirect: url must be a string", 2) end status = status or 302 -- Default to temporary redirect @@ -194,228 +189,227 @@ http.redirect = function(url, status) end -- ====================================================================== --- COOKIE MODULE +-- COOKIE FUNCTIONS -- ====================================================================== -local cookie = { - -- Set a cookie - set = function(name, value, options) - if type(name) ~= "string" then - error("cookie.set: name must be a string", 2) - end - - local resp = __ensure_response() - resp.cookies = resp.cookies or {} - - local opts = options or {} - local cookie = { - name = name, - value = value or "", - path = opts.path or "/", - domain = opts.domain - } - - if opts.expires then - if type(opts.expires) == "number" then - if opts.expires > 0 then - cookie.max_age = opts.expires - local now = os.time() - cookie.expires = now + opts.expires - elseif opts.expires < 0 then - cookie.expires = 1 - cookie.max_age = 0 - end - -- opts.expires == 0: Session cookie (omitting both expires and max-age) - end - end - - cookie.secure = (opts.secure ~= false) - cookie.http_only = (opts.http_only ~= false) - - if opts.same_site then - local same_site = string.trim(opts.same_site):lower() - local valid_values = {none = true, lax = true, strict = true} - - if not valid_values[same_site] then - error("cookie.set: same_site must be one of 'None', 'Lax', or 'Strict'", 2) - end - - -- If SameSite=None, the cookie must be secure - if same_site == "none" and not cookie.secure then - cookie.secure = true - end - - cookie.same_site = opts.same_site - else - cookie.same_site = "Lax" - end - - table.insert(resp.cookies, cookie) - return true - end, - - -- Get a cookie value - get = function(name) - if type(name) ~= "string" then - error("cookie.get: name must be a string", 2) - end - - local env = getfenv(2) - - if env.ctx and env.ctx.cookies then - return env.ctx.cookies[name] - end - - if env.ctx and env.ctx._request_cookies then - return env.ctx._request_cookies[name] - end - - return nil - end, - - -- Remove a cookie - remove = function(name, path, domain) - if type(name) ~= "string" then - error("cookie.remove: name must be a string", 2) - end - - return cookie.set(name, "", {expires = 0, path = path or "/", domain = domain}) +-- Set a cookie +function cookie_set(name, value, options) + if type(name) ~= "string" then + error("cookie_set: name must be a string", 2) end -} --- ====================================================================== --- SESSION MODULE --- ====================================================================== + local resp = __ensure_response() + resp.cookies = resp.cookies or {} -local session = { - get = function(key) - if type(key) ~= "string" then - error("session.get: key must be a string", 2) - end + local opts = options or {} + local cookie = { + name = name, + value = value or "", + path = opts.path or "/", + domain = opts.domain + } - local env = getfenv(2) - - if env.ctx and env.ctx.session and env.ctx.session.data then - return env.ctx.session.data[key] - end - - return nil - end, - - set = function(key, value) - if type(key) ~= "string" then - error("session.set: key must be a string", 2) - end - if type(value) == nil then - error("session.set: value cannot be nil", 2) - end - - local resp = __ensure_response() - resp.session = resp.session or {} - resp.session[key] = value - end, - - id = function() - local env = getfenv(2) - - if env.ctx and env.ctx.session then - return env.ctx.session.id - end - - return nil - end, - - get_all = function() - local env = getfenv(2) - - if env.ctx and env.ctx.session then - return env.ctx.session.data - end - - return nil - end, - - delete = function(key) - if type(key) ~= "string" then - error("session.delete: key must be a string", 2) - end - - local resp = __ensure_response() - resp.session = resp.session or {} - resp.session[key] = "__SESSION_DELETE_MARKER__" - - local env = getfenv(2) - if env.ctx and env.ctx.session and env.ctx.session.data then - env.ctx.session.data[key] = nil - end - end, - - clear = function() - local env = getfenv(2) - if env.ctx and env.ctx.session and env.ctx.session.data then - for k, _ in pairs(env.ctx.session.data) do - env.ctx.session.data[k] = nil + if opts.expires then + if type(opts.expires) == "number" then + if opts.expires > 0 then + cookie.max_age = opts.expires + local now = os.time() + cookie.expires = now + opts.expires + elseif opts.expires < 0 then + cookie.expires = 1 + cookie.max_age = 0 end + -- opts.expires == 0: Session cookie (omitting both expires and max-age) + end + end + + cookie.secure = (opts.secure ~= false) + cookie.http_only = (opts.http_only ~= false) + + if opts.same_site then + local same_site = string.trim(opts.same_site):lower() + local valid_values = {none = true, lax = true, strict = true} + + if not valid_values[same_site] then + error("cookie_set: same_site must be one of 'None', 'Lax', or 'Strict'", 2) end - local resp = __ensure_response() - resp.session = {} - resp.session["__clear_all"] = true + -- If SameSite=None, the cookie must be secure + if same_site == "none" and not cookie.secure then + cookie.secure = true + end + + cookie.same_site = opts.same_site + else + cookie.same_site = "Lax" end -} + + table.insert(resp.cookies, cookie) + return true +end + +-- Get a cookie value +function cookie_get(name) + if type(name) ~= "string" then + error("cookie_get: name must be a string", 2) + end + + local env = getfenv(2) + + if env.ctx and env.ctx.cookies then + return env.ctx.cookies[name] + end + + if env.ctx and env.ctx._request_cookies then + return env.ctx._request_cookies[name] + end + + return nil +end + +-- Remove a cookie +function cookie_remove(name, path, domain) + if type(name) ~= "string" then + error("cookie_remove: name must be a string", 2) + end + + return cookie_set(name, "", {expires = 0, path = path or "/", domain = domain}) +end -- ====================================================================== --- CSRF MODULE +-- SESSION FUNCTIONS -- ====================================================================== -local csrf = { - generate = function() - local token = util.generate_token(32) - session.set("_csrf_token", token) - return token - end, - - field = function() - local token = session.get("_csrf_token") - if not token then - token = csrf.generate() - end - return string.format('', - util.html_special_chars(token)) - end, - - validate = function() - local env = getfenv(2) - local token = false - if env.ctx and env.ctx.session and env.ctx.session.data then - token = env.ctx.session.data["_csrf_token"] - end - - if not token then - http.set_status(403) - __http_response.body = "CSRF validation failed" - exit() - end - - local request_token = nil - if env.ctx and env.ctx.form then - request_token = env.ctx.form._csrf_token - end - - if not request_token and env.ctx and env.ctx._request_headers then - request_token = env.ctx._request_headers["x-csrf-token"] or - env.ctx._request_headers["csrf-token"] - end - - if not request_token or request_token ~= token then - http.set_status(403) - __http_response.body = "CSRF validation failed" - exit() - end - - return true +function session_get(key) + if type(key) ~= "string" then + error("session_get: key must be a string", 2) end -} + + local env = getfenv(2) + + if env.ctx and env.ctx.session and env.ctx.session.data then + return env.ctx.session.data[key] + end + + return nil +end + +function session_set(key, value) + if type(key) ~= "string" then + error("session_set: key must be a string", 2) + end + if type(value) == nil then + error("session_set: value cannot be nil", 2) + end + + local resp = __ensure_response() + resp.session = resp.session or {} + resp.session[key] = value + + local env = getfenv(2) + if env.ctx and env.ctx.session and env.ctx.session.data then + env.ctx.session.data[key] = value + end +end + +function session_id() + local env = getfenv(2) + + if env.ctx and env.ctx.session then + return env.ctx.session.id + end + + return nil +end + +function session_get_all() + local env = getfenv(2) + + if env.ctx and env.ctx.session then + return env.ctx.session.data + end + + return nil +end + +function session_delete(key) + if type(key) ~= "string" then + error("session_delete: key must be a string", 2) + end + + local resp = __ensure_response() + resp.session = resp.session or {} + resp.session[key] = "__SESSION_DELETE_MARKER__" + + local env = getfenv(2) + if env.ctx and env.ctx.session and env.ctx.session.data then + env.ctx.session.data[key] = nil + end +end + +function session_clear() + local env = getfenv(2) + if env.ctx and env.ctx.session and env.ctx.session.data then + for k, _ in pairs(env.ctx.session.data) do + env.ctx.session.data[k] = nil + end + end + + local resp = __ensure_response() + resp.session = {} + resp.session["__clear_all"] = true +end + +-- ====================================================================== +-- CSRF FUNCTIONS +-- ====================================================================== + +function csrf_generate() + local token = generate_token(32) + session_set("_csrf_token", token) + return token +end + +function csrf_field() + local token = session_get("_csrf_token") + if not token then + token = csrf_generate() + end + return string.format('', + html_special_chars(token)) +end + +function csrf_validate() + local env = getfenv(2) + local token = false + if env.ctx and env.ctx.session and env.ctx.session.data then + token = env.ctx.session.data["_csrf_token"] + end + + if not token then + http_set_status(403) + __http_response.body = "CSRF validation failed" + exit() + end + + local request_token = nil + if env.ctx and env.ctx.form then + request_token = env.ctx.form._csrf_token + end + + if not request_token and env.ctx and env.ctx._request_headers then + request_token = env.ctx._request_headers["x-csrf-token"] or + env.ctx._request_headers["csrf-token"] + end + + if not request_token or request_token ~= token then + http_set_status(403) + __http_response.body = "CSRF validation failed" + exit() + end + + return true +end -- ====================================================================== -- TEMPLATE RENDER FUNCTIONS @@ -502,7 +496,7 @@ _G.render = function(template_str, env) setfenv(fn, runtime_env) local output_buffer = {} - fn(tostring, util.html_special_chars, output_buffer, 0) + fn(tostring, html_special_chars, output_buffer, 0) return table.concat(output_buffer) end @@ -536,7 +530,7 @@ _G.parse = function(template_str, env) local value = env[name] local str = tostring(value or "") if escaped then - str = util.html_special_chars(str) + str = html_special_chars(str) end table.insert(output, str) @@ -576,7 +570,7 @@ _G.iparse = function(template_str, values) local value = values[value_index] local str = tostring(value or "") if escaped then - str = util.html_special_chars(str) + str = html_special_chars(str) end table.insert(output, str) @@ -588,11 +582,9 @@ _G.iparse = function(template_str, values) end -- ====================================================================== --- PASSWORD MODULE +-- PASSWORD FUNCTIONS -- ====================================================================== -local password = {} - -- Hash a password using Argon2id -- Options: -- memory: Amount of memory to use in KB (default: 128MB) @@ -600,85 +592,72 @@ local password = {} -- parallelism: Number of threads (default: 4) -- salt_length: Length of salt in bytes (default: 16) -- key_length: Length of the derived key in bytes (default: 32) -function password.hash(plain_password, options) +function password_hash(plain_password, options) if type(plain_password) ~= "string" then - error("password.hash: expected string password", 2) + error("password_hash: expected string password", 2) end return __password_hash(plain_password, options) end -- Verify a password against a hash -function password.verify(plain_password, hash_string) +function password_verify(plain_password, hash_string) if type(plain_password) ~= "string" then - error("password.verify: expected string password", 2) + error("password_verify: expected string password", 2) end if type(hash_string) ~= "string" then - error("password.verify: expected string hash", 2) + error("password_verify: expected string hash", 2) end return __password_verify(plain_password, hash_string) end -- ====================================================================== --- SEND MODULE +-- SEND FUNCTIONS -- ====================================================================== -local send = {} - -function send.html(content) - http.set_content_type("text/html") +function send_html(content) + http_set_content_type("text/html") return content end -function send.json(content) - http.set_content_type("application/json") +function send_json(content) + http_set_content_type("application/json") return content end -function send.text(content) - http.set_content_type("text/plain") +function send_text(content) + http_set_content_type("text/plain") return content end -function send.xml(content) - http.set_content_type("application/xml") +function send_xml(content) + http_set_content_type("application/xml") return content end -function send.javascript(content) - http.set_content_type("application/javascript") +function send_javascript(content) + http_set_content_type("application/javascript") return content end -function send.css(content) - http.set_content_type("text/css") +function send_css(content) + http_set_content_type("text/css") return content end -function send.svg(content) - http.set_content_type("image/svg+xml") +function send_svg(content) + http_set_content_type("image/svg+xml") return content end -function send.csv(content) - http.set_content_type("text/csv") +function send_csv(content) + http_set_content_type("text/csv") return content end -function send.binary(content, mime_type) - http.set_content_type(mime_type or "application/octet-stream") +function send_binary(content, mime_type) + http_set_content_type(mime_type or "application/octet-stream") return content -end - --- ====================================================================== --- REGISTER MODULES GLOBALLY --- ====================================================================== - -_G.http = http -_G.session = session -_G.csrf = csrf -_G.cookie = cookie -_G.password = password -_G.send = send +end \ No newline at end of file diff --git a/runner/lua/util.lua b/runner/lua/util.lua index 99ecef8..a3e7871 100644 --- a/runner/lua/util.lua +++ b/runner/lua/util.lua @@ -1,16 +1,13 @@ --[[ util.lua - Utility functions for the Lua sandbox -Enhanced with web development utilities ]]-- -local util = {} - -- ====================================================================== -- CORE UTILITY FUNCTIONS -- ====================================================================== -- Generate a random token -function util.generate_token(length) +function generate_token(length) return __generate_token(length or 32) end @@ -18,20 +15,8 @@ end -- HTML ENTITY FUNCTIONS -- ====================================================================== --- HTML entity mapping for common characters -local html_entities = { - ["&"] = "&", - ["<"] = "<", - [">"] = ">", - ['"'] = """, - ["'"] = "'", - ["/"] = "/", - ["`"] = "`", - ["="] = "=" -} - -- Convert special characters to HTML entities (like htmlspecialchars) -function util.html_special_chars(str) +function html_special_chars(str) if type(str) ~= "string" then return str end @@ -40,7 +25,7 @@ function util.html_special_chars(str) end -- Convert all applicable characters to HTML entities (like htmlentities) -function util.html_entities(str) +function html_entities(str) if type(str) ~= "string" then return str end @@ -49,7 +34,7 @@ function util.html_entities(str) end -- Convert HTML entities back to characters (simple version) -function util.html_entity_decode(str) +function html_entity_decode(str) if type(str) ~= "string" then return str end @@ -64,7 +49,7 @@ function util.html_entity_decode(str) end -- Convert newlines to
tags -function util.nl2br(str) +function nl2br(str) if type(str) ~= "string" then return str end @@ -77,7 +62,7 @@ end -- ====================================================================== -- URL encode a string -function util.url_encode(str) +function url_encode(str) if type(str) ~= "string" then return str end @@ -91,7 +76,7 @@ function util.url_encode(str) end -- URL decode a string -function util.url_decode(str) +function url_decode(str) if type(str) ~= "string" then return str end @@ -108,7 +93,7 @@ end -- ====================================================================== -- Email validation -function util.is_email(str) +function is_email(str) if type(str) ~= "string" then return false end @@ -119,7 +104,7 @@ function util.is_email(str) end -- URL validation -function util.is_url(str) +function is_url(str) if type(str) ~= "string" then return false end @@ -130,7 +115,7 @@ function util.is_url(str) end -- IP address validation (IPv4) -function util.is_ipv4(str) +function is_ipv4(str) if type(str) ~= "string" then return false end @@ -147,7 +132,7 @@ function util.is_ipv4(str) end -- Integer validation -function util.is_int(str) +function is_int(str) if type(str) == "number" then return math.floor(str) == str elseif type(str) ~= "string" then @@ -158,7 +143,7 @@ function util.is_int(str) end -- Float validation -function util.is_float(str) +function is_float(str) if type(str) == "number" then return true elseif type(str) ~= "string" then @@ -169,7 +154,7 @@ function util.is_float(str) end -- Boolean validation -function util.is_bool(value) +function is_bool(value) if type(value) == "boolean" then return true elseif type(value) ~= "string" and type(value) ~= "number" then @@ -183,7 +168,7 @@ function util.is_bool(value) end -- Convert to boolean -function util.to_bool(value) +function to_bool(value) if type(value) == "boolean" then return value elseif type(value) ~= "string" and type(value) ~= "number" then @@ -195,16 +180,16 @@ function util.to_bool(value) end -- Sanitize string (simple version) -function util.sanitize_string(str) +function sanitize_string(str) if type(str) ~= "string" then return "" end - return util.html_special_chars(str) + return html_special_chars(str) end -- Sanitize to integer -function util.sanitize_int(value) +function sanitize_int(value) if type(value) ~= "string" and type(value) ~= "number" then return 0 end @@ -215,7 +200,7 @@ function util.sanitize_int(value) end -- Sanitize to float -function util.sanitize_float(value) +function sanitize_float(value) if type(value) ~= "string" and type(value) ~= "number" then return 0 end @@ -226,7 +211,7 @@ function util.sanitize_float(value) end -- Sanitize URL -function util.sanitize_url(str) +function sanitize_url(str) if type(str) ~= "string" then return "" end @@ -235,12 +220,12 @@ function util.sanitize_url(str) str = str:gsub("[\000-\031]", "") -- Make sure it's a valid URL - if util.is_url(str) then + if is_url(str) then return str end -- Try to prepend http:// if it's missing - if not str:match("^https?://") and util.is_url("http://" .. str) then + if not str:match("^https?://") and is_url("http://" .. str) then return "http://" .. str end @@ -248,7 +233,7 @@ function util.sanitize_url(str) end -- Sanitize email -function util.sanitize_email(str) +function sanitize_email(str) if type(str) ~= "string" then return "" end @@ -257,7 +242,7 @@ function util.sanitize_email(str) str = str:gsub("[^%a%d%!%#%$%%%&%'%*%+%-%/%=%?%^%_%`%{%|%}%~%@%.%[%]]", "") -- Return only if it's a valid email - if util.is_email(str) then + if is_email(str) then return str end @@ -269,13 +254,13 @@ end -- ====================================================================== -- Basic XSS prevention -function util.xss_clean(str) +function xss_clean(str) if type(str) ~= "string" then return str end -- Convert problematic characters to entities - local result = util.html_special_chars(str) + local result = html_special_chars(str) -- Remove JavaScript event handlers result = result:gsub("on%w+%s*=", "") @@ -290,7 +275,7 @@ function util.xss_clean(str) end -- Base64 encode -function util.base64_encode(str) +function base64_encode(str) if type(str) ~= "string" then return str end @@ -299,12 +284,10 @@ function util.base64_encode(str) end -- Base64 decode -function util.base64_decode(str) +function base64_decode(str) if type(str) ~= "string" then return str end return __base64_decode(str) -end - -return util \ No newline at end of file +end \ No newline at end of file diff --git a/sessions/session.go b/sessions/session.go index c01fdf6..5c11823 100644 --- a/sessions/session.go +++ b/sessions/session.go @@ -81,6 +81,9 @@ func (s *Session) GetAll() map[string]any { // Set stores a value in the session func (s *Session) Set(key string, value any) { + if existing, ok := s.Data[key]; ok && deepEqual(existing, value) { + return // No change + } s.Data[key] = value s.UpdatedAt = time.Now() s.dirty = true @@ -346,3 +349,81 @@ func validate(v any) error { } return nil } + +// deepEqual efficiently compares two values for deep equality +func deepEqual(a, b any) bool { + if a == b { + return true + } + + if a == nil || b == nil { + return false + } + + switch va := a.(type) { + case string: + if vb, ok := b.(string); ok { + return va == vb + } + case int: + if vb, ok := b.(int); ok { + return va == vb + } + if vb, ok := b.(int64); ok { + return int64(va) == vb + } + case int64: + if vb, ok := b.(int64); ok { + return va == vb + } + if vb, ok := b.(int); ok { + return va == int64(vb) + } + case float64: + if vb, ok := b.(float64); ok { + return va == vb + } + case bool: + if vb, ok := b.(bool); ok { + return va == vb + } + case []byte: + if vb, ok := b.([]byte); ok { + if len(va) != len(vb) { + return false + } + for i, v := range va { + if v != vb[i] { + return false + } + } + return true + } + case map[string]any: + if vb, ok := b.(map[string]any); ok { + if len(va) != len(vb) { + return false + } + for k, v := range va { + if bv, exists := vb[k]; !exists || !deepEqual(v, bv) { + return false + } + } + return true + } + case []any: + if vb, ok := b.([]any); ok { + if len(va) != len(vb) { + return false + } + for i, v := range va { + if !deepEqual(v, vb[i]) { + return false + } + } + return true + } + } + + return false +}